]> git.openfabrics.org - ~shefty/librdmacm.git/commitdiff
rspreload: Do not block connect when supporting fork
authorSean Hefty <sean.hefty@intel.com>
Sat, 11 Aug 2012 04:44:39 +0000 (21:44 -0700)
committerSean Hefty <sean.hefty@intel.com>
Sat, 11 Aug 2012 04:44:39 +0000 (21:44 -0700)
Many FTP servers require fork support.  However, FTP clients,
such as ncftp, will perform the following call sequence:

send PASV request to server over connection 1
         server will listen for connection 2
issue nonblocking connect to server
send ACCEPT request to server over connection 1
         server will accept connection 2

The current fork support converts all nonblocking connect
calls to blocking.  The result is that the FTP client ends up
blocked waiting for the server to accept the connection,
which it will never do.

To handle this case, we have the active side follow the same
rule as the server side and defer establishing the rsocket
connection until the user calls the first data transfer routine.

Signed-off-by: Sean Hefty <sean.hefty@intel.com>
src/preload.c

index b18d31005a4759bef9531ed909c734ff0c22ae8d..9182b6aa7feaefc5ae7072eb4d79f86a757fd830 100644 (file)
@@ -99,12 +99,20 @@ static int fork_support;
 
 enum fd_type {
        fd_normal,
-       fd_rsocket,
-       fd_fork
+       fd_rsocket
+};
+
+enum fd_fork_state {
+       fd_ready,
+       fd_fork,
+       fd_fork_listen,
+       fd_fork_active,
+       fd_fork_passive
 };
 
 struct fd_info {
        enum fd_type type;
+       enum fd_fork_state state;
        int fd;
        int dupfd;
        atomic_t refcnt;
@@ -143,13 +151,14 @@ err1:
        return ret;
 }
 
-static void fd_store(int index, int fd, enum fd_type type)
+static void fd_store(int index, int fd, enum fd_type type, enum fd_fork_state state)
 {
        struct fd_info *fdi;
 
        fdi = idm_at(&idm, index);
        fdi->fd = fd;
        fdi->type = type;
+       fdi->state = state;
 }
 
 static inline enum fd_type fd_get(int index, int *fd)
@@ -175,6 +184,14 @@ static inline int fd_getd(int index)
        return fdi ? fdi->fd : index;
 }
 
+static inline enum fd_fork_state fd_gets(int index)
+{
+       struct fd_info *fdi;
+
+       fdi = idm_lookup(&idm, index);
+       return fdi ? fdi->state : fd_ready;
+}
+
 static inline enum fd_type fd_gett(int index)
 {
        struct fd_info *fdi;
@@ -353,7 +370,7 @@ static int transpose_socket(int socket, enum fd_type new_type)
        if (ret)
                goto err;
 
-       fd_store(socket, dfd, new_type);
+       fd_store(socket, dfd, new_type, fd_ready);
        return dfd;
 
 err:
@@ -398,9 +415,9 @@ int socket(int domain, int type, int protocol)
                        ret = real.socket(domain, type, protocol);
                        if (ret < 0)
                                return ret;
-                       fd_store(index, ret, fd_fork);
+                       fd_store(index, ret, fd_normal, fd_fork);
                } else {
-                       fd_store(index, ret, fd_rsocket);
+                       fd_store(index, ret, fd_rsocket, fd_ready);
                        set_rsocket_options(ret);
                }
                return index;
@@ -419,30 +436,46 @@ int bind(int socket, const struct sockaddr *addr, socklen_t addrlen)
 
 int listen(int socket, int backlog)
 {
-       int fd;
-       return (fd_get(socket, &fd) == fd_rsocket) ?
-               rlisten(fd, backlog) : real.listen(fd, backlog);
+       int fd, ret;
+       if (fd_get(socket, &fd) == fd_rsocket) {
+               ret = rlisten(fd, backlog);
+       } else {
+               ret = real.listen(fd, backlog);
+               if (!ret && fd_gets(socket) == fd_fork)
+                       fd_store(socket, fd, fd_normal, fd_fork_listen);
+       }
+       return ret;
 }
 
 int accept(int socket, struct sockaddr *addr, socklen_t *addrlen)
 {
        int fd, index, ret;
-       enum fd_type type;
 
-       type = fd_get(socket, &fd);
-       if (type == fd_rsocket || type == fd_fork) {
+       if (fd_get(socket, &fd) == fd_rsocket) {
                index = fd_open();
                if (index < 0)
                        return index;
 
-               ret = (type == fd_rsocket) ? raccept(fd, addr, addrlen) :
-                                            real.accept(fd, addr, addrlen);
+               ret = raccept(fd, addr, addrlen);
                if (ret < 0) {
                        fd_close(index, &fd);
                        return ret;
                }
 
-               fd_store(index, ret, type);
+               fd_store(index, ret, fd_rsocket, fd_ready);
+               return index;
+       } else if (fd_gets(socket) == fd_fork_listen) {
+               index = fd_open();
+               if (index < 0)
+                       return index;
+
+               ret = real.accept(fd, addr, addrlen);
+               if (ret < 0) {
+                       fd_close(index, &fd);
+                       return ret;
+               }
+
+               fd_store(index, ret, fd_normal, fd_fork_passive);
                return index;
        } else {
                return real.accept(fd, addr, addrlen);
@@ -453,35 +486,49 @@ int accept(int socket, struct sockaddr *addr, socklen_t *addrlen)
  * We can't fork RDMA connections and pass them from the parent to the child
  * process.  Instead, we need to establish the RDMA connection after calling
  * fork.  To do this, we delay establishing the RDMA connection until we try
- * to send/receive on the server side.  On the client side, we don't expect
- * to fork, so we switch from a TCP connection to an rsocket when connecting.
+ * to send/receive on the server side.
  */
-static int fork_active(int socket, const struct sockaddr *addr, socklen_t addrlen)
+static void fork_active(int socket)
 {
-       int fd, ret;
+       struct sockaddr_storage addr;
+       int sfd, dfd, ret;
+       socklen_t len;
        uint32_t msg;
        long flags;
 
-       fd = fd_getd(socket);
-       flags = real.fcntl(fd, F_GETFL);
-       real.fcntl(fd, F_SETFL, 0);
-       ret = real.connect(fd, addr, addrlen);
+       sfd = fd_getd(socket);
+
+       len = sizeof addr;
+       ret = real.getpeername(sfd, (struct sockaddr *) &addr, &len);
        if (ret)
-               return ret;
+               goto err1;
 
-       ret = real.recv(fd, &msg, sizeof msg, MSG_PEEK);
-       if ((ret != sizeof msg) || msg) {
-               fd_store(socket, fd, fd_normal);
-               return 0;
-       }
+       dfd = rsocket(addr.ss_family, SOCK_STREAM, 0);
+       if (dfd < 0)
+               goto err1;
 
-       real.fcntl(fd, F_SETFL, flags);
-       ret = transpose_socket(socket, fd_rsocket);
-       if (ret < 0)
-               return ret;
+       flags = real.fcntl(sfd, F_GETFL);
+       real.fcntl(sfd, F_SETFL, 0);
+       ret = real.recv(sfd, &msg, sizeof msg, MSG_PEEK);
+       real.fcntl(sfd, F_SETFL, flags);
+       if ((ret != sizeof msg) || msg)
+               goto err2;
+
+       ret = rconnect(ret, (struct sockaddr *) &addr, len);
+       if (ret)
+               goto err2;
 
-       real.close(fd);
-       return rconnect(ret, addr, addrlen);
+       set_rsocket_options(dfd);
+       copysockopts(dfd, sfd, &rs, &real);
+       real.shutdown(sfd, SHUT_RDWR);
+       real.close(sfd);
+       fd_store(socket, dfd, fd_rsocket, fd_ready);
+       return;
+
+err2:
+       rclose(dfd);
+err1:
+       fd_store(socket, sfd, fd_normal, fd_ready);
 }
 
 static void fork_passive(int socket)
@@ -492,7 +539,7 @@ static void fork_passive(int socket)
        socklen_t len;
        uint32_t msg;
 
-       fd_get(socket, &sfd);
+       sfd = fd_getd(socket);
 
        len = sizeof sin6;
        ret = real.getsockname(sfd, (struct sockaddr *) &sin6, &len);
@@ -510,7 +557,7 @@ static void fork_passive(int socket)
 
        lfd = rsocket(sin6.sin6_family, SOCK_STREAM, 0);
        if (lfd < 0) {
-               ret  = lfd;
+               ret = lfd;
                goto sclose;
        }
 
@@ -537,14 +584,11 @@ static void fork_passive(int socket)
                goto lclose;
        }
 
-       param = 1;
-       rsetsockopt(dfd, IPPROTO_TCP, TCP_NODELAY, &param, sizeof param);
        set_rsocket_options(dfd);
-
        copysockopts(dfd, sfd, &rs, &real);
        real.shutdown(sfd, SHUT_RDWR);
        real.close(sfd);
-       fd_store(socket, dfd, fd_rsocket);
+       fd_store(socket, dfd, fd_rsocket, fd_ready);
 
 lclose:
        rclose(lfd);
@@ -553,7 +597,7 @@ sclose:
        sem_close(sem);
 out:
        if (ret)
-               fd_store(socket, sfd, fd_normal);
+               fd_store(socket, sfd, fd_normal, fd_ready);
 }
 
 static inline enum fd_type fd_fork_get(int index, int *fd)
@@ -562,8 +606,10 @@ static inline enum fd_type fd_fork_get(int index, int *fd)
 
        fdi = idm_lookup(&idm, index);
        if (fdi) {
-               if (fdi->type == fd_fork)
+               if (fdi->type == fd_fork_passive)
                        fork_passive(index);
+               else if (fdi->type == fd_fork_active)
+                       fork_active(index);
                *fd = fdi->fd;
                return fdi->type;
 
@@ -577,10 +623,7 @@ int connect(int socket, const struct sockaddr *addr, socklen_t addrlen)
 {
        int fd, ret;
 
-       switch (fd_get(socket, &fd)) {
-       case fd_fork:
-               return fork_active(socket, addr, addrlen);
-       case fd_rsocket:
+       if (fd_get(socket, &fd) == fd_rsocket) {
                ret = rconnect(fd, addr, addrlen);
                if (!ret || errno == EINPROGRESS)
                        return ret;
@@ -591,9 +634,8 @@ int connect(int socket, const struct sockaddr *addr, socklen_t addrlen)
 
                rclose(fd);
                fd = ret;
-               break;
-       default:
-               break;
+       } else if (fd_gets(socket) == fd_fork) {
+               fd_store(socket, fd, fd_normal, fd_fork_active);
        }
 
        return real.connect(fd, addr, addrlen);