From 3e52eb22f44eafaefa95c4674bc5665a94e15694 Mon Sep 17 00:00:00 2001 From: Sean Hefty Date: Mon, 16 Jul 2012 14:17:58 -0700 Subject: [PATCH] librspreload: Make socket_fallback() call more generic socket_fallback is used to switch from an rsocket to a normal socket in the case of failures. Rename the call and make it more generic, so that it can switch between an rsocket and a normal socket in either direction. This will be used to support fork(). As part of this change, we move the list of hooked and rsocket calls into structures, versus maintaining a large number of static variables. Signed-off-by: Sean Hefty --- src/preload.c | 276 ++++++++++++++++++++++++++++++-------------------- 1 file changed, 164 insertions(+), 112 deletions(-) diff --git a/src/preload.c b/src/preload.c index 2750b301..d2058e23 100644 --- a/src/preload.c +++ b/src/preload.c @@ -53,38 +53,38 @@ #include "cma.h" #include "indexer.h" -static int (*real_socket)(int domain, int type, int protocol); -static int (*real_bind)(int socket, const struct sockaddr *addr, - socklen_t addrlen); -static int (*real_listen)(int socket, int backlog); -static int (*real_accept)(int socket, struct sockaddr *addr, - socklen_t *addrlen); -static int (*real_connect)(int socket, const struct sockaddr *addr, - socklen_t addrlen); -static ssize_t (*real_recv)(int socket, void *buf, size_t len, int flags); -static ssize_t (*real_recvfrom)(int socket, void *buf, size_t len, int flags, - struct sockaddr *src_addr, socklen_t *addrlen); -static ssize_t (*real_recvmsg)(int socket, struct msghdr *msg, int flags); -static ssize_t (*real_read)(int socket, void *buf, size_t count); -static ssize_t (*real_readv)(int socket, const struct iovec *iov, int iovcnt); -static ssize_t (*real_send)(int socket, const void *buf, size_t len, int flags); -static ssize_t (*real_sendto)(int socket, const void *buf, size_t len, int flags, - const struct sockaddr *dest_addr, socklen_t addrlen); -static ssize_t (*real_sendmsg)(int socket, const struct msghdr *msg, int flags); -static ssize_t (*real_write)(int socket, const void *buf, size_t count); -static ssize_t (*real_writev)(int socket, const struct iovec *iov, int iovcnt); -static int (*real_poll)(struct pollfd *fds, nfds_t nfds, int timeout); -static int (*real_shutdown)(int socket, int how); -static int (*real_close)(int socket); -static int (*real_getpeername)(int socket, struct sockaddr *addr, - socklen_t *addrlen); -static int (*real_getsockname)(int socket, struct sockaddr *addr, - socklen_t *addrlen); -static int (*real_setsockopt)(int socket, int level, int optname, - const void *optval, socklen_t optlen); -static int (*real_getsockopt)(int socket, int level, int optname, - void *optval, socklen_t *optlen); -static int (*real_fcntl)(int socket, int cmd, ... /* arg */); +struct socket_calls { + int (*socket)(int domain, int type, int protocol); + int (*bind)(int socket, const struct sockaddr *addr, socklen_t addrlen); + int (*listen)(int socket, int backlog); + int (*accept)(int socket, struct sockaddr *addr, socklen_t *addrlen); + int (*connect)(int socket, const struct sockaddr *addr, socklen_t addrlen); + ssize_t (*recv)(int socket, void *buf, size_t len, int flags); + ssize_t (*recvfrom)(int socket, void *buf, size_t len, int flags, + struct sockaddr *src_addr, socklen_t *addrlen); + ssize_t (*recvmsg)(int socket, struct msghdr *msg, int flags); + ssize_t (*read)(int socket, void *buf, size_t count); + ssize_t (*readv)(int socket, const struct iovec *iov, int iovcnt); + ssize_t (*send)(int socket, const void *buf, size_t len, int flags); + ssize_t (*sendto)(int socket, const void *buf, size_t len, int flags, + const struct sockaddr *dest_addr, socklen_t addrlen); + ssize_t (*sendmsg)(int socket, const struct msghdr *msg, int flags); + ssize_t (*write)(int socket, const void *buf, size_t count); + ssize_t (*writev)(int socket, const struct iovec *iov, int iovcnt); + int (*poll)(struct pollfd *fds, nfds_t nfds, int timeout); + int (*shutdown)(int socket, int how); + int (*close)(int socket); + int (*getpeername)(int socket, struct sockaddr *addr, socklen_t *addrlen); + int (*getsockname)(int socket, struct sockaddr *addr, socklen_t *addrlen); + int (*setsockopt)(int socket, int level, int optname, + const void *optval, socklen_t optlen); + int (*getsockopt)(int socket, int level, int optname, + void *optval, socklen_t *optlen); + int (*fcntl)(int socket, int cmd, ... /* arg */); +}; + +static struct socket_calls real; +static struct socket_calls rs; static struct index_map idm; static pthread_mutex_t mut = PTHREAD_MUTEX_INITIALIZER; @@ -221,29 +221,53 @@ static void init_preload(void) if (init) goto out; - real_socket = dlsym(RTLD_NEXT, "socket"); - real_bind = dlsym(RTLD_NEXT, "bind"); - real_listen = dlsym(RTLD_NEXT, "listen"); - real_accept = dlsym(RTLD_NEXT, "accept"); - real_connect = dlsym(RTLD_NEXT, "connect"); - real_recv = dlsym(RTLD_NEXT, "recv"); - real_recvfrom = dlsym(RTLD_NEXT, "recvfrom"); - real_recvmsg = dlsym(RTLD_NEXT, "recvmsg"); - real_read = dlsym(RTLD_NEXT, "read"); - real_readv = dlsym(RTLD_NEXT, "readv"); - real_send = dlsym(RTLD_NEXT, "send"); - real_sendto = dlsym(RTLD_NEXT, "sendto"); - real_sendmsg = dlsym(RTLD_NEXT, "sendmsg"); - real_write = dlsym(RTLD_NEXT, "write"); - real_writev = dlsym(RTLD_NEXT, "writev"); - real_poll = dlsym(RTLD_NEXT, "poll"); - real_shutdown = dlsym(RTLD_NEXT, "shutdown"); - real_close = dlsym(RTLD_NEXT, "close"); - real_getpeername = dlsym(RTLD_NEXT, "getpeername"); - real_getsockname = dlsym(RTLD_NEXT, "getsockname"); - real_setsockopt = dlsym(RTLD_NEXT, "setsockopt"); - real_getsockopt = dlsym(RTLD_NEXT, "getsockopt"); - real_fcntl = dlsym(RTLD_NEXT, "fcntl"); + real.socket = dlsym(RTLD_NEXT, "socket"); + real.bind = dlsym(RTLD_NEXT, "bind"); + real.listen = dlsym(RTLD_NEXT, "listen"); + real.accept = dlsym(RTLD_NEXT, "accept"); + real.connect = dlsym(RTLD_NEXT, "connect"); + real.recv = dlsym(RTLD_NEXT, "recv"); + real.recvfrom = dlsym(RTLD_NEXT, "recvfrom"); + real.recvmsg = dlsym(RTLD_NEXT, "recvmsg"); + real.read = dlsym(RTLD_NEXT, "read"); + real.readv = dlsym(RTLD_NEXT, "readv"); + real.send = dlsym(RTLD_NEXT, "send"); + real.sendto = dlsym(RTLD_NEXT, "sendto"); + real.sendmsg = dlsym(RTLD_NEXT, "sendmsg"); + real.write = dlsym(RTLD_NEXT, "write"); + real.writev = dlsym(RTLD_NEXT, "writev"); + real.poll = dlsym(RTLD_NEXT, "poll"); + real.shutdown = dlsym(RTLD_NEXT, "shutdown"); + real.close = dlsym(RTLD_NEXT, "close"); + real.getpeername = dlsym(RTLD_NEXT, "getpeername"); + real.getsockname = dlsym(RTLD_NEXT, "getsockname"); + real.setsockopt = dlsym(RTLD_NEXT, "setsockopt"); + real.getsockopt = dlsym(RTLD_NEXT, "getsockopt"); + real.fcntl = dlsym(RTLD_NEXT, "fcntl"); + + rs.socket = dlsym(RTLD_DEFAULT, "rsocket"); + rs.bind = dlsym(RTLD_DEFAULT, "rbind"); + rs.listen = dlsym(RTLD_DEFAULT, "rlisten"); + rs.accept = dlsym(RTLD_DEFAULT, "raccept"); + rs.connect = dlsym(RTLD_DEFAULT, "rconnect"); + rs.recv = dlsym(RTLD_DEFAULT, "rrecv"); + rs.recvfrom = dlsym(RTLD_DEFAULT, "rrecvfrom"); + rs.recvmsg = dlsym(RTLD_DEFAULT, "rrecvmsg"); + rs.read = dlsym(RTLD_DEFAULT, "rread"); + rs.readv = dlsym(RTLD_DEFAULT, "rreadv"); + rs.send = dlsym(RTLD_DEFAULT, "rsend"); + rs.sendto = dlsym(RTLD_DEFAULT, "rsendto"); + rs.sendmsg = dlsym(RTLD_DEFAULT, "rsendmsg"); + rs.write = dlsym(RTLD_DEFAULT, "rwrite"); + rs.writev = dlsym(RTLD_DEFAULT, "rwritev"); + rs.poll = dlsym(RTLD_DEFAULT, "rpoll"); + rs.shutdown = dlsym(RTLD_DEFAULT, "rshutdown"); + rs.close = dlsym(RTLD_DEFAULT, "rclose"); + rs.getpeername = dlsym(RTLD_DEFAULT, "rgetpeername"); + rs.getsockname = dlsym(RTLD_DEFAULT, "rgetsockname"); + rs.setsockopt = dlsym(RTLD_DEFAULT, "rsetsockopt"); + rs.getsockopt = dlsym(RTLD_DEFAULT, "rgetsockopt"); + rs.fcntl = dlsym(RTLD_DEFAULT, "rfcntl"); getenv_options(); init = 1; @@ -252,51 +276,73 @@ out: } /* - * Convert from an rsocket to a normal socket. The new socket should have the - * same settings and bindings as the rsocket. We currently only handle setting - * a few of the more common values. + * We currently only handle copying a few common values. */ -static int socket_fallback(int socket, int *fd) +static int copysockopts(int dfd, int sfd, struct socket_calls *dapi, + struct socket_calls *sapi) { - socklen_t len = 0; - int new_fd, param, ret; + socklen_t len; + int param, ret; - ret = rgetsockname(*fd, NULL, &len); - if (ret) - return ret; - - param = (len == sizeof(struct sockaddr_in6)) ? PF_INET6 : PF_INET; - new_fd = real_socket(param, SOCK_STREAM, IPPROTO_TCP); - if (new_fd < 0) - return new_fd; - - ret = rfcntl(*fd, F_GETFL); + ret = sapi->fcntl(sfd, F_GETFL); if (ret > 0) - ret = real_fcntl(new_fd, F_SETFL, ret); + ret = dapi->fcntl(dfd, F_SETFL, ret); if (ret) - goto err; + return ret; len = sizeof param; - ret = rgetsockopt(*fd, SOL_SOCKET, SO_REUSEADDR, ¶m, &len); + ret = sapi->getsockopt(sfd, SOL_SOCKET, SO_REUSEADDR, ¶m, &len); if (param && !ret) - ret = real_setsockopt(new_fd, SOL_SOCKET, SO_REUSEADDR, ¶m, len); + ret = dapi->setsockopt(dfd, SOL_SOCKET, SO_REUSEADDR, ¶m, len); if (ret) - goto err; + return ret; len = sizeof param; - ret = rgetsockopt(*fd, IPPROTO_TCP, TCP_NODELAY, ¶m, &len); + ret = sapi->getsockopt(sfd, IPPROTO_TCP, TCP_NODELAY, ¶m, &len); if (param && !ret) - ret = real_setsockopt(new_fd, IPPROTO_TCP, TCP_NODELAY, ¶m, len); + ret = dapi->setsockopt(dfd, IPPROTO_TCP, TCP_NODELAY, ¶m, len); if (ret) - goto err; + return ret; - rclose(*fd); - fd_store(socket, new_fd, fd_normal); - *fd = new_fd; return 0; +} + +/* + * Convert between an rsocket and a normal socket. + */ +static int transpose_socket(int socket, enum fd_type new_type) +{ + socklen_t len = 0; + int sfd, dfd, param, ret; + struct socket_calls *sapi, *dapi; + + sfd = fd_getd(socket); + if (new_type == fd_rsocket) { + dapi = &rs; + sapi = ℜ + } else { + dapi = ℜ + sapi = &rs; + } + + ret = sapi->getsockname(sfd, NULL, &len); + if (ret) + return ret; + + param = (len == sizeof(struct sockaddr_in6)) ? PF_INET6 : PF_INET; + dfd = dapi->socket(param, SOCK_STREAM, 0); + if (dfd < 0) + return dfd; + + ret = copysockopts(dfd, sfd, dapi, sapi); + if (ret) + goto err; + + fd_store(socket, dfd, new_type); + return dfd; err: - real_close(new_fd); + dapi->close(dfd); return ret; } @@ -338,7 +384,7 @@ int socket(int domain, int type, int protocol) } fd_close(index, &ret); real: - return real_socket(domain, type, protocol); + return real.socket(domain, type, protocol); } int bind(int socket, const struct sockaddr *addr, socklen_t addrlen) @@ -351,19 +397,22 @@ int bind(int socket, const struct sockaddr *addr, socklen_t addrlen) if (!sin->sin_port || ntohs(sin->sin_port) > 1024) return rbind(fd, addr, addrlen); - ret = socket_fallback(socket, &fd); - if (ret) + ret = transpose_socket(socket, fd_normal); + if (ret < 0) return ret; + + rclose(fd); + fd = ret; } - return real_bind(fd, addr, addrlen); + return real.bind(fd, addr, addrlen); } int listen(int socket, int backlog) { int fd; return (fd_get(socket, &fd) == fd_rsocket) ? - rlisten(fd, backlog) : real_listen(fd, backlog); + rlisten(fd, backlog) : real.listen(fd, backlog); } int accept(int socket, struct sockaddr *addr, socklen_t *addrlen) @@ -384,7 +433,7 @@ int accept(int socket, struct sockaddr *addr, socklen_t *addrlen) fd_store(index, ret, fd_rsocket); return index; } else { - return real_accept(fd, addr, addrlen); + return real.accept(fd, addr, addrlen); } } @@ -401,19 +450,22 @@ int connect(int socket, const struct sockaddr *addr, socklen_t addrlen) return ret; } - ret = socket_fallback(socket, &fd); - if (ret) + ret = transpose_socket(socket, fd_normal); + if (ret < 0) return ret; + + rclose(fd); + fd = ret; } - return real_connect(fd, addr, addrlen); + return real.connect(fd, addr, addrlen); } ssize_t recv(int socket, void *buf, size_t len, int flags) { int fd; return (fd_get(socket, &fd) == fd_rsocket) ? - rrecv(fd, buf, len, flags) : real_recv(fd, buf, len, flags); + rrecv(fd, buf, len, flags) : real.recv(fd, buf, len, flags); } ssize_t recvfrom(int socket, void *buf, size_t len, int flags, @@ -422,14 +474,14 @@ ssize_t recvfrom(int socket, void *buf, size_t len, int flags, int fd; return (fd_get(socket, &fd) == fd_rsocket) ? rrecvfrom(fd, buf, len, flags, src_addr, addrlen) : - real_recvfrom(fd, buf, len, flags, src_addr, addrlen); + real.recvfrom(fd, buf, len, flags, src_addr, addrlen); } ssize_t recvmsg(int socket, struct msghdr *msg, int flags) { int fd; return (fd_get(socket, &fd) == fd_rsocket) ? - rrecvmsg(fd, msg, flags) : real_recvmsg(fd, msg, flags); + rrecvmsg(fd, msg, flags) : real.recvmsg(fd, msg, flags); } ssize_t read(int socket, void *buf, size_t count) @@ -437,7 +489,7 @@ ssize_t read(int socket, void *buf, size_t count) int fd; init_preload(); return (fd_get(socket, &fd) == fd_rsocket) ? - rread(fd, buf, count) : real_read(fd, buf, count); + rread(fd, buf, count) : real.read(fd, buf, count); } ssize_t readv(int socket, const struct iovec *iov, int iovcnt) @@ -445,14 +497,14 @@ ssize_t readv(int socket, const struct iovec *iov, int iovcnt) int fd; init_preload(); return (fd_get(socket, &fd) == fd_rsocket) ? - rreadv(fd, iov, iovcnt) : real_readv(fd, iov, iovcnt); + rreadv(fd, iov, iovcnt) : real.readv(fd, iov, iovcnt); } ssize_t send(int socket, const void *buf, size_t len, int flags) { int fd; return (fd_get(socket, &fd) == fd_rsocket) ? - rsend(fd, buf, len, flags) : real_send(fd, buf, len, flags); + rsend(fd, buf, len, flags) : real.send(fd, buf, len, flags); } ssize_t sendto(int socket, const void *buf, size_t len, int flags, @@ -461,14 +513,14 @@ ssize_t sendto(int socket, const void *buf, size_t len, int flags, int fd; return (fd_get(socket, &fd) == fd_rsocket) ? rsendto(fd, buf, len, flags, dest_addr, addrlen) : - real_sendto(fd, buf, len, flags, dest_addr, addrlen); + real.sendto(fd, buf, len, flags, dest_addr, addrlen); } ssize_t sendmsg(int socket, const struct msghdr *msg, int flags) { int fd; return (fd_get(socket, &fd) == fd_rsocket) ? - rsendmsg(fd, msg, flags) : real_sendmsg(fd, msg, flags); + rsendmsg(fd, msg, flags) : real.sendmsg(fd, msg, flags); } ssize_t write(int socket, const void *buf, size_t count) @@ -476,7 +528,7 @@ ssize_t write(int socket, const void *buf, size_t count) int fd; init_preload(); return (fd_get(socket, &fd) == fd_rsocket) ? - rwrite(fd, buf, count) : real_write(fd, buf, count); + rwrite(fd, buf, count) : real.write(fd, buf, count); } ssize_t writev(int socket, const struct iovec *iov, int iovcnt) @@ -484,7 +536,7 @@ ssize_t writev(int socket, const struct iovec *iov, int iovcnt) int fd; init_preload(); return (fd_get(socket, &fd) == fd_rsocket) ? - rwritev(fd, iov, iovcnt) : real_writev(fd, iov, iovcnt); + rwritev(fd, iov, iovcnt) : real.writev(fd, iov, iovcnt); } static struct pollfd *fds_alloc(nfds_t nfds) @@ -514,7 +566,7 @@ int poll(struct pollfd *fds, nfds_t nfds, int timeout) goto use_rpoll; } - return real_poll(fds, nfds, timeout); + return real.poll(fds, nfds, timeout); use_rpoll: rfds = fds_alloc(nfds); @@ -619,14 +671,14 @@ int shutdown(int socket, int how) { int fd; return (fd_get(socket, &fd) == fd_rsocket) ? - rshutdown(fd, how) : real_shutdown(fd, how); + rshutdown(fd, how) : real.shutdown(fd, how); } int close(int socket) { int fd; init_preload(); - return (fd_close(socket, &fd) == fd_rsocket) ? rclose(fd) : real_close(fd); + return (fd_close(socket, &fd) == fd_rsocket) ? rclose(fd) : real.close(fd); } int getpeername(int socket, struct sockaddr *addr, socklen_t *addrlen) @@ -634,7 +686,7 @@ int getpeername(int socket, struct sockaddr *addr, socklen_t *addrlen) int fd; return (fd_get(socket, &fd) == fd_rsocket) ? rgetpeername(fd, addr, addrlen) : - real_getpeername(fd, addr, addrlen); + real.getpeername(fd, addr, addrlen); } int getsockname(int socket, struct sockaddr *addr, socklen_t *addrlen) @@ -642,7 +694,7 @@ int getsockname(int socket, struct sockaddr *addr, socklen_t *addrlen) int fd; return (fd_get(socket, &fd) == fd_rsocket) ? rgetsockname(fd, addr, addrlen) : - real_getsockname(fd, addr, addrlen); + real.getsockname(fd, addr, addrlen); } int setsockopt(int socket, int level, int optname, @@ -651,7 +703,7 @@ int setsockopt(int socket, int level, int optname, int fd; return (fd_get(socket, &fd) == fd_rsocket) ? rsetsockopt(fd, level, optname, optval, optlen) : - real_setsockopt(fd, level, optname, optval, optlen); + real.setsockopt(fd, level, optname, optval, optlen); } int getsockopt(int socket, int level, int optname, @@ -660,7 +712,7 @@ int getsockopt(int socket, int level, int optname, int fd; return (fd_get(socket, &fd) == fd_rsocket) ? rgetsockopt(fd, level, optname, optval, optlen) : - real_getsockopt(fd, level, optname, optval, optlen); + real.getsockopt(fd, level, optname, optval, optlen); } int fcntl(int socket, int cmd, ... /* arg */) @@ -679,7 +731,7 @@ int fcntl(int socket, int cmd, ... /* arg */) case F_GETSIG: case F_GETLEASE: ret = (fd_get(socket, &fd) == fd_rsocket) ? - rfcntl(fd, cmd) : real_fcntl(fd, cmd); + rfcntl(fd, cmd) : real.fcntl(fd, cmd); break; case F_DUPFD: /*case F_DUPFD_CLOEXEC:*/ @@ -691,12 +743,12 @@ int fcntl(int socket, int cmd, ... /* arg */) case F_NOTIFY: lparam = va_arg(args, long); ret = (fd_get(socket, &fd) == fd_rsocket) ? - rfcntl(fd, cmd, lparam) : real_fcntl(fd, cmd, lparam); + rfcntl(fd, cmd, lparam) : real.fcntl(fd, cmd, lparam); break; default: pparam = va_arg(args, void *); ret = (fd_get(socket, &fd) == fd_rsocket) ? - rfcntl(fd, cmd, pparam) : real_fcntl(fd, cmd, pparam); + rfcntl(fd, cmd, pparam) : real.fcntl(fd, cmd, pparam); break; } va_end(args); -- 2.41.0