]> git.openfabrics.org - ~shefty/librdmacm.git/commitdiff
rsocket: Handle SHUT_WR shutdown option
authorSean Hefty <sean.hefty@intel.com>
Mon, 25 Jun 2012 21:19:54 +0000 (14:19 -0700)
committerSean Hefty <sean.hefty@intel.com>
Tue, 26 Jun 2012 23:40:09 +0000 (16:40 -0700)
Signed-off-by: Sean Hefty <sean.hefty@intel.com>
src/rsocket.c

index c833d46d8ed4f9bfe5a9ea7da167037ff216cd3d..4eb1ce1eab18bbca2780648911ca35b1741f8fc5 100644 (file)
@@ -96,7 +96,8 @@ enum {
 #define rs_msg_data(imm_data) (imm_data & 0x1FFFFFFF)
 
 enum {
-       RS_CTRL_DISCONNECT
+       RS_CTRL_DISCONNECT,
+       RS_CTRL_SHUTDOWN
 };
 
 struct rs_msg {
@@ -136,16 +137,20 @@ union rs_wr_id {
  */
 enum rs_state {
        rs_init,
-       rs_bound,
-       rs_listening,
-       rs_resolving_addr,
-       rs_resolving_route,
-       rs_connecting,
-       rs_accepting,
-       rs_connect_error,
-       rs_connected,
-       rs_disconnected,
-       rs_error
+       rs_bound           =                0x0001,
+       rs_listening       =                0x0002,
+       rs_opening         =                0x0004,
+       rs_resolving_addr  = rs_opening |   0x0010,
+       rs_resolving_route = rs_opening |   0x0020,
+       rs_connecting      = rs_opening |   0x0040,
+       rs_accepting       = rs_opening |   0x0080,
+       rs_connected       =                0x0100,
+       rs_connect_wr      = rs_connected | 0x0200,
+       rs_connect_rd      = rs_connected | 0x0400,
+       rs_connect_rdwr    = rs_connect_rd | rs_connect_wr,
+       rs_connect_error   =                0x0800,
+       rs_disconnected    =                0x1000,
+       rs_error           =                0x2000,
 };
 
 #define RS_OPT_SWAP_SGL 1
@@ -321,7 +326,7 @@ static int rs_set_nonblocking(struct rsocket *rs, long arg)
        if (rs->cm_id->recv_cq_channel)
                ret = fcntl(rs->cm_id->recv_cq_channel->fd, F_SETFL, arg);
 
-       if (!ret && rs->state != rs_connected)
+       if (!ret && rs->state < rs_connected)
                ret = fcntl(rs->cm_id->channel->fd, F_SETFL, arg);
 
        return ret;
@@ -628,7 +633,7 @@ int raccept(int socket, struct sockaddr *addr, socklen_t *addrlen)
        rs_set_conn_data(new_rs, &param, &cresp);
        ret = rdma_accept(new_rs->cm_id, &param);
        if (!ret)
-               new_rs->state = rs_connected;
+               new_rs->state = rs_connect_rdwr;
        else if (errno == EAGAIN || errno == EWOULDBLOCK)
                new_rs->state = rs_accepting;
        else
@@ -715,7 +720,7 @@ connected:
                }
 
                rs_save_conn_data(rs, cresp);
-               rs->state = rs_connected;
+               rs->state = rs_connect_rdwr;
                break;
        case rs_accepting:
                if (!(rs->fd_flags & O_NONBLOCK))
@@ -725,7 +730,7 @@ connected:
                if (ret)
                        break;
 
-               rs->state = rs_connected;
+               rs->state = rs_connect_rdwr;
                break;
        default:
                ret = ERR(EINVAL);
@@ -852,7 +857,7 @@ static int rs_give_credits(struct rsocket *rs)
 {
        return ((rs->rbuf_bytes_avail >= (rs->rbuf_size >> 1)) ||
                ((short) ((short) rs->rseq_no - (short) rs->rseq_comp) >= 0)) &&
-              rs->ctrl_avail && (rs->state == rs_connected);
+              rs->ctrl_avail && (rs->state & rs_connected);
 }
 
 static void rs_update_credits(struct rsocket *rs)
@@ -883,6 +888,8 @@ static int rs_poll_cq(struct rsocket *rs)
                                if (rs_msg_data(imm_data) == RS_CTRL_DISCONNECT) {
                                        rs->state = rs_disconnected;
                                        return ERR(ECONNRESET);
+                               } else if (rs_msg_data(imm_data) == RS_CTRL_SHUTDOWN) {
+                                       rs->state &= ~rs_connect_rd;
                                }
                                break;
                        default:
@@ -900,14 +907,14 @@ static int rs_poll_cq(struct rsocket *rs)
                        } else {
                                rs->ctrl_avail++;
                        }
-                       if (wc.status != IBV_WC_SUCCESS && rs->state == rs_connected) {
+                       if (wc.status != IBV_WC_SUCCESS && (rs->state & rs_connected)) {
                                rs->state = rs_error;
                                rs->err = EIO;
                        }
                }
        }
 
-       if (rs->state == rs_connected) {
+       if (rs->state & rs_connected) {
                while (!ret && rcnt--)
                        ret = rdma_post_recvv(rs->cm_id, NULL, NULL, 0);
 
@@ -932,7 +939,7 @@ static int rs_get_cq_event(struct rsocket *rs)
        if (!ret) {
                ibv_ack_cq_events(rs->cm_id->recv_cq, 1);
                rs->cq_armed = 0;
-       } else if (errno != EAGAIN && rs->state == rs_connected) {
+       } else if (errno != EAGAIN) {
                rs->state = rs_error;
        }
 
@@ -1043,7 +1050,7 @@ static int rs_can_send(struct rsocket *rs)
 
 static int rs_conn_can_send(struct rsocket *rs)
 {
-       return rs_can_send(rs) || (rs->state != rs_connected);
+       return rs_can_send(rs) || !(rs->state & rs_connect_wr);
 }
 
 static int rs_can_send_ctrl(struct rsocket *rs)
@@ -1058,7 +1065,7 @@ static int rs_have_rdata(struct rsocket *rs)
 
 static int rs_conn_have_rdata(struct rsocket *rs)
 {
-       return rs_have_rdata(rs) || (rs->state != rs_connected);
+       return rs_have_rdata(rs) || !(rs->state & rs_connect_rd);
 }
 
 static int rs_all_sends_done(struct rsocket *rs)
@@ -1111,7 +1118,7 @@ ssize_t rrecv(int socket, void *buf, size_t len, int flags)
        int ret;
 
        rs = idm_at(&idm, socket);
-       if (rs->state < rs_connected) {
+       if (rs->state & rs_opening) {
                ret = rs_do_connect(rs);
                if (ret) {
                        if (errno == EINPROGRESS)
@@ -1213,7 +1220,7 @@ ssize_t rsend(int socket, const void *buf, size_t len, int flags)
        int ret = 0;
 
        rs = idm_at(&idm, socket);
-       if (rs->state < rs_connected) {
+       if (rs->state & rs_opening) {
                ret = rs_do_connect(rs);
                if (ret) {
                        if (errno == EINPROGRESS)
@@ -1229,7 +1236,7 @@ ssize_t rsend(int socket, const void *buf, size_t len, int flags)
                                          rs_conn_can_send);
                        if (ret)
                                break;
-                       if (rs->state != rs_connected) {
+                       if (!(rs->state & rs_connect_wr)) {
                                ret = ERR(ECONNRESET);
                                break;
                        }
@@ -1322,7 +1329,7 @@ static ssize_t rsendv(int socket, const struct iovec *iov, int iovcnt, int flags
        int i, ret = 0;
 
        rs = idm_at(&idm, socket);
-       if (rs->state < rs_connected) {
+       if (rs->state & rs_opening) {
                ret = rs_do_connect(rs);
                if (ret) {
                        if (errno == EINPROGRESS)
@@ -1343,7 +1350,7 @@ static ssize_t rsendv(int socket, const struct iovec *iov, int iovcnt, int flags
                                          rs_conn_can_send);
                        if (ret)
                                break;
-                       if (rs->state != rs_connected) {
+                       if (!(rs->state & rs_connect_wr)) {
                                ret = ERR(ECONNRESET);
                                break;
                        }
@@ -1435,17 +1442,35 @@ static int rs_poll_rs(struct rsocket *rs, int events,
        short revents;
        int ret;
 
-       switch (rs->state) {
-       case rs_listening:
+check_cq:
+       if ((rs->state & rs_connected) || (rs->state == rs_disconnected) ||
+           (rs->state & rs_error)) {
+               rs_process_cq(rs, nonblock, test);
+
+               revents = 0;
+               if ((events & POLLIN) && rs_have_rdata(rs))
+                       revents |= POLLIN;
+               if ((events & POLLOUT) && rs_can_send(rs))
+                       revents |= POLLOUT;
+               if (!(rs->state & rs_connected)) {
+                       if (rs->state == rs_disconnected)
+                               revents |= POLLHUP;
+                       else
+                               revents |= POLLERR;
+               }
+
+               return revents;
+       }
+
+       if (rs->state == rs_listening) {
                fds.fd = rs->cm_id->channel->fd;
                fds.events = events;
                fds.revents = 0;
                poll(&fds, 1, 0);
                return fds.revents;
-       case rs_resolving_addr:
-       case rs_resolving_route:
-       case rs_connecting:
-       case rs_accepting:
+       }
+
+       if (rs->state & rs_opening) {
                ret = rs_do_connect(rs);
                if (ret) {
                        if (errno == EINPROGRESS) {
@@ -1455,28 +1480,13 @@ static int rs_poll_rs(struct rsocket *rs, int events,
                                return POLLOUT;
                        }
                }
-               /* fall through */
-       case rs_connected:
-       case rs_disconnected:
-       case rs_error:
-               rs_process_cq(rs, nonblock, test);
-
-               revents = 0;
-               if ((events & POLLIN) && rs_have_rdata(rs))
-                       revents |= POLLIN;
-               if ((events & POLLOUT) && rs_can_send(rs))
-                       revents |= POLLOUT;
-               if (rs->state == rs_disconnected)
-                       revents |= POLLHUP;
-               else if (rs->state == rs_error)
-                       revents |= POLLERR;
+               goto check_cq;
+       }
 
-               return revents;
-       case rs_connect_error:
+       if (rs->state == rs_connect_error)
                return (rs->err && events & POLLOUT) ? POLLOUT : 0;
-       default:
-               return 0;
-       }
+
+       return 0;
 }
 
 static int rs_poll_check(struct pollfd *fds, nfds_t nfds)
@@ -1688,14 +1698,27 @@ int rselect(int nfds, fd_set *readfds, fd_set *writefds,
 int rshutdown(int socket, int how)
 {
        struct rsocket *rs;
-       int ret = 0;
+       int ctrl, ret = 0;
 
        rs = idm_at(&idm, socket);
+       if (how == SHUT_RD) {
+               rs->state &= ~rs_connect_rd;
+               if (rs->state == rs_connected)
+                       rs->state = rs_disconnected;
+               return 0;
+       }
+
        if (rs->fd_flags & O_NONBLOCK)
                rs_set_nonblocking(rs, 0);
 
-       if (rs->state == rs_connected) {
-               rs->state = rs_disconnected;
+       if (rs->state & rs_connected) {
+               if (how == SHUT_RDWR) {
+                       ctrl = RS_CTRL_DISCONNECT;
+                       rs->state = rs_disconnected;
+               } else {
+                       ctrl = RS_CTRL_SHUTDOWN;
+                       rs->state &= ~rs_connect_wr;
+               }
                if (!rs_can_send_ctrl(rs)) {
                        ret = rs_process_cq(rs, 0, rs_can_send_ctrl);
                        if (ret)
@@ -1704,13 +1727,16 @@ int rshutdown(int socket, int how)
 
                rs->ctrl_avail--;
                ret = rs_post_write(rs, 0, NULL, 0,
-                                   rs_msg_set(RS_OP_CTRL, RS_CTRL_DISCONNECT),
+                                   rs_msg_set(RS_OP_CTRL, ctrl),
                                    0, 0, 0);
        }
 
-       if (!rs_all_sends_done(rs) && rs->state != rs_error)
+       if (!rs_all_sends_done(rs) && !(rs->state & rs_error))
                rs_process_cq(rs, 0, rs_all_sends_done);
 
+       if ((rs->fd_flags & O_NONBLOCK) && (how == SHUT_WR))
+               rs_set_nonblocking(rs, 1);
+
        return 0;
 }
 
@@ -1719,7 +1745,7 @@ int rclose(int socket)
        struct rsocket *rs;
 
        rs = idm_at(&idm, socket);
-       if (rs->state == rs_connected)
+       if (rs->state & rs_connected)
                rshutdown(socket, SHUT_RDWR);
 
        rs_free(rs);
@@ -1830,8 +1856,9 @@ int rsetsockopt(int socket, int level, int optname,
                default:
                        break;
                }
+               break;
        case SOL_RDMA:
-               if (rs->state > rs_listening) {
+               if (rs->state >= rs_opening) {
                        ret = ERR(EINVAL);
                        break;
                }