From: Sean Hefty Date: Tue, 24 Jul 2012 18:40:10 +0000 (-0700) Subject: librspreload: Support server apps that call fork() X-Git-Url: https://openfabrics.org/gitweb/?a=commitdiff_plain;h=f34057dc48aa90a8a996efd71b2bd434991a48d9;p=~shefty%2Flibrdmacm.git librspreload: Support server apps that call fork() Provide limited support for applications that call fork() after accepting a connection. Fork support is indicated by setting the environment variable RDMAV_FORK_SAFE. Signed-off-by: Sean Hefty --- diff --git a/src/preload.c b/src/preload.c index d2058e23..79340c6f 100644 --- a/src/preload.c +++ b/src/preload.c @@ -46,6 +46,8 @@ #include #include #include +#include +#include #include #include @@ -81,6 +83,7 @@ struct socket_calls { int (*getsockopt)(int socket, int level, int optname, void *optval, socklen_t *optlen); int (*fcntl)(int socket, int cmd, ... /* arg */); + pid_t (*fork)(void); }; static struct socket_calls real; @@ -92,10 +95,12 @@ static pthread_mutex_t mut = PTHREAD_MUTEX_INITIALIZER; static int sq_size; static int rq_size; static int sq_inline; +static int fork_support; enum fd_type { fd_normal, - fd_rsocket + fd_rsocket, + fd_fork }; struct fd_info { @@ -207,6 +212,10 @@ void getenv_options(void) var = getenv("RS_INLINE"); if (var) sq_inline = atoi(var); + + var = getenv("RDMAV_FORK_SAFE"); + if (var) + fork_support = atoi(var); } static void init_preload(void) @@ -244,6 +253,7 @@ static void init_preload(void) real.setsockopt = dlsym(RTLD_NEXT, "setsockopt"); real.getsockopt = dlsym(RTLD_NEXT, "getsockopt"); real.fcntl = dlsym(RTLD_NEXT, "fcntl"); + real.fork = dlsym(RTLD_NEXT, "fork"); rs.socket = dlsym(RTLD_DEFAULT, "rsocket"); rs.bind = dlsym(RTLD_DEFAULT, "rbind"); @@ -378,8 +388,16 @@ int socket(int domain, int type, int protocol) ret = rsocket(domain, type, protocol); recursive = 0; if (ret >= 0) { - fd_store(index, ret, fd_rsocket); - set_rsocket_options(ret); + if (fork_support) { + rclose(ret); + ret = real.socket(domain, type, protocol); + if (ret < 0) + return ret; + fd_store(index, ret, fd_fork); + } else { + fd_store(index, ret, fd_rsocket); + set_rsocket_options(ret); + } return index; } fd_close(index, &ret); @@ -418,31 +436,161 @@ int listen(int socket, int backlog) int accept(int socket, struct sockaddr *addr, socklen_t *addrlen) { int fd, index, ret; + enum fd_type type; - if (fd_get(socket, &fd) == fd_rsocket) { + type = fd_get(socket, &fd); + if (type == fd_rsocket || type == fd_fork) { index = fd_open(); if (index < 0) return index; - ret = raccept(fd, addr, addrlen); + ret = (type == fd_rsocket) ? raccept(fd, addr, addrlen) : + real.accept(fd, addr, addrlen); if (ret < 0) { fd_close(index, &fd); return ret; } - fd_store(index, ret, fd_rsocket); + fd_store(index, ret, type); return index; } else { return real.accept(fd, addr, addrlen); } } +/* + * We can't fork RDMA connections and pass them from the parent to the child + * process. Instead, we need to establish the RDMA connection after calling + * fork. To do this, we delay establishing the RDMA connection until we try + * to send/receive on the server side. On the client side, we don't expect + * to fork, so we switch from a TCP connection to an rsocket when connecting. + */ +static int fork_active(int socket, const struct sockaddr *addr, socklen_t addrlen) +{ + int fd, ret; + uint32_t msg; + long flags; + + fd = fd_getd(socket); + flags = real.fcntl(fd, F_GETFL); + real.fcntl(fd, F_SETFL, 0); + ret = real.connect(fd, addr, addrlen); + if (ret) + return ret; + + ret = real.recv(fd, &msg, sizeof msg, MSG_PEEK); + if ((ret != sizeof msg) || msg) { + fd_store(socket, fd, fd_normal); + return 0; + } + + real.fcntl(fd, F_SETFL, flags); + ret = transpose_socket(socket, fd_rsocket); + if (ret < 0) + return ret; + + real.close(fd); + return rconnect(ret, addr, addrlen); +} + +static void fork_passive(int socket) +{ + struct sockaddr_in6 sin6; + sem_t *sem; + int lfd, sfd, dfd, ret, param; + socklen_t len; + uint32_t msg; + + fd_get(socket, &sfd); + + len = sizeof sin6; + ret = real.getsockname(sfd, (struct sockaddr *) &sin6, &len); + if (ret) + goto out; + sin6.sin6_flowinfo = sin6.sin6_scope_id = 0; + memset(&sin6.sin6_addr, 0, sizeof sin6.sin6_addr); + + sem = sem_open("/rsocket_fork", O_CREAT | O_RDWR, + S_IRWXU | S_IRWXG, 1); + if (sem == SEM_FAILED) { + ret = -1; + goto out; + } + + lfd = rsocket(sin6.sin6_family, SOCK_STREAM, 0); + if (lfd < 0) { + ret = lfd; + goto sclose; + } + + param = 1; + rsetsockopt(lfd, SOL_SOCKET, SO_REUSEADDR, ¶m, sizeof param); + + sem_wait(sem); + ret = rbind(lfd, (struct sockaddr *) &sin6, sizeof sin6); + if (ret) + goto lclose; + + ret = rlisten(lfd, 1); + if (ret) + goto lclose; + + msg = 0; + len = real.write(sfd, &msg, sizeof msg); + if (len != sizeof msg) + goto lclose; + + dfd = raccept(lfd, NULL, NULL); + if (dfd < 0) { + ret = dfd; + goto lclose; + } + + param = 1; + rsetsockopt(dfd, IPPROTO_TCP, TCP_NODELAY, ¶m, sizeof param); + set_rsocket_options(dfd); + + copysockopts(dfd, sfd, &rs, &real); + real.shutdown(sfd, SHUT_RDWR); + real.close(sfd); + fd_store(socket, dfd, fd_rsocket); + +lclose: + rclose(lfd); + sem_post(sem); +sclose: + sem_close(sem); +out: + if (ret) + fd_store(socket, sfd, fd_normal); +} + +static inline enum fd_type fd_fork_get(int index, int *fd) +{ + struct fd_info *fdi; + + fdi = idm_lookup(&idm, index); + if (fdi) { + if (fdi->type == fd_fork) + fork_passive(index); + *fd = fdi->fd; + return fdi->type; + + } else { + *fd = index; + return fd_normal; + } +} + int connect(int socket, const struct sockaddr *addr, socklen_t addrlen) { struct sockaddr_in *sin; int fd, ret; - if (fd_get(socket, &fd) == fd_rsocket) { + switch (fd_get(socket, &fd)) { + case fd_fork: + return fork_active(socket, addr, addrlen); + case fd_rsocket: sin = (struct sockaddr_in *) addr; if (ntohs(sin->sin_port) > 1024) { ret = rconnect(fd, addr, addrlen); @@ -456,6 +604,9 @@ int connect(int socket, const struct sockaddr *addr, socklen_t addrlen) rclose(fd); fd = ret; + break; + default: + break; } return real.connect(fd, addr, addrlen); @@ -464,7 +615,7 @@ int connect(int socket, const struct sockaddr *addr, socklen_t addrlen) ssize_t recv(int socket, void *buf, size_t len, int flags) { int fd; - return (fd_get(socket, &fd) == fd_rsocket) ? + return (fd_fork_get(socket, &fd) == fd_rsocket) ? rrecv(fd, buf, len, flags) : real.recv(fd, buf, len, flags); } @@ -472,7 +623,7 @@ ssize_t recvfrom(int socket, void *buf, size_t len, int flags, struct sockaddr *src_addr, socklen_t *addrlen) { int fd; - return (fd_get(socket, &fd) == fd_rsocket) ? + return (fd_fork_get(socket, &fd) == fd_rsocket) ? rrecvfrom(fd, buf, len, flags, src_addr, addrlen) : real.recvfrom(fd, buf, len, flags, src_addr, addrlen); } @@ -480,7 +631,7 @@ ssize_t recvfrom(int socket, void *buf, size_t len, int flags, ssize_t recvmsg(int socket, struct msghdr *msg, int flags) { int fd; - return (fd_get(socket, &fd) == fd_rsocket) ? + return (fd_fork_get(socket, &fd) == fd_rsocket) ? rrecvmsg(fd, msg, flags) : real.recvmsg(fd, msg, flags); } @@ -488,7 +639,7 @@ ssize_t read(int socket, void *buf, size_t count) { int fd; init_preload(); - return (fd_get(socket, &fd) == fd_rsocket) ? + return (fd_fork_get(socket, &fd) == fd_rsocket) ? rread(fd, buf, count) : real.read(fd, buf, count); } @@ -496,14 +647,14 @@ ssize_t readv(int socket, const struct iovec *iov, int iovcnt) { int fd; init_preload(); - return (fd_get(socket, &fd) == fd_rsocket) ? + return (fd_fork_get(socket, &fd) == fd_rsocket) ? rreadv(fd, iov, iovcnt) : real.readv(fd, iov, iovcnt); } ssize_t send(int socket, const void *buf, size_t len, int flags) { int fd; - return (fd_get(socket, &fd) == fd_rsocket) ? + return (fd_fork_get(socket, &fd) == fd_rsocket) ? rsend(fd, buf, len, flags) : real.send(fd, buf, len, flags); } @@ -511,7 +662,7 @@ ssize_t sendto(int socket, const void *buf, size_t len, int flags, const struct sockaddr *dest_addr, socklen_t addrlen) { int fd; - return (fd_get(socket, &fd) == fd_rsocket) ? + return (fd_fork_get(socket, &fd) == fd_rsocket) ? rsendto(fd, buf, len, flags, dest_addr, addrlen) : real.sendto(fd, buf, len, flags, dest_addr, addrlen); } @@ -519,7 +670,7 @@ ssize_t sendto(int socket, const void *buf, size_t len, int flags, ssize_t sendmsg(int socket, const struct msghdr *msg, int flags) { int fd; - return (fd_get(socket, &fd) == fd_rsocket) ? + return (fd_fork_get(socket, &fd) == fd_rsocket) ? rsendmsg(fd, msg, flags) : real.sendmsg(fd, msg, flags); } @@ -527,7 +678,7 @@ ssize_t write(int socket, const void *buf, size_t count) { int fd; init_preload(); - return (fd_get(socket, &fd) == fd_rsocket) ? + return (fd_fork_get(socket, &fd) == fd_rsocket) ? rwrite(fd, buf, count) : real.write(fd, buf, count); } @@ -535,7 +686,7 @@ ssize_t writev(int socket, const struct iovec *iov, int iovcnt) { int fd; init_preload(); - return (fd_get(socket, &fd) == fd_rsocket) ? + return (fd_fork_get(socket, &fd) == fd_rsocket) ? rwritev(fd, iov, iovcnt) : real.writev(fd, iov, iovcnt); }