#include <string.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
+#include <unistd.h>
+#include <semaphore.h>
#include <rdma/rdma_cma.h>
#include <rdma/rdma_verbs.h>
static int sq_size;
static int rq_size;
static int sq_inline;
+static int fork_support;
enum fd_type {
fd_normal,
- fd_rsocket
+ fd_rsocket,
+ fd_fork
};
struct fd_info {
var = getenv("RS_INLINE");
if (var)
sq_inline = atoi(var);
+
+ var = getenv("RDMAV_FORK_SAFE");
+ if (var)
+ fork_support = atoi(var);
}
static void init_preload(void)
ret = rsocket(domain, type, protocol);
recursive = 0;
if (ret >= 0) {
- fd_store(index, ret, fd_rsocket);
- set_rsocket_options(ret);
+ if (fork_support) {
+ rclose(ret);
+ ret = real.socket(domain, type, protocol);
+ if (ret < 0)
+ return ret;
+ fd_store(index, ret, fd_fork);
+ } else {
+ fd_store(index, ret, fd_rsocket);
+ set_rsocket_options(ret);
+ }
return index;
}
fd_close(index, &ret);
int accept(int socket, struct sockaddr *addr, socklen_t *addrlen)
{
int fd, index, ret;
+ enum fd_type type;
- if (fd_get(socket, &fd) == fd_rsocket) {
+ type = fd_get(socket, &fd);
+ if (type == fd_rsocket || type == fd_fork) {
index = fd_open();
if (index < 0)
return index;
- ret = raccept(fd, addr, addrlen);
+ ret = (type == fd_rsocket) ? raccept(fd, addr, addrlen) :
+ real.accept(fd, addr, addrlen);
if (ret < 0) {
fd_close(index, &fd);
return ret;
}
- fd_store(index, ret, fd_rsocket);
+ fd_store(index, ret, type);
return index;
} else {
return real.accept(fd, addr, addrlen);
}
}
+/*
+ * We can't fork RDMA connections and pass them from the parent to the child
+ * process. Instead, we need to establish the RDMA connection after calling
+ * fork. To do this, we delay establishing the RDMA connection until we try
+ * to send/receive on the server side. On the client side, we don't expect
+ * to fork, so we switch from a TCP connection to an rsocket when connecting.
+ */
+static int fork_active(int socket, const struct sockaddr *addr, socklen_t addrlen)
+{
+ int fd, ret;
+ uint32_t msg;
+ long flags;
+
+ fd = fd_getd(socket);
+ flags = real.fcntl(fd, F_GETFL);
+ real.fcntl(fd, F_SETFL, 0);
+ ret = real.connect(fd, addr, addrlen);
+ if (ret)
+ return ret;
+
+ ret = real.recv(fd, &msg, sizeof msg, MSG_PEEK);
+ if ((ret != sizeof msg) || msg) {
+ fd_store(socket, fd, fd_normal);
+ return 0;
+ }
+
+ real.fcntl(fd, F_SETFL, flags);
+ ret = transpose_socket(socket, fd_rsocket);
+ if (ret < 0)
+ return ret;
+
+ real.close(fd);
+ return rconnect(ret, addr, addrlen);
+}
+
+static void fork_passive(int socket)
+{
+ struct sockaddr_in6 sin6;
+ sem_t *sem;
+ int lfd, sfd, dfd, ret, param;
+ socklen_t len;
+ uint32_t msg;
+
+ fd_get(socket, &sfd);
+
+ len = sizeof sin6;
+ ret = real.getsockname(sfd, (struct sockaddr *) &sin6, &len);
+ if (ret)
+ goto out;
+ sin6.sin6_flowinfo = sin6.sin6_scope_id = 0;
+ memset(&sin6.sin6_addr, 0, sizeof sin6.sin6_addr);
+
+ sem = sem_open("/rsocket_fork", O_CREAT | O_RDWR,
+ S_IRWXU | S_IRWXG, 1);
+ if (sem == SEM_FAILED) {
+ ret = -1;
+ goto out;
+ }
+
+ lfd = rsocket(sin6.sin6_family, SOCK_STREAM, 0);
+ if (lfd < 0) {
+ ret = lfd;
+ goto sclose;
+ }
+
+ param = 1;
+ rsetsockopt(lfd, SOL_SOCKET, SO_REUSEADDR, ¶m, sizeof param);
+
+ sem_wait(sem);
+ ret = rbind(lfd, (struct sockaddr *) &sin6, sizeof sin6);
+ if (ret)
+ goto lclose;
+
+ ret = rlisten(lfd, 1);
+ if (ret)
+ goto lclose;
+
+ msg = 0;
+ len = real.write(sfd, &msg, sizeof msg);
+ if (len != sizeof msg)
+ goto lclose;
+
+ dfd = raccept(lfd, NULL, NULL);
+ if (dfd < 0) {
+ ret = dfd;
+ goto lclose;
+ }
+
+ param = 1;
+ rsetsockopt(dfd, IPPROTO_TCP, TCP_NODELAY, ¶m, sizeof param);
+ set_rsocket_options(dfd);
+
+ copysockopts(dfd, sfd, &rs, &real);
+ real.shutdown(sfd, SHUT_RDWR);
+ real.close(sfd);
+ fd_store(socket, dfd, fd_rsocket);
+
+lclose:
+ rclose(lfd);
+ sem_post(sem);
+sclose:
+ sem_close(sem);
+out:
+ if (ret)
+ fd_store(socket, sfd, fd_normal);
+}
+
+static inline enum fd_type fd_fork_get(int index, int *fd)
+{
+ struct fd_info *fdi;
+
+ fdi = idm_lookup(&idm, index);
+ if (fdi) {
+ if (fdi->type == fd_fork)
+ fork_passive(index);
+ *fd = fdi->fd;
+ return fdi->type;
+
+ } else {
+ *fd = index;
+ return fd_normal;
+ }
+}
+
int connect(int socket, const struct sockaddr *addr, socklen_t addrlen)
{
struct sockaddr_in *sin;
int fd, ret;
- if (fd_get(socket, &fd) == fd_rsocket) {
+ switch (fd_get(socket, &fd)) {
+ case fd_fork:
+ return fork_active(socket, addr, addrlen);
+ case fd_rsocket:
sin = (struct sockaddr_in *) addr;
if (ntohs(sin->sin_port) > 1024) {
ret = rconnect(fd, addr, addrlen);
rclose(fd);
fd = ret;
+ break;
+ default:
+ break;
}
return real.connect(fd, addr, addrlen);
ssize_t recv(int socket, void *buf, size_t len, int flags)
{
int fd;
- return (fd_get(socket, &fd) == fd_rsocket) ?
+ return (fd_fork_get(socket, &fd) == fd_rsocket) ?
rrecv(fd, buf, len, flags) : real.recv(fd, buf, len, flags);
}
struct sockaddr *src_addr, socklen_t *addrlen)
{
int fd;
- return (fd_get(socket, &fd) == fd_rsocket) ?
+ return (fd_fork_get(socket, &fd) == fd_rsocket) ?
rrecvfrom(fd, buf, len, flags, src_addr, addrlen) :
real.recvfrom(fd, buf, len, flags, src_addr, addrlen);
}
ssize_t recvmsg(int socket, struct msghdr *msg, int flags)
{
int fd;
- return (fd_get(socket, &fd) == fd_rsocket) ?
+ return (fd_fork_get(socket, &fd) == fd_rsocket) ?
rrecvmsg(fd, msg, flags) : real.recvmsg(fd, msg, flags);
}
{
int fd;
init_preload();
- return (fd_get(socket, &fd) == fd_rsocket) ?
+ return (fd_fork_get(socket, &fd) == fd_rsocket) ?
rread(fd, buf, count) : real.read(fd, buf, count);
}
{
int fd;
init_preload();
- return (fd_get(socket, &fd) == fd_rsocket) ?
+ return (fd_fork_get(socket, &fd) == fd_rsocket) ?
rreadv(fd, iov, iovcnt) : real.readv(fd, iov, iovcnt);
}
ssize_t send(int socket, const void *buf, size_t len, int flags)
{
int fd;
- return (fd_get(socket, &fd) == fd_rsocket) ?
+ return (fd_fork_get(socket, &fd) == fd_rsocket) ?
rsend(fd, buf, len, flags) : real.send(fd, buf, len, flags);
}
const struct sockaddr *dest_addr, socklen_t addrlen)
{
int fd;
- return (fd_get(socket, &fd) == fd_rsocket) ?
+ return (fd_fork_get(socket, &fd) == fd_rsocket) ?
rsendto(fd, buf, len, flags, dest_addr, addrlen) :
real.sendto(fd, buf, len, flags, dest_addr, addrlen);
}
ssize_t sendmsg(int socket, const struct msghdr *msg, int flags)
{
int fd;
- return (fd_get(socket, &fd) == fd_rsocket) ?
+ return (fd_fork_get(socket, &fd) == fd_rsocket) ?
rsendmsg(fd, msg, flags) : real.sendmsg(fd, msg, flags);
}
{
int fd;
init_preload();
- return (fd_get(socket, &fd) == fd_rsocket) ?
+ return (fd_fork_get(socket, &fd) == fd_rsocket) ?
rwrite(fd, buf, count) : real.write(fd, buf, count);
}
{
int fd;
init_preload();
- return (fd_get(socket, &fd) == fd_rsocket) ?
+ return (fd_fork_get(socket, &fd) == fd_rsocket) ?
rwritev(fd, iov, iovcnt) : real.writev(fd, iov, iovcnt);
}