]> git.openfabrics.org - ~shefty/librdmacm.git/commitdiff
Refresh of dsocket
authorSean Hefty <sean.hefty@intel.com>
Mon, 3 Dec 2012 19:22:59 +0000 (11:22 -0800)
committerSean Hefty <sean.hefty@intel.com>
Mon, 3 Dec 2012 19:22:59 +0000 (11:22 -0800)
src/rsocket.c

index d7c3163b9a813d03e337f79e4b3d02b64ba81310..07cf31d914c67b82f1a7f175e6adf49f6537ec4b 100644 (file)
@@ -78,8 +78,11 @@ struct rs_svc_msg {
 };
 
 static pthread_t svc_id;
+static int svc_sock[2];
 static int svc_cnt;
-static int svc_fds[2];
+static int svc_size;
+static struct rsocket **svc_rss;
+static struct pollfd *svc_fds;
 
 static uint16_t def_iomap_size = 0;
 static uint16_t def_inline = 64;
@@ -227,11 +230,19 @@ struct ds_header {
 #define DS_IPV4_HDR_LEN  8
 #define DS_IPV6_HDR_LEN 24
 
+struct ds_dest {
+       union socket_addr addr; /* must be first */
+       struct ds_qp      *qp;
+       struct ibv_ah     *ah;
+       uint32_t           qpn;
+};
+
 struct ds_qp {
        dlist_t           list;
        struct rsocket    *rs;
        struct rdma_cm_id *cm_id;
        struct ds_header  hdr;
+       struct ds_dest    dest;
 
        struct ibv_mr     *smr;
        struct ibv_mr     *rmr;
@@ -240,13 +251,6 @@ struct ds_qp {
        int               cq_armed;
 };
 
-struct ds_dest {
-       union socket_addr addr; /* must be first */
-       struct ds_qp      *qp;
-       struct ibv_ah     *ah;
-       uint32_t           qpn;
-};
-
 struct rsocket {
        int               type;
        int               index;
@@ -332,13 +336,21 @@ struct rsocket {
        int               iomap_pending;
 };
 
-#define DS_UDP_TAG 0x5555555555555555ULL
+#define DS_UDP_TAG 0x55555555
 
 struct ds_udp_header {
-       uint64_t          tag;
+       uint32_t          tag;
        uint8_t           version;
-       uint8_t           reserved[3];
-       uint32_t          qpn;  /* upper 8-bits reserved */
+       uint8_t           op;
+       uint8_t           length;
+       uint8_t           reserved;
+       uint32_t          qpn;  /* lower 8-bits reserved */
+       union {
+               uint32_t ipv4;
+               struct {
+                       uint8_t  addr[16];
+               } ipv6;
+       } addr;
 };
 
 #define ds_next_qp(qp) container_of((qp)->list.next, struct ds_qp, list)
@@ -362,8 +374,118 @@ static void ds_remove_qp(struct rsocket *rs, struct ds_qp *qp)
        }
 }
 
+static int rs_svc_grow_sets(void)
+{
+       struct rsocket **rss;
+       struct pollfd *fds;
+       void *set;
+
+       set = calloc(svc_size + 2, sizeof(*rss) + sizeof(*fds));
+       if (!set)
+               return ENOMEM;
+
+       svc_size += 2;
+       rss = set;
+       fds = set + sizeof(*rss) * svc_size;
+       if (svc_cnt) {
+               memcpy(rss, svc_rss, sizeof(*rss) * svc_cnt);
+               memcpy(fds, svc_fds, sizeof(*fds) * svc_cnt);
+       }
+
+       free(svc_rss);
+       free(svc_fds);
+       svc_rss = rss;
+       svc_fds = fds;
+       return 0;
+}
+
+/*
+ * Index 0 is reserved for the service's communication socket.
+ */
+static int rs_svc_add_rs(struct rsocket *rs)
+{
+       int ret;
+
+       if (svc_cnt >= svc_size - 1) {
+               ret = rs_svc_grow_sets();
+               if (ret)
+                       return ret;
+       }
+
+       svc_rss[++svc_cnt] = rs;
+       svc_fds[svc_cnt].fd = rs->udp_sock;
+       svc_fds[svc_cnt].events = POLLIN;
+       svc_fds[svc_cnt].revents = 0;
+       return 0;
+}
+
+static int rs_svc_rm_rs(struct rsocket *rs)
+{
+       int i;
+
+       for (i = 1; i <= svc_cnt; i++) {
+               if (svc_rss[i] == rs) {
+                       svc_cnt--;
+                       svc_rss[i] = svc_rss[svc_cnt];
+                       svc_fds[i] = svc_fds[svc_cnt];
+                       return 0;
+               }
+       }
+       return EBADF;
+}
+
+static void rs_svc_process_sock(void)
+{
+       struct rs_svc_msg msg;
+
+       read(svc_sock[1], &msg, sizeof msg);
+       switch (msg.op) {
+       case RS_SVC_INSERT:
+               msg.status = rs_svc_add_rs(msg.rs);
+               break;
+       case RS_SVC_REMOVE:
+               msg.status = rs_svc_rm_rs(msg.rs);
+               break;
+       default:
+               msg.status = ENOTSUP;
+               break;
+       }
+       write(svc_sock[1], &msg, sizeof msg);
+}
+
+static void rs_svc_process_rs(struct rsocket *rs)
+{
+
+}
+
 static int rs_svc_run(void *arg)
 {
+       struct rs_svc_msg msg;
+       int i, ret;
+
+       ret = rs_svc_grow_sets();
+       if (ret) {
+               msg.status = ret;
+               write(svc_sock[1] &msg, sizeof msg);
+               return ret;
+       }
+
+       svc_fds[0].fd = svc_sock[1];
+       svc_fds[0].events = POLLIN;
+       do {
+               for (i = 0; i <= svc_cnt; i++)
+                       svc_fds[i].revents = 0;
+
+               poll(svc_fds, svc_cnt + 1, -1);
+               if (svc_fds[0].revents)
+                       rs_svc_process_sock();
+
+               for (i = 1; i <= svc_cnt; i++) {
+                       if (svc_fds[i].revents)
+                               rs_svc_process_rs(svc_rss[i]);
+               }
+       } while (svc_cnt > 1);
+
        return 0;
 }
 
@@ -374,27 +496,35 @@ static int rs_svc_insert(struct rsocket *rs)
 
        pthread_mutex_lock(&mut);
        if (!svc_cnt) {
-               ret = socketpair(AF_INET, SOCK_STREAM, 0, &svc_fds);
+               ret = socketpair(AF_INET, SOCK_STREAM, 0, &svc_sock);
                if (ret)
-                       goto out;
+                       goto err1;
 
                ret = pthread_create(&svc_id, NULL, rs_svc_run, NULL);
                if (ret) {
-                       close(svc_fds[0]);
-                       close(svc_fds[1]);
                        ret = ERR(ret);
-                       goto out;
+                       goto err2;
                }
        }
 
        msg.op = RS_SVC_INSERT;
        msg.status = EINVAL;
        msg.rs = rs;
-       svc_cnt++;
-       write(svc_fds[0], &msg, sizeof msg);
-       read(svc_fds[0], &msg, sizeof msg);
+       write(svc_sock[0], &msg, sizeof msg);
+       read(svc_sock[0], &msg, sizeof msg);
        ret = ERR(msg.status);
-out:
+       if (ret && !svn_cnt)
+               goto err3;
+
+       pthread_mutex_unlock(&mut);
+       return ret;
+
+err3:
+       pthread_join(svc_id, NULL);
+err2:
+       close(svc_sock[0]);
+       close(svc_sock[1]);
+err1:
        pthread_mutex_unlock(&mut);
        return ret;
 }
@@ -408,11 +538,14 @@ static int rs_svc_remove(struct rsocket *rs)
        msg.op = RS_SVC_REMOVE;
        msg.status = EINVAL;
        msg.rs = rs;
-       write(svc_fds[0], &msg, sizeof msg);
-       read(svc_fds[0], &msg, sizeof msg);
+       write(svc_sock[0], &msg, sizeof msg);
+       read(svc_sock[0], &msg, sizeof msg);
        ret = ERR(msg.status);
-       if (!ret && !--svn_cnt)
+       if (!svn_cnt) {
                pthread_join(svc_id, NULL);
+               close(svc_sock[0]);
+               close(svc_sock[1]);
+       }
 
        pthread_mutex_unlock(&mut);
        return ret;
@@ -821,12 +954,14 @@ static void ds_free_qp(struct ds_qp *qp)
 
        if (qp->cm_id) {
                if (qp->cm_id->qp) {
+                       tdelete(&qp->dest.addr, &qp->rs->dest_map, ds_compare_dest);
                        epoll_ctl(qp->rs->epfd, EPOLL_CTL_DEL,
                                  qp->cm_id->recv_cq_channel->fd, NULL);
                        rdma_destroy_qp(qp->cm_id);
                }
                rdma_destroy_id(qp->cm_id);
        }
+
        free(qp);
 }
 
@@ -860,6 +995,7 @@ static void ds_free(struct rsocket *rs)
        if (rs->sbuf)
                free(rs->sbuf);
 
+       tdestroy(rs->dest_map, free);
        fastlock_destroy(&rs->map_lock);
        fastlock_destroy(&rs->cq_wait_lock);
        fastlock_destroy(&rs->cq_lock);
@@ -1317,6 +1453,32 @@ static void ds_format_hdr(struct ds_header *hdr, union socket_addr *addr)
        }
 }
 
+static int ds_add_qp_dest(struct ds_qp *qp, union socket_addr *addr,
+                         socklen_t addrlen)
+{
+       struct ibv_port_attr port_attr;
+       struct ibv_ah_attr attr;
+       int ret;
+
+       memcpy(&qp->dest.addr, addr, addrlen);
+       qp->dest.qp = qp;
+       qp->dest.qpn = qp->cm_id->qp->qp_num;
+
+       ret = ibv_query_port(qp->cm_id->verbs, qp->cm_id->port_num, &port_attr);
+       if (ret)
+               return ret;
+
+       memset(&attr, 0, sizeof attr);
+       attr.dlid = port_attr.lid;
+       attr.port_num = qp->cm_id->port_num;
+       qp->dest.ah = ibv_create_ah(qp->cm_id->pd, &attr);
+       if (!qp->dest.ah)
+               return ERR(ENOMEM);
+
+       tsearch(&qp->dest.addr, &qp->rs->dest_map, ds_compare_addr);
+       return 0;
+}
+
 static int ds_create_qp(struct rsocket *rs, union socket_addr *src_addr,
                        socklen_t addrlen, struct ds_qp **qp)
 {
@@ -1361,7 +1523,11 @@ static int ds_create_qp(struct rsocket *rs, union socket_addr *src_addr,
        if (ret)
                goto err;
 
-       event.events = EPOLLIN | EPOLLOUT;
+       ret = ds_add_qp_dest(*qp, src_addr, addrlen);
+       if (ret)
+               goto err;
+
+       event.events = EPOLLIN;
        event.data.ptr = *qp;
        ret = epoll_ctl(rs->epfd,  EPOLL_CTL_ADD,
                        (*qp)->cm_id->recv_cq_channel->fd, &event);
@@ -1424,15 +1590,17 @@ static int ds_get_dest(struct rsocket *rs, const struct sockaddr *addr,
        if (ret)
                goto out;
 
-       *dest = calloc(1, sizeof(struct ds_dest));
-       if (!*dest) {
-               ret = ERR(ENOMEM);
-               goto out;
-       }
+       if ((addrlen != src_len) || memcmp(addr, src_addr, addrlen)) {
+               *dest = calloc(1, sizeof(struct ds_dest));
+               if (!*dest) {
+                       ret = ERR(ENOMEM);
+                       goto out;
+               }
 
-       memcpy(&(*dest)->addr, addr, addrlen);
-       (*dest)->qp = qp;
-       tsearch((*dest)->addr, &rs->dest_map, ds_compare_addr);
+               memcpy(&(*dest)->addr, addr, addrlen);
+               (*dest)->qp = qp;
+               tsearch((*dest)->addr, &rs->dest_map, ds_compare_addr);
+       }
 out:
        fastlock_release(&rs->map_lock);
        return ret;
@@ -2319,6 +2487,7 @@ static ssize_t ds_sendv_udp(struct rsocket *rs, const struct iovec *iov,
 
        hdr.tag = htonll(DS_UDP_TAG);
        hdr.version = 1;
+       hdr.op = op;
        memset(&hdr->reserved, 0, sizeof hdr->reserved);
        hdr.qpn = htonl(rs->conn_dest->qp->cm_id->qp->qp_num & 0xFFFFFF);
 
@@ -2329,15 +2498,20 @@ static ssize_t ds_sendv_udp(struct rsocket *rs, const struct iovec *iov,
        memset(&msg, 0, sizeof msg);
        msg.msg_iov = miov;
        msg.msg_iovlen = iovcnt + 1;
-       return sendmsg(rs->fd, msg, flags);
+       return sendmsg(rs->udp_sock, msg, flags);
 }
 
-static ssize_t ds_send_udp(struct rsocket *rs, const void *buf, size_t len, int flags)
+static ssize_t ds_send_udp(struct rsocket *rs, const void *buf, size_t len,
+                          int flags, uint8_t op)
 {
        struct iovec iov;
-       iov.iov_base = buf;
-       iov_iov_len = len;
-       return ds_sendv_udp(s, &iov, 1, flags);
+       if (buf && len) {
+               iov.iov_base = buf;
+               iov_iov_len = len;
+               return ds_sendv_udp(rs, &iov, 1, flags, op);
+       } else {
+               return ds_sendv_udp(rs, NULL, 0, flags, op);
+       }
 }
 
 static ssize_t dsend(struct rsocket *rs, const void *buf, size_t len, int flags)
@@ -2348,7 +2522,7 @@ static ssize_t dsend(struct rsocket *rs, const void *buf, size_t len, int flags)
        int flags, ret = 0;
 
        if (!rs->conn_dest->ah)
-               return ds_send_udp(rs, buf, len, flags);
+               return ds_send_udp(rs, buf, len, flags, RS_OP_DATA);
 
        if (!ds_can_send(rs)) {
                ret = ds_get_comp(rs, rs_nonblocking(rs, flags), ds_can_send);
@@ -2633,8 +2807,8 @@ static int rs_poll_rs(struct rsocket *rs, int events,
        int ret;
 
 check_cq:
-       if ((rs->state & rs_connected) || (rs->state == rs_disconnected) ||
-           (rs->state & rs_error)) {
+       if ((rs->type == SOCK_STREAM) && ((rs->state & rs_connected) ||
+            (rs->state == rs_disconnected) || (rs->state & rs_error))) {
                rs_process_cq(rs, nonblock, test);
 
                revents = 0;
@@ -2649,6 +2823,16 @@ check_cq:
                                revents |= POLLERR;
                }
 
+               return revents;
+       } else if (rs->type == SOCK_DGRAM) {
+               ds_process_cqs(rs, nonblock, test);
+
+               revents = 0;
+               if ((events & POLLIN) && rs_have_rdata(rs))
+                       revents |= POLLIN;
+               if ((events & POLLOUT) && ds_can_send(rs))
+                       revents |= POLLOUT;
+
                return revents;
        }
 
@@ -2709,11 +2893,14 @@ static int rs_poll_arm(struct pollfd *rfds, struct pollfd *fds, nfds_t nfds)
                        if (fds[i].revents)
                                return 1;
 
-                       if (rs->state >= rs_connected)
-                               rfds[i].fd = rs->cm_id->recv_cq_channel->fd;
-                       else
-                               rfds[i].fd = rs->cm_id->channel->fd;
-
+                       if (rs->type == SOCK_STREAM) {
+                               if (rs->state >= rs_connected)
+                                       rfds[i].fd = rs->cm_id->recv_cq_channel->fd;
+                               else
+                                       rfds[i].fd = rs->cm_id->channel->fd;
+                       } else {
+                               rfds[i].fd = rs->epfd;
+                       }
                        rfds[i].events = POLLIN;
                } else {
                        rfds[i].fd = fds[i].fd;
@@ -2736,7 +2923,10 @@ static int rs_poll_events(struct pollfd *rfds, struct pollfd *fds, nfds_t nfds)
 
                rs = idm_lookup(&idm, fds[i].fd);
                if (rs) {
-                       rs_get_cq_event(rs);
+                       if (rs->type == SOCK_STREAM)
+                               rs_get_cq_event(rs);
+                       else
+                               ds_get_cq_event(rs);
                        fds[i].revents = rs_poll_rs(rs, fds[i].events, 1, rs_poll_all);
                } else {
                        fds[i].revents = rfds[i].revents;