示例#1
0
def predict(args):
    encoder, classifiers, Us, Ps, Ns = torch.load(args.load_model)
    map(lambda m: m.eval(), [encoder] + classifiers)

    # args = argparser.parse_args()
    # say(args)
    if args.cuda:
        map(lambda m: m.cuda(), [encoder] + classifiers)
        Us = [U.cuda() for U in Us]
        Ps = [P.cuda() for P in Ps]
        Ns = [N.cuda() for N in Ns]

    say("\nTransferring from %s to %s\n" % (args.train, args.test))
    source_train_sets = args.train.split(',')
    train_loaders = []
    for source in source_train_sets:
        filepath = os.path.join(DATA_DIR, "%s_train.svmlight" % (source))
        train_dataset = AmazonDataset(filepath)
        train_loader = data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=0
        )
        train_loaders.append(train_loader)

    test_filepath = os.path.join(DATA_DIR, "%s_test.svmlight" % (args.test))
    test_dataset = AmazonDataset(test_filepath)
    test_loader = data.DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=0
    )
    say("Corpus loaded.\n")

    mats = [Us, Ps, Ns]
    (acc, oracle_acc), confusion_mat = evaluate(
        encoder, classifiers,
        mats,
        [train_loaders, test_loader],
        args
    )
    say(colored("Test accuracy/oracle {:.4f}/{:.4f}\n".format(acc, oracle_acc), 'red'))
示例#2
0
def train_moe_deep_stack(args):
    save_model_dir = os.path.join(settings.OUT_DIR, args.test)
    classifiers, attn_mats = torch.load(
        os.path.join(
            save_model_dir,
            "{}_{}_moe_best_now.mdl".format(args.test, args.base_model)))
    print("base model", args.base_model)
    print("classifier", classifiers[0])

    source_train_sets = args.train.split(',')
    pretrain_emb = torch.load(
        os.path.join(settings.OUT_DIR, "rnn_init_word_emb.emb"))

    encoders_src = []
    for src_i in range(len(source_train_sets)):
        cur_model_dir = os.path.join(settings.OUT_DIR,
                                     source_train_sets[src_i])

        if args.base_model == "cnn":
            encoder_class = CNNMatchModel(
                input_matrix_size1=args.matrix_size1,
                input_matrix_size2=args.matrix_size2,
                mat1_channel1=args.mat1_channel1,
                mat1_kernel_size1=args.mat1_kernel_size1,
                mat1_channel2=args.mat1_channel2,
                mat1_kernel_size2=args.mat1_kernel_size2,
                mat1_hidden=args.mat1_hidden,
                mat2_channel1=args.mat2_channel1,
                mat2_kernel_size1=args.mat2_kernel_size1,
                mat2_hidden=args.mat2_hidden)
        elif args.base_model == "rnn":
            encoder_class = BiLSTM(pretrain_emb=pretrain_emb,
                                   vocab_size=args.max_vocab_size,
                                   embedding_size=args.embedding_size,
                                   hidden_size=args.hidden_size,
                                   dropout=args.dropout)
        else:
            raise NotImplementedError
        if args.cuda:
            encoder_class.load_state_dict(
                torch.load(
                    os.path.join(
                        cur_model_dir,
                        "{}-match-best-now.mdl".format(args.base_model))))
        else:
            encoder_class.load_state_dict(
                torch.load(os.path.join(
                    cur_model_dir,
                    "{}-match-best-now.mdl".format(args.base_model)),
                           map_location=torch.device('cpu')))

        encoders_src.append(encoder_class)

    map(lambda m: m.eval(), encoders_src + classifiers + attn_mats)

    if args.cuda:
        map(lambda m: m.cuda(), classifiers + encoders_src + attn_mats)

    if args.base_model == "cnn":
        train_dataset_dst = ProcessedCNNInputDataset(args.test, "train")
        valid_dataset = ProcessedCNNInputDataset(args.test, "valid")
        test_dataset = ProcessedCNNInputDataset(args.test, "test")
    elif args.base_model == "rnn":
        train_dataset_dst = ProcessedRNNInputDataset(args.test, "train")
        valid_dataset = ProcessedRNNInputDataset(args.test, "valid")
        test_dataset = ProcessedRNNInputDataset(args.test, "test")
    else:
        raise NotImplementedError

    train_loader_dst = data.DataLoader(train_dataset_dst,
                                       batch_size=args.batch_size,
                                       shuffle=False,
                                       num_workers=0)

    valid_loader = data.DataLoader(valid_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   num_workers=0)

    test_loader = data.DataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=0)
    say("Corpus loaded.\n")

    meta_features = np.empty(shape=(0, 192 + 2 * 8))
    meta_labels = []
    n_sources = len(encoders_src)
    encoders = encoders_src

    if args.base_model == "cnn":
        for batch1, batch2, label in train_loader_dst:
            if args.cuda:
                batch1 = batch1.cuda()
                batch2 = batch2.cuda()
                label = label.cuda()

            outputs_dst_transfer = []
            hidden_from_src_enc = []
            for src_i in range(n_sources):
                _, cur_hidden = encoders[src_i](batch1, batch2)
                hidden_from_src_enc.append(cur_hidden)
                cur_output = classifiers[src_i](cur_hidden)
                outputs_dst_transfer.append(cur_output)

            source_ids = range(n_sources)
            support_ids = [x for x in source_ids]  # experts

            source_alphas = [
                attn_mats[j](hidden_from_src_enc[j]).squeeze()
                for j in source_ids
            ]

            support_alphas = [source_alphas[x] for x in support_ids]
            support_alphas = softmax(support_alphas)
            source_alphas = softmax(source_alphas)  # [ 32, 32, 32 ]
            alphas = source_alphas
示例#3
0
def train(args):
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    say('cuda is available %s\n' % args.cuda)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed + args.seed_delta)
    if args.cuda:
        torch.cuda.manual_seed(args.seed + args.seed_delta)

    source_train_sets = args.train.split(',')
    print("sources", source_train_sets)

    pretrain_emb = torch.load(
        os.path.join(settings.OUT_DIR, "rnn_init_word_emb.emb"))

    encoders_src = []
    for src_i in range(len(source_train_sets)):
        cur_model_dir = os.path.join(settings.OUT_DIR,
                                     source_train_sets[src_i])

        if args.base_model == "cnn":
            encoder_class = CNNMatchModel(
                input_matrix_size1=args.matrix_size1,
                input_matrix_size2=args.matrix_size2,
                mat1_channel1=args.mat1_channel1,
                mat1_kernel_size1=args.mat1_kernel_size1,
                mat1_channel2=args.mat1_channel2,
                mat1_kernel_size2=args.mat1_kernel_size2,
                mat1_hidden=args.mat1_hidden,
                mat2_channel1=args.mat2_channel1,
                mat2_kernel_size1=args.mat2_kernel_size1,
                mat2_hidden=args.mat2_hidden)
        elif args.base_model == "rnn":
            encoder_class = BiLSTM(pretrain_emb=pretrain_emb,
                                   vocab_size=args.max_vocab_size,
                                   embedding_size=args.embedding_size,
                                   hidden_size=args.hidden_size,
                                   dropout=args.dropout)
        else:
            raise NotImplementedError
        if args.cuda:
            encoder_class.load_state_dict(
                torch.load(
                    os.path.join(
                        cur_model_dir,
                        "{}-match-best-now.mdl".format(args.base_model))))
        else:
            encoder_class.load_state_dict(
                torch.load(os.path.join(
                    cur_model_dir,
                    "{}-match-best-now.mdl".format(args.base_model)),
                           map_location=torch.device('cpu')))

        encoders_src.append(encoder_class)

    dst_pretrain_dir = os.path.join(settings.OUT_DIR, args.test)
    if args.base_model == "cnn":
        encoder_dst_pretrain = CNNMatchModel(
            input_matrix_size1=args.matrix_size1,
            input_matrix_size2=args.matrix_size2,
            mat1_channel1=args.mat1_channel1,
            mat1_kernel_size1=args.mat1_kernel_size1,
            mat1_channel2=args.mat1_channel2,
            mat1_kernel_size2=args.mat1_kernel_size2,
            mat1_hidden=args.mat1_hidden,
            mat2_channel1=args.mat2_channel1,
            mat2_kernel_size1=args.mat2_kernel_size1,
            mat2_hidden=args.mat2_hidden)
    elif args.base_model == "rnn":
        encoder_dst_pretrain = BiLSTM(pretrain_emb=pretrain_emb,
                                      vocab_size=args.max_vocab_size,
                                      embedding_size=args.embedding_size,
                                      hidden_size=args.hidden_size,
                                      dropout=args.dropout)
    else:
        raise NotImplementedError

    args = argparser.parse_args()
    say(args)
    print()

    say("Transferring from %s to %s\n" % (args.train, args.test))

    if args.base_model == "cnn":
        train_dataset_dst = ProcessedCNNInputDataset(args.test, "train")
        valid_dataset = ProcessedCNNInputDataset(args.test, "valid")
        test_dataset = ProcessedCNNInputDataset(args.test, "test")

    elif args.base_model == "rnn":
        train_dataset_dst = ProcessedRNNInputDataset(args.test, "train")
        valid_dataset = ProcessedRNNInputDataset(args.test, "valid")
        test_dataset = ProcessedRNNInputDataset(args.test, "test")
    else:
        raise NotImplementedError

    train_loader_dst = data.DataLoader(train_dataset_dst,
                                       batch_size=args.batch_size,
                                       shuffle=False,
                                       num_workers=0)

    valid_loader = data.DataLoader(valid_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   num_workers=0)

    test_loader = data.DataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=0)

    say("Corpus loaded.\n")

    classifiers = []
    attn_mats = []
    for source in source_train_sets:

        classifier = nn.Sequential(
            nn.Linear(encoders_src[0].n_out, 64),
            nn.ReLU(),
            nn.Linear(64, 16),
            nn.ReLU(),
            nn.Linear(16, 2),
        )
        # cur_att_weight = nn.Linear(len(encoders_src), 1, bias=True)
        cur_att_weight = nn.Linear(len(encoders_src), 1, bias=False)
        # nn.init.uniform_(cur_att_weight.weight)
        # print(cur_att_weight)
        cur_att_weight.weight = nn.Parameter(
            torch.ones(size=(1, len(encoders_src))), requires_grad=True)
        print("init cur att weight", cur_att_weight.weight)
        if args.attn_type == "onehot":
            attn_mats.append(
                # nn.Linear(encoders_src[0].n_out, 1)
                cur_att_weight
                # nn.Linear(encoders_src[0].n_out, encoders_src[0].n_out)
                # MulInteractAttention(encoders_src[0].n_out, 16)
            )
        elif args.attn_type == "cor":
            attn_mats.append(MulInteractAttention(encoders_src[0].n_out, 16))
        else:
            raise NotImplementedError
        classifiers.append(classifier)
    print("classifier build", classifiers[0])

    if args.cuda:
        map(lambda m: m.cuda(), classifiers + encoders_src + attn_mats)

    for i, classifier in enumerate(classifiers):
        say("Classifier-{}: {}\n".format(i, classifier))

    requires_grad = lambda x: x.requires_grad
    task_params = []
    for src_i in range(len(classifiers)):
        task_params += list(classifiers[src_i].parameters())
        task_params += list(attn_mats[src_i].parameters())

    if args.base_model == "cnn":
        optim_model = optim.Adagrad(
            filter(requires_grad, task_params),
            lr=args.lr,
            weight_decay=1e-4  #TODO
        )
    elif args.base_model == "rnn":
        optim_model = optim.Adam(filter(requires_grad, task_params),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    else:
        raise NotImplementedError

    say("Training will begin from scratch\n")

    iter_cnt = 0
    min_loss_val = None
    best_test_results = None
    model_dir = os.path.join(settings.OUT_DIR, args.test)

    for epoch in range(args.max_epoch):
        print("training epoch", epoch)

        iter_cnt = train_epoch(iter_cnt, [encoders_src, encoder_dst_pretrain],
                               classifiers, attn_mats, train_loader_dst, args,
                               optim_model, epoch)

        thr, metrics_val = evaluate(epoch,
                                    [encoders_src, encoder_dst_pretrain],
                                    classifiers, attn_mats, valid_loader, True,
                                    args)

        _, metrics_test = evaluate(epoch, [encoders_src, encoder_dst_pretrain],
                                   classifiers,
                                   attn_mats,
                                   test_loader,
                                   False,
                                   args,
                                   thr=thr)

        if min_loss_val is None or min_loss_val > metrics_val[0]:
            print("change val loss from {} to {}".format(
                min_loss_val, metrics_val[0]))
            min_loss_val = metrics_val[0]
            best_test_results = metrics_test
            torch.save([classifiers, attn_mats],
                       os.path.join(
                           model_dir,
                           "{}_{}_moe_simple_attn_best_now.mdl".format(
                               args.test, args.base_model)))
        say("\n")
        writer.flush()

    say(
        colored("Min valid loss: {:.4f}, best test results, "
                "AUC: {:.2f}, Prec: {:.2f}, Rec: {:.2f}, F1: {:.2f}\n".format(
                    min_loss_val, best_test_results[1] * 100,
                    best_test_results[2] * 100, best_test_results[3] * 100,
                    best_test_results[4] * 100)))
示例#4
0
def train_epoch(iter_cnt, encoders, classifiers, attn_mats, train_loader_dst,
                args, optim_model, epoch):

    encoders, encoder_dst = encoders

    map(lambda m: m.train(), classifiers + encoders + attn_mats)

    moe_criterion = nn.NLLLoss()  # with log_softmax separated
    entropy_criterion = HLoss()

    loss_total = 0
    n_batch = 0
    n_sources = len(encoders)

    for batch in train_loader_dst:
        if args.base_model == "cnn":
            batch1, batch2, label = batch
        elif args.base_model == "rnn":
            batch1, batch2, batch3, batch4, label = batch
        else:
            raise NotImplementedError

        bs = len(label)

        iter_cnt += 1
        n_batch += 1
        if args.cuda:
            batch1 = batch1.cuda()
            batch2 = batch2.cuda()
            label = label.cuda()
            if args.base_model == "rnn":
                batch3 = batch3.cuda()
                batch4 = batch4.cuda()

        if args.base_model == "cnn":
            _, hidden_from_dst_enc = encoder_dst(batch1, batch2)
        elif args.base_model == "rnn":
            _, hidden_from_dst_enc = encoder_dst(batch1, batch2, batch3,
                                                 batch4)
        else:
            raise NotImplementedError

        outputs_dst_transfer = []
        hidden_from_src_enc = []
        one_hot_sources = []
        for src_i in range(n_sources):
            if args.base_model == "cnn":
                _, cur_hidden = encoders[src_i](batch1, batch2)
                hidden_from_src_enc.append(cur_hidden)
            elif args.base_model == "rnn":
                _, cur_hidden = encoders[src_i](batch1, batch2, batch3, batch4)
                hidden_from_src_enc.append(cur_hidden)
            else:
                raise NotImplementedError
            cur_output = classifiers[src_i](cur_hidden)
            outputs_dst_transfer.append(cur_output)
            cur_one_hot_sources = torch.zeros(size=(bs, n_sources))
            cur_one_hot_sources[:, src_i] = 1
            one_hot_sources.append(cur_one_hot_sources)
        # print("one hot sources", one_hot_sources)

        optim_model.zero_grad()

        source_ids = range(n_sources)
        support_ids = [x for x in source_ids]  # experts
        # print("attn mats", attn_mats)
        # source_alphas = [attn_mats[j](hidden_from_src_enc[j]).squeeze() for j in source_ids]
        if args.attn_type == "onehot":
            source_alphas = [
                attn_mats[j](one_hot_sources[j]).squeeze() for j in source_ids
            ]
        elif args.attn_type == "cor":
            source_alphas = [
                attn_mats[j](hidden_from_src_enc[j],
                             hidden_from_dst_enc).squeeze() for j in source_ids
            ]
        else:
            raise NotImplementedError

        # source_alphas = [attn_mats[j](hidden_from_src_enc[j], hidden_from_dst_enc).squeeze() for j in source_ids]
        # source_alphas = [torch.bmm(attn_mats[j](hidden_from_src_enc[j]).unsqueeze(1), hidden_from_dst_enc.unsqueeze(2)).squeeze() for j in source_ids]

        # print("source alphas", source_alphas[0].size(), source_alphas)

        support_alphas = [source_alphas[x] for x in support_ids]
        support_alphas = softmax(support_alphas)
        source_alphas = softmax(source_alphas)  # [ 32, 32, 32 ]

        if args.cuda:
            source_alphas = [alpha.cuda() for alpha in source_alphas]
        source_alphas = torch.stack(source_alphas, dim=0)
        source_alphas = source_alphas.permute(1, 0)

        loss_entropy = entropy_criterion(source_alphas)

        output_moe = sum([alpha.unsqueeze(1).repeat(1, 2) * F.softmax(outputs_dst_transfer[id], dim=1) \
                            for alpha, id in zip(support_alphas, support_ids)])
        loss_moe = moe_criterion(torch.log(output_moe), label)
        lambda_moe = args.lambda_moe
        loss = lambda_moe * loss_moe
        loss += args.lambda_entropy * loss_entropy
        loss_total += loss.item()
        loss.backward()
        optim_model.step()

        if iter_cnt % 5 == 0:
            say("{} MOE loss: {:.4f}, Entropy loss: {:.4f}, "
                "loss: {:.4f}\n".format(iter_cnt, loss_moe.item(),
                                        loss_entropy.item(), loss.data.item()))

    loss_total /= n_batch
    writer.add_scalar('training_loss', loss_total, epoch)

    say("\n")
    return iter_cnt
示例#5
0
def evaluate_cross(encoder,
                   classifiers,
                   mats,
                   loaders,
                   return_best_thrs,
                   args,
                   thr=None):
    ''' Evaluate model using MOE
    '''
    map(lambda m: m.eval(), [encoder] + classifiers)

    if args.metric == "biaffine":
        Us, Ws, Vs = mats
    else:
        Us, Ps, Ns = mats

    source_loaders, valid_loaders_src = loaders
    domain_encs = domain_encoding(source_loaders, args, encoder)

    source_ids = range(len(valid_loaders_src))

    thresholds = []
    metrics = []
    alphas_weights = np.zeros(shape=(4, 4))

    for src_i in range(len(valid_loaders_src)):
        valid_loader = valid_loaders_src[src_i]

        oracle_correct = 0
        correct = 0
        tot_cnt = 0
        y_true = []
        y_pred = []
        y_score = []

        # support_ids = [x for x in source_ids if x != src_i]  # experts
        support_ids = [x for x in source_ids]  # experts
        cur_domain_encs = [domain_encs[x] for x in support_ids]
        cur_Us = [Us[x] for x in support_ids]
        cur_Ps = [Ps[x] for x in support_ids]
        cur_Ns = [Ns[x] for x in support_ids]

        cur_alpha_weights = [[]] * 4
        cur_alpha_weights_stack = np.empty(shape=(0, len(support_ids)))

        for batch1, batch2, label in valid_loader:
            if args.cuda:
                batch1 = batch1.cuda()
                batch2 = batch2.cuda()
                label = label.cuda()
            # print("eval labels", label)

            batch1 = Variable(batch1)
            batch2 = Variable(batch2)
            _, hidden = encoder(batch1, batch2)
            # source_ids = range(len(domain_encs))
            if args.metric == "biaffine":
                alphas = [biaffine_metric_fast(hidden, mu[0], Us[0]) \
                          for mu in domain_encs]
            else:
                alphas = [mahalanobis_metric_fast(hidden, mu[0], U, mu[1], P, mu[2], N) \
                          for (mu, U, P, N) in zip(cur_domain_encs, cur_Us, cur_Ps, cur_Ns)]
            # alphas = [ (1 - x / sum(alphas)) for x in alphas ]
            alphas = softmax(alphas)
            # print("alphas", alphas[0].mean(), alphas[1].mean(), alphas[2].mean())
            # print("alphas", alphas)

            alphas = []
            for al_i in range(len(support_ids)):
                alphas.append(torch.zeros(size=(batch1.size()[0], )))
            alphas[src_i] = torch.ones(size=(batch1.size()[0], ))

            alpha_cat = torch.zeros(size=(alphas[0].shape[0],
                                          len(support_ids)))
            for col, a_list in enumerate(alphas):
                alpha_cat[:, col] = a_list
            cur_alpha_weights_stack = np.concatenate(
                (cur_alpha_weights_stack, alpha_cat.detach().numpy()))
            # for j, supp_id in enumerate(support_ids):
            # cur_alpha_weights[supp_id] += alphas[j].data.tolist()
            # cur_alpha_weights[supp_id].append(alphas[j].mean().item())
            if args.cuda:
                alphas = [alpha.cuda() for alpha in alphas]
            alphas = [Variable(alpha) for alpha in alphas]

            outputs = [
                F.softmax(classifiers[j](hidden), dim=1) for j in support_ids
            ]
            output = sum([alpha.unsqueeze(1).repeat(1, 2) * output_i \
                          for (alpha, output_i) in zip(alphas, outputs)])
            # print("pred output", output)
            pred = output.data.max(dim=1)[1]
            oracle_eq = compute_oracle(outputs, label, args)

            if args.eval_only:
                for i in range(batch1.shape[0]):
                    for j in range(len(alphas)):
                        say("{:.4f}: [{:.4f}, {:.4f}], ".format(
                            alphas[j].data[i], outputs[j].data[i][0],
                            outputs[j].data[i][1]))
                    oracle_TF = "T" if oracle_eq[i] == 1 else colored(
                        "F", 'red')
                    say("gold: {}, pred: {}, oracle: {}\n".format(
                        label[i], pred[i], oracle_TF))
                say("\n")
                # print torch.cat(
                #         [
                #             torch.cat([ x.unsqueeze(1) for x in alphas ], 1),
                #             torch.cat([ x for x in outputs ], 1)
                #         ], 1
                #     )

            y_true += label.tolist()
            y_pred += pred.tolist()
            y_score += output[:, 1].data.tolist()
            correct += pred.eq(label).sum()
            oracle_correct += oracle_eq.sum()
            tot_cnt += output.size(0)

        # print("y_true", y_true)
        # print("y_pred", y_pred)

        # for j in support_ids:
        #     print(src_i, j, cur_alpha_weights[j])
        #     alphas_weights[src_i, j] = np.mean(cur_alpha_weights[j])
        # print(alphas_weights)
        alphas_weights[src_i, support_ids] = np.mean(cur_alpha_weights_stack,
                                                     axis=0)

        if thr is not None:
            print("using threshold %.4f" % thr[src_i])
            y_score = np.array(y_score)
            y_pred = np.zeros_like(y_score)
            y_pred[y_score > thr[src_i]] = 1

        # prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary")

        acc = float(correct) / tot_cnt
        oracle_acc = float(oracle_correct) / tot_cnt
        # print("source", src_i, "validation results: precision: {:.2f}, recall: {:.2f}, f1: {:.2f}".format(
        #     prec*100, rec*100, f1*100))
        # return (acc, oracle_acc), confusion_matrix(y_true, y_pred)

        prec, rec, f1, _ = precision_recall_fscore_support(y_true,
                                                           y_pred,
                                                           average="binary")
        auc = roc_auc_score(y_true, y_score)
        print("source {}, AUC: {:.2f}, Prec: {:.2f}, Rec: {:.2f}, F1: {:.2f}".
              format(src_i, auc * 100, prec * 100, rec * 100, f1 * 100))

        metrics.append([auc, prec, rec, f1])

        if return_best_thrs:
            precs, recs, thrs = precision_recall_curve(y_true, y_score)
            f1s = 2 * precs * recs / (precs + recs)
            f1s = f1s[:-1]
            thrs = thrs[~np.isnan(f1s)]
            f1s = f1s[~np.isnan(f1s)]
            best_thr = thrs[np.argmax(f1s)]
            print("best threshold=%4f, f1=%.4f", best_thr, np.max(f1s))
            thresholds.append(best_thr)

    print("source domain weight matrix\n", alphas_weights)

    metrics = np.array(metrics)
    return thresholds, metrics, alphas_weights
示例#6
0
def train_epoch(iter_cnt, encoders, classifiers, critic, mats, data_loaders,
                args, optim_model, epoch):
    encoders, encoder_dst = encoders
    classifiers, classifier_dst, classifier_mix = classifiers
    map(
        lambda m: m.train(), encoders +
        [encoder_dst, classifier_dst, critic, classifier_mix] + classifiers)

    train_loaders, train_loader_dst, unl_loader, valid_loader = data_loaders
    dup_train_loaders = deepcopy(train_loaders)

    # mtl_criterion = nn.CrossEntropyLoss()
    mtl_criterion = nn.NLLLoss()
    moe_criterion = nn.NLLLoss()  # with log_softmax separated
    kl_criterion = nn.MSELoss()
    entropy_criterion = HLoss()

    if args.metric == "biaffine":
        metric = biaffine_metric
        Us, Ws, Vs = mats
    else:
        metric = mahalanobis_metric
        Us, Ps, Ns = mats

    loss_total = 0
    total = 0

    for batches, batches_dst, unl_batch in zip(zip(*train_loaders),
                                               train_loader_dst, unl_loader):
        train_batches1, train_batches2, train_labels = zip(*batches)
        # print("train batches1", train_labels[0].size())
        # print("train batches2", train_batches2)
        # print("train labels", train_labels)
        unl_critic_batch1, unl_critic_batch2, unl_critic_label = unl_batch
        # print("unl", unl_critic_batch1)
        batches1_dst, batches2_dst, labels_dst = batches_dst
        # print("batches1_dst", batches1_dst)
        # print("batches2_dst", batches2_dst)

        total += len(batches1_dst)

        iter_cnt += 1
        if args.cuda:
            train_batches1 = [batch.cuda() for batch in train_batches1]
            train_batches2 = [batch.cuda() for batch in train_batches2]
            train_labels = [label.cuda() for label in train_labels]

            batches1_dst = batches1_dst.cuda()
            batches2_dst = batches2_dst.cuda()
            labels_dst = labels_dst.cuda()

            unl_critic_batch1 = unl_critic_batch1.cuda()
            unl_critic_batch2 = unl_critic_batch2.cuda()
            unl_critic_label = unl_critic_label.cuda()

        # train_batches1 = [Variable(batch) for batch in train_batches1]
        # train_batches2 = [Variable(batch) for batch in train_batches2]
        # train_labels = [Variable(label) for label in train_labels]
        # unl_critic_batch1 = Variable(unl_critic_batch1)
        # unl_critic_batch2 = Variable(unl_critic_batch2)
        # unl_critic_label = Variable(unl_critic_label)

        optim_model.zero_grad()
        loss_train_dst = []
        loss_mtl = []
        loss_moe = []
        loss_kl = []
        loss_entropy = []
        loss_dan = []
        loss_all = []

        ms_outputs = []  # (n_sources, n_classifiers)
        hiddens = []
        hidden_corresponding_labels = []
        # labels = []

        _, hidden_dst = encoder_dst(batches1_dst, batches2_dst)
        cur_output_dst = classifier_dst(hidden_dst)
        cur_output_dst_mem = torch.softmax(cur_output_dst, dim=1)
        cur_output_dst = torch.log(cur_output_dst_mem)
        loss_train_dst.append(mtl_criterion(cur_output_dst, labels_dst))

        outputs_dst_transfer = []
        for i in range(len(train_batches1)):
            _, cur_hidden = encoders[i](batches1_dst, batches2_dst)
            cur_output = classifiers[i](cur_hidden)
            outputs_dst_transfer.append(cur_output)

        for i, (batch1, batch2, label) in enumerate(
                zip(train_batches1, train_batches2, train_labels)):  # source i
            _, hidden = encoders[i](batch1, batch2)
            outputs = []
            # create output matrix:
            #     - (i, j) indicates the output of i'th source batch using j'th classifier
            # print("hidden", hidden)
            # raise
            hiddens.append(hidden)
            for classifier in classifiers:
                output = classifier(hidden)
                output = torch.log_softmax(output, dim=1)
                # print("output", output)
                outputs.append(output)
            ms_outputs.append(outputs)
            hidden_corresponding_labels.append(label)
            # multi-task loss
            # print("ms & label", ms_outputs[i][i], label)
            loss_mtl.append(mtl_criterion(ms_outputs[i][i], label))
            # labels.append(label)

            if args.lambda_critic > 0:
                # critic_batch = torch.cat([batch, unl_critic_batch])
                critic_label = torch.cat(
                    [1 - unl_critic_label, unl_critic_label])
                # critic_label = torch.cat([1 - unl_critic_label] * len(train_batches) + [unl_critic_label])

                if isinstance(critic, ClassificationD):
                    critic_output = critic(
                        torch.cat(
                            hidden, encoders[i](unl_critic_batch1,
                                                unl_critic_batch2)))
                    loss_dan.append(
                        critic.compute_loss(critic_output, critic_label))
                else:
                    critic_output = critic(
                        hidden, encoders[i](unl_critic_batch1,
                                            unl_critic_batch2))
                    loss_dan.append(critic_output)

                    # critic_output = critic(torch.cat(hiddens), encoder(unl_critic_batch))
                    # loss_dan = critic_output
            else:
                loss_dan = Variable(torch.FloatTensor([0]))

        # assert (len(outputs) == len(outputs[0]))
        source_ids = range(len(train_batches1))
        # for i in source_ids:

        # support_ids = [x for x in source_ids if x != i]  # experts
        support_ids = [x for x in source_ids]  # experts

        # i = 0

        # support_alphas = [ metric(
        #                      hiddens[i],
        #                      hiddens[j].detach(),
        #                      hidden_corresponding_labels[j],
        #                      Us[j], Ps[j], Ns[j],
        #                      args) for j in support_ids ]

        if args.metric == "biaffine":
            source_alphas = [
                metric(
                    hidden_dst,
                    hiddens[j].detach(),
                    Us[0],
                    Ws[0],
                    Vs[0],  # for biaffine metric, we use a unified matrix
                    args) for j in source_ids
            ]
        else:
            source_alphas = [
                metric(
                    hidden_dst,  # i^th source
                    hiddens[j].detach(),
                    hidden_corresponding_labels[j],
                    Us[j],
                    Ps[j],
                    Ns[j],
                    args) for j in source_ids
            ]

        support_alphas = [source_alphas[x] for x in support_ids]

        # print torch.cat([ x.unsqueeze(1) for x in support_alphas ], 1)
        support_alphas = softmax(support_alphas)

        # print("support_alphas after softmax", support_alphas)

        # meta-supervision: KL loss over \alpha and real source
        source_alphas = softmax(source_alphas)  # [ 32, 32, 32 ]
        source_labels = [
            torch.FloatTensor([x == len(train_batches1)]) for x in source_ids
        ]  # one-hot
        if args.cuda:
            source_alphas = [alpha.cuda() for alpha in source_alphas]
            source_labels = [label.cuda() for label in source_labels]

        source_labels = Variable(torch.stack(source_labels, dim=0))  # 3*1
        # print("source labels", source_labels)
        source_alphas = torch.stack(source_alphas, dim=0)
        # print("source_alpha after stack", source_alphas)

        source_labels = source_labels.expand_as(source_alphas).permute(1, 0)
        source_alphas = source_alphas.permute(1, 0)
        loss_kl.append(kl_criterion(source_alphas, source_labels))

        # entropy loss over \alpha
        # entropy_loss = entropy_criterion(torch.stack(support_alphas, dim=0).permute(1, 0))
        # print source_alphas
        loss_entropy.append(entropy_criterion(source_alphas))

        output_moe_i = sum([alpha.unsqueeze(1).repeat(1, 2) * F.softmax(outputs_dst_transfer[id], dim=1) \
                            for alpha, id in zip(support_alphas, support_ids)])
        # output_moe_full = sum([ alpha.unsqueeze(1).repeat(1, 2) * F.softmax(ms_outputs[i][id], dim=1) \
        #                         for alpha, id in zip(full_alphas, source_ids) ])

        # print("output_moe_i & labels", output_moe_i, train_labels[i])
        loss_moe.append(moe_criterion(torch.log(output_moe_i), labels_dst))
        # loss_moe.append(moe_criterion(torch.log(output_moe_full), train_labels[i]))

        # print("labels_dst", labels_dst)

        # upper_out = classifier_mix(torch.cat((cur_output_dst_mem, output_moe_i), dim=1))
        upper_out = cur_output_dst_mem + classifier_mix.multp * output_moe_i
        loss_all = mtl_criterion(torch.log_softmax(upper_out, dim=1),
                                 labels_dst)

        loss_train_dst = sum(loss_train_dst)

        loss_mtl = sum(loss_mtl)
        # print("loss mtl", loss_mtl)
        # loss_mtl = loss_mtl.mean()
        loss_mtl /= len(source_ids)
        loss_moe = sum(loss_moe)
        # if iter_cnt < 400:
        #     lambda_moe = 0
        #     lambda_entropy = 0
        # else:
        lambda_moe = args.lambda_moe
        lambda_entropy = args.lambda_entropy
        # loss = (1 - lambda_moe) * loss_mtl + lambda_moe * loss_moe
        loss = args.lambda_mtl * loss_mtl + lambda_moe * loss_moe
        loss_kl = sum(loss_kl)
        loss_entropy = sum(loss_entropy)
        loss += args.lambda_entropy * loss_entropy
        loss += loss_train_dst * args.lambda_dst
        loss += loss_all * args.lambda_all

        loss_total += loss

        if args.lambda_critic > 0:
            loss_dan = sum(loss_dan)
            loss += args.lambda_critic * loss_dan

        loss.backward()
        optim_model.step()

        # print("loss entropy", loss_entropy)

        # print("mats", [Us, Ps, Ns])
        # for paras in task_paras:
        #     print(paras)
        #     for name, param in paras:
        #         if param.requires_grad:
        #             print(name, param.data)

        # for name, param in encoder.named_parameters():
        #     if param.requires_grad:
        #         # print(name, param.data)
        #         print(name, param.grad)

        for cls_i, classifier in enumerate(classifiers):
            for name, param in classifier.named_parameters():
                # print(cls_i, name, param.grad)
                pass

        if iter_cnt % 5 == 0:
            # [(mu_i, covi_i), ...]
            # domain_encs = domain_encoding(dup_train_loaders, args, encoder)
            if args.metric == "biaffine":
                mats = [Us, Ws, Vs]
            else:
                mats = [Us, Ps, Ns]

            # evaluate(
            #             #     [encoders, encoder_dst],
            #             #     [classifiers, classifier_dst, classifier_mix],
            #             #     mats,
            #             #     [dup_train_loaders, valid_loader],
            #             #     True,
            #             #     args
            #             # )

            # say("\r" + " " * 50)
            # TODO: print train acc as well
            # print("loss dan", loss_dan)
            say("{} MTL loss: {:.4f}, MOE loss: {:.4f}, DAN loss: {:.4f}, "
                "loss: {:.4f}\n"
                # ", dev acc/oracle: {:.4f}/{:.4f}"
                .format(iter_cnt,
                        loss_mtl.item(),
                        loss_moe.item(),
                        loss_dan.item(),
                        loss.item(),
                        # curr_dev,
                        # oracle_curr_dev
                        ))

    writer.add_scalar('training_loss', loss_total / total, epoch)

    say("\n")
    return iter_cnt
示例#7
0
def train(args):
    ''' Training Strategy

    Input: source = {S1, S2, ..., Sk}, target = {T}

    Train:
        Approach 1: fix metric and learn encoder only
        Approach 2: learn metric and encoder alternatively
    '''

    # test_mahalanobis_metric() and return

    args.cuda = not args.no_cuda and torch.cuda.is_available()
    say('cuda is available %s\n' % args.cuda)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed + args.seed_delta)
    if args.cuda:
        torch.cuda.manual_seed(args.seed + args.seed_delta)

    source_train_sets = args.train.split(',')
    print("sources", source_train_sets)

    encoders = []
    for _ in range(len(source_train_sets)):
        # encoder_class = get_model_class("mlp")
        encoder_class = CNNMatchModel(input_matrix_size1=args.matrix_size1,
                                      input_matrix_size2=args.matrix_size2,
                                      mat1_channel1=args.mat1_channel1,
                                      mat1_kernel_size1=args.mat1_kernel_size1,
                                      mat1_channel2=args.mat1_channel2,
                                      mat1_kernel_size2=args.mat1_kernel_size2,
                                      mat1_hidden=args.mat1_hidden,
                                      mat2_channel1=args.mat2_channel1,
                                      mat2_kernel_size1=args.mat2_kernel_size1,
                                      mat2_hidden=args.mat2_hidden)

        # encoder_class.add_config(argparser)
        encoders.append(encoder_class)

    encoder_dst = CNNMatchModel(input_matrix_size1=args.matrix_size1,
                                input_matrix_size2=args.matrix_size2,
                                mat1_channel1=args.mat1_channel1,
                                mat1_kernel_size1=args.mat1_kernel_size1,
                                mat1_channel2=args.mat1_channel2,
                                mat1_kernel_size2=args.mat1_kernel_size2,
                                mat1_hidden=args.mat1_hidden,
                                mat2_channel1=args.mat2_channel1,
                                mat2_kernel_size1=args.mat2_kernel_size1,
                                mat2_hidden=args.mat2_hidden)

    critic_class = get_critic_class(args.critic)
    critic_class.add_config(argparser)

    args = argparser.parse_args()
    say(args)

    # encoder is shared across domains
    # encoder = encoder_class(args)
    # encoder = encoder_class

    print()
    print("encoder", encoders[0])

    say("Transferring from %s to %s\n" % (args.train, args.test))
    train_loaders = []
    # valid_loaders_src = []
    # test_loaders_src = []
    Us = []
    Ps = []
    Ns = []
    Ws = []
    Vs = []
    # Ms = []

    for source in source_train_sets:
        # filepath = os.path.join(DATA_DIR, "%s_train.svmlight" % (source))
        filepath = os.path.join(settings.DOM_ADAPT_DIR,
                                "{}_train.pkl".format(source))
        assert (os.path.exists(filepath))
        # train_dataset = AmazonDataset(filepath)
        train_dataset = ProcessedCNNInputDataset(source, "train")
        train_loader = data.DataLoader(train_dataset,
                                       batch_size=args.batch_size,
                                       shuffle=False,
                                       num_workers=0)
        train_loaders.append(train_loader)

        # cur_valid_dataset = ProcessedCNNInputDataset(source, "valid")
        # cur_valid_loader = data.DataLoader(
        #     cur_valid_dataset,
        #     batch_size=args.batch_size,
        #     shuffle=False,
        #     num_workers=0
        # )
        # valid_loaders_src.append(cur_valid_loader)
        #
        # cur_test_dataset = ProcessedCNNInputDataset(source, "test")
        # cur_test_loader = data.DataLoader(
        #     cur_test_dataset,
        #     batch_size=args.batch_size,
        #     shuffle=False,
        #     num_workers=0
        # )
        # test_loaders_src.append(cur_test_loader)

        if args.metric == "biaffine":
            U = torch.FloatTensor(encoders[0].n_d, encoders[0].n_d)
            W = torch.FloatTensor(encoders[0].n_d, 1)
            nn.init.xavier_uniform(W)
            Ws.append(W)
            V = torch.FloatTensor(encoders[0].n_d, 1)
            nn.init.xavier_uniform(V)
            Vs.append(V)
        else:
            U = torch.FloatTensor(encoders[0].n_d, args.m_rank)

        nn.init.xavier_uniform_(U)
        Us.append(U)
        P = torch.FloatTensor(encoders[0].n_d, args.m_rank)
        nn.init.xavier_uniform_(P)
        Ps.append(P)
        N = torch.FloatTensor(encoders[0].n_d, args.m_rank)
        nn.init.xavier_uniform_(N)
        Ns.append(N)
        # Ms.append(U.mm(U.t()))

    # unl_filepath = os.path.join(DATA_DIR, "%s_train.svmlight" % (args.test))
    unl_filepath = os.path.join(settings.DOM_ADAPT_DIR,
                                "{}_train.pkl".format(args.test))
    print("****************", unl_filepath)
    assert (os.path.exists(unl_filepath))
    # unl_dataset = AmazonDomainDataset(unl_filepath)  # using domain as labels
    unl_dataset = OAGDomainDataset(args.test, "train")
    unl_loader = data.DataLoader(unl_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=0)

    train_dataset_dst = ProcessedCNNInputDataset(args.test, "train")
    train_loader_dst = data.DataLoader(train_dataset_dst,
                                       batch_size=args.batch_size,
                                       shuffle=False,
                                       num_workers=0)

    # valid_filepath = os.path.join(DATA_DIR, "%s_test.svmlight" % (args.test))  # No dev files
    # valid_dataset = AmazonDataset(valid_filepath)
    valid_dataset = ProcessedCNNInputDataset(args.test, "valid")
    print("valid y", len(valid_dataset), valid_dataset.y)
    valid_loader = data.DataLoader(valid_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   num_workers=0)

    # test_filepath = os.path.join(DATA_DIR, "%s_test.svmlight" % (args.test))
    # assert (os.path.exists(test_filepath))
    # test_dataset = AmazonDataset(test_filepath)
    test_dataset = ProcessedCNNInputDataset(args.test, "test")
    test_loader = data.DataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=0)
    say("Corpus loaded.\n")

    classifiers = []
    for source in source_train_sets:  # only one layer
        classifier = nn.Linear(encoders[0].n_out, 2)  # binary classification
        # classifier = encoder.fc_out
        # nn.init.xavier_normal(classifier.weight)
        # nn.init.constant(classifier.bias, 0.1)
        classifiers.append(classifier)

    classifier_dst = nn.Linear(encoder_dst.n_out, 2)
    # classifier_mix = nn.Linear(2, 2)
    classifier_mix = WeightScaler()

    critic = critic_class(encoders[0], args)

    # if args.save_model:
    #     say(colored("Save model to {}\n".format(args.save_model + ".init"), 'red'))
    #     torch.save([encoder, classifiers, Us, Ps, Ns], args.save_model + ".init")

    if args.cuda:
        map(lambda m: m.cuda(),
            [encoder_dst, critic, classifier_dst, classifier_mix] + encoders +
            classifiers)
        Us = [Variable(U.cuda(), requires_grad=True) for U in Us]
        Ps = [Variable(P.cuda(), requires_grad=True) for P in Ps]
        Ns = [Variable(N.cuda(), requires_grad=True) for N in Ns]
        if args.metric == "biaffine":
            Ws = [Variable(W.cuda(), requires_grad=True) for W in Ws]
            Vs = [Variable(V.cuda(), requires_grad=True) for V in Vs]

    # Ms = [ U.mm(U.t()) for U in Us ]

    # say("\nEncoder: {}\n".format(encoder))
    for i, classifier in enumerate(classifiers):
        say("Classifier-{}: {}\n".format(i, classifier))
    say("Critic: {}\n".format(critic))

    requires_grad = lambda x: x.requires_grad
    # task_params = list(encoder.parameters())
    task_params = []
    for encoder in encoders:
        task_params += encoder.parameters()
    task_params += encoder_dst.parameters()
    for classifier in classifiers:
        task_params += list(classifier.parameters())
    task_params += classifier_dst.parameters()
    task_params += classifier_mix.parameters()
    # task_params += [classifier_mix.data]
    task_params += list(critic.parameters())
    task_params += Us
    task_params += Ps
    task_params += Ns
    if args.metric == "biaffine":
        task_params += Ws
        task_params += Vs

    optim_model = optim.Adagrad(  # use adagrad instead of adam
        filter(requires_grad, task_params),
        lr=args.lr,
        weight_decay=1e-4)

    say("Training will begin from scratch\n")

    best_dev = 0
    best_test = 0
    iter_cnt = 0

    # encoder.load_state_dict(torch.load(os.path.join(settings.OUT_VENUE_DIR, "venue-matching-cnn.mdl")))

    for epoch in range(args.max_epoch):
        say("epoch: {}\n".format(epoch))
        if args.metric == "biaffine":
            mats = [Us, Ws, Vs]
        else:
            mats = [Us, Ps, Ns]

        iter_cnt = train_epoch(
            iter_cnt, [encoders, encoder_dst],
            [classifiers, classifier_dst, classifier_mix], critic, mats,
            [train_loaders, train_loader_dst, unl_loader, valid_loader], args,
            optim_model, epoch)

        # thrs, metrics_val, src_weights_val = evaluate_cross(
        #     encoder, classifiers,
        #     mats,
        #     [train_loaders, valid_loaders_src],
        #     return_best_thrs=True,
        #     args=args
        # )
        #
        # _, metrics_test, src_weights_test = evaluate_cross(
        #     encoder, classifiers,
        #     mats,
        #     [train_loaders, test_loaders_src],
        #     return_best_thrs=False,
        #     args=args,
        #     thr=thrs
        # )

        thr, metrics_val = evaluate(
            epoch, [encoders, encoder_dst],
            [classifiers, classifier_dst, classifier_mix], mats,
            [train_loaders, valid_loader], True, args)
        # say("Dev accuracy/oracle: {:.4f}/{:.4f}\n".format(curr_dev, oracle_curr_dev))
        _, metrics_test = evaluate(
            epoch, [encoders, encoder_dst],
            [classifiers, classifier_dst, classifier_mix],
            mats, [train_loaders, test_loader],
            False,
            args,
            thr=thr)
        # say("Test accuracy/oracle: {:.4f}/{:.4f}\n".format(curr_test, oracle_curr_test))

        # if curr_dev >= best_dev:
        #     best_dev = curr_dev
        #     best_test = curr_test
        #     print(confusion_mat)
        #     if args.save_model:
        #         say(colored("Save model to {}\n".format(args.save_model + ".best"), 'red'))
        #         torch.save([encoder, classifiers, Us, Ps, Ns], args.save_model + ".best")
        say("\n")

    say(colored("Best test accuracy {:.4f}\n".format(best_test), 'red'))
示例#8
0
def train(args):
    ''' Training Strategy

    Input: source = {S1, S2, ..., Sk}, target = {T}

    Train:
        Approach 1: fix metric and learn encoder only
        Approach 2: learn metric and encoder alternatively
    '''

    # test_mahalanobis_metric() and return

    encoder_class = get_model_class("mlp")
    encoder_class.add_config(argparser)
    critic_class = get_critic_class(args.critic)
    critic_class.add_config(argparser)

    args = argparser.parse_args()
    say(args)

    # encoder is shared across domains
    encoder = encoder_class(args)

    print()
    print("encoder", encoder)

    say("Transferring from %s to %s\n" % (args.train, args.test))
    source_train_sets = args.train.split(',')
    train_loaders = []
    Us = []
    Ps = []
    Ns = []
    Ws = []
    Vs = []
    # Ms = []
    for source in source_train_sets:
        filepath = os.path.join(DATA_DIR, "%s_train.svmlight" % (source))
        assert (os.path.exists(filepath))
        train_dataset = AmazonDataset(filepath)
        train_loader = data.DataLoader(train_dataset,
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=0)
        train_loaders.append(train_loader)

        if args.metric == "biaffine":
            U = torch.FloatTensor(encoder.n_d, encoder.n_d)
            W = torch.FloatTensor(encoder.n_d, 1)
            nn.init.xavier_uniform(W)
            Ws.append(W)
            V = torch.FloatTensor(encoder.n_d, 1)
            nn.init.xavier_uniform(V)
            Vs.append(V)
        else:
            U = torch.FloatTensor(encoder.n_d, args.m_rank)

        nn.init.xavier_uniform_(U)
        Us.append(U)
        P = torch.FloatTensor(encoder.n_d, args.m_rank)
        nn.init.xavier_uniform_(P)
        Ps.append(P)
        N = torch.FloatTensor(encoder.n_d, args.m_rank)
        nn.init.xavier_uniform_(N)
        Ns.append(N)
        # Ms.append(U.mm(U.t()))

    unl_filepath = os.path.join(DATA_DIR, "%s_train.svmlight" % (args.test))
    print("****************", unl_filepath)
    assert (os.path.exists(unl_filepath))
    unl_dataset = AmazonDomainDataset(unl_filepath)  # using domain as labels
    unl_loader = data.DataLoader(unl_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=0)

    valid_filepath = os.path.join(DATA_DIR, "%s_test.svmlight" %
                                  (args.test))  # No dev files
    valid_dataset = AmazonDataset(valid_filepath)
    valid_loader = data.DataLoader(valid_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   num_workers=0)

    test_filepath = os.path.join(DATA_DIR, "%s_test.svmlight" % (args.test))
    assert (os.path.exists(test_filepath))
    test_dataset = AmazonDataset(test_filepath)
    test_loader = data.DataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=0)
    say("Corpus loaded.\n")

    classifiers = []
    for source in source_train_sets:
        classifier = nn.Linear(encoder.n_out, 2)  # binary classification
        nn.init.xavier_normal(classifier.weight)
        nn.init.constant(classifier.bias, 0.1)
        classifiers.append(classifier)

    critic = critic_class(encoder, args)

    # if args.save_model:
    #     say(colored("Save model to {}\n".format(args.save_model + ".init"), 'red'))
    #     torch.save([encoder, classifiers, Us, Ps, Ns], args.save_model + ".init")

    if args.cuda:
        map(lambda m: m.cuda(), [encoder, critic] + classifiers)
        Us = [Variable(U.cuda(), requires_grad=True) for U in Us]
        Ps = [Variable(P.cuda(), requires_grad=True) for P in Ps]
        Ns = [Variable(N.cuda(), requires_grad=True) for N in Ns]
        if args.metric == "biaffine":
            Ws = [Variable(W.cuda(), requires_grad=True) for W in Ws]
            Vs = [Variable(V.cuda(), requires_grad=True) for V in Vs]

    # Ms = [ U.mm(U.t()) for U in Us ]

    say("\nEncoder: {}\n".format(encoder))
    for i, classifier in enumerate(classifiers):
        say("Classifier-{}: {}\n".format(i, classifier))
    say("Critic: {}\n".format(critic))

    requires_grad = lambda x: x.requires_grad
    task_params = list(encoder.parameters())
    for classifier in classifiers:
        task_params += list(classifier.parameters())
    task_params += list(critic.parameters())
    task_params += Us
    task_params += Ps
    task_params += Ns
    if args.metric == "biaffine":
        task_params += Ws
        task_params += Vs

    optim_model = optim.Adam(filter(requires_grad, task_params),
                             lr=args.lr,
                             weight_decay=1e-4)

    say("Training will begin from scratch\n")

    best_dev = 0
    best_test = 0
    iter_cnt = 0

    for epoch in range(args.max_epoch):
        if args.metric == "biaffine":
            mats = [Us, Ws, Vs]
        else:
            mats = [Us, Ps, Ns]

        iter_cnt = train_epoch(iter_cnt, encoder, classifiers, critic, mats,
                               [train_loaders, unl_loader, valid_loader], args,
                               optim_model)

        (curr_dev, oracle_curr_dev), confusion_mat = evaluate(
            encoder, classifiers, mats, [train_loaders, valid_loader], args)
        say("Dev accuracy/oracle: {:.4f}/{:.4f}\n".format(
            curr_dev, oracle_curr_dev))
        (curr_test, oracle_curr_test), confusion_mat = evaluate(
            encoder, classifiers, mats, [train_loaders, test_loader], args)
        say("Test accuracy/oracle: {:.4f}/{:.4f}\n".format(
            curr_test, oracle_curr_test))

        if curr_dev >= best_dev:
            best_dev = curr_dev
            best_test = curr_test
            print(confusion_mat)
            if args.save_model:
                say(
                    colored(
                        "Save model to {}\n".format(args.save_model + ".best"),
                        'red'))
                torch.save([encoder, classifiers, Us, Ps, Ns],
                           args.save_model + ".best")
        say("\n")

    say(colored("Best test accuracy {:.4f}\n".format(best_test), 'red'))
示例#9
0
def evaluate(encoder, classifiers, mats, loaders, args):
    ''' Evaluate model using MOE
    '''
    map(lambda m: m.eval(), [encoder] + classifiers)

    if args.metric == "biaffine":
        Us, Ws, Vs = mats
    else:
        Us, Ps, Ns = mats

    source_loaders, valid_loader = loaders
    domain_encs = domain_encoding(source_loaders, args, encoder)

    oracle_correct = 0
    correct = 0
    tot_cnt = 0
    y_true = []
    y_pred = []

    for batch, label in valid_loader:
        if args.cuda:
            batch = batch.cuda()
            label = label.cuda()

        batch = Variable(batch)
        hidden = encoder(batch)
        source_ids = range(len(domain_encs))
        if args.metric == "biaffine":
            alphas = [ biaffine_metric_fast(hidden, mu[0], Us[0]) \
                       for mu in domain_encs ]
        else:
            alphas = [ mahalanobis_metric_fast(hidden, mu[0], U, mu[1], P, mu[2], N) \
                       for (mu, U, P, N) in zip(domain_encs, Us, Ps, Ns) ]
        # alphas = [ (1 - x / sum(alphas)) for x in alphas ]
        alphas = softmax(alphas)
        if args.cuda:
            alphas = [alpha.cuda() for alpha in alphas]
        alphas = [Variable(alpha) for alpha in alphas]

        outputs = [
            F.softmax(classifier(hidden), dim=1) for classifier in classifiers
        ]
        output = sum([ alpha.unsqueeze(1).repeat(1, 2) * output_i \
                        for (alpha, output_i) in zip(alphas, outputs) ])
        pred = output.data.max(dim=1)[1]
        oracle_eq = compute_oracle(outputs, label, args)

        if args.eval_only:
            for i in range(batch.shape[0]):
                for j in range(len(alphas)):
                    say("{:.4f}: [{:.4f}, {:.4f}], ".format(
                        alphas[j].data[i], outputs[j].data[i][0],
                        outputs[j].data[i][1]))
                oracle_TF = "T" if oracle_eq[i] == 1 else colored("F", 'red')
                say("gold: {}, pred: {}, oracle: {}\n".format(
                    label[i], pred[i], oracle_TF))
            say("\n")
            # print torch.cat(
            #         [
            #             torch.cat([ x.unsqueeze(1) for x in alphas ], 1),
            #             torch.cat([ x for x in outputs ], 1)
            #         ], 1
            #     )

        y_true += label.tolist()
        y_pred += pred.tolist()
        correct += pred.eq(label).sum()
        oracle_correct += oracle_eq.sum()
        tot_cnt += output.size(0)

    acc = float(correct) / tot_cnt
    oracle_acc = float(oracle_correct) / tot_cnt
    return (acc, oracle_acc), confusion_matrix(y_true, y_pred)
示例#10
0
def train_epoch(iter_cnt, encoder, classifiers, critic, mats, data_loaders,
                args, optim_model):
    map(lambda m: m.train(), [encoder, critic] + classifiers)

    train_loaders, unl_loader, valid_loader = data_loaders
    dup_train_loaders = deepcopy(train_loaders)

    mtl_criterion = nn.CrossEntropyLoss()
    moe_criterion = nn.NLLLoss()  # with log_softmax separated
    kl_criterion = nn.MSELoss()
    entropy_criterion = HLoss()

    if args.metric == "biaffine":
        metric = biaffine_metric
        Us, Ws, Vs = mats
    else:
        metric = mahalanobis_metric
        Us, Ps, Ns = mats

    for batches, unl_batch in zip(zip(*train_loaders), unl_loader):
        train_batches, train_labels = zip(*batches)
        unl_critic_batch, unl_critic_label = unl_batch

        iter_cnt += 1
        if args.cuda:
            train_batches = [batch.cuda() for batch in train_batches]
            train_labels = [label.cuda() for label in train_labels]

            unl_critic_batch = unl_critic_batch.cuda()
            unl_critic_label = unl_critic_label.cuda()

        train_batches = [Variable(batch) for batch in train_batches]
        train_labels = [Variable(label) for label in train_labels]
        unl_critic_batch = Variable(unl_critic_batch)
        unl_critic_label = Variable(unl_critic_label)

        optim_model.zero_grad()
        loss_mtl = []
        loss_moe = []
        loss_kl = []
        loss_entropy = []
        loss_dan = []

        ms_outputs = []  # (n_sources, n_classifiers)
        hiddens = []
        hidden_corresponding_labels = []
        # labels = []
        for i, (batch, label) in enumerate(zip(train_batches, train_labels)):
            hidden = encoder(batch)
            outputs = []
            # create output matrix:
            #     - (i, j) indicates the output of i'th source batch using j'th classifier
            hiddens.append(hidden)
            for classifier in classifiers:
                output = classifier(hidden)
                outputs.append(output)
            ms_outputs.append(outputs)
            hidden_corresponding_labels.append(label)
            # multi-task loss
            loss_mtl.append(mtl_criterion(ms_outputs[i][i], label))
            # labels.append(label)

            if args.lambda_critic > 0:
                # critic_batch = torch.cat([batch, unl_critic_batch])
                critic_label = torch.cat(
                    [1 - unl_critic_label, unl_critic_label])
                # critic_label = torch.cat([1 - unl_critic_label] * len(train_batches) + [unl_critic_label])

                if isinstance(critic, ClassificationD):
                    critic_output = critic(
                        torch.cat(hidden, encoder(unl_critic_batch)))
                    loss_dan.append(
                        critic.compute_loss(critic_output, critic_label))
                else:
                    critic_output = critic(hidden, encoder(unl_critic_batch))
                    loss_dan.append(critic_output)

                    # critic_output = critic(torch.cat(hiddens), encoder(unl_critic_batch))
                    # loss_dan = critic_output
            else:
                loss_dan = Variable(torch.FloatTensor([0]))

        # assert (len(outputs) == len(outputs[0]))
        source_ids = range(len(train_batches))
        for i in source_ids:

            support_ids = [x for x in source_ids if x != i]  # experts

            # support_alphas = [ metric(
            #                      hiddens[i],
            #                      hiddens[j].detach(),
            #                      hidden_corresponding_labels[j],
            #                      Us[j], Ps[j], Ns[j],
            #                      args) for j in support_ids ]

            if args.metric == "biaffine":
                source_alphas = [
                    metric(
                        hiddens[i],
                        hiddens[j].detach(),
                        Us[0],
                        Ws[0],
                        Vs[0],  # for biaffine metric, we use a unified matrix
                        args) for j in source_ids
                ]
            else:
                source_alphas = [
                    metric(
                        hiddens[i],  # i^th source
                        hiddens[j].detach(),
                        hidden_corresponding_labels[j],
                        Us[j],
                        Ps[j],
                        Ns[j],
                        args) for j in source_ids
                ]

            support_alphas = [source_alphas[x] for x in support_ids]

            # print torch.cat([ x.unsqueeze(1) for x in support_alphas ], 1)
            support_alphas = softmax(support_alphas)

            # print("support_alphas after softmax", support_alphas)

            # meta-supervision: KL loss over \alpha and real source
            source_alphas = softmax(source_alphas)  # [ 32, 32, 32 ]
            source_labels = [torch.FloatTensor([x == i])
                             for x in source_ids]  # one-hot
            if args.cuda:
                source_alphas = [alpha.cuda() for alpha in source_alphas]
                source_labels = [label.cuda() for label in source_labels]

            source_labels = Variable(torch.stack(source_labels, dim=0))  # 3*1
            source_alphas = torch.stack(source_alphas, dim=0)
            print("source_alpha after stack", source_alphas.size())

            source_labels = source_labels.expand_as(source_alphas).permute(
                1, 0)
            source_alphas = source_alphas.permute(1, 0)
            loss_kl.append(kl_criterion(source_alphas, source_labels))

            # entropy loss over \alpha
            # entropy_loss = entropy_criterion(torch.stack(support_alphas, dim=0).permute(1, 0))
            # print source_alphas
            loss_entropy.append(entropy_criterion(source_alphas))

            output_moe_i = sum([ alpha.unsqueeze(1).repeat(1, 2) * F.softmax(ms_outputs[i][id], dim=1) \
                                    for alpha, id in zip(support_alphas, support_ids) ])
            # output_moe_full = sum([ alpha.unsqueeze(1).repeat(1, 2) * F.softmax(ms_outputs[i][id], dim=1) \
            #                         for alpha, id in zip(full_alphas, source_ids) ])

            loss_moe.append(
                moe_criterion(torch.log(output_moe_i), train_labels[i]))
            # loss_moe.append(moe_criterion(torch.log(output_moe_full), train_labels[i]))

        loss_mtl = sum(loss_mtl)
        loss_moe = sum(loss_moe)
        # if iter_cnt < 400:
        #     lambda_moe = 0
        #     lambda_entropy = 0
        # else:
        lambda_moe = args.lambda_moe
        lambda_entropy = args.lambda_entropy
        # loss = (1 - lambda_moe) * loss_mtl + lambda_moe * loss_moe
        loss = loss_mtl + lambda_moe * loss_moe
        loss_kl = sum(loss_kl)
        loss_entropy = sum(loss_entropy)
        loss += args.lambda_entropy * loss_entropy

        if args.lambda_critic > 0:
            loss_dan = sum(loss_dan)
            loss += args.lambda_critic * loss_dan

        loss.backward()
        optim_model.step()

        if iter_cnt % 30 == 0:
            # [(mu_i, covi_i), ...]
            # domain_encs = domain_encoding(dup_train_loaders, args, encoder)
            if args.metric == "biaffine":
                mats = [Us, Ws, Vs]
            else:
                mats = [Us, Ps, Ns]

            (curr_dev, oracle_curr_dev), confusion_mat = evaluate(
                encoder, classifiers, mats, [dup_train_loaders, valid_loader],
                args)

            # say("\r" + " " * 50)
            # TODO: print train acc as well
            # print("loss dan", loss_dan)
            say("{} MTL loss: {:.4f}, MOE loss: {:.4f}, DAN loss: {:.4f}, "
                "loss: {:.4f}, dev acc/oracle: {:.4f}/{:.4f}\n".format(
                    iter_cnt, loss_mtl.item(), loss_moe.item(),
                    loss_dan.item(), loss.item(), curr_dev, oracle_curr_dev))

    say("\n")
    return iter_cnt
示例#11
0
def evaluate(encoders,
             classifiers,
             mats,
             loaders,
             return_best_thrs,
             args,
             thr=None):
    ''' Evaluate model using MOE
    '''
    map(lambda m: m.eval(), [encoders] + classifiers)

    if args.metric == "biaffine":
        Us, Ws, Vs = mats
    else:
        Us, Ps, Ns = mats

    source_loaders, valid_loader = loaders
    domain_encs = domain_encoding(source_loaders, args, encoders)

    oracle_correct = 0
    correct = 0
    tot_cnt = 0
    y_true = []
    y_pred = []
    y_score = []

    for batch1, batch2, label in valid_loader:
        if args.cuda:
            batch1 = batch1.cuda()
            batch2 = batch2.cuda()
            label = label.cuda()
        # print("eval labels", label)

        batch1 = Variable(batch1)
        batch2 = Variable(batch2)
        # print("bs", len(batch1))
        _, hidden = encoders[0](batch1, batch2)
        source_ids = range(len(domain_encs))
        if args.metric == "biaffine":
            alphas = [biaffine_metric_fast(hidden, mu[0], Us[0]) \
                      for mu in domain_encs]
        else:
            alphas = [mahalanobis_metric_fast(hidden, mu[0], U, mu[1], P, mu[2], N) \
                      for (mu, U, P, N) in zip(domain_encs, Us, Ps, Ns)]
        # alphas = [ (1 - x / sum(alphas)) for x in alphas ]
        alphas = softmax(alphas)
        if args.cuda:
            alphas = [alpha.cuda() for alpha in alphas]
        alphas = [Variable(alpha) for alpha in alphas]

        outputs = [
            F.softmax(classifier(hidden), dim=1) for classifier in classifiers
        ]
        output = sum([alpha.unsqueeze(1).repeat(1, 2) * output_i \
                      for (alpha, output_i) in zip(alphas, outputs)])
        pred = output.data.max(dim=1)[1]
        oracle_eq = compute_oracle(outputs, label, args)

        if args.eval_only:
            for i in range(batch1.shape[0]):
                for j in range(len(alphas)):
                    say("{:.4f}: [{:.4f}, {:.4f}], ".format(
                        alphas[j].data[i], outputs[j].data[i][0],
                        outputs[j].data[i][1]))
                oracle_TF = "T" if oracle_eq[i] == 1 else colored("F", 'red')
                say("gold: {}, pred: {}, oracle: {}\n".format(
                    label[i], pred[i], oracle_TF))
            say("\n")
            # print torch.cat(
            #         [
            #             torch.cat([ x.unsqueeze(1) for x in alphas ], 1),
            #             torch.cat([ x for x in outputs ], 1)
            #         ], 1
            #     )

        y_true += label.tolist()
        y_pred += pred.tolist()
        # print("output", output[:, 1].data.tolist())
        y_score += output[:, 1].data.tolist()
        # print("cur y score", y_score)

        correct += pred.eq(label).sum()
        oracle_correct += oracle_eq.sum()
        tot_cnt += output.size(0)

    # print("y_true", y_true)
    # print("y_pred", y_pred)

    if thr is not None:
        print("using threshold %.4f" % thr)
        y_score = np.array(y_score)
        y_pred = np.zeros_like(y_score)
        y_pred[y_score > thr] = 1
    else:
        # print("y_score", y_score)
        pass

    prec, rec, f1, _ = precision_recall_fscore_support(y_true,
                                                       y_pred,
                                                       average="binary")
    auc = roc_auc_score(y_true, y_score)
    print("AUC: {:.2f}, Prec: {:.2f}, Rec: {:.2f}, F1: {:.2f}".format(
        auc * 100, prec * 100, rec * 100, f1 * 100))

    best_thr = None
    metric = [auc, prec, rec, f1]

    if return_best_thrs:
        precs, recs, thrs = precision_recall_curve(y_true, y_score)
        f1s = 2 * precs * recs / (precs + recs)
        f1s = f1s[:-1]
        thrs = thrs[~np.isnan(f1s)]
        f1s = f1s[~np.isnan(f1s)]
        best_thr = thrs[np.argmax(f1s)]
        print("best threshold={:.4f}, f1={:.4f}".format(best_thr, np.max(f1s)))

    acc = float(correct) / tot_cnt
    oracle_acc = float(oracle_correct) / tot_cnt
    # return (acc, oracle_acc), confusion_matrix(y_true, y_pred)
    return best_thr, metric