Exemple #1
0
def train(args):
    # set random seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available() and args.cuda:
        torch.cuda.manual_seed(args.seed)

    print('Configuration:')
    print('\n'.join('\t{:15} {}'.format(k + ':', str(v)) for k, v in sorted(dict(vars(args)).items())))
    print()

    model_path = os.path.join(args.export, 'model.pt')
    config_path = os.path.join(args.export, 'config.json')
    export_config(args, config_path)
    check_path(model_path)

    ###############################################################################
    # Load data
    ###############################################################################

    n_trg = len(args.trg)
    lang2id = {lang: i for i, lang in enumerate(args.lang)}
    dom2id = {dom: i for i, dom in enumerate(args.dom)}
    src_id, dom_id = lang2id[args.src], dom2id[args.sup_dom]
    trg_ids = [lang2id[t] for t in args.trg]

    unlabeled_set = torch.load(args.unlabeled)
    train_set = torch.load(args.train)
    val_set = torch.load(args.val)
    test_set = torch.load(args.test)
    vocabs = [train_set[lang]['vocab'] for lang in args.lang]
    unlabeled = to_device([[batchify(unlabeled_set[lang][dom], args.batch_size) for dom in args.dom] for lang in args.lang], args.cuda)
    train_x, train_y, train_l = to_device(train_set[args.src][args.sup_dom], args.cuda)
    val_ds = [to_device(val_set[t][args.sup_dom], args.cuda) for t in args.trg]
    test_ds = [to_device(test_set[t][args.sup_dom], args.cuda) for t in args.trg]

    if args.sample_unlabeled > 0:
        print('Downsampling unlabeled set...')
        print()
        unlabeled = [[x[:(args.sample_unlabeled // args.batch_size)] for x in t] for t in unlabeled]
    if args.sample_train > 0:
        print('Downsampling training set...')
        print()
        train_x, train_y, train_l = sample([train_x, train_y, train_l], args.sample_train, True)

    senti_train = DataLoader(SentiDataset(train_x, train_y, train_l), batch_size=args.clf_batch_size)
    train_iter = iter(senti_train)
    train_ds = DataLoader(SentiDataset(train_x, train_y, train_l), batch_size=args.test_batch_size)
    val_ds = [DataLoader(SentiDataset(*ds), batch_size=args.test_batch_size) for ds in val_ds]
    test_ds = [DataLoader(SentiDataset(*ds), batch_size=args.test_batch_size) for ds in test_ds]

    lexicons = []
    for tid, tlang in zip(trg_ids, args.trg):
        sv, tv = vocabs[src_id], vocabs[tid]
        lex, lexsz = load_lexicon('data/muse/{}-{}.0-5000.txt'.format(args.src, tlang), sv, tv)
        lexicons.append((lex, lexsz, tid))

    ###############################################################################
    # Build the model
    ###############################################################################
    if args.resume:
        model, dis, lm_opt, dis_opt = model_load(args.resume)

    else:
        model = XLXDClassifier(n_classes=2, clf_p=args.dropoutc, n_langs=len(args.lang), n_doms=len(args.dom),
                               vocab_sizes=list(map(len, vocabs)), emb_size=args.emb_dim, hidden_size=args.hid_dim,
                               num_layers=args.nlayers, num_share=args.nshare, tie_weights=args.tie_softmax,
                               output_p=args.dropouto, hidden_p=args.dropouth, input_p=args.dropouti, embed_p=args.dropoute,
                               weight_p=args.dropoutw, alpha=2, beta=1)
        dis = Discriminator(args.emb_dim, args.dis_hid_dim, len(args.lang), args.dis_nlayers, args.dropoutd)

        if args.mwe:
            mwe = []
            for lid, (v, lang) in enumerate(zip(vocabs, args.lang)):
                x, count = load_vectors_with_vocab(args.mwe_path.format(lang), v, -1)
                model.encoders[lid].weight.data.copy_(torch.from_numpy(x))
                freeze_net(model.encoders[lid])

        params = [{'params': model.models.parameters(),  'lr': args.lr},
                  {'params': model.clfs.parameters(), 'lr': args.lr}]
        if args.optimizer == 'sgd':
            lm_opt = torch.optim.SGD(params, lr=args.lr, weight_decay=args.wdecay)
            dis_opt = torch.optim.SGD(dis.parameters(), lr=args.dis_lr, weight_decay=args.wdecay)
        if args.optimizer == 'adam':
            lm_opt = torch.optim.Adam(params, lr=args.lr, weight_decay=args.wdecay, betas=(args.beta1, 0.999))
            dis_opt = torch.optim.Adam(dis.parameters(), lr=args.dis_lr, weight_decay=args.wdecay, betas=(args.beta1, 0.999))

    crit = nn.CrossEntropyLoss()

    bs = args.batch_size
    n_doms = len(args.dom)
    n_langs = len(args.lang)
    dis_y = to_device(torch.arange(n_langs).unsqueeze(-1).expand(n_langs, bs).contiguous().view(-1), args.cuda)

    if args.cuda:
        model.cuda(), dis.cuda(), crit.cuda()
    else:
        model.cpu(), dis.cpu(), crit.cpu()

    print('Parameters:')
    total_params = sum([np.prod(x.size()) for x in model.parameters()])
    print('\ttotal params:   {}'.format(total_params))
    print('\tparam list:     {}'.format(len(list(model.parameters()))))
    for name, x in model.named_parameters():
        print('\t' + name + '\t', tuple(x.size()))
    for name, x in dis.named_parameters():
        print('\t' + name + '\t', tuple(x.size()))
    print()

    ###############################################################################
    # Training code
    ###############################################################################

    bptt = args.bptt
    best_accs = {tlang: 0. for tlang in args.trg}
    final_test_accs = {tlang: 0. for tlang in args.trg}
    print('Traning:')
    print_line()
    ptrs = np.zeros((len(args.lang), len(args.dom)), dtype=np.int64)  # pointers for reading unlabeled data, of shape (n_lang, n_dom)
    total_loss = np.zeros((len(args.lang), len(args.dom)))  # shape (n_lang, n_dom)
    total_clf_loss = 0
    total_dis_loss = 0
    start_time = time.time()
    model.train()
    model.reset()
    for step in range(args.max_steps):
        loss = 0
        lm_opt.zero_grad()
        dis_opt.zero_grad()

        if not args.mwe:
            seq_len = max(5, int(np.random.normal(bptt if np.random.random() < 0.95 else bptt / 2., 5)))
            lr0 = lm_opt.param_groups[0]['lr']
            lm_opt.param_groups[0]['lr'] = lr0 * seq_len / args.bptt

            # language modeling loss
            dis_x = []
            for lid, t in enumerate(unlabeled):
                for did, lm_x in enumerate(t):
                    if ptrs[lid, did] + bptt + 1 > lm_x.size(0):
                        ptrs[lid, did] = 0
                        model.reset(lid=lid, did=did)
                    p = ptrs[lid, did]
                    xs = lm_x[p: p + bptt].t().contiguous()
                    ys = lm_x[p + 1: p + 1 + bptt].t().contiguous()
                    lm_raw_loss, lm_loss, hid = model.lm_loss(xs, ys, lid=lid, did=did, return_h=True)
                    loss = loss + lm_loss * args.lambd_lm
                    total_loss[lid, did] += lm_raw_loss.item()
                    ptrs[lid, did] += bptt
                    if did == dom_id:
                        dis_x.append(hid[-1].mean(1))

            # language adversarial loss
            dis_x_rev = GradReverse.apply(torch.cat(dis_x, 0))
            dis_loss = crit(dis(dis_x_rev), dis_y)
            loss = loss + args.lambd_dis * dis_loss
            total_dis_loss += dis_loss.item()
            loss.backward()

        # sentiment classification loss
        try:
            xs, ys, ls = next(train_iter)
        except StopIteration:
            train_iter = iter(senti_train)
            xs, ys, ls = next(train_iter)
        clf_loss = crit(model(xs, ls, src_id, dom_id), ys)
        total_clf_loss += clf_loss.item()
        (args.lambd_clf * clf_loss).backward()

        nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        if args.dis_clip > 0:
            for x in dis.parameters():
                x.data.clamp_(-args.dis_clip, args.dis_clip)
        lm_opt.step()
        dis_opt.step()
        if not args.mwe:
            lm_opt.param_groups[0]['lr'] = lr0

        if (step + 1) % args.log_interval == 0:
            total_loss /= args.log_interval
            total_clf_loss /= args.log_interval
            total_dis_loss /= args.log_interval
            elapsed = time.time() - start_time
            print('| step {:5d} | lr {:05.5f} | ms/batch {:7.2f} | lm_loss {:7.4f} | avg_ppl {:7.2f} | clf {:7.4f} | dis {:7.4f} |'.format(
                step, lm_opt.param_groups[0]['lr'], elapsed * 1000 / args.log_interval,
                total_loss.mean(), np.exp(total_loss).mean(), total_clf_loss, total_dis_loss))
            total_loss[:, :], total_clf_loss, total_dis_loss = 0, 0, 0
            start_time = time.time()

        if (step + 1) % args.val_interval == 0:
            model.eval()
            with torch.no_grad():
                train_acc = evaluate(model, train_ds, src_id, dom_id)
                val_accs = [evaluate(model, ds, tid, dom_id) for tid, ds in zip(trg_ids, val_ds)]
                test_accs = [evaluate(model, ds, tid, dom_id) for tid, ds in zip(trg_ids, test_ds)]
                bdi_accs = [compute_nn_accuracy(model.encoder_weight(src_id),
                                                model.encoder_weight(tid),
                                                lexicon, 10000, lexicon_size=lexsz) for lexicon, lexsz, tid in lexicons]
                print_line()
                print(('| step {:5d} | train {:.4f} |' +
                       ' val' + ' {} {:.4f}' * n_trg + ' |' +
                       ' test' + ' {} {:.4f}' * n_trg + ' |' +
                       ' bdi' + ' {} {:.4f}' * n_trg + ' |').format(step, train_acc,
                                                                    *sum([[tlang, acc] for tlang, acc in zip(args.trg, val_accs)], []),
                                                                    *sum([[tlang, acc] for tlang, acc in zip(args.trg, test_accs)], []),
                                                                    *sum([[tlang, acc] for tlang, acc in zip(args.trg, bdi_accs)], [])))
                print_line()
                print('saving model to {}'.format(model_path.replace('.pt', '_final.pt')))
                model_save(model, dis, lm_opt, dis_opt, model_path.replace('.pt', '_final.pt'))
                for tlang, val_acc, test_acc in zip(args.trg, val_accs, test_accs):
                    if val_acc > best_accs[tlang]:
                        save_path = model_path.replace('.pt', '_{}.pt'.format(tlang))
                        print('saving {} model to {}'.format(tlang, save_path))
                        model_save(model, dis, lm_opt, dis_opt, save_path)
                        best_accs[tlang] = val_acc
                        final_test_accs[tlang] = test_acc
                print_line()
            model.train()
            start_time = time.time()

    print_line()
    print('Training ended with {} steps'.format(step + 1))
    print(('Best val acc:             ' + ' {} {:.4f}' * n_trg).format(*sum([[tlang, best_accs[tlang]] for tlang in args.trg], [])))
    print(('Test acc (w/ early stop): ' + ' {} {:.4f}' * n_trg).format(*sum([[tlang, final_test_accs[tlang]] for tlang in args.trg], [])))
    print(('Test acc (w/o early stop):' + ' {} {:.4f}' * n_trg).format(*sum([[tlang, acc] for tlang, acc in zip(args.trg, test_accs)], [])))
def train(args):
    # set random seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available() and args.cuda:
        torch.cuda.manual_seed(args.seed)

    print('Configuration:')
    print('\n'.join('\t{:15} {}'.format(k + ':', str(v))
                    for k, v in sorted(dict(vars(args)).items())))
    print()

    model_path = os.path.join(args.export, 'model.pt')
    config_path = os.path.join(args.export, 'config.json')
    export_config(args, config_path)
    check_path(model_path)

    ###############################################################################
    # Load data
    ###############################################################################

    src_lang, src_dom = args.src.split('-')
    trg_lang, trg_dom = args.trg.split('-')
    lang_dom_pairs = [[src_lang, src_dom], [src_lang, trg_dom],
                      [trg_lang, trg_dom]]
    id_pairs = [[0, 0], [0, 1], [1, 1]]

    unlabeled_set = torch.load(args.unlabeled)
    train_set = torch.load(args.train)
    val_set = torch.load(args.val)
    test_set = torch.load(args.test)

    src_vocab = train_set[src_lang]['vocab']
    trg_vocab = train_set[trg_lang]['vocab']
    unlabeled = to_device([
        batchify(unlabeled_set[lang][dom], args.batch_size)
        for lang, dom in lang_dom_pairs
    ], args.cuda)
    train_x, train_y, train_l = to_device(train_set[src_lang][src_dom],
                                          args.cuda)
    val_x, val_y, val_l = to_device(val_set[trg_lang][trg_dom], args.cuda)
    test_x, test_y, test_l = to_device(test_set[trg_lang][trg_dom], args.cuda)

    if args.sample_unlabeled > 0:
        print('Downsampling unlabeled set...')
        print()
        unlabeled = [
            x[:(args.sample_unlabeled // args.batch_size)] for x in unlabeled
        ]
    if args.sample_train > 0:
        print('Downsampling training set...')
        print()
        train_x, train_y, train_l = sample([train_x, train_y, train_l],
                                           args.sample_train, True)

    senti_train = DataLoader(SentiDataset(train_x, train_y, train_l),
                             batch_size=args.clf_batch_size)
    train_iter = iter(senti_train)
    train_ds = DataLoader(SentiDataset(train_x, train_y, train_l),
                          batch_size=args.test_batch_size)
    val_ds = DataLoader(SentiDataset(val_x, val_y, val_l),
                        batch_size=args.test_batch_size)
    test_ds = DataLoader(SentiDataset(test_x, test_y, test_l),
                         batch_size=args.test_batch_size)
    lexicon, lexsz = load_lexicon(
        'data/muse/{}-{}.0-5000.txt'.format(src_lang, trg_lang), src_vocab,
        trg_vocab)

    ###############################################################################
    # Build the model
    ###############################################################################

    if args.resume:
        model, dis, lm_opt, dis_opt = model_load(args.resume)

    else:
        model = XLXDClassifier(n_classes=2,
                               clf_p=args.dropoutc,
                               n_langs=2,
                               n_doms=2,
                               vocab_sizes=[len(src_vocab),
                                            len(trg_vocab)],
                               emb_size=args.emb_dim,
                               hidden_size=args.hid_dim,
                               num_layers=args.nlayers,
                               num_share=args.nshare,
                               tie_weights=args.tie_softmax,
                               output_p=args.dropouto,
                               hidden_p=args.dropouth,
                               input_p=args.dropouti,
                               embed_p=args.dropoute,
                               weight_p=args.dropoutw,
                               alpha=2,
                               beta=1)

        dis = Discriminator(args.emb_dim, args.dis_hid_dim, 2,
                            args.dis_nlayers, args.dropoutd)

        if args.mwe:
            x, count = load_vectors_with_vocab(args.mwe_path.format(src_lang),
                                               src_vocab, -1)
            model.encoders[0].weight.data.copy_(torch.from_numpy(x))
            x, count = load_vectors_with_vocab(args.mwe_path.format(trg_lang),
                                               trg_vocab, -1)
            model.encoders[1].weight.data.copy_(torch.from_numpy(x))
            freeze_net(model.encoders)

        params = [{
            'params': model.models.parameters(),
            'lr': args.lr
        }, {
            'params': model.clfs.parameters(),
            'lr': args.lr
        }]
        if args.optimizer == 'sgd':
            lm_opt = torch.optim.SGD(params,
                                     lr=args.lr,
                                     weight_decay=args.wdecay)
            dis_opt = torch.optim.SGD(dis.parameters(),
                                      lr=args.lr,
                                      weight_decay=args.wdecay)
        if args.optimizer == 'adam':
            lm_opt = torch.optim.Adam(params,
                                      lr=args.lr,
                                      weight_decay=args.wdecay,
                                      betas=(args.beta1, 0.999))
            dis_opt = torch.optim.Adam(dis.parameters(),
                                       lr=args.lr,
                                       weight_decay=args.wdecay,
                                       betas=(args.beta1, 0.999))

    crit = nn.CrossEntropyLoss()
    bs = args.batch_size
    dis_y = to_device(torch.tensor([0] * bs + [1] * bs), args.cuda)

    if args.cuda:
        model.cuda(), dis.cuda(), crit.cuda()
    else:
        model.cpu(), dis.cpu(), crit.cpu()

    print('Parameters:')
    total_params = sum([np.prod(x.size()) for x in model.parameters()])
    print('\ttotal params:   {}'.format(total_params))
    print('\tparam list:     {}'.format(len(list(model.parameters()))))
    for name, x in model.named_parameters():
        print('\t' + name + '\t', tuple(x.size()))
    for name, x in dis.named_parameters():
        print('\t' + name + '\t', tuple(x.size()))
    print()

    ###############################################################################
    # Training code
    ###############################################################################

    bptt = args.bptt
    best_acc = 0.
    final_test_acc = 0.
    print('Traning:')
    print_line()
    p = 0
    ptrs = np.zeros(3, dtype=np.int64)
    total_loss = np.zeros(3)  # shape (n_lang, n_dom)
    total_clf_loss = 0
    total_dis_loss = 0
    start_time = time.time()
    model.train()
    model.reset()
    for step in range(args.max_steps):
        loss = 0
        lm_opt.zero_grad()
        dis_opt.zero_grad()

        if not args.mwe:
            seq_len = max(
                5,
                int(
                    np.random.normal(
                        bptt if np.random.random() < 0.95 else bptt / 2., 5)))
            lr0 = lm_opt.param_groups[0]['lr']
            lm_opt.param_groups[0]['lr'] = lr0 * seq_len / args.bptt

            # language modeling loss
            dis_x = []
            for i, ((lid, did), lm_x) in enumerate(zip(id_pairs, unlabeled)):
                if ptrs[i] + bptt + 1 > lm_x.size(0):
                    ptrs[i] = 0
                    model.reset(lid=lid, did=did)
                p = ptrs[i]
                xs = lm_x[p:p + bptt].t().contiguous()
                ys = lm_x[p + 1:p + 1 + bptt].t().contiguous()
                lm_raw_loss, lm_loss, hid = model.lm_loss(xs,
                                                          ys,
                                                          lid=lid,
                                                          did=did,
                                                          return_h=True)
                loss = loss + lm_loss * args.lambd_lm
                if lid == 0 and did == 0:
                    dis_x.append(hid[-1].mean(1))
                elif lid == 1 and did == 1:
                    _, _, hid = model.lm_loss(xs,
                                              ys,
                                              lid=1,
                                              did=0,
                                              return_h=True)
                    dis_x.append(hid[-1].mean(1))
                total_loss[i] += lm_raw_loss.item()
                ptrs[i] += bptt

            # language adversarial loss
            dis_x_rev = GradReverse.apply(torch.cat(dis_x, 0))
            dis_loss = crit(dis(dis_x_rev), dis_y)
            loss = loss + args.lambd_dis * dis_loss
            total_dis_loss += dis_loss.item()
            loss.backward()

        # sentiment classification loss
        try:
            xs, ys, ls = next(train_iter)
        except StopIteration:
            train_iter = iter(senti_train)
            xs, ys, ls = next(train_iter)
        clf_loss = crit(model(xs, ls, lid=0, did=0), ys)
        total_clf_loss += clf_loss.item()
        (args.lambd_clf * clf_loss).backward()

        nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        if args.dis_clip > 0:
            for x in dis.parameters():
                x.data.clamp_(-args.dis_clip, args.dis_clip)
        lm_opt.step()
        dis_opt.step()
        if not args.mwe:
            lm_opt.param_groups[0]['lr'] = lr0

        if (step + 1) % args.log_interval == 0:
            total_loss /= args.log_interval
            total_clf_loss /= args.log_interval
            total_dis_loss /= args.log_interval
            elapsed = time.time() - start_time
            print(
                '| step {:5d} | lr {:05.5f} | ms/batch {:7.2f} | lm_loss {:7.4f} | avg_ppl {:7.2f} | clf {:7.4f} | dis {:7.4f} |'
                .format(step, lm_opt.param_groups[0]['lr'],
                        elapsed * 1000 / args.log_interval, total_loss.mean(),
                        np.exp(total_loss).mean(), total_clf_loss,
                        total_dis_loss))
            total_loss[:], total_clf_loss, total_dis_loss = 0, 0, 0
            start_time = time.time()

        if (step + 1) % args.val_interval == 0:
            model.eval()
            with torch.no_grad():
                train_acc = evaluate(model, train_ds, 0, 0)
                val_acc = evaluate(model, val_ds, 1, 0)
                test_acc = evaluate(model, test_ds, 1, 0)
                bdi_acc = compute_nn_accuracy(model.encoder_weight(0),
                                              model.encoder_weight(1),
                                              lexicon,
                                              10000,
                                              lexicon_size=lexsz)
                print_line()
                print(
                    '| step {:5d} | train {:.4f} |  val {:.4f} | test {:.4f} | bdi {:.4f} |'
                    .format(step, train_acc, val_acc, test_acc, bdi_acc))
                print_line()
                print('saving model to {}'.format(
                    model_path.replace('.pt', '_final.pt')))
                model_save(model, dis, lm_opt, dis_opt,
                           model_path.replace('.pt', '_final.pt'))
                if val_acc > best_acc:
                    print('saving model to {}'.format(model_path))
                    model_save(model, dis, lm_opt, dis_opt, model_path)
                    best_acc, final_test_acc = val_acc, test_acc
                print_line()
            model.train()
            start_time = time.time()

    print_line()
    print('Training ended with {} steps'.format(step + 1))
    print('Best val acc:              {}->{} {:.4f}'.format(
        args.src, args.trg, best_acc))
    print('Test acc (w/ early stop):  {}->{} {:.4f}'.format(
        args.src, args.trg, final_test_acc))
    print('Test acc (w/o early stop): {}->{} {:.4f}'.format(
        args.src, args.trg, test_acc))