def test_broadcast_coalesced(self): numel = 5 num_bytes = numel * 8 tensors = [ torch.randn(numel).long().cuda(), torch.randn(numel).cuda(), torch.randn(numel).long().cuda(), torch.randn(numel).long().cuda(), torch.randn(numel * 2).int().cuda(), # int is 2x shorter torch.randn(numel).cuda(), ] b_tensors = [comm.broadcast(t, (0, 1)) for t in tensors] for (_, bt), t in zip(b_tensors, tensors): self.assertEqual(bt.get_device(), 1) self.assertEqual(bt, t) self.assertIsInstance(bt, type(t)) bc_tensors = comm.broadcast_coalesced(tensors, (0, 1), buffer_size=num_bytes * 5 // 2) bc_tensors_t = list(zip(*bc_tensors)) self.assertEqual(b_tensors, bc_tensors_t) for (_, bt), (_, bct) in zip(b_tensors, bc_tensors_t): self.assertEqual(bt.get_device(), bct.get_device()) self.assertIsInstance(bct, type(bt))
def _test_broadcast(self, input): if torch.cuda.device_count() < 2: raise unittest.SkipTest("only one GPU detected") result = comm.broadcast(input, (0, 1)) for i, t in enumerate(result): self.assertEqual(t.get_device(), i) self.assertEqual(t, input)
def data_parallel(f, input, params, stats, mode, device_ids, output_device=None): if output_device is None: output_device = device_ids[0] if len(device_ids) == 1: return f(input, params, stats, mode) def replicate(param_dict, g): replicas = [{} for d in device_ids] for k, v in param_dict.iteritems(): for i, u in enumerate(g(v)): replicas[i][k] = u return replicas params_replicas = replicate(params, lambda x: Broadcast(device_ids)(x)) stats_replicas = replicate(stats, lambda x: comm.broadcast(x, device_ids)) replicas = [ lambda x, p=p, s=s, mode=mode: f(x, p, s, mode) for i, (p, s) in enumerate(zip(params_replicas, stats_replicas)) ] inputs = scatter(input, device_ids) outputs = parallel_apply(replicas, inputs) return gather(outputs, output_device)
def data_parallel(f, input, params, stats, mode, device_ids, output_device=None): if output_device is None: output_device = device_ids[0] if len(device_ids) == 1: # only 1 device return f(input, params, stats, mode) # function inside data_parallel def replicate(param_dict, g): replicas = [{} for d in device_ids] # replicas, list of n_devices dict for k,v in param_dict.iteritems(): # v is parameter for i,u in enumerate(g(v)): replicas[i][k] = u return replicas # broadcast parameters params_replicas = replicate(params, lambda x: Broadcast(device_ids)(x)) # broadcast stats stats_replicas = replicate(stats, lambda x: comm.broadcast(x, device_ids)) replicas = [lambda x,p=p,s=s,mode=mode: f(x,p,s,mode) for i,(p,s) in enumerate(zip(params_replicas, stats_replicas))] inputs = scatter(input, device_ids) outputs = parallel_apply(replicas, inputs) return gather(outputs, output_device)
def _test_broadcast_coalesced(self, tensors, buffer_size): b_tensors = [comm.broadcast(t, (0, 1)) for t in tensors] for (_, bt), t in zip(b_tensors, tensors): self.assertEqual(bt.get_device(), 1) self.assertEqual(bt, t) self.assertIsInstance(bt, type(t)) bc_tensors = comm.broadcast_coalesced(tensors, (0, 1), buffer_size=buffer_size) bc_tensors_t = list(zip(*bc_tensors)) self.assertEqual(b_tensors, bc_tensors_t) for (_, bt), (_, bct) in zip(b_tensors, bc_tensors_t): self.assertEqual(bt.get_device(), bct.get_device()) self.assertIsInstance(bct, type(bt))
def replicate(module, device_ids): from ._functions import Broadcast seen_params = set() param_remap = [{} for dev_id in device_ids] for param in module.parameters(): if param in seen_params: continue seen_params.add(param) param_copies = Broadcast(device_ids)(param) for param_copy, remap in zip(param_copies, param_remap): remap[param] = param_copy for m in module.modules(): for buffer in m._buffers.values(): copies = comm.broadcast(buffer, device_ids) for buf_copy, remap in zip(copies, param_remap): remap[buffer] = buf_copy return [ _replicate_module(module, device_id, remap) for device_id, remap in zip(device_ids, param_remap) ]
def forward(self, input): assert input.is_cuda, "Broadcast function not implemented for CPU tensors" self.input_device = input.get_device() return comm.broadcast(input, self.target_gpus)
def soft_argmax_1d(self, heatmap1d): heatmap1d = F.softmax(heatmap1d, 1) # bs, 64 accu = heatmap1d * comm.broadcast(torch.arange(self.output_root_hm_shape).type(torch.cuda.FloatTensor), devices=[heatmap1d.device.index])[0] # bs, 64 coord = accu.sum(dim=1) # bs return coord
def finetune(novel_loader, n_query=15, pretrained_dataset='miniImageNet', freeze_backbone=False, n_way=5, n_support=5): iter_num = len(novel_loader) val_accs = AverageMeter() es = AverageMeter() models_score = [] for task_i, (x, y) in enumerate(novel_loader): ############################################################################################### # load pretrained model on miniImageNet if task_i == 0: state_list = [] imgisgray = True if torch.abs( torch.mean(x[0, 0, 0, :, :] * 0.229 - x[0, 0, 1, :, :] * 0.224 + 0.485 - 0.456)) < 1e-5 else False for model_name in params.model_list: checkpoint_dir = '%s/%s%s_%s.tar' % (params.model_dir, model_name, 'G50' if imgisgray else '', 399) print(checkpoint_dir) tmp = torch.load(checkpoint_dir) state = tmp['state'] state_keys = list(state.keys()) for _, key in enumerate(state_keys): if "feature." in key: newkey = key.replace("feature.", "") state[newkey] = state.pop(key) else: state.pop(key) state_list.append(state) pretrained_model_list = [] classifier_list = [] model_num = len(params.model_list) device_list = list(range(torch.cuda.device_count( )))[:model_num] if params.device_list is None else params.device_list # set params.device_list =[0,0,0,0,0,0,0,0] to run all models on cuda:0 for model_idx, model_name in enumerate(params.model_list): pretrained_model = model_dict[model_name]() pretrained_model.load_state_dict( copy.deepcopy(state_list[model_idx])) pretrained_model.to(device_list[model_idx]) classifier = Classifier(pretrained_model.final_feat_dim, n_way) classifier.to(device_list[model_idx]) pretrained_model_list.append(pretrained_model) classifier_list.append(classifier) class_p = comm.broadcast(torch.ones(1, n_way) / n_way, device_list) model_pool = [ TrainingModel(mo, cl, de, params.p_coef, params.e_coef, cp, params) for (mo, cl, de, cp) in zip(pretrained_model_list, classifier_list, device_list, class_p) ] n_query = x.size(1) - n_support support_size = n_way * n_support y_a_i = torch.from_numpy( np.repeat(range(n_way), n_support).astype(np.int64)) y_a_i_oh = torch.zeros(support_size, n_way).scatter_(1, y_a_i.view(-1, 1), 1) x_a_i = x[:, :n_support, :, :, :].contiguous().view( n_way * n_support, *x.size()[2:]) x_b_i = x[:, n_support:, :, :, :].contiguous().view( n_way * n_query, *x.size()[2:]) x_a_list = comm.broadcast(x_a_i, device_list) x_b_list = comm.broadcast(x_b_i, device_list) y_a_list = comm.broadcast(y_a_i_oh, device_list) total_epoch = params.joint_start_epoch + params.joint_epoch for mi, m in enumerate(model_pool): m.init_task(x_a_list[mi], y_a_list[mi], x_b_list[mi]) for epoch in range(total_epoch): parallel_apply(model_pool, [[] for _ in range(model_num)]) if params.mml is not None and epoch >= params.joint_start_epoch - 1 and epoch < total_epoch - 1: if params.mml == 'all': cur_score = [ m.scores_history[m.epoch - params.use_epoch:m.epoch].to(0) for m in model_pool ] #print(cur_score) mean_score = comm.broadcast( torch.mean(torch.cat(cur_score, dim=0), dim=0), device_list) for mi, m in enumerate(model_pool): m.y = torch.cat((m.y_tr, mean_score[mi]), dim=0) else: print('Wrong type !') val_acc_m = [] for mi, m in enumerate(model_pool): val_acc_m.append(m.val_acc.cpu().numpy()) val_acc_m = np.array(val_acc_m) val_accs.update(val_acc_m) s = [] for mi, m in enumerate(model_pool): s.append(m.scores_history.cpu().numpy()) s = np.array(s) models_score.append(s) es.update(avg_ensemble(s)) print(es.val, es.avg) for i in range(model_num): print(val_accs.val[i, params.joint_start_epoch - 1], val_accs.val[i, -1], val_accs.avg[i, params.joint_start_epoch - 1], val_accs.avg[i, -1]) print('#################################################') c95 = val_accs.c95() for i in range(model_num): print('%d Test Acc = %4.2f+%4.2f%%, %4.2f+%4.2f%%, %4.2f+%4.2f%%' % (iter_num, val_accs.avg[i, params.joint_start_epoch - 1], c95[i, params.joint_start_epoch - 1], val_accs.avg[i, 299], c95[i, 299], val_accs.avg[i, -1], c95[i, -1])) print(val_accs.avg[i]) c95 = es.c95() print('Ensemble Avg. Acc = %4.2f+%4.2f%%' % (es.avg, c95)) return np.array(models_score)