Major rework to use cbchains on sockets.
[doldaconnect.git] / daemon / net.c
index 0f7be54..4054518 100644 (file)
@@ -33,7 +33,6 @@
 #include <netinet/in.h>
 #include <netdb.h>
 #include <sys/signal.h>
-#include <printf.h>
 #ifdef HAVE_LINUX_SOCKIOS_H
 #include <linux/sockios.h>
 #endif
@@ -95,6 +94,8 @@ int getpublicaddr(int af, struct sockaddr **addr, socklen_t *lenbuf)
            flog(LOG_ERR, "could not convert net.publicif into local charset: %s", strerror(errno));
            return(-1);
        }
+       if(!strcmp(pif, ""))
+           return(1);
        if((sock = socket(AF_INET, SOCK_DGRAM, 0)) < 0)
            return(-1);
        conf.ifc_buf = smalloc(conf.ifc_len = 65536);
@@ -108,32 +109,25 @@ int getpublicaddr(int af, struct sockaddr **addr, socklen_t *lenbuf)
        ipv4 = NULL;
        for(ifr = conf.ifc_ifcu.ifcu_req; (void *)ifr < bufend; ifr++)
        {
+           if(strcmp(ifr->ifr_name, pif))
+               continue;
            memset(&req, 0, sizeof(req));
            memcpy(req.ifr_name, ifr->ifr_name, sizeof(ifr->ifr_name));
            if(ioctl(sock, SIOCGIFFLAGS, &req) < 0)
+               break;
+           if(!(req.ifr_flags & IFF_UP))
            {
-               free(conf.ifc_buf);
-               close(sock);
-               return(-1);
+               flog(LOG_WARNING, "public interface is down");
+               break;
            }
-           if(!(req.ifr_flags & IFF_UP))
-               continue;
-           if(ifr->ifr_addr.sa_family == AF_INET)
+           if(ifr->ifr_addr.sa_family != AF_INET)
            {
-               if(ntohl(((struct sockaddr_in *)&ifr->ifr_addr)->sin_addr.s_addr) == 0x7f000001)
-                   continue;
-               if(ipv4 == NULL)
-               {
-                   ipv4 = smalloc(sizeof(*ipv4));
-                   memcpy(ipv4, &ifr->ifr_addr, sizeof(ifr->ifr_addr));
-               } else {
-                   free(ipv4);
-                   free(conf.ifc_buf);
-                   flog(LOG_WARNING, "could not locate an unambiguous interface for determining your public IP address - set net.publicif");
-                   errno = ENFILE; /* XXX: There's no appropriate one for this... */
-                   return(-1);
-               }
+               flog(LOG_WARNING, "address of the public interface is not AF_INET");
+               break;
            }
+           ipv4 = smalloc(sizeof(*ipv4));
+           memcpy(ipv4, &ifr->ifr_addr, sizeof(ifr->ifr_addr));
+           break;
        }
        free(conf.ifc_buf);
        close(sock);
@@ -146,8 +140,7 @@ int getpublicaddr(int af, struct sockaddr **addr, socklen_t *lenbuf)
        errno = ENETDOWN;
        return(-1);
     }
-    errno = EPFNOSUPPORT;
-    return(-1);
+    return(1);
 }
 
 static struct socket *newsock(int type)
@@ -181,11 +174,11 @@ static struct socket *newsock(int type)
        new->inbuf.d.f = new->inbuf.d.l = NULL;
        break;
     }
-    new->conncb = NULL;
-    new->errcb = NULL;
-    new->readcb = NULL;
-    new->writecb = NULL;
-    new->acceptcb = NULL;
+    CBCHAININIT(new, socket_conn);
+    CBCHAININIT(new, socket_err);
+    CBCHAININIT(new, socket_read);
+    CBCHAININIT(new, socket_write);
+    CBCHAININIT(new, socket_accept);
     new->next = sockets;
     new->prev = NULL;
     if(sockets != NULL)
@@ -247,6 +240,11 @@ void putsock(struct socket *sk)
     
     if(--(sk->refcount) == 0)
     {
+       CBCHAINFREE(sk, socket_conn);
+       CBCHAINFREE(sk, socket_err);
+       CBCHAINFREE(sk, socket_read);
+       CBCHAINFREE(sk, socket_write);
+       CBCHAINFREE(sk, socket_accept);
        switch(sk->type)
        {
        case SOCK_STREAM:
@@ -260,12 +258,14 @@ void putsock(struct socket *sk)
            {
                sk->outbuf.d.f = buf->next;
                free(buf->data);
+               free(buf->addr);
                free(buf);
            }
            while((buf = sk->inbuf.d.f) != NULL)
            {
                sk->inbuf.d.f = buf->next;
                free(buf->data);
+               free(buf->addr);
                free(buf);
            }
            break;
@@ -357,21 +357,18 @@ static void sockrecv(struct socket *sk)
        {
            if((errno == EINTR) || (errno == EAGAIN))
                return;
-           if(sk->errcb != NULL)
-               sk->errcb(sk, errno, sk->data);
+           CBCHAINDOCB(sk, socket_err, sk, errno);
            closesock(sk);
            return;
        }
        if(ret == 0)
        {
-           if(sk->errcb != NULL)
-               sk->errcb(sk, 0, sk->data);
+           CBCHAINDOCB(sk, socket_err, sk, 0);
            closesock(sk);
            return;
        }
        sk->inbuf.s.datasize += ret;
-       if(sk->readcb != NULL)
-           sk->readcb(sk, sk->data);
+       CBCHAINDOCB(sk, socket_read, sk);
        break;
     case SOCK_DGRAM:
        if(ioctl(sk->fd, SIOCINQ, &inq))
@@ -392,8 +389,7 @@ static void sockrecv(struct socket *sk)
            free(dbuf);
            if((errno == EINTR) || (errno == EAGAIN))
                return;
-           if(sk->errcb != NULL)
-               sk->errcb(sk, errno, sk->data);
+           CBCHAINDOCB(sk, socket_err, sk, errno);
            closesock(sk);
            return;
        }
@@ -408,8 +404,7 @@ static void sockrecv(struct socket *sk)
            free(dbuf);
            if(!((sk->family == AF_INET) || (sk->family == AF_INET6)))
            {
-               if(sk->errcb != NULL)
-                   sk->errcb(sk, 0, sk->data);
+               CBCHAINDOCB(sk, socket_err, sk, 0);
                closesock(sk);
            }
            return;
@@ -422,8 +417,7 @@ static void sockrecv(struct socket *sk)
        else
            sk->inbuf.d.f = dbuf;
        sk->inbuf.d.l = dbuf;
-       if(sk->readcb != NULL)
-           sk->readcb(sk, sk->data);
+       CBCHAINDOCB(sk, socket_read, sk);
        break;
     }
 }
@@ -449,8 +443,7 @@ static void sockflush(struct socket *sk)
        if(ret > 0)
        {
            memmove(sk->outbuf.s.buf, ((char *)sk->outbuf.s.buf) + ret, sk->outbuf.s.datasize -= ret);
-           if(sk->writecb != NULL)
-               sk->writecb(sk, sk->data);
+           CBCHAINDOCB(sk, socket_write, sk);
        }
        break;
     case SOCK_DGRAM:
@@ -461,8 +454,7 @@ static void sockflush(struct socket *sk)
        free(dbuf->data);
        free(dbuf->addr);
        free(dbuf);
-       if(sk->writecb != NULL)
-           sk->writecb(sk, sk->data);
+       CBCHAINDOCB(sk, socket_write, sk);
        break;
     }
 }
@@ -552,7 +544,7 @@ size_t sockqueuesize(struct socket *sk)
  * netcslisten() instead.
 */
 
-struct socket *netcslistenlocal(int type, struct sockaddr *name, socklen_t namelen, void (*func)(struct socket *, struct socket *, void *), void *data)
+struct socket *netcslistenlocal(int type, struct sockaddr *name, socklen_t namelen, int (*func)(struct socket *, struct socket *, void *), void *data)
 {
     struct socket *sk;
     int intbuf;
@@ -581,37 +573,31 @@ struct socket *netcslistenlocal(int type, struct sockaddr *name, socklen_t namel
        putsock(sk);
        return(NULL);
     }
-    sk->acceptcb = func;
-    sk->data = data;
+    if(func != NULL)
+       CBREG(sk, socket_accept, func, NULL, data);
     return(sk);
 }
 
-struct socket *netcslisten(int type, struct sockaddr *name, socklen_t namelen, void (*func)(struct socket *, struct socket *, void *), void *data)
+struct socket *netcslisten(int type, struct sockaddr *name, socklen_t namelen, int (*func)(struct socket *, struct socket *, void *), void *data)
 {
     if(confgetint("net", "mode") == 1)
     {
        errno = EOPNOTSUPP;
        return(NULL);
     }
-    /* I don't know if this is actually correct (it probably isn't),
-     * but since, at on least Linux systems, PF_* are specifically
-     * #define'd to their AF_* counterparts, it allows for a severely
-     * smoother implementation. If it breaks something on your
-     * platform, please tell me so.
-     */
     if(confgetint("net", "mode") == 0)
        return(netcslistenlocal(type, name, namelen, func, data));
     errno = EOPNOTSUPP;
     return(NULL);
 }
 
-struct socket *netcstcplisten(int port, int local, void (*func)(struct socket *, struct socket *, void *), void *data)
+struct socket *netcstcplisten(int port, int local, int (*func)(struct socket *, struct socket *, void *), void *data)
 {
     struct sockaddr_in addr;
 #ifdef HAVE_IPV6
     struct sockaddr_in6 addr6;
 #endif
-    struct socket *(*csfunc)(int, struct sockaddr *, socklen_t, void (*)(struct socket *, struct socket *, void *), void *);
+    struct socket *(*csfunc)(int, struct sockaddr *, socklen_t, int (*)(struct socket *, struct socket *, void *), void *);
     struct socket *ret;
     
     if(local)
@@ -682,7 +668,7 @@ void netdgramconn(struct socket *sk, struct sockaddr *addr, socklen_t addrlen)
     sk->ignread = 1;
 }
 
-struct socket *netcsconn(struct sockaddr *addr, socklen_t addrlen, void (*func)(struct socket *, int, void *), void *data)
+struct socket *netcsconn(struct sockaddr *addr, socklen_t addrlen, int (*func)(struct socket *, int, void *), void *data)
 {
     struct socket *sk;
     int mode;
@@ -702,8 +688,8 @@ struct socket *netcsconn(struct sockaddr *addr, socklen_t addrlen, void (*func)(
        if(errno == EINPROGRESS)
        {
            sk->state = SOCK_SYN;
-           sk->conncb = func;
-           sk->data = data;
+           if(func != NULL)
+               CBREG(sk, socket_conn, func, NULL, data);
            return(sk);
        }
        putsock(sk);
@@ -715,7 +701,8 @@ struct socket *netcsconn(struct sockaddr *addr, socklen_t addrlen, void (*func)(
 
 int pollsocks(int timeout)
 {
-    int i, num, ret, retlen;
+    int i, num, ret;
+    socklen_t retlen;
     int newfd;
     struct pollfd *pfds;
     struct socket *sk, *next, *newsk;
@@ -769,10 +756,7 @@ int pollsocks(int timeout)
            {
                sslen = sizeof(ss);
                if((newfd = accept(sk->fd, (struct sockaddr *)&ss, &sslen)) < 0)
-               {
-                   if(sk->errcb != NULL)
-                       sk->errcb(sk, errno, sk->data);
-               }
+                   CBCHAINDOCB(sk, socket_err, sk, errno);
                newsk = newsock(sk->type);
                newsk->fd = newfd;
                newsk->family = sk->family;
@@ -780,15 +764,13 @@ int pollsocks(int timeout)
                memcpy(newsk->remote = smalloc(sslen), &ss, sslen);
                newsk->remotelen = sslen;
                putsock(newsk);
-               if(sk->acceptcb != NULL)
-                   sk->acceptcb(sk, newsk, sk->data);
+               CBCHAINDOCB(sk, socket_accept, sk, newsk);
            }
            if(pfds[i].revents & POLLERR)
            {
                retlen = sizeof(ret);
                getsockopt(sk->fd, SOL_SOCKET, SO_ERROR, &ret, &retlen);
-               if(sk->errcb != NULL)
-                   sk->errcb(sk, ret, sk->data);
+               CBCHAINDOCB(sk, socket_err, sk, ret);
                continue;
            }
            break;
@@ -797,16 +779,14 @@ int pollsocks(int timeout)
            {
                retlen = sizeof(ret);
                getsockopt(sk->fd, SOL_SOCKET, SO_ERROR, &ret, &retlen);
-               if(sk->conncb != NULL)
-                   sk->conncb(sk, ret, sk->data);
+               CBCHAINDOCB(sk, socket_conn, sk, ret);
                closesock(sk);
                continue;
            }
            if(pfds[i].revents & (POLLIN | POLLOUT))
            {
                sk->state = SOCK_EST;
-               if(sk->conncb != NULL)
-                   sk->conncb(sk, 0, sk->data);
+               CBCHAINDOCB(sk, socket_conn, sk, 0);
            }
            break;
        case SOCK_EST:
@@ -814,8 +794,7 @@ int pollsocks(int timeout)
            {
                retlen = sizeof(ret);
                getsockopt(sk->fd, SOL_SOCKET, SO_ERROR, &ret, &retlen);
-               if(sk->errcb != NULL)
-                   sk->errcb(sk, ret, sk->data);
+               CBCHAINDOCB(sk, socket_err, sk, ret);
                closesock(sk);
                continue;
            }
@@ -837,8 +816,7 @@ int pollsocks(int timeout)
        }
        if(pfds[i].revents & POLLHUP)
        {
-           if(sk->errcb != NULL)
-               sk->errcb(sk, 0, sk->data);
+           CBCHAINDOCB(sk, socket_err, sk, 0);
            closesock(sk);
            unlinksock(sk);
            continue;
@@ -1044,7 +1022,7 @@ int sockgetlocalname(struct socket *sk, struct sockaddr **namebuf, socklen_t *le
     len = sizeof(name);
     if(getsockname(sk->fd, (struct sockaddr *)&name, &len) < 0)
     {
-       flog(LOG_ERR, "BUG: alive socket with dead fd in sockgetlocalname");
+       flog(LOG_ERR, "BUG: alive socket with dead fd in sockgetlocalname (%s)", strerror(errno));
        return(-1);
     }
     *namebuf = memcpy(smalloc(len), &name, len);
@@ -1052,40 +1030,68 @@ int sockgetlocalname(struct socket *sk, struct sockaddr **namebuf, socklen_t *le
     return(0);
 }
 
+static void sethostaddr(struct sockaddr *dst, struct sockaddr *src)
+{
+    if(dst->sa_family != src->sa_family)
+    {
+       flog(LOG_ERR, "BUG: non-matching socket families in sethostaddr (%i -> %i)", src->sa_family, dst->sa_family);
+       return;
+    }
+    switch(src->sa_family)
+    {
+    case AF_INET:
+       ((struct sockaddr_in *)dst)->sin_addr = ((struct sockaddr_in *)src)->sin_addr;
+       break;
+    case AF_INET6:
+       ((struct sockaddr_in6 *)dst)->sin6_addr = ((struct sockaddr_in6 *)src)->sin6_addr;
+       break;
+    default:
+       flog(LOG_WARNING, "sethostaddr unimplemented for family %i", src->sa_family);
+       break;
+    }
+}
+
+static int makepublic(struct sockaddr *addr)
+{
+    int ret;
+    socklen_t plen;
+    struct sockaddr *pname;
+    
+    if((ret = getpublicaddr(addr->sa_family, &pname, &plen)) < 0)
+    {
+       flog(LOG_ERR, "could not get public address: %s", strerror(errno));
+       return(-1);
+    }
+    if(ret)
+       return(0);
+    sethostaddr(addr, pname);
+    free(pname);
+    return(0);
+}
+
 int sockgetremotename(struct socket *sk, struct sockaddr **namebuf, socklen_t *lenbuf)
 {
     socklen_t len;
-    struct sockaddr_storage name;
-    struct sockaddr_in *ipv4;
-    struct sockaddr *pname;
-    socklen_t pnamelen;
+    struct sockaddr *name;
     
     switch(confgetint("net", "mode"))
     {
     case 0:
        *namebuf = NULL;
        if((sk->state == SOCK_STL) || (sk->fd < 0))
-           return(-1);
-       len = sizeof(name);
-       if(getsockname(sk->fd, (struct sockaddr *)&name, &len) < 0)
        {
-           flog(LOG_ERR, "BUG: alive socket with dead fd in sockgetremotename");
+           errno = EBADF;
            return(-1);
        }
-       if(name.ss_family == AF_INET)
+       if(!sockgetlocalname(sk, &name, &len))
        {
-           ipv4 = (struct sockaddr_in *)&name;
-           if(getpublicaddr(AF_INET, &pname, &pnamelen) < 0)
-           {
-               flog(LOG_WARNING, "could not determine public IP address - strange things may happen");
-               return(-1);
-           }
-           ipv4->sin_addr.s_addr = ((struct sockaddr_in *)pname)->sin_addr.s_addr;
-           free(pname);
+           *namebuf = name;
+           *lenbuf = len;
+           makepublic(name);
+           return(0);
        }
-       *namebuf = memcpy(smalloc(len), &name, len);
-       *lenbuf = len;
-       return(0);
+       flog(LOG_ERR, "could not get remotely accessible name by any means");
+       return(-1);
     case 1:
        errno = EOPNOTSUPP;
        return(-1);
@@ -1096,6 +1102,29 @@ int sockgetremotename(struct socket *sk, struct sockaddr **namebuf, socklen_t *l
     }
 }
 
+int sockgetremotename2(struct socket *sk, struct socket *sk2, struct sockaddr **namebuf, socklen_t *lenbuf)
+{
+    struct sockaddr *name1, *name2;
+    socklen_t len1, len2;
+    
+    if(sk->family != sk2->family)
+    {
+       flog(LOG_ERR, "using sockgetremotename2 with sockets of differing family: %i %i", sk->family, sk2->family);
+       return(-1);
+    }
+    if(sockgetremotename(sk, &name1, &len1))
+       return(-1);
+    if(sockgetremotename(sk2, &name2, &len2)) {
+       free(name1);
+       return(-1);
+    }
+    sethostaddr(name1, name2);
+    free(name2);
+    *namebuf = name1;
+    *lenbuf = len1;
+    return(0);
+}
+
 int addreq(struct sockaddr *x, struct sockaddr *y)
 {
     struct sockaddr_un *u1, *u2;
@@ -1119,6 +1148,7 @@ int addreq(struct sockaddr *x, struct sockaddr *y)
        if(n1->sin_addr.s_addr != n2->sin_addr.s_addr)
            return(0);
        break;
+#ifdef HAVE_IPV6
     case AF_INET6:
        s1 = (struct sockaddr_in6 *)x; s2 = (struct sockaddr_in6 *)y;
        if(s1->sin6_port != s2->sin6_port)
@@ -1126,6 +1156,7 @@ int addreq(struct sockaddr *x, struct sockaddr *y)
        if(memcmp(s1->sin6_addr.s6_addr, s2->sin6_addr.s6_addr, sizeof(s1->sin6_addr.s6_addr)))
            return(0);
        break;
+#endif
     }
     return(1);
 }