enum fd_type {
fd_normal,
- fd_rsocket,
- fd_fork
+ fd_rsocket
+};
+
+enum fd_fork_state {
+ fd_ready,
+ fd_fork,
+ fd_fork_listen,
+ fd_fork_active,
+ fd_fork_passive
};
struct fd_info {
enum fd_type type;
+ enum fd_fork_state state;
int fd;
int dupfd;
atomic_t refcnt;
return ret;
}
-static void fd_store(int index, int fd, enum fd_type type)
+static void fd_store(int index, int fd, enum fd_type type, enum fd_fork_state state)
{
struct fd_info *fdi;
fdi = idm_at(&idm, index);
fdi->fd = fd;
fdi->type = type;
+ fdi->state = state;
}
static inline enum fd_type fd_get(int index, int *fd)
return fdi ? fdi->fd : index;
}
+static inline enum fd_state fd_gets(int index)
+{
+ struct fd_info *fdi;
+
+ fdi = idm_lookup(&idm, index);
+ return fdi ? fdi->state : fd_ready;
+}
+
static inline enum fd_type fd_gett(int index)
{
struct fd_info *fdi;
if (ret)
goto err;
- fd_store(socket, dfd, new_type);
+ fd_store(socket, dfd, new_type, fd_ready);
return dfd;
err:
ret = real.socket(domain, type, protocol);
if (ret < 0)
return ret;
- fd_store(index, ret, fd_fork);
+ fd_store(index, ret, fd_normal, fd_fork);
} else {
- fd_store(index, ret, fd_rsocket);
+ fd_store(index, ret, fd_rsocket, fd_ready);
set_rsocket_options(ret);
}
return index;
int listen(int socket, int backlog)
{
- int fd;
- return (fd_get(socket, &fd) == fd_rsocket) ?
- rlisten(fd, backlog) : real.listen(fd, backlog);
+ int fd, ret;
+ if (fd_get(socket, &fd) == fd_rsocket) {
+ ret = rlisten(fd, backlog);
+ } else {
+ ret = real.listen(fd, backlog);
+ if (!ret && fd_gets(socket) == fd_fork)
+ fd_store(socket, fd, fd_normal, fd_fork_listen);
+ }
+ return ret;
}
int accept(int socket, struct sockaddr *addr, socklen_t *addrlen)
{
int fd, index, ret;
- enum fd_type type;
- type = fd_get(socket, &fd);
- if (type == fd_rsocket || type == fd_fork) {
+ if (fd_get(socket, &fd) == fd_rsocket) {
+ index = fd_open();
+ if (index < 0)
+ return index;
+
+ ret = raccept(fd, addr, addrlen);
+ if (ret < 0) {
+ fd_close(index, &fd);
+ return ret;
+ }
+
+ fd_store(index, ret, fd_rsocket, fd_ready);
+ return index;
+ } else if (fd_gets(socket) == fd_fork_listen) {
index = fd_open();
if (index < 0)
return index;
- ret = (type == fd_rsocket) ? raccept(fd, addr, addrlen) :
- real.accept(fd, addr, addrlen);
+ ret = real.accept(fd, addr, addrlen);
if (ret < 0) {
fd_close(index, &fd);
return ret;
}
- fd_store(index, ret, type);
+ fd_store(index, ret, fd_normal, fd_fork_passive);
return index;
} else {
return real.accept(fd, addr, addrlen);
* We can't fork RDMA connections and pass them from the parent to the child
* process. Instead, we need to establish the RDMA connection after calling
* fork. To do this, we delay establishing the RDMA connection until we try
- * to send/receive on the server side. On the client side, we don't expect
- * to fork, so we switch from a TCP connection to an rsocket when connecting.
+ * to send/receive on the server side.
*/
-static int fork_active(int socket, const struct sockaddr *addr, socklen_t addrlen)
+static void fork_active(int socket)
{
- int fd, ret;
+ struct sockaddr_storage addr;
+ int sfd, dfd, ret;
+ socklen_t len;
uint32_t msg;
long flags;
- fd = fd_getd(socket);
- flags = real.fcntl(fd, F_GETFL);
- real.fcntl(fd, F_SETFL, 0);
- ret = real.connect(fd, addr, addrlen);
+ sfd = fd_getd(socket);
+
+ len = sizeof addr;
+ ret = real.getpeername(sfd, (struct sockaddr *) &addr, &len);
if (ret)
- return ret;
+ goto err1;
- ret = real.recv(fd, &msg, sizeof msg, MSG_PEEK);
- if ((ret != sizeof msg) || msg) {
- fd_store(socket, fd, fd_normal);
- return 0;
- }
+ dfd = rsocket(addr.ss_family, SOCK_STREAM, 0);
+ if (dfd < 0)
+ goto err1;
- real.fcntl(fd, F_SETFL, flags);
- ret = transpose_socket(socket, fd_rsocket);
- if (ret < 0)
- return ret;
+ flags = real.fcntl(sfd, F_GETFL);
+ real.fcntl(sfd, F_SETFL, 0);
+ ret = real.recv(sfd, &msg, sizeof msg, MSG_PEEK);
+ real.fcntl(sfd, F_SETFL, flags);
+ if ((ret != sizeof msg) || msg)
+ goto err2;
- real.close(fd);
- return rconnect(ret, addr, addrlen);
+ ret = rconnect(ret, &sin6, len);
+ if (ret)
+ goto err2;
+
+ set_rsocket_options(dfd);
+ copysockopts(dfd, sfd, &rs, &real);
+ real.shutdown(sfd, SHUT_RDWR);
+ real.close(sfd);
+ fd_store(socket, dfd, fd_rsocket, fd_ready);
+ return;
+
+err2:
+ rclose(dfd);
+err1:
+ fd_store(socket, sfd, fd_normal, fd_ready);
}
static void fork_passive(int socket)
socklen_t len;
uint32_t msg;
- fd_get(socket, &sfd);
+ sfd = fd_getd(socket);
len = sizeof sin6;
ret = real.getsockname(sfd, (struct sockaddr *) &sin6, &len);
lfd = rsocket(sin6.sin6_family, SOCK_STREAM, 0);
if (lfd < 0) {
- ret = lfd;
+ ret = lfd;
goto sclose;
}
goto lclose;
msg = 0;
- len = real.write(sfd, &msg, sizeof msg);
+ len = real.send(sfd, &msg, sizeof msg, MSG_NODELAY);
if (len != sizeof msg)
goto lclose;
goto lclose;
}
- param = 1;
- rsetsockopt(dfd, IPPROTO_TCP, TCP_NODELAY, ¶m, sizeof param);
set_rsocket_options(dfd);
-
copysockopts(dfd, sfd, &rs, &real);
real.shutdown(sfd, SHUT_RDWR);
real.close(sfd);
- fd_store(socket, dfd, fd_rsocket);
+ fd_store(socket, dfd, fd_rsocket, fd_ready);
lclose:
rclose(lfd);
sem_close(sem);
out:
if (ret)
- fd_store(socket, sfd, fd_normal);
+ fd_store(socket, sfd, fd_normal, fd_ready);
}
static inline enum fd_type fd_fork_get(int index, int *fd)
fdi = idm_lookup(&idm, index);
if (fdi) {
- if (fdi->type == fd_fork)
+ if (fdi->type == fd_fork_passive)
fork_passive(index);
+ else if (fdi->type == fd_fork_active)
+ fork_active(index);
*fd = fdi->fd;
return fdi->type;
int connect(int socket, const struct sockaddr *addr, socklen_t addrlen)
{
int fd, ret;
+ long flags;
- switch (fd_get(socket, &fd)) {
- case fd_fork:
- return fork_active(socket, addr, addrlen);
- case fd_rsocket:
+ if (fd_get(socket, &fd) == fd_rsocket) {
ret = rconnect(fd, addr, addrlen);
if (!ret || errno == EINPROGRESS)
return ret;
rclose(fd);
fd = ret;
- break;
- default:
- break;
+ } else if (fd_gets(socket) == fd_fork) {
+ fd_store(socket, fd, fd_normal, fd_fork_active);
}
return real.connect(fd, addr, addrlen);