From: Sean Hefty Date: Fri, 9 Nov 2012 18:26:38 +0000 (-0800) Subject: rsocket: Add datagram support X-Git-Tag: v1.0.17~8 X-Git-Url: https://openfabrics.org/gitweb/?a=commitdiff_plain;h=e6e93ed4231976eeab707b31e283be0a7acff6db;p=~shefty%2Flibrdmacm.git rsocket: Add datagram support Add datagram support through the rsocket API. Datagram support is handled through an entirely different protocol and internal implementation than streaming sockets. Unlike connected rsockets, datagram rsockets are not necessarily bound to a network (IP) address. A datagram socket may use any number of network (IP) addresses, including those which map to different RDMA devices. As a result, a single datagram rsocket must support using multiple RDMA devices and ports, and a datagram rsocket references a single UDP socket, plus zero or more UD QPs. Rsockets uses headers inserted before user data sent over UDP sockets to resolve remote UD QP numbers. When a user first attempts to send a datagram to a remote address (IP and UDP port), rsockets will take the following steps: 1. Store the destination address into a lookup table. 2. Resolve which local network address should be used when sending to the specified destination. 3. Allocate a UD QP on the RDMA device associated with the local address. 4. Send the user's datagram to the remote UDP socket. A header is inserted before the user's datagram. The header specifies the UD QP number associated with the local network address (IP and UDP port) of the send. A service thread is used to process messages received on the UDP socket. This thread updates the rsocket lookup tables with the remote QPN and path record data. The service thread forwards data received on the UDP socket to an rsocket QP. After the remote QPN and path records have been resolved, datagram communication between two nodes are done over the UD QP. Signed-off-by: Sean Hefty --- diff --git a/docs/rsocket b/docs/rsocket index 1484f65b..f453c1b6 100644 --- a/docs/rsocket +++ b/docs/rsocket @@ -1,7 +1,7 @@ -rsocket Protocol and Design Guide 9/10/2012 +rsocket Protocol and Design Guide 11/11/2012 -Overview --------- +Data Streaming (TCP) Overview +----------------------------- Rsockets is a protocol over RDMA that supports a socket-level API for applications. For details on the current state of the implementation, readers should refer to the rsocket man page. This @@ -189,3 +189,91 @@ registered remote data buffer. From host A's perspective, the transfer appears as a normal send/write operation, with the data stream redirected directly into the receiving application's buffer. + + + +Datagram Overview +----------------- +The rsocket API supports datagram sockets. Datagram support is handled through an +entirely different protocol and internal implementation. Unlike connected rsockets, +datagram rsockets are not necessarily bound to a network (IP) address. A datagram +socket may use any number of network (IP) addresses, including those which map to +different RDMA devices. As a result, a single datagram rsocket must support +using multiple RDMA devices and ports, and a datagram rsocket references a single +UDP socket, plus zero or more UD QPs. + +Rsockets uses headers inserted before user data sent over UDP sockets to resolve +remote UD QP numbers. When a user first attempts to send a datagram to a remote +address (IP and UDP port), rsockets will take the following steps: + +1. Store the destination address into a lookup table. +2. Resolve which local network address should be used when sending + to the specified destination. +3. Allocate a UD QP on the RDMA device associated with the local address. +4. Send the user's datagram to the remote UDP socket. + +A header is inserted before the user's datagram. The header specifies the +UD QP number associated with the local network address (IP and UDP port) of +the send. + +A service thread is used to process messages received on the UDP socket. This +thread updates the rsocket lookup tables with the remote QPN and path record +data. The service thread forwards data received on the UDP socket to an +rsocket QP. After the remote QPN and path records have been resolved, datagram +communication between two nodes are done over the UD QP. + +UDP Message Format +------------------ +Rsockets uses messages exchanged over UDP sockets to resolve remote QP numbers. +If a user sends a datagram to a remote service and the local rsocket is not +yet configured to send directly to a remote UD QP, the user data is sent over +a UDP socket with the following header inserted before the user data. + +struct ds_udp_header { + uint32_t tag; + uint8_t version; + uint8_t op; + uint8_t length; + uint8_t reserved; + uint32_t qpn; /* lower 8-bits reserved */ + union { + uint32_t ipv4; + uint8_t ipv6[16]; + } addr; +}; + +Tag - Marker used to help identify that the UDP header is present. +#define DS_UDP_TAG 0x55555555 + +Version - IP address version, either 4 or 6 +Op - Indicates message type, used to control the receiver's operation. + Valid operations are RS_OP_DATA and RS_OP_CTRL. Data messages + carry user data, while control messages are used to reply with the + local QP number. +Length - Size of the UDP header. +QPN - UD QP number associated with sender's IP address and port. + The sender's address and port is extracted from the received UDP + datagram. +Addr - Target IP address of the sent datagram. + +Once the remote QP information has been resolved, data is sent directly +between UD QPs. The following header is inserted before any user data that +is transferred over a UD QP. + +struct ds_header { + uint8_t version; + uint8_t length; + uint16_t port; + union { + uint32_t ipv4; + struct { + uint32_t flowinfo; + uint8_t addr[16]; + } ipv6; + } addr; +}; + +Verion - IP address version +Length - Size of the header +Port - Associated source address UDP port +Addr - Associated source IP address \ No newline at end of file diff --git a/src/cma.c b/src/cma.c index 388be617..ff9b426c 100755 --- a/src/cma.c +++ b/src/cma.c @@ -513,7 +513,7 @@ int rdma_destroy_id(struct rdma_cm_id *id) return 0; } -static int ucma_addrlen(struct sockaddr *addr) +int ucma_addrlen(struct sockaddr *addr) { if (!addr) return 0; @@ -2232,9 +2232,19 @@ void rdma_destroy_ep(struct rdma_cm_id *id) int ucma_max_qpsize(struct rdma_cm_id *id) { struct cma_id_private *id_priv; + int i, max_size = 0; id_priv = container_of(id, struct cma_id_private, id); - return id_priv->cma_dev->max_qpsize; + if (id && id_priv->cma_dev) { + max_size = id_priv->cma_dev->max_qpsize; + } else { + ucma_init(); + for (i = 0; i < cma_dev_cnt; i++) { + if (!max_size || max_size > cma_dev_array[i].max_qpsize) + max_size = cma_dev_array[i].max_qpsize; + } + } + return max_size; } uint16_t ucma_get_port(struct sockaddr *addr) diff --git a/src/cma.h b/src/cma.h index 0a0370ee..7135a612 100644 --- a/src/cma.h +++ b/src/cma.h @@ -145,10 +145,12 @@ typedef struct { volatile int val; } atomic_t; #define atomic_set(v, s) ((v)->val = s) uint16_t ucma_get_port(struct sockaddr *addr); +int ucma_addrlen(struct sockaddr *addr); void ucma_set_sid(enum rdma_port_space ps, struct sockaddr *addr, struct sockaddr_ib *sib); int ucma_max_qpsize(struct rdma_cm_id *id); int ucma_complete(struct rdma_cm_id *id); + static inline int ERR(int err) { errno = err; diff --git a/src/rsocket.c b/src/rsocket.c index a060f66a..7be42cab 100644 --- a/src/rsocket.c +++ b/src/rsocket.c @@ -47,6 +47,8 @@ #include #include #include +#include +#include #include #include @@ -56,7 +58,7 @@ #define RS_OLAP_START_SIZE 2048 #define RS_MAX_TRANSFER 65536 -#define RS_SNDLOWAT 64 +#define RS_SNDLOWAT 2048 #define RS_QP_MAX_SIZE 0xFFFE #define RS_QP_CTRL_SIZE 4 #define RS_CONN_RETRIES 6 @@ -64,6 +66,27 @@ static struct index_map idm; static pthread_mutex_t mut = PTHREAD_MUTEX_INITIALIZER; +struct rsocket; + +enum { + RS_SVC_DGRAM = 1 << 0 +}; + +struct rs_svc_msg { + uint32_t svcs; + uint32_t status; + struct rsocket *rs; +}; + +static pthread_t svc_id; +static int svc_sock[2]; +static int svc_cnt; +static int svc_size; +static struct rsocket **svc_rss; +static struct pollfd *svc_fds; +static uint8_t svc_buf[RS_SNDLOWAT]; +static void *rs_svc_run(void *arg); + static uint16_t def_iomap_size = 0; static uint16_t def_inline = 64; static uint16_t def_sqsize = 384; @@ -100,6 +123,14 @@ enum { #define rs_msg_set(op, data) ((op << 29) | (uint32_t) (data)) #define rs_msg_op(imm_data) (imm_data >> 29) #define rs_msg_data(imm_data) (imm_data & 0x1FFFFFFF) +#define RS_RECV_WR_ID (~((uint64_t) 0)) + +#define DS_WR_RECV 0xFFFFFFFF +#define ds_send_wr_id(offset, length) (((uint64_t) (offset)) << 32 | (uint64_t) length) +#define ds_recv_wr_id(offset) (((uint64_t) (offset)) << 32 | (uint64_t) DS_WR_RECV) +#define ds_wr_offset(wr_id) ((uint32_t) (wr_id >> 32)) +#define ds_wr_length(wr_id) ((uint32_t) wr_id) +#define ds_wr_is_recv(wr_id) (ds_wr_length(wr_id) == DS_WR_RECV) enum { RS_CTRL_DISCONNECT, @@ -111,6 +142,18 @@ struct rs_msg { uint32_t data; }; +struct ds_qp; + +struct ds_rmsg { + struct ds_qp *qp; + uint32_t offset; + uint32_t length; +}; + +struct ds_smsg { + struct ds_smsg *next; +}; + struct rs_sge { uint64_t addr; uint32_t key; @@ -145,8 +188,6 @@ struct rs_conn_data { struct rs_sge data_buf; }; -#define RS_RECV_WR_ID (~((uint64_t) 0)) - /* * rsocket states are ordered as passive, connecting, connected, disconnected. */ @@ -160,9 +201,9 @@ enum rs_state { rs_connecting = rs_opening | 0x0040, rs_accepting = rs_opening | 0x0080, rs_connected = 0x0100, - rs_connect_wr = 0x0200, - rs_connect_rd = 0x0400, - rs_connect_rdwr = rs_connected | rs_connect_rd | rs_connect_wr, + rs_writable = 0x0200, + rs_readable = 0x0400, + rs_connect_rdwr = rs_connected | rs_readable | rs_writable, rs_connect_error = 0x0800, rs_disconnected = 0x1000, rs_error = 0x2000, @@ -170,68 +211,223 @@ enum rs_state { #define RS_OPT_SWAP_SGL 1 -struct rsocket { +union socket_addr { + struct sockaddr sa; + struct sockaddr_in sin; + struct sockaddr_in6 sin6; +}; + +struct ds_header { + uint8_t version; + uint8_t length; + uint16_t port; + union { + uint32_t ipv4; + struct { + uint32_t flowinfo; + uint8_t addr[16]; + } ipv6; + } addr; +}; + +#define DS_IPV4_HDR_LEN 8 +#define DS_IPV6_HDR_LEN 24 + +struct ds_dest { + union socket_addr addr; /* must be first */ + struct ds_qp *qp; + struct ibv_ah *ah; + uint32_t qpn; +}; + +struct ds_qp { + dlist_entry list; + struct rsocket *rs; struct rdma_cm_id *cm_id; + struct ds_header hdr; + struct ds_dest dest; + + struct ibv_mr *smr; + struct ibv_mr *rmr; + uint8_t *rbuf; + + int cq_armed; +}; + +struct rsocket { + int type; + int index; fastlock_t slock; fastlock_t rlock; fastlock_t cq_lock; fastlock_t cq_wait_lock; - fastlock_t iomap_lock; - + fastlock_t map_lock; /* acquire slock first if needed */ + + union { + /* data stream */ + struct { + struct rdma_cm_id *cm_id; + uint64_t tcp_opts; + + int ctrl_avail; + uint16_t sseq_no; + uint16_t sseq_comp; + uint16_t rseq_no; + uint16_t rseq_comp; + + int remote_sge; + struct rs_sge remote_sgl; + struct rs_sge remote_iomap; + + struct ibv_mr *target_mr; + int target_sge; + int target_iomap_size; + void *target_buffer_list; + volatile struct rs_sge *target_sgl; + struct rs_iomap *target_iomap; + + int rbuf_bytes_avail; + int rbuf_free_offset; + int rbuf_offset; + struct ibv_mr *rmr; + uint8_t *rbuf; + + int sbuf_bytes_avail; + struct ibv_mr *smr; + struct ibv_sge ssgl[2]; + }; + /* datagram */ + struct { + struct ds_qp *qp_list; + void *dest_map; + struct ds_dest *conn_dest; + + int udp_sock; + int epfd; + int rqe_avail; + struct ds_smsg *smsg_free; + }; + }; + + int svcs; int opts; long fd_flags; uint64_t so_opts; - uint64_t tcp_opts; uint64_t ipv6_opts; int state; int cq_armed; int retries; int err; - int index; - int ctrl_avail; + int sqe_avail; - int sbuf_bytes_avail; - uint16_t sseq_no; - uint16_t sseq_comp; + uint32_t sbuf_size; uint16_t sq_size; uint16_t sq_inline; + uint32_t rbuf_size; uint16_t rq_size; - uint16_t rseq_no; - uint16_t rseq_comp; - int rbuf_bytes_avail; - int rbuf_free_offset; - int rbuf_offset; int rmsg_head; int rmsg_tail; - struct rs_msg *rmsg; - - int remote_sge; - struct rs_sge remote_sgl; - struct rs_sge remote_iomap; + union { + struct rs_msg *rmsg; + struct ds_rmsg *dmsg; + }; + uint8_t *sbuf; struct rs_iomap_mr *remote_iomappings; dlist_entry iomap_list; dlist_entry iomap_queue; int iomap_pending; +}; - struct ibv_mr *target_mr; - int target_sge; - int target_iomap_size; - void *target_buffer_list; - volatile struct rs_sge *target_sgl; - struct rs_iomap *target_iomap; +#define DS_UDP_TAG 0x55555555 - uint32_t rbuf_size; - struct ibv_mr *rmr; - uint8_t *rbuf; - - uint32_t sbuf_size; - struct ibv_mr *smr; - struct ibv_sge ssgl[2]; - uint8_t *sbuf; +struct ds_udp_header { + uint32_t tag; + uint8_t version; + uint8_t op; + uint8_t length; + uint8_t reserved; + uint32_t qpn; /* lower 8-bits reserved */ + union { + uint32_t ipv4; + uint8_t ipv6[16]; + } addr; }; +#define DS_UDP_IPV4_HDR_LEN 16 +#define DS_UDP_IPV6_HDR_LEN 28 + +#define ds_next_qp(qp) container_of((qp)->list.next, struct ds_qp, list) + +static void ds_insert_qp(struct rsocket *rs, struct ds_qp *qp) +{ + if (!rs->qp_list) + dlist_init(&qp->list); + else + dlist_insert_head(&qp->list, &rs->qp_list->list); + rs->qp_list = qp; +} + +static void ds_remove_qp(struct rsocket *rs, struct ds_qp *qp) +{ + if (qp->list.next != &qp->list) { + rs->qp_list = ds_next_qp(qp); + dlist_remove(&qp->list); + } else { + rs->qp_list = NULL; + } +} + +static int rs_modify_svcs(struct rsocket *rs, int svcs) +{ + struct rs_svc_msg msg; + int ret; + + pthread_mutex_lock(&mut); + if (!svc_cnt) { + ret = socketpair(AF_UNIX, SOCK_STREAM, 0, svc_sock); + if (ret) + goto unlock; + + ret = pthread_create(&svc_id, NULL, rs_svc_run, NULL); + if (ret) { + ret = ERR(ret); + goto closepair; + } + } + + msg.svcs = svcs; + msg.status = EINVAL; + msg.rs = rs; + write(svc_sock[0], &msg, sizeof msg); + read(svc_sock[0], &msg, sizeof msg); + ret = rdma_seterrno(msg.status); + if (svc_cnt) + goto unlock; + + pthread_join(svc_id, NULL); +closepair: + close(svc_sock[0]); + close(svc_sock[1]); +unlock: + pthread_mutex_unlock(&mut); + return ret; +} + +static int ds_compare_addr(const void *dst1, const void *dst2) +{ + const struct sockaddr *sa1, *sa2; + size_t len; + + sa1 = (const struct sockaddr *) dst1; + sa2 = (const struct sockaddr *) dst2; + + len = (sa1->sa_family == AF_INET6 && sa2->sa_family == AF_INET6) ? + sizeof(struct sockaddr_in6) : sizeof(struct sockaddr_in); + return memcmp(dst1, dst2, len); +} + static int rs_value_to_scale(int value, int bits) { return value <= (1 << (bits - 1)) ? @@ -307,10 +503,10 @@ out: pthread_mutex_unlock(&mut); } -static int rs_insert(struct rsocket *rs) +static int rs_insert(struct rsocket *rs, int index) { pthread_mutex_lock(&mut); - rs->index = idm_set(&idm, rs->cm_id->channel->fd, rs); + rs->index = idm_set(&idm, index, rs); pthread_mutex_unlock(&mut); return rs->index; } @@ -322,7 +518,7 @@ static void rs_remove(struct rsocket *rs) pthread_mutex_unlock(&mut); } -static struct rsocket *rs_alloc(struct rsocket *inherited_rs) +static struct rsocket *rs_alloc(struct rsocket *inherited_rs, int type) { struct rsocket *rs; @@ -330,29 +526,39 @@ static struct rsocket *rs_alloc(struct rsocket *inherited_rs) if (!rs) return NULL; + rs->type = type; rs->index = -1; + if (type == SOCK_DGRAM) { + rs->udp_sock = -1; + rs->epfd = -1; + } + if (inherited_rs) { rs->sbuf_size = inherited_rs->sbuf_size; rs->rbuf_size = inherited_rs->rbuf_size; rs->sq_inline = inherited_rs->sq_inline; rs->sq_size = inherited_rs->sq_size; rs->rq_size = inherited_rs->rq_size; - rs->ctrl_avail = inherited_rs->ctrl_avail; - rs->target_iomap_size = inherited_rs->target_iomap_size; + if (type == SOCK_STREAM) { + rs->ctrl_avail = inherited_rs->ctrl_avail; + rs->target_iomap_size = inherited_rs->target_iomap_size; + } } else { rs->sbuf_size = def_wmem; rs->rbuf_size = def_mem; rs->sq_inline = def_inline; rs->sq_size = def_sqsize; rs->rq_size = def_rqsize; - rs->ctrl_avail = RS_QP_CTRL_SIZE; - rs->target_iomap_size = def_iomap_size; + if (type == SOCK_STREAM) { + rs->ctrl_avail = RS_QP_CTRL_SIZE; + rs->target_iomap_size = def_iomap_size; + } } fastlock_init(&rs->slock); fastlock_init(&rs->rlock); fastlock_init(&rs->cq_lock); fastlock_init(&rs->cq_wait_lock); - fastlock_init(&rs->iomap_lock); + fastlock_init(&rs->map_lock); dlist_init(&rs->iomap_list); dlist_init(&rs->iomap_queue); return rs; @@ -360,13 +566,26 @@ static struct rsocket *rs_alloc(struct rsocket *inherited_rs) static int rs_set_nonblocking(struct rsocket *rs, long arg) { + struct ds_qp *qp; int ret = 0; - if (rs->cm_id->recv_cq_channel) - ret = fcntl(rs->cm_id->recv_cq_channel->fd, F_SETFL, arg); + if (rs->type == SOCK_STREAM) { + if (rs->cm_id->recv_cq_channel) + ret = fcntl(rs->cm_id->recv_cq_channel->fd, F_SETFL, arg); - if (!ret && rs->state < rs_connected) - ret = fcntl(rs->cm_id->channel->fd, F_SETFL, arg); + if (!ret && rs->state < rs_connected) + ret = fcntl(rs->cm_id->channel->fd, F_SETFL, arg); + } else { + ret = fcntl(rs->epfd, F_SETFL, arg); + if (!ret && rs->qp_list) { + qp = rs->qp_list; + do { + ret = fcntl(qp->cm_id->recv_cq_channel->fd, + F_SETFL, arg); + qp = ds_next_qp(qp); + } while (qp != rs->qp_list && !ret); + } + } return ret; } @@ -390,17 +609,39 @@ static void rs_set_qp_size(struct rsocket *rs) rs->rq_size = 2; } +static void ds_set_qp_size(struct rsocket *rs) +{ + uint16_t max_size; + + max_size = min(ucma_max_qpsize(NULL), RS_QP_MAX_SIZE); + + if (rs->sq_size > max_size) + rs->sq_size = max_size; + if (rs->rq_size > max_size) + rs->rq_size = max_size; + + if (rs->rq_size > (rs->rbuf_size / RS_SNDLOWAT)) + rs->rq_size = rs->rbuf_size / RS_SNDLOWAT; + else + rs->rbuf_size = rs->rq_size * RS_SNDLOWAT; + + if (rs->sq_size > (rs->sbuf_size / RS_SNDLOWAT)) + rs->sq_size = rs->sbuf_size / RS_SNDLOWAT; + else + rs->sbuf_size = rs->sq_size * RS_SNDLOWAT; +} + static int rs_init_bufs(struct rsocket *rs) { size_t len; rs->rmsg = calloc(rs->rq_size + 1, sizeof(*rs->rmsg)); if (!rs->rmsg) - return -1; + return ERR(ENOMEM); rs->sbuf = calloc(rs->sbuf_size, sizeof(*rs->sbuf)); if (!rs->sbuf) - return -1; + return ERR(ENOMEM); rs->smr = rdma_reg_msgs(rs->cm_id, rs->sbuf, rs->sbuf_size); if (!rs->smr) @@ -410,7 +651,7 @@ static int rs_init_bufs(struct rsocket *rs) sizeof(*rs->target_iomap) * rs->target_iomap_size; rs->target_buffer_list = malloc(len); if (!rs->target_buffer_list) - return -1; + return ERR(ENOMEM); rs->target_mr = rdma_reg_write(rs->cm_id, rs->target_buffer_list, len); if (!rs->target_mr) @@ -423,7 +664,7 @@ static int rs_init_bufs(struct rsocket *rs) rs->rbuf = calloc(rs->rbuf_size, sizeof(*rs->rbuf)); if (!rs->rbuf) - return -1; + return ERR(ENOMEM); rs->rmr = rdma_reg_write(rs->cm_id, rs->rbuf, rs->rbuf_size); if (!rs->rmr) @@ -440,37 +681,61 @@ static int rs_init_bufs(struct rsocket *rs) return 0; } -static int rs_create_cq(struct rsocket *rs) +static int ds_init_bufs(struct ds_qp *qp) +{ + qp->rbuf = calloc(qp->rs->rbuf_size + sizeof(struct ibv_grh), + sizeof(*qp->rbuf)); + if (!qp->rbuf) + return ERR(ENOMEM); + + qp->smr = rdma_reg_msgs(qp->cm_id, qp->rs->sbuf, qp->rs->sbuf_size); + if (!qp->smr) + return -1; + + qp->rmr = rdma_reg_msgs(qp->cm_id, qp->rbuf, qp->rs->rbuf_size + + sizeof(struct ibv_grh)); + if (!qp->rmr) + return -1; + + return 0; +} + +/* + * If a user is waiting on a datagram rsocket through poll or select, then + * we need the first completion to generate an event on the related epoll fd + * in order to signal the user. We arm the CQ on creation for this purpose + */ +static int rs_create_cq(struct rsocket *rs, struct rdma_cm_id *cm_id) { - rs->cm_id->recv_cq_channel = ibv_create_comp_channel(rs->cm_id->verbs); - if (!rs->cm_id->recv_cq_channel) + cm_id->recv_cq_channel = ibv_create_comp_channel(cm_id->verbs); + if (!cm_id->recv_cq_channel) return -1; - rs->cm_id->recv_cq = ibv_create_cq(rs->cm_id->verbs, rs->sq_size + rs->rq_size, - rs->cm_id, rs->cm_id->recv_cq_channel, 0); - if (!rs->cm_id->recv_cq) + cm_id->recv_cq = ibv_create_cq(cm_id->verbs, rs->sq_size + rs->rq_size, + cm_id, cm_id->recv_cq_channel, 0); + if (!cm_id->recv_cq) goto err1; if (rs->fd_flags & O_NONBLOCK) { - if (rs_set_nonblocking(rs, O_NONBLOCK)) + if (fcntl(cm_id->recv_cq_channel->fd, F_SETFL, O_NONBLOCK)) goto err2; } - rs->cm_id->send_cq_channel = rs->cm_id->recv_cq_channel; - rs->cm_id->send_cq = rs->cm_id->recv_cq; + ibv_req_notify_cq(cm_id->recv_cq, 0); + cm_id->send_cq_channel = cm_id->recv_cq_channel; + cm_id->send_cq = cm_id->recv_cq; return 0; err2: - ibv_destroy_cq(rs->cm_id->recv_cq); - rs->cm_id->recv_cq = NULL; + ibv_destroy_cq(cm_id->recv_cq); + cm_id->recv_cq = NULL; err1: - ibv_destroy_comp_channel(rs->cm_id->recv_cq_channel); - rs->cm_id->recv_cq_channel = NULL; + ibv_destroy_comp_channel(cm_id->recv_cq_channel); + cm_id->recv_cq_channel = NULL; return -1; } -static inline int -rs_post_recv(struct rsocket *rs) +static inline int rs_post_recv(struct rsocket *rs) { struct ibv_recv_wr wr, *bad; @@ -482,6 +747,26 @@ rs_post_recv(struct rsocket *rs) return rdma_seterrno(ibv_post_recv(rs->cm_id->qp, &wr, &bad)); } +static inline int ds_post_recv(struct rsocket *rs, struct ds_qp *qp, uint32_t offset) +{ + struct ibv_recv_wr wr, *bad; + struct ibv_sge sge[2]; + + sge[0].addr = (uintptr_t) qp->rbuf + rs->rbuf_size; + sge[0].length = sizeof(struct ibv_grh); + sge[0].lkey = qp->rmr->lkey; + sge[1].addr = (uintptr_t) qp->rbuf + offset; + sge[1].length = RS_SNDLOWAT; + sge[1].lkey = qp->rmr->lkey; + + wr.wr_id = ds_recv_wr_id(offset); + wr.next = NULL; + wr.sg_list = sge; + wr.num_sge = 2; + + return rdma_seterrno(ibv_post_recv(qp->cm_id->qp, &wr, &bad)); +} + static int rs_create_ep(struct rsocket *rs) { struct ibv_qp_init_attr qp_attr; @@ -492,7 +777,7 @@ static int rs_create_ep(struct rsocket *rs) if (ret) return ret; - ret = rs_create_cq(rs); + ret = rs_create_cq(rs, rs->cm_id); if (ret) return ret; @@ -549,8 +834,70 @@ static void rs_free_iomappings(struct rsocket *rs) } } +static void ds_free_qp(struct ds_qp *qp) +{ + if (qp->smr) + rdma_dereg_mr(qp->smr); + + if (qp->rbuf) { + if (qp->rmr) + rdma_dereg_mr(qp->rmr); + free(qp->rbuf); + } + + if (qp->cm_id) { + if (qp->cm_id->qp) { + tdelete(&qp->dest.addr, &qp->rs->dest_map, ds_compare_addr); + epoll_ctl(qp->rs->epfd, EPOLL_CTL_DEL, + qp->cm_id->recv_cq_channel->fd, NULL); + rdma_destroy_qp(qp->cm_id); + } + rdma_destroy_id(qp->cm_id); + } + + free(qp); +} + +static void ds_free(struct rsocket *rs) +{ + struct ds_qp *qp; + + if (rs->udp_sock >= 0) + close(rs->udp_sock); + + if (rs->index >= 0) + rs_remove(rs); + + if (rs->dmsg) + free(rs->dmsg); + + while ((qp = rs->qp_list)) { + ds_remove_qp(rs, qp); + ds_free_qp(qp); + } + + if (rs->epfd >= 0) + close(rs->epfd); + + if (rs->sbuf) + free(rs->sbuf); + + tdestroy(rs->dest_map, free); + fastlock_destroy(&rs->map_lock); + fastlock_destroy(&rs->cq_wait_lock); + fastlock_destroy(&rs->cq_lock); + fastlock_destroy(&rs->rlock); + fastlock_destroy(&rs->slock); + free(rs); +} + static void rs_free(struct rsocket *rs) { + if (rs->type == SOCK_DGRAM) { + ds_free(rs); + return; + } + if (rs->index >= 0) rs_remove(rs); @@ -582,7 +929,7 @@ static void rs_free(struct rsocket *rs) rdma_destroy_id(rs->cm_id); } - fastlock_destroy(&rs->iomap_lock); + fastlock_destroy(&rs->map_lock); fastlock_destroy(&rs->cq_wait_lock); fastlock_destroy(&rs->cq_lock); fastlock_destroy(&rs->rlock); @@ -636,29 +983,88 @@ static void rs_save_conn_data(struct rsocket *rs, struct rs_conn_data *conn) rs->sseq_comp = ntohs(conn->credits); } +static int ds_init(struct rsocket *rs, int domain) +{ + rs->udp_sock = socket(domain, SOCK_DGRAM, 0); + if (rs->udp_sock < 0) + return rs->udp_sock; + + rs->epfd = epoll_create(2); + if (rs->epfd < 0) + return rs->epfd; + + return 0; +} + +static int ds_init_ep(struct rsocket *rs) +{ + struct ds_smsg *msg; + int i, ret; + + ds_set_qp_size(rs); + + rs->sbuf = calloc(rs->sq_size, RS_SNDLOWAT); + if (!rs->sbuf) + return ERR(ENOMEM); + + rs->dmsg = calloc(rs->rq_size + 1, sizeof(*rs->dmsg)); + if (!rs->dmsg) + return ERR(ENOMEM); + + rs->sqe_avail = rs->sq_size; + rs->rqe_avail = rs->rq_size; + + rs->smsg_free = (struct ds_smsg *) rs->sbuf; + msg = rs->smsg_free; + for (i = 0; i < rs->sq_size - 1; i++) { + msg->next = (void *) msg + RS_SNDLOWAT; + msg = msg->next; + } + msg->next = NULL; + + ret = rs_modify_svcs(rs, RS_SVC_DGRAM); + if (ret) + return ret; + + rs->state = rs_readable | rs_writable; + return 0; +} + int rsocket(int domain, int type, int protocol) { struct rsocket *rs; - int ret; + int index, ret; if ((domain != PF_INET && domain != PF_INET6) || - (type != SOCK_STREAM) || (protocol && protocol != IPPROTO_TCP)) + ((type != SOCK_STREAM) && (type != SOCK_DGRAM)) || + (type == SOCK_STREAM && protocol && protocol != IPPROTO_TCP) || + (type == SOCK_DGRAM && protocol && protocol != IPPROTO_UDP)) return ERR(ENOTSUP); rs_configure(); - rs = rs_alloc(NULL); + rs = rs_alloc(NULL, type); if (!rs) return ERR(ENOMEM); - ret = rdma_create_id(NULL, &rs->cm_id, rs, RDMA_PS_TCP); - if (ret) - goto err; + if (type == SOCK_STREAM) { + ret = rdma_create_id(NULL, &rs->cm_id, rs, RDMA_PS_TCP); + if (ret) + goto err; - ret = rs_insert(rs); + rs->cm_id->route.addr.src_addr.sa_family = domain; + index = rs->cm_id->channel->fd; + } else { + ret = ds_init(rs, domain); + if (ret) + goto err; + + index = rs->udp_sock; + } + + ret = rs_insert(rs, index); if (ret < 0) goto err; - rs->cm_id->route.addr.src_addr.sa_family = domain; return rs->index; err: @@ -672,9 +1078,18 @@ int rbind(int socket, const struct sockaddr *addr, socklen_t addrlen) int ret; rs = idm_at(&idm, socket); - ret = rdma_bind_addr(rs->cm_id, (struct sockaddr *) addr); - if (!ret) - rs->state = rs_bound; + if (rs->type == SOCK_STREAM) { + ret = rdma_bind_addr(rs->cm_id, (struct sockaddr *) addr); + if (!ret) + rs->state = rs_bound; + } else { + if (rs->state == rs_init) { + ret = ds_init_ep(rs); + if (ret) + return ret; + } + ret = bind(rs->udp_sock, addr, addrlen); + } return ret; } @@ -710,7 +1125,7 @@ int raccept(int socket, struct sockaddr *addr, socklen_t *addrlen) int ret; rs = idm_at(&idm, socket); - new_rs = rs_alloc(rs); + new_rs = rs_alloc(rs, rs->type); if (!new_rs) return ERR(ENOMEM); @@ -718,7 +1133,7 @@ int raccept(int socket, struct sockaddr *addr, socklen_t *addrlen) if (ret) goto err; - ret = rs_insert(new_rs); + ret = rs_insert(new_rs, new_rs->cm_id->channel->fd); if (ret < 0) goto err; @@ -729,7 +1144,7 @@ int raccept(int socket, struct sockaddr *addr, socklen_t *addrlen) } if (rs->fd_flags & O_NONBLOCK) - rs_set_nonblocking(new_rs, O_NONBLOCK); + fcntl(new_rs->cm_id->channel->fd, F_SETFL, O_NONBLOCK); ret = rs_create_ep(new_rs); if (ret) @@ -831,7 +1246,7 @@ connected: break; case rs_accepting: if (!(rs->fd_flags & O_NONBLOCK)) - rs_set_nonblocking(rs, 0); + fcntl(rs->cm_id->channel->fd, F_SETFL, 0); ret = ucma_complete(rs->cm_id); if (ret) @@ -855,13 +1270,240 @@ connected: return ret; } +static int rs_any_addr(const union socket_addr *addr) +{ + if (addr->sa.sa_family == AF_INET) { + return (addr->sin.sin_addr.s_addr == INADDR_ANY || + addr->sin.sin_addr.s_addr == INADDR_LOOPBACK); + } else { + return (!memcmp(&addr->sin6.sin6_addr, &in6addr_any, 16) || + !memcmp(&addr->sin6.sin6_addr, &in6addr_loopback, 16)); + } +} + +static int ds_get_src_addr(struct rsocket *rs, + const struct sockaddr *dest_addr, socklen_t dest_len, + union socket_addr *src_addr, socklen_t *src_len) +{ + int sock, ret; + uint16_t port; + + *src_len = sizeof *src_addr; + ret = getsockname(rs->udp_sock, &src_addr->sa, src_len); + if (ret || !rs_any_addr(src_addr)) + return ret; + + port = src_addr->sin.sin_port; + sock = socket(dest_addr->sa_family, SOCK_DGRAM, 0); + if (sock < 0) + return sock; + + ret = connect(sock, dest_addr, dest_len); + if (ret) + goto out; + + *src_len = sizeof *src_addr; + ret = getsockname(sock, &src_addr->sa, src_len); + src_addr->sin.sin_port = port; +out: + close(sock); + return ret; +} + +static void ds_format_hdr(struct ds_header *hdr, union socket_addr *addr) +{ + if (addr->sa.sa_family == AF_INET) { + hdr->version = 4; + hdr->length = DS_IPV4_HDR_LEN; + hdr->port = addr->sin.sin_port; + hdr->addr.ipv4 = addr->sin.sin_addr.s_addr; + } else { + hdr->version = 6; + hdr->length = DS_IPV6_HDR_LEN; + hdr->port = addr->sin6.sin6_port; + hdr->addr.ipv6.flowinfo= addr->sin6.sin6_flowinfo; + memcpy(&hdr->addr.ipv6.addr, &addr->sin6.sin6_addr, 16); + } +} + +static int ds_add_qp_dest(struct ds_qp *qp, union socket_addr *addr, + socklen_t addrlen) +{ + struct ibv_port_attr port_attr; + struct ibv_ah_attr attr; + int ret; + + memcpy(&qp->dest.addr, addr, addrlen); + qp->dest.qp = qp; + qp->dest.qpn = qp->cm_id->qp->qp_num; + + ret = ibv_query_port(qp->cm_id->verbs, qp->cm_id->port_num, &port_attr); + if (ret) + return ret; + + memset(&attr, 0, sizeof attr); + attr.dlid = port_attr.lid; + attr.port_num = qp->cm_id->port_num; + qp->dest.ah = ibv_create_ah(qp->cm_id->pd, &attr); + if (!qp->dest.ah) + return ERR(ENOMEM); + + tsearch(&qp->dest.addr, &qp->rs->dest_map, ds_compare_addr); + return 0; +} + +static int ds_create_qp(struct rsocket *rs, union socket_addr *src_addr, + socklen_t addrlen, struct ds_qp **new_qp) +{ + struct ds_qp *qp; + struct ibv_qp_init_attr qp_attr; + struct epoll_event event; + int i, ret; + + qp = calloc(1, sizeof(*qp)); + if (!qp) + return ERR(ENOMEM); + + qp->rs = rs; + ret = rdma_create_id(NULL, &qp->cm_id, qp, RDMA_PS_UDP); + if (ret) + goto err; + + ds_format_hdr(&qp->hdr, src_addr); + ret = rdma_bind_addr(qp->cm_id, &src_addr->sa); + if (ret) + goto err; + + ret = ds_init_bufs(qp); + if (ret) + goto err; + + ret = rs_create_cq(rs, qp->cm_id); + if (ret) + goto err; + + memset(&qp_attr, 0, sizeof qp_attr); + qp_attr.qp_context = qp; + qp_attr.send_cq = qp->cm_id->send_cq; + qp_attr.recv_cq = qp->cm_id->recv_cq; + qp_attr.qp_type = IBV_QPT_UD; + qp_attr.sq_sig_all = 1; + qp_attr.cap.max_send_wr = rs->sq_size; + qp_attr.cap.max_recv_wr = rs->rq_size; + qp_attr.cap.max_send_sge = 1; + qp_attr.cap.max_recv_sge = 2; + qp_attr.cap.max_inline_data = rs->sq_inline; + ret = rdma_create_qp(qp->cm_id, NULL, &qp_attr); + if (ret) + goto err; + + ret = ds_add_qp_dest(qp, src_addr, addrlen); + if (ret) + goto err; + + event.events = EPOLLIN; + event.data.ptr = qp; + ret = epoll_ctl(rs->epfd, EPOLL_CTL_ADD, + qp->cm_id->recv_cq_channel->fd, &event); + if (ret) + goto err; + + for (i = 0; i < rs->rq_size; i++) { + ret = ds_post_recv(rs, qp, i * RS_SNDLOWAT); + if (ret) + goto err; + } + + ds_insert_qp(rs, qp); + *new_qp = qp; + return 0; +err: + ds_free_qp(qp); + return ret; +} + +static int ds_get_qp(struct rsocket *rs, union socket_addr *src_addr, + socklen_t addrlen, struct ds_qp **qp) +{ + if (rs->qp_list) { + *qp = rs->qp_list; + do { + if (!ds_compare_addr(rdma_get_local_addr((*qp)->cm_id), + src_addr)) + return 0; + + *qp = ds_next_qp(*qp); + } while (*qp != rs->qp_list); + } + + return ds_create_qp(rs, src_addr, addrlen, qp); +} + +static int ds_get_dest(struct rsocket *rs, const struct sockaddr *addr, + socklen_t addrlen, struct ds_dest **dest) +{ + union socket_addr src_addr; + socklen_t src_len; + struct ds_qp *qp; + struct ds_dest **tdest, *new_dest; + int ret = 0; + + fastlock_acquire(&rs->map_lock); + tdest = tfind(addr, &rs->dest_map, ds_compare_addr); + if (tdest) + goto found; + + ret = ds_get_src_addr(rs, addr, addrlen, &src_addr, &src_len); + if (ret) + goto out; + + ret = ds_get_qp(rs, &src_addr, src_len, &qp); + if (ret) + goto out; + + tdest = tfind(addr, &rs->dest_map, ds_compare_addr); + if (!tdest) { + new_dest = calloc(1, sizeof(*new_dest)); + if (!new_dest) { + ret = ERR(ENOMEM); + goto out; + } + + memcpy(&new_dest->addr, addr, addrlen); + new_dest->qp = qp; + tdest = tsearch(&new_dest->addr, &rs->dest_map, ds_compare_addr); + } + +found: + *dest = *tdest; +out: + fastlock_release(&rs->map_lock); + return ret; +} + int rconnect(int socket, const struct sockaddr *addr, socklen_t addrlen) { struct rsocket *rs; + int ret; rs = idm_at(&idm, socket); - memcpy(&rs->cm_id->route.addr.dst_addr, addr, addrlen); - return rs_do_connect(rs); + if (rs->type == SOCK_STREAM) { + memcpy(&rs->cm_id->route.addr.dst_addr, addr, addrlen); + ret = rs_do_connect(rs); + } else { + if (rs->state == rs_init) { + ret = ds_init_ep(rs); + if (ret) + return ret; + } + + fastlock_acquire(&rs->slock); + ret = connect(rs->udp_sock, addr, addrlen); + if (!ret) + ret = ds_get_dest(rs, addr, addrlen, &rs->conn_dest); + fastlock_release(&rs->slock); + } + return ret; } static int rs_post_write_msg(struct rsocket *rs, @@ -903,6 +1545,24 @@ static int rs_post_write(struct rsocket *rs, return rdma_seterrno(ibv_post_send(rs->cm_id->qp, &wr, &bad)); } +static int ds_post_send(struct rsocket *rs, struct ibv_sge *sge, + uint64_t wr_id) +{ + struct ibv_send_wr wr, *bad; + + wr.wr_id = wr_id; + wr.next = NULL; + wr.sg_list = sge; + wr.num_sge = 1; + wr.opcode = IBV_WR_SEND; + wr.send_flags = (sge->length <= rs->sq_inline) ? IBV_SEND_INLINE : 0; + wr.wr.ud.ah = rs->conn_dest->ah; + wr.wr.ud.remote_qpn = rs->conn_dest->qpn; + wr.wr.ud.remote_qkey = RDMA_UDP_QKEY; + + return rdma_seterrno(ibv_post_send(rs->conn_dest->qp->cm_id->qp, &wr, &bad)); +} + /* * Update target SGE before sending data. Otherwise the remote side may * update the entry before we do. @@ -1046,7 +1706,7 @@ static int rs_poll_cq(struct rsocket *rs) rs->state = rs_disconnected; return 0; } else if (rs_msg_data(imm_data) == RS_CTRL_SHUTDOWN) { - rs->state &= ~rs_connect_rd; + rs->state &= ~rs_readable; } break; case RS_OP_WRITE: @@ -1133,46 +1793,208 @@ static int rs_get_cq_event(struct rsocket *rs) */ static int rs_process_cq(struct rsocket *rs, int nonblock, int (*test)(struct rsocket *rs)) { - int ret; + int ret; + + fastlock_acquire(&rs->cq_lock); + do { + rs_update_credits(rs); + ret = rs_poll_cq(rs); + if (test(rs)) { + ret = 0; + break; + } else if (ret) { + break; + } else if (nonblock) { + ret = ERR(EWOULDBLOCK); + } else if (!rs->cq_armed) { + ibv_req_notify_cq(rs->cm_id->recv_cq, 0); + rs->cq_armed = 1; + } else { + rs_update_credits(rs); + fastlock_acquire(&rs->cq_wait_lock); + fastlock_release(&rs->cq_lock); + + ret = rs_get_cq_event(rs); + fastlock_release(&rs->cq_wait_lock); + fastlock_acquire(&rs->cq_lock); + } + } while (!ret); + + rs_update_credits(rs); + fastlock_release(&rs->cq_lock); + return ret; +} + +static int rs_get_comp(struct rsocket *rs, int nonblock, int (*test)(struct rsocket *rs)) +{ + struct timeval s, e; + uint32_t poll_time = 0; + int ret; + + do { + ret = rs_process_cq(rs, 1, test); + if (!ret || nonblock || errno != EWOULDBLOCK) + return ret; + + if (!poll_time) + gettimeofday(&s, NULL); + + gettimeofday(&e, NULL); + poll_time = (e.tv_sec - s.tv_sec) * 1000000 + + (e.tv_usec - s.tv_usec) + 1; + } while (poll_time <= polling_time); + + ret = rs_process_cq(rs, 0, test); + return ret; +} + +static int ds_valid_recv(struct ds_qp *qp, struct ibv_wc *wc) +{ + struct ds_header *hdr; + + hdr = (struct ds_header *) (qp->rbuf + ds_wr_offset(wc->wr_id)); + return ((wc->byte_len >= sizeof(struct ibv_grh) + DS_IPV4_HDR_LEN) && + ((hdr->version == 4 && hdr->length == DS_IPV4_HDR_LEN) || + (hdr->version == 6 && hdr->length == DS_IPV6_HDR_LEN))); +} + +/* + * Poll all CQs associated with a datagram rsocket. We need to drop any + * received messages that we do not have room to store. To limit drops, + * we only poll if we have room to store the receive or we need a send + * buffer. To ensure fairness, we poll the CQs round robin, remembering + * where we left off. + */ +static void ds_poll_cqs(struct rsocket *rs) +{ + struct ds_qp *qp; + struct ds_smsg *smsg; + struct ds_rmsg *rmsg; + struct ibv_wc wc; + int ret, cnt; + + if (!(qp = rs->qp_list)) + return; + + do { + cnt = 0; + do { + ret = ibv_poll_cq(qp->cm_id->recv_cq, 1, &wc); + if (ret <= 0) { + qp = ds_next_qp(qp); + continue; + } + + if (ds_wr_is_recv(wc.wr_id)) { + if (rs->rqe_avail && wc.status == IBV_WC_SUCCESS && + ds_valid_recv(qp, &wc)) { + rs->rqe_avail--; + rmsg = &rs->dmsg[rs->rmsg_tail]; + rmsg->qp = qp; + rmsg->offset = ds_wr_offset(wc.wr_id); + rmsg->length = wc.byte_len - sizeof(struct ibv_grh); + if (++rs->rmsg_tail == rs->rq_size + 1) + rs->rmsg_tail = 0; + } else { + ds_post_recv(rs, qp, ds_wr_offset(wc.wr_id)); + } + } else { + smsg = (struct ds_smsg *) + (rs->sbuf + ds_wr_offset(wc.wr_id)); + smsg->next = rs->smsg_free; + rs->smsg_free = smsg; + rs->sqe_avail++; + } + + qp = ds_next_qp(qp); + if (!rs->rqe_avail && rs->sqe_avail) { + rs->qp_list = qp; + return; + } + cnt++; + } while (qp != rs->qp_list); + } while (cnt); +} + +static void ds_req_notify_cqs(struct rsocket *rs) +{ + struct ds_qp *qp; + + if (!(qp = rs->qp_list)) + return; + + do { + if (!qp->cq_armed) { + ibv_req_notify_cq(qp->cm_id->recv_cq, 0); + qp->cq_armed = 1; + } + qp = ds_next_qp(qp); + } while (qp != rs->qp_list); +} + +static int ds_get_cq_event(struct rsocket *rs) +{ + struct epoll_event event; + struct ds_qp *qp; + struct ibv_cq *cq; + void *context; + int ret; + + if (!rs->cq_armed) + return 0; + + ret = epoll_wait(rs->epfd, &event, 1, -1); + if (ret <= 0) + return ret; + + qp = event.data.ptr; + ret = ibv_get_cq_event(qp->cm_id->recv_cq_channel, &cq, &context); + if (!ret) { + ibv_ack_cq_events(qp->cm_id->recv_cq, 1); + qp->cq_armed = 0; + rs->cq_armed = 0; + } + + return ret; +} + +static int ds_process_cqs(struct rsocket *rs, int nonblock, int (*test)(struct rsocket *rs)) +{ + int ret = 0; fastlock_acquire(&rs->cq_lock); do { - rs_update_credits(rs); - ret = rs_poll_cq(rs); + ds_poll_cqs(rs); if (test(rs)) { ret = 0; break; - } else if (ret) { - break; } else if (nonblock) { ret = ERR(EWOULDBLOCK); } else if (!rs->cq_armed) { - ibv_req_notify_cq(rs->cm_id->recv_cq, 0); + ds_req_notify_cqs(rs); rs->cq_armed = 1; } else { - rs_update_credits(rs); fastlock_acquire(&rs->cq_wait_lock); fastlock_release(&rs->cq_lock); - ret = rs_get_cq_event(rs); + ret = ds_get_cq_event(rs); fastlock_release(&rs->cq_wait_lock); fastlock_acquire(&rs->cq_lock); } } while (!ret); - rs_update_credits(rs); fastlock_release(&rs->cq_lock); return ret; } -static int rs_get_comp(struct rsocket *rs, int nonblock, int (*test)(struct rsocket *rs)) +static int ds_get_comp(struct rsocket *rs, int nonblock, int (*test)(struct rsocket *rs)) { struct timeval s, e; uint32_t poll_time = 0; int ret; do { - ret = rs_process_cq(rs, 1, test); + ret = ds_process_cqs(rs, 1, test); if (!ret || nonblock || errno != EWOULDBLOCK) return ret; @@ -1184,7 +2006,7 @@ static int rs_get_comp(struct rsocket *rs, int nonblock, int (*test)(struct rsoc (e.tv_usec - s.tv_usec) + 1; } while (poll_time <= polling_time); - ret = rs_process_cq(rs, 0, test); + ret = ds_process_cqs(rs, 0, test); return ret; } @@ -1219,9 +2041,19 @@ static int rs_can_send(struct rsocket *rs) (rs->target_sgl[rs->target_sge].length != 0); } +static int ds_can_send(struct rsocket *rs) +{ + return rs->sqe_avail; +} + +static int ds_all_sends_done(struct rsocket *rs) +{ + return rs->sqe_avail == rs->sq_size; +} + static int rs_conn_can_send(struct rsocket *rs) { - return rs_can_send(rs) || !(rs->state & rs_connect_wr); + return rs_can_send(rs) || !(rs->state & rs_writable); } static int rs_conn_can_send_ctrl(struct rsocket *rs) @@ -1236,7 +2068,7 @@ static int rs_have_rdata(struct rsocket *rs) static int rs_conn_have_rdata(struct rsocket *rs) { - return rs_have_rdata(rs) || !(rs->state & rs_connect_rd); + return rs_have_rdata(rs) || !(rs->state & rs_readable); } static int rs_conn_all_sends_done(struct rsocket *rs) @@ -1245,6 +2077,67 @@ static int rs_conn_all_sends_done(struct rsocket *rs) !(rs->state & rs_connected); } +static void ds_set_src(struct sockaddr *addr, socklen_t *addrlen, + struct ds_header *hdr) +{ + union socket_addr sa; + + memset(&sa, 0, sizeof sa); + if (hdr->version == 4) { + if (*addrlen > sizeof(sa.sin)) + *addrlen = sizeof(sa.sin); + + sa.sin.sin_family = AF_INET; + sa.sin.sin_port = hdr->port; + sa.sin.sin_addr.s_addr = hdr->addr.ipv4; + } else { + if (*addrlen > sizeof(sa.sin6)) + *addrlen = sizeof(sa.sin6); + + sa.sin6.sin6_family = AF_INET6; + sa.sin6.sin6_port = hdr->port; + sa.sin6.sin6_flowinfo = hdr->addr.ipv6.flowinfo; + memcpy(&sa.sin6.sin6_addr, &hdr->addr.ipv6.addr, 16); + } + memcpy(addr, &sa, *addrlen); +} + +static ssize_t ds_recvfrom(struct rsocket *rs, void *buf, size_t len, int flags, + struct sockaddr *src_addr, socklen_t *addrlen) +{ + struct ds_rmsg *rmsg; + struct ds_header *hdr; + int ret; + + if (!(rs->state & rs_readable)) + return ERR(EINVAL); + + if (!rs_have_rdata(rs)) { + ret = ds_get_comp(rs, rs_nonblocking(rs, flags), + rs_have_rdata); + if (ret) + return ret; + } + + rmsg = &rs->dmsg[rs->rmsg_head]; + hdr = (struct ds_header *) (rmsg->qp->rbuf + rmsg->offset); + if (len > rmsg->length - hdr->length) + len = rmsg->length - hdr->length; + + memcpy(buf, (void *) hdr + hdr->length, len); + if (addrlen) + ds_set_src(src_addr, addrlen, hdr); + + if (!(flags & MSG_PEEK)) { + ds_post_recv(rs, rmsg->qp, rmsg->offset); + if (++rs->rmsg_head == rs->rq_size + 1) + rs->rmsg_head = 0; + rs->rqe_avail++; + } + + return len; +} + static ssize_t rs_peek(struct rsocket *rs, void *buf, size_t len) { size_t left = len; @@ -1290,6 +2183,13 @@ ssize_t rrecv(int socket, void *buf, size_t len, int flags) int ret; rs = idm_at(&idm, socket); + if (rs->type == SOCK_DGRAM) { + fastlock_acquire(&rs->rlock); + ret = ds_recvfrom(rs, buf, len, flags, NULL, 0); + fastlock_release(&rs->rlock); + return ret; + } + if (rs->state & rs_opening) { ret = rs_do_connect(rs); if (ret) { @@ -1339,7 +2239,7 @@ ssize_t rrecv(int socket, void *buf, size_t len, int flags) rs->rbuf_bytes_avail += rsize; } - } while (left && (flags & MSG_WAITALL) && (rs->state & rs_connect_rd)); + } while (left && (flags & MSG_WAITALL) && (rs->state & rs_readable)); fastlock_release(&rs->rlock); return ret ? ret : len - left; @@ -1348,8 +2248,17 @@ ssize_t rrecv(int socket, void *buf, size_t len, int flags) ssize_t rrecvfrom(int socket, void *buf, size_t len, int flags, struct sockaddr *src_addr, socklen_t *addrlen) { + struct rsocket *rs; int ret; + rs = idm_at(&idm, socket); + if (rs->type == SOCK_DGRAM) { + fastlock_acquire(&rs->rlock); + ret = ds_recvfrom(rs, buf, len, flags, src_addr, addrlen); + fastlock_release(&rs->rlock); + return ret; + } + ret = rrecv(socket, buf, len, flags); if (ret > 0 && src_addr) rgetpeername(socket, src_addr, addrlen); @@ -1391,14 +2300,14 @@ static int rs_send_iomaps(struct rsocket *rs, int flags) struct rs_iomap iom; int ret; - fastlock_acquire(&rs->iomap_lock); + fastlock_acquire(&rs->map_lock); while (!dlist_empty(&rs->iomap_queue)) { if (!rs_can_send(rs)) { ret = rs_get_comp(rs, rs_nonblocking(rs, flags), rs_conn_can_send); if (ret) break; - if (!(rs->state & rs_connect_wr)) { + if (!(rs->state & rs_writable)) { ret = ERR(ECONNRESET); break; } @@ -1447,10 +2356,92 @@ static int rs_send_iomaps(struct rsocket *rs, int flags) } rs->iomap_pending = !dlist_empty(&rs->iomap_queue); - fastlock_release(&rs->iomap_lock); + fastlock_release(&rs->map_lock); return ret; } +static ssize_t ds_sendv_udp(struct rsocket *rs, const struct iovec *iov, + int iovcnt, int flags, uint8_t op) +{ + struct ds_udp_header hdr; + struct msghdr msg; + struct iovec miov[8]; + ssize_t ret; + + if (iovcnt > 8) + return ERR(ENOTSUP); + + hdr.tag = htonl(DS_UDP_TAG); + hdr.version = rs->conn_dest->qp->hdr.version; + hdr.op = op; + hdr.reserved = 0; + hdr.qpn = htonl(rs->conn_dest->qp->cm_id->qp->qp_num & 0xFFFFFF); + if (rs->conn_dest->qp->hdr.version == 4) { + hdr.length = DS_UDP_IPV4_HDR_LEN; + hdr.addr.ipv4 = rs->conn_dest->qp->hdr.addr.ipv4; + } else { + hdr.length = DS_UDP_IPV6_HDR_LEN; + memcpy(hdr.addr.ipv6, &rs->conn_dest->qp->hdr.addr.ipv6, 16); + } + + miov[0].iov_base = &hdr; + miov[0].iov_len = hdr.length; + if (iov && iovcnt) + memcpy(&miov[1], iov, sizeof *iov * iovcnt); + + memset(&msg, 0, sizeof msg); + msg.msg_name = &rs->conn_dest->addr; + msg.msg_namelen = ucma_addrlen(&rs->conn_dest->addr.sa); + msg.msg_iov = miov; + msg.msg_iovlen = iovcnt + 1; + ret = sendmsg(rs->udp_sock, &msg, flags); + return ret > 0 ? ret - hdr.length : ret; +} + +static ssize_t ds_send_udp(struct rsocket *rs, const void *buf, size_t len, + int flags, uint8_t op) +{ + struct iovec iov; + if (buf && len) { + iov.iov_base = (void *) buf; + iov.iov_len = len; + return ds_sendv_udp(rs, &iov, 1, flags, op); + } else { + return ds_sendv_udp(rs, NULL, 0, flags, op); + } +} + +static ssize_t dsend(struct rsocket *rs, const void *buf, size_t len, int flags) +{ + struct ds_smsg *msg; + struct ibv_sge sge; + uint64_t offset; + int ret = 0; + + if (!rs->conn_dest->ah) + return ds_send_udp(rs, buf, len, flags, RS_OP_DATA); + + if (!ds_can_send(rs)) { + ret = ds_get_comp(rs, rs_nonblocking(rs, flags), ds_can_send); + if (ret) + return ret; + } + + msg = rs->smsg_free; + rs->smsg_free = msg->next; + rs->sqe_avail--; + + memcpy((void *) msg, &rs->conn_dest->qp->hdr, rs->conn_dest->qp->hdr.length); + memcpy((void *) msg + rs->conn_dest->qp->hdr.length, buf, len); + sge.addr = (uintptr_t) msg; + sge.length = rs->conn_dest->qp->hdr.length + len; + sge.lkey = rs->conn_dest->qp->smr->lkey; + offset = (uint8_t *) msg - rs->sbuf; + + ret = ds_post_send(rs, &sge, ds_send_wr_id(offset, sge.length)); + return ret ? ret : len; +} + /* * We overlap sending the data, by posting a small work request immediately, * then increasing the size of the send on each iteration. @@ -1464,6 +2455,13 @@ ssize_t rsend(int socket, const void *buf, size_t len, int flags) int ret = 0; rs = idm_at(&idm, socket); + if (rs->type == SOCK_DGRAM) { + fastlock_acquire(&rs->slock); + ret = dsend(rs, buf, len, flags); + fastlock_release(&rs->slock); + return ret; + } + if (rs->state & rs_opening) { ret = rs_do_connect(rs); if (ret) { @@ -1485,7 +2483,7 @@ ssize_t rsend(int socket, const void *buf, size_t len, int flags) rs_conn_can_send); if (ret) break; - if (!(rs->state & rs_connect_wr)) { + if (!(rs->state & rs_writable)) { ret = ERR(ECONNRESET); break; } @@ -1538,10 +2536,34 @@ out: ssize_t rsendto(int socket, const void *buf, size_t len, int flags, const struct sockaddr *dest_addr, socklen_t addrlen) { - if (dest_addr || addrlen) - return ERR(EISCONN); + struct rsocket *rs; + int ret; + + rs = idm_at(&idm, socket); + if (rs->type == SOCK_STREAM) { + if (dest_addr || addrlen) + return ERR(EISCONN); + + return rsend(socket, buf, len, flags); + } + + if (rs->state == rs_init) { + ret = ds_init_ep(rs); + if (ret) + return ret; + } + + fastlock_acquire(&rs->slock); + if (!rs->conn_dest || ds_compare_addr(dest_addr, &rs->conn_dest->addr)) { + ret = ds_get_dest(rs, dest_addr, addrlen, &rs->conn_dest); + if (ret) + goto out; + } - return rsend(socket, buf, len, flags); + ret = dsend(rs, buf, len, flags); +out: + fastlock_release(&rs->slock); + return ret; } static void rs_copy_iov(void *dst, const struct iovec **iov, size_t *offset, size_t len) @@ -1600,7 +2622,7 @@ static ssize_t rsendv(int socket, const struct iovec *iov, int iovcnt, int flags rs_conn_can_send); if (ret) break; - if (!(rs->state & rs_connect_wr)) { + if (!(rs->state & rs_writable)) { ret = ERR(ECONNRESET); break; } @@ -1653,7 +2675,7 @@ ssize_t rsendmsg(int socket, const struct msghdr *msg, int flags) if (msg->msg_control && msg->msg_controllen) return ERR(ENOTSUP); - return rsendv(socket, msg->msg_iov, (int) msg->msg_iovlen, msg->msg_flags); + return rsendv(socket, msg->msg_iov, (int) msg->msg_iovlen, flags); } ssize_t rwrite(int socket, const void *buf, size_t count) @@ -1690,8 +2712,8 @@ static int rs_poll_rs(struct rsocket *rs, int events, int ret; check_cq: - if ((rs->state & rs_connected) || (rs->state == rs_disconnected) || - (rs->state & rs_error)) { + if ((rs->type == SOCK_STREAM) && ((rs->state & rs_connected) || + (rs->state == rs_disconnected) || (rs->state & rs_error))) { rs_process_cq(rs, nonblock, test); revents = 0; @@ -1706,6 +2728,16 @@ check_cq: revents |= POLLERR; } + return revents; + } else if (rs->type == SOCK_DGRAM) { + ds_process_cqs(rs, nonblock, test); + + revents = 0; + if ((events & POLLIN) && rs_have_rdata(rs)) + revents |= POLLIN; + if ((events & POLLOUT) && ds_can_send(rs)) + revents |= POLLOUT; + return revents; } @@ -1766,18 +2798,20 @@ static int rs_poll_arm(struct pollfd *rfds, struct pollfd *fds, nfds_t nfds) if (fds[i].revents) return 1; - if (rs->state >= rs_connected) - rfds[i].fd = rs->cm_id->recv_cq_channel->fd; - else - rfds[i].fd = rs->cm_id->channel->fd; - + if (rs->type == SOCK_STREAM) { + if (rs->state >= rs_connected) + rfds[i].fd = rs->cm_id->recv_cq_channel->fd; + else + rfds[i].fd = rs->cm_id->channel->fd; + } else { + rfds[i].fd = rs->epfd; + } rfds[i].events = POLLIN; } else { rfds[i].fd = fds[i].fd; rfds[i].events = fds[i].events; } rfds[i].revents = 0; - } return 0; } @@ -1793,7 +2827,10 @@ static int rs_poll_events(struct pollfd *rfds, struct pollfd *fds, nfds_t nfds) rs = idm_lookup(&idm, fds[i].fd); if (rs) { - rs_get_cq_event(rs); + if (rs->type == SOCK_STREAM) + rs_get_cq_event(rs); + else + ds_get_cq_event(rs); fds[i].revents = rs_poll_rs(rs, fds[i].events, 1, rs_poll_all); } else { fds[i].revents = rfds[i].revents; @@ -1949,7 +2986,7 @@ int rshutdown(int socket, int how) rs = idm_at(&idm, socket); if (how == SHUT_RD) { - rs->state &= ~rs_connect_rd; + rs->state &= ~rs_readable; return 0; } @@ -1959,10 +2996,10 @@ int rshutdown(int socket, int how) if (rs->state & rs_connected) { if (how == SHUT_RDWR) { ctrl = RS_CTRL_DISCONNECT; - rs->state &= ~(rs_connect_rd | rs_connect_wr); + rs->state &= ~(rs_readable | rs_writable); } else { - rs->state &= ~rs_connect_wr; - ctrl = (rs->state & rs_connect_rd) ? + rs->state &= ~rs_writable; + ctrl = (rs->state & rs_readable) ? RS_CTRL_SHUTDOWN : RS_CTRL_DISCONNECT; } if (!rs->ctrl_avail) { @@ -1987,13 +3024,32 @@ int rshutdown(int socket, int how) return 0; } +static void ds_shutdown(struct rsocket *rs) +{ + if (rs->svcs) + rs_modify_svcs(rs, 0); + + if (rs->fd_flags & O_NONBLOCK) + rs_set_nonblocking(rs, 0); + + rs->state &= ~(rs_readable | rs_writable); + ds_process_cqs(rs, 0, ds_all_sends_done); + + if (rs->fd_flags & O_NONBLOCK) + rs_set_nonblocking(rs, 1); +} + int rclose(int socket) { struct rsocket *rs; rs = idm_at(&idm, socket); - if (rs->state & rs_connected) - rshutdown(socket, SHUT_RDWR); + if (rs->type == SOCK_STREAM) { + if (rs->state & rs_connected) + rshutdown(socket, SHUT_RDWR); + } else { + ds_shutdown(rs); + } rs_free(rs); return 0; @@ -2018,8 +3074,12 @@ int rgetpeername(int socket, struct sockaddr *addr, socklen_t *addrlen) struct rsocket *rs; rs = idm_at(&idm, socket); - rs_copy_addr(addr, rdma_get_peer_addr(rs->cm_id), addrlen); - return 0; + if (rs->type == SOCK_STREAM) { + rs_copy_addr(addr, rdma_get_peer_addr(rs->cm_id), addrlen); + return 0; + } else { + return getpeername(rs->udp_sock, addr, addrlen); + } } int rgetsockname(int socket, struct sockaddr *addr, socklen_t *addrlen) @@ -2027,8 +3087,12 @@ int rgetsockname(int socket, struct sockaddr *addr, socklen_t *addrlen) struct rsocket *rs; rs = idm_at(&idm, socket); - rs_copy_addr(addr, rdma_get_local_addr(rs->cm_id), addrlen); - return 0; + if (rs->type == SOCK_STREAM) { + rs_copy_addr(addr, rdma_get_local_addr(rs->cm_id), addrlen); + return 0; + } else { + return getsockname(rs->udp_sock, addr, addrlen); + } } int rsetsockopt(int socket, int level, int optname, @@ -2040,22 +3104,31 @@ int rsetsockopt(int socket, int level, int optname, ret = ERR(ENOTSUP); rs = idm_at(&idm, socket); + if (rs->type == SOCK_DGRAM && level != SOL_RDMA) { + ret = setsockopt(rs->udp_sock, level, optname, optval, optlen); + if (ret) + return ret; + } + switch (level) { case SOL_SOCKET: opts = &rs->so_opts; switch (optname) { case SO_REUSEADDR: - ret = rdma_set_option(rs->cm_id, RDMA_OPTION_ID, - RDMA_OPTION_ID_REUSEADDR, - (void *) optval, optlen); - if (ret && ((errno == ENOSYS) || ((rs->state != rs_init) && - rs->cm_id->context && - (rs->cm_id->verbs->device->transport_type == IBV_TRANSPORT_IB)))) - ret = 0; + if (rs->type == SOCK_STREAM) { + ret = rdma_set_option(rs->cm_id, RDMA_OPTION_ID, + RDMA_OPTION_ID_REUSEADDR, + (void *) optval, optlen); + if (ret && ((errno == ENOSYS) || ((rs->state != rs_init) && + rs->cm_id->context && + (rs->cm_id->verbs->device->transport_type == IBV_TRANSPORT_IB)))) + ret = 0; + } opt_on = *(int *) optval; break; case SO_RCVBUF: - if (!rs->rbuf) + if ((rs->type == SOCK_STREAM && !rs->rbuf) || + (rs->type == SOCK_DGRAM && !rs->qp_list)) rs->rbuf_size = (*(uint32_t *) optval) << 1; ret = 0; break; @@ -2101,9 +3174,11 @@ int rsetsockopt(int socket, int level, int optname, opts = &rs->ipv6_opts; switch (optname) { case IPV6_V6ONLY: - ret = rdma_set_option(rs->cm_id, RDMA_OPTION_ID, - RDMA_OPTION_ID_AFONLY, - (void *) optval, optlen); + if (rs->type == SOCK_STREAM) { + ret = rdma_set_option(rs->cm_id, RDMA_OPTION_ID, + RDMA_OPTION_ID_AFONLY, + (void *) optval, optlen); + } opt_on = *(int *) optval; break; default: @@ -2315,7 +3390,7 @@ off_t riomap(int socket, void *buf, size_t len, int prot, int flags, off_t offse if (!rs->cm_id->pd || (prot & ~(PROT_WRITE | PROT_NONE))) return ERR(EINVAL); - fastlock_acquire(&rs->iomap_lock); + fastlock_acquire(&rs->map_lock); if (prot & PROT_WRITE) { iomr = rs_get_iomap_mr(rs); access |= IBV_ACCESS_REMOTE_WRITE; @@ -2349,7 +3424,7 @@ off_t riomap(int socket, void *buf, size_t len, int prot, int flags, off_t offse dlist_insert_tail(&iomr->entry, &rs->iomap_list); } out: - fastlock_release(&rs->iomap_lock); + fastlock_release(&rs->map_lock); return offset; } @@ -2361,7 +3436,7 @@ int riounmap(int socket, void *buf, size_t len) int ret = 0; rs = idm_at(&idm, socket); - fastlock_acquire(&rs->iomap_lock); + fastlock_acquire(&rs->map_lock); for (entry = rs->iomap_list.next; entry != &rs->iomap_list; entry = entry->next) { @@ -2382,7 +3457,7 @@ int riounmap(int socket, void *buf, size_t len) } ret = ERR(EINVAL); out: - fastlock_release(&rs->iomap_lock); + fastlock_release(&rs->map_lock); return ret; } @@ -2426,7 +3501,7 @@ size_t riowrite(int socket, const void *buf, size_t count, off_t offset, int fla rs_conn_can_send); if (ret) break; - if (!(rs->state & rs_connect_wr)) { + if (!(rs->state & rs_writable)) { ret = ERR(ECONNRESET); break; } @@ -2476,3 +3551,272 @@ out: return (ret && left == count) ? ret : count - left; } + +static int rs_svc_grow_sets(void) +{ + struct rsocket **rss; + struct pollfd *fds; + void *set; + + set = calloc(svc_size + 2, sizeof(*rss) + sizeof(*fds)); + if (!set) + return ENOMEM; + + svc_size += 2; + rss = set; + fds = set + sizeof(*rss) * svc_size; + if (svc_cnt) { + memcpy(rss, svc_rss, sizeof(*rss) * svc_cnt); + memcpy(fds, svc_fds, sizeof(*fds) * svc_cnt); + } + + free(svc_rss); + free(svc_fds); + svc_rss = rss; + svc_fds = fds; + return 0; +} + +/* + * Index 0 is reserved for the service's communication socket. + */ +static int rs_svc_add_rs(struct rsocket *rs) +{ + int ret; + + if (svc_cnt >= svc_size - 1) { + ret = rs_svc_grow_sets(); + if (ret) + return ret; + } + + svc_rss[++svc_cnt] = rs; + svc_fds[svc_cnt].fd = rs->udp_sock; + svc_fds[svc_cnt].events = POLLIN; + svc_fds[svc_cnt].revents = 0; + return 0; +} + +static int rs_svc_rm_rs(struct rsocket *rs) +{ + int i; + + for (i = 1; i <= svc_cnt; i++) { + if (svc_rss[i] == rs) { + svc_cnt--; + svc_rss[i] = svc_rss[svc_cnt]; + svc_fds[i] = svc_fds[svc_cnt]; + return 0; + } + } + return EBADF; +} + +static void rs_svc_process_sock(void) +{ + struct rs_svc_msg msg; + + read(svc_sock[1], &msg, sizeof msg); + if (msg.svcs & RS_SVC_DGRAM) { + msg.status = rs_svc_add_rs(msg.rs); + } else if (!msg.svcs) { + msg.status = rs_svc_rm_rs(msg.rs); + } + + if (!msg.status) + msg.rs->svcs = msg.svcs; + write(svc_sock[1], &msg, sizeof msg); +} + +static uint8_t rs_svc_sgid_index(struct ds_dest *dest, union ibv_gid *sgid) +{ + union ibv_gid gid; + int i, ret; + + for (i = 0; i < 16; i++) { + ret = ibv_query_gid(dest->qp->cm_id->verbs, dest->qp->cm_id->port_num, + i, &gid); + if (!memcmp(sgid, &gid, sizeof gid)) + return i; + } + return 0; +} + +static uint8_t rs_svc_path_bits(struct ds_dest *dest) +{ + struct ibv_port_attr attr; + + if (!ibv_query_port(dest->qp->cm_id->verbs, dest->qp->cm_id->port_num, &attr)) + return (uint8_t) ((1 << attr.lmc) - 1); + return 0x7f; +} + +static void rs_svc_create_ah(struct rsocket *rs, struct ds_dest *dest, uint32_t qpn) +{ + union socket_addr saddr; + struct rdma_cm_id *id; + struct ibv_ah_attr attr; + int ret; + + if (dest->ah) { + fastlock_acquire(&rs->slock); + ibv_destroy_ah(dest->ah); + dest->ah = NULL; + fastlock_release(&rs->slock); + } + + ret = rdma_create_id(NULL, &id, NULL, dest->qp->cm_id->ps); + if (ret) + return; + + memcpy(&saddr, rdma_get_local_addr(dest->qp->cm_id), + ucma_addrlen(rdma_get_local_addr(dest->qp->cm_id))); + if (saddr.sa.sa_family == AF_INET) + saddr.sin.sin_port = 0; + else + saddr.sin6.sin6_port = 0; + ret = rdma_resolve_addr(id, &saddr.sa, &dest->addr.sa, 2000); + if (ret) + goto out; + + ret = rdma_resolve_route(id, 2000); + if (ret) + goto out; + + memset(&attr, 0, sizeof attr); + if (id->route.path_rec->hop_limit > 1) { + attr.is_global = 1; + attr.grh.dgid = id->route.path_rec->dgid; + attr.grh.flow_label = ntohl(id->route.path_rec->flow_label); + attr.grh.sgid_index = rs_svc_sgid_index(dest, &id->route.path_rec->sgid); + attr.grh.hop_limit = id->route.path_rec->hop_limit; + attr.grh.traffic_class = id->route.path_rec->traffic_class; + } + attr.dlid = ntohs(id->route.path_rec->dlid); + attr.sl = id->route.path_rec->sl; + attr.src_path_bits = id->route.path_rec->slid & rs_svc_path_bits(dest); + attr.static_rate = id->route.path_rec->rate; + attr.port_num = id->port_num; + + fastlock_acquire(&rs->slock); + dest->qpn = qpn; + dest->ah = ibv_create_ah(dest->qp->cm_id->pd, &attr); + fastlock_release(&rs->slock); +out: + rdma_destroy_id(id); +} + +static int rs_svc_valid_udp_hdr(struct ds_udp_header *udp_hdr, + union socket_addr *addr) +{ + return (udp_hdr->tag == ntohl(DS_UDP_TAG)) && + ((udp_hdr->version == 4 && addr->sa.sa_family == AF_INET && + udp_hdr->length == DS_UDP_IPV4_HDR_LEN) || + (udp_hdr->version == 6 && addr->sa.sa_family == AF_INET6 && + udp_hdr->length == DS_UDP_IPV6_HDR_LEN)); +} + +static void rs_svc_forward(struct rsocket *rs, void *buf, size_t len, + union socket_addr *src) +{ + struct ds_header hdr; + struct ds_smsg *msg; + struct ibv_sge sge; + uint64_t offset; + + if (!ds_can_send(rs)) { + if (ds_get_comp(rs, 0, ds_can_send)) + return; + } + + msg = rs->smsg_free; + rs->smsg_free = msg->next; + rs->sqe_avail--; + + ds_format_hdr(&hdr, src); + memcpy((void *) msg, &hdr, hdr.length); + memcpy((void *) msg + hdr.length, buf, len); + sge.addr = (uintptr_t) msg; + sge.length = hdr.length + len; + sge.lkey = rs->conn_dest->qp->smr->lkey; + offset = (uint8_t *) msg - rs->sbuf; + + ds_post_send(rs, &sge, ds_send_wr_id(offset, sge.length)); +} + +static void rs_svc_process_rs(struct rsocket *rs) +{ + struct ds_dest *dest, *cur_dest; + struct ds_udp_header *udp_hdr; + union socket_addr addr; + socklen_t addrlen = sizeof addr; + int len, ret; + + ret = recvfrom(rs->udp_sock, svc_buf, sizeof svc_buf, 0, &addr.sa, &addrlen); + if (ret < DS_UDP_IPV4_HDR_LEN) + return; + + udp_hdr = (struct ds_udp_header *) svc_buf; + if (!rs_svc_valid_udp_hdr(udp_hdr, &addr)) + return; + + len = ret - udp_hdr->length; + udp_hdr->tag = ntohl(udp_hdr->tag); + udp_hdr->qpn = ntohl(udp_hdr->qpn) & 0xFFFFFF; + ret = ds_get_dest(rs, &addr.sa, addrlen, &dest); + if (ret) + return; + + if (udp_hdr->op == RS_OP_DATA) { + fastlock_acquire(&rs->slock); + cur_dest = rs->conn_dest; + rs->conn_dest = dest; + ds_send_udp(rs, NULL, 0, 0, RS_OP_CTRL); + rs->conn_dest = cur_dest; + fastlock_release(&rs->slock); + } + + if (!dest->ah || (dest->qpn != udp_hdr->qpn)) + rs_svc_create_ah(rs, dest, udp_hdr->qpn); + + /* to do: handle when dest local ip address doesn't match udp ip */ + if (udp_hdr->op == RS_OP_DATA) { + fastlock_acquire(&rs->slock); + cur_dest = rs->conn_dest; + rs->conn_dest = &dest->qp->dest; + rs_svc_forward(rs, svc_buf + udp_hdr->length, len, &addr); + rs->conn_dest = cur_dest; + fastlock_release(&rs->slock); + } +} + +static void *rs_svc_run(void *arg) +{ + struct rs_svc_msg msg; + int i, ret; + + ret = rs_svc_grow_sets(); + if (ret) { + msg.status = ret; + write(svc_sock[1], &msg, sizeof msg); + return (void *) (uintptr_t) ret; + } + + svc_fds[0].fd = svc_sock[1]; + svc_fds[0].events = POLLIN; + do { + for (i = 0; i <= svc_cnt; i++) + svc_fds[i].revents = 0; + + poll(svc_fds, svc_cnt + 1, -1); + if (svc_fds[0].revents) + rs_svc_process_sock(); + + for (i = 1; i <= svc_cnt; i++) { + if (svc_fds[i].revents) + rs_svc_process_rs(svc_rss[i]); + } + } while (svc_cnt >= 1); + + return NULL; +}