def _test_gather_helper(self, group, group_id, rank): for dest in group: tensor = _build_tensor(dest + 1, rank) tensors = [_build_tensor(dest + 1, -1) for i in group] if rank == dest else [] dist.gather(tensor, dst=dest, gather_list=tensors, group=group_id) if rank == dest: expected_tensors = [_build_tensor(dest + 1, i) for i in group] for t1, t2 in zip(tensors, expected_tensors): self.assertEqual(t1, t2) self._barrier()
def gather(data, dst=0, group=None): """ Run gather on arbitrary picklable data (not necessarily tensors). Args: data: any picklable object dst (int): destination rank group: a torch process group. By default, will use a group which contains all ranks on gloo backend. Returns: list[data]: on dst, a list of data gathered from each rank. Otherwise, an empty list. """ if get_world_size() == 1: return [data] if group is None: group = _get_global_gloo_group() if dist.get_world_size(group=group) == 1: return [data] rank = dist.get_rank(group=group) tensor = _serialize_to_tensor(data, group) size_list, tensor = _pad_to_largest_tensor(tensor, group) # receiving Tensor from all ranks if rank == dst: max_size = max(size_list) tensor_list = [ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list ] dist.gather(tensor, tensor_list, dst=dst, group=group) data_list = [] for size, tensor in zip(size_list, tensor_list): buffer = tensor.cpu().numpy().tobytes()[:size] data_list.append(pickle.loads(buffer)) return data_list else: dist.gather(tensor, [], dst=dst, group=group) return []
def check_same_model_params(same_params: bool): # Check that all the params are the same on all ranks # This should be true with and without broadcast_buffers, we don't have any real buffer here receptacle: List[torch.Tensor] = [] if dist.get_backend() != "nccl": for pg in optimizer.param_groups: for p in pg["params"]: # Check the params receptacle = [p.clone() for _ in range(world_size) ] if rank == 0 else [] dist.gather(p, receptacle, dst=0) if rank == 0: for sync_p in receptacle[1:]: if same_params: assert torch.all( torch.eq(receptacle[0], sync_p) ), "Models differ in between ranks" else: assert not torch.all( torch.eq(receptacle[0], sync_p) ), "Gradients should not have been synced" # Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0) if broadcast_buffers: for b in ddp_model.buffers(): receptacle = [b.clone() for _ in range(world_size) ] if rank == 0 else [] dist.gather(b, receptacle, dst=0) if rank == 0: for sync_b in receptacle[1:]: if same_params: assert torch.all( torch.eq(receptacle[0], sync_b) ), "Models differ in between ranks" else: assert not torch.all( torch.eq(receptacle[0], sync_b) ), "Gradients should not have been synced" assert b.cpu().item() == 0.0
def fp_send_proc(conv_wid, conv_wn, fc_wid, fc_wn, wid, wn, pred_wid, succ_wid, comm_rank, world_sz, bs, subbs, pd, input_shp, output_shp, fp_tail_list, shared_cnters, global_step, sta_lidx, end_lidx): #fp_send:0; fp_recv:1; bp_send:2; bp_recv:3 iter_thresh = bs / subbs allreduce_group, fp_gather_group, bp_scatter_group = init_processes( comm_rank, world_sz) print("fp_send_proc comm_rank=", comm_rank) #if wid == wn -1: if succ_wid == -1: shared_cnters[1] = 4 return local_fp_sent_counter = 0 dst_rank = succ_wid * 4 + 1 place_tensor_list = [torch.zeros(1)] while True: #print("fp send ", local_fp_sent_counter, " ", shared_cnters[1]) #fp_tail_tensor if local_fp_sent_counter < shared_cnters[1]: # is it okay to directly send gpu tensor? #print("fp send ", comm_rank, " -> ", dst_rank) #Hard code if wid == 0 or wid == 1: #print(fp_tail_list[local_fp_sent_counter].device) dist.gather(tensor=fp_tail_list[local_fp_sent_counter], gather_list=[], dst=dst_rank, group=fp_gather_group, async_op=False) elif wid == 2: dist.send(tensor=fp_tail_list[local_fp_sent_counter], dst=dst_rank) #print("wid=",wid, " fp send ", fp_tail_list[local_fp_sent_counter].numel()) #print("fin fp send ", comm_rank, " -> ", dst_rank) local_fp_sent_counter += 1 else: time.sleep(0.001) if local_fp_sent_counter == iter_thresh: #reset local_fp_sent_counter = 0 shared_cnters[1].zero_()
def _gather(rank, rows, columns): dest = 0 tensor = _get_tensor(rank, rows, columns) if rank == dest: tensors_list = _get_zeros_tensors_list(rows, columns) logger.debug('Rank: {},\nTensor BEFORE gather: {}. tensors_list: {}'.format( rank, tensor, tensors_list)) dist.gather(tensor=tensor, gather_list=tensors_list) logger.debug('Rank: {},\nTensor AFTER gather: {}. tensors_list: {}\n'.format( rank, tensor, tensors_list)) for i in range(dist.get_world_size()): assert torch.equal(tensors_list[i], _get_tensor(i, rows, columns)), \ 'Rank {}: tensors lists are not the same after gather.' else: logger.debug('Rank: {},\nTensor BEFORE gather: {}\n'.format(rank, tensor)) dist.gather(tensor=tensor, dst=dest) logger.debug('Rank: {},\nTensor AFTER gather: {}\n'.format(rank, tensor)) # tensor shouldn't have changed assert torch.equal(tensor, _get_tensor(rank, rows, columns)), \ 'Rank {}: Tensor got changed after gather.'.format(rank)
def gather(data, *, dst_rank=0, group=None) -> list: """Run gather on arbitrary picklable data (not necessarily tensors). Args: data: any picklable object dst (int): destination rank group: a torch process group. By default, will use a group which contains all ranks on gloo backend. Returns: list[data]: on dst, a list of data gathered from each rank. Otherwise, an empty list. """ if is_single_processes(group=group): return [data] if group is None: group = _get_global_gloo_group() tensor = _serialize_to_tensor(data, group) tensor, tensor_sizes = _pad_to_largest_tensor(tensor, group) if dist.get_rank(group=group) == dst_rank: max_tensor_size = max(tensor_sizes) tensor_list = [ torch.empty((max_tensor_size, ), dtype=torch.uint8, device=tensor.device) for _ in tensor_sizes ] dist.gather(tensor, tensor_list, dst=dst_rank, group=group) datum = [] for length, tensor in zip(tensor_sizes, tensor_list): single_data = tensor.cpu().numpy().tobytes() single_data = single_data[:length] datum.append(pickle.loads(single_data)) return datum else: dist.gather(tensor, [], dst=dst_rank, group=group) return []
def main_func(numProcesses, group, src_tensor): while (True): t = torch.zeros(15) #THE FINAL ELEMENT IS LENGTH WHEN NOT PADDED gather_t = [torch.ones_like(t) for _ in range(numProcesses)] #every process in group sends tensor to this gather_t list dist.gather(tensor=t, gather_list=gather_t, dst=0, group=group) print('GATHERED DATA') print(gather_t[1][:15]) print(gather_t[2][:15]) to_scatter = torch.rand((5, 3)) outputTens = torch.rand((5)) #SIZE OF EACH TENSOR to scatter is main_params.num_children*2 +1 #where first part is the actions, then probs, then leaf value #print('len to scatter: {}'.format(len(to_scatter))) print(to_scatter) to_scatter = np.split(to_scatter, 3, axis=1) #this is vital to make sure memory isn't shared among these vectors to_scatter = [torch.clone(t).squeeze() for t in to_scatter] #to_scatter = [x.view(1,-1) for x in to_scatter] #print('TO SCATTER: ',to_scatter) print('just before scattering: ') #print(to_scatter[1].type) #print(to_scatter[1][:15]) #print(to_scatter[2][:15]) dist.scatter(tensor=outputTens, scatter_list=to_scatter, src=0, group=group) time.sleep(5) exit(1)
def run(rank, numProcesses, group, trg_tensor): print('gathering rank: ', rank) #now just continually gather and scatter until scatter gives a #negative value which means we can exit #and also tell main_func that length is 0 while (True): padded_output = torch.rand((15)) print('Gathering rank: ', rank) print('rank: {}, sending to gather: {}'.format(rank, padded_output)) dist.gather(tensor=padded_output, gather_list=None, dst=0, group=group) #send to process 2 print('Finished gather: ', rank) model_response = torch.rand(5) dist.scatter(tensor=model_response, scatter_list=None, src=0, group=group) print('scatter rank: {}, given: {}'.format(rank, model_response))
def gather_proc(rank): print("rank = ", rank) if rank == 0: ga_t = torch.zeros(2) if rank == 1: ga_t = torch.ones(4) if rank == 2: ga_t = torch.ones(4)*2 if rank == 3: ga_t = torch.ones(4)*3 print("gather tensor = ", ga_t) init_processes(rank, 4, backend='gloo') if rank == 3: g1 = torch.zeros(4) g2 = torch.zeros(4) g3 = torch.zeros(4) g4 = g3 gather_list=[g1,g2,g3,g4] dist.gather(tensor= ga_t, gather_list = gather_list, dst = 3 ) print(gather_list) else: dist.gather(tensor= ga_t, gather_list = [], dst = 3 )
def score_parameters(self) -> List[Tensor]: """ :return: List of Tensors the same shapes as the given Parameters where each Parameter's elements are scored by their weight times the direction of their gradient. """ if not self._is_ddp: return self._movement_scores # move all movement scores to one device and combine scores_flat = [ score.view(-1).to("cpu") for score in self._movement_scores ] if self._is_main_proc: gather_list = [ torch.zeros_like(scores_flat) for _ in range(dist.get_world_size()) ] dist.gather(scores_flat, gather_list=gather_list, group=self._gloo_handle, dst=0) total_scores_flat = torch.sum(torch.stack(gather_list), dim=0) else: dist.gather(scores_flat, group=self._gloo_handle, dst=0) # broadcast total scores to all devices total_scores_flat = self._broadcast_list_from_main( [total_scores_flat if self._is_main_proc else None])[0] # move total scores to correct device on each process score_idx = 0 for idx, score in enumerate(self._movement_scores): next_idx = score_idx + score.numel() score.view(-1)[:] = total_scores_flat[score_idx:next_idx].to( score.device) score_idx = next_idx return self._movement_scores
def average_models(model, group=None, choose_r0=True, weights=None): global fl_round global rat_per_class gp_size = len(all_groups_np[fl_round%len(all_groups)]) if rank == 0 and opt.weight_avg and weights is not None: cur_gp = all_groups_np[fl_round%len(all_groups)] if opt.weight_scheme == 'exp': e_w = [np.exp(w.item()) for w in weights] #Getting e^w for each w in weights (w here is the success rate of workers' generators) else: e_w = [w.item() for w in weights] e_w = np.array(e_w) if not choose_r0: e_w/= sum(e_w[1:]) else: e_w/= sum(e_w) if opt.weight_scheme == 'dirac': e_w = [0 if w < 0.5 else w for w in e_w] #The threshold here is 0.5 #Reweighting after removing the harmful/useless updates (could work as a simulation to taking thee forgiving updates) if not choose_r0: e_w/= sum(e_w[1:]) else: e_w/= sum(e_w) for param in model.parameters(): if rank == 0 and not choose_r0: #If rank=0 is not in included in this round, put zeros instead param.data = torch.zeros(param.size()).cuda() if not opt.weight_avg or weights is None: dist.reduce(param.data, dst=0, op=dist.ReduceOp.SUM, group=group) param.data /= (gp_size if choose_r0 else gp_size - 1) else: gather_list = [] if rank == 0: gather_list = [torch.zeros(param.size()).cuda() if cuda else torch.zeros(param.size()) for _ in range(gp_size)] dist.gather(param.data, gather_list, dst=0, group=group) if rank == 0: param.data = torch.zeros(param.size()).cuda() if cuda else torch.zeros(param.size()) for w,t in zip(e_w,gather_list): param.data+= t*w
def average_gradients(model): size = float(dist.get_world_size()) for param in model.parameters(): """ using all_reduce """ # dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) # param.grad.data /= size """ using gather and scatter """ # group = dist.new_group(list(range(int(size)))) gather_list, scatter_list = None, None if args.rank == 0: gather_list = [torch.zeros_like(param.grad.data)] * int(size) scatter_list = [torch.zeros_like(param.grad.data)] * int(size) dist.gather(tensor=param.grad.data, dst=0, gather_list=gather_list) # dist.gather(tensor=param.grad.data, dst=0) if args.rank == 0: param.grad.data /= size dist.scatter(tensor=param.grad.data, src=0, scatter_list=scatter_list) # dist.scatter(tensor=param.grad.data, src=0) """ using ring-reduce """
def run(rank,numProcesses,group,maxlen,main_params,trg_tensor): mcts = MCTS(tgt_tensor=trg_tensor,group=group,rankInGroup=rank, max_len=maxlen,main_params=main_params) #here actions is list of actions corresponding to the #200 probabilities in mcts_probs bleu, output_states, mcts_probs,actions = mcts.translate_sentence() #write to file fileName = globalsFile.CODEPATH+'MCTSFiles/rank'+str(rank)+'.json' with open(fileName,'w') as f: json.dump([bleu,output_states,mcts_probs,actions],f) print('rank: ',rank, ' is done NOW WAITING FOR REST') while(True): #now just gathering and scattering until main exits padded_output = torch.zeros(maxlen+1)*globalsFile.BLANK_WORD_ID dist.gather(tensor=padded_output,gather_list=None, dst=0,group=group) #send to process 2 model_response = torch.ones(2*main_params.num_children + 1).double() dist.scatter(tensor=model_response,scatter_list=None,src=0,group=group)
def prune_and_eval(rank, size, param_name, prune_threshold, ref_model_dict, ref_sorted_weights, results): local_ref_model_dict = ref_model_dict local_sorted_weights = ref_sorted_weights gpu_id = GPU_ID if rank >= 4 and rank < size: # split tasks to different GPUs gpu_id = GPU_ID2 cuda.set_device(gpu_id) local_checkpoint = torch.load(weights, map_location=lambda storage, loc: storage) local_opt, _, _ = opt_initialize(local_checkpoint, 'opennmt_translate_opt.pt', 'opennmt_translate_dummy_opt.pt') local_opt.gpuid = [gpu_id] _train = torch.load(TRAIN_DATA + '.train.pt') _valid = torch.load(TRAIN_DATA + '.valid.pt') local_fields = load_fields(_train, _valid, local_checkpoint, local_opt) #local_ref_model = init_train_model(local_checkpoint, local_opt, local_fields) # fields need data thenet = init_train_model(local_checkpoint, local_opt, local_fields) # fields need data pruned_model = apply_prune(thenet, local_ref_model_dict, local_sorted_weights, param_name, prune_threshold[rank]) fitness = evaluate(pruned_model, _valid, local_fields, local_opt) fitness[2] = rank tensor_list = [] if rank == 0: # master node tensor_list = [torch.FloatTensor([0.0, 0.1, 0.2]) for i in range(size)] dist.gather(fitness, gather_list=tensor_list) for ind_i in range(size): results[ind_i].copy_(tensor_list[ind_i]) else: dist.gather(fitness, dst=0)
def run(): src = dst = 0; mytensor = torch.zeros(1000) dist.scatter(mytensor,src=src) #processing features,num_frames,freqs = mysimpl(mytensor) frames_features = {} for frame in range(num_frames+1): frames_features[frame] = [] for x in features: # print(x[0],x[1],x[2]) x[1] = framenumber x[0] amp x[2] freq frames_features[int(x[1])].append((x[0],x[2])) frame_freq_bins =[] for x in range(num_frames+1): freq_bins = np.zeros(2048) #dict with key as freqbin to_be_added ={} for y in frames_features[x]: index_i = np.abs(freqs-y[1]).argmin(); if(y[1] < freqs[index_i]): index_i -=1; if index_i not in to_be_added.keys(): to_be_added[index_i] = [] to_be_added[index_i].append(y[0]) all_non_zero_bins = to_be_added.keys() for x in all_non_zero_bins: amp_array =to_be_added[x] amp_array = np.array(amp_array) avg_amp = np.mean(amp_array) freq_bins[x] += avg_amp # freq_bins = torch.LongTensor(freq_bins) frame_freq_bins.append(freq_bins) frame_freq_bins = np.array(frame_freq_bins) frame_freq_bins = torch.from_numpy(frame_freq_bins) dist.gather(frame_freq_bins,dst=dst) return;
def main_func(numProcesses, group): #gather_list = [torch.ones(1),torch.ones(1)] t = torch.ones(1) for i in range(10): print('WORLD SIZE: ', dist.get_world_size()) gather_t = [torch.ones_like(t) for _ in range(dist.get_world_size())] dist.gather(tensor=torch.zeros(1), gather_list=gather_t, dst=0, group=group) #tensor = torch.ones(1)*9 #dist.all_reduce(tensor, op=dist.reduce_op.SUM, group=group) print('IN MAIN gather: ', gather_t) gather_t = [i + 1 for i in gather_t] #now add 1 to each of these then scatter back outputTens = torch.ones(1) dist.scatter(tensor=outputTens, scatter_list=gather_t, src=0, group=group) print('main process: outputtens: ', outputTens)
def run(rank, size): t = torch.tensor([rank for _ in range(1)]) for i in range(10): print("----gather-------") if rank == 0: gather_t = [torch.ones_like(t) for _ in range(size)] dist.gather(tensor=t, dst=0, gather_list=gather_t) print(gather_t) else: t.add_(rank) dist.gather(tensor=t, dst=0, gather_list=[]) print(t) # t.add_(1) print("----broadcast-------") if rank == 0: b = torch.tensor([i for _ in range(1)]) dist.broadcast(tensor=b, src=0) print(b) else: dist.broadcast(tensor=t, src=0) print(t)
def backward(ctx, *grad_output): global myreq my_rank = dist.get_rank() a2ai = ctx.a2ai grad_output = [t.contiguous() for t in grad_output] mb_split_lengths = a2ai.gNS if a2ai.gNS else [a2ai.lN] * my_size per_rank_split_lengths = a2ai.gSS if a2ai.gSS else [a2ai.lS] * my_size grad_inputs = [grad_output[0].new_empty([ctx.a2ai.N, ctx.a2ai.E]) for _ in range(a2ai.lS)] req_list = [] ind = 0 for i in range(my_size): for j in range(per_rank_split_lengths[i]): gather_list = list(grad_inputs[j].split(mb_split_lengths, dim = 0)) if i == my_rank else None req = dist.gather(grad_output[ind], gather_list, dst = i, async_op=True) req_list.append(req) ind += 1 myreq.req = req_list myreq.tensor = grad_inputs return tuple(grad_output)
def backward(ctx, *grad_output): global myreq my_rank = dist.get_rank() #print("All2All_Scatter_Wait:backward") assert len(grad_output) == my_size scatter_list = [t.contiguous() for t in grad_output] a2ai = ctx.a2ai mb_split_lengths = a2ai.gNS if a2ai.gNS else a2ai.lN emb_split_lengths = a2ai.gSS if a2ai.gSS else [a2ai.lS] * my_size grad_input = grad_output[0].new_empty([a2ai.N, a2ai.E*a2ai.lS]) gather_list = list(grad_input.split(mb_split_lengths, dim=0)) req_list = [] for i in range(my_size): #req = dist.scatter(gather_list[i], scatter_list if i == my_rank else [], src=i, async_op=True) req = dist.gather(scatter_list[i], gather_list if i == my_rank else [], dst=i, async_op=True) req_list.append(req) myreq.req = req_list myreq.tensor = grad_input return grad_output
def gather(self, collectiveArgs, retFlag=False, pair=False): if pair: ipTensors = collectiveArgs.ipTensor_pair opTensors = collectiveArgs.opTensor_pair else: ipTensors = collectiveArgs.ipTensor opTensors = collectiveArgs.opTensor retObj = dist.gather( gather_list=opTensors if (collectiveArgs.global_rank == collectiveArgs.srcOrDst) else None, tensor=ipTensors, dst=collectiveArgs.srcOrDst, group=collectiveArgs.group, async_op=collectiveArgs.asyncOp, ) # synchronicity is maintained in runColl if collectiveArgs.asyncOp: collectiveArgs.waitObj.append(retObj) if retFlag: return retObj
def gather_objs(data, dst=0, group=None): """ Run gather on arbitrary picklable data (not necessarily tensors). Args: data: any picklable object dst (int): destination rank group: a torch process group. By default, will use a group which contains all ranks on gloo backend. Returns: list[data]: on dst, a list of data gathered from each rank. Otherwise, an empty list. """ if get_world_size() == 1: return [data] if dist.get_world_size() == 1: return [data] rank = dist.get_rank() world_size = get_world_size() tensor, local_size = _object_to_tensor(data) size_list = [ torch.zeros(1, dtype=torch.int64, device=tensor.device) for _ in range(world_size) ] dist.gather(local_size, size_list if rank == 0 else None) # receiving Tensor from all ranks if rank == dst: tensor_list = [ torch.empty((size,), dtype=torch.uint8, device=tensor.device) for size in size_list ] dist.gather(tensor, tensor_list, dst=dst) data_list = [] for size, tensor in zip(size_list, tensor_list): buffer = tensor.cpu().numpy().tobytes()[:size] data_list.append(pickle.loads(buffer)) return data_list else: dist.gather(tensor, dst=dst) return []
def gather( self, dst: int = 0, out: Optional[torch.Tensor] = None, ) -> None: """ Creates a full :class:`Tensor` on rank ``dst`` by gathering all shards of the sharded tensor. The API needs to be called on all ranks in SPMD fashion. All ranks should have the same ``dst``. ``out`` should be a tensor of the same size as the overall size of the sharded tensor on ``dst`` and ``None`` on all other ranks. Args: dst(int): The rank where full tensor is constructed. Default: 0 out (:class `torch.Tensor`, optional): The output full tensor. Must to be provided ONLY on ``dst`` rank. Default: ``None`` """ def shard_size(shard_md): return reduce((lambda x, y: x * y), shard_md.shard_sizes) # type: ignore[attr-defined] rank = dist.get_rank(self._process_group) full_size = self.metadata().size _validate_output_tensor_for_gather(rank, dst, full_size, out) local_shards = self.local_shards() world_size = dist.get_world_size(self._process_group) rank_sizes = [0 for _ in range(world_size)] max_rank_size = 0 shard_placement: Dict[ShardMetadata, Tuple[int, int]] = dict() # collect sizes for shard_md in self.metadata().shards_metadata: shard_rank = cast(_remote_device, shard_md.placement).rank() assert shard_rank is not None shard_placement[shard_md] = (shard_rank, rank_sizes[shard_rank]) rank_sizes[shard_rank] += shard_size(shard_md) max_rank_size = max(max_rank_size, rank_sizes[shard_rank]) gather_list: Optional[List[torch.Tensor]] if rank == dst: assert out is not None gather_list = [ torch.empty((max_rank_size, ), device=out.device) for _ in range(world_size) ] else: gather_list = None with torch.no_grad(): data = torch.empty(max_rank_size, device=self._get_preferred_device()) for shard in local_shards: src = shard.tensor.flatten() shard_offset = shard_placement[shard.metadata][1] data[shard_offset:shard_offset + src.numel()].copy_(src) dist.gather( tensor=data, gather_list=gather_list, dst=dst, group=self._process_group, ) if rank != dst: return # In _validate_output_tensor_for_gather, we raise if out == None and rank == dst out = cast(torch.Tensor, out) assert gather_list is not None full_size = self.metadata().size dims = len(full_size) for shard_md in self.metadata().shards_metadata: rank, rank_offset = shard_placement[shard_md] tensor = gather_list[rank] tensor = tensor[rank_offset:rank_offset + shard_size(shard_md)] tensor = tensor.view(shard_md.shard_sizes) out_narrow_view = out for dim in range(dims): out_narrow_view = out_narrow_view.narrow( dim, shard_md.shard_offsets[dim], shard_md.shard_sizes[dim], ) out_narrow_view.copy_(tensor)
def gather(self, tensor, gather_list, dst=0, async_op=False): return dist.gather(tensor, gather_list, dst, self.group, async_op)
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]: tensor = torch.ByteTensor(bytes).fill_(42) for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]: for i in range(0, num_tensors): dist.scatter(tensor, src=0) dist.barrier() if rank == 0: print_header("gather") for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]: tensor = torch.ByteTensor(bytes).fill_(42) tensors = [tensor for n in range(0, dist.get_world_size())] for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]: start = timer() for i in range(0, num_tensors): dist.gather(tensor, gather_list=tensors) end = timer() print_stats(bytes, num_tensors, end - start) print() else: for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]: tensor = torch.ByteTensor(bytes).fill_(42) for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]: for i in range(0, num_tensors): dist.gather(tensor, dst=0) dist.barrier() if rank == 0: print_header("all gather") for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]: tensor = torch.ByteTensor(bytes).fill_(42)
def retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> Tuple[np.ndarray, List[dict]]: """ Retrieves documents for specified ``question_hidden_states``. The main process, which has the access to the index stored in memory, gathers queries from all the processes in the main training process group, performs the retrieval and scatters back the results. Args: question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`): A batch of query vectors to retrieve with. n_docs (:obj:`int`): The number of docs retrieved per query. Output: retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)` The retrieval embeddings of the retrieved docs per query. doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`) The ids of the documents in the index doc_dicts (:obj:`List[dict]`): The retrieved_doc_embeds examples per query. """ # single GPU training if not dist.is_initialized(): doc_ids, retrieved_doc_embeds = self._main_retrieve( question_hidden_states, n_docs) return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts( doc_ids) # distributed training world_size = dist.get_world_size(group=self.process_group) # gather logic gather_list = None if self._is_main(): gather_list = [ torch.empty(question_hidden_states.shape, dtype=torch.float32) for _ in range(world_size) ] dist.gather(torch.tensor(question_hidden_states), dst=0, gather_list=gather_list, group=self.process_group) # scatter logic n_queries = question_hidden_states.shape[0] scatter_ids = [] scatter_vectors = [] if self._is_main(): assert len(gather_list) == world_size ids, vectors = self._main_retrieve( torch.cat(gather_list).numpy(), n_docs) ids, vectors = torch.tensor(ids), torch.tensor(vectors) scatter_ids = self._chunk_tensor(ids, n_queries) scatter_vectors = self._chunk_tensor(vectors, n_queries) doc_ids = self._scattered(scatter_ids, [n_queries, n_docs], target_type=torch.int64) retrieved_doc_embeds = self._scattered( scatter_vectors, [n_queries, n_docs, question_hidden_states.shape[1]]) return retrieved_doc_embeds.numpy(), doc_ids.numpy( ), self.index.get_doc_dicts(doc_ids)
def run(rank, model, train_pics, train_bsz): workers = [int(v) for v in str(args.learners).split('-')] _group = [w for w in workers].append(rank) group = dist.new_group(_group) for p in model.parameters(): scatter_p_list = [p.data for _ in range(len(workers) + 1)] dist.scatter(tensor=p.data, scatter_list=scatter_p_list, group=group) print('Model Sent Finished!') print('Begin!') transform = transforms.Compose([ transforms.Resize(128), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) tmp = [ (0, 0) for _ in range(int(math.ceil(train_pics / (len(workers) * train_bsz)))) ] pre_time = datetime.datetime.now() for epoch in range(args.epochs): for batch_idx, (_, _) in enumerate(tmp): for param in model.parameters(): tensor = torch.zeros_like(param.data) # FIXME FIXED:gather_list中的每个Tensor都必须是新的对象,否则会出问题 gather_list = [ torch.zeros_like(param.data) for _ in range(len(workers) + 1) ] dist.gather(tensor=tensor, gather_list=gather_list, group=group) tensor = sum(gather_list) / len(workers) param.data -= tensor scatter_list = [param.data for _ in range(len(workers) + 1)] dist.scatter(tensor=tensor, scatter_list=scatter_list, group=group) print('Done {}/{}!'.format(batch_idx, len(tmp))) print('Done Epoch {}/{}!'.format(epoch + 1, args.epochs)) end_time = datetime.datetime.now() # 测试ps的模型准确率 h, remainder = divmod((end_time - pre_time).seconds, 3600) m, s = divmod(remainder, 60) time_str = "Time %02d:%02d:%02d" % (h, m, s) test_dataset = datasets.CIFAR10(args.data_dir, train=False, download=False, transform=transform) criterion = torch.nn.CrossEntropyLoss() test_data = DataLoader(test_dataset, batch_size=128, shuffle=True) test_loss, acc = test_model(dist.get_rank(), model, test_data, criterion=criterion) print('total time ' + str(time_str)) f = open('./result_' + str(rank) + '_' + args.model + '.txt', 'a') f.write('Rank: ' + str(rank) + ', \tEpoch: ' + str(args.epochs) + ', \tTestLoss: ' + str(test_loss) + ', \tTestAcc: ' + str(acc) + ', \tTotalTime: ' + str(time_str) + '\n') f.close()
def run(rank, workers, model, save_path, train_data, test_data): # 获取ps端传来的模型初始参数 _group = [w for w in workers].append(0) group = dist.new_group(_group) for p in model.parameters(): tmp_p = torch.zeros_like(p) dist.scatter(tensor=tmp_p, src=0, group=group) p.data = tmp_p print('Model recved successfully!') optimizer = MySGD(model.parameters(), lr=0.01, momentum=0.5) criterion = torch.nn.CrossEntropyLoss() print('Begin!') for epoch in range(args.epochs): pre_time = datetime.datetime.now() model.train() # AlexNet在指定epoch减少学习率LR if args.model == 'AlexNet': if epoch + 1 in [40, 60]: for param_group in optimizer.param_groups: param_group['lr'] *= 0.1 print('LR Decreased! Now: {}'.format(param_group['lr'])) epoch_train_loss = 0 for batch_idx, (data, target) in enumerate(train_data): data, target = Variable(data), Variable(target) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() delta_ws = optimizer.get_delta_w() # 同步操作 for idx, param in enumerate(model.parameters()): dist.gather(tensor=delta_ws[idx], dst=0, group=group) recv = torch.zeros_like(delta_ws[idx]) dist.scatter(tensor=recv, src=0, group=group) param.data = recv epoch_train_loss += loss.data.item() print('Rank {}, Epoch {}, Batch {}/{}, Loss:{}'.format( rank, epoch, batch_idx, len(train_data), loss.data.item())) end_time = datetime.datetime.now() h, remainder = divmod((end_time - pre_time).seconds, 3600) m, s = divmod(remainder, 60) time_str = "Time %02d:%02d:%02d" % (h, m, s) epoch_train_loss /= len(train_data) epoch_train_loss = format(epoch_train_loss, '.4f') # 训练结束后进行test test_loss, acc = test_model(rank, model, test_data, criterion=criterion) print('total time ' + str(time_str)) f = open('./result_' + str(rank) + '_' + args.model + '.txt', 'a') f.write('Rank: ' + str(rank) + ', \tEpoch: ' + str(epoch + 1) + ', \tTrainLoss: ' + str(epoch_train_loss) + ', \tTestLoss: ' + str(test_loss) + ', \tTestAcc: ' + str(acc) + ', \tTime: ' + str(time_str) + '\n') f.close() if (epoch + 1) % 5 == 0: if not os.path.exists(save_path): os.makedirs(save_path) torch.save( model.state_dict(), save_path + '/' + args.model + '_' + str(epoch + 1) + '.pkl')
def main_func(numProcesses,group,src_tensor,maxlen,model): #exit('Exiting 1') starting_time = time.time() #POSSIBLE PROBLEM: #Some of the processes will finish before others from needing less calls if somehow continually hit EOS on their sims #now continually gather then scatter until we see that #the results from gather while(True): t = torch.zeros(maxlen+1) #THE FINAL ELEMENT IS LENGTH WHEN NOT PADDED gather_t = [torch.ones_like(t) for _ in range(numProcesses)] #every process in group sends tensor to this gather_t list dist.gather(tensor=t,gather_list=gather_t,dst=0,group=group) #print('GATHERED DATA') #print(gather_t[1][:15]) #print(gather_t[2][:15]) #trim them down to the maximum length of all gathered when remove padding #don't use first element in list since it's from this process dec_lengths = torch.tensor([x[-1] for x in gather_t[1:]]).long() #print('Dec_lengths: ',dec_lengths) #assert(min(dec_lengths) > 0) max_gathered_len = int(dec_lengths.max().item()) if max_gathered_len == 0: print('max gathered len is 0') #for some reason last scatter not seen by other processes so best #way to shut them all down is throw exception which returns #control our main function. exit(1) #print('max gathered len:',max_gathered_len) #print('gather_t: ',gather_t) #TO DO: #ONCE THIS IS WORKING: don't send blanks through, can filter dec_input = torch.cat([x[:max_gathered_len].view(-1,1) for x in gather_t[1:]],1).long() #print('dec_input: ',dec_input) dec_lengths[dec_lengths==0]+=1 #allows function to work for trees that are finished #mask for decoder_input happens within this function log_probs, values = model.forward(src_tensor,dec_input, sentence_lens=dec_lengths,req_grad=False) #need to get top model.num_children probs and their corresponding actions #which are the indices #print('log probs shape: ',log_probs.shape) sorted_probs,inds = torch.sort(log_probs,dim=1,descending=True) inds = inds.double() #so that concat with probs and values #now stack sorted_probs under inds then put value underneath, (this is what #the other processes are expecting as format) #print('INDS shape: ',inds.shape) #print('values shape: ',values.shape) to_scatter = torch.cat([inds[:,:model.num_children].transpose(0,1), sorted_probs[:,:model.num_children].transpose(0,1), values.unsqueeze(0)], dim=0).to('cpu') #print('to_scatter shape: ',to_scatter.shape) #print('values: ',values) #print(to_scatter[-1,:]) #print('shape to_scatter: ',to_scatter.shape) #print('to scatter type : ',to_scatter.type()) #print('shape to_scatter: ',to_scatter.shape) #now have a tensor which we need to split column wise into lists to_scatter = list(np.split(to_scatter,to_scatter.shape[1],axis=1)) #print('after split: len: {}, first el: {}'.format(len(to_scatter),to_scatter[0])) #print('first 50 of to_scatter') #print(to_scatter[1][:50]) #exit(1) #need to clone compoennets so that they don't share memory to_scatter = [t.clone().squeeze(1) for t in to_scatter] #print('len to_scatter: {}, shape to_scatter[0]: {}'.format(len(to_scatter),to_scatter[0].shape)) #now add fake tensor for this process to start of this list to_scatter.insert(0,torch.ones(len(to_scatter[0])).double()) outputTens = torch.ones(len(to_scatter[0])).double() #SIZE OF EACH TENSOR to scatter is main_params.num_children*2 +1 #where first part is the actions, then probs, then leaf value #print('len to scatter: {}'.format(len(to_scatter))) #print('just before scattering: ') #print(to_scatter[1].type) #print(to_scatter[1][:15]) #print(to_scatter[2][:15]) dist.scatter(tensor=outputTens,scatter_list=to_scatter,src=0,group=group)
def train_and_validate_federated_drfa(client): """The training scheme of Distributionally Robust Federated Learning DRFA. paper: https://papers.nips.cc/paper/2020/hash/ac450d10e166657ec8f93a1b65ca1b14-Abstract.html """ log('start training and validation with Federated setting.', client.args.debug) if client.args.evaluate and client.args.graph.rank == 0: # Do the testing on the server and return do_validate(client.args, client.model, client.optimizer, client.criterion, client.metrics, client.test_loader, client.all_clients_group, data_mode='test') return # Initialize lambda variable proportianate to their sample size if client.args.graph.rank == 0: gather_list_size = [ torch.tensor(0.0) for _ in range(client.args.graph.n_nodes) ] dist.gather(torch.tensor(client.args.num_samples_per_epoch, dtype=torch.float32), gather_list=gather_list_size, dst=0) lambda_vector = torch.stack( gather_list_size) / client.args.train_dataset_size else: dist.gather(torch.tensor(client.args.num_samples_per_epoch, dtype=torch.float32), dst=0) lambda_vector = torch.tensor([1 / client.args.graph.n_nodes] * client.args.graph.n_nodes) tracker = define_local_training_tracker() start_global_time = time.time() tracker['start_load_time'] = time.time() log('enter the training.', client.args.debug) # Number of communication rounds in federated setting should be defined for n_c in range(client.args.num_comms): client.args.rounds_comm += 1 client.args.comm_time.append(0.0) # Configuring the devices for this round of communication # TODO: not make the server rank hard coded log("Starting round {} of training".format(n_c + 1), client.args.debug) online_clients = set_online_clients(client.args) if n_c == 0: # The first round server should be in the communication to initilize its own training online_clients = online_clients if 0 in online_clients else online_clients + [ 0 ] online_clients_server = online_clients if 0 in online_clients else online_clients + [ 0 ] online_clients_group = dist.new_group(online_clients_server) client.args.drfa_gamma *= 0.9 if client.args.graph.rank in online_clients_server: if client.args.federated_type == 'scaffold': st = time.time() client.model_server, client.model_server_control = distribute_model_server_control( client.model_server, client.model_server_control, online_clients_group, src=0) client.args.comm_time[-1] += time.time() - st else: st = time.time() model_server = distribute_model_server(client.model_server, online_clients_group, src=0) client.args.comm_time[-1] += time.time() - st client.model.load_state_dict(client.model_server.state_dict()) # Send related variables to drfa algorithm st = time.time() dist.broadcast(client.lambda_vector, src=0, group=online_clients_group) # Sending the random number k to all nodes: # Does not fully support the epoch mode now k = torch.randint(low=1, high=client.args.local_step, size=(1, )) dist.broadcast(k, src=0, group=online_clients_group) client.args.comm_time[-1] += time.time() - st k = k.tolist()[0] local_steps = 0 # Start running updates on local machines if client.args.graph.rank in online_clients: is_sync = False while not is_sync: for _input, _target in client.train_loader: local_steps += 1 # Getting the k-th model for dual variable update if k == local_steps: client.kth_model.load_state_dict( client.model.state_dict()) client.model.train() # update local step. logging_load_time(tracker) # update local index and get local step client.args.local_index += 1 client.args.local_data_seen += len(_target) get_current_epoch(client.args) local_step = get_current_local_step(client.args) # adjust learning rate (based on the # of accessed samples) lr = adjust_learning_rate(client.args, client.optimizer, client.scheduler) # load data _input, _target = load_data_batch( client.args, _input, _target, tracker) # Skip batches with one sample because of BatchNorm issue in some models! if _input.size(0) == 1: is_sync = is_sync_fed(client.args) break # inference and get current performance. client.optimizer.zero_grad() loss, performance = inference(client.model, client.criterion, client.metrics, _input, _target) # compute gradient and do local SGD step. loss.backward() if client.args.federated_type == 'fedgate': for client_param, delta_param in zip( client.model.parameters(), client.model_delta.parameters()): client_param.grad.data -= delta_param.data elif client.args.federated_type == 'scaffold': for cp, ccp, scp in zip( client.model.parameters(), client.model_client_control.parameters(), client.model_server_control.parameters()): cp.grad.data += scp.data - ccp.data client.optimizer.step( apply_lr=True, apply_in_momentum=client.args.in_momentum, apply_out_momentum=False) # logging locally. # logging_computing(tracker, loss, performance, _input, lr) # display the logging info. # logging_display_training(client.args, tracker) # reset load time for the tracker. tracker['start_load_time'] = time.time() is_sync = is_sync_fed(client.args) if is_sync: break else: log("Offline in this round. Waiting on others to finish!", client.args.debug) # Validate the local models befor sync do_validate(client.args, client.model, client.optimizer, client.criterion, client.metrics, client.train_loader, online_clients_group, data_mode='train', local=True) if client.args.fed_personal: do_validate(client.args, client.model, client.optimizer, client.criterion, client.metrics, client.val_loader, online_clients_group, data_mode='validation', local=True) # Sync the model server based on model_clients log('Enter synching', client.args.debug) tracker['start_sync_time'] = time.time() client.args.global_index += 1 if client.args.federated_type == 'fedgate': client.model_server, client.model_delta = fedgate_aggregation( client.args, client.model_server, client.model, client.model_delta, client.model_memory, online_clients_group, online_clients, client.optimizer, lr, local_steps, lambda_weight=client.lambda_vector[ client.args.graph.rank].item()) elif client.args.federated_type == 'scaffold': client.model_server, client.model_client_control, client.model_server_control = scaffold_aggregation( client.args, client.model_server, client.model, client.model_server_control, client.model_client_control, online_clients_group, online_clients, client.optimizer, lr, local_steps, lambda_weight=client.lambda_vector[ client.args.graph.rank].item()) else: client.model_server = fedavg_aggregation( client.args, client.model_server, client.model, online_clients_group, online_clients, client.optimizer, lambda_weight=client.lambda_vector[ client.args.graph.rank].item()) # Average the kth_model client.kth_model = aggregate_models_virtual( client.args, client.kth_model, online_clients_group, online_clients) # evaluate the sync time logging_sync_time(tracker) # Do the validation on the server model do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.train_loader, online_clients_group, data_mode='train') if client.args.fed_personal: do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.val_loader, online_clients_group, data_mode='validation') # validate the model at the server if client.args.graph.rank == 0: do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.test_loader, online_clients_group, data_mode='test') else: log("Offline in this round. Waiting on others to finish!", client.args.debug) # Update lambda parameters online_clients_lambda = set_online_clients(client.args) online_clients_server_lambda = online_clients_lambda if 0 in online_clients_lambda else [ 0 ] + online_clients_lambda online_clients_group_lambda = dist.new_group( online_clients_server_lambda) if client.args.graph.rank in online_clients_server_lambda: st = time.time() client.kth_model = distribute_model_server( client.kth_model, online_clients_group_lambda, src=0) client.args.comm_time[-1] += time.time() - st loss = torch.tensor(0.0) if client.args.graph.rank in online_clients_lambda: for _input, _target in client.train_loader: _input, _target = load_data_batch(client.args, _input, _target, tracker) # Skip batches with one sample because of BatchNorm issue in some models! if _input.size(0) == 1: break loss, _ = inference(client.kth_model, client.criterion, client.metrics, _input, _target) break loss_tensor_online = loss_gather( client.args, torch.tensor(loss.item()), group=online_clients_group_lambda, online_clients=online_clients_lambda) if client.args.graph.rank == 0: num_online_clients = len( online_clients_lambda ) if 0 in online_clients_lambda else len( online_clients_lambda) + 1 loss_tensor = torch.zeros(client.args.graph.n_nodes) loss_tensor[sorted( online_clients_server_lambda)] = loss_tensor_online * ( client.args.graph.n_nodes / num_online_clients) # Dual update client.lambda_vector += client.args.drfa_gamma * client.args.local_step * loss_tensor client.lambda_vector = euclidean_proj_simplex( client.lambda_vector) # Avoid zero probability lambda_zeros = client.lambda_vector <= 1e-3 if lambda_zeros.sum() > 0: client.lambda_vector[lambda_zeros] = 1e-3 client.lambda_vector /= client.lambda_vector.sum() # logging. logging_globally(tracker, start_global_time) # reset start round time. start_global_time = time.time() log( 'This round communication time is: {}'.format( client.args.comm_time[-1]), client.args.debug) dist.barrier(group=client.all_clients_group) return
def train_and_validate_federated_afl(client): """The training scheme of Federated Learning systems. This the implementation of Agnostic Federated Learning https://arxiv.org/abs/1902.00146 """ log('start training and validation with Federated setting.', client.args.debug) if client.args.evaluate and client.args.graph.rank == 0: # Do the testing on the server and return do_validate(client.args, client.model, client.optimizer, client.criterion, client.metrics, client.test_loader, client.all_clients_group, data_mode='test') return # Initialize lambda variable proportianate to their sample size if client.args.graph.rank == 0: gather_list_size = [ torch.tensor(0.0) for _ in range(client.args.graph.n_nodes) ] dist.gather(torch.tensor(client.args.num_samples_per_epoch, dtype=torch.float32), gather_list=gather_list_size, dst=0) client.lambda_vector = torch.stack( gather_list_size) / client.args.train_dataset_size else: dist.gather(torch.tensor(client.args.num_samples_per_epoch, dtype=torch.float32), dst=0) client.lambda_vector = torch.tensor([1 / client.args.graph.n_nodes] * client.args.graph.n_nodes) tracker = define_local_training_tracker() start_global_time = time.time() tracker['start_load_time'] = time.time() log('enter the training.', client.args.debug) # Number of communication rounds in federated setting should be defined for n_c in range(client.args.num_comms): client.args.rounds_comm += 1 client.args.comm_time.append(0.0) # Configuring the devices for this round of communication # TODO: not make the server rank hard coded log("Starting round {} of training".format(n_c + 1), client.args.debug) online_clients = set_online_clients(client.args) if n_c == 0: # The first round server should be in the communication to initilize its own training online_clients = online_clients if 0 in online_clients else online_clients + [ 0 ] online_clients_server = online_clients if 0 in online_clients else online_clients + [ 0 ] online_clients_group = dist.new_group(online_clients_server) if client.args.graph.rank in online_clients_server: st = time.time() client.model_server = distribute_model_server(client.model_server, online_clients_group, src=0) dist.broadcast(client.lambda_vector, src=0, group=online_clients_group) client.args.comm_time[-1] += time.time() - st client.model.load_state_dict(client.model_server.state_dict()) # This loss tensor is for those clients not participating in the first round loss = torch.tensor(0.0) # Start running updates on local machines if client.args.graph.rank in online_clients: is_sync = False while not is_sync: for _input, _target in client.train_loader: client.model.train() # update local step. logging_load_time(tracker) # update local index and get local step client.args.local_index += 1 client.args.local_data_seen += len(_target) get_current_epoch(client.args) local_step = get_current_local_step(client.args) # adjust learning rate (based on the # of accessed samples) lr = adjust_learning_rate(client.args, client.optimizer, client.scheduler) # load data _input, _target = load_data_batch( client.args, _input, _target, tracker) # Skip batches with one sample because of BatchNorm issue in some models! if _input.size(0) == 1: is_sync = is_sync_fed(client.args) break # inference and get current performance. client.optimizer.zero_grad() loss, performance = inference(client.model, client.criterion, client.metrics, _input, _target) # compute gradient and do local SGD step. loss.backward() client.optimizer.step( apply_lr=True, apply_in_momentum=client.args.in_momentum, apply_out_momentum=False) # logging locally. # logging_computing(tracker, loss, performance, _input, lr) # display the logging info. # logging_display_training(args, tracker) # reset load time for the tracker. tracker['start_load_time'] = time.time() is_sync = is_sync_fed(client.args) if is_sync: break else: log("Offline in this round. Waiting on others to finish!", client.args.debug) # Validate the local models befor sync do_validate(client.args, client.model, client.optimizer, client.criterion, client.metrics, client.train_loader, online_clients_group, data_mode='train', local=True) if client.args.fed_personal: do_validate(client.args, client.model, client.optimizer, client.criterion, client.metrics, client.val_loader, online_clients_group, data_mode='validation', local=True) # Sync the model server based on client models log('Enter synching', client.args.debug) tracker['start_sync_time'] = time.time() client.args.global_index += 1 client.model_server, loss_tensor_online = afl_aggregation( client.args, client.model_server, client.model, client.lambda_vector[client.args.graph.rank].item(), torch.tensor(loss.item()), online_clients_group, online_clients, client.optimizer) # evaluate the sync time logging_sync_time(tracker) # Do the validation on the server model do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.train_loader, online_clients_group, data_mode='train') if client.args.fed_personal: do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.val_loader, online_clients_group, data_mode='validation') # Updating lambda variable if client.args.graph.rank == 0: num_online_clients = len( online_clients ) if 0 in online_clients else len(online_clients) + 1 loss_tensor = torch.zeros(client.args.graph.n_nodes) loss_tensor[sorted(online_clients_server)] = loss_tensor_online # Dual update client.lambda_vector += client.args.drfa_gamma * loss_tensor # Projection into a simplex client.lambda_vector = euclidean_proj_simplex( client.lambda_vector) # Avoid zero probability lambda_zeros = client.lambda_vector <= 1e-3 if lambda_zeros.sum() > 0: client.lambda_vector[lambda_zeros] = 1e-3 client.lambda_vector /= client.lambda_vector.sum() # logging. logging_globally(tracker, start_global_time) # reset start round time. start_global_time = time.time() # validate the model at the server if client.args.graph.rank == 0: do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.test_loader, online_clients_group, data_mode='test') log( 'This round communication time is: {}'.format( client.args.comm_time[-1]), client.args.debug) else: log("Offline in this round. Waiting on others to finish!", client.args.debug) dist.barrier(group=client.all_clients_group) return
def prune_and_eval(rank, size, orig_fit, acc_constraint, valid, es, ref_model, num_runs, final_results): _valid = valid gpu_id = GPU_ID total_iterations = es.Tmax / es.popsize individual_iter_count = 0 #ref_model = masked_models[rank] X = torch.Tensor(copy.deepcopy(es.pop)) communicate_size = es.n + 4 # the size of tensors transfer accross computers communicate_tensor = torch.FloatTensor(communicate_size * [0.]) fitness_list = [] itr_best_remain = 0 if rank == 0: # rank 0 is the main process to collect finesses X.share_memory_() #fitness_list = [torch.FloatTensor([0.0,0.1,0.2,0.3]).share_memory_() for i in range(size)] fitness_list = [ torch.FloatTensor(communicate_size * [0.]).share_memory_() for i in range(size) ] if rank >= 1 and rank < size: # split tasks to different GPUs gpu_id = other_GPU_IDs[rank - 1] with cuda.device(gpu_id): local_fields = onmt.IO.load_fields(torch.load(TRAIN_DATA + '.vocab.pt')) _valid.fields = local_fields # fields can not be packed, so reconstruct it in each threahds while (individual_iter_count < total_iterations): if rank == 0: # master node itr_X = torch.Tensor(es.ask()) # broadcast the fathers X.copy_(itr_X) dist.broadcast(itr_X, 0) else: # recieve fathers from the source process dist.broadcast(X, 0) # apply MP on model x = X.numpy()[rank] ref_model.change_mask(x, apply_MP_on_mask) ref_model.apply_mask() # evaluate pruned network fitness = evaluate(ref_model, _valid, local_fields) communicate_tensor[0] = fitness[0] communicate_tensor[1] = fitness[1] communicate_tensor[2] = rank communicate_tensor[3] = ref_model.get_sparsity() for i in range(x.size): communicate_tensor[i + 4] = X[rank, i] #x[i] # sync fitness if rank == 0: # collect fitness across processes dist.gather(communicate_tensor, gather_list=fitness_list) else: dist.gather(communicate_tensor, dst=0) # judge new solutions if rank == 0: # negatively correlated search in master node fit = [] X_ = [] for i in range(es.popsize): the_fitness = 100 for j in range(len( fitness_list)): # results of fitness evaluation if int(fitness_list[j] [2]) == i: # 0:ppl, 1:acc, 2:rank of individual X_.append(fitness_list[j].numpy()[4:]) if orig_fit[1] - fitness_list[j][ 1] <= acc_constraint: the_fitness = -fitness_list[j][3] else: the_fitness = (orig_fit[1] - fitness_list[j][1] ) / acc_constraint continue fit.append(the_fitness) es.tell(X_, fit) itr_best_remain = min(fit) final_results['result_NCS'].copy_(torch.Tensor(es.result()[0])) individual_iter_count += 1 if rank == 0: # record status logger.scalar_summary( 'ncs_%s_fitness' % num_runs, es.result()[1], num_runs * total_iterations + individual_iter_count) logger.scalar_summary( 'ncs_%s_best_itr_remain' % num_runs, itr_best_remain, num_runs * total_iterations + individual_iter_count) logger.histo_summary( 'ncs_%s_pop' % num_runs, es.result()[0], num_runs * total_iterations + individual_iter_count) logger.histo_summary( 'pop of 1', X_[0], num_runs * total_iterations + individual_iter_count) logger.scalar_summary( 'sp of 1', -fitness_list[0][3], num_runs * total_iterations + individual_iter_count) logger.scalar_summary( 'rank of 1', fitness_list[0][2], num_runs * total_iterations + individual_iter_count) logger.histo_summary( 'pop of 2', X_[1], num_runs * total_iterations + individual_iter_count) logger.scalar_summary( 'sp of 2', -fitness_list[1][3], num_runs * total_iterations + individual_iter_count) logger.scalar_summary( 'rank of 2', fitness_list[1][2], num_runs * total_iterations + individual_iter_count) #logger.histo_summary('pop of 3', X_[2], num_runs*total_iterations + individual_iter_count) #logger.scalar_summary('sp of 3', -fitness_list[2][3], num_runs*total_iterations + individual_iter_count) #logger.scalar_summary('rank of 3', fitness_list[2][2], num_runs*total_iterations + individual_iter_count) ref_model.clear_cache()
# ---------------- ALL_REDUCE ----------------- if True: tensor = new_tensor(device_id, local_value) dist.all_reduce(tensor, op=dist.reduce_op.SUM) # all ranks become 10.0 print('{} AFTER all_reduce {}'.format(local_rank, tensor)) assert_mean(tensor, 10.) # ---------------- GATHER ----------------- if backend in ['tcp']: tensor = new_tensor(device_id, local_value) if local_rank == 0: gather_list = [new_tensor(device_id, 0.) for _ in range(4)] dist.gather(tensor, dst=0, gather_list=gather_list) else: dist.gather(tensor, dst=0) gather_list = None # all ranks become 24 if local_rank == 0: print('{} AFTER gather {}'.format(local_rank, gather_list)) assert_mean(gather_list[0], 1.) assert_mean(gather_list[1], 2.) assert_mean(gather_list[2], 3.) assert_mean(gather_list[3], 4.) # ---------------- ALL_GATHER ----------------- if backend in ['tcp', 'nccl']: tensor = new_tensor(device_id, local_value)