def get_epoch(self): for _ in range(self.num_episodes): # wait until self.thread finishes support, query = self.done_queue.get() # convert to torch.tensor support = utils.to_tensor(support, self.args.cuda, ['raw']) query = utils.to_tensor(query, self.args.cuda, ['raw']) if 'bert_id' in support.keys(): # run bert to get ebd support['ebd'] = self.get_bert( support['bert_id'], support['text_len']+2) query['ebd'] = self.get_bert( query['bert_id'], query['text_len']+2) if self.args.meta_w_target: if self.args.meta_target_entropy: w = stats.get_w_target( support, self.data['vocab_size'], self.data['avg_ebd'], self.args.meta_w_target_lam) else: # use rr approxmation (this one is faster) w = stats.get_w_target_rr( support, self.data['vocab_size'], self.data['avg_ebd'], self.args.meta_w_target_lam) support['w_target'] = w.detach() query['w_target'] = w.detach() support['is_support'] = True query['is_support'] = False yield support, query
def __getitem__(self, item): video = self.videos[item] metadata_file = os.path.join(self.root_dir, video, 'metadata.pkl') with open(metadata_file, 'rb') as f_in: frame_list = pickle.load(f_in) frame_count = len(frame_list) remains = self.return_count sampled_frames = list() while remains > frame_count: sampled_frames.extend(range(frame_count)) remains -= frame_count sampled_frames.extend(random.sample(range(frame_count), remains)) # sanity check assert len(sampled_frames) == self.return_count x = list() y = list() for i in sampled_frames: f, landmarks = frame_list[i] full_path = os.path.join(self.root_dir, video, f) img = Image.open(full_path).convert('RGB') img = img.resize((self.image_size, self.image_size), Image.LANCZOS) if self.random_flip: indicator = random.random() if indicator > 0.5: # flip img = img.transpose(Image.FLIP_LEFT_RIGHT) landmarks[:, 0] = self.image_size - 1 - landmarks[:, 0] x.append(to_tensor(img, self.normalize)) rendered = plot_landmarks(self.image_size, landmarks) y.append(to_tensor(rendered, self.normalize)) # debug if self._debug: img.save('%d.jpg' % i) rendered.save('%d_lm.jpg' % i) debug_img = plot_landmarks(self.image_size, landmarks, original_image=img) debug_img.save('%d_bg.jpg' % i) x_t = x[0] y_t = y[0] x = torch.stack(x[1:]) # return_count * c * h * w y = torch.stack(y[1:]) return item, x, y, x_t, y_t
def pre_calculate(train_data, class_names, net, args): with torch.no_grad(): all_classes = np.unique(train_data['label']) num_classes = len(all_classes) # 生成sample类时候的概率矩阵 train_class_names = {} train_class_names['text'] = class_names['text'][all_classes] train_class_names['text_len'] = class_names['text_len'][all_classes] train_class_names['label'] = class_names['label'][all_classes] train_class_names = utils.to_tensor(train_class_names, args.cuda) train_class_names_ebd = net.ebd(train_class_names) # [10, 36, 300] train_class_names_ebd = torch.sum( train_class_names_ebd, dim=1) / train_class_names['text_len'].view( (-1, 1)) # [10, 300] dist_metrix = -neg_dist(train_class_names_ebd, train_class_names_ebd) # [10, 10] for i, d in enumerate(dist_metrix): if i == 0: dist_metrix_nodiag = del_tensor_ele(d, i).view((1, -1)) else: dist_metrix_nodiag = torch.cat( (dist_metrix_nodiag, del_tensor_ele(d, i).view((1, -1))), dim=0) prob_metrix = F.softmax(dist_metrix_nodiag, dim=1) # [10, 9] prob_metrix = prob_metrix.cpu().numpy() # 生成sample样本时候的概率矩阵 example_prob_metrix = [] for i, label in enumerate(all_classes): train_examples = {} train_examples['text'] = train_data['text'][train_data['label'] == label] train_examples['text_len'] = train_data['text_len'][ train_data['label'] == label] train_examples['label'] = train_data['label'][train_data['label'] == label] train_examples = utils.to_tensor(train_examples, args.cuda) train_examples_ebd = net.ebd(train_examples) train_examples_ebd = torch.sum( train_examples_ebd, dim=1) / train_examples['text_len'].view( (-1, 1)) # [N, 300] example_prob_metrix_one = -neg_dist( train_class_names_ebd[i].view((1, -1)), train_examples_ebd) example_prob_metrix_one = F.softmax(example_prob_metrix_one, dim=1) # [1, 1000] example_prob_metrix_one = example_prob_metrix_one.cpu().numpy() example_prob_metrix.append(example_prob_metrix_one) return prob_metrix, example_prob_metrix
def get_epoch(self): for _ in range(self.num_episodes): # wait until self.thread finishes support, query = self.done_queue.get() # convert to torch.tensor support = utils.to_tensor(support, self.args.cuda, ['raw']) query = utils.to_tensor(query, self.args.cuda, ['raw']) support['is_support'] = True query['is_support'] = False yield support, query
def test_one(task, class_names, model, optCLF, args, grad): ''' Train the model on one sampled task. ''' # model['G'].eval() # model['clf'].train() support, query = task # print("support, query:", support, query) # print("class_names_dict:", class_names_dict) sampled_classes = torch.unique(support['label']).cpu().numpy().tolist() # print("sampled_classes:", sampled_classes) class_names_dict = {} class_names_dict['label'] = class_names['label'][sampled_classes] # print("class_names_dict['label']:", class_names_dict['label']) class_names_dict['text'] = class_names['text'][sampled_classes] class_names_dict['text_len'] = class_names['text_len'][sampled_classes] class_names_dict['is_support'] = False class_names_dict = utils.to_tensor(class_names_dict, args.cuda, exclude_keys=['is_support']) # Embedding the document XS = model['G'](support) # XS:[N*K, 256(hidden_size*2)] # print("XS:", XS.shape) YS = support['label'] # print('YS:', YS) CN = model['G'](class_names_dict) # CN:[N, 256(hidden_size*2)]] # print("CN:", CN.shape) XQ = model['G'](query) YQ = query['label'] # print('YQ:', YQ) YS, YQ = reidx_y(args, YS, YQ) for _ in range(args.test_iter): # Embedding the document XS_mlp = model['clf'](XS) # [N*K, 256(hidden_size*2)] -> [N*K, 128] CN_mlp = model['clf'](CN) # [N, 256(hidden_size*2)]] -> [N, 128] neg_d = neg_dist(XS_mlp, CN_mlp) # [N*K, N] # print("neg_d:", neg_d.shape) mlp_loss = model['clf'].loss(neg_d, YS) # print("mlp_loss:", mlp_loss) optCLF.zero_grad() mlp_loss.backward(retain_graph=True) optCLF.step() XQ_mlp = model['clf'](XQ) CN_mlp = model['clf'](CN) neg_d = neg_dist(XQ_mlp, CN_mlp) _, pred = torch.max(neg_d, 1) acc_q = model['clf'].accuracy(pred, YQ) return acc_q
def main(model_file): run_id = datetime.now().strftime('%Y%m%d_%H%M_finetune') output_path = os.path.join('output', run_id) if not os.path.exists(output_path): os.makedirs(output_path) print('The ID of this run: ' + run_id) print('Output directory: ' + output_path) all_people = os.listdir(config.test_dataset) people_count = len(all_people) print('People count: %d' % people_count) for i, person in enumerate(all_people): print('Progress: %d/%d' % (i, people_count)) # T training images should come from the same video xx = sample_frames(person, config.finetune_T) person_t = person while person_t == person: (person_t, ) = random.sample(all_people, 1) xx_t = sample_frames(person_t, 1) xx_all = xx_t + xx x = list() y = list() detector = get_detector('cuda') for filename in xx_all: img = Image.open(filename).convert('RGB') img = img.resize((config.input_size, config.input_size), Image.LANCZOS) x.append(to_tensor(img, config.input_normalize)) arr = np.array(img) landmarks = extract_landmark(detector, arr) rendered = plot_landmarks(config.input_size, landmarks) y.append(to_tensor(rendered, config.input_normalize)) del detector torch.set_grad_enabled(True) x_t = torch.unsqueeze(x[0], dim=0) y_t = torch.unsqueeze(y[0], dim=0) y_t = y_t.cuda() x = torch.stack(x[1:]) # n * c * h * w y = torch.stack(y[1:]) # sanity check assert x.size(0) == config.finetune_T # load models save_data = torch.load(model_file) _, _, _, G_state_dict, E_state_dict, D_state_dict = save_data[:6] G = Generator(config.G_config, config.input_normalize) G = G.eval() G = G.cuda() E = Embedder(config.E_config, config.embedding_dim) E = E.eval() E = E.cuda() D = Discriminator(config.V_config, config.embedding_dim) D = D.eval() D = D.cuda() with torch.no_grad(): E.load_state_dict(E_state_dict) E_input = torch.cat((x, y), dim=1) E_input = E_input.cuda() e_hat = E(E_input) e_hat = e_hat.view(1, -1, config.embedding_dim) e_hat_mean = torch.mean(e_hat, dim=1, keepdim=False) del E P = G_state_dict['P.weight'] adain = torch.matmul(e_hat_mean, torch.transpose(P, 0, 1)) del G_state_dict['P.weight'] adain = adain.view(1, -1, 2) assert adain.size(1) == G.adain_param_count G_state_dict['adain'] = adain.data G.load_state_dict(G_state_dict) del D_state_dict['embedding.weight'] w0 = D_state_dict['w0'] w = w0 + e_hat_mean del D_state_dict['w0'] D_state_dict['w'] = w.data D.load_state_dict(D_state_dict) x_hat_0 = G(y_t) x_hat_0_img = to_pil_image(x_hat_0, config.input_normalize) del x_hat_0 G = G.train() set_grad_enabled(G, True) D = D.train() set_grad_enabled(D, True) # loss L_EG = Loss_EG_finetune(config.vgg19_layers, config.vggface_layers, config.vgg19_weight_file, config.vggface_weight_file, config.vgg19_loss_weight, config.vggface_loss_weight, config.fm_loss_weight, config.input_normalize) L_EG = L_EG.eval() L_EG = L_EG.cuda() set_grad_enabled(L_EG, False) optim_EG = optim.Adam(G.parameters(), lr=config.lr_EG, betas=config.adam_betas) optim_D = optim.Adam(D.parameters(), lr=config.lr_D, betas=config.adam_betas) # dataset dataset = TensorDataset(x, y) dataloader = DataLoader(dataset, batch_size=config.finetune_batch_size, shuffle=config.dataset_shuffle, num_workers=config.num_worker, pin_memory=True, drop_last=False) # finetune for epoch in range(config.finetune_epoch): for _, (xx, yy) in enumerate(dataloader): xx = xx.cuda() yy = yy.cuda() optim_EG.zero_grad() optim_D.zero_grad() x_hat = G(yy) d_output = D(torch.cat((xx, yy), dim=1)) d_output_hat = D(torch.cat((x_hat, yy), dim=1)) d_features = d_output[:-1] d_features_hat = d_output_hat[:-1] d_score = d_output[-1] d_score_hat = d_output_hat[-1] l_eg, l_vgg19, l_vggface, l_cnt, l_adv, l_fm = \ L_EG(xx, x_hat, d_features, d_features_hat, d_score_hat) l_d = Loss_DSC(d_score_hat, d_score) loss = l_eg + l_d loss.backward() optim_EG.step() optim_D.step() # train D again optim_D.zero_grad() x_hat = x_hat.detach() # do not need to train the generator d_output = D(torch.cat((xx, yy), dim=1)) d_output_hat = D(torch.cat((x_hat, yy), dim=1)) d_score = d_output[-1] d_score_hat = d_output_hat[-1] l_d2 = Loss_DSC(d_score_hat, d_score) l_d2.backward() optim_D.step() # after finetuning with torch.no_grad(): x_hat_1 = G(y_t) x_hat_1_img = to_pil_image(x_hat_1, config.input_normalize) del x_hat_1 # save image training_img = Image.new( 'RGB', (config.finetune_T * config.input_size, config.input_size)) for j in range(config.metatrain_T): img = to_pil_image(x[j], config.input_normalize) training_img.paste(img, (j * config.input_size, 0)) training_img.save(os.path.join(output_path, 't_%d.jpg' % i)) x_t_img = to_pil_image(x_t, config.input_normalize) y_t_img = to_pil_image(y_t, config.input_normalize) output_img = Image.new('RGB', (4 * config.input_size, config.input_size)) output_img.paste(x_hat_0_img, (0, 0)) output_img.paste(x_hat_1_img, (config.input_size, 0)) output_img.paste(x_t_img, (2 * config.input_size, 0)) output_img.paste(y_t_img, (3 * config.input_size, 0)) output_img.save(os.path.join(output_path, 'o_%d.jpg' % i))
def test_one(task, class_names, model, optG, criterion, args, grad): ''' Train the model on one sampled task. ''' model['G'].eval() support, query = task # print("support, query:", support, query) # print("class_names_dict:", class_names_dict) '''分样本对''' YS = support['label'] YQ = query['label'] sampled_classes = torch.unique(support['label']).cpu().numpy().tolist() # print("sampled_classes:", sampled_classes) class_names_dict = {} class_names_dict['label'] = class_names['label'][sampled_classes] # print("class_names_dict['label']:", class_names_dict['label']) class_names_dict['text'] = class_names['text'][sampled_classes] class_names_dict['text_len'] = class_names['text_len'][sampled_classes] class_names_dict['is_support'] = False class_names_dict = utils.to_tensor(class_names_dict, args.cuda, exclude_keys=['is_support']) YS, YQ = reidx_y(args, YS, YQ) # print('YS:', support['label']) # print('YQ:', query['label']) # print("class_names_dict:", class_names_dict['label']) """维度填充""" if support['text'].shape[1] > class_names_dict['text'].shape[1]: zero = torch.zeros( (class_names_dict['text'].shape[0], support['text'].shape[1] - class_names_dict['text'].shape[1]), dtype=torch.long) class_names_dict['text'] = torch.cat( (class_names_dict['text'], zero.cuda()), dim=-1) elif support['text'].shape[1] < class_names_dict['text'].shape[1]: zero = torch.zeros( (support['text'].shape[0], class_names_dict['text'].shape[1] - support['text'].shape[1]), dtype=torch.long) support['text'] = torch.cat((support['text'], zero.cuda()), dim=-1) support['text'] = torch.cat((support['text'], class_names_dict['text']), dim=0) support['text_len'] = torch.cat( (support['text_len'], class_names_dict['text_len']), dim=0) support['label'] = torch.cat((support['label'], class_names_dict['label']), dim=0) # print("support['text']:", support['text'].shape) # print("support['label']:", support['label']) text_sample_len = support['text'].shape[0] # print("support['text'].shape[0]:", support['text'].shape[0]) support['text_1'] = support['text'][0].view((1, -1)) support['text_len_1'] = support['text_len'][0].view(-1) support['label_1'] = support['label'][0].view(-1) for i in range(text_sample_len): if i == 0: for j in range(1, len(sampled_classes)): support['text_1'] = torch.cat( (support['text_1'], support['text'][i].view((1, -1))), dim=0) support['text_len_1'] = torch.cat( (support['text_len_1'], support['text_len'][i].view(-1)), dim=0) support['label_1'] = torch.cat( (support['label_1'], support['label'][i].view(-1)), dim=0) else: for j in range(len(sampled_classes)): support['text_1'] = torch.cat( (support['text_1'], support['text'][i].view((1, -1))), dim=0) support['text_len_1'] = torch.cat( (support['text_len_1'], support['text_len'][i].view(-1)), dim=0) support['label_1'] = torch.cat( (support['label_1'], support['label'][i].view(-1)), dim=0) support['text_2'] = class_names_dict['text'][0].view((1, -1)) support['text_len_2'] = class_names_dict['text_len'][0].view(-1) support['label_2'] = class_names_dict['label'][0].view(-1) for i in range(text_sample_len): if i == 0: for j in range(1, len(sampled_classes)): support['text_2'] = torch.cat( (support['text_2'], class_names_dict['text'][j].view( (1, -1))), dim=0) support['text_len_2'] = torch.cat( (support['text_len_2'], class_names_dict['text_len'][j].view(-1)), dim=0) support['label_2'] = torch.cat( (support['label_2'], class_names_dict['label'][j].view(-1)), dim=0) else: for j in range(len(sampled_classes)): support['text_2'] = torch.cat( (support['text_2'], class_names_dict['text'][j].view( (1, -1))), dim=0) support['text_len_2'] = torch.cat( (support['text_len_2'], class_names_dict['text_len'][j].view(-1)), dim=0) support['label_2'] = torch.cat( (support['label_2'], class_names_dict['label'][j].view(-1)), dim=0) # print("support['text_1']:", support['text_1'].shape, support['text_len_1'].shape, support['label_1'].shape) # print("support['text_2']:", support['text_2'].shape, support['text_len_2'].shape, support['label_2'].shape) support['label_final'] = support['label_1'].eq(support['label_2']).int() support_1 = {} support_1['text'] = support['text_1'] support_1['text_len'] = support['text_len_1'] support_1['label'] = support['label_1'] support_2 = {} support_2['text'] = support['text_2'] support_2['text_len'] = support['text_len_2'] support_2['label'] = support['label_2'] # print("**************************************") # print("1111111", support['label_1']) # print("2222222", support['label_2']) # print(support['label_final']) '''first step''' S_out1, S_out2 = model['G'](support_1, support_2) supp_, que_ = model['G'](support, query) loss_weight = get_weight_of_test_support(supp_, que_, args) loss = criterion(S_out1, S_out2, support['label_final'], loss_weight) # print("s_1_loss:", loss) zero_grad(model['G'].parameters()) grads_fc = autograd.grad(loss, model['G'].fc.parameters(), allow_unused=True, retain_graph=True) fast_weights_fc, orderd_params_fc = model['G'].cloned_fc_dict( ), OrderedDict() for (key, val), grad in zip(model['G'].fc.named_parameters(), grads_fc): fast_weights_fc[key] = orderd_params_fc[ key] = val - args.task_lr * grad grads_conv11 = autograd.grad(loss, model['G'].conv11.parameters(), allow_unused=True, retain_graph=True) fast_weights_conv11, orderd_params_conv11 = model['G'].cloned_conv11_dict( ), OrderedDict() for (key, val), grad in zip(model['G'].conv11.named_parameters(), grads_conv11): fast_weights_conv11[key] = orderd_params_conv11[ key] = val - args.task_lr * grad grads_conv12 = autograd.grad(loss, model['G'].conv12.parameters(), allow_unused=True, retain_graph=True) fast_weights_conv12, orderd_params_conv12 = model['G'].cloned_conv12_dict( ), OrderedDict() for (key, val), grad in zip(model['G'].conv12.named_parameters(), grads_conv12): fast_weights_conv12[key] = orderd_params_conv12[ key] = val - args.task_lr * grad grads_conv13 = autograd.grad(loss, model['G'].conv13.parameters(), allow_unused=True) fast_weights_conv13, orderd_params_conv13 = model['G'].cloned_conv13_dict( ), OrderedDict() for (key, val), grad in zip(model['G'].conv13.named_parameters(), grads_conv13): fast_weights_conv13[key] = orderd_params_conv13[ key] = val - args.task_lr * grad fast_weights = {} fast_weights['fc'] = fast_weights_fc fast_weights['conv11'] = fast_weights_conv11 fast_weights['conv12'] = fast_weights_conv12 fast_weights['conv13'] = fast_weights_conv13 '''steps remaining''' for k in range(args.test_iter - 1): S_out1, S_out2 = model['G'](support_1, support_2, fast_weights) supp_, que_ = model['G'](support, query, fast_weights) loss_weight = get_weight_of_test_support(supp_, que_, args) loss = criterion(S_out1, S_out2, support['label_final'], loss_weight) # print("train_iter: {} s_loss:{}".format(k, loss)) zero_grad(orderd_params_fc.values()) zero_grad(orderd_params_conv11.values()) zero_grad(orderd_params_conv12.values()) zero_grad(orderd_params_conv13.values()) grads_fc = torch.autograd.grad(loss, orderd_params_fc.values(), allow_unused=True, retain_graph=True) grads_conv11 = torch.autograd.grad(loss, orderd_params_conv11.values(), allow_unused=True, retain_graph=True) grads_conv12 = torch.autograd.grad(loss, orderd_params_conv12.values(), allow_unused=True, retain_graph=True) grads_conv13 = torch.autograd.grad(loss, orderd_params_conv13.values(), allow_unused=True) for (key, val), grad in zip(orderd_params_fc.items(), grads_fc): if grad is not None: fast_weights['fc'][key] = orderd_params_fc[ key] = val - args.task_lr * grad for (key, val), grad in zip(orderd_params_conv11.items(), grads_conv11): if grad is not None: fast_weights['conv11'][key] = orderd_params_conv11[ key] = val - args.task_lr * grad for (key, val), grad in zip(orderd_params_conv12.items(), grads_conv12): if grad is not None: fast_weights['conv12'][key] = orderd_params_conv12[ key] = val - args.task_lr * grad for (key, val), grad in zip(orderd_params_conv13.items(), grads_conv13): if grad is not None: fast_weights['conv13'][key] = orderd_params_conv13[ key] = val - args.task_lr * grad """计算Q上的损失""" CN = model['G'].forward_once_with_param(class_names_dict, fast_weights) XQ = model['G'].forward_once_with_param(query, fast_weights) logits_q = pos_dist(XQ, CN) logits_q = dis_to_level(logits_q) _, pred = torch.max(logits_q, 1) acc_q = model['G'].accuracy(pred, YQ) return acc_q
def test(test_data, class_names, optG, optCLF, model, args, num_episodes, verbose=True): ''' Evaluate the model on a bag of sampled tasks. Return the mean accuracy and its std. ''' model['G'].train() model['G2'].train() model['clf'].train() acc = [] for ep in range(num_episodes): # if args.embedding == 'mlada': # acc1, d_acc1, sentence_ebd, avg_sentence_ebd, sentence_label, word_weight, query_data, x_hat = test_one(task, model, args) # if count < 20: # if all_sentence_ebd is None: # all_sentence_ebd = sentence_ebd # all_avg_sentence_ebd = avg_sentence_ebd # all_sentence_label = sentence_label # all_word_weight = word_weight # all_query_data = query_data # all_x_hat = x_hat # else: # all_sentence_ebd = np.concatenate((all_sentence_ebd, sentence_ebd), 0) # all_avg_sentence_ebd = np.concatenate((all_avg_sentence_ebd, avg_sentence_ebd), 0) # all_sentence_label = np.concatenate((all_sentence_label, sentence_label)) # all_word_weight = np.concatenate((all_word_weight, word_weight), 0) # all_query_data = np.concatenate((all_query_data, query_data), 0) # all_x_hat = np.concatenate((all_x_hat, x_hat), 0) # count = count + 1 # acc.append(acc1) # d_acc.append(d_acc1) # else: # acc.append(test_one(task, model, args)) sampled_classes, source_classes = task_sampler(test_data, args) class_names_dict = {} class_names_dict['label'] = class_names['label'][sampled_classes] class_names_dict['text'] = class_names['text'][sampled_classes] class_names_dict['text_len'] = class_names['text_len'][sampled_classes] class_names_dict['is_support'] = False train_gen = ParallelSampler(test_data, args, sampled_classes, source_classes, args.train_episodes) sampled_tasks = train_gen.get_epoch() class_names_dict = utils.to_tensor(class_names_dict, args.cuda, exclude_keys=['is_support']) grad = {'clf': [], 'G': []} if not args.notqdm: sampled_tasks = tqdm(sampled_tasks, total=train_gen.num_episodes, ncols=80, leave=False, desc=colored('Training on train', 'yellow')) for task in sampled_tasks: if task is None: break q_acc = test_one(task, class_names_dict, model, optG, optCLF, args, grad) acc.append(q_acc.cpu().item()) acc = np.array(acc) if verbose: if args.embedding != 'mlada': print("{}, {:s} {:>7.4f}, {:s} {:>7.4f}".format( datetime.datetime.now(), colored("test acc mean", "blue"), np.mean(acc), colored("test std", "blue"), np.std(acc), ), flush=True) else: print("{}, {:s} {:>7.4f}, {:s} {:>7.4f}".format( datetime.datetime.now(), colored("test acc mean", "blue"), np.mean(acc), colored("test std", "blue"), np.std(acc), ), flush=True) return np.mean(acc), np.std(acc)
def train(train_data, val_data, model, class_names, args): ''' Train the model Use val_data to do early stopping ''' # creating a tmp directory to save the models out_dir = os.path.abspath( os.path.join(os.path.curdir, "tmp-runs", str(int(time.time() * 1e7)))) if not os.path.exists(out_dir): os.makedirs(out_dir) best_acc = 0 sub_cycle = 0 best_path = None optG = torch.optim.Adam(grad_param(model, ['G']), lr=args.meta_lr) optG2 = torch.optim.Adam(grad_param(model, ['G2']), lr=args.task_lr) optCLF = torch.optim.Adam(grad_param(model, ['clf']), lr=args.task_lr) if args.lr_scheduler == 'ReduceLROnPlateau': schedulerG = torch.optim.lr_scheduler.ReduceLROnPlateau( optG, 'max', patience=args.patience // 2, factor=0.1, verbose=True) schedulerCLF = torch.optim.lr_scheduler.ReduceLROnPlateau( optCLF, 'max', patience=args.patience // 2, factor=0.1, verbose=True) elif args.lr_scheduler == 'ExponentialLR': schedulerG = torch.optim.lr_scheduler.ExponentialLR( optG, gamma=args.ExponentialLR_gamma) schedulerCLF = torch.optim.lr_scheduler.ExponentialLR( optCLF, gamma=args.ExponentialLR_gamma) print("{}, Start training".format(datetime.datetime.now()), flush=True) # train_gen = ParallelSampler(train_data, args, args.train_episodes) # train_gen_val = ParallelSampler_Test(train_data, args, args.val_episodes) # val_gen = ParallelSampler_Test(val_data, args, args.val_episodes) # sampled_classes, source_classes = task_sampler(train_data, args) acc = 0 loss = 0 for ep in range(args.train_epochs): sampled_classes, source_classes = task_sampler(train_data, args) class_names_dict = {} class_names_dict['label'] = class_names['label'][sampled_classes] class_names_dict['text'] = class_names['text'][sampled_classes] class_names_dict['text_len'] = class_names['text_len'][sampled_classes] class_names_dict['is_support'] = False train_gen = ParallelSampler(train_data, args, sampled_classes, source_classes, args.train_episodes) sampled_tasks = train_gen.get_epoch() class_names_dict = utils.to_tensor(class_names_dict, args.cuda, exclude_keys=['is_support']) grad = {'clf': [], 'G': []} if not args.notqdm: sampled_tasks = tqdm(sampled_tasks, total=train_gen.num_episodes, ncols=80, leave=False, desc=colored('Training on train', 'yellow')) for task in sampled_tasks: if task is None: break q_loss, q_acc = train_one(task, class_names_dict, model, optG, optG2, optCLF, args, grad) acc += q_acc loss += q_loss if ep % 100 == 0: print("--------[TRAIN] ep:" + str(ep) + ", loss:" + str(q_loss.item()) + ", acc:" + str(q_acc.item()) + "-----------") if (ep % 200 == 0) and (ep != 0): acc = acc / args.train_episodes / 200 loss = loss / args.train_episodes / 200 print("--------[TRAIN] ep:" + str(ep) + ", mean_loss:" + str(loss.item()) + ", mean_acc:" + str(acc.item()) + "-----------") net = copy.deepcopy(model) acc, std = test(train_data, class_names, optG, optCLF, net, args, args.test_epochs, False) print( "[TRAIN] {}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f} ".format( datetime.datetime.now(), "ep", ep, colored("train", "red"), colored("acc:", "blue"), acc, std, ), flush=True) acc = 0 loss = 0 # Evaluate validation accuracy cur_acc, cur_std = test(val_data, class_names, optG, optCLF, net, args, args.test_epochs, False) print(("[EVAL] {}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f}, " "{:s} {:s}{:>7.4f}, {:s}{:>7.4f}").format( datetime.datetime.now(), "ep", ep, colored("val ", "cyan"), colored("acc:", "blue"), cur_acc, cur_std, colored("train stats", "cyan"), colored("G_grad:", "blue"), np.mean(np.array(grad['G'])), colored("clf_grad:", "blue"), np.mean(np.array(grad['clf'])), ), flush=True) # Update the current best model if val acc is better if cur_acc > best_acc: best_acc = cur_acc best_path = os.path.join(out_dir, str(ep)) # save current model print("{}, Save cur best model to {}".format( datetime.datetime.now(), best_path)) torch.save(model['G'].state_dict(), best_path + '.G') torch.save(model['G2'].state_dict(), best_path + '.G2') torch.save(model['clf'].state_dict(), best_path + '.clf') sub_cycle = 0 else: sub_cycle += 1 # Break if the val acc hasn't improved in the past patience epochs if sub_cycle == args.patience: break if args.lr_scheduler == 'ReduceLROnPlateau': schedulerG.step(cur_acc) schedulerCLF.step(cur_acc) elif args.lr_scheduler == 'ExponentialLR': schedulerG.step() schedulerCLF.step() print("{}, End of training. Restore the best weights".format( datetime.datetime.now()), flush=True) # restore the best saved model model['G'].load_state_dict(torch.load(best_path + '.G')) model['G2'].load_state_dict(torch.load(best_path + '.G2')) model['clf'].load_state_dict(torch.load(best_path + '.clf')) if args.save: # save the current model out_dir = os.path.abspath( os.path.join(os.path.curdir, "saved-runs", str(int(time.time() * 1e7)))) if not os.path.exists(out_dir): os.makedirs(out_dir) best_path = os.path.join(out_dir, 'best') print("{}, Save best model to {}".format(datetime.datetime.now(), best_path), flush=True) torch.save(model['G'].state_dict(), best_path + '.G') torch.save(model['clf'].state_dict(), best_path + '.clf') with open(best_path + '_args.txt', 'w') as f: for attr, value in sorted(args.__dict__.items()): f.write("{}={}\n".format(attr, value)) return optG, optCLF
def test_one(task, class_names, model, optG, criterion, args, grad): ''' Train the model on one sampled task. ''' support, query = task # print("support, query:", support, query) # print("class_names_dict:", class_names_dict) '''分样本对''' YS = support['label'] YQ = query['label'] sampled_classes = torch.unique(support['label']).cpu().numpy().tolist() # print("sampled_classes:", sampled_classes) class_names_dict = {} class_names_dict['label'] = class_names['label'][sampled_classes] # print("class_names_dict['label']:", class_names_dict['label']) class_names_dict['text'] = class_names['text'][sampled_classes] class_names_dict['text_len'] = class_names['text_len'][sampled_classes] class_names_dict['is_support'] = False class_names_dict = utils.to_tensor(class_names_dict, args.cuda, exclude_keys=['is_support']) YS, YQ = reidx_y(args, YS, YQ) # print('YS:', support['label']) # print('YQ:', query['label']) # print("class_names_dict:", class_names_dict['label']) """维度填充""" if support['text'].shape[1] != class_names_dict['text'].shape[1]: zero = torch.zeros( (class_names_dict['text'].shape[0], support['text'].shape[1] - class_names_dict['text'].shape[1]), dtype=torch.long) class_names_dict['text'] = torch.cat( (class_names_dict['text'], zero.cuda()), dim=-1) support['text'] = torch.cat((support['text'], class_names_dict['text']), dim=0) support['text_len'] = torch.cat( (support['text_len'], class_names_dict['text_len']), dim=0) support['label'] = torch.cat((support['label'], class_names_dict['label']), dim=0) # print("support['text']:", support['text'].shape) # print("support['label']:", support['label']) text_sample_len = support['text'].shape[0] # print("support['text'].shape[0]:", support['text'].shape[0]) support['text_1'] = support['text'][0].view((1, -1)) support['text_len_1'] = support['text_len'][0].view(-1) support['label_1'] = support['label'][0].view(-1) for i in range(text_sample_len): if i == 0: for j in range(1, text_sample_len): support['text_1'] = torch.cat( (support['text_1'], support['text'][i].view((1, -1))), dim=0) support['text_len_1'] = torch.cat( (support['text_len_1'], support['text_len'][i].view(-1)), dim=0) support['label_1'] = torch.cat( (support['label_1'], support['label'][i].view(-1)), dim=0) else: for j in range(text_sample_len): support['text_1'] = torch.cat( (support['text_1'], support['text'][i].view((1, -1))), dim=0) support['text_len_1'] = torch.cat( (support['text_len_1'], support['text_len'][i].view(-1)), dim=0) support['label_1'] = torch.cat( (support['label_1'], support['label'][i].view(-1)), dim=0) support['text_2'] = support['text'][0].view((1, -1)) support['text_len_2'] = support['text_len'][0].view(-1) support['label_2'] = support['label'][0].view(-1) for i in range(text_sample_len): if i == 0: for j in range(1, text_sample_len): support['text_2'] = torch.cat( (support['text_2'], support['text'][j].view((1, -1))), dim=0) support['text_len_2'] = torch.cat( (support['text_len_2'], support['text_len'][j].view(-1)), dim=0) support['label_2'] = torch.cat( (support['label_2'], support['label'][j].view(-1)), dim=0) else: for j in range(text_sample_len): support['text_2'] = torch.cat( (support['text_2'], support['text'][j].view((1, -1))), dim=0) support['text_len_2'] = torch.cat( (support['text_len_2'], support['text_len'][j].view(-1)), dim=0) support['label_2'] = torch.cat( (support['label_2'], support['label'][j].view(-1)), dim=0) # print("support['text_1']:", support['text_1'].shape, support['text_len_1'].shape, support['label_1'].shape) # print("support['text_2']:", support['text_2'].shape, support['text_len_2'].shape, support['label_2'].shape) support['label_final'] = support['label_1'].eq(support['label_2']).int() support_1 = {} support_1['text'] = support['text_1'] support_1['text_len'] = support['text_len_1'] support_1['label'] = support['label_1'] support_2 = {} support_2['text'] = support['text_2'] support_2['text_len'] = support['text_len_2'] support_2['label'] = support['label_2'] # print("**************************************") # print("1111111", support['label_1']) # print("2222222", support['label_2']) # print(support['label_final']) '''first step''' S_out1, S_out2 = model['G'](support_1, support_2) loss = criterion(S_out1, S_out2, support['label_final']) zero_grad(model['G'].parameters()) grads = autograd.grad(loss, model['G'].fc.parameters(), allow_unused=True) fast_weights, orderd_params = model['G'].cloned_fc_dict(), OrderedDict() for (key, val), grad in zip(model['G'].fc.named_parameters(), grads): fast_weights[key] = orderd_params[key] = val - args.task_lr * grad '''steps remaining''' for k in range(args.train_iter - 1): S_out1, S_out2 = model['G'](support_1, support_2, fast_weights) loss = criterion(S_out1, S_out2, support['label_final']) zero_grad(orderd_params.values()) grads = torch.autograd.grad(loss, orderd_params.values(), allow_unused=True) # print('grads:', grads) # print("orderd_params.items():", orderd_params.items()) for (key, val), grad in zip(orderd_params.items(), grads): if grad is not None: fast_weights[key] = orderd_params[ key] = val - args.task_lr * grad """计算Q上的损失""" CN = model['G'].forward_once_with_param(class_names_dict, fast_weights) XQ = model['G'].forward_once_with_param(query, fast_weights) logits_q = neg_dist(XQ, CN) _, pred = torch.max(logits_q, 1) acc_q = model['G'].accuracy(pred, YQ) return acc_q
def train_one(task, class_names, model, optG, criterion, args, grad): ''' Train the model on one sampled task. ''' model['G'].train() # model['G2'].train() # model['clf'].train() support, query = task # print("support, query:", support, query) # print("class_names_dict:", class_names_dict) '''分样本对''' YS = support['label'] YQ = query['label'] sampled_classes = torch.unique(support['label']).cpu().numpy().tolist() # print("sampled_classes:", sampled_classes) class_names_dict = {} class_names_dict['label'] = class_names['label'][sampled_classes] # print("class_names_dict['label']:", class_names_dict['label']) class_names_dict['text'] = class_names['text'][sampled_classes] class_names_dict['text_len'] = class_names['text_len'][sampled_classes] class_names_dict['is_support'] = False class_names_dict = utils.to_tensor(class_names_dict, args.cuda, exclude_keys=['is_support']) YS, YQ = reidx_y(args, YS, YQ) # print('YS:', support['label']) # print('YQ:', query['label']) # print("class_names_dict:", class_names_dict['label']) """维度填充""" if support['text'].shape[1] != class_names_dict['text'].shape[1]: zero = torch.zeros( (class_names_dict['text'].shape[0], support['text'].shape[1] - class_names_dict['text'].shape[1]), dtype=torch.long) class_names_dict['text'] = torch.cat( (class_names_dict['text'], zero.cuda()), dim=-1) support['text'] = torch.cat((support['text'], class_names_dict['text']), dim=0) support['text_len'] = torch.cat( (support['text_len'], class_names_dict['text_len']), dim=0) support['label'] = torch.cat((support['label'], class_names_dict['label']), dim=0) # print("support['text']:", support['text'].shape) # print("support['label']:", support['label']) text_sample_len = support['text'].shape[0] # print("support['text'].shape[0]:", support['text'].shape[0]) support['text_1'] = support['text'][0].view((1, -1)) support['text_len_1'] = support['text_len'][0].view(-1) support['label_1'] = support['label'][0].view(-1) for i in range(text_sample_len): if i == 0: for j in range(1, text_sample_len): support['text_1'] = torch.cat( (support['text_1'], support['text'][i].view((1, -1))), dim=0) support['text_len_1'] = torch.cat( (support['text_len_1'], support['text_len'][i].view(-1)), dim=0) support['label_1'] = torch.cat( (support['label_1'], support['label'][i].view(-1)), dim=0) else: for j in range(text_sample_len): support['text_1'] = torch.cat( (support['text_1'], support['text'][i].view((1, -1))), dim=0) support['text_len_1'] = torch.cat( (support['text_len_1'], support['text_len'][i].view(-1)), dim=0) support['label_1'] = torch.cat( (support['label_1'], support['label'][i].view(-1)), dim=0) support['text_2'] = support['text'][0].view((1, -1)) support['text_len_2'] = support['text_len'][0].view(-1) support['label_2'] = support['label'][0].view(-1) for i in range(text_sample_len): if i == 0: for j in range(1, text_sample_len): support['text_2'] = torch.cat( (support['text_2'], support['text'][j].view((1, -1))), dim=0) support['text_len_2'] = torch.cat( (support['text_len_2'], support['text_len'][j].view(-1)), dim=0) support['label_2'] = torch.cat( (support['label_2'], support['label'][j].view(-1)), dim=0) else: for j in range(text_sample_len): support['text_2'] = torch.cat( (support['text_2'], support['text'][j].view((1, -1))), dim=0) support['text_len_2'] = torch.cat( (support['text_len_2'], support['text_len'][j].view(-1)), dim=0) support['label_2'] = torch.cat( (support['label_2'], support['label'][j].view(-1)), dim=0) # print("support['text_1']:", support['text_1'].shape, support['text_len_1'].shape, support['label_1'].shape) # print("support['text_2']:", support['text_2'].shape, support['text_len_2'].shape, support['label_2'].shape) support['label_final'] = support['label_1'].eq(support['label_2']).int() support_1 = {} support_1['text'] = support['text_1'] support_1['text_len'] = support['text_len_1'] support_1['label'] = support['label_1'] support_2 = {} support_2['text'] = support['text_2'] support_2['text_len'] = support['text_len_2'] support_2['label'] = support['label_2'] # print("**************************************") # print("1111111", support['label_1']) # print("2222222", support['label_2']) # print(support['label_final']) '''first step''' S_out1, S_out2 = model['G'](support_1, support_2) loss = criterion(S_out1, S_out2, support['label_final']) zero_grad(model['G'].parameters()) grads = autograd.grad(loss, model['G'].fc.parameters(), allow_unused=True) fast_weights, orderd_params = model['G'].cloned_fc_dict(), OrderedDict() for (key, val), grad in zip(model['G'].fc.named_parameters(), grads): fast_weights[key] = orderd_params[key] = val - args.task_lr * grad '''steps remaining''' for k in range(args.train_iter - 1): S_out1, S_out2 = model['G'](support_1, support_2, fast_weights) loss = criterion(S_out1, S_out2, support['label_final']) zero_grad(orderd_params.values()) grads = torch.autograd.grad(loss, orderd_params.values(), allow_unused=True) # print('grads:', grads) # print("orderd_params.items():", orderd_params.items()) for (key, val), grad in zip(orderd_params.items(), grads): if grad is not None: fast_weights[key] = orderd_params[ key] = val - args.task_lr * grad """计算Q上的损失""" CN = model['G'].forward_once_with_param(class_names_dict, fast_weights) XQ = model['G'].forward_once_with_param(query, fast_weights) logits_q = neg_dist(XQ, CN) q_loss = model['G'].loss(logits_q, YQ) _, pred = torch.max(logits_q, 1) acc_q = model['G'].accuracy(pred, YQ) optG.zero_grad() q_loss.backward() optG.step() # '把CN过微调过的G, S和Q过G2' # CN = model['G'](class_names_dict) # CN:[N, 256(hidden_size*2)] # # Embedding the document # XS = model['G2'](support) # XS:[N*K, 256(hidden_size*2)] # # print("XS:", XS.shape) # YS = support['label'] # # print('YS:', YS) # # XQ = model['G2'](query) # YQ = query['label'] # # print('YQ:', YQ) # # YS, YQ = reidx_y(args, YS, YQ) # 映射标签为从0开始 # # '第二步:用Support更新MLP' # for _ in range(args.train_iter): # # # Embedding the document # XS_mlp = model['clf'](XS) # [N*K, 256(hidden_size*2)] -> [N*K, 256] # # neg_d = neg_dist(XS_mlp, CN) # [N*K, N] # # print("neg_d:", neg_d.shape) # # mlp_loss = model['clf'].loss(neg_d, YS) # # print("mlp_loss:", mlp_loss) # # optCLF.zero_grad() # mlp_loss.backward(retain_graph=True) # optCLF.step() # # '第三步:用Q更新G2' # XQ_mlp = model['clf'](XQ) # neg_d = neg_dist(XQ_mlp, CN) # q_loss = model['clf'].loss(neg_d, YQ) # optG2.zero_grad() # q_loss.backward() # optG2.step() # # _, pred = torch.max(neg_d, 1) # acc_q = model['clf'].accuracy(pred, YQ) # YQ_d = torch.ones(query['label'].shape, dtype=torch.long).to(query['label'].device) # print('YQ', set(YQ.numpy())) # XSource, XSource_inputD, _ = model['G'](source) # YSource_d = torch.zeros(source['label'].shape, dtype=torch.long).to(source['label'].device) # XQ_logitsD = model['D'](XQ_inputD) # XSource_logitsD = model['D'](XSource_inputD) # # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d) # d_loss.backward(retain_graph=True) # grad['D'].append(get_norm(model['D'])) # optD.step() # # # *****************update G**************** # optG.zero_grad() # XQ_logitsD = model['D'](XQ_inputD) # XSource_logitsD = model['D'](XSource_inputD) # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d) # # acc, d_acc, loss, _ = model['clf'](XS, YS, XQ, YQ, XQ_logitsD, XSource_logitsD, YQ_d, YSource_d) # # g_loss = loss - d_loss # if args.ablation == "-DAN": # g_loss = loss # print("%%%%%%%%%%%%%%%%%%%This is ablation mode: -DAN%%%%%%%%%%%%%%%%%%%%%%%%%%") # g_loss.backward(retain_graph=True) # grad['G'].append(get_norm(model['G'])) # grad['clf'].append(get_norm(model['clf'])) # optG.step() return q_loss, acc_q
def train_one(task, class_names, model, optG, optCLF, args, grad): ''' Train the model on one sampled task. ''' model['G'].train() model['clf'].train() support, query = task # print("support, query:", support, query) # print("class_names_dict:", class_names_dict) sampled_classes = torch.unique(support['label']).cpu().numpy().tolist() # print("sampled_classes:", sampled_classes) class_names_dict = {} class_names_dict['label'] = class_names['label'][sampled_classes] # print("class_names_dict['label']:", class_names_dict['label']) class_names_dict['text'] = class_names['text'][sampled_classes] class_names_dict['text_len'] = class_names['text_len'][sampled_classes] class_names_dict['is_support'] = False class_names_dict = utils.to_tensor(class_names_dict, args.cuda, exclude_keys=['is_support']) # Embedding the document XS = model['G'](support) # XS:[N*K, 256(hidden_size*2)] # print("XS:", XS.shape) YS = support['label'] # print('YS:', YS) CN = model['G'](class_names_dict) # CN:[N, 256(hidden_size*2)]] # print("CN:", CN.shape) XQ = model['G'](query) YQ = query['label'] # print('YQ:', YQ) YS, YQ = reidx_y(args, YS, YQ) for _ in range(args.train_iter): # Embedding the document XS_mlp = model['clf'](XS) # [N*K, 256(hidden_size*2)] -> [N*K, 128] CN_mlp = model['clf'](CN) # [N, 256(hidden_size*2)]] -> [N, 128] neg_d = neg_dist(XS_mlp, CN_mlp) # [N*K, N] # print("neg_d:", neg_d.shape) mlp_loss = model['clf'].loss(neg_d, YS) # print("mlp_loss:", mlp_loss) optCLF.zero_grad() mlp_loss.backward(retain_graph=True) optCLF.step() XQ_mlp = model['clf'](XQ) CN_mlp = model['clf'](CN) neg_d = neg_dist(XQ_mlp, CN_mlp) g_loss = model['clf'].loss(neg_d, YQ) optG.zero_grad() g_loss.backward() optG.step() _, pred = torch.max(neg_d, 1) acc_q = model['clf'].accuracy(pred, YQ) # YQ_d = torch.ones(query['label'].shape, dtype=torch.long).to(query['label'].device) # print('YQ', set(YQ.numpy())) # XSource, XSource_inputD, _ = model['G'](source) # YSource_d = torch.zeros(source['label'].shape, dtype=torch.long).to(source['label'].device) # XQ_logitsD = model['D'](XQ_inputD) # XSource_logitsD = model['D'](XSource_inputD) # # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d) # d_loss.backward(retain_graph=True) # grad['D'].append(get_norm(model['D'])) # optD.step() # # # *****************update G**************** # optG.zero_grad() # XQ_logitsD = model['D'](XQ_inputD) # XSource_logitsD = model['D'](XSource_inputD) # d_loss = F.cross_entropy(XQ_logitsD, YQ_d) + F.cross_entropy(XSource_logitsD, YSource_d) # # acc, d_acc, loss, _ = model['clf'](XS, YS, XQ, YQ, XQ_logitsD, XSource_logitsD, YQ_d, YSource_d) # # g_loss = loss - d_loss # if args.ablation == "-DAN": # g_loss = loss # print("%%%%%%%%%%%%%%%%%%%This is ablation mode: -DAN%%%%%%%%%%%%%%%%%%%%%%%%%%") # g_loss.backward(retain_graph=True) # grad['G'].append(get_norm(model['G'])) # grad['clf'].append(get_norm(model['clf'])) # optG.step() return g_loss, acc_q