]> git.openfabrics.org - ~shefty/librdmacm.git/commitdiff
librspreload: Make socket_fallback() call more generic
authorSean Hefty <sean.hefty@intel.com>
Mon, 16 Jul 2012 21:17:58 +0000 (14:17 -0700)
committerSean Hefty <sean.hefty@intel.com>
Thu, 19 Jul 2012 21:22:07 +0000 (14:22 -0700)
socket_fallback is used to switch from an rsocket to a normal
socket in the case of failures.  Rename the call and make it
more generic, so that it can switch between an rsocket and
a normal socket in either direction.  This will be used to
support fork().

As part of this change, we move the list of hooked and rsocket
calls into structures, versus maintaining a large number of
static variables.

Signed-off-by: Sean Hefty <sean.hefty@intel.com>
src/preload.c

index 2750b301b60daf1c7b35ae60be2007b0cc346318..d2058e23906cb54843add199ab0746406eb53dc8 100644 (file)
 #include "cma.h"
 #include "indexer.h"
 
-static int (*real_socket)(int domain, int type, int protocol);
-static int (*real_bind)(int socket, const struct sockaddr *addr,
-                       socklen_t addrlen);
-static int (*real_listen)(int socket, int backlog);
-static int (*real_accept)(int socket, struct sockaddr *addr,
-                         socklen_t *addrlen);
-static int (*real_connect)(int socket, const struct sockaddr *addr,
-                          socklen_t addrlen);
-static ssize_t (*real_recv)(int socket, void *buf, size_t len, int flags);
-static ssize_t (*real_recvfrom)(int socket, void *buf, size_t len, int flags,
-                               struct sockaddr *src_addr, socklen_t *addrlen);
-static ssize_t (*real_recvmsg)(int socket, struct msghdr *msg, int flags);
-static ssize_t (*real_read)(int socket, void *buf, size_t count);
-static ssize_t (*real_readv)(int socket, const struct iovec *iov, int iovcnt);
-static ssize_t (*real_send)(int socket, const void *buf, size_t len, int flags);
-static ssize_t (*real_sendto)(int socket, const void *buf, size_t len, int flags,
-                             const struct sockaddr *dest_addr, socklen_t addrlen);
-static ssize_t (*real_sendmsg)(int socket, const struct msghdr *msg, int flags);
-static ssize_t (*real_write)(int socket, const void *buf, size_t count);
-static ssize_t (*real_writev)(int socket, const struct iovec *iov, int iovcnt);
-static int (*real_poll)(struct pollfd *fds, nfds_t nfds, int timeout);
-static int (*real_shutdown)(int socket, int how);
-static int (*real_close)(int socket);
-static int (*real_getpeername)(int socket, struct sockaddr *addr,
-                              socklen_t *addrlen);
-static int (*real_getsockname)(int socket, struct sockaddr *addr,
-                              socklen_t *addrlen);
-static int (*real_setsockopt)(int socket, int level, int optname,
-                             const void *optval, socklen_t optlen);
-static int (*real_getsockopt)(int socket, int level, int optname,
-                             void *optval, socklen_t *optlen);
-static int (*real_fcntl)(int socket, int cmd, ... /* arg */);
+struct socket_calls {
+       int (*socket)(int domain, int type, int protocol);
+       int (*bind)(int socket, const struct sockaddr *addr, socklen_t addrlen);
+       int (*listen)(int socket, int backlog);
+       int (*accept)(int socket, struct sockaddr *addr, socklen_t *addrlen);
+       int (*connect)(int socket, const struct sockaddr *addr, socklen_t addrlen);
+       ssize_t (*recv)(int socket, void *buf, size_t len, int flags);
+       ssize_t (*recvfrom)(int socket, void *buf, size_t len, int flags,
+                           struct sockaddr *src_addr, socklen_t *addrlen);
+       ssize_t (*recvmsg)(int socket, struct msghdr *msg, int flags);
+       ssize_t (*read)(int socket, void *buf, size_t count);
+       ssize_t (*readv)(int socket, const struct iovec *iov, int iovcnt);
+       ssize_t (*send)(int socket, const void *buf, size_t len, int flags);
+       ssize_t (*sendto)(int socket, const void *buf, size_t len, int flags,
+                         const struct sockaddr *dest_addr, socklen_t addrlen);
+       ssize_t (*sendmsg)(int socket, const struct msghdr *msg, int flags);
+       ssize_t (*write)(int socket, const void *buf, size_t count);
+       ssize_t (*writev)(int socket, const struct iovec *iov, int iovcnt);
+       int (*poll)(struct pollfd *fds, nfds_t nfds, int timeout);
+       int (*shutdown)(int socket, int how);
+       int (*close)(int socket);
+       int (*getpeername)(int socket, struct sockaddr *addr, socklen_t *addrlen);
+       int (*getsockname)(int socket, struct sockaddr *addr, socklen_t *addrlen);
+       int (*setsockopt)(int socket, int level, int optname,
+                         const void *optval, socklen_t optlen);
+       int (*getsockopt)(int socket, int level, int optname,
+                         void *optval, socklen_t *optlen);
+       int (*fcntl)(int socket, int cmd, ... /* arg */);
+};
+
+static struct socket_calls real;
+static struct socket_calls rs;
 
 static struct index_map idm;
 static pthread_mutex_t mut = PTHREAD_MUTEX_INITIALIZER;
@@ -221,29 +221,53 @@ static void init_preload(void)
        if (init)
                goto out;
 
-       real_socket = dlsym(RTLD_NEXT, "socket");
-       real_bind = dlsym(RTLD_NEXT, "bind");
-       real_listen = dlsym(RTLD_NEXT, "listen");
-       real_accept = dlsym(RTLD_NEXT, "accept");
-       real_connect = dlsym(RTLD_NEXT, "connect");
-       real_recv = dlsym(RTLD_NEXT, "recv");
-       real_recvfrom = dlsym(RTLD_NEXT, "recvfrom");
-       real_recvmsg = dlsym(RTLD_NEXT, "recvmsg");
-       real_read = dlsym(RTLD_NEXT, "read");
-       real_readv = dlsym(RTLD_NEXT, "readv");
-       real_send = dlsym(RTLD_NEXT, "send");
-       real_sendto = dlsym(RTLD_NEXT, "sendto");
-       real_sendmsg = dlsym(RTLD_NEXT, "sendmsg");
-       real_write = dlsym(RTLD_NEXT, "write");
-       real_writev = dlsym(RTLD_NEXT, "writev");
-       real_poll = dlsym(RTLD_NEXT, "poll");
-       real_shutdown = dlsym(RTLD_NEXT, "shutdown");
-       real_close = dlsym(RTLD_NEXT, "close");
-       real_getpeername = dlsym(RTLD_NEXT, "getpeername");
-       real_getsockname = dlsym(RTLD_NEXT, "getsockname");
-       real_setsockopt = dlsym(RTLD_NEXT, "setsockopt");
-       real_getsockopt = dlsym(RTLD_NEXT, "getsockopt");
-       real_fcntl = dlsym(RTLD_NEXT, "fcntl");
+       real.socket = dlsym(RTLD_NEXT, "socket");
+       real.bind = dlsym(RTLD_NEXT, "bind");
+       real.listen = dlsym(RTLD_NEXT, "listen");
+       real.accept = dlsym(RTLD_NEXT, "accept");
+       real.connect = dlsym(RTLD_NEXT, "connect");
+       real.recv = dlsym(RTLD_NEXT, "recv");
+       real.recvfrom = dlsym(RTLD_NEXT, "recvfrom");
+       real.recvmsg = dlsym(RTLD_NEXT, "recvmsg");
+       real.read = dlsym(RTLD_NEXT, "read");
+       real.readv = dlsym(RTLD_NEXT, "readv");
+       real.send = dlsym(RTLD_NEXT, "send");
+       real.sendto = dlsym(RTLD_NEXT, "sendto");
+       real.sendmsg = dlsym(RTLD_NEXT, "sendmsg");
+       real.write = dlsym(RTLD_NEXT, "write");
+       real.writev = dlsym(RTLD_NEXT, "writev");
+       real.poll = dlsym(RTLD_NEXT, "poll");
+       real.shutdown = dlsym(RTLD_NEXT, "shutdown");
+       real.close = dlsym(RTLD_NEXT, "close");
+       real.getpeername = dlsym(RTLD_NEXT, "getpeername");
+       real.getsockname = dlsym(RTLD_NEXT, "getsockname");
+       real.setsockopt = dlsym(RTLD_NEXT, "setsockopt");
+       real.getsockopt = dlsym(RTLD_NEXT, "getsockopt");
+       real.fcntl = dlsym(RTLD_NEXT, "fcntl");
+
+       rs.socket = dlsym(RTLD_DEFAULT, "rsocket");
+       rs.bind = dlsym(RTLD_DEFAULT, "rbind");
+       rs.listen = dlsym(RTLD_DEFAULT, "rlisten");
+       rs.accept = dlsym(RTLD_DEFAULT, "raccept");
+       rs.connect = dlsym(RTLD_DEFAULT, "rconnect");
+       rs.recv = dlsym(RTLD_DEFAULT, "rrecv");
+       rs.recvfrom = dlsym(RTLD_DEFAULT, "rrecvfrom");
+       rs.recvmsg = dlsym(RTLD_DEFAULT, "rrecvmsg");
+       rs.read = dlsym(RTLD_DEFAULT, "rread");
+       rs.readv = dlsym(RTLD_DEFAULT, "rreadv");
+       rs.send = dlsym(RTLD_DEFAULT, "rsend");
+       rs.sendto = dlsym(RTLD_DEFAULT, "rsendto");
+       rs.sendmsg = dlsym(RTLD_DEFAULT, "rsendmsg");
+       rs.write = dlsym(RTLD_DEFAULT, "rwrite");
+       rs.writev = dlsym(RTLD_DEFAULT, "rwritev");
+       rs.poll = dlsym(RTLD_DEFAULT, "rpoll");
+       rs.shutdown = dlsym(RTLD_DEFAULT, "rshutdown");
+       rs.close = dlsym(RTLD_DEFAULT, "rclose");
+       rs.getpeername = dlsym(RTLD_DEFAULT, "rgetpeername");
+       rs.getsockname = dlsym(RTLD_DEFAULT, "rgetsockname");
+       rs.setsockopt = dlsym(RTLD_DEFAULT, "rsetsockopt");
+       rs.getsockopt = dlsym(RTLD_DEFAULT, "rgetsockopt");
+       rs.fcntl = dlsym(RTLD_DEFAULT, "rfcntl");
 
        getenv_options();
        init = 1;
@@ -252,51 +276,73 @@ out:
 }
 
 /*
- * Convert from an rsocket to a normal socket.  The new socket should have the
- * same settings and bindings as the rsocket.  We currently only handle setting
- * a few of the more common values.
+ * We currently only handle copying a few common values.
  */
-static int socket_fallback(int socket, int *fd)
+static int copysockopts(int dfd, int sfd, struct socket_calls *dapi,
+                       struct socket_calls *sapi)
 {
-       socklen_t len = 0;
-       int new_fd, param, ret;
+       socklen_t len;
+       int param, ret;
 
-       ret = rgetsockname(*fd, NULL, &len);
-       if (ret)
-               return ret;
-
-       param = (len == sizeof(struct sockaddr_in6)) ? PF_INET6 : PF_INET;
-       new_fd = real_socket(param, SOCK_STREAM, IPPROTO_TCP);
-       if (new_fd < 0)
-               return new_fd;
-
-       ret = rfcntl(*fd, F_GETFL);
+       ret = sapi->fcntl(sfd, F_GETFL);
        if (ret > 0)
-               ret = real_fcntl(new_fd, F_SETFL, ret);
+               ret = dapi->fcntl(dfd, F_SETFL, ret);
        if (ret)
-               goto err;
+               return ret;
 
        len = sizeof param;
-       ret = rgetsockopt(*fd, SOL_SOCKET, SO_REUSEADDR, &param, &len);
+       ret = sapi->getsockopt(sfd, SOL_SOCKET, SO_REUSEADDR, &param, &len);
        if (param && !ret)
-               ret = real_setsockopt(new_fd, SOL_SOCKET, SO_REUSEADDR, &param, len);
+               ret = dapi->setsockopt(dfd, SOL_SOCKET, SO_REUSEADDR, &param, len);
        if (ret)
-               goto err;
+               return ret;
 
        len = sizeof param;
-       ret = rgetsockopt(*fd, IPPROTO_TCP, TCP_NODELAY, &param, &len);
+       ret = sapi->getsockopt(sfd, IPPROTO_TCP, TCP_NODELAY, &param, &len);
        if (param && !ret)
-               ret = real_setsockopt(new_fd, IPPROTO_TCP, TCP_NODELAY, &param, len);
+               ret = dapi->setsockopt(dfd, IPPROTO_TCP, TCP_NODELAY, &param, len);
        if (ret)
-               goto err;
+               return ret;
 
-       rclose(*fd);
-       fd_store(socket, new_fd, fd_normal);
-       *fd = new_fd;
        return 0;
+}
+
+/*
+ * Convert between an rsocket and a normal socket.
+ */
+static int transpose_socket(int socket, enum fd_type new_type)
+{
+       socklen_t len = 0;
+       int sfd, dfd, param, ret;
+       struct socket_calls *sapi, *dapi;
+
+       sfd = fd_getd(socket);
+       if (new_type == fd_rsocket) {
+               dapi = &rs;
+               sapi = &real;
+       } else {
+               dapi = &real;
+               sapi = &rs;
+       }
+
+       ret = sapi->getsockname(sfd, NULL, &len);
+       if (ret)
+               return ret;
+
+       param = (len == sizeof(struct sockaddr_in6)) ? PF_INET6 : PF_INET;
+       dfd = dapi->socket(param, SOCK_STREAM, 0);
+       if (dfd < 0)
+               return dfd;
+
+       ret = copysockopts(dfd, sfd, dapi, sapi);
+       if (ret)
+               goto err;
+
+       fd_store(socket, dfd, new_type);
+       return dfd;
 
 err:
-       real_close(new_fd);
+       dapi->close(dfd);
        return ret;
 }
 
@@ -338,7 +384,7 @@ int socket(int domain, int type, int protocol)
        }
        fd_close(index, &ret);
 real:
-       return real_socket(domain, type, protocol);
+       return real.socket(domain, type, protocol);
 }
 
 int bind(int socket, const struct sockaddr *addr, socklen_t addrlen)
@@ -351,19 +397,22 @@ int bind(int socket, const struct sockaddr *addr, socklen_t addrlen)
                if (!sin->sin_port || ntohs(sin->sin_port) > 1024)
                        return rbind(fd, addr, addrlen);
 
-               ret = socket_fallback(socket, &fd);
-               if (ret)
+               ret = transpose_socket(socket, fd_normal);
+               if (ret < 0)
                        return ret;
+
+               rclose(fd);
+               fd = ret;
        }
 
-       return real_bind(fd, addr, addrlen);
+       return real.bind(fd, addr, addrlen);
 }
 
 int listen(int socket, int backlog)
 {
        int fd;
        return (fd_get(socket, &fd) == fd_rsocket) ?
-               rlisten(fd, backlog) : real_listen(fd, backlog);
+               rlisten(fd, backlog) : real.listen(fd, backlog);
 }
 
 int accept(int socket, struct sockaddr *addr, socklen_t *addrlen)
@@ -384,7 +433,7 @@ int accept(int socket, struct sockaddr *addr, socklen_t *addrlen)
                fd_store(index, ret, fd_rsocket);
                return index;
        } else {
-               return real_accept(fd, addr, addrlen);
+               return real.accept(fd, addr, addrlen);
        }
 }
 
@@ -401,19 +450,22 @@ int connect(int socket, const struct sockaddr *addr, socklen_t addrlen)
                                return ret;
                }
 
-               ret = socket_fallback(socket, &fd);
-               if (ret)
+               ret = transpose_socket(socket, fd_normal);
+               if (ret < 0)
                        return ret;
+
+               rclose(fd);
+               fd = ret;
        }
 
-       return real_connect(fd, addr, addrlen);
+       return real.connect(fd, addr, addrlen);
 }
 
 ssize_t recv(int socket, void *buf, size_t len, int flags)
 {
        int fd;
        return (fd_get(socket, &fd) == fd_rsocket) ?
-               rrecv(fd, buf, len, flags) : real_recv(fd, buf, len, flags);
+               rrecv(fd, buf, len, flags) : real.recv(fd, buf, len, flags);
 }
 
 ssize_t recvfrom(int socket, void *buf, size_t len, int flags,
@@ -422,14 +474,14 @@ ssize_t recvfrom(int socket, void *buf, size_t len, int flags,
        int fd;
        return (fd_get(socket, &fd) == fd_rsocket) ?
                rrecvfrom(fd, buf, len, flags, src_addr, addrlen) :
-               real_recvfrom(fd, buf, len, flags, src_addr, addrlen);
+               real.recvfrom(fd, buf, len, flags, src_addr, addrlen);
 }
 
 ssize_t recvmsg(int socket, struct msghdr *msg, int flags)
 {
        int fd;
        return (fd_get(socket, &fd) == fd_rsocket) ?
-               rrecvmsg(fd, msg, flags) : real_recvmsg(fd, msg, flags);
+               rrecvmsg(fd, msg, flags) : real.recvmsg(fd, msg, flags);
 }
 
 ssize_t read(int socket, void *buf, size_t count)
@@ -437,7 +489,7 @@ ssize_t read(int socket, void *buf, size_t count)
        int fd;
        init_preload();
        return (fd_get(socket, &fd) == fd_rsocket) ?
-               rread(fd, buf, count) : real_read(fd, buf, count);
+               rread(fd, buf, count) : real.read(fd, buf, count);
 }
 
 ssize_t readv(int socket, const struct iovec *iov, int iovcnt)
@@ -445,14 +497,14 @@ ssize_t readv(int socket, const struct iovec *iov, int iovcnt)
        int fd;
        init_preload();
        return (fd_get(socket, &fd) == fd_rsocket) ?
-               rreadv(fd, iov, iovcnt) : real_readv(fd, iov, iovcnt);
+               rreadv(fd, iov, iovcnt) : real.readv(fd, iov, iovcnt);
 }
 
 ssize_t send(int socket, const void *buf, size_t len, int flags)
 {
        int fd;
        return (fd_get(socket, &fd) == fd_rsocket) ?
-               rsend(fd, buf, len, flags) : real_send(fd, buf, len, flags);
+               rsend(fd, buf, len, flags) : real.send(fd, buf, len, flags);
 }
 
 ssize_t sendto(int socket, const void *buf, size_t len, int flags,
@@ -461,14 +513,14 @@ ssize_t sendto(int socket, const void *buf, size_t len, int flags,
        int fd;
        return (fd_get(socket, &fd) == fd_rsocket) ?
                rsendto(fd, buf, len, flags, dest_addr, addrlen) :
-               real_sendto(fd, buf, len, flags, dest_addr, addrlen);
+               real.sendto(fd, buf, len, flags, dest_addr, addrlen);
 }
 
 ssize_t sendmsg(int socket, const struct msghdr *msg, int flags)
 {
        int fd;
        return (fd_get(socket, &fd) == fd_rsocket) ?
-               rsendmsg(fd, msg, flags) : real_sendmsg(fd, msg, flags);
+               rsendmsg(fd, msg, flags) : real.sendmsg(fd, msg, flags);
 }
 
 ssize_t write(int socket, const void *buf, size_t count)
@@ -476,7 +528,7 @@ ssize_t write(int socket, const void *buf, size_t count)
        int fd;
        init_preload();
        return (fd_get(socket, &fd) == fd_rsocket) ?
-               rwrite(fd, buf, count) : real_write(fd, buf, count);
+               rwrite(fd, buf, count) : real.write(fd, buf, count);
 }
 
 ssize_t writev(int socket, const struct iovec *iov, int iovcnt)
@@ -484,7 +536,7 @@ ssize_t writev(int socket, const struct iovec *iov, int iovcnt)
        int fd;
        init_preload();
        return (fd_get(socket, &fd) == fd_rsocket) ?
-               rwritev(fd, iov, iovcnt) : real_writev(fd, iov, iovcnt);
+               rwritev(fd, iov, iovcnt) : real.writev(fd, iov, iovcnt);
 }
 
 static struct pollfd *fds_alloc(nfds_t nfds)
@@ -514,7 +566,7 @@ int poll(struct pollfd *fds, nfds_t nfds, int timeout)
                        goto use_rpoll;
        }
 
-       return real_poll(fds, nfds, timeout);
+       return real.poll(fds, nfds, timeout);
 
 use_rpoll:
        rfds = fds_alloc(nfds);
@@ -619,14 +671,14 @@ int shutdown(int socket, int how)
 {
        int fd;
        return (fd_get(socket, &fd) == fd_rsocket) ?
-               rshutdown(fd, how) : real_shutdown(fd, how);
+               rshutdown(fd, how) : real.shutdown(fd, how);
 }
 
 int close(int socket)
 {
        int fd;
        init_preload();
-       return (fd_close(socket, &fd) == fd_rsocket) ? rclose(fd) : real_close(fd);
+       return (fd_close(socket, &fd) == fd_rsocket) ? rclose(fd) : real.close(fd);
 }
 
 int getpeername(int socket, struct sockaddr *addr, socklen_t *addrlen)
@@ -634,7 +686,7 @@ int getpeername(int socket, struct sockaddr *addr, socklen_t *addrlen)
        int fd;
        return (fd_get(socket, &fd) == fd_rsocket) ?
                rgetpeername(fd, addr, addrlen) :
-               real_getpeername(fd, addr, addrlen);
+               real.getpeername(fd, addr, addrlen);
 }
 
 int getsockname(int socket, struct sockaddr *addr, socklen_t *addrlen)
@@ -642,7 +694,7 @@ int getsockname(int socket, struct sockaddr *addr, socklen_t *addrlen)
        int fd;
        return (fd_get(socket, &fd) == fd_rsocket) ?
                rgetsockname(fd, addr, addrlen) :
-               real_getsockname(fd, addr, addrlen);
+               real.getsockname(fd, addr, addrlen);
 }
 
 int setsockopt(int socket, int level, int optname,
@@ -651,7 +703,7 @@ int setsockopt(int socket, int level, int optname,
        int fd;
        return (fd_get(socket, &fd) == fd_rsocket) ?
                rsetsockopt(fd, level, optname, optval, optlen) :
-               real_setsockopt(fd, level, optname, optval, optlen);
+               real.setsockopt(fd, level, optname, optval, optlen);
 }
 
 int getsockopt(int socket, int level, int optname,
@@ -660,7 +712,7 @@ int getsockopt(int socket, int level, int optname,
        int fd;
        return (fd_get(socket, &fd) == fd_rsocket) ?
                rgetsockopt(fd, level, optname, optval, optlen) :
-               real_getsockopt(fd, level, optname, optval, optlen);
+               real.getsockopt(fd, level, optname, optval, optlen);
 }
 
 int fcntl(int socket, int cmd, ... /* arg */)
@@ -679,7 +731,7 @@ int fcntl(int socket, int cmd, ... /* arg */)
        case F_GETSIG:
        case F_GETLEASE:
                ret = (fd_get(socket, &fd) == fd_rsocket) ?
-                       rfcntl(fd, cmd) : real_fcntl(fd, cmd);
+                       rfcntl(fd, cmd) : real.fcntl(fd, cmd);
                break;
        case F_DUPFD:
        /*case F_DUPFD_CLOEXEC:*/
@@ -691,12 +743,12 @@ int fcntl(int socket, int cmd, ... /* arg */)
        case F_NOTIFY:
                lparam = va_arg(args, long);
                ret = (fd_get(socket, &fd) == fd_rsocket) ?
-                       rfcntl(fd, cmd, lparam) : real_fcntl(fd, cmd, lparam);
+                       rfcntl(fd, cmd, lparam) : real.fcntl(fd, cmd, lparam);
                break;
        default:
                pparam = va_arg(args, void *);
                ret = (fd_get(socket, &fd) == fd_rsocket) ?
-                       rfcntl(fd, cmd, pparam) : real_fcntl(fd, cmd, pparam);
+                       rfcntl(fd, cmd, pparam) : real.fcntl(fd, cmd, pparam);
                break;
        }
        va_end(args);