def train(args): # set random seed np.random.seed(args.seed) torch.manual_seed(args.seed) if torch.cuda.is_available() and args.cuda: torch.cuda.manual_seed(args.seed) print('Configuration:') print('\n'.join('\t{:15} {}'.format(k + ':', str(v)) for k, v in sorted(dict(vars(args)).items()))) print() model_path = os.path.join(args.export, 'model.pt') config_path = os.path.join(args.export, 'config.json') export_config(args, config_path) check_path(model_path) ############################################################################### # Load data ############################################################################### n_trg = len(args.trg) lang2id = {lang: i for i, lang in enumerate(args.lang)} dom2id = {dom: i for i, dom in enumerate(args.dom)} src_id, dom_id = lang2id[args.src], dom2id[args.sup_dom] trg_ids = [lang2id[t] for t in args.trg] unlabeled_set = torch.load(args.unlabeled) train_set = torch.load(args.train) val_set = torch.load(args.val) test_set = torch.load(args.test) vocabs = [train_set[lang]['vocab'] for lang in args.lang] unlabeled = to_device([[batchify(unlabeled_set[lang][dom], args.batch_size) for dom in args.dom] for lang in args.lang], args.cuda) train_x, train_y, train_l = to_device(train_set[args.src][args.sup_dom], args.cuda) val_ds = [to_device(val_set[t][args.sup_dom], args.cuda) for t in args.trg] test_ds = [to_device(test_set[t][args.sup_dom], args.cuda) for t in args.trg] if args.sample_unlabeled > 0: print('Downsampling unlabeled set...') print() unlabeled = [[x[:(args.sample_unlabeled // args.batch_size)] for x in t] for t in unlabeled] if args.sample_train > 0: print('Downsampling training set...') print() train_x, train_y, train_l = sample([train_x, train_y, train_l], args.sample_train, True) senti_train = DataLoader(SentiDataset(train_x, train_y, train_l), batch_size=args.clf_batch_size) train_iter = iter(senti_train) train_ds = DataLoader(SentiDataset(train_x, train_y, train_l), batch_size=args.test_batch_size) val_ds = [DataLoader(SentiDataset(*ds), batch_size=args.test_batch_size) for ds in val_ds] test_ds = [DataLoader(SentiDataset(*ds), batch_size=args.test_batch_size) for ds in test_ds] lexicons = [] for tid, tlang in zip(trg_ids, args.trg): sv, tv = vocabs[src_id], vocabs[tid] lex, lexsz = load_lexicon('data/muse/{}-{}.0-5000.txt'.format(args.src, tlang), sv, tv) lexicons.append((lex, lexsz, tid)) ############################################################################### # Build the model ############################################################################### if args.resume: model, dis, lm_opt, dis_opt = model_load(args.resume) else: model = XLXDClassifier(n_classes=2, clf_p=args.dropoutc, n_langs=len(args.lang), n_doms=len(args.dom), vocab_sizes=list(map(len, vocabs)), emb_size=args.emb_dim, hidden_size=args.hid_dim, num_layers=args.nlayers, num_share=args.nshare, tie_weights=args.tie_softmax, output_p=args.dropouto, hidden_p=args.dropouth, input_p=args.dropouti, embed_p=args.dropoute, weight_p=args.dropoutw, alpha=2, beta=1) dis = Discriminator(args.emb_dim, args.dis_hid_dim, len(args.lang), args.dis_nlayers, args.dropoutd) if args.mwe: mwe = [] for lid, (v, lang) in enumerate(zip(vocabs, args.lang)): x, count = load_vectors_with_vocab(args.mwe_path.format(lang), v, -1) model.encoders[lid].weight.data.copy_(torch.from_numpy(x)) freeze_net(model.encoders[lid]) params = [{'params': model.models.parameters(), 'lr': args.lr}, {'params': model.clfs.parameters(), 'lr': args.lr}] if args.optimizer == 'sgd': lm_opt = torch.optim.SGD(params, lr=args.lr, weight_decay=args.wdecay) dis_opt = torch.optim.SGD(dis.parameters(), lr=args.dis_lr, weight_decay=args.wdecay) if args.optimizer == 'adam': lm_opt = torch.optim.Adam(params, lr=args.lr, weight_decay=args.wdecay, betas=(args.beta1, 0.999)) dis_opt = torch.optim.Adam(dis.parameters(), lr=args.dis_lr, weight_decay=args.wdecay, betas=(args.beta1, 0.999)) crit = nn.CrossEntropyLoss() bs = args.batch_size n_doms = len(args.dom) n_langs = len(args.lang) dis_y = to_device(torch.arange(n_langs).unsqueeze(-1).expand(n_langs, bs).contiguous().view(-1), args.cuda) if args.cuda: model.cuda(), dis.cuda(), crit.cuda() else: model.cpu(), dis.cpu(), crit.cpu() print('Parameters:') total_params = sum([np.prod(x.size()) for x in model.parameters()]) print('\ttotal params: {}'.format(total_params)) print('\tparam list: {}'.format(len(list(model.parameters())))) for name, x in model.named_parameters(): print('\t' + name + '\t', tuple(x.size())) for name, x in dis.named_parameters(): print('\t' + name + '\t', tuple(x.size())) print() ############################################################################### # Training code ############################################################################### bptt = args.bptt best_accs = {tlang: 0. for tlang in args.trg} final_test_accs = {tlang: 0. for tlang in args.trg} print('Traning:') print_line() ptrs = np.zeros((len(args.lang), len(args.dom)), dtype=np.int64) # pointers for reading unlabeled data, of shape (n_lang, n_dom) total_loss = np.zeros((len(args.lang), len(args.dom))) # shape (n_lang, n_dom) total_clf_loss = 0 total_dis_loss = 0 start_time = time.time() model.train() model.reset() for step in range(args.max_steps): loss = 0 lm_opt.zero_grad() dis_opt.zero_grad() if not args.mwe: seq_len = max(5, int(np.random.normal(bptt if np.random.random() < 0.95 else bptt / 2., 5))) lr0 = lm_opt.param_groups[0]['lr'] lm_opt.param_groups[0]['lr'] = lr0 * seq_len / args.bptt # language modeling loss dis_x = [] for lid, t in enumerate(unlabeled): for did, lm_x in enumerate(t): if ptrs[lid, did] + bptt + 1 > lm_x.size(0): ptrs[lid, did] = 0 model.reset(lid=lid, did=did) p = ptrs[lid, did] xs = lm_x[p: p + bptt].t().contiguous() ys = lm_x[p + 1: p + 1 + bptt].t().contiguous() lm_raw_loss, lm_loss, hid = model.lm_loss(xs, ys, lid=lid, did=did, return_h=True) loss = loss + lm_loss * args.lambd_lm total_loss[lid, did] += lm_raw_loss.item() ptrs[lid, did] += bptt if did == dom_id: dis_x.append(hid[-1].mean(1)) # language adversarial loss dis_x_rev = GradReverse.apply(torch.cat(dis_x, 0)) dis_loss = crit(dis(dis_x_rev), dis_y) loss = loss + args.lambd_dis * dis_loss total_dis_loss += dis_loss.item() loss.backward() # sentiment classification loss try: xs, ys, ls = next(train_iter) except StopIteration: train_iter = iter(senti_train) xs, ys, ls = next(train_iter) clf_loss = crit(model(xs, ls, src_id, dom_id), ys) total_clf_loss += clf_loss.item() (args.lambd_clf * clf_loss).backward() nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) if args.dis_clip > 0: for x in dis.parameters(): x.data.clamp_(-args.dis_clip, args.dis_clip) lm_opt.step() dis_opt.step() if not args.mwe: lm_opt.param_groups[0]['lr'] = lr0 if (step + 1) % args.log_interval == 0: total_loss /= args.log_interval total_clf_loss /= args.log_interval total_dis_loss /= args.log_interval elapsed = time.time() - start_time print('| step {:5d} | lr {:05.5f} | ms/batch {:7.2f} | lm_loss {:7.4f} | avg_ppl {:7.2f} | clf {:7.4f} | dis {:7.4f} |'.format( step, lm_opt.param_groups[0]['lr'], elapsed * 1000 / args.log_interval, total_loss.mean(), np.exp(total_loss).mean(), total_clf_loss, total_dis_loss)) total_loss[:, :], total_clf_loss, total_dis_loss = 0, 0, 0 start_time = time.time() if (step + 1) % args.val_interval == 0: model.eval() with torch.no_grad(): train_acc = evaluate(model, train_ds, src_id, dom_id) val_accs = [evaluate(model, ds, tid, dom_id) for tid, ds in zip(trg_ids, val_ds)] test_accs = [evaluate(model, ds, tid, dom_id) for tid, ds in zip(trg_ids, test_ds)] bdi_accs = [compute_nn_accuracy(model.encoder_weight(src_id), model.encoder_weight(tid), lexicon, 10000, lexicon_size=lexsz) for lexicon, lexsz, tid in lexicons] print_line() print(('| step {:5d} | train {:.4f} |' + ' val' + ' {} {:.4f}' * n_trg + ' |' + ' test' + ' {} {:.4f}' * n_trg + ' |' + ' bdi' + ' {} {:.4f}' * n_trg + ' |').format(step, train_acc, *sum([[tlang, acc] for tlang, acc in zip(args.trg, val_accs)], []), *sum([[tlang, acc] for tlang, acc in zip(args.trg, test_accs)], []), *sum([[tlang, acc] for tlang, acc in zip(args.trg, bdi_accs)], []))) print_line() print('saving model to {}'.format(model_path.replace('.pt', '_final.pt'))) model_save(model, dis, lm_opt, dis_opt, model_path.replace('.pt', '_final.pt')) for tlang, val_acc, test_acc in zip(args.trg, val_accs, test_accs): if val_acc > best_accs[tlang]: save_path = model_path.replace('.pt', '_{}.pt'.format(tlang)) print('saving {} model to {}'.format(tlang, save_path)) model_save(model, dis, lm_opt, dis_opt, save_path) best_accs[tlang] = val_acc final_test_accs[tlang] = test_acc print_line() model.train() start_time = time.time() print_line() print('Training ended with {} steps'.format(step + 1)) print(('Best val acc: ' + ' {} {:.4f}' * n_trg).format(*sum([[tlang, best_accs[tlang]] for tlang in args.trg], []))) print(('Test acc (w/ early stop): ' + ' {} {:.4f}' * n_trg).format(*sum([[tlang, final_test_accs[tlang]] for tlang in args.trg], []))) print(('Test acc (w/o early stop):' + ' {} {:.4f}' * n_trg).format(*sum([[tlang, acc] for tlang, acc in zip(args.trg, test_accs)], [])))
def train(args): # set random seed np.random.seed(args.seed) torch.manual_seed(args.seed) if torch.cuda.is_available() and args.cuda: torch.cuda.manual_seed(args.seed) print('Configuration:') print('\n'.join('\t{:15} {}'.format(k + ':', str(v)) for k, v in sorted(dict(vars(args)).items()))) print() model_path = os.path.join(args.export, 'model.pt') config_path = os.path.join(args.export, 'config.json') export_config(args, config_path) check_path(model_path) ############################################################################### # Load data ############################################################################### src_lang, src_dom = args.src.split('-') trg_lang, trg_dom = args.trg.split('-') lang_dom_pairs = [[src_lang, src_dom], [src_lang, trg_dom], [trg_lang, trg_dom]] id_pairs = [[0, 0], [0, 1], [1, 1]] unlabeled_set = torch.load(args.unlabeled) train_set = torch.load(args.train) val_set = torch.load(args.val) test_set = torch.load(args.test) src_vocab = train_set[src_lang]['vocab'] trg_vocab = train_set[trg_lang]['vocab'] unlabeled = to_device([ batchify(unlabeled_set[lang][dom], args.batch_size) for lang, dom in lang_dom_pairs ], args.cuda) train_x, train_y, train_l = to_device(train_set[src_lang][src_dom], args.cuda) val_x, val_y, val_l = to_device(val_set[trg_lang][trg_dom], args.cuda) test_x, test_y, test_l = to_device(test_set[trg_lang][trg_dom], args.cuda) if args.sample_unlabeled > 0: print('Downsampling unlabeled set...') print() unlabeled = [ x[:(args.sample_unlabeled // args.batch_size)] for x in unlabeled ] if args.sample_train > 0: print('Downsampling training set...') print() train_x, train_y, train_l = sample([train_x, train_y, train_l], args.sample_train, True) senti_train = DataLoader(SentiDataset(train_x, train_y, train_l), batch_size=args.clf_batch_size) train_iter = iter(senti_train) train_ds = DataLoader(SentiDataset(train_x, train_y, train_l), batch_size=args.test_batch_size) val_ds = DataLoader(SentiDataset(val_x, val_y, val_l), batch_size=args.test_batch_size) test_ds = DataLoader(SentiDataset(test_x, test_y, test_l), batch_size=args.test_batch_size) lexicon, lexsz = load_lexicon( 'data/muse/{}-{}.0-5000.txt'.format(src_lang, trg_lang), src_vocab, trg_vocab) ############################################################################### # Build the model ############################################################################### if args.resume: model, dis, lm_opt, dis_opt = model_load(args.resume) else: model = XLXDClassifier(n_classes=2, clf_p=args.dropoutc, n_langs=2, n_doms=2, vocab_sizes=[len(src_vocab), len(trg_vocab)], emb_size=args.emb_dim, hidden_size=args.hid_dim, num_layers=args.nlayers, num_share=args.nshare, tie_weights=args.tie_softmax, output_p=args.dropouto, hidden_p=args.dropouth, input_p=args.dropouti, embed_p=args.dropoute, weight_p=args.dropoutw, alpha=2, beta=1) dis = Discriminator(args.emb_dim, args.dis_hid_dim, 2, args.dis_nlayers, args.dropoutd) if args.mwe: x, count = load_vectors_with_vocab(args.mwe_path.format(src_lang), src_vocab, -1) model.encoders[0].weight.data.copy_(torch.from_numpy(x)) x, count = load_vectors_with_vocab(args.mwe_path.format(trg_lang), trg_vocab, -1) model.encoders[1].weight.data.copy_(torch.from_numpy(x)) freeze_net(model.encoders) params = [{ 'params': model.models.parameters(), 'lr': args.lr }, { 'params': model.clfs.parameters(), 'lr': args.lr }] if args.optimizer == 'sgd': lm_opt = torch.optim.SGD(params, lr=args.lr, weight_decay=args.wdecay) dis_opt = torch.optim.SGD(dis.parameters(), lr=args.lr, weight_decay=args.wdecay) if args.optimizer == 'adam': lm_opt = torch.optim.Adam(params, lr=args.lr, weight_decay=args.wdecay, betas=(args.beta1, 0.999)) dis_opt = torch.optim.Adam(dis.parameters(), lr=args.lr, weight_decay=args.wdecay, betas=(args.beta1, 0.999)) crit = nn.CrossEntropyLoss() bs = args.batch_size dis_y = to_device(torch.tensor([0] * bs + [1] * bs), args.cuda) if args.cuda: model.cuda(), dis.cuda(), crit.cuda() else: model.cpu(), dis.cpu(), crit.cpu() print('Parameters:') total_params = sum([np.prod(x.size()) for x in model.parameters()]) print('\ttotal params: {}'.format(total_params)) print('\tparam list: {}'.format(len(list(model.parameters())))) for name, x in model.named_parameters(): print('\t' + name + '\t', tuple(x.size())) for name, x in dis.named_parameters(): print('\t' + name + '\t', tuple(x.size())) print() ############################################################################### # Training code ############################################################################### bptt = args.bptt best_acc = 0. final_test_acc = 0. print('Traning:') print_line() p = 0 ptrs = np.zeros(3, dtype=np.int64) total_loss = np.zeros(3) # shape (n_lang, n_dom) total_clf_loss = 0 total_dis_loss = 0 start_time = time.time() model.train() model.reset() for step in range(args.max_steps): loss = 0 lm_opt.zero_grad() dis_opt.zero_grad() if not args.mwe: seq_len = max( 5, int( np.random.normal( bptt if np.random.random() < 0.95 else bptt / 2., 5))) lr0 = lm_opt.param_groups[0]['lr'] lm_opt.param_groups[0]['lr'] = lr0 * seq_len / args.bptt # language modeling loss dis_x = [] for i, ((lid, did), lm_x) in enumerate(zip(id_pairs, unlabeled)): if ptrs[i] + bptt + 1 > lm_x.size(0): ptrs[i] = 0 model.reset(lid=lid, did=did) p = ptrs[i] xs = lm_x[p:p + bptt].t().contiguous() ys = lm_x[p + 1:p + 1 + bptt].t().contiguous() lm_raw_loss, lm_loss, hid = model.lm_loss(xs, ys, lid=lid, did=did, return_h=True) loss = loss + lm_loss * args.lambd_lm if lid == 0 and did == 0: dis_x.append(hid[-1].mean(1)) elif lid == 1 and did == 1: _, _, hid = model.lm_loss(xs, ys, lid=1, did=0, return_h=True) dis_x.append(hid[-1].mean(1)) total_loss[i] += lm_raw_loss.item() ptrs[i] += bptt # language adversarial loss dis_x_rev = GradReverse.apply(torch.cat(dis_x, 0)) dis_loss = crit(dis(dis_x_rev), dis_y) loss = loss + args.lambd_dis * dis_loss total_dis_loss += dis_loss.item() loss.backward() # sentiment classification loss try: xs, ys, ls = next(train_iter) except StopIteration: train_iter = iter(senti_train) xs, ys, ls = next(train_iter) clf_loss = crit(model(xs, ls, lid=0, did=0), ys) total_clf_loss += clf_loss.item() (args.lambd_clf * clf_loss).backward() nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) if args.dis_clip > 0: for x in dis.parameters(): x.data.clamp_(-args.dis_clip, args.dis_clip) lm_opt.step() dis_opt.step() if not args.mwe: lm_opt.param_groups[0]['lr'] = lr0 if (step + 1) % args.log_interval == 0: total_loss /= args.log_interval total_clf_loss /= args.log_interval total_dis_loss /= args.log_interval elapsed = time.time() - start_time print( '| step {:5d} | lr {:05.5f} | ms/batch {:7.2f} | lm_loss {:7.4f} | avg_ppl {:7.2f} | clf {:7.4f} | dis {:7.4f} |' .format(step, lm_opt.param_groups[0]['lr'], elapsed * 1000 / args.log_interval, total_loss.mean(), np.exp(total_loss).mean(), total_clf_loss, total_dis_loss)) total_loss[:], total_clf_loss, total_dis_loss = 0, 0, 0 start_time = time.time() if (step + 1) % args.val_interval == 0: model.eval() with torch.no_grad(): train_acc = evaluate(model, train_ds, 0, 0) val_acc = evaluate(model, val_ds, 1, 0) test_acc = evaluate(model, test_ds, 1, 0) bdi_acc = compute_nn_accuracy(model.encoder_weight(0), model.encoder_weight(1), lexicon, 10000, lexicon_size=lexsz) print_line() print( '| step {:5d} | train {:.4f} | val {:.4f} | test {:.4f} | bdi {:.4f} |' .format(step, train_acc, val_acc, test_acc, bdi_acc)) print_line() print('saving model to {}'.format( model_path.replace('.pt', '_final.pt'))) model_save(model, dis, lm_opt, dis_opt, model_path.replace('.pt', '_final.pt')) if val_acc > best_acc: print('saving model to {}'.format(model_path)) model_save(model, dis, lm_opt, dis_opt, model_path) best_acc, final_test_acc = val_acc, test_acc print_line() model.train() start_time = time.time() print_line() print('Training ended with {} steps'.format(step + 1)) print('Best val acc: {}->{} {:.4f}'.format( args.src, args.trg, best_acc)) print('Test acc (w/ early stop): {}->{} {:.4f}'.format( args.src, args.trg, final_test_acc)) print('Test acc (w/o early stop): {}->{} {:.4f}'.format( args.src, args.trg, test_acc))