3ea7a8aa91d6133321cf640da735ca87d5d1588c
[wrw.git] / wrw / sp / util.py
1 import StringIO
2 from wrw import dispatch
3 import cons
4
5 def findnsnames(el):
6     names = {}
7     nid = [1]
8     def proc(el):
9         if isinstance(el, cons.element):
10             if el.ns not in names:
11                 names[el.ns] = u"n" + unicode(nid[0])
12                 nid[:] = [nid[0] + 1]
13             for ch in el.children:
14                 proc(ch)
15     proc(el)
16     if None in names:
17         names[None] = None
18     else:
19         names[el.ns] = None
20     return names
21
22 class formatter(object):
23     def __init__(self, out, root, nsnames=None, charset="utf-8", doctype=None):
24         self.root = root
25         if nsnames is None:
26             nsnames = findnsnames(root)
27         self.nsnames = nsnames
28         self.out = out
29         self.charset = charset
30         self.doctype = doctype
31
32     def write(self, text):
33         self.out.write(text.encode(self.charset))
34
35     def quotewrite(self, buf):
36         for ch in buf:
37             if ch == u'&':
38                 self.write(u"&")
39             elif ch == u'<':
40                 self.write(u"&lt;")
41             elif ch == u'>':
42                 self.write(u"&gt;")
43             else:
44                 self.write(ch)
45
46     def text(self, el):
47         self.quotewrite(el)
48
49     def rawcode(self, el):
50         self.write(el)
51
52     def attrval(self, buf):
53         qc, qt = (u"'", u"&apos;") if u'"' in buf else (u'"', u"&quot;")
54         self.write(qc)
55         for ch in buf:
56             if ch == u'&':
57                 self.write(u"&amp;")
58             elif ch == u'<':
59                 self.write(u"&lt;")
60             elif ch == u'>':
61                 self.write(u"&gt;")
62             elif ch == qc:
63                 self.write(qt)
64             else:
65                 self.write(ch)
66         self.write(qc)
67
68     def attr(self, k, v):
69         self.write(k)
70         self.write(u'=')
71         self.attrval(v)
72
73     def shorttag(self, el, **extra):
74         self.write(u'<' + self.elname(el))
75         for k, v in el.attrs.iteritems():
76             self.write(u' ')
77             self.attr(k, v)
78         for k, v in extra.iteritems():
79             self.write(u' ')
80             self.attr(k, v)
81         self.write(u" />")
82
83     def elname(self, el):
84         ns = self.nsnames[el.ns]
85         if ns is None:
86             return el.name
87         else:
88             return ns + u':' + el.name
89
90     def starttag(self, el, **extra):
91         self.write(u'<' + self.elname(el))
92         for k, v in el.attrs.iteritems():
93             self.write(u' ')
94             self.attr(k, v)
95         for k, v in extra.iteritems():
96             self.write(u' ')
97             self.attr(k, v)
98         self.write(u'>')
99
100     def endtag(self, el):
101         self.write(u'</' + self.elname(el) + u'>')
102
103     def longtag(self, el, **extra):
104         self.starttag(el, **extra)
105         for ch in el.children:
106             self.node(ch)
107         self.endtag(el)
108
109     def element(self, el, **extra):
110         if len(el.children) == 0:
111             self.shorttag(el, **extra)
112         else:
113             self.longtag(el, **extra)
114
115     def node(self, el):
116         if isinstance(el, cons.element):
117             self.element(el)
118         elif isinstance(el, cons.text):
119             self.text(el)
120         elif isinstance(el, cons.raw):
121             self.rawcode(el)
122         else:
123             raise Exception("Unknown object in element tree: " + el)
124
125     def start(self):
126         self.write(u'<?xml version="1.0" encoding="' + self.charset + u'" ?>\n')
127         if self.doctype:
128             self.write(u'<!DOCTYPE %s PUBLIC "%s" "%s">\n' % (self.root.name,
129                                                               self.doctype[0],
130                                                               self.doctype[1]))
131         extra = {}
132         for uri, nm in self.nsnames.iteritems():
133             if uri is None:
134                 continue
135             if nm is None:
136                 extra[u"xmlns"] = uri
137             else:
138                 extra[u"xmlns:" + nm] = uri
139         self.element(self.root, **extra)
140
141     @classmethod
142     def output(cls, out, el, *args, **kw):
143         cls(out=out, root=el, *args, **kw).start()
144
145     @classmethod
146     def fragment(cls, out, el, *args, **kw):
147         cls(out=out, root=el, *args, **kw).node(el)
148
149     @classmethod
150     def format(cls, el, *args, **kw):
151         buf = StringIO.StringIO()
152         cls.output(buf, el, *args, **kw)
153         return buf.getvalue()
154
155     def update(self, **ch):
156         ret = type(self).__new__(type(self))
157         ret.__dict__.update(self.__dict__)
158         ret.__dict__.update(ch)
159         return ret
160
161 class iwriter(object):
162     def __init__(self, out):
163         self.out = out
164         self.atbol = True
165         self.col = 0
166
167     def write(self, buf):
168         for c in buf:
169             if c == '\n':
170                 self.col = 0
171             else:
172                 self.col += 1
173             self.out.write(c)
174         self.atbol = False
175
176     def indent(self, indent):
177         if self.atbol:
178             return
179         if self.col != 0:
180             self.write('\n')
181         self.write(indent)
182         self.atbol = True
183
184 class indenter(formatter):
185     def __init__(self, indent=u"  ", *args, **kw):
186         super(indenter, self).__init__(*args, **kw)
187         self.out = iwriter(self.out)
188         self.indent = indent
189         self.curind = u""
190
191     def simple(self, el):
192         for ch in el.children:
193             if not isinstance(ch, cons.text):
194                 return False
195         return True
196
197     def longtag(self, el, **extra):
198         self.starttag(el, **extra)
199         sub = self
200         reind = False
201         if not self.simple(el):
202             sub = self.update(curind=self.curind + self.indent)
203             sub.reindent()
204             reind = True
205         for ch in el.children:
206             sub.node(ch)
207         if reind:
208             self.reindent()
209         self.endtag(el)
210
211     def element(self, el, **extra):
212         super(indenter, self).element(el, **extra)
213         if self.out.col > 80 and self.simple(el):
214             self.reindent()
215
216     def reindent(self):
217         self.out.indent(self.curind.encode(self.charset))
218
219     def start(self):
220         super(indenter, self).start()
221         self.write('\n')
222
223 class response(dispatch.restart):
224     charset = "utf-8"
225     doctype = None
226     formatter = indenter
227
228     def __init__(self, root):
229         super(response, self).__init__()
230         self.root = root
231
232     @property
233     def ctype(self):
234         raise Exception("a subclass of wrw.sp.util.response must override ctype")
235
236     def handle(self, req):
237         req.ohead["Content-Type"] = self.ctype
238         return [self.formatter.format(self.root, doctype=self.doctype, charset=self.charset)]