Пример #1
0
def predict(args):
    encoder, classifier = torch.load(args.load_model)
    map(lambda m: m.eval(), [encoder, classifier])

    if args.cuda:
        map(lambda m: m.cuda(), [encoder, classifier])

    test_filepath = os.path.join(DATA_DIR, "%s_train.svmlight" % (args.test))
    assert (os.path.exists(test_filepath))
    test_dataset = AmazonDataset(test_filepath)
    test_loader = data.DataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=0)

    acc, confusion_mat, _ = evaluate(encoder, classifier, test_loader, args)
    say(colored("Test accuracy {:.4f}\n".format(acc), 'red'))
    print confusion_mat
Пример #2
0
def predict(args):
    encoder, classifiers, Us, Ps, Ns = torch.load(args.load_model)
    map(lambda m: m.eval(), [encoder] + classifiers)

    # args = argparser.parse_args()
    # say(args)
    if args.cuda:
        map(lambda m: m.cuda(), [encoder] + classifiers)
        Us = [ U.cuda() for U in Us ]
        Ps = [ P.cuda() for P in Ps ]
        Ns = [ N.cuda() for N in Ns ]

    say("\nTransferring from %s to %s\n" % (args.train, args.test))
    source_train_sets = args.train.split(',')
    train_loaders = []
    for source in source_train_sets:
        filepath = os.path.join(DATA_DIR, "%s_train.svmlight" % (source))
        train_dataset = AmazonDataset(filepath)
        train_loader = data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=0
        )
        train_loaders.append(train_loader)

    test_filepath = os.path.join(DATA_DIR, "%s_test.svmlight" % (args.test))
    test_dataset = AmazonDataset(test_filepath)
    test_loader = data.DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=0
    )
    say("Corpus loaded.\n")

    mats = [Us, Ps, Ns]
    (acc, oracle_acc), confusion_mat = evaluate(
            encoder, classifiers,
            mats,
            [train_loaders, test_loader],
            args
        )
    say(colored("Test accuracy/oracle {:.4f}/{:.4f}\n".format(acc, oracle_acc), 'red'))
Пример #3
0
def train_epoch(iter_cnt, encoder, classifier, critic, train_loaders,
                target_d_loader, valid_loader, args, optimizer):
    encoder.train()
    classifier.train()
    critic.train()

    task_criterion = nn.CrossEntropyLoss()
    ae_criterion = nn.MSELoss()

    for source_batches, target_batch in zip(zip(*train_loaders),
                                            target_d_loader):

        all_task_batch, all_task_labels = zip(*source_batches)
        target_d_batch, target_d_labels = target_batch

        iter_cnt += 1

        if args.cuda:
            all_task_batch = [
                task_batch.cuda() for task_batch in all_task_batch
            ]
            all_task_labels = [
                task_labels.cuda() for task_labels in all_task_labels
            ]

            target_d_batch = target_d_batch.cuda()
            target_d_labels = target_d_labels.cuda()

        all_task_batch = [
            Variable(task_batch) for task_batch in all_task_batch
        ]
        all_task_labels = [
            Variable(task_labels) for task_labels in all_task_labels
        ]

        target_d_batch = Variable(target_d_batch)
        target_d_labels = Variable(target_d_labels)

        optimizer.zero_grad()

        loss_c = []
        loss_d = []
        for task_batch, task_labels in zip(all_task_batch, all_task_labels):
            ''' compute task loss '''
            hidden = encoder(task_batch)
            task_output = classifier(hidden)

            ### task accuracy on training batch
            task_pred = torch.squeeze(task_output.max(dim=1)[1])
            task_acc = (task_pred == task_labels).float().mean()

            ### task loss
            loss_c.append(task_criterion(task_output, task_labels))
            ''' domain-critic loss '''
            critic_batch = torch.cat([task_batch, target_d_batch])
            critic_labels = torch.cat([1 - target_d_labels, target_d_labels])

            # hidden_d = encoder(critic_batch)
            # move the grl layer to DANN
            # hidden_d_grl = flip_gradient(hidden_d)
            if isinstance(critic, ClassificationD):
                critic_output = critic(encoder(critic_batch))
                critic_pred = torch.squeeze(critic_output.max(dim=1)[1])
                critic_acc = (critic_pred == critic_labels).float().mean()
                loss_d.append(critic.compute_loss(critic_output,
                                                  critic_labels))
            else:  # mmd, coral, wd
                if args.cond is not None:
                    # outer(encoding, g) where g is the class distribution
                    target_d_g = F.softmax(classifier(encoder(target_d_batch)),
                                           dim=1).detach()
                    # task_g = F.softmax(task_output, dim=1).detach()
                    task_g = one_hot(task_labels, cuda=args.cuda)
                    # print torch.cat([task_g, task_labels.unsqueeze(1).float()], dim=1)
                    # print target_d_g

                    if args.cond == "concat":
                        task_encoding = torch.cat([hidden, task_g], dim=1)
                        target_d_encoding = torch.cat(
                            [encoder(target_d_batch), target_d_g], dim=1)
                    else:  # "outer"
                        task_encoding = torch.bmm(hidden.unsqueeze(2),
                                                  task_g.unsqueeze(1))
                        target_d_encoding = torch.bmm(
                            encoder(target_d_batch).unsqueeze(2),
                            target_d_g.unsqueeze(1))
                        task_encoding = task_encoding.view(
                            task_encoding.shape[0], -1).contiguous()
                        target_d_encoding = target_d_encoding.view(
                            target_d_encoding.shape[0], -1).contiguous()
                        # print task_encoding.shape, target_d_encoding.shape
                    critic_output = critic(task_encoding, target_d_encoding)
                else:
                    critic_output = critic(hidden, encoder(target_d_batch))
                loss_d.append(critic_output)

        loss_c = sum(loss_c)
        loss_d = sum(loss_d)
        loss = loss_c + args.lambda_critic * loss_d
        loss.backward()
        optimizer.step()

        if valid_loader and iter_cnt % 30 == 0:
            curr_test, confusion_mat, _ = evaluate(encoder, classifier,
                                                   valid_loader, args)

            # say("\r" + " " * 50)
            say("{} task loss/acc: {:.4f}/{:.4f}, "
                "domain critic loss: {:.4f}, "
                # "adversarial loss/acc: {:.4f}/{:.4f}, "
                "loss: {:.4f}, "
                "test acc: {:.4f}\n".format(
                    iter_cnt,
                    loss_c.data[0],
                    task_acc.data[0],
                    loss_d.data[0],
                    # loss_target_d.data[0], target_d_acc.data[0],
                    loss.data[0],
                    curr_test))

    say("\n")
    return iter_cnt
Пример #4
0
def train(args):
    encoder_class = get_model_class(args.encoder)
    encoder_class.add_config(argparser)
    critic_class = get_critic_class(args.critic)
    critic_class.add_config(argparser)

    args = argparser.parse_args()
    say(args)

    say("Transferring from %s to %s\n" % (args.train, args.test))

    source_train_sets = args.train.split(',')
    train_loaders = []
    for source in source_train_sets:
        filepath = os.path.join(DATA_DIR, "%s_train.svmlight" % (source))
        assert (os.path.exists(filepath))
        train_dataset = AmazonDataset(filepath)
        train_loader = data.DataLoader(train_dataset,
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=0)
        train_loaders.append(train_loader)

    target_d_filepath = os.path.join(DATA_DIR,
                                     "%s_train.svmlight" % (args.test))
    assert (os.path.exists(target_d_filepath))
    train_target_d_dataset = AmazonDomainDataset(target_d_filepath, domain=1)
    train_target_d_loader = data.DataLoader(train_target_d_dataset,
                                            batch_size=args.batch_size_d,
                                            shuffle=True,
                                            num_workers=0)

    valid_filepath = os.path.join(DATA_DIR, "%s_dev.svmlight" % (args.test))
    # assert (os.path.exists(valid_filepath))
    if os.path.exists(valid_filepath):
        valid_dataset = AmazonDataset(valid_filepath)
        valid_loader = data.DataLoader(valid_dataset,
                                       batch_size=args.batch_size,
                                       shuffle=False,
                                       num_workers=0)
    else:
        valid_loader = None

    test_filepath = os.path.join(DATA_DIR, "%s_test.svmlight" % (args.test))
    assert (os.path.exists(test_filepath))
    test_dataset = AmazonDataset(test_filepath)
    test_loader = data.DataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=0)
    say("Corpus loaded.\n")

    encoder = encoder_class(args)
    critic = critic_class(encoder, args)
    classifier = nn.Linear(encoder.n_out, 2)  # binary classification
    nn.init.xavier_normal_(classifier.weight)
    nn.init.constant_(classifier.bias, 0.1)

    gan_gen = encoder_class(args)
    gan_disc = MMD(gan_gen, args)

    if args.cuda:
        encoder = encoder.cuda()
        critic = critic.cuda()
        classifier = classifier.cuda()
        gan_gen = gan_gen.cuda()
        gan_disc = gan_disc.cuda()

    say("\n{}\n\n".format(encoder))
    say("\n{}\n\n".format(critic))
    say("\n{}\n\n".format(classifier))
    say("\n{}\n\n".format(gan_gen))
    say("\n{}\n\n".format(gan_disc))

    print(encoder.state_dict().keys())
    print(critic.state_dict().keys())
    print(classifier.state_dict().keys())
    print(gan_gen.state_dict().keys())
    print(gan_disc.state_dict().keys())

    requires_grad = lambda x: x.requires_grad
    task_params = list(encoder.parameters()) + \
                  list(classifier.parameters()) + \
                  list(critic.parameters())
    optimizer = optim.Adam(filter(requires_grad, task_params),
                           lr=args.lr,
                           weight_decay=1e-4)

    reg_params = list(encoder.parameters()) + \
                 list(gan_gen.parameters())
    optimizer_reg = optim.Adam(filter(requires_grad, reg_params),
                               lr=args.lr,
                               weight_decay=1e-4)

    say("Training will begin from scratch\n")

    best_dev = 0
    best_test = 0
    iter_cnt = 0

    for epoch in range(args.max_epoch):
        iter_cnt = train_epoch(iter_cnt, encoder, classifier, critic,
                               train_loaders, train_target_d_loader,
                               valid_loader, args, optimizer)

        if args.advreg:
            for loader in train_loaders + [train_target_d_loader]:
                train_advreg_mmd(iter_cnt, encoder, gan_gen, gan_disc, loader,
                                 args, optimizer_reg)

        if valid_loader:
            curr_dev, confusion_mat, _ = evaluate(encoder, classifier,
                                                  valid_loader, args)
            say("Dev accuracy: {:.4f}\n".format(curr_dev))

        curr_test, confusion_mat, _ = evaluate(encoder, classifier,
                                               test_loader, args)
        say("Test accuracy: {:.4f}\n".format(curr_test))

        if valid_loader and curr_dev >= best_dev:
            best_dev = curr_dev
            best_test = curr_test
            # print(confusion_mat)
            if args.save_model:
                say(
                    colored(
                        "Save model to {}\n".format(args.save_model + ".best"),
                        'red'))
                torch.save([encoder, classifier], args.save_model + ".best")
            say("\n")

    if valid_loader:
        say(colored("Best test accuracy {:.4f}\n".format(best_test), 'red'))
    say(
        colored("Test accuracy after training {:.4f}\n".format(curr_test),
                'red'))
Пример #5
0
def train(args):
    ''' Training Strategy

    Input: source = {S1, S2, ..., Sk}, target = {T}

    Train:
        Approach 1: fix metric and learn encoder only
        Approach 2: learn metric and encoder alternatively
    '''

    # test_mahalanobis_metric() and return

    encoder_class = get_model_class("mlp")
    encoder_class.add_config(argparser)
    critic_class = get_critic_class(args.critic)
    critic_class.add_config(argparser)

    args = argparser.parse_args()
    say(args)

    # encoder is shared across domains
    encoder = encoder_class(args)

    say("Transferring from %s to %s\n" % (args.train, args.test))
    source_train_sets = args.train.split(',')
    train_loaders = []
    Us = []
    Ps = []
    Ns = []
    Ws = []
    Vs = []
    # Ms = []
    for source in source_train_sets:
        filepath = os.path.join(DATA_DIR, "%s_train.svmlight" % (source))
        assert (os.path.exists(filepath))
        train_dataset = AmazonDataset(filepath)
        train_loader = data.DataLoader(train_dataset,
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=0)
        train_loaders.append(train_loader)

        if args.metric == "biaffine":
            U = torch.FloatTensor(encoder.n_d, encoder.n_d)
            W = torch.FloatTensor(encoder.n_d, 1)
            nn.init.xavier_uniform_(W)
            Ws.append(W)
            V = torch.FloatTensor(encoder.n_d, 1)
            nn.init.xavier_uniform_(V)
            Vs.append(V)
        else:
            U = torch.FloatTensor(encoder.n_d, args.m_rank)

        nn.init.xavier_uniform_(U)
        Us.append(U)
        P = torch.FloatTensor(encoder.n_d, args.m_rank)
        nn.init.xavier_uniform_(P)
        Ps.append(P)
        N = torch.FloatTensor(encoder.n_d, args.m_rank)
        nn.init.xavier_uniform_(N)
        Ns.append(N)
        # Ms.append(U.mm(U.t()))

    unl_filepath = os.path.join(DATA_DIR, "%s_train.svmlight" % (args.test))
    assert (os.path.exists(unl_filepath))
    unl_dataset = AmazonDomainDataset(unl_filepath)
    unl_loader = data.DataLoader(unl_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=0)

    valid_filepath = os.path.join(DATA_DIR, "%s_dev.svmlight" % (args.test))
    if os.path.exists(valid_filepath):
        valid_dataset = AmazonDataset(valid_filepath)
        valid_loader = data.DataLoader(valid_dataset,
                                       batch_size=args.batch_size,
                                       shuffle=False,
                                       num_workers=0)
    else:
        valid_loader = None

    test_filepath = os.path.join(DATA_DIR, "%s_test.svmlight" % (args.test))
    assert (os.path.exists(test_filepath))
    test_dataset = AmazonDataset(test_filepath)
    test_loader = data.DataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=0)
    say("Corpus loaded.\n")

    classifiers = []
    for source in source_train_sets:
        classifier = nn.Linear(encoder.n_out, 2)  # binary classification
        nn.init.xavier_normal_(classifier.weight)
        nn.init.constant_(classifier.bias, 0.1)
        classifiers.append(classifier)

    critic = critic_class(encoder, args)

    # if args.save_model:
    #     say(colored("Save model to {}\n".format(args.save_model + ".init"), 'red'))
    #     torch.save([encoder, classifiers, Us, Ps, Ns], args.save_model + ".init")

    if args.cuda:
        # map(lambda m: m.cuda(), [encoder, critic] + classifiers)
        encoder = encoder.cuda()
        critic = critic.cuda()
        classifiers = [x.cuda() for x in classifiers]
        Us = [Variable(U.cuda(), requires_grad=True) for U in Us]
        Ps = [Variable(P.cuda(), requires_grad=True) for P in Ps]
        Ns = [Variable(N.cuda(), requires_grad=True) for N in Ns]
        if args.metric == "biaffine":
            Ws = [Variable(W.cuda(), requires_grad=True) for W in Ws]
            Vs = [Variable(V.cuda(), requires_grad=True) for V in Vs]

    # Ms = [ U.mm(U.t()) for U in Us ]

    say("\nEncoder: {}\n".format(encoder))
    for i, classifier in enumerate(classifiers):
        say("Classifier-{}: {}\n".format(i, classifier))
    say("Critic: {}\n".format(critic))

    requires_grad = lambda x: x.requires_grad
    task_params = list(encoder.parameters())
    for classifier in classifiers:
        task_params += list(classifier.parameters())
    task_params += list(critic.parameters())
    task_params += Us
    task_params += Ps
    task_params += Ns
    if args.metric == "biaffine":
        task_params += Ws
        task_params += Vs

    optim_model = optim.Adam(filter(requires_grad, task_params),
                             lr=args.lr,
                             weight_decay=1e-4)

    say("Training will begin from scratch\n")

    best_dev = 0
    best_test = 0
    iter_cnt = 0

    for epoch in range(args.max_epoch):
        if args.metric == "biaffine":
            mats = [Us, Ws, Vs]
        else:
            mats = [Us, Ps, Ns]

        iter_cnt = train_epoch(iter_cnt, encoder, classifiers, critic, mats,
                               [train_loaders, unl_loader, valid_loader], args,
                               optim_model)

        if valid_loader:
            (curr_dev, oracle_curr_dev), confusion_mat = evaluate(
                encoder, classifiers, mats, [train_loaders, valid_loader],
                args)
            say("Dev accuracy/oracle: {:.4f}/{:.4f}\n".format(
                curr_dev, oracle_curr_dev))
        (curr_test, oracle_curr_test), confusion_mat = evaluate(
            encoder, classifiers, mats, [train_loaders, test_loader], args)
        say("Test accuracy/oracle: {:.4f}/{:.4f}\n".format(
            curr_test, oracle_curr_test))

        if valid_loader and curr_dev >= best_dev:
            best_dev = curr_dev
            best_test = curr_test
            print(confusion_mat)
            if args.save_model:
                say(
                    colored(
                        "Save model to {}\n".format(args.save_model + ".best"),
                        'red'))
                torch.save([encoder, classifiers, Us, Ps, Ns],
                           args.save_model + ".best")
            say("\n")

    if valid_loader:
        say(colored("Best test accuracy {:.4f}\n".format(best_test), 'red'))
    say(
        colored("Test accuracy after training {:.4f}\n".format(curr_test),
                'red'))
Пример #6
0
def evaluate(encoder, classifiers, mats, loaders, args):
    ''' Evaluate model using MOE
    '''
    map(lambda m: m.eval(), [encoder] + classifiers)

    if args.metric == "biaffine":
        Us, Ws, Vs = mats
    else:
        Us, Ps, Ns = mats

    source_loaders, valid_loader = loaders
    domain_encs = domain_encoding(source_loaders, args, encoder)

    oracle_correct = 0
    correct = 0
    tot_cnt = 0
    y_true = []
    y_pred = []

    for batch, label in valid_loader:
        if args.cuda:
            batch = batch.cuda()
            label = label.cuda()

        batch = Variable(batch)
        hidden = encoder(batch)
        source_ids = range(len(domain_encs))
        if args.metric == "biaffine":
            alphas = [ biaffine_metric_fast(hidden, mu[0], Us[0]) \
                       for mu in domain_encs ]
        else:
            alphas = [ mahalanobis_metric_fast(hidden, mu[0], U, mu[1], P, mu[2], N) \
                       for (mu, U, P, N) in zip(domain_encs, Us, Ps, Ns) ]
        # alphas = [ (1 - x / sum(alphas)) for x in alphas ]
        alphas = softmax(alphas)
        if args.cuda:
            alphas = [alpha.cuda() for alpha in alphas]
        alphas = [Variable(alpha) for alpha in alphas]

        outputs = [
            F.softmax(classifier(hidden), dim=1) for classifier in classifiers
        ]
        output = sum([ alpha.unsqueeze(1).repeat(1, 2) * output_i \
                        for (alpha, output_i) in zip(alphas, outputs) ])
        pred = output.data.max(dim=1)[1]
        oracle_eq = compute_oracle(outputs, label, args)

        if args.eval_only:
            for i in range(batch.shape[0]):
                for j in range(len(alphas)):
                    say("{:.4f}: [{:.4f}, {:.4f}], ".format(
                        alphas[j].data[i], outputs[j].data[i][0],
                        outputs[j].data[i][1]))
                oracle_TF = "T" if oracle_eq[i] == 1 else colored("F", 'red')
                say("gold: {}, pred: {}, oracle: {}\n".format(
                    label[i], pred[i], oracle_TF))
            say("\n")
            # print torch.cat(
            #         [
            #             torch.cat([ x.unsqueeze(1) for x in alphas ], 1),
            #             torch.cat([ x for x in outputs ], 1)
            #         ], 1
            #     )

        y_true += label.tolist()
        y_pred += pred.tolist()
        correct += pred.eq(label).sum()
        oracle_correct += oracle_eq.sum()
        tot_cnt += output.size(0)

    acc = float(correct) / tot_cnt
    oracle_acc = float(oracle_correct) / tot_cnt
    return (acc, oracle_acc), confusion_matrix(y_true, y_pred)
Пример #7
0
def train_epoch(iter_cnt, encoder, classifiers, critic, mats, data_loaders,
                args, optim_model):
    map(lambda m: m.train(), [encoder, critic] + classifiers)

    train_loaders, unl_loader, valid_loader = data_loaders
    dup_train_loaders = deepcopy(train_loaders)

    mtl_criterion = nn.CrossEntropyLoss()
    moe_criterion = nn.NLLLoss()  # with log_softmax separated
    kl_criterion = nn.MSELoss()
    entropy_criterion = HLoss()

    if args.metric == "biaffine":
        metric = biaffine_metric
        Us, Ws, Vs = mats
    else:
        metric = mahalanobis_metric
        Us, Ps, Ns = mats

    for batches, unl_batch in zip(zip(*train_loaders), unl_loader):
        train_batches, train_labels = zip(*batches)
        unl_critic_batch, unl_critic_label = unl_batch

        iter_cnt += 1
        if args.cuda:
            train_batches = [batch.cuda() for batch in train_batches]
            train_labels = [label.cuda() for label in train_labels]

            unl_critic_batch = unl_critic_batch.cuda()
            unl_critic_label = unl_critic_label.cuda()

        train_batches = [Variable(batch) for batch in train_batches]
        train_labels = [Variable(label) for label in train_labels]
        unl_critic_batch = Variable(unl_critic_batch)
        unl_critic_label = Variable(unl_critic_label)

        optim_model.zero_grad()
        loss_mtl = []
        loss_moe = []
        loss_kl = []
        loss_entropy = []
        loss_dan = []

        ms_outputs = []  # (n_sources, n_classifiers)
        hiddens = []
        hidden_corresponding_labels = []
        # labels = []
        for i, (batch, label) in enumerate(zip(train_batches, train_labels)):
            hidden = encoder(batch)
            outputs = []
            # create output matrix:
            #     - (i, j) indicates the output of i'th source batch using j'th classifier
            hiddens.append(hidden)
            for classifier in classifiers:
                output = classifier(hidden)
                outputs.append(output)
            ms_outputs.append(outputs)
            hidden_corresponding_labels.append(label)
            # multi-task loss
            loss_mtl.append(mtl_criterion(ms_outputs[i][i], label))
            # labels.append(label)

            if args.lambda_critic > 0:
                # critic_batch = torch.cat([batch, unl_critic_batch])
                critic_label = torch.cat(
                    [1 - unl_critic_label, unl_critic_label])
                # critic_label = torch.cat([1 - unl_critic_label] * len(train_batches) + [unl_critic_label])

                if isinstance(critic, ClassificationD):
                    critic_output = critic(
                        torch.cat(hidden, encoder(unl_critic_batch)))
                    loss_dan.append(
                        critic.compute_loss(critic_output, critic_label))
                else:
                    critic_output = critic(hidden, encoder(unl_critic_batch))
                    loss_dan.append(critic_output)

                    # critic_output = critic(torch.cat(hiddens), encoder(unl_critic_batch))
                    # loss_dan = critic_output
            else:
                loss_dan = Variable(torch.FloatTensor([0]))

        # assert (len(outputs) == len(outputs[0]))
        source_ids = range(len(train_batches))
        for i in source_ids:

            support_ids = [x for x in source_ids if x != i]  # experts

            # support_alphas = [ metric(
            #                      hiddens[i],
            #                      hiddens[j].detach(),
            #                      hidden_corresponding_labels[j],
            #                      Us[j], Ps[j], Ns[j],
            #                      args) for j in support_ids ]

            if args.metric == "biaffine":
                source_alphas = [
                    metric(
                        hiddens[i],
                        hiddens[j].detach(),
                        Us[0],
                        Ws[0],
                        Vs[0],  # for biaffine metric, we use a unified matrix
                        args) for j in source_ids
                ]
            else:
                source_alphas = [
                    metric(hiddens[i], hiddens[j].detach(),
                           hidden_corresponding_labels[j], Us[j], Ps[j], Ns[j],
                           args) for j in source_ids
                ]

            support_alphas = [source_alphas[x] for x in support_ids]

            # print torch.cat([ x.unsqueeze(1) for x in support_alphas ], 1)
            support_alphas = softmax(support_alphas)

            # meta-supervision: KL loss over \alpha and real source
            source_alphas = softmax(source_alphas)  # [ 32, 32, 32 ]
            source_labels = [torch.FloatTensor([x == i])
                             for x in source_ids]  # one-hot
            if args.cuda:
                source_alphas = [alpha.cuda() for alpha in source_alphas]
                source_labels = [label.cuda() for label in source_labels]

            source_labels = Variable(torch.stack(source_labels, dim=0))  # 3*1
            source_alphas = torch.stack(source_alphas, dim=0)
            source_labels = source_labels.expand_as(source_alphas).permute(
                1, 0)
            source_alphas = source_alphas.permute(1, 0)
            loss_kl.append(kl_criterion(source_alphas, source_labels))

            # entropy loss over \alpha
            # entropy_loss = entropy_criterion(torch.stack(support_alphas, dim=0).permute(1, 0))
            # print source_alphas
            loss_entropy.append(entropy_criterion(source_alphas))

            output_moe_i = sum([ alpha.unsqueeze(1).repeat(1, 2) * F.softmax(ms_outputs[i][id], dim=1) \
                                    for alpha, id in zip(support_alphas, support_ids) ])
            # output_moe_full = sum([ alpha.unsqueeze(1).repeat(1, 2) * F.softmax(ms_outputs[i][id], dim=1) \
            #                         for alpha, id in zip(full_alphas, source_ids) ])

            loss_moe.append(
                moe_criterion(torch.log(output_moe_i), train_labels[i]))
            # loss_moe.append(moe_criterion(torch.log(output_moe_full), train_labels[i]))

        loss_mtl = sum(loss_mtl)
        loss_moe = sum(loss_moe)
        # if iter_cnt < 400:
        #     lambda_moe = 0
        #     lambda_entropy = 0
        # else:
        lambda_moe = args.lambda_moe
        lambda_entropy = args.lambda_entropy
        # loss = (1 - lambda_moe) * loss_mtl + lambda_moe * loss_moe
        loss = loss_mtl + lambda_moe * loss_moe
        loss_kl = sum(loss_kl)
        loss_entropy = sum(loss_entropy)
        loss += args.lambda_entropy * loss_entropy

        if args.lambda_critic > 0:
            loss_dan = sum(loss_dan)
            loss += args.lambda_critic * loss_dan

        loss.backward()
        optim_model.step()

        if valid_loader and iter_cnt % 30 == 0:
            # [(mu_i, covi_i), ...]
            # domain_encs = domain_encoding(dup_train_loaders, args, encoder)
            if args.metric == "biaffine":
                mats = [Us, Ws, Vs]
            else:
                mats = [Us, Ps, Ns]

            (curr_dev, oracle_curr_dev), confusion_mat = evaluate(
                encoder, classifiers, mats, [dup_train_loaders, valid_loader],
                args)

            # say("\r" + " " * 50)
            # TODO: print train acc as well
            say("{} MTL loss: {:.4f}, MOE loss: {:.4f}, DAN loss: {:.4f}, "
                "loss: {:.4f}, dev acc/oracle: {:.4f}/{:.4f}\n".format(
                    iter_cnt, loss_mtl.data[0], loss_moe.data[0],
                    loss_dan.data[0], loss.data[0], curr_dev, oracle_curr_dev))

    say("\n")
    return iter_cnt
Пример #8
0
def visualize(args):
    if args.mop == 3:
        encoder, classifiers, source_classifier = torch.load(args.load_model)
    elif args.mop == 2:
        encoder, classifiers, Us, Ps, Ns = torch.load(args.load_model)
    else:
        say("\nUndefined --mop\n")
        return

    map(lambda m: m.eval(), [encoder] + classifiers)
    if args.cuda:
        map(lambda m: m.cuda(), [encoder] + classifiers)

    source_train_sets = args.train.split(',')
    train_loaders = []
    for source in source_train_sets:
        filepath = os.path.join(DATA_DIR, "%s_train.svmlight" % (source))
        train_dataset = AmazonDataset(filepath)
        train_loader = data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=0
        )
        train_loaders.append(train_loader)

    test_filepath = os.path.join(DATA_DIR, "%s_train.svmlight" % (args.test))
    test_dataset = AmazonDataset(test_filepath)
    test_loader = data.DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=0
    )
    say("Corpus loaded.\n")

    source_hs = []
    source_ys = []
    source_num = []
    for loader in train_loaders:
        encoding_vecs = torch.FloatTensor()
        labels = torch.LongTensor()
        if args.cuda:
            encoding_vecs = encoding_vecs.cuda()
            labels = labels.cuda()

        for batch, label in loader:
            if args.cuda:
                batch = batch.cuda()
                label = label.cuda()

            batch = Variable(batch)
            hidden = encoder(batch)
            encoding_vecs = torch.cat([encoding_vecs, hidden.data])
            labels = torch.cat([labels, label.view(-1, 1)])

        source_hs.append(encoding_vecs)
        source_ys.append(labels)
        source_num.append(labels.shape[0])

    ht = torch.FloatTensor()
    yt = torch.LongTensor()
    if args.cuda:
        ht = ht.cuda()
        yt = yt.cuda()

    for batch, label in test_loader:
        if args.cuda:
            batch = batch.cuda()
            label = label.cuda()

        batch = Variable(batch)
        hidden = encoder(batch)
        ht = torch.cat([ht, hidden.data])
        yt = torch.cat([yt, label.view(-1, 1)])

    h_both = torch.cat(source_hs + [ht]).cpu().numpy()
    y_both = torch.cat(source_ys + [yt]).cpu().numpy()

    say("Dimension reduction...\n")
    tsne = TSNE(perplexity=30, n_components=2, n_iter=3300)
    vdata = tsne.fit_transform(h_both)
    print vdata.shape, source_num
    torch.save([vdata, y_both, source_num], 'vis/%s-%s-mop%d.vdata' % (args.train, args.test, args.mop))
    ms_plot_embedding_sep(vdata, y_both, source_num, args.save_image)