Merge branch 'master' into python3
[wrw.git] / wrw / sp / util.py
1 import io
2 from .. import dispatch
3 from . import cons
4
5 def findnsnames(el):
6     names = {}
7     nid = [1]
8     def proc(el):
9         if isinstance(el, cons.element):
10             if el.ns not in names:
11                 names[el.ns] = "n" + str(nid[0])
12                 nid[:] = [nid[0] + 1]
13             for ch in el.children:
14                 proc(ch)
15     proc(el)
16     if None in names:
17         names[None] = None
18     else:
19         names[el.ns] = None
20     return names
21
22 class formatter(object):
23     def __init__(self, out, root, nsnames=None, charset="utf-8", doctype=None):
24         self.root = root
25         if nsnames is None:
26             nsnames = findnsnames(root)
27         self.nsnames = nsnames
28         self.out = out
29         self.charset = charset
30         self.doctype = doctype
31
32     def write(self, text):
33         self.out.write(text.encode(self.charset))
34
35     def quotewrite(self, buf):
36         for ch in buf:
37             if ch == '&':
38                 self.write("&")
39             elif ch == '<':
40                 self.write("&lt;")
41             elif ch == '>':
42                 self.write("&gt;")
43             else:
44                 self.write(ch)
45
46     def text(self, el):
47         self.quotewrite(el)
48
49     def rawcode(self, el):
50         self.write(el)
51
52     def attrval(self, buf):
53         qc, qt = ("'", "&apos;") if '"' in buf else ('"', "&quot;")
54         self.write(qc)
55         for ch in buf:
56             if ch == '&':
57                 self.write("&amp;")
58             elif ch == '<':
59                 self.write("&lt;")
60             elif ch == '>':
61                 self.write("&gt;")
62             elif ch == qc:
63                 self.write(qt)
64             else:
65                 self.write(ch)
66         self.write(qc)
67
68     def attr(self, k, v):
69         self.write(k)
70         self.write('=')
71         self.attrval(v)
72
73     def shorttag(self, el, **extra):
74         self.write('<' + self.elname(el))
75         for k, v in el.attrs.items():
76             self.write(' ')
77             self.attr(k, v)
78         for k, v in extra.items():
79             self.write(' ')
80             self.attr(k, v)
81         self.write(" />")
82
83     def elname(self, el):
84         ns = self.nsnames[el.ns]
85         if ns is None:
86             return el.name
87         else:
88             return ns + ':' + el.name
89
90     def starttag(self, el, **extra):
91         self.write('<' + self.elname(el))
92         for k, v in el.attrs.items():
93             self.write(' ')
94             self.attr(k, v)
95         for k, v in extra.items():
96             self.write(' ')
97             self.attr(k, v)
98         self.write('>')
99
100     def endtag(self, el):
101         self.write('</' + self.elname(el) + '>')
102
103     def longtag(self, el, **extra):
104         self.starttag(el, **extra)
105         for ch in el.children:
106             self.node(ch)
107         self.endtag(el)
108
109     def element(self, el, **extra):
110         if len(el.children) == 0:
111             self.shorttag(el, **extra)
112         else:
113             self.longtag(el, **extra)
114
115     def node(self, el):
116         if isinstance(el, cons.element):
117             self.element(el)
118         elif isinstance(el, cons.text):
119             self.text(el)
120         elif isinstance(el, cons.raw):
121             self.rawcode(el)
122         else:
123             raise Exception("Unknown object in element tree: " + el)
124
125     def start(self):
126         self.write('<?xml version="1.0" encoding="' + self.charset + '" ?>\n')
127         if self.doctype:
128             self.write('<!DOCTYPE %s PUBLIC "%s" "%s">\n' % (self.root.name,
129                                                              self.doctype[0],
130                                                              self.doctype[1]))
131         extra = {}
132         for uri, nm in self.nsnames.items():
133             if uri is None:
134                 continue
135             if nm is None:
136                 extra["xmlns"] = uri
137             else:
138                 extra["xmlns:" + nm] = uri
139         self.element(self.root, **extra)
140
141     @classmethod
142     def output(cls, out, el, *args, **kw):
143         cls(out=out, root=el, *args, **kw).start()
144
145     @classmethod
146     def fragment(cls, out, el, *args, **kw):
147         cls(out=out, root=el, *args, **kw).node(el)
148
149     @classmethod
150     def format(cls, el, *args, **kw):
151         buf = io.BytesIO()
152         cls.output(buf, el, *args, **kw)
153         return buf.getvalue()
154
155     def update(self, **ch):
156         ret = type(self).__new__(type(self))
157         ret.__dict__.update(self.__dict__)
158         ret.__dict__.update(ch)
159         return ret
160
161 class iwriter(object):
162     def __init__(self, out):
163         self.out = out
164         self.atbol = True
165         self.col = 0
166
167     def write(self, buf):
168         for i in range(len(buf)):
169             c = buf[i:i + 1]
170             if c == b'\n':
171                 self.col = 0
172             else:
173                 self.col += 1
174             self.out.write(c)
175         self.atbol = False
176
177     def indent(self, indent):
178         if self.atbol:
179             return
180         if self.col != 0:
181             self.write(b'\n')
182         self.write(indent)
183         self.atbol = True
184
185 class indenter(formatter):
186     def __init__(self, indent="  ", *args, **kw):
187         super(indenter, self).__init__(*args, **kw)
188         self.out = iwriter(self.out)
189         self.indent = indent
190         self.curind = ""
191
192     def simple(self, el):
193         for ch in el.children:
194             if not isinstance(ch, cons.text):
195                 return False
196         return True
197
198     def longtag(self, el, **extra):
199         self.starttag(el, **extra)
200         sub = self
201         reind = False
202         if not self.simple(el):
203             sub = self.update(curind=self.curind + self.indent)
204             sub.reindent()
205             reind = True
206         for ch in el.children:
207             sub.node(ch)
208         if reind:
209             self.reindent()
210         self.endtag(el)
211
212     def element(self, el, **extra):
213         super(indenter, self).element(el, **extra)
214         if self.out.col > 80 and self.simple(el):
215             self.reindent()
216
217     def reindent(self):
218         self.out.indent(self.curind.encode(self.charset))
219
220     def start(self):
221         super(indenter, self).start()
222         self.write('\n')
223
224 class response(dispatch.restart):
225     charset = "utf-8"
226     doctype = None
227     formatter = indenter
228
229     def __init__(self, root):
230         super().__init__()
231         self.root = root
232
233     @property
234     def ctype(self):
235         raise Exception("a subclass of wrw.sp.util.response must override ctype")
236
237     def handle(self, req):
238         req.ohead["Content-Type"] = self.ctype
239         return [self.formatter.format(self.root, doctype=self.doctype, charset=self.charset)]