def update(self, src_rank): """Receive gradients and update""" keys = list(self.params.keys()) grads = dict() recv_list = [] for key in keys: to_recv = self.params[key] recv_list.append(torch.zeros(to_recv.size()).cuda()) groupStart() for i in range(len(keys)): collective.recv(recv_list[i], src_rank, "default") groupEnd() for i in range(len(keys)): grads[keys[i]] = recv_list[i] self._inc_gradients(grads) if self.grad_counts == len(self.workers): #self.optimizer.zero_grad() #self._set_gradients(grads) self.optimizer.step() self.optimizer.zero_grad() return True
def __init__(self, workers, world_size, rank): self.params = dict() self.optimizer = None self.workers = workers self.world_size = world_size self.rank = rank self.grad_counts = 0 collective.init_collective_group(self.world_size, self.rank, "nccl", "default") for i in range(len(self.workers)): recv = torch.zeros(1, ).cuda() collective.recv(recv, i, "default") for i in range(len(self.workers)): recv = torch.zeros(1, ).cuda() collective.send(recv, i, "default")
def compute(self): """Returns the loss, and send gradients to servers""" # First receive params from servers param_shards = [] weights = self.get_weights(cpu=False) params = dict() # create the receive lists to group collective calls recv_list = [] for i in range(self.num_ps): recv_list.append([]) param_shard_keys = self.name_list[i] for key in param_shard_keys: to_recv = weights[key] recv_list[-1].append((torch.ones(to_recv.size()) * 2).cuda()) logging.warning( f"worker {self.rank} {recv_list[0][0][0][0]}, {recv_list[0][0].size()}, {recv_list[0][1]}, {recv_list[0][1].size()}, {recv_list[0][2]}, {recv_list[0][2].size()}" ) groupStart() for i in range(self.num_ps): for j in range(len(self.name_list[i])): logging.warning(f"recv {i}{j} {self.name_list[i][j]}") collective.recv(recv_list[i][j], self.num_workers + i, "default") if j == 2: break break groupEnd() logging.warning( f"worker {self.rank} {recv_list[0][0][0][0]}, {recv_list[0][0].size()}, {recv_list[0][1]}, {recv_list[0][1].size()}, {recv_list[0][2]}, {recv_list[0][2].size()}" ) time.sleep(100) for i in range(self.num_ps): param_shard_keys = self.name_list[i] for j in range(len(param_shard_keys)): params[param_shard_keys[j]] = recv_list[i][j] grad, loss = self.compute_gradients(params) split_grad = self.split_gradients(grad, self.assignments) groupStart() for i in range(self.num_ps): this_shard = self.index_shard(split_grad, i) for _, v in this_shard.items(): collective.send(v, self.num_workers + i, "default") groupEnd() return loss
def __init__(self, model, batch_size, world_size, rank, num_ps): self.model_type = model print("=> creating model '{}'".format(model)) self.model = torchmodels.__dict__[model]().cuda() self.criterion = nn.CrossEntropyLoss().cuda() self.batch_size = batch_size self.train_loader = self.get_data_loader(self.batch_size) self.world_size = world_size self.rank = rank self.num_ps = num_ps self.num_workers = self.world_size - self.num_ps self.assignments = None # index i of this list stores the names of params in ith server. self.name_list = [[] for i in range(num_ps)] collective.init_collective_group(world_size, rank, "nccl", "default") for i in range(num_ps): send = torch.ones(1, ).cuda() collective.send(send, self.num_workers + i, "default") for i in range(num_ps): send = torch.ones(1, ).cuda() collective.recv(send, self.num_workers + i, "default")
def do_recv(self, group_name="default", src_rank=0): col.recv(self.buffer, src_rank, group_name) return self.buffer