acmecert: Fix cryptography bugs.
[utils.git] / acmecert
index 91fcece..14d0f00 100755 (executable)
--- a/acmecert
+++ b/acmecert
@@ -1,17 +1,15 @@
 #!/usr/bin/python3
 
-import sys, os, getopt, binascii, json, pprint, signal, time
+#### ACME client (only http-01 challenges supported thus far)
+
+import sys, os, getopt, binascii, json, pprint, signal, time, calendar, threading
 import urllib.request
-import Crypto.PublicKey.RSA, Crypto.Random, Crypto.Hash.SHA256, Crypto.Signature.PKCS1_v1_5
 
-service = "https://acme-v02.api.letsencrypt.org/directory"
-_directory = None
-def directory():
-    global _directory
-    if _directory is None:
-        with req(service) as resp:
-            _directory = json.loads(resp.read().decode("utf-8"))
-    return _directory
+### General utilities
+
+class msgerror(Exception):
+    def report(self, out):
+        out.write("acmecert: undefined error\n")
 
 def base64url(dat):
     return binascii.b2a_base64(dat).decode("us-ascii").translate({43: 45, 47: 95, 61: None}).strip()
@@ -21,6 +19,293 @@ def ebignum(num):
     if len(h) % 2 == 1: h = "0" + h
     return base64url(binascii.a2b_hex(h))
 
+class maybeopen(object):
+    def __init__(self, name, mode):
+        if name == "-":
+            self.opened = False
+            if mode == "r":
+                self.fp = sys.stdin
+            elif mode == "w":
+                self.fp = sys.stdout
+            else:
+                raise ValueError(mode)
+        else:
+            self.opened = True
+            self.fp = open(name, mode)
+
+    def __enter__(self):
+        return self.fp
+
+    def __exit__(self, *excinfo):
+        if self.opened:
+            self.fp.close()
+        return False
+
+### Crypto utilities
+
+_cryptobke = None
+def cryptobke():
+    global _cryptobke
+    if _cryptobke is None:
+        from cryptography.hazmat import backends
+        _cryptobke = backends.default_backend()
+    return _cryptobke
+
+class dererror(Exception):
+    pass
+
+class pemerror(Exception):
+    pass
+
+def pemdec(pem, ptypes):
+    if isinstance(ptypes, str):
+        ptypes = [ptypes]
+    p = 0
+    while True:
+        p = pem.find("-----BEGIN ", p)
+        if p < 0:
+            raise pemerror("could not find any %s in PEM-encoded data" % (ptypes,))
+        p2 = pem.find("-----", p + 11)
+        if p2 < 0:
+            raise pemerror("incomplete PEM header")
+        ptype = pem[p + 11 : p2]
+        if ptype not in ptypes:
+            p = p2 + 5
+            continue
+        p3 = pem.find("-----END " + ptype + "-----", p2 + 5)
+        if p3 < 0:
+            raise pemerror("incomplete PEM data")
+        pem = pem[p2 + 5 : p3]
+        return binascii.a2b_base64(pem)
+
+class derdecoder(object):
+    def __init__(self, data, offset=0, size=None):
+        self.data = data
+        self.offset = offset
+        self.size = len(data) if size is None else size
+
+    def end(self):
+        return self.offset >= self.size
+
+    def byte(self):
+        if self.offset >= self.size:
+            raise dererror("unexpected end-of-data")
+        ret = self.data[self.offset]
+        self.offset += 1
+        return ret
+
+    def splice(self, ln):
+        if self.offset + ln > self.size:
+            raise dererror("unexpected end-of-data")
+        ret = self.data[self.offset : self.offset + ln]
+        self.offset += ln
+        return ret
+
+    def dectag(self):
+        h = self.byte()
+        cl = (h & 0xc0) >> 6
+        cons = (h & 0x20) != 0
+        tag = h & 0x1f
+        if tag == 0x1f:
+            raise dererror("extended type tags not supported")
+        return cl, cons, tag
+
+    def declen(self):
+        h = self.byte()
+        if (h & 0x80) == 0:
+            return h
+        if h == 0x80:
+            raise dererror("indefinite lengths not supported in DER")
+        if h == 0xff:
+            raise dererror("invalid length byte")
+        n = h & 0x7f
+        ret = 0
+        for i in range(n):
+            ret = (ret << 8) + self.byte()
+        return ret
+
+    def get(self):
+        cl, cons, tag = self.dectag()
+        ln = self.declen()
+        return cons, cl, tag, self.splice(ln)
+
+    def getcons(self, ckcl, cktag):
+        cons, cl, tag, data = self.get()
+        if not cons:
+            raise dererror("expected constructed value")
+        if (ckcl != None and ckcl != cl) or (cktag != None and cktag != tag):
+            raise dererror("unexpected value tag: got (%d, %d), expected (%d, %d)" % (cl, tag, ckcl, cktag))
+        return derdecoder(data)
+
+    def getint(self):
+        cons, cl, tag, data = self.get()
+        if (cons, cl, tag) == (False, 0, 2):
+            ret = 0
+            for b in data:
+                ret = (ret << 8) + b
+            return ret
+        raise dererror("unexpected integer type: (%s, %d, %d)" % (cons, cl, tag))
+
+    def getstr(self):
+        cons, cl, tag, data = self.get()
+        if (cons, cl, tag) == (False, 0, 12):
+            return data.decode("utf-8")
+        if (cons, cl, tag) == (False, 0, 13):
+            return data.decode("us-ascii")
+        if (cons, cl, tag) == (False, 0, 22):
+            return data.decode("us-ascii")
+        if (cons, cl, tag) == (False, 0, 30):
+            return data.decode("utf-16-be")
+        raise dererror("unexpected string type: (%s, %d, %d)" % (cons, cl, tag))
+
+    def getbytes(self):
+        cons, cl, tag, data = self.get()
+        if (cons, cl, tag) == (False, 0, 4):
+            return data
+        raise dererror("unexpected byte-string type: (%s, %d, %d)" % (cons, cl, tag))
+
+    def getoid(self):
+        cons, cl, tag, data = self.get()
+        if (cons, cl, tag) == (False, 0, 6):
+            ret = []
+            ret.append(data[0] // 40)
+            ret.append(data[0] % 40)
+            p = 1
+            while p < len(data):
+                n = 0
+                v = data[p]
+                p += 1
+                while v & 0x80:
+                    n = (n + (v & 0x7f)) * 128
+                    v = data[p]
+                    p += 1
+                n += v
+                ret.append(n)
+            return tuple(ret)
+        raise dererror("unexpected object-id type: (%s, %d, %d)" % (cons, cl, tag))
+
+    @staticmethod
+    def parsetime(data, c):
+        if c:
+            y = int(data[0:4])
+            data = data[4:]
+        else:
+            y = int(data[0:2])
+            y += 1900 if y > 50 else 2000
+            data = data[2:]
+        m = int(data[0:2])
+        d = int(data[2:4])
+        H = int(data[4:6])
+        data = data[6:]
+        if data[:1].isdigit():
+            M = int(data[0:2])
+            data = data[2:]
+        else:
+            M = 0
+        if data[:1].isdigit():
+            S = int(data[0:2])
+            data = data[2:]
+        else:
+            S = 0
+        if data[:1] == '.':
+            p = 1
+            while len(data) < p and data[p].isdigit():
+                p += 1
+            S += float("0." + data[1:p])
+            data = data[p:]
+        if len(data) < 1:
+            raise dererror("unspecified local time not supported for decoding")
+        if data[0] == 'Z':
+            tz = 0
+        elif data[0] == '+':
+            tz = (int(data[1:3]) * 60) + int(data[3:5])
+        elif data[0] == '-':
+            tz = -((int(data[1:3]) * 60) + int(data[3:5]))
+        else:
+            raise dererror("cannot parse X.690 timestamp")
+        return calendar.timegm((y, m, d, H, M, S)) - (tz * 60)
+
+    def gettime(self):
+        cons, cl, tag, data = self.get()
+        if (cons, cl, tag) == (False, 0, 23):
+            return self.parsetime(data.decode("us-ascii"), False)
+        if (cons, cl, tag) == (False, 0, 24):
+            return self.parsetime(data.decode("us-ascii"), True)
+        raise dererror("unexpected time type: (%s, %d, %d)" % (cons, cl, tag))
+
+    @classmethod
+    def frompem(cls, pem, ptypes):
+        return cls(pemdec(pem, ptypes))
+
+class certificate(object):
+    def __init__(self, der):
+        ci = der.getcons(0, 16).getcons(0, 16)
+        self.ver = ci.getcons(2, 0).getint()
+        self.serial = ci.getint()
+        ci.getcons(0, 16)       # Signature algorithm
+        ci.getcons(0, 16)       # Issuer
+        vl = ci.getcons(0, 16)
+        self.startdate = vl.gettime()
+        self.enddate = vl.gettime()
+
+    def expiring(self, timespec):
+        if timespec.endswith("y"):
+            timespec = int(timespec[:-1]) * 365 * 86400
+        elif timespec.endswith("m"):
+            timespec = int(timespec[:-1]) * 30 * 86400
+        elif timespec.endswith("w"):
+            timespec = int(timespec[:-1]) * 7 * 86400
+        elif timespec.endswith("d"):
+            timespec = int(timespec[:-1]) * 86400
+        elif timespec.endswith("h"):
+            timespec = int(timespec[:-1]) * 3600
+        else:
+            timespec = int(timespec)
+        return (self.enddate - time.time()) < timespec
+
+    @classmethod
+    def read(cls, fp):
+        return cls(derdecoder.frompem(fp.read(), {"CERTIFICATE", "X509 CERTIFICATE"}))
+
+class signreq(object):
+    def __init__(self, der):
+        self.raw = der
+        req = derdecoder(der).getcons(0, 16).getcons(0, 16)
+        self.ver = req.getint()
+        req.getcons(0, 16)      # Subject
+        req.getcons(0, 16)      # Public key
+        self.altnames = []
+        if not req.end():
+            attrs = req.getcons(2, 0)
+            while not attrs.end():
+                attr = attrs.getcons(0, 16)
+                anm = attr.getoid()
+                if anm == (1, 2, 840, 113549, 1, 9, 14):
+                    # Certificate extension request
+                    exts = attr.getcons(0, 17).getcons(0, 16)
+                    while not exts.end():
+                        ext = exts.getcons(0, 16)
+                        extnm = ext.getoid()
+                        if extnm == (2, 5, 29, 17):
+                            # Subject alternative names
+                            names = derdecoder(ext.getbytes()).getcons(0, 16)
+                            while not names.end():
+                                cons, cl, tag, data = names.get()
+                                if (cons, cl, tag) == (False, 2, 2):
+                                    self.altnames.append(("DNS", data.decode("us-ascii")))
+
+    def domains(self):
+        return [nm[1] for nm in self.altnames if nm[0] == "DNS"]
+
+    def der(self):
+        return self.raw
+
+    @classmethod
+    def read(cls, fp):
+        return cls(pemdec(fp.read(), {"CERTIFICATE REQUEST"}))
+
+### Somewhat general request utilities
+
 def getnonce():
     with urllib.request.urlopen(directory()["newNonce"]) as resp:
         resp.read()
@@ -37,6 +322,48 @@ def req(url, data=None, ctype=None, headers={}, method=None, **kws):
         req.add_header("Content-Type", ctype)
     return urllib.request.urlopen(req)
 
+class problem(msgerror):
+    def __init__(self, code, data, *args, url=None, **kw):
+        super().__init__(*args, **kw)
+        self.code = code
+        self.data = data
+        self.url = url
+        if not isinstance(data, dict):
+            raise ValueError("unexpected problem object type: %r" % (data,))
+
+    @property
+    def type(self):
+        return self.data.get("type", "about:blank")
+    @property
+    def title(self):
+        return self.data.get("title")
+    @property
+    def detail(self):
+        return self.data.get("detail")
+
+    def report(self, out):
+        extra = None
+        if self.title is None:
+            msg = self.detail
+            if "\n" in msg:
+                extra, msg = msg, None
+        else:
+            msg = self.title
+            extra = self.detail
+        if msg is None:
+            msg = self.data.get("type")
+        if msg is not None:
+            out.write("acemcert: %s: %s\n" % (
+                ("remote service error" if self.url is None else self.url),
+                ("unspecified error" if msg is None else msg)))
+        if extra is not None:
+            out.write("%s\n" % (extra,))
+
+    @classmethod
+    def read(cls, err, **kw):
+        self = cls(err.code, json.loads(err.read().decode("utf-8")), **kw)
+        return self
+
 def jreq(url, data, auth):
     authdata = {"alg": "RS256", "url": url, "nonce": getnonce()}
     authdata.update(auth.authdata())
@@ -47,57 +374,28 @@ def jreq(url, data, auth):
         data = base64url(json.dumps(data).encode("us-ascii"))
     seal = base64url(auth.sign(("%s.%s" % (authdata, data)).encode("us-ascii")))
     enc = {"protected": authdata, "payload": data, "signature": seal}
-    with req(url, data=enc) as resp:
-        return json.loads(resp.read().decode("utf-8")), resp.headers
-
-class signreq(object):
-    def domains(self):
-        # No PCKS10 parser for Python?
-        import subprocess, re
-        with subprocess.Popen(["openssl", "req", "-noout", "-text"], stdin=subprocess.PIPE, stdout=subprocess.PIPE) as openssl:
-            openssl.stdin.write(self.data.encode("us-ascii"))
-            openssl.stdin.close()
-            resp = openssl.stdout.read().decode("utf8")
-            if openssl.wait() != 0:
-                raise Exception("openssl error")
-        m = re.search(r"X509v3 Subject Alternative Name:[^\n]*\n\s*((\w+:\S+,\s*)*\w+:\S+)\s*\n", resp)
-        if m is None:
-            return []
-        ret = []
-        for nm in m.group(1).split(","):
-            nm = nm.strip()
-            typ, nm = nm.split(":", 1)
-            if typ == "DNS":
-                ret.append(nm)
-        return ret
+    try:
+        with req(url, data=enc) as resp:
+            return json.loads(resp.read().decode("utf-8")), resp.headers
+    except urllib.error.HTTPError as exc:
+        if exc.headers["Content-Type"] == "application/problem+json":
+            raise problem.read(exc, url=url)
+        raise
 
-    def der(self):
-        import subprocess
-        with subprocess.Popen(["openssl", "req", "-outform", "der"], stdin=subprocess.PIPE, stdout=subprocess.PIPE) as openssl:
-            openssl.stdin.write(self.data.encode("us-ascii"))
-            openssl.stdin.close()
-            resp = openssl.stdout.read()
-            if openssl.wait() != 0:
-                raise Exception("openssl error")
-        return resp
-
-    @classmethod
-    def read(cls, fp):
-        self = cls()
-        self.data = fp.read()
-        return self
+## Authentication
 
 class jwkauth(object):
     def __init__(self, key):
         self.key = key
 
     def authdata(self):
-        return {"jwk": {"kty": "RSA", "e": ebignum(self.key.e), "n": ebignum(self.key.n)}}
+        pub = self.key.public_key().public_numbers()
+        return {"jwk": {"kty": "RSA", "e": ebignum(pub.e), "n": ebignum(pub.n)}}
 
     def sign(self, data):
-        dig = Crypto.Hash.SHA256.new()
-        dig.update(data)
-        return Crypto.Signature.PKCS1_v1_5.new(self.key).sign(dig)
+        from cryptography.hazmat.primitives import hashes
+        from cryptography.hazmat.primitives.asymmetric import padding
+        return self.key.sign(data, padding.PKCS1v15(), hashes.SHA256())
 
 class account(object):
     def __init__(self, uri, key):
@@ -108,9 +406,9 @@ class account(object):
         return {"kid": self.uri}
 
     def sign(self, data):
-        dig = Crypto.Hash.SHA256.new()
-        dig.update(data)
-        return Crypto.Signature.PKCS1_v1_5.new(self.key).sign(dig)
+        from cryptography.hazmat.primitives import hashes
+        from cryptography.hazmat.primitives.asymmetric import padding
+        return self.key.sign(data, padding.PKCS1v15(), hashes.SHA256())
 
     def getinfo(self):
         data, headers = jreq(self.uri, None, self)
@@ -122,18 +420,79 @@ class account(object):
             raise Exception("account is not valid: %s" % (data.get("status", "\"\"")))
 
     def write(self, out):
+        from cryptography.hazmat.primitives import serialization
         out.write("%s\n" % (self.uri,))
-        out.write("%s\n" % (self.key.exportKey().decode("us-ascii"),))
+        out.write("%s\n" % (self.key.private_bytes(
+            encoding=serialization.Encoding.PEM,
+            format=serialization.PrivateFormat.TraditionalOpenSSL,
+            encryption_algorithm=serialization.NoEncryption()
+        ).decode("us-ascii"),))
 
     @classmethod
     def read(cls, fp):
+        from cryptography.hazmat.primitives import serialization
         uri = fp.readline()
         if uri == "":
             raise Exception("missing account URI")
         uri = uri.strip()
-        key = Crypto.PublicKey.RSA.importKey(fp.read())
+        key = serialization.load_pem_private_key(fp.read().encode("us-ascii"), password=None, backend=cryptobke())
         return cls(uri, key)
 
+### ACME protocol
+
+service = "https://acme-v02.api.letsencrypt.org/directory"
+_directory = None
+def directory():
+    global _directory
+    if _directory is None:
+        with req(service) as resp:
+            _directory = json.loads(resp.read().decode("utf-8"))
+    return _directory
+
+def register(keysize=4096):
+    from cryptography.hazmat.primitives.asymmetric import rsa
+    key = rsa.generate_private_key(public_exponent=65537, key_size=keysize, backend=cryptobke())
+    data, headers = jreq(directory()["newAccount"], {"termsOfServiceAgreed": True}, jwkauth(key))
+    return account(headers["Location"], key)
+    
+def mkorder(acct, csr):
+    data, headers = jreq(directory()["newOrder"], {"identifiers": [{"type": "dns", "value": dn} for dn in csr.domains()]}, acct)
+    data["acmecert.location"] = headers["Location"]
+    return data
+
+def httptoken(acct, ch):
+    from cryptography.hazmat.primitives import hashes
+    pub = acct.key.public_key().public_numbers()
+    jwk = {"kty": "RSA", "e": ebignum(pub.e), "n": ebignum(pub.n)}
+    dig = hashes.Hash(hashes.SHA256(), backend=cryptobke())
+    dig.update(json.dumps(jwk, separators=(',', ':'), sort_keys=True).encode("us-ascii"))
+    khash = base64url(dig.finalize())
+    return ch["token"], ("%s.%s" % (ch["token"], khash))
+
+def finalize(acct, csr, orderid):
+    order, headers = jreq(orderid, None, acct)
+    if order["status"] == "valid":
+        pass
+    elif order["status"] == "ready":
+        jreq(order["finalize"], {"csr": base64url(csr.der())}, acct)
+        for n in range(30):
+            resp, headers = jreq(orderid, None, acct)
+            if resp["status"] == "processing":
+                time.sleep(2)
+            elif resp["status"] == "valid":
+                order = resp
+                break
+            else:
+                raise Exception("unexpected order status when finalizing: %s" % resp["status"])
+        else:
+            raise Exception("order finalization timed out")
+    else:
+        raise Exception("unexpected order state when finalizing: %s" % (order["status"],))
+    with req(order["certificate"]) as resp:
+        return resp.read().decode("us-ascii")
+
+## http-01 challenge
+
 class htconfig(object):
     def __init__(self):
         self.roots = {}
@@ -151,25 +510,6 @@ class htconfig(object):
                 sys.stderr.write("acmecert: warning: unknown htconfig directive: %s\n" % (words[0]))
         return self
 
-def register(keysize=4096):
-    key = Crypto.PublicKey.RSA.generate(keysize, Crypto.Random.new().read)
-    # jwk = {"kty": "RSA", "e": ebignum(key.e), "n": ebignum(key.n)}
-    # cjwk = json.dumps(jwk, separators=(',', ':'), sort_keys=True)
-    data, headers = jreq(directory()["newAccount"], {"termsOfServiceAgreed": True}, jwkauth(key))
-    return account(headers["Location"], key)
-    
-def mkorder(acct, csr):
-    data, headers = jreq(directory()["newOrder"], {"identifiers": [{"type": "dns", "value": dn} for dn in csr.domains()]}, acct)
-    data["acmecert.location"] = headers["Location"]
-    return data
-
-def httptoken(acct, ch):
-    jwk = {"kty": "RSA", "e": ebignum(acct.key.e), "n": ebignum(acct.key.n)}
-    dig = Crypto.Hash.SHA256.new()
-    dig.update(json.dumps(jwk, separators=(',', ':'), sort_keys=True).encode("us-ascii"))
-    khash = base64url(dig.digest())
-    return ch["token"], ("%s.%s" % (ch["token"], khash))
-
 def authorder(acct, htconf, orderid):
     order, headers = jreq(orderid, None, acct)
     valid = False
@@ -212,6 +552,13 @@ def authorder(acct, htconf, orderid):
                     resp, headers = jreq(ch["url"], {}, acct)
                     if resp["status"] == "processing":
                         time.sleep(2)
+                    elif resp["status"] == "pending":
+                        # I don't think this should happen, but it
+                        # does. LE bug? Anyway, just retry.
+                        if n < 5:
+                            time.sleep(2)
+                        else:
+                            break
                     elif resp["status"] == "valid":
                         break
                     else:
@@ -221,91 +568,164 @@ def authorder(acct, htconf, orderid):
             finally:
                 os.unlink(tokpath)
 
-def finalize(acct, csr, orderid):
-    order, headers = jreq(orderid, None, acct)
-    if order["status"] == "valid":
-        pass
-    elif order["status"] == "ready":
-        jreq(order["finalize"], {"csr": base64url(csr.der())}, acct)
-        for n in range(30):
-            resp, headers = jreq(orderid, None, acct)
-            if resp["status"] == "processing":
-                time.sleep(2)
-            elif resp["status"] == "valid":
-                order = resp
-                break
-            else:
-                raise Exception("unexpected order status when finalizing: %s" % resp["status"])
-        else:
-            raise Exception("order finalization timed out")
+### Invocation and commands
+
+invdata = threading.local()
+commands = {}
+
+class usageerr(msgerror):
+    def __init__(self):
+        self.cmd = invdata.cmd
+
+    def report(self, out):
+        out.write("%s\n" % (self.cmd.__doc__,))
+
+## User commands
+
+def cmd_reg(args):
+    "usage: acmecert reg [OUTPUT-FILE]"
+    acct = register()
+    os.umask(0o077)
+    with maybeopen(args[1] if len(args) > 1 else "-", "w") as fp:
+        acct.write(fp)
+commands["reg"] = cmd_reg
+
+def cmd_validate_acct(args):
+    "usage: acmecert validate-acct ACCOUNT-FILE"
+    if len(args) < 2: raise usageerr()
+    with maybeopen(args[1], "r") as fp:
+        account.read(fp).validate()
+commands["validate-acct"] = cmd_validate_acct
+
+def cmd_acct_info(args):
+    "usage: acmecert acct-info ACCOUNT-FILE"
+    if len(args) < 2: raise usageerr()
+    with maybeopen(args[1], "r") as fp:
+        pprint.pprint(account.read(fp).getinfo())
+commands["acct-info"] = cmd_acct_info
+
+def cmd_order(args):
+    "usage: acmecert order ACCOUNT-FILE CSR [OUTPUT-FILE]"
+    if len(args) < 3: raise usageerr()
+    with maybeopen(args[1], "r") as fp:
+        acct = account.read(fp)
+    with maybeopen(args[2], "r") as fp:
+        csr = signreq.read(fp)
+    order = mkorder(acct, csr)
+    with maybeopen(args[3] if len(args) > 3 else "-", "w") as fp:
+        fp.write("%s\n" % (order["acmecert.location"]))
+commands["order"] = cmd_order
+
+def cmd_http_auth(args):
+    "usage: acmecert http-auth ACCOUNT-FILE HTTP-CONFIG {ORDER-ID|ORDER-FILE}"
+    if len(args) < 4: raise usageerr()
+    with maybeopen(args[1], "r") as fp:
+        acct = account.read(fp)
+    with maybeopen(args[2], "r") as fp:
+        htconf = htconfig.read(fp)
+    if "://" in args[3]:
+        orderid = args[3]
     else:
-        raise Exception("unexpected order state when finalizing: %s" % (order["status"],))
-    with req(order["certificate"]) as resp:
-        return resp.read().decode("us-ascii")
+        with maybeopen(args[3], "r") as fp:
+            orderid = fp.readline().strip()
+    authorder(acct, htconf, orderid)
+commands["http-auth"] = cmd_http_auth
+
+def cmd_get(args):
+    "usage: acmecert get ACCOUNT-FILE CSR {ORDER-ID|ORDER-FILE}"
+    if len(args) < 4: raise usageerr()
+    with maybeopen(args[1], "r") as fp:
+        acct = account.read(fp)
+    with maybeopen(args[2], "r") as fp:
+        csr = signreq.read(fp)
+    if "://" in args[3]:
+        orderid = args[3]
+    else:
+        with maybeopen(args[3], "r") as fp:
+            orderid = fp.readline().strip()
+    sys.stdout.write(finalize(acct, csr, orderid))
+commands["get"] = cmd_get
+
+def cmd_http_order(args):
+    "usage: acmecert http-order ACCOUNT-FILE CSR HTTP-CONFIG [OUTPUT-FILE]"
+    if len(args) < 4: raise usageerr()
+    with maybeopen(args[1], "r") as fp:
+        acct = account.read(fp)
+    with maybeopen(args[2], "r") as fp:
+        csr = signreq.read(fp)
+    with maybeopen(args[3], "r") as fp:
+        htconf = htconfig.read(fp)
+    orderid = mkorder(acct, csr)["acmecert.location"]
+    authorder(acct, htconf, orderid)
+    with maybeopen(args[4] if len(args) > 4 else "-", "w") as fp:
+        fp.write(finalize(acct, csr, orderid))
+commands["http-order"] = cmd_http_order
+
+def cmd_check_cert(args):
+    "usage: acmecert check-cert CERT-FILE TIME-SPEC"
+    if len(args) < 3: raise usageerr()
+    with maybeopen(args[1], "r") as fp:
+        crt = certificate.read(fp)
+    sys.exit(1 if crt.expiring(args[2]) else 0)
+commands["check-cert"] = cmd_check_cert
+
+def cmd_directory(args):
+    "usage: acmecert directory"
+    pprint.pprint(directory())
+commands["directory"] = cmd_directory
+
+## Main invocation
 
 def usage(out):
-    out.write("usage: acmecert [-h] [-D SERVICE]\n")
+    out.write("usage: acmecert [-D SERVICE] COMMAND [ARGS...]\n")
+    out.write("       acmecert -h [COMMAND]\n")
+    buf =     "       COMMAND is any of: "
+    f = True
+    for cmd in commands:
+        if len(buf) + len(cmd) > 70:
+            out.write("%s\n" % (buf,))
+            buf =     "           "
+            f = True
+        if not f:
+            buf += ", "
+        buf += cmd
+        f = False
+    if not f:
+        out.write("%s\n" % (buf,))
 
 def main(argv):
     global service
     opts, args = getopt.getopt(argv[1:], "hD:")
     for o, a in opts:
         if o == "-h":
-            usage(sys.stdout)
+            if len(args) > 0:
+                cmd = commands.get(args[0])
+                if cmd is None:
+                    sys.stderr.write("acmecert: unknown command: %s\n" % (args[0],))
+                    sys.exit(1)
+                sys.stdout.write("%s\n" % (cmd.__doc__,))
+            else:
+                usage(sys.stdout)
             sys.exit(0)
         elif o == "-D":
             service = a
     if len(args) < 1:
         usage(sys.stderr)
         sys.exit(1)
-    if args[0] == "reg":
-        register().write(sys.stdout)
-    elif args[0] == "validate-acct":
-        with open(args[1], "r") as fp:
-            account.read(fp).validate()
-    elif args[0] == "acctinfo":
-        with open(args[1], "r") as fp:
-            pprint.pprint(account.read(fp).getinfo())
-    elif args[0] == "order":
-        with open(args[1], "r") as fp:
-            acct = account.read(fp)
-        with open(args[2], "r") as fp:
-            csr = signreq.read(fp)
-        order = mkorder(acct, csr)
-        with open(args[3], "w") as fp:
-            fp.write("%s\n" % (order["acmecert.location"]))
-    elif args[0] == "http-auth":
-        with open(args[1], "r") as fp:
-            acct = account.read(fp)
-        with open(args[2], "r") as fp:
-            htconf = htconfig.read(fp)
-        with open(args[3], "r") as fp:
-            orderid = fp.readline().strip()
-        authorder(acct, htconf, orderid)
-    elif args[0] == "get":
-        with open(args[1], "r") as fp:
-            acct = account.read(fp)
-        with open(args[2], "r") as fp:
-            csr = signreq.read(fp)
-        with open(args[3], "r") as fp:
-            orderid = fp.readline().strip()
-        sys.stdout.write(finalize(acct, csr, orderid))
-    elif args[0] == "http-order":
-        with open(args[1], "r") as fp:
-            acct = account.read(fp)
-        with open(args[2], "r") as fp:
-            csr = signreq.read(fp)
-        with open(args[3], "r") as fp:
-            htconf = htconfig.read(fp)
-        orderid = mkorder(acct, csr)["acmecert.location"]
-        authorder(acct, htconf, orderid)
-        sys.stdout.write(finalize(acct, csr, orderid))
-    elif args[0] == "directory":
-        pprint.pprint(directory())
-    else:
+    cmd = commands.get(args[0])
+    if cmd is None:
         sys.stderr.write("acmecert: unknown command: %s\n" % (args[0],))
         usage(sys.stderr)
         sys.exit(1)
+    try:
+        try:
+            invdata.cmd = cmd
+            cmd(args)
+        finally:
+            invdata.cmd = None
+    except msgerror as exc:
+        exc.report(sys.stderr)
+        sys.exit(1)
 
 if __name__ == "__main__":
     try: