From ad3be3a08de1c984c85f2f44d953894f30d279f6 Mon Sep 17 00:00:00 2001 From: Sean Hefty Date: Fri, 9 Nov 2012 10:26:38 -0800 Subject: [PATCH] rsocket: Add datagram support Signed-off-by: Sean Hefty --- docs/rsocket | 37 +- src/cma.c | 13 +- src/cma.h | 2 + src/rsocket.c | 1609 ++++++++++++++++++++++++++++++++++++++++++++----- 4 files changed, 1518 insertions(+), 143 deletions(-) diff --git a/docs/rsocket b/docs/rsocket index 1484f65b..a6602084 100644 --- a/docs/rsocket +++ b/docs/rsocket @@ -1,7 +1,7 @@ -rsocket Protocol and Design Guide 9/10/2012 +rsocket Protocol and Design Guide 11/11/2012 -Overview --------- +Data Streaming (TCP) Overview +----------------------------- Rsockets is a protocol over RDMA that supports a socket-level API for applications. For details on the current state of the implementation, readers should refer to the rsocket man page. This @@ -189,3 +189,34 @@ registered remote data buffer. From host A's perspective, the transfer appears as a normal send/write operation, with the data stream redirected directly into the receiving application's buffer. + + + +Datagram Overview +----------------- +The rsocket API supports datagram sockets. Datagram support is handled through an +entirely different protocol and internal implementation. Unlike connected rsockets, +datagram rsockets are not necessarily bound to a network (IP) address. A datagram +socket may use any number of network (IP) addresses, including those which map to +different RDMA devices. As a result, a single datagram rsocket must support +using multiple RDMA devices and ports, and a datagram rsocket references a single +UDP socket, plus zero or more UD QPs. + +Rsockets uses headers inserted before user data sent over UDP sockets to resolve +remote UD QP numbers. When a user first attempts to send a datagram to a remote +address (IP and UDP port), rsockets will take the following steps: + +1. Store the destination address into a lookup table. +2. Resolve which local network address should be used when sending + to the specified destination. +3. Allocate a UD QP on the RDMA device associated with the local address. +4. Send the user's datagram to the remote UDP socket. + +A header is inserted before the user's datagram. The header specifies the +UD QP number associated with the local network address (IP and UDP port) of +the send. + +A service thread is used to process messages received on the UDP socket. This +thread updates the rsocket lookup tables with the remote QPN and path record +data. The service thread forwards data received on the UDP socket to an +rsocket QP. \ No newline at end of file diff --git a/src/cma.c b/src/cma.c index 388be617..0f589668 100755 --- a/src/cma.c +++ b/src/cma.c @@ -513,7 +513,7 @@ int rdma_destroy_id(struct rdma_cm_id *id) return 0; } -static int ucma_addrlen(struct sockaddr *addr) +int ucma_addrlen(struct sockaddr *addr) { if (!addr) return 0; @@ -2232,9 +2232,18 @@ void rdma_destroy_ep(struct rdma_cm_id *id) int ucma_max_qpsize(struct rdma_cm_id *id) { struct cma_id_private *id_priv; + int i, max_size = 0; id_priv = container_of(id, struct cma_id_private, id); - return id_priv->cma_dev->max_qpsize; + if (id && id_priv->cma_dev) { + max_size = id_priv->cma_dev->max_qpsize; + } else { + for (i = 0; i < cma_dev_cnt; i++) { + if (!max_size || max_size > cma_dev_array[i].max_qpsize) + max_size = cma_dev_array[i].max_qpsize; + } + } + return max_size; } uint16_t ucma_get_port(struct sockaddr *addr) diff --git a/src/cma.h b/src/cma.h index 0a0370ee..7135a612 100644 --- a/src/cma.h +++ b/src/cma.h @@ -145,10 +145,12 @@ typedef struct { volatile int val; } atomic_t; #define atomic_set(v, s) ((v)->val = s) uint16_t ucma_get_port(struct sockaddr *addr); +int ucma_addrlen(struct sockaddr *addr); void ucma_set_sid(enum rdma_port_space ps, struct sockaddr *addr, struct sockaddr_ib *sib); int ucma_max_qpsize(struct rdma_cm_id *id); int ucma_complete(struct rdma_cm_id *id); + static inline int ERR(int err) { errno = err; diff --git a/src/rsocket.c b/src/rsocket.c index a060f66a..9996d333 100644 --- a/src/rsocket.c +++ b/src/rsocket.c @@ -47,6 +47,8 @@ #include #include #include +#include +#include #include #include @@ -56,7 +58,7 @@ #define RS_OLAP_START_SIZE 2048 #define RS_MAX_TRANSFER 65536 -#define RS_SNDLOWAT 64 +#define RS_SNDLOWAT 2048 #define RS_QP_MAX_SIZE 0xFFFE #define RS_QP_CTRL_SIZE 4 #define RS_CONN_RETRIES 6 @@ -64,6 +66,28 @@ static struct index_map idm; static pthread_mutex_t mut = PTHREAD_MUTEX_INITIALIZER; +enum { + RS_SVC_INSERT, + RS_SVC_REMOVE +}; + +struct rsocket; + +struct rs_svc_msg { + uint32_t op; + uint32_t status; + struct rsocket *rs; +}; + +static pthread_t svc_id; +static int svc_sock[2]; +static int svc_cnt; +static int svc_size; +static struct rsocket **svc_rss; +static struct pollfd *svc_fds; +static uint8_t svc_buf[RS_SNDLOWAT]; +static void *rs_svc_run(void *arg); + static uint16_t def_iomap_size = 0; static uint16_t def_inline = 64; static uint16_t def_sqsize = 384; @@ -100,6 +124,14 @@ enum { #define rs_msg_set(op, data) ((op << 29) | (uint32_t) (data)) #define rs_msg_op(imm_data) (imm_data >> 29) #define rs_msg_data(imm_data) (imm_data & 0x1FFFFFFF) +#define RS_RECV_WR_ID (~((uint64_t) 0)) + +#define DS_WR_RECV 0xFFFFFFFF +#define ds_send_wr_id(offset, length) (((uint64_t) (offset)) << 32 | (uint64_t) length) +#define ds_recv_wr_id(offset) (((uint64_t) (offset)) << 32 | (uint64_t) DS_WR_RECV) +#define ds_wr_offset(wr_id) ((uint32_t) (wr_id >> 32)) +#define ds_wr_length(wr_id) ((uint32_t) wr_id) +#define ds_wr_is_recv(wr_id) (ds_wr_length(wr_id) == DS_WR_RECV) enum { RS_CTRL_DISCONNECT, @@ -111,6 +143,18 @@ struct rs_msg { uint32_t data; }; +struct ds_qp; + +struct ds_rmsg { + struct ds_qp *qp; + uint32_t offset; + uint32_t length; +}; + +struct ds_smsg { + struct ds_smsg *next; +}; + struct rs_sge { uint64_t addr; uint32_t key; @@ -145,8 +189,6 @@ struct rs_conn_data { struct rs_sge data_buf; }; -#define RS_RECV_WR_ID (~((uint64_t) 0)) - /* * rsocket states are ordered as passive, connecting, connected, disconnected. */ @@ -160,9 +202,9 @@ enum rs_state { rs_connecting = rs_opening | 0x0040, rs_accepting = rs_opening | 0x0080, rs_connected = 0x0100, - rs_connect_wr = 0x0200, - rs_connect_rd = 0x0400, - rs_connect_rdwr = rs_connected | rs_connect_rd | rs_connect_wr, + rs_writable = 0x0200, + rs_readable = 0x0400, + rs_connect_rdwr = rs_connected | rs_readable | rs_writable, rs_connect_error = 0x0800, rs_disconnected = 0x1000, rs_error = 0x2000, @@ -170,68 +212,248 @@ enum rs_state { #define RS_OPT_SWAP_SGL 1 -struct rsocket { +union socket_addr { + struct sockaddr sa; + struct sockaddr_in sin; + struct sockaddr_in6 sin6; +}; + +struct ds_header { + uint8_t version; + uint8_t length; + uint16_t port; + union { + uint32_t ipv4; + struct { + uint32_t flowinfo; + uint8_t addr[16]; + } ipv6; + } addr; +}; + +#define DS_IPV4_HDR_LEN 8 +#define DS_IPV6_HDR_LEN 24 + +struct ds_dest { + union socket_addr addr; /* must be first */ + struct ds_qp *qp; + struct ibv_ah *ah; + uint32_t qpn; +}; + +struct ds_qp { + dlist_entry list; + struct rsocket *rs; struct rdma_cm_id *cm_id; + struct ds_header hdr; + struct ds_dest dest; + + struct ibv_mr *smr; + struct ibv_mr *rmr; + uint8_t *rbuf; + + int cq_armed; +}; + +struct rsocket { + int type; + int index; fastlock_t slock; fastlock_t rlock; fastlock_t cq_lock; fastlock_t cq_wait_lock; - fastlock_t iomap_lock; + fastlock_t map_lock; /* acquire slock first if needed */ + + union { + /* data stream */ + struct { + struct rdma_cm_id *cm_id; + uint64_t tcp_opts; + + int ctrl_avail; + uint16_t sseq_no; + uint16_t sseq_comp; + uint16_t rseq_no; + uint16_t rseq_comp; + + int remote_sge; + struct rs_sge remote_sgl; + struct rs_sge remote_iomap; + + struct ibv_mr *target_mr; + int target_sge; + int target_iomap_size; + void *target_buffer_list; + volatile struct rs_sge *target_sgl; + struct rs_iomap *target_iomap; + + int rbuf_bytes_avail; + int rbuf_free_offset; + int rbuf_offset; + struct ibv_mr *rmr; + uint8_t *rbuf; + + int sbuf_bytes_avail; + struct ibv_mr *smr; + struct ibv_sge ssgl[2]; + }; + /* datagram */ + struct { + struct ds_qp *qp_list; + void *dest_map; + struct ds_dest *conn_dest; + + int udp_sock; + int epfd; + int rqe_avail; + struct ds_smsg *smsg_free; + }; + }; int opts; long fd_flags; uint64_t so_opts; - uint64_t tcp_opts; uint64_t ipv6_opts; int state; int cq_armed; int retries; int err; - int index; - int ctrl_avail; + int sqe_avail; - int sbuf_bytes_avail; - uint16_t sseq_no; - uint16_t sseq_comp; + uint32_t sbuf_size; uint16_t sq_size; uint16_t sq_inline; + uint32_t rbuf_size; uint16_t rq_size; - uint16_t rseq_no; - uint16_t rseq_comp; - int rbuf_bytes_avail; - int rbuf_free_offset; - int rbuf_offset; int rmsg_head; int rmsg_tail; - struct rs_msg *rmsg; - - int remote_sge; - struct rs_sge remote_sgl; - struct rs_sge remote_iomap; + union { + struct rs_msg *rmsg; + struct ds_rmsg *dmsg; + }; + uint8_t *sbuf; struct rs_iomap_mr *remote_iomappings; dlist_entry iomap_list; dlist_entry iomap_queue; int iomap_pending; +}; - struct ibv_mr *target_mr; - int target_sge; - int target_iomap_size; - void *target_buffer_list; - volatile struct rs_sge *target_sgl; - struct rs_iomap *target_iomap; +#define DS_UDP_TAG 0x55555555 - uint32_t rbuf_size; - struct ibv_mr *rmr; - uint8_t *rbuf; - - uint32_t sbuf_size; - struct ibv_mr *smr; - struct ibv_sge ssgl[2]; - uint8_t *sbuf; +struct ds_udp_header { + uint32_t tag; + uint8_t version; + uint8_t op; + uint8_t length; + uint8_t reserved; + uint32_t qpn; /* lower 8-bits reserved */ + union { + uint32_t ipv4; + uint8_t ipv6[16]; + } addr; }; +#define DS_UDP_IPV4_HDR_LEN 16 +#define DS_UDP_IPV6_HDR_LEN 28 + +#define ds_next_qp(qp) container_of((qp)->list.next, struct ds_qp, list) + +static void ds_insert_qp(struct rsocket *rs, struct ds_qp *qp) +{ + if (!rs->qp_list) + dlist_init(&qp->list); + else + dlist_insert_head(&qp->list, &rs->qp_list->list); + rs->qp_list = qp; +} + +static void ds_remove_qp(struct rsocket *rs, struct ds_qp *qp) +{ + if (qp->list.next != &qp->list) { + rs->qp_list = ds_next_qp(qp); + dlist_remove(&qp->list); + } else { + rs->qp_list = NULL; + } +} + +static int rs_add_to_svc(struct rsocket *rs) +{ + struct rs_svc_msg msg; + int ret; + + pthread_mutex_lock(&mut); + if (!svc_cnt) { + ret = socketpair(AF_INET, SOCK_STREAM, 0, svc_sock); + if (ret) + goto err1; + + ret = pthread_create(&svc_id, NULL, rs_svc_run, NULL); + if (ret) { + ret = ERR(ret); + goto err2; + } + } + + msg.op = RS_SVC_INSERT; + msg.status = EINVAL; + msg.rs = rs; + write(svc_sock[0], &msg, sizeof msg); + read(svc_sock[0], &msg, sizeof msg); + ret = ERR(msg.status); + if (ret && !svc_cnt) + goto err3; + + pthread_mutex_unlock(&mut); + return ret; + +err3: + pthread_join(svc_id, NULL); +err2: + close(svc_sock[0]); + close(svc_sock[1]); +err1: + pthread_mutex_unlock(&mut); + return ret; +} + +static int rs_remove_from_svc(struct rsocket *rs) +{ + struct rs_svc_msg msg; + int ret; + + pthread_mutex_lock(&mut); + msg.op = RS_SVC_REMOVE; + msg.status = EINVAL; + msg.rs = rs; + write(svc_sock[0], &msg, sizeof msg); + read(svc_sock[0], &msg, sizeof msg); + ret = ERR(msg.status); + if (!svc_cnt) { + pthread_join(svc_id, NULL); + close(svc_sock[0]); + close(svc_sock[1]); + } + + pthread_mutex_unlock(&mut); + return ret; +} + +static int ds_compare_addr(const void *dst1, const void *dst2) +{ + const struct sockaddr *sa1, *sa2; + size_t len; + + sa1 = (const struct sockaddr *) dst1; + sa2 = (const struct sockaddr *) dst2; + + len = (sa1->sa_family == AF_INET6 && sa2->sa_family == AF_INET6) ? + sizeof(struct sockaddr_in6) : sizeof(struct sockaddr); + return memcmp(dst1, dst2, len); +} + static int rs_value_to_scale(int value, int bits) { return value <= (1 << (bits - 1)) ? @@ -307,10 +529,10 @@ out: pthread_mutex_unlock(&mut); } -static int rs_insert(struct rsocket *rs) +static int rs_insert(struct rsocket *rs, int index) { pthread_mutex_lock(&mut); - rs->index = idm_set(&idm, rs->cm_id->channel->fd, rs); + rs->index = idm_set(&idm, index, rs); pthread_mutex_unlock(&mut); return rs->index; } @@ -322,7 +544,7 @@ static void rs_remove(struct rsocket *rs) pthread_mutex_unlock(&mut); } -static struct rsocket *rs_alloc(struct rsocket *inherited_rs) +static struct rsocket *rs_alloc(struct rsocket *inherited_rs, int type) { struct rsocket *rs; @@ -330,7 +552,11 @@ static struct rsocket *rs_alloc(struct rsocket *inherited_rs) if (!rs) return NULL; + rs->type = type; rs->index = -1; + rs->udp_sock = -1; + rs->epfd = -1; + if (inherited_rs) { rs->sbuf_size = inherited_rs->sbuf_size; rs->rbuf_size = inherited_rs->rbuf_size; @@ -352,7 +578,7 @@ static struct rsocket *rs_alloc(struct rsocket *inherited_rs) fastlock_init(&rs->rlock); fastlock_init(&rs->cq_lock); fastlock_init(&rs->cq_wait_lock); - fastlock_init(&rs->iomap_lock); + fastlock_init(&rs->map_lock); dlist_init(&rs->iomap_list); dlist_init(&rs->iomap_queue); return rs; @@ -360,13 +586,27 @@ static struct rsocket *rs_alloc(struct rsocket *inherited_rs) static int rs_set_nonblocking(struct rsocket *rs, long arg) { + struct ds_qp *qp; int ret = 0; - if (rs->cm_id->recv_cq_channel) - ret = fcntl(rs->cm_id->recv_cq_channel->fd, F_SETFL, arg); + if (rs->type == SOCK_STREAM) { + if (rs->cm_id->recv_cq_channel) + ret = fcntl(rs->cm_id->recv_cq_channel->fd, F_SETFL, arg); - if (!ret && rs->state < rs_connected) - ret = fcntl(rs->cm_id->channel->fd, F_SETFL, arg); + if (!ret && rs->state < rs_connected) + ret = fcntl(rs->cm_id->channel->fd, F_SETFL, arg); + } else { + ret = fcntl(rs->epfd, F_SETFL, arg); + + if (!ret && rs->qp_list) { + qp = rs->qp_list; + do { + ret = fcntl(qp->cm_id->recv_cq_channel->fd, + F_SETFL, arg); + qp = ds_next_qp(qp); + } while (qp != rs->qp_list && !ret); + } + } return ret; } @@ -390,17 +630,39 @@ static void rs_set_qp_size(struct rsocket *rs) rs->rq_size = 2; } +static void ds_set_qp_size(struct rsocket *rs) +{ + uint16_t max_size; + + max_size = min(ucma_max_qpsize(NULL), RS_QP_MAX_SIZE); + + if (rs->sq_size > max_size) + rs->sq_size = max_size; + if (rs->rq_size > max_size) + rs->rq_size = max_size; + + if (rs->rq_size > (rs->rbuf_size / RS_SNDLOWAT)) + rs->rq_size = rs->rbuf_size / RS_SNDLOWAT; + else + rs->rbuf_size = rs->rq_size * RS_SNDLOWAT; + + if (rs->sq_size > (rs->sbuf_size / RS_SNDLOWAT)) + rs->sq_size = rs->sbuf_size / RS_SNDLOWAT; + else + rs->sbuf_size = rs->sq_size * RS_SNDLOWAT; +} + static int rs_init_bufs(struct rsocket *rs) { size_t len; rs->rmsg = calloc(rs->rq_size + 1, sizeof(*rs->rmsg)); if (!rs->rmsg) - return -1; + return ERR(ENOMEM); rs->sbuf = calloc(rs->sbuf_size, sizeof(*rs->sbuf)); if (!rs->sbuf) - return -1; + return ERR(ENOMEM); rs->smr = rdma_reg_msgs(rs->cm_id, rs->sbuf, rs->sbuf_size); if (!rs->smr) @@ -410,7 +672,7 @@ static int rs_init_bufs(struct rsocket *rs) sizeof(*rs->target_iomap) * rs->target_iomap_size; rs->target_buffer_list = malloc(len); if (!rs->target_buffer_list) - return -1; + return ERR(ENOMEM); rs->target_mr = rdma_reg_write(rs->cm_id, rs->target_buffer_list, len); if (!rs->target_mr) @@ -423,7 +685,7 @@ static int rs_init_bufs(struct rsocket *rs) rs->rbuf = calloc(rs->rbuf_size, sizeof(*rs->rbuf)); if (!rs->rbuf) - return -1; + return ERR(ENOMEM); rs->rmr = rdma_reg_write(rs->cm_id, rs->rbuf, rs->rbuf_size); if (!rs->rmr) @@ -440,15 +702,32 @@ static int rs_init_bufs(struct rsocket *rs) return 0; } -static int rs_create_cq(struct rsocket *rs) +static int ds_init_bufs(struct ds_qp *qp) { - rs->cm_id->recv_cq_channel = ibv_create_comp_channel(rs->cm_id->verbs); - if (!rs->cm_id->recv_cq_channel) + qp->rbuf = calloc(qp->rs->rbuf_size, sizeof(*qp->rbuf)); + if (!qp->rbuf) + return ERR(ENOMEM); + + qp->smr = rdma_reg_msgs(qp->cm_id, qp->rs->sbuf, qp->rs->sbuf_size); + if (!qp->smr) return -1; - rs->cm_id->recv_cq = ibv_create_cq(rs->cm_id->verbs, rs->sq_size + rs->rq_size, - rs->cm_id, rs->cm_id->recv_cq_channel, 0); - if (!rs->cm_id->recv_cq) + qp->rmr = rdma_reg_msgs(qp->cm_id, qp->rbuf, qp->rs->rbuf_size); + if (!qp->rmr) + return -1; + + return 0; +} + +static int rs_create_cq(struct rsocket *rs, struct rdma_cm_id *cm_id) +{ + cm_id->recv_cq_channel = ibv_create_comp_channel(cm_id->verbs); + if (!cm_id->recv_cq_channel) + return -1; + + cm_id->recv_cq = ibv_create_cq(cm_id->verbs, rs->sq_size + rs->rq_size, + cm_id, cm_id->recv_cq_channel, 0); + if (!cm_id->recv_cq) goto err1; if (rs->fd_flags & O_NONBLOCK) { @@ -456,21 +735,20 @@ static int rs_create_cq(struct rsocket *rs) goto err2; } - rs->cm_id->send_cq_channel = rs->cm_id->recv_cq_channel; - rs->cm_id->send_cq = rs->cm_id->recv_cq; + cm_id->send_cq_channel = cm_id->recv_cq_channel; + cm_id->send_cq = cm_id->recv_cq; return 0; err2: - ibv_destroy_cq(rs->cm_id->recv_cq); - rs->cm_id->recv_cq = NULL; + ibv_destroy_cq(cm_id->recv_cq); + cm_id->recv_cq = NULL; err1: - ibv_destroy_comp_channel(rs->cm_id->recv_cq_channel); - rs->cm_id->recv_cq_channel = NULL; + ibv_destroy_comp_channel(cm_id->recv_cq_channel); + cm_id->recv_cq_channel = NULL; return -1; } -static inline int -rs_post_recv(struct rsocket *rs) +static inline int rs_post_recv(struct rsocket *rs) { struct ibv_recv_wr wr, *bad; @@ -482,6 +760,23 @@ rs_post_recv(struct rsocket *rs) return rdma_seterrno(ibv_post_recv(rs->cm_id->qp, &wr, &bad)); } +static inline int ds_post_recv(struct rsocket *rs, struct ds_qp *qp, void *buf) +{ + struct ibv_recv_wr wr, *bad; + struct ibv_sge sge; + + sge.addr = (uintptr_t) buf; + sge.length = RS_SNDLOWAT; + sge.lkey = qp->rmr->lkey; + + wr.wr_id = ds_recv_wr_id((uint32_t) ((uint8_t *) buf - rs->rbuf)); + wr.next = NULL; + wr.sg_list = &sge; + wr.num_sge = 1; + + return rdma_seterrno(ibv_post_recv(qp->cm_id->qp, &wr, &bad)); +} + static int rs_create_ep(struct rsocket *rs) { struct ibv_qp_init_attr qp_attr; @@ -492,7 +787,7 @@ static int rs_create_ep(struct rsocket *rs) if (ret) return ret; - ret = rs_create_cq(rs); + ret = rs_create_cq(rs, rs->cm_id); if (ret) return ret; @@ -549,8 +844,74 @@ static void rs_free_iomappings(struct rsocket *rs) } } +static void ds_free_qp(struct ds_qp *qp) +{ + if (qp->smr) + rdma_dereg_mr(qp->smr); + + if (qp->rbuf) { + if (qp->rmr) + rdma_dereg_mr(qp->rmr); + free(qp->rbuf); + } + + if (qp->cm_id) { + if (qp->cm_id->qp) { + tdelete(&qp->dest.addr, &qp->rs->dest_map, ds_compare_addr); + epoll_ctl(qp->rs->epfd, EPOLL_CTL_DEL, + qp->cm_id->recv_cq_channel->fd, NULL); + rdma_destroy_qp(qp->cm_id); + } + rdma_destroy_id(qp->cm_id); + } + + free(qp); +} + +static void ds_free(struct rsocket *rs) +{ + if (rs->state & (rs_readable | rs_writable)) + rs_remove_from_svc(rs); + + if (rs->udp_sock >= 0) + close(rs->udp_sock); + + if (rs->index >= 0) + rs_remove(rs); + + if (rs->dmsg) + free(rs->dmsg); + + if (rs->smsg_free) + free(rs->smsg_free); + + while (rs->qp_list) { + ds_remove_qp(rs, rs->qp_list); + ds_free_qp(rs->qp_list); + } + + if (rs->epfd >= 0) + close(rs->epfd); + + if (rs->sbuf) + free(rs->sbuf); + + tdestroy(rs->dest_map, free); + fastlock_destroy(&rs->map_lock); + fastlock_destroy(&rs->cq_wait_lock); + fastlock_destroy(&rs->cq_lock); + fastlock_destroy(&rs->rlock); + fastlock_destroy(&rs->slock); + free(rs); +} + static void rs_free(struct rsocket *rs) { + if (rs->type == SOCK_DGRAM) { + ds_free(rs); + return; + } + if (rs->index >= 0) rs_remove(rs); @@ -582,7 +943,7 @@ static void rs_free(struct rsocket *rs) rdma_destroy_id(rs->cm_id); } - fastlock_destroy(&rs->iomap_lock); + fastlock_destroy(&rs->map_lock); fastlock_destroy(&rs->cq_wait_lock); fastlock_destroy(&rs->cq_lock); fastlock_destroy(&rs->rlock); @@ -636,29 +997,54 @@ static void rs_save_conn_data(struct rsocket *rs, struct rs_conn_data *conn) rs->sseq_comp = ntohs(conn->credits); } +static int ds_init(struct rsocket *rs, int domain) +{ + rs->udp_sock = socket(domain, SOCK_DGRAM, 0); + if (rs->udp_sock < 0) + return rs->udp_sock; + + rs->epfd = epoll_create(0); + if (rs->epfd < 0) + return rs->epfd; + + return 0; +} + int rsocket(int domain, int type, int protocol) { struct rsocket *rs; - int ret; + int index, ret; if ((domain != PF_INET && domain != PF_INET6) || - (type != SOCK_STREAM) || (protocol && protocol != IPPROTO_TCP)) + ((type != SOCK_STREAM) && (type != SOCK_DGRAM)) || + (type == SOCK_STREAM && protocol && protocol != IPPROTO_TCP) || + (type == SOCK_DGRAM && protocol && protocol != IPPROTO_UDP)) return ERR(ENOTSUP); rs_configure(); - rs = rs_alloc(NULL); + rs = rs_alloc(NULL, type); if (!rs) return ERR(ENOMEM); - ret = rdma_create_id(NULL, &rs->cm_id, rs, RDMA_PS_TCP); - if (ret) - goto err; + if (type == SOCK_STREAM) { + ret = rdma_create_id(NULL, &rs->cm_id, rs, RDMA_PS_TCP); + if (ret) + goto err; + + rs->cm_id->route.addr.src_addr.sa_family = domain; + index = rs->cm_id->channel->fd; + } else { + ret = ds_init(rs, domain); + if (ret) + goto err; + + index = rs->udp_sock; + } - ret = rs_insert(rs); + ret = rs_insert(rs, index); if (ret < 0) goto err; - rs->cm_id->route.addr.src_addr.sa_family = domain; return rs->index; err: @@ -672,9 +1058,18 @@ int rbind(int socket, const struct sockaddr *addr, socklen_t addrlen) int ret; rs = idm_at(&idm, socket); - ret = rdma_bind_addr(rs->cm_id, (struct sockaddr *) addr); - if (!ret) - rs->state = rs_bound; + if (rs->type == SOCK_STREAM) { + ret = rdma_bind_addr(rs->cm_id, (struct sockaddr *) addr); + if (!ret) + rs->state = rs_bound; + } else { + ret = bind(rs->udp_sock, addr, addrlen); + if (!ret) { + ret = rs_add_to_svc(rs); + if (!ret) + rs->state = rs_readable | rs_writable; + } + } return ret; } @@ -710,7 +1105,7 @@ int raccept(int socket, struct sockaddr *addr, socklen_t *addrlen) int ret; rs = idm_at(&idm, socket); - new_rs = rs_alloc(rs); + new_rs = rs_alloc(rs, rs->type); if (!new_rs) return ERR(ENOMEM); @@ -718,7 +1113,7 @@ int raccept(int socket, struct sockaddr *addr, socklen_t *addrlen) if (ret) goto err; - ret = rs_insert(new_rs); + ret = rs_insert(new_rs, rs->cm_id->channel->fd); if (ret < 0) goto err; @@ -855,13 +1250,268 @@ connected: return ret; } +static int ds_init_ep(struct rsocket *rs) +{ + struct ds_smsg *msg; + int i, ret; + + ds_set_qp_size(rs); + + rs->sbuf = calloc(rs->sq_size, RS_SNDLOWAT); + if (!rs->sbuf) + return ERR(ENOMEM); + + rs->dmsg = calloc(rs->rq_size + 1, sizeof(*rs->dmsg)); + if (!rs->dmsg) + return ERR(ENOMEM); + + rs->sbuf_bytes_avail = rs->sbuf_size; + rs->sqe_avail = rs->sq_size; + rs->rqe_avail = rs->rq_size; + + rs->smsg_free = (struct ds_smsg *) rs->sbuf; + msg = rs->smsg_free; + for (i = 0; i < rs->sq_size - 1; i++) { + msg->next = (void *) msg + i * RS_SNDLOWAT; + msg = msg->next; + } + msg->next = NULL; + + ret = rs_add_to_svc(rs); + if (ret) + return ret; + + rs->state = rs_readable | rs_writable; + return 0; +} + +static int rs_any_addr(const union socket_addr *addr) +{ + if (addr->sa.sa_family == AF_INET) { + return (addr->sin.sin_addr.s_addr == INADDR_ANY || + addr->sin.sin_addr.s_addr == INADDR_LOOPBACK); + } else { + return (!memcmp(&addr->sin6.sin6_addr, &in6addr_any, 16) || + !memcmp(&addr->sin6.sin6_addr, &in6addr_loopback, 16)); + } +} + +static int ds_get_src_addr(struct rsocket *rs, + const struct sockaddr *dest_addr, socklen_t dest_len, + union socket_addr *src_addr, socklen_t *src_len) +{ + int sock, ret; + uint16_t port; + + *src_len = sizeof src_addr; + ret = getsockname(rs->udp_sock, &src_addr->sa, src_len); + if (ret || !rs_any_addr(src_addr)) + return ret; + + port = src_addr->sin.sin_port; + sock = socket(dest_addr->sa_family, SOCK_DGRAM, 0); + if (sock < 0) + return sock; + + ret = connect(sock, dest_addr, dest_len); + if (ret) + goto out; + + *src_len = sizeof src_addr; + ret = getsockname(sock, &src_addr->sa, src_len); + src_addr->sin.sin_port = port; +out: + close(sock); + return ret; +} + +static void ds_format_hdr(struct ds_header *hdr, union socket_addr *addr) +{ + if (addr->sa.sa_family == AF_INET) { + hdr->version = 4; + hdr->length = DS_IPV4_HDR_LEN; + hdr->port = addr->sin.sin_port; + hdr->addr.ipv4 = addr->sin.sin_addr.s_addr; + } else { + hdr->version = 6; + hdr->length = DS_IPV6_HDR_LEN; + hdr->port = addr->sin6.sin6_port; + hdr->addr.ipv6.flowinfo= addr->sin6.sin6_flowinfo; + memcpy(&hdr->addr.ipv6.addr, &addr->sin6.sin6_addr, 16); + } +} + +static int ds_add_qp_dest(struct ds_qp *qp, union socket_addr *addr, + socklen_t addrlen) +{ + struct ibv_port_attr port_attr; + struct ibv_ah_attr attr; + int ret; + + memcpy(&qp->dest.addr, addr, addrlen); + qp->dest.qp = qp; + qp->dest.qpn = qp->cm_id->qp->qp_num; + + ret = ibv_query_port(qp->cm_id->verbs, qp->cm_id->port_num, &port_attr); + if (ret) + return ret; + + memset(&attr, 0, sizeof attr); + attr.dlid = port_attr.lid; + attr.port_num = qp->cm_id->port_num; + qp->dest.ah = ibv_create_ah(qp->cm_id->pd, &attr); + if (!qp->dest.ah) + return ERR(ENOMEM); + + tsearch(&qp->dest.addr, &qp->rs->dest_map, ds_compare_addr); + return 0; +} + +static int ds_create_qp(struct rsocket *rs, union socket_addr *src_addr, + socklen_t addrlen, struct ds_qp **qp) +{ + struct ibv_qp_init_attr qp_attr; + struct epoll_event event; + int i, ret; + + *qp = calloc(1, sizeof(struct ds_qp)); + if (!*qp) + return ERR(ENOMEM); + + (*qp)->rs = rs; + ret = rdma_create_id(NULL, &(*qp)->cm_id, *qp, RDMA_PS_UDP); + if (ret) + goto err; + + ds_format_hdr(&(*qp)->hdr, src_addr); + ret = rdma_bind_addr((*qp)->cm_id, &src_addr->sa); + if (ret) + goto err; + + ret = ds_init_bufs(*qp); + if (ret) + goto err; + + ret = rs_create_cq(rs, (*qp)->cm_id); + if (ret) + goto err; + + memset(&qp_attr, 0, sizeof qp_attr); + qp_attr.qp_context = qp; + qp_attr.send_cq = rs->cm_id->send_cq; + qp_attr.recv_cq = rs->cm_id->recv_cq; + qp_attr.qp_type = IBV_QPT_UD; + qp_attr.sq_sig_all = 1; + qp_attr.cap.max_send_wr = rs->sq_size; + qp_attr.cap.max_recv_wr = rs->rq_size; + qp_attr.cap.max_send_sge = 2; + qp_attr.cap.max_recv_sge = 1; + qp_attr.cap.max_inline_data = rs->sq_inline; + ret = rdma_create_qp((*qp)->cm_id, NULL, &qp_attr); + if (ret) + goto err; + + ret = ds_add_qp_dest(*qp, src_addr, addrlen); + if (ret) + goto err; + + event.events = EPOLLIN; + event.data.ptr = *qp; + ret = epoll_ctl(rs->epfd, EPOLL_CTL_ADD, + (*qp)->cm_id->recv_cq_channel->fd, &event); + if (ret) + goto err; + + for (i = 0; i < rs->rq_size; i++) { + ret = ds_post_recv(rs, *qp, (*qp)->rbuf + i * RS_SNDLOWAT); + if (ret) + goto err; + } + + ds_insert_qp(rs, *qp); + return 0; +err: + ds_free_qp(*qp); + return ret; +} + +static int ds_get_qp(struct rsocket *rs, union socket_addr *src_addr, + socklen_t addrlen, struct ds_qp **qp) +{ + if (rs->qp_list) { + *qp = rs->qp_list; + do { + if (!ds_compare_addr(rdma_get_local_addr((*qp)->cm_id), + src_addr)) + return 0; + + *qp = ds_next_qp(*qp); + } while (*qp != rs->qp_list); + } + + return ds_create_qp(rs, src_addr, addrlen, qp); +} + +static int ds_get_dest(struct rsocket *rs, const struct sockaddr *addr, + socklen_t addrlen, struct ds_dest **dest) +{ + union socket_addr src_addr; + socklen_t src_len; + struct ds_qp *qp; + int ret = 0; + + fastlock_acquire(&rs->map_lock); + dest = tfind(addr, &rs->dest_map, ds_compare_addr); + if (dest) + goto out; + + if (rs->state == rs_init) { + ret = ds_init_ep(rs); + if (ret) + goto out; + } + + ret = ds_get_src_addr(rs, addr, addrlen, &src_addr, &src_len); + if (ret) + goto out; + + ret = ds_get_qp(rs, &src_addr, src_len, &qp); + if (ret) + goto out; + + if ((addrlen != src_len) || memcmp(addr, &src_addr, addrlen)) { + *dest = calloc(1, sizeof(struct ds_dest)); + if (!*dest) { + ret = ERR(ENOMEM); + goto out; + } + + memcpy(&(*dest)->addr, addr, addrlen); + (*dest)->qp = qp; + tsearch(&(*dest)->addr, &rs->dest_map, ds_compare_addr); + } +out: + fastlock_release(&rs->map_lock); + return ret; +} + int rconnect(int socket, const struct sockaddr *addr, socklen_t addrlen) { struct rsocket *rs; + int ret; rs = idm_at(&idm, socket); - memcpy(&rs->cm_id->route.addr.dst_addr, addr, addrlen); - return rs_do_connect(rs); + if (rs->type == SOCK_STREAM) { + memcpy(&rs->cm_id->route.addr.dst_addr, addr, addrlen); + ret = rs_do_connect(rs); + } else { + fastlock_acquire(&rs->slock); + ret = connect(rs->udp_sock, addr, addrlen); + if (!ret) + ret = ds_get_dest(rs, addr, addrlen, &rs->conn_dest); + fastlock_release(&rs->slock); + } + return ret; } static int rs_post_write_msg(struct rsocket *rs, @@ -903,6 +1553,24 @@ static int rs_post_write(struct rsocket *rs, return rdma_seterrno(ibv_post_send(rs->cm_id->qp, &wr, &bad)); } +static int ds_post_send(struct rsocket *rs, struct ibv_sge *sge, + uint64_t wr_id) +{ + struct ibv_send_wr wr, *bad; + + wr.wr_id = wr_id; + wr.next = NULL; + wr.sg_list = sge; + wr.num_sge = 1; + wr.opcode = IBV_WR_SEND; + wr.send_flags = (sge->length <= rs->sq_inline) ? IBV_SEND_INLINE : 0; + wr.wr.ud.ah = rs->conn_dest->ah; + wr.wr.ud.remote_qpn = rs->conn_dest->qpn; + wr.wr.ud.remote_qkey = RDMA_UDP_QKEY; + + return rdma_seterrno(ibv_post_send(rs->conn_dest->qp->cm_id->qp, &wr, &bad)); +} + /* * Update target SGE before sending data. Otherwise the remote side may * update the entry before we do. @@ -1046,7 +1714,7 @@ static int rs_poll_cq(struct rsocket *rs) rs->state = rs_disconnected; return 0; } else if (rs_msg_data(imm_data) == RS_CTRL_SHUTDOWN) { - rs->state &= ~rs_connect_rd; + rs->state &= ~rs_readable; } break; case RS_OP_WRITE: @@ -1133,46 +1801,205 @@ static int rs_get_cq_event(struct rsocket *rs) */ static int rs_process_cq(struct rsocket *rs, int nonblock, int (*test)(struct rsocket *rs)) { - int ret; + int ret; + + fastlock_acquire(&rs->cq_lock); + do { + rs_update_credits(rs); + ret = rs_poll_cq(rs); + if (test(rs)) { + ret = 0; + break; + } else if (ret) { + break; + } else if (nonblock) { + ret = ERR(EWOULDBLOCK); + } else if (!rs->cq_armed) { + ibv_req_notify_cq(rs->cm_id->recv_cq, 0); + rs->cq_armed = 1; + } else { + rs_update_credits(rs); + fastlock_acquire(&rs->cq_wait_lock); + fastlock_release(&rs->cq_lock); + + ret = rs_get_cq_event(rs); + fastlock_release(&rs->cq_wait_lock); + fastlock_acquire(&rs->cq_lock); + } + } while (!ret); + + rs_update_credits(rs); + fastlock_release(&rs->cq_lock); + return ret; +} + +static int rs_get_comp(struct rsocket *rs, int nonblock, int (*test)(struct rsocket *rs)) +{ + struct timeval s, e; + uint32_t poll_time = 0; + int ret; + + do { + ret = rs_process_cq(rs, 1, test); + if (!ret || nonblock || errno != EWOULDBLOCK) + return ret; + + if (!poll_time) + gettimeofday(&s, NULL); + + gettimeofday(&e, NULL); + poll_time = (e.tv_sec - s.tv_sec) * 1000000 + + (e.tv_usec - s.tv_usec) + 1; + } while (poll_time <= polling_time); + + ret = rs_process_cq(rs, 0, test); + return ret; +} + +static int ds_valid_recv(void *buf, uint32_t len) +{ + struct ds_header *hdr = (struct ds_header *) buf; + return ((len >= sizeof(*hdr)) && + ((hdr->version == 4 && hdr->length == DS_IPV4_HDR_LEN) || + (hdr->version == 6 && hdr->length == DS_IPV6_HDR_LEN))); +} + +/* + * Poll all CQs associated with a datagram rsocket. We need to drop any + * received messages that we do not have room to store. To limit drops, + * we only poll if we have room to store the receive or we need a send + * buffer. To ensure fairness, we poll the CQs round robin, remembering + * where we left off. + */ +static void ds_poll_cqs(struct rsocket *rs) +{ + struct ds_qp *qp; + struct ds_smsg *smsg; + struct ds_rmsg *rmsg; + struct ibv_wc wc; + int ret, cnt; + + qp = rs->qp_list; + do { + cnt = 0; + do { + ret = ibv_poll_cq(qp->cm_id->recv_cq, 1, &wc); + if (ret <= 0) { + qp = ds_next_qp(qp); + continue; + } + + if (ds_wr_is_recv(wc.wr_id)) { + if (rs->rqe_avail && wc.status == IBV_WC_SUCCESS && + ds_valid_recv(qp->rbuf + ds_wr_offset(wc.wr_id), + wc.byte_len)) { + rs->rqe_avail--; + rmsg = &rs->dmsg[rs->rmsg_tail]; + rmsg->qp = qp; + rmsg->offset = ds_wr_offset(wc.wr_id); + rmsg->length = wc.byte_len; + if (++rs->rmsg_tail == rs->rq_size + 1) + rs->rmsg_tail = 0; + } else { + ds_post_recv(rs, qp, qp->rbuf + + ds_wr_offset(wc.wr_id)); + } + } else { + smsg = (struct ds_smsg *) + (rs->sbuf + ds_wr_offset(wc.wr_id)); + smsg->next = rs->smsg_free; + rs->smsg_free = smsg; + rs->sqe_avail++; + } + + qp = ds_next_qp(qp); + if (!rs->rqe_avail && rs->sqe_avail) { + rs->qp_list = qp; + return; + } + cnt++; + } while (qp != rs->qp_list); + } while (cnt); +} + +static void ds_req_notify_cqs(struct rsocket *rs) +{ + struct ds_qp *qp; + + qp = rs->qp_list; + do { + if (!qp->cq_armed) { + ibv_req_notify_cq(qp->cm_id->recv_cq, 0); + qp->cq_armed = 1; + } + qp = ds_next_qp(qp); + } while (qp != rs->qp_list); +} + +static int ds_get_cq_event(struct rsocket *rs) +{ + struct epoll_event event; + struct ds_qp *qp; + struct ibv_cq *cq; + void *context; + int ret; + + if (!rs->cq_armed) + return 0; + + ret = epoll_wait(rs->epfd, &event, 1, -1); + if (ret <= 0) + return ret; + + qp = event.data.ptr; + ret = ibv_get_cq_event(rs->cm_id->recv_cq_channel, &cq, &context); + if (!ret) { + ibv_ack_cq_events(rs->cm_id->recv_cq, 1); + qp->cq_armed = 0; + rs->cq_armed = 0; + } + + return ret; +} + +static int ds_process_cqs(struct rsocket *rs, int nonblock, int (*test)(struct rsocket *rs)) +{ + int ret = 0; fastlock_acquire(&rs->cq_lock); do { - rs_update_credits(rs); - ret = rs_poll_cq(rs); + ds_poll_cqs(rs); if (test(rs)) { ret = 0; break; - } else if (ret) { - break; } else if (nonblock) { ret = ERR(EWOULDBLOCK); } else if (!rs->cq_armed) { - ibv_req_notify_cq(rs->cm_id->recv_cq, 0); + ds_req_notify_cqs(rs); rs->cq_armed = 1; } else { rs_update_credits(rs); fastlock_acquire(&rs->cq_wait_lock); fastlock_release(&rs->cq_lock); - ret = rs_get_cq_event(rs); + ret = ds_get_cq_event(rs); fastlock_release(&rs->cq_wait_lock); fastlock_acquire(&rs->cq_lock); } } while (!ret); - rs_update_credits(rs); fastlock_release(&rs->cq_lock); return ret; } -static int rs_get_comp(struct rsocket *rs, int nonblock, int (*test)(struct rsocket *rs)) +static int ds_get_comp(struct rsocket *rs, int nonblock, int (*test)(struct rsocket *rs)) { struct timeval s, e; uint32_t poll_time = 0; int ret; do { - ret = rs_process_cq(rs, 1, test); + ret = ds_process_cqs(rs, 1, test); if (!ret || nonblock || errno != EWOULDBLOCK) return ret; @@ -1184,7 +2011,7 @@ static int rs_get_comp(struct rsocket *rs, int nonblock, int (*test)(struct rsoc (e.tv_usec - s.tv_usec) + 1; } while (poll_time <= polling_time); - ret = rs_process_cq(rs, 0, test); + ret = ds_process_cqs(rs, 0, test); return ret; } @@ -1219,9 +2046,19 @@ static int rs_can_send(struct rsocket *rs) (rs->target_sgl[rs->target_sge].length != 0); } +static int ds_can_send(struct rsocket *rs) +{ + return rs->sqe_avail; +} + +static int ds_all_sends_done(struct rsocket *rs) +{ + return rs->sqe_avail == rs->sq_size; +} + static int rs_conn_can_send(struct rsocket *rs) { - return rs_can_send(rs) || !(rs->state & rs_connect_wr); + return rs_can_send(rs) || !(rs->state & rs_writable); } static int rs_conn_can_send_ctrl(struct rsocket *rs) @@ -1236,7 +2073,7 @@ static int rs_have_rdata(struct rsocket *rs) static int rs_conn_have_rdata(struct rsocket *rs) { - return rs_have_rdata(rs) || !(rs->state & rs_connect_rd); + return rs_have_rdata(rs) || !(rs->state & rs_readable); } static int rs_conn_all_sends_done(struct rsocket *rs) @@ -1245,6 +2082,66 @@ static int rs_conn_all_sends_done(struct rsocket *rs) !(rs->state & rs_connected); } +static void ds_set_src(struct sockaddr *addr, socklen_t *addrlen, + struct ds_header *hdr) +{ + union socket_addr sa; + + if (hdr->version == 4) { + if (*addrlen > sizeof(sa.sin)) + *addrlen = sizeof(sa.sin); + + sa.sin.sin_family = AF_INET; + sa.sin.sin_port = hdr->port; + sa.sin.sin_addr.s_addr = hdr->addr.ipv4; + } else { + if (*addrlen > sizeof(sa.sin6)) + *addrlen = sizeof(sa.sin6); + + sa.sin6.sin6_family = AF_INET6; + sa.sin6.sin6_port = hdr->port; + sa.sin6.sin6_flowinfo = hdr->addr.ipv6.flowinfo; + memcpy(&sa.sin6.sin6_addr, &hdr->addr.ipv6.addr, 16); + sa.sin6.sin6_scope_id = 0; + } + memcpy(addr, &sa, *addrlen); +} + +static ssize_t ds_recvfrom(struct rsocket *rs, void *buf, size_t len, int flags, + struct sockaddr *src_addr, socklen_t *addrlen) +{ + struct ds_rmsg *rmsg; + struct ds_header *hdr; + int ret; + + if (!(rs->state & rs_readable)) + return ERR(EINVAL); + + if (!rs_have_rdata(rs)) { + ret = ds_get_comp(rs, rs_nonblocking(rs, flags), + rs_have_rdata); + if (ret) + return ret; + } + + rmsg = &rs->dmsg[rs->rmsg_head]; + hdr = (struct ds_header *) (rmsg->qp->rbuf + rmsg->offset); + if (len > rmsg->length - hdr->length) + len = rmsg->length - hdr->length; + + memcpy(buf, (void *) hdr + hdr->length, len); + if (addrlen) + ds_set_src(src_addr, addrlen, hdr); + + if (!(flags & MSG_PEEK)) { + ds_post_recv(rs, rmsg->qp, hdr); + if (++rs->rmsg_head == rs->rq_size + 1) + rs->rmsg_head = 0; + } + + return len; +} + static ssize_t rs_peek(struct rsocket *rs, void *buf, size_t len) { size_t left = len; @@ -1290,6 +2187,13 @@ ssize_t rrecv(int socket, void *buf, size_t len, int flags) int ret; rs = idm_at(&idm, socket); + if (rs->type == SOCK_DGRAM) { + fastlock_acquire(&rs->slock); + ret = ds_recvfrom(rs, buf, len, flags, NULL, 0); + fastlock_release(&rs->slock); + return ret; + } + if (rs->state & rs_opening) { ret = rs_do_connect(rs); if (ret) { @@ -1339,7 +2243,7 @@ ssize_t rrecv(int socket, void *buf, size_t len, int flags) rs->rbuf_bytes_avail += rsize; } - } while (left && (flags & MSG_WAITALL) && (rs->state & rs_connect_rd)); + } while (left && (flags & MSG_WAITALL) && (rs->state & rs_readable)); fastlock_release(&rs->rlock); return ret ? ret : len - left; @@ -1348,8 +2252,17 @@ ssize_t rrecv(int socket, void *buf, size_t len, int flags) ssize_t rrecvfrom(int socket, void *buf, size_t len, int flags, struct sockaddr *src_addr, socklen_t *addrlen) { + struct rsocket *rs; int ret; + rs = idm_at(&idm, socket); + if (rs->type == SOCK_DGRAM) { + fastlock_acquire(&rs->slock); + ret = ds_recvfrom(rs, buf, len, flags, src_addr, addrlen); + fastlock_release(&rs->slock); + return ret; + } + ret = rrecv(socket, buf, len, flags); if (ret > 0 && src_addr) rgetpeername(socket, src_addr, addrlen); @@ -1391,14 +2304,14 @@ static int rs_send_iomaps(struct rsocket *rs, int flags) struct rs_iomap iom; int ret; - fastlock_acquire(&rs->iomap_lock); + fastlock_acquire(&rs->map_lock); while (!dlist_empty(&rs->iomap_queue)) { if (!rs_can_send(rs)) { ret = rs_get_comp(rs, rs_nonblocking(rs, flags), rs_conn_can_send); if (ret) break; - if (!(rs->state & rs_connect_wr)) { + if (!(rs->state & rs_writable)) { ret = ERR(ECONNRESET); break; } @@ -1447,10 +2360,90 @@ static int rs_send_iomaps(struct rsocket *rs, int flags) } rs->iomap_pending = !dlist_empty(&rs->iomap_queue); - fastlock_release(&rs->iomap_lock); + fastlock_release(&rs->map_lock); return ret; } +static ssize_t ds_sendv_udp(struct rsocket *rs, const struct iovec *iov, + int iovcnt, int flags, uint8_t op) +{ + struct ds_udp_header hdr; + struct msghdr msg; + struct iovec miov[8]; + + if (iovcnt > 8) + return ERR(ENOTSUP); + + hdr.tag = htonl(DS_UDP_TAG); + hdr.version = 1; + hdr.op = op; + hdr.reserved = 0; + hdr.qpn = htonl(rs->conn_dest->qp->cm_id->qp->qp_num & 0xFFFFFF); + if (rs->conn_dest->qp->hdr.version == 4) { + hdr.length = DS_UDP_IPV4_HDR_LEN; + hdr.addr.ipv4 = rs->conn_dest->qp->hdr.addr.ipv4; + } else { + hdr.length = DS_UDP_IPV6_HDR_LEN; + memcpy(hdr.addr.ipv6, &rs->conn_dest->qp->hdr.addr.ipv6, 16); + } + + miov[0].iov_base = &hdr; + miov[0].iov_len = sizeof hdr; + if (iov && iovcnt) + memcpy(&miov[1], iov, sizeof *iov * iovcnt); + + memset(&msg, 0, sizeof msg); + msg.msg_name = &rs->conn_dest->addr; + msg.msg_namelen = ucma_addrlen(&rs->conn_dest->addr.sa); + msg.msg_iov = miov; + msg.msg_iovlen = iovcnt + 1; + return sendmsg(rs->udp_sock, &msg, flags); +} + +static ssize_t ds_send_udp(struct rsocket *rs, const void *buf, size_t len, + int flags, uint8_t op) +{ + struct iovec iov; + if (buf && len) { + iov.iov_base = (void *) buf; + iov.iov_len = len; + return ds_sendv_udp(rs, &iov, 1, flags, op); + } else { + return ds_sendv_udp(rs, NULL, 0, flags, op); + } +} + +static ssize_t dsend(struct rsocket *rs, const void *buf, size_t len, int flags) +{ + struct ds_smsg *msg; + struct ibv_sge sge; + uint64_t offset; + int ret = 0; + + if (!rs->conn_dest->ah) + return ds_send_udp(rs, buf, len, flags, RS_OP_DATA); + + if (!ds_can_send(rs)) { + ret = ds_get_comp(rs, rs_nonblocking(rs, flags), ds_can_send); + if (ret) + return ret; + } + + msg = rs->smsg_free; + rs->smsg_free = msg->next; + rs->sqe_avail--; + + memcpy((void *) msg, &rs->conn_dest->qp->hdr, rs->conn_dest->qp->hdr.length); + memcpy((void *) msg + rs->conn_dest->qp->hdr.length, buf, len); + sge.addr = (uintptr_t) msg; + sge.length = rs->conn_dest->qp->hdr.length + len; + sge.lkey = rs->smr->lkey; + offset = (uint8_t *) msg - rs->sbuf; + + ret = ds_post_send(rs, &sge, ds_send_wr_id(offset, sge.length)); + return ret ? ret : len; +} + /* * We overlap sending the data, by posting a small work request immediately, * then increasing the size of the send on each iteration. @@ -1464,6 +2457,13 @@ ssize_t rsend(int socket, const void *buf, size_t len, int flags) int ret = 0; rs = idm_at(&idm, socket); + if (rs->type == SOCK_DGRAM) { + fastlock_acquire(&rs->slock); + ret = dsend(rs, buf, len, flags); + fastlock_release(&rs->slock); + return ret; + } + if (rs->state & rs_opening) { ret = rs_do_connect(rs); if (ret) { @@ -1485,7 +2485,7 @@ ssize_t rsend(int socket, const void *buf, size_t len, int flags) rs_conn_can_send); if (ret) break; - if (!(rs->state & rs_connect_wr)) { + if (!(rs->state & rs_writable)) { ret = ERR(ECONNRESET); break; } @@ -1538,10 +2538,27 @@ out: ssize_t rsendto(int socket, const void *buf, size_t len, int flags, const struct sockaddr *dest_addr, socklen_t addrlen) { - if (dest_addr || addrlen) - return ERR(EISCONN); + struct rsocket *rs; + int ret; + + rs = idm_at(&idm, socket); + if (rs->type == SOCK_STREAM) { + if (dest_addr || addrlen) + return ERR(EISCONN); + + return rsend(socket, buf, len, flags); + } - return rsend(socket, buf, len, flags); + fastlock_acquire(&rs->slock); + if (!rs->conn_dest || ds_compare_addr(dest_addr, &rs->conn_dest->addr)) { + ret = ds_get_dest(rs, dest_addr, addrlen, &rs->conn_dest); + if (ret) + goto out; + } + ret = dsend(rs, buf, len, flags); +out: + fastlock_release(&rs->slock); + return ret; } static void rs_copy_iov(void *dst, const struct iovec **iov, size_t *offset, size_t len) @@ -1600,7 +2617,7 @@ static ssize_t rsendv(int socket, const struct iovec *iov, int iovcnt, int flags rs_conn_can_send); if (ret) break; - if (!(rs->state & rs_connect_wr)) { + if (!(rs->state & rs_writable)) { ret = ERR(ECONNRESET); break; } @@ -1653,7 +2670,7 @@ ssize_t rsendmsg(int socket, const struct msghdr *msg, int flags) if (msg->msg_control && msg->msg_controllen) return ERR(ENOTSUP); - return rsendv(socket, msg->msg_iov, (int) msg->msg_iovlen, msg->msg_flags); + return rsendv(socket, msg->msg_iov, (int) msg->msg_iovlen, flags); } ssize_t rwrite(int socket, const void *buf, size_t count) @@ -1690,8 +2707,8 @@ static int rs_poll_rs(struct rsocket *rs, int events, int ret; check_cq: - if ((rs->state & rs_connected) || (rs->state == rs_disconnected) || - (rs->state & rs_error)) { + if ((rs->type == SOCK_STREAM) && ((rs->state & rs_connected) || + (rs->state == rs_disconnected) || (rs->state & rs_error))) { rs_process_cq(rs, nonblock, test); revents = 0; @@ -1706,6 +2723,16 @@ check_cq: revents |= POLLERR; } + return revents; + } else if (rs->type == SOCK_DGRAM) { + ds_process_cqs(rs, nonblock, test); + + revents = 0; + if ((events & POLLIN) && rs_have_rdata(rs)) + revents |= POLLIN; + if ((events & POLLOUT) && ds_can_send(rs)) + revents |= POLLOUT; + return revents; } @@ -1766,11 +2793,14 @@ static int rs_poll_arm(struct pollfd *rfds, struct pollfd *fds, nfds_t nfds) if (fds[i].revents) return 1; - if (rs->state >= rs_connected) - rfds[i].fd = rs->cm_id->recv_cq_channel->fd; - else - rfds[i].fd = rs->cm_id->channel->fd; - + if (rs->type == SOCK_STREAM) { + if (rs->state >= rs_connected) + rfds[i].fd = rs->cm_id->recv_cq_channel->fd; + else + rfds[i].fd = rs->cm_id->channel->fd; + } else { + rfds[i].fd = rs->epfd; + } rfds[i].events = POLLIN; } else { rfds[i].fd = fds[i].fd; @@ -1793,7 +2823,10 @@ static int rs_poll_events(struct pollfd *rfds, struct pollfd *fds, nfds_t nfds) rs = idm_lookup(&idm, fds[i].fd); if (rs) { - rs_get_cq_event(rs); + if (rs->type == SOCK_STREAM) + rs_get_cq_event(rs); + else + ds_get_cq_event(rs); fds[i].revents = rs_poll_rs(rs, fds[i].events, 1, rs_poll_all); } else { fds[i].revents = rfds[i].revents; @@ -1949,7 +2982,7 @@ int rshutdown(int socket, int how) rs = idm_at(&idm, socket); if (how == SHUT_RD) { - rs->state &= ~rs_connect_rd; + rs->state &= ~rs_readable; return 0; } @@ -1959,10 +2992,10 @@ int rshutdown(int socket, int how) if (rs->state & rs_connected) { if (how == SHUT_RDWR) { ctrl = RS_CTRL_DISCONNECT; - rs->state &= ~(rs_connect_rd | rs_connect_wr); + rs->state &= ~(rs_readable | rs_writable); } else { - rs->state &= ~rs_connect_wr; - ctrl = (rs->state & rs_connect_rd) ? + rs->state &= ~rs_writable; + ctrl = (rs->state & rs_readable) ? RS_CTRL_SHUTDOWN : RS_CTRL_DISCONNECT; } if (!rs->ctrl_avail) { @@ -1987,13 +3020,29 @@ int rshutdown(int socket, int how) return 0; } +static void ds_shutdown(struct rsocket *rs) +{ + if (rs->fd_flags & O_NONBLOCK) + rs_set_nonblocking(rs, 0); + + rs->state &= ~(rs_readable | rs_writable); + ds_process_cqs(rs, 0, ds_all_sends_done); + + if (rs->fd_flags & O_NONBLOCK) + rs_set_nonblocking(rs, 1); +} + int rclose(int socket) { struct rsocket *rs; rs = idm_at(&idm, socket); - if (rs->state & rs_connected) - rshutdown(socket, SHUT_RDWR); + if (rs->type == SOCK_STREAM) { + if (rs->state & rs_connected) + rshutdown(socket, SHUT_RDWR); + } else { + ds_shutdown(rs); + } rs_free(rs); return 0; @@ -2018,8 +3067,12 @@ int rgetpeername(int socket, struct sockaddr *addr, socklen_t *addrlen) struct rsocket *rs; rs = idm_at(&idm, socket); - rs_copy_addr(addr, rdma_get_peer_addr(rs->cm_id), addrlen); - return 0; + if (rs->type == SOCK_STREAM) { + rs_copy_addr(addr, rdma_get_peer_addr(rs->cm_id), addrlen); + return 0; + } else { + return getpeername(rs->udp_sock, addr, addrlen); + } } int rgetsockname(int socket, struct sockaddr *addr, socklen_t *addrlen) @@ -2027,8 +3080,12 @@ int rgetsockname(int socket, struct sockaddr *addr, socklen_t *addrlen) struct rsocket *rs; rs = idm_at(&idm, socket); - rs_copy_addr(addr, rdma_get_local_addr(rs->cm_id), addrlen); - return 0; + if (rs->type == SOCK_STREAM) { + rs_copy_addr(addr, rdma_get_local_addr(rs->cm_id), addrlen); + return 0; + } else { + return getsockname(rs->udp_sock, addr, addrlen); + } } int rsetsockopt(int socket, int level, int optname, @@ -2040,18 +3097,26 @@ int rsetsockopt(int socket, int level, int optname, ret = ERR(ENOTSUP); rs = idm_at(&idm, socket); + if (rs->type == SOCK_DGRAM && level != SOL_RDMA) { + ret = setsockopt(rs->udp_sock, level, optname, optval, optlen); + if (ret) + return ret; + } + switch (level) { case SOL_SOCKET: opts = &rs->so_opts; switch (optname) { case SO_REUSEADDR: - ret = rdma_set_option(rs->cm_id, RDMA_OPTION_ID, - RDMA_OPTION_ID_REUSEADDR, - (void *) optval, optlen); - if (ret && ((errno == ENOSYS) || ((rs->state != rs_init) && - rs->cm_id->context && - (rs->cm_id->verbs->device->transport_type == IBV_TRANSPORT_IB)))) - ret = 0; + if (rs->type == SOCK_STREAM) { + ret = rdma_set_option(rs->cm_id, RDMA_OPTION_ID, + RDMA_OPTION_ID_REUSEADDR, + (void *) optval, optlen); + if (ret && ((errno == ENOSYS) || ((rs->state != rs_init) && + rs->cm_id->context && + (rs->cm_id->verbs->device->transport_type == IBV_TRANSPORT_IB)))) + ret = 0; + } opt_on = *(int *) optval; break; case SO_RCVBUF: @@ -2101,9 +3166,11 @@ int rsetsockopt(int socket, int level, int optname, opts = &rs->ipv6_opts; switch (optname) { case IPV6_V6ONLY: - ret = rdma_set_option(rs->cm_id, RDMA_OPTION_ID, - RDMA_OPTION_ID_AFONLY, - (void *) optval, optlen); + if (rs->type == SOCK_STREAM) { + ret = rdma_set_option(rs->cm_id, RDMA_OPTION_ID, + RDMA_OPTION_ID_AFONLY, + (void *) optval, optlen); + } opt_on = *(int *) optval; break; default: @@ -2315,7 +3382,7 @@ off_t riomap(int socket, void *buf, size_t len, int prot, int flags, off_t offse if (!rs->cm_id->pd || (prot & ~(PROT_WRITE | PROT_NONE))) return ERR(EINVAL); - fastlock_acquire(&rs->iomap_lock); + fastlock_acquire(&rs->map_lock); if (prot & PROT_WRITE) { iomr = rs_get_iomap_mr(rs); access |= IBV_ACCESS_REMOTE_WRITE; @@ -2349,7 +3416,7 @@ off_t riomap(int socket, void *buf, size_t len, int prot, int flags, off_t offse dlist_insert_tail(&iomr->entry, &rs->iomap_list); } out: - fastlock_release(&rs->iomap_lock); + fastlock_release(&rs->map_lock); return offset; } @@ -2361,7 +3428,7 @@ int riounmap(int socket, void *buf, size_t len) int ret = 0; rs = idm_at(&idm, socket); - fastlock_acquire(&rs->iomap_lock); + fastlock_acquire(&rs->map_lock); for (entry = rs->iomap_list.next; entry != &rs->iomap_list; entry = entry->next) { @@ -2382,7 +3449,7 @@ int riounmap(int socket, void *buf, size_t len) } ret = ERR(EINVAL); out: - fastlock_release(&rs->iomap_lock); + fastlock_release(&rs->map_lock); return ret; } @@ -2426,7 +3493,7 @@ size_t riowrite(int socket, const void *buf, size_t count, off_t offset, int fla rs_conn_can_send); if (ret) break; - if (!(rs->state & rs_connect_wr)) { + if (!(rs->state & rs_writable)) { ret = ERR(ECONNRESET); break; } @@ -2476,3 +3543,269 @@ out: return (ret && left == count) ? ret : count - left; } + +static int rs_svc_grow_sets(void) +{ + struct rsocket **rss; + struct pollfd *fds; + void *set; + + set = calloc(svc_size + 2, sizeof(*rss) + sizeof(*fds)); + if (!set) + return ENOMEM; + + svc_size += 2; + rss = set; + fds = set + sizeof(*rss) * svc_size; + if (svc_cnt) { + memcpy(rss, svc_rss, sizeof(*rss) * svc_cnt); + memcpy(fds, svc_fds, sizeof(*fds) * svc_cnt); + } + + free(svc_rss); + free(svc_fds); + svc_rss = rss; + svc_fds = fds; + return 0; +} + +/* + * Index 0 is reserved for the service's communication socket. + */ +static int rs_svc_add_rs(struct rsocket *rs) +{ + int ret; + + if (svc_cnt >= svc_size - 1) { + ret = rs_svc_grow_sets(); + if (ret) + return ret; + } + + svc_rss[++svc_cnt] = rs; + svc_fds[svc_cnt].fd = rs->udp_sock; + svc_fds[svc_cnt].events = POLLIN; + svc_fds[svc_cnt].revents = 0; + return 0; +} + +static int rs_svc_rm_rs(struct rsocket *rs) +{ + int i; + + for (i = 1; i <= svc_cnt; i++) { + if (svc_rss[i] == rs) { + svc_cnt--; + svc_rss[i] = svc_rss[svc_cnt]; + svc_fds[i] = svc_fds[svc_cnt]; + return 0; + } + } + return EBADF; +} + +static void rs_svc_process_sock(void) +{ + struct rs_svc_msg msg; + + read(svc_sock[1], &msg, sizeof msg); + switch (msg.op) { + case RS_SVC_INSERT: + msg.status = rs_svc_add_rs(msg.rs); + break; + case RS_SVC_REMOVE: + msg.status = rs_svc_rm_rs(msg.rs); + break; + default: + msg.status = ENOTSUP; + break; + } + write(svc_sock[1], &msg, sizeof msg); +} + +static uint8_t rs_svc_sgid_index(struct ds_dest *dest, union ibv_gid *sgid) +{ + union ibv_gid gid; + int i, ret; + + for (i = 0; i < 16; i++) { + ret = ibv_query_gid(dest->qp->cm_id->verbs, dest->qp->cm_id->port_num, + i, &gid); + if (!memcmp(sgid, &gid, sizeof gid)) + return i; + } + return 0; +} + +static uint8_t rs_svc_path_bits(struct ds_dest *dest) +{ + struct ibv_port_attr attr; + + if (!ibv_query_port(dest->qp->cm_id->verbs, dest->qp->cm_id->port_num, &attr)) + return (uint8_t) ((1 << attr.lmc) - 1); + return 0x7f; +} + +static void rs_svc_create_ah(struct rsocket *rs, struct ds_dest *dest, uint32_t qpn) +{ + union socket_addr saddr; + struct rdma_cm_id *id; + struct ibv_ah_attr attr; + int ret; + + if (dest->ah) { + fastlock_acquire(&rs->slock); + ibv_destroy_ah(dest->ah); + dest->ah = NULL; + fastlock_release(&rs->slock); + } + + ret = rdma_create_id(NULL, &id, NULL, dest->qp->cm_id->ps); + if (ret) + return; + + memcpy(&saddr, rdma_get_local_addr(dest->qp->cm_id), + ucma_addrlen(rdma_get_local_addr(dest->qp->cm_id))); + if (saddr.sa.sa_family == AF_INET) + saddr.sin.sin_port = 0; + else + saddr.sin6.sin6_port = 0; + ret = rdma_resolve_addr(id, &saddr.sa, &dest->addr.sa, 2000); + if (ret) + goto out; + + ret = rdma_resolve_route(id, 2000); + if (ret) + goto out; + + memset(&attr, 0, sizeof attr); + if (id->route.path_rec->hop_limit) { + attr.is_global = 1; + attr.grh.dgid = id->route.path_rec->dgid; + attr.grh.flow_label = id->route.path_rec->flow_label; + attr.grh.sgid_index = rs_svc_sgid_index(dest, &id->route.path_rec->sgid); + attr.grh.hop_limit = id->route.path_rec->hop_limit; + attr.grh.traffic_class = id->route.path_rec->traffic_class; + } + attr.dlid = id->route.path_rec->dlid; + attr.sl = id->route.path_rec->sl; + attr.src_path_bits = id->route.path_rec->slid & rs_svc_path_bits(dest); + attr.static_rate = id->route.path_rec->rate; + attr.port_num = id->port_num; + + fastlock_acquire(&rs->slock); + dest->qpn = qpn; + dest->ah = ibv_create_ah(dest->qp->cm_id->pd, &attr); + fastlock_release(&rs->slock); +out: + rdma_destroy_id(id); +} + +static int rs_svc_valid_udp_hdr(struct ds_udp_header *udp_hdr, + union socket_addr *addr) +{ + return (udp_hdr->tag == DS_UDP_TAG) && + ((udp_hdr->version == 4 && addr->sa.sa_family == AF_INET && + udp_hdr->length == DS_UDP_IPV4_HDR_LEN) || + (udp_hdr->version == 6 && addr->sa.sa_family == AF_INET6 && + udp_hdr->length == DS_UDP_IPV6_HDR_LEN)); +} + +static void rs_svc_forward(struct rsocket *rs, void *buf, size_t len, + union socket_addr *src) +{ + struct ds_header hdr; + struct ds_smsg *msg; + struct ibv_sge sge; + uint64_t offset; + + if (!ds_can_send(rs)) { + if (ds_get_comp(rs, 0, ds_can_send)) + return; + } + + msg = rs->smsg_free; + rs->smsg_free = msg->next; + rs->sqe_avail--; + + ds_format_hdr(&hdr, src); + memcpy((void *) msg, &hdr, hdr.length); + memcpy((void *) msg + hdr.length, buf, len); + sge.addr = (uintptr_t) msg; + sge.length = hdr.length + len; + sge.lkey = rs->smr->lkey; + offset = (uint8_t *) msg - rs->sbuf; + + ds_post_send(rs, &sge, ds_send_wr_id(offset, sge.length)); +} + +static void rs_svc_process_rs(struct rsocket *rs) +{ + struct ds_dest *dest, *cur_dest; + struct ds_udp_header *udp_hdr; + union socket_addr addr; + socklen_t addrlen = sizeof addr; + int len, ret; + + ret = recvfrom(rs->udp_sock, svc_buf, sizeof svc_buf, 0, &addr.sa, &addrlen); + if (ret < DS_UDP_IPV4_HDR_LEN) + return; + + udp_hdr = (struct ds_udp_header *) svc_buf; + if (!rs_svc_valid_udp_hdr(udp_hdr, &addr)) + return; + + len = ret - udp_hdr->length; + udp_hdr->tag = ntohl(udp_hdr->tag); + udp_hdr->qpn = ntohl(udp_hdr->qpn) & 0xFFFFFF; + ret = ds_get_dest(rs, &addr.sa, addrlen, &dest); + if (ret) + return; + + if (!dest->ah || (dest->qpn != udp_hdr->qpn)) + rs_svc_create_ah(rs, dest, udp_hdr->qpn); + + /* to do: handle when dest local ip address doesn't match udp ip */ + fastlock_acquire(&rs->slock); + cur_dest = rs->conn_dest; + if (udp_hdr->op == RS_OP_DATA) { + rs->conn_dest = &dest->qp->dest; + rs_svc_forward(rs, svc_buf + udp_hdr->length, len, &addr); + } + + rs->conn_dest = dest; + ds_send_udp(rs, svc_buf + udp_hdr->length, len, 0, RS_OP_CTRL); + rs->conn_dest = cur_dest; + fastlock_release(&rs->slock); +} + +static void *rs_svc_run(void *arg) +{ + struct rs_svc_msg msg; + int i, ret; + + ret = rs_svc_grow_sets(); + if (ret) { + msg.status = ret; + write(svc_sock[1], &msg, sizeof msg); + return (void *) (uintptr_t) ret; + } + + svc_fds[0].fd = svc_sock[1]; + svc_fds[0].events = POLLIN; + do { + for (i = 0; i <= svc_cnt; i++) + svc_fds[i].revents = 0; + + poll(svc_fds, svc_cnt + 1, -1); + if (svc_fds[0].revents) + rs_svc_process_sock(); + + for (i = 1; i <= svc_cnt; i++) { + if (svc_fds[i].revents) + rs_svc_process_rs(svc_rss[i]); + } + } while (svc_cnt > 1); + + return NULL; +} -- 2.41.0