Merge branch 'master' into python3
[wrw.git] / wrw / sp / util.py
index 0728a7e..e7d9c29 100644 (file)
@@ -1,4 +1,4 @@
-import io
+import itertools, io
 from .. import dispatch
 from . import cons
 
@@ -19,207 +19,298 @@ def findnsnames(el):
         names[el.ns] = None
     return names
 
+def flatiter(root, short=True):
+    yield ">", root
+    stack = [(root, 0)]
+    while len(stack) > 0:
+        el, i = stack[-1]
+        if i >= len(el.children):
+            yield "<", el
+            stack.pop()
+        else:
+            ch = el.children[i]
+            stack[-1] = el, i + 1
+            if isinstance(ch, cons.element):
+                if short and len(ch.children) == 0:
+                    yield "/", ch
+                else:
+                    yield ">", ch
+                    stack.append((ch, 0))
+            elif isinstance(ch, cons.text):
+                yield "", ch
+            elif isinstance(ch, cons.raw):
+                yield "!", ch
+            else:
+                raise Exception("Unknown object in element tree: " + el)
+
 class formatter(object):
-    def __init__(self, out, root, nsnames=None, charset="utf-8", doctype=None):
-        self.root = root
-        if nsnames is None:
-            nsnames = findnsnames(root)
-        self.nsnames = nsnames
-        self.out = out
+    def __init__(self, src, nsnames=None, charset="utf-8"):
+        self.src = src
+        self.nsnames = nsnames or {}
+        self.nextns = 1
+        self.first = False
+        self.buf = bytearray()
         self.charset = charset
-        self.doctype = doctype
 
     def write(self, text):
-        self.out.write(text.encode(self.charset))
+        self.buf.extend(text.encode(self.charset))
 
     def quotewrite(self, buf):
-        for ch in buf:
-            if ch == '&':
-                self.write("&amp;")
-            elif ch == '<':
-                self.write("&lt;")
-            elif ch == '>':
-                self.write("&gt;")
-            else:
-                self.write(ch)
+        buf = buf.replace('&', "&amp;")
+        buf = buf.replace('<', "&lt;")
+        buf = buf.replace('>', "&gt;")
+        self.write(buf)
 
-    def text(self, el):
-        self.quotewrite(el)
+    def __iter__(self):
+        return self
 
-    def rawcode(self, el):
-        self.write(el)
+    def elname(self, el):
+        ns = self.nsnames[el.ns]
+        if ns is None:
+            return el.name
+        else:
+            return ns + ":" + el.name
 
-    def attrval(self, buf):
-        qc, qt = ("'", "&apos;") if '"' in buf else ('"', "&quot;")
+    def attrval(self, v):
+        qc, qt = ("'", "&apos;") if '"' in v else ('"', "&quot;")
         self.write(qc)
-        for ch in buf:
-            if ch == '&':
-                self.write("&amp;")
-            elif ch == '<':
-                self.write("&lt;")
-            elif ch == '>':
-                self.write("&gt;")
-            elif ch == qc:
-                self.write(qt)
-            else:
-                self.write(ch)
+        v = v.replace('&', "&amp;")
+        v = v.replace('<', "&lt;")
+        v = v.replace('>', "&gt;")
+        v = v.replace(qc, qt)
+        self.write(v)
         self.write(qc)
 
     def attr(self, k, v):
         self.write(k)
-        self.write('=')
+        self.write("=")
         self.attrval(v)
 
-    def shorttag(self, el, **extra):
-        self.write('<' + self.elname(el))
-        for k, v in el.attrs.items():
-            self.write(' ')
-            self.attr(k, v)
-        for k, v in extra.items():
-            self.write(' ')
+    def attrs(self, attrs):
+        for k, v in attrs:
+            self.write(" ")
             self.attr(k, v)
-        self.write(" />")
 
-    def elname(self, el):
-        ns = self.nsnames[el.ns]
-        if ns is None:
-            return el.name
-        else:
-            return ns + ':' + el.name
-
-    def starttag(self, el, **extra):
-        self.write('<' + self.elname(el))
-        for k, v in el.attrs.items():
-            self.write(' ')
-            self.attr(k, v)
-        for k, v in extra.items():
-            self.write(' ')
-            self.attr(k, v)
-        self.write('>')
+    def inittag(self, el):
+        self.write("<" + self.elname(el))
+        attrs = el.attrs.items()
+        if self.first:
+            nsnames = []
+            for ns, name in self.nsnames.items():
+                if ns is None:
+                    if name is not None:
+                        raise Exception("null namespace must have null name, not" + name)
+                    continue
+                nsnames.append(("xmlns" if name is None else ("xmlns:" + name), ns))
+            attrs = itertools.chain(attrs, iter(nsnames))
+            self.first = False
+        self.attrs(attrs)
+
+    def starttag(self, el):
+        self.inittag(el)
+        self.write(">")
+
+    def shorttag(self, el):
+        self.inittag(el)
+        self.write(" />")
 
     def endtag(self, el):
-        self.write('</' + self.elname(el) + '>')
+        self.write("</" + self.elname(el) + ">")
 
-    def longtag(self, el, **extra):
-        self.starttag(el, **extra)
-        for ch in el.children:
-            self.node(ch)
-        self.endtag(el)
+    def text(self, el):
+        self.quotewrite(el)
 
-    def element(self, el, **extra):
-        if len(el.children) == 0:
-            self.shorttag(el, **extra)
-        else:
-            self.longtag(el, **extra)
+    def rawcode(self, el):
+        self.write(el)
 
-    def node(self, el):
-        if isinstance(el, cons.element):
-            self.element(el)
-        elif isinstance(el, cons.text):
+    def start(self, el):
+        self.write('<?xml version="1.0" encoding="' + self.charset + '" ?>\n')
+        if isinstance(el, cons.doctype):
+            self.write('<!DOCTYPE %s PUBLIC "%s" "%s">\n' % (el.rootname,
+                                                              el.pubid,
+                                                              el.dtdid))
+        self.first = True
+
+    def end(self, el):
+        pass
+
+    def __next__(self):
+        if self.src is None:
+            raise StopIteration()
+        try:
+            ev, el = next(self.src)
+        except StopIteration:
+            self.src = None
+            ev, el = "$", None
+        if ev == ">":
+            self.starttag(el)
+        elif ev == "/":
+            self.shorttag(el)
+        elif ev == "<":
+            self.endtag(el)
+        elif ev == "":
             self.text(el)
-        elif isinstance(el, cons.raw):
+        elif ev == "!":
             self.rawcode(el)
-        else:
-            raise Exception("Unknown object in element tree: " + el)
+        elif ev == "^":
+            self.start(el)
+        elif ev == "$":
+            self.end(el)
+        ret = bytes(self.buf)
+        self.buf[:] = b""
+        return ret
 
-    def start(self):
-        self.write('<?xml version="1.0" encoding="' + self.charset + '" ?>\n')
-        if self.doctype:
-            self.write('<!DOCTYPE %s PUBLIC "%s" "%s">\n' % (self.root.name,
-                                                             self.doctype[0],
-                                                             self.doctype[1]))
-        extra = {}
-        for uri, nm in self.nsnames.items():
-            if uri is None:
-                continue
-            if nm is None:
-                extra["xmlns"] = uri
-            else:
-                extra["xmlns:" + nm] = uri
-        self.element(self.root, **extra)
+    def nsname(self, el):
+        for t in type(self).__mro__:
+            ret = getattr(t, "defns", {}).get(el.ns, None)
+            if ret is not None:
+                return ret
+        if el.ns is None:
+            return None
+        ret = "n" + str(self.nextns)
+        self.nextns += 1
+        return ret
+
+    def findnsnames(self, root):
+        fnames = {}
+        rnames = {}
+        def proc(el):
+            if isinstance(el, cons.element):
+                if el.ns not in fnames:
+                    nm = self.nsname(el)
+                    fnames[el.ns] = nm
+                    rnames[nm] = el.ns
+                for ch in el.children:
+                    proc(ch)
+        proc(root)
+        if None not in rnames:
+            fnames[root.ns] = None
+            rnames[None] = root.ns
+        self.nsnames = fnames
 
     @classmethod
-    def output(cls, out, el, *args, **kw):
-        cls(out=out, root=el, *args, **kw).start()
+    def output(cls, out, root, nsnames=None, doctype=None, **kw):
+        if isinstance(doctype, cons.doctype):
+            pass
+        elif doctype is not None:
+            doctype = cons.doctype(root.name, doctype[0], doctype[1])
+        src = itertools.chain(iter([("^", doctype)]), flatiter(root))
+        self = cls(src=src, nsnames=nsnames, **kw)
+        if nsnames is None:
+            self.findnsnames(root)
+        self.first = True
+        for piece in self:
+            out.write(piece)
 
     @classmethod
-    def fragment(cls, out, el, *args, **kw):
-        cls(out=out, root=el, *args, **kw).node(el)
+    def fragment(cls, out, root, nsnames=None, **kw):
+        self = cls(src=flatiter(root), nsnames=nsnames, **kw)
+        if nsnames is None:
+            self.findnsnames(root)
+        for piece in self:
+            out.write(piece)
 
     @classmethod
-    def format(cls, el, *args, **kw):
+    def format(cls, root, **kw):
         buf = io.BytesIO()
-        cls.output(buf, el, *args, **kw)
+        cls.output(buf, root, **kw)
         return buf.getvalue()
 
-    def update(self, **ch):
-        ret = type(self).__new__(type(self))
-        ret.__dict__.update(self.__dict__)
-        ret.__dict__.update(ch)
-        return ret
-
-class iwriter(object):
-    def __init__(self, out):
-        self.out = out
-        self.atbol = True
-        self.col = 0
-
-    def write(self, buf):
-        for i in range(len(buf)):
-            c = buf[i:i + 1]
-            if c == b'\n':
-                self.col = 0
-            else:
-                self.col += 1
-            self.out.write(c)
-        self.atbol = False
-
-    def indent(self, indent):
-        if self.atbol:
-            return
-        if self.col != 0:
-            self.write(b'\n')
-        self.write(indent)
-        self.atbol = True
-
 class indenter(formatter):
     def __init__(self, indent="  ", *args, **kw):
-        super(indenter, self).__init__(*args, **kw)
-        self.out = iwriter(self.out)
+        super().__init__(*args, **kw)
         self.indent = indent
+        self.col = 0
         self.curind = ""
+        self.atbreak = True
+        self.inline = False
+        self.stack = []
 
-    def simple(self, el):
-        for ch in el.children:
-            if not isinstance(ch, cons.text):
-                return False
-        return True
-
-    def longtag(self, el, **extra):
-        self.starttag(el, **extra)
-        sub = self
-        reind = False
-        if not self.simple(el):
-            sub = self.update(curind=self.curind + self.indent)
-            sub.reindent()
-            reind = True
+    def write(self, text):
+        lines = text.split("\n")
+        if len(lines) > 1:
+            for ln in lines[:-1]:
+                self.buf.extend(ln.encode(self.charset))
+                self.buf.extend(b"\n")
+            self.col = 0
+        self.buf.extend(lines[-1].encode(self.charset))
+        self.col += len(lines[-1])
+        self.atbreak = False
+
+    def br(self):
+        if not self.atbreak:
+            self.buf.extend(("\n" + self.curind).encode(self.charset))
+            self.col = 0
+            self.atbreak = True
+
+    def inlinep(self, el):
         for ch in el.children:
-            sub.node(ch)
-        if reind:
-            self.reindent()
-        self.endtag(el)
-
-    def element(self, el, **extra):
-        super(indenter, self).element(el, **extra)
-        if self.out.col > 80 and self.simple(el):
-            self.reindent()
-
-    def reindent(self):
-        self.out.indent(self.curind.encode(self.charset))
-
-    def start(self):
-        super(indenter, self).start()
-        self.write('\n')
+            if isinstance(ch, cons.text):
+                return True
+        return False
+
+    def push(self, el):
+        self.stack.append((el, self.curind, self.inline))
+
+    def pop(self):
+        el, self.curind, self.inline = self.stack.pop()
+        return el
+
+    def starttag(self, el):
+        if not self.inline:
+            self.br()
+        self.push(el)
+        self.inline = self.inline or self.inlinep(el)
+        self.curind += self.indent
+        super().starttag(el)
+
+    def shorttag(self, el):
+        if not self.inline:
+            self.br()
+        super().shorttag(el)
+
+    def endtag(self, el):
+        il = self.inline
+        self.pop()
+        if not il:
+            self.br()
+        super().endtag(el)
+
+    def start(self, el):
+        super().start(el)
+        self.atbreak = True
+
+    def end(self, el):
+        self.br()
+
+class textindenter(indenter):
+    maxcol = 70
+
+    def text(self, el):
+        left = str(el)
+        while True:
+            if len(left) + self.col > self.maxcol:
+                bp = max(self.maxcol - self.col, 0)
+                for i in range(bp, -1, -1):
+                    if left[i].isspace():
+                        while i > 0 and left[i - 1].isspace(): i -= 1
+                        break
+                else:
+                    for i in range(bp + 1, len(left)):
+                        if left[i].isspace():
+                            break
+                    else:
+                        i = None
+                if i is None:
+                    self.quotewrite(left)
+                    break
+                else:
+                    self.quotewrite(left[:i])
+                    self.br()
+                    left = left[i + 1:].lstrip()
+            else:
+                self.quotewrite(left)
+                break
 
 class response(dispatch.restart):
     charset = "utf-8"