Add Linux/BSD switch for Unix auth.
[doldaconnect.git] / lib / uilib.c
index 65215b7..973a2b3 100644 (file)
@@ -102,6 +102,7 @@ struct {
     int family;
     int sentcreds;
 } servinfo;
+char *dc_srv_local;
 
 static struct dc_response *makeresp(void)
 {
@@ -258,6 +259,7 @@ int dc_init(void)
 {
     if((ichandle = iconv_open("wchar_t", "utf-8")) == (iconv_t)-1)
        return(-1);
+    dc_srv_local = sstrdup("");
     initcmds();
     return(0);
 }
@@ -378,7 +380,7 @@ int dc_queuecmd(int (*callback)(struct dc_response *), void *data, ...)
     struct qcmd *qcmd;
     int num, freepart;
     va_list al;
-    char *final;
+    char *final, *sarg;
     wchar_t **toks;
     wchar_t *buf;
     wchar_t *part, *tpart;
@@ -390,7 +392,7 @@ int dc_queuecmd(int (*callback)(struct dc_response *), void *data, ...)
     va_start(al, data);
     while((part = va_arg(al, wchar_t *)) != NULL)
     {
-       if(!wcscmp(part, L"%%a"))
+       if(!wcscmp(part, L"%a"))
        {
            for(toks = va_arg(al, wchar_t **); *toks != NULL; toks++)
            {
@@ -410,25 +412,36 @@ int dc_queuecmd(int (*callback)(struct dc_response *), void *data, ...)
        } else {
            if(*part == L'%')
            {
-               /* This demands that all arguments that are passed to the
-                * function are of equal length, that of an int. I know
-                * that GCC does that on IA32 platforms, but I do not know
-                * which other platforms and compilers that it applies
-                * to. If this breaks your platform, please mail me about
-                * it.
-                */
-               part = vswprintf2(tpart = (part + 1), al);
-               for(; *tpart != L'\0'; tpart++)
+               tpart = part + 1;
+               if(!wcscmp(tpart, L"i"))
                {
-                   if(*tpart == L'%')
+                   freepart = 1;
+                   part = swprintf2(L"%i", va_arg(al, int));
+               } else if(!wcscmp(tpart, L"s")) {
+                   freepart = 1;
+                   part = icmbstowcs(sarg = va_arg(al, char *), NULL);
+                   if(part == NULL)
                    {
-                       if(tpart[1] == L'%')
-                           tpart++;
-                       else
-                           va_arg(al, int);
+                       if(buf != NULL)
+                           free(buf);
+                       return(-1);
                    }
+               } else if(!wcscmp(tpart, L"ls")) {
+                   part = va_arg(al, wchar_t *);
+               } else if(!wcscmp(tpart, L"ll")) {
+                   freepart = 1;
+                   part = swprintf2(L"%lli", va_arg(al, long long));
+               } else if(!wcscmp(tpart, L"f")) {
+                   freepart = 1;
+                   part = swprintf2(L"%f", va_arg(al, double));
+               } else if(!wcscmp(tpart, L"x")) {
+                   freepart = 1;
+                   part = swprintf2(L"%x", va_arg(al, int));
+               } else {
+                   if(buf != NULL)
+                       free(buf);
+                   return(-1);
                }
-               freepart = 1;
            } else {
                freepart = 0;
            }
@@ -749,6 +762,7 @@ int dc_handleread(void)
     return(0);
 }
 
+#if UNIX_AUTH_STYLE == 1
 static void mkcreds(struct msghdr *msg)
 {
     struct ucred *ucred;
@@ -767,6 +781,7 @@ static void mkcreds(struct msghdr *msg)
     ucred->gid = getgid();
     msg->msg_controllen = cmsg->cmsg_len;
 }
+#endif
 
 int dc_handlewrite(void)
 {
@@ -785,11 +800,13 @@ int dc_handlewrite(void)
            msg.msg_iovlen = 1;
            bufvec.iov_base = queue->buf;
            bufvec.iov_len = queue->buflen;
+#if UNIX_AUTH_STYLE == 1
            if((servinfo.family == PF_UNIX) && !servinfo.sentcreds)
            {
                mkcreds(&msg);
                servinfo.sentcreds = 1;
            }
+#endif
            ret = sendmsg(fd, &msg, MSG_NOSIGNAL | MSG_DONTWAIT);
            if(ret < 0)
            {
@@ -1085,30 +1102,39 @@ static struct addrinfo *resolvhost(char *host)
     return(NULL);
 }
 
-static struct addrinfo *defaulthost(void)
+static struct addrinfo *getlocalai(void)
 {
     struct addrinfo *ret;
     struct passwd *pwd;
     char *tmp;
-    char dn[1024];
-    
-    if(((tmp = getenv("DCSERVER")) != NULL) && *tmp)
-       return(resolvhost(tmp));
+
     ret = NULL;
     if((getuid() != 0) && ((pwd = getpwuid(getuid())) != NULL))
     {
        tmp = sprintf2("/tmp/doldacond-%s", pwd->pw_name);
-       ret = gaicat(ret, unixgai(SOCK_STREAM, tmp));
+       ret = unixgai(SOCK_STREAM, tmp);
        free(tmp);
     }
     ret = gaicat(ret, unixgai(SOCK_STREAM, "/var/run/doldacond.sock"));
+    return(ret);
+}
+
+static struct addrinfo *defaulthost(void)
+{
+    struct addrinfo *ret;
+    char *tmp;
+    char dn[1024];
+    
+    if(((tmp = getenv("DCSERVER")) != NULL) && *tmp)
+       return(resolvhost(tmp));
+    ret = getlocalai();
     ret = gaicat(ret, resolvtcp("localhost", 1500));
     if(!getdomainname(dn, sizeof(dn)) && *dn && strcmp(dn, "(none)"))
        ret = gaicat(ret, resolvsrv(dn));
     return(ret);
 }
 
-int dc_connect(char *host)
+static int dc_connectai(struct addrinfo *hosts, struct qcmd **cnctcmd)
 {
     struct qcmd *qcmd;
     int errnobak;
@@ -1118,17 +1144,14 @@ int dc_connect(char *host)
     state = -1;
     if(hostlist != NULL)
        freeaddrinfo(hostlist);
-    if(!host || !*host)
-       hostlist = defaulthost();
-    else
-       hostlist = resolvhost(host);
-    if(hostlist == NULL)
-       return(-1);
+    hostlist = hosts;
     for(curhost = hostlist; curhost != NULL; curhost = curhost->ai_next)
     {
        if((fd = socket(curhost->ai_family, curhost->ai_socktype, curhost->ai_protocol)) < 0)
        {
            errnobak = errno;
+           freeaddrinfo(hostlist);
+           hostlist = NULL;
            errno = errnobak;
            return(-1);
        }
@@ -1150,11 +1173,82 @@ int dc_connect(char *host)
            break;
        }
     }
-    qcmd = makeqcmd(NULL);
-    resetreader = 1;
+    if(state != -1)
+    {
+       qcmd = makeqcmd(NULL);
+       if(cnctcmd != NULL)
+           *cnctcmd = qcmd;
+       resetreader = 1;
+    } else {
+       free(hostlist);
+       hostlist = NULL;
+    }
     return(fd);
 }
 
+static int dc_connect2(char *host, struct qcmd **cnctcmd)
+{
+    struct addrinfo *ai;
+    struct qcmd *qcmd;
+    int ret;
+    
+    if(host == dc_srv_local)
+       ai = getlocalai();
+    else if(!host || !*host)
+       ai = defaulthost();
+    else
+       ai = resolvhost(host);
+    if(ai == NULL)
+       return(-1);
+    ret = dc_connectai(ai, &qcmd);
+    if((ret >= 0) && (cnctcmd != NULL))
+       *cnctcmd = qcmd;
+    return(ret);
+}
+
+int dc_connect(char *host)
+{
+    return(dc_connect2(host, NULL));
+}
+
+int dc_connectsync(char *host, struct dc_response **respbuf)
+{
+    int ret;
+    struct qcmd *cc;
+    struct dc_response *resp;
+    
+    if((ret = dc_connect2(host, &cc)) < 0)
+       return(-1);
+    resp = dc_gettaggedrespsync(cc->tag);
+    if(resp == NULL) {
+       dc_disconnect();
+       return(-1);
+    }
+    if(respbuf == NULL)
+       dc_freeresp(resp);
+    else
+       *respbuf = resp;
+    return(ret);
+}
+
+int dc_connectsync2(char *host, int rev)
+{
+    int ret;
+    struct dc_response *resp;
+    
+    if((ret = dc_connectsync(host, &resp)) < 0)
+       return(-1);
+    if(dc_checkprotocol(resp, rev))
+    {
+       dc_freeresp(resp);
+       dc_disconnect();
+       errno = EPROTONOSUPPORT;
+       return(-1);
+    }
+    dc_freeresp(resp);
+    return(ret);
+}
+
 struct dc_intresp *dc_interpret(struct dc_response *resp)
 {
     int i;
@@ -1225,7 +1319,30 @@ void dc_freeires(struct dc_intresp *ires)
     free(ires);
 }
 
+int dc_checkprotocol(struct dc_response *resp, int revision)
+{
+    struct dc_intresp *ires;
+    int low, high;
+    
+    if(resp->code != 201)
+       return(-1);
+    resp->curline = 0;
+    if((ires = dc_interpret(resp)) == NULL)
+       return(-1);
+    low = ires->argv[0].val.num;
+    high = ires->argv[1].val.num;
+    dc_freeires(ires);
+    if((revision < low) || (revision > high))
+       return(-1);
+    return(0);
+}
+
 const char *dc_gethostname(void)
 {
     return(servinfo.hostname);
 }
+
+int dc_getfd(void)
+{
+    return(fd);
+}