Merge branch 'master' into python3
[wrw.git] / wrw / sp / util.py
1 import itertools, io
2 from .. import dispatch
3 from . import cons
4
5 def findnsnames(el):
6     names = {}
7     nid = [1]
8     def proc(el):
9         if isinstance(el, cons.element):
10             if el.ns not in names:
11                 names[el.ns] = "n" + str(nid[0])
12                 nid[:] = [nid[0] + 1]
13             for ch in el.children:
14                 proc(ch)
15     proc(el)
16     if None in names:
17         names[None] = None
18     else:
19         names[el.ns] = None
20     return names
21
22 def flatiter(root, short=True):
23     yield ">", root
24     stack = [(root, 0)]
25     while len(stack) > 0:
26         el, i = stack[-1]
27         if i >= len(el.children):
28             yield "<", el
29             stack.pop()
30         else:
31             ch = el.children[i]
32             stack[-1] = el, i + 1
33             if isinstance(ch, cons.element):
34                 if short and len(ch.children) == 0:
35                     yield "/", ch
36                 else:
37                     yield ">", ch
38                     stack.append((ch, 0))
39             elif isinstance(ch, cons.text):
40                 yield "", ch
41             elif isinstance(ch, cons.raw):
42                 yield "!", ch
43             else:
44                 raise Exception("Unknown object in element tree: " + el)
45
46 class formatter(object):
47     def __init__(self, src, nsnames=None, charset="utf-8"):
48         self.src = src
49         self.nsnames = nsnames or {}
50         self.nextns = 1
51         self.first = False
52         self.buf = bytearray()
53         self.charset = charset
54
55     def write(self, text):
56         self.buf.extend(text.encode(self.charset))
57
58     def quotewrite(self, buf):
59         buf = buf.replace('&', "&amp;")
60         buf = buf.replace('<', "&lt;")
61         buf = buf.replace('>', "&gt;")
62         self.write(buf)
63
64     def __iter__(self):
65         return self
66
67     def elname(self, el):
68         ns = self.nsnames[el.ns]
69         if ns is None:
70             return el.name
71         else:
72             return ns + ":" + el.name
73
74     def attrval(self, v):
75         qc, qt = ("'", "&apos;") if '"' in v else ('"', "&quot;")
76         self.write(qc)
77         v = v.replace('&', "&amp;")
78         v = v.replace('<', "&lt;")
79         v = v.replace('>', "&gt;")
80         v = v.replace(qc, qt)
81         self.write(v)
82         self.write(qc)
83
84     def attr(self, k, v):
85         self.write(k)
86         self.write("=")
87         self.attrval(v)
88
89     def attrs(self, attrs):
90         for k, v in attrs:
91             self.write(" ")
92             self.attr(k, v)
93
94     def inittag(self, el):
95         self.write("<" + self.elname(el))
96         attrs = el.attrs.items()
97         if self.first:
98             nsnames = []
99             for ns, name in self.nsnames.items():
100                 if ns is None:
101                     if name is not None:
102                         raise Exception("null namespace must have null name, not" + name)
103                     continue
104                 nsnames.append(("xmlns" if name is None else ("xmlns:" + name), ns))
105             attrs = itertools.chain(attrs, iter(nsnames))
106             self.first = False
107         self.attrs(attrs)
108
109     def starttag(self, el):
110         self.inittag(el)
111         self.write(">")
112
113     def shorttag(self, el):
114         self.inittag(el)
115         self.write(" />")
116
117     def endtag(self, el):
118         self.write("</" + self.elname(el) + ">")
119
120     def text(self, el):
121         self.quotewrite(el)
122
123     def rawcode(self, el):
124         self.write(el)
125
126     def start(self, el):
127         self.write('<?xml version="1.0" encoding="' + self.charset + '" ?>\n')
128         if isinstance(el, cons.doctype):
129             self.write('<!DOCTYPE %s PUBLIC "%s" "%s">\n' % (el.rootname,
130                                                               el.pubid,
131                                                               el.dtdid))
132         self.first = True
133
134     def end(self, el):
135         pass
136
137     def __next__(self):
138         if self.src is None:
139             raise StopIteration()
140         try:
141             ev, el = next(self.src)
142         except StopIteration:
143             self.src = None
144             ev, el = "$", None
145         if ev == ">":
146             self.starttag(el)
147         elif ev == "/":
148             self.shorttag(el)
149         elif ev == "<":
150             self.endtag(el)
151         elif ev == "":
152             self.text(el)
153         elif ev == "!":
154             self.rawcode(el)
155         elif ev == "^":
156             self.start(el)
157         elif ev == "$":
158             self.end(el)
159         ret = bytes(self.buf)
160         self.buf[:] = b""
161         return ret
162
163     def nsname(self, el):
164         for t in type(self).__mro__:
165             ret = getattr(t, "defns", {}).get(el.ns, None)
166             if ret is not None:
167                 return ret
168         if el.ns is None:
169             return None
170         ret = "n" + str(self.nextns)
171         self.nextns += 1
172         return ret
173
174     def findnsnames(self, root):
175         fnames = {}
176         rnames = {}
177         def proc(el):
178             if isinstance(el, cons.element):
179                 if el.ns not in fnames:
180                     nm = self.nsname(el)
181                     fnames[el.ns] = nm
182                     rnames[nm] = el.ns
183                 for ch in el.children:
184                     proc(ch)
185         proc(root)
186         if None not in rnames:
187             fnames[root.ns] = None
188             rnames[None] = root.ns
189         self.nsnames = fnames
190
191     @classmethod
192     def output(cls, out, root, nsnames=None, doctype=None, **kw):
193         if isinstance(doctype, cons.doctype):
194             pass
195         elif doctype is not None:
196             doctype = cons.doctype(root.name, doctype[0], doctype[1])
197         src = itertools.chain(iter([("^", doctype)]), flatiter(root))
198         self = cls(src=src, nsnames=nsnames, **kw)
199         if nsnames is None:
200             self.findnsnames(root)
201         self.first = True
202         for piece in self:
203             out.write(piece)
204
205     @classmethod
206     def fragment(cls, out, root, nsnames=None, **kw):
207         self = cls(src=flatiter(root), nsnames=nsnames, **kw)
208         if nsnames is None:
209             self.findnsnames(root)
210         for piece in self:
211             out.write(piece)
212
213     @classmethod
214     def format(cls, root, **kw):
215         buf = io.BytesIO()
216         cls.output(buf, root, **kw)
217         return buf.getvalue()
218
219 class indenter(formatter):
220     def __init__(self, indent="  ", *args, **kw):
221         super().__init__(*args, **kw)
222         self.indent = indent
223         self.col = 0
224         self.curind = ""
225         self.atbreak = True
226         self.inline = False
227         self.stack = []
228
229     def write(self, text):
230         lines = text.split("\n")
231         if len(lines) > 1:
232             for ln in lines[:-1]:
233                 self.buf.extend(ln.encode(self.charset))
234                 self.buf.extend(b"\n")
235             self.col = 0
236         self.buf.extend(lines[-1].encode(self.charset))
237         self.col += len(lines[-1])
238         self.atbreak = False
239
240     def br(self):
241         if not self.atbreak:
242             self.buf.extend(("\n" + self.curind).encode(self.charset))
243             self.col = 0
244             self.atbreak = True
245
246     def inlinep(self, el):
247         for ch in el.children:
248             if isinstance(ch, cons.text):
249                 return True
250         return False
251
252     def push(self, el):
253         self.stack.append((el, self.curind, self.inline))
254
255     def pop(self):
256         el, self.curind, self.inline = self.stack.pop()
257         return el
258
259     def starttag(self, el):
260         if not self.inline:
261             self.br()
262         self.push(el)
263         self.inline = self.inline or self.inlinep(el)
264         self.curind += self.indent
265         super().starttag(el)
266
267     def shorttag(self, el):
268         if not self.inline:
269             self.br()
270         super().shorttag(el)
271
272     def endtag(self, el):
273         il = self.inline
274         self.pop()
275         if not il:
276             self.br()
277         super().endtag(el)
278
279     def start(self, el):
280         super().start(el)
281         self.atbreak = True
282
283     def end(self, el):
284         self.br()
285
286 class textindenter(indenter):
287     maxcol = 70
288
289     def text(self, el):
290         left = str(el)
291         while True:
292             if len(left) + self.col > self.maxcol:
293                 bp = max(self.maxcol - self.col, 0)
294                 for i in range(bp, -1, -1):
295                     if left[i].isspace():
296                         while i > 0 and left[i - 1].isspace(): i -= 1
297                         break
298                 else:
299                     for i in range(bp + 1, len(left)):
300                         if left[i].isspace():
301                             break
302                     else:
303                         i = None
304                 if i is None:
305                     self.quotewrite(left)
306                     break
307                 else:
308                     self.quotewrite(left[:i])
309                     self.br()
310                     left = left[i + 1:].lstrip()
311             else:
312                 self.quotewrite(left)
313                 break
314
315 class response(dispatch.restart):
316     charset = "utf-8"
317     doctype = None
318     formatter = indenter
319
320     def __init__(self, root):
321         super().__init__()
322         self.root = root
323
324     @property
325     def ctype(self):
326         raise Exception("a subclass of wrw.sp.util.response must override ctype")
327
328     def handle(self, req):
329         ret = self.formatter.format(self.root, doctype=self.doctype, charset=self.charset)
330         req.ohead["Content-Type"] = self.ctype
331         req.ohead["Content-Length"] = len(ret)
332         return [ret]