Drop looped packets.
[mctap.git] / src / mctap.c
index b021d88..24efc12 100644 (file)
@@ -6,15 +6,25 @@
 #include <netinet/in.h>
 #include <arpa/inet.h>
 #include <errno.h>
+#include <net/if.h>
+#include <linux/if_tun.h>
+#include <fcntl.h>
+#include <sys/ioctl.h>
+#include <sys/poll.h>
+#include <syslog.h>
+#include <signal.h>
 
 #include "utils.h"
 
+static int quit = 0;
+static unsigned char macaddr[6];
+
 static void usage(FILE *out)
 {
-    fprintf(out, "usage: mctap [-h] MCASTGROUP PORT\n");
+    fprintf(out, "usage: mctap [-hdpk] [-P PIDFILE] [-D TAPNAME] MCASTGROUP PORT\n");
 }
 
-static char *formataddress(struct sockaddr *arg, socklen_t arglen)
+static __attribute__ ((unused)) char *formataddress(struct sockaddr *arg, socklen_t arglen)
 {
     struct sockaddr_in *ipv4;
     struct sockaddr_in6 *ipv6;
@@ -27,7 +37,7 @@ static char *formataddress(struct sockaddr *arg, socklen_t arglen)
     switch(arg->sa_family)
     {
     case AF_UNIX:
-       ret = strdup("Unix socket");
+       ret = sstrdup("Unix socket");
        break;
     case AF_INET:
        ipv4 = (struct sockaddr_in *)arg;
@@ -51,10 +61,14 @@ static char *formataddress(struct sockaddr *arg, socklen_t arglen)
 static int mkmcastsk4(struct in_addr group, int port)
 {
     int fd;
+    int soval;
     struct sockaddr_in nm;
     struct ip_mreqn mreq;
     
     fd = socket(PF_INET, SOCK_DGRAM, IPPROTO_UDP);
+    soval = 1;
+    if(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &soval, sizeof(soval)))
+       return(-1);
     memset(&nm, 0, sizeof(nm));
     nm.sin_family = AF_INET;
     nm.sin_port = htons(port);
@@ -62,38 +76,191 @@ static int mkmcastsk4(struct in_addr group, int port)
        return(-1);
     memset(&mreq, 0, sizeof(mreq));
     mreq.imr_multiaddr = group;
-    if(setsockopt(fd, IPPROTO_IP, IP_ADD_MEMBERSHIP, &mreq, sizeof(mreq)))
+    if(setsockopt(fd, SOL_IP, IP_ADD_MEMBERSHIP, &mreq, sizeof(mreq)))
+       return(-1);
+    soval = 1;
+    if(setsockopt(fd, SOL_IP, IP_MULTICAST_LOOP, &soval, sizeof(soval)))
        return(-1);
     return(fd);
 }
 
-static void test(int fd)
+static __attribute__ ((unused)) void test(int fd)
 {
     char buf[65536];
-    int ret;
-    struct sockaddr_storage nm;
-    socklen_t nmlen;
+    int i, ret;
+    struct pollfd pfd;
     
     while(1) {
-       nmlen = sizeof(nm);
-       ret = recvfrom(fd, buf, sizeof(buf), 0, (struct sockaddr *)&nm, &nmlen);
+       pfd.fd = fd;
+       pfd.events = POLLIN;
+       ret = poll(&pfd, 1, -1);
        if(ret < 0) {
-           fprintf(stderr, "mctap: recvfrom: %s\n", strerror(ret));
+           fprintf(stderr, "mctap: poll: %s\n", strerror(errno));
            exit(1);
        }
-       printf("%s %i:\n", formataddress((struct sockaddr *)&nm, nmlen), ret);
+       if(pfd.revents) {
+           ret = read(fd, buf, sizeof(buf));
+           if(ret < 0) {
+               fprintf(stderr, "mctap: read: %s\n", strerror(errno));
+               exit(1);
+           }
+           for(i = 0; i < ret; i++) {
+               printf("%02x ", (unsigned char)buf[i]);
+               if(i % 20 == 19)
+                   putchar(10);
+           }
+           putchar(10);
+       }
+    }
+}
+
+static void bridge(int sock, int tap, struct sockaddr *dst, socklen_t dstlen)
+{
+    char buf[65536];
+    int ret;
+    struct pollfd pfds[2];
+    struct tun_pi pi;
+    
+    fcntl(sock, F_SETFL, fcntl(sock, F_GETFL) | O_NONBLOCK);
+    fcntl(tap, F_SETFL, fcntl(tap, F_GETFL) | O_NONBLOCK);
+    while(!quit) {
+       pfds[0].fd = sock;
+       pfds[0].events = POLLIN;
+       pfds[1].fd = tap;
+       pfds[1].events = POLLIN;
+       ret = poll(pfds, 2, -1);
+       if(ret < 0) {
+           if(errno != EINTR) {
+               syslog(LOG_ERR, "mctap: poll: %s", strerror(errno));
+               exit(1);
+           }
+           continue;
+       }
+       if(pfds[0].revents) {
+           ret = read(sock, buf, sizeof(buf));
+           if(ret < 0) {
+               if((errno != EINTR) && (errno != EAGAIN)) {
+                   syslog(LOG_ERR, "mctap: mcast packet: %s", strerror(errno));
+                   exit(1);
+               }
+           } else {
+               if(sizeof(buf) - ret < sizeof(pi)) {
+                   /* Drop */
+               } else if((ret < 12) || !memcmp(macaddr, buf + 6, 6)) {
+                   /* Drop looped back */
+               } else {
+                   memmove(buf + sizeof(pi), buf, ret);
+                   pi.flags = 0;
+                   pi.proto = 0;
+                   memcpy(buf, &pi, sizeof(pi));
+                   write(tap, buf, sizeof(pi) + ret);
+               }
+           }
+       }
+       if(pfds[1].revents) {
+           ret = read(tap, buf, sizeof(buf));
+           if(ret < 0) {
+               if((errno != EINTR) && (errno != EAGAIN)) {
+                   syslog(LOG_ERR, "mctap: mcast packet: %s", strerror(errno));
+                   exit(1);
+               }
+           } else {
+               if(ret < sizeof(pi)) {
+                   /* Drop */
+               } else {
+                   memcpy(&pi, buf, sizeof(pi));
+                   if(pi.flags & TUN_PKT_STRIP) {
+                       /* Drop */
+                   } else {
+                       sendto(sock, buf + sizeof(pi), ret - sizeof(pi), 0, dst, dstlen);
+                   }
+               }
+           }
+       }
+    }
+}
+
+static int maketap(char *name)
+{
+    int fd;
+    struct ifreq rb;
+    
+    if((fd = open("/dev/net/tun", O_RDWR)) < 0)
+       return(-1);
+    memset(&rb, 0, sizeof(rb));
+    rb.ifr_flags = IFF_TAP;
+    strncpy(rb.ifr_name, name, IFNAMSIZ);
+    if(ioctl(fd, TUNSETIFF, &rb))
+       return(-1);
+    if(ioctl(fd, SIOCGIFHWADDR, &rb))
+       return(-1);
+    memcpy(macaddr, rb.ifr_hwaddr.sa_data, 6);
+    return(fd);
+}
+
+static void sighand(int sig)
+{
+    switch(sig) {
+    case SIGINT:
+    case SIGTERM:
+       quit = 1;
+       break;
+    case SIGHUP:
+       break;
+    }
+}
+
+static void killrunning(char *pidfile)
+{
+    FILE *pidfd;
+    int pid;
+    
+    if((pidfd = fopen(pidfile, "r")) == NULL) {
+       fprintf(stderr, "mctab -k: could not read PID file %s: %s\n", pidfile, strerror(errno));
+       exit(1);
+    }
+    fscanf(pidfd, "%i", &pid);
+    if(kill(pid, SIGTERM)) {
+       fprintf(stderr, "mctab -k: could not kill %i: %s\n", pid, strerror(errno));
+       exit(1);
     }
+    fclose(pidfd);
 }
 
 int main(int argc, char **argv)
 {
     int c;
-    int sock;
+    int sock, tap;
     struct in_addr group;
     int port;
+    char *tapname;
+    char *pidfile;
+    int daemonize, killold;
+    struct sockaddr_in dst;
+    FILE *pidfd;
     
-    while((c = getopt(argc, argv, "h")) >= 0) {
+    tapname = "mctap";
+    daemonize = killold = 0;
+    pidfile = NULL;
+    while((c = getopt(argc, argv, "hD:dpP:k")) >= 0) {
        switch(c) {
+       case 'D':
+           tapname = optarg;
+           break;
+       case 'd':
+           daemonize = 1;
+           break;
+       case 'p':
+           pidfile = (void *)-1;
+           break;
+       case 'P':
+           pidfile = optarg;
+           break;
+       case 'k':
+           killold = 1;
+           if(pidfile == NULL)
+               pidfile = (void *)-1;
+           break;
        case 'h':
            usage(stdout);
            return(0);
@@ -102,6 +269,12 @@ int main(int argc, char **argv)
            exit(1);
        }
     }
+    if(pidfile == (void *)-1)
+       pidfile = sprintf2("/var/run/mctap.%s.pid", tapname);
+    if(killold) {
+       killrunning(pidfile);
+       return(0);
+    }
     if(argc - optind < 2) {
        usage(stderr);
        exit(1);
@@ -115,7 +288,37 @@ int main(int argc, char **argv)
        fprintf(stderr, "mctap: could not create multicast socket: %s\n", strerror(errno));
        exit(1);
     }
+    if((tap = maketap(tapname)) < 0) {
+       fprintf(stderr, "mctap: could not create TAP device: %s\n", strerror(errno));
+       exit(1);
+    }
+    openlog(sprintf2("mctap-%s", tapname), LOG_PID, LOG_DAEMON);
+    
+    pidfd = NULL;
+    if((pidfile != NULL) && ((pidfd = fopen(pidfile, "w")) == NULL)) {
+       fprintf(stderr, "mctap: could not create PID file %s: %s\n", pidfile, strerror(errno));
+       exit(1);
+    }
+    if(daemonize)
+       daemon(0, 0);
+    if(pidfd != NULL) {
+       fprintf(pidfd, "%i\n", getpid());
+       fclose(pidfd);
+    }
+    
+    signal(SIGTERM, sighand);
+    signal(SIGINT, sighand);
+    signal(SIGHUP, sighand);
+    
+    memset(&dst, 0, sizeof(dst));
+    dst.sin_family = AF_INET;
+    dst.sin_addr = group;
+    dst.sin_port = htons(port);
+    syslog(LOG_INFO, "bridge created with MAC %02x:%02x:%02x:%02x:%02x:%02x", macaddr[0], macaddr[1], macaddr[2], macaddr[3], macaddr[4], macaddr[5]);
+    bridge(sock, tap, (struct sockaddr *)&dst, sizeof(dst));
+    syslog(LOG_INFO, "exiting");
     
-    test(sock);
+    if(pidfile != NULL)
+       unlink(pidfile);
     return(0);
 }