From c57b602d27ad4e916883dce975d819b52d2f902e Mon Sep 17 00:00:00 2001 From: Sean Hefty Date: Tue, 27 Nov 2012 16:27:39 -0800 Subject: [PATCH] Refresh of dsocket --- src/rsocket.c | 229 ++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 175 insertions(+), 54 deletions(-) diff --git a/src/rsocket.c b/src/rsocket.c index a81b8f39..d7c3163b 100644 --- a/src/rsocket.c +++ b/src/rsocket.c @@ -211,10 +211,27 @@ union socket_addr { struct sockaddr_in6 sin6; }; +struct ds_header { + uint8_t version; + uint8_t length; + uint16_t port; + union { + uint32_t ipv4; + struct { + uint32_t flowinfo; + uint8_t addr[16]; + } ipv6; + } addr; +}; + +#define DS_IPV4_HDR_LEN 8 +#define DS_IPV6_HDR_LEN 24 + struct ds_qp { dlist_t list; struct rsocket *rs; struct rdma_cm_id *cm_id; + struct ds_header hdr; struct ibv_mr *smr; struct ibv_mr *rmr; @@ -324,7 +341,6 @@ struct ds_udp_header { uint32_t qpn; /* upper 8-bits reserved */ }; - #define ds_next_qp(qp) container_of((qp)->list.next, struct ds_qp, list) static void ds_insert_qp(struct rsocket *rs, struct ds_qp *qp) @@ -1285,6 +1301,22 @@ out: return ret; } +static void ds_format_hdr(struct ds_header *hdr, union socket_addr *addr) +{ + if (addr->sa.sa_family == AF_INET) { + hdr->version = 4; + hdr->length = DS_IPV4_HDR_LEN; + hdr->port = addr->sin.sin_port; + hdr->addr.ipv4 = addr->sin.sin_addr; + } else { + hdr->version = 6; + hdr->length = DS_IPV6_HDR_LEN; + hdr->port = addr->sin6.sin6_port; + hdr->addr.ipv6.flowinfo= addr->sin6.sin6_flowinfo; + memcpy(&hdr->addr.ipv6.addr, &addr->sin6.sin6_addr, 16); + } +} + static int ds_create_qp(struct rsocket *rs, union socket_addr *src_addr, socklen_t addrlen, struct ds_qp **qp) { @@ -1301,6 +1333,7 @@ static int ds_create_qp(struct rsocket *rs, union socket_addr *src_addr, if (ret) goto err; + ds_format_hdr(&(*qp)->hdr, src_addr); ret = rdma_bind_addr((*qp)->cm_id, &src_addr->sa); if (ret) goto err; @@ -1463,18 +1496,17 @@ static int rs_post_write(struct rsocket *rs, return rdma_seterrno(ibv_post_send(rs->cm_id->qp, &wr, &bad)); } -static int ds_post_send(struct rsocket *rs, - struct ibv_sge *sgl, int nsge, - uint64_t wr_id, int flags) +static int ds_post_send(struct rsocket *rs, struct ibv_sge *sge, + uint64_t wr_id) { struct ibv_send_wr wr, *bad; wr.wr_id = wr_id; wr.next = NULL; - wr.sg_list = sgl; - wr.num_sge = nsge; + wr.sg_list = sge; + wr.num_sge = 1; wr.opcode = IBV_WR_SEND; - wr.send_flags = flags; + wr.send_flags = (sge.length <= rs->sq_inline) ? IBV_SEND_INLINE : 0; wr.wr.ud.ah = rs->conn_dest->ah; wr.wr.ud.remote_qpn = rs->conn_dest->qpn; wr.wr.ud.remote_qkey = RDMA_UDP_QKEY; @@ -1512,18 +1544,6 @@ static int rs_write_data(struct rsocket *rs, flags, addr, rkey); } -static int ds_send_data(struct rsocket *rs, - struct ibv_sge *sgl, int nsge, - uint32_t length, int flags) -{ - uint64_t offset; - - rs->sqe_avail--; - rs->sbuf_bytes_avail -= length; - offset = sgl->addr - (uintptr_t) rs->sbuf; - return ds_post_send(rs, sgl, nsge, ds_send_wr_id(offset, length), flags); -} - static int rs_write_direct(struct rsocket *rs, struct rs_iomap *iom, uint64_t offset, struct ibv_sge *sgl, int nsge, uint32_t length, int flags) { @@ -1779,6 +1799,14 @@ static int rs_get_comp(struct rsocket *rs, int nonblock, int (*test)(struct rsoc return ret; } +static int ds_valid_recv(void *buf, uint32_t len) +{ + struct ds_header *hdr = (struct ds_header *) buf; + return ((len >= sizeof(*hdr)) && + ((hdr->version == 4 && hdr->length == DS_IPV4_HDR_LEN) || + (hdr->version == 6 && hdr->length == DS_IPV6_HDR_LEN))); +} + /* * Poll all CQs associated with a datagram rsocket. We need to drop any * received messages that we do not have room to store. To limit drops, @@ -1789,7 +1817,8 @@ static int rs_get_comp(struct rsocket *rs, int nonblock, int (*test)(struct rsoc static void ds_poll_cqs(struct rsocket *rs) { struct ds_qp *qp; - struct ds_smsg *msg; + struct ds_smsg *smsg; + struct ds_rmsg *rmsg; struct ibv_wc wc; int ret, cnt; @@ -1804,9 +1833,14 @@ static void ds_poll_cqs(struct rsocket *rs) } if (ds_wr_is_recv(wc.wr_id)) { - if (rs->rqe_avail && wc.status == IBV_WC_SUCCESS) { + if (rs->rqe_avail && wc.status == IBV_WC_SUCCESS && + ds_valid_recv(qp->rbuf + ds_wr_offset(wc.wr_id), + wc.byte_len)) { rs->rqe_avail--; - rs->dmsg[rs->rmsg_tail].qp = qp; + rmsg = &rs->dmsg[rs->rmsg_tail]; + rmsg->qp = qp; + rmsg->offset = ds_wr_offset(wc.wr_id); + rmsg->length = wc.byte_len; if (++rs->rmsg_tail == rs->rq_size + 1) rs->rmsg_tail = 0; } else { @@ -1814,12 +1848,10 @@ static void ds_poll_cqs(struct rsocket *rs) ds_wr_offset(wc.wr_id)); } } else { - if (ds_wr_length(wc.wr_id) > rs->sq_inline) { - msg = (struct ds_smsg *) - (rs->sbuf + ds_wr_offset(wc.wr_id)); - msg->next = rs->smsg_free; - rs->smsg_free = msg; - } + smsg = (struct ds_smsg *) + (rs->sbuf + ds_wr_offset(wc.wr_id)); + smsg->next = rs->smsg_free; + rs->smsg_free = smsg; rs->sqe_avail++; } @@ -1959,7 +1991,12 @@ static int rs_can_send(struct rsocket *rs) static int ds_can_send(struct rsocket *rs) { - return rs->sqe_avail && (rs->sbuf_bytes_avail >= RS_SNDLOWAT); + return rs->sqe_avail; +} + +static int ds_all_sends_done(struct rsocket *rs) +{ + return rs->sqe_avail == rs->sq_size; } static int rs_conn_can_send(struct rsocket *rs) @@ -1988,6 +2025,66 @@ static int rs_conn_all_sends_done(struct rsocket *rs) !(rs->state & rs_connected); } +static void ds_set_src(struct sockaddr *addr, socklen_t *addrlen, + struct ds_header *hdr) +{ + union socket_addr sa; + + if (hdr->version == 4) { + if (*addrlen > sizeof(sa.sin)) + *addrlen = sizeof(sa.sin); + + sa.sin.sin_family = AF_INET; + sa.sin.sin_port = hdr->port; + sa.sin.sin_addr.s_addr = hdr->addr.ipv4; + } else { + if (*addrlen > sizeof(sa.sin6)) + *addrlen = sizeof(sa.sin6); + + sa.sin6.sin6_family = AF_INET6; + sa.sin6.sin6_port = hdr->port; + sa.sin6.sin6_flowinfo = hdr->addr.ipv6.flowinfo; + memcpy(&sa.sin6.sin6_addr, &hdr->addr.ipv6.addr, 16); + sa.sin6.sin6_scope_id = 0; + } + memcpy(addr, &sa, *addrlen); +} + +static ssize_t ds_recvfrom(struct rsocket *rs, void *buf, size_t len, int flags, + struct sockaddr *src_addr, socklen_t *addrlen) +{ + struct ds_rmsg *rmsg; + struct ds_header *hdr; + int ret; + + if (!(rs->state & rs_readable)) + return ERR(EINVAL); + + if (!rs_have_rdata(rs)) { + ret = ds_get_comp(rs, rs_nonblocking(rs, flags), + rs_have_rdata); + if (ret) + return ret; + } + + rmsg = &rs->dmsg[rs->rmsg_head]; + hdr = (struct ds_header *) (rmsg->qp->rbuf + rmsg->offset); + if (len > rmsg->length - hdr->length) + len = rmsg->length - hdr->length; + + memcpy(buf, (void *) hdr + hdr->length, len); + if (addrlen) + ds_set_src(src_addr, addrlen, hdr); + + if (!(flags & MSG_PEEK)) { + ds_post_recv(rs, rmsg->qp, hdr); + if (++rs->rmsg_head == rs->rq_size + 1) + rs->rmsg_head = 0; + } + + return len; +} + static ssize_t rs_peek(struct rsocket *rs, void *buf, size_t len) { size_t left = len; @@ -2033,6 +2130,13 @@ ssize_t rrecv(int socket, void *buf, size_t len, int flags) int ret; rs = idm_at(&idm, socket); + if (rs->type == SOCK_DGRAM) { + fastlock_acquire(&rs->slock); + ret = ds_recvfrom(rs, buf, len, flags, src_addr, addrlen); + fastlock_release(&rs->slock); + return ret; + } + if (rs->state & rs_opening) { ret = rs_do_connect(rs); if (ret) { @@ -2093,6 +2197,14 @@ ssize_t rrecvfrom(int socket, void *buf, size_t len, int flags, { int ret; + rs = idm_at(&idm, socket); + if (rs->type == SOCK_DGRAM) { + fastlock_acquire(&rs->slock); + ret = ds_recvfrom(rs, buf, len, flags, src_addr, addrlen); + fastlock_release(&rs->slock); + return ret; + } + ret = rrecv(socket, buf, len, flags); if (ret > 0 && src_addr) rgetpeername(socket, src_addr, addrlen); @@ -2230,8 +2342,10 @@ static ssize_t ds_send_udp(struct rsocket *rs, const void *buf, size_t len, int static ssize_t dsend(struct rsocket *rs, const void *buf, size_t len, int flags) { + struct ds_smsg *msg; struct ibv_sge sge; - int ret = 0; + uint64_t offset; + int flags, ret = 0; if (!rs->conn_dest->ah) return ds_send_udp(rs, buf, len, flags); @@ -2242,29 +2356,18 @@ static ssize_t dsend(struct rsocket *rs, const void *buf, size_t len, int flags) return ret; } - if (len <= rs->sq_inline) { - sge.addr = (uintptr_t) buf; - sge.length = len; - sge.lkey = 0; - ret = ds_send_data(rs, &sge, 1, len, IBV_SEND_INLINE); - } else if (len <= rs_sbuf_left(rs)) { - memcpy((void *) (uintptr_t) rs->ssgl[0].addr, buf, len); - rs->ssgl[0].length = len; - ret = ds_send_data(rs, rs->ssgl, 1, len, 0); - if (len < rs_sbuf_left(rs)) - rs->ssgl[0].addr += len; - else - rs->ssgl[0].addr = (uintptr_t) rs->sbuf; - } else { - rs->ssgl[0].length = rs_sbuf_left(rs); - memcpy((void *) (uintptr_t) rs->ssgl[0].addr, buf, - rs->ssgl[0].length); - rs->ssgl[1].length = len - rs->ssgl[0].length; - memcpy(rs->sbuf, buf + rs->ssgl[0].length, rs->ssgl[1].length); - ret = ds_send_data(rs, rs->ssgl, 2, len, 0); - rs->ssgl[0].addr = (uintptr_t) rs->sbuf + rs->ssgl[1].length; - } + msg = rs->smsg_free; + rs->smsg_free = msg->next; + rs->sqe_avail--; + + memcpy((void *) msg, rs->conn_dest->qp->hdr, rs->conn_dest->qp->hdr.length); + memcpy((void *) msg + rs->conn_dest->qp->hdr.length, buf, len); + sge.addr = (uintptr_t) msg; + sge.length = rs->conn_dest->qp->hdr.length + len; + sge.lkey = rs->smr->lkey; + offset = (uint8_t *) msg - rs->sbuf; + ret = ds_post_send(rs, &sge, ds_send_wr_id(offset, sge.length)); return ret ? ret : len; } @@ -2827,13 +2930,31 @@ int rshutdown(int socket, int how) return 0; } +static void ds_shutdown(struct rsocket *rs) +{ + int ret = 0; + + if (rs->fd_flags & O_NONBLOCK) + rs_set_nonblocking(rs, 0); + + rs->state &= ~(rs_readable | rs_writable); + ds_process_cqs(rs, 0, ds_all_sends_done); + + if (rs->fd_flags & O_NONBLOCK) + rs_set_nonblocking(rs, 1); +} + int rclose(int socket) { struct rsocket *rs; rs = idm_at(&idm, socket); - if (rs->state & rs_connected) - rshutdown(socket, SHUT_RDWR); + if (rs->type == SOCK_STREAM) { + if (rs->state & rs_connected) + rshutdown(socket, SHUT_RDWR); + } else { + ds_shutdown(rs); + } rs_free(rs); return 0; -- 2.41.0