f387f1c09298d948ca115b5fd7c24466b91ee07f
[wrw.git] / wrw / sp / util.py
1 import cons
2
3 def findnsnames(el):
4     names = {}
5     nid = [1]
6     def proc(el):
7         if isinstance(el, cons.element):
8             if el.ns not in names:
9                 names[el.ns] = u"n" + unicode(nid[0])
10                 nid[:] = [nid[0] + 1]
11             for ch in el.children:
12                 proc(ch)
13     proc(el)
14     if None in names:
15         names[None] = None
16     else:
17         names[el.ns] = None
18     return names
19
20 class formatter(object):
21     def __init__(self, out, root, nsnames=None, charset="utf-8", doctype=None):
22         self.root = root
23         if nsnames is None:
24             nsnames = findnsnames(root)
25         self.nsnames = nsnames
26         self.out = out
27         self.charset = charset
28         self.doctype = doctype
29
30     def write(self, text):
31         self.out.write(text.encode(self.charset))
32
33     def quotewrite(self, buf):
34         for ch in buf:
35             if ch == u'&':
36                 self.write(u"&")
37             elif ch == u'<':
38                 self.write(u"&lt;")
39             elif ch == u'>':
40                 self.write(u"&gt;")
41             else:
42                 self.write(ch)
43
44     def text(self, el):
45         self.quotewrite(el)
46
47     def attrval(self, buf):
48         qc, qt = (u"'", u"&apos;") if u'"' in buf else (u'"', u"&quot;")
49         self.write(qc)
50         for ch in buf:
51             if ch == u'&':
52                 self.write(u"&amp;")
53             elif ch == u'<':
54                 self.write(u"&lt;")
55             elif ch == u'>':
56                 self.write(u"&gt;")
57             elif ch == qc:
58                 self.write(qt)
59             else:
60                 self.write(ch)
61         self.write(qc)
62
63     def attr(self, k, v):
64         self.write(k)
65         self.write(u'=')
66         self.attrval(v)
67
68     def shorttag(self, el, **extra):
69         self.write(u'<' + self.elname(el))
70         for k, v in el.attrs.iteritems():
71             self.write(u' ')
72             self.attr(k, v)
73         for k, v in extra.iteritems():
74             self.write(u' ')
75             self.attr(k, v)
76         self.write(u" />")
77
78     def elname(self, el):
79         ns = self.nsnames[el.ns]
80         if ns is None:
81             return el.name
82         else:
83             return ns + u':' + el.name
84
85     def starttag(self, el, **extra):
86         self.write(u'<' + self.elname(el))
87         for k, v in el.attrs.iteritems():
88             self.write(u' ')
89             self.attr(k, v)
90         for k, v in extra.iteritems():
91             self.write(u' ')
92             self.attr(k, v)
93         self.write(u'>')
94
95     def endtag(self, el):
96         self.write(u'</' + self.elname(el) + u'>')
97
98     def longtag(self, el):
99         self.starttag(el, **extra)
100         for ch in el.children:
101             self.node(ch)
102         self.endtag(el)
103
104     def element(self, el, **extra):
105         if len(el.children) == 0:
106             self.shorttag(el, **extra)
107         else:
108             self.longtag(el, **extra)
109
110     def node(self, el):
111         if isinstance(el, cons.element):
112             self.element(el)
113         elif isinstance(el, cons.text):
114             self.text(el)
115         else:
116             raise Exception("Unknown object in element tree: " + el)
117
118     def start(self):
119         self.write(u'<?xml version="1.0" encoding="' + self.charset + u'" ?>\n')
120         if self.doctype:
121             self.write(u'<!DOCTYPE %s PUBLIC "%s" "%s">\n' % (self.root.name,
122                                                               self.doctype[0],
123                                                               self.doctype[1]))
124         extra = {}
125         for uri, nm in self.nsnames.iteritems():
126             if uri is None:
127                 continue
128             if nm is None:
129                 extra[u"xmlns"] = uri
130             else:
131                 extra[u"xmlns:" + nm] = uri
132         self.element(self.root, **extra)
133
134     @classmethod
135     def output(cls, out, el, *args, **kw):
136         cls(out=out, root=el, *args, **kw).start()
137
138     def update(self, **ch):
139         ret = type(self).__new__(type(self))
140         ret.__dict__.update(self.__dict__)
141         ret.__dict__.update(ch)
142         return ret
143
144 class iwriter(object):
145     def __init__(self, out):
146         self.out = out
147         self.atbol = True
148         self.col = 0
149
150     def write(self, buf):
151         for c in buf:
152             if c == '\n':
153                 self.col = 0
154             else:
155                 self.col += 1
156             self.out.write(c)
157         self.atbol = False
158
159     def indent(self, indent):
160         if self.atbol:
161             return
162         if self.col != 0:
163             self.write('\n')
164         self.write(indent)
165         self.atbol = True
166
167 class indenter(formatter):
168     def __init__(self, indent=u"  ", *args, **kw):
169         super(indenter, self).__init__(*args, **kw)
170         self.out = iwriter(self.out)
171         self.indent = indent
172         self.curind = u""
173
174     def simple(self, el):
175         for ch in el.children:
176             if not isinstance(ch, cons.text):
177                 return False
178         return True
179
180     def longtag(self, el, **extra):
181         self.starttag(el, **extra)
182         sub = self
183         reind = False
184         if not self.simple(el):
185             sub = self.update(curind=self.curind + self.indent)
186             sub.reindent()
187             reind = True
188         for ch in el.children:
189             sub.node(ch)
190         if reind:
191             self.reindent()
192         self.endtag(el)
193
194     def element(self, el, **extra):
195         super(indenter, self).element(el, **extra)
196         if self.out.col > 80 and self.simple(el):
197             self.reindent()
198
199     def reindent(self):
200         self.out.indent(self.curind.encode(self.charset))
201
202     def start(self):
203         super(indenter, self).start()
204         self.write('\n')