--- /dev/null
+Bottom: af4cc3c64b7827f75b1392c03c0414ad0a869bca
+Top: 1242d5df80aef5749c32bdb6490d768bfdd28ffe
+Author: Sean Hefty <sean.hefty@intel.com>
+Date: 2012-06-28 11:19:00 -0700
+
+Refresh of shut_wr
+
+---
+
+diff --git a/src/rsocket.c b/src/rsocket.c
+index 012bb5e..bdb756f 100644
+--- a/src/rsocket.c
++++ b/src/rsocket.c
+@@ -96,7 +96,8 @@ enum {
+ #define rs_msg_data(imm_data) (imm_data & 0x1FFFFFFF)
+
+ enum {
+- RS_CTRL_DISCONNECT
++ RS_CTRL_DISCONNECT,
++ RS_CTRL_SHUTDOWN
+ };
+
+ struct rs_msg {
+@@ -136,16 +137,20 @@ union rs_wr_id {
+ */
+ enum rs_state {
+ rs_init,
+- rs_bound,
+- rs_listening,
+- rs_resolving_addr,
+- rs_resolving_route,
+- rs_connecting,
+- rs_accepting,
+- rs_connect_error,
+- rs_connected,
+- rs_disconnected,
+- rs_error
++ rs_bound = 0x0001,
++ rs_listening = 0x0002,
++ rs_opening = 0x0004,
++ rs_resolving_addr = rs_opening | 0x0010,
++ rs_resolving_route = rs_opening | 0x0020,
++ rs_connecting = rs_opening | 0x0040,
++ rs_accepting = rs_opening | 0x0080,
++ rs_connected = 0x0100,
++ rs_connect_wr = 0x0200,
++ rs_connect_rd = 0x0400,
++ rs_connect_rdwr = rs_connected | rs_connect_rd | rs_connect_wr,
++ rs_connect_error = 0x0800,
++ rs_disconnected = 0x1000,
++ rs_error = 0x2000,
+ };
+
+ #define RS_OPT_SWAP_SGL 1
+@@ -161,7 +166,7 @@ struct rsocket {
+ long fd_flags;
+ uint64_t so_opts;
+ uint64_t tcp_opts;
+- enum rs_state state;
++ int state;
+ int cq_armed;
+ int retries;
+ int err;
+@@ -320,7 +325,7 @@ static int rs_set_nonblocking(struct rsocket *rs, long arg)
+ if (rs->cm_id->recv_cq_channel)
+ ret = fcntl(rs->cm_id->recv_cq_channel->fd, F_SETFL, arg);
+
+- if (!ret && rs->state != rs_connected)
++ if (!ret && rs->state < rs_connected)
+ ret = fcntl(rs->cm_id->channel->fd, F_SETFL, arg);
+
+ return ret;
+@@ -627,7 +632,7 @@ int raccept(int socket, struct sockaddr *addr, socklen_t *addrlen)
+ rs_set_conn_data(new_rs, ¶m, &cresp);
+ ret = rdma_accept(new_rs->cm_id, ¶m);
+ if (!ret)
+- new_rs->state = rs_connected;
++ new_rs->state = rs_connect_rdwr;
+ else if (errno == EAGAIN || errno == EWOULDBLOCK)
+ new_rs->state = rs_accepting;
+ else
+@@ -714,7 +719,7 @@ connected:
+ }
+
+ rs_save_conn_data(rs, cresp);
+- rs->state = rs_connected;
++ rs->state = rs_connect_rdwr;
+ break;
+ case rs_accepting:
+ if (!(rs->fd_flags & O_NONBLOCK))
+@@ -724,7 +729,7 @@ connected:
+ if (ret)
+ break;
+
+- rs->state = rs_connected;
++ rs->state = rs_connect_rdwr;
+ break;
+ default:
+ ret = ERR(EINVAL);
+@@ -751,6 +756,13 @@ int rconnect(int socket, const struct sockaddr *addr, socklen_t addrlen)
+ return rs_do_connect(rs);
+ }
+
++static void rs_shutdown_state(struct rsocket *rs, int state)
++{
++ rs->state &= ~state;
++ if (rs->state == rs_connected)
++ rs->state = rs_disconnected;
++}
++
+ static int rs_post_write(struct rsocket *rs, uint64_t wr_id,
+ struct ibv_sge *sgl, int nsge,
+ uint32_t imm_data, int flags,
+@@ -851,7 +863,7 @@ static int rs_give_credits(struct rsocket *rs)
+ {
+ return ((rs->rbuf_bytes_avail >= (rs->rbuf_size >> 1)) ||
+ ((short) ((short) rs->rseq_no - (short) rs->rseq_comp) >= 0)) &&
+- rs->ctrl_avail && (rs->state == rs_connected);
++ rs->ctrl_avail && (rs->state & rs_connected);
+ }
+
+ static void rs_update_credits(struct rsocket *rs)
+@@ -881,7 +893,9 @@ static int rs_poll_cq(struct rsocket *rs)
+ case RS_OP_CTRL:
+ if (rs_msg_data(imm_data) == RS_CTRL_DISCONNECT) {
+ rs->state = rs_disconnected;
+- return ERR(ECONNRESET);
++ return 0;
++ } else if (rs_msg_data(imm_data) == RS_CTRL_SHUTDOWN) {
++ rs_shutdown_state(rs, rs_connect_rd);
+ }
+ break;
+ default:
+@@ -899,14 +913,14 @@ static int rs_poll_cq(struct rsocket *rs)
+ } else {
+ rs->ctrl_avail++;
+ }
+- if (wc.status != IBV_WC_SUCCESS && rs->state == rs_connected) {
++ if (wc.status != IBV_WC_SUCCESS && (rs->state & rs_connected)) {
+ rs->state = rs_error;
+ rs->err = EIO;
+ }
+ }
+ }
+
+- if (rs->state == rs_connected) {
++ if (rs->state & rs_connected) {
+ while (!ret && rcnt--)
+ ret = rdma_post_recvv(rs->cm_id, NULL, NULL, 0);
+
+@@ -931,7 +945,7 @@ static int rs_get_cq_event(struct rsocket *rs)
+ if (!ret) {
+ ibv_ack_cq_events(rs->cm_id->recv_cq, 1);
+ rs->cq_armed = 0;
+- } else if (errno != EAGAIN && rs->state == rs_connected) {
++ } else if (errno != EAGAIN) {
+ rs->state = rs_error;
+ }
+
+@@ -1042,7 +1056,7 @@ static int rs_can_send(struct rsocket *rs)
+
+ static int rs_conn_can_send(struct rsocket *rs)
+ {
+- return rs_can_send(rs) || (rs->state != rs_connected);
++ return rs_can_send(rs) || !(rs->state & rs_connect_wr);
+ }
+
+ static int rs_can_send_ctrl(struct rsocket *rs)
+@@ -1057,7 +1071,7 @@ static int rs_have_rdata(struct rsocket *rs)
+
+ static int rs_conn_have_rdata(struct rsocket *rs)
+ {
+- return rs_have_rdata(rs) || (rs->state != rs_connected);
++ return rs_have_rdata(rs) || !(rs->state & rs_connect_rd);
+ }
+
+ static int rs_all_sends_done(struct rsocket *rs)
+@@ -1110,7 +1124,7 @@ ssize_t rrecv(int socket, void *buf, size_t len, int flags)
+ int ret;
+
+ rs = idm_at(&idm, socket);
+- if (rs->state < rs_connected) {
++ if (rs->state & rs_opening) {
+ ret = rs_do_connect(rs);
+ if (ret) {
+ if (errno == EINPROGRESS)
+@@ -1121,7 +1135,7 @@ ssize_t rrecv(int socket, void *buf, size_t len, int flags)
+ fastlock_acquire(&rs->rlock);
+ if (!rs_have_rdata(rs)) {
+ ret = rs_get_comp(rs, rs_nonblocking(rs, flags), rs_conn_have_rdata);
+- if (ret && errno != ECONNRESET)
++ if (ret)
+ goto out;
+ }
+
+@@ -1212,7 +1226,7 @@ ssize_t rsend(int socket, const void *buf, size_t len, int flags)
+ int ret = 0;
+
+ rs = idm_at(&idm, socket);
+- if (rs->state < rs_connected) {
++ if (rs->state & rs_opening) {
+ ret = rs_do_connect(rs);
+ if (ret) {
+ if (errno == EINPROGRESS)
+@@ -1228,7 +1242,7 @@ ssize_t rsend(int socket, const void *buf, size_t len, int flags)
+ rs_conn_can_send);
+ if (ret)
+ break;
+- if (rs->state != rs_connected) {
++ if (!(rs->state & rs_connect_wr)) {
+ ret = ERR(ECONNRESET);
+ break;
+ }
+@@ -1321,7 +1335,7 @@ static ssize_t rsendv(int socket, const struct iovec *iov, int iovcnt, int flags
+ int i, ret = 0;
+
+ rs = idm_at(&idm, socket);
+- if (rs->state < rs_connected) {
++ if (rs->state & rs_opening) {
+ ret = rs_do_connect(rs);
+ if (ret) {
+ if (errno == EINPROGRESS)
+@@ -1342,7 +1356,7 @@ static ssize_t rsendv(int socket, const struct iovec *iov, int iovcnt, int flags
+ rs_conn_can_send);
+ if (ret)
+ break;
+- if (rs->state != rs_connected) {
++ if (!(rs->state & rs_connect_wr)) {
+ ret = ERR(ECONNRESET);
+ break;
+ }
+@@ -1434,17 +1448,35 @@ static int rs_poll_rs(struct rsocket *rs, int events,
+ short revents;
+ int ret;
+
+- switch (rs->state) {
+- case rs_listening:
++check_cq:
++ if ((rs->state & rs_connected) || (rs->state == rs_disconnected) ||
++ (rs->state & rs_error)) {
++ rs_process_cq(rs, nonblock, test);
++
++ revents = 0;
++ if ((events & POLLIN) && rs_conn_have_rdata(rs))
++ revents |= POLLIN;
++ if ((events & POLLOUT) && rs_can_send(rs))
++ revents |= POLLOUT;
++ if (!(rs->state & rs_connected)) {
++ if (rs->state == rs_disconnected)
++ revents |= POLLHUP;
++ else
++ revents |= POLLERR;
++ }
++
++ return revents;
++ }
++
++ if (rs->state == rs_listening) {
+ fds.fd = rs->cm_id->channel->fd;
+ fds.events = events;
+ fds.revents = 0;
+ poll(&fds, 1, 0);
+ return fds.revents;
+- case rs_resolving_addr:
+- case rs_resolving_route:
+- case rs_connecting:
+- case rs_accepting:
++ }
++
++ if (rs->state & rs_opening) {
+ ret = rs_do_connect(rs);
+ if (ret) {
+ if (errno == EINPROGRESS) {
+@@ -1454,28 +1486,13 @@ static int rs_poll_rs(struct rsocket *rs, int events,
+ return POLLOUT;
+ }
+ }
+- /* fall through */
+- case rs_connected:
+- case rs_disconnected:
+- case rs_error:
+- rs_process_cq(rs, nonblock, test);
+-
+- revents = 0;
+- if ((events & POLLIN) && rs_have_rdata(rs))
+- revents |= POLLIN;
+- if ((events & POLLOUT) && rs_can_send(rs))
+- revents |= POLLOUT;
+- if (rs->state == rs_disconnected)
+- revents |= POLLHUP;
+- else if (rs->state == rs_error)
+- revents |= POLLERR;
++ goto check_cq;
++ }
+
+- return revents;
+- case rs_connect_error:
++ if (rs->state == rs_connect_error)
+ return (rs->err && events & POLLOUT) ? POLLOUT : 0;
+- default:
+- return 0;
+- }
++
++ return 0;
+ }
+
+ static int rs_poll_check(struct pollfd *fds, nfds_t nfds)
+@@ -1687,14 +1704,26 @@ int rselect(int nfds, fd_set *readfds, fd_set *writefds,
+ int rshutdown(int socket, int how)
+ {
+ struct rsocket *rs;
+- int ret = 0;
++ int ctrl, ret = 0;
+
+ rs = idm_at(&idm, socket);
++ if (how == SHUT_RD) {
++ rs_shutdown_state(rs, rs_connect_rd);
++ return 0;
++ }
++
+ if (rs->fd_flags & O_NONBLOCK)
+ rs_set_nonblocking(rs, 0);
+
+- if (rs->state == rs_connected) {
+- rs->state = rs_disconnected;
++ if (rs->state & rs_connected) {
++ if (how == SHUT_RDWR) {
++ ctrl = RS_CTRL_DISCONNECT;
++ rs->state = rs_disconnected;
++ } else {
++ rs_shutdown_state(rs, rs_connect_wr);
++ ctrl = (rs->state & rs_connected) ?
++ RS_CTRL_SHUTDOWN : RS_CTRL_DISCONNECT;
++ }
+ if (!rs_can_send_ctrl(rs)) {
+ ret = rs_process_cq(rs, 0, rs_can_send_ctrl);
+ if (ret)
+@@ -1703,13 +1732,16 @@ int rshutdown(int socket, int how)
+
+ rs->ctrl_avail--;
+ ret = rs_post_write(rs, 0, NULL, 0,
+- rs_msg_set(RS_OP_CTRL, RS_CTRL_DISCONNECT),
++ rs_msg_set(RS_OP_CTRL, ctrl),
+ 0, 0, 0);
+ }
+
+- if (!rs_all_sends_done(rs) && rs->state != rs_error)
++ if (!rs_all_sends_done(rs) && !(rs->state & rs_error))
+ rs_process_cq(rs, 0, rs_all_sends_done);
+
++ if ((rs->fd_flags & O_NONBLOCK) && (rs->state & rs_connected))
++ rs_set_nonblocking(rs, 1);
++
+ return 0;
+ }
+
+@@ -1718,7 +1750,7 @@ int rclose(int socket)
+ struct rsocket *rs;
+
+ rs = idm_at(&idm, socket);
+- if (rs->state == rs_connected)
++ if (rs->state & rs_connected)
+ rshutdown(socket, SHUT_RDWR);
+
+ rs_free(rs);
+@@ -1818,7 +1850,7 @@ int rsetsockopt(int socket, int level, int optname,
+ }
+ break;
+ case SOL_RDMA:
+- if (rs->state > rs_listening) {
++ if (rs->state >= rs_opening) {
+ ret = ERR(EINVAL);
+ break;
+ }