]> git.openfabrics.org - ~shefty/librdmacm.git/commitdiff
librspreload: Support server apps that call fork()
authorSean Hefty <sean.hefty@intel.com>
Tue, 24 Jul 2012 18:40:10 +0000 (11:40 -0700)
committerSean Hefty <sean.hefty@intel.com>
Tue, 24 Jul 2012 18:40:10 +0000 (11:40 -0700)
Provide limited support for applications that call fork() after
accepting a connection.

Fork support is indicated by setting the environment variable
RDMAV_FORK_SAFE.

Signed-off-by: Sean Hefty <sean.hefty@intel.com>
src/preload.c

index d2058e23906cb54843add199ab0746406eb53dc8..79340c6f4965b98b260d2fcd775b2eba0e570548 100644 (file)
@@ -46,6 +46,8 @@
 #include <string.h>
 #include <netinet/in.h>
 #include <netinet/tcp.h>
+#include <unistd.h>
+#include <semaphore.h>
 
 #include <rdma/rdma_cma.h>
 #include <rdma/rdma_verbs.h>
@@ -81,6 +83,7 @@ struct socket_calls {
        int (*getsockopt)(int socket, int level, int optname,
                          void *optval, socklen_t *optlen);
        int (*fcntl)(int socket, int cmd, ... /* arg */);
+       pid_t (*fork)(void);
 };
 
 static struct socket_calls real;
@@ -92,10 +95,12 @@ static pthread_mutex_t mut = PTHREAD_MUTEX_INITIALIZER;
 static int sq_size;
 static int rq_size;
 static int sq_inline;
+static int fork_support;
 
 enum fd_type {
        fd_normal,
-       fd_rsocket
+       fd_rsocket,
+       fd_fork
 };
 
 struct fd_info {
@@ -207,6 +212,10 @@ void getenv_options(void)
        var = getenv("RS_INLINE");
        if (var)
                sq_inline = atoi(var);
+
+       var = getenv("RDMAV_FORK_SAFE");
+       if (var)
+               fork_support = atoi(var);
 }
 
 static void init_preload(void)
@@ -244,6 +253,7 @@ static void init_preload(void)
        real.setsockopt = dlsym(RTLD_NEXT, "setsockopt");
        real.getsockopt = dlsym(RTLD_NEXT, "getsockopt");
        real.fcntl = dlsym(RTLD_NEXT, "fcntl");
+       real.fork = dlsym(RTLD_NEXT, "fork");
 
        rs.socket = dlsym(RTLD_DEFAULT, "rsocket");
        rs.bind = dlsym(RTLD_DEFAULT, "rbind");
@@ -378,8 +388,16 @@ int socket(int domain, int type, int protocol)
        ret = rsocket(domain, type, protocol);
        recursive = 0;
        if (ret >= 0) {
-               fd_store(index, ret, fd_rsocket);
-               set_rsocket_options(ret);
+               if (fork_support) {
+                       rclose(ret);
+                       ret = real.socket(domain, type, protocol);
+                       if (ret < 0)
+                               return ret;
+                       fd_store(index, ret, fd_fork);
+               } else {
+                       fd_store(index, ret, fd_rsocket);
+                       set_rsocket_options(ret);
+               }
                return index;
        }
        fd_close(index, &ret);
@@ -418,31 +436,161 @@ int listen(int socket, int backlog)
 int accept(int socket, struct sockaddr *addr, socklen_t *addrlen)
 {
        int fd, index, ret;
+       enum fd_type type;
 
-       if (fd_get(socket, &fd) == fd_rsocket) {
+       type = fd_get(socket, &fd);
+       if (type == fd_rsocket || type == fd_fork) {
                index = fd_open();
                if (index < 0)
                        return index;
 
-               ret = raccept(fd, addr, addrlen);
+               ret = (type == fd_rsocket) ? raccept(fd, addr, addrlen) :
+                                            real.accept(fd, addr, addrlen);
                if (ret < 0) {
                        fd_close(index, &fd);
                        return ret;
                }
 
-               fd_store(index, ret, fd_rsocket);
+               fd_store(index, ret, type);
                return index;
        } else {
                return real.accept(fd, addr, addrlen);
        }
 }
 
+/*
+ * We can't fork RDMA connections and pass them from the parent to the child
+ * process.  Instead, we need to establish the RDMA connection after calling
+ * fork.  To do this, we delay establishing the RDMA connection until we try
+ * to send/receive on the server side.  On the client side, we don't expect
+ * to fork, so we switch from a TCP connection to an rsocket when connecting.
+ */
+static int fork_active(int socket, const struct sockaddr *addr, socklen_t addrlen)
+{
+       int fd, ret;
+       uint32_t msg;
+       long flags;
+
+       fd = fd_getd(socket);
+       flags = real.fcntl(fd, F_GETFL);
+       real.fcntl(fd, F_SETFL, 0);
+       ret = real.connect(fd, addr, addrlen);
+       if (ret)
+               return ret;
+
+       ret = real.recv(fd, &msg, sizeof msg, MSG_PEEK);
+       if ((ret != sizeof msg) || msg) {
+               fd_store(socket, fd, fd_normal);
+               return 0;
+       }
+
+       real.fcntl(fd, F_SETFL, flags);
+       ret = transpose_socket(socket, fd_rsocket);
+       if (ret < 0)
+               return ret;
+
+       real.close(fd);
+       return rconnect(ret, addr, addrlen);
+}
+
+static void fork_passive(int socket)
+{
+       struct sockaddr_in6 sin6;
+       sem_t *sem;
+       int lfd, sfd, dfd, ret, param;
+       socklen_t len;
+       uint32_t msg;
+
+       fd_get(socket, &sfd);
+
+       len = sizeof sin6;
+       ret = real.getsockname(sfd, (struct sockaddr *) &sin6, &len);
+       if (ret)
+               goto out;
+       sin6.sin6_flowinfo = sin6.sin6_scope_id = 0;
+       memset(&sin6.sin6_addr, 0, sizeof sin6.sin6_addr);
+
+       sem = sem_open("/rsocket_fork", O_CREAT | O_RDWR,
+                      S_IRWXU | S_IRWXG, 1);
+       if (sem == SEM_FAILED) {
+               ret = -1;
+               goto out;
+       }
+
+       lfd = rsocket(sin6.sin6_family, SOCK_STREAM, 0);
+       if (lfd < 0) {
+               ret  = lfd;
+               goto sclose;
+       }
+
+       param = 1;
+       rsetsockopt(lfd, SOL_SOCKET, SO_REUSEADDR, &param, sizeof param);
+
+       sem_wait(sem);
+       ret = rbind(lfd, (struct sockaddr *) &sin6, sizeof sin6);
+       if (ret)
+               goto lclose;
+
+       ret = rlisten(lfd, 1);
+       if (ret)
+               goto lclose;
+
+       msg = 0;
+       len = real.write(sfd, &msg, sizeof msg);
+       if (len != sizeof msg)
+               goto lclose;
+
+       dfd = raccept(lfd, NULL, NULL);
+       if (dfd < 0) {
+               ret  = dfd;
+               goto lclose;
+       }
+
+       param = 1;
+       rsetsockopt(dfd, IPPROTO_TCP, TCP_NODELAY, &param, sizeof param);
+       set_rsocket_options(dfd);
+
+       copysockopts(dfd, sfd, &rs, &real);
+       real.shutdown(sfd, SHUT_RDWR);
+       real.close(sfd);
+       fd_store(socket, dfd, fd_rsocket);
+
+lclose:
+       rclose(lfd);
+       sem_post(sem);
+sclose:
+       sem_close(sem);
+out:
+       if (ret)
+               fd_store(socket, sfd, fd_normal);
+}
+
+static inline enum fd_type fd_fork_get(int index, int *fd)
+{
+       struct fd_info *fdi;
+
+       fdi = idm_lookup(&idm, index);
+       if (fdi) {
+               if (fdi->type == fd_fork)
+                       fork_passive(index);
+               *fd = fdi->fd;
+               return fdi->type;
+
+       } else {
+               *fd = index;
+               return fd_normal;
+       }
+}
+
 int connect(int socket, const struct sockaddr *addr, socklen_t addrlen)
 {
        struct sockaddr_in *sin;
        int fd, ret;
 
-       if (fd_get(socket, &fd) == fd_rsocket) {
+       switch (fd_get(socket, &fd)) {
+       case fd_fork:
+               return fork_active(socket, addr, addrlen);
+       case fd_rsocket:
                sin = (struct sockaddr_in *) addr;
                if (ntohs(sin->sin_port) > 1024) {
                        ret = rconnect(fd, addr, addrlen);
@@ -456,6 +604,9 @@ int connect(int socket, const struct sockaddr *addr, socklen_t addrlen)
 
                rclose(fd);
                fd = ret;
+               break;
+       default:
+               break;
        }
 
        return real.connect(fd, addr, addrlen);
@@ -464,7 +615,7 @@ int connect(int socket, const struct sockaddr *addr, socklen_t addrlen)
 ssize_t recv(int socket, void *buf, size_t len, int flags)
 {
        int fd;
-       return (fd_get(socket, &fd) == fd_rsocket) ?
+       return (fd_fork_get(socket, &fd) == fd_rsocket) ?
                rrecv(fd, buf, len, flags) : real.recv(fd, buf, len, flags);
 }
 
@@ -472,7 +623,7 @@ ssize_t recvfrom(int socket, void *buf, size_t len, int flags,
                 struct sockaddr *src_addr, socklen_t *addrlen)
 {
        int fd;
-       return (fd_get(socket, &fd) == fd_rsocket) ?
+       return (fd_fork_get(socket, &fd) == fd_rsocket) ?
                rrecvfrom(fd, buf, len, flags, src_addr, addrlen) :
                real.recvfrom(fd, buf, len, flags, src_addr, addrlen);
 }
@@ -480,7 +631,7 @@ ssize_t recvfrom(int socket, void *buf, size_t len, int flags,
 ssize_t recvmsg(int socket, struct msghdr *msg, int flags)
 {
        int fd;
-       return (fd_get(socket, &fd) == fd_rsocket) ?
+       return (fd_fork_get(socket, &fd) == fd_rsocket) ?
                rrecvmsg(fd, msg, flags) : real.recvmsg(fd, msg, flags);
 }
 
@@ -488,7 +639,7 @@ ssize_t read(int socket, void *buf, size_t count)
 {
        int fd;
        init_preload();
-       return (fd_get(socket, &fd) == fd_rsocket) ?
+       return (fd_fork_get(socket, &fd) == fd_rsocket) ?
                rread(fd, buf, count) : real.read(fd, buf, count);
 }
 
@@ -496,14 +647,14 @@ ssize_t readv(int socket, const struct iovec *iov, int iovcnt)
 {
        int fd;
        init_preload();
-       return (fd_get(socket, &fd) == fd_rsocket) ?
+       return (fd_fork_get(socket, &fd) == fd_rsocket) ?
                rreadv(fd, iov, iovcnt) : real.readv(fd, iov, iovcnt);
 }
 
 ssize_t send(int socket, const void *buf, size_t len, int flags)
 {
        int fd;
-       return (fd_get(socket, &fd) == fd_rsocket) ?
+       return (fd_fork_get(socket, &fd) == fd_rsocket) ?
                rsend(fd, buf, len, flags) : real.send(fd, buf, len, flags);
 }
 
@@ -511,7 +662,7 @@ ssize_t sendto(int socket, const void *buf, size_t len, int flags,
                const struct sockaddr *dest_addr, socklen_t addrlen)
 {
        int fd;
-       return (fd_get(socket, &fd) == fd_rsocket) ?
+       return (fd_fork_get(socket, &fd) == fd_rsocket) ?
                rsendto(fd, buf, len, flags, dest_addr, addrlen) :
                real.sendto(fd, buf, len, flags, dest_addr, addrlen);
 }
@@ -519,7 +670,7 @@ ssize_t sendto(int socket, const void *buf, size_t len, int flags,
 ssize_t sendmsg(int socket, const struct msghdr *msg, int flags)
 {
        int fd;
-       return (fd_get(socket, &fd) == fd_rsocket) ?
+       return (fd_fork_get(socket, &fd) == fd_rsocket) ?
                rsendmsg(fd, msg, flags) : real.sendmsg(fd, msg, flags);
 }
 
@@ -527,7 +678,7 @@ ssize_t write(int socket, const void *buf, size_t count)
 {
        int fd;
        init_preload();
-       return (fd_get(socket, &fd) == fd_rsocket) ?
+       return (fd_fork_get(socket, &fd) == fd_rsocket) ?
                rwrite(fd, buf, count) : real.write(fd, buf, count);
 }
 
@@ -535,7 +686,7 @@ ssize_t writev(int socket, const struct iovec *iov, int iovcnt)
 {
        int fd;
        init_preload();
-       return (fd_get(socket, &fd) == fd_rsocket) ?
+       return (fd_fork_get(socket, &fd) == fd_rsocket) ?
                rwritev(fd, iov, iovcnt) : real.writev(fd, iov, iovcnt);
 }