enum fd_type {
fd_normal,
- fd_rsocket,
- fd_fork
+ fd_rsocket
+};
+
+enum fd_fork_state {
+ fd_ready,
+ fd_fork,
+ fd_fork_listen,
+ fd_fork_active,
+ fd_fork_passive
};
struct fd_info {
enum fd_type type;
+ enum fd_fork_state state;
int fd;
int dupfd;
atomic_t refcnt;
return ret;
}
-static void fd_store(int index, int fd, enum fd_type type)
+static void fd_store(int index, int fd, enum fd_type type, enum fd_fork_state state)
{
struct fd_info *fdi;
fdi = idm_at(&idm, index);
fdi->fd = fd;
fdi->type = type;
+ fdi->state = state;
}
static inline enum fd_type fd_get(int index, int *fd)
return fdi ? fdi->fd : index;
}
+static inline enum fd_state fd_gets(int index)
+{
+ struct fd_info *fdi;
+
+ fdi = idm_lookup(&idm, index);
+ return fdi ? fdi->state : fd_ready;
+}
+
static inline enum fd_type fd_gett(int index)
{
struct fd_info *fdi;
if (ret)
goto err;
- fd_store(socket, dfd, new_type);
+ fd_store(socket, dfd, new_type, fd_ready);
return dfd;
err:
ret = real.socket(domain, type, protocol);
if (ret < 0)
return ret;
- fd_store(index, ret, fd_fork);
+ fd_store(index, ret, fd_normal, fd_fork);
} else {
- fd_store(index, ret, fd_rsocket);
+ fd_store(index, ret, fd_rsocket, fd_ready);
set_rsocket_options(ret);
}
return index;
int listen(int socket, int backlog)
{
- int fd;
- return (fd_get(socket, &fd) == fd_rsocket) ?
- rlisten(fd, backlog) : real.listen(fd, backlog);
+ int fd, ret;
+ if (fd_get(socket, &fd) == fd_rsocket) {
+ ret = rlisten(fd, backlog);
+ } else {
+ ret = real.listen(fd, backlog);
+ if (!ret && fd_gets(socket) == fd_fork)
+ fd_store(socket, fd, fd_normal, fd_fork_listen);
+ }
+ return ret;
}
int accept(int socket, struct sockaddr *addr, socklen_t *addrlen)
{
int fd, index, ret;
- enum fd_type type;
- type = fd_get(socket, &fd);
- if (type == fd_rsocket || type == fd_fork) {
+ if (fd_get(socket, &fd) == fd_rsocket) {
+ index = fd_open();
+ if (index < 0)
+ return index;
+
+ ret = raccept(fd, addr, addrlen);
+ if (ret < 0) {
+ fd_close(index, &fd);
+ return ret;
+ }
+
+ fd_store(index, ret, fd_rsocket, fd_ready);
+ return index;
+ } else if (fd_gets(socket) == fd_fork_listen) {
index = fd_open();
if (index < 0)
return index;
- ret = (type == fd_rsocket) ? raccept(fd, addr, addrlen) :
- real.accept(fd, addr, addrlen);
+ ret = real.accept(fd, addr, addrlen);
if (ret < 0) {
fd_close(index, &fd);
return ret;
}
- fd_store(index, ret, type);
+ fd_store(index, ret, fd_normal, fd_fork_passive);
return index;
} else {
return real.accept(fd, addr, addrlen);
* We can't fork RDMA connections and pass them from the parent to the child
* process. Instead, we need to establish the RDMA connection after calling
* fork. To do this, we delay establishing the RDMA connection until we try
- * to send/receive on the server side. On the client side, we don't expect
- * to fork, so we switch from a TCP connection to an rsocket when connecting.
+ * to send/receive on the server side.
*/
static int fork_active(int socket, const struct sockaddr *addr, socklen_t addrlen)
{
fd = fd_getd(socket);
flags = real.fcntl(fd, F_GETFL);
real.fcntl(fd, F_SETFL, 0);
- ret = real.connect(fd, addr, addrlen);
- if (ret)
- return ret;
+
+ if (!(flags & O_NONBLOCK) && addr && addrlen) {
+ ret = real.connect(fd, addr, addrlen);
+ if (ret)
+ return ret;
+ }
ret = real.recv(fd, &msg, sizeof msg, MSG_PEEK);
if ((ret != sizeof msg) || msg) {
- fd_store(socket, fd, fd_normal);
+ fd_store(socket, fd, fd_normal, fd_ready);
return 0;
}
copysockopts(dfd, sfd, &rs, &real);
real.shutdown(sfd, SHUT_RDWR);
real.close(sfd);
- fd_store(socket, dfd, fd_rsocket);
+ fd_store(socket, dfd, fd_rsocket, fd_ready);
lclose:
rclose(lfd);
sem_close(sem);
out:
if (ret)
- fd_store(socket, sfd, fd_normal);
+ fd_store(socket, sfd, fd_normal, fd_ready);
}
static inline enum fd_type fd_fork_get(int index, int *fd)
fdi = idm_lookup(&idm, index);
if (fdi) {
- if (fdi->type == fd_fork)
+ if (fdi->type == fd_fork_passive)
fork_passive(index);
+ else if (fdi->type == fd_fork_active)
+ fork_active(index, NULL, 0);
*fd = fdi->fd;
return fdi->type;
int connect(int socket, const struct sockaddr *addr, socklen_t addrlen)
{
int fd, ret;
+ long flags;
- switch (fd_get(socket, &fd)) {
- case fd_fork:
- return fork_active(socket, addr, addrlen);
- case fd_rsocket:
+ if (fd_get(socket, &fd) == fd_rsocket) {
ret = rconnect(fd, addr, addrlen);
if (!ret || errno == EINPROGRESS)
return ret;
rclose(fd);
fd = ret;
- break;
- default:
- break;
+ } else if (fd_gets(socket) == fd_fork) {
+ flags = real.fcntl(fd, F_GETFL);
+ if (!(flags & O_NONBLOCK))
+ return fork_active(socket, addr, addrlen);
+
+ fd_store(socket, fd, fd_normal, fd_fork_active);
}
return real.connect(fd, addr, addrlen);