From 2e5b0fc95964f74ea59dd725e849027faa0cd526 Mon Sep 17 00:00:00 2001 From: Sean Hefty Date: Mon, 25 Jun 2012 14:19:54 -0700 Subject: [PATCH] rsocket: Handle other shutdown option Handle SHUT_RD and SHUT_WR shutdown options. In order to handle shutting down the send and receive sides separately, we break the connection state into multiple sub-states. This allows us to be partially connected (i.e. for either just reads or just writes). Support for SHUT_WR is needed to handle netperf properly, which shuts down a socket by having the client use SHUT_WR, followed by the server completing the disconnect with SHUT_RDWR. The following patch eliminates an error message from netperf: 'shutdown_control: no response received errno 95' Signed-off-by: Sean Hefty --- src/rsocket.c | 156 ++++++++++++++++++++++++++++++-------------------- 1 file changed, 94 insertions(+), 62 deletions(-) diff --git a/src/rsocket.c b/src/rsocket.c index 012bb5ef..bdb756f9 100644 --- a/src/rsocket.c +++ b/src/rsocket.c @@ -96,7 +96,8 @@ enum { #define rs_msg_data(imm_data) (imm_data & 0x1FFFFFFF) enum { - RS_CTRL_DISCONNECT + RS_CTRL_DISCONNECT, + RS_CTRL_SHUTDOWN }; struct rs_msg { @@ -136,16 +137,20 @@ union rs_wr_id { */ enum rs_state { rs_init, - rs_bound, - rs_listening, - rs_resolving_addr, - rs_resolving_route, - rs_connecting, - rs_accepting, - rs_connect_error, - rs_connected, - rs_disconnected, - rs_error + rs_bound = 0x0001, + rs_listening = 0x0002, + rs_opening = 0x0004, + rs_resolving_addr = rs_opening | 0x0010, + rs_resolving_route = rs_opening | 0x0020, + rs_connecting = rs_opening | 0x0040, + rs_accepting = rs_opening | 0x0080, + rs_connected = 0x0100, + rs_connect_wr = 0x0200, + rs_connect_rd = 0x0400, + rs_connect_rdwr = rs_connected | rs_connect_rd | rs_connect_wr, + rs_connect_error = 0x0800, + rs_disconnected = 0x1000, + rs_error = 0x2000, }; #define RS_OPT_SWAP_SGL 1 @@ -161,7 +166,7 @@ struct rsocket { long fd_flags; uint64_t so_opts; uint64_t tcp_opts; - enum rs_state state; + int state; int cq_armed; int retries; int err; @@ -320,7 +325,7 @@ static int rs_set_nonblocking(struct rsocket *rs, long arg) if (rs->cm_id->recv_cq_channel) ret = fcntl(rs->cm_id->recv_cq_channel->fd, F_SETFL, arg); - if (!ret && rs->state != rs_connected) + if (!ret && rs->state < rs_connected) ret = fcntl(rs->cm_id->channel->fd, F_SETFL, arg); return ret; @@ -627,7 +632,7 @@ int raccept(int socket, struct sockaddr *addr, socklen_t *addrlen) rs_set_conn_data(new_rs, ¶m, &cresp); ret = rdma_accept(new_rs->cm_id, ¶m); if (!ret) - new_rs->state = rs_connected; + new_rs->state = rs_connect_rdwr; else if (errno == EAGAIN || errno == EWOULDBLOCK) new_rs->state = rs_accepting; else @@ -714,7 +719,7 @@ connected: } rs_save_conn_data(rs, cresp); - rs->state = rs_connected; + rs->state = rs_connect_rdwr; break; case rs_accepting: if (!(rs->fd_flags & O_NONBLOCK)) @@ -724,7 +729,7 @@ connected: if (ret) break; - rs->state = rs_connected; + rs->state = rs_connect_rdwr; break; default: ret = ERR(EINVAL); @@ -751,6 +756,13 @@ int rconnect(int socket, const struct sockaddr *addr, socklen_t addrlen) return rs_do_connect(rs); } +static void rs_shutdown_state(struct rsocket *rs, int state) +{ + rs->state &= ~state; + if (rs->state == rs_connected) + rs->state = rs_disconnected; +} + static int rs_post_write(struct rsocket *rs, uint64_t wr_id, struct ibv_sge *sgl, int nsge, uint32_t imm_data, int flags, @@ -851,7 +863,7 @@ static int rs_give_credits(struct rsocket *rs) { return ((rs->rbuf_bytes_avail >= (rs->rbuf_size >> 1)) || ((short) ((short) rs->rseq_no - (short) rs->rseq_comp) >= 0)) && - rs->ctrl_avail && (rs->state == rs_connected); + rs->ctrl_avail && (rs->state & rs_connected); } static void rs_update_credits(struct rsocket *rs) @@ -881,7 +893,9 @@ static int rs_poll_cq(struct rsocket *rs) case RS_OP_CTRL: if (rs_msg_data(imm_data) == RS_CTRL_DISCONNECT) { rs->state = rs_disconnected; - return ERR(ECONNRESET); + return 0; + } else if (rs_msg_data(imm_data) == RS_CTRL_SHUTDOWN) { + rs_shutdown_state(rs, rs_connect_rd); } break; default: @@ -899,14 +913,14 @@ static int rs_poll_cq(struct rsocket *rs) } else { rs->ctrl_avail++; } - if (wc.status != IBV_WC_SUCCESS && rs->state == rs_connected) { + if (wc.status != IBV_WC_SUCCESS && (rs->state & rs_connected)) { rs->state = rs_error; rs->err = EIO; } } } - if (rs->state == rs_connected) { + if (rs->state & rs_connected) { while (!ret && rcnt--) ret = rdma_post_recvv(rs->cm_id, NULL, NULL, 0); @@ -931,7 +945,7 @@ static int rs_get_cq_event(struct rsocket *rs) if (!ret) { ibv_ack_cq_events(rs->cm_id->recv_cq, 1); rs->cq_armed = 0; - } else if (errno != EAGAIN && rs->state == rs_connected) { + } else if (errno != EAGAIN) { rs->state = rs_error; } @@ -1042,7 +1056,7 @@ static int rs_can_send(struct rsocket *rs) static int rs_conn_can_send(struct rsocket *rs) { - return rs_can_send(rs) || (rs->state != rs_connected); + return rs_can_send(rs) || !(rs->state & rs_connect_wr); } static int rs_can_send_ctrl(struct rsocket *rs) @@ -1057,7 +1071,7 @@ static int rs_have_rdata(struct rsocket *rs) static int rs_conn_have_rdata(struct rsocket *rs) { - return rs_have_rdata(rs) || (rs->state != rs_connected); + return rs_have_rdata(rs) || !(rs->state & rs_connect_rd); } static int rs_all_sends_done(struct rsocket *rs) @@ -1110,7 +1124,7 @@ ssize_t rrecv(int socket, void *buf, size_t len, int flags) int ret; rs = idm_at(&idm, socket); - if (rs->state < rs_connected) { + if (rs->state & rs_opening) { ret = rs_do_connect(rs); if (ret) { if (errno == EINPROGRESS) @@ -1121,7 +1135,7 @@ ssize_t rrecv(int socket, void *buf, size_t len, int flags) fastlock_acquire(&rs->rlock); if (!rs_have_rdata(rs)) { ret = rs_get_comp(rs, rs_nonblocking(rs, flags), rs_conn_have_rdata); - if (ret && errno != ECONNRESET) + if (ret) goto out; } @@ -1212,7 +1226,7 @@ ssize_t rsend(int socket, const void *buf, size_t len, int flags) int ret = 0; rs = idm_at(&idm, socket); - if (rs->state < rs_connected) { + if (rs->state & rs_opening) { ret = rs_do_connect(rs); if (ret) { if (errno == EINPROGRESS) @@ -1228,7 +1242,7 @@ ssize_t rsend(int socket, const void *buf, size_t len, int flags) rs_conn_can_send); if (ret) break; - if (rs->state != rs_connected) { + if (!(rs->state & rs_connect_wr)) { ret = ERR(ECONNRESET); break; } @@ -1321,7 +1335,7 @@ static ssize_t rsendv(int socket, const struct iovec *iov, int iovcnt, int flags int i, ret = 0; rs = idm_at(&idm, socket); - if (rs->state < rs_connected) { + if (rs->state & rs_opening) { ret = rs_do_connect(rs); if (ret) { if (errno == EINPROGRESS) @@ -1342,7 +1356,7 @@ static ssize_t rsendv(int socket, const struct iovec *iov, int iovcnt, int flags rs_conn_can_send); if (ret) break; - if (rs->state != rs_connected) { + if (!(rs->state & rs_connect_wr)) { ret = ERR(ECONNRESET); break; } @@ -1434,17 +1448,35 @@ static int rs_poll_rs(struct rsocket *rs, int events, short revents; int ret; - switch (rs->state) { - case rs_listening: +check_cq: + if ((rs->state & rs_connected) || (rs->state == rs_disconnected) || + (rs->state & rs_error)) { + rs_process_cq(rs, nonblock, test); + + revents = 0; + if ((events & POLLIN) && rs_conn_have_rdata(rs)) + revents |= POLLIN; + if ((events & POLLOUT) && rs_can_send(rs)) + revents |= POLLOUT; + if (!(rs->state & rs_connected)) { + if (rs->state == rs_disconnected) + revents |= POLLHUP; + else + revents |= POLLERR; + } + + return revents; + } + + if (rs->state == rs_listening) { fds.fd = rs->cm_id->channel->fd; fds.events = events; fds.revents = 0; poll(&fds, 1, 0); return fds.revents; - case rs_resolving_addr: - case rs_resolving_route: - case rs_connecting: - case rs_accepting: + } + + if (rs->state & rs_opening) { ret = rs_do_connect(rs); if (ret) { if (errno == EINPROGRESS) { @@ -1454,28 +1486,13 @@ static int rs_poll_rs(struct rsocket *rs, int events, return POLLOUT; } } - /* fall through */ - case rs_connected: - case rs_disconnected: - case rs_error: - rs_process_cq(rs, nonblock, test); - - revents = 0; - if ((events & POLLIN) && rs_have_rdata(rs)) - revents |= POLLIN; - if ((events & POLLOUT) && rs_can_send(rs)) - revents |= POLLOUT; - if (rs->state == rs_disconnected) - revents |= POLLHUP; - else if (rs->state == rs_error) - revents |= POLLERR; + goto check_cq; + } - return revents; - case rs_connect_error: + if (rs->state == rs_connect_error) return (rs->err && events & POLLOUT) ? POLLOUT : 0; - default: - return 0; - } + + return 0; } static int rs_poll_check(struct pollfd *fds, nfds_t nfds) @@ -1687,14 +1704,26 @@ int rselect(int nfds, fd_set *readfds, fd_set *writefds, int rshutdown(int socket, int how) { struct rsocket *rs; - int ret = 0; + int ctrl, ret = 0; rs = idm_at(&idm, socket); + if (how == SHUT_RD) { + rs_shutdown_state(rs, rs_connect_rd); + return 0; + } + if (rs->fd_flags & O_NONBLOCK) rs_set_nonblocking(rs, 0); - if (rs->state == rs_connected) { - rs->state = rs_disconnected; + if (rs->state & rs_connected) { + if (how == SHUT_RDWR) { + ctrl = RS_CTRL_DISCONNECT; + rs->state = rs_disconnected; + } else { + rs_shutdown_state(rs, rs_connect_wr); + ctrl = (rs->state & rs_connected) ? + RS_CTRL_SHUTDOWN : RS_CTRL_DISCONNECT; + } if (!rs_can_send_ctrl(rs)) { ret = rs_process_cq(rs, 0, rs_can_send_ctrl); if (ret) @@ -1703,13 +1732,16 @@ int rshutdown(int socket, int how) rs->ctrl_avail--; ret = rs_post_write(rs, 0, NULL, 0, - rs_msg_set(RS_OP_CTRL, RS_CTRL_DISCONNECT), + rs_msg_set(RS_OP_CTRL, ctrl), 0, 0, 0); } - if (!rs_all_sends_done(rs) && rs->state != rs_error) + if (!rs_all_sends_done(rs) && !(rs->state & rs_error)) rs_process_cq(rs, 0, rs_all_sends_done); + if ((rs->fd_flags & O_NONBLOCK) && (rs->state & rs_connected)) + rs_set_nonblocking(rs, 1); + return 0; } @@ -1718,7 +1750,7 @@ int rclose(int socket) struct rsocket *rs; rs = idm_at(&idm, socket); - if (rs->state == rs_connected) + if (rs->state & rs_connected) rshutdown(socket, SHUT_RDWR); rs_free(rs); @@ -1818,7 +1850,7 @@ int rsetsockopt(int socket, int level, int optname, } break; case SOL_RDMA: - if (rs->state > rs_listening) { + if (rs->state >= rs_opening) { ret = ERR(EINVAL); break; } -- 2.41.0