bin: Fixed context-less load().
[coe.git] / coe / bin.py
index fae69b9..e603ff8 100644 (file)
@@ -15,7 +15,7 @@ STR_SYM = 1
 BIT_BFLOAT = 1
 BIT_DFLOAT = 2
 
-CON_LIST = 0
+CON_SEQ = 0
 CON_SET = 1
 CON_MAP = 2
 CON_OBJ = 3
@@ -23,6 +23,124 @@ CON_OBJ = 3
 NIL_FALSE = 1
 NIL_TRUE = 2
 
+class encoder(object):
+    def __init__(self, *, backrefs=True):
+        self.backrefs = backrefs
+        self.reftab = {}
+        self.nextref = 0
+        self.nstab = {}
+
+    @staticmethod
+    def enctag(pri, sec):
+        return bytes([(sec << 3) | pri])
+
+    def writetag(self, dst, pri, sec, datum):
+        dst.write(self.enctag(pri, sec))
+        if self.backrefs:
+            ref = self.nextref
+            self.nextref += 1
+            if datum is not None and id(datum) not in self.reftab:
+                self.reftab[id(datum)] = ref
+            return ref
+        return None
+
+    @staticmethod
+    def encint(x):
+        ret = bytearray()
+        if x >= 0:
+            b = x & 0x7f
+            x >>= 7
+            while (x > 0) or (b & 0x40) != 0:
+                ret.append(0x80 | b)
+                b = x & 0x7f
+                x >>= 7
+            ret.append(b)
+        elif x < 0:
+            b = x & 0x7f
+            x >>= 7
+            while x < -1 or (b & 0x40) == 0:
+                ret.append(0x80 | b)
+                b = x & 0x7f
+                x >>= 7
+            ret.append(b)
+        return ret
+
+    @staticmethod
+    def writestr(dst, text):
+        dst.write(text.encode("utf-8"))
+        dst.write(b'\0')
+
+    def dumpseq(self, dst, seq):
+        for v in seq:
+            self.dump(dst, v)
+        dst.write(self.enctag(T_END, 0))
+
+    def dumpmap(self, dst, val):
+        for k, v in val.items():
+            self.dump(dst, k)
+            self.dump(dst, v)
+        dst.write(self.enctag(T_END, 0))
+
+    def dump(self, dst, datum):
+        ref = self.reftab.get(id(datum))
+        if ref is not None:
+            dst.write(self.enctag(T_INT, INT_REF))
+            dst.write(self.encint(ref))
+            return
+        if datum == None:
+            self.writetag(dst, T_NIL, 0, None)
+        elif datum == False:
+            self.writetag(dst, T_NIL, NIL_FALSE, None)
+        elif datum == True:
+            self.writetag(dst, T_NIL, NIL_TRUE, None)
+        elif isinstance(datum, int):
+            self.writetag(dst, T_INT, 0, None)
+            dst.write(self.encint(datum))
+        elif isinstance(datum, str):
+            self.writetag(dst, T_STR, 0, datum)
+            self.writestr(dst, datum)
+        elif isinstance(datum, (bytes, bytearray)):
+            self.writetag(dst, T_BIT, 0, datum)
+            dst.write(self.encint(len(datum)))
+            dst.write(datum)
+        elif isinstance(datum, data.symbol):
+            if datum.ns == "":
+                self.writetag(dst, T_STR, STR_SYM, datum)
+                self.writestr(dst, datum.name)
+            else:
+                nsref = self.nstab.get(datum.ns)
+                if nsref is None:
+                    nsref = self.writetag(dst, T_SYM, 0, datum)
+                    dst.write(b'\0')
+                    self.writestr(dst, datum.ns)
+                    self.writestr(dst, datum.name)
+                    if nsref is not None:
+                        self.nstab[datum.ns] = nsref
+                else:
+                    self.writetag(dst, T_SYM, 0, datum)
+                    dst.write(b'\x01')
+                    dst.write(self.encint(nsref))
+                    self.writestr(dst, datum.name)
+        elif isinstance(datum, list):
+            self.writetag(dst, T_CON, CON_SEQ, datum)
+            self.dumpseq(dst, datum)
+        elif isinstance(datum, set):
+            self.writetag(dst, T_CON, CON_SET, datum)
+            self.dumpseq(dst, datum)
+        elif isinstance(datum, dict):
+            self.writetag(dst, T_CON, CON_MAP, datum)
+            self.dumpmap(dst, datum)
+        elif isinstance(datum, data.obj):
+            self.writetag(dst, T_CON, CON_OBJ, datum)
+            self.dump(dst, getattr(type(datum), "typename", None))
+            self.dumpmap(dst, datum.__dict__)
+        else:
+            raise ValueError("unsupported object type: " + repr(datum))
+
+def dump(dst, datum):
+    encoder().dump(dst, datum)
+    return dst
+
 class fmterror(Exception):
     pass
 
@@ -34,9 +152,6 @@ class referror(fmterror):
     def __init__(self):
         super().__init__("bad backref")
 
-class namedtype(type):
-    pass
-
 class decoder(object):
     def __init__(self):
         self.reftab = []
@@ -107,6 +222,9 @@ class decoder(object):
                 return buf
             buf[key] = self.loadtagged(fp, tag)
 
+    def makeobjtype(self, nm):
+        return data.namedtype.make(str(nm), (data.obj, object), {}, typename=nm)
+
     def loadobj(self, fp, ref=False):
         if ref:
             refid = len(self.reftab)
@@ -114,8 +232,7 @@ class decoder(object):
         nm = self.load(fp)
         typ = self.namedtypes.get(nm)
         if typ is None:
-            typ = self.namedtypes[nm] = namedtype(str(nm), (data.obj, object), {})
-            typ.typename = nm
+            typ = self.namedtypes[nm] = self.makeobjtype(nm)
         ret = typ()
         if ref:
             self.reftab[refid] = ret
@@ -142,10 +259,10 @@ class decoder(object):
                 return self.reftab[idx]
             return self.addref(self.loadint(fp))
         elif pri == T_STR:
-            ret = self.addref(self.loadstr(fp))
+            ret = self.loadstr(fp)
             if sec == STR_SYM:
-                return data.symbol.get("", ret)
-            return ret
+                return self.addref(data.symbol.get("", ret))
+            return self.addref(ret)
         elif pri == T_BIT:
             ln = self.loadint(fp)
             ret = self.addref(fp.read(ln))
@@ -175,4 +292,4 @@ class decoder(object):
         return self.loadtagged(fp, tag)
 
 def load(fp):
-    decoder().load(fp)
+    return decoder().load(fp)