From bc7d39b5bd7e376900e5b0d0024e5b9a521697c7 Mon Sep 17 00:00:00 2001 From: Sean Hefty Date: Fri, 9 Nov 2012 10:26:38 -0800 Subject: [PATCH] rsocket: Add datagram support Signed-off-by: Sean Hefty --- docs/rsocket | 37 +- src/cma.c | 14 +- src/cma.h | 2 + src/rsocket.c | 1670 ++++++++++++++++++++++++++++++++++++++++++++----- 4 files changed, 1570 insertions(+), 153 deletions(-) diff --git a/docs/rsocket b/docs/rsocket index 1484f65b..a6602084 100644 --- a/docs/rsocket +++ b/docs/rsocket @@ -1,7 +1,7 @@ -rsocket Protocol and Design Guide 9/10/2012 +rsocket Protocol and Design Guide 11/11/2012 -Overview --------- +Data Streaming (TCP) Overview +----------------------------- Rsockets is a protocol over RDMA that supports a socket-level API for applications. For details on the current state of the implementation, readers should refer to the rsocket man page. This @@ -189,3 +189,34 @@ registered remote data buffer. From host A's perspective, the transfer appears as a normal send/write operation, with the data stream redirected directly into the receiving application's buffer. + + + +Datagram Overview +----------------- +The rsocket API supports datagram sockets. Datagram support is handled through an +entirely different protocol and internal implementation. Unlike connected rsockets, +datagram rsockets are not necessarily bound to a network (IP) address. A datagram +socket may use any number of network (IP) addresses, including those which map to +different RDMA devices. As a result, a single datagram rsocket must support +using multiple RDMA devices and ports, and a datagram rsocket references a single +UDP socket, plus zero or more UD QPs. + +Rsockets uses headers inserted before user data sent over UDP sockets to resolve +remote UD QP numbers. When a user first attempts to send a datagram to a remote +address (IP and UDP port), rsockets will take the following steps: + +1. Store the destination address into a lookup table. +2. Resolve which local network address should be used when sending + to the specified destination. +3. Allocate a UD QP on the RDMA device associated with the local address. +4. Send the user's datagram to the remote UDP socket. + +A header is inserted before the user's datagram. The header specifies the +UD QP number associated with the local network address (IP and UDP port) of +the send. + +A service thread is used to process messages received on the UDP socket. This +thread updates the rsocket lookup tables with the remote QPN and path record +data. The service thread forwards data received on the UDP socket to an +rsocket QP. \ No newline at end of file diff --git a/src/cma.c b/src/cma.c index 388be617..ff9b426c 100755 --- a/src/cma.c +++ b/src/cma.c @@ -513,7 +513,7 @@ int rdma_destroy_id(struct rdma_cm_id *id) return 0; } -static int ucma_addrlen(struct sockaddr *addr) +int ucma_addrlen(struct sockaddr *addr) { if (!addr) return 0; @@ -2232,9 +2232,19 @@ void rdma_destroy_ep(struct rdma_cm_id *id) int ucma_max_qpsize(struct rdma_cm_id *id) { struct cma_id_private *id_priv; + int i, max_size = 0; id_priv = container_of(id, struct cma_id_private, id); - return id_priv->cma_dev->max_qpsize; + if (id && id_priv->cma_dev) { + max_size = id_priv->cma_dev->max_qpsize; + } else { + ucma_init(); + for (i = 0; i < cma_dev_cnt; i++) { + if (!max_size || max_size > cma_dev_array[i].max_qpsize) + max_size = cma_dev_array[i].max_qpsize; + } + } + return max_size; } uint16_t ucma_get_port(struct sockaddr *addr) diff --git a/src/cma.h b/src/cma.h index 0a0370ee..7135a612 100644 --- a/src/cma.h +++ b/src/cma.h @@ -145,10 +145,12 @@ typedef struct { volatile int val; } atomic_t; #define atomic_set(v, s) ((v)->val = s) uint16_t ucma_get_port(struct sockaddr *addr); +int ucma_addrlen(struct sockaddr *addr); void ucma_set_sid(enum rdma_port_space ps, struct sockaddr *addr, struct sockaddr_ib *sib); int ucma_max_qpsize(struct rdma_cm_id *id); int ucma_complete(struct rdma_cm_id *id); + static inline int ERR(int err) { errno = err; diff --git a/src/rsocket.c b/src/rsocket.c index a060f66a..219aa4a2 100644 --- a/src/rsocket.c +++ b/src/rsocket.c @@ -47,6 +47,8 @@ #include #include #include +#include +#include #include #include @@ -56,7 +58,7 @@ #define RS_OLAP_START_SIZE 2048 #define RS_MAX_TRANSFER 65536 -#define RS_SNDLOWAT 64 +#define RS_SNDLOWAT 2048 #define RS_QP_MAX_SIZE 0xFFFE #define RS_QP_CTRL_SIZE 4 #define RS_CONN_RETRIES 6 @@ -64,6 +66,27 @@ static struct index_map idm; static pthread_mutex_t mut = PTHREAD_MUTEX_INITIALIZER; +struct rsocket; + +enum { + RS_SVC_DGRAM = 1 << 0 +}; + +struct rs_svc_msg { + uint32_t svcs; + uint32_t status; + struct rsocket *rs; +}; + +static pthread_t svc_id; +static int svc_sock[2]; +static int svc_cnt; +static int svc_size; +static struct rsocket **svc_rss; +static struct pollfd *svc_fds; +static uint8_t svc_buf[RS_SNDLOWAT]; +static void *rs_svc_run(void *arg); + static uint16_t def_iomap_size = 0; static uint16_t def_inline = 64; static uint16_t def_sqsize = 384; @@ -100,6 +123,14 @@ enum { #define rs_msg_set(op, data) ((op << 29) | (uint32_t) (data)) #define rs_msg_op(imm_data) (imm_data >> 29) #define rs_msg_data(imm_data) (imm_data & 0x1FFFFFFF) +#define RS_RECV_WR_ID (~((uint64_t) 0)) + +#define DS_WR_RECV 0xFFFFFFFF +#define ds_send_wr_id(offset, length) (((uint64_t) (offset)) << 32 | (uint64_t) length) +#define ds_recv_wr_id(offset) (((uint64_t) (offset)) << 32 | (uint64_t) DS_WR_RECV) +#define ds_wr_offset(wr_id) ((uint32_t) (wr_id >> 32)) +#define ds_wr_length(wr_id) ((uint32_t) wr_id) +#define ds_wr_is_recv(wr_id) (ds_wr_length(wr_id) == DS_WR_RECV) enum { RS_CTRL_DISCONNECT, @@ -111,6 +142,18 @@ struct rs_msg { uint32_t data; }; +struct ds_qp; + +struct ds_rmsg { + struct ds_qp *qp; + uint32_t offset; + uint32_t length; +}; + +struct ds_smsg { + struct ds_smsg *next; +}; + struct rs_sge { uint64_t addr; uint32_t key; @@ -145,8 +188,6 @@ struct rs_conn_data { struct rs_sge data_buf; }; -#define RS_RECV_WR_ID (~((uint64_t) 0)) - /* * rsocket states are ordered as passive, connecting, connected, disconnected. */ @@ -160,9 +201,9 @@ enum rs_state { rs_connecting = rs_opening | 0x0040, rs_accepting = rs_opening | 0x0080, rs_connected = 0x0100, - rs_connect_wr = 0x0200, - rs_connect_rd = 0x0400, - rs_connect_rdwr = rs_connected | rs_connect_rd | rs_connect_wr, + rs_writable = 0x0200, + rs_readable = 0x0400, + rs_connect_rdwr = rs_connected | rs_readable | rs_writable, rs_connect_error = 0x0800, rs_disconnected = 0x1000, rs_error = 0x2000, @@ -170,68 +211,251 @@ enum rs_state { #define RS_OPT_SWAP_SGL 1 -struct rsocket { +union socket_addr { + struct sockaddr sa; + struct sockaddr_in sin; + struct sockaddr_in6 sin6; +}; + +struct ds_header { + uint8_t version; + uint8_t length; + uint16_t port; + union { + uint32_t ipv4; + struct { + uint32_t flowinfo; + uint8_t addr[16]; + } ipv6; + } addr; +}; + +#define DS_IPV4_HDR_LEN 8 +#define DS_IPV6_HDR_LEN 24 + +struct ds_dest { + union socket_addr addr; /* must be first */ + struct ds_qp *qp; + struct ibv_ah *ah; + uint32_t qpn; +}; + +struct ds_qp { + dlist_entry list; + struct rsocket *rs; struct rdma_cm_id *cm_id; + struct ds_header hdr; + struct ds_dest dest; + + struct ibv_mr *smr; + struct ibv_mr *rmr; + uint8_t *rbuf; + + int cq_armed; +}; + +struct rsocket { + int type; + int index; fastlock_t slock; fastlock_t rlock; fastlock_t cq_lock; fastlock_t cq_wait_lock; - fastlock_t iomap_lock; - + fastlock_t map_lock; /* acquire slock first if needed */ + + union { + /* data stream */ + struct { + struct rdma_cm_id *cm_id; + uint64_t tcp_opts; + + int ctrl_avail; + uint16_t sseq_no; + uint16_t sseq_comp; + uint16_t rseq_no; + uint16_t rseq_comp; + + int remote_sge; + struct rs_sge remote_sgl; + struct rs_sge remote_iomap; + + struct ibv_mr *target_mr; + int target_sge; + int target_iomap_size; + void *target_buffer_list; + volatile struct rs_sge *target_sgl; + struct rs_iomap *target_iomap; + + int rbuf_bytes_avail; + int rbuf_free_offset; + int rbuf_offset; + struct ibv_mr *rmr; + uint8_t *rbuf; + + int sbuf_bytes_avail; + struct ibv_mr *smr; + struct ibv_sge ssgl[2]; + }; + /* datagram */ + struct { + struct ds_qp *qp_list; + void *dest_map; + struct ds_dest *conn_dest; + + int udp_sock; + int epfd; + int rqe_avail; + struct ds_smsg *smsg_free; + }; + }; + + int svcs; int opts; long fd_flags; uint64_t so_opts; - uint64_t tcp_opts; uint64_t ipv6_opts; int state; int cq_armed; int retries; int err; - int index; - int ctrl_avail; + int sqe_avail; - int sbuf_bytes_avail; - uint16_t sseq_no; - uint16_t sseq_comp; + uint32_t sbuf_size; uint16_t sq_size; uint16_t sq_inline; + uint32_t rbuf_size; uint16_t rq_size; - uint16_t rseq_no; - uint16_t rseq_comp; - int rbuf_bytes_avail; - int rbuf_free_offset; - int rbuf_offset; int rmsg_head; int rmsg_tail; - struct rs_msg *rmsg; - - int remote_sge; - struct rs_sge remote_sgl; - struct rs_sge remote_iomap; + union { + struct rs_msg *rmsg; + struct ds_rmsg *dmsg; + }; + uint8_t *sbuf; struct rs_iomap_mr *remote_iomappings; dlist_entry iomap_list; dlist_entry iomap_queue; int iomap_pending; +}; - struct ibv_mr *target_mr; - int target_sge; - int target_iomap_size; - void *target_buffer_list; - volatile struct rs_sge *target_sgl; - struct rs_iomap *target_iomap; +#define DS_UDP_TAG 0x55555555 - uint32_t rbuf_size; - struct ibv_mr *rmr; - uint8_t *rbuf; - - uint32_t sbuf_size; - struct ibv_mr *smr; - struct ibv_sge ssgl[2]; - uint8_t *sbuf; +struct ds_udp_header { + uint32_t tag; + uint8_t version; + uint8_t op; + uint8_t length; + uint8_t reserved; + uint32_t qpn; /* lower 8-bits reserved */ + union { + uint32_t ipv4; + uint8_t ipv6[16]; + } addr; }; +#define DS_UDP_IPV4_HDR_LEN 16 +#define DS_UDP_IPV6_HDR_LEN 28 + +#define ds_next_qp(qp) container_of((qp)->list.next, struct ds_qp, list) + +static void ds_insert_qp(struct rsocket *rs, struct ds_qp *qp) +{ + if (!rs->qp_list) + dlist_init(&qp->list); + else + dlist_insert_head(&qp->list, &rs->qp_list->list); + rs->qp_list = qp; +} + +static void ds_remove_qp(struct rsocket *rs, struct ds_qp *qp) +{ + if (qp->list.next != &qp->list) { + rs->qp_list = ds_next_qp(qp); + dlist_remove(&qp->list); + } else { + rs->qp_list = NULL; + } +} + +static int rs_modify_svcs(struct rsocket *rs, int svcs) +{ + struct rs_svc_msg msg; + int ret; + + pthread_mutex_lock(&mut); + if (!svc_cnt) { + ret = socketpair(AF_UNIX, SOCK_STREAM, 0, svc_sock); + if (ret) + goto unlock; + + ret = pthread_create(&svc_id, NULL, rs_svc_run, NULL); + if (ret) { + ret = ERR(ret); + goto closepair; + } + } + + msg.svcs = svcs; + msg.status = EINVAL; + msg.rs = rs; + write(svc_sock[0], &msg, sizeof msg); + read(svc_sock[0], &msg, sizeof msg); + ret = rdma_seterrno(msg.status); + if (svc_cnt) + goto unlock; +// if (ret && !svc_cnt) +// goto join; +// +// pthread_mutex_unlock(&mut); +// return ret; + + pthread_join(svc_id, NULL); +closepair: + close(svc_sock[0]); + close(svc_sock[1]); +unlock: + pthread_mutex_unlock(&mut); + return ret; +} + +//static void rs_remove_from_svc(struct rsocket *rs) +//{ +// struct rs_svc_msg msg; +// int ret; +// +// pthread_mutex_lock(&mut); +// if (svc_cnt) { +// msg.op = RS_SVC_REMOVE; +// msg.status = EINVAL; +// msg.rs = rs; +// write(svc_sock[0], &msg, sizeof msg); +// read(svc_sock[0], &msg, sizeof msg); +// } +// +// if (!svc_cnt) { +// pthread_join(svc_id, NULL); +// close(svc_sock[0]); +// close(svc_sock[1]); +// } +// +// pthread_mutex_unlock(&mut); +//} + +static int ds_compare_addr(const void *dst1, const void *dst2) +{ + const struct sockaddr *sa1, *sa2; + size_t len; + + sa1 = (const struct sockaddr *) dst1; + sa2 = (const struct sockaddr *) dst2; + + len = (sa1->sa_family == AF_INET6 && sa2->sa_family == AF_INET6) ? + sizeof(struct sockaddr_in6) : sizeof(struct sockaddr_in); + return memcmp(dst1, dst2, len); +} + static int rs_value_to_scale(int value, int bits) { return value <= (1 << (bits - 1)) ? @@ -307,10 +531,10 @@ out: pthread_mutex_unlock(&mut); } -static int rs_insert(struct rsocket *rs) +static int rs_insert(struct rsocket *rs, int index) { pthread_mutex_lock(&mut); - rs->index = idm_set(&idm, rs->cm_id->channel->fd, rs); + rs->index = idm_set(&idm, index, rs); pthread_mutex_unlock(&mut); return rs->index; } @@ -322,7 +546,7 @@ static void rs_remove(struct rsocket *rs) pthread_mutex_unlock(&mut); } -static struct rsocket *rs_alloc(struct rsocket *inherited_rs) +static struct rsocket *rs_alloc(struct rsocket *inherited_rs, int type) { struct rsocket *rs; @@ -330,29 +554,39 @@ static struct rsocket *rs_alloc(struct rsocket *inherited_rs) if (!rs) return NULL; + rs->type = type; rs->index = -1; + if (type == SOCK_DGRAM) { + rs->udp_sock = -1; + rs->epfd = -1; + } + if (inherited_rs) { rs->sbuf_size = inherited_rs->sbuf_size; rs->rbuf_size = inherited_rs->rbuf_size; rs->sq_inline = inherited_rs->sq_inline; rs->sq_size = inherited_rs->sq_size; rs->rq_size = inherited_rs->rq_size; - rs->ctrl_avail = inherited_rs->ctrl_avail; - rs->target_iomap_size = inherited_rs->target_iomap_size; + if (type == SOCK_STREAM) { + rs->ctrl_avail = inherited_rs->ctrl_avail; + rs->target_iomap_size = inherited_rs->target_iomap_size; + } } else { rs->sbuf_size = def_wmem; rs->rbuf_size = def_mem; rs->sq_inline = def_inline; rs->sq_size = def_sqsize; rs->rq_size = def_rqsize; - rs->ctrl_avail = RS_QP_CTRL_SIZE; - rs->target_iomap_size = def_iomap_size; + if (type == SOCK_STREAM) { + rs->ctrl_avail = RS_QP_CTRL_SIZE; + rs->target_iomap_size = def_iomap_size; + } } fastlock_init(&rs->slock); fastlock_init(&rs->rlock); fastlock_init(&rs->cq_lock); fastlock_init(&rs->cq_wait_lock); - fastlock_init(&rs->iomap_lock); + fastlock_init(&rs->map_lock); dlist_init(&rs->iomap_list); dlist_init(&rs->iomap_queue); return rs; @@ -360,13 +594,26 @@ static struct rsocket *rs_alloc(struct rsocket *inherited_rs) static int rs_set_nonblocking(struct rsocket *rs, long arg) { + struct ds_qp *qp; int ret = 0; - if (rs->cm_id->recv_cq_channel) - ret = fcntl(rs->cm_id->recv_cq_channel->fd, F_SETFL, arg); + if (rs->type == SOCK_STREAM) { + if (rs->cm_id->recv_cq_channel) + ret = fcntl(rs->cm_id->recv_cq_channel->fd, F_SETFL, arg); - if (!ret && rs->state < rs_connected) - ret = fcntl(rs->cm_id->channel->fd, F_SETFL, arg); + if (!ret && rs->state < rs_connected) + ret = fcntl(rs->cm_id->channel->fd, F_SETFL, arg); + } else { + ret = fcntl(rs->epfd, F_SETFL, arg); + if (!ret && rs->qp_list) { + qp = rs->qp_list; + do { + ret = fcntl(qp->cm_id->recv_cq_channel->fd, + F_SETFL, arg); + qp = ds_next_qp(qp); + } while (qp != rs->qp_list && !ret); + } + } return ret; } @@ -390,17 +637,39 @@ static void rs_set_qp_size(struct rsocket *rs) rs->rq_size = 2; } +static void ds_set_qp_size(struct rsocket *rs) +{ + uint16_t max_size; + + max_size = min(ucma_max_qpsize(NULL), RS_QP_MAX_SIZE); + + if (rs->sq_size > max_size) + rs->sq_size = max_size; + if (rs->rq_size > max_size) + rs->rq_size = max_size; + + if (rs->rq_size > (rs->rbuf_size / RS_SNDLOWAT)) + rs->rq_size = rs->rbuf_size / RS_SNDLOWAT; + else + rs->rbuf_size = rs->rq_size * RS_SNDLOWAT; + + if (rs->sq_size > (rs->sbuf_size / RS_SNDLOWAT)) + rs->sq_size = rs->sbuf_size / RS_SNDLOWAT; + else + rs->sbuf_size = rs->sq_size * RS_SNDLOWAT; +} + static int rs_init_bufs(struct rsocket *rs) { size_t len; rs->rmsg = calloc(rs->rq_size + 1, sizeof(*rs->rmsg)); if (!rs->rmsg) - return -1; + return ERR(ENOMEM); rs->sbuf = calloc(rs->sbuf_size, sizeof(*rs->sbuf)); if (!rs->sbuf) - return -1; + return ERR(ENOMEM); rs->smr = rdma_reg_msgs(rs->cm_id, rs->sbuf, rs->sbuf_size); if (!rs->smr) @@ -410,7 +679,7 @@ static int rs_init_bufs(struct rsocket *rs) sizeof(*rs->target_iomap) * rs->target_iomap_size; rs->target_buffer_list = malloc(len); if (!rs->target_buffer_list) - return -1; + return ERR(ENOMEM); rs->target_mr = rdma_reg_write(rs->cm_id, rs->target_buffer_list, len); if (!rs->target_mr) @@ -423,7 +692,7 @@ static int rs_init_bufs(struct rsocket *rs) rs->rbuf = calloc(rs->rbuf_size, sizeof(*rs->rbuf)); if (!rs->rbuf) - return -1; + return ERR(ENOMEM); rs->rmr = rdma_reg_write(rs->cm_id, rs->rbuf, rs->rbuf_size); if (!rs->rmr) @@ -440,37 +709,57 @@ static int rs_init_bufs(struct rsocket *rs) return 0; } -static int rs_create_cq(struct rsocket *rs) +static int ds_init_bufs(struct ds_qp *qp) +{ + qp->rbuf = calloc(qp->rs->rbuf_size + sizeof(struct ibv_grh), + sizeof(*qp->rbuf)); + if (!qp->rbuf) + return ERR(ENOMEM); + + qp->smr = rdma_reg_msgs(qp->cm_id, qp->rs->sbuf, qp->rs->sbuf_size); + if (!qp->smr) + return -1; + + qp->rmr = rdma_reg_msgs(qp->cm_id, qp->rbuf, qp->rs->rbuf_size + + sizeof(struct ibv_grh)); + if (!qp->rmr) + return -1; + + return 0; +} + +static int rs_create_cq(struct rsocket *rs, struct rdma_cm_id *cm_id) { - rs->cm_id->recv_cq_channel = ibv_create_comp_channel(rs->cm_id->verbs); - if (!rs->cm_id->recv_cq_channel) + cm_id->recv_cq_channel = ibv_create_comp_channel(cm_id->verbs); + if (!cm_id->recv_cq_channel) return -1; - rs->cm_id->recv_cq = ibv_create_cq(rs->cm_id->verbs, rs->sq_size + rs->rq_size, - rs->cm_id, rs->cm_id->recv_cq_channel, 0); - if (!rs->cm_id->recv_cq) + cm_id->recv_cq = ibv_create_cq(cm_id->verbs, rs->sq_size + rs->rq_size, + cm_id, cm_id->recv_cq_channel, 0); + if (!cm_id->recv_cq) goto err1; if (rs->fd_flags & O_NONBLOCK) { - if (rs_set_nonblocking(rs, O_NONBLOCK)) + if (fcntl(cm_id->recv_cq_channel->fd, F_SETFL, O_NONBLOCK)) goto err2; + } else { + ibv_req_notify_cq(cm_id->recv_cq, 0); } - rs->cm_id->send_cq_channel = rs->cm_id->recv_cq_channel; - rs->cm_id->send_cq = rs->cm_id->recv_cq; + cm_id->send_cq_channel = cm_id->recv_cq_channel; + cm_id->send_cq = cm_id->recv_cq; return 0; err2: - ibv_destroy_cq(rs->cm_id->recv_cq); - rs->cm_id->recv_cq = NULL; + ibv_destroy_cq(cm_id->recv_cq); + cm_id->recv_cq = NULL; err1: - ibv_destroy_comp_channel(rs->cm_id->recv_cq_channel); - rs->cm_id->recv_cq_channel = NULL; + ibv_destroy_comp_channel(cm_id->recv_cq_channel); + cm_id->recv_cq_channel = NULL; return -1; } -static inline int -rs_post_recv(struct rsocket *rs) +static inline int rs_post_recv(struct rsocket *rs) { struct ibv_recv_wr wr, *bad; @@ -482,6 +771,26 @@ rs_post_recv(struct rsocket *rs) return rdma_seterrno(ibv_post_recv(rs->cm_id->qp, &wr, &bad)); } +static inline int ds_post_recv(struct rsocket *rs, struct ds_qp *qp, uint32_t offset) +{ + struct ibv_recv_wr wr, *bad; + struct ibv_sge sge[2]; + + sge[0].addr = (uintptr_t) qp->rbuf + rs->rbuf_size; + sge[0].length = sizeof(struct ibv_grh); + sge[0].lkey = qp->rmr->lkey; + sge[1].addr = (uintptr_t) qp->rbuf + offset; + sge[1].length = RS_SNDLOWAT; + sge[1].lkey = qp->rmr->lkey; + + wr.wr_id = ds_recv_wr_id(offset); + wr.next = NULL; + wr.sg_list = sge; + wr.num_sge = 2; + + return rdma_seterrno(ibv_post_recv(qp->cm_id->qp, &wr, &bad)); +} + static int rs_create_ep(struct rsocket *rs) { struct ibv_qp_init_attr qp_attr; @@ -492,7 +801,7 @@ static int rs_create_ep(struct rsocket *rs) if (ret) return ret; - ret = rs_create_cq(rs); + ret = rs_create_cq(rs, rs->cm_id); if (ret) return ret; @@ -549,8 +858,70 @@ static void rs_free_iomappings(struct rsocket *rs) } } +static void ds_free_qp(struct ds_qp *qp) +{ + if (qp->smr) + rdma_dereg_mr(qp->smr); + + if (qp->rbuf) { + if (qp->rmr) + rdma_dereg_mr(qp->rmr); + free(qp->rbuf); + } + + if (qp->cm_id) { + if (qp->cm_id->qp) { + tdelete(&qp->dest.addr, &qp->rs->dest_map, ds_compare_addr); + epoll_ctl(qp->rs->epfd, EPOLL_CTL_DEL, + qp->cm_id->recv_cq_channel->fd, NULL); + rdma_destroy_qp(qp->cm_id); + } + rdma_destroy_id(qp->cm_id); + } + + free(qp); +} + +static void ds_free(struct rsocket *rs) +{ + struct ds_qp *qp; + + if (rs->udp_sock >= 0) + close(rs->udp_sock); + + if (rs->index >= 0) + rs_remove(rs); + + if (rs->dmsg) + free(rs->dmsg); + + while ((qp = rs->qp_list)) { + ds_remove_qp(rs, qp); + ds_free_qp(qp); + } + + if (rs->epfd >= 0) + close(rs->epfd); + + if (rs->sbuf) + free(rs->sbuf); + + tdestroy(rs->dest_map, free); + fastlock_destroy(&rs->map_lock); + fastlock_destroy(&rs->cq_wait_lock); + fastlock_destroy(&rs->cq_lock); + fastlock_destroy(&rs->rlock); + fastlock_destroy(&rs->slock); + free(rs); +} + static void rs_free(struct rsocket *rs) { + if (rs->type == SOCK_DGRAM) { + ds_free(rs); + return; + } + if (rs->index >= 0) rs_remove(rs); @@ -582,7 +953,7 @@ static void rs_free(struct rsocket *rs) rdma_destroy_id(rs->cm_id); } - fastlock_destroy(&rs->iomap_lock); + fastlock_destroy(&rs->map_lock); fastlock_destroy(&rs->cq_wait_lock); fastlock_destroy(&rs->cq_lock); fastlock_destroy(&rs->rlock); @@ -636,29 +1007,88 @@ static void rs_save_conn_data(struct rsocket *rs, struct rs_conn_data *conn) rs->sseq_comp = ntohs(conn->credits); } +static int ds_init(struct rsocket *rs, int domain) +{ + rs->udp_sock = socket(domain, SOCK_DGRAM, 0); + if (rs->udp_sock < 0) + return rs->udp_sock; + + rs->epfd = epoll_create(2); + if (rs->epfd < 0) + return rs->epfd; + + return 0; +} + +static int ds_init_ep(struct rsocket *rs) +{ + struct ds_smsg *msg; + int i, ret; + + ds_set_qp_size(rs); + + rs->sbuf = calloc(rs->sq_size, RS_SNDLOWAT); + if (!rs->sbuf) + return ERR(ENOMEM); + + rs->dmsg = calloc(rs->rq_size + 1, sizeof(*rs->dmsg)); + if (!rs->dmsg) + return ERR(ENOMEM); + + rs->sqe_avail = rs->sq_size; + rs->rqe_avail = rs->rq_size; + + rs->smsg_free = (struct ds_smsg *) rs->sbuf; + msg = rs->smsg_free; + for (i = 0; i < rs->sq_size - 1; i++) { + msg->next = (void *) msg + RS_SNDLOWAT; + msg = msg->next; + } + msg->next = NULL; + + ret = rs_modify_svcs(rs, RS_SVC_DGRAM); + if (ret) + return ret; + + rs->state = rs_readable | rs_writable; + return 0; +} + int rsocket(int domain, int type, int protocol) { struct rsocket *rs; - int ret; + int index, ret; if ((domain != PF_INET && domain != PF_INET6) || - (type != SOCK_STREAM) || (protocol && protocol != IPPROTO_TCP)) + ((type != SOCK_STREAM) && (type != SOCK_DGRAM)) || + (type == SOCK_STREAM && protocol && protocol != IPPROTO_TCP) || + (type == SOCK_DGRAM && protocol && protocol != IPPROTO_UDP)) return ERR(ENOTSUP); rs_configure(); - rs = rs_alloc(NULL); + rs = rs_alloc(NULL, type); if (!rs) return ERR(ENOMEM); - ret = rdma_create_id(NULL, &rs->cm_id, rs, RDMA_PS_TCP); - if (ret) - goto err; + if (type == SOCK_STREAM) { + ret = rdma_create_id(NULL, &rs->cm_id, rs, RDMA_PS_TCP); + if (ret) + goto err; + + rs->cm_id->route.addr.src_addr.sa_family = domain; + index = rs->cm_id->channel->fd; + } else { + ret = ds_init(rs, domain); + if (ret) + goto err; - ret = rs_insert(rs); + index = rs->udp_sock; + } + + ret = rs_insert(rs, index); if (ret < 0) goto err; - rs->cm_id->route.addr.src_addr.sa_family = domain; return rs->index; err: @@ -672,9 +1102,18 @@ int rbind(int socket, const struct sockaddr *addr, socklen_t addrlen) int ret; rs = idm_at(&idm, socket); - ret = rdma_bind_addr(rs->cm_id, (struct sockaddr *) addr); - if (!ret) - rs->state = rs_bound; + if (rs->type == SOCK_STREAM) { + ret = rdma_bind_addr(rs->cm_id, (struct sockaddr *) addr); + if (!ret) + rs->state = rs_bound; + } else { + if (rs->state == rs_init) { + ret = ds_init_ep(rs); + if (ret) + return ret; + } + ret = bind(rs->udp_sock, addr, addrlen); + } return ret; } @@ -710,7 +1149,7 @@ int raccept(int socket, struct sockaddr *addr, socklen_t *addrlen) int ret; rs = idm_at(&idm, socket); - new_rs = rs_alloc(rs); + new_rs = rs_alloc(rs, rs->type); if (!new_rs) return ERR(ENOMEM); @@ -718,7 +1157,7 @@ int raccept(int socket, struct sockaddr *addr, socklen_t *addrlen) if (ret) goto err; - ret = rs_insert(new_rs); + ret = rs_insert(new_rs, new_rs->cm_id->channel->fd); if (ret < 0) goto err; @@ -729,7 +1168,7 @@ int raccept(int socket, struct sockaddr *addr, socklen_t *addrlen) } if (rs->fd_flags & O_NONBLOCK) - rs_set_nonblocking(new_rs, O_NONBLOCK); + fcntl(new_rs->cm_id->channel->fd, F_SETFL, O_NONBLOCK); ret = rs_create_ep(new_rs); if (ret) @@ -831,7 +1270,7 @@ connected: break; case rs_accepting: if (!(rs->fd_flags & O_NONBLOCK)) - rs_set_nonblocking(rs, 0); + fcntl(rs->cm_id->channel->fd, F_SETFL, 0); ret = ucma_complete(rs->cm_id); if (ret) @@ -855,13 +1294,240 @@ connected: return ret; } +static int rs_any_addr(const union socket_addr *addr) +{ + if (addr->sa.sa_family == AF_INET) { + return (addr->sin.sin_addr.s_addr == INADDR_ANY || + addr->sin.sin_addr.s_addr == INADDR_LOOPBACK); + } else { + return (!memcmp(&addr->sin6.sin6_addr, &in6addr_any, 16) || + !memcmp(&addr->sin6.sin6_addr, &in6addr_loopback, 16)); + } +} + +static int ds_get_src_addr(struct rsocket *rs, + const struct sockaddr *dest_addr, socklen_t dest_len, + union socket_addr *src_addr, socklen_t *src_len) +{ + int sock, ret; + uint16_t port; + + *src_len = sizeof *src_addr; + ret = getsockname(rs->udp_sock, &src_addr->sa, src_len); + if (ret || !rs_any_addr(src_addr)) + return ret; + + port = src_addr->sin.sin_port; + sock = socket(dest_addr->sa_family, SOCK_DGRAM, 0); + if (sock < 0) + return sock; + + ret = connect(sock, dest_addr, dest_len); + if (ret) + goto out; + + *src_len = sizeof *src_addr; + ret = getsockname(sock, &src_addr->sa, src_len); + src_addr->sin.sin_port = port; +out: + close(sock); + return ret; +} + +static void ds_format_hdr(struct ds_header *hdr, union socket_addr *addr) +{ + if (addr->sa.sa_family == AF_INET) { + hdr->version = 4; + hdr->length = DS_IPV4_HDR_LEN; + hdr->port = addr->sin.sin_port; + hdr->addr.ipv4 = addr->sin.sin_addr.s_addr; + } else { + hdr->version = 6; + hdr->length = DS_IPV6_HDR_LEN; + hdr->port = addr->sin6.sin6_port; + hdr->addr.ipv6.flowinfo= addr->sin6.sin6_flowinfo; + memcpy(&hdr->addr.ipv6.addr, &addr->sin6.sin6_addr, 16); + } +} + +static int ds_add_qp_dest(struct ds_qp *qp, union socket_addr *addr, + socklen_t addrlen) +{ + struct ibv_port_attr port_attr; + struct ibv_ah_attr attr; + int ret; + + memcpy(&qp->dest.addr, addr, addrlen); + qp->dest.qp = qp; + qp->dest.qpn = qp->cm_id->qp->qp_num; + + ret = ibv_query_port(qp->cm_id->verbs, qp->cm_id->port_num, &port_attr); + if (ret) + return ret; + + memset(&attr, 0, sizeof attr); + attr.dlid = port_attr.lid; + attr.port_num = qp->cm_id->port_num; + qp->dest.ah = ibv_create_ah(qp->cm_id->pd, &attr); + if (!qp->dest.ah) + return ERR(ENOMEM); + + tsearch(&qp->dest.addr, &qp->rs->dest_map, ds_compare_addr); + return 0; +} + +static int ds_create_qp(struct rsocket *rs, union socket_addr *src_addr, + socklen_t addrlen, struct ds_qp **new_qp) +{ + struct ds_qp *qp; + struct ibv_qp_init_attr qp_attr; + struct epoll_event event; + int i, ret; + + qp = calloc(1, sizeof(*qp)); + if (!qp) + return ERR(ENOMEM); + + qp->rs = rs; + ret = rdma_create_id(NULL, &qp->cm_id, qp, RDMA_PS_UDP); + if (ret) + goto err; + + ds_format_hdr(&qp->hdr, src_addr); + ret = rdma_bind_addr(qp->cm_id, &src_addr->sa); + if (ret) + goto err; + + ret = ds_init_bufs(qp); + if (ret) + goto err; + + ret = rs_create_cq(rs, qp->cm_id); + if (ret) + goto err; + + memset(&qp_attr, 0, sizeof qp_attr); + qp_attr.qp_context = qp; + qp_attr.send_cq = qp->cm_id->send_cq; + qp_attr.recv_cq = qp->cm_id->recv_cq; + qp_attr.qp_type = IBV_QPT_UD; + qp_attr.sq_sig_all = 1; + qp_attr.cap.max_send_wr = rs->sq_size; + qp_attr.cap.max_recv_wr = rs->rq_size; + qp_attr.cap.max_send_sge = 1; + qp_attr.cap.max_recv_sge = 2; + qp_attr.cap.max_inline_data = rs->sq_inline; + ret = rdma_create_qp(qp->cm_id, NULL, &qp_attr); + if (ret) + goto err; + + ret = ds_add_qp_dest(qp, src_addr, addrlen); + if (ret) + goto err; + + event.events = EPOLLIN; + event.data.ptr = qp; + ret = epoll_ctl(rs->epfd, EPOLL_CTL_ADD, + qp->cm_id->recv_cq_channel->fd, &event); + if (ret) + goto err; + + for (i = 0; i < rs->rq_size; i++) { + ret = ds_post_recv(rs, qp, i * RS_SNDLOWAT); + if (ret) + goto err; + } + + ds_insert_qp(rs, qp); + *new_qp = qp; + return 0; +err: + ds_free_qp(qp); + return ret; +} + +static int ds_get_qp(struct rsocket *rs, union socket_addr *src_addr, + socklen_t addrlen, struct ds_qp **qp) +{ + if (rs->qp_list) { + *qp = rs->qp_list; + do { + if (!ds_compare_addr(rdma_get_local_addr((*qp)->cm_id), + src_addr)) + return 0; + + *qp = ds_next_qp(*qp); + } while (*qp != rs->qp_list); + } + + return ds_create_qp(rs, src_addr, addrlen, qp); +} + +static int ds_get_dest(struct rsocket *rs, const struct sockaddr *addr, + socklen_t addrlen, struct ds_dest **dest) +{ + union socket_addr src_addr; + socklen_t src_len; + struct ds_qp *qp; + struct ds_dest **tdest, *new_dest; + int ret = 0; + + fastlock_acquire(&rs->map_lock); + tdest = tfind(addr, &rs->dest_map, ds_compare_addr); + if (tdest) + goto found; + + ret = ds_get_src_addr(rs, addr, addrlen, &src_addr, &src_len); + if (ret) + goto out; + + ret = ds_get_qp(rs, &src_addr, src_len, &qp); + if (ret) + goto out; + + tdest = tfind(addr, &rs->dest_map, ds_compare_addr); + if (!tdest) { + new_dest = calloc(1, sizeof(*new_dest)); + if (!new_dest) { + ret = ERR(ENOMEM); + goto out; + } + + memcpy(&new_dest->addr, addr, addrlen); + new_dest->qp = qp; + tdest = tsearch(&new_dest->addr, &rs->dest_map, ds_compare_addr); + } + +found: + *dest = *tdest; +out: + fastlock_release(&rs->map_lock); + return ret; +} + int rconnect(int socket, const struct sockaddr *addr, socklen_t addrlen) { struct rsocket *rs; + int ret; rs = idm_at(&idm, socket); - memcpy(&rs->cm_id->route.addr.dst_addr, addr, addrlen); - return rs_do_connect(rs); + if (rs->type == SOCK_STREAM) { + memcpy(&rs->cm_id->route.addr.dst_addr, addr, addrlen); + ret = rs_do_connect(rs); + } else { + if (rs->state == rs_init) { + ret = ds_init_ep(rs); + if (ret) + return ret; + } + + fastlock_acquire(&rs->slock); + ret = connect(rs->udp_sock, addr, addrlen); + if (!ret) + ret = ds_get_dest(rs, addr, addrlen, &rs->conn_dest); + fastlock_release(&rs->slock); + } + return ret; } static int rs_post_write_msg(struct rsocket *rs, @@ -903,6 +1569,24 @@ static int rs_post_write(struct rsocket *rs, return rdma_seterrno(ibv_post_send(rs->cm_id->qp, &wr, &bad)); } +static int ds_post_send(struct rsocket *rs, struct ibv_sge *sge, + uint64_t wr_id) +{ + struct ibv_send_wr wr, *bad; + + wr.wr_id = wr_id; + wr.next = NULL; + wr.sg_list = sge; + wr.num_sge = 1; + wr.opcode = IBV_WR_SEND; + wr.send_flags = (sge->length <= rs->sq_inline) ? IBV_SEND_INLINE : 0; + wr.wr.ud.ah = rs->conn_dest->ah; + wr.wr.ud.remote_qpn = rs->conn_dest->qpn; + wr.wr.ud.remote_qkey = RDMA_UDP_QKEY; + + return rdma_seterrno(ibv_post_send(rs->conn_dest->qp->cm_id->qp, &wr, &bad)); +} + /* * Update target SGE before sending data. Otherwise the remote side may * update the entry before we do. @@ -1046,7 +1730,7 @@ static int rs_poll_cq(struct rsocket *rs) rs->state = rs_disconnected; return 0; } else if (rs_msg_data(imm_data) == RS_CTRL_SHUTDOWN) { - rs->state &= ~rs_connect_rd; + rs->state &= ~rs_readable; } break; case RS_OP_WRITE: @@ -1133,46 +1817,213 @@ static int rs_get_cq_event(struct rsocket *rs) */ static int rs_process_cq(struct rsocket *rs, int nonblock, int (*test)(struct rsocket *rs)) { - int ret; + int ret; + + fastlock_acquire(&rs->cq_lock); + do { + rs_update_credits(rs); + ret = rs_poll_cq(rs); + if (test(rs)) { + ret = 0; + break; + } else if (ret) { + break; + } else if (nonblock) { + ret = ERR(EWOULDBLOCK); + } else if (!rs->cq_armed) { + ibv_req_notify_cq(rs->cm_id->recv_cq, 0); + rs->cq_armed = 1; + } else { + rs_update_credits(rs); + fastlock_acquire(&rs->cq_wait_lock); + fastlock_release(&rs->cq_lock); + + ret = rs_get_cq_event(rs); + fastlock_release(&rs->cq_wait_lock); + fastlock_acquire(&rs->cq_lock); + } + } while (!ret); + + rs_update_credits(rs); + fastlock_release(&rs->cq_lock); + return ret; +} + +static int rs_get_comp(struct rsocket *rs, int nonblock, int (*test)(struct rsocket *rs)) +{ + struct timeval s, e; + uint32_t poll_time = 0; + int ret; + + do { + ret = rs_process_cq(rs, 1, test); + if (!ret || nonblock || errno != EWOULDBLOCK) + return ret; + + if (!poll_time) + gettimeofday(&s, NULL); + + gettimeofday(&e, NULL); + poll_time = (e.tv_sec - s.tv_sec) * 1000000 + + (e.tv_usec - s.tv_usec) + 1; + } while (poll_time <= polling_time); + + ret = rs_process_cq(rs, 0, test); + return ret; +} + +static int ds_valid_recv(struct ds_qp *qp, struct ibv_wc *wc) +{ + struct ds_header *hdr; + + hdr = (struct ds_header *) (qp->rbuf + ds_wr_offset(wc->wr_id)); + return ((wc->byte_len >= sizeof(struct ibv_grh) + DS_IPV4_HDR_LEN) && + ((hdr->version == 4 && hdr->length == DS_IPV4_HDR_LEN) || + (hdr->version == 6 && hdr->length == DS_IPV6_HDR_LEN))); +} + +/* + * Poll all CQs associated with a datagram rsocket. We need to drop any + * received messages that we do not have room to store. To limit drops, + * we only poll if we have room to store the receive or we need a send + * buffer. To ensure fairness, we poll the CQs round robin, remembering + * where we left off. + */ +static void ds_poll_cqs(struct rsocket *rs) +{ + struct ds_qp *qp; + struct ds_smsg *smsg; + struct ds_rmsg *rmsg; + struct ibv_wc wc; + int ret, cnt; + + if (!(qp = rs->qp_list)) + return; + + do { + cnt = 0; + do { + ret = ibv_poll_cq(qp->cm_id->recv_cq, 1, &wc); + if (ret <= 0) { + qp = ds_next_qp(qp); + continue; + } + + if (ds_wr_is_recv(wc.wr_id)) { + if (rs->rqe_avail && wc.status == IBV_WC_SUCCESS && + ds_valid_recv(qp, &wc)) { + rs->rqe_avail--; + rmsg = &rs->dmsg[rs->rmsg_tail]; + rmsg->qp = qp; + rmsg->offset = ds_wr_offset(wc.wr_id); + rmsg->length = wc.byte_len - sizeof(struct ibv_grh); + if (++rs->rmsg_tail == rs->rq_size + 1) + rs->rmsg_tail = 0; + } else { + ds_post_recv(rs, qp, ds_wr_offset(wc.wr_id)); + } + } else { + smsg = (struct ds_smsg *) + (rs->sbuf + ds_wr_offset(wc.wr_id)); + smsg->next = rs->smsg_free; + rs->smsg_free = smsg; + rs->sqe_avail++; + } + + qp = ds_next_qp(qp); + if (!rs->rqe_avail && rs->sqe_avail) { + rs->qp_list = qp; + return; + } + cnt++; + } while (qp != rs->qp_list); + } while (cnt); +} + +static void ds_req_notify_cqs(struct rsocket *rs) +{ + struct ds_qp *qp; + + if (!(qp = rs->qp_list)) + return; + + do { + if (!qp->cq_armed) { + ibv_req_notify_cq(qp->cm_id->recv_cq, 0); + qp->cq_armed = 1; + } + qp = ds_next_qp(qp); + } while (qp != rs->qp_list); +} + +static int ds_get_cq_event(struct rsocket *rs) +{ + struct epoll_event event; + struct ds_qp *qp; + struct ibv_cq *cq; + void *context; + int ret; + + if (!rs->cq_armed) + return 0; + + ret = epoll_wait(rs->epfd, &event, 1, -1); + if (ret <= 0) + return ret; + + qp = event.data.ptr; + ret = ibv_get_cq_event(qp->cm_id->recv_cq_channel, &cq, &context); + if (!ret) { + ibv_ack_cq_events(qp->cm_id->recv_cq, 1); + qp->cq_armed = 0; + rs->cq_armed = 0; + } + + return ret; +} + +static int rs_have_rdata(struct rsocket *rs); +static int ds_can_send(struct rsocket *rs); +static int rs_poll_all(struct rsocket *rs); +static int ds_all_sends_done(struct rsocket *rs); + +static int ds_process_cqs(struct rsocket *rs, int nonblock, int (*test)(struct rsocket *rs)) +{ + int ret = 0; fastlock_acquire(&rs->cq_lock); do { - rs_update_credits(rs); - ret = rs_poll_cq(rs); + ds_poll_cqs(rs); if (test(rs)) { ret = 0; break; - } else if (ret) { - break; } else if (nonblock) { ret = ERR(EWOULDBLOCK); } else if (!rs->cq_armed) { - ibv_req_notify_cq(rs->cm_id->recv_cq, 0); + ds_req_notify_cqs(rs); rs->cq_armed = 1; } else { - rs_update_credits(rs); fastlock_acquire(&rs->cq_wait_lock); fastlock_release(&rs->cq_lock); - ret = rs_get_cq_event(rs); + ret = ds_get_cq_event(rs); fastlock_release(&rs->cq_wait_lock); fastlock_acquire(&rs->cq_lock); } } while (!ret); - rs_update_credits(rs); fastlock_release(&rs->cq_lock); return ret; } -static int rs_get_comp(struct rsocket *rs, int nonblock, int (*test)(struct rsocket *rs)) +static int ds_get_comp(struct rsocket *rs, int nonblock, int (*test)(struct rsocket *rs)) { struct timeval s, e; uint32_t poll_time = 0; int ret; do { - ret = rs_process_cq(rs, 1, test); + ret = ds_process_cqs(rs, 1, test); if (!ret || nonblock || errno != EWOULDBLOCK) return ret; @@ -1184,7 +2035,7 @@ static int rs_get_comp(struct rsocket *rs, int nonblock, int (*test)(struct rsoc (e.tv_usec - s.tv_usec) + 1; } while (poll_time <= polling_time); - ret = rs_process_cq(rs, 0, test); + ret = ds_process_cqs(rs, 0, test); return ret; } @@ -1219,9 +2070,19 @@ static int rs_can_send(struct rsocket *rs) (rs->target_sgl[rs->target_sge].length != 0); } +static int ds_can_send(struct rsocket *rs) +{ + return rs->sqe_avail; +} + +static int ds_all_sends_done(struct rsocket *rs) +{ + return rs->sqe_avail == rs->sq_size; +} + static int rs_conn_can_send(struct rsocket *rs) { - return rs_can_send(rs) || !(rs->state & rs_connect_wr); + return rs_can_send(rs) || !(rs->state & rs_writable); } static int rs_conn_can_send_ctrl(struct rsocket *rs) @@ -1236,7 +2097,7 @@ static int rs_have_rdata(struct rsocket *rs) static int rs_conn_have_rdata(struct rsocket *rs) { - return rs_have_rdata(rs) || !(rs->state & rs_connect_rd); + return rs_have_rdata(rs) || !(rs->state & rs_readable); } static int rs_conn_all_sends_done(struct rsocket *rs) @@ -1245,6 +2106,67 @@ static int rs_conn_all_sends_done(struct rsocket *rs) !(rs->state & rs_connected); } +static void ds_set_src(struct sockaddr *addr, socklen_t *addrlen, + struct ds_header *hdr) +{ + union socket_addr sa; + + memset(&sa, 0, sizeof sa); + if (hdr->version == 4) { + if (*addrlen > sizeof(sa.sin)) + *addrlen = sizeof(sa.sin); + + sa.sin.sin_family = AF_INET; + sa.sin.sin_port = hdr->port; + sa.sin.sin_addr.s_addr = hdr->addr.ipv4; + } else { + if (*addrlen > sizeof(sa.sin6)) + *addrlen = sizeof(sa.sin6); + + sa.sin6.sin6_family = AF_INET6; + sa.sin6.sin6_port = hdr->port; + sa.sin6.sin6_flowinfo = hdr->addr.ipv6.flowinfo; + memcpy(&sa.sin6.sin6_addr, &hdr->addr.ipv6.addr, 16); + } + memcpy(addr, &sa, *addrlen); +} + +static ssize_t ds_recvfrom(struct rsocket *rs, void *buf, size_t len, int flags, + struct sockaddr *src_addr, socklen_t *addrlen) +{ + struct ds_rmsg *rmsg; + struct ds_header *hdr; + int ret; + + if (!(rs->state & rs_readable)) + return ERR(EINVAL); + + if (!rs_have_rdata(rs)) { + ret = ds_get_comp(rs, rs_nonblocking(rs, flags), + rs_have_rdata); + if (ret) + return ret; + } + + rmsg = &rs->dmsg[rs->rmsg_head]; + hdr = (struct ds_header *) (rmsg->qp->rbuf + rmsg->offset); + if (len > rmsg->length - hdr->length) + len = rmsg->length - hdr->length; + + memcpy(buf, (void *) hdr + hdr->length, len); + if (addrlen) + ds_set_src(src_addr, addrlen, hdr); + + if (!(flags & MSG_PEEK)) { + ds_post_recv(rs, rmsg->qp, rmsg->offset); + if (++rs->rmsg_head == rs->rq_size + 1) + rs->rmsg_head = 0; + rs->rqe_avail++; + } + + return len; +} + static ssize_t rs_peek(struct rsocket *rs, void *buf, size_t len) { size_t left = len; @@ -1290,6 +2212,13 @@ ssize_t rrecv(int socket, void *buf, size_t len, int flags) int ret; rs = idm_at(&idm, socket); + if (rs->type == SOCK_DGRAM) { + fastlock_acquire(&rs->rlock); + ret = ds_recvfrom(rs, buf, len, flags, NULL, 0); + fastlock_release(&rs->rlock); + return ret; + } + if (rs->state & rs_opening) { ret = rs_do_connect(rs); if (ret) { @@ -1339,7 +2268,7 @@ ssize_t rrecv(int socket, void *buf, size_t len, int flags) rs->rbuf_bytes_avail += rsize; } - } while (left && (flags & MSG_WAITALL) && (rs->state & rs_connect_rd)); + } while (left && (flags & MSG_WAITALL) && (rs->state & rs_readable)); fastlock_release(&rs->rlock); return ret ? ret : len - left; @@ -1348,8 +2277,17 @@ ssize_t rrecv(int socket, void *buf, size_t len, int flags) ssize_t rrecvfrom(int socket, void *buf, size_t len, int flags, struct sockaddr *src_addr, socklen_t *addrlen) { + struct rsocket *rs; int ret; + rs = idm_at(&idm, socket); + if (rs->type == SOCK_DGRAM) { + fastlock_acquire(&rs->rlock); + ret = ds_recvfrom(rs, buf, len, flags, src_addr, addrlen); + fastlock_release(&rs->rlock); + return ret; + } + ret = rrecv(socket, buf, len, flags); if (ret > 0 && src_addr) rgetpeername(socket, src_addr, addrlen); @@ -1391,14 +2329,14 @@ static int rs_send_iomaps(struct rsocket *rs, int flags) struct rs_iomap iom; int ret; - fastlock_acquire(&rs->iomap_lock); + fastlock_acquire(&rs->map_lock); while (!dlist_empty(&rs->iomap_queue)) { if (!rs_can_send(rs)) { ret = rs_get_comp(rs, rs_nonblocking(rs, flags), rs_conn_can_send); if (ret) break; - if (!(rs->state & rs_connect_wr)) { + if (!(rs->state & rs_writable)) { ret = ERR(ECONNRESET); break; } @@ -1447,10 +2385,92 @@ static int rs_send_iomaps(struct rsocket *rs, int flags) } rs->iomap_pending = !dlist_empty(&rs->iomap_queue); - fastlock_release(&rs->iomap_lock); + fastlock_release(&rs->map_lock); return ret; } +static ssize_t ds_sendv_udp(struct rsocket *rs, const struct iovec *iov, + int iovcnt, int flags, uint8_t op) +{ + struct ds_udp_header hdr; + struct msghdr msg; + struct iovec miov[8]; + ssize_t ret; + + if (iovcnt > 8) + return ERR(ENOTSUP); + + hdr.tag = htonl(DS_UDP_TAG); + hdr.version = rs->conn_dest->qp->hdr.version; + hdr.op = op; + hdr.reserved = 0; + hdr.qpn = htonl(rs->conn_dest->qp->cm_id->qp->qp_num & 0xFFFFFF); + if (rs->conn_dest->qp->hdr.version == 4) { + hdr.length = DS_UDP_IPV4_HDR_LEN; + hdr.addr.ipv4 = rs->conn_dest->qp->hdr.addr.ipv4; + } else { + hdr.length = DS_UDP_IPV6_HDR_LEN; + memcpy(hdr.addr.ipv6, &rs->conn_dest->qp->hdr.addr.ipv6, 16); + } + + miov[0].iov_base = &hdr; + miov[0].iov_len = hdr.length; + if (iov && iovcnt) + memcpy(&miov[1], iov, sizeof *iov * iovcnt); + + memset(&msg, 0, sizeof msg); + msg.msg_name = &rs->conn_dest->addr; + msg.msg_namelen = ucma_addrlen(&rs->conn_dest->addr.sa); + msg.msg_iov = miov; + msg.msg_iovlen = iovcnt + 1; + ret = sendmsg(rs->udp_sock, &msg, flags); + return ret > 0 ? ret - hdr.length : ret; +} + +static ssize_t ds_send_udp(struct rsocket *rs, const void *buf, size_t len, + int flags, uint8_t op) +{ + struct iovec iov; + if (buf && len) { + iov.iov_base = (void *) buf; + iov.iov_len = len; + return ds_sendv_udp(rs, &iov, 1, flags, op); + } else { + return ds_sendv_udp(rs, NULL, 0, flags, op); + } +} + +static ssize_t dsend(struct rsocket *rs, const void *buf, size_t len, int flags) +{ + struct ds_smsg *msg; + struct ibv_sge sge; + uint64_t offset; + int ret = 0; + + if (!rs->conn_dest->ah) + return ds_send_udp(rs, buf, len, flags, RS_OP_DATA); + + if (!ds_can_send(rs)) { + ret = ds_get_comp(rs, rs_nonblocking(rs, flags), ds_can_send); + if (ret) + return ret; + } + + msg = rs->smsg_free; + rs->smsg_free = msg->next; + rs->sqe_avail--; + + memcpy((void *) msg, &rs->conn_dest->qp->hdr, rs->conn_dest->qp->hdr.length); + memcpy((void *) msg + rs->conn_dest->qp->hdr.length, buf, len); + sge.addr = (uintptr_t) msg; + sge.length = rs->conn_dest->qp->hdr.length + len; + sge.lkey = rs->conn_dest->qp->smr->lkey; + offset = (uint8_t *) msg - rs->sbuf; + + ret = ds_post_send(rs, &sge, ds_send_wr_id(offset, sge.length)); + return ret ? ret : len; +} + /* * We overlap sending the data, by posting a small work request immediately, * then increasing the size of the send on each iteration. @@ -1464,6 +2484,13 @@ ssize_t rsend(int socket, const void *buf, size_t len, int flags) int ret = 0; rs = idm_at(&idm, socket); + if (rs->type == SOCK_DGRAM) { + fastlock_acquire(&rs->slock); + ret = dsend(rs, buf, len, flags); + fastlock_release(&rs->slock); + return ret; + } + if (rs->state & rs_opening) { ret = rs_do_connect(rs); if (ret) { @@ -1485,7 +2512,7 @@ ssize_t rsend(int socket, const void *buf, size_t len, int flags) rs_conn_can_send); if (ret) break; - if (!(rs->state & rs_connect_wr)) { + if (!(rs->state & rs_writable)) { ret = ERR(ECONNRESET); break; } @@ -1538,10 +2565,34 @@ out: ssize_t rsendto(int socket, const void *buf, size_t len, int flags, const struct sockaddr *dest_addr, socklen_t addrlen) { - if (dest_addr || addrlen) - return ERR(EISCONN); + struct rsocket *rs; + int ret; - return rsend(socket, buf, len, flags); + rs = idm_at(&idm, socket); + if (rs->type == SOCK_STREAM) { + if (dest_addr || addrlen) + return ERR(EISCONN); + + return rsend(socket, buf, len, flags); + } + + if (rs->state == rs_init) { + ret = ds_init_ep(rs); + if (ret) + return ret; + } + + fastlock_acquire(&rs->slock); + if (!rs->conn_dest || ds_compare_addr(dest_addr, &rs->conn_dest->addr)) { + ret = ds_get_dest(rs, dest_addr, addrlen, &rs->conn_dest); + if (ret) + goto out; + } + + ret = dsend(rs, buf, len, flags); +out: + fastlock_release(&rs->slock); + return ret; } static void rs_copy_iov(void *dst, const struct iovec **iov, size_t *offset, size_t len) @@ -1600,7 +2651,7 @@ static ssize_t rsendv(int socket, const struct iovec *iov, int iovcnt, int flags rs_conn_can_send); if (ret) break; - if (!(rs->state & rs_connect_wr)) { + if (!(rs->state & rs_writable)) { ret = ERR(ECONNRESET); break; } @@ -1653,7 +2704,7 @@ ssize_t rsendmsg(int socket, const struct msghdr *msg, int flags) if (msg->msg_control && msg->msg_controllen) return ERR(ENOTSUP); - return rsendv(socket, msg->msg_iov, (int) msg->msg_iovlen, msg->msg_flags); + return rsendv(socket, msg->msg_iov, (int) msg->msg_iovlen, flags); } ssize_t rwrite(int socket, const void *buf, size_t count) @@ -1690,8 +2741,8 @@ static int rs_poll_rs(struct rsocket *rs, int events, int ret; check_cq: - if ((rs->state & rs_connected) || (rs->state == rs_disconnected) || - (rs->state & rs_error)) { + if ((rs->type == SOCK_STREAM) && ((rs->state & rs_connected) || + (rs->state == rs_disconnected) || (rs->state & rs_error))) { rs_process_cq(rs, nonblock, test); revents = 0; @@ -1706,6 +2757,16 @@ check_cq: revents |= POLLERR; } + return revents; + } else if (rs->type == SOCK_DGRAM) { + ds_process_cqs(rs, nonblock, test); + + revents = 0; + if ((events & POLLIN) && rs_have_rdata(rs)) + revents |= POLLIN; + if ((events & POLLOUT) && ds_can_send(rs)) + revents |= POLLOUT; + return revents; } @@ -1766,11 +2827,14 @@ static int rs_poll_arm(struct pollfd *rfds, struct pollfd *fds, nfds_t nfds) if (fds[i].revents) return 1; - if (rs->state >= rs_connected) - rfds[i].fd = rs->cm_id->recv_cq_channel->fd; - else - rfds[i].fd = rs->cm_id->channel->fd; - + if (rs->type == SOCK_STREAM) { + if (rs->state >= rs_connected) + rfds[i].fd = rs->cm_id->recv_cq_channel->fd; + else + rfds[i].fd = rs->cm_id->channel->fd; + } else { + rfds[i].fd = rs->epfd; + } rfds[i].events = POLLIN; } else { rfds[i].fd = fds[i].fd; @@ -1793,7 +2857,10 @@ static int rs_poll_events(struct pollfd *rfds, struct pollfd *fds, nfds_t nfds) rs = idm_lookup(&idm, fds[i].fd); if (rs) { - rs_get_cq_event(rs); + if (rs->type == SOCK_STREAM) + rs_get_cq_event(rs); + else + ds_get_cq_event(rs); fds[i].revents = rs_poll_rs(rs, fds[i].events, 1, rs_poll_all); } else { fds[i].revents = rfds[i].revents; @@ -1949,7 +3016,7 @@ int rshutdown(int socket, int how) rs = idm_at(&idm, socket); if (how == SHUT_RD) { - rs->state &= ~rs_connect_rd; + rs->state &= ~rs_readable; return 0; } @@ -1959,10 +3026,10 @@ int rshutdown(int socket, int how) if (rs->state & rs_connected) { if (how == SHUT_RDWR) { ctrl = RS_CTRL_DISCONNECT; - rs->state &= ~(rs_connect_rd | rs_connect_wr); + rs->state &= ~(rs_readable | rs_writable); } else { - rs->state &= ~rs_connect_wr; - ctrl = (rs->state & rs_connect_rd) ? + rs->state &= ~rs_writable; + ctrl = (rs->state & rs_readable) ? RS_CTRL_SHUTDOWN : RS_CTRL_DISCONNECT; } if (!rs->ctrl_avail) { @@ -1987,13 +3054,32 @@ int rshutdown(int socket, int how) return 0; } +static void ds_shutdown(struct rsocket *rs) +{ + if (rs->svcs) + rs_modify_svcs(rs, 0); + + if (rs->fd_flags & O_NONBLOCK) + rs_set_nonblocking(rs, 0); + + rs->state &= ~(rs_readable | rs_writable); + ds_process_cqs(rs, 0, ds_all_sends_done); + + if (rs->fd_flags & O_NONBLOCK) + rs_set_nonblocking(rs, 1); +} + int rclose(int socket) { struct rsocket *rs; rs = idm_at(&idm, socket); - if (rs->state & rs_connected) - rshutdown(socket, SHUT_RDWR); + if (rs->type == SOCK_STREAM) { + if (rs->state & rs_connected) + rshutdown(socket, SHUT_RDWR); + } else { + ds_shutdown(rs); + } rs_free(rs); return 0; @@ -2018,8 +3104,12 @@ int rgetpeername(int socket, struct sockaddr *addr, socklen_t *addrlen) struct rsocket *rs; rs = idm_at(&idm, socket); - rs_copy_addr(addr, rdma_get_peer_addr(rs->cm_id), addrlen); - return 0; + if (rs->type == SOCK_STREAM) { + rs_copy_addr(addr, rdma_get_peer_addr(rs->cm_id), addrlen); + return 0; + } else { + return getpeername(rs->udp_sock, addr, addrlen); + } } int rgetsockname(int socket, struct sockaddr *addr, socklen_t *addrlen) @@ -2027,8 +3117,12 @@ int rgetsockname(int socket, struct sockaddr *addr, socklen_t *addrlen) struct rsocket *rs; rs = idm_at(&idm, socket); - rs_copy_addr(addr, rdma_get_local_addr(rs->cm_id), addrlen); - return 0; + if (rs->type == SOCK_STREAM) { + rs_copy_addr(addr, rdma_get_local_addr(rs->cm_id), addrlen); + return 0; + } else { + return getsockname(rs->udp_sock, addr, addrlen); + } } int rsetsockopt(int socket, int level, int optname, @@ -2040,22 +3134,31 @@ int rsetsockopt(int socket, int level, int optname, ret = ERR(ENOTSUP); rs = idm_at(&idm, socket); + if (rs->type == SOCK_DGRAM && level != SOL_RDMA) { + ret = setsockopt(rs->udp_sock, level, optname, optval, optlen); + if (ret) + return ret; + } + switch (level) { case SOL_SOCKET: opts = &rs->so_opts; switch (optname) { case SO_REUSEADDR: - ret = rdma_set_option(rs->cm_id, RDMA_OPTION_ID, - RDMA_OPTION_ID_REUSEADDR, - (void *) optval, optlen); - if (ret && ((errno == ENOSYS) || ((rs->state != rs_init) && - rs->cm_id->context && - (rs->cm_id->verbs->device->transport_type == IBV_TRANSPORT_IB)))) - ret = 0; + if (rs->type == SOCK_STREAM) { + ret = rdma_set_option(rs->cm_id, RDMA_OPTION_ID, + RDMA_OPTION_ID_REUSEADDR, + (void *) optval, optlen); + if (ret && ((errno == ENOSYS) || ((rs->state != rs_init) && + rs->cm_id->context && + (rs->cm_id->verbs->device->transport_type == IBV_TRANSPORT_IB)))) + ret = 0; + } opt_on = *(int *) optval; break; case SO_RCVBUF: - if (!rs->rbuf) + if ((rs->type == SOCK_STREAM && !rs->rbuf) || + (rs->type == SOCK_DGRAM && !rs->qp_list)) rs->rbuf_size = (*(uint32_t *) optval) << 1; ret = 0; break; @@ -2101,9 +3204,11 @@ int rsetsockopt(int socket, int level, int optname, opts = &rs->ipv6_opts; switch (optname) { case IPV6_V6ONLY: - ret = rdma_set_option(rs->cm_id, RDMA_OPTION_ID, - RDMA_OPTION_ID_AFONLY, - (void *) optval, optlen); + if (rs->type == SOCK_STREAM) { + ret = rdma_set_option(rs->cm_id, RDMA_OPTION_ID, + RDMA_OPTION_ID_AFONLY, + (void *) optval, optlen); + } opt_on = *(int *) optval; break; default: @@ -2315,7 +3420,7 @@ off_t riomap(int socket, void *buf, size_t len, int prot, int flags, off_t offse if (!rs->cm_id->pd || (prot & ~(PROT_WRITE | PROT_NONE))) return ERR(EINVAL); - fastlock_acquire(&rs->iomap_lock); + fastlock_acquire(&rs->map_lock); if (prot & PROT_WRITE) { iomr = rs_get_iomap_mr(rs); access |= IBV_ACCESS_REMOTE_WRITE; @@ -2349,7 +3454,7 @@ off_t riomap(int socket, void *buf, size_t len, int prot, int flags, off_t offse dlist_insert_tail(&iomr->entry, &rs->iomap_list); } out: - fastlock_release(&rs->iomap_lock); + fastlock_release(&rs->map_lock); return offset; } @@ -2361,7 +3466,7 @@ int riounmap(int socket, void *buf, size_t len) int ret = 0; rs = idm_at(&idm, socket); - fastlock_acquire(&rs->iomap_lock); + fastlock_acquire(&rs->map_lock); for (entry = rs->iomap_list.next; entry != &rs->iomap_list; entry = entry->next) { @@ -2382,7 +3487,7 @@ int riounmap(int socket, void *buf, size_t len) } ret = ERR(EINVAL); out: - fastlock_release(&rs->iomap_lock); + fastlock_release(&rs->map_lock); return ret; } @@ -2426,7 +3531,7 @@ size_t riowrite(int socket, const void *buf, size_t count, off_t offset, int fla rs_conn_can_send); if (ret) break; - if (!(rs->state & rs_connect_wr)) { + if (!(rs->state & rs_writable)) { ret = ERR(ECONNRESET); break; } @@ -2476,3 +3581,272 @@ out: return (ret && left == count) ? ret : count - left; } + +static int rs_svc_grow_sets(void) +{ + struct rsocket **rss; + struct pollfd *fds; + void *set; + + set = calloc(svc_size + 2, sizeof(*rss) + sizeof(*fds)); + if (!set) + return ENOMEM; + + svc_size += 2; + rss = set; + fds = set + sizeof(*rss) * svc_size; + if (svc_cnt) { + memcpy(rss, svc_rss, sizeof(*rss) * svc_cnt); + memcpy(fds, svc_fds, sizeof(*fds) * svc_cnt); + } + + free(svc_rss); + free(svc_fds); + svc_rss = rss; + svc_fds = fds; + return 0; +} + +/* + * Index 0 is reserved for the service's communication socket. + */ +static int rs_svc_add_rs(struct rsocket *rs) +{ + int ret; + + if (svc_cnt >= svc_size - 1) { + ret = rs_svc_grow_sets(); + if (ret) + return ret; + } + + svc_rss[++svc_cnt] = rs; + svc_fds[svc_cnt].fd = rs->udp_sock; + svc_fds[svc_cnt].events = POLLIN; + svc_fds[svc_cnt].revents = 0; + return 0; +} + +static int rs_svc_rm_rs(struct rsocket *rs) +{ + int i; + + for (i = 1; i <= svc_cnt; i++) { + if (svc_rss[i] == rs) { + svc_cnt--; + svc_rss[i] = svc_rss[svc_cnt]; + svc_fds[i] = svc_fds[svc_cnt]; + return 0; + } + } + return EBADF; +} + +static void rs_svc_process_sock(void) +{ + struct rs_svc_msg msg; + + read(svc_sock[1], &msg, sizeof msg); + if (msg.svcs & RS_SVC_DGRAM) { + msg.status = rs_svc_add_rs(msg.rs); + } else if (!msg.svcs) { + msg.status = rs_svc_rm_rs(msg.rs); + } + + if (!msg.status) + msg.rs->svcs = msg.svcs; + write(svc_sock[1], &msg, sizeof msg); +} + +static uint8_t rs_svc_sgid_index(struct ds_dest *dest, union ibv_gid *sgid) +{ + union ibv_gid gid; + int i, ret; + + for (i = 0; i < 16; i++) { + ret = ibv_query_gid(dest->qp->cm_id->verbs, dest->qp->cm_id->port_num, + i, &gid); + if (!memcmp(sgid, &gid, sizeof gid)) + return i; + } + return 0; +} + +static uint8_t rs_svc_path_bits(struct ds_dest *dest) +{ + struct ibv_port_attr attr; + + if (!ibv_query_port(dest->qp->cm_id->verbs, dest->qp->cm_id->port_num, &attr)) + return (uint8_t) ((1 << attr.lmc) - 1); + return 0x7f; +} + +static void rs_svc_create_ah(struct rsocket *rs, struct ds_dest *dest, uint32_t qpn) +{ + union socket_addr saddr; + struct rdma_cm_id *id; + struct ibv_ah_attr attr; + int ret; + + if (dest->ah) { + fastlock_acquire(&rs->slock); + ibv_destroy_ah(dest->ah); + dest->ah = NULL; + fastlock_release(&rs->slock); + } + + ret = rdma_create_id(NULL, &id, NULL, dest->qp->cm_id->ps); + if (ret) + return; + + memcpy(&saddr, rdma_get_local_addr(dest->qp->cm_id), + ucma_addrlen(rdma_get_local_addr(dest->qp->cm_id))); + if (saddr.sa.sa_family == AF_INET) + saddr.sin.sin_port = 0; + else + saddr.sin6.sin6_port = 0; + ret = rdma_resolve_addr(id, &saddr.sa, &dest->addr.sa, 2000); + if (ret) + goto out; + + ret = rdma_resolve_route(id, 2000); + if (ret) + goto out; + + memset(&attr, 0, sizeof attr); + if (id->route.path_rec->hop_limit > 1) { + attr.is_global = 1; + attr.grh.dgid = id->route.path_rec->dgid; + attr.grh.flow_label = ntohl(id->route.path_rec->flow_label); + attr.grh.sgid_index = rs_svc_sgid_index(dest, &id->route.path_rec->sgid); + attr.grh.hop_limit = id->route.path_rec->hop_limit; + attr.grh.traffic_class = id->route.path_rec->traffic_class; + } + attr.dlid = ntohs(id->route.path_rec->dlid); + attr.sl = id->route.path_rec->sl; + attr.src_path_bits = id->route.path_rec->slid & rs_svc_path_bits(dest); + attr.static_rate = id->route.path_rec->rate; + attr.port_num = id->port_num; + + fastlock_acquire(&rs->slock); + dest->qpn = qpn; + dest->ah = ibv_create_ah(dest->qp->cm_id->pd, &attr); + fastlock_release(&rs->slock); +out: + rdma_destroy_id(id); +} + +static int rs_svc_valid_udp_hdr(struct ds_udp_header *udp_hdr, + union socket_addr *addr) +{ + return (udp_hdr->tag == ntohl(DS_UDP_TAG)) && + ((udp_hdr->version == 4 && addr->sa.sa_family == AF_INET && + udp_hdr->length == DS_UDP_IPV4_HDR_LEN) || + (udp_hdr->version == 6 && addr->sa.sa_family == AF_INET6 && + udp_hdr->length == DS_UDP_IPV6_HDR_LEN)); +} + +static void rs_svc_forward(struct rsocket *rs, void *buf, size_t len, + union socket_addr *src) +{ + struct ds_header hdr; + struct ds_smsg *msg; + struct ibv_sge sge; + uint64_t offset; + + if (!ds_can_send(rs)) { + if (ds_get_comp(rs, 0, ds_can_send)) + return; + } + + msg = rs->smsg_free; + rs->smsg_free = msg->next; + rs->sqe_avail--; + + ds_format_hdr(&hdr, src); + memcpy((void *) msg, &hdr, hdr.length); + memcpy((void *) msg + hdr.length, buf, len); + sge.addr = (uintptr_t) msg; + sge.length = hdr.length + len; + sge.lkey = rs->conn_dest->qp->smr->lkey; + offset = (uint8_t *) msg - rs->sbuf; + + ds_post_send(rs, &sge, ds_send_wr_id(offset, sge.length)); +} + +static void rs_svc_process_rs(struct rsocket *rs) +{ + struct ds_dest *dest, *cur_dest; + struct ds_udp_header *udp_hdr; + union socket_addr addr; + socklen_t addrlen = sizeof addr; + int len, ret; + + ret = recvfrom(rs->udp_sock, svc_buf, sizeof svc_buf, 0, &addr.sa, &addrlen); + if (ret < DS_UDP_IPV4_HDR_LEN) + return; + + udp_hdr = (struct ds_udp_header *) svc_buf; + if (!rs_svc_valid_udp_hdr(udp_hdr, &addr)) + return; + + len = ret - udp_hdr->length; + udp_hdr->tag = ntohl(udp_hdr->tag); + udp_hdr->qpn = ntohl(udp_hdr->qpn) & 0xFFFFFF; + ret = ds_get_dest(rs, &addr.sa, addrlen, &dest); + if (ret) + return; + + if (udp_hdr->op == RS_OP_DATA) { + fastlock_acquire(&rs->slock); + cur_dest = rs->conn_dest; + rs->conn_dest = dest; + ds_send_udp(rs, NULL, 0, 0, RS_OP_CTRL); + rs->conn_dest = cur_dest; + fastlock_release(&rs->slock); + } + + if (!dest->ah || (dest->qpn != udp_hdr->qpn)) + rs_svc_create_ah(rs, dest, udp_hdr->qpn); + + /* to do: handle when dest local ip address doesn't match udp ip */ + if (udp_hdr->op == RS_OP_DATA) { + fastlock_acquire(&rs->slock); + cur_dest = rs->conn_dest; + rs->conn_dest = &dest->qp->dest; + rs_svc_forward(rs, svc_buf + udp_hdr->length, len, &addr); + rs->conn_dest = cur_dest; + fastlock_release(&rs->slock); + } +} + +static void *rs_svc_run(void *arg) +{ + struct rs_svc_msg msg; + int i, ret; + + ret = rs_svc_grow_sets(); + if (ret) { + msg.status = ret; + write(svc_sock[1], &msg, sizeof msg); + return (void *) (uintptr_t) ret; + } + + svc_fds[0].fd = svc_sock[1]; + svc_fds[0].events = POLLIN; + do { + for (i = 0; i <= svc_cnt; i++) + svc_fds[i].revents = 0; + + poll(svc_fds, svc_cnt + 1, -1); + if (svc_fds[0].revents) + rs_svc_process_sock(); + + for (i = 1; i <= svc_cnt; i++) { + if (svc_fds[i].revents) + rs_svc_process_rs(svc_rss[i]); + } + } while (svc_cnt >= 1); + + return NULL; +} -- 2.41.0