]> git.openfabrics.org - ~shefty/librdmacm.git/commitdiff
librspreload: Support server apps that call fork()
authorSean Hefty <sean.hefty@intel.com>
Fri, 13 Jul 2012 22:25:53 +0000 (15:25 -0700)
committerSean Hefty <sean.hefty@intel.com>
Thu, 19 Jul 2012 21:22:07 +0000 (14:22 -0700)
Provide limited support for applications that call fork() after
accepting a connection.

Fork support is indicated by setting the environment variable
RDMAV_FORK_SAFE.

Signed-off-by: Sean Hefty <sean.hefty@intel.com>
src/preload.c

index d2058e23906cb54843add199ab0746406eb53dc8..f824af3fb2a0fc41927cfe00189a8e962fe2c374 100644 (file)
@@ -46,6 +46,8 @@
 #include <string.h>
 #include <netinet/in.h>
 #include <netinet/tcp.h>
+#include <unistd.h>
+#include <semaphore.h>
 
 #include <rdma/rdma_cma.h>
 #include <rdma/rdma_verbs.h>
@@ -81,6 +83,7 @@ struct socket_calls {
        int (*getsockopt)(int socket, int level, int optname,
                          void *optval, socklen_t *optlen);
        int (*fcntl)(int socket, int cmd, ... /* arg */);
+       pid_t (*fork)(void);
 };
 
 static struct socket_calls real;
@@ -92,10 +95,13 @@ static pthread_mutex_t mut = PTHREAD_MUTEX_INITIALIZER;
 static int sq_size;
 static int rq_size;
 static int sq_inline;
+static int fork_support;
+static int last_accept = -1;
 
 enum fd_type {
        fd_normal,
-       fd_rsocket
+       fd_rsocket,
+       fd_fork
 };
 
 struct fd_info {
@@ -207,6 +213,10 @@ void getenv_options(void)
        var = getenv("RS_INLINE");
        if (var)
                sq_inline = atoi(var);
+
+       var = getenv("RDMAV_FORK_SAFE");
+       if (var)
+               fork_support = atoi(var);
 }
 
 static void init_preload(void)
@@ -244,6 +254,7 @@ static void init_preload(void)
        real.setsockopt = dlsym(RTLD_NEXT, "setsockopt");
        real.getsockopt = dlsym(RTLD_NEXT, "getsockopt");
        real.fcntl = dlsym(RTLD_NEXT, "fcntl");
+       real.fork = dlsym(RTLD_NEXT, "fork");
 
        rs.socket = dlsym(RTLD_DEFAULT, "rsocket");
        rs.bind = dlsym(RTLD_DEFAULT, "rbind");
@@ -378,8 +389,16 @@ int socket(int domain, int type, int protocol)
        ret = rsocket(domain, type, protocol);
        recursive = 0;
        if (ret >= 0) {
-               fd_store(index, ret, fd_rsocket);
-               set_rsocket_options(ret);
+               if (fork_support) {
+                       rclose(ret);
+                       ret = real.socket(domain, type, protocol);
+                       if (ret < 0)
+                               return ret;
+                       fd_store(index, ret, fd_fork);
+               } else {
+                       fd_store(index, ret, fd_rsocket);
+                       set_rsocket_options(ret);
+               }
                return index;
        }
        fd_close(index, &ret);
@@ -418,31 +437,67 @@ int listen(int socket, int backlog)
 int accept(int socket, struct sockaddr *addr, socklen_t *addrlen)
 {
        int fd, index, ret;
+       enum fd_type type;
 
-       if (fd_get(socket, &fd) == fd_rsocket) {
+       type = fd_get(socket, &fd);
+       if (type == fd_rsocket || type == fd_fork) {
                index = fd_open();
                if (index < 0)
                        return index;
 
-               ret = raccept(fd, addr, addrlen);
+               ret = (type == fd_rsocket) ? raccept(fd, addr, addrlen) :
+                                            real.accept(fd, addr, addrlen);
                if (ret < 0) {
                        fd_close(index, &fd);
                        return ret;
                }
 
-               fd_store(index, ret, fd_rsocket);
+               fd_store(index, ret, type);
+               last_accept = (type == fd_fork) ? index : -1;
                return index;
        } else {
+               last_accept = -1;
                return real.accept(fd, addr, addrlen);
        }
 }
 
+static int connect_fork(int socket, const struct sockaddr *addr, socklen_t addrlen)
+{
+       int fd, ret;
+       uint32_t msg;
+       long flags;
+
+       fd = fd_getd(socket);
+       flags = real.fcntl(fd, F_GETFL);
+       real.fcntl(fd, F_SETFL, 0);
+       ret = real.connect(fd, addr, addrlen);
+       if (ret)
+               return ret;
+
+       ret = real.recv(fd, &msg, sizeof msg, MSG_PEEK);
+       if ((ret != sizeof msg) || msg) {
+               fd_store(socket, fd, fd_normal);
+               return 0;
+       }
+
+       real.fcntl(fd, F_SETFL, flags);
+       ret = transpose_socket(socket, fd_rsocket);
+       if (ret < 0)
+               return ret;
+
+       real.close(fd);
+       return rconnect(ret, addr, addrlen);
+}
+
 int connect(int socket, const struct sockaddr *addr, socklen_t addrlen)
 {
        struct sockaddr_in *sin;
        int fd, ret;
 
-       if (fd_get(socket, &fd) == fd_rsocket) {
+       switch (fd_get(socket, &fd)) {
+       case fd_fork:
+               return connect_fork(socket, addr, addrlen);
+       case fd_rsocket:
                sin = (struct sockaddr_in *) addr;
                if (ntohs(sin->sin_port) > 1024) {
                        ret = rconnect(fd, addr, addrlen);
@@ -456,6 +511,9 @@ int connect(int socket, const struct sockaddr *addr, socklen_t addrlen)
 
                rclose(fd);
                fd = ret;
+               break;
+       default:
+               break;
        }
 
        return real.connect(fd, addr, addrlen);
@@ -754,3 +812,85 @@ int fcntl(int socket, int cmd, ... /* arg */)
        va_end(args);
        return ret;
 }
+
+/*
+ * We can't fork RDMA connections and pass them from the parent to the child
+ * process.  Intercept the fork call, and if we're the child establish the
+ * RDMA connection after calling fork.  The assumption is that the last
+ * connection accepted by the server will be processed by the child after the
+ * fork call.
+ *
+ * It would be better to establishing the RDMA connection once the child
+ * process tries to use the connection after the fork call (i.e. in a read
+ * or write call), rather than making the previous assumption.
+ */
+pid_t fork(void)
+{
+       struct sockaddr_in6 sin6;
+       pid_t pid;
+       sem_t *sem;
+       int lfd, sfd, dfd, ret, param;
+       socklen_t len;
+       uint32_t msg;
+
+       init_preload();
+       pid = real.fork();
+       if (pid || !fork_support || (last_accept < 0) ||
+           (fd_get(last_accept, &sfd) != fd_fork))
+               goto out;
+
+       len = sizeof sin6;
+       ret = real.getsockname(sfd, (struct sockaddr *) &sin6, &len);
+       if (ret)
+               goto out;
+       sin6.sin6_flowinfo = sin6.sin6_scope_id = 0;
+       memset(&sin6.sin6_addr, 0, sizeof sin6.sin6_addr);
+
+       sem = sem_open("/rsocket_fork", O_CREAT | O_RDWR,
+                      S_IRWXU | S_IRWXG, 1);
+       if (sem == SEM_FAILED)
+               goto out;
+
+       lfd = rsocket(sin6.sin6_family, SOCK_STREAM, 0);
+       if (lfd < 0)
+               goto sclose;
+
+       param = 1;
+       rsetsockopt(lfd, SOL_SOCKET, SO_REUSEADDR, &param, sizeof param);
+
+       sem_wait(sem);
+       ret = rbind(lfd, (struct sockaddr *) &sin6, sizeof sin6);
+       if (ret)
+               goto lclose;
+
+       ret = rlisten(lfd, 1);
+       if (ret)
+               goto lclose;
+
+       msg = 0;
+       ret = real.write(sfd, &msg, sizeof msg);
+       if (ret != sizeof msg)
+               goto lclose;
+
+       dfd = raccept(lfd, NULL, NULL);
+       if (dfd < 0)
+               goto lclose;
+
+       param = 1;
+       rsetsockopt(dfd, IPPROTO_TCP, TCP_NODELAY, &param, sizeof param);
+       set_rsocket_options(dfd);
+
+       copysockopts(dfd, sfd, &rs, &real);
+       real.shutdown(sfd, SHUT_RDWR);
+       real.close(sfd);
+       fd_store(last_accept, dfd, fd_rsocket);
+
+lclose:
+       rclose(lfd);
+       sem_post(sem);
+sclose:
+       sem_close(sem);
+out:
+       last_accept = -1;
+       return pid;
+}