From: Sean Hefty Date: Tue, 26 Jun 2012 00:11:41 +0000 (-0700) Subject: Refresh of shut_wr X-Git-Url: https://openfabrics.org/gitweb/?a=commitdiff_plain;h=d363fe91691a296f6a2253bf1346ea113558b9fd;p=~shefty%2Flibrdmacm.git Refresh of shut_wr --- diff --git a/src/rsocket.c b/src/rsocket.c index ed994fe3..63bf03ea 100644 --- a/src/rsocket.c +++ b/src/rsocket.c @@ -96,7 +96,8 @@ enum { #define rs_msg_data(imm_data) (imm_data & 0x1FFFFFFF) enum { - RS_CTRL_DISCONNECT + RS_CTRL_DISCONNECT, + RS_CTRL_SHUTDOWN }; struct rs_msg { @@ -321,7 +322,7 @@ static int rs_set_nonblocking(struct rsocket *rs, long arg) if (rs->cm_id->recv_cq_channel) ret = fcntl(rs->cm_id->recv_cq_channel->fd, F_SETFL, arg); - if (!ret && rs->state != rs_connected) + if (!ret && rs->state < rs_connected) ret = fcntl(rs->cm_id->channel->fd, F_SETFL, arg); return ret; @@ -883,6 +884,8 @@ static int rs_poll_cq(struct rsocket *rs) if (rs_msg_data(imm_data) == RS_CTRL_DISCONNECT) { rs->state = rs_disconnected; return ERR(ECONNRESET); + } else if (rs_msg_data(imm_data) == RS_CTRL_SHUTDOWN) { + rs->state = rs_shutdown_rd; } break; default: @@ -907,7 +910,7 @@ static int rs_poll_cq(struct rsocket *rs) } } - if (rs->state == rs_connected) { + if (rs->state != rs_error) { while (!ret && rcnt--) ret = rdma_post_recvv(rs->cm_id, NULL, NULL, 0); @@ -932,7 +935,7 @@ static int rs_get_cq_event(struct rsocket *rs) if (!ret) { ibv_ack_cq_events(rs->cm_id->recv_cq, 1); rs->cq_armed = 0; - } else if (errno != EAGAIN && rs->state == rs_connected) { + } else if (errno != EAGAIN) { rs->state = rs_error; } @@ -1457,6 +1460,8 @@ static int rs_poll_rs(struct rsocket *rs, int events, } /* fall through */ case rs_connected: + case rs_shutdown_rd: + case rs_shutdown_wr: case rs_disconnected: case rs_error: rs_process_cq(rs, nonblock, test); @@ -1466,10 +1471,12 @@ static int rs_poll_rs(struct rsocket *rs, int events, revents |= POLLIN; if ((events & POLLOUT) && rs_can_send(rs)) revents |= POLLOUT; - if (rs->state == rs_disconnected) - revents |= POLLHUP; - else if (rs->state == rs_error) - revents |= POLLERR; + if (rs->state > rs_connected) { + if (rs->state == rs_error) + revents |= POLLERR; + else + revents |= POLLHUP; + } return revents; case rs_connect_error: @@ -1690,12 +1697,16 @@ int rshutdown(int socket, int how) struct rsocket *rs; int ret = 0; + if (how == SHUT_RD) + return 0; + rs = idm_at(&idm, socket); if (rs->fd_flags & O_NONBLOCK) rs_set_nonblocking(rs, 0); if (rs->state == rs_connected) { - rs->state = rs_disconnected; + if (how == SHUT_RDWR) + rs->state = rs_disconnected; if (!rs_can_send_ctrl(rs)) { ret = rs_process_cq(rs, 0, rs_can_send_ctrl); if (ret) @@ -1711,6 +1722,9 @@ int rshutdown(int socket, int how) if (!rs_all_sends_done(rs) && rs->state != rs_error) rs_process_cq(rs, 0, rs_all_sends_done); + if ((rs->fd_flags & O_NONBLOCK) && (how == SHUT_WR)) + rs_set_nonblocking(rs, 1); + return 0; }