Improved indentation a bit further.
[wrw.git] / wrw / sp / util.py
1 import itertools, StringIO
2 from wrw import dispatch
3 import cons
4
5 def findnsnames(el):
6     names = {}
7     nid = [1]
8     def proc(el):
9         if isinstance(el, cons.element):
10             if el.ns not in names:
11                 names[el.ns] = u"n" + unicode(nid[0])
12                 nid[:] = [nid[0] + 1]
13             for ch in el.children:
14                 proc(ch)
15     proc(el)
16     if None in names:
17         names[None] = None
18     else:
19         names[el.ns] = None
20     return names
21
22 def flatiter(root, short=True):
23     yield ">", root
24     stack = [(root, 0)]
25     while len(stack) > 0:
26         el, i = stack[-1]
27         if i >= len(el.children):
28             yield "<", el
29             stack.pop()
30         else:
31             ch = el.children[i]
32             stack[-1] = el, i + 1
33             if isinstance(ch, cons.element):
34                 if short and len(ch.children) == 0:
35                     yield "/", ch
36                 else:
37                     yield ">", ch
38                     stack.append((ch, 0))
39             elif isinstance(ch, cons.text):
40                 yield "", ch
41             elif isinstance(ch, cons.raw):
42                 yield "!", ch
43             else:
44                 raise Exception("Unknown object in element tree: " + el)
45
46 class formatter(object):
47     def __init__(self, src, nsnames=None, charset="utf-8"):
48         self.src = src
49         self.nsnames = nsnames or {}
50         self.nextns = 1
51         self.first = False
52         self.buf = bytearray()
53         self.charset = charset
54
55     def write(self, text):
56         self.buf.extend(text.encode(self.charset))
57
58     def quotewrite(self, buf):
59         buf = buf.replace(u'&', u"&amp;")
60         buf = buf.replace(u'<', u"&lt;")
61         buf = buf.replace(u'>', u"&gt;")
62         self.write(buf)
63
64     def __iter__(self):
65         return self
66
67     def elname(self, el):
68         ns = self.nsnames[el.ns]
69         if ns is None:
70             return el.name
71         else:
72             return ns + u":" + el.name
73
74     def attrval(self, v):
75         qc, qt = (u"'", u"&apos;") if u'"' in v else (u'"', u"&quot;")
76         self.write(qc)
77         v = v.replace(u'&', u"&amp;")
78         v = v.replace(u'<', u"&lt;")
79         v = v.replace(u'>', u"&gt;")
80         v = v.replace(qc, qt)
81         self.write(v)
82         self.write(qc)
83
84     def attr(self, k, v):
85         self.write(k)
86         self.write(u"=")
87         self.attrval(v)
88
89     def attrs(self, attrs):
90         for k, v in attrs:
91             self.write(u" ")
92             self.attr(k, v)
93
94     def inittag(self, el):
95         self.write(u"<" + self.elname(el))
96         attrs = el.attrs.iteritems()
97         if self.first:
98             nsnames = []
99             for ns, name in self.nsnames.iteritems():
100                 if ns is None:
101                     if name is not None:
102                         raise Exception("null namespace must have null name, not" + name)
103                     continue
104                 nsnames.append((u"xmlns" if name is None else (u"xmlns:" + name), ns))
105             attrs = itertools.chain(attrs, iter(nsnames))
106             self.first = False
107         self.attrs(attrs)
108
109     def starttag(self, el):
110         self.inittag(el)
111         self.write(u">")
112
113     def shorttag(self, el):
114         self.inittag(el)
115         self.write(u" />")
116
117     def endtag(self, el):
118         self.write(u"</" + self.elname(el) + u">")
119
120     def text(self, el):
121         self.quotewrite(el)
122
123     def rawcode(self, el):
124         self.write(el)
125
126     def start(self, el):
127         self.write(u'<?xml version="1.0" encoding="' + self.charset + u'" ?>\n')
128         if isinstance(el, cons.doctype):
129             self.write(u'<!DOCTYPE %s PUBLIC "%s" "%s">\n' % (el.rootname,
130                                                               el.pubid,
131                                                               el.dtdid))
132         self.first = True
133
134     def end(self, el):
135         pass
136
137     def handle(self, ev, el):
138         if ev == ">":
139             self.starttag(el)
140         elif ev == "/":
141             self.shorttag(el)
142         elif ev == "<":
143             self.endtag(el)
144         elif ev == "":
145             self.text(el)
146         elif ev == "!":
147             self.rawcode(el)
148         elif ev == "^":
149             self.start(el)
150         elif ev == "$":
151             self.end(el)
152
153     def next(self):
154         if self.src is None:
155             raise StopIteration()
156         try:
157             ev, el = next(self.src)
158         except StopIteration:
159             self.src = None
160             ev, el = "$", None
161         self.handle(ev, el)
162         ret = str(self.buf)
163         self.buf[:] = ""
164         return ret
165
166     def nsname(self, el):
167         for t in type(self).__mro__:
168             ret = getattr(t, "defns", {}).get(el.ns, None)
169             if ret is not None:
170                 return ret
171         if el.ns is None:
172             return None
173         ret = u"n" + unicode(self.nextns)
174         self.nextns += 1
175         return ret
176
177     def findnsnames(self, root):
178         fnames = {}
179         rnames = {}
180         def proc(el):
181             if isinstance(el, cons.element):
182                 if el.ns not in fnames:
183                     nm = self.nsname(el)
184                     fnames[el.ns] = nm
185                     rnames[nm] = el.ns
186                 for ch in el.children:
187                     proc(ch)
188         proc(root)
189         if None not in rnames:
190             fnames[root.ns] = None
191             rnames[None] = root.ns
192         self.nsnames = fnames
193
194     @classmethod
195     def output(cls, out, root, nsnames=None, doctype=None, **kw):
196         if isinstance(doctype, cons.doctype):
197             pass
198         elif doctype is not None:
199             doctype = cons.doctype(root.name, doctype[0], doctype[1])
200         src = itertools.chain(iter([("^", doctype)]), flatiter(root))
201         self = cls(src=src, nsnames=nsnames, **kw)
202         if nsnames is None:
203             self.findnsnames(root)
204         self.first = True
205         for piece in self:
206             out.write(piece)
207
208     @classmethod
209     def fragment(cls, out, root, nsnames=None, **kw):
210         self = cls(src=flatiter(root), nsnames=nsnames, **kw)
211         if nsnames is None:
212             self.findnsnames(root)
213         for piece in self:
214             out.write(piece)
215
216     @classmethod
217     def format(cls, root, **kw):
218         buf = StringIO.StringIO()
219         cls.output(buf, root, **kw)
220         return buf.getvalue()
221
222 class indenter(formatter):
223     def __init__(self, indent=u"  ", *args, **kw):
224         super(indenter, self).__init__(*args, **kw)
225         self.indent = indent
226         self.col = 0
227         self.curind = u""
228         self.atbreak = True
229         self.inline = False
230         self.stack = []
231         self.last = None, None
232         self.lastendbr = True
233
234     def write(self, text):
235         lines = text.split(u"\n")
236         if len(lines) > 1:
237             for ln in lines[:-1]:
238                 self.buf.extend(ln.encode(self.charset))
239                 self.buf.extend("\n")
240             self.col = 0
241         self.buf.extend(lines[-1].encode(self.charset))
242         self.col += len(lines[-1])
243         self.atbreak = False
244
245     def br(self):
246         if not self.atbreak:
247             self.buf.extend((u"\n" + self.curind).encode(self.charset))
248             self.col = 0
249             self.atbreak = True
250
251     def inlinep(self, el):
252         for ch in el.children:
253             if isinstance(ch, cons.text):
254                 return True
255         return False
256
257     def push(self, el):
258         self.stack.append((el, self.curind, self.inline))
259
260     def pop(self):
261         el, self.curind, self.inline = self.stack.pop()
262         return el
263
264     def starttag(self, el):
265         if not self.inline:
266             if self.last[0] == "<" and self.last[1].name == el.name and self.lastendbr:
267                 pass
268             else:
269                 self.br()
270         self.push(el)
271         self.inline = self.inline or self.inlinep(el)
272         self.curind += self.indent
273         super(indenter, self).starttag(el)
274
275     def shorttag(self, el):
276         if not self.inline:
277             self.br()
278         super(indenter, self).shorttag(el)
279
280     def endtag(self, el):
281         il = self.inline
282         self.pop()
283         if il or (self.last[0] == ">" and self.last[1] == el):
284             self.lastendbr = False
285         else:
286             self.br()
287             self.lastendbr = True
288         super(indenter, self).endtag(el)
289
290     def start(self, el):
291         super(indenter, self).start(el)
292         self.atbreak = True
293
294     def end(self, el):
295         self.br()
296
297     def handle(self, ev, el):
298         super(indenter, self).handle(ev, el)
299         self.last = ev, el
300
301 class textindenter(indenter):
302     maxcol = 70
303
304     def text(self, el):
305         left = unicode(el)
306         while True:
307             if len(left) + self.col > self.maxcol:
308                 bp = max(self.maxcol - self.col, 0)
309                 for i in xrange(bp, -1, -1):
310                     if left[i].isspace():
311                         while i > 0 and left[i - 1].isspace(): i -= 1
312                         break
313                 else:
314                     for i in xrange(bp + 1, len(left)):
315                         if left[i].isspace():
316                             break
317                     else:
318                         i = None
319                 if i is None:
320                     self.quotewrite(left)
321                     break
322                 else:
323                     self.quotewrite(left[:i])
324                     self.br()
325                     left = left[i + 1:].lstrip()
326             else:
327                 self.quotewrite(left)
328                 break
329
330 class response(dispatch.restart):
331     charset = "utf-8"
332     doctype = None
333     formatter = indenter
334
335     def __init__(self, root):
336         super(response, self).__init__()
337         self.root = root
338
339     @property
340     def ctype(self):
341         raise Exception("a subclass of wrw.sp.util.response must override ctype")
342
343     def handle(self, req):
344         ret = self.formatter.format(self.root, doctype=self.doctype, charset=self.charset)
345         req.ohead["Content-Type"] = self.ctype
346         req.ohead["Content-Length"] = len(ret)
347         return [ret]