示例#1
0
    def test_broadcast_coalesced(self):
        numel = 5
        num_bytes = numel * 8
        tensors = [
            torch.randn(numel).long().cuda(),
            torch.randn(numel).cuda(),
            torch.randn(numel).long().cuda(),
            torch.randn(numel).long().cuda(),
            torch.randn(numel * 2).int().cuda(),  # int is 2x shorter
            torch.randn(numel).cuda(),
        ]

        b_tensors = [comm.broadcast(t, (0, 1)) for t in tensors]
        for (_, bt), t in zip(b_tensors, tensors):
            self.assertEqual(bt.get_device(), 1)
            self.assertEqual(bt, t)
            self.assertIsInstance(bt, type(t))

        bc_tensors = comm.broadcast_coalesced(tensors, (0, 1),
                                              buffer_size=num_bytes * 5 // 2)
        bc_tensors_t = list(zip(*bc_tensors))
        self.assertEqual(b_tensors, bc_tensors_t)
        for (_, bt), (_, bct) in zip(b_tensors, bc_tensors_t):
            self.assertEqual(bt.get_device(), bct.get_device())
            self.assertIsInstance(bct, type(bt))
示例#2
0
 def _test_broadcast(self, input):
     if torch.cuda.device_count() < 2:
         raise unittest.SkipTest("only one GPU detected")
     result = comm.broadcast(input, (0, 1))
     for i, t in enumerate(result):
         self.assertEqual(t.get_device(), i)
         self.assertEqual(t, input)
示例#3
0
def data_parallel(f,
                  input,
                  params,
                  stats,
                  mode,
                  device_ids,
                  output_device=None):
    if output_device is None:
        output_device = device_ids[0]

    if len(device_ids) == 1:
        return f(input, params, stats, mode)

    def replicate(param_dict, g):
        replicas = [{} for d in device_ids]
        for k, v in param_dict.iteritems():
            for i, u in enumerate(g(v)):
                replicas[i][k] = u
        return replicas

    params_replicas = replicate(params, lambda x: Broadcast(device_ids)(x))
    stats_replicas = replicate(stats, lambda x: comm.broadcast(x, device_ids))

    replicas = [
        lambda x, p=p, s=s, mode=mode: f(x, p, s, mode)
        for i, (p, s) in enumerate(zip(params_replicas, stats_replicas))
    ]
    inputs = scatter(input, device_ids)
    outputs = parallel_apply(replicas, inputs)
    return gather(outputs, output_device)
示例#4
0
def data_parallel(f, input, params, stats, mode, device_ids, output_device=None):
    if output_device is None:
        output_device = device_ids[0]

    if len(device_ids) == 1: # only 1 device 
        return f(input, params, stats, mode)
    
    # function inside data_parallel 
    def replicate(param_dict, g):
        replicas = [{} for d in device_ids]  # replicas, list of n_devices dict
        for k,v in param_dict.iteritems():  # v is parameter
            for i,u in enumerate(g(v)):
                replicas[i][k] = u
        return replicas
    
    # broadcast parameters 
    params_replicas = replicate(params, lambda x: Broadcast(device_ids)(x))
    # broadcast stats 
    stats_replicas = replicate(stats, lambda x: comm.broadcast(x, device_ids))

    replicas = [lambda x,p=p,s=s,mode=mode: f(x,p,s,mode)
            for i,(p,s) in enumerate(zip(params_replicas, stats_replicas))]

    inputs = scatter(input, device_ids)

    outputs = parallel_apply(replicas, inputs)

    return gather(outputs, output_device)
示例#5
0
    def _test_broadcast_coalesced(self, tensors, buffer_size):
        b_tensors = [comm.broadcast(t, (0, 1)) for t in tensors]
        for (_, bt), t in zip(b_tensors, tensors):
            self.assertEqual(bt.get_device(), 1)
            self.assertEqual(bt, t)
            self.assertIsInstance(bt, type(t))

        bc_tensors = comm.broadcast_coalesced(tensors, (0, 1), buffer_size=buffer_size)
        bc_tensors_t = list(zip(*bc_tensors))
        self.assertEqual(b_tensors, bc_tensors_t)
        for (_, bt), (_, bct) in zip(b_tensors, bc_tensors_t):
            self.assertEqual(bt.get_device(), bct.get_device())
            self.assertIsInstance(bct, type(bt))
示例#6
0
def replicate(module, device_ids):
    from ._functions import Broadcast
    seen_params = set()
    param_remap = [{} for dev_id in device_ids]
    for param in module.parameters():
        if param in seen_params:
            continue
        seen_params.add(param)
        param_copies = Broadcast(device_ids)(param)
        for param_copy, remap in zip(param_copies, param_remap):
            remap[param] = param_copy
    for m in module.modules():
        for buffer in m._buffers.values():
            copies = comm.broadcast(buffer, device_ids)
            for buf_copy, remap in zip(copies, param_remap):
                remap[buffer] = buf_copy
    return [
        _replicate_module(module, device_id, remap)
        for device_id, remap in zip(device_ids, param_remap)
    ]
示例#7
0
 def forward(self, input):
     assert input.is_cuda, "Broadcast function not implemented for CPU tensors"
     self.input_device = input.get_device()
     return comm.broadcast(input, self.target_gpus)
示例#8
0
 def soft_argmax_1d(self, heatmap1d):
     heatmap1d = F.softmax(heatmap1d, 1)   # bs, 64
     accu = heatmap1d * comm.broadcast(torch.arange(self.output_root_hm_shape).type(torch.cuda.FloatTensor), devices=[heatmap1d.device.index])[0]  # bs, 64
     coord = accu.sum(dim=1) # bs
     return coord
def finetune(novel_loader,
             n_query=15,
             pretrained_dataset='miniImageNet',
             freeze_backbone=False,
             n_way=5,
             n_support=5):

    iter_num = len(novel_loader)

    val_accs = AverageMeter()
    es = AverageMeter()

    models_score = []

    for task_i, (x, y) in enumerate(novel_loader):

        ###############################################################################################
        # load pretrained model on miniImageNet

        if task_i == 0:
            state_list = []
            imgisgray = True if torch.abs(
                torch.mean(x[0, 0, 0, :, :] * 0.229 -
                           x[0, 0, 1, :, :] * 0.224 + 0.485 -
                           0.456)) < 1e-5 else False
            for model_name in params.model_list:
                checkpoint_dir = '%s/%s%s_%s.tar' % (params.model_dir,
                                                     model_name, 'G50'
                                                     if imgisgray else '', 399)
                print(checkpoint_dir)
                tmp = torch.load(checkpoint_dir)
                state = tmp['state']
                state_keys = list(state.keys())
                for _, key in enumerate(state_keys):
                    if "feature." in key:
                        newkey = key.replace("feature.", "")
                        state[newkey] = state.pop(key)
                    else:
                        state.pop(key)
                state_list.append(state)

        pretrained_model_list = []
        classifier_list = []
        model_num = len(params.model_list)
        device_list = list(range(torch.cuda.device_count(
        )))[:model_num] if params.device_list is None else params.device_list
        # set params.device_list =[0,0,0,0,0,0,0,0] to run all models on cuda:0
        for model_idx, model_name in enumerate(params.model_list):
            pretrained_model = model_dict[model_name]()
            pretrained_model.load_state_dict(
                copy.deepcopy(state_list[model_idx]))
            pretrained_model.to(device_list[model_idx])
            classifier = Classifier(pretrained_model.final_feat_dim, n_way)
            classifier.to(device_list[model_idx])
            pretrained_model_list.append(pretrained_model)
            classifier_list.append(classifier)

        class_p = comm.broadcast(torch.ones(1, n_way) / n_way, device_list)
        model_pool = [
            TrainingModel(mo, cl, de, params.p_coef, params.e_coef, cp, params)
            for (mo, cl, de, cp) in zip(pretrained_model_list, classifier_list,
                                        device_list, class_p)
        ]

        n_query = x.size(1) - n_support
        support_size = n_way * n_support

        y_a_i = torch.from_numpy(
            np.repeat(range(n_way), n_support).astype(np.int64))
        y_a_i_oh = torch.zeros(support_size,
                               n_way).scatter_(1, y_a_i.view(-1, 1), 1)

        x_a_i = x[:, :n_support, :, :, :].contiguous().view(
            n_way * n_support,
            *x.size()[2:])
        x_b_i = x[:, n_support:, :, :, :].contiguous().view(
            n_way * n_query,
            *x.size()[2:])

        x_a_list = comm.broadcast(x_a_i, device_list)
        x_b_list = comm.broadcast(x_b_i, device_list)
        y_a_list = comm.broadcast(y_a_i_oh, device_list)
        total_epoch = params.joint_start_epoch + params.joint_epoch
        for mi, m in enumerate(model_pool):
            m.init_task(x_a_list[mi], y_a_list[mi], x_b_list[mi])

        for epoch in range(total_epoch):
            parallel_apply(model_pool, [[] for _ in range(model_num)])
            if params.mml is not None and epoch >= params.joint_start_epoch - 1 and epoch < total_epoch - 1:
                if params.mml == 'all':
                    cur_score = [
                        m.scores_history[m.epoch -
                                         params.use_epoch:m.epoch].to(0)
                        for m in model_pool
                    ]
                    #print(cur_score)
                    mean_score = comm.broadcast(
                        torch.mean(torch.cat(cur_score, dim=0), dim=0),
                        device_list)
                    for mi, m in enumerate(model_pool):
                        m.y = torch.cat((m.y_tr, mean_score[mi]), dim=0)
                else:
                    print('Wrong type !')

        val_acc_m = []
        for mi, m in enumerate(model_pool):
            val_acc_m.append(m.val_acc.cpu().numpy())
        val_acc_m = np.array(val_acc_m)
        val_accs.update(val_acc_m)

        s = []
        for mi, m in enumerate(model_pool):
            s.append(m.scores_history.cpu().numpy())
        s = np.array(s)
        models_score.append(s)

        es.update(avg_ensemble(s))
        print(es.val, es.avg)
        for i in range(model_num):
            print(val_accs.val[i, params.joint_start_epoch - 1],
                  val_accs.val[i, -1],
                  val_accs.avg[i,
                               params.joint_start_epoch - 1], val_accs.avg[i,
                                                                           -1])
        print('#################################################')

    c95 = val_accs.c95()
    for i in range(model_num):
        print('%d Test Acc = %4.2f+%4.2f%%, %4.2f+%4.2f%%, %4.2f+%4.2f%%' %
              (iter_num, val_accs.avg[i, params.joint_start_epoch - 1],
               c95[i, params.joint_start_epoch - 1], val_accs.avg[i, 299],
               c95[i, 299], val_accs.avg[i, -1], c95[i, -1]))
        print(val_accs.avg[i])

    c95 = es.c95()
    print('Ensemble Avg. Acc = %4.2f+%4.2f%%' % (es.avg, c95))

    return np.array(models_score)