]> git.openfabrics.org - ~shefty/librdmacm.git/commitdiff
rsocket: Handle SHUT_WR shutdown option
authorSean Hefty <sean.hefty@intel.com>
Mon, 25 Jun 2012 21:19:54 +0000 (14:19 -0700)
committerSean Hefty <sean.hefty@intel.com>
Mon, 25 Jun 2012 22:14:30 +0000 (15:14 -0700)
Signed-off-by: Sean Hefty <sean.hefty@intel.com>
src/rsocket.c

index ed994fe30829d495f5cfd466463e12c7bebd1751..63bf03ea998fb7aa6db3f2003d4680b4ee141755 100644 (file)
@@ -96,7 +96,8 @@ enum {
 #define rs_msg_data(imm_data) (imm_data & 0x1FFFFFFF)
 
 enum {
-       RS_CTRL_DISCONNECT
+       RS_CTRL_DISCONNECT,
+       RS_CTRL_SHUTDOWN
 };
 
 struct rs_msg {
@@ -321,7 +322,7 @@ static int rs_set_nonblocking(struct rsocket *rs, long arg)
        if (rs->cm_id->recv_cq_channel)
                ret = fcntl(rs->cm_id->recv_cq_channel->fd, F_SETFL, arg);
 
-       if (!ret && rs->state != rs_connected)
+       if (!ret && rs->state < rs_connected)
                ret = fcntl(rs->cm_id->channel->fd, F_SETFL, arg);
 
        return ret;
@@ -883,6 +884,8 @@ static int rs_poll_cq(struct rsocket *rs)
                                if (rs_msg_data(imm_data) == RS_CTRL_DISCONNECT) {
                                        rs->state = rs_disconnected;
                                        return ERR(ECONNRESET);
+                               } else if (rs_msg_data(imm_data) == RS_CTRL_SHUTDOWN) {
+                                       rs->state = rs_shutdown_rd;
                                }
                                break;
                        default:
@@ -907,7 +910,7 @@ static int rs_poll_cq(struct rsocket *rs)
                }
        }
 
-       if (rs->state == rs_connected) {
+       if (rs->state != rs_error) {
                while (!ret && rcnt--)
                        ret = rdma_post_recvv(rs->cm_id, NULL, NULL, 0);
 
@@ -932,7 +935,7 @@ static int rs_get_cq_event(struct rsocket *rs)
        if (!ret) {
                ibv_ack_cq_events(rs->cm_id->recv_cq, 1);
                rs->cq_armed = 0;
-       } else if (errno != EAGAIN && rs->state == rs_connected) {
+       } else if (errno != EAGAIN) {
                rs->state = rs_error;
        }
 
@@ -1457,6 +1460,8 @@ static int rs_poll_rs(struct rsocket *rs, int events,
                }
                /* fall through */
        case rs_connected:
+       case rs_shutdown_rd:
+       case rs_shutdown_wr:
        case rs_disconnected:
        case rs_error:
                rs_process_cq(rs, nonblock, test);
@@ -1466,10 +1471,12 @@ static int rs_poll_rs(struct rsocket *rs, int events,
                        revents |= POLLIN;
                if ((events & POLLOUT) && rs_can_send(rs))
                        revents |= POLLOUT;
-               if (rs->state == rs_disconnected)
-                       revents |= POLLHUP;
-               else if (rs->state == rs_error)
-                       revents |= POLLERR;
+               if (rs->state > rs_connected) {
+                       if (rs->state == rs_error)
+                               revents |= POLLERR;
+                       else
+                               revents |= POLLHUP;
+               }
 
                return revents;
        case rs_connect_error:
@@ -1690,12 +1697,16 @@ int rshutdown(int socket, int how)
        struct rsocket *rs;
        int ret = 0;
 
+       if (how == SHUT_RD)
+               return 0;
+
        rs = idm_at(&idm, socket);
        if (rs->fd_flags & O_NONBLOCK)
                rs_set_nonblocking(rs, 0);
 
        if (rs->state == rs_connected) {
-               rs->state = rs_disconnected;
+               if (how == SHUT_RDWR)
+                       rs->state = rs_disconnected;
                if (!rs_can_send_ctrl(rs)) {
                        ret = rs_process_cq(rs, 0, rs_can_send_ctrl);
                        if (ret)
@@ -1711,6 +1722,9 @@ int rshutdown(int socket, int how)
        if (!rs_all_sends_done(rs) && rs->state != rs_error)
                rs_process_cq(rs, 0, rs_all_sends_done);
 
+       if ((rs->fd_flags & O_NONBLOCK) && (how == SHUT_WR))
+               rs_set_nonblocking(rs, 1);
+
        return 0;
 }