Merge branch 'master' into python3
[wrw.git] / wrw / sp / util.py
1 import itertools, io
2 from .. import dispatch
3 from . import cons
4
5 def findnsnames(el):
6     names = {}
7     nid = [1]
8     def proc(el):
9         if isinstance(el, cons.element):
10             if el.ns not in names:
11                 names[el.ns] = "n" + str(nid[0])
12                 nid[:] = [nid[0] + 1]
13             for ch in el.children:
14                 proc(ch)
15     proc(el)
16     if None in names:
17         names[None] = None
18     else:
19         names[el.ns] = None
20     return names
21
22 def flatiter(root, short=True):
23     yield ">", root
24     stack = [(root, 0)]
25     while len(stack) > 0:
26         el, i = stack[-1]
27         if i >= len(el.children):
28             yield "<", el
29             stack.pop()
30         else:
31             ch = el.children[i]
32             stack[-1] = el, i + 1
33             if isinstance(ch, cons.element):
34                 if short and len(ch.children) == 0:
35                     yield "/", ch
36                 else:
37                     yield ">", ch
38                     stack.append((ch, 0))
39             elif isinstance(ch, cons.text):
40                 yield "", ch
41             elif isinstance(ch, cons.raw):
42                 yield "!", ch
43             else:
44                 raise Exception("Unknown object in element tree: " + el)
45
46 class formatter(object):
47     def __init__(self, src, nsnames=None, charset="utf-8"):
48         self.src = src
49         self.nsnames = nsnames or {}
50         self.nextns = 1
51         self.first = False
52         self.buf = bytearray()
53         self.charset = charset
54
55     def write(self, text):
56         self.buf.extend(text.encode(self.charset))
57
58     def quotewrite(self, buf):
59         buf = buf.replace('&', "&amp;")
60         buf = buf.replace('<', "&lt;")
61         buf = buf.replace('>', "&gt;")
62         self.write(buf)
63
64     def __iter__(self):
65         return self
66
67     def elname(self, el):
68         ns = self.nsnames[el.ns]
69         if ns is None:
70             return el.name
71         else:
72             return ns + ":" + el.name
73
74     def attrval(self, v):
75         qc, qt = ("'", "&apos;") if '"' in v else ('"', "&quot;")
76         self.write(qc)
77         v = v.replace('&', "&amp;")
78         v = v.replace('<', "&lt;")
79         v = v.replace('>', "&gt;")
80         v = v.replace(qc, qt)
81         self.write(v)
82         self.write(qc)
83
84     def attr(self, k, v):
85         self.write(k)
86         self.write("=")
87         self.attrval(v)
88
89     def attrs(self, attrs):
90         for k, v in attrs:
91             self.write(" ")
92             self.attr(k, v)
93
94     def inittag(self, el):
95         self.write("<" + self.elname(el))
96         attrs = el.attrs.items()
97         if self.first:
98             nsnames = []
99             for ns, name in self.nsnames.items():
100                 if ns is None:
101                     if name is not None:
102                         raise Exception("null namespace must have null name, not" + name)
103                     continue
104                 nsnames.append(("xmlns" if name is None else ("xmlns:" + name), ns))
105             attrs = itertools.chain(attrs, iter(nsnames))
106             self.first = False
107         self.attrs(attrs)
108
109     def starttag(self, el):
110         self.inittag(el)
111         self.write(">")
112
113     def shorttag(self, el):
114         self.inittag(el)
115         self.write(" />")
116
117     def endtag(self, el):
118         self.write("</" + self.elname(el) + ">")
119
120     def text(self, el):
121         self.quotewrite(el)
122
123     def rawcode(self, el):
124         self.write(el)
125
126     def start(self, el):
127         self.write('<?xml version="1.0" encoding="' + self.charset + '" ?>\n')
128         if isinstance(el, cons.doctype):
129             self.write('<!DOCTYPE %s PUBLIC "%s" "%s">\n' % (el.rootname,
130                                                               el.pubid,
131                                                               el.dtdid))
132         self.first = True
133
134     def end(self, el):
135         pass
136
137     def handle(self, ev, el):
138         if ev == ">":
139             self.starttag(el)
140         elif ev == "/":
141             self.shorttag(el)
142         elif ev == "<":
143             self.endtag(el)
144         elif ev == "":
145             self.text(el)
146         elif ev == "!":
147             self.rawcode(el)
148         elif ev == "^":
149             self.start(el)
150         elif ev == "$":
151             self.end(el)
152
153     def __next__(self):
154         if self.src is None:
155             raise StopIteration()
156         try:
157             ev, el = next(self.src)
158         except StopIteration:
159             self.src = None
160             ev, el = "$", None
161         self.handle(ev, el)
162         ret = bytes(self.buf)
163         self.buf[:] = b""
164         return ret
165
166     def nsname(self, el):
167         for t in type(self).__mro__:
168             ret = getattr(t, "defns", {}).get(el.ns, None)
169             if ret is not None:
170                 return ret
171         if el.ns is None:
172             return None
173         ret = "n" + str(self.nextns)
174         self.nextns += 1
175         return ret
176
177     def findnsnames(self, root):
178         fnames = {}
179         rnames = {}
180         def proc(el):
181             if isinstance(el, cons.element):
182                 if el.ns not in fnames:
183                     nm = self.nsname(el)
184                     fnames[el.ns] = nm
185                     rnames[nm] = el.ns
186                 for ch in el.children:
187                     proc(ch)
188         proc(root)
189         if None not in rnames:
190             fnames[root.ns] = None
191             rnames[None] = root.ns
192         self.nsnames = fnames
193
194     @classmethod
195     def output(cls, out, root, nsnames=None, doctype=None, **kw):
196         if isinstance(doctype, cons.doctype):
197             pass
198         elif doctype is not None:
199             doctype = cons.doctype(root.name, doctype[0], doctype[1])
200         src = itertools.chain(iter([("^", doctype)]), flatiter(root))
201         self = cls(src=src, nsnames=nsnames, **kw)
202         if nsnames is None:
203             self.findnsnames(root)
204         self.first = True
205         for piece in self:
206             out.write(piece)
207
208     @classmethod
209     def fragment(cls, out, root, nsnames=None, **kw):
210         self = cls(src=flatiter(root), nsnames=nsnames, **kw)
211         if nsnames is None:
212             self.findnsnames(root)
213         for piece in self:
214             out.write(piece)
215
216     @classmethod
217     def format(cls, root, **kw):
218         buf = io.BytesIO()
219         cls.output(buf, root, **kw)
220         return buf.getvalue()
221
222 class indenter(formatter):
223     def __init__(self, indent="  ", *args, **kw):
224         super().__init__(*args, **kw)
225         self.indent = indent
226         self.col = 0
227         self.curind = ""
228         self.atbreak = True
229         self.inline = False
230         self.stack = []
231         self.last = None, None
232
233     def write(self, text):
234         lines = text.split("\n")
235         if len(lines) > 1:
236             for ln in lines[:-1]:
237                 self.buf.extend(ln.encode(self.charset))
238                 self.buf.extend(b"\n")
239             self.col = 0
240         self.buf.extend(lines[-1].encode(self.charset))
241         self.col += len(lines[-1])
242         self.atbreak = False
243
244     def br(self):
245         if not self.atbreak:
246             self.buf.extend(("\n" + self.curind).encode(self.charset))
247             self.col = 0
248             self.atbreak = True
249
250     def inlinep(self, el):
251         for ch in el.children:
252             if isinstance(ch, cons.text):
253                 return True
254         return False
255
256     def push(self, el):
257         self.stack.append((el, self.curind, self.inline))
258
259     def pop(self):
260         el, self.curind, self.inline = self.stack.pop()
261         return el
262
263     def starttag(self, el):
264         if not self.inline:
265             if self.last[0] == "<" and self.last[1].name == el.name:
266                 pass
267             else:
268                 self.br()
269         self.push(el)
270         self.inline = self.inline or self.inlinep(el)
271         self.curind += self.indent
272         super().starttag(el)
273
274     def shorttag(self, el):
275         if not self.inline:
276             self.br()
277         super().shorttag(el)
278
279     def endtag(self, el):
280         il = self.inline
281         self.pop()
282         if not il:
283             self.br()
284         super().endtag(el)
285
286     def start(self, el):
287         super().start(el)
288         self.atbreak = True
289
290     def end(self, el):
291         self.br()
292
293     def handle(self, ev, el):
294         super().handle(ev, el)
295         self.last = ev, el
296
297 class textindenter(indenter):
298     maxcol = 70
299
300     def text(self, el):
301         left = str(el)
302         while True:
303             if len(left) + self.col > self.maxcol:
304                 bp = max(self.maxcol - self.col, 0)
305                 for i in range(bp, -1, -1):
306                     if left[i].isspace():
307                         while i > 0 and left[i - 1].isspace(): i -= 1
308                         break
309                 else:
310                     for i in range(bp + 1, len(left)):
311                         if left[i].isspace():
312                             break
313                     else:
314                         i = None
315                 if i is None:
316                     self.quotewrite(left)
317                     break
318                 else:
319                     self.quotewrite(left[:i])
320                     self.br()
321                     left = left[i + 1:].lstrip()
322             else:
323                 self.quotewrite(left)
324                 break
325
326 class response(dispatch.restart):
327     charset = "utf-8"
328     doctype = None
329     formatter = indenter
330
331     def __init__(self, root):
332         super().__init__()
333         self.root = root
334
335     @property
336     def ctype(self):
337         raise Exception("a subclass of wrw.sp.util.response must override ctype")
338
339     def handle(self, req):
340         ret = self.formatter.format(self.root, doctype=self.doctype, charset=self.charset)
341         req.ohead["Content-Type"] = self.ctype
342         req.ohead["Content-Length"] = len(ret)
343         return [ret]