Reimplemented the SP output formatters to hopefully work better and faster.
authorFredrik Tolf <fredrik@dolda2000.com>
Fri, 27 Dec 2013 09:40:35 +0000 (10:40 +0100)
committerFredrik Tolf <fredrik@dolda2000.com>
Fri, 27 Dec 2013 09:40:35 +0000 (10:40 +0100)
wrw/sp/cons.py
wrw/sp/util.py
wrw/sp/xhtml.py

index ff6a1b1..bc9bfa2 100644 (file)
@@ -77,3 +77,9 @@ class constructor(object):
 
     def __getattr__(self, name):
         return self._elcls(self._ns, name, self._ctx)
+
+class doctype(node):
+    def __init__(self, rootname, pubid, dtdid):
+        self.rootname = rootname
+        self.pubid = pubid
+        self.dtdid = dtdid
index e15b947..deec92d 100644 (file)
@@ -1,4 +1,4 @@
-import StringIO
+import itertools, StringIO
 from wrw import dispatch
 import cons
 
@@ -19,206 +19,298 @@ def findnsnames(el):
         names[el.ns] = None
     return names
 
+def flatiter(root, short=True):
+    yield ">", root
+    stack = [(root, 0)]
+    while len(stack) > 0:
+        el, i = stack[-1]
+        if i >= len(el.children):
+            yield "<", el
+            stack.pop()
+        else:
+            ch = el.children[i]
+            stack[-1] = el, i + 1
+            if isinstance(ch, cons.element):
+                if short and len(ch.children) == 0:
+                    yield "/", ch
+                else:
+                    yield ">", ch
+                    stack.append((ch, 0))
+            elif isinstance(ch, cons.text):
+                yield "", ch
+            elif isinstance(ch, cons.raw):
+                yield "!", ch
+            else:
+                raise Exception("Unknown object in element tree: " + el)
+
 class formatter(object):
-    def __init__(self, out, root, nsnames=None, charset="utf-8", doctype=None):
-        self.root = root
-        if nsnames is None:
-            nsnames = findnsnames(root)
-        self.nsnames = nsnames
-        self.out = out
+    def __init__(self, src, nsnames=None, charset="utf-8"):
+        self.src = src
+        self.nsnames = nsnames or {}
+        self.nextns = 1
+        self.first = False
+        self.buf = bytearray()
         self.charset = charset
-        self.doctype = doctype
 
     def write(self, text):
-        self.out.write(text.encode(self.charset))
+        self.buf.extend(text.encode(self.charset))
 
     def quotewrite(self, buf):
-        for ch in buf:
-            if ch == u'&':
-                self.write(u"&amp;")
-            elif ch == u'<':
-                self.write(u"&lt;")
-            elif ch == u'>':
-                self.write(u"&gt;")
-            else:
-                self.write(ch)
+        buf = buf.replace(u'&', u"&amp;")
+        buf = buf.replace(u'<', u"&lt;")
+        buf = buf.replace(u'>', u"&gt;")
+        self.write(buf)
 
-    def text(self, el):
-        self.quotewrite(el)
+    def __iter__(self):
+        return self
 
-    def rawcode(self, el):
-        self.write(el)
+    def elname(self, el):
+        ns = self.nsnames[el.ns]
+        if ns is None:
+            return el.name
+        else:
+            return ns + u":" + el.name
 
-    def attrval(self, buf):
-        qc, qt = (u"'", u"&apos;") if u'"' in buf else (u'"', u"&quot;")
+    def attrval(self, v):
+        qc, qt = (u"'", u"&apos;") if u'"' in v else (u'"', u"&quot;")
         self.write(qc)
-        for ch in buf:
-            if ch == u'&':
-                self.write(u"&amp;")
-            elif ch == u'<':
-                self.write(u"&lt;")
-            elif ch == u'>':
-                self.write(u"&gt;")
-            elif ch == qc:
-                self.write(qt)
-            else:
-                self.write(ch)
+        v = v.replace(u'&', u"&amp;")
+        v = v.replace(u'<', u"&lt;")
+        v = v.replace(u'>', u"&gt;")
+        v = v.replace(qc, qt)
+        self.write(v)
         self.write(qc)
 
     def attr(self, k, v):
         self.write(k)
-        self.write(u'=')
+        self.write(u"=")
         self.attrval(v)
 
-    def shorttag(self, el, **extra):
-        self.write(u'<' + self.elname(el))
-        for k, v in el.attrs.iteritems():
-            self.write(u' ')
-            self.attr(k, v)
-        for k, v in extra.iteritems():
-            self.write(u' ')
+    def attrs(self, attrs):
+        for k, v in attrs:
+            self.write(u" ")
             self.attr(k, v)
-        self.write(u" />")
 
-    def elname(self, el):
-        ns = self.nsnames[el.ns]
-        if ns is None:
-            return el.name
-        else:
-            return ns + u':' + el.name
-
-    def starttag(self, el, **extra):
-        self.write(u'<' + self.elname(el))
-        for k, v in el.attrs.iteritems():
-            self.write(u' ')
-            self.attr(k, v)
-        for k, v in extra.iteritems():
-            self.write(u' ')
-            self.attr(k, v)
-        self.write(u'>')
+    def inittag(self, el):
+        self.write(u"<" + self.elname(el))
+        attrs = el.attrs.iteritems()
+        if self.first:
+            nsnames = []
+            for ns, name in self.nsnames.iteritems():
+                if ns is None:
+                    if name is not None:
+                        raise Exception("null namespace must have null name, not" + name)
+                    continue
+                nsnames.append((u"xmlns" if name is None else (u"xmlns:" + name), ns))
+            attrs = itertools.chain(attrs, iter(nsnames))
+            self.first = False
+        self.attrs(attrs)
+
+    def starttag(self, el):
+        self.inittag(el)
+        self.write(u">")
+
+    def shorttag(self, el):
+        self.inittag(el)
+        self.write(u" />")
 
     def endtag(self, el):
-        self.write(u'</' + self.elname(el) + u'>')
+        self.write(u"</" + self.elname(el) + u">")
 
-    def longtag(self, el, **extra):
-        self.starttag(el, **extra)
-        for ch in el.children:
-            self.node(ch)
-        self.endtag(el)
+    def text(self, el):
+        self.quotewrite(el)
 
-    def element(self, el, **extra):
-        if len(el.children) == 0:
-            self.shorttag(el, **extra)
-        else:
-            self.longtag(el, **extra)
+    def rawcode(self, el):
+        self.write(el)
 
-    def node(self, el):
-        if isinstance(el, cons.element):
-            self.element(el)
-        elif isinstance(el, cons.text):
+    def start(self, el):
+        self.write(u'<?xml version="1.0" encoding="' + self.charset + u'" ?>\n')
+        if isinstance(el, cons.doctype):
+            self.write(u'<!DOCTYPE %s PUBLIC "%s" "%s">\n' % (el.rootname,
+                                                              el.pubid,
+                                                              el.dtdid))
+        self.first = True
+
+    def end(self, el):
+        pass
+
+    def next(self):
+        if self.src is None:
+            raise StopIteration()
+        try:
+            ev, el = next(self.src)
+        except StopIteration:
+            self.src = None
+            ev, el = "$", None
+        if ev == ">":
+            self.starttag(el)
+        elif ev == "/":
+            self.shorttag(el)
+        elif ev == "<":
+            self.endtag(el)
+        elif ev == "":
             self.text(el)
-        elif isinstance(el, cons.raw):
+        elif ev == "!":
             self.rawcode(el)
-        else:
-            raise Exception("Unknown object in element tree: " + el)
+        elif ev == "^":
+            self.start(el)
+        elif ev == "$":
+            self.end(el)
+        ret = str(self.buf)
+        self.buf[:] = ""
+        return ret
 
-    def start(self):
-        self.write(u'<?xml version="1.0" encoding="' + self.charset + u'" ?>\n')
-        if self.doctype:
-            self.write(u'<!DOCTYPE %s PUBLIC "%s" "%s">\n' % (self.root.name,
-                                                              self.doctype[0],
-                                                              self.doctype[1]))
-        extra = {}
-        for uri, nm in self.nsnames.iteritems():
-            if uri is None:
-                continue
-            if nm is None:
-                extra[u"xmlns"] = uri
-            else:
-                extra[u"xmlns:" + nm] = uri
-        self.element(self.root, **extra)
+    def nsname(self, el):
+        for t in type(self).__mro__:
+            ret = getattr(t, "defns", {}).get(el.ns, None)
+            if ret is not None:
+                return ret
+        if el.ns is None:
+            return None
+        ret = u"n" + unicode(self.nextns)
+        self.nextns += 1
+        return ret
+
+    def findnsnames(self, root):
+        fnames = {}
+        rnames = {}
+        def proc(el):
+            if isinstance(el, cons.element):
+                if el.ns not in fnames:
+                    nm = self.nsname(el)
+                    fnames[el.ns] = nm
+                    rnames[nm] = el.ns
+                for ch in el.children:
+                    proc(ch)
+        proc(root)
+        if None not in rnames:
+            fnames[root.ns] = None
+            rnames[None] = root.ns
+        self.nsnames = fnames
 
     @classmethod
-    def output(cls, out, el, *args, **kw):
-        cls(out=out, root=el, *args, **kw).start()
+    def output(cls, out, root, nsnames=None, doctype=None, **kw):
+        if isinstance(doctype, cons.doctype):
+            pass
+        elif doctype is not None:
+            doctype = cons.doctype(root.name, doctype[0], doctype[1])
+        src = itertools.chain(iter([("^", doctype)]), flatiter(root))
+        self = cls(src=src, nsnames=nsnames, **kw)
+        if nsnames is None:
+            self.findnsnames(root)
+        self.first = True
+        for piece in self:
+            out.write(piece)
 
     @classmethod
-    def fragment(cls, out, el, *args, **kw):
-        cls(out=out, root=el, *args, **kw).node(el)
+    def fragment(cls, out, root, nsnames=None, **kw):
+        self = cls(src=flatiter(root), nsnames=nsnames, **kw)
+        if nsnames is None:
+            self.findnsnames(root)
+        for piece in self:
+            out.write(piece)
 
     @classmethod
-    def format(cls, el, *args, **kw):
+    def format(cls, root, **kw):
         buf = StringIO.StringIO()
-        cls.output(buf, el, *args, **kw)
+        cls.output(buf, root, **kw)
         return buf.getvalue()
 
-    def update(self, **ch):
-        ret = type(self).__new__(type(self))
-        ret.__dict__.update(self.__dict__)
-        ret.__dict__.update(ch)
-        return ret
-
-class iwriter(object):
-    def __init__(self, out):
-        self.out = out
-        self.atbol = True
-        self.col = 0
-
-    def write(self, buf):
-        for c in buf:
-            if c == '\n':
-                self.col = 0
-            else:
-                self.col += 1
-            self.out.write(c)
-        self.atbol = False
-
-    def indent(self, indent):
-        if self.atbol:
-            return
-        if self.col != 0:
-            self.write('\n')
-        self.write(indent)
-        self.atbol = True
-
 class indenter(formatter):
     def __init__(self, indent=u"  ", *args, **kw):
         super(indenter, self).__init__(*args, **kw)
-        self.out = iwriter(self.out)
         self.indent = indent
+        self.col = 0
         self.curind = u""
+        self.atbreak = True
+        self.inline = False
+        self.stack = []
 
-    def simple(self, el):
-        for ch in el.children:
-            if not isinstance(ch, cons.text):
-                return False
-        return True
-
-    def longtag(self, el, **extra):
-        self.starttag(el, **extra)
-        sub = self
-        reind = False
-        if not self.simple(el):
-            sub = self.update(curind=self.curind + self.indent)
-            sub.reindent()
-            reind = True
+    def write(self, text):
+        lines = text.split(u"\n")
+        if len(lines) > 1:
+            for ln in lines[:-1]:
+                self.buf.extend(ln.encode(self.charset))
+                self.buf.extend("\n")
+            self.col = 0
+        self.buf.extend(lines[-1].encode(self.charset))
+        self.col += len(lines[-1])
+        self.atbreak = False
+
+    def br(self):
+        if not self.atbreak:
+            self.buf.extend((u"\n" + self.curind).encode(self.charset))
+            self.col = 0
+            self.atbreak = True
+
+    def inlinep(self, el):
         for ch in el.children:
-            sub.node(ch)
-        if reind:
-            self.reindent()
-        self.endtag(el)
-
-    def element(self, el, **extra):
-        super(indenter, self).element(el, **extra)
-        if self.out.col > 80 and self.simple(el):
-            self.reindent()
-
-    def reindent(self):
-        self.out.indent(self.curind.encode(self.charset))
-
-    def start(self):
-        super(indenter, self).start()
-        self.write('\n')
+            if isinstance(ch, cons.text):
+                return True
+        return False
+
+    def push(self, el):
+        self.stack.append((el, self.curind, self.inline))
+
+    def pop(self):
+        el, self.curind, self.inline = self.stack.pop()
+        return el
+
+    def starttag(self, el):
+        if not self.inline:
+            self.br()
+        self.push(el)
+        self.inline = self.inline or self.inlinep(el)
+        self.curind += self.indent
+        super(indenter, self).starttag(el)
+
+    def shorttag(self, el):
+        if not self.inline:
+            self.br()
+        super(indenter, self).shorttag(el)
+
+    def endtag(self, el):
+        il = self.inline
+        self.pop()
+        if not il:
+            self.br()
+        super(indenter, self).endtag(el)
+
+    def start(self, el):
+        super(indenter, self).start(el)
+        self.atbreak = True
+
+    def end(self, el):
+        self.br()
+
+class textindenter(indenter):
+    maxcol = 70
+
+    def text(self, el):
+        left = unicode(el)
+        while True:
+            if len(left) + self.col > self.maxcol:
+                bp = max(self.maxcol - self.col, 0)
+                for i in xrange(bp, -1, -1):
+                    if left[i].isspace():
+                        while i > 0 and left[i - 1].isspace(): i -= 1
+                        break
+                else:
+                    for i in xrange(bp + 1, len(left)):
+                        if left[i].isspace():
+                            break
+                    else:
+                        i = None
+                if i is None:
+                    self.quotewrite(left)
+                    break
+                else:
+                    self.quotewrite(left[:i])
+                    self.br()
+                    left = left[i + 1:].lstrip()
+            else:
+                self.quotewrite(left)
+                break
 
 class response(dispatch.restart):
     charset = "utf-8"
index f10d315..48ed41d 100644 (file)
@@ -39,13 +39,14 @@ def head(title=None, css=None):
 
 class htmlformatter(util.formatter):
     allowshort = set([u"br", u"hr", u"img", u"input", u"meta", u"link"])
-    def element(self, el, **extra):
+    def shorttag(self, el):
         if el.name in self.allowshort:
-            super(htmlformatter, self).element(el, **extra)
+            super(htmlformatter, self).shorttag(el)
         else:
-            self.longtag(el, **extra)
+            self.starttag(el)
+            self.endtag(el)
 
-class htmlindenter(util.indenter, htmlformatter):
+class htmlindenter(util.textindenter, htmlformatter):
     pass
 
 def forreq(req, tree):