#include <string.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
+#include <unistd.h>
+#include <semaphore.h>
#include <rdma/rdma_cma.h>
#include <rdma/rdma_verbs.h>
int (*getsockopt)(int socket, int level, int optname,
void *optval, socklen_t *optlen);
int (*fcntl)(int socket, int cmd, ... /* arg */);
+ pid_t (*fork)(void);
};
static struct socket_calls real;
static int rq_size;
static int sq_inline;
static int fork_support;
+static int last_accept = -1;
enum fd_type {
fd_normal,
real.setsockopt = dlsym(RTLD_NEXT, "setsockopt");
real.getsockopt = dlsym(RTLD_NEXT, "getsockopt");
real.fcntl = dlsym(RTLD_NEXT, "fcntl");
+ real.fork = dlsym(RTLD_NEXT, "fork");
rs.socket = dlsym(RTLD_DEFAULT, "rsocket");
rs.bind = dlsym(RTLD_DEFAULT, "rbind");
* the same settings and bindings as the current socket. We currently only
* handle setting a few of the more common values.
*/
-static int transpose_socket(int index, int *fd, enum fd_type new_type)
+static int transpose_socket(int socket, enum fd_type new_type)
{
- socklen_t len = 0;
- int new_fd, param, ret;
+ int fd, new_fd, param, ret;
struct socket_calls *new, *old;
+ socklen_t len = 0;
+ fd = fd_getd(socket);
if (new_type == fd_rsocket) {
new = &rs;
old = ℜ
old = &rs;
}
- ret = old->getsockname(*fd, NULL, &len);
+ ret = old->getsockname(fd, NULL, &len);
if (ret)
return ret;
if (new_fd < 0)
return new_fd;
- ret = old->fcntl(*fd, F_GETFL);
+ ret = old->fcntl(fd, F_GETFL);
if (ret > 0)
ret = new->fcntl(new_fd, F_SETFL, ret);
if (ret)
goto err;
len = sizeof param;
- ret = old->getsockopt(*fd, SOL_SOCKET, SO_REUSEADDR, ¶m, &len);
+ ret = old->getsockopt(fd, SOL_SOCKET, SO_REUSEADDR, ¶m, &len);
if (param && !ret)
ret = new->setsockopt(new_fd, SOL_SOCKET, SO_REUSEADDR, ¶m, len);
if (ret)
goto err;
len = sizeof param;
- ret = old->getsockopt(*fd, IPPROTO_TCP, TCP_NODELAY, ¶m, &len);
+ ret = old->getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, ¶m, &len);
if (param && !ret)
ret = new->setsockopt(new_fd, IPPROTO_TCP, TCP_NODELAY, ¶m, len);
if (ret)
goto err;
- old->close(*fd);
fd_store(socket, new_fd, new_type);
- *fd = new_fd;
- return 0;
+ return new_fd;
err:
new->close(new_fd);
if (!sin->sin_port || ntohs(sin->sin_port) > 1024)
return rbind(fd, addr, addrlen);
- ret = transpose_socket(socket, &fd, fd_normal);
- if (ret)
+ ret = transpose_socket(socket, fd_normal);
+ if (ret < 0)
return ret;
+ rclose(fd);
+ fd = ret;
}
return real.bind(fd, addr, addrlen);
}
fd_store(index, ret, fd_rsocket);
+ last_accept = index;
return index;
} else {
+ last_accept = -1;
return real.accept(fd, addr, addrlen);
}
}
int fd, ret;
uint32_t msg;
- fd_get(socket, &fd);
+ fd = fd_getd(socket);
ret = real.connect(fd, addr, addrlen);
if (ret)
return ret;
ret = real.read(fd, &msg, sizeof msg);
- if (ret != sizeof msg)
- return ret;
+ if ((ret != sizeof msg) || msg) {
+ fd_store(socket, fd, fd_normal);
+ return 0;
+ }
- ret = transpose_socket(socket, &fd, fd_rsocket);
- if (ret)
+ ret = transpose_socket(socket, fd_rsocket);
+ if (ret < 0)
return ret;
- return rconnect(fd, addr, addrlen);
+ real.close(fd);
+ return rconnect(ret, addr, addrlen);
}
int connect(int socket, const struct sockaddr *addr, socklen_t addrlen)
return ret;
}
- ret = transpose_socket(socket, &fd, fd_normal);
- if (ret)
+ ret = transpose_socket(socket, fd_normal);
+ if (ret < 0)
return ret;
+ rclose(fd);
+ fd = ret;
+ break;
}
return real.connect(fd, addr, addrlen);
va_end(args);
return ret;
}
+
+/*
+ * We can't fork RDMA connections and pass them from the parent to the child
+ * process. Intercept the fork call, and if we're the child establish the
+ * RDMA connection after calling fork. The assumption is that the last
+ * connection accepted by the server will be processed by the child after the
+ * fork call.
+ *
+ * It would be better to establishing the RDMA connection once the child
+ * process tries to use the connection after the fork call (i.e. in a read
+ * or write call), rather than making the previous assumption.
+ */
+pid_t fork(void)
+{
+ struct sockaddr_storage sa;
+ pid_t pid;
+ sem_t *sem;
+ int fd, lfd, newfd, ret, len, param;
+ uint32_t msg;
+
+ init_preload();
+ pid = real.fork();
+ if (pid || !fork_support || (last_accept < 0) ||
+ (fd_get(last_accept, &fd) != fd_fork))
+ goto out;
+
+ sem = sem_open("/rsocket_fork", O_CREAT, 0644, 1);
+ if (sem == SEM_FAILED)
+ goto out;
+
+ lfd = transpose_socket(last_accept, fd_rsocket);
+ if (lfd < 0)
+ goto out;
+
+ param = 1;
+ len = sizeof param;
+ rsetsockopt(lfd, SOL_SOCKET, SO_REUSEADDR, ¶m, len);
+
+ len = sizeof sa;
+ ret = real.getsockname(fd, &sa, &len);
+ if (ret)
+ goto out;
+
+ sem_wait(sem);
+ ret = rbind()
+
+ real.close(fd);
+
+ sem_post(sem);
+ sem_close(sem);
+out:
+ last_accept = -1;
+ return pid;
+}