Allow the usage of the TOS API to set DSCP values on IPv4 sockets.
[doldaconnect.git] / daemon / net.c
index 9b5a34a..e7c63d1 100644 (file)
@@ -28,7 +28,7 @@
 #include <sys/ioctl.h>
 #include <sys/socket.h>
 #include <sys/un.h>
-#include <sys/poll.h>
+#include <sys/select.h>
 #include <arpa/inet.h>
 #include <netinet/in.h>
 #include <netdb.h>
@@ -66,8 +66,6 @@ static struct configvar myvars[] =
      * and net.visibleipv4 are unspecified the address of the hub
      * connection is used. */
     {CONF_VAR_STRING, "publicif", {.str = L""}},
-    /* Diffserv should be supported on IPv4, too, but I don't know the
-     * API to do that. */
     /** The Diffserv value to use on IPv6 connections when the
      * minimize cost TOS value is used (see the TOS VALUES
      * section). */
@@ -84,6 +82,11 @@ static struct configvar myvars[] =
      * minimize delay TOS value is used (see the TOS VALUES
      * section). */
     {CONF_VAR_INT, "diffserv-mindelay", {.num = 0}},
+    /** If enabled, the IP TOS interface will be used to set Diffserv
+     * codepoints on IPv4 sockets, by shifting the DSCP value to bits
+     * to the left (across the ECN bits). This may only work on
+     * Linux. */
+    {CONF_VAR_BOOL, "dscp-tos", {.num = 0}},
     {CONF_VAR_END}
 };
 
@@ -172,6 +175,7 @@ static struct socket *newsock(int type)
     struct socket *new;
     
     new = smalloc(sizeof(*new));
+    memset(new, 0, sizeof(*new));
     new->refcount = 2;
     new->fd = -1;
     new->isrealsocket = 1;
@@ -388,12 +392,17 @@ static void sockrecv(struct socket *sk)
 #if defined(HAVE_LINUX_SOCKIOS_H) && defined(SIOCINQ)
        /* SIOCINQ is Linux-specific AFAIK, but I really have no idea
         * how to read the inqueue size on other OSs */
-       if(ioctl(sk->fd, SIOCINQ, &inq))
-       {
-           /* I don't really know what could go wrong here, so let's
-            * assume it's transient. */
-           flog(LOG_WARNING, "SIOCINQ return %s on socket %i, falling back to 2048 bytes", strerror(errno), sk->fd);
-           inq = 2048;
+       if(sk->isrealsocket) {
+           if(ioctl(sk->fd, SIOCINQ, &inq))
+           {
+               /* I don't really know what could go wrong here, so let's
+                * assume it's transient. */
+               flog(LOG_WARNING, "SIOCINQ return %s on socket %i, falling back to 2048 bytes", strerror(errno), sk->fd);
+               inq = 2048;
+           }
+       } else {
+           /* There are perils when trying to use SIOCINQ on files >2GiB... */
+           inq = 65536;
        }
 #else
        inq = 2048;
@@ -520,8 +529,12 @@ static void sockflush(struct socket *sk)
            ret = write(sk->fd, sk->outbuf.s.buf, sk->outbuf.s.datasize);
        if(ret < 0)
        {
-           /* For now, assume transient error, since
-            * the socket is polled for errors */
+           if((errno != EINTR) && (errno != EAGAIN))
+           {
+               if(sk->errcb != NULL)
+                   sk->errcb(sk, errno, sk->data);
+               closesock(sk);
+           }
            break;
        }
        if(ret > 0)
@@ -838,61 +851,55 @@ static void acceptunix(struct socket *sk)
 
 int pollsocks(int timeout)
 {
-    int i, num, ret;
+    int ret, fd;
     socklen_t retlen;
-    int newfd;
-    struct pollfd *pfds;
+    int newfd, maxfd;
+    fd_set rfds, wfds, efds;
     struct socket *sk, *next, *newsk;
     struct sockaddr_storage ss;
     socklen_t sslen;
+    struct timeval tv;
     
-    pfds = smalloc(sizeof(*pfds) * (num = numsocks));
-    for(i = 0, sk = sockets; i < num; sk = sk->next)
+    FD_ZERO(&rfds);
+    FD_ZERO(&wfds);
+    FD_ZERO(&efds);
+    for(maxfd = 0, sk = sockets; sk != NULL; sk = sk->next)
     {
-       if(sk->state == SOCK_STL)
-       {
-           num--;
+       if((sk->state == SOCK_STL) || (sk->fd < 0))
            continue;
-       }
-       pfds[i].fd = sk->fd;
-       pfds[i].events = 0;
        if(!sk->ignread)
-           pfds[i].events |= POLLIN;
+           FD_SET(sk->fd, &rfds);
        if((sk->state == SOCK_SYN) || (sockqueuesize(sk) > 0))
-           pfds[i].events |= POLLOUT;
-       pfds[i].revents = 0;
-       i++;
+           FD_SET(sk->fd, &wfds);
+       FD_SET(sk->fd, &efds);
+       if(sk->fd > maxfd)
+           maxfd = sk->fd;
     }
-    ret = poll(pfds, num, timeout);
+    tv.tv_sec = timeout / 1000;
+    tv.tv_usec = (timeout % 1000) * 1000;
+    ret = select(maxfd + 1, &rfds, &wfds, &efds, (timeout < 0)?NULL:&tv);
     if(ret < 0)
     {
        if(errno != EINTR)
        {
-           flog(LOG_CRIT, "pollsocks: poll errored out: %s", strerror(errno));
+           flog(LOG_CRIT, "pollsocks: select errored out: %s", strerror(errno));
            /* To avoid CPU hogging in case it's bad, which it
             * probably is. */
            sleep(1);
        }
-       free(pfds);
        return(1);
     }
     for(sk = sockets; sk != NULL; sk = next)
     {
        next = sk->next;
-       for(i = 0; i < num; i++)
-       {
-           if(pfds[i].fd == sk->fd)
-               break;
-       }
-       if(i == num)
-           continue;
+       fd = sk->fd;
        switch(sk->state)
        {
        case SOCK_LST:
-           if(pfds[i].revents & POLLIN)
+           if(FD_ISSET(fd, &rfds))
            {
                sslen = sizeof(ss);
-               if((newfd = accept(sk->fd, (struct sockaddr *)&ss, &sslen)) < 0)
+               if((newfd = accept(fd, (struct sockaddr *)&ss, &sslen)) < 0)
                {
                    if(sk->errcb != NULL)
                        sk->errcb(sk, errno, sk->data);
@@ -909,26 +916,26 @@ int pollsocks(int timeout)
                    sk->acceptcb(sk, newsk, sk->data);
                putsock(newsk);
            }
-           if(pfds[i].revents & POLLERR)
+           if(FD_ISSET(fd, &efds))
            {
                retlen = sizeof(ret);
-               getsockopt(sk->fd, SOL_SOCKET, SO_ERROR, &ret, &retlen);
+               getsockopt(fd, SOL_SOCKET, SO_ERROR, &ret, &retlen);
                if(sk->errcb != NULL)
                    sk->errcb(sk, ret, sk->data);
                continue;
            }
            break;
        case SOCK_SYN:
-           if(pfds[i].revents & POLLERR)
+           if(FD_ISSET(fd, &efds))
            {
                retlen = sizeof(ret);
-               getsockopt(sk->fd, SOL_SOCKET, SO_ERROR, &ret, &retlen);
+               getsockopt(fd, SOL_SOCKET, SO_ERROR, &ret, &retlen);
                if(sk->conncb != NULL)
                    sk->conncb(sk, ret, sk->data);
                closesock(sk);
                continue;
            }
-           if(pfds[i].revents & (POLLIN | POLLOUT))
+           if(FD_ISSET(fd, &rfds) || FD_ISSET(fd, &wfds))
            {
                sk->state = SOCK_EST;
                if(sk->conncb != NULL)
@@ -936,41 +943,25 @@ int pollsocks(int timeout)
            }
            break;
        case SOCK_EST:
-           if(pfds[i].revents & POLLERR)
+           if(FD_ISSET(fd, &efds))
            {
                retlen = sizeof(ret);
-               getsockopt(sk->fd, SOL_SOCKET, SO_ERROR, &ret, &retlen);
+               getsockopt(fd, SOL_SOCKET, SO_ERROR, &ret, &retlen);
                if(sk->errcb != NULL)
                    sk->errcb(sk, ret, sk->data);
                closesock(sk);
                continue;
            }
-           if(pfds[i].revents & POLLIN)
+           if(FD_ISSET(fd, &rfds))
                sockrecv(sk);
-           if(pfds[i].revents & POLLOUT)
+           if(FD_ISSET(fd, &wfds))
            {
                if(sockqueuesize(sk) > 0)
                    sockflush(sk);
            }
            break;
        }
-       if(pfds[i].revents & POLLNVAL)
-       {
-           flog(LOG_CRIT, "BUG: stale socket struct on fd %i", sk->fd);
-           sk->state = SOCK_STL;
-           unlinksock(sk);
-           continue;
-       }
-       if(pfds[i].revents & POLLHUP)
-       {
-           if(sk->errcb != NULL)
-               sk->errcb(sk, 0, sk->data);
-           closesock(sk);
-           unlinksock(sk);
-           continue;
-       }
     }
-    free(pfds);
     for(sk = sockets; sk != NULL; sk = next)
     {
        next = sk->next;
@@ -993,27 +984,41 @@ int pollsocks(int timeout)
 int socksettos(struct socket *sk, int tos)
 {
     int buf;
+    int dscp2tos;
     
     if(sk->family == AF_UNIX)
        return(0); /* Unix sockets are always perfect. :) */
     if(sk->family == AF_INET)
     {
+       dscp2tos = confgetint("net", "dscp-tos");
        switch(tos)
        {
        case 0:
            buf = 0;
            break;
        case SOCK_TOS_MINCOST:
-           buf = 0x02;
+           if(dscp2tos)
+               buf = confgetint("net", "diffserv-mincost") << 2;
+           else
+               buf = 0x02;
            break;
        case SOCK_TOS_MAXREL:
-           buf = 0x04;
+           if(dscp2tos)
+               buf = confgetint("net", "diffserv-maxrel") << 2;
+           else
+               buf = 0x04;
            break;
        case SOCK_TOS_MAXTP:
-           buf = 0x08;
+           if(dscp2tos)
+               buf = confgetint("net", "diffserv-maxtp") << 2;
+           else
+               buf = 0x08;
            break;
        case SOCK_TOS_MINDELAY:
-           buf = 0x10;
+           if(dscp2tos)
+               buf = confgetint("net", "diffserv-mindelay") << 2;
+           else
+               buf = 0x10;
            break;
        default:
            flog(LOG_WARNING, "attempted to set unknown TOS value %i to IPv4 sock", tos);