]> git.openfabrics.org - ~shefty/librdmacm.git/commitdiff
Refresh of dsocket
authorSean Hefty <sean.hefty@intel.com>
Wed, 28 Nov 2012 00:27:39 +0000 (16:27 -0800)
committerSean Hefty <sean.hefty@intel.com>
Wed, 28 Nov 2012 00:27:39 +0000 (16:27 -0800)
src/rsocket.c

index a81b8f3915971eec0e802318d84e6698b5b6b425..d7c3163b9a813d03e337f79e4b3d02b64ba81310 100644 (file)
@@ -211,10 +211,27 @@ union socket_addr {
        struct sockaddr_in6     sin6;
 };
 
+struct ds_header {
+       uint8_t           version;
+       uint8_t           length;
+       uint16_t          port;
+       union {
+               uint32_t  ipv4;
+               struct {
+                       uint32_t flowinfo;
+                       uint8_t  addr[16];
+               } ipv6;
+       } addr;
+};
+
+#define DS_IPV4_HDR_LEN  8
+#define DS_IPV6_HDR_LEN 24
+
 struct ds_qp {
        dlist_t           list;
        struct rsocket    *rs;
        struct rdma_cm_id *cm_id;
+       struct ds_header  hdr;
 
        struct ibv_mr     *smr;
        struct ibv_mr     *rmr;
@@ -324,7 +341,6 @@ struct ds_udp_header {
        uint32_t          qpn;  /* upper 8-bits reserved */
 };
 
-
 #define ds_next_qp(qp) container_of((qp)->list.next, struct ds_qp, list)
 
 static void ds_insert_qp(struct rsocket *rs, struct ds_qp *qp)
@@ -1285,6 +1301,22 @@ out:
        return ret;
 }
 
+static void ds_format_hdr(struct ds_header *hdr, union socket_addr *addr)
+{
+       if (addr->sa.sa_family == AF_INET) {
+               hdr->version = 4;
+               hdr->length = DS_IPV4_HDR_LEN;
+               hdr->port = addr->sin.sin_port;
+               hdr->addr.ipv4 = addr->sin.sin_addr;
+       } else {
+               hdr->version = 6;
+               hdr->length = DS_IPV6_HDR_LEN;
+               hdr->port = addr->sin6.sin6_port;
+               hdr->addr.ipv6.flowinfo= addr->sin6.sin6_flowinfo;
+               memcpy(&hdr->addr.ipv6.addr, &addr->sin6.sin6_addr, 16);
+       }
+}
+
 static int ds_create_qp(struct rsocket *rs, union socket_addr *src_addr,
                        socklen_t addrlen, struct ds_qp **qp)
 {
@@ -1301,6 +1333,7 @@ static int ds_create_qp(struct rsocket *rs, union socket_addr *src_addr,
        if (ret)
                goto err;
 
+       ds_format_hdr(&(*qp)->hdr, src_addr);
        ret = rdma_bind_addr((*qp)->cm_id, &src_addr->sa);
        if (ret)
                goto err;
@@ -1463,18 +1496,17 @@ static int rs_post_write(struct rsocket *rs,
        return rdma_seterrno(ibv_post_send(rs->cm_id->qp, &wr, &bad));
 }
 
-static int ds_post_send(struct rsocket *rs,
-                       struct ibv_sge *sgl, int nsge,
-                       uint64_t wr_id, int flags)
+static int ds_post_send(struct rsocket *rs, struct ibv_sge *sge,
+                       uint64_t wr_id)
 {
        struct ibv_send_wr wr, *bad;
 
        wr.wr_id = wr_id;
        wr.next = NULL;
-       wr.sg_list = sgl;
-       wr.num_sge = nsge;
+       wr.sg_list = sge;
+       wr.num_sge = 1;
        wr.opcode = IBV_WR_SEND;
-       wr.send_flags = flags;
+       wr.send_flags = (sge.length <= rs->sq_inline) ? IBV_SEND_INLINE : 0;
        wr.wr.ud.ah = rs->conn_dest->ah;
        wr.wr.ud.remote_qpn = rs->conn_dest->qpn;
        wr.wr.ud.remote_qkey = RDMA_UDP_QKEY;
@@ -1512,18 +1544,6 @@ static int rs_write_data(struct rsocket *rs,
                                 flags, addr, rkey);
 }
 
-static int ds_send_data(struct rsocket *rs,
-                       struct ibv_sge *sgl, int nsge,
-                       uint32_t length, int flags)
-{
-       uint64_t offset;
-
-       rs->sqe_avail--;
-       rs->sbuf_bytes_avail -= length;
-       offset = sgl->addr - (uintptr_t) rs->sbuf;
-       return ds_post_send(rs, sgl, nsge, ds_send_wr_id(offset, length), flags);
-}
-
 static int rs_write_direct(struct rsocket *rs, struct rs_iomap *iom, uint64_t offset,
                           struct ibv_sge *sgl, int nsge, uint32_t length, int flags)
 {
@@ -1779,6 +1799,14 @@ static int rs_get_comp(struct rsocket *rs, int nonblock, int (*test)(struct rsoc
        return ret;
 }
 
+static int ds_valid_recv(void *buf, uint32_t len)
+{
+       struct ds_header *hdr = (struct ds_header *) buf;
+       return ((len >= sizeof(*hdr)) &&
+               ((hdr->version == 4 && hdr->length == DS_IPV4_HDR_LEN) ||
+                (hdr->version == 6 && hdr->length == DS_IPV6_HDR_LEN)));
+}
+
 /*
  * Poll all CQs associated with a datagram rsocket.  We need to drop any
  * received messages that we do not have room to store.  To limit drops,
@@ -1789,7 +1817,8 @@ static int rs_get_comp(struct rsocket *rs, int nonblock, int (*test)(struct rsoc
 static void ds_poll_cqs(struct rsocket *rs)
 {
        struct ds_qp *qp;
-       struct ds_smsg *msg;
+       struct ds_smsg *smsg;
+       struct ds_rmsg *rmsg;
        struct ibv_wc wc;
        int ret, cnt;
 
@@ -1804,9 +1833,14 @@ static void ds_poll_cqs(struct rsocket *rs)
                        }
 
                        if (ds_wr_is_recv(wc.wr_id)) {
-                               if (rs->rqe_avail && wc.status == IBV_WC_SUCCESS) {
+                               if (rs->rqe_avail && wc.status == IBV_WC_SUCCESS &&
+                                   ds_valid_recv(qp->rbuf + ds_wr_offset(wc.wr_id),
+                                                 wc.byte_len)) {
                                        rs->rqe_avail--;
-                                       rs->dmsg[rs->rmsg_tail].qp = qp;
+                                       rmsg = &rs->dmsg[rs->rmsg_tail];
+                                       rmsg->qp = qp;
+                                       rmsg->offset = ds_wr_offset(wc.wr_id);
+                                       rmsg->length = wc.byte_len;
                                        if (++rs->rmsg_tail == rs->rq_size + 1)
                                                rs->rmsg_tail = 0;
                                } else {
@@ -1814,12 +1848,10 @@ static void ds_poll_cqs(struct rsocket *rs)
                                                             ds_wr_offset(wc.wr_id));
                                }
                        } else {
-                               if (ds_wr_length(wc.wr_id) > rs->sq_inline) {
-                                       msg = (struct ds_smsg *)
-                                             (rs->sbuf + ds_wr_offset(wc.wr_id));
-                                       msg->next = rs->smsg_free;
-                                       rs->smsg_free = msg;
-                               }
+                               smsg = (struct ds_smsg *)
+                                      (rs->sbuf + ds_wr_offset(wc.wr_id));
+                               smsg->next = rs->smsg_free;
+                               rs->smsg_free = smsg;
                                rs->sqe_avail++;
                        }
 
@@ -1959,7 +1991,12 @@ static int rs_can_send(struct rsocket *rs)
 
 static int ds_can_send(struct rsocket *rs)
 {
-       return rs->sqe_avail && (rs->sbuf_bytes_avail >= RS_SNDLOWAT);
+       return rs->sqe_avail;
+}
+
+static int ds_all_sends_done(struct rsocket *rs)
+{
+       return rs->sqe_avail == rs->sq_size;
 }
 
 static int rs_conn_can_send(struct rsocket *rs)
@@ -1988,6 +2025,66 @@ static int rs_conn_all_sends_done(struct rsocket *rs)
               !(rs->state & rs_connected);
 }
 
+static void ds_set_src(struct sockaddr *addr, socklen_t *addrlen,
+                      struct ds_header *hdr)
+{
+       union socket_addr sa;
+
+       if (hdr->version == 4) {
+               if (*addrlen > sizeof(sa.sin))
+                       *addrlen = sizeof(sa.sin);
+
+               sa.sin.sin_family = AF_INET;
+               sa.sin.sin_port = hdr->port;
+               sa.sin.sin_addr.s_addr =  hdr->addr.ipv4;
+       } else {
+               if (*addrlen > sizeof(sa.sin6))
+                       *addrlen = sizeof(sa.sin6);
+
+               sa.sin6.sin6_family = AF_INET6;
+               sa.sin6.sin6_port = hdr->port;
+               sa.sin6.sin6_flowinfo = hdr->addr.ipv6.flowinfo;
+               memcpy(&sa.sin6.sin6_addr, &hdr->addr.ipv6.addr, 16);
+               sa.sin6.sin6_scope_id = 0;
+       }
+       memcpy(addr, &sa, *addrlen);
+}
+
+static ssize_t ds_recvfrom(struct rsocket *rs, void *buf, size_t len, int flags,
+                          struct sockaddr *src_addr, socklen_t *addrlen)
+{
+       struct ds_rmsg *rmsg;
+       struct ds_header *hdr;
+       int ret;
+
+       if (!(rs->state & rs_readable))
+               return ERR(EINVAL);
+
+       if (!rs_have_rdata(rs)) {
+               ret = ds_get_comp(rs, rs_nonblocking(rs, flags),
+                                 rs_have_rdata);
+               if (ret)
+                       return ret;
+       }
+
+       rmsg = &rs->dmsg[rs->rmsg_head];
+       hdr = (struct ds_header *) (rmsg->qp->rbuf + rmsg->offset);
+       if (len > rmsg->length - hdr->length)
+               len = rmsg->length - hdr->length;
+
+       memcpy(buf, (void *) hdr + hdr->length, len);
+       if (addrlen)
+               ds_set_src(src_addr, addrlen, hdr);
+
+       if (!(flags & MSG_PEEK)) {
+               ds_post_recv(rs, rmsg->qp, hdr);
+               if (++rs->rmsg_head == rs->rq_size + 1)
+                       rs->rmsg_head = 0;
+       }
+
+       return len;
+}
+
 static ssize_t rs_peek(struct rsocket *rs, void *buf, size_t len)
 {
        size_t left = len;
@@ -2033,6 +2130,13 @@ ssize_t rrecv(int socket, void *buf, size_t len, int flags)
        int ret;
 
        rs = idm_at(&idm, socket);
+       if (rs->type == SOCK_DGRAM) {
+               fastlock_acquire(&rs->slock);
+               ret = ds_recvfrom(rs, buf, len, flags, src_addr, addrlen);
+               fastlock_release(&rs->slock);
+               return ret;
+       }
+
        if (rs->state & rs_opening) {
                ret = rs_do_connect(rs);
                if (ret) {
@@ -2093,6 +2197,14 @@ ssize_t rrecvfrom(int socket, void *buf, size_t len, int flags,
 {
        int ret;
 
+       rs = idm_at(&idm, socket);
+       if (rs->type == SOCK_DGRAM) {
+               fastlock_acquire(&rs->slock);
+               ret = ds_recvfrom(rs, buf, len, flags, src_addr, addrlen);
+               fastlock_release(&rs->slock);
+               return ret;
+       }
+
        ret = rrecv(socket, buf, len, flags);
        if (ret > 0 && src_addr)
                rgetpeername(socket, src_addr, addrlen);
@@ -2230,8 +2342,10 @@ static ssize_t ds_send_udp(struct rsocket *rs, const void *buf, size_t len, int
 
 static ssize_t dsend(struct rsocket *rs, const void *buf, size_t len, int flags)
 {
+       struct ds_smsg *msg;
        struct ibv_sge sge;
-       int ret = 0;
+       uint64_t offset;
+       int flags, ret = 0;
 
        if (!rs->conn_dest->ah)
                return ds_send_udp(rs, buf, len, flags);
@@ -2242,29 +2356,18 @@ static ssize_t dsend(struct rsocket *rs, const void *buf, size_t len, int flags)
                        return ret;
        }
 
-       if (len <= rs->sq_inline) {
-               sge.addr = (uintptr_t) buf;
-               sge.length = len;
-               sge.lkey = 0;
-               ret = ds_send_data(rs, &sge, 1, len, IBV_SEND_INLINE);
-       } else if (len <= rs_sbuf_left(rs)) {
-               memcpy((void *) (uintptr_t) rs->ssgl[0].addr, buf, len);
-               rs->ssgl[0].length = len;
-               ret = ds_send_data(rs, rs->ssgl, 1, len, 0);
-               if (len < rs_sbuf_left(rs))
-                       rs->ssgl[0].addr += len;
-               else
-                       rs->ssgl[0].addr = (uintptr_t) rs->sbuf;
-       } else {
-               rs->ssgl[0].length = rs_sbuf_left(rs);
-               memcpy((void *) (uintptr_t) rs->ssgl[0].addr, buf,
-                       rs->ssgl[0].length);
-               rs->ssgl[1].length = len - rs->ssgl[0].length;
-               memcpy(rs->sbuf, buf + rs->ssgl[0].length, rs->ssgl[1].length);
-               ret = ds_send_data(rs, rs->ssgl, 2, len, 0);
-               rs->ssgl[0].addr = (uintptr_t) rs->sbuf + rs->ssgl[1].length;
-       }
+       msg = rs->smsg_free;
+       rs->smsg_free = msg->next;
+       rs->sqe_avail--;
+
+       memcpy((void *) msg, rs->conn_dest->qp->hdr, rs->conn_dest->qp->hdr.length);
+       memcpy((void *) msg + rs->conn_dest->qp->hdr.length, buf, len);
+       sge.addr = (uintptr_t) msg;
+       sge.length = rs->conn_dest->qp->hdr.length + len;
+       sge.lkey = rs->smr->lkey;
+       offset = (uint8_t *) msg - rs->sbuf;
 
+       ret = ds_post_send(rs, &sge, ds_send_wr_id(offset, sge.length));
        return ret ? ret : len;
 }
 
@@ -2827,13 +2930,31 @@ int rshutdown(int socket, int how)
        return 0;
 }
 
+static void ds_shutdown(struct rsocket *rs)
+{
+       int ret = 0;
+
+       if (rs->fd_flags & O_NONBLOCK)
+               rs_set_nonblocking(rs, 0);
+
+       rs->state &= ~(rs_readable | rs_writable);
+       ds_process_cqs(rs, 0, ds_all_sends_done);
+
+       if (rs->fd_flags & O_NONBLOCK)
+               rs_set_nonblocking(rs, 1);
+}
+
 int rclose(int socket)
 {
        struct rsocket *rs;
 
        rs = idm_at(&idm, socket);
-       if (rs->state & rs_connected)
-               rshutdown(socket, SHUT_RDWR);
+       if (rs->type == SOCK_STREAM) {
+               if (rs->state & rs_connected)
+                       rshutdown(socket, SHUT_RDWR);
+       } else {
+               ds_shutdown(rs);
+       }
 
        rs_free(rs);
        return 0;