def pre_calculate(train_data, class_names, net, args): with torch.no_grad(): all_classes = np.unique(train_data['label']) num_classes = len(all_classes) # 生成sample类时候的概率矩阵 train_class_names = {} train_class_names['text'] = class_names['text'][all_classes] train_class_names['text_len'] = class_names['text_len'][all_classes] train_class_names['label'] = class_names['label'][all_classes] train_class_names = utils.to_tensor(train_class_names, args.cuda) train_class_names_ebd = net.ebd(train_class_names) # [10, 36, 300] train_class_names_ebd = torch.sum( train_class_names_ebd, dim=1) / train_class_names['text_len'].view( (-1, 1)) # [10, 300] dist_metrix = -neg_dist(train_class_names_ebd, train_class_names_ebd) # [10, 10] for i, d in enumerate(dist_metrix): if i == 0: dist_metrix_nodiag = del_tensor_ele(d, i).view((1, -1)) else: dist_metrix_nodiag = torch.cat( (dist_metrix_nodiag, del_tensor_ele(d, i).view((1, -1))), dim=0) prob_metrix = F.softmax(dist_metrix_nodiag, dim=1) # [10, 9] prob_metrix = prob_metrix.cpu().numpy() # 生成sample样本时候的概率矩阵 example_prob_metrix = [] for i, label in enumerate(all_classes): train_examples = {} train_examples['text'] = train_data['text'][train_data['label'] == label] train_examples['text_len'] = train_data['text_len'][ train_data['label'] == label] train_examples['label'] = train_data['label'][train_data['label'] == label] train_examples = utils.to_tensor(train_examples, args.cuda) train_examples_ebd = net.ebd(train_examples) train_examples_ebd = torch.sum( train_examples_ebd, dim=1) / train_examples['text_len'].view( (-1, 1)) # [N, 300] example_prob_metrix_one = -neg_dist( train_class_names_ebd[i].view((1, -1)), train_examples_ebd) example_prob_metrix_one = F.softmax(example_prob_metrix_one, dim=1) # [1, 1000] example_prob_metrix_one = example_prob_metrix_one.cpu().numpy() example_prob_metrix.append(example_prob_metrix_one) return prob_metrix, example_prob_metrix
def test_one(task, class_names, model, optCLF, args, grad): ''' Train the model on one sampled task. ''' # model['G'].eval() # model['clf'].train() support, query = task # print("support, query:", support, query) # print("class_names_dict:", class_names_dict) sampled_classes = torch.unique(support['label']).cpu().numpy().tolist() # print("sampled_classes:", sampled_classes) class_names_dict = {} class_names_dict['label'] = class_names['label'][sampled_classes] # print("class_names_dict['label']:", class_names_dict['label']) class_names_dict['text'] = class_names['text'][sampled_classes] class_names_dict['text_len'] = class_names['text_len'][sampled_classes] class_names_dict['is_support'] = False class_names_dict = utils.to_tensor(class_names_dict, args.cuda, exclude_keys=['is_support']) # Embedding the document XS = model['G'](support) # XS:[N*K, 256(hidden_size*2)] # print("XS:", XS.shape) YS = support['label'] # print('YS:', YS) CN = model['G'](class_names_dict) # CN:[N, 256(hidden_size*2)]] # print("CN:", CN.shape) XQ = model['G'](query) YQ = query['label'] # print('YQ:', YQ) YS, YQ = reidx_y(args, YS, YQ) for _ in range(args.test_iter): # Embedding the document XS_mlp = model['clf'](XS) # [N*K, 256(hidden_size*2)] -> [N*K, 128] CN_mlp = model['clf'](CN) # [N, 256(hidden_size*2)]] -> [N, 128] neg_d = neg_dist(XS_mlp, CN_mlp) # [N*K, N] # print("neg_d:", neg_d.shape) mlp_loss = model['clf'].loss(neg_d, YS) # print("mlp_loss:", mlp_loss) optCLF.zero_grad() mlp_loss.backward(retain_graph=True) optCLF.step() XQ_mlp = model['clf'](XQ) CN_mlp = model['clf'](CN) neg_d = neg_dist(XQ_mlp, CN_mlp) _, pred = torch.max(neg_d, 1) acc_q = model['clf'].accuracy(pred, YQ) return acc_q
def test_one(task, class_names_dict, model, optG, optCLF, args, grad): ''' Train the model on one sampled task. ''' support, query = task # print("support, query:", support, query) # print("class_names_dict:", class_names_dict) '第一步:更新G,让类描述相互离得远一些' cn_loss_all = 0 for _ in range(5): CN = model['G'](class_names_dict) # CN:[N, 256(hidden_size*2)] # print("CN:", CN.shape) dis = neg_dist(CN, CN) # [N, N] cn_loss = torch.sum(dis) - torch.sum(torch.diag(dis)) cn_loss_all += cn_loss optG.zero_grad() cn_loss.backward(retain_graph=True) optG.step() print('***********[TEST] cn_loss:', cn_loss_all / 5) '把CN过微调过的G, S和Q过G2' # Embedding the document XS = model['G2'](support) # XS:[N*K, 256(hidden_size*2)] # print("XS:", XS.shape) YS = support['label'] # print('YS:', YS) CN = model['G'](class_names_dict) # CN:[N, 256(hidden_size*2)]] # print("CN:", CN.shape) XQ = model['G2'](query) YQ = query['label'] # print('YQ:', YQ) YS, YQ = reidx_y(args, YS, YQ) '第二步:用Support更新MLP' for _ in range(args.test_iter): # Embedding the document XS_mlp = model['clf'](XS) # [N*K, 256(hidden_size*2)] -> [N*K, 256] neg_d = neg_dist(XS_mlp, CN) # [N*K, N] # print("neg_d:", neg_d.shape) mlp_loss = model['clf'].loss(neg_d, YS) # print("mlp_loss:", mlp_loss) optCLF.zero_grad() mlp_loss.backward(retain_graph=True) optCLF.step() XQ_mlp = model['clf'](XQ) neg_d = neg_dist(XQ_mlp, CN) _, pred = torch.max(neg_d, 1) acc_q = model['clf'].accuracy(pred, YQ) return acc_q
def train_one(task, class_names_dict, model, optG, optG2, optCLF, args, grad): ''' Train the model on one sampled task. ''' model['G'].train() model['G2'].train() model['clf'].train() support, query = task # print("support, query:", support, query) # print("class_names_dict:", class_names_dict) '第一步:更新G,让类描述相互离得远一些' cn_loss_all = 0 for _ in range(5): CN = model['G'](class_names_dict) # CN:[N, 256(hidden_size*2)] # print("CN:", CN) dis = neg_dist(CN, CN) / torch.mean(neg_dist(CN, CN), dim=0) # [N, N] # print("dis:", dis) m = torch.tensor(2.0).cuda() cn_loss = torch.tensor(0.0).cuda() for i, d in enumerate(dis): for j, dd in enumerate(d): if i != j: cn_loss = cn_loss + ((torch.max( torch.tensor(0.0).cuda(), m + dd))**2) / 2 / 20 print("cn_loss:", cn_loss) # cn_loss = cn_loss print("********************cn_loss:", cn_loss) cn_loss_all += cn_loss for name, param in model['G'].named_parameters(): print("name:", name, "param:", param) optG.zero_grad() cn_loss.backward() optG.step() print('***********cn_loss:', cn_loss_all / 5) '把CN过微调过的G, S和Q过G2' CN = model['G'](class_names_dict) # CN:[N, 256(hidden_size*2)] # Embedding the document XS = model['G2'](support) # XS:[N*K, 256(hidden_size*2)] # print("XS:", XS.shape) YS = support['label'] # print('YS:', YS) XQ = model['G2'](query) YQ = query['label'] # print('YQ:', YQ) YS, YQ = reidx_y(args, YS, YQ) # 映射标签为从0开始 '第二步:用Support更新MLP' for _ in range(args.train_iter): # Embedding the document XS_mlp = model['clf'](XS) # [N*K, 256(hidden_size*2)] -> [N*K, 256] neg_d = neg_dist(XS_mlp, CN) # [N*K, N] # print("neg_d:", neg_d.shape) mlp_loss = model['clf'].loss(neg_d, YS) # print("mlp_loss:", mlp_loss) optCLF.zero_grad() mlp_loss.backward(retain_graph=True) optCLF.step() '第三步:用Q更新G2' XQ_mlp = model['clf'](XQ) neg_d = neg_dist(XQ_mlp, CN) q_loss = model['clf'].loss(neg_d, YQ) optG2.zero_grad() q_loss.backward() optG2.step() _, pred = torch.max(neg_d, 1) acc_q = model['clf'].accuracy(pred, YQ) # YQ_d = torch.ones(query['label'].shape, dtype=torch.long).to(query['label'].device) # print('YQ', set(YQ.numpy())) # XSource, XSource_inputD, _ = model['G'](source) # YSource_d = torch.zeros(source['label'].shape, dtype=torch.long).to(source['label'].device) # XQ_logitsD = model['D'](XQ_inputD) # XSource_logitsD = model['D'](XSource_inputD) # # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d) # d_loss.backward(retain_graph=True) # grad['D'].append(get_norm(model['D'])) # optD.step() # # # *****************update G**************** # optG.zero_grad() # XQ_logitsD = model['D'](XQ_inputD) # XSource_logitsD = model['D'](XSource_inputD) # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d) # # acc, d_acc, loss, _ = model['clf'](XS, YS, XQ, YQ, XQ_logitsD, XSource_logitsD, YQ_d, YSource_d) # # g_loss = loss - d_loss # if args.ablation == "-DAN": # g_loss = loss # print("%%%%%%%%%%%%%%%%%%%This is ablation mode: -DAN%%%%%%%%%%%%%%%%%%%%%%%%%%") # g_loss.backward(retain_graph=True) # grad['G'].append(get_norm(model['G'])) # grad['clf'].append(get_norm(model['clf'])) # optG.step() return q_loss, acc_q
def test_one(task, class_names, model, optG, criterion, args, grad): ''' Train the model on one sampled task. ''' support, query = task # print("support, query:", support, query) # print("class_names_dict:", class_names_dict) '''分样本对''' YS = support['label'] YQ = query['label'] sampled_classes = torch.unique(support['label']).cpu().numpy().tolist() # print("sampled_classes:", sampled_classes) class_names_dict = {} class_names_dict['label'] = class_names['label'][sampled_classes] # print("class_names_dict['label']:", class_names_dict['label']) class_names_dict['text'] = class_names['text'][sampled_classes] class_names_dict['text_len'] = class_names['text_len'][sampled_classes] class_names_dict['is_support'] = False class_names_dict = utils.to_tensor(class_names_dict, args.cuda, exclude_keys=['is_support']) YS, YQ = reidx_y(args, YS, YQ) # print('YS:', support['label']) # print('YQ:', query['label']) # print("class_names_dict:", class_names_dict['label']) """维度填充""" if support['text'].shape[1] != class_names_dict['text'].shape[1]: zero = torch.zeros( (class_names_dict['text'].shape[0], support['text'].shape[1] - class_names_dict['text'].shape[1]), dtype=torch.long) class_names_dict['text'] = torch.cat( (class_names_dict['text'], zero.cuda()), dim=-1) support['text'] = torch.cat((support['text'], class_names_dict['text']), dim=0) support['text_len'] = torch.cat( (support['text_len'], class_names_dict['text_len']), dim=0) support['label'] = torch.cat((support['label'], class_names_dict['label']), dim=0) # print("support['text']:", support['text'].shape) # print("support['label']:", support['label']) text_sample_len = support['text'].shape[0] # print("support['text'].shape[0]:", support['text'].shape[0]) support['text_1'] = support['text'][0].view((1, -1)) support['text_len_1'] = support['text_len'][0].view(-1) support['label_1'] = support['label'][0].view(-1) for i in range(text_sample_len): if i == 0: for j in range(1, text_sample_len): support['text_1'] = torch.cat( (support['text_1'], support['text'][i].view((1, -1))), dim=0) support['text_len_1'] = torch.cat( (support['text_len_1'], support['text_len'][i].view(-1)), dim=0) support['label_1'] = torch.cat( (support['label_1'], support['label'][i].view(-1)), dim=0) else: for j in range(text_sample_len): support['text_1'] = torch.cat( (support['text_1'], support['text'][i].view((1, -1))), dim=0) support['text_len_1'] = torch.cat( (support['text_len_1'], support['text_len'][i].view(-1)), dim=0) support['label_1'] = torch.cat( (support['label_1'], support['label'][i].view(-1)), dim=0) support['text_2'] = support['text'][0].view((1, -1)) support['text_len_2'] = support['text_len'][0].view(-1) support['label_2'] = support['label'][0].view(-1) for i in range(text_sample_len): if i == 0: for j in range(1, text_sample_len): support['text_2'] = torch.cat( (support['text_2'], support['text'][j].view((1, -1))), dim=0) support['text_len_2'] = torch.cat( (support['text_len_2'], support['text_len'][j].view(-1)), dim=0) support['label_2'] = torch.cat( (support['label_2'], support['label'][j].view(-1)), dim=0) else: for j in range(text_sample_len): support['text_2'] = torch.cat( (support['text_2'], support['text'][j].view((1, -1))), dim=0) support['text_len_2'] = torch.cat( (support['text_len_2'], support['text_len'][j].view(-1)), dim=0) support['label_2'] = torch.cat( (support['label_2'], support['label'][j].view(-1)), dim=0) # print("support['text_1']:", support['text_1'].shape, support['text_len_1'].shape, support['label_1'].shape) # print("support['text_2']:", support['text_2'].shape, support['text_len_2'].shape, support['label_2'].shape) support['label_final'] = support['label_1'].eq(support['label_2']).int() support_1 = {} support_1['text'] = support['text_1'] support_1['text_len'] = support['text_len_1'] support_1['label'] = support['label_1'] support_2 = {} support_2['text'] = support['text_2'] support_2['text_len'] = support['text_len_2'] support_2['label'] = support['label_2'] # print("**************************************") # print("1111111", support['label_1']) # print("2222222", support['label_2']) # print(support['label_final']) '''first step''' S_out1, S_out2 = model['G'](support_1, support_2) loss = criterion(S_out1, S_out2, support['label_final']) zero_grad(model['G'].parameters()) grads = autograd.grad(loss, model['G'].fc.parameters(), allow_unused=True) fast_weights, orderd_params = model['G'].cloned_fc_dict(), OrderedDict() for (key, val), grad in zip(model['G'].fc.named_parameters(), grads): fast_weights[key] = orderd_params[key] = val - args.task_lr * grad '''steps remaining''' for k in range(args.train_iter - 1): S_out1, S_out2 = model['G'](support_1, support_2, fast_weights) loss = criterion(S_out1, S_out2, support['label_final']) zero_grad(orderd_params.values()) grads = torch.autograd.grad(loss, orderd_params.values(), allow_unused=True) # print('grads:', grads) # print("orderd_params.items():", orderd_params.items()) for (key, val), grad in zip(orderd_params.items(), grads): if grad is not None: fast_weights[key] = orderd_params[ key] = val - args.task_lr * grad """计算Q上的损失""" CN = model['G'].forward_once_with_param(class_names_dict, fast_weights) XQ = model['G'].forward_once_with_param(query, fast_weights) logits_q = neg_dist(XQ, CN) _, pred = torch.max(logits_q, 1) acc_q = model['G'].accuracy(pred, YQ) return acc_q
def train_one(task, class_names, model, optG, criterion, args, grad): ''' Train the model on one sampled task. ''' model['G'].train() # model['G2'].train() # model['clf'].train() support, query = task # print("support, query:", support, query) # print("class_names_dict:", class_names_dict) '''分样本对''' YS = support['label'] YQ = query['label'] sampled_classes = torch.unique(support['label']).cpu().numpy().tolist() # print("sampled_classes:", sampled_classes) class_names_dict = {} class_names_dict['label'] = class_names['label'][sampled_classes] # print("class_names_dict['label']:", class_names_dict['label']) class_names_dict['text'] = class_names['text'][sampled_classes] class_names_dict['text_len'] = class_names['text_len'][sampled_classes] class_names_dict['is_support'] = False class_names_dict = utils.to_tensor(class_names_dict, args.cuda, exclude_keys=['is_support']) YS, YQ = reidx_y(args, YS, YQ) # print('YS:', support['label']) # print('YQ:', query['label']) # print("class_names_dict:", class_names_dict['label']) """维度填充""" if support['text'].shape[1] != class_names_dict['text'].shape[1]: zero = torch.zeros( (class_names_dict['text'].shape[0], support['text'].shape[1] - class_names_dict['text'].shape[1]), dtype=torch.long) class_names_dict['text'] = torch.cat( (class_names_dict['text'], zero.cuda()), dim=-1) support['text'] = torch.cat((support['text'], class_names_dict['text']), dim=0) support['text_len'] = torch.cat( (support['text_len'], class_names_dict['text_len']), dim=0) support['label'] = torch.cat((support['label'], class_names_dict['label']), dim=0) # print("support['text']:", support['text'].shape) # print("support['label']:", support['label']) text_sample_len = support['text'].shape[0] # print("support['text'].shape[0]:", support['text'].shape[0]) support['text_1'] = support['text'][0].view((1, -1)) support['text_len_1'] = support['text_len'][0].view(-1) support['label_1'] = support['label'][0].view(-1) for i in range(text_sample_len): if i == 0: for j in range(1, text_sample_len): support['text_1'] = torch.cat( (support['text_1'], support['text'][i].view((1, -1))), dim=0) support['text_len_1'] = torch.cat( (support['text_len_1'], support['text_len'][i].view(-1)), dim=0) support['label_1'] = torch.cat( (support['label_1'], support['label'][i].view(-1)), dim=0) else: for j in range(text_sample_len): support['text_1'] = torch.cat( (support['text_1'], support['text'][i].view((1, -1))), dim=0) support['text_len_1'] = torch.cat( (support['text_len_1'], support['text_len'][i].view(-1)), dim=0) support['label_1'] = torch.cat( (support['label_1'], support['label'][i].view(-1)), dim=0) support['text_2'] = support['text'][0].view((1, -1)) support['text_len_2'] = support['text_len'][0].view(-1) support['label_2'] = support['label'][0].view(-1) for i in range(text_sample_len): if i == 0: for j in range(1, text_sample_len): support['text_2'] = torch.cat( (support['text_2'], support['text'][j].view((1, -1))), dim=0) support['text_len_2'] = torch.cat( (support['text_len_2'], support['text_len'][j].view(-1)), dim=0) support['label_2'] = torch.cat( (support['label_2'], support['label'][j].view(-1)), dim=0) else: for j in range(text_sample_len): support['text_2'] = torch.cat( (support['text_2'], support['text'][j].view((1, -1))), dim=0) support['text_len_2'] = torch.cat( (support['text_len_2'], support['text_len'][j].view(-1)), dim=0) support['label_2'] = torch.cat( (support['label_2'], support['label'][j].view(-1)), dim=0) # print("support['text_1']:", support['text_1'].shape, support['text_len_1'].shape, support['label_1'].shape) # print("support['text_2']:", support['text_2'].shape, support['text_len_2'].shape, support['label_2'].shape) support['label_final'] = support['label_1'].eq(support['label_2']).int() support_1 = {} support_1['text'] = support['text_1'] support_1['text_len'] = support['text_len_1'] support_1['label'] = support['label_1'] support_2 = {} support_2['text'] = support['text_2'] support_2['text_len'] = support['text_len_2'] support_2['label'] = support['label_2'] # print("**************************************") # print("1111111", support['label_1']) # print("2222222", support['label_2']) # print(support['label_final']) '''first step''' S_out1, S_out2 = model['G'](support_1, support_2) loss = criterion(S_out1, S_out2, support['label_final']) zero_grad(model['G'].parameters()) grads = autograd.grad(loss, model['G'].fc.parameters(), allow_unused=True) fast_weights, orderd_params = model['G'].cloned_fc_dict(), OrderedDict() for (key, val), grad in zip(model['G'].fc.named_parameters(), grads): fast_weights[key] = orderd_params[key] = val - args.task_lr * grad '''steps remaining''' for k in range(args.train_iter - 1): S_out1, S_out2 = model['G'](support_1, support_2, fast_weights) loss = criterion(S_out1, S_out2, support['label_final']) zero_grad(orderd_params.values()) grads = torch.autograd.grad(loss, orderd_params.values(), allow_unused=True) # print('grads:', grads) # print("orderd_params.items():", orderd_params.items()) for (key, val), grad in zip(orderd_params.items(), grads): if grad is not None: fast_weights[key] = orderd_params[ key] = val - args.task_lr * grad """计算Q上的损失""" CN = model['G'].forward_once_with_param(class_names_dict, fast_weights) XQ = model['G'].forward_once_with_param(query, fast_weights) logits_q = neg_dist(XQ, CN) q_loss = model['G'].loss(logits_q, YQ) _, pred = torch.max(logits_q, 1) acc_q = model['G'].accuracy(pred, YQ) optG.zero_grad() q_loss.backward() optG.step() # '把CN过微调过的G, S和Q过G2' # CN = model['G'](class_names_dict) # CN:[N, 256(hidden_size*2)] # # Embedding the document # XS = model['G2'](support) # XS:[N*K, 256(hidden_size*2)] # # print("XS:", XS.shape) # YS = support['label'] # # print('YS:', YS) # # XQ = model['G2'](query) # YQ = query['label'] # # print('YQ:', YQ) # # YS, YQ = reidx_y(args, YS, YQ) # 映射标签为从0开始 # # '第二步:用Support更新MLP' # for _ in range(args.train_iter): # # # Embedding the document # XS_mlp = model['clf'](XS) # [N*K, 256(hidden_size*2)] -> [N*K, 256] # # neg_d = neg_dist(XS_mlp, CN) # [N*K, N] # # print("neg_d:", neg_d.shape) # # mlp_loss = model['clf'].loss(neg_d, YS) # # print("mlp_loss:", mlp_loss) # # optCLF.zero_grad() # mlp_loss.backward(retain_graph=True) # optCLF.step() # # '第三步:用Q更新G2' # XQ_mlp = model['clf'](XQ) # neg_d = neg_dist(XQ_mlp, CN) # q_loss = model['clf'].loss(neg_d, YQ) # optG2.zero_grad() # q_loss.backward() # optG2.step() # # _, pred = torch.max(neg_d, 1) # acc_q = model['clf'].accuracy(pred, YQ) # YQ_d = torch.ones(query['label'].shape, dtype=torch.long).to(query['label'].device) # print('YQ', set(YQ.numpy())) # XSource, XSource_inputD, _ = model['G'](source) # YSource_d = torch.zeros(source['label'].shape, dtype=torch.long).to(source['label'].device) # XQ_logitsD = model['D'](XQ_inputD) # XSource_logitsD = model['D'](XSource_inputD) # # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d) # d_loss.backward(retain_graph=True) # grad['D'].append(get_norm(model['D'])) # optD.step() # # # *****************update G**************** # optG.zero_grad() # XQ_logitsD = model['D'](XQ_inputD) # XSource_logitsD = model['D'](XSource_inputD) # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d) # # acc, d_acc, loss, _ = model['clf'](XS, YS, XQ, YQ, XQ_logitsD, XSource_logitsD, YQ_d, YSource_d) # # g_loss = loss - d_loss # if args.ablation == "-DAN": # g_loss = loss # print("%%%%%%%%%%%%%%%%%%%This is ablation mode: -DAN%%%%%%%%%%%%%%%%%%%%%%%%%%") # g_loss.backward(retain_graph=True) # grad['G'].append(get_norm(model['G'])) # grad['clf'].append(get_norm(model['clf'])) # optG.step() return q_loss, acc_q
def train_one(task, class_names, model, optG, optCLF, args, grad): ''' Train the model on one sampled task. ''' model['G'].train() model['clf'].train() support, query = task # print("support, query:", support, query) # print("class_names_dict:", class_names_dict) sampled_classes = torch.unique(support['label']).cpu().numpy().tolist() # print("sampled_classes:", sampled_classes) class_names_dict = {} class_names_dict['label'] = class_names['label'][sampled_classes] # print("class_names_dict['label']:", class_names_dict['label']) class_names_dict['text'] = class_names['text'][sampled_classes] class_names_dict['text_len'] = class_names['text_len'][sampled_classes] class_names_dict['is_support'] = False class_names_dict = utils.to_tensor(class_names_dict, args.cuda, exclude_keys=['is_support']) # Embedding the document XS = model['G'](support) # XS:[N*K, 256(hidden_size*2)] # print("XS:", XS.shape) YS = support['label'] # print('YS:', YS) CN = model['G'](class_names_dict) # CN:[N, 256(hidden_size*2)]] # print("CN:", CN.shape) XQ = model['G'](query) YQ = query['label'] # print('YQ:', YQ) YS, YQ = reidx_y(args, YS, YQ) for _ in range(args.train_iter): # Embedding the document XS_mlp = model['clf'](XS) # [N*K, 256(hidden_size*2)] -> [N*K, 128] CN_mlp = model['clf'](CN) # [N, 256(hidden_size*2)]] -> [N, 128] neg_d = neg_dist(XS_mlp, CN_mlp) # [N*K, N] # print("neg_d:", neg_d.shape) mlp_loss = model['clf'].loss(neg_d, YS) # print("mlp_loss:", mlp_loss) optCLF.zero_grad() mlp_loss.backward(retain_graph=True) optCLF.step() XQ_mlp = model['clf'](XQ) CN_mlp = model['clf'](CN) neg_d = neg_dist(XQ_mlp, CN_mlp) g_loss = model['clf'].loss(neg_d, YQ) optG.zero_grad() g_loss.backward() optG.step() _, pred = torch.max(neg_d, 1) acc_q = model['clf'].accuracy(pred, YQ) # YQ_d = torch.ones(query['label'].shape, dtype=torch.long).to(query['label'].device) # print('YQ', set(YQ.numpy())) # XSource, XSource_inputD, _ = model['G'](source) # YSource_d = torch.zeros(source['label'].shape, dtype=torch.long).to(source['label'].device) # XQ_logitsD = model['D'](XQ_inputD) # XSource_logitsD = model['D'](XSource_inputD) # # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d) # d_loss.backward(retain_graph=True) # grad['D'].append(get_norm(model['D'])) # optD.step() # # # *****************update G**************** # optG.zero_grad() # XQ_logitsD = model['D'](XQ_inputD) # XSource_logitsD = model['D'](XSource_inputD) # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d) # # acc, d_acc, loss, _ = model['clf'](XS, YS, XQ, YQ, XQ_logitsD, XSource_logitsD, YQ_d, YSource_d) # # g_loss = loss - d_loss # if args.ablation == "-DAN": # g_loss = loss # print("%%%%%%%%%%%%%%%%%%%This is ablation mode: -DAN%%%%%%%%%%%%%%%%%%%%%%%%%%") # g_loss.backward(retain_graph=True) # grad['G'].append(get_norm(model['G'])) # grad['clf'].append(get_norm(model['clf'])) # optG.step() return g_loss, acc_q