def test_one(task, class_names, model, optG, criterion, args, grad): ''' Train the model on one sampled task. ''' model['G'].eval() support, query = task # print("support, query:", support, query) # print("class_names_dict:", class_names_dict) '''分样本对''' YS = support['label'] YQ = query['label'] sampled_classes = torch.unique(support['label']).cpu().numpy().tolist() # print("sampled_classes:", sampled_classes) class_names_dict = {} class_names_dict['label'] = class_names['label'][sampled_classes] # print("class_names_dict['label']:", class_names_dict['label']) class_names_dict['text'] = class_names['text'][sampled_classes] class_names_dict['text_len'] = class_names['text_len'][sampled_classes] class_names_dict['is_support'] = False class_names_dict = utils.to_tensor(class_names_dict, args.cuda, exclude_keys=['is_support']) YS, YQ = reidx_y(args, YS, YQ) # print('YS:', support['label']) # print('YQ:', query['label']) # print("class_names_dict:", class_names_dict['label']) """维度填充""" if support['text'].shape[1] > class_names_dict['text'].shape[1]: zero = torch.zeros( (class_names_dict['text'].shape[0], support['text'].shape[1] - class_names_dict['text'].shape[1]), dtype=torch.long) class_names_dict['text'] = torch.cat( (class_names_dict['text'], zero.cuda()), dim=-1) elif support['text'].shape[1] < class_names_dict['text'].shape[1]: zero = torch.zeros( (support['text'].shape[0], class_names_dict['text'].shape[1] - support['text'].shape[1]), dtype=torch.long) support['text'] = torch.cat((support['text'], zero.cuda()), dim=-1) support['text'] = torch.cat((support['text'], class_names_dict['text']), dim=0) support['text_len'] = torch.cat( (support['text_len'], class_names_dict['text_len']), dim=0) support['label'] = torch.cat((support['label'], class_names_dict['label']), dim=0) # print("support['text']:", support['text'].shape) # print("support['label']:", support['label']) text_sample_len = support['text'].shape[0] # print("support['text'].shape[0]:", support['text'].shape[0]) support['text_1'] = support['text'][0].view((1, -1)) support['text_len_1'] = support['text_len'][0].view(-1) support['label_1'] = support['label'][0].view(-1) for i in range(text_sample_len): if i == 0: for j in range(1, len(sampled_classes)): support['text_1'] = torch.cat( (support['text_1'], support['text'][i].view((1, -1))), dim=0) support['text_len_1'] = torch.cat( (support['text_len_1'], support['text_len'][i].view(-1)), dim=0) support['label_1'] = torch.cat( (support['label_1'], support['label'][i].view(-1)), dim=0) else: for j in range(len(sampled_classes)): support['text_1'] = torch.cat( (support['text_1'], support['text'][i].view((1, -1))), dim=0) support['text_len_1'] = torch.cat( (support['text_len_1'], support['text_len'][i].view(-1)), dim=0) support['label_1'] = torch.cat( (support['label_1'], support['label'][i].view(-1)), dim=0) support['text_2'] = class_names_dict['text'][0].view((1, -1)) support['text_len_2'] = class_names_dict['text_len'][0].view(-1) support['label_2'] = class_names_dict['label'][0].view(-1) for i in range(text_sample_len): if i == 0: for j in range(1, len(sampled_classes)): support['text_2'] = torch.cat( (support['text_2'], class_names_dict['text'][j].view( (1, -1))), dim=0) support['text_len_2'] = torch.cat( (support['text_len_2'], class_names_dict['text_len'][j].view(-1)), dim=0) support['label_2'] = torch.cat( (support['label_2'], class_names_dict['label'][j].view(-1)), dim=0) else: for j in range(len(sampled_classes)): support['text_2'] = torch.cat( (support['text_2'], class_names_dict['text'][j].view( (1, -1))), dim=0) support['text_len_2'] = torch.cat( (support['text_len_2'], class_names_dict['text_len'][j].view(-1)), dim=0) support['label_2'] = torch.cat( (support['label_2'], class_names_dict['label'][j].view(-1)), dim=0) # print("support['text_1']:", support['text_1'].shape, support['text_len_1'].shape, support['label_1'].shape) # print("support['text_2']:", support['text_2'].shape, support['text_len_2'].shape, support['label_2'].shape) support['label_final'] = support['label_1'].eq(support['label_2']).int() support_1 = {} support_1['text'] = support['text_1'] support_1['text_len'] = support['text_len_1'] support_1['label'] = support['label_1'] support_2 = {} support_2['text'] = support['text_2'] support_2['text_len'] = support['text_len_2'] support_2['label'] = support['label_2'] # print("**************************************") # print("1111111", support['label_1']) # print("2222222", support['label_2']) # print(support['label_final']) '''first step''' S_out1, S_out2 = model['G'](support_1, support_2) supp_, que_ = model['G'](support, query) loss_weight = get_weight_of_test_support(supp_, que_, args) loss = criterion(S_out1, S_out2, support['label_final'], loss_weight) # print("s_1_loss:", loss) zero_grad(model['G'].parameters()) grads_fc = autograd.grad(loss, model['G'].fc.parameters(), allow_unused=True, retain_graph=True) fast_weights_fc, orderd_params_fc = model['G'].cloned_fc_dict( ), OrderedDict() for (key, val), grad in zip(model['G'].fc.named_parameters(), grads_fc): fast_weights_fc[key] = orderd_params_fc[ key] = val - args.task_lr * grad grads_conv11 = autograd.grad(loss, model['G'].conv11.parameters(), allow_unused=True, retain_graph=True) fast_weights_conv11, orderd_params_conv11 = model['G'].cloned_conv11_dict( ), OrderedDict() for (key, val), grad in zip(model['G'].conv11.named_parameters(), grads_conv11): fast_weights_conv11[key] = orderd_params_conv11[ key] = val - args.task_lr * grad grads_conv12 = autograd.grad(loss, model['G'].conv12.parameters(), allow_unused=True, retain_graph=True) fast_weights_conv12, orderd_params_conv12 = model['G'].cloned_conv12_dict( ), OrderedDict() for (key, val), grad in zip(model['G'].conv12.named_parameters(), grads_conv12): fast_weights_conv12[key] = orderd_params_conv12[ key] = val - args.task_lr * grad grads_conv13 = autograd.grad(loss, model['G'].conv13.parameters(), allow_unused=True) fast_weights_conv13, orderd_params_conv13 = model['G'].cloned_conv13_dict( ), OrderedDict() for (key, val), grad in zip(model['G'].conv13.named_parameters(), grads_conv13): fast_weights_conv13[key] = orderd_params_conv13[ key] = val - args.task_lr * grad fast_weights = {} fast_weights['fc'] = fast_weights_fc fast_weights['conv11'] = fast_weights_conv11 fast_weights['conv12'] = fast_weights_conv12 fast_weights['conv13'] = fast_weights_conv13 '''steps remaining''' for k in range(args.test_iter - 1): S_out1, S_out2 = model['G'](support_1, support_2, fast_weights) supp_, que_ = model['G'](support, query, fast_weights) loss_weight = get_weight_of_test_support(supp_, que_, args) loss = criterion(S_out1, S_out2, support['label_final'], loss_weight) # print("train_iter: {} s_loss:{}".format(k, loss)) zero_grad(orderd_params_fc.values()) zero_grad(orderd_params_conv11.values()) zero_grad(orderd_params_conv12.values()) zero_grad(orderd_params_conv13.values()) grads_fc = torch.autograd.grad(loss, orderd_params_fc.values(), allow_unused=True, retain_graph=True) grads_conv11 = torch.autograd.grad(loss, orderd_params_conv11.values(), allow_unused=True, retain_graph=True) grads_conv12 = torch.autograd.grad(loss, orderd_params_conv12.values(), allow_unused=True, retain_graph=True) grads_conv13 = torch.autograd.grad(loss, orderd_params_conv13.values(), allow_unused=True) for (key, val), grad in zip(orderd_params_fc.items(), grads_fc): if grad is not None: fast_weights['fc'][key] = orderd_params_fc[ key] = val - args.task_lr * grad for (key, val), grad in zip(orderd_params_conv11.items(), grads_conv11): if grad is not None: fast_weights['conv11'][key] = orderd_params_conv11[ key] = val - args.task_lr * grad for (key, val), grad in zip(orderd_params_conv12.items(), grads_conv12): if grad is not None: fast_weights['conv12'][key] = orderd_params_conv12[ key] = val - args.task_lr * grad for (key, val), grad in zip(orderd_params_conv13.items(), grads_conv13): if grad is not None: fast_weights['conv13'][key] = orderd_params_conv13[ key] = val - args.task_lr * grad """计算Q上的损失""" CN = model['G'].forward_once_with_param(class_names_dict, fast_weights) XQ = model['G'].forward_once_with_param(query, fast_weights) logits_q = pos_dist(XQ, CN) logits_q = dis_to_level(logits_q) _, pred = torch.max(logits_q, 1) acc_q = model['G'].accuracy(pred, YQ) return acc_q
def test_one(task, class_names, model, optCLF, args, grad): ''' Train the model on one sampled task. ''' # model['G'].eval() # model['clf'].train() support, query = task # print("support, query:", support, query) # print("class_names_dict:", class_names_dict) sampled_classes = torch.unique(support['label']).cpu().numpy().tolist() # print("sampled_classes:", sampled_classes) class_names_dict = {} class_names_dict['label'] = class_names['label'][sampled_classes] # print("class_names_dict['label']:", class_names_dict['label']) class_names_dict['text'] = class_names['text'][sampled_classes] class_names_dict['text_len'] = class_names['text_len'][sampled_classes] class_names_dict['is_support'] = False class_names_dict = utils.to_tensor(class_names_dict, args.cuda, exclude_keys=['is_support']) # Embedding the document XS = model['G'](support) # XS:[N*K, 256(hidden_size*2)] # print("XS:", XS.shape) YS = support['label'] # print('YS:', YS) CN = model['G'](class_names_dict) # CN:[N, 256(hidden_size*2)]] # print("CN:", CN.shape) XQ = model['G'](query) YQ = query['label'] # print('YQ:', YQ) YS, YQ = reidx_y(args, YS, YQ) for _ in range(args.test_iter): # Embedding the document XS_mlp = model['clf'](XS) # [N*K, 256(hidden_size*2)] -> [N*K, 128] CN_mlp = model['clf'](CN) # [N, 256(hidden_size*2)]] -> [N, 128] neg_d = neg_dist(XS_mlp, CN_mlp) # [N*K, N] # print("neg_d:", neg_d.shape) mlp_loss = model['clf'].loss(neg_d, YS) # print("mlp_loss:", mlp_loss) optCLF.zero_grad() mlp_loss.backward(retain_graph=True) optCLF.step() XQ_mlp = model['clf'](XQ) CN_mlp = model['clf'](CN) neg_d = neg_dist(XQ_mlp, CN_mlp) _, pred = torch.max(neg_d, 1) acc_q = model['clf'].accuracy(pred, YQ) return acc_q
def test_one(task, class_names_dict, model, optG, optCLF, args, grad): ''' Train the model on one sampled task. ''' support, query = task # print("support, query:", support, query) # print("class_names_dict:", class_names_dict) '第一步:更新G,让类描述相互离得远一些' cn_loss_all = 0 for _ in range(5): CN = model['G'](class_names_dict) # CN:[N, 256(hidden_size*2)] # print("CN:", CN.shape) dis = neg_dist(CN, CN) # [N, N] cn_loss = torch.sum(dis) - torch.sum(torch.diag(dis)) cn_loss_all += cn_loss optG.zero_grad() cn_loss.backward(retain_graph=True) optG.step() print('***********[TEST] cn_loss:', cn_loss_all / 5) '把CN过微调过的G, S和Q过G2' # Embedding the document XS = model['G2'](support) # XS:[N*K, 256(hidden_size*2)] # print("XS:", XS.shape) YS = support['label'] # print('YS:', YS) CN = model['G'](class_names_dict) # CN:[N, 256(hidden_size*2)]] # print("CN:", CN.shape) XQ = model['G2'](query) YQ = query['label'] # print('YQ:', YQ) YS, YQ = reidx_y(args, YS, YQ) '第二步:用Support更新MLP' for _ in range(args.test_iter): # Embedding the document XS_mlp = model['clf'](XS) # [N*K, 256(hidden_size*2)] -> [N*K, 256] neg_d = neg_dist(XS_mlp, CN) # [N*K, N] # print("neg_d:", neg_d.shape) mlp_loss = model['clf'].loss(neg_d, YS) # print("mlp_loss:", mlp_loss) optCLF.zero_grad() mlp_loss.backward(retain_graph=True) optCLF.step() XQ_mlp = model['clf'](XQ) neg_d = neg_dist(XQ_mlp, CN) _, pred = torch.max(neg_d, 1) acc_q = model['clf'].accuracy(pred, YQ) return acc_q
def train_one(task, class_names_dict, model, optG, optG2, optCLF, args, grad): ''' Train the model on one sampled task. ''' model['G'].train() model['G2'].train() model['clf'].train() support, query = task # print("support, query:", support, query) # print("class_names_dict:", class_names_dict) '第一步:更新G,让类描述相互离得远一些' cn_loss_all = 0 for _ in range(5): CN = model['G'](class_names_dict) # CN:[N, 256(hidden_size*2)] # print("CN:", CN) dis = neg_dist(CN, CN) / torch.mean(neg_dist(CN, CN), dim=0) # [N, N] # print("dis:", dis) m = torch.tensor(2.0).cuda() cn_loss = torch.tensor(0.0).cuda() for i, d in enumerate(dis): for j, dd in enumerate(d): if i != j: cn_loss = cn_loss + ((torch.max( torch.tensor(0.0).cuda(), m + dd))**2) / 2 / 20 print("cn_loss:", cn_loss) # cn_loss = cn_loss print("********************cn_loss:", cn_loss) cn_loss_all += cn_loss for name, param in model['G'].named_parameters(): print("name:", name, "param:", param) optG.zero_grad() cn_loss.backward() optG.step() print('***********cn_loss:', cn_loss_all / 5) '把CN过微调过的G, S和Q过G2' CN = model['G'](class_names_dict) # CN:[N, 256(hidden_size*2)] # Embedding the document XS = model['G2'](support) # XS:[N*K, 256(hidden_size*2)] # print("XS:", XS.shape) YS = support['label'] # print('YS:', YS) XQ = model['G2'](query) YQ = query['label'] # print('YQ:', YQ) YS, YQ = reidx_y(args, YS, YQ) # 映射标签为从0开始 '第二步:用Support更新MLP' for _ in range(args.train_iter): # Embedding the document XS_mlp = model['clf'](XS) # [N*K, 256(hidden_size*2)] -> [N*K, 256] neg_d = neg_dist(XS_mlp, CN) # [N*K, N] # print("neg_d:", neg_d.shape) mlp_loss = model['clf'].loss(neg_d, YS) # print("mlp_loss:", mlp_loss) optCLF.zero_grad() mlp_loss.backward(retain_graph=True) optCLF.step() '第三步:用Q更新G2' XQ_mlp = model['clf'](XQ) neg_d = neg_dist(XQ_mlp, CN) q_loss = model['clf'].loss(neg_d, YQ) optG2.zero_grad() q_loss.backward() optG2.step() _, pred = torch.max(neg_d, 1) acc_q = model['clf'].accuracy(pred, YQ) # YQ_d = torch.ones(query['label'].shape, dtype=torch.long).to(query['label'].device) # print('YQ', set(YQ.numpy())) # XSource, XSource_inputD, _ = model['G'](source) # YSource_d = torch.zeros(source['label'].shape, dtype=torch.long).to(source['label'].device) # XQ_logitsD = model['D'](XQ_inputD) # XSource_logitsD = model['D'](XSource_inputD) # # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d) # d_loss.backward(retain_graph=True) # grad['D'].append(get_norm(model['D'])) # optD.step() # # # *****************update G**************** # optG.zero_grad() # XQ_logitsD = model['D'](XQ_inputD) # XSource_logitsD = model['D'](XSource_inputD) # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d) # # acc, d_acc, loss, _ = model['clf'](XS, YS, XQ, YQ, XQ_logitsD, XSource_logitsD, YQ_d, YSource_d) # # g_loss = loss - d_loss # if args.ablation == "-DAN": # g_loss = loss # print("%%%%%%%%%%%%%%%%%%%This is ablation mode: -DAN%%%%%%%%%%%%%%%%%%%%%%%%%%") # g_loss.backward(retain_graph=True) # grad['G'].append(get_norm(model['G'])) # grad['clf'].append(get_norm(model['clf'])) # optG.step() return q_loss, acc_q
def test_one(task, class_names, model, optG, criterion, args, grad): ''' Train the model on one sampled task. ''' support, query = task # print("support, query:", support, query) # print("class_names_dict:", class_names_dict) '''分样本对''' YS = support['label'] YQ = query['label'] sampled_classes = torch.unique(support['label']).cpu().numpy().tolist() # print("sampled_classes:", sampled_classes) class_names_dict = {} class_names_dict['label'] = class_names['label'][sampled_classes] # print("class_names_dict['label']:", class_names_dict['label']) class_names_dict['text'] = class_names['text'][sampled_classes] class_names_dict['text_len'] = class_names['text_len'][sampled_classes] class_names_dict['is_support'] = False class_names_dict = utils.to_tensor(class_names_dict, args.cuda, exclude_keys=['is_support']) YS, YQ = reidx_y(args, YS, YQ) # print('YS:', support['label']) # print('YQ:', query['label']) # print("class_names_dict:", class_names_dict['label']) """维度填充""" if support['text'].shape[1] != class_names_dict['text'].shape[1]: zero = torch.zeros( (class_names_dict['text'].shape[0], support['text'].shape[1] - class_names_dict['text'].shape[1]), dtype=torch.long) class_names_dict['text'] = torch.cat( (class_names_dict['text'], zero.cuda()), dim=-1) support['text'] = torch.cat((support['text'], class_names_dict['text']), dim=0) support['text_len'] = torch.cat( (support['text_len'], class_names_dict['text_len']), dim=0) support['label'] = torch.cat((support['label'], class_names_dict['label']), dim=0) # print("support['text']:", support['text'].shape) # print("support['label']:", support['label']) text_sample_len = support['text'].shape[0] # print("support['text'].shape[0]:", support['text'].shape[0]) support['text_1'] = support['text'][0].view((1, -1)) support['text_len_1'] = support['text_len'][0].view(-1) support['label_1'] = support['label'][0].view(-1) for i in range(text_sample_len): if i == 0: for j in range(1, text_sample_len): support['text_1'] = torch.cat( (support['text_1'], support['text'][i].view((1, -1))), dim=0) support['text_len_1'] = torch.cat( (support['text_len_1'], support['text_len'][i].view(-1)), dim=0) support['label_1'] = torch.cat( (support['label_1'], support['label'][i].view(-1)), dim=0) else: for j in range(text_sample_len): support['text_1'] = torch.cat( (support['text_1'], support['text'][i].view((1, -1))), dim=0) support['text_len_1'] = torch.cat( (support['text_len_1'], support['text_len'][i].view(-1)), dim=0) support['label_1'] = torch.cat( (support['label_1'], support['label'][i].view(-1)), dim=0) support['text_2'] = support['text'][0].view((1, -1)) support['text_len_2'] = support['text_len'][0].view(-1) support['label_2'] = support['label'][0].view(-1) for i in range(text_sample_len): if i == 0: for j in range(1, text_sample_len): support['text_2'] = torch.cat( (support['text_2'], support['text'][j].view((1, -1))), dim=0) support['text_len_2'] = torch.cat( (support['text_len_2'], support['text_len'][j].view(-1)), dim=0) support['label_2'] = torch.cat( (support['label_2'], support['label'][j].view(-1)), dim=0) else: for j in range(text_sample_len): support['text_2'] = torch.cat( (support['text_2'], support['text'][j].view((1, -1))), dim=0) support['text_len_2'] = torch.cat( (support['text_len_2'], support['text_len'][j].view(-1)), dim=0) support['label_2'] = torch.cat( (support['label_2'], support['label'][j].view(-1)), dim=0) # print("support['text_1']:", support['text_1'].shape, support['text_len_1'].shape, support['label_1'].shape) # print("support['text_2']:", support['text_2'].shape, support['text_len_2'].shape, support['label_2'].shape) support['label_final'] = support['label_1'].eq(support['label_2']).int() support_1 = {} support_1['text'] = support['text_1'] support_1['text_len'] = support['text_len_1'] support_1['label'] = support['label_1'] support_2 = {} support_2['text'] = support['text_2'] support_2['text_len'] = support['text_len_2'] support_2['label'] = support['label_2'] # print("**************************************") # print("1111111", support['label_1']) # print("2222222", support['label_2']) # print(support['label_final']) '''first step''' S_out1, S_out2 = model['G'](support_1, support_2) loss = criterion(S_out1, S_out2, support['label_final']) zero_grad(model['G'].parameters()) grads = autograd.grad(loss, model['G'].fc.parameters(), allow_unused=True) fast_weights, orderd_params = model['G'].cloned_fc_dict(), OrderedDict() for (key, val), grad in zip(model['G'].fc.named_parameters(), grads): fast_weights[key] = orderd_params[key] = val - args.task_lr * grad '''steps remaining''' for k in range(args.train_iter - 1): S_out1, S_out2 = model['G'](support_1, support_2, fast_weights) loss = criterion(S_out1, S_out2, support['label_final']) zero_grad(orderd_params.values()) grads = torch.autograd.grad(loss, orderd_params.values(), allow_unused=True) # print('grads:', grads) # print("orderd_params.items():", orderd_params.items()) for (key, val), grad in zip(orderd_params.items(), grads): if grad is not None: fast_weights[key] = orderd_params[ key] = val - args.task_lr * grad """计算Q上的损失""" CN = model['G'].forward_once_with_param(class_names_dict, fast_weights) XQ = model['G'].forward_once_with_param(query, fast_weights) logits_q = neg_dist(XQ, CN) _, pred = torch.max(logits_q, 1) acc_q = model['G'].accuracy(pred, YQ) return acc_q
def train_one(task, class_names, model, optG, criterion, args, grad): ''' Train the model on one sampled task. ''' model['G'].train() # model['G2'].train() # model['clf'].train() support, query = task # print("support, query:", support, query) # print("class_names_dict:", class_names_dict) '''分样本对''' YS = support['label'] YQ = query['label'] sampled_classes = torch.unique(support['label']).cpu().numpy().tolist() # print("sampled_classes:", sampled_classes) class_names_dict = {} class_names_dict['label'] = class_names['label'][sampled_classes] # print("class_names_dict['label']:", class_names_dict['label']) class_names_dict['text'] = class_names['text'][sampled_classes] class_names_dict['text_len'] = class_names['text_len'][sampled_classes] class_names_dict['is_support'] = False class_names_dict = utils.to_tensor(class_names_dict, args.cuda, exclude_keys=['is_support']) YS, YQ = reidx_y(args, YS, YQ) # print('YS:', support['label']) # print('YQ:', query['label']) # print("class_names_dict:", class_names_dict['label']) """维度填充""" if support['text'].shape[1] != class_names_dict['text'].shape[1]: zero = torch.zeros( (class_names_dict['text'].shape[0], support['text'].shape[1] - class_names_dict['text'].shape[1]), dtype=torch.long) class_names_dict['text'] = torch.cat( (class_names_dict['text'], zero.cuda()), dim=-1) support['text'] = torch.cat((support['text'], class_names_dict['text']), dim=0) support['text_len'] = torch.cat( (support['text_len'], class_names_dict['text_len']), dim=0) support['label'] = torch.cat((support['label'], class_names_dict['label']), dim=0) # print("support['text']:", support['text'].shape) # print("support['label']:", support['label']) text_sample_len = support['text'].shape[0] # print("support['text'].shape[0]:", support['text'].shape[0]) support['text_1'] = support['text'][0].view((1, -1)) support['text_len_1'] = support['text_len'][0].view(-1) support['label_1'] = support['label'][0].view(-1) for i in range(text_sample_len): if i == 0: for j in range(1, text_sample_len): support['text_1'] = torch.cat( (support['text_1'], support['text'][i].view((1, -1))), dim=0) support['text_len_1'] = torch.cat( (support['text_len_1'], support['text_len'][i].view(-1)), dim=0) support['label_1'] = torch.cat( (support['label_1'], support['label'][i].view(-1)), dim=0) else: for j in range(text_sample_len): support['text_1'] = torch.cat( (support['text_1'], support['text'][i].view((1, -1))), dim=0) support['text_len_1'] = torch.cat( (support['text_len_1'], support['text_len'][i].view(-1)), dim=0) support['label_1'] = torch.cat( (support['label_1'], support['label'][i].view(-1)), dim=0) support['text_2'] = support['text'][0].view((1, -1)) support['text_len_2'] = support['text_len'][0].view(-1) support['label_2'] = support['label'][0].view(-1) for i in range(text_sample_len): if i == 0: for j in range(1, text_sample_len): support['text_2'] = torch.cat( (support['text_2'], support['text'][j].view((1, -1))), dim=0) support['text_len_2'] = torch.cat( (support['text_len_2'], support['text_len'][j].view(-1)), dim=0) support['label_2'] = torch.cat( (support['label_2'], support['label'][j].view(-1)), dim=0) else: for j in range(text_sample_len): support['text_2'] = torch.cat( (support['text_2'], support['text'][j].view((1, -1))), dim=0) support['text_len_2'] = torch.cat( (support['text_len_2'], support['text_len'][j].view(-1)), dim=0) support['label_2'] = torch.cat( (support['label_2'], support['label'][j].view(-1)), dim=0) # print("support['text_1']:", support['text_1'].shape, support['text_len_1'].shape, support['label_1'].shape) # print("support['text_2']:", support['text_2'].shape, support['text_len_2'].shape, support['label_2'].shape) support['label_final'] = support['label_1'].eq(support['label_2']).int() support_1 = {} support_1['text'] = support['text_1'] support_1['text_len'] = support['text_len_1'] support_1['label'] = support['label_1'] support_2 = {} support_2['text'] = support['text_2'] support_2['text_len'] = support['text_len_2'] support_2['label'] = support['label_2'] # print("**************************************") # print("1111111", support['label_1']) # print("2222222", support['label_2']) # print(support['label_final']) '''first step''' S_out1, S_out2 = model['G'](support_1, support_2) loss = criterion(S_out1, S_out2, support['label_final']) zero_grad(model['G'].parameters()) grads = autograd.grad(loss, model['G'].fc.parameters(), allow_unused=True) fast_weights, orderd_params = model['G'].cloned_fc_dict(), OrderedDict() for (key, val), grad in zip(model['G'].fc.named_parameters(), grads): fast_weights[key] = orderd_params[key] = val - args.task_lr * grad '''steps remaining''' for k in range(args.train_iter - 1): S_out1, S_out2 = model['G'](support_1, support_2, fast_weights) loss = criterion(S_out1, S_out2, support['label_final']) zero_grad(orderd_params.values()) grads = torch.autograd.grad(loss, orderd_params.values(), allow_unused=True) # print('grads:', grads) # print("orderd_params.items():", orderd_params.items()) for (key, val), grad in zip(orderd_params.items(), grads): if grad is not None: fast_weights[key] = orderd_params[ key] = val - args.task_lr * grad """计算Q上的损失""" CN = model['G'].forward_once_with_param(class_names_dict, fast_weights) XQ = model['G'].forward_once_with_param(query, fast_weights) logits_q = neg_dist(XQ, CN) q_loss = model['G'].loss(logits_q, YQ) _, pred = torch.max(logits_q, 1) acc_q = model['G'].accuracy(pred, YQ) optG.zero_grad() q_loss.backward() optG.step() # '把CN过微调过的G, S和Q过G2' # CN = model['G'](class_names_dict) # CN:[N, 256(hidden_size*2)] # # Embedding the document # XS = model['G2'](support) # XS:[N*K, 256(hidden_size*2)] # # print("XS:", XS.shape) # YS = support['label'] # # print('YS:', YS) # # XQ = model['G2'](query) # YQ = query['label'] # # print('YQ:', YQ) # # YS, YQ = reidx_y(args, YS, YQ) # 映射标签为从0开始 # # '第二步:用Support更新MLP' # for _ in range(args.train_iter): # # # Embedding the document # XS_mlp = model['clf'](XS) # [N*K, 256(hidden_size*2)] -> [N*K, 256] # # neg_d = neg_dist(XS_mlp, CN) # [N*K, N] # # print("neg_d:", neg_d.shape) # # mlp_loss = model['clf'].loss(neg_d, YS) # # print("mlp_loss:", mlp_loss) # # optCLF.zero_grad() # mlp_loss.backward(retain_graph=True) # optCLF.step() # # '第三步:用Q更新G2' # XQ_mlp = model['clf'](XQ) # neg_d = neg_dist(XQ_mlp, CN) # q_loss = model['clf'].loss(neg_d, YQ) # optG2.zero_grad() # q_loss.backward() # optG2.step() # # _, pred = torch.max(neg_d, 1) # acc_q = model['clf'].accuracy(pred, YQ) # YQ_d = torch.ones(query['label'].shape, dtype=torch.long).to(query['label'].device) # print('YQ', set(YQ.numpy())) # XSource, XSource_inputD, _ = model['G'](source) # YSource_d = torch.zeros(source['label'].shape, dtype=torch.long).to(source['label'].device) # XQ_logitsD = model['D'](XQ_inputD) # XSource_logitsD = model['D'](XSource_inputD) # # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d) # d_loss.backward(retain_graph=True) # grad['D'].append(get_norm(model['D'])) # optD.step() # # # *****************update G**************** # optG.zero_grad() # XQ_logitsD = model['D'](XQ_inputD) # XSource_logitsD = model['D'](XSource_inputD) # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d) # # acc, d_acc, loss, _ = model['clf'](XS, YS, XQ, YQ, XQ_logitsD, XSource_logitsD, YQ_d, YSource_d) # # g_loss = loss - d_loss # if args.ablation == "-DAN": # g_loss = loss # print("%%%%%%%%%%%%%%%%%%%This is ablation mode: -DAN%%%%%%%%%%%%%%%%%%%%%%%%%%") # g_loss.backward(retain_graph=True) # grad['G'].append(get_norm(model['G'])) # grad['clf'].append(get_norm(model['clf'])) # optG.step() return q_loss, acc_q
def train_one(task, class_names, model, optG, optCLF, args, grad): ''' Train the model on one sampled task. ''' model['G'].train() model['clf'].train() support, query = task # print("support, query:", support, query) # print("class_names_dict:", class_names_dict) sampled_classes = torch.unique(support['label']).cpu().numpy().tolist() # print("sampled_classes:", sampled_classes) class_names_dict = {} class_names_dict['label'] = class_names['label'][sampled_classes] # print("class_names_dict['label']:", class_names_dict['label']) class_names_dict['text'] = class_names['text'][sampled_classes] class_names_dict['text_len'] = class_names['text_len'][sampled_classes] class_names_dict['is_support'] = False class_names_dict = utils.to_tensor(class_names_dict, args.cuda, exclude_keys=['is_support']) # Embedding the document XS = model['G'](support) # XS:[N*K, 256(hidden_size*2)] # print("XS:", XS.shape) YS = support['label'] # print('YS:', YS) CN = model['G'](class_names_dict) # CN:[N, 256(hidden_size*2)]] # print("CN:", CN.shape) XQ = model['G'](query) YQ = query['label'] # print('YQ:', YQ) YS, YQ = reidx_y(args, YS, YQ) for _ in range(args.train_iter): # Embedding the document XS_mlp = model['clf'](XS) # [N*K, 256(hidden_size*2)] -> [N*K, 128] CN_mlp = model['clf'](CN) # [N, 256(hidden_size*2)]] -> [N, 128] neg_d = neg_dist(XS_mlp, CN_mlp) # [N*K, N] # print("neg_d:", neg_d.shape) mlp_loss = model['clf'].loss(neg_d, YS) # print("mlp_loss:", mlp_loss) optCLF.zero_grad() mlp_loss.backward(retain_graph=True) optCLF.step() XQ_mlp = model['clf'](XQ) CN_mlp = model['clf'](CN) neg_d = neg_dist(XQ_mlp, CN_mlp) g_loss = model['clf'].loss(neg_d, YQ) optG.zero_grad() g_loss.backward() optG.step() _, pred = torch.max(neg_d, 1) acc_q = model['clf'].accuracy(pred, YQ) # YQ_d = torch.ones(query['label'].shape, dtype=torch.long).to(query['label'].device) # print('YQ', set(YQ.numpy())) # XSource, XSource_inputD, _ = model['G'](source) # YSource_d = torch.zeros(source['label'].shape, dtype=torch.long).to(source['label'].device) # XQ_logitsD = model['D'](XQ_inputD) # XSource_logitsD = model['D'](XSource_inputD) # # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d) # d_loss.backward(retain_graph=True) # grad['D'].append(get_norm(model['D'])) # optD.step() # # # *****************update G**************** # optG.zero_grad() # XQ_logitsD = model['D'](XQ_inputD) # XSource_logitsD = model['D'](XSource_inputD) # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d) # # acc, d_acc, loss, _ = model['clf'](XS, YS, XQ, YQ, XQ_logitsD, XSource_logitsD, YQ_d, YSource_d) # # g_loss = loss - d_loss # if args.ablation == "-DAN": # g_loss = loss # print("%%%%%%%%%%%%%%%%%%%This is ablation mode: -DAN%%%%%%%%%%%%%%%%%%%%%%%%%%") # g_loss.backward(retain_graph=True) # grad['G'].append(get_norm(model['G'])) # grad['clf'].append(get_norm(model['clf'])) # optG.step() return g_loss, acc_q