Esempio n. 1
0
    def _test_gather_helper(self, group, group_id, rank):
        for dest in group:
            tensor = _build_tensor(dest + 1, rank)
            tensors = [_build_tensor(dest + 1, -1) for i in group] if rank == dest else []
            dist.gather(tensor, dst=dest, gather_list=tensors, group=group_id)
            if rank == dest:
                expected_tensors = [_build_tensor(dest + 1, i) for i in group]
                for t1, t2 in zip(tensors, expected_tensors):
                    self.assertEqual(t1, t2)

        self._barrier()
def gather(data, dst=0, group=None):
    """
    Run gather on arbitrary picklable data (not necessarily tensors).

    Args:
        data: any picklable object
        dst (int): destination rank
        group: a torch process group. By default, will use a group which
            contains all ranks on gloo backend.

    Returns:
        list[data]: on dst, a list of data gathered from each rank. Otherwise,
            an empty list.
    """
    if get_world_size() == 1:
        return [data]
    if group is None:
        group = _get_global_gloo_group()
    if dist.get_world_size(group=group) == 1:
        return [data]
    rank = dist.get_rank(group=group)

    tensor = _serialize_to_tensor(data, group)
    size_list, tensor = _pad_to_largest_tensor(tensor, group)

    # receiving Tensor from all ranks
    if rank == dst:
        max_size = max(size_list)
        tensor_list = [
            torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
        ]
        dist.gather(tensor, tensor_list, dst=dst, group=group)

        data_list = []
        for size, tensor in zip(size_list, tensor_list):
            buffer = tensor.cpu().numpy().tobytes()[:size]
            data_list.append(pickle.loads(buffer))
        return data_list
    else:
        dist.gather(tensor, [], dst=dst, group=group)
        return []
        def check_same_model_params(same_params: bool):
            # Check that all the params are the same on all ranks
            # This should be true with and without broadcast_buffers, we don't have any real buffer here
            receptacle: List[torch.Tensor] = []

            if dist.get_backend() != "nccl":
                for pg in optimizer.param_groups:
                    for p in pg["params"]:
                        # Check the params
                        receptacle = [p.clone() for _ in range(world_size)
                                      ] if rank == 0 else []
                        dist.gather(p, receptacle, dst=0)
                        if rank == 0:
                            for sync_p in receptacle[1:]:
                                if same_params:
                                    assert torch.all(
                                        torch.eq(receptacle[0], sync_p)
                                    ), "Models differ in between ranks"
                                else:
                                    assert not torch.all(
                                        torch.eq(receptacle[0], sync_p)
                                    ), "Gradients should not have been synced"

                # Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
                if broadcast_buffers:
                    for b in ddp_model.buffers():
                        receptacle = [b.clone() for _ in range(world_size)
                                      ] if rank == 0 else []
                        dist.gather(b, receptacle, dst=0)
                        if rank == 0:
                            for sync_b in receptacle[1:]:
                                if same_params:
                                    assert torch.all(
                                        torch.eq(receptacle[0], sync_b)
                                    ), "Models differ in between ranks"
                                else:
                                    assert not torch.all(
                                        torch.eq(receptacle[0], sync_b)
                                    ), "Gradients should not have been synced"

                        assert b.cpu().item() == 0.0
Esempio n. 4
0
def fp_send_proc(conv_wid, conv_wn, fc_wid, fc_wn, wid, wn, pred_wid, succ_wid,
                 comm_rank, world_sz, bs, subbs, pd, input_shp, output_shp,
                 fp_tail_list, shared_cnters, global_step, sta_lidx, end_lidx):
    #fp_send:0; fp_recv:1; bp_send:2; bp_recv:3
    iter_thresh = bs / subbs
    allreduce_group, fp_gather_group, bp_scatter_group = init_processes(
        comm_rank, world_sz)
    print("fp_send_proc comm_rank=", comm_rank)
    #if wid == wn -1:
    if succ_wid == -1:
        shared_cnters[1] = 4
        return
    local_fp_sent_counter = 0
    dst_rank = succ_wid * 4 + 1
    place_tensor_list = [torch.zeros(1)]
    while True:
        #print("fp send ", local_fp_sent_counter, " ", shared_cnters[1])
        #fp_tail_tensor
        if local_fp_sent_counter < shared_cnters[1]:
            # is it okay to directly send gpu tensor?
            #print("fp send ", comm_rank, "  -> ", dst_rank)
            #Hard code
            if wid == 0 or wid == 1:
                #print(fp_tail_list[local_fp_sent_counter].device)
                dist.gather(tensor=fp_tail_list[local_fp_sent_counter],
                            gather_list=[],
                            dst=dst_rank,
                            group=fp_gather_group,
                            async_op=False)
            elif wid == 2:
                dist.send(tensor=fp_tail_list[local_fp_sent_counter],
                          dst=dst_rank)
            #print("wid=",wid, " fp send ", fp_tail_list[local_fp_sent_counter].numel())
            #print("fin fp send ", comm_rank, "  -> ", dst_rank)
            local_fp_sent_counter += 1
        else:
            time.sleep(0.001)
        if local_fp_sent_counter == iter_thresh:
            #reset
            local_fp_sent_counter = 0
            shared_cnters[1].zero_()
def _gather(rank, rows, columns):
    dest = 0
    tensor = _get_tensor(rank, rows, columns)
    if rank == dest:
        tensors_list = _get_zeros_tensors_list(rows, columns)
        logger.debug('Rank: {},\nTensor BEFORE gather: {}. tensors_list: {}'.format(
            rank, tensor, tensors_list))
        dist.gather(tensor=tensor, gather_list=tensors_list)
        logger.debug('Rank: {},\nTensor AFTER gather: {}. tensors_list: {}\n'.format(
            rank, tensor, tensors_list))
        for i in range(dist.get_world_size()):
            assert torch.equal(tensors_list[i], _get_tensor(i, rows, columns)), \
                'Rank {}: tensors lists are not the same after gather.'
    else:
        logger.debug('Rank: {},\nTensor BEFORE gather: {}\n'.format(rank, tensor))
        dist.gather(tensor=tensor, dst=dest)
        logger.debug('Rank: {},\nTensor AFTER gather: {}\n'.format(rank, tensor))

    # tensor shouldn't have changed
    assert torch.equal(tensor, _get_tensor(rank, rows, columns)), \
        'Rank {}: Tensor got changed after gather.'.format(rank)
Esempio n. 6
0
def gather(data, *, dst_rank=0, group=None) -> list:
    """Run gather on arbitrary picklable data (not necessarily tensors).

    Args:
        data: any picklable object
        dst (int): destination rank
        group: a torch process group. By default, will use a group which
            contains all ranks on gloo backend.

    Returns:
        list[data]: on dst, a list of data gathered from each rank. Otherwise,
            an empty list.
    """

    if is_single_processes(group=group):
        return [data]
    if group is None:
        group = _get_global_gloo_group()

    tensor = _serialize_to_tensor(data, group)
    tensor, tensor_sizes = _pad_to_largest_tensor(tensor, group)

    if dist.get_rank(group=group) == dst_rank:
        max_tensor_size = max(tensor_sizes)
        tensor_list = [
            torch.empty((max_tensor_size, ),
                        dtype=torch.uint8,
                        device=tensor.device) for _ in tensor_sizes
        ]
        dist.gather(tensor, tensor_list, dst=dst_rank, group=group)

        datum = []
        for length, tensor in zip(tensor_sizes, tensor_list):
            single_data = tensor.cpu().numpy().tobytes()
            single_data = single_data[:length]
            datum.append(pickle.loads(single_data))
        return datum
    else:
        dist.gather(tensor, [], dst=dst_rank, group=group)
        return []
Esempio n. 7
0
def main_func(numProcesses, group, src_tensor):

    while (True):
        t = torch.zeros(15)  #THE FINAL ELEMENT IS LENGTH WHEN NOT PADDED
        gather_t = [torch.ones_like(t) for _ in range(numProcesses)]

        #every process in group sends tensor to this gather_t list
        dist.gather(tensor=t, gather_list=gather_t, dst=0, group=group)

        print('GATHERED DATA')
        print(gather_t[1][:15])
        print(gather_t[2][:15])

        to_scatter = torch.rand((5, 3))

        outputTens = torch.rand((5))

        #SIZE OF EACH TENSOR to scatter is main_params.num_children*2 +1
        #where first part is the actions, then probs, then leaf value
        #print('len to scatter: {}'.format(len(to_scatter)))
        print(to_scatter)
        to_scatter = np.split(to_scatter, 3, axis=1)

        #this is vital to make sure memory isn't shared among these vectors
        to_scatter = [torch.clone(t).squeeze() for t in to_scatter]

        #to_scatter = [x.view(1,-1) for x in to_scatter]

        #print('TO SCATTER: ',to_scatter)
        print('just before scattering: ')
        #print(to_scatter[1].type)
        #print(to_scatter[1][:15])
        #print(to_scatter[2][:15])
        dist.scatter(tensor=outputTens,
                     scatter_list=to_scatter,
                     src=0,
                     group=group)

        time.sleep(5)
        exit(1)
Esempio n. 8
0
def run(rank, numProcesses, group, trg_tensor):

    print('gathering rank: ', rank)

    #now just continually gather and scatter until scatter gives a
    #negative value which means we can exit
    #and also tell main_func that length is 0
    while (True):

        padded_output = torch.rand((15))
        print('Gathering rank: ', rank)
        print('rank: {}, sending to gather: {}'.format(rank, padded_output))
        dist.gather(tensor=padded_output, gather_list=None, dst=0,
                    group=group)  #send to process 2
        print('Finished gather: ', rank)

        model_response = torch.rand(5)
        dist.scatter(tensor=model_response,
                     scatter_list=None,
                     src=0,
                     group=group)
        print('scatter rank: {}, given: {}'.format(rank, model_response))
Esempio n. 9
0
def gather_proc(rank):
	print("rank = ", rank)
	if rank == 0:
		ga_t = torch.zeros(2)
	if rank == 1:
		ga_t = torch.ones(4)
	if rank == 2:
		ga_t = torch.ones(4)*2
	if rank == 3:
		ga_t = torch.ones(4)*3
	print("gather tensor = ", ga_t)
	init_processes(rank, 4, backend='gloo')
	if rank == 3:
		g1 = torch.zeros(4)
		g2 = torch.zeros(4)
		g3 = torch.zeros(4)
		g4 = g3
		gather_list=[g1,g2,g3,g4]
		dist.gather(tensor= ga_t, gather_list = gather_list, dst = 3 )
		print(gather_list)
	else:
		dist.gather(tensor= ga_t, gather_list = [], dst = 3 )
Esempio n. 10
0
    def score_parameters(self) -> List[Tensor]:
        """
        :return: List of Tensors the same shapes as the given Parameters where
            each Parameter's elements are scored by their weight times the direction
            of their gradient.
        """
        if not self._is_ddp:
            return self._movement_scores

        # move all movement scores to one device and combine
        scores_flat = [
            score.view(-1).to("cpu") for score in self._movement_scores
        ]
        if self._is_main_proc:
            gather_list = [
                torch.zeros_like(scores_flat)
                for _ in range(dist.get_world_size())
            ]
            dist.gather(scores_flat,
                        gather_list=gather_list,
                        group=self._gloo_handle,
                        dst=0)
            total_scores_flat = torch.sum(torch.stack(gather_list), dim=0)
        else:
            dist.gather(scores_flat, group=self._gloo_handle, dst=0)

        # broadcast total scores to all devices
        total_scores_flat = self._broadcast_list_from_main(
            [total_scores_flat if self._is_main_proc else None])[0]

        # move total scores to correct device on each process
        score_idx = 0
        for idx, score in enumerate(self._movement_scores):
            next_idx = score_idx + score.numel()
            score.view(-1)[:] = total_scores_flat[score_idx:next_idx].to(
                score.device)
            score_idx = next_idx

        return self._movement_scores
Esempio n. 11
0
def average_models(model, group=None, choose_r0=True, weights=None):
    global fl_round
    global rat_per_class
    gp_size = len(all_groups_np[fl_round%len(all_groups)])
    if rank == 0 and opt.weight_avg and weights is not None:
        cur_gp = all_groups_np[fl_round%len(all_groups)]
        if opt.weight_scheme == 'exp':
            e_w = [np.exp(w.item()) for w in weights]               #Getting e^w for each w in weights (w here is the success rate of workers' generators)
        else:
            e_w = [w.item() for w in weights]

        e_w = np.array(e_w)
        if not choose_r0:
            e_w/= sum(e_w[1:])
        else:
            e_w/= sum(e_w)
        if opt.weight_scheme == 'dirac':
            e_w = [0 if w < 0.5 else w for w in e_w]		#The threshold here is 0.5
            #Reweighting after removing the harmful/useless updates (could work as a simulation to taking thee forgiving updates)
            if not choose_r0:
                e_w/= sum(e_w[1:])
            else:
                e_w/= sum(e_w)

    for param in model.parameters():
        if rank == 0 and not choose_r0:				#If rank=0 is not in included in this round, put zeros instead
            param.data = torch.zeros(param.size()).cuda()
        if not opt.weight_avg or weights is None:
            dist.reduce(param.data, dst=0, op=dist.ReduceOp.SUM, group=group)
            param.data /= (gp_size if choose_r0 else gp_size - 1)
        else:
            gather_list = []
            if rank == 0:
                gather_list = [torch.zeros(param.size()).cuda() if cuda else torch.zeros(param.size()) for _ in range(gp_size)]
            dist.gather(param.data, gather_list, dst=0, group=group)
            if rank == 0:
                param.data = torch.zeros(param.size()).cuda() if cuda else torch.zeros(param.size())
                for w,t in zip(e_w,gather_list):
                    param.data+= t*w
Esempio n. 12
0
def average_gradients(model):
    size = float(dist.get_world_size())
    for param in model.parameters():
        """ using all_reduce """
        # dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
        # param.grad.data /= size

        """ using gather and scatter """
        # group = dist.new_group(list(range(int(size))))
        gather_list, scatter_list = None, None
        if args.rank == 0:
            gather_list = [torch.zeros_like(param.grad.data)] * int(size)
            scatter_list = [torch.zeros_like(param.grad.data)] * int(size)

        dist.gather(tensor=param.grad.data, dst=0, gather_list=gather_list)
        # dist.gather(tensor=param.grad.data, dst=0)
        if args.rank == 0:
            param.grad.data /= size

        dist.scatter(tensor=param.grad.data, src=0, scatter_list=scatter_list)
        # dist.scatter(tensor=param.grad.data, src=0)

        """ using ring-reduce """
Esempio n. 13
0
def run(rank,numProcesses,group,maxlen,main_params,trg_tensor):
	
	mcts = MCTS(tgt_tensor=trg_tensor,group=group,rankInGroup=rank,
				max_len=maxlen,main_params=main_params)


	#here actions is list of actions corresponding to the 
	#200 probabilities in mcts_probs
	bleu, output_states, mcts_probs,actions = mcts.translate_sentence()
	#write to file
	fileName = globalsFile.CODEPATH+'MCTSFiles/rank'+str(rank)+'.json'
	with open(fileName,'w') as f:
		json.dump([bleu,output_states,mcts_probs,actions],f)

	print('rank: ',rank, ' is done NOW WAITING FOR REST')


	while(True):
		#now just gathering and scattering until main exits
		padded_output = torch.zeros(maxlen+1)*globalsFile.BLANK_WORD_ID
		dist.gather(tensor=padded_output,gather_list=None, dst=0,group=group) #send to process 2
		model_response = torch.ones(2*main_params.num_children + 1).double()
		dist.scatter(tensor=model_response,scatter_list=None,src=0,group=group)
Esempio n. 14
0
def prune_and_eval(rank, size, param_name, prune_threshold, ref_model_dict,
                   ref_sorted_weights, results):
    local_ref_model_dict = ref_model_dict
    local_sorted_weights = ref_sorted_weights
    gpu_id = GPU_ID

    if rank >= 4 and rank < size:  # split tasks to different GPUs
        gpu_id = GPU_ID2
        cuda.set_device(gpu_id)

    local_checkpoint = torch.load(weights,
                                  map_location=lambda storage, loc: storage)
    local_opt, _, _ = opt_initialize(local_checkpoint,
                                     'opennmt_translate_opt.pt',
                                     'opennmt_translate_dummy_opt.pt')
    local_opt.gpuid = [gpu_id]
    _train = torch.load(TRAIN_DATA + '.train.pt')
    _valid = torch.load(TRAIN_DATA + '.valid.pt')
    local_fields = load_fields(_train, _valid, local_checkpoint, local_opt)
    #local_ref_model = init_train_model(local_checkpoint, local_opt, local_fields) # fields need data

    thenet = init_train_model(local_checkpoint, local_opt,
                              local_fields)  # fields need data
    pruned_model = apply_prune(thenet, local_ref_model_dict,
                               local_sorted_weights, param_name,
                               prune_threshold[rank])
    fitness = evaluate(pruned_model, _valid, local_fields, local_opt)
    fitness[2] = rank
    tensor_list = []

    if rank == 0:  # master node
        tensor_list = [torch.FloatTensor([0.0, 0.1, 0.2]) for i in range(size)]
        dist.gather(fitness, gather_list=tensor_list)
        for ind_i in range(size):
            results[ind_i].copy_(tensor_list[ind_i])
    else:
        dist.gather(fitness, dst=0)
Esempio n. 15
0
def run():
    src = dst = 0;
    mytensor = torch.zeros(1000)
    dist.scatter(mytensor,src=src)

    #processing
    features,num_frames,freqs = mysimpl(mytensor)
    frames_features = {}
    for frame in range(num_frames+1):
        frames_features[frame] = []
    for x in features:
        # print(x[0],x[1],x[2]) x[1] = framenumber x[0] amp x[2] freq
        frames_features[int(x[1])].append((x[0],x[2]))
    frame_freq_bins =[]
    for x in range(num_frames+1):
        freq_bins = np.zeros(2048)
        #dict with key as freqbin
        to_be_added ={}
        for y in frames_features[x]:
            index_i = np.abs(freqs-y[1]).argmin();
            if(y[1] < freqs[index_i]):
                index_i -=1;
            if index_i not in to_be_added.keys():
                to_be_added[index_i] = []
            to_be_added[index_i].append(y[0])
        all_non_zero_bins = to_be_added.keys()
        for x in all_non_zero_bins:
            amp_array =to_be_added[x]
            amp_array = np.array(amp_array)
            avg_amp = np.mean(amp_array)
            freq_bins[x] += avg_amp
            # freq_bins = torch.LongTensor(freq_bins)
        frame_freq_bins.append(freq_bins)
    frame_freq_bins = np.array(frame_freq_bins)
    frame_freq_bins = torch.from_numpy(frame_freq_bins)
    dist.gather(frame_freq_bins,dst=dst)
    return;
Esempio n. 16
0
def main_func(numProcesses, group):
    #gather_list = [torch.ones(1),torch.ones(1)]
    t = torch.ones(1)
    for i in range(10):
        print('WORLD SIZE: ', dist.get_world_size())

        gather_t = [torch.ones_like(t) for _ in range(dist.get_world_size())]
        dist.gather(tensor=torch.zeros(1),
                    gather_list=gather_t,
                    dst=0,
                    group=group)

        #tensor = torch.ones(1)*9
        #dist.all_reduce(tensor, op=dist.reduce_op.SUM, group=group)
        print('IN MAIN gather: ', gather_t)
        gather_t = [i + 1 for i in gather_t]
        #now add 1 to each of these then scatter back
        outputTens = torch.ones(1)
        dist.scatter(tensor=outputTens,
                     scatter_list=gather_t,
                     src=0,
                     group=group)

        print('main process: outputtens: ', outputTens)
Esempio n. 17
0
def run(rank, size):
    t = torch.tensor([rank for _ in range(1)])

    for i in range(10):
        print("----gather-------")
        if rank == 0:
            gather_t = [torch.ones_like(t) for _ in range(size)]
            dist.gather(tensor=t, dst=0, gather_list=gather_t)
            print(gather_t)
        else:
            t.add_(rank)
            dist.gather(tensor=t, dst=0, gather_list=[])
            print(t)

        # t.add_(1)

        print("----broadcast-------")
        if rank == 0:
            b = torch.tensor([i for _ in range(1)])
            dist.broadcast(tensor=b, src=0)
            print(b)
        else:
            dist.broadcast(tensor=t, src=0)
            print(t)
Esempio n. 18
0
 def backward(ctx, *grad_output):
     global myreq
     my_rank = dist.get_rank()
     a2ai = ctx.a2ai
     grad_output = [t.contiguous() for t in grad_output]
     mb_split_lengths = a2ai.gNS if a2ai.gNS else [a2ai.lN] * my_size
     per_rank_split_lengths = a2ai.gSS if a2ai.gSS else [a2ai.lS] * my_size
     grad_inputs = [grad_output[0].new_empty([ctx.a2ai.N, ctx.a2ai.E]) for _ in range(a2ai.lS)]
     req_list = []
     ind = 0
     for i in range(my_size):
         for j in range(per_rank_split_lengths[i]):
             gather_list = list(grad_inputs[j].split(mb_split_lengths, dim = 0)) if i == my_rank else None
             req = dist.gather(grad_output[ind], gather_list, dst = i, async_op=True)
             req_list.append(req)
             ind += 1
     myreq.req = req_list
     myreq.tensor = grad_inputs
     return tuple(grad_output)
Esempio n. 19
0
 def backward(ctx, *grad_output):
     global myreq
     my_rank = dist.get_rank()
     #print("All2All_Scatter_Wait:backward")
     assert len(grad_output) == my_size
     scatter_list = [t.contiguous() for t in grad_output]
     a2ai = ctx.a2ai
     mb_split_lengths = a2ai.gNS if a2ai.gNS else a2ai.lN
     emb_split_lengths = a2ai.gSS if a2ai.gSS else [a2ai.lS] * my_size
     grad_input = grad_output[0].new_empty([a2ai.N, a2ai.E*a2ai.lS])
     gather_list = list(grad_input.split(mb_split_lengths, dim=0))
     req_list = []
     for i in range(my_size):
         #req = dist.scatter(gather_list[i], scatter_list if i == my_rank else [], src=i, async_op=True)
         req = dist.gather(scatter_list[i], gather_list if i == my_rank else [], dst=i, async_op=True)
         req_list.append(req)
     myreq.req = req_list
     myreq.tensor = grad_input
     return grad_output
    def gather(self, collectiveArgs, retFlag=False, pair=False):
        if pair:
            ipTensors = collectiveArgs.ipTensor_pair
            opTensors = collectiveArgs.opTensor_pair
        else:
            ipTensors = collectiveArgs.ipTensor
            opTensors = collectiveArgs.opTensor

        retObj = dist.gather(
            gather_list=opTensors
            if (collectiveArgs.global_rank == collectiveArgs.srcOrDst)
            else None,
            tensor=ipTensors,
            dst=collectiveArgs.srcOrDst,
            group=collectiveArgs.group,
            async_op=collectiveArgs.asyncOp,
        )  # synchronicity is maintained in runColl

        if collectiveArgs.asyncOp:
            collectiveArgs.waitObj.append(retObj)

        if retFlag:
            return retObj
Esempio n. 21
0
def gather_objs(data, dst=0, group=None):
    """
    Run gather on arbitrary picklable data (not necessarily tensors).

    Args:
        data: any picklable object
        dst (int): destination rank
        group: a torch process group. By default, will use a group which
            contains all ranks on gloo backend.

    Returns:
        list[data]: on dst, a list of data gathered from each rank. Otherwise,
            an empty list.
    """
    if get_world_size() == 1:
        return [data]
    if dist.get_world_size() == 1:
        return [data]
    rank = dist.get_rank()
    world_size = get_world_size()

    tensor, local_size = _object_to_tensor(data)
    size_list = [
        torch.zeros(1, dtype=torch.int64, device=tensor.device) for _ in range(world_size)
    ]
    dist.gather(local_size, size_list if rank == 0 else None)

    # receiving Tensor from all ranks
    if rank == dst:
        tensor_list = [
            torch.empty((size,), dtype=torch.uint8, device=tensor.device) for size in size_list
        ]
        dist.gather(tensor, tensor_list, dst=dst)

        data_list = []
        for size, tensor in zip(size_list, tensor_list):
            buffer = tensor.cpu().numpy().tobytes()[:size]
            data_list.append(pickle.loads(buffer))
        return data_list
    else:
        dist.gather(tensor, dst=dst)
        return []
Esempio n. 22
0
    def gather(
        self,
        dst: int = 0,
        out: Optional[torch.Tensor] = None,
    ) -> None:
        """
        Creates a full :class:`Tensor` on rank ``dst`` by gathering all shards of the
        sharded tensor.

        The API needs to be called on all ranks in SPMD fashion. All ranks should have
        the same ``dst``. ``out`` should be a tensor of the same size as the overall
        size of the sharded tensor on ``dst`` and ``None`` on all other ranks.

        Args:
            dst(int): The rank where full tensor is constructed.
                Default: 0
            out (:class `torch.Tensor`, optional): The output full tensor.
                Must to be provided ONLY on ``dst`` rank.
                Default: ``None``
        """
        def shard_size(shard_md):
            return reduce((lambda x, y: x * y),
                          shard_md.shard_sizes)  # type: ignore[attr-defined]

        rank = dist.get_rank(self._process_group)
        full_size = self.metadata().size
        _validate_output_tensor_for_gather(rank, dst, full_size, out)

        local_shards = self.local_shards()
        world_size = dist.get_world_size(self._process_group)
        rank_sizes = [0 for _ in range(world_size)]
        max_rank_size = 0
        shard_placement: Dict[ShardMetadata, Tuple[int, int]] = dict()
        # collect sizes
        for shard_md in self.metadata().shards_metadata:
            shard_rank = cast(_remote_device, shard_md.placement).rank()
            assert shard_rank is not None

            shard_placement[shard_md] = (shard_rank, rank_sizes[shard_rank])
            rank_sizes[shard_rank] += shard_size(shard_md)
            max_rank_size = max(max_rank_size, rank_sizes[shard_rank])

        gather_list: Optional[List[torch.Tensor]]
        if rank == dst:
            assert out is not None
            gather_list = [
                torch.empty((max_rank_size, ), device=out.device)
                for _ in range(world_size)
            ]
        else:
            gather_list = None

        with torch.no_grad():
            data = torch.empty(max_rank_size,
                               device=self._get_preferred_device())

            for shard in local_shards:
                src = shard.tensor.flatten()
                shard_offset = shard_placement[shard.metadata][1]
                data[shard_offset:shard_offset + src.numel()].copy_(src)

        dist.gather(
            tensor=data,
            gather_list=gather_list,
            dst=dst,
            group=self._process_group,
        )
        if rank != dst:
            return
        # In _validate_output_tensor_for_gather, we raise if out == None and rank == dst
        out = cast(torch.Tensor, out)
        assert gather_list is not None

        full_size = self.metadata().size
        dims = len(full_size)
        for shard_md in self.metadata().shards_metadata:
            rank, rank_offset = shard_placement[shard_md]
            tensor = gather_list[rank]
            tensor = tensor[rank_offset:rank_offset + shard_size(shard_md)]
            tensor = tensor.view(shard_md.shard_sizes)

            out_narrow_view = out
            for dim in range(dims):
                out_narrow_view = out_narrow_view.narrow(
                    dim,
                    shard_md.shard_offsets[dim],
                    shard_md.shard_sizes[dim],
                )

            out_narrow_view.copy_(tensor)
Esempio n. 23
0
 def gather(self, tensor, gather_list, dst=0, async_op=False):
     return dist.gather(tensor, gather_list, dst, self.group, async_op)
Esempio n. 24
0
    for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
        tensor = torch.ByteTensor(bytes).fill_(42)
        for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
            for i in range(0, num_tensors):
                dist.scatter(tensor, src=0)
dist.barrier()

if rank == 0:
    print_header("gather")
    for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
        tensor = torch.ByteTensor(bytes).fill_(42)
        tensors = [tensor for n in range(0, dist.get_world_size())]
        for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
            start = timer()
            for i in range(0, num_tensors):
                dist.gather(tensor, gather_list=tensors)
            end = timer()
            print_stats(bytes, num_tensors, end - start)
    print()
else:
    for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
        tensor = torch.ByteTensor(bytes).fill_(42)
        for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
            for i in range(0, num_tensors):
                dist.gather(tensor, dst=0)
dist.barrier()

if rank == 0:
    print_header("all gather")
    for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
        tensor = torch.ByteTensor(bytes).fill_(42)
    def retrieve(self, question_hidden_states: np.ndarray,
                 n_docs: int) -> Tuple[np.ndarray, List[dict]]:
        """
        Retrieves documents for specified ``question_hidden_states``. The main process, which has the access to the index stored in memory, gathers queries
        from all the processes in the main training process group, performs the retrieval and scatters back the results.

        Args:
            question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`):
                A batch of query vectors to retrieve with.
            n_docs (:obj:`int`):
                The number of docs retrieved per query.

        Output:
            retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)`
                The retrieval embeddings of the retrieved docs per query.
            doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`)
                The ids of the documents in the index
            doc_dicts (:obj:`List[dict]`):
                The retrieved_doc_embeds examples per query.
        """

        # single GPU training
        if not dist.is_initialized():
            doc_ids, retrieved_doc_embeds = self._main_retrieve(
                question_hidden_states, n_docs)
            return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(
                doc_ids)

        # distributed training
        world_size = dist.get_world_size(group=self.process_group)

        # gather logic
        gather_list = None
        if self._is_main():
            gather_list = [
                torch.empty(question_hidden_states.shape, dtype=torch.float32)
                for _ in range(world_size)
            ]
        dist.gather(torch.tensor(question_hidden_states),
                    dst=0,
                    gather_list=gather_list,
                    group=self.process_group)

        # scatter logic
        n_queries = question_hidden_states.shape[0]
        scatter_ids = []
        scatter_vectors = []
        if self._is_main():
            assert len(gather_list) == world_size
            ids, vectors = self._main_retrieve(
                torch.cat(gather_list).numpy(), n_docs)
            ids, vectors = torch.tensor(ids), torch.tensor(vectors)
            scatter_ids = self._chunk_tensor(ids, n_queries)
            scatter_vectors = self._chunk_tensor(vectors, n_queries)
        doc_ids = self._scattered(scatter_ids, [n_queries, n_docs],
                                  target_type=torch.int64)
        retrieved_doc_embeds = self._scattered(
            scatter_vectors,
            [n_queries, n_docs, question_hidden_states.shape[1]])

        return retrieved_doc_embeds.numpy(), doc_ids.numpy(
        ), self.index.get_doc_dicts(doc_ids)
Esempio n. 26
0
def run(rank, model, train_pics, train_bsz):
    workers = [int(v) for v in str(args.learners).split('-')]
    _group = [w for w in workers].append(rank)
    group = dist.new_group(_group)

    for p in model.parameters():
        scatter_p_list = [p.data for _ in range(len(workers) + 1)]
        dist.scatter(tensor=p.data, scatter_list=scatter_p_list, group=group)

    print('Model Sent Finished!')

    print('Begin!')

    transform = transforms.Compose([
        transforms.Resize(128),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    tmp = [
        (0, 0)
        for _ in range(int(math.ceil(train_pics / (len(workers) * train_bsz))))
    ]

    pre_time = datetime.datetime.now()
    for epoch in range(args.epochs):
        for batch_idx, (_, _) in enumerate(tmp):
            for param in model.parameters():
                tensor = torch.zeros_like(param.data)

                # FIXME FIXED:gather_list中的每个Tensor都必须是新的对象,否则会出问题
                gather_list = [
                    torch.zeros_like(param.data)
                    for _ in range(len(workers) + 1)
                ]
                dist.gather(tensor=tensor,
                            gather_list=gather_list,
                            group=group)
                tensor = sum(gather_list) / len(workers)
                param.data -= tensor
                scatter_list = [param.data for _ in range(len(workers) + 1)]
                dist.scatter(tensor=tensor,
                             scatter_list=scatter_list,
                             group=group)

            print('Done {}/{}!'.format(batch_idx, len(tmp)))
        print('Done Epoch {}/{}!'.format(epoch + 1, args.epochs))

    end_time = datetime.datetime.now()
    # 测试ps的模型准确率
    h, remainder = divmod((end_time - pre_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = "Time %02d:%02d:%02d" % (h, m, s)

    test_dataset = datasets.CIFAR10(args.data_dir,
                                    train=False,
                                    download=False,
                                    transform=transform)
    criterion = torch.nn.CrossEntropyLoss()
    test_data = DataLoader(test_dataset, batch_size=128, shuffle=True)

    test_loss, acc = test_model(dist.get_rank(),
                                model,
                                test_data,
                                criterion=criterion)
    print('total time ' + str(time_str))
    f = open('./result_' + str(rank) + '_' + args.model + '.txt', 'a')
    f.write('Rank: ' + str(rank) + ', \tEpoch: ' + str(args.epochs) +
            ', \tTestLoss: ' + str(test_loss) + ', \tTestAcc: ' + str(acc) +
            ', \tTotalTime: ' + str(time_str) + '\n')
    f.close()
Esempio n. 27
0
def run(rank, workers, model, save_path, train_data, test_data):
    # 获取ps端传来的模型初始参数
    _group = [w for w in workers].append(0)
    group = dist.new_group(_group)

    for p in model.parameters():
        tmp_p = torch.zeros_like(p)
        dist.scatter(tensor=tmp_p, src=0, group=group)
        p.data = tmp_p
    print('Model recved successfully!')

    optimizer = MySGD(model.parameters(), lr=0.01, momentum=0.5)
    criterion = torch.nn.CrossEntropyLoss()
    print('Begin!')

    for epoch in range(args.epochs):
        pre_time = datetime.datetime.now()
        model.train()

        # AlexNet在指定epoch减少学习率LR
        if args.model == 'AlexNet':
            if epoch + 1 in [40, 60]:
                for param_group in optimizer.param_groups:
                    param_group['lr'] *= 0.1
                    print('LR Decreased! Now: {}'.format(param_group['lr']))

        epoch_train_loss = 0
        for batch_idx, (data, target) in enumerate(train_data):
            data, target = Variable(data), Variable(target)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()

            delta_ws = optimizer.get_delta_w()
            # 同步操作
            for idx, param in enumerate(model.parameters()):
                dist.gather(tensor=delta_ws[idx], dst=0, group=group)
                recv = torch.zeros_like(delta_ws[idx])
                dist.scatter(tensor=recv, src=0, group=group)
                param.data = recv

            epoch_train_loss += loss.data.item()
            print('Rank {}, Epoch {}, Batch {}/{}, Loss:{}'.format(
                rank, epoch, batch_idx, len(train_data), loss.data.item()))

        end_time = datetime.datetime.now()
        h, remainder = divmod((end_time - pre_time).seconds, 3600)
        m, s = divmod(remainder, 60)
        time_str = "Time %02d:%02d:%02d" % (h, m, s)

        epoch_train_loss /= len(train_data)
        epoch_train_loss = format(epoch_train_loss, '.4f')

        # 训练结束后进行test
        test_loss, acc = test_model(rank,
                                    model,
                                    test_data,
                                    criterion=criterion)
        print('total time ' + str(time_str))
        f = open('./result_' + str(rank) + '_' + args.model + '.txt', 'a')
        f.write('Rank: ' + str(rank) + ', \tEpoch: ' + str(epoch + 1) +
                ', \tTrainLoss: ' + str(epoch_train_loss) + ', \tTestLoss: ' +
                str(test_loss) + ', \tTestAcc: ' + str(acc) + ', \tTime: ' +
                str(time_str) + '\n')
        f.close()

        if (epoch + 1) % 5 == 0:
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            torch.save(
                model.state_dict(),
                save_path + '/' + args.model + '_' + str(epoch + 1) + '.pkl')
Esempio n. 28
0
def main_func(numProcesses,group,src_tensor,maxlen,model):
	#exit('Exiting 1')
	starting_time = time.time()
	#POSSIBLE PROBLEM:
	#Some of the processes will finish before others from needing less calls if somehow continually hit EOS on their sims
	#now continually gather then scatter until we see that
	#the results from gather
	
	while(True):
		t = torch.zeros(maxlen+1) #THE FINAL ELEMENT IS LENGTH WHEN NOT PADDED
		gather_t = [torch.ones_like(t) for _ in range(numProcesses)]
		
		#every process in group sends tensor to this gather_t list
		dist.gather(tensor=t,gather_list=gather_t,dst=0,group=group)
		
		#print('GATHERED DATA')
		#print(gather_t[1][:15])
		#print(gather_t[2][:15])

		#trim them down to the maximum length of all gathered when remove padding
		#don't use first element in list since it's from this process
		dec_lengths = torch.tensor([x[-1] for x in gather_t[1:]]).long()
		#print('Dec_lengths: ',dec_lengths)
		#assert(min(dec_lengths) > 0)
		max_gathered_len = int(dec_lengths.max().item())
		if max_gathered_len == 0:
			print('max gathered len is 0')
            #for some reason last scatter not seen by other processes so best 
			#way to shut them all down is throw exception which returns 
			#control our main function.
			exit(1) 

		#print('max gathered len:',max_gathered_len)
		#print('gather_t: ',gather_t)
		#TO DO: 
		#ONCE THIS IS WORKING: don't send blanks through, can filter

		dec_input = torch.cat([x[:max_gathered_len].view(-1,1) for x in gather_t[1:]],1).long()
		
		#print('dec_input: ',dec_input)
		dec_lengths[dec_lengths==0]+=1 #allows function to work for trees that are finished
		#mask for decoder_input happens within this function
		log_probs, values = model.forward(src_tensor,dec_input,
										sentence_lens=dec_lengths,req_grad=False)

		#need to get top model.num_children probs and their corresponding actions
		#which are the indices
		#print('log probs shape: ',log_probs.shape)
		sorted_probs,inds = torch.sort(log_probs,dim=1,descending=True)
		inds = inds.double() #so that concat with probs and values

		#now stack sorted_probs under inds then put value underneath, (this is what 
		#the other processes are expecting as format)
		#print('INDS shape: ',inds.shape)
		#print('values shape: ',values.shape)
		to_scatter = torch.cat([inds[:,:model.num_children].transpose(0,1),
							sorted_probs[:,:model.num_children].transpose(0,1),
							values.unsqueeze(0)], dim=0).to('cpu')
		#print('to_scatter shape: ',to_scatter.shape)
		#print('values: ',values)
		#print(to_scatter[-1,:])
		#print('shape to_scatter: ',to_scatter.shape)
		#print('to scatter type : ',to_scatter.type())
		#print('shape to_scatter: ',to_scatter.shape)
		#now have a tensor which we need to split column wise into lists
		to_scatter = list(np.split(to_scatter,to_scatter.shape[1],axis=1))
		#print('after split: len: {}, first el: {}'.format(len(to_scatter),to_scatter[0]))
		#print('first 50 of to_scatter')
		#print(to_scatter[1][:50])
		#exit(1)

		#need to clone compoennets so that they don't share memory
		to_scatter = [t.clone().squeeze(1) for t in to_scatter]
		#print('len to_scatter: {}, shape to_scatter[0]: {}'.format(len(to_scatter),to_scatter[0].shape))
		#now add fake tensor for this process to start of this list
		to_scatter.insert(0,torch.ones(len(to_scatter[0])).double())
		
		outputTens = torch.ones(len(to_scatter[0])).double()
		
		#SIZE OF EACH TENSOR to scatter is main_params.num_children*2 +1
		#where first part is the actions, then probs, then leaf value
		#print('len to scatter: {}'.format(len(to_scatter)))
		#print('just before scattering: ')
		#print(to_scatter[1].type)
		#print(to_scatter[1][:15])
		#print(to_scatter[2][:15])
		dist.scatter(tensor=outputTens,scatter_list=to_scatter,src=0,group=group)
Esempio n. 29
0
def train_and_validate_federated_drfa(client):
    """The training scheme of Distributionally Robust Federated Learning DRFA.
        paper: https://papers.nips.cc/paper/2020/hash/ac450d10e166657ec8f93a1b65ca1b14-Abstract.html
    """
    log('start training and validation with Federated setting.',
        client.args.debug)

    if client.args.evaluate and client.args.graph.rank == 0:
        # Do the testing on the server and return
        do_validate(client.args,
                    client.model,
                    client.optimizer,
                    client.criterion,
                    client.metrics,
                    client.test_loader,
                    client.all_clients_group,
                    data_mode='test')
        return

    # Initialize lambda variable proportianate to their sample size
    if client.args.graph.rank == 0:
        gather_list_size = [
            torch.tensor(0.0) for _ in range(client.args.graph.n_nodes)
        ]
        dist.gather(torch.tensor(client.args.num_samples_per_epoch,
                                 dtype=torch.float32),
                    gather_list=gather_list_size,
                    dst=0)
        lambda_vector = torch.stack(
            gather_list_size) / client.args.train_dataset_size
    else:
        dist.gather(torch.tensor(client.args.num_samples_per_epoch,
                                 dtype=torch.float32),
                    dst=0)
        lambda_vector = torch.tensor([1 / client.args.graph.n_nodes] *
                                     client.args.graph.n_nodes)

    tracker = define_local_training_tracker()
    start_global_time = time.time()
    tracker['start_load_time'] = time.time()
    log('enter the training.', client.args.debug)

    # Number of communication rounds in federated setting should be defined
    for n_c in range(client.args.num_comms):
        client.args.rounds_comm += 1
        client.args.comm_time.append(0.0)
        # Configuring the devices for this round of communication
        # TODO: not make the server rank hard coded
        log("Starting round {} of training".format(n_c + 1), client.args.debug)
        online_clients = set_online_clients(client.args)
        if n_c == 0:
            # The first round server should be in the communication to initilize its own training
            online_clients = online_clients if 0 in online_clients else online_clients + [
                0
            ]
        online_clients_server = online_clients if 0 in online_clients else online_clients + [
            0
        ]
        online_clients_group = dist.new_group(online_clients_server)
        client.args.drfa_gamma *= 0.9
        if client.args.graph.rank in online_clients_server:
            if client.args.federated_type == 'scaffold':
                st = time.time()
                client.model_server, client.model_server_control = distribute_model_server_control(
                    client.model_server,
                    client.model_server_control,
                    online_clients_group,
                    src=0)
                client.args.comm_time[-1] += time.time() - st
            else:
                st = time.time()
                model_server = distribute_model_server(client.model_server,
                                                       online_clients_group,
                                                       src=0)
                client.args.comm_time[-1] += time.time() - st
            client.model.load_state_dict(client.model_server.state_dict())

            # Send related variables to drfa algorithm
            st = time.time()
            dist.broadcast(client.lambda_vector,
                           src=0,
                           group=online_clients_group)
            # Sending the random number k to all nodes:
            # Does not fully support the epoch mode now
            k = torch.randint(low=1, high=client.args.local_step, size=(1, ))
            dist.broadcast(k, src=0, group=online_clients_group)
            client.args.comm_time[-1] += time.time() - st

            k = k.tolist()[0]
            local_steps = 0
            # Start running updates on local machines
            if client.args.graph.rank in online_clients:
                is_sync = False
                while not is_sync:
                    for _input, _target in client.train_loader:
                        local_steps += 1
                        # Getting the k-th model for dual variable update
                        if k == local_steps:
                            client.kth_model.load_state_dict(
                                client.model.state_dict())
                        client.model.train()

                        # update local step.
                        logging_load_time(tracker)

                        # update local index and get local step
                        client.args.local_index += 1
                        client.args.local_data_seen += len(_target)
                        get_current_epoch(client.args)
                        local_step = get_current_local_step(client.args)

                        # adjust learning rate (based on the # of accessed samples)
                        lr = adjust_learning_rate(client.args,
                                                  client.optimizer,
                                                  client.scheduler)

                        # load data
                        _input, _target = load_data_batch(
                            client.args, _input, _target, tracker)
                        # Skip batches with one sample because of BatchNorm issue in some models!
                        if _input.size(0) == 1:
                            is_sync = is_sync_fed(client.args)
                            break

                        # inference and get current performance.
                        client.optimizer.zero_grad()
                        loss, performance = inference(client.model,
                                                      client.criterion,
                                                      client.metrics, _input,
                                                      _target)

                        # compute gradient and do local SGD step.
                        loss.backward()

                        if client.args.federated_type == 'fedgate':
                            for client_param, delta_param in zip(
                                    client.model.parameters(),
                                    client.model_delta.parameters()):
                                client_param.grad.data -= delta_param.data
                        elif client.args.federated_type == 'scaffold':
                            for cp, ccp, scp in zip(
                                    client.model.parameters(),
                                    client.model_client_control.parameters(),
                                    client.model_server_control.parameters()):
                                cp.grad.data += scp.data - ccp.data

                        client.optimizer.step(
                            apply_lr=True,
                            apply_in_momentum=client.args.in_momentum,
                            apply_out_momentum=False)

                        # logging locally.
                        # logging_computing(tracker, loss, performance, _input, lr)

                        # display the logging info.
                        # logging_display_training(client.args, tracker)

                        # reset load time for the tracker.
                        tracker['start_load_time'] = time.time()
                        is_sync = is_sync_fed(client.args)
                        if is_sync:
                            break
            else:
                log("Offline in this round. Waiting on others to finish!",
                    client.args.debug)

        # Validate the local models befor sync
            do_validate(client.args,
                        client.model,
                        client.optimizer,
                        client.criterion,
                        client.metrics,
                        client.train_loader,
                        online_clients_group,
                        data_mode='train',
                        local=True)
            if client.args.fed_personal:
                do_validate(client.args,
                            client.model,
                            client.optimizer,
                            client.criterion,
                            client.metrics,
                            client.val_loader,
                            online_clients_group,
                            data_mode='validation',
                            local=True)
            # Sync the model server based on model_clients
            log('Enter synching', client.args.debug)
            tracker['start_sync_time'] = time.time()
            client.args.global_index += 1

            if client.args.federated_type == 'fedgate':
                client.model_server, client.model_delta = fedgate_aggregation(
                    client.args,
                    client.model_server,
                    client.model,
                    client.model_delta,
                    client.model_memory,
                    online_clients_group,
                    online_clients,
                    client.optimizer,
                    lr,
                    local_steps,
                    lambda_weight=client.lambda_vector[
                        client.args.graph.rank].item())
            elif client.args.federated_type == 'scaffold':
                client.model_server, client.model_client_control, client.model_server_control = scaffold_aggregation(
                    client.args,
                    client.model_server,
                    client.model,
                    client.model_server_control,
                    client.model_client_control,
                    online_clients_group,
                    online_clients,
                    client.optimizer,
                    lr,
                    local_steps,
                    lambda_weight=client.lambda_vector[
                        client.args.graph.rank].item())
            else:
                client.model_server = fedavg_aggregation(
                    client.args,
                    client.model_server,
                    client.model,
                    online_clients_group,
                    online_clients,
                    client.optimizer,
                    lambda_weight=client.lambda_vector[
                        client.args.graph.rank].item())
            # Average the kth_model
            client.kth_model = aggregate_models_virtual(
                client.args, client.kth_model, online_clients_group,
                online_clients)
            # evaluate the sync time
            logging_sync_time(tracker)

            # Do the validation on the server model
            do_validate(client.args,
                        client.model_server,
                        client.optimizer,
                        client.criterion,
                        client.metrics,
                        client.train_loader,
                        online_clients_group,
                        data_mode='train')
            if client.args.fed_personal:
                do_validate(client.args,
                            client.model_server,
                            client.optimizer,
                            client.criterion,
                            client.metrics,
                            client.val_loader,
                            online_clients_group,
                            data_mode='validation')

            # validate the model at the server
            if client.args.graph.rank == 0:
                do_validate(client.args,
                            client.model_server,
                            client.optimizer,
                            client.criterion,
                            client.metrics,
                            client.test_loader,
                            online_clients_group,
                            data_mode='test')

        else:
            log("Offline in this round. Waiting on others to finish!",
                client.args.debug)

        # Update lambda parameters
        online_clients_lambda = set_online_clients(client.args)
        online_clients_server_lambda = online_clients_lambda if 0 in online_clients_lambda else [
            0
        ] + online_clients_lambda
        online_clients_group_lambda = dist.new_group(
            online_clients_server_lambda)

        if client.args.graph.rank in online_clients_server_lambda:
            st = time.time()
            client.kth_model = distribute_model_server(
                client.kth_model, online_clients_group_lambda, src=0)
            client.args.comm_time[-1] += time.time() - st
            loss = torch.tensor(0.0)

            if client.args.graph.rank in online_clients_lambda:
                for _input, _target in client.train_loader:
                    _input, _target = load_data_batch(client.args, _input,
                                                      _target, tracker)
                    # Skip batches with one sample because of BatchNorm issue in some models!
                    if _input.size(0) == 1:
                        break
                    loss, _ = inference(client.kth_model, client.criterion,
                                        client.metrics, _input, _target)
                    break
            loss_tensor_online = loss_gather(
                client.args,
                torch.tensor(loss.item()),
                group=online_clients_group_lambda,
                online_clients=online_clients_lambda)
            if client.args.graph.rank == 0:
                num_online_clients = len(
                    online_clients_lambda
                ) if 0 in online_clients_lambda else len(
                    online_clients_lambda) + 1
                loss_tensor = torch.zeros(client.args.graph.n_nodes)
                loss_tensor[sorted(
                    online_clients_server_lambda)] = loss_tensor_online * (
                        client.args.graph.n_nodes / num_online_clients)
                # Dual update
                client.lambda_vector += client.args.drfa_gamma * client.args.local_step * loss_tensor
                client.lambda_vector = euclidean_proj_simplex(
                    client.lambda_vector)

                # Avoid zero probability
                lambda_zeros = client.lambda_vector <= 1e-3
                if lambda_zeros.sum() > 0:
                    client.lambda_vector[lambda_zeros] = 1e-3
                    client.lambda_vector /= client.lambda_vector.sum()

        # logging.
        logging_globally(tracker, start_global_time)

        # reset start round time.
        start_global_time = time.time()
        log(
            'This round communication time is: {}'.format(
                client.args.comm_time[-1]), client.args.debug)
        dist.barrier(group=client.all_clients_group)
    return
Esempio n. 30
0
def train_and_validate_federated_afl(client):
    """The training scheme of Federated Learning systems.
        This the implementation of Agnostic Federated Learning
        https://arxiv.org/abs/1902.00146
    """
    log('start training and validation with Federated setting.',
        client.args.debug)

    if client.args.evaluate and client.args.graph.rank == 0:
        # Do the testing on the server and return
        do_validate(client.args,
                    client.model,
                    client.optimizer,
                    client.criterion,
                    client.metrics,
                    client.test_loader,
                    client.all_clients_group,
                    data_mode='test')
        return

    # Initialize lambda variable proportianate to their sample size
    if client.args.graph.rank == 0:
        gather_list_size = [
            torch.tensor(0.0) for _ in range(client.args.graph.n_nodes)
        ]
        dist.gather(torch.tensor(client.args.num_samples_per_epoch,
                                 dtype=torch.float32),
                    gather_list=gather_list_size,
                    dst=0)
        client.lambda_vector = torch.stack(
            gather_list_size) / client.args.train_dataset_size
    else:
        dist.gather(torch.tensor(client.args.num_samples_per_epoch,
                                 dtype=torch.float32),
                    dst=0)
        client.lambda_vector = torch.tensor([1 / client.args.graph.n_nodes] *
                                            client.args.graph.n_nodes)

    tracker = define_local_training_tracker()
    start_global_time = time.time()
    tracker['start_load_time'] = time.time()
    log('enter the training.', client.args.debug)

    # Number of communication rounds in federated setting should be defined
    for n_c in range(client.args.num_comms):
        client.args.rounds_comm += 1
        client.args.comm_time.append(0.0)
        # Configuring the devices for this round of communication
        # TODO: not make the server rank hard coded
        log("Starting round {} of training".format(n_c + 1), client.args.debug)
        online_clients = set_online_clients(client.args)
        if n_c == 0:
            # The first round server should be in the communication to initilize its own training
            online_clients = online_clients if 0 in online_clients else online_clients + [
                0
            ]
        online_clients_server = online_clients if 0 in online_clients else online_clients + [
            0
        ]
        online_clients_group = dist.new_group(online_clients_server)

        if client.args.graph.rank in online_clients_server:
            st = time.time()
            client.model_server = distribute_model_server(client.model_server,
                                                          online_clients_group,
                                                          src=0)
            dist.broadcast(client.lambda_vector,
                           src=0,
                           group=online_clients_group)
            client.args.comm_time[-1] += time.time() - st
            client.model.load_state_dict(client.model_server.state_dict())

            # This loss tensor is for those clients not participating in the first round
            loss = torch.tensor(0.0)
            # Start running updates on local machines
            if client.args.graph.rank in online_clients:
                is_sync = False
                while not is_sync:
                    for _input, _target in client.train_loader:

                        client.model.train()
                        # update local step.
                        logging_load_time(tracker)
                        # update local index and get local step
                        client.args.local_index += 1
                        client.args.local_data_seen += len(_target)
                        get_current_epoch(client.args)
                        local_step = get_current_local_step(client.args)

                        # adjust learning rate (based on the # of accessed samples)
                        lr = adjust_learning_rate(client.args,
                                                  client.optimizer,
                                                  client.scheduler)

                        # load data
                        _input, _target = load_data_batch(
                            client.args, _input, _target, tracker)

                        # Skip batches with one sample because of BatchNorm issue in some models!
                        if _input.size(0) == 1:
                            is_sync = is_sync_fed(client.args)
                            break

                        # inference and get current performance.
                        client.optimizer.zero_grad()
                        loss, performance = inference(client.model,
                                                      client.criterion,
                                                      client.metrics, _input,
                                                      _target)
                        # compute gradient and do local SGD step.
                        loss.backward()
                        client.optimizer.step(
                            apply_lr=True,
                            apply_in_momentum=client.args.in_momentum,
                            apply_out_momentum=False)

                        # logging locally.
                        # logging_computing(tracker, loss, performance, _input, lr)

                        # display the logging info.
                        # logging_display_training(args, tracker)

                        # reset load time for the tracker.
                        tracker['start_load_time'] = time.time()
                        is_sync = is_sync_fed(client.args)
                        if is_sync:
                            break
            else:
                log("Offline in this round. Waiting on others to finish!",
                    client.args.debug)

            # Validate the local models befor sync
            do_validate(client.args,
                        client.model,
                        client.optimizer,
                        client.criterion,
                        client.metrics,
                        client.train_loader,
                        online_clients_group,
                        data_mode='train',
                        local=True)
            if client.args.fed_personal:
                do_validate(client.args,
                            client.model,
                            client.optimizer,
                            client.criterion,
                            client.metrics,
                            client.val_loader,
                            online_clients_group,
                            data_mode='validation',
                            local=True)
            # Sync the model server based on client models
            log('Enter synching', client.args.debug)
            tracker['start_sync_time'] = time.time()
            client.args.global_index += 1

            client.model_server, loss_tensor_online = afl_aggregation(
                client.args, client.model_server, client.model,
                client.lambda_vector[client.args.graph.rank].item(),
                torch.tensor(loss.item()), online_clients_group,
                online_clients, client.optimizer)

            # evaluate the sync time
            logging_sync_time(tracker)
            # Do the validation on the server model
            do_validate(client.args,
                        client.model_server,
                        client.optimizer,
                        client.criterion,
                        client.metrics,
                        client.train_loader,
                        online_clients_group,
                        data_mode='train')
            if client.args.fed_personal:
                do_validate(client.args,
                            client.model_server,
                            client.optimizer,
                            client.criterion,
                            client.metrics,
                            client.val_loader,
                            online_clients_group,
                            data_mode='validation')

            # Updating lambda variable
            if client.args.graph.rank == 0:
                num_online_clients = len(
                    online_clients
                ) if 0 in online_clients else len(online_clients) + 1
                loss_tensor = torch.zeros(client.args.graph.n_nodes)
                loss_tensor[sorted(online_clients_server)] = loss_tensor_online
                # Dual update
                client.lambda_vector += client.args.drfa_gamma * loss_tensor
                # Projection into a simplex
                client.lambda_vector = euclidean_proj_simplex(
                    client.lambda_vector)
                # Avoid zero probability
                lambda_zeros = client.lambda_vector <= 1e-3
                if lambda_zeros.sum() > 0:
                    client.lambda_vector[lambda_zeros] = 1e-3
                    client.lambda_vector /= client.lambda_vector.sum()

            # logging.
            logging_globally(tracker, start_global_time)

            # reset start round time.
            start_global_time = time.time()
            # validate the model at the server
            if client.args.graph.rank == 0:
                do_validate(client.args,
                            client.model_server,
                            client.optimizer,
                            client.criterion,
                            client.metrics,
                            client.test_loader,
                            online_clients_group,
                            data_mode='test')
            log(
                'This round communication time is: {}'.format(
                    client.args.comm_time[-1]), client.args.debug)
        else:
            log("Offline in this round. Waiting on others to finish!",
                client.args.debug)
        dist.barrier(group=client.all_clients_group)

    return
def prune_and_eval(rank, size, orig_fit, acc_constraint, valid, es, ref_model,
                   num_runs, final_results):
    _valid = valid
    gpu_id = GPU_ID
    total_iterations = es.Tmax / es.popsize
    individual_iter_count = 0
    #ref_model = masked_models[rank]
    X = torch.Tensor(copy.deepcopy(es.pop))
    communicate_size = es.n + 4  # the size of tensors transfer accross computers
    communicate_tensor = torch.FloatTensor(communicate_size * [0.])
    fitness_list = []
    itr_best_remain = 0

    if rank == 0:  # rank 0 is the main process to collect finesses
        X.share_memory_()
        #fitness_list = [torch.FloatTensor([0.0,0.1,0.2,0.3]).share_memory_() for i in range(size)]
        fitness_list = [
            torch.FloatTensor(communicate_size * [0.]).share_memory_()
            for i in range(size)
        ]

    if rank >= 1 and rank < size:  # split tasks to different GPUs
        gpu_id = other_GPU_IDs[rank - 1]

    with cuda.device(gpu_id):
        local_fields = onmt.IO.load_fields(torch.load(TRAIN_DATA +
                                                      '.vocab.pt'))
        _valid.fields = local_fields  # fields can not be packed, so reconstruct it in each threahds

        while (individual_iter_count < total_iterations):
            if rank == 0:  # master node
                itr_X = torch.Tensor(es.ask())
                # broadcast the fathers
                X.copy_(itr_X)
                dist.broadcast(itr_X, 0)
            else:
                # recieve fathers from the source process
                dist.broadcast(X, 0)

            # apply MP on model
            x = X.numpy()[rank]
            ref_model.change_mask(x, apply_MP_on_mask)

            ref_model.apply_mask()

            # evaluate pruned network
            fitness = evaluate(ref_model, _valid, local_fields)
            communicate_tensor[0] = fitness[0]
            communicate_tensor[1] = fitness[1]
            communicate_tensor[2] = rank
            communicate_tensor[3] = ref_model.get_sparsity()
            for i in range(x.size):
                communicate_tensor[i + 4] = X[rank, i]  #x[i]

            # sync fitness
            if rank == 0:  # collect fitness across processes
                dist.gather(communicate_tensor, gather_list=fitness_list)
            else:
                dist.gather(communicate_tensor, dst=0)

            # judge new solutions
            if rank == 0:  # negatively correlated search in master node
                fit = []
                X_ = []
                for i in range(es.popsize):
                    the_fitness = 100
                    for j in range(len(
                            fitness_list)):  # results of fitness evaluation
                        if int(fitness_list[j]
                               [2]) == i:  # 0:ppl, 1:acc, 2:rank of individual
                            X_.append(fitness_list[j].numpy()[4:])
                            if orig_fit[1] - fitness_list[j][
                                    1] <= acc_constraint:
                                the_fitness = -fitness_list[j][3]
                            else:
                                the_fitness = (orig_fit[1] - fitness_list[j][1]
                                               ) / acc_constraint
                            continue
                    fit.append(the_fitness)

                es.tell(X_, fit)

                itr_best_remain = min(fit)

            final_results['result_NCS'].copy_(torch.Tensor(es.result()[0]))
            individual_iter_count += 1

            if rank == 0:  # record status
                logger.scalar_summary(
                    'ncs_%s_fitness' % num_runs,
                    es.result()[1],
                    num_runs * total_iterations + individual_iter_count)
                logger.scalar_summary(
                    'ncs_%s_best_itr_remain' % num_runs, itr_best_remain,
                    num_runs * total_iterations + individual_iter_count)
                logger.histo_summary(
                    'ncs_%s_pop' % num_runs,
                    es.result()[0],
                    num_runs * total_iterations + individual_iter_count)
                logger.histo_summary(
                    'pop of 1', X_[0],
                    num_runs * total_iterations + individual_iter_count)
                logger.scalar_summary(
                    'sp of 1', -fitness_list[0][3],
                    num_runs * total_iterations + individual_iter_count)
                logger.scalar_summary(
                    'rank of 1', fitness_list[0][2],
                    num_runs * total_iterations + individual_iter_count)
                logger.histo_summary(
                    'pop of 2', X_[1],
                    num_runs * total_iterations + individual_iter_count)
                logger.scalar_summary(
                    'sp of 2', -fitness_list[1][3],
                    num_runs * total_iterations + individual_iter_count)
                logger.scalar_summary(
                    'rank of 2', fitness_list[1][2],
                    num_runs * total_iterations + individual_iter_count)
                #logger.histo_summary('pop of 3', X_[2], num_runs*total_iterations + individual_iter_count)
                #logger.scalar_summary('sp of 3', -fitness_list[2][3], num_runs*total_iterations + individual_iter_count)
                #logger.scalar_summary('rank of 3', fitness_list[2][2], num_runs*total_iterations + individual_iter_count)

    ref_model.clear_cache()
Esempio n. 32
0
# ---------------- ALL_REDUCE -----------------
if True:
    tensor = new_tensor(device_id, local_value)
    dist.all_reduce(tensor, op=dist.reduce_op.SUM)

    # all ranks become 10.0
    print('{} AFTER all_reduce {}'.format(local_rank, tensor))
    assert_mean(tensor, 10.)

# ---------------- GATHER -----------------
if backend in ['tcp']:
    tensor = new_tensor(device_id, local_value)
    if local_rank == 0:
        gather_list = [new_tensor(device_id, 0.) for _ in range(4)]
        dist.gather(tensor, dst=0, gather_list=gather_list)
    else:
        dist.gather(tensor, dst=0)
        gather_list = None

    # all ranks become 24
    if local_rank == 0:
        print('{} AFTER gather {}'.format(local_rank, gather_list))
        assert_mean(gather_list[0], 1.)
        assert_mean(gather_list[1], 2.)
        assert_mean(gather_list[2], 3.)
        assert_mean(gather_list[3], 4.)

# ---------------- ALL_GATHER -----------------
if backend in ['tcp', 'nccl']:
    tensor = new_tensor(device_id, local_value)