#define rs_msg_data(imm_data) (imm_data & 0x1FFFFFFF)
enum {
- RS_CTRL_DISCONNECT
+ RS_CTRL_DISCONNECT,
+ RS_CTRL_SHUTDOWN
};
struct rs_msg {
*/
enum rs_state {
rs_init,
- rs_bound,
- rs_listening,
- rs_resolving_addr,
- rs_resolving_route,
- rs_connecting,
- rs_accepting,
- rs_connect_error,
- rs_connected,
- rs_disconnected,
- rs_error
+ rs_bound = 0x0001,
+ rs_listening = 0x0002,
+ rs_opening = 0x0004,
+ rs_resolving_addr = rs_opening | 0x0010,
+ rs_resolving_route = rs_opening | 0x0020,
+ rs_connecting = rs_opening | 0x0040,
+ rs_accepting = rs_opening | 0x0080,
+ rs_connected = 0x0100,
+ rs_connect_wr = 0x0200,
+ rs_connect_rd = 0x0400,
+ rs_connect_rdwr = rs_connected | rs_connect_rd | rs_connect_wr,
+ rs_connect_error = 0x0800,
+ rs_disconnected = 0x1000,
+ rs_error = 0x2000,
};
#define RS_OPT_SWAP_SGL 1
long fd_flags;
uint64_t so_opts;
uint64_t tcp_opts;
- enum rs_state state;
+ int state;
int cq_armed;
int retries;
int err;
if (rs->cm_id->recv_cq_channel)
ret = fcntl(rs->cm_id->recv_cq_channel->fd, F_SETFL, arg);
- if (!ret && rs->state != rs_connected)
+ if (!ret && rs->state < rs_connected)
ret = fcntl(rs->cm_id->channel->fd, F_SETFL, arg);
return ret;
rs_set_conn_data(new_rs, ¶m, &cresp);
ret = rdma_accept(new_rs->cm_id, ¶m);
if (!ret)
- new_rs->state = rs_connected;
+ new_rs->state = rs_connect_rdwr;
else if (errno == EAGAIN || errno == EWOULDBLOCK)
new_rs->state = rs_accepting;
else
}
rs_save_conn_data(rs, cresp);
- rs->state = rs_connected;
+ rs->state = rs_connect_rdwr;
break;
case rs_accepting:
if (!(rs->fd_flags & O_NONBLOCK))
if (ret)
break;
- rs->state = rs_connected;
+ rs->state = rs_connect_rdwr;
break;
default:
ret = ERR(EINVAL);
return rs_do_connect(rs);
}
+static void rs_shutdown_state(struct rsocket *rs, int state)
+{
+ rs->state &= ~state;
+ if (rs->state == rs_connected)
+ rs->state = rs_disconnected;
+}
+
static int rs_post_write(struct rsocket *rs, uint64_t wr_id,
struct ibv_sge *sgl, int nsge,
uint32_t imm_data, int flags,
{
return ((rs->rbuf_bytes_avail >= (rs->rbuf_size >> 1)) ||
((short) ((short) rs->rseq_no - (short) rs->rseq_comp) >= 0)) &&
- rs->ctrl_avail && (rs->state == rs_connected);
+ rs->ctrl_avail && (rs->state & rs_connected);
}
static void rs_update_credits(struct rsocket *rs)
case RS_OP_CTRL:
if (rs_msg_data(imm_data) == RS_CTRL_DISCONNECT) {
rs->state = rs_disconnected;
- return ERR(ECONNRESET);
+ return 0;
+ } else if (rs_msg_data(imm_data) == RS_CTRL_SHUTDOWN) {
+ rs_shutdown_state(rs, rs_connect_rd);
}
break;
default:
} else {
rs->ctrl_avail++;
}
- if (wc.status != IBV_WC_SUCCESS && rs->state == rs_connected) {
+ if (wc.status != IBV_WC_SUCCESS && (rs->state & rs_connected)) {
rs->state = rs_error;
rs->err = EIO;
}
}
}
- if (rs->state == rs_connected) {
+ if (rs->state & rs_connected) {
while (!ret && rcnt--)
ret = rdma_post_recvv(rs->cm_id, NULL, NULL, 0);
if (!ret) {
ibv_ack_cq_events(rs->cm_id->recv_cq, 1);
rs->cq_armed = 0;
- } else if (errno != EAGAIN && rs->state == rs_connected) {
+ } else if (errno != EAGAIN) {
rs->state = rs_error;
}
static int rs_conn_can_send(struct rsocket *rs)
{
- return rs_can_send(rs) || (rs->state != rs_connected);
+ return rs_can_send(rs) || !(rs->state & rs_connect_wr);
}
static int rs_can_send_ctrl(struct rsocket *rs)
static int rs_conn_have_rdata(struct rsocket *rs)
{
- return rs_have_rdata(rs) || (rs->state != rs_connected);
+ return rs_have_rdata(rs) || !(rs->state & rs_connect_rd);
}
static int rs_all_sends_done(struct rsocket *rs)
int ret;
rs = idm_at(&idm, socket);
- if (rs->state < rs_connected) {
+ if (rs->state & rs_opening) {
ret = rs_do_connect(rs);
if (ret) {
if (errno == EINPROGRESS)
fastlock_acquire(&rs->rlock);
if (!rs_have_rdata(rs)) {
ret = rs_get_comp(rs, rs_nonblocking(rs, flags), rs_conn_have_rdata);
- if (ret && errno != ECONNRESET)
+ if (ret)
goto out;
}
int ret = 0;
rs = idm_at(&idm, socket);
- if (rs->state < rs_connected) {
+ if (rs->state & rs_opening) {
ret = rs_do_connect(rs);
if (ret) {
if (errno == EINPROGRESS)
rs_conn_can_send);
if (ret)
break;
- if (rs->state != rs_connected) {
+ if (!(rs->state & rs_connect_wr)) {
ret = ERR(ECONNRESET);
break;
}
int i, ret = 0;
rs = idm_at(&idm, socket);
- if (rs->state < rs_connected) {
+ if (rs->state & rs_opening) {
ret = rs_do_connect(rs);
if (ret) {
if (errno == EINPROGRESS)
rs_conn_can_send);
if (ret)
break;
- if (rs->state != rs_connected) {
+ if (!(rs->state & rs_connect_wr)) {
ret = ERR(ECONNRESET);
break;
}
short revents;
int ret;
- switch (rs->state) {
- case rs_listening:
+check_cq:
+ if ((rs->state & rs_connected) || (rs->state == rs_disconnected) ||
+ (rs->state & rs_error)) {
+ rs_process_cq(rs, nonblock, test);
+
+ revents = 0;
+ if ((events & POLLIN) && rs_conn_have_rdata(rs))
+ revents |= POLLIN;
+ if ((events & POLLOUT) && rs_can_send(rs))
+ revents |= POLLOUT;
+ if (!(rs->state & rs_connected)) {
+ if (rs->state == rs_disconnected)
+ revents |= POLLHUP;
+ else
+ revents |= POLLERR;
+ }
+
+ return revents;
+ }
+
+ if (rs->state == rs_listening) {
fds.fd = rs->cm_id->channel->fd;
fds.events = events;
fds.revents = 0;
poll(&fds, 1, 0);
return fds.revents;
- case rs_resolving_addr:
- case rs_resolving_route:
- case rs_connecting:
- case rs_accepting:
+ }
+
+ if (rs->state & rs_opening) {
ret = rs_do_connect(rs);
if (ret) {
if (errno == EINPROGRESS) {
return POLLOUT;
}
}
- /* fall through */
- case rs_connected:
- case rs_disconnected:
- case rs_error:
- rs_process_cq(rs, nonblock, test);
-
- revents = 0;
- if ((events & POLLIN) && rs_have_rdata(rs))
- revents |= POLLIN;
- if ((events & POLLOUT) && rs_can_send(rs))
- revents |= POLLOUT;
- if (rs->state == rs_disconnected)
- revents |= POLLHUP;
- else if (rs->state == rs_error)
- revents |= POLLERR;
+ goto check_cq;
+ }
- return revents;
- case rs_connect_error:
+ if (rs->state == rs_connect_error)
return (rs->err && events & POLLOUT) ? POLLOUT : 0;
- default:
- return 0;
- }
+
+ return 0;
}
static int rs_poll_check(struct pollfd *fds, nfds_t nfds)
int rshutdown(int socket, int how)
{
struct rsocket *rs;
- int ret = 0;
+ int ctrl, ret = 0;
rs = idm_at(&idm, socket);
+ if (how == SHUT_RD) {
+ rs_shutdown_state(rs, rs_connect_rd);
+ return 0;
+ }
+
if (rs->fd_flags & O_NONBLOCK)
rs_set_nonblocking(rs, 0);
- if (rs->state == rs_connected) {
- rs->state = rs_disconnected;
+ if (rs->state & rs_connected) {
+ if (how == SHUT_RDWR) {
+ ctrl = RS_CTRL_DISCONNECT;
+ rs->state = rs_disconnected;
+ } else {
+ rs_shutdown_state(rs, rs_connect_wr);
+ ctrl = (rs->state & rs_connected) ?
+ RS_CTRL_SHUTDOWN : RS_CTRL_DISCONNECT;
+ }
if (!rs_can_send_ctrl(rs)) {
ret = rs_process_cq(rs, 0, rs_can_send_ctrl);
if (ret)
rs->ctrl_avail--;
ret = rs_post_write(rs, 0, NULL, 0,
- rs_msg_set(RS_OP_CTRL, RS_CTRL_DISCONNECT),
+ rs_msg_set(RS_OP_CTRL, ctrl),
0, 0, 0);
}
- if (!rs_all_sends_done(rs) && rs->state != rs_error)
+ if (!rs_all_sends_done(rs) && !(rs->state & rs_error))
rs_process_cq(rs, 0, rs_all_sends_done);
+ if ((rs->fd_flags & O_NONBLOCK) && (rs->state & rs_connected))
+ rs_set_nonblocking(rs, 1);
+
return 0;
}
struct rsocket *rs;
rs = idm_at(&idm, socket);
- if (rs->state == rs_connected)
+ if (rs->state & rs_connected)
rshutdown(socket, SHUT_RDWR);
rs_free(rs);
}
break;
case SOL_RDMA:
- if (rs->state > rs_listening) {
+ if (rs->state >= rs_opening) {
ret = ERR(EINVAL);
break;
}