train_iter = iter(train_loader)
    num_iter = len(train_loader)

    model = AttenLSTM(vocab, len(meshlabel_to_ix), char_to_ix)

    if torch.cuda.is_available():
        model.cuda(opt.gpu)

    optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate)
    criterion = nn.CrossEntropyLoss()

    if opt.fine_tune:
        utils.unfreeze_net(model.embedding)
    else:
        utils.freeze_net(model.embedding)
    #pre-training dictionary instance
    if opt.pretraining:

        dict_iter = iter(dict_loader)
        dict_num_iter = len(dict_loader)

        #start training dictionary
        logging.info("batch_size: %s,  dict_num_iter %s, train num_iter %s" %
                     (str(opt.batch_size), str(dict_num_iter), str(num_iter)))
        for epoch in range(opt.dict_iteration):
            epoch_start = time.time()

            # sum_dict_cost = 0.0
            correct_1, total_1 = 0, 0
Exemple #2
0
def train(vocabs, char_vocab, tag_vocab, train_sets, dev_sets, test_sets, unlabeled_sets):
    """
    train_sets, dev_sets, test_sets: dict[lang] -> AmazonDataset
    For unlabeled langs, no train_sets are available
    """
    # dataset loaders
    train_loaders, unlabeled_loaders = {}, {}
    train_iters, unlabeled_iters, d_unlabeled_iters = {}, {}, {}
    dev_loaders, test_loaders = {}, {}
    my_collate = utils.sorted_collate if opt.model=='lstm' else utils.unsorted_collate
    for lang in opt.langs:
        train_loaders[lang] = DataLoader(train_sets[lang],
                opt.batch_size, shuffle=True, collate_fn = my_collate)
        train_iters[lang] = iter(train_loaders[lang])
    for lang in opt.dev_langs:
        dev_loaders[lang] = DataLoader(dev_sets[lang],
                opt.batch_size, shuffle=False, collate_fn = my_collate)
        test_loaders[lang] = DataLoader(test_sets[lang],
                opt.batch_size, shuffle=False, collate_fn = my_collate)
    for lang in opt.all_langs:
        if lang in opt.unlabeled_langs:
            uset = unlabeled_sets[lang]
        else:
            # for labeled langs, consider which data to use as unlabeled set
            if opt.unlabeled_data == 'both':
                uset = ConcatDataset([train_sets[lang], unlabeled_sets[lang]])
            elif opt.unlabeled_data == 'unlabeled':
                uset = unlabeled_sets[lang]
            elif opt.unlabeled_data == 'train':
                uset = train_sets[lang]
            else:
                raise Exception(f'Unknown options for the unlabeled data usage: {opt.unlabeled_data}')
        unlabeled_loaders[lang] = DataLoader(uset,
                opt.batch_size, shuffle=True, collate_fn = my_collate)
        unlabeled_iters[lang] = iter(unlabeled_loaders[lang])
        d_unlabeled_iters[lang] = iter(unlabeled_loaders[lang])

    # embeddings
    emb = MultiLangWordEmb(vocabs, char_vocab, opt.use_wordemb, opt.use_charemb).to(opt.device)
    # models
    F_s = None
    F_p = None
    C, D = None, None
    num_experts = len(opt.langs)+1 if opt.expert_sp else len(opt.langs)
    if opt.model.lower() == 'lstm':
        if opt.shared_hidden_size > 0:
            F_s = LSTMFeatureExtractor(opt.total_emb_size, opt.F_layers, opt.shared_hidden_size,
                                       opt.word_dropout, opt.dropout, opt.bdrnn)
        if opt.private_hidden_size > 0:
            if not opt.concat_sp:
                assert opt.shared_hidden_size == opt.private_hidden_size, "shared dim != private dim when using add_sp!"
            F_p = nn.Sequential(
                    LSTMFeatureExtractor(opt.total_emb_size, opt.F_layers, opt.private_hidden_size,
                            opt.word_dropout, opt.dropout, opt.bdrnn),
                    MixtureOfExperts(opt.MoE_layers, opt.private_hidden_size,
                            len(opt.langs), opt.private_hidden_size,
                            opt.private_hidden_size, opt.dropout, opt.MoE_bn, False)
                    )
    else:
        raise Exception(f'Unknown model architecture {opt.model}')

    if opt.C_MoE:
        C = SpMixtureOfExperts(opt.C_layers, opt.shared_hidden_size, opt.private_hidden_size, opt.concat_sp,
                num_experts, opt.shared_hidden_size + opt.private_hidden_size, len(tag_vocab),
                opt.mlp_dropout, opt.C_bn)
    else:
        C = SpMlpTagger(opt.C_layers, opt.shared_hidden_size, opt.private_hidden_size, opt.concat_sp,
                opt.shared_hidden_size + opt.private_hidden_size, len(tag_vocab),
                opt.mlp_dropout, opt.C_bn)
    if opt.shared_hidden_size > 0 and opt.n_critic > 0:
        if opt.D_model.lower() == 'lstm':
            d_args = {
                'num_layers': opt.D_lstm_layers,
                'input_size': opt.shared_hidden_size,
                'hidden_size': opt.shared_hidden_size,
                'word_dropout': opt.D_word_dropout,
                'dropout': opt.D_dropout,
                'bdrnn': opt.D_bdrnn,
                'attn_type': opt.D_attn
            }
        elif opt.D_model.lower() == 'cnn':
            d_args = {
                'num_layers': 1,
                'input_size': opt.shared_hidden_size,
                'hidden_size': opt.shared_hidden_size,
                'kernel_num': opt.D_kernel_num,
                'kernel_sizes': opt.D_kernel_sizes,
                'word_dropout': opt.D_word_dropout,
                'dropout': opt.D_dropout
            }
        else:
            d_args = None

        if opt.D_model.lower() == 'mlp':
            D = MLPLanguageDiscriminator(opt.D_layers, opt.shared_hidden_size,
                    opt.shared_hidden_size, len(opt.all_langs), opt.loss, opt.D_dropout, opt.D_bn)
        else:
            D = LanguageDiscriminator(opt.D_model, opt.D_layers,
                    opt.shared_hidden_size, opt.shared_hidden_size,
                    len(opt.all_langs), opt.D_dropout, opt.D_bn, d_args)
    if opt.use_data_parallel:
        F_s, C, D = nn.DataParallel(F_s).to(opt.device) if F_s else None, nn.DataParallel(C).to(opt.device), nn.DataParallel(D).to(opt.device) if D else None
    else:
        F_s, C, D = F_s.to(opt.device) if F_s else None, C.to(opt.device), D.to(opt.device) if D else None
    if F_p:
        if opt.use_data_parallel:
            F_p = nn.DataParallel(F_p).to(opt.device)
        else:
            F_p = F_p.to(opt.device)
    # optimizers
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, itertools.chain(*map(list,
        [emb.parameters(), F_s.parameters() if F_s else [], \
        C.parameters(), F_p.parameters() if F_p else []]))),
        lr=opt.learning_rate,
        weight_decay=opt.weight_decay)
    if D:
        optimizerD = optim.Adam(D.parameters(), lr=opt.D_learning_rate, weight_decay=opt.D_weight_decay)

    # testing
    if opt.test_only:
        log.info(f'Loading model from {opt.model_save_file}...')
        if F_s:
            F_s.load_state_dict(torch.load(os.path.join(opt.model_save_file,
                f'netF_s.pth')))
        for lang in opt.all_langs:
            F_p.load_state_dict(torch.load(os.path.join(opt.model_save_file,
                f'net_F_p.pth')))
        C.load_state_dict(torch.load(os.path.join(opt.model_save_file,
            f'netC.pth')))
        if D:
            D.load_state_dict(torch.load(os.path.join(opt.model_save_file,
                f'netD.pth')))

        log.info('Evaluating validation sets:')
        acc = {}
        log.info(dev_loaders)
        log.info(vocabs)
        for lang in opt.all_langs:
            acc[lang] = evaluate(f'{lang}_dev', dev_loaders[lang], vocabs[lang], tag_vocab,
                    emb, lang, F_s, F_p, C)
        avg_acc = sum([acc[d] for d in opt.dev_langs]) / len(opt.dev_langs)
        log.info(f'Average validation accuracy: {avg_acc}')
        log.info('Evaluating test sets:')
        test_acc = {}
        for lang in opt.all_langs:
            test_acc[lang] = evaluate(f'{lang}_test', test_loaders[lang], vocabs[lang], tag_vocab,
                    emb, lang, F_s, F_p, C)
        avg_test_acc = sum([test_acc[d] for d in opt.dev_langs]) / len(opt.dev_langs)
        log.info(f'Average test accuracy: {avg_test_acc}')
        return {'valid': acc, 'test': test_acc}

    # training
    best_acc, best_avg_acc = defaultdict(float), 0.0
    epochs_since_decay = 0
    # lambda scheduling
    if opt.lambd > 0 and opt.lambd_schedule:
        opt.lambd_orig = opt.lambd
    num_iter = int(utils.gmean([len(train_loaders[l]) for l in opt.langs]))
    # adapt max_epoch
    if opt.max_epoch > 0 and num_iter * opt.max_epoch < 15000:
        opt.max_epoch = 15000 // num_iter
        log.info(f"Setting max_epoch to {opt.max_epoch}")
    for epoch in range(opt.max_epoch):
        emb.train()
        if F_s:
            F_s.train()
        C.train()
        if D:
            D.train()
        if F_p:
            F_p.train()
            
        # lambda scheduling
        if hasattr(opt, 'lambd_orig') and opt.lambd_schedule:
            if epoch == 0:
                opt.lambd = opt.lambd_orig
            elif epoch == 5:
                opt.lambd = 10 * opt.lambd_orig
            elif epoch == 15:
                opt.lambd = 100 * opt.lambd_orig
            log.info(f'Scheduling lambda = {opt.lambd}')

        # training accuracy
        correct, total = defaultdict(int), defaultdict(int)
        gate_correct = defaultdict(int)
        c_gate_correct = defaultdict(int)
        # D accuracy
        d_correct, d_total = 0, 0
        for i in tqdm(range(num_iter), ascii=True):
            # D iterations
            if opt.shared_hidden_size > 0:
                utils.freeze_net(emb)
                utils.freeze_net(F_s)
                utils.freeze_net(F_p)
                utils.freeze_net(C)
                utils.unfreeze_net(D)
                # WGAN n_critic trick since D trains slower
                n_critic = opt.n_critic
                if opt.wgan_trick:
                    if opt.n_critic>0 and ((epoch==0 and i<25) or i%500==0):
                        n_critic = 100

                for _ in range(n_critic):
                    D.zero_grad()
                    loss_d = {}
                    lang_features = {}
                    # train on both labeled and unlabeled langs
                    for lang in opt.all_langs:
                        # targets not used
                        d_inputs, _ = utils.endless_get_next_batch(
                                unlabeled_loaders, d_unlabeled_iters, lang)
                        d_inputs, d_lengths, mask, d_chars, d_char_lengths = d_inputs
                        d_embeds = emb(lang, d_inputs, d_chars, d_char_lengths)
                        shared_feat = F_s((d_embeds, d_lengths))
                        if opt.grad_penalty != 'none':
                            lang_features[lang] = shared_feat.detach()
                        if opt.D_model.lower() == 'mlp':
                            d_outputs = D(shared_feat)
                            # if token-level D, we can reuse the gate label generator
                            d_targets = utils.get_gate_label(d_outputs, lang, mask, False, all_langs=True)
                            d_total += torch.sum(d_lengths).item()
                        else:
                            d_outputs = D((shared_feat, d_lengths))
                            d_targets = utils.get_lang_label(opt.loss, lang, len(d_lengths))
                            d_total += len(d_lengths)
                        # D accuracy
                        _, pred = torch.max(d_outputs, -1)
                        # d_total += len(d_lengths)
                        d_correct += (pred==d_targets).sum().item()
                        if opt.use_data_parallel:
                            l_d = functional.nll_loss(d_outputs.view(-1, D.module.num_langs),
                                    d_targets.view(-1), ignore_index=-1)
                        else:
                            l_d = functional.nll_loss(d_outputs.view(-1, D.num_langs),
                                    d_targets.view(-1), ignore_index=-1)

                        l_d.backward()
                        loss_d[lang] = l_d.item()
                    # gradient penalty
                    if opt.grad_penalty != 'none':
                        gp = utils.calc_gradient_penalty(D, lang_features,
                                onesided=opt.onesided_gp, interpolate=(opt.grad_penalty=='wgan'))
                        gp.backward()
                    optimizerD.step()

            # F&C iteration
            utils.unfreeze_net(emb)
            if opt.use_wordemb and opt.fix_emb:
                for lang in emb.langs:
                    emb.wordembs[lang].weight.requires_grad = False
            if opt.use_charemb and opt.fix_charemb:
                emb.charemb.weight.requires_grad = False
            utils.unfreeze_net(F_s)
            utils.unfreeze_net(F_p)
            utils.unfreeze_net(C)
            utils.freeze_net(D)
            emb.zero_grad()
            if F_s:
                F_s.zero_grad()
            if F_p:
                F_p.zero_grad()
            C.zero_grad()
            # optimizer.zero_grad()
            for lang in opt.langs:
                inputs, targets = utils.endless_get_next_batch(
                        train_loaders, train_iters, lang)
                inputs, lengths, mask, chars, char_lengths = inputs
                bs, seq_len = inputs.size()
                embeds = emb(lang, inputs, chars, char_lengths)
                shared_feat, private_feat = None, None
                if opt.shared_hidden_size > 0:
                    shared_feat = F_s((embeds, lengths))
                if opt.private_hidden_size > 0:
                    private_feat, gate_outputs = F_p((embeds, lengths))
                if opt.C_MoE:
                    c_outputs, c_gate_outputs = C((shared_feat, private_feat))
                else:
                    c_outputs = C((shared_feat, private_feat))
                # targets are padded with -1
                l_c = functional.nll_loss(c_outputs.view(bs*seq_len, -1),
                        targets.view(-1), ignore_index=-1)
                # gate loss
                if F_p:
                    gate_targets = utils.get_gate_label(gate_outputs, lang, mask, False)
                    l_gate = functional.cross_entropy(gate_outputs.view(bs*seq_len, -1),
                            gate_targets.view(-1), ignore_index=-1)
                    l_c += opt.gate_loss_weight * l_gate
                    _, gate_pred = torch.max(gate_outputs.view(bs*seq_len, -1), -1)
                    gate_correct[lang] += (gate_pred == gate_targets.view(-1)).sum().item()
                if opt.C_MoE and opt.C_gate_loss_weight > 0:
                    c_gate_targets = utils.get_gate_label(c_gate_outputs, lang, mask, opt.expert_sp)
                    _, c_gate_pred = torch.max(c_gate_outputs.view(bs*seq_len, -1), -1)
                    if opt.expert_sp:
                        l_c_gate = functional.binary_cross_entropy_with_logits(
                                mask.unsqueeze(-1) * c_gate_outputs, c_gate_targets)
                        c_gate_correct[lang] += torch.index_select(c_gate_targets.view(bs*seq_len, -1),
                                -1, c_gate_pred.view(bs*seq_len)).sum().item()
                    else:
                        l_c_gate = functional.cross_entropy(c_gate_outputs.view(bs*seq_len, -1),
                                c_gate_targets.view(-1), ignore_index=-1)
                        c_gate_correct[lang] += (c_gate_pred == c_gate_targets.view(-1)).sum().item()
                    l_c += opt.C_gate_loss_weight * l_c_gate
                l_c.backward()
                _, pred = torch.max(c_outputs, -1)
                total[lang] += torch.sum(lengths).item()
                correct[lang] += (pred == targets).sum().item()

            # update F with D gradients on all langs
            if D:
                for lang in opt.all_langs:
                    inputs, _ = utils.endless_get_next_batch(
                            unlabeled_loaders, unlabeled_iters, lang)
                    inputs, lengths, mask, chars, char_lengths = inputs
                    embeds = emb(lang, inputs, chars, char_lengths)
                    shared_feat = F_s((embeds, lengths))
                    # d_outputs = D((shared_feat, lengths))
                    if opt.D_model.lower() == 'mlp':
                        d_outputs = D(shared_feat)
                        # if token-level D, we can reuse the gate label generator
                        d_targets = utils.get_gate_label(d_outputs, lang, mask, False, all_langs=True)
                    else:
                        d_outputs = D((shared_feat, lengths))
                        d_targets = utils.get_lang_label(opt.loss, lang, len(lengths))
                    if opt.use_data_parallel:
                        l_d = functional.nll_loss(d_outputs.view(-1, D.module.num_langs),
                                d_targets.view(-1), ignore_index=-1)
                    else:
                        l_d = functional.nll_loss(d_outputs.view(-1, D.num_langs),
                                d_targets.view(-1), ignore_index=-1)
                    if opt.lambd > 0:
                        l_d *= -opt.lambd
                    l_d.backward()

            optimizer.step()

        # end of epoch
        log.info('Ending epoch {}'.format(epoch+1))
        if d_total > 0:
            log.info('D Training Accuracy: {}%'.format(100.0*d_correct/d_total))
        log.info('Training accuracy:')
        log.info('\t'.join(opt.langs))
        log.info('\t'.join([str(100.0*correct[d]/total[d]) for d in opt.langs]))
        log.info('Gate accuracy:')
        log.info('\t'.join([str(100.0*gate_correct[d]/total[d]) for d in opt.langs]))
        log.info('Tagger Gate accuracy:')
        log.info('\t'.join([str(100.0*c_gate_correct[d]/total[d]) for d in opt.langs]))
        log.info('Evaluating validation sets:')
        acc = {}
        for lang in opt.dev_langs:
            acc[lang] = evaluate(f'{lang}_dev', dev_loaders[lang], vocabs[lang], tag_vocab,
                    emb, lang, F_s, F_p, C)
        avg_acc = sum([acc[d] for d in opt.dev_langs]) / len(opt.dev_langs)
        log.info(f'Average validation accuracy: {avg_acc}')
        log.info('Evaluating test sets:')
        test_acc = {}
        for lang in opt.dev_langs:
            test_acc[lang] = evaluate(f'{lang}_test', test_loaders[lang], vocabs[lang], tag_vocab,
                    emb, lang, F_s, F_p, C)
        avg_test_acc = sum([test_acc[d] for d in opt.dev_langs]) / len(opt.dev_langs)
        log.info(f'Average test accuracy: {avg_test_acc}')

        if avg_acc > best_avg_acc:
            epochs_since_decay = 0
            log.info(f'New best average validation accuracy: {avg_acc}')
            best_acc['valid'] = acc
            best_acc['test'] = test_acc
            best_avg_acc = avg_acc
            with open(os.path.join(opt.model_save_file, 'options.pkl'), 'wb') as ouf:
                pickle.dump(opt, ouf)
            if F_s:
                torch.save(F_s.state_dict(),
                        '{}/netF_s.pth'.format(opt.model_save_file))
            torch.save(emb.state_dict(),
                    '{}/net_emb.pth'.format(opt.model_save_file))
            if F_p:
                torch.save(F_p.state_dict(),
                        '{}/net_F_p.pth'.format(opt.model_save_file))
            torch.save(C.state_dict(),
                    '{}/netC.pth'.format(opt.model_save_file))
            if D:
                torch.save(D.state_dict(),
                        '{}/netD.pth'.format(opt.model_save_file))
        else:
            epochs_since_decay += 1
            if opt.lr_decay < 1 and epochs_since_decay >= opt.lr_decay_epochs:
                epochs_since_decay = 0
                old_lr = optimizer.param_groups[0]['lr']
                optimizer.param_groups[0]['lr'] = old_lr * opt.lr_decay
                log.info(f'Decreasing LR to {old_lr * opt.lr_decay}')

    # end of training
    log.info(f'Best average validation accuracy: {best_avg_acc}')
    return best_acc
Exemple #3
0
def train(vocab, train_sets, dev_sets, test_sets, unlabeled_sets):
    """
    train_sets, dev_sets, test_sets: dict[domain] -> AmazonDataset
    For unlabeled domains, no train_sets are available
    """
    # dataset loaders
    train_loaders, unlabeled_loaders = {}, {}
    train_iters, unlabeled_iters = {}, {}
    dev_loaders, test_loaders = {}, {}
    my_collate = utils.sorted_collate if opt.model == 'lstm' else utils.unsorted_collate
    for domain in opt.domains:
        train_loaders[domain] = DataLoader(train_sets[domain],
                                           opt.batch_size, shuffle=True, collate_fn=my_collate)
        train_iters[domain] = iter(train_loaders[domain])
    for domain in opt.dev_domains:
        dev_loaders[domain] = DataLoader(dev_sets[domain],
                                         opt.batch_size, shuffle=False, collate_fn=my_collate)
        test_loaders[domain] = DataLoader(test_sets[domain],
                                          opt.batch_size, shuffle=False, collate_fn=my_collate)
    for domain in opt.all_domains:
        if domain in opt.unlabeled_domains:
            uset = unlabeled_sets[domain]
        else:
            # for labeled domains, consider which data to use as unlabeled set
            if opt.unlabeled_data == 'both':
                uset = ConcatDataset([train_sets[domain], unlabeled_sets[domain]])
            elif opt.unlabeled_data == 'unlabeled':
                uset = unlabeled_sets[domain]
            elif opt.unlabeled_data == 'train':
                uset = train_sets[domain]
            else:
                raise Exception('Unknown options for the unlabeled data usage: {}'.format(opt.unlabeled_data))
        unlabeled_loaders[domain] = DataLoader(uset,
                                               opt.batch_size, shuffle=True, collate_fn=my_collate)
        unlabeled_iters[domain] = iter(unlabeled_loaders[domain])

    # models
    F_s = None
    C, D = None, None
    if opt.model.lower() == 'dan':
        F_s = DanFeatureExtractor(vocab, opt.F_layers, opt.shared_hidden_size,
                                  opt.sum_pooling, opt.dropout, opt.F_bn)
    elif opt.model.lower() == 'lstm':
        F_s = LSTMFeatureExtractor(vocab, opt.F_layers, opt.shared_hidden_size,
                                   opt.dropout, opt.bdrnn, opt.attn)
    elif opt.model.lower() == 'cnn':
        F_s = CNNFeatureExtractor(vocab, opt.F_layers, opt.shared_hidden_size,
                                  opt.kernel_num, opt.kernel_sizes, opt.dropout)
    else:
        raise Exception('Unknown model architecture {}'.format(opt.model))

    C = SentimentClassifier(opt.C_layers, opt.shared_hidden_size + opt.domain_hidden_size,
                            opt.shared_hidden_size + opt.domain_hidden_size, opt.num_labels,
                            opt.dropout, opt.C_bn)
    D = DomainClassifier(opt.D_layers, opt.shared_hidden_size, opt.shared_hidden_size,
                         len(opt.all_domains), opt.loss, opt.dropout, opt.D_bn)

    F_s, C, D = F_s.to(opt.device), C.to(opt.device), D.to(opt.device)
    # optimizers
    optimizer = optim.Adam(itertools.chain(
        *map(list, [F_s.parameters() if F_s else [], C.parameters()] + [])),
                           lr=opt.learning_rate)
    optimizerD = optim.Adam(D.parameters(), lr=opt.D_learning_rate)


    # training
    best_acc, best_avg_acc = defaultdict(float), 0.0
    for epoch in range(opt.max_epoch):
        F_s.train()
        C.train()
        D.train()

        # training accuracy
        correct, total = defaultdict(int), defaultdict(int)
        # D accuracy
        d_correct, d_total = 0, 0
        # conceptually view 1 epoch as 1 epoch of the first domain
        num_iter = len(train_loaders[opt.domains[0]])
        for i in tqdm(range(num_iter)):
            # D iterations
            utils.freeze_net(F_s)
            utils.freeze_net(C)
            utils.unfreeze_net(D)
            # WGAN n_critic trick since D trains slower
            n_critic = opt.n_critic
            if opt.wgan_trick:
                if opt.n_critic > 0 and ((epoch == 0 and i < 25) or i % 500 == 0):
                    n_critic = 100

            for _ in range(n_critic):
                D.zero_grad()
                loss_d = {}
                # train on both labeled and unlabeled domains
                for domain in opt.all_domains:
                    # targets not used
                    d_inputs, _ = utils.endless_get_next_batch(
                        unlabeled_loaders, unlabeled_iters, domain)
                    d_targets = utils.get_domain_label(opt.loss, domain, len(d_inputs[1]))
                    shared_feat = F_s(d_inputs)
                    d_outputs = D(shared_feat)
                    # D accuracy
                    _, pred = torch.max(d_outputs, 1)
                    d_total += len(d_inputs[1])
                    if opt.loss.lower() == 'l2':
                        _, tgt_indices = torch.max(d_targets, 1)
                        d_correct += (pred == tgt_indices).sum().item()
                        l_d = functional.mse_loss(d_outputs, d_targets)
                        l_d.backward()
                    else:
                        d_correct += (pred == d_targets).sum().item()
                        l_d = functional.nll_loss(d_outputs, d_targets)
                        l_d.backward()
                    loss_d[domain] = l_d.item()
                optimizerD.step()

            # F&C iteration
            utils.unfreeze_net(F_s)
            utils.unfreeze_net(C)
            utils.freeze_net(D)
            if opt.fix_emb:
                utils.freeze_net(F_s.word_emb)
            F_s.zero_grad()
            C.zero_grad()
            for domain in opt.domains:
                inputs, targets = utils.endless_get_next_batch(
                    train_loaders, train_iters, domain)
                targets = targets.to(opt.device)
                shared_feat = F_s(inputs)
                domain_feat = torch.zeros(len(targets), opt.domain_hidden_size).to(opt.device)
                features = torch.cat((shared_feat, domain_feat), dim=1)
                c_outputs = C(features)
                l_c = functional.nll_loss(c_outputs, targets)
                l_c.backward(retain_graph=True)
                _, pred = torch.max(c_outputs, 1)
                total[domain] += targets.size(0)
                correct[domain] += (pred == targets).sum().item()
            # update F with D gradients on all domains
            for domain in opt.all_domains:
                d_inputs, _ = utils.endless_get_next_batch(
                    unlabeled_loaders, unlabeled_iters, domain)
                shared_feat = F_s(d_inputs)
                d_outputs = D(shared_feat)
                if opt.loss.lower() == 'gr':
                    d_targets = utils.get_domain_label(opt.loss, domain, len(d_inputs[1]))
                    l_d = functional.nll_loss(d_outputs, d_targets)
                    if opt.lambd > 0:
                        l_d *= -opt.lambd
                elif opt.loss.lower() == 'bs':
                    d_targets = utils.get_random_domain_label(opt.loss, len(d_inputs[1]))
                    l_d = functional.kl_div(d_outputs, d_targets, size_average=False)
                    if opt.lambd > 0:
                        l_d *= opt.lambd
                elif opt.loss.lower() == 'l2':
                    d_targets = utils.get_random_domain_label(opt.loss, len(d_inputs[1]))
                    l_d = functional.mse_loss(d_outputs, d_targets)
                    if opt.lambd > 0:
                        l_d *= opt.lambd
                l_d.backward()

            optimizer.step()

        # end of epoch
        log.info('Ending epoch {}'.format(epoch + 1))
        if d_total > 0:
            log.info('D Training Accuracy: {}%'.format(100.0 * d_correct / d_total))
        log.info('Training accuracy:')
        log.info('\t'.join(opt.domains))
        log.info('\t'.join([str(100.0 * correct[d] / total[d]) for d in opt.domains]))
        log.info('Evaluating validation sets:')
        acc = {}
        for domain in opt.dev_domains:
            acc[domain] = evaluate(domain, dev_loaders[domain],
                                   F_s, None, C)
        avg_acc = sum([acc[d] for d in opt.dev_domains]) / len(opt.dev_domains)
        log.info('Average validation accuracy: {}'.format(avg_acc))
        log.info('Evaluating test sets:')
        test_acc = {}
        for domain in opt.dev_domains:
            test_acc[domain] = evaluate(domain, test_loaders[domain],
                                        F_s, None, C)
        avg_test_acc = sum([test_acc[d] for d in opt.dev_domains]) / len(opt.dev_domains)
        log.info('Average test accuracy: {}'.format(avg_test_acc))

        if avg_acc > best_avg_acc:
            log.info('New best average validation accuracy: {}'.format(avg_acc))
            best_acc['valid'] = acc
            best_acc['test'] = test_acc
            best_avg_acc = avg_acc
            with open(os.path.join(opt.model_save_file, 'options.pkl'), 'wb') as ouf:
                pickle.dump(opt, ouf)
            torch.save(F_s.state_dict(),
                       '{}/netF_s.pth'.format(opt.model_save_file))
            torch.save(C.state_dict(),
                       '{}/netC.pth'.format(opt.model_save_file))
            torch.save(D.state_dict(),
                       '{}/netD.pth'.format(opt.model_save_file))

    # end of training
    log.info('Best average validation accuracy: {}'.format(best_avg_acc))
    return best_acc
def train(target_domain_idx, train_sets, test_sets):
    """
    train_sets, test_sets: unlabeled domain(dev domain) -> AmazonDataset
    """
    # test dataset
    print(train_sets[0])
    print(test_sets[0])
    print(train_sets[0][0].shape)
    print(test_sets[0][1].shape)

    # ---------------------- dataloader ---------------------- #
    train_loaders = DataLoader(train_sets, opt.batch_size, shuffle=False)
    test_loaders = DataLoader(test_sets, opt.batch_size, shuffle=False)

    # ---------------------- model initialization ---------------------- #
    F_s = None
    F_d = {}
    C = None
    if opt.model.lower() == 'mlp':
        F_s = MlpFeatureExtractor(opt.feature_num, opt.F_hidden_sizes,
                opt.shared_hidden_size, opt.dropout, opt.F_bn)
        for domain in opt.domains:
            F_d[domain] = MlpFeatureExtractor(opt.feature_num, opt.F_hidden_sizes,
                opt.domain_hidden_size, opt.dropout, opt.F_bn)
    else:
        raise Exception(f'Unknown model architecture {opt.model}')

    C = SentimentClassifier(opt.C_layers, opt.shared_hidden_size + opt.domain_hidden_size,
                            opt.shared_hidden_size + opt.domain_hidden_size, opt.num_labels,
                            opt.dropout, opt.C_bn)
    D = DomainClassifier(opt.D_layers, opt.shared_hidden_size, opt.shared_hidden_size,
                         len(opt.all_domains), opt.loss, opt.dropout, opt.D_bn)

    # 转移到gpu上
    F_s, C, D = F_s.to(opt.device), C.to(opt.device), D.to(opt.device)
    for f_d in F_d.values():
        f_d = f_d.to(opt.device)

    # ---------------------- load pre-training model ---------------------- #
    log.info(f'Loading model from {opt.exp2_model_save_file}...')
    F_s.load_state_dict(torch.load(os.path.join(opt.exp2_model_save_file,
                                                f'netF_s.pth')))
    for domain in opt.all_domains:
        if domain in F_d:
            F_d[domain].load_state_dict(torch.load(os.path.join(opt.exp2_model_save_file,
                                                                f'net_F_d_{domain}.pth')))
    C.load_state_dict(torch.load(os.path.join(opt.exp2_model_save_file,
                                              f'netC.pth')))
    D.load_state_dict(torch.load(os.path.join(opt.exp2_model_save_file,
                                              f'netD.pth')))

    # ---------------------- get fake label(hard label) ---------------------- #
    log.info('Get fake label:')
    pseudo_labels, _ = genarate_labels(target_domain_idx, False, train_loaders, F_s, F_d, C, D)

    # ********************** Test the accuracy of the label prediction ********************** #
    test_pseudo_labels, targets_total = genarate_labels(target_domain_idx, True, test_loaders, F_s, F_d, C, D)
    label_correct_acc = calc_label_prediction(test_pseudo_labels, targets_total)
    log.info(f'the correct rate of label prediction: {label_correct_acc}')
    # ********************** Test the accuracy of the label prediction ********************** #

    # ------------------------------- 构造target数据集 ------------------------------- #
    target_dataset = Subset(train_sets, pseudo_labels)
    target_dataloader_labelled = DataLoader(target_dataset, opt.batch_size, shuffle=True)

    target_train_iters = iter(target_dataloader_labelled)

    # ------------------------------- F_d_target模型以及一些必要的参数定义 ------------------------------- #
    F_d_target = MlpFeatureExtractor(opt.feature_num, opt.F_hidden_sizes,
                                     opt.domain_hidden_size, opt.dropout, opt.F_bn)

    F_d_target = F_d_target.to(opt.device)

    # ******************************* 加载预训练模型的权重,不用再从头训练 ******************************* #
    f_d_choice = random.randint(0, len(opt.domains))
    F_d_target.load_state_dict(torch.load(os.path.join(opt.exp2_model_save_file,
                                                       f'net_F_d_{opt.domains[f_d_choice]}.pth')))
    optimizer_F_d_target = optim.Adam(F_d_target.parameters(), lr=opt.learning_rate)

    # ------------------------------- 测试一下只利用shared feature的准确率 ------------------------------- #
    log.info('Evaluating test sets only on shared feature:')
    test_acc = evaluate(opt.dev_domains, test_loaders, F_s, None, C)
    log.info(f'test accuracy: {test_acc}')

    for epoch in range(opt.max_epoch):
        F_d_target.train()
        C.train()
        # training accuracy
        correct, total = 0, 0
        num_iter = len(target_dataloader_labelled)
        print(num_iter)
        for _ in tqdm(range(num_iter)):
            utils.freeze_net(F_s)
            C.zero_grad()
            F_d_target.zero_grad()
            inputs, targets = utils.endless_get_next_batch(target_dataloader_labelled, target_train_iters)
            targets = targets.to(opt.device)
            shared_feat = F_s(inputs)
            domain_feat = F_d_target(inputs)
            features = torch.cat((shared_feat, domain_feat), dim=1)
            c_outputs = C(features)
            l_c = functional.nll_loss(c_outputs, targets)
            l_c.backward()
            _, pred_idx = torch.max(c_outputs, 1)
            total += targets.size(0)
            correct += (pred_idx == targets).sum().item()

            optimizer_F_d_target.step()

        # 暂时留下一个问题,这里没有写验证集,因为如果这里要设置验证集的话,需要将训练集切割
        # 因为unlabeled domain里面只有2000个label sample以及raw_unlabeled_sets(四千多张)
        # 后者训练时用,前者测试的时候用
        # end of epoch
        print("correct is %d" % correct)
        print("total is %d" % total)
        log.info('Ending epoch {}'.format(epoch + 1))
        log.info('Training accuracy:')
        log.info(opt.unlabeled_domains)
        log.info(str(100.0 * correct / total))

        # 训练过程中的验证集
        print("correct is %d" % correct)
        print("total is %d" % total)
        log.info('Ending epoch {}'.format(epoch + 1))
        log.info('Training accuracy:')
        log.info('\t'.join(opt.unlabeled_domains))
        log.info(str(100.0 * correct / total))

    # 保存模型
    torch.save(F_d_target.state_dict(),
               '{}/net_F_d_target.pth'.format(opt.exp2_target_model_save_file))
    torch.save(C.state_dict(),
               '{}/netC_target.pth'.format(opt.exp2_target_model_save_file))

    # 在测试集上测试,选择2000张有label的标签数据
    # 方便对比与单独训练有label时的情况
    log.info('Evaluating test sets:')
    test_acc = evaluate(opt.dev_domains, test_loaders, F_s, F_d_target, C)
    log.info(f'test accuracy: {test_acc}')
    return test_acc
def train(vocab, train_sets, dev_sets, test_sets, unlabeled_sets):
    """
    train_sets, dev_sets, test_sets: dict[domain] -> AmazonDataset
    For unlabeled domains, no train_sets are available
    """
    # dataset loaders
    train_loaders, unlabeled_loaders = {}, {}
    train_iters, unlabeled_iters = {}, {}
    dev_loaders, test_loaders = {}, {}
    my_collate = utils.sorted_collate if opt.model == 'lstm' else utils.unsorted_collate
    for domain in opt.domains:
        train_loaders[domain] = DataLoader(train_sets[domain],
                                           opt.batch_size,
                                           shuffle=True,
                                           collate_fn=my_collate)
        train_iters[domain] = iter(train_loaders[domain])
    for domain in opt.dev_domains:
        dev_loaders[domain] = DataLoader(dev_sets[domain],
                                         opt.batch_size,
                                         shuffle=False,
                                         collate_fn=my_collate)
        test_loaders[domain] = DataLoader(test_sets[domain],
                                          opt.batch_size,
                                          shuffle=False,
                                          collate_fn=my_collate)
    for domain in opt.all_domains:
        if domain in opt.unlabeled_domains:
            uset = unlabeled_sets[domain]
        else:
            # for labeled domains, consider which data to use as unlabeled set
            if opt.unlabeled_data == 'both':
                uset = ConcatDataset(
                    [train_sets[domain], unlabeled_sets[domain]])
            elif opt.unlabeled_data == 'unlabeled':
                uset = unlabeled_sets[domain]
            elif opt.unlabeled_data == 'train':
                uset = train_sets[domain]
            else:
                raise Exception(
                    f'Unknown options for the unlabeled data usage: {opt.unlabeled_data}'
                )
        unlabeled_loaders[domain] = DataLoader(uset,
                                               opt.batch_size,
                                               shuffle=True,
                                               collate_fn=my_collate)
        unlabeled_iters[domain] = iter(unlabeled_loaders[domain])

    # model
    F_s = None
    F_d = {}
    C, D = None, None
    if opt.model.lower() == 'dan':
        F_s = DanFeatureExtractor(vocab, opt.F_layers, opt.shared_hidden_size,
                                  opt.sum_pooling, opt.dropout, opt.F_bn)
        for domain in opt.domains:
            F_d[domain] = DanFeatureExtractor(vocab, opt.F_layers,
                                              opt.domain_hidden_size,
                                              opt.sum_pooling, opt.dropout,
                                              opt.F_bn)
    elif opt.model.lower() == 'lstm':
        F_s = LSTMFeatureExtractor(vocab, opt.F_layers, opt.shared_hidden_size,
                                   opt.dropout, opt.bdrnn, opt.attn)
        for domain in opt.domains:
            F_d[domain] = LSTMFeatureExtractor(vocab, opt.F_layers,
                                               opt.domain_hidden_size,
                                               opt.dropout, opt.bdrnn,
                                               opt.attn)
    elif opt.model.lower() == 'cnn':
        F_s = CNNFeatureExtractor(vocab, opt.F_layers, opt.shared_hidden_size,
                                  opt.kernel_num, opt.kernel_sizes,
                                  opt.dropout)
        for domain in opt.domains:
            F_d[domain] = CNNFeatureExtractor(vocab, opt.F_layers,
                                              opt.domain_hidden_size,
                                              opt.kernel_num, opt.kernel_sizes,
                                              opt.dropout)
    else:
        raise Exception(f'Unknown model architecture {opt.model}')

    C = SentimentClassifier(opt.C_layers,
                            opt.shared_hidden_size + opt.domain_hidden_size,
                            opt.shared_hidden_size + opt.domain_hidden_size,
                            opt.num_labels, opt.dropout, opt.C_bn)
    D = DomainClassifier(opt.D_layers,
                         opt.shared_hidden_size, opt.shared_hidden_size,
                         len(opt.all_domains), opt.loss, opt.dropout, opt.D_bn)

    F_s, C, D = F_s.to(opt.device), C.to(opt.device), D.to(opt.device)
    for f_d in F_d.values():
        f_d = f_d.to(opt.device)
    # optimizers
    optimizer = optim.Adam(itertools.chain(
        *map(list, [F_s.parameters() if F_s else [],
                    C.parameters()] + [f.parameters() for f in F_d.values()])),
                           lr=opt.learning_rate)
    optimizerD = optim.Adam(D.parameters(), lr=opt.D_learning_rate)

    # testing
    if opt.test_only:
        log.info(f'Loading model from {opt.exp3_model_save_file}...')
        if F_s:
            F_s.load_state_dict(
                torch.load(
                    os.path.join(opt.exp3_model_save_file, f'netF_s.pth')))
        for domain in opt.all_domains:
            if domain in F_d:
                F_d[domain].load_state_dict(
                    torch.load(
                        os.path.join(opt.exp3_model_save_file,
                                     f'net_F_d_{domain}.pth')))
        C.load_state_dict(
            torch.load(os.path.join(opt.exp3_model_save_file, f'netC.pth')))
        D.load_state_dict(
            torch.load(os.path.join(opt.exp3_model_save_file, f'netD.pth')))

        log.info('Evaluating validation sets:')
        acc = {}
        for domain in opt.all_domains:
            acc[domain] = evaluate(domain, dev_loaders[domain], F_s,
                                   F_d[domain] if domain in F_d else None, C)
        avg_acc = sum([acc[d] for d in opt.dev_domains]) / len(opt.dev_domains)
        log.info(f'Average validation accuracy: {avg_acc}')
        log.info('Evaluating test sets:')
        test_acc = {}
        for domain in opt.all_domains:
            test_acc[domain] = evaluate(domain, test_loaders[domain], F_s,
                                        F_d[domain] if domain in F_d else None,
                                        C)
        avg_test_acc = sum([test_acc[d]
                            for d in opt.dev_domains]) / len(opt.dev_domains)
        log.info(f'Average test accuracy: {avg_test_acc}')
        return {'valid': acc, 'test': test_acc}

    # training
    best_acc, best_avg_acc = defaultdict(float), 0.0
    for epoch in range(opt.max_epoch):
        F_s.train()
        C.train()
        D.train()
        LAMBDA = 3
        lambda1 = 0.1
        lambda2 = 0.1

        for f in F_d.values():
            f.train()

        # training accuracy
        correct, total = defaultdict(int), defaultdict(int)
        # D accuracy
        shared_d_correct, private_d_correct, d_total = 0, 0, 0
        # conceptually view 1 epoch as 1 epoch of the first domain
        num_iter = len(train_loaders[opt.domains[0]])
        for i in tqdm(range(num_iter)):
            # D iterations
            utils.freeze_net(F_s)
            map(utils.freeze_net, F_d.values())
            utils.freeze_net(C)
            utils.unfreeze_net(D)
            # WGAN n_critic trick since D trains slower
            # ********************** D iterations on all domains ********************** #
            # ---------------------- update D with D gradients on all domains ---------------------- #
            n_critic = opt.n_critic
            if opt.wgan_trick:
                if opt.n_critic > 0 and ((epoch == 0 and i < 25)
                                         or i % 500 == 0):
                    n_critic = 100

            for _ in range(n_critic):
                D.zero_grad()
                loss_d = {}
                # train on both labeled and unlabeled domains
                for domain in opt.all_domains:
                    # targets not usedndless_get_next_batch(
                    #                             unlabeled_loaders, unlabeled_iters, domain)
                    d_inputs, _ = utils.endless_get_next_batch(
                        unlabeled_loaders, unlabeled_iters, domain)
                    d_targets = utils.get_domain_label(opt.loss, domain,
                                                       len(d_inputs[1]))
                    shared_feat = F_s(d_inputs)
                    shared_d_outputs = D(shared_feat)
                    _, shared_pred = torch.max(shared_d_outputs, 1)
                    if domain != opt.dev_domains[0]:
                        private_feat = F_d[domain](d_inputs)
                        private_d_outputs = D(private_feat)
                        _, private_pred = torch.max(private_d_outputs, 1)

                    d_total += len(d_inputs[1])
                    if opt.loss.lower() == 'l2':
                        _, tgt_indices = torch.max(d_targets, 1)
                        shared_d_correct += (
                            shared_pred == tgt_indices).sum().item()
                        shared_l_d = functional.mse_loss(
                            shared_d_outputs, d_targets)
                        private_l_d = 0.0
                        if domain != opt.dev_domains[0]:
                            private_d_correct += (
                                private_pred == tgt_indices).sum().item()
                            private_l_d = functional.mse_loss(
                                private_d_outputs, d_targets) / len(
                                    opt.domains)

                        l_d_sum = shared_l_d + private_l_d
                        l_d_sum.backward()
                    else:
                        shared_d_correct += (
                            shared_pred == d_targets).sum().item()
                        shared_l_d = functional.nll_loss(
                            shared_d_outputs, d_targets)
                        private_l_d = 0.0
                        if domain != opt.dev_domains[0]:
                            private_d_correct += (
                                private_pred == d_targets).sum().item()
                            private_l_d = functional.nll_loss(
                                private_d_outputs, d_targets) / len(
                                    opt.domains)
                        l_d_sum = shared_l_d + private_l_d
                        l_d_sum.backward()

                    loss_d[domain] = l_d_sum.item()
                optimizerD.step()

            # ---------------------- update D with C gradients on all domains ---------------------- #
            # ********************** D iterations on all domains ********************** #

            # ********************** F&C iteration ********************** #
            # ---------------------- update F_s & F_ds with C gradients on all labeled domains ---------------------- #
            # F&C iteration
            utils.unfreeze_net(F_s)
            map(utils.unfreeze_net, F_d.values())
            utils.unfreeze_net(C)
            utils.freeze_net(D)
            if opt.fix_emb:
                utils.freeze_net(F_s.word_emb)
                for f_d in F_d.values():
                    utils.freeze_net(f_d.word_emb)
            F_s.zero_grad()
            for f_d in F_d.values():
                f_d.zero_grad()
            C.zero_grad()
            for domain in opt.domains:
                inputs, targets = utils.endless_get_next_batch(
                    train_loaders, train_iters, domain)
                targets = targets.to(opt.device)
                shared_feat = F_s(inputs)
                domain_feat = F_d[domain](inputs)
                features = torch.cat((shared_feat, domain_feat), dim=1)
                c_outputs = C(features)
                loss_part_1 = functional.nll_loss(c_outputs, targets)

                targets = targets.unsqueeze(1)
                targets_onehot = torch.FloatTensor(opt.batch_size, 2)
                targets_onehot.zero_()
                targets_onehot.scatter_(1, targets.cpu(), 1)
                targets_onehot = targets_onehot.to(opt.device)
                loss_part_2 = lambda1 * margin_regularization(
                    inputs, targets_onehot, F_d[domain], LAMBDA)

                loss_part_3 = -lambda2 * center_point_constraint(
                    domain_feat, targets)
                print("lambda1: " + str(lambda1))
                print("lambda2: " + str(lambda2))
                print("loss_part_1: " + str(loss_part_1))
                print("loss_part_2: " + str(loss_part_2))
                print("loss_part_3: " + str(loss_part_3))
                l_c = loss_part_1 + loss_part_2 + loss_part_3
                l_c.backward(retain_graph=True)
                _, pred = torch.max(c_outputs, 1)
                total[domain] += targets.size(0)
                correct[domain] += (pred == targets).sum().item()
            # ---------------------- update F_s & F_ds with C gradients on all labeled domains ---------------------- #

            # ---------------------- update F_s with D gradients on all domains ---------------------- #
            # update F with D gradients on all domains
            for domain in opt.all_domains:
                d_inputs, _ = utils.endless_get_next_batch(
                    unlabeled_loaders, unlabeled_iters, domain)
                shared_feat = F_s(d_inputs)
                shared_d_outputs = D(shared_feat)
                if domain != opt.dev_domains[0]:
                    private_feat = F_d[domain](d_inputs)
                    private_d_outputs = D(private_feat)

                l_d_sum = None
                if opt.loss.lower() == 'gr':
                    d_targets = utils.get_domain_label(opt.loss, domain,
                                                       len(d_inputs[1]))
                    shared_l_d = functional.nll_loss(shared_d_outputs,
                                                     d_targets)
                    private_l_d, l_d_sum = 0.0, 0.0
                    if domain != opt.dev_domains[0]:
                        # 注意这边的loss function
                        private_l_d = functional.nll_loss(
                            private_d_outputs, d_targets) * -1. / len(
                                opt.domains)
                    if opt.shared_lambd > 0:
                        l_d_sum = shared_l_d * opt.shared_lambd * -1.
                    else:
                        l_d_sum = shared_l_d * -1.
                    if opt.private_lambd > 0:
                        l_d_sum += private_l_d * opt.private_lambd * -1.
                    else:
                        l_d_sum += private_l_d * -1.
                elif opt.loss.lower() == 'bs':
                    d_targets = utils.get_random_domain_label(
                        opt.loss, len(d_inputs[1]))
                    shared_l_d = functional.kl_div(shared_d_outputs,
                                                   d_targets,
                                                   size_average=False)
                    private_l_d, l_d_sum = 0.0, 0.0
                    if domain != opt.dev_domains[0]:
                        private_l_d = functional.kl_div(private_d_outputs, d_targets, size_average=False) \
                                      * -1. / len(opt.domains)
                    if opt.shared_lambd > 0:
                        l_d_sum = shared_l_d * opt.shared_lambd
                    else:
                        l_d_sum = shared_l_d
                    if opt.private_lambd > 0:
                        l_d_sum += private_l_d * opt.private_lambd
                    else:
                        l_d_sum += private_l_d
                elif opt.loss.lower() == 'l2':
                    d_targets = utils.get_random_domain_label(
                        opt.loss, len(d_inputs[1]))
                    shared_l_d = functional.mse_loss(shared_d_outputs,
                                                     d_targets)
                    private_l_d, l_d_sum = 0.0, 0.0
                    if domain != opt.dev_domains[0]:
                        private_l_d = functional.mse_loss(
                            private_d_outputs, d_targets) * -1. / len(
                                opt.domains)
                    if opt.shared_lambd > 0:
                        l_d_sum = shared_l_d * opt.shared_lambd
                    else:
                        l_d_sum = shared_l_d
                    if opt.private_lambd > 0:
                        l_d_sum += private_l_d * opt.private_lambd
                    else:
                        l_d_sum += private_l_d
                l_d_sum.backward()

            optimizer.step()

        # end of epoch
        log.info('Ending epoch {}'.format(epoch + 1))
        if d_total > 0:
            log.info('shared D Training Accuracy: {}%'.format(
                100.0 * shared_d_correct / d_total))
            log.info('private D Training Accuracy(average): {}%'.format(
                100.0 * private_d_correct / len(opt.domains) / d_total))
        log.info('Training accuracy:')
        log.info('\t'.join(opt.domains))
        log.info('\t'.join(
            [str(100.0 * correct[d] / total[d]) for d in opt.domains]))

        # 验证集上验证实验
        log.info('Evaluating validation sets:')
        acc = {}
        for domain in opt.dev_domains:
            acc[domain] = evaluate(domain, dev_loaders[domain], F_s,
                                   F_d[domain] if domain in F_d else None, C)
        avg_acc = sum([acc[d] for d in opt.dev_domains]) / len(opt.dev_domains)
        log.info(f'Average validation accuracy: {avg_acc}')

        # 测试集上验证实验
        log.info('Evaluating test sets:')
        test_acc = {}
        for domain in opt.dev_domains:
            test_acc[domain] = evaluate(domain, test_loaders[domain], F_s,
                                        F_d[domain] if domain in F_d else None,
                                        C)
        avg_test_acc = sum([test_acc[d]
                            for d in opt.dev_domains]) / len(opt.dev_domains)
        log.info(f'Average test accuracy: {avg_test_acc}')

        # 保存模型
        if avg_acc > best_avg_acc:
            log.info(f'New best average validation accuracy: {avg_acc}')
            best_acc['valid'] = acc
            best_acc['test'] = test_acc
            best_avg_acc = avg_acc
            with open(os.path.join(opt.exp3_model_save_file, 'options.pkl'),
                      'wb') as ouf:
                pickle.dump(opt, ouf)
            torch.save(F_s.state_dict(),
                       '{}/netF_s.pth'.format(opt.exp3_model_save_file))
            for d in opt.domains:
                if d in F_d:
                    torch.save(
                        F_d[d].state_dict(),
                        '{}/net_F_d_{}.pth'.format(opt.exp3_model_save_file,
                                                   d))
            torch.save(C.state_dict(),
                       '{}/netC.pth'.format(opt.exp3_model_save_file))
            torch.save(D.state_dict(),
                       '{}/netD.pth'.format(opt.exp3_model_save_file))

    # end of training
    log.info(f'Best average validation accuracy: {best_avg_acc}')

    log.info(
        f'Loading model for feature visualization from {opt.exp3_model_save_file}...'
    )

    for domain in opt.domains:
        F_d[domain].load_state_dict(
            torch.load(
                os.path.join(opt.exp3_model_save_file,
                             f'net_F_d_{domain}.pth')))
    num_iter = len(train_loaders[opt.domains[0]])
    # visual_features暂时不加上shared feature
    # visual_features, senti_labels = get_visual_features(num_iter, unlabeled_loaders, unlabeled_iters, F_s, F_d)
    visual_features, senti_labels = get_visual_features(
        num_iter, unlabeled_loaders, unlabeled_iters, F_d)
    return best_acc, visual_features, senti_labels
Exemple #6
0
def train(vocab, train_sets, dev_sets, test_sets, unlabeled_sets):
    train_loaders, unlabeled_loaders = {}, {}
    train_iters, unlabeled_iters = {}, {}
    dev_loaders, test_loaders = {}, {}
    my_collate = utils.sorted_collate if opt.model == 'lstm' else utils.unsorted_collate
    for domain in opt.domains:
        train_loaders[domain] = DataLoader(train_sets[domain],
                                           opt.batch_size,
                                           shuffle=True,
                                           collate_fn=my_collate)
        train_iters[domain] = iter(train_loaders[domain])
    for domain in opt.dev_domains:
        dev_loaders[domain] = DataLoader(dev_sets[domain],
                                         opt.batch_size,
                                         shuffle=False,
                                         collate_fn=my_collate)
        test_loaders[domain] = DataLoader(test_sets[domain],
                                          opt.batch_size,
                                          shuffle=False,
                                          collate_fn=my_collate)
    for domain in opt.all_domains:
        if domain in opt.unlabeled_domains:
            uset = unlabeled_sets[domain]
        else:
            # for labeled domains, consider which data to use as unlabeled set
            if opt.unlabeled_data == 'both':
                uset = ConcatDataset(
                    [train_sets[domain], unlabeled_sets[domain]])
            elif opt.unlabeled_data == 'unlabeled':
                uset = unlabeled_sets[domain]
            elif opt.unlabeled_data == 'train':
                uset = train_sets[domain]
            else:
                raise Exception(
                    'Unknown options for the unlabeled data usage: {}'.format(
                        opt.unlabeled_data))
        unlabeled_loaders[domain] = DataLoader(uset,
                                               opt.batch_size,
                                               shuffle=True,
                                               collate_fn=my_collate)
        unlabeled_iters[domain] = iter(unlabeled_loaders[domain])

    # models
    F_s = None
    C = None
    if opt.model.lower() == 'dan':
        F_s = DanFeatureExtractor(vocab, opt.F_layers, opt.shared_hidden_size,
                                  opt.sum_pooling, opt.dropout, opt.F_bn)
    elif opt.model.lower() == 'lstm':
        F_s = LSTMFeatureExtractor(vocab, opt.F_layers, opt.shared_hidden_size,
                                   opt.dropout, opt.bdrnn, opt.attn)
    elif opt.model.lower() == 'cnn':
        F_s = CNNFeatureExtractor(vocab, opt.F_layers, opt.shared_hidden_size,
                                  opt.kernel_num, opt.kernel_sizes,
                                  opt.dropout)
    else:
        raise Exception('Unknown model architecture {}'.format(opt.model))

    C = SentimentClassifier(opt.C_layers,
                            opt.shared_hidden_size + opt.domain_hidden_size,
                            opt.shared_hidden_size + opt.domain_hidden_size,
                            opt.num_labels, opt.dropout, opt.C_bn)

    F_s, C = F_s.to(opt.device), C.to(opt.device)
    optimizer = optim.Adam(itertools.chain(
        *map(list, [F_s.parameters() if F_s else [],
                    C.parameters()])),
                           lr=opt.learning_rate)

    best_acc, best_avg_acc = defaultdict(float), 0.0
    for epoch in range(opt.max_epoch):
        F_s.train()
        C.train()

        # training accuracy
        correct, total = defaultdict(int), defaultdict(int)
        # conceptually view 1 epoch as 1 epoch of the first domain
        num_iter = len(train_loaders[opt.domains[0]])
        for i in tqdm(range(num_iter)):
            # F&C iteration
            utils.unfreeze_net(F_s)
            utils.unfreeze_net(C)
            if opt.fix_emb:
                utils.freeze_net(F_s.word_emb)
            F_s.zero_grad()
            C.zero_grad()
            for domain in opt.domains:
                inputs, targets = utils.endless_get_next_batch(
                    train_loaders, train_iters, domain)
                targets = targets.to(opt.device)
                shared_feat = F_s(inputs)
                domain_feat = torch.zeros(
                    len(targets), opt.domain_hidden_size).to(opt.device)
                features = torch.cat((shared_feat, domain_feat), dim=1)
                c_outputs = C(features)
                l_c = functional.nll_loss(c_outputs, targets)
                l_c.backward(retain_graph=True)
                _, pred = torch.max(c_outputs, 1)
                total[domain] += targets.size(0)
                correct[domain] += (pred == targets).sum().item()
            optimizer.step()

        # end of epoch
        log.info('Ending epoch {}'.format(epoch + 1))
        log.info('Training accuracy:')
        log.info('\t'.join(opt.domains))
        log.info('\t'.join(
            [str(100.0 * correct[d] / total[d]) for d in opt.domains]))
        log.info('Evaluating validation sets:')
        acc = {}
        for domain in opt.dev_domains:
            acc[domain] = evaluate(domain, dev_loaders[domain], F_s, None, C)
        avg_acc = sum([acc[d] for d in opt.dev_domains]) / len(opt.dev_domains)
        log.info('Average validation accuracy: {}'.format(avg_acc))
        log.info('Evaluating test sets:')
        test_acc = {}
        for domain in opt.dev_domains:
            test_acc[domain] = evaluate(domain, test_loaders[domain], F_s,
                                        None, C)
        avg_test_acc = sum([test_acc[d]
                            for d in opt.dev_domains]) / len(opt.dev_domains)
        log.info('Average test accuracy: {}'.format(avg_test_acc))

        if avg_acc > best_avg_acc:
            log.info(
                'New best average validation accuracy: {}'.format(avg_acc))
            best_acc['valid'] = acc
            best_acc['test'] = test_acc
            best_avg_acc = avg_acc
            with open(os.path.join(opt.model_save_file, 'options.pkl'),
                      'wb') as ouf:
                pickle.dump(opt, ouf)
            torch.save(F_s.state_dict(),
                       '{}/netF_s.pth'.format(opt.model_save_file))

            torch.save(C.state_dict(),
                       '{}/netC.pth'.format(opt.model_save_file))

    # end of training
    log.info('Best average validation accuracy: {}'.format(best_avg_acc))
    return best_acc
def genarate_labels(target_domain_idx, test_for_label_acc, dataloader, F_s, F_d, C, D):
    """Genrate pseudo labels for unlabeled domain dataset."""
    # ---------------------- Switch to eval() mode ---------------------- #
    optimizerD = optim.Adam(D.parameters(), lr=opt.D_learning_rate)
    F_s.eval()
    C.eval()
    for f_d in F_d.values():
        f_d.eval()
    D.eval()

    it = iter(dataloader)               # batch_size为128
    print(len(it))

    F_d_features = []
    F_d_outputs = []
    targets_total, pseudo_labels = None, None

    with torch.no_grad():
        for inputs, targets in tqdm(it):

            # 少于一个batch的数据,丢弃
            if inputs.shape[0] < opt.batch_size:
                continue
            # ---------------------- 得到F_d_features和shared_feature ---------------------- #
            for f_d in F_d.values():
                F_d_features.append(f_d(inputs))
            shared_feature = F_s(inputs)

            # ---------------------- 送入D中训练 ---------------------- #
            if test_for_label_acc is False:
                utils.freeze_net(F_s)
                map(utils.freeze_net, F_d.values())
                utils.freeze_net(C)
                utils.unfreeze_net(D)
                train_D(target_domain_idx, optimizerD, F_d_features, shared_feature, D)

            # ---------------------- 处理F_s经过D之后的"分数" ---------------------- #
            # 此时的shared_feature_d_outputs是source domain中domain数量 + 1(target domain)
            shared_feature_d_outputs = D(shared_feature)
            # print(shared_feature_d_outputs.shape)

            # 去掉target domain这一维度
            indices = []
            for i in range(len(opt.all_domains)):
                if opt.all_domains[i] == opt.unlabeled_domains:
                    continue
                else:
                    indices.append(i)

            indices = torch.tensor(indices).to(opt.device)
            # print(indices)
            selected_shared_feature_d_outputs = torch.index_select(shared_feature_d_outputs, 1, indices)
            # print(selected_shared_feature_d_outputs.shape)

            # 将selected_shared_feature_d_outputs经过softmax,进行归一化
            norm_selected_shared_feature_d_outputs = F.softmax(selected_shared_feature_d_outputs, dim=1)

            # ---------------------- 所有F_d经过C之后的所有classifier分数 ---------------------- #
            for f_d_feature in F_d_features:
                padding_shared_feature = torch.zeros((shared_feature.shape[0], shared_feature.shape[1]),
                                                     requires_grad=True).to(opt.device)
                f_d_feature = torch.cat([f_d_feature, padding_shared_feature], 1)
                F_d_outputs.append(C(f_d_feature))

            # ---------------------- 得到c * w后的分数 ---------------------- #
            norm_selected_shared_feature_d_outputs = np.repeat(norm_selected_shared_feature_d_outputs.cpu().numpy(),
                                                               repeats=opt.num_labels, axis=1)
            norm_selected_shared_feature_d_outputs = torch.from_numpy(norm_selected_shared_feature_d_outputs)
            F_d_outputs_tensor = torch.stack(F_d_outputs, 0).reshape(opt.batch_size, opt.num_labels * (len(opt.all_domains) - 1))
            c_mul_w = norm_selected_shared_feature_d_outputs * F_d_outputs_tensor.cpu()

            # ---------------------- 对c_mul_w处理得到hard label ---------------------- #
            even_indices = torch.LongTensor(np.arange(0, 2 * (len(opt.all_domains) - 1), 2))
            odd_indices = torch.LongTensor(np.arange(1, 2 * (len(opt.all_domains) - 1), 2))
            even_index_scores = torch.index_select(c_mul_w, 1, even_indices)
            odd_index_scores = torch.index_select(c_mul_w, 1, odd_indices)
            even_index_scores_sum = torch.sum(even_index_scores, 1).unsqueeze(1)
            odd_index_scores_sum = torch.sum(odd_index_scores, 1).unsqueeze(1)
            pred_scores = torch.cat([even_index_scores_sum, odd_index_scores_sum], 1)
            _, pred_idx = torch.max(pred_scores, 1)

            # ---------------------- 保存pseudo_labels ---------------------- #
            if pseudo_labels is None:
                pseudo_labels = pred_idx
                targets_total = targets
            else:
                pseudo_labels = torch.cat(
                    [pseudo_labels, pred_idx], 0)
                targets_total = torch.cat(
                    [targets_total, targets], 0)

            F_d_features.clear()
            F_d_outputs.clear()

    print(pseudo_labels.shape)
    print(targets_total.shape)
    print(">>> Generate pseudo labels {}, target samples {}".format(
        len(pseudo_labels), targets_total.shape[0]))

    return pseudo_labels, targets_total
def train_nobert(vocab, train_loaders, unlabeled_loaders, train_iters, unlabeled_iters, dev_loaders, test_loaders):
    # models
    F_s = None
    F_d = {}
    C, D = None, None
    if opt.model.lower() == 'dan':
        F_s = DanFeatureExtractor(vocab, opt.F_layers, opt.shared_hidden_size,
                               opt.sum_pooling, opt.dropout, opt.F_bn)
        for domain in opt.domains:
            F_d[domain] = DanFeatureExtractor(vocab, opt.F_layers, opt.domain_hidden_size,
                                           opt.sum_pooling, opt.dropout, opt.F_bn)
    elif opt.model.lower() == 'lstm':
        F_s = LSTMFeatureExtractor(vocab, opt.F_layers, opt.shared_hidden_size,
                                   opt.dropout, opt.bdrnn, opt.attn)
        for domain in opt.domains:
            F_d[domain] = LSTMFeatureExtractor(vocab, opt.F_layers, opt.domain_hidden_size,
                                               opt.dropout, opt.bdrnn, opt.attn)
    elif opt.model.lower() == 'cnn':
        F_s = CNNFeatureExtractor(vocab, opt.F_layers, opt.shared_hidden_size,
                                  opt.kernel_num, opt.kernel_sizes, opt.dropout)
        for domain in opt.domains:
            F_d[domain] = CNNFeatureExtractor(vocab, opt.F_layers, opt.domain_hidden_size,
                                              opt.kernel_num, opt.kernel_sizes, opt.dropout)
    else:
        raise Exception('Unknown model architecture {}'.format(opt.model))

    C = SentimentClassifier(opt.C_layers, opt.shared_hidden_size + opt.domain_hidden_size,
            opt.shared_hidden_size + opt.domain_hidden_size, opt.num_labels,
            opt.dropout, opt.C_bn)
    D = DomainClassifier(opt.D_layers, opt.shared_hidden_size, opt.shared_hidden_size,
                         len(opt.all_domains), opt.loss, opt.dropout, opt.D_bn)
    
    F_s, C, D = F_s.to(opt.device), C.to(opt.device), D.to(opt.device)
    for f_d in F_d.values():
        f_d = f_d.to(opt.device)
    # optimizers
    optimizer = optim.Adam(itertools.chain(*map(list, [F_s.parameters() if F_s else [], C.parameters()] + [f.parameters() for f in F_d.values()])), lr=opt.learning_rate)
    optimizerD = optim.Adam(D.parameters(), lr=opt.D_learning_rate)

    # testing
    if opt.test_only:
        log.info('Loading model from {}...'.format(opt.model_save_file))
        if F_s:
            F_s.load_state_dict(torch.load(os.path.join(opt.model_save_file,
                                           'netF_s.pth')))
        for domain in opt.all_domains:
            if domain in F_d:
                F_d[domain].load_state_dict(torch.load(os.path.join(opt.model_save_file,
                        'net_F_d_{}.pth'.format(domain))))
        C.load_state_dict(torch.load(os.path.join(opt.model_save_file,
                                                  'netC.pth')))
        D.load_state_dict(torch.load(os.path.join(opt.model_save_file,
                                                  'netD.pth')))

        log.info('Evaluating validation sets:')
        acc = {}
        for domain in opt.all_domains:
            acc[domain] = evaluate(domain, dev_loaders[domain],
                                   F_s, F_d[domain] if domain in F_d else None, C)
        avg_acc = sum([acc[d] for d in opt.dev_domains]) / len(opt.dev_domains)
        log.info('Average validation accuracy: {}'.format(avg_acc))
        log.info('Evaluating test sets:')
        test_acc = {}
        for domain in opt.all_domains:
            test_acc[domain] = evaluate(domain, test_loaders[domain],
                    F_s, F_d[domain] if domain in F_d else None, C)
        avg_test_acc = sum([test_acc[d] for d in opt.dev_domains]) / len(opt.dev_domains)
        log.info('Average test accuracy: {}'.format(avg_test_acc))
        return {'valid': acc, 'test': test_acc}

    # training
    best_acc, best_avg_acc = defaultdict(float), 0.0
    for epoch in range(opt.max_epoch):
        F_s.train()
        C.train()
        D.train()
        for f in F_d.values():
            f.train()

        # training accuracy
        correct, total = defaultdict(int), defaultdict(int)
        # D accuracy
        d_correct, d_total = 0, 0
        # conceptually view 1 epoch as 1 epoch of the first domain
        num_iter = len(train_loaders[opt.domains[0]])
        for i in tqdm(range(num_iter)):
            # D iterations
            utils.freeze_net(F_s)
            map(utils.freeze_net, F_d.values())
            utils.freeze_net(C)
            utils.unfreeze_net(D)
            # WGAN n_critic trick since D trains slower
            n_critic = opt.n_critic
            if opt.wgan_trick:
                if opt.n_critic>0 and ((epoch==0 and i<25) or i%500==0):
                    n_critic = 100

            for _ in range(n_critic):
                D.zero_grad()
                loss_d = {}
                # train on both labeled and unlabeled domains
                for domain in opt.all_domains:
                    # targets not used
                    batch = utils.endless_get_next_batch(
                            unlabeled_loaders, unlabeled_iters, domain)
                    batch = tuple(t.to(opt.device) for t in batch)
                    emb_ids, input_ids, input_mask, segment_ids, label_ids = batch
                    d_targets = utils.get_domain_label(opt.loss, domain, len(emb_ids))
                    shared_feat = F_s(emb_ids)
                    d_outputs = D(shared_feat)
                    # D accuracy
                    _, pred = torch.max(d_outputs, 1)
                    d_total += len(emb_ids)
                    if opt.loss.lower() == 'l2':
                        _, tgt_indices = torch.max(d_targets, 1)
                        d_correct += (pred==tgt_indices).sum().item()
                        l_d = functional.mse_loss(d_outputs, d_targets)
                        l_d.backward()
                    else:
                        d_correct += (pred==d_targets).sum().item()
                        l_d = functional.nll_loss(d_outputs, d_targets)
                        l_d.backward()
                    loss_d[domain] = l_d.item()
                optimizerD.step()

            # F&C iteration
            utils.unfreeze_net(F_s)
            map(utils.unfreeze_net, F_d.values())
            utils.unfreeze_net(C)
            utils.freeze_net(D)
            # if opt.fix_emb:
            #     utils.freeze_net(F_s.word_emb)
            #     for f_d in F_d.values():
            #         utils.freeze_net(f_d.word_emb)
            F_s.zero_grad()
            for f_d in F_d.values():
                f_d.zero_grad()
            C.zero_grad()
            for domain in opt.domains:
                batch = utils.endless_get_next_batch(
                        train_loaders, train_iters, domain)
                batch = tuple(t.to(opt.device) for t in batch)
                emb_ids, input_ids, input_mask, segment_ids, label_ids = batch
                inputs = emb_ids
                targets = label_ids
                shared_feat = F_s(inputs)
                domain_feat = F_d[domain](inputs)
                features = torch.cat((shared_feat, domain_feat), dim=1)
                c_outputs = C(features)
                l_c = functional.nll_loss(c_outputs, targets)
                l_c.backward(retain_graph=True)
                _, pred = torch.max(c_outputs, 1)
                total[domain] += targets.size(0)
                correct[domain] += (pred == targets).sum().item()
            # update F with D gradients on all domains
            for domain in opt.all_domains:
                batch = utils.endless_get_next_batch(
                        unlabeled_loaders, unlabeled_iters, domain)
                batch = tuple(t.to(opt.device) for t in batch)
                emb_ids, input_ids, input_mask, segment_ids, label_ids = batch
                d_inputs = emb_ids
                shared_feat = F_s(d_inputs)
                d_outputs = D(shared_feat)
                if opt.loss.lower() == 'gr':
                    d_targets = utils.get_domain_label(opt.loss, domain, len(d_inputs))
                    l_d = functional.nll_loss(d_outputs, d_targets)
                    if opt.lambd > 0:
                        l_d *= -opt.lambd
                elif opt.loss.lower() == 'bs':
                    d_targets = utils.get_random_domain_label(opt.loss, len(d_inputs))
                    l_d = functional.kl_div(d_outputs, d_targets, size_average=False)
                    if opt.lambd > 0:
                        l_d *= opt.lambd
                elif opt.loss.lower() == 'l2':
                    d_targets = utils.get_random_domain_label(opt.loss, len(d_inputs))
                    l_d = functional.mse_loss(d_outputs, d_targets)
                    if opt.lambd > 0:
                        l_d *= opt.lambd
                l_d.backward()

            optimizer.step()

        # end of epoch
        log.info('Ending epoch {}'.format(epoch+1))
        if d_total > 0:
            log.info('D Training Accuracy: {}%'.format(100.0*d_correct/d_total))
        log.info('Training accuracy:')
        log.info('\t'.join(opt.domains))
        log.info('\t'.join([str(100.0*correct[d]/total[d]) for d in opt.domains]))
        log.info('Evaluating validation sets:')
        acc = {}
        for domain in opt.dev_domains:
            acc[domain] = evaluate(domain, dev_loaders[domain],
                    F_s, F_d[domain] if domain in F_d else None, C)
        avg_acc = sum([acc[d] for d in opt.dev_domains]) / len(opt.dev_domains)
        log.info('Average validation accuracy: {}'.format(avg_acc))
        log.info('Evaluating test sets:')
        test_acc = {}
        for domain in opt.dev_domains:
            test_acc[domain] = evaluate(domain, test_loaders[domain],
                    F_s, F_d[domain] if domain in F_d else None, C)
        avg_test_acc = sum([test_acc[d] for d in opt.dev_domains]) / len(opt.dev_domains)
        log.info('Average test accuracy: {}'.format(avg_test_acc))

        if avg_acc > best_avg_acc:
            log.info('New best average validation accuracy: {}'.format(avg_acc))
            best_acc['valid'] = acc
            best_acc['test'] = test_acc
            best_avg_acc = avg_acc
            with open(os.path.join(opt.model_save_file, 'options.pkl'), 'wb') as ouf:
                pickle.dump(opt, ouf)
            # torch.save(F_s.state_dict(),
            #            '{}/netF_s.pth'.format(opt.model_save_file))
            # for d in opt.domains:
            #     if d in F_d:
            #         torch.save(F_d[d].state_dict(),
            #                    '{}/net_F_d_{}.pth'.format(opt.model_save_file, d))
            # torch.save(C.state_dict(),
            #            '{}/netC.pth'.format(opt.model_save_file))
            # torch.save(D.state_dict(),
            #         '{}/netD.pth'.format(opt.model_save_file))

    # end of training
    log.info('Best average validation accuracy: {}'.format(best_avg_acc))
    return best_acc
def train_private_nobert(vocab, train_loaders, unlabeled_loaders, train_iters, unlabeled_iters, dev_loaders, test_loaders):
    # models
    F_d = {}
    C = None
    if opt.model.lower() == 'dan':
        for domain in opt.domains:
            F_d[domain] = DanFeatureExtractor(vocab, opt.F_layers, opt.domain_hidden_size,
                                              opt.sum_pooling, opt.dropout, opt.F_bn)
    elif opt.model.lower() == 'lstm':
        for domain in opt.domains:
            F_d[domain] = LSTMFeatureExtractor(vocab, opt.F_layers, opt.domain_hidden_size,
                                               opt.dropout, opt.bdrnn, opt.attn)
    elif opt.model.lower() == 'cnn':
        for domain in opt.domains:
            F_d[domain] = CNNFeatureExtractor(vocab, opt.F_layers, opt.domain_hidden_size,
                                              opt.kernel_num, opt.kernel_sizes, opt.dropout)
    else:
        raise Exception('Unknown model architecture {}'.format(opt.model))

    C = SentimentClassifier(opt.C_layers, opt.shared_hidden_size + opt.domain_hidden_size,
                            opt.shared_hidden_size + opt.domain_hidden_size, opt.num_labels,
                            opt.dropout, opt.C_bn)

    C = C.to(opt.device)
    for f_d in F_d.values():
        f_d = f_d.to(opt.device)
    # optimizers
    optimizer = optim.Adam(itertools.chain(
        *map(list, [[], C.parameters()] + [f.parameters() for f in F_d.values()])),
                           lr=opt.learning_rate)


    # training
    best_acc, best_avg_acc = defaultdict(float), 0.0
    for epoch in range(opt.max_epoch):
        C.train()
        for f in F_d.values():
            f.train()

        # training accuracy
        correct, total = defaultdict(int), defaultdict(int)
        # D accuracy
        d_correct, d_total = 0, 0
        # conceptually view 1 epoch as 1 epoch of the first domain
        num_iter = len(train_loaders[opt.domains[0]])
        for i in tqdm(range(num_iter)):

            # F&C iteration
            map(utils.unfreeze_net, F_d.values())
            utils.unfreeze_net(C)
            if opt.fix_emb:
                for f_d in F_d.values():
                    utils.freeze_net(f_d.word_emb)
            for f_d in F_d.values():
                f_d.zero_grad()
            C.zero_grad()
            for domain in opt.domains:
                batch = utils.endless_get_next_batch(
                    train_loaders, train_iters, domain)
                batch = tuple(t.to(opt.device) for t in batch)
                emb_ids, input_ids, input_mask, segment_ids, label_ids = batch
                inputs = emb_ids
                targets = label_ids
                shared_feat = torch.zeros(len(targets), opt.shared_hidden_size).to(opt.device)
                domain_feat = F_d[domain](inputs)
                features = torch.cat((shared_feat, domain_feat), dim=1)
                c_outputs = C(features)
                l_c = functional.nll_loss(c_outputs, targets)
                l_c.backward(retain_graph=True)
                _, pred = torch.max(c_outputs, 1)
                total[domain] += targets.size(0)
                correct[domain] += (pred == targets).sum().item()

            optimizer.step()

        # end of epoch
        log.info('Ending epoch {}'.format(epoch + 1))
        if d_total > 0:
            log.info('D Training Accuracy: {}%'.format(100.0 * d_correct / d_total))
        log.info('Training accuracy:')
        log.info('\t'.join(opt.domains))
        log.info('\t'.join([str(100.0 * correct[d] / total[d]) for d in opt.domains]))
        log.info('Evaluating validation sets:')
        acc = {}
        for domain in opt.dev_domains:
            acc[domain] = evaluate(domain, dev_loaders[domain],
                                   None, F_d[domain] if domain in F_d else None, C)
        avg_acc = sum([acc[d] for d in opt.dev_domains]) / len(opt.dev_domains)
        log.info('Average validation accuracy: {}'.format(avg_acc))
        log.info('Evaluating test sets:')
        test_acc = {}
        for domain in opt.dev_domains:
            test_acc[domain] = evaluate(domain, test_loaders[domain],
                                        None, F_d[domain] if domain in F_d else None, C)
        avg_test_acc = sum([test_acc[d] for d in opt.dev_domains]) / len(opt.dev_domains)
        log.info('Average test accuracy: {}'.format(avg_test_acc))

        if avg_acc > best_avg_acc:
            log.info('New best average validation accuracy: {}'.format(avg_acc))
            best_acc['valid'] = acc
            best_acc['test'] = test_acc
            best_avg_acc = avg_acc
            with open(os.path.join(opt.model_save_file, 'options.pkl'), 'wb') as ouf:
                pickle.dump(opt, ouf)
            # for d in opt.domains:
            #     if d in F_d:
            #         torch.save(F_d[d].state_dict(),
            #                    '{}/net_F_d_{}.pth'.format(opt.model_save_file, d))
            # torch.save(C.state_dict(),
            #            '{}/netC.pth'.format(opt.model_save_file))

    # end of training
    log.info('Best average validation accuracy: {}'.format(best_avg_acc))
    return best_acc
Exemple #10
0
def train(opt):
    # vocab
    log.info(f'Loading Embeddings...')
    vocab = Vocab(opt.emb_filename)
    # datasets
    log.info(f'Loading data...')
    yelp_X_train = os.path.join(opt.src_data_dir, 'X_train.txt.tok.shuf.lower')
    yelp_Y_train = os.path.join(opt.src_data_dir, 'Y_train.txt.shuf')
    yelp_X_test = os.path.join(opt.src_data_dir, 'X_test.txt.tok.lower')
    yelp_Y_test = os.path.join(opt.src_data_dir, 'Y_test.txt')
    yelp_train, yelp_valid = get_yelp_datasets(vocab, yelp_X_train, yelp_Y_train,
            opt.en_train_lines, yelp_X_test, yelp_Y_test, opt.max_seq_len)
    chn_X_file = os.path.join(opt.tgt_data_dir, 'X.sent.txt.shuf.lower')
    chn_Y_file = os.path.join(opt.tgt_data_dir, 'Y.txt.shuf')
    chn_train, chn_valid, chn_test = get_chn_htl_datasets(vocab, chn_X_file, chn_Y_file,
            opt.ch_train_lines, opt.max_seq_len)
    log.info('Done loading datasets.')
    opt.num_labels = yelp_train.num_labels

    if opt.max_seq_len <= 0:
        # set to true max_seq_len in the datasets
        opt.max_seq_len = max(yelp_train.get_max_seq_len(),
                              chn_train.get_max_seq_len())
    # dataset loaders
    my_collate = utils.sorted_collate if opt.model=='lstm' else utils.unsorted_collate
    yelp_train_loader = DataLoader(yelp_train, opt.batch_size,
            shuffle=True, collate_fn=my_collate)
    yelp_train_loader_Q = DataLoader(yelp_train,
                                     opt.batch_size,
                                     shuffle=True, collate_fn=my_collate)
    chn_train_loader = DataLoader(chn_train, opt.batch_size,
            shuffle=True, collate_fn=my_collate)
    chn_train_loader_Q = DataLoader(chn_train,
                                    opt.batch_size,
                                    shuffle=True, collate_fn=my_collate)
    yelp_train_iter_Q = iter(yelp_train_loader_Q)
    chn_train_iter = iter(chn_train_loader)
    chn_train_iter_Q = iter(chn_train_loader_Q)

    yelp_valid_loader = DataLoader(yelp_valid, opt.batch_size,
            shuffle=False, collate_fn=my_collate)
    chn_valid_loader = DataLoader(chn_valid, opt.batch_size,
            shuffle=False, collate_fn=my_collate)
    chn_test_loader = DataLoader(chn_test, opt.batch_size,
            shuffle=False, collate_fn=my_collate)

    # models
    if opt.model.lower() == 'dan':
        F = DANFeatureExtractor(vocab, opt.F_layers, opt.hidden_size, opt.dropout, opt.F_bn)
    elif opt.model.lower() == 'lstm':
        F = LSTMFeatureExtractor(vocab, opt.F_layers, opt.hidden_size, opt.dropout,
                opt.bdrnn, opt.attn)
    elif opt.model.lower() == 'cnn':
        F = CNNFeatureExtractor(vocab, opt.F_layers,
                opt.hidden_size, opt.kernel_num, opt.kernel_sizes, opt.dropout)
    else:
        raise Exception('Unknown model')
    P = SentimentClassifier(opt.P_layers, opt.hidden_size, opt.num_labels,
            opt.dropout, opt.P_bn)
    Q = LanguageDetector(opt.Q_layers, opt.hidden_size, opt.dropout, opt.Q_bn)
    F, P, Q = F.to(opt.device), P.to(opt.device), Q.to(opt.device)
    optimizer = optim.Adam(list(F.parameters()) + list(P.parameters()),
                           lr=opt.learning_rate)
    optimizerQ = optim.Adam(Q.parameters(), lr=opt.Q_learning_rate)

    # training
    best_acc = 0.0
    for epoch in range(opt.max_epoch):
        F.train()
        P.train()
        Q.train()
        yelp_train_iter = iter(yelp_train_loader)
        # training accuracy
        correct, total = 0, 0
        sum_en_q, sum_ch_q = (0, 0.0), (0, 0.0)
        grad_norm_p, grad_norm_q = (0, 0.0), (0, 0.0)
        for i, (inputs_en, targets_en) in tqdm(enumerate(yelp_train_iter),
                                               total=len(yelp_train)//opt.batch_size):
            try:
                inputs_ch, _ = next(chn_train_iter)  # Chinese labels are not used
            except:
                # check if Chinese data is exhausted
                chn_train_iter = iter(chn_train_loader)
                inputs_ch, _ = next(chn_train_iter)

            # Q iterations
            n_critic = opt.n_critic
            if n_critic>0 and ((epoch==0 and i<=25) or (i%500==0)):
                n_critic = 10
            utils.freeze_net(F)
            utils.freeze_net(P)
            utils.unfreeze_net(Q)
            for qiter in range(n_critic):
                # clip Q weights
                for p in Q.parameters():
                    p.data.clamp_(opt.clip_lower, opt.clip_upper)
                Q.zero_grad()
                # get a minibatch of data
                try:
                    # labels are not used
                    q_inputs_en, _ = next(yelp_train_iter_Q)
                except StopIteration:
                    # check if dataloader is exhausted
                    yelp_train_iter_Q = iter(yelp_train_loader_Q)
                    q_inputs_en, _ = next(yelp_train_iter_Q)
                try:
                    q_inputs_ch, _ = next(chn_train_iter_Q)
                except StopIteration:
                    chn_train_iter_Q = iter(chn_train_loader_Q)
                    q_inputs_ch, _ = next(chn_train_iter_Q)

                features_en = F(q_inputs_en)
                o_en_ad = Q(features_en)
                l_en_ad = torch.mean(o_en_ad)
                (-l_en_ad).backward()
                log.debug(f'Q grad norm: {Q.net[1].weight.grad.data.norm()}')
                sum_en_q = (sum_en_q[0] + 1, sum_en_q[1] + l_en_ad.item())

                features_ch = F(q_inputs_ch)
                o_ch_ad = Q(features_ch)
                l_ch_ad = torch.mean(o_ch_ad)
                l_ch_ad.backward()
                log.debug(f'Q grad norm: {Q.net[1].weight.grad.data.norm()}')
                sum_ch_q = (sum_ch_q[0] + 1, sum_ch_q[1] + l_ch_ad.item())

                optimizerQ.step()

            # F&P iteration
            utils.unfreeze_net(F)
            utils.unfreeze_net(P)
            utils.freeze_net(Q)
            if opt.fix_emb:
                utils.freeze_net(F.word_emb)
            # clip Q weights
            for p in Q.parameters():
                p.data.clamp_(opt.clip_lower, opt.clip_upper)
            F.zero_grad()
            P.zero_grad()
            
            features_en = F(inputs_en)
            o_en_sent = P(features_en)
            l_en_sent = functional.nll_loss(o_en_sent, targets_en)
            l_en_sent.backward(retain_graph=True)
            o_en_ad = Q(features_en)
            l_en_ad = torch.mean(o_en_ad)
            (opt.lambd*l_en_ad).backward(retain_graph=True)
            # training accuracy
            _, pred = torch.max(o_en_sent, 1)
            total += targets_en.size(0)
            correct += (pred == targets_en).sum().item()

            features_ch = F(inputs_ch)
            o_ch_ad = Q(features_ch)
            l_ch_ad = torch.mean(o_ch_ad)
            (-opt.lambd*l_ch_ad).backward()

            optimizer.step()
    
        # end of epoch
        log.info('Ending epoch {}'.format(epoch+1))
        # logs
        if sum_en_q[0] > 0:
            log.info(f'Average English Q output: {sum_en_q[1]/sum_en_q[0]}')
            log.info(f'Average Foreign Q output: {sum_ch_q[1]/sum_ch_q[0]}')
        # evaluate
        log.info('Training Accuracy: {}%'.format(100.0*correct/total))
        log.info('Evaluating English Validation set:')
        evaluate(opt, yelp_valid_loader, F, P)
        log.info('Evaluating Foreign validation set:')
        acc = evaluate(opt, chn_valid_loader, F, P)
        if acc > best_acc:
            log.info(f'New Best Foreign validation accuracy: {acc}')
            best_acc = acc
            torch.save(F.state_dict(),
                    '{}/netF_epoch_{}.pth'.format(opt.model_save_file, epoch))
            torch.save(P.state_dict(),
                    '{}/netP_epoch_{}.pth'.format(opt.model_save_file, epoch))
            torch.save(Q.state_dict(),
                    '{}/netQ_epoch_{}.pth'.format(opt.model_save_file, epoch))
        log.info('Evaluating Foreign test set:')
        evaluate(opt, chn_test_loader, F, P)
    log.info(f'Best Foreign validation accuracy: {best_acc}')
Exemple #11
0
def train_shared(vocab, train_loaders, unlabeled_loaders, train_iters, unlabeled_iters, dev_loaders, test_loaders, F_s):

    C = None

    C = SentimentClassifier(opt.C_layers, opt.shared_hidden_size + opt.domain_hidden_size,
                            opt.shared_hidden_size + opt.domain_hidden_size, opt.num_labels,
                            opt.dropout, opt.C_bn)

    F_s, C = F_s.to(opt.device), C.to(opt.device)
    # optimizers
    optimizer = optim.Adam(itertools.chain(
        *map(list, [F_s.parameters() if F_s else [], C.parameters()] + [])),
        lr=opt.learning_rate)

    # training
    best_acc, best_avg_acc = defaultdict(float), 0.0
    for epoch in range(opt.max_epoch):
        F_s.train()
        C.train()

        # training accuracy
        correct, total = defaultdict(int), defaultdict(int)
        # D accuracy
        d_correct, d_total = 0, 0
        # conceptually view 1 epoch as 1 epoch of the first domain
        num_iter = len(train_loaders[opt.domains[0]])
        for i in tqdm(range(num_iter)):
            # F&C iteration
            utils.unfreeze_net(F_s)
            utils.unfreeze_net(C)
            if opt.fix_emb:
                utils.freeze_net(F_s.word_emb)
            F_s.zero_grad()
            C.zero_grad()
            for domain in opt.domains:
                batch = utils.endless_get_next_batch(
                    train_loaders, train_iters, domain)
                batch = tuple(t.to(opt.device) for t in batch)
                emb_ids, input_ids, input_mask, segment_ids, label_ids = batch
                inputs = emb_ids
                targets = label_ids
                _, shared_feat = F_s(input_ids, input_mask, segment_ids)
                domain_feat = torch.zeros(len(targets), opt.domain_hidden_size).to(opt.device)
                features = torch.cat((shared_feat, domain_feat), dim=1)
                c_outputs = C(features)
                l_c = functional.nll_loss(c_outputs, targets)
                l_c.backward(retain_graph=True)
                _, pred = torch.max(c_outputs, 1)
                total[domain] += targets.size(0)
                correct[domain] += (pred == targets).sum().item()

            optimizer.step()

        # end of epoch
        log.info('Ending epoch {}'.format(epoch + 1))
        if d_total > 0:
            log.info('D Training Accuracy: {}%'.format(100.0 * d_correct / d_total))
        log.info('Training accuracy:')
        log.info('\t'.join(opt.domains))
        log.info('\t'.join([str(100.0 * correct[d] / total[d]) for d in opt.domains]))
        log.info('Evaluating validation sets:')
        acc = {}
        for domain in opt.dev_domains:
            acc[domain] = evaluate(domain, dev_loaders[domain],
                                   F_s, None, C)
        avg_acc = sum([acc[d] for d in opt.dev_domains]) / len(opt.dev_domains)
        log.info('Average validation accuracy: {}'.format(avg_acc))
        log.info('Evaluating test sets:')
        test_acc = {}
        for domain in opt.dev_domains:
            test_acc[domain] = evaluate(domain, test_loaders[domain],
                                        F_s, None, C)
        avg_test_acc = sum([test_acc[d] for d in opt.dev_domains]) / len(opt.dev_domains)
        log.info('Average test accuracy: {}'.format(avg_test_acc))

        if avg_acc > best_avg_acc:
            log.info('New best average validation accuracy: {}'.format(avg_acc))
            best_acc['valid'] = acc
            best_acc['test'] = test_acc
            best_avg_acc = avg_acc
            with open(os.path.join(opt.model_save_file, 'options.pkl'), 'wb') as ouf:
                pickle.dump(opt, ouf)
            # torch.save(C.state_dict(),
            #            '{}/netC.pth'.format(opt.model_save_file))

    # end of training
    log.info('Best average validation accuracy: {}'.format(best_avg_acc))
    return best_acc
Exemple #12
0
def train(train_sets, dev_sets, test_sets, unlabeled_sets, fold):
    """
    train_sets, dev_sets, test_sets: dict[domain] -> TensorDataset
    For unlabeled domains, no train_sets are available
    """
    # dataset loaders
    train_loaders, unlabeled_loaders = {}, {}
    train_iters, unlabeled_iters = {}, {}
    dev_loaders, test_loaders = {}, {}
    for domain in opt.domains:
        train_loaders[domain] = DataLoader(train_sets[domain],
                                           opt.batch_size,
                                           shuffle=True)
        train_iters[domain] = iter(train_loaders[domain])
    for domain in opt.all_domains:
        dev_loaders[domain] = DataLoader(dev_sets[domain],
                                         opt.batch_size,
                                         shuffle=False)
        test_loaders[domain] = DataLoader(test_sets[domain],
                                          opt.batch_size,
                                          shuffle=False)
        if domain in opt.unlabeled_domains:
            uset = unlabeled_sets[domain]
        else:
            # for labeled domains, consider which data to use as unlabeled set
            if opt.unlabeled_data == 'both':
                uset = ConcatDataset(
                    [train_sets[domain], unlabeled_sets[domain]])
            elif opt.unlabeled_data == 'unlabeled':
                uset = unlabeled_sets[domain]
            elif opt.unlabeled_data == 'train':
                uset = train_sets[domain]
            else:
                raise Exception(
                    f'Unknown options for the unlabeled data usage: {opt.unlabeled_data}'
                )
        unlabeled_loaders[domain] = DataLoader(uset,
                                               opt.batch_size,
                                               shuffle=True)
        unlabeled_iters[domain] = iter(unlabeled_loaders[domain])

    # models
    F_s = None
    F_d = {}
    C, D = None, None
    if opt.model.lower() == 'mlp':
        F_s = MlpFeatureExtractor(opt.feature_num, opt.F_hidden_sizes,
                                  opt.shared_hidden_size, opt.dropout,
                                  opt.F_bn)
        for domain in opt.domains:
            F_d[domain] = MlpFeatureExtractor(opt.feature_num,
                                              opt.F_hidden_sizes,
                                              opt.domain_hidden_size,
                                              opt.dropout, opt.F_bn)
    else:
        raise Exception(f'Unknown model architecture {opt.model}')
    C = SentimentClassifier(opt.C_layers,
                            opt.shared_hidden_size + opt.domain_hidden_size,
                            opt.shared_hidden_size + opt.domain_hidden_size,
                            opt.num_labels, opt.dropout, opt.C_bn)
    D = DomainClassifier(opt.D_layers,
                         opt.shared_hidden_size, opt.shared_hidden_size,
                         len(opt.all_domains), opt.loss, opt.dropout, opt.D_bn)

    F_s, C, D = F_s.to(opt.device), C.to(opt.device), D.to(opt.device)
    for f_d in F_d.values():
        f_d = f_d.to(opt.device)
    # optimizers
    optimizer = optim.Adam(itertools.chain(
        *map(list, [F_s.parameters() if F_s else [],
                    C.parameters()] + [f.parameters() for f in F_d.values()])),
                           lr=opt.learning_rate)
    optimizerD = optim.Adam(D.parameters(), lr=opt.D_learning_rate)

    # testing
    if opt.test_only:
        log.info(f'Loading model from {opt.model_save_file}...')
        F_s.load_state_dict(
            torch.load(
                os.path.join(opt.model_save_file, f'netF_s_fold{fold}.pth')))
        for domain in opt.all_domains:
            if domain in F_d:
                F_d[domain].load_state_dict(
                    torch.load(
                        os.path.join(opt.model_save_file,
                                     f'net_F_d_{domain}_fold{fold}.pth')))
        C.load_state_dict(
            torch.load(
                os.path.join(opt.model_save_file, f'netC_fold{fold}.pth')))
        D.load_state_dict(
            torch.load(
                os.path.join(opt.model_save_file, f'netD_fold{fold}.pth')))

        log.info('Evaluating validation sets:')
        acc = {}
        for domain in opt.all_domains:
            acc[domain] = evaluate(domain, dev_loaders[domain], F_s,
                                   F_d[domain] if domain in F_d else None, C)
        avg_acc = sum([acc[d] for d in opt.dev_domains]) / len(opt.dev_domains)
        log.info(f'Average validation accuracy: {avg_acc}')
        log.info('Evaluating test sets:')
        test_acc = {}
        for domain in opt.all_domains:
            test_acc[domain] = evaluate(domain, test_loaders[domain], F_s,
                                        F_d[domain] if domain in F_d else None,
                                        C)
        avg_test_acc = sum([test_acc[d]
                            for d in opt.dev_domains]) / len(opt.dev_domains)
        log.info(f'Average test accuracy: {avg_test_acc}')
        return {'valid': acc, 'test': test_acc}

    # training
    best_acc, best_avg_acc = defaultdict(float), 0.0

    for epoch in range(opt.max_epoch):
        F_s.train()
        C.train()
        D.train()
        for f in F_d.values():
            f.train()

        # training accuracy
        correct, total = defaultdict(int), defaultdict(int)
        # D accuracy
        d_correct, d_total = 0, 0
        # conceptually view 1 epoch as 1 epoch of the first domain
        num_iter = len(train_loaders[opt.domains[0]])
        for i in tqdm(range(num_iter)):
            # D iterations
            utils.freeze_net(F_s)
            map(utils.freeze_net, F_d.values())
            utils.freeze_net(C)
            utils.unfreeze_net(D)
            # optional WGAN n_critic trick
            n_critic = opt.n_critic
            if opt.wgan_trick:
                if opt.n_critic > 0 and ((epoch == 0 and i < 25)
                                         or i % 500 == 0):
                    n_critic = 100

            for _ in range(n_critic):
                D.zero_grad()
                loss_d = {}
                # train on both labeled and unlabeled domains
                for domain in opt.all_domains:
                    # targets not used
                    d_inputs, _ = utils.endless_get_next_batch(
                        unlabeled_loaders, unlabeled_iters, domain)
                    d_targets = utils.get_domain_label(opt.loss, domain,
                                                       len(d_inputs))
                    shared_feat = F_s(d_inputs)
                    d_outputs = D(shared_feat)
                    # D accuracy
                    _, pred = torch.max(d_outputs, 1)
                    d_total += len(d_inputs)
                    if opt.loss.lower() == 'l2':
                        _, tgt_indices = torch.max(d_targets, 1)
                        d_correct += (pred == tgt_indices).sum().item()
                        l_d = functional.mse_loss(d_outputs, d_targets)
                        l_d.backward()
                    else:
                        d_correct += (pred == d_targets).sum().item()
                        l_d = functional.nll_loss(d_outputs, d_targets)
                        l_d.backward()
                    loss_d[domain] = l_d.item()
                optimizerD.step()

            # F&C iteration
            utils.unfreeze_net(F_s)
            map(utils.unfreeze_net, F_d.values())
            utils.unfreeze_net(C)
            utils.freeze_net(D)
            F_s.zero_grad()
            for f_d in F_d.values():
                f_d.zero_grad()
            C.zero_grad()
            shared_feats, domain_feats = [], []
            for domain in opt.domains:
                inputs, targets = utils.endless_get_next_batch(
                    train_loaders, train_iters, domain)
                targets = targets.to(opt.device)
                shared_feat = F_s(inputs)
                shared_feats.append(shared_feat)
                domain_feat = F_d[domain](inputs)
                domain_feats.append(domain_feat)
                features = torch.cat((shared_feat, domain_feat), dim=1)
                c_outputs = C(features)
                l_c = functional.nll_loss(c_outputs, targets)
                l_c.backward(retain_graph=True)
                # training accuracy
                _, pred = torch.max(c_outputs, 1)
                total[domain] += targets.size(0)
                correct[domain] += (pred == targets).sum().item()
            # update F with D gradients on all domains
            for domain in opt.all_domains:
                d_inputs, _ = utils.endless_get_next_batch(
                    unlabeled_loaders, unlabeled_iters, domain)
                shared_feat = F_s(d_inputs)
                d_outputs = D(shared_feat)
                if opt.loss.lower() == 'gr':
                    d_targets = utils.get_domain_label(opt.loss, domain,
                                                       len(d_inputs))
                    l_d = functional.nll_loss(d_outputs, d_targets)
                    log.debug(f'D loss: {l_d.item()}')
                    if opt.lambd > 0:
                        l_d *= -opt.lambd
                elif opt.loss.lower() == 'bs':
                    d_targets = utils.get_random_domain_label(
                        opt.loss, len(d_inputs))
                    l_d = functional.kl_div(d_outputs,
                                            d_targets,
                                            size_average=False)
                    if opt.lambd > 0:
                        l_d *= opt.lambd
                elif opt.loss.lower() == 'l2':
                    d_targets = utils.get_random_domain_label(
                        opt.loss, len(d_inputs))
                    l_d = functional.mse_loss(d_outputs, d_targets)
                    if opt.lambd > 0:
                        l_d *= opt.lambd
                l_d.backward()
                if opt.model.lower() != 'lstm':
                    log.debug(
                        f'F_s norm: {F_s.net[-2].weight.grad.data.norm()}')

            optimizer.step()

        # end of epoch
        log.info('Ending epoch {}'.format(epoch + 1))
        if d_total > 0:
            log.info('D Training Accuracy: {}%'.format(100.0 * d_correct /
                                                       d_total))
        log.info('Training accuracy:')
        log.info('\t'.join(opt.domains))
        log.info('\t'.join(
            [str(100.0 * correct[d] / total[d]) for d in opt.domains]))
        log.info('Evaluating validation sets:')
        acc = {}
        for domain in opt.all_domains:
            acc[domain] = evaluate(domain, dev_loaders[domain], F_s,
                                   F_d[domain] if domain in F_d else None, C)
        avg_acc = sum([acc[d] for d in opt.dev_domains]) / len(opt.dev_domains)
        log.info(f'Average validation accuracy: {avg_acc}')
        log.info('Evaluating test sets:')
        test_acc = {}
        for domain in opt.all_domains:
            test_acc[domain] = evaluate(domain, test_loaders[domain], F_s,
                                        F_d[domain] if domain in F_d else None,
                                        C)
        avg_test_acc = sum([test_acc[d]
                            for d in opt.dev_domains]) / len(opt.dev_domains)
        log.info(f'Average test accuracy: {avg_test_acc}')

        if avg_acc > best_avg_acc:
            log.info(f'New best average validation accuracy: {avg_acc}')
            best_acc['valid'] = acc
            best_acc['test'] = test_acc
            best_avg_acc = avg_acc
            with open(os.path.join(opt.model_save_file, 'options.pkl'),
                      'wb') as ouf:
                pickle.dump(opt, ouf)
            torch.save(
                F_s.state_dict(),
                '{}/netF_s_fold{}.pth'.format(opt.model_save_file, fold))
            for d in opt.domains:
                if d in F_d:
                    torch.save(
                        F_d[d].state_dict(), '{}/net_F_d_{}_fold{}.pth'.format(
                            opt.model_save_file, d, fold))
            torch.save(C.state_dict(),
                       '{}/netC_fold{}.pth'.format(opt.model_save_file, fold))
            torch.save(D.state_dict(),
                       '{}/netD_fold{}.pth'.format(opt.model_save_file, fold))

    # end of training
    log.info(f'Best average validation accuracy: {best_avg_acc}')
    return best_acc