コード例 #1
0
    def update(self, src_rank):
        """Receive gradients and update"""
        keys = list(self.params.keys())
        grads = dict()
        recv_list = []
        for key in keys:
            to_recv = self.params[key]
            recv_list.append(torch.zeros(to_recv.size()).cuda())

        groupStart()
        for i in range(len(keys)):
            collective.recv(recv_list[i], src_rank, "default")
        groupEnd()

        for i in range(len(keys)):
            grads[keys[i]] = recv_list[i]

        self._inc_gradients(grads)
        if self.grad_counts == len(self.workers):
            #self.optimizer.zero_grad()
            #self._set_gradients(grads)
            self.optimizer.step()
            self.optimizer.zero_grad()

        return True
コード例 #2
0
 def __init__(self, workers, world_size, rank):
     self.params = dict()
     self.optimizer = None
     self.workers = workers
     self.world_size = world_size
     self.rank = rank
     self.grad_counts = 0
     collective.init_collective_group(self.world_size, self.rank, "nccl",
                                      "default")
     for i in range(len(self.workers)):
         recv = torch.zeros(1, ).cuda()
         collective.recv(recv, i, "default")
     for i in range(len(self.workers)):
         recv = torch.zeros(1, ).cuda()
         collective.send(recv, i, "default")
コード例 #3
0
    def compute(self):
        """Returns the loss, and send gradients to servers"""
        # First receive params from servers
        param_shards = []
        weights = self.get_weights(cpu=False)
        params = dict()
        # create the receive lists to group collective calls
        recv_list = []
        for i in range(self.num_ps):
            recv_list.append([])
            param_shard_keys = self.name_list[i]
            for key in param_shard_keys:
                to_recv = weights[key]
                recv_list[-1].append((torch.ones(to_recv.size()) * 2).cuda())

        logging.warning(
            f"worker {self.rank} {recv_list[0][0][0][0]}, {recv_list[0][0].size()}, {recv_list[0][1]}, {recv_list[0][1].size()}, {recv_list[0][2]}, {recv_list[0][2].size()}"
        )
        groupStart()
        for i in range(self.num_ps):
            for j in range(len(self.name_list[i])):
                logging.warning(f"recv {i}{j} {self.name_list[i][j]}")
                collective.recv(recv_list[i][j], self.num_workers + i,
                                "default")
                if j == 2:
                    break
            break
        groupEnd()
        logging.warning(
            f"worker {self.rank} {recv_list[0][0][0][0]}, {recv_list[0][0].size()}, {recv_list[0][1]}, {recv_list[0][1].size()}, {recv_list[0][2]}, {recv_list[0][2].size()}"
        )
        time.sleep(100)
        for i in range(self.num_ps):
            param_shard_keys = self.name_list[i]
            for j in range(len(param_shard_keys)):
                params[param_shard_keys[j]] = recv_list[i][j]

        grad, loss = self.compute_gradients(params)
        split_grad = self.split_gradients(grad, self.assignments)
        groupStart()
        for i in range(self.num_ps):
            this_shard = self.index_shard(split_grad, i)
            for _, v in this_shard.items():
                collective.send(v, self.num_workers + i, "default")
        groupEnd()
        return loss
コード例 #4
0
 def __init__(self, model, batch_size, world_size, rank, num_ps):
     self.model_type = model
     print("=> creating model '{}'".format(model))
     self.model = torchmodels.__dict__[model]().cuda()
     self.criterion = nn.CrossEntropyLoss().cuda()
     self.batch_size = batch_size
     self.train_loader = self.get_data_loader(self.batch_size)
     self.world_size = world_size
     self.rank = rank
     self.num_ps = num_ps
     self.num_workers = self.world_size - self.num_ps
     self.assignments = None
     # index i of this list stores the names of params in ith server.
     self.name_list = [[] for i in range(num_ps)]
     collective.init_collective_group(world_size, rank, "nccl", "default")
     for i in range(num_ps):
         send = torch.ones(1, ).cuda()
         collective.send(send, self.num_workers + i, "default")
     for i in range(num_ps):
         send = torch.ones(1, ).cuda()
         collective.recv(send, self.num_workers + i, "default")
コード例 #5
0
 def do_recv(self, group_name="default", src_rank=0):
     col.recv(self.buffer, src_rank, group_name)
     return self.buffer