train_iter = iter(train_loader) num_iter = len(train_loader) model = AttenLSTM(vocab, len(meshlabel_to_ix), char_to_ix) if torch.cuda.is_available(): model.cuda(opt.gpu) optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate) criterion = nn.CrossEntropyLoss() if opt.fine_tune: utils.unfreeze_net(model.embedding) else: utils.freeze_net(model.embedding) #pre-training dictionary instance if opt.pretraining: dict_iter = iter(dict_loader) dict_num_iter = len(dict_loader) #start training dictionary logging.info("batch_size: %s, dict_num_iter %s, train num_iter %s" % (str(opt.batch_size), str(dict_num_iter), str(num_iter))) for epoch in range(opt.dict_iteration): epoch_start = time.time() # sum_dict_cost = 0.0 correct_1, total_1 = 0, 0
def train(vocabs, char_vocab, tag_vocab, train_sets, dev_sets, test_sets, unlabeled_sets): """ train_sets, dev_sets, test_sets: dict[lang] -> AmazonDataset For unlabeled langs, no train_sets are available """ # dataset loaders train_loaders, unlabeled_loaders = {}, {} train_iters, unlabeled_iters, d_unlabeled_iters = {}, {}, {} dev_loaders, test_loaders = {}, {} my_collate = utils.sorted_collate if opt.model=='lstm' else utils.unsorted_collate for lang in opt.langs: train_loaders[lang] = DataLoader(train_sets[lang], opt.batch_size, shuffle=True, collate_fn = my_collate) train_iters[lang] = iter(train_loaders[lang]) for lang in opt.dev_langs: dev_loaders[lang] = DataLoader(dev_sets[lang], opt.batch_size, shuffle=False, collate_fn = my_collate) test_loaders[lang] = DataLoader(test_sets[lang], opt.batch_size, shuffle=False, collate_fn = my_collate) for lang in opt.all_langs: if lang in opt.unlabeled_langs: uset = unlabeled_sets[lang] else: # for labeled langs, consider which data to use as unlabeled set if opt.unlabeled_data == 'both': uset = ConcatDataset([train_sets[lang], unlabeled_sets[lang]]) elif opt.unlabeled_data == 'unlabeled': uset = unlabeled_sets[lang] elif opt.unlabeled_data == 'train': uset = train_sets[lang] else: raise Exception(f'Unknown options for the unlabeled data usage: {opt.unlabeled_data}') unlabeled_loaders[lang] = DataLoader(uset, opt.batch_size, shuffle=True, collate_fn = my_collate) unlabeled_iters[lang] = iter(unlabeled_loaders[lang]) d_unlabeled_iters[lang] = iter(unlabeled_loaders[lang]) # embeddings emb = MultiLangWordEmb(vocabs, char_vocab, opt.use_wordemb, opt.use_charemb).to(opt.device) # models F_s = None F_p = None C, D = None, None num_experts = len(opt.langs)+1 if opt.expert_sp else len(opt.langs) if opt.model.lower() == 'lstm': if opt.shared_hidden_size > 0: F_s = LSTMFeatureExtractor(opt.total_emb_size, opt.F_layers, opt.shared_hidden_size, opt.word_dropout, opt.dropout, opt.bdrnn) if opt.private_hidden_size > 0: if not opt.concat_sp: assert opt.shared_hidden_size == opt.private_hidden_size, "shared dim != private dim when using add_sp!" F_p = nn.Sequential( LSTMFeatureExtractor(opt.total_emb_size, opt.F_layers, opt.private_hidden_size, opt.word_dropout, opt.dropout, opt.bdrnn), MixtureOfExperts(opt.MoE_layers, opt.private_hidden_size, len(opt.langs), opt.private_hidden_size, opt.private_hidden_size, opt.dropout, opt.MoE_bn, False) ) else: raise Exception(f'Unknown model architecture {opt.model}') if opt.C_MoE: C = SpMixtureOfExperts(opt.C_layers, opt.shared_hidden_size, opt.private_hidden_size, opt.concat_sp, num_experts, opt.shared_hidden_size + opt.private_hidden_size, len(tag_vocab), opt.mlp_dropout, opt.C_bn) else: C = SpMlpTagger(opt.C_layers, opt.shared_hidden_size, opt.private_hidden_size, opt.concat_sp, opt.shared_hidden_size + opt.private_hidden_size, len(tag_vocab), opt.mlp_dropout, opt.C_bn) if opt.shared_hidden_size > 0 and opt.n_critic > 0: if opt.D_model.lower() == 'lstm': d_args = { 'num_layers': opt.D_lstm_layers, 'input_size': opt.shared_hidden_size, 'hidden_size': opt.shared_hidden_size, 'word_dropout': opt.D_word_dropout, 'dropout': opt.D_dropout, 'bdrnn': opt.D_bdrnn, 'attn_type': opt.D_attn } elif opt.D_model.lower() == 'cnn': d_args = { 'num_layers': 1, 'input_size': opt.shared_hidden_size, 'hidden_size': opt.shared_hidden_size, 'kernel_num': opt.D_kernel_num, 'kernel_sizes': opt.D_kernel_sizes, 'word_dropout': opt.D_word_dropout, 'dropout': opt.D_dropout } else: d_args = None if opt.D_model.lower() == 'mlp': D = MLPLanguageDiscriminator(opt.D_layers, opt.shared_hidden_size, opt.shared_hidden_size, len(opt.all_langs), opt.loss, opt.D_dropout, opt.D_bn) else: D = LanguageDiscriminator(opt.D_model, opt.D_layers, opt.shared_hidden_size, opt.shared_hidden_size, len(opt.all_langs), opt.D_dropout, opt.D_bn, d_args) if opt.use_data_parallel: F_s, C, D = nn.DataParallel(F_s).to(opt.device) if F_s else None, nn.DataParallel(C).to(opt.device), nn.DataParallel(D).to(opt.device) if D else None else: F_s, C, D = F_s.to(opt.device) if F_s else None, C.to(opt.device), D.to(opt.device) if D else None if F_p: if opt.use_data_parallel: F_p = nn.DataParallel(F_p).to(opt.device) else: F_p = F_p.to(opt.device) # optimizers optimizer = optim.Adam(filter(lambda p: p.requires_grad, itertools.chain(*map(list, [emb.parameters(), F_s.parameters() if F_s else [], \ C.parameters(), F_p.parameters() if F_p else []]))), lr=opt.learning_rate, weight_decay=opt.weight_decay) if D: optimizerD = optim.Adam(D.parameters(), lr=opt.D_learning_rate, weight_decay=opt.D_weight_decay) # testing if opt.test_only: log.info(f'Loading model from {opt.model_save_file}...') if F_s: F_s.load_state_dict(torch.load(os.path.join(opt.model_save_file, f'netF_s.pth'))) for lang in opt.all_langs: F_p.load_state_dict(torch.load(os.path.join(opt.model_save_file, f'net_F_p.pth'))) C.load_state_dict(torch.load(os.path.join(opt.model_save_file, f'netC.pth'))) if D: D.load_state_dict(torch.load(os.path.join(opt.model_save_file, f'netD.pth'))) log.info('Evaluating validation sets:') acc = {} log.info(dev_loaders) log.info(vocabs) for lang in opt.all_langs: acc[lang] = evaluate(f'{lang}_dev', dev_loaders[lang], vocabs[lang], tag_vocab, emb, lang, F_s, F_p, C) avg_acc = sum([acc[d] for d in opt.dev_langs]) / len(opt.dev_langs) log.info(f'Average validation accuracy: {avg_acc}') log.info('Evaluating test sets:') test_acc = {} for lang in opt.all_langs: test_acc[lang] = evaluate(f'{lang}_test', test_loaders[lang], vocabs[lang], tag_vocab, emb, lang, F_s, F_p, C) avg_test_acc = sum([test_acc[d] for d in opt.dev_langs]) / len(opt.dev_langs) log.info(f'Average test accuracy: {avg_test_acc}') return {'valid': acc, 'test': test_acc} # training best_acc, best_avg_acc = defaultdict(float), 0.0 epochs_since_decay = 0 # lambda scheduling if opt.lambd > 0 and opt.lambd_schedule: opt.lambd_orig = opt.lambd num_iter = int(utils.gmean([len(train_loaders[l]) for l in opt.langs])) # adapt max_epoch if opt.max_epoch > 0 and num_iter * opt.max_epoch < 15000: opt.max_epoch = 15000 // num_iter log.info(f"Setting max_epoch to {opt.max_epoch}") for epoch in range(opt.max_epoch): emb.train() if F_s: F_s.train() C.train() if D: D.train() if F_p: F_p.train() # lambda scheduling if hasattr(opt, 'lambd_orig') and opt.lambd_schedule: if epoch == 0: opt.lambd = opt.lambd_orig elif epoch == 5: opt.lambd = 10 * opt.lambd_orig elif epoch == 15: opt.lambd = 100 * opt.lambd_orig log.info(f'Scheduling lambda = {opt.lambd}') # training accuracy correct, total = defaultdict(int), defaultdict(int) gate_correct = defaultdict(int) c_gate_correct = defaultdict(int) # D accuracy d_correct, d_total = 0, 0 for i in tqdm(range(num_iter), ascii=True): # D iterations if opt.shared_hidden_size > 0: utils.freeze_net(emb) utils.freeze_net(F_s) utils.freeze_net(F_p) utils.freeze_net(C) utils.unfreeze_net(D) # WGAN n_critic trick since D trains slower n_critic = opt.n_critic if opt.wgan_trick: if opt.n_critic>0 and ((epoch==0 and i<25) or i%500==0): n_critic = 100 for _ in range(n_critic): D.zero_grad() loss_d = {} lang_features = {} # train on both labeled and unlabeled langs for lang in opt.all_langs: # targets not used d_inputs, _ = utils.endless_get_next_batch( unlabeled_loaders, d_unlabeled_iters, lang) d_inputs, d_lengths, mask, d_chars, d_char_lengths = d_inputs d_embeds = emb(lang, d_inputs, d_chars, d_char_lengths) shared_feat = F_s((d_embeds, d_lengths)) if opt.grad_penalty != 'none': lang_features[lang] = shared_feat.detach() if opt.D_model.lower() == 'mlp': d_outputs = D(shared_feat) # if token-level D, we can reuse the gate label generator d_targets = utils.get_gate_label(d_outputs, lang, mask, False, all_langs=True) d_total += torch.sum(d_lengths).item() else: d_outputs = D((shared_feat, d_lengths)) d_targets = utils.get_lang_label(opt.loss, lang, len(d_lengths)) d_total += len(d_lengths) # D accuracy _, pred = torch.max(d_outputs, -1) # d_total += len(d_lengths) d_correct += (pred==d_targets).sum().item() if opt.use_data_parallel: l_d = functional.nll_loss(d_outputs.view(-1, D.module.num_langs), d_targets.view(-1), ignore_index=-1) else: l_d = functional.nll_loss(d_outputs.view(-1, D.num_langs), d_targets.view(-1), ignore_index=-1) l_d.backward() loss_d[lang] = l_d.item() # gradient penalty if opt.grad_penalty != 'none': gp = utils.calc_gradient_penalty(D, lang_features, onesided=opt.onesided_gp, interpolate=(opt.grad_penalty=='wgan')) gp.backward() optimizerD.step() # F&C iteration utils.unfreeze_net(emb) if opt.use_wordemb and opt.fix_emb: for lang in emb.langs: emb.wordembs[lang].weight.requires_grad = False if opt.use_charemb and opt.fix_charemb: emb.charemb.weight.requires_grad = False utils.unfreeze_net(F_s) utils.unfreeze_net(F_p) utils.unfreeze_net(C) utils.freeze_net(D) emb.zero_grad() if F_s: F_s.zero_grad() if F_p: F_p.zero_grad() C.zero_grad() # optimizer.zero_grad() for lang in opt.langs: inputs, targets = utils.endless_get_next_batch( train_loaders, train_iters, lang) inputs, lengths, mask, chars, char_lengths = inputs bs, seq_len = inputs.size() embeds = emb(lang, inputs, chars, char_lengths) shared_feat, private_feat = None, None if opt.shared_hidden_size > 0: shared_feat = F_s((embeds, lengths)) if opt.private_hidden_size > 0: private_feat, gate_outputs = F_p((embeds, lengths)) if opt.C_MoE: c_outputs, c_gate_outputs = C((shared_feat, private_feat)) else: c_outputs = C((shared_feat, private_feat)) # targets are padded with -1 l_c = functional.nll_loss(c_outputs.view(bs*seq_len, -1), targets.view(-1), ignore_index=-1) # gate loss if F_p: gate_targets = utils.get_gate_label(gate_outputs, lang, mask, False) l_gate = functional.cross_entropy(gate_outputs.view(bs*seq_len, -1), gate_targets.view(-1), ignore_index=-1) l_c += opt.gate_loss_weight * l_gate _, gate_pred = torch.max(gate_outputs.view(bs*seq_len, -1), -1) gate_correct[lang] += (gate_pred == gate_targets.view(-1)).sum().item() if opt.C_MoE and opt.C_gate_loss_weight > 0: c_gate_targets = utils.get_gate_label(c_gate_outputs, lang, mask, opt.expert_sp) _, c_gate_pred = torch.max(c_gate_outputs.view(bs*seq_len, -1), -1) if opt.expert_sp: l_c_gate = functional.binary_cross_entropy_with_logits( mask.unsqueeze(-1) * c_gate_outputs, c_gate_targets) c_gate_correct[lang] += torch.index_select(c_gate_targets.view(bs*seq_len, -1), -1, c_gate_pred.view(bs*seq_len)).sum().item() else: l_c_gate = functional.cross_entropy(c_gate_outputs.view(bs*seq_len, -1), c_gate_targets.view(-1), ignore_index=-1) c_gate_correct[lang] += (c_gate_pred == c_gate_targets.view(-1)).sum().item() l_c += opt.C_gate_loss_weight * l_c_gate l_c.backward() _, pred = torch.max(c_outputs, -1) total[lang] += torch.sum(lengths).item() correct[lang] += (pred == targets).sum().item() # update F with D gradients on all langs if D: for lang in opt.all_langs: inputs, _ = utils.endless_get_next_batch( unlabeled_loaders, unlabeled_iters, lang) inputs, lengths, mask, chars, char_lengths = inputs embeds = emb(lang, inputs, chars, char_lengths) shared_feat = F_s((embeds, lengths)) # d_outputs = D((shared_feat, lengths)) if opt.D_model.lower() == 'mlp': d_outputs = D(shared_feat) # if token-level D, we can reuse the gate label generator d_targets = utils.get_gate_label(d_outputs, lang, mask, False, all_langs=True) else: d_outputs = D((shared_feat, lengths)) d_targets = utils.get_lang_label(opt.loss, lang, len(lengths)) if opt.use_data_parallel: l_d = functional.nll_loss(d_outputs.view(-1, D.module.num_langs), d_targets.view(-1), ignore_index=-1) else: l_d = functional.nll_loss(d_outputs.view(-1, D.num_langs), d_targets.view(-1), ignore_index=-1) if opt.lambd > 0: l_d *= -opt.lambd l_d.backward() optimizer.step() # end of epoch log.info('Ending epoch {}'.format(epoch+1)) if d_total > 0: log.info('D Training Accuracy: {}%'.format(100.0*d_correct/d_total)) log.info('Training accuracy:') log.info('\t'.join(opt.langs)) log.info('\t'.join([str(100.0*correct[d]/total[d]) for d in opt.langs])) log.info('Gate accuracy:') log.info('\t'.join([str(100.0*gate_correct[d]/total[d]) for d in opt.langs])) log.info('Tagger Gate accuracy:') log.info('\t'.join([str(100.0*c_gate_correct[d]/total[d]) for d in opt.langs])) log.info('Evaluating validation sets:') acc = {} for lang in opt.dev_langs: acc[lang] = evaluate(f'{lang}_dev', dev_loaders[lang], vocabs[lang], tag_vocab, emb, lang, F_s, F_p, C) avg_acc = sum([acc[d] for d in opt.dev_langs]) / len(opt.dev_langs) log.info(f'Average validation accuracy: {avg_acc}') log.info('Evaluating test sets:') test_acc = {} for lang in opt.dev_langs: test_acc[lang] = evaluate(f'{lang}_test', test_loaders[lang], vocabs[lang], tag_vocab, emb, lang, F_s, F_p, C) avg_test_acc = sum([test_acc[d] for d in opt.dev_langs]) / len(opt.dev_langs) log.info(f'Average test accuracy: {avg_test_acc}') if avg_acc > best_avg_acc: epochs_since_decay = 0 log.info(f'New best average validation accuracy: {avg_acc}') best_acc['valid'] = acc best_acc['test'] = test_acc best_avg_acc = avg_acc with open(os.path.join(opt.model_save_file, 'options.pkl'), 'wb') as ouf: pickle.dump(opt, ouf) if F_s: torch.save(F_s.state_dict(), '{}/netF_s.pth'.format(opt.model_save_file)) torch.save(emb.state_dict(), '{}/net_emb.pth'.format(opt.model_save_file)) if F_p: torch.save(F_p.state_dict(), '{}/net_F_p.pth'.format(opt.model_save_file)) torch.save(C.state_dict(), '{}/netC.pth'.format(opt.model_save_file)) if D: torch.save(D.state_dict(), '{}/netD.pth'.format(opt.model_save_file)) else: epochs_since_decay += 1 if opt.lr_decay < 1 and epochs_since_decay >= opt.lr_decay_epochs: epochs_since_decay = 0 old_lr = optimizer.param_groups[0]['lr'] optimizer.param_groups[0]['lr'] = old_lr * opt.lr_decay log.info(f'Decreasing LR to {old_lr * opt.lr_decay}') # end of training log.info(f'Best average validation accuracy: {best_avg_acc}') return best_acc
def train(vocab, train_sets, dev_sets, test_sets, unlabeled_sets): """ train_sets, dev_sets, test_sets: dict[domain] -> AmazonDataset For unlabeled domains, no train_sets are available """ # dataset loaders train_loaders, unlabeled_loaders = {}, {} train_iters, unlabeled_iters = {}, {} dev_loaders, test_loaders = {}, {} my_collate = utils.sorted_collate if opt.model == 'lstm' else utils.unsorted_collate for domain in opt.domains: train_loaders[domain] = DataLoader(train_sets[domain], opt.batch_size, shuffle=True, collate_fn=my_collate) train_iters[domain] = iter(train_loaders[domain]) for domain in opt.dev_domains: dev_loaders[domain] = DataLoader(dev_sets[domain], opt.batch_size, shuffle=False, collate_fn=my_collate) test_loaders[domain] = DataLoader(test_sets[domain], opt.batch_size, shuffle=False, collate_fn=my_collate) for domain in opt.all_domains: if domain in opt.unlabeled_domains: uset = unlabeled_sets[domain] else: # for labeled domains, consider which data to use as unlabeled set if opt.unlabeled_data == 'both': uset = ConcatDataset([train_sets[domain], unlabeled_sets[domain]]) elif opt.unlabeled_data == 'unlabeled': uset = unlabeled_sets[domain] elif opt.unlabeled_data == 'train': uset = train_sets[domain] else: raise Exception('Unknown options for the unlabeled data usage: {}'.format(opt.unlabeled_data)) unlabeled_loaders[domain] = DataLoader(uset, opt.batch_size, shuffle=True, collate_fn=my_collate) unlabeled_iters[domain] = iter(unlabeled_loaders[domain]) # models F_s = None C, D = None, None if opt.model.lower() == 'dan': F_s = DanFeatureExtractor(vocab, opt.F_layers, opt.shared_hidden_size, opt.sum_pooling, opt.dropout, opt.F_bn) elif opt.model.lower() == 'lstm': F_s = LSTMFeatureExtractor(vocab, opt.F_layers, opt.shared_hidden_size, opt.dropout, opt.bdrnn, opt.attn) elif opt.model.lower() == 'cnn': F_s = CNNFeatureExtractor(vocab, opt.F_layers, opt.shared_hidden_size, opt.kernel_num, opt.kernel_sizes, opt.dropout) else: raise Exception('Unknown model architecture {}'.format(opt.model)) C = SentimentClassifier(opt.C_layers, opt.shared_hidden_size + opt.domain_hidden_size, opt.shared_hidden_size + opt.domain_hidden_size, opt.num_labels, opt.dropout, opt.C_bn) D = DomainClassifier(opt.D_layers, opt.shared_hidden_size, opt.shared_hidden_size, len(opt.all_domains), opt.loss, opt.dropout, opt.D_bn) F_s, C, D = F_s.to(opt.device), C.to(opt.device), D.to(opt.device) # optimizers optimizer = optim.Adam(itertools.chain( *map(list, [F_s.parameters() if F_s else [], C.parameters()] + [])), lr=opt.learning_rate) optimizerD = optim.Adam(D.parameters(), lr=opt.D_learning_rate) # training best_acc, best_avg_acc = defaultdict(float), 0.0 for epoch in range(opt.max_epoch): F_s.train() C.train() D.train() # training accuracy correct, total = defaultdict(int), defaultdict(int) # D accuracy d_correct, d_total = 0, 0 # conceptually view 1 epoch as 1 epoch of the first domain num_iter = len(train_loaders[opt.domains[0]]) for i in tqdm(range(num_iter)): # D iterations utils.freeze_net(F_s) utils.freeze_net(C) utils.unfreeze_net(D) # WGAN n_critic trick since D trains slower n_critic = opt.n_critic if opt.wgan_trick: if opt.n_critic > 0 and ((epoch == 0 and i < 25) or i % 500 == 0): n_critic = 100 for _ in range(n_critic): D.zero_grad() loss_d = {} # train on both labeled and unlabeled domains for domain in opt.all_domains: # targets not used d_inputs, _ = utils.endless_get_next_batch( unlabeled_loaders, unlabeled_iters, domain) d_targets = utils.get_domain_label(opt.loss, domain, len(d_inputs[1])) shared_feat = F_s(d_inputs) d_outputs = D(shared_feat) # D accuracy _, pred = torch.max(d_outputs, 1) d_total += len(d_inputs[1]) if opt.loss.lower() == 'l2': _, tgt_indices = torch.max(d_targets, 1) d_correct += (pred == tgt_indices).sum().item() l_d = functional.mse_loss(d_outputs, d_targets) l_d.backward() else: d_correct += (pred == d_targets).sum().item() l_d = functional.nll_loss(d_outputs, d_targets) l_d.backward() loss_d[domain] = l_d.item() optimizerD.step() # F&C iteration utils.unfreeze_net(F_s) utils.unfreeze_net(C) utils.freeze_net(D) if opt.fix_emb: utils.freeze_net(F_s.word_emb) F_s.zero_grad() C.zero_grad() for domain in opt.domains: inputs, targets = utils.endless_get_next_batch( train_loaders, train_iters, domain) targets = targets.to(opt.device) shared_feat = F_s(inputs) domain_feat = torch.zeros(len(targets), opt.domain_hidden_size).to(opt.device) features = torch.cat((shared_feat, domain_feat), dim=1) c_outputs = C(features) l_c = functional.nll_loss(c_outputs, targets) l_c.backward(retain_graph=True) _, pred = torch.max(c_outputs, 1) total[domain] += targets.size(0) correct[domain] += (pred == targets).sum().item() # update F with D gradients on all domains for domain in opt.all_domains: d_inputs, _ = utils.endless_get_next_batch( unlabeled_loaders, unlabeled_iters, domain) shared_feat = F_s(d_inputs) d_outputs = D(shared_feat) if opt.loss.lower() == 'gr': d_targets = utils.get_domain_label(opt.loss, domain, len(d_inputs[1])) l_d = functional.nll_loss(d_outputs, d_targets) if opt.lambd > 0: l_d *= -opt.lambd elif opt.loss.lower() == 'bs': d_targets = utils.get_random_domain_label(opt.loss, len(d_inputs[1])) l_d = functional.kl_div(d_outputs, d_targets, size_average=False) if opt.lambd > 0: l_d *= opt.lambd elif opt.loss.lower() == 'l2': d_targets = utils.get_random_domain_label(opt.loss, len(d_inputs[1])) l_d = functional.mse_loss(d_outputs, d_targets) if opt.lambd > 0: l_d *= opt.lambd l_d.backward() optimizer.step() # end of epoch log.info('Ending epoch {}'.format(epoch + 1)) if d_total > 0: log.info('D Training Accuracy: {}%'.format(100.0 * d_correct / d_total)) log.info('Training accuracy:') log.info('\t'.join(opt.domains)) log.info('\t'.join([str(100.0 * correct[d] / total[d]) for d in opt.domains])) log.info('Evaluating validation sets:') acc = {} for domain in opt.dev_domains: acc[domain] = evaluate(domain, dev_loaders[domain], F_s, None, C) avg_acc = sum([acc[d] for d in opt.dev_domains]) / len(opt.dev_domains) log.info('Average validation accuracy: {}'.format(avg_acc)) log.info('Evaluating test sets:') test_acc = {} for domain in opt.dev_domains: test_acc[domain] = evaluate(domain, test_loaders[domain], F_s, None, C) avg_test_acc = sum([test_acc[d] for d in opt.dev_domains]) / len(opt.dev_domains) log.info('Average test accuracy: {}'.format(avg_test_acc)) if avg_acc > best_avg_acc: log.info('New best average validation accuracy: {}'.format(avg_acc)) best_acc['valid'] = acc best_acc['test'] = test_acc best_avg_acc = avg_acc with open(os.path.join(opt.model_save_file, 'options.pkl'), 'wb') as ouf: pickle.dump(opt, ouf) torch.save(F_s.state_dict(), '{}/netF_s.pth'.format(opt.model_save_file)) torch.save(C.state_dict(), '{}/netC.pth'.format(opt.model_save_file)) torch.save(D.state_dict(), '{}/netD.pth'.format(opt.model_save_file)) # end of training log.info('Best average validation accuracy: {}'.format(best_avg_acc)) return best_acc
def train(target_domain_idx, train_sets, test_sets): """ train_sets, test_sets: unlabeled domain(dev domain) -> AmazonDataset """ # test dataset print(train_sets[0]) print(test_sets[0]) print(train_sets[0][0].shape) print(test_sets[0][1].shape) # ---------------------- dataloader ---------------------- # train_loaders = DataLoader(train_sets, opt.batch_size, shuffle=False) test_loaders = DataLoader(test_sets, opt.batch_size, shuffle=False) # ---------------------- model initialization ---------------------- # F_s = None F_d = {} C = None if opt.model.lower() == 'mlp': F_s = MlpFeatureExtractor(opt.feature_num, opt.F_hidden_sizes, opt.shared_hidden_size, opt.dropout, opt.F_bn) for domain in opt.domains: F_d[domain] = MlpFeatureExtractor(opt.feature_num, opt.F_hidden_sizes, opt.domain_hidden_size, opt.dropout, opt.F_bn) else: raise Exception(f'Unknown model architecture {opt.model}') C = SentimentClassifier(opt.C_layers, opt.shared_hidden_size + opt.domain_hidden_size, opt.shared_hidden_size + opt.domain_hidden_size, opt.num_labels, opt.dropout, opt.C_bn) D = DomainClassifier(opt.D_layers, opt.shared_hidden_size, opt.shared_hidden_size, len(opt.all_domains), opt.loss, opt.dropout, opt.D_bn) # 转移到gpu上 F_s, C, D = F_s.to(opt.device), C.to(opt.device), D.to(opt.device) for f_d in F_d.values(): f_d = f_d.to(opt.device) # ---------------------- load pre-training model ---------------------- # log.info(f'Loading model from {opt.exp2_model_save_file}...') F_s.load_state_dict(torch.load(os.path.join(opt.exp2_model_save_file, f'netF_s.pth'))) for domain in opt.all_domains: if domain in F_d: F_d[domain].load_state_dict(torch.load(os.path.join(opt.exp2_model_save_file, f'net_F_d_{domain}.pth'))) C.load_state_dict(torch.load(os.path.join(opt.exp2_model_save_file, f'netC.pth'))) D.load_state_dict(torch.load(os.path.join(opt.exp2_model_save_file, f'netD.pth'))) # ---------------------- get fake label(hard label) ---------------------- # log.info('Get fake label:') pseudo_labels, _ = genarate_labels(target_domain_idx, False, train_loaders, F_s, F_d, C, D) # ********************** Test the accuracy of the label prediction ********************** # test_pseudo_labels, targets_total = genarate_labels(target_domain_idx, True, test_loaders, F_s, F_d, C, D) label_correct_acc = calc_label_prediction(test_pseudo_labels, targets_total) log.info(f'the correct rate of label prediction: {label_correct_acc}') # ********************** Test the accuracy of the label prediction ********************** # # ------------------------------- 构造target数据集 ------------------------------- # target_dataset = Subset(train_sets, pseudo_labels) target_dataloader_labelled = DataLoader(target_dataset, opt.batch_size, shuffle=True) target_train_iters = iter(target_dataloader_labelled) # ------------------------------- F_d_target模型以及一些必要的参数定义 ------------------------------- # F_d_target = MlpFeatureExtractor(opt.feature_num, opt.F_hidden_sizes, opt.domain_hidden_size, opt.dropout, opt.F_bn) F_d_target = F_d_target.to(opt.device) # ******************************* 加载预训练模型的权重,不用再从头训练 ******************************* # f_d_choice = random.randint(0, len(opt.domains)) F_d_target.load_state_dict(torch.load(os.path.join(opt.exp2_model_save_file, f'net_F_d_{opt.domains[f_d_choice]}.pth'))) optimizer_F_d_target = optim.Adam(F_d_target.parameters(), lr=opt.learning_rate) # ------------------------------- 测试一下只利用shared feature的准确率 ------------------------------- # log.info('Evaluating test sets only on shared feature:') test_acc = evaluate(opt.dev_domains, test_loaders, F_s, None, C) log.info(f'test accuracy: {test_acc}') for epoch in range(opt.max_epoch): F_d_target.train() C.train() # training accuracy correct, total = 0, 0 num_iter = len(target_dataloader_labelled) print(num_iter) for _ in tqdm(range(num_iter)): utils.freeze_net(F_s) C.zero_grad() F_d_target.zero_grad() inputs, targets = utils.endless_get_next_batch(target_dataloader_labelled, target_train_iters) targets = targets.to(opt.device) shared_feat = F_s(inputs) domain_feat = F_d_target(inputs) features = torch.cat((shared_feat, domain_feat), dim=1) c_outputs = C(features) l_c = functional.nll_loss(c_outputs, targets) l_c.backward() _, pred_idx = torch.max(c_outputs, 1) total += targets.size(0) correct += (pred_idx == targets).sum().item() optimizer_F_d_target.step() # 暂时留下一个问题,这里没有写验证集,因为如果这里要设置验证集的话,需要将训练集切割 # 因为unlabeled domain里面只有2000个label sample以及raw_unlabeled_sets(四千多张) # 后者训练时用,前者测试的时候用 # end of epoch print("correct is %d" % correct) print("total is %d" % total) log.info('Ending epoch {}'.format(epoch + 1)) log.info('Training accuracy:') log.info(opt.unlabeled_domains) log.info(str(100.0 * correct / total)) # 训练过程中的验证集 print("correct is %d" % correct) print("total is %d" % total) log.info('Ending epoch {}'.format(epoch + 1)) log.info('Training accuracy:') log.info('\t'.join(opt.unlabeled_domains)) log.info(str(100.0 * correct / total)) # 保存模型 torch.save(F_d_target.state_dict(), '{}/net_F_d_target.pth'.format(opt.exp2_target_model_save_file)) torch.save(C.state_dict(), '{}/netC_target.pth'.format(opt.exp2_target_model_save_file)) # 在测试集上测试,选择2000张有label的标签数据 # 方便对比与单独训练有label时的情况 log.info('Evaluating test sets:') test_acc = evaluate(opt.dev_domains, test_loaders, F_s, F_d_target, C) log.info(f'test accuracy: {test_acc}') return test_acc
def train(vocab, train_sets, dev_sets, test_sets, unlabeled_sets): """ train_sets, dev_sets, test_sets: dict[domain] -> AmazonDataset For unlabeled domains, no train_sets are available """ # dataset loaders train_loaders, unlabeled_loaders = {}, {} train_iters, unlabeled_iters = {}, {} dev_loaders, test_loaders = {}, {} my_collate = utils.sorted_collate if opt.model == 'lstm' else utils.unsorted_collate for domain in opt.domains: train_loaders[domain] = DataLoader(train_sets[domain], opt.batch_size, shuffle=True, collate_fn=my_collate) train_iters[domain] = iter(train_loaders[domain]) for domain in opt.dev_domains: dev_loaders[domain] = DataLoader(dev_sets[domain], opt.batch_size, shuffle=False, collate_fn=my_collate) test_loaders[domain] = DataLoader(test_sets[domain], opt.batch_size, shuffle=False, collate_fn=my_collate) for domain in opt.all_domains: if domain in opt.unlabeled_domains: uset = unlabeled_sets[domain] else: # for labeled domains, consider which data to use as unlabeled set if opt.unlabeled_data == 'both': uset = ConcatDataset( [train_sets[domain], unlabeled_sets[domain]]) elif opt.unlabeled_data == 'unlabeled': uset = unlabeled_sets[domain] elif opt.unlabeled_data == 'train': uset = train_sets[domain] else: raise Exception( f'Unknown options for the unlabeled data usage: {opt.unlabeled_data}' ) unlabeled_loaders[domain] = DataLoader(uset, opt.batch_size, shuffle=True, collate_fn=my_collate) unlabeled_iters[domain] = iter(unlabeled_loaders[domain]) # model F_s = None F_d = {} C, D = None, None if opt.model.lower() == 'dan': F_s = DanFeatureExtractor(vocab, opt.F_layers, opt.shared_hidden_size, opt.sum_pooling, opt.dropout, opt.F_bn) for domain in opt.domains: F_d[domain] = DanFeatureExtractor(vocab, opt.F_layers, opt.domain_hidden_size, opt.sum_pooling, opt.dropout, opt.F_bn) elif opt.model.lower() == 'lstm': F_s = LSTMFeatureExtractor(vocab, opt.F_layers, opt.shared_hidden_size, opt.dropout, opt.bdrnn, opt.attn) for domain in opt.domains: F_d[domain] = LSTMFeatureExtractor(vocab, opt.F_layers, opt.domain_hidden_size, opt.dropout, opt.bdrnn, opt.attn) elif opt.model.lower() == 'cnn': F_s = CNNFeatureExtractor(vocab, opt.F_layers, opt.shared_hidden_size, opt.kernel_num, opt.kernel_sizes, opt.dropout) for domain in opt.domains: F_d[domain] = CNNFeatureExtractor(vocab, opt.F_layers, opt.domain_hidden_size, opt.kernel_num, opt.kernel_sizes, opt.dropout) else: raise Exception(f'Unknown model architecture {opt.model}') C = SentimentClassifier(opt.C_layers, opt.shared_hidden_size + opt.domain_hidden_size, opt.shared_hidden_size + opt.domain_hidden_size, opt.num_labels, opt.dropout, opt.C_bn) D = DomainClassifier(opt.D_layers, opt.shared_hidden_size, opt.shared_hidden_size, len(opt.all_domains), opt.loss, opt.dropout, opt.D_bn) F_s, C, D = F_s.to(opt.device), C.to(opt.device), D.to(opt.device) for f_d in F_d.values(): f_d = f_d.to(opt.device) # optimizers optimizer = optim.Adam(itertools.chain( *map(list, [F_s.parameters() if F_s else [], C.parameters()] + [f.parameters() for f in F_d.values()])), lr=opt.learning_rate) optimizerD = optim.Adam(D.parameters(), lr=opt.D_learning_rate) # testing if opt.test_only: log.info(f'Loading model from {opt.exp3_model_save_file}...') if F_s: F_s.load_state_dict( torch.load( os.path.join(opt.exp3_model_save_file, f'netF_s.pth'))) for domain in opt.all_domains: if domain in F_d: F_d[domain].load_state_dict( torch.load( os.path.join(opt.exp3_model_save_file, f'net_F_d_{domain}.pth'))) C.load_state_dict( torch.load(os.path.join(opt.exp3_model_save_file, f'netC.pth'))) D.load_state_dict( torch.load(os.path.join(opt.exp3_model_save_file, f'netD.pth'))) log.info('Evaluating validation sets:') acc = {} for domain in opt.all_domains: acc[domain] = evaluate(domain, dev_loaders[domain], F_s, F_d[domain] if domain in F_d else None, C) avg_acc = sum([acc[d] for d in opt.dev_domains]) / len(opt.dev_domains) log.info(f'Average validation accuracy: {avg_acc}') log.info('Evaluating test sets:') test_acc = {} for domain in opt.all_domains: test_acc[domain] = evaluate(domain, test_loaders[domain], F_s, F_d[domain] if domain in F_d else None, C) avg_test_acc = sum([test_acc[d] for d in opt.dev_domains]) / len(opt.dev_domains) log.info(f'Average test accuracy: {avg_test_acc}') return {'valid': acc, 'test': test_acc} # training best_acc, best_avg_acc = defaultdict(float), 0.0 for epoch in range(opt.max_epoch): F_s.train() C.train() D.train() LAMBDA = 3 lambda1 = 0.1 lambda2 = 0.1 for f in F_d.values(): f.train() # training accuracy correct, total = defaultdict(int), defaultdict(int) # D accuracy shared_d_correct, private_d_correct, d_total = 0, 0, 0 # conceptually view 1 epoch as 1 epoch of the first domain num_iter = len(train_loaders[opt.domains[0]]) for i in tqdm(range(num_iter)): # D iterations utils.freeze_net(F_s) map(utils.freeze_net, F_d.values()) utils.freeze_net(C) utils.unfreeze_net(D) # WGAN n_critic trick since D trains slower # ********************** D iterations on all domains ********************** # # ---------------------- update D with D gradients on all domains ---------------------- # n_critic = opt.n_critic if opt.wgan_trick: if opt.n_critic > 0 and ((epoch == 0 and i < 25) or i % 500 == 0): n_critic = 100 for _ in range(n_critic): D.zero_grad() loss_d = {} # train on both labeled and unlabeled domains for domain in opt.all_domains: # targets not usedndless_get_next_batch( # unlabeled_loaders, unlabeled_iters, domain) d_inputs, _ = utils.endless_get_next_batch( unlabeled_loaders, unlabeled_iters, domain) d_targets = utils.get_domain_label(opt.loss, domain, len(d_inputs[1])) shared_feat = F_s(d_inputs) shared_d_outputs = D(shared_feat) _, shared_pred = torch.max(shared_d_outputs, 1) if domain != opt.dev_domains[0]: private_feat = F_d[domain](d_inputs) private_d_outputs = D(private_feat) _, private_pred = torch.max(private_d_outputs, 1) d_total += len(d_inputs[1]) if opt.loss.lower() == 'l2': _, tgt_indices = torch.max(d_targets, 1) shared_d_correct += ( shared_pred == tgt_indices).sum().item() shared_l_d = functional.mse_loss( shared_d_outputs, d_targets) private_l_d = 0.0 if domain != opt.dev_domains[0]: private_d_correct += ( private_pred == tgt_indices).sum().item() private_l_d = functional.mse_loss( private_d_outputs, d_targets) / len( opt.domains) l_d_sum = shared_l_d + private_l_d l_d_sum.backward() else: shared_d_correct += ( shared_pred == d_targets).sum().item() shared_l_d = functional.nll_loss( shared_d_outputs, d_targets) private_l_d = 0.0 if domain != opt.dev_domains[0]: private_d_correct += ( private_pred == d_targets).sum().item() private_l_d = functional.nll_loss( private_d_outputs, d_targets) / len( opt.domains) l_d_sum = shared_l_d + private_l_d l_d_sum.backward() loss_d[domain] = l_d_sum.item() optimizerD.step() # ---------------------- update D with C gradients on all domains ---------------------- # # ********************** D iterations on all domains ********************** # # ********************** F&C iteration ********************** # # ---------------------- update F_s & F_ds with C gradients on all labeled domains ---------------------- # # F&C iteration utils.unfreeze_net(F_s) map(utils.unfreeze_net, F_d.values()) utils.unfreeze_net(C) utils.freeze_net(D) if opt.fix_emb: utils.freeze_net(F_s.word_emb) for f_d in F_d.values(): utils.freeze_net(f_d.word_emb) F_s.zero_grad() for f_d in F_d.values(): f_d.zero_grad() C.zero_grad() for domain in opt.domains: inputs, targets = utils.endless_get_next_batch( train_loaders, train_iters, domain) targets = targets.to(opt.device) shared_feat = F_s(inputs) domain_feat = F_d[domain](inputs) features = torch.cat((shared_feat, domain_feat), dim=1) c_outputs = C(features) loss_part_1 = functional.nll_loss(c_outputs, targets) targets = targets.unsqueeze(1) targets_onehot = torch.FloatTensor(opt.batch_size, 2) targets_onehot.zero_() targets_onehot.scatter_(1, targets.cpu(), 1) targets_onehot = targets_onehot.to(opt.device) loss_part_2 = lambda1 * margin_regularization( inputs, targets_onehot, F_d[domain], LAMBDA) loss_part_3 = -lambda2 * center_point_constraint( domain_feat, targets) print("lambda1: " + str(lambda1)) print("lambda2: " + str(lambda2)) print("loss_part_1: " + str(loss_part_1)) print("loss_part_2: " + str(loss_part_2)) print("loss_part_3: " + str(loss_part_3)) l_c = loss_part_1 + loss_part_2 + loss_part_3 l_c.backward(retain_graph=True) _, pred = torch.max(c_outputs, 1) total[domain] += targets.size(0) correct[domain] += (pred == targets).sum().item() # ---------------------- update F_s & F_ds with C gradients on all labeled domains ---------------------- # # ---------------------- update F_s with D gradients on all domains ---------------------- # # update F with D gradients on all domains for domain in opt.all_domains: d_inputs, _ = utils.endless_get_next_batch( unlabeled_loaders, unlabeled_iters, domain) shared_feat = F_s(d_inputs) shared_d_outputs = D(shared_feat) if domain != opt.dev_domains[0]: private_feat = F_d[domain](d_inputs) private_d_outputs = D(private_feat) l_d_sum = None if opt.loss.lower() == 'gr': d_targets = utils.get_domain_label(opt.loss, domain, len(d_inputs[1])) shared_l_d = functional.nll_loss(shared_d_outputs, d_targets) private_l_d, l_d_sum = 0.0, 0.0 if domain != opt.dev_domains[0]: # 注意这边的loss function private_l_d = functional.nll_loss( private_d_outputs, d_targets) * -1. / len( opt.domains) if opt.shared_lambd > 0: l_d_sum = shared_l_d * opt.shared_lambd * -1. else: l_d_sum = shared_l_d * -1. if opt.private_lambd > 0: l_d_sum += private_l_d * opt.private_lambd * -1. else: l_d_sum += private_l_d * -1. elif opt.loss.lower() == 'bs': d_targets = utils.get_random_domain_label( opt.loss, len(d_inputs[1])) shared_l_d = functional.kl_div(shared_d_outputs, d_targets, size_average=False) private_l_d, l_d_sum = 0.0, 0.0 if domain != opt.dev_domains[0]: private_l_d = functional.kl_div(private_d_outputs, d_targets, size_average=False) \ * -1. / len(opt.domains) if opt.shared_lambd > 0: l_d_sum = shared_l_d * opt.shared_lambd else: l_d_sum = shared_l_d if opt.private_lambd > 0: l_d_sum += private_l_d * opt.private_lambd else: l_d_sum += private_l_d elif opt.loss.lower() == 'l2': d_targets = utils.get_random_domain_label( opt.loss, len(d_inputs[1])) shared_l_d = functional.mse_loss(shared_d_outputs, d_targets) private_l_d, l_d_sum = 0.0, 0.0 if domain != opt.dev_domains[0]: private_l_d = functional.mse_loss( private_d_outputs, d_targets) * -1. / len( opt.domains) if opt.shared_lambd > 0: l_d_sum = shared_l_d * opt.shared_lambd else: l_d_sum = shared_l_d if opt.private_lambd > 0: l_d_sum += private_l_d * opt.private_lambd else: l_d_sum += private_l_d l_d_sum.backward() optimizer.step() # end of epoch log.info('Ending epoch {}'.format(epoch + 1)) if d_total > 0: log.info('shared D Training Accuracy: {}%'.format( 100.0 * shared_d_correct / d_total)) log.info('private D Training Accuracy(average): {}%'.format( 100.0 * private_d_correct / len(opt.domains) / d_total)) log.info('Training accuracy:') log.info('\t'.join(opt.domains)) log.info('\t'.join( [str(100.0 * correct[d] / total[d]) for d in opt.domains])) # 验证集上验证实验 log.info('Evaluating validation sets:') acc = {} for domain in opt.dev_domains: acc[domain] = evaluate(domain, dev_loaders[domain], F_s, F_d[domain] if domain in F_d else None, C) avg_acc = sum([acc[d] for d in opt.dev_domains]) / len(opt.dev_domains) log.info(f'Average validation accuracy: {avg_acc}') # 测试集上验证实验 log.info('Evaluating test sets:') test_acc = {} for domain in opt.dev_domains: test_acc[domain] = evaluate(domain, test_loaders[domain], F_s, F_d[domain] if domain in F_d else None, C) avg_test_acc = sum([test_acc[d] for d in opt.dev_domains]) / len(opt.dev_domains) log.info(f'Average test accuracy: {avg_test_acc}') # 保存模型 if avg_acc > best_avg_acc: log.info(f'New best average validation accuracy: {avg_acc}') best_acc['valid'] = acc best_acc['test'] = test_acc best_avg_acc = avg_acc with open(os.path.join(opt.exp3_model_save_file, 'options.pkl'), 'wb') as ouf: pickle.dump(opt, ouf) torch.save(F_s.state_dict(), '{}/netF_s.pth'.format(opt.exp3_model_save_file)) for d in opt.domains: if d in F_d: torch.save( F_d[d].state_dict(), '{}/net_F_d_{}.pth'.format(opt.exp3_model_save_file, d)) torch.save(C.state_dict(), '{}/netC.pth'.format(opt.exp3_model_save_file)) torch.save(D.state_dict(), '{}/netD.pth'.format(opt.exp3_model_save_file)) # end of training log.info(f'Best average validation accuracy: {best_avg_acc}') log.info( f'Loading model for feature visualization from {opt.exp3_model_save_file}...' ) for domain in opt.domains: F_d[domain].load_state_dict( torch.load( os.path.join(opt.exp3_model_save_file, f'net_F_d_{domain}.pth'))) num_iter = len(train_loaders[opt.domains[0]]) # visual_features暂时不加上shared feature # visual_features, senti_labels = get_visual_features(num_iter, unlabeled_loaders, unlabeled_iters, F_s, F_d) visual_features, senti_labels = get_visual_features( num_iter, unlabeled_loaders, unlabeled_iters, F_d) return best_acc, visual_features, senti_labels
def train(vocab, train_sets, dev_sets, test_sets, unlabeled_sets): train_loaders, unlabeled_loaders = {}, {} train_iters, unlabeled_iters = {}, {} dev_loaders, test_loaders = {}, {} my_collate = utils.sorted_collate if opt.model == 'lstm' else utils.unsorted_collate for domain in opt.domains: train_loaders[domain] = DataLoader(train_sets[domain], opt.batch_size, shuffle=True, collate_fn=my_collate) train_iters[domain] = iter(train_loaders[domain]) for domain in opt.dev_domains: dev_loaders[domain] = DataLoader(dev_sets[domain], opt.batch_size, shuffle=False, collate_fn=my_collate) test_loaders[domain] = DataLoader(test_sets[domain], opt.batch_size, shuffle=False, collate_fn=my_collate) for domain in opt.all_domains: if domain in opt.unlabeled_domains: uset = unlabeled_sets[domain] else: # for labeled domains, consider which data to use as unlabeled set if opt.unlabeled_data == 'both': uset = ConcatDataset( [train_sets[domain], unlabeled_sets[domain]]) elif opt.unlabeled_data == 'unlabeled': uset = unlabeled_sets[domain] elif opt.unlabeled_data == 'train': uset = train_sets[domain] else: raise Exception( 'Unknown options for the unlabeled data usage: {}'.format( opt.unlabeled_data)) unlabeled_loaders[domain] = DataLoader(uset, opt.batch_size, shuffle=True, collate_fn=my_collate) unlabeled_iters[domain] = iter(unlabeled_loaders[domain]) # models F_s = None C = None if opt.model.lower() == 'dan': F_s = DanFeatureExtractor(vocab, opt.F_layers, opt.shared_hidden_size, opt.sum_pooling, opt.dropout, opt.F_bn) elif opt.model.lower() == 'lstm': F_s = LSTMFeatureExtractor(vocab, opt.F_layers, opt.shared_hidden_size, opt.dropout, opt.bdrnn, opt.attn) elif opt.model.lower() == 'cnn': F_s = CNNFeatureExtractor(vocab, opt.F_layers, opt.shared_hidden_size, opt.kernel_num, opt.kernel_sizes, opt.dropout) else: raise Exception('Unknown model architecture {}'.format(opt.model)) C = SentimentClassifier(opt.C_layers, opt.shared_hidden_size + opt.domain_hidden_size, opt.shared_hidden_size + opt.domain_hidden_size, opt.num_labels, opt.dropout, opt.C_bn) F_s, C = F_s.to(opt.device), C.to(opt.device) optimizer = optim.Adam(itertools.chain( *map(list, [F_s.parameters() if F_s else [], C.parameters()])), lr=opt.learning_rate) best_acc, best_avg_acc = defaultdict(float), 0.0 for epoch in range(opt.max_epoch): F_s.train() C.train() # training accuracy correct, total = defaultdict(int), defaultdict(int) # conceptually view 1 epoch as 1 epoch of the first domain num_iter = len(train_loaders[opt.domains[0]]) for i in tqdm(range(num_iter)): # F&C iteration utils.unfreeze_net(F_s) utils.unfreeze_net(C) if opt.fix_emb: utils.freeze_net(F_s.word_emb) F_s.zero_grad() C.zero_grad() for domain in opt.domains: inputs, targets = utils.endless_get_next_batch( train_loaders, train_iters, domain) targets = targets.to(opt.device) shared_feat = F_s(inputs) domain_feat = torch.zeros( len(targets), opt.domain_hidden_size).to(opt.device) features = torch.cat((shared_feat, domain_feat), dim=1) c_outputs = C(features) l_c = functional.nll_loss(c_outputs, targets) l_c.backward(retain_graph=True) _, pred = torch.max(c_outputs, 1) total[domain] += targets.size(0) correct[domain] += (pred == targets).sum().item() optimizer.step() # end of epoch log.info('Ending epoch {}'.format(epoch + 1)) log.info('Training accuracy:') log.info('\t'.join(opt.domains)) log.info('\t'.join( [str(100.0 * correct[d] / total[d]) for d in opt.domains])) log.info('Evaluating validation sets:') acc = {} for domain in opt.dev_domains: acc[domain] = evaluate(domain, dev_loaders[domain], F_s, None, C) avg_acc = sum([acc[d] for d in opt.dev_domains]) / len(opt.dev_domains) log.info('Average validation accuracy: {}'.format(avg_acc)) log.info('Evaluating test sets:') test_acc = {} for domain in opt.dev_domains: test_acc[domain] = evaluate(domain, test_loaders[domain], F_s, None, C) avg_test_acc = sum([test_acc[d] for d in opt.dev_domains]) / len(opt.dev_domains) log.info('Average test accuracy: {}'.format(avg_test_acc)) if avg_acc > best_avg_acc: log.info( 'New best average validation accuracy: {}'.format(avg_acc)) best_acc['valid'] = acc best_acc['test'] = test_acc best_avg_acc = avg_acc with open(os.path.join(opt.model_save_file, 'options.pkl'), 'wb') as ouf: pickle.dump(opt, ouf) torch.save(F_s.state_dict(), '{}/netF_s.pth'.format(opt.model_save_file)) torch.save(C.state_dict(), '{}/netC.pth'.format(opt.model_save_file)) # end of training log.info('Best average validation accuracy: {}'.format(best_avg_acc)) return best_acc
def genarate_labels(target_domain_idx, test_for_label_acc, dataloader, F_s, F_d, C, D): """Genrate pseudo labels for unlabeled domain dataset.""" # ---------------------- Switch to eval() mode ---------------------- # optimizerD = optim.Adam(D.parameters(), lr=opt.D_learning_rate) F_s.eval() C.eval() for f_d in F_d.values(): f_d.eval() D.eval() it = iter(dataloader) # batch_size为128 print(len(it)) F_d_features = [] F_d_outputs = [] targets_total, pseudo_labels = None, None with torch.no_grad(): for inputs, targets in tqdm(it): # 少于一个batch的数据,丢弃 if inputs.shape[0] < opt.batch_size: continue # ---------------------- 得到F_d_features和shared_feature ---------------------- # for f_d in F_d.values(): F_d_features.append(f_d(inputs)) shared_feature = F_s(inputs) # ---------------------- 送入D中训练 ---------------------- # if test_for_label_acc is False: utils.freeze_net(F_s) map(utils.freeze_net, F_d.values()) utils.freeze_net(C) utils.unfreeze_net(D) train_D(target_domain_idx, optimizerD, F_d_features, shared_feature, D) # ---------------------- 处理F_s经过D之后的"分数" ---------------------- # # 此时的shared_feature_d_outputs是source domain中domain数量 + 1(target domain) shared_feature_d_outputs = D(shared_feature) # print(shared_feature_d_outputs.shape) # 去掉target domain这一维度 indices = [] for i in range(len(opt.all_domains)): if opt.all_domains[i] == opt.unlabeled_domains: continue else: indices.append(i) indices = torch.tensor(indices).to(opt.device) # print(indices) selected_shared_feature_d_outputs = torch.index_select(shared_feature_d_outputs, 1, indices) # print(selected_shared_feature_d_outputs.shape) # 将selected_shared_feature_d_outputs经过softmax,进行归一化 norm_selected_shared_feature_d_outputs = F.softmax(selected_shared_feature_d_outputs, dim=1) # ---------------------- 所有F_d经过C之后的所有classifier分数 ---------------------- # for f_d_feature in F_d_features: padding_shared_feature = torch.zeros((shared_feature.shape[0], shared_feature.shape[1]), requires_grad=True).to(opt.device) f_d_feature = torch.cat([f_d_feature, padding_shared_feature], 1) F_d_outputs.append(C(f_d_feature)) # ---------------------- 得到c * w后的分数 ---------------------- # norm_selected_shared_feature_d_outputs = np.repeat(norm_selected_shared_feature_d_outputs.cpu().numpy(), repeats=opt.num_labels, axis=1) norm_selected_shared_feature_d_outputs = torch.from_numpy(norm_selected_shared_feature_d_outputs) F_d_outputs_tensor = torch.stack(F_d_outputs, 0).reshape(opt.batch_size, opt.num_labels * (len(opt.all_domains) - 1)) c_mul_w = norm_selected_shared_feature_d_outputs * F_d_outputs_tensor.cpu() # ---------------------- 对c_mul_w处理得到hard label ---------------------- # even_indices = torch.LongTensor(np.arange(0, 2 * (len(opt.all_domains) - 1), 2)) odd_indices = torch.LongTensor(np.arange(1, 2 * (len(opt.all_domains) - 1), 2)) even_index_scores = torch.index_select(c_mul_w, 1, even_indices) odd_index_scores = torch.index_select(c_mul_w, 1, odd_indices) even_index_scores_sum = torch.sum(even_index_scores, 1).unsqueeze(1) odd_index_scores_sum = torch.sum(odd_index_scores, 1).unsqueeze(1) pred_scores = torch.cat([even_index_scores_sum, odd_index_scores_sum], 1) _, pred_idx = torch.max(pred_scores, 1) # ---------------------- 保存pseudo_labels ---------------------- # if pseudo_labels is None: pseudo_labels = pred_idx targets_total = targets else: pseudo_labels = torch.cat( [pseudo_labels, pred_idx], 0) targets_total = torch.cat( [targets_total, targets], 0) F_d_features.clear() F_d_outputs.clear() print(pseudo_labels.shape) print(targets_total.shape) print(">>> Generate pseudo labels {}, target samples {}".format( len(pseudo_labels), targets_total.shape[0])) return pseudo_labels, targets_total
def train_nobert(vocab, train_loaders, unlabeled_loaders, train_iters, unlabeled_iters, dev_loaders, test_loaders): # models F_s = None F_d = {} C, D = None, None if opt.model.lower() == 'dan': F_s = DanFeatureExtractor(vocab, opt.F_layers, opt.shared_hidden_size, opt.sum_pooling, opt.dropout, opt.F_bn) for domain in opt.domains: F_d[domain] = DanFeatureExtractor(vocab, opt.F_layers, opt.domain_hidden_size, opt.sum_pooling, opt.dropout, opt.F_bn) elif opt.model.lower() == 'lstm': F_s = LSTMFeatureExtractor(vocab, opt.F_layers, opt.shared_hidden_size, opt.dropout, opt.bdrnn, opt.attn) for domain in opt.domains: F_d[domain] = LSTMFeatureExtractor(vocab, opt.F_layers, opt.domain_hidden_size, opt.dropout, opt.bdrnn, opt.attn) elif opt.model.lower() == 'cnn': F_s = CNNFeatureExtractor(vocab, opt.F_layers, opt.shared_hidden_size, opt.kernel_num, opt.kernel_sizes, opt.dropout) for domain in opt.domains: F_d[domain] = CNNFeatureExtractor(vocab, opt.F_layers, opt.domain_hidden_size, opt.kernel_num, opt.kernel_sizes, opt.dropout) else: raise Exception('Unknown model architecture {}'.format(opt.model)) C = SentimentClassifier(opt.C_layers, opt.shared_hidden_size + opt.domain_hidden_size, opt.shared_hidden_size + opt.domain_hidden_size, opt.num_labels, opt.dropout, opt.C_bn) D = DomainClassifier(opt.D_layers, opt.shared_hidden_size, opt.shared_hidden_size, len(opt.all_domains), opt.loss, opt.dropout, opt.D_bn) F_s, C, D = F_s.to(opt.device), C.to(opt.device), D.to(opt.device) for f_d in F_d.values(): f_d = f_d.to(opt.device) # optimizers optimizer = optim.Adam(itertools.chain(*map(list, [F_s.parameters() if F_s else [], C.parameters()] + [f.parameters() for f in F_d.values()])), lr=opt.learning_rate) optimizerD = optim.Adam(D.parameters(), lr=opt.D_learning_rate) # testing if opt.test_only: log.info('Loading model from {}...'.format(opt.model_save_file)) if F_s: F_s.load_state_dict(torch.load(os.path.join(opt.model_save_file, 'netF_s.pth'))) for domain in opt.all_domains: if domain in F_d: F_d[domain].load_state_dict(torch.load(os.path.join(opt.model_save_file, 'net_F_d_{}.pth'.format(domain)))) C.load_state_dict(torch.load(os.path.join(opt.model_save_file, 'netC.pth'))) D.load_state_dict(torch.load(os.path.join(opt.model_save_file, 'netD.pth'))) log.info('Evaluating validation sets:') acc = {} for domain in opt.all_domains: acc[domain] = evaluate(domain, dev_loaders[domain], F_s, F_d[domain] if domain in F_d else None, C) avg_acc = sum([acc[d] for d in opt.dev_domains]) / len(opt.dev_domains) log.info('Average validation accuracy: {}'.format(avg_acc)) log.info('Evaluating test sets:') test_acc = {} for domain in opt.all_domains: test_acc[domain] = evaluate(domain, test_loaders[domain], F_s, F_d[domain] if domain in F_d else None, C) avg_test_acc = sum([test_acc[d] for d in opt.dev_domains]) / len(opt.dev_domains) log.info('Average test accuracy: {}'.format(avg_test_acc)) return {'valid': acc, 'test': test_acc} # training best_acc, best_avg_acc = defaultdict(float), 0.0 for epoch in range(opt.max_epoch): F_s.train() C.train() D.train() for f in F_d.values(): f.train() # training accuracy correct, total = defaultdict(int), defaultdict(int) # D accuracy d_correct, d_total = 0, 0 # conceptually view 1 epoch as 1 epoch of the first domain num_iter = len(train_loaders[opt.domains[0]]) for i in tqdm(range(num_iter)): # D iterations utils.freeze_net(F_s) map(utils.freeze_net, F_d.values()) utils.freeze_net(C) utils.unfreeze_net(D) # WGAN n_critic trick since D trains slower n_critic = opt.n_critic if opt.wgan_trick: if opt.n_critic>0 and ((epoch==0 and i<25) or i%500==0): n_critic = 100 for _ in range(n_critic): D.zero_grad() loss_d = {} # train on both labeled and unlabeled domains for domain in opt.all_domains: # targets not used batch = utils.endless_get_next_batch( unlabeled_loaders, unlabeled_iters, domain) batch = tuple(t.to(opt.device) for t in batch) emb_ids, input_ids, input_mask, segment_ids, label_ids = batch d_targets = utils.get_domain_label(opt.loss, domain, len(emb_ids)) shared_feat = F_s(emb_ids) d_outputs = D(shared_feat) # D accuracy _, pred = torch.max(d_outputs, 1) d_total += len(emb_ids) if opt.loss.lower() == 'l2': _, tgt_indices = torch.max(d_targets, 1) d_correct += (pred==tgt_indices).sum().item() l_d = functional.mse_loss(d_outputs, d_targets) l_d.backward() else: d_correct += (pred==d_targets).sum().item() l_d = functional.nll_loss(d_outputs, d_targets) l_d.backward() loss_d[domain] = l_d.item() optimizerD.step() # F&C iteration utils.unfreeze_net(F_s) map(utils.unfreeze_net, F_d.values()) utils.unfreeze_net(C) utils.freeze_net(D) # if opt.fix_emb: # utils.freeze_net(F_s.word_emb) # for f_d in F_d.values(): # utils.freeze_net(f_d.word_emb) F_s.zero_grad() for f_d in F_d.values(): f_d.zero_grad() C.zero_grad() for domain in opt.domains: batch = utils.endless_get_next_batch( train_loaders, train_iters, domain) batch = tuple(t.to(opt.device) for t in batch) emb_ids, input_ids, input_mask, segment_ids, label_ids = batch inputs = emb_ids targets = label_ids shared_feat = F_s(inputs) domain_feat = F_d[domain](inputs) features = torch.cat((shared_feat, domain_feat), dim=1) c_outputs = C(features) l_c = functional.nll_loss(c_outputs, targets) l_c.backward(retain_graph=True) _, pred = torch.max(c_outputs, 1) total[domain] += targets.size(0) correct[domain] += (pred == targets).sum().item() # update F with D gradients on all domains for domain in opt.all_domains: batch = utils.endless_get_next_batch( unlabeled_loaders, unlabeled_iters, domain) batch = tuple(t.to(opt.device) for t in batch) emb_ids, input_ids, input_mask, segment_ids, label_ids = batch d_inputs = emb_ids shared_feat = F_s(d_inputs) d_outputs = D(shared_feat) if opt.loss.lower() == 'gr': d_targets = utils.get_domain_label(opt.loss, domain, len(d_inputs)) l_d = functional.nll_loss(d_outputs, d_targets) if opt.lambd > 0: l_d *= -opt.lambd elif opt.loss.lower() == 'bs': d_targets = utils.get_random_domain_label(opt.loss, len(d_inputs)) l_d = functional.kl_div(d_outputs, d_targets, size_average=False) if opt.lambd > 0: l_d *= opt.lambd elif opt.loss.lower() == 'l2': d_targets = utils.get_random_domain_label(opt.loss, len(d_inputs)) l_d = functional.mse_loss(d_outputs, d_targets) if opt.lambd > 0: l_d *= opt.lambd l_d.backward() optimizer.step() # end of epoch log.info('Ending epoch {}'.format(epoch+1)) if d_total > 0: log.info('D Training Accuracy: {}%'.format(100.0*d_correct/d_total)) log.info('Training accuracy:') log.info('\t'.join(opt.domains)) log.info('\t'.join([str(100.0*correct[d]/total[d]) for d in opt.domains])) log.info('Evaluating validation sets:') acc = {} for domain in opt.dev_domains: acc[domain] = evaluate(domain, dev_loaders[domain], F_s, F_d[domain] if domain in F_d else None, C) avg_acc = sum([acc[d] for d in opt.dev_domains]) / len(opt.dev_domains) log.info('Average validation accuracy: {}'.format(avg_acc)) log.info('Evaluating test sets:') test_acc = {} for domain in opt.dev_domains: test_acc[domain] = evaluate(domain, test_loaders[domain], F_s, F_d[domain] if domain in F_d else None, C) avg_test_acc = sum([test_acc[d] for d in opt.dev_domains]) / len(opt.dev_domains) log.info('Average test accuracy: {}'.format(avg_test_acc)) if avg_acc > best_avg_acc: log.info('New best average validation accuracy: {}'.format(avg_acc)) best_acc['valid'] = acc best_acc['test'] = test_acc best_avg_acc = avg_acc with open(os.path.join(opt.model_save_file, 'options.pkl'), 'wb') as ouf: pickle.dump(opt, ouf) # torch.save(F_s.state_dict(), # '{}/netF_s.pth'.format(opt.model_save_file)) # for d in opt.domains: # if d in F_d: # torch.save(F_d[d].state_dict(), # '{}/net_F_d_{}.pth'.format(opt.model_save_file, d)) # torch.save(C.state_dict(), # '{}/netC.pth'.format(opt.model_save_file)) # torch.save(D.state_dict(), # '{}/netD.pth'.format(opt.model_save_file)) # end of training log.info('Best average validation accuracy: {}'.format(best_avg_acc)) return best_acc
def train_private_nobert(vocab, train_loaders, unlabeled_loaders, train_iters, unlabeled_iters, dev_loaders, test_loaders): # models F_d = {} C = None if opt.model.lower() == 'dan': for domain in opt.domains: F_d[domain] = DanFeatureExtractor(vocab, opt.F_layers, opt.domain_hidden_size, opt.sum_pooling, opt.dropout, opt.F_bn) elif opt.model.lower() == 'lstm': for domain in opt.domains: F_d[domain] = LSTMFeatureExtractor(vocab, opt.F_layers, opt.domain_hidden_size, opt.dropout, opt.bdrnn, opt.attn) elif opt.model.lower() == 'cnn': for domain in opt.domains: F_d[domain] = CNNFeatureExtractor(vocab, opt.F_layers, opt.domain_hidden_size, opt.kernel_num, opt.kernel_sizes, opt.dropout) else: raise Exception('Unknown model architecture {}'.format(opt.model)) C = SentimentClassifier(opt.C_layers, opt.shared_hidden_size + opt.domain_hidden_size, opt.shared_hidden_size + opt.domain_hidden_size, opt.num_labels, opt.dropout, opt.C_bn) C = C.to(opt.device) for f_d in F_d.values(): f_d = f_d.to(opt.device) # optimizers optimizer = optim.Adam(itertools.chain( *map(list, [[], C.parameters()] + [f.parameters() for f in F_d.values()])), lr=opt.learning_rate) # training best_acc, best_avg_acc = defaultdict(float), 0.0 for epoch in range(opt.max_epoch): C.train() for f in F_d.values(): f.train() # training accuracy correct, total = defaultdict(int), defaultdict(int) # D accuracy d_correct, d_total = 0, 0 # conceptually view 1 epoch as 1 epoch of the first domain num_iter = len(train_loaders[opt.domains[0]]) for i in tqdm(range(num_iter)): # F&C iteration map(utils.unfreeze_net, F_d.values()) utils.unfreeze_net(C) if opt.fix_emb: for f_d in F_d.values(): utils.freeze_net(f_d.word_emb) for f_d in F_d.values(): f_d.zero_grad() C.zero_grad() for domain in opt.domains: batch = utils.endless_get_next_batch( train_loaders, train_iters, domain) batch = tuple(t.to(opt.device) for t in batch) emb_ids, input_ids, input_mask, segment_ids, label_ids = batch inputs = emb_ids targets = label_ids shared_feat = torch.zeros(len(targets), opt.shared_hidden_size).to(opt.device) domain_feat = F_d[domain](inputs) features = torch.cat((shared_feat, domain_feat), dim=1) c_outputs = C(features) l_c = functional.nll_loss(c_outputs, targets) l_c.backward(retain_graph=True) _, pred = torch.max(c_outputs, 1) total[domain] += targets.size(0) correct[domain] += (pred == targets).sum().item() optimizer.step() # end of epoch log.info('Ending epoch {}'.format(epoch + 1)) if d_total > 0: log.info('D Training Accuracy: {}%'.format(100.0 * d_correct / d_total)) log.info('Training accuracy:') log.info('\t'.join(opt.domains)) log.info('\t'.join([str(100.0 * correct[d] / total[d]) for d in opt.domains])) log.info('Evaluating validation sets:') acc = {} for domain in opt.dev_domains: acc[domain] = evaluate(domain, dev_loaders[domain], None, F_d[domain] if domain in F_d else None, C) avg_acc = sum([acc[d] for d in opt.dev_domains]) / len(opt.dev_domains) log.info('Average validation accuracy: {}'.format(avg_acc)) log.info('Evaluating test sets:') test_acc = {} for domain in opt.dev_domains: test_acc[domain] = evaluate(domain, test_loaders[domain], None, F_d[domain] if domain in F_d else None, C) avg_test_acc = sum([test_acc[d] for d in opt.dev_domains]) / len(opt.dev_domains) log.info('Average test accuracy: {}'.format(avg_test_acc)) if avg_acc > best_avg_acc: log.info('New best average validation accuracy: {}'.format(avg_acc)) best_acc['valid'] = acc best_acc['test'] = test_acc best_avg_acc = avg_acc with open(os.path.join(opt.model_save_file, 'options.pkl'), 'wb') as ouf: pickle.dump(opt, ouf) # for d in opt.domains: # if d in F_d: # torch.save(F_d[d].state_dict(), # '{}/net_F_d_{}.pth'.format(opt.model_save_file, d)) # torch.save(C.state_dict(), # '{}/netC.pth'.format(opt.model_save_file)) # end of training log.info('Best average validation accuracy: {}'.format(best_avg_acc)) return best_acc
def train(opt): # vocab log.info(f'Loading Embeddings...') vocab = Vocab(opt.emb_filename) # datasets log.info(f'Loading data...') yelp_X_train = os.path.join(opt.src_data_dir, 'X_train.txt.tok.shuf.lower') yelp_Y_train = os.path.join(opt.src_data_dir, 'Y_train.txt.shuf') yelp_X_test = os.path.join(opt.src_data_dir, 'X_test.txt.tok.lower') yelp_Y_test = os.path.join(opt.src_data_dir, 'Y_test.txt') yelp_train, yelp_valid = get_yelp_datasets(vocab, yelp_X_train, yelp_Y_train, opt.en_train_lines, yelp_X_test, yelp_Y_test, opt.max_seq_len) chn_X_file = os.path.join(opt.tgt_data_dir, 'X.sent.txt.shuf.lower') chn_Y_file = os.path.join(opt.tgt_data_dir, 'Y.txt.shuf') chn_train, chn_valid, chn_test = get_chn_htl_datasets(vocab, chn_X_file, chn_Y_file, opt.ch_train_lines, opt.max_seq_len) log.info('Done loading datasets.') opt.num_labels = yelp_train.num_labels if opt.max_seq_len <= 0: # set to true max_seq_len in the datasets opt.max_seq_len = max(yelp_train.get_max_seq_len(), chn_train.get_max_seq_len()) # dataset loaders my_collate = utils.sorted_collate if opt.model=='lstm' else utils.unsorted_collate yelp_train_loader = DataLoader(yelp_train, opt.batch_size, shuffle=True, collate_fn=my_collate) yelp_train_loader_Q = DataLoader(yelp_train, opt.batch_size, shuffle=True, collate_fn=my_collate) chn_train_loader = DataLoader(chn_train, opt.batch_size, shuffle=True, collate_fn=my_collate) chn_train_loader_Q = DataLoader(chn_train, opt.batch_size, shuffle=True, collate_fn=my_collate) yelp_train_iter_Q = iter(yelp_train_loader_Q) chn_train_iter = iter(chn_train_loader) chn_train_iter_Q = iter(chn_train_loader_Q) yelp_valid_loader = DataLoader(yelp_valid, opt.batch_size, shuffle=False, collate_fn=my_collate) chn_valid_loader = DataLoader(chn_valid, opt.batch_size, shuffle=False, collate_fn=my_collate) chn_test_loader = DataLoader(chn_test, opt.batch_size, shuffle=False, collate_fn=my_collate) # models if opt.model.lower() == 'dan': F = DANFeatureExtractor(vocab, opt.F_layers, opt.hidden_size, opt.dropout, opt.F_bn) elif opt.model.lower() == 'lstm': F = LSTMFeatureExtractor(vocab, opt.F_layers, opt.hidden_size, opt.dropout, opt.bdrnn, opt.attn) elif opt.model.lower() == 'cnn': F = CNNFeatureExtractor(vocab, opt.F_layers, opt.hidden_size, opt.kernel_num, opt.kernel_sizes, opt.dropout) else: raise Exception('Unknown model') P = SentimentClassifier(opt.P_layers, opt.hidden_size, opt.num_labels, opt.dropout, opt.P_bn) Q = LanguageDetector(opt.Q_layers, opt.hidden_size, opt.dropout, opt.Q_bn) F, P, Q = F.to(opt.device), P.to(opt.device), Q.to(opt.device) optimizer = optim.Adam(list(F.parameters()) + list(P.parameters()), lr=opt.learning_rate) optimizerQ = optim.Adam(Q.parameters(), lr=opt.Q_learning_rate) # training best_acc = 0.0 for epoch in range(opt.max_epoch): F.train() P.train() Q.train() yelp_train_iter = iter(yelp_train_loader) # training accuracy correct, total = 0, 0 sum_en_q, sum_ch_q = (0, 0.0), (0, 0.0) grad_norm_p, grad_norm_q = (0, 0.0), (0, 0.0) for i, (inputs_en, targets_en) in tqdm(enumerate(yelp_train_iter), total=len(yelp_train)//opt.batch_size): try: inputs_ch, _ = next(chn_train_iter) # Chinese labels are not used except: # check if Chinese data is exhausted chn_train_iter = iter(chn_train_loader) inputs_ch, _ = next(chn_train_iter) # Q iterations n_critic = opt.n_critic if n_critic>0 and ((epoch==0 and i<=25) or (i%500==0)): n_critic = 10 utils.freeze_net(F) utils.freeze_net(P) utils.unfreeze_net(Q) for qiter in range(n_critic): # clip Q weights for p in Q.parameters(): p.data.clamp_(opt.clip_lower, opt.clip_upper) Q.zero_grad() # get a minibatch of data try: # labels are not used q_inputs_en, _ = next(yelp_train_iter_Q) except StopIteration: # check if dataloader is exhausted yelp_train_iter_Q = iter(yelp_train_loader_Q) q_inputs_en, _ = next(yelp_train_iter_Q) try: q_inputs_ch, _ = next(chn_train_iter_Q) except StopIteration: chn_train_iter_Q = iter(chn_train_loader_Q) q_inputs_ch, _ = next(chn_train_iter_Q) features_en = F(q_inputs_en) o_en_ad = Q(features_en) l_en_ad = torch.mean(o_en_ad) (-l_en_ad).backward() log.debug(f'Q grad norm: {Q.net[1].weight.grad.data.norm()}') sum_en_q = (sum_en_q[0] + 1, sum_en_q[1] + l_en_ad.item()) features_ch = F(q_inputs_ch) o_ch_ad = Q(features_ch) l_ch_ad = torch.mean(o_ch_ad) l_ch_ad.backward() log.debug(f'Q grad norm: {Q.net[1].weight.grad.data.norm()}') sum_ch_q = (sum_ch_q[0] + 1, sum_ch_q[1] + l_ch_ad.item()) optimizerQ.step() # F&P iteration utils.unfreeze_net(F) utils.unfreeze_net(P) utils.freeze_net(Q) if opt.fix_emb: utils.freeze_net(F.word_emb) # clip Q weights for p in Q.parameters(): p.data.clamp_(opt.clip_lower, opt.clip_upper) F.zero_grad() P.zero_grad() features_en = F(inputs_en) o_en_sent = P(features_en) l_en_sent = functional.nll_loss(o_en_sent, targets_en) l_en_sent.backward(retain_graph=True) o_en_ad = Q(features_en) l_en_ad = torch.mean(o_en_ad) (opt.lambd*l_en_ad).backward(retain_graph=True) # training accuracy _, pred = torch.max(o_en_sent, 1) total += targets_en.size(0) correct += (pred == targets_en).sum().item() features_ch = F(inputs_ch) o_ch_ad = Q(features_ch) l_ch_ad = torch.mean(o_ch_ad) (-opt.lambd*l_ch_ad).backward() optimizer.step() # end of epoch log.info('Ending epoch {}'.format(epoch+1)) # logs if sum_en_q[0] > 0: log.info(f'Average English Q output: {sum_en_q[1]/sum_en_q[0]}') log.info(f'Average Foreign Q output: {sum_ch_q[1]/sum_ch_q[0]}') # evaluate log.info('Training Accuracy: {}%'.format(100.0*correct/total)) log.info('Evaluating English Validation set:') evaluate(opt, yelp_valid_loader, F, P) log.info('Evaluating Foreign validation set:') acc = evaluate(opt, chn_valid_loader, F, P) if acc > best_acc: log.info(f'New Best Foreign validation accuracy: {acc}') best_acc = acc torch.save(F.state_dict(), '{}/netF_epoch_{}.pth'.format(opt.model_save_file, epoch)) torch.save(P.state_dict(), '{}/netP_epoch_{}.pth'.format(opt.model_save_file, epoch)) torch.save(Q.state_dict(), '{}/netQ_epoch_{}.pth'.format(opt.model_save_file, epoch)) log.info('Evaluating Foreign test set:') evaluate(opt, chn_test_loader, F, P) log.info(f'Best Foreign validation accuracy: {best_acc}')
def train_shared(vocab, train_loaders, unlabeled_loaders, train_iters, unlabeled_iters, dev_loaders, test_loaders, F_s): C = None C = SentimentClassifier(opt.C_layers, opt.shared_hidden_size + opt.domain_hidden_size, opt.shared_hidden_size + opt.domain_hidden_size, opt.num_labels, opt.dropout, opt.C_bn) F_s, C = F_s.to(opt.device), C.to(opt.device) # optimizers optimizer = optim.Adam(itertools.chain( *map(list, [F_s.parameters() if F_s else [], C.parameters()] + [])), lr=opt.learning_rate) # training best_acc, best_avg_acc = defaultdict(float), 0.0 for epoch in range(opt.max_epoch): F_s.train() C.train() # training accuracy correct, total = defaultdict(int), defaultdict(int) # D accuracy d_correct, d_total = 0, 0 # conceptually view 1 epoch as 1 epoch of the first domain num_iter = len(train_loaders[opt.domains[0]]) for i in tqdm(range(num_iter)): # F&C iteration utils.unfreeze_net(F_s) utils.unfreeze_net(C) if opt.fix_emb: utils.freeze_net(F_s.word_emb) F_s.zero_grad() C.zero_grad() for domain in opt.domains: batch = utils.endless_get_next_batch( train_loaders, train_iters, domain) batch = tuple(t.to(opt.device) for t in batch) emb_ids, input_ids, input_mask, segment_ids, label_ids = batch inputs = emb_ids targets = label_ids _, shared_feat = F_s(input_ids, input_mask, segment_ids) domain_feat = torch.zeros(len(targets), opt.domain_hidden_size).to(opt.device) features = torch.cat((shared_feat, domain_feat), dim=1) c_outputs = C(features) l_c = functional.nll_loss(c_outputs, targets) l_c.backward(retain_graph=True) _, pred = torch.max(c_outputs, 1) total[domain] += targets.size(0) correct[domain] += (pred == targets).sum().item() optimizer.step() # end of epoch log.info('Ending epoch {}'.format(epoch + 1)) if d_total > 0: log.info('D Training Accuracy: {}%'.format(100.0 * d_correct / d_total)) log.info('Training accuracy:') log.info('\t'.join(opt.domains)) log.info('\t'.join([str(100.0 * correct[d] / total[d]) for d in opt.domains])) log.info('Evaluating validation sets:') acc = {} for domain in opt.dev_domains: acc[domain] = evaluate(domain, dev_loaders[domain], F_s, None, C) avg_acc = sum([acc[d] for d in opt.dev_domains]) / len(opt.dev_domains) log.info('Average validation accuracy: {}'.format(avg_acc)) log.info('Evaluating test sets:') test_acc = {} for domain in opt.dev_domains: test_acc[domain] = evaluate(domain, test_loaders[domain], F_s, None, C) avg_test_acc = sum([test_acc[d] for d in opt.dev_domains]) / len(opt.dev_domains) log.info('Average test accuracy: {}'.format(avg_test_acc)) if avg_acc > best_avg_acc: log.info('New best average validation accuracy: {}'.format(avg_acc)) best_acc['valid'] = acc best_acc['test'] = test_acc best_avg_acc = avg_acc with open(os.path.join(opt.model_save_file, 'options.pkl'), 'wb') as ouf: pickle.dump(opt, ouf) # torch.save(C.state_dict(), # '{}/netC.pth'.format(opt.model_save_file)) # end of training log.info('Best average validation accuracy: {}'.format(best_avg_acc)) return best_acc
def train(train_sets, dev_sets, test_sets, unlabeled_sets, fold): """ train_sets, dev_sets, test_sets: dict[domain] -> TensorDataset For unlabeled domains, no train_sets are available """ # dataset loaders train_loaders, unlabeled_loaders = {}, {} train_iters, unlabeled_iters = {}, {} dev_loaders, test_loaders = {}, {} for domain in opt.domains: train_loaders[domain] = DataLoader(train_sets[domain], opt.batch_size, shuffle=True) train_iters[domain] = iter(train_loaders[domain]) for domain in opt.all_domains: dev_loaders[domain] = DataLoader(dev_sets[domain], opt.batch_size, shuffle=False) test_loaders[domain] = DataLoader(test_sets[domain], opt.batch_size, shuffle=False) if domain in opt.unlabeled_domains: uset = unlabeled_sets[domain] else: # for labeled domains, consider which data to use as unlabeled set if opt.unlabeled_data == 'both': uset = ConcatDataset( [train_sets[domain], unlabeled_sets[domain]]) elif opt.unlabeled_data == 'unlabeled': uset = unlabeled_sets[domain] elif opt.unlabeled_data == 'train': uset = train_sets[domain] else: raise Exception( f'Unknown options for the unlabeled data usage: {opt.unlabeled_data}' ) unlabeled_loaders[domain] = DataLoader(uset, opt.batch_size, shuffle=True) unlabeled_iters[domain] = iter(unlabeled_loaders[domain]) # models F_s = None F_d = {} C, D = None, None if opt.model.lower() == 'mlp': F_s = MlpFeatureExtractor(opt.feature_num, opt.F_hidden_sizes, opt.shared_hidden_size, opt.dropout, opt.F_bn) for domain in opt.domains: F_d[domain] = MlpFeatureExtractor(opt.feature_num, opt.F_hidden_sizes, opt.domain_hidden_size, opt.dropout, opt.F_bn) else: raise Exception(f'Unknown model architecture {opt.model}') C = SentimentClassifier(opt.C_layers, opt.shared_hidden_size + opt.domain_hidden_size, opt.shared_hidden_size + opt.domain_hidden_size, opt.num_labels, opt.dropout, opt.C_bn) D = DomainClassifier(opt.D_layers, opt.shared_hidden_size, opt.shared_hidden_size, len(opt.all_domains), opt.loss, opt.dropout, opt.D_bn) F_s, C, D = F_s.to(opt.device), C.to(opt.device), D.to(opt.device) for f_d in F_d.values(): f_d = f_d.to(opt.device) # optimizers optimizer = optim.Adam(itertools.chain( *map(list, [F_s.parameters() if F_s else [], C.parameters()] + [f.parameters() for f in F_d.values()])), lr=opt.learning_rate) optimizerD = optim.Adam(D.parameters(), lr=opt.D_learning_rate) # testing if opt.test_only: log.info(f'Loading model from {opt.model_save_file}...') F_s.load_state_dict( torch.load( os.path.join(opt.model_save_file, f'netF_s_fold{fold}.pth'))) for domain in opt.all_domains: if domain in F_d: F_d[domain].load_state_dict( torch.load( os.path.join(opt.model_save_file, f'net_F_d_{domain}_fold{fold}.pth'))) C.load_state_dict( torch.load( os.path.join(opt.model_save_file, f'netC_fold{fold}.pth'))) D.load_state_dict( torch.load( os.path.join(opt.model_save_file, f'netD_fold{fold}.pth'))) log.info('Evaluating validation sets:') acc = {} for domain in opt.all_domains: acc[domain] = evaluate(domain, dev_loaders[domain], F_s, F_d[domain] if domain in F_d else None, C) avg_acc = sum([acc[d] for d in opt.dev_domains]) / len(opt.dev_domains) log.info(f'Average validation accuracy: {avg_acc}') log.info('Evaluating test sets:') test_acc = {} for domain in opt.all_domains: test_acc[domain] = evaluate(domain, test_loaders[domain], F_s, F_d[domain] if domain in F_d else None, C) avg_test_acc = sum([test_acc[d] for d in opt.dev_domains]) / len(opt.dev_domains) log.info(f'Average test accuracy: {avg_test_acc}') return {'valid': acc, 'test': test_acc} # training best_acc, best_avg_acc = defaultdict(float), 0.0 for epoch in range(opt.max_epoch): F_s.train() C.train() D.train() for f in F_d.values(): f.train() # training accuracy correct, total = defaultdict(int), defaultdict(int) # D accuracy d_correct, d_total = 0, 0 # conceptually view 1 epoch as 1 epoch of the first domain num_iter = len(train_loaders[opt.domains[0]]) for i in tqdm(range(num_iter)): # D iterations utils.freeze_net(F_s) map(utils.freeze_net, F_d.values()) utils.freeze_net(C) utils.unfreeze_net(D) # optional WGAN n_critic trick n_critic = opt.n_critic if opt.wgan_trick: if opt.n_critic > 0 and ((epoch == 0 and i < 25) or i % 500 == 0): n_critic = 100 for _ in range(n_critic): D.zero_grad() loss_d = {} # train on both labeled and unlabeled domains for domain in opt.all_domains: # targets not used d_inputs, _ = utils.endless_get_next_batch( unlabeled_loaders, unlabeled_iters, domain) d_targets = utils.get_domain_label(opt.loss, domain, len(d_inputs)) shared_feat = F_s(d_inputs) d_outputs = D(shared_feat) # D accuracy _, pred = torch.max(d_outputs, 1) d_total += len(d_inputs) if opt.loss.lower() == 'l2': _, tgt_indices = torch.max(d_targets, 1) d_correct += (pred == tgt_indices).sum().item() l_d = functional.mse_loss(d_outputs, d_targets) l_d.backward() else: d_correct += (pred == d_targets).sum().item() l_d = functional.nll_loss(d_outputs, d_targets) l_d.backward() loss_d[domain] = l_d.item() optimizerD.step() # F&C iteration utils.unfreeze_net(F_s) map(utils.unfreeze_net, F_d.values()) utils.unfreeze_net(C) utils.freeze_net(D) F_s.zero_grad() for f_d in F_d.values(): f_d.zero_grad() C.zero_grad() shared_feats, domain_feats = [], [] for domain in opt.domains: inputs, targets = utils.endless_get_next_batch( train_loaders, train_iters, domain) targets = targets.to(opt.device) shared_feat = F_s(inputs) shared_feats.append(shared_feat) domain_feat = F_d[domain](inputs) domain_feats.append(domain_feat) features = torch.cat((shared_feat, domain_feat), dim=1) c_outputs = C(features) l_c = functional.nll_loss(c_outputs, targets) l_c.backward(retain_graph=True) # training accuracy _, pred = torch.max(c_outputs, 1) total[domain] += targets.size(0) correct[domain] += (pred == targets).sum().item() # update F with D gradients on all domains for domain in opt.all_domains: d_inputs, _ = utils.endless_get_next_batch( unlabeled_loaders, unlabeled_iters, domain) shared_feat = F_s(d_inputs) d_outputs = D(shared_feat) if opt.loss.lower() == 'gr': d_targets = utils.get_domain_label(opt.loss, domain, len(d_inputs)) l_d = functional.nll_loss(d_outputs, d_targets) log.debug(f'D loss: {l_d.item()}') if opt.lambd > 0: l_d *= -opt.lambd elif opt.loss.lower() == 'bs': d_targets = utils.get_random_domain_label( opt.loss, len(d_inputs)) l_d = functional.kl_div(d_outputs, d_targets, size_average=False) if opt.lambd > 0: l_d *= opt.lambd elif opt.loss.lower() == 'l2': d_targets = utils.get_random_domain_label( opt.loss, len(d_inputs)) l_d = functional.mse_loss(d_outputs, d_targets) if opt.lambd > 0: l_d *= opt.lambd l_d.backward() if opt.model.lower() != 'lstm': log.debug( f'F_s norm: {F_s.net[-2].weight.grad.data.norm()}') optimizer.step() # end of epoch log.info('Ending epoch {}'.format(epoch + 1)) if d_total > 0: log.info('D Training Accuracy: {}%'.format(100.0 * d_correct / d_total)) log.info('Training accuracy:') log.info('\t'.join(opt.domains)) log.info('\t'.join( [str(100.0 * correct[d] / total[d]) for d in opt.domains])) log.info('Evaluating validation sets:') acc = {} for domain in opt.all_domains: acc[domain] = evaluate(domain, dev_loaders[domain], F_s, F_d[domain] if domain in F_d else None, C) avg_acc = sum([acc[d] for d in opt.dev_domains]) / len(opt.dev_domains) log.info(f'Average validation accuracy: {avg_acc}') log.info('Evaluating test sets:') test_acc = {} for domain in opt.all_domains: test_acc[domain] = evaluate(domain, test_loaders[domain], F_s, F_d[domain] if domain in F_d else None, C) avg_test_acc = sum([test_acc[d] for d in opt.dev_domains]) / len(opt.dev_domains) log.info(f'Average test accuracy: {avg_test_acc}') if avg_acc > best_avg_acc: log.info(f'New best average validation accuracy: {avg_acc}') best_acc['valid'] = acc best_acc['test'] = test_acc best_avg_acc = avg_acc with open(os.path.join(opt.model_save_file, 'options.pkl'), 'wb') as ouf: pickle.dump(opt, ouf) torch.save( F_s.state_dict(), '{}/netF_s_fold{}.pth'.format(opt.model_save_file, fold)) for d in opt.domains: if d in F_d: torch.save( F_d[d].state_dict(), '{}/net_F_d_{}_fold{}.pth'.format( opt.model_save_file, d, fold)) torch.save(C.state_dict(), '{}/netC_fold{}.pth'.format(opt.model_save_file, fold)) torch.save(D.state_dict(), '{}/netD_fold{}.pth'.format(opt.model_save_file, fold)) # end of training log.info(f'Best average validation accuracy: {best_avg_acc}') return best_acc