示例#1
0
def train_epoch(iter_cnt, encoders, classifiers, attn_mats, train_loader_dst,
                args, optim_model, epoch):

    encoders, encoder_dst = encoders

    map(lambda m: m.train(), classifiers + encoders + attn_mats)

    moe_criterion = nn.NLLLoss()  # with log_softmax separated
    entropy_criterion = HLoss()

    loss_total = 0
    n_batch = 0
    n_sources = len(encoders)

    for batch in train_loader_dst:
        if args.base_model == "cnn":
            batch1, batch2, label = batch
        elif args.base_model == "rnn":
            batch1, batch2, batch3, batch4, label = batch
        else:
            raise NotImplementedError

        bs = len(label)

        iter_cnt += 1
        n_batch += 1
        if args.cuda:
            batch1 = batch1.cuda()
            batch2 = batch2.cuda()
            label = label.cuda()
            if args.base_model == "rnn":
                batch3 = batch3.cuda()
                batch4 = batch4.cuda()

        if args.base_model == "cnn":
            _, hidden_from_dst_enc = encoder_dst(batch1, batch2)
        elif args.base_model == "rnn":
            _, hidden_from_dst_enc = encoder_dst(batch1, batch2, batch3,
                                                 batch4)
        else:
            raise NotImplementedError

        outputs_dst_transfer = []
        hidden_from_src_enc = []
        one_hot_sources = []
        for src_i in range(n_sources):
            if args.base_model == "cnn":
                _, cur_hidden = encoders[src_i](batch1, batch2)
                hidden_from_src_enc.append(cur_hidden)
            elif args.base_model == "rnn":
                _, cur_hidden = encoders[src_i](batch1, batch2, batch3, batch4)
                hidden_from_src_enc.append(cur_hidden)
            else:
                raise NotImplementedError
            cur_output = classifiers[src_i](cur_hidden)
            outputs_dst_transfer.append(cur_output)
            cur_one_hot_sources = torch.zeros(size=(bs, n_sources))
            cur_one_hot_sources[:, src_i] = 1
            one_hot_sources.append(cur_one_hot_sources)
        # print("one hot sources", one_hot_sources)

        optim_model.zero_grad()

        source_ids = range(n_sources)
        support_ids = [x for x in source_ids]  # experts
        # print("attn mats", attn_mats)
        # source_alphas = [attn_mats[j](hidden_from_src_enc[j]).squeeze() for j in source_ids]
        if args.attn_type == "onehot":
            source_alphas = [
                attn_mats[j](one_hot_sources[j]).squeeze() for j in source_ids
            ]
        elif args.attn_type == "cor":
            source_alphas = [
                attn_mats[j](hidden_from_src_enc[j],
                             hidden_from_dst_enc).squeeze() for j in source_ids
            ]
        else:
            raise NotImplementedError

        # source_alphas = [attn_mats[j](hidden_from_src_enc[j], hidden_from_dst_enc).squeeze() for j in source_ids]
        # source_alphas = [torch.bmm(attn_mats[j](hidden_from_src_enc[j]).unsqueeze(1), hidden_from_dst_enc.unsqueeze(2)).squeeze() for j in source_ids]

        # print("source alphas", source_alphas[0].size(), source_alphas)

        support_alphas = [source_alphas[x] for x in support_ids]
        support_alphas = softmax(support_alphas)
        source_alphas = softmax(source_alphas)  # [ 32, 32, 32 ]

        if args.cuda:
            source_alphas = [alpha.cuda() for alpha in source_alphas]
        source_alphas = torch.stack(source_alphas, dim=0)
        source_alphas = source_alphas.permute(1, 0)

        loss_entropy = entropy_criterion(source_alphas)

        output_moe = sum([alpha.unsqueeze(1).repeat(1, 2) * F.softmax(outputs_dst_transfer[id], dim=1) \
                            for alpha, id in zip(support_alphas, support_ids)])
        loss_moe = moe_criterion(torch.log(output_moe), label)
        lambda_moe = args.lambda_moe
        loss = lambda_moe * loss_moe
        loss += args.lambda_entropy * loss_entropy
        loss_total += loss.item()
        loss.backward()
        optim_model.step()

        if iter_cnt % 5 == 0:
            say("{} MOE loss: {:.4f}, Entropy loss: {:.4f}, "
                "loss: {:.4f}\n".format(iter_cnt, loss_moe.item(),
                                        loss_entropy.item(), loss.data.item()))

    loss_total /= n_batch
    writer.add_scalar('training_loss', loss_total, epoch)

    say("\n")
    return iter_cnt
示例#2
0
def train_moe_deep_stack(args):
    save_model_dir = os.path.join(settings.OUT_DIR, args.test)
    classifiers, attn_mats = torch.load(
        os.path.join(
            save_model_dir,
            "{}_{}_moe_best_now.mdl".format(args.test, args.base_model)))
    print("base model", args.base_model)
    print("classifier", classifiers[0])

    source_train_sets = args.train.split(',')
    pretrain_emb = torch.load(
        os.path.join(settings.OUT_DIR, "rnn_init_word_emb.emb"))

    encoders_src = []
    for src_i in range(len(source_train_sets)):
        cur_model_dir = os.path.join(settings.OUT_DIR,
                                     source_train_sets[src_i])

        if args.base_model == "cnn":
            encoder_class = CNNMatchModel(
                input_matrix_size1=args.matrix_size1,
                input_matrix_size2=args.matrix_size2,
                mat1_channel1=args.mat1_channel1,
                mat1_kernel_size1=args.mat1_kernel_size1,
                mat1_channel2=args.mat1_channel2,
                mat1_kernel_size2=args.mat1_kernel_size2,
                mat1_hidden=args.mat1_hidden,
                mat2_channel1=args.mat2_channel1,
                mat2_kernel_size1=args.mat2_kernel_size1,
                mat2_hidden=args.mat2_hidden)
        elif args.base_model == "rnn":
            encoder_class = BiLSTM(pretrain_emb=pretrain_emb,
                                   vocab_size=args.max_vocab_size,
                                   embedding_size=args.embedding_size,
                                   hidden_size=args.hidden_size,
                                   dropout=args.dropout)
        else:
            raise NotImplementedError
        if args.cuda:
            encoder_class.load_state_dict(
                torch.load(
                    os.path.join(
                        cur_model_dir,
                        "{}-match-best-now.mdl".format(args.base_model))))
        else:
            encoder_class.load_state_dict(
                torch.load(os.path.join(
                    cur_model_dir,
                    "{}-match-best-now.mdl".format(args.base_model)),
                           map_location=torch.device('cpu')))

        encoders_src.append(encoder_class)

    map(lambda m: m.eval(), encoders_src + classifiers + attn_mats)

    if args.cuda:
        map(lambda m: m.cuda(), classifiers + encoders_src + attn_mats)

    if args.base_model == "cnn":
        train_dataset_dst = ProcessedCNNInputDataset(args.test, "train")
        valid_dataset = ProcessedCNNInputDataset(args.test, "valid")
        test_dataset = ProcessedCNNInputDataset(args.test, "test")
    elif args.base_model == "rnn":
        train_dataset_dst = ProcessedRNNInputDataset(args.test, "train")
        valid_dataset = ProcessedRNNInputDataset(args.test, "valid")
        test_dataset = ProcessedRNNInputDataset(args.test, "test")
    else:
        raise NotImplementedError

    train_loader_dst = data.DataLoader(train_dataset_dst,
                                       batch_size=args.batch_size,
                                       shuffle=False,
                                       num_workers=0)

    valid_loader = data.DataLoader(valid_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   num_workers=0)

    test_loader = data.DataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=0)
    say("Corpus loaded.\n")

    meta_features = np.empty(shape=(0, 192 + 2 * 8))
    meta_labels = []
    n_sources = len(encoders_src)
    encoders = encoders_src

    if args.base_model == "cnn":
        for batch1, batch2, label in train_loader_dst:
            if args.cuda:
                batch1 = batch1.cuda()
                batch2 = batch2.cuda()
                label = label.cuda()

            outputs_dst_transfer = []
            hidden_from_src_enc = []
            for src_i in range(n_sources):
                _, cur_hidden = encoders[src_i](batch1, batch2)
                hidden_from_src_enc.append(cur_hidden)
                cur_output = classifiers[src_i](cur_hidden)
                outputs_dst_transfer.append(cur_output)

            source_ids = range(n_sources)
            support_ids = [x for x in source_ids]  # experts

            source_alphas = [
                attn_mats[j](hidden_from_src_enc[j]).squeeze()
                for j in source_ids
            ]

            support_alphas = [source_alphas[x] for x in support_ids]
            support_alphas = softmax(support_alphas)
            source_alphas = softmax(source_alphas)  # [ 32, 32, 32 ]
            alphas = source_alphas
示例#3
0
def evaluate_cross(encoder,
                   classifiers,
                   mats,
                   loaders,
                   return_best_thrs,
                   args,
                   thr=None):
    ''' Evaluate model using MOE
    '''
    map(lambda m: m.eval(), [encoder] + classifiers)

    if args.metric == "biaffine":
        Us, Ws, Vs = mats
    else:
        Us, Ps, Ns = mats

    source_loaders, valid_loaders_src = loaders
    domain_encs = domain_encoding(source_loaders, args, encoder)

    source_ids = range(len(valid_loaders_src))

    thresholds = []
    metrics = []
    alphas_weights = np.zeros(shape=(4, 4))

    for src_i in range(len(valid_loaders_src)):
        valid_loader = valid_loaders_src[src_i]

        oracle_correct = 0
        correct = 0
        tot_cnt = 0
        y_true = []
        y_pred = []
        y_score = []

        # support_ids = [x for x in source_ids if x != src_i]  # experts
        support_ids = [x for x in source_ids]  # experts
        cur_domain_encs = [domain_encs[x] for x in support_ids]
        cur_Us = [Us[x] for x in support_ids]
        cur_Ps = [Ps[x] for x in support_ids]
        cur_Ns = [Ns[x] for x in support_ids]

        cur_alpha_weights = [[]] * 4
        cur_alpha_weights_stack = np.empty(shape=(0, len(support_ids)))

        for batch1, batch2, label in valid_loader:
            if args.cuda:
                batch1 = batch1.cuda()
                batch2 = batch2.cuda()
                label = label.cuda()
            # print("eval labels", label)

            batch1 = Variable(batch1)
            batch2 = Variable(batch2)
            _, hidden = encoder(batch1, batch2)
            # source_ids = range(len(domain_encs))
            if args.metric == "biaffine":
                alphas = [biaffine_metric_fast(hidden, mu[0], Us[0]) \
                          for mu in domain_encs]
            else:
                alphas = [mahalanobis_metric_fast(hidden, mu[0], U, mu[1], P, mu[2], N) \
                          for (mu, U, P, N) in zip(cur_domain_encs, cur_Us, cur_Ps, cur_Ns)]
            # alphas = [ (1 - x / sum(alphas)) for x in alphas ]
            alphas = softmax(alphas)
            # print("alphas", alphas[0].mean(), alphas[1].mean(), alphas[2].mean())
            # print("alphas", alphas)

            alphas = []
            for al_i in range(len(support_ids)):
                alphas.append(torch.zeros(size=(batch1.size()[0], )))
            alphas[src_i] = torch.ones(size=(batch1.size()[0], ))

            alpha_cat = torch.zeros(size=(alphas[0].shape[0],
                                          len(support_ids)))
            for col, a_list in enumerate(alphas):
                alpha_cat[:, col] = a_list
            cur_alpha_weights_stack = np.concatenate(
                (cur_alpha_weights_stack, alpha_cat.detach().numpy()))
            # for j, supp_id in enumerate(support_ids):
            # cur_alpha_weights[supp_id] += alphas[j].data.tolist()
            # cur_alpha_weights[supp_id].append(alphas[j].mean().item())
            if args.cuda:
                alphas = [alpha.cuda() for alpha in alphas]
            alphas = [Variable(alpha) for alpha in alphas]

            outputs = [
                F.softmax(classifiers[j](hidden), dim=1) for j in support_ids
            ]
            output = sum([alpha.unsqueeze(1).repeat(1, 2) * output_i \
                          for (alpha, output_i) in zip(alphas, outputs)])
            # print("pred output", output)
            pred = output.data.max(dim=1)[1]
            oracle_eq = compute_oracle(outputs, label, args)

            if args.eval_only:
                for i in range(batch1.shape[0]):
                    for j in range(len(alphas)):
                        say("{:.4f}: [{:.4f}, {:.4f}], ".format(
                            alphas[j].data[i], outputs[j].data[i][0],
                            outputs[j].data[i][1]))
                    oracle_TF = "T" if oracle_eq[i] == 1 else colored(
                        "F", 'red')
                    say("gold: {}, pred: {}, oracle: {}\n".format(
                        label[i], pred[i], oracle_TF))
                say("\n")
                # print torch.cat(
                #         [
                #             torch.cat([ x.unsqueeze(1) for x in alphas ], 1),
                #             torch.cat([ x for x in outputs ], 1)
                #         ], 1
                #     )

            y_true += label.tolist()
            y_pred += pred.tolist()
            y_score += output[:, 1].data.tolist()
            correct += pred.eq(label).sum()
            oracle_correct += oracle_eq.sum()
            tot_cnt += output.size(0)

        # print("y_true", y_true)
        # print("y_pred", y_pred)

        # for j in support_ids:
        #     print(src_i, j, cur_alpha_weights[j])
        #     alphas_weights[src_i, j] = np.mean(cur_alpha_weights[j])
        # print(alphas_weights)
        alphas_weights[src_i, support_ids] = np.mean(cur_alpha_weights_stack,
                                                     axis=0)

        if thr is not None:
            print("using threshold %.4f" % thr[src_i])
            y_score = np.array(y_score)
            y_pred = np.zeros_like(y_score)
            y_pred[y_score > thr[src_i]] = 1

        # prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary")

        acc = float(correct) / tot_cnt
        oracle_acc = float(oracle_correct) / tot_cnt
        # print("source", src_i, "validation results: precision: {:.2f}, recall: {:.2f}, f1: {:.2f}".format(
        #     prec*100, rec*100, f1*100))
        # return (acc, oracle_acc), confusion_matrix(y_true, y_pred)

        prec, rec, f1, _ = precision_recall_fscore_support(y_true,
                                                           y_pred,
                                                           average="binary")
        auc = roc_auc_score(y_true, y_score)
        print("source {}, AUC: {:.2f}, Prec: {:.2f}, Rec: {:.2f}, F1: {:.2f}".
              format(src_i, auc * 100, prec * 100, rec * 100, f1 * 100))

        metrics.append([auc, prec, rec, f1])

        if return_best_thrs:
            precs, recs, thrs = precision_recall_curve(y_true, y_score)
            f1s = 2 * precs * recs / (precs + recs)
            f1s = f1s[:-1]
            thrs = thrs[~np.isnan(f1s)]
            f1s = f1s[~np.isnan(f1s)]
            best_thr = thrs[np.argmax(f1s)]
            print("best threshold=%4f, f1=%.4f", best_thr, np.max(f1s))
            thresholds.append(best_thr)

    print("source domain weight matrix\n", alphas_weights)

    metrics = np.array(metrics)
    return thresholds, metrics, alphas_weights
示例#4
0
def evaluate(epoch,
             encoders,
             classifiers,
             attn_mats,
             data_loader,
             return_best_thrs,
             args,
             thr=None):
    encoders, encoder_dst = encoders

    map(lambda m: m.eval(), encoders + classifiers + attn_mats)

    oracle_correct = 0
    correct = 0
    tot_cnt = 0
    y_true = []
    y_pred = []
    y_score = []
    loss = 0.

    n_sources = len(encoders)

    cur_alpha_weights_stack = np.empty(shape=(0, n_sources))
    if args.base_model == "cnn":
        for batch1, batch2, label in data_loader:
            if args.cuda:
                batch1 = batch1.cuda()
                batch2 = batch2.cuda()
                label = label.cuda()

            batch1 = Variable(batch1)
            batch2 = Variable(batch2)
            bs = len(batch1)

            _, hidden_from_dst_enc = encoder_dst(batch1, batch2)

            outputs_dst_transfer = []
            hidden_from_src_enc = []
            one_hot_sources = []

            for src_i in range(n_sources):
                _, cur_hidden = encoders[src_i](batch1, batch2)
                hidden_from_src_enc.append(cur_hidden)
                cur_output = classifiers[src_i](cur_hidden)
                outputs_dst_transfer.append(cur_output)
                cur_one_hot_sources = torch.zeros(size=(bs, n_sources))
                cur_one_hot_sources[:, src_i] = 1
                one_hot_sources.append(cur_one_hot_sources)

            source_ids = range(n_sources)
            support_ids = [x for x in source_ids]  # experts

            # source_alphas = [attn_mats[j](hidden_from_src_enc[j]).squeeze() for j in source_ids]
            # source_alphas = [attn_mats[j](one_hot_sources[j]).squeeze() for j in source_ids]

            if args.attn_type == "onehot":
                source_alphas = [
                    attn_mats[j](one_hot_sources[j]).squeeze()
                    for j in source_ids
                ]
            elif args.attn_type == "cor":
                source_alphas = [
                    attn_mats[j](hidden_from_src_enc[j],
                                 hidden_from_dst_enc).squeeze()
                    for j in source_ids
                ]
            else:
                raise NotImplementedError

            # source_alphas = [
            #     torch.bmm(attn_mats[j](hidden_from_src_enc[j]).unsqueeze(1), hidden_from_dst_enc.unsqueeze(2)).squeeze()
            #     for j in source_ids]
            # source_alphas = [attn_mats[j](hidden_from_src_enc[j], hidden_from_dst_enc).squeeze() for j in source_ids]

            support_alphas = [source_alphas[x] for x in support_ids]
            support_alphas = softmax(support_alphas)
            source_alphas = softmax(source_alphas)  # [ 32, 32, 32 ]
            alphas = source_alphas
            if args.cuda:
                alphas = [alpha.cuda() for alpha in alphas]

            outputs = [F.softmax(out, dim=1) for out in outputs_dst_transfer]

            alpha_cat = torch.zeros(size=(alphas[0].shape[0], n_sources))
            for col, a_list in enumerate(alphas):
                alpha_cat[:, col] = a_list

            cur_alpha_weights_stack = np.concatenate(
                (cur_alpha_weights_stack, alpha_cat.detach().numpy()))
            output = sum([alpha.unsqueeze(1).repeat(1, 2) * output_i \
                          for (alpha, output_i) in zip(alphas, outputs)])
            pred = output.data.max(dim=1)[1]

            loss_batch = F.nll_loss(torch.log(output), label)
            loss += bs * loss_batch.item()

            y_true += label.tolist()
            y_pred += pred.tolist()
            correct += pred.eq(label).sum()
            tot_cnt += output.size(0)
            y_score += output[:, 1].data.tolist()

    elif args.base_model == "rnn":
        for batch1, batch2, batch3, batch4, label in data_loader:
            if args.cuda:
                batch1 = batch1.cuda()
                batch2 = batch2.cuda()
                batch3 = batch3.cuda()
                batch4 = batch4.cuda()
                label = label.cuda()

            bs = len(batch1)

            _, hidden_from_dst_enc = encoder_dst(batch1, batch2, batch3,
                                                 batch4)

            outputs_dst_transfer = []
            hidden_from_src_enc = []
            one_hot_sources = []

            for src_i in range(n_sources):
                _, cur_hidden = encoders[src_i](batch1, batch2, batch3, batch4)
                hidden_from_src_enc.append(cur_hidden)
                cur_output = classifiers[src_i](cur_hidden)
                outputs_dst_transfer.append(cur_output)
                cur_one_hot_sources = torch.zeros(size=(bs, n_sources))
                cur_one_hot_sources[:, src_i] = 1
                one_hot_sources.append(cur_one_hot_sources)

            source_ids = range(n_sources)
            support_ids = [x for x in source_ids]  # experts

            # source_alphas = [attn_mats[j](hidden_from_src_enc[j]).squeeze() for j in source_ids]
            # source_alphas = [
            #     torch.bmm(attn_mats[j](hidden_from_src_enc[j]).unsqueeze(1), hidden_from_dst_enc.unsqueeze(2)).squeeze()
            #     for j in source_ids]
            # source_alphas = [attn_mats[j](hidden_from_src_enc[j], hidden_from_dst_enc).squeeze() for j in source_ids]
            # source_alphas = [attn_mats[j](one_hot_sources[j]).squeeze() for j in source_ids]

            if args.attn_type == "onehot":
                source_alphas = [
                    attn_mats[j](one_hot_sources[j]).squeeze()
                    for j in source_ids
                ]
            elif args.attn_type == "cor":
                source_alphas = [
                    attn_mats[j](hidden_from_src_enc[j],
                                 hidden_from_dst_enc).squeeze()
                    for j in source_ids
                ]
            else:
                raise NotImplementedError

            support_alphas = [source_alphas[x] for x in support_ids]
            support_alphas = softmax(support_alphas)
            source_alphas = softmax(source_alphas)  # [ 32, 32, 32 ]
            alphas = source_alphas
            if args.cuda:
                alphas = [alpha.cuda() for alpha in alphas]

            outputs = [F.softmax(out, dim=1) for out in outputs_dst_transfer]

            alpha_cat = torch.zeros(size=(alphas[0].shape[0], n_sources))
            for col, a_list in enumerate(alphas):
                alpha_cat[:, col] = a_list

            cur_alpha_weights_stack = np.concatenate(
                (cur_alpha_weights_stack, alpha_cat.detach().numpy()))
            output = sum([alpha.unsqueeze(1).repeat(1, 2) * output_i \
                          for (alpha, output_i) in zip(alphas, outputs)])
            pred = output.data.max(dim=1)[1]

            loss_batch = F.nll_loss(torch.log(output), label)
            loss += bs * loss_batch.item()

            y_true += label.tolist()
            y_pred += pred.tolist()
            correct += pred.eq(label).sum()
            tot_cnt += output.size(0)
            y_score += output[:, 1].data.tolist()

    else:
        raise NotImplementedError

    alpha_weights = np.mean(cur_alpha_weights_stack, axis=0)
    print("alpha weights", alpha_weights)

    if thr is not None:
        print("using threshold %.4f" % thr)
        y_score = np.array(y_score)
        y_pred = np.zeros_like(y_score)
        y_pred[y_score > thr] = 1
    else:
        pass

    loss /= tot_cnt

    prec, rec, f1, _ = precision_recall_fscore_support(y_true,
                                                       y_pred,
                                                       average="binary")
    auc = roc_auc_score(y_true, y_score)
    print("Loss: {:.4f}, AUC: {:.2f}, Prec: {:.2f}, Rec: {:.2f}, F1: {:.2f}".
          format(loss, auc * 100, prec * 100, rec * 100, f1 * 100))

    best_thr = None
    metric = [loss, auc, prec, rec, f1]

    if return_best_thrs:
        precs, recs, thrs = precision_recall_curve(y_true, y_score)
        f1s = 2 * precs * recs / (precs + recs)
        f1s = f1s[:-1]
        thrs = thrs[~np.isnan(f1s)]
        f1s = f1s[~np.isnan(f1s)]
        best_thr = thrs[np.argmax(f1s)]
        print("best threshold={:.4f}, f1={:.4f}".format(best_thr, np.max(f1s)))

        writer.add_scalar('val_loss', loss, epoch)
    else:
        writer.add_scalar('test_f1', f1, epoch)

    return best_thr, metric
示例#5
0
def evaluate(epoch,
             encoders,
             classifiers,
             mats,
             loaders,
             return_best_thrs,
             args,
             thr=None):
    ''' Evaluate model using MOE
    '''

    encoders, encoder_dst = encoders
    classifiers, classifier_dst, classifier_mix = classifiers

    map(lambda m: m.eval(),
        encoders + classifiers + [encoder_dst, classifier_dst, classifier_mix])

    if args.metric == "biaffine":
        Us, Ws, Vs = mats
    else:
        Us, Ps, Ns = mats

    source_loaders, valid_loader = loaders
    domain_encs = domain_encoding(source_loaders, args, encoders)

    oracle_correct = 0
    correct = 0
    tot_cnt = 0
    y_true = []
    y_pred = []
    y_score = []

    loss = 0.

    source_ids = range(len(domain_encs))

    for batch1, batch2, label in valid_loader:
        if args.cuda:
            batch1 = batch1.cuda()
            batch2 = batch2.cuda()
            label = label.cuda()
        # print("eval labels", label)

        batch1 = Variable(batch1)
        batch2 = Variable(batch2)
        bs = len(batch1)
        # print("bs", len(batch1))

        _, hidden_dst = encoder_dst(batch1, batch2)
        cur_output_dst = classifier_dst(hidden_dst)
        cur_output_dst_mem = torch.softmax(cur_output_dst, dim=1)
        # print("mem", cur_output_dst_mem)
        cur_output_dst = torch.log(cur_output_dst_mem)

        outputs_dst_transfer = []
        for src_i in range(len(source_loaders)):
            _, cur_hidden = encoders[src_i](batch1, batch2)
            cur_output = classifiers[src_i](cur_hidden)
            outputs_dst_transfer.append(cur_output)

        # _, hidden = encoders[0](batch1, batch2)
        # source_ids = range(len(domain_encs))
        if args.metric == "biaffine":
            alphas = [biaffine_metric_fast(hidden_dst, mu[0], Us[0]) \
                      for mu in domain_encs]
        else:
            alphas = [mahalanobis_metric_fast(hidden_dst, mu[0], U, mu[1], P, mu[2], N) \
                      for (mu, U, P, N) in zip(domain_encs, Us, Ps, Ns)]
        # # alphas = [ (1 - x / sum(alphas)) for x in alphas ]
        alphas = softmax(alphas)
        if args.cuda:
            alphas = [alpha.cuda() for alpha in alphas]
        alphas = [Variable(alpha) for alpha in alphas]
        #
        # outputs = [F.softmax(classifier(hidden), dim=1) for classifier in classifiers]
        output_moe = sum([alpha.unsqueeze(1).repeat(1, 2) * output_i \
                      for (alpha, output_i) in zip(alphas, outputs_dst_transfer)])
        # pred = output.data.max(dim=1)[1]
        # oracle_eq = compute_oracle(outputs, label, args)

        # outputs = classifier_mix(torch.cat((cur_output_dst_mem, output_moe), dim=1))
        outputs = cur_output_dst_mem + classifier_mix.multp * output_moe
        # print("weight mix", classifier_mix.multp)
        outputs_upper_logits = torch.log_softmax(outputs, dim=1)
        # outputs_upper_logits = torch.log(cur_output_dst_mem)
        outputs_upper_logits = output_moe
        # print("outputs_upper_logits", outputs_upper_logits)
        pred = outputs_upper_logits.data.max(dim=1)[1]
        # oracle_eq = compute_oracle(outputs_upper_logits, label, args)

        loss_batch = F.nll_loss(outputs_upper_logits, label)
        loss += bs * loss_batch.item()

        # if args.eval_only:
        #     for i in range(batch1.shape[0]):
        #         for j in range(len(alphas)):
        #             say("{:.4f}: [{:.4f}, {:.4f}], ".format(
        #                 alphas[j].data[i], outputs[j].data[i][0], outputs[j].data[i][1])
        #             )
        #         oracle_TF = "T" if oracle_eq[i] == 1 else colored("F", 'red')
        #         say("gold: {}, pred: {}, oracle: {}\n".format(label[i], pred[i], oracle_TF))
        #     say("\n")
        # print torch.cat(
        #         [
        #             torch.cat([ x.unsqueeze(1) for x in alphas ], 1),
        #             torch.cat([ x for x in outputs ], 1)
        #         ], 1
        #     )

        y_true += label.tolist()
        y_pred += pred.tolist()
        # print("output", output[:, 1].data.tolist())
        y_score += outputs_upper_logits[:, 1].data.tolist()
        # print("cur y score", y_score)

        correct += pred.eq(label).sum()
        # oracle_correct += oracle_eq.sum()
        tot_cnt += outputs_upper_logits.size(0)

    # print("y_true", y_true)
    # print("y_pred", y_pred)

    if thr is not None:
        print("using threshold %.4f" % thr)
        y_score = np.array(y_score)
        y_pred = np.zeros_like(y_score)
        y_pred[y_score > thr] = 1
    else:
        # print("y_score", y_score)
        pass

    loss /= tot_cnt

    prec, rec, f1, _ = precision_recall_fscore_support(y_true,
                                                       y_pred,
                                                       average="binary")
    # print("y_score", y_score)
    auc = roc_auc_score(y_true, y_score)
    print("Loss: {:.4f}, AUC: {:.2f}, Prec: {:.2f}, Rec: {:.2f}, F1: {:.2f}".
          format(loss, auc * 100, prec * 100, rec * 100, f1 * 100))

    best_thr = None
    metric = [auc, prec, rec, f1]

    if return_best_thrs:
        precs, recs, thrs = precision_recall_curve(y_true, y_score)
        f1s = 2 * precs * recs / (precs + recs)
        f1s = f1s[:-1]
        thrs = thrs[~np.isnan(f1s)]
        f1s = f1s[~np.isnan(f1s)]
        best_thr = thrs[np.argmax(f1s)]
        print("best threshold={:.4f}, f1={:.4f}".format(best_thr, np.max(f1s)))

        writer.add_scalar('val_loss', loss, epoch)
    else:
        writer.add_scalar('test_f1', f1, epoch)

    acc = float(correct) / tot_cnt
    oracle_acc = float(oracle_correct) / tot_cnt
    # return (acc, oracle_acc), confusion_matrix(y_true, y_pred)
    return best_thr, metric
示例#6
0
def train_epoch(iter_cnt, encoders, classifiers, critic, mats, data_loaders,
                args, optim_model, epoch):
    encoders, encoder_dst = encoders
    classifiers, classifier_dst, classifier_mix = classifiers
    map(
        lambda m: m.train(), encoders +
        [encoder_dst, classifier_dst, critic, classifier_mix] + classifiers)

    train_loaders, train_loader_dst, unl_loader, valid_loader = data_loaders
    dup_train_loaders = deepcopy(train_loaders)

    # mtl_criterion = nn.CrossEntropyLoss()
    mtl_criterion = nn.NLLLoss()
    moe_criterion = nn.NLLLoss()  # with log_softmax separated
    kl_criterion = nn.MSELoss()
    entropy_criterion = HLoss()

    if args.metric == "biaffine":
        metric = biaffine_metric
        Us, Ws, Vs = mats
    else:
        metric = mahalanobis_metric
        Us, Ps, Ns = mats

    loss_total = 0
    total = 0

    for batches, batches_dst, unl_batch in zip(zip(*train_loaders),
                                               train_loader_dst, unl_loader):
        train_batches1, train_batches2, train_labels = zip(*batches)
        # print("train batches1", train_labels[0].size())
        # print("train batches2", train_batches2)
        # print("train labels", train_labels)
        unl_critic_batch1, unl_critic_batch2, unl_critic_label = unl_batch
        # print("unl", unl_critic_batch1)
        batches1_dst, batches2_dst, labels_dst = batches_dst
        # print("batches1_dst", batches1_dst)
        # print("batches2_dst", batches2_dst)

        total += len(batches1_dst)

        iter_cnt += 1
        if args.cuda:
            train_batches1 = [batch.cuda() for batch in train_batches1]
            train_batches2 = [batch.cuda() for batch in train_batches2]
            train_labels = [label.cuda() for label in train_labels]

            batches1_dst = batches1_dst.cuda()
            batches2_dst = batches2_dst.cuda()
            labels_dst = labels_dst.cuda()

            unl_critic_batch1 = unl_critic_batch1.cuda()
            unl_critic_batch2 = unl_critic_batch2.cuda()
            unl_critic_label = unl_critic_label.cuda()

        # train_batches1 = [Variable(batch) for batch in train_batches1]
        # train_batches2 = [Variable(batch) for batch in train_batches2]
        # train_labels = [Variable(label) for label in train_labels]
        # unl_critic_batch1 = Variable(unl_critic_batch1)
        # unl_critic_batch2 = Variable(unl_critic_batch2)
        # unl_critic_label = Variable(unl_critic_label)

        optim_model.zero_grad()
        loss_train_dst = []
        loss_mtl = []
        loss_moe = []
        loss_kl = []
        loss_entropy = []
        loss_dan = []
        loss_all = []

        ms_outputs = []  # (n_sources, n_classifiers)
        hiddens = []
        hidden_corresponding_labels = []
        # labels = []

        _, hidden_dst = encoder_dst(batches1_dst, batches2_dst)
        cur_output_dst = classifier_dst(hidden_dst)
        cur_output_dst_mem = torch.softmax(cur_output_dst, dim=1)
        cur_output_dst = torch.log(cur_output_dst_mem)
        loss_train_dst.append(mtl_criterion(cur_output_dst, labels_dst))

        outputs_dst_transfer = []
        for i in range(len(train_batches1)):
            _, cur_hidden = encoders[i](batches1_dst, batches2_dst)
            cur_output = classifiers[i](cur_hidden)
            outputs_dst_transfer.append(cur_output)

        for i, (batch1, batch2, label) in enumerate(
                zip(train_batches1, train_batches2, train_labels)):  # source i
            _, hidden = encoders[i](batch1, batch2)
            outputs = []
            # create output matrix:
            #     - (i, j) indicates the output of i'th source batch using j'th classifier
            # print("hidden", hidden)
            # raise
            hiddens.append(hidden)
            for classifier in classifiers:
                output = classifier(hidden)
                output = torch.log_softmax(output, dim=1)
                # print("output", output)
                outputs.append(output)
            ms_outputs.append(outputs)
            hidden_corresponding_labels.append(label)
            # multi-task loss
            # print("ms & label", ms_outputs[i][i], label)
            loss_mtl.append(mtl_criterion(ms_outputs[i][i], label))
            # labels.append(label)

            if args.lambda_critic > 0:
                # critic_batch = torch.cat([batch, unl_critic_batch])
                critic_label = torch.cat(
                    [1 - unl_critic_label, unl_critic_label])
                # critic_label = torch.cat([1 - unl_critic_label] * len(train_batches) + [unl_critic_label])

                if isinstance(critic, ClassificationD):
                    critic_output = critic(
                        torch.cat(
                            hidden, encoders[i](unl_critic_batch1,
                                                unl_critic_batch2)))
                    loss_dan.append(
                        critic.compute_loss(critic_output, critic_label))
                else:
                    critic_output = critic(
                        hidden, encoders[i](unl_critic_batch1,
                                            unl_critic_batch2))
                    loss_dan.append(critic_output)

                    # critic_output = critic(torch.cat(hiddens), encoder(unl_critic_batch))
                    # loss_dan = critic_output
            else:
                loss_dan = Variable(torch.FloatTensor([0]))

        # assert (len(outputs) == len(outputs[0]))
        source_ids = range(len(train_batches1))
        # for i in source_ids:

        # support_ids = [x for x in source_ids if x != i]  # experts
        support_ids = [x for x in source_ids]  # experts

        # i = 0

        # support_alphas = [ metric(
        #                      hiddens[i],
        #                      hiddens[j].detach(),
        #                      hidden_corresponding_labels[j],
        #                      Us[j], Ps[j], Ns[j],
        #                      args) for j in support_ids ]

        if args.metric == "biaffine":
            source_alphas = [
                metric(
                    hidden_dst,
                    hiddens[j].detach(),
                    Us[0],
                    Ws[0],
                    Vs[0],  # for biaffine metric, we use a unified matrix
                    args) for j in source_ids
            ]
        else:
            source_alphas = [
                metric(
                    hidden_dst,  # i^th source
                    hiddens[j].detach(),
                    hidden_corresponding_labels[j],
                    Us[j],
                    Ps[j],
                    Ns[j],
                    args) for j in source_ids
            ]

        support_alphas = [source_alphas[x] for x in support_ids]

        # print torch.cat([ x.unsqueeze(1) for x in support_alphas ], 1)
        support_alphas = softmax(support_alphas)

        # print("support_alphas after softmax", support_alphas)

        # meta-supervision: KL loss over \alpha and real source
        source_alphas = softmax(source_alphas)  # [ 32, 32, 32 ]
        source_labels = [
            torch.FloatTensor([x == len(train_batches1)]) for x in source_ids
        ]  # one-hot
        if args.cuda:
            source_alphas = [alpha.cuda() for alpha in source_alphas]
            source_labels = [label.cuda() for label in source_labels]

        source_labels = Variable(torch.stack(source_labels, dim=0))  # 3*1
        # print("source labels", source_labels)
        source_alphas = torch.stack(source_alphas, dim=0)
        # print("source_alpha after stack", source_alphas)

        source_labels = source_labels.expand_as(source_alphas).permute(1, 0)
        source_alphas = source_alphas.permute(1, 0)
        loss_kl.append(kl_criterion(source_alphas, source_labels))

        # entropy loss over \alpha
        # entropy_loss = entropy_criterion(torch.stack(support_alphas, dim=0).permute(1, 0))
        # print source_alphas
        loss_entropy.append(entropy_criterion(source_alphas))

        output_moe_i = sum([alpha.unsqueeze(1).repeat(1, 2) * F.softmax(outputs_dst_transfer[id], dim=1) \
                            for alpha, id in zip(support_alphas, support_ids)])
        # output_moe_full = sum([ alpha.unsqueeze(1).repeat(1, 2) * F.softmax(ms_outputs[i][id], dim=1) \
        #                         for alpha, id in zip(full_alphas, source_ids) ])

        # print("output_moe_i & labels", output_moe_i, train_labels[i])
        loss_moe.append(moe_criterion(torch.log(output_moe_i), labels_dst))
        # loss_moe.append(moe_criterion(torch.log(output_moe_full), train_labels[i]))

        # print("labels_dst", labels_dst)

        # upper_out = classifier_mix(torch.cat((cur_output_dst_mem, output_moe_i), dim=1))
        upper_out = cur_output_dst_mem + classifier_mix.multp * output_moe_i
        loss_all = mtl_criterion(torch.log_softmax(upper_out, dim=1),
                                 labels_dst)

        loss_train_dst = sum(loss_train_dst)

        loss_mtl = sum(loss_mtl)
        # print("loss mtl", loss_mtl)
        # loss_mtl = loss_mtl.mean()
        loss_mtl /= len(source_ids)
        loss_moe = sum(loss_moe)
        # if iter_cnt < 400:
        #     lambda_moe = 0
        #     lambda_entropy = 0
        # else:
        lambda_moe = args.lambda_moe
        lambda_entropy = args.lambda_entropy
        # loss = (1 - lambda_moe) * loss_mtl + lambda_moe * loss_moe
        loss = args.lambda_mtl * loss_mtl + lambda_moe * loss_moe
        loss_kl = sum(loss_kl)
        loss_entropy = sum(loss_entropy)
        loss += args.lambda_entropy * loss_entropy
        loss += loss_train_dst * args.lambda_dst
        loss += loss_all * args.lambda_all

        loss_total += loss

        if args.lambda_critic > 0:
            loss_dan = sum(loss_dan)
            loss += args.lambda_critic * loss_dan

        loss.backward()
        optim_model.step()

        # print("loss entropy", loss_entropy)

        # print("mats", [Us, Ps, Ns])
        # for paras in task_paras:
        #     print(paras)
        #     for name, param in paras:
        #         if param.requires_grad:
        #             print(name, param.data)

        # for name, param in encoder.named_parameters():
        #     if param.requires_grad:
        #         # print(name, param.data)
        #         print(name, param.grad)

        for cls_i, classifier in enumerate(classifiers):
            for name, param in classifier.named_parameters():
                # print(cls_i, name, param.grad)
                pass

        if iter_cnt % 5 == 0:
            # [(mu_i, covi_i), ...]
            # domain_encs = domain_encoding(dup_train_loaders, args, encoder)
            if args.metric == "biaffine":
                mats = [Us, Ws, Vs]
            else:
                mats = [Us, Ps, Ns]

            # evaluate(
            #             #     [encoders, encoder_dst],
            #             #     [classifiers, classifier_dst, classifier_mix],
            #             #     mats,
            #             #     [dup_train_loaders, valid_loader],
            #             #     True,
            #             #     args
            #             # )

            # say("\r" + " " * 50)
            # TODO: print train acc as well
            # print("loss dan", loss_dan)
            say("{} MTL loss: {:.4f}, MOE loss: {:.4f}, DAN loss: {:.4f}, "
                "loss: {:.4f}\n"
                # ", dev acc/oracle: {:.4f}/{:.4f}"
                .format(iter_cnt,
                        loss_mtl.item(),
                        loss_moe.item(),
                        loss_dan.item(),
                        loss.item(),
                        # curr_dev,
                        # oracle_curr_dev
                        ))

    writer.add_scalar('training_loss', loss_total / total, epoch)

    say("\n")
    return iter_cnt
示例#7
0
def evaluate(encoder, classifiers, mats, loaders, args):
    ''' Evaluate model using MOE
    '''
    map(lambda m: m.eval(), [encoder] + classifiers)

    if args.metric == "biaffine":
        Us, Ws, Vs = mats
    else:
        Us, Ps, Ns = mats

    source_loaders, valid_loader = loaders
    domain_encs = domain_encoding(source_loaders, args, encoder)

    oracle_correct = 0
    correct = 0
    tot_cnt = 0
    y_true = []
    y_pred = []

    for batch, label in valid_loader:
        if args.cuda:
            batch = batch.cuda()
            label = label.cuda()

        batch = Variable(batch)
        hidden = encoder(batch)
        source_ids = range(len(domain_encs))
        if args.metric == "biaffine":
            alphas = [ biaffine_metric_fast(hidden, mu[0], Us[0]) \
                       for mu in domain_encs ]
        else:
            alphas = [ mahalanobis_metric_fast(hidden, mu[0], U, mu[1], P, mu[2], N) \
                       for (mu, U, P, N) in zip(domain_encs, Us, Ps, Ns) ]
        # alphas = [ (1 - x / sum(alphas)) for x in alphas ]
        alphas = softmax(alphas)
        if args.cuda:
            alphas = [alpha.cuda() for alpha in alphas]
        alphas = [Variable(alpha) for alpha in alphas]

        outputs = [
            F.softmax(classifier(hidden), dim=1) for classifier in classifiers
        ]
        output = sum([ alpha.unsqueeze(1).repeat(1, 2) * output_i \
                        for (alpha, output_i) in zip(alphas, outputs) ])
        pred = output.data.max(dim=1)[1]
        oracle_eq = compute_oracle(outputs, label, args)

        if args.eval_only:
            for i in range(batch.shape[0]):
                for j in range(len(alphas)):
                    say("{:.4f}: [{:.4f}, {:.4f}], ".format(
                        alphas[j].data[i], outputs[j].data[i][0],
                        outputs[j].data[i][1]))
                oracle_TF = "T" if oracle_eq[i] == 1 else colored("F", 'red')
                say("gold: {}, pred: {}, oracle: {}\n".format(
                    label[i], pred[i], oracle_TF))
            say("\n")
            # print torch.cat(
            #         [
            #             torch.cat([ x.unsqueeze(1) for x in alphas ], 1),
            #             torch.cat([ x for x in outputs ], 1)
            #         ], 1
            #     )

        y_true += label.tolist()
        y_pred += pred.tolist()
        correct += pred.eq(label).sum()
        oracle_correct += oracle_eq.sum()
        tot_cnt += output.size(0)

    acc = float(correct) / tot_cnt
    oracle_acc = float(oracle_correct) / tot_cnt
    return (acc, oracle_acc), confusion_matrix(y_true, y_pred)
示例#8
0
def train_epoch(iter_cnt, encoder, classifiers, critic, mats, data_loaders,
                args, optim_model):
    map(lambda m: m.train(), [encoder, critic] + classifiers)

    train_loaders, unl_loader, valid_loader = data_loaders
    dup_train_loaders = deepcopy(train_loaders)

    mtl_criterion = nn.CrossEntropyLoss()
    moe_criterion = nn.NLLLoss()  # with log_softmax separated
    kl_criterion = nn.MSELoss()
    entropy_criterion = HLoss()

    if args.metric == "biaffine":
        metric = biaffine_metric
        Us, Ws, Vs = mats
    else:
        metric = mahalanobis_metric
        Us, Ps, Ns = mats

    for batches, unl_batch in zip(zip(*train_loaders), unl_loader):
        train_batches, train_labels = zip(*batches)
        unl_critic_batch, unl_critic_label = unl_batch

        iter_cnt += 1
        if args.cuda:
            train_batches = [batch.cuda() for batch in train_batches]
            train_labels = [label.cuda() for label in train_labels]

            unl_critic_batch = unl_critic_batch.cuda()
            unl_critic_label = unl_critic_label.cuda()

        train_batches = [Variable(batch) for batch in train_batches]
        train_labels = [Variable(label) for label in train_labels]
        unl_critic_batch = Variable(unl_critic_batch)
        unl_critic_label = Variable(unl_critic_label)

        optim_model.zero_grad()
        loss_mtl = []
        loss_moe = []
        loss_kl = []
        loss_entropy = []
        loss_dan = []

        ms_outputs = []  # (n_sources, n_classifiers)
        hiddens = []
        hidden_corresponding_labels = []
        # labels = []
        for i, (batch, label) in enumerate(zip(train_batches, train_labels)):
            hidden = encoder(batch)
            outputs = []
            # create output matrix:
            #     - (i, j) indicates the output of i'th source batch using j'th classifier
            hiddens.append(hidden)
            for classifier in classifiers:
                output = classifier(hidden)
                outputs.append(output)
            ms_outputs.append(outputs)
            hidden_corresponding_labels.append(label)
            # multi-task loss
            loss_mtl.append(mtl_criterion(ms_outputs[i][i], label))
            # labels.append(label)

            if args.lambda_critic > 0:
                # critic_batch = torch.cat([batch, unl_critic_batch])
                critic_label = torch.cat(
                    [1 - unl_critic_label, unl_critic_label])
                # critic_label = torch.cat([1 - unl_critic_label] * len(train_batches) + [unl_critic_label])

                if isinstance(critic, ClassificationD):
                    critic_output = critic(
                        torch.cat(hidden, encoder(unl_critic_batch)))
                    loss_dan.append(
                        critic.compute_loss(critic_output, critic_label))
                else:
                    critic_output = critic(hidden, encoder(unl_critic_batch))
                    loss_dan.append(critic_output)

                    # critic_output = critic(torch.cat(hiddens), encoder(unl_critic_batch))
                    # loss_dan = critic_output
            else:
                loss_dan = Variable(torch.FloatTensor([0]))

        # assert (len(outputs) == len(outputs[0]))
        source_ids = range(len(train_batches))
        for i in source_ids:

            support_ids = [x for x in source_ids if x != i]  # experts

            # support_alphas = [ metric(
            #                      hiddens[i],
            #                      hiddens[j].detach(),
            #                      hidden_corresponding_labels[j],
            #                      Us[j], Ps[j], Ns[j],
            #                      args) for j in support_ids ]

            if args.metric == "biaffine":
                source_alphas = [
                    metric(
                        hiddens[i],
                        hiddens[j].detach(),
                        Us[0],
                        Ws[0],
                        Vs[0],  # for biaffine metric, we use a unified matrix
                        args) for j in source_ids
                ]
            else:
                source_alphas = [
                    metric(
                        hiddens[i],  # i^th source
                        hiddens[j].detach(),
                        hidden_corresponding_labels[j],
                        Us[j],
                        Ps[j],
                        Ns[j],
                        args) for j in source_ids
                ]

            support_alphas = [source_alphas[x] for x in support_ids]

            # print torch.cat([ x.unsqueeze(1) for x in support_alphas ], 1)
            support_alphas = softmax(support_alphas)

            # print("support_alphas after softmax", support_alphas)

            # meta-supervision: KL loss over \alpha and real source
            source_alphas = softmax(source_alphas)  # [ 32, 32, 32 ]
            source_labels = [torch.FloatTensor([x == i])
                             for x in source_ids]  # one-hot
            if args.cuda:
                source_alphas = [alpha.cuda() for alpha in source_alphas]
                source_labels = [label.cuda() for label in source_labels]

            source_labels = Variable(torch.stack(source_labels, dim=0))  # 3*1
            source_alphas = torch.stack(source_alphas, dim=0)
            print("source_alpha after stack", source_alphas.size())

            source_labels = source_labels.expand_as(source_alphas).permute(
                1, 0)
            source_alphas = source_alphas.permute(1, 0)
            loss_kl.append(kl_criterion(source_alphas, source_labels))

            # entropy loss over \alpha
            # entropy_loss = entropy_criterion(torch.stack(support_alphas, dim=0).permute(1, 0))
            # print source_alphas
            loss_entropy.append(entropy_criterion(source_alphas))

            output_moe_i = sum([ alpha.unsqueeze(1).repeat(1, 2) * F.softmax(ms_outputs[i][id], dim=1) \
                                    for alpha, id in zip(support_alphas, support_ids) ])
            # output_moe_full = sum([ alpha.unsqueeze(1).repeat(1, 2) * F.softmax(ms_outputs[i][id], dim=1) \
            #                         for alpha, id in zip(full_alphas, source_ids) ])

            loss_moe.append(
                moe_criterion(torch.log(output_moe_i), train_labels[i]))
            # loss_moe.append(moe_criterion(torch.log(output_moe_full), train_labels[i]))

        loss_mtl = sum(loss_mtl)
        loss_moe = sum(loss_moe)
        # if iter_cnt < 400:
        #     lambda_moe = 0
        #     lambda_entropy = 0
        # else:
        lambda_moe = args.lambda_moe
        lambda_entropy = args.lambda_entropy
        # loss = (1 - lambda_moe) * loss_mtl + lambda_moe * loss_moe
        loss = loss_mtl + lambda_moe * loss_moe
        loss_kl = sum(loss_kl)
        loss_entropy = sum(loss_entropy)
        loss += args.lambda_entropy * loss_entropy

        if args.lambda_critic > 0:
            loss_dan = sum(loss_dan)
            loss += args.lambda_critic * loss_dan

        loss.backward()
        optim_model.step()

        if iter_cnt % 30 == 0:
            # [(mu_i, covi_i), ...]
            # domain_encs = domain_encoding(dup_train_loaders, args, encoder)
            if args.metric == "biaffine":
                mats = [Us, Ws, Vs]
            else:
                mats = [Us, Ps, Ns]

            (curr_dev, oracle_curr_dev), confusion_mat = evaluate(
                encoder, classifiers, mats, [dup_train_loaders, valid_loader],
                args)

            # say("\r" + " " * 50)
            # TODO: print train acc as well
            # print("loss dan", loss_dan)
            say("{} MTL loss: {:.4f}, MOE loss: {:.4f}, DAN loss: {:.4f}, "
                "loss: {:.4f}, dev acc/oracle: {:.4f}/{:.4f}\n".format(
                    iter_cnt, loss_mtl.item(), loss_moe.item(),
                    loss_dan.item(), loss.item(), curr_dev, oracle_curr_dev))

    say("\n")
    return iter_cnt
示例#9
0
def evaluate(encoders,
             classifiers,
             mats,
             loaders,
             return_best_thrs,
             args,
             thr=None):
    ''' Evaluate model using MOE
    '''
    map(lambda m: m.eval(), [encoders] + classifiers)

    if args.metric == "biaffine":
        Us, Ws, Vs = mats
    else:
        Us, Ps, Ns = mats

    source_loaders, valid_loader = loaders
    domain_encs = domain_encoding(source_loaders, args, encoders)

    oracle_correct = 0
    correct = 0
    tot_cnt = 0
    y_true = []
    y_pred = []
    y_score = []

    for batch1, batch2, label in valid_loader:
        if args.cuda:
            batch1 = batch1.cuda()
            batch2 = batch2.cuda()
            label = label.cuda()
        # print("eval labels", label)

        batch1 = Variable(batch1)
        batch2 = Variable(batch2)
        # print("bs", len(batch1))
        _, hidden = encoders[0](batch1, batch2)
        source_ids = range(len(domain_encs))
        if args.metric == "biaffine":
            alphas = [biaffine_metric_fast(hidden, mu[0], Us[0]) \
                      for mu in domain_encs]
        else:
            alphas = [mahalanobis_metric_fast(hidden, mu[0], U, mu[1], P, mu[2], N) \
                      for (mu, U, P, N) in zip(domain_encs, Us, Ps, Ns)]
        # alphas = [ (1 - x / sum(alphas)) for x in alphas ]
        alphas = softmax(alphas)
        if args.cuda:
            alphas = [alpha.cuda() for alpha in alphas]
        alphas = [Variable(alpha) for alpha in alphas]

        outputs = [
            F.softmax(classifier(hidden), dim=1) for classifier in classifiers
        ]
        output = sum([alpha.unsqueeze(1).repeat(1, 2) * output_i \
                      for (alpha, output_i) in zip(alphas, outputs)])
        pred = output.data.max(dim=1)[1]
        oracle_eq = compute_oracle(outputs, label, args)

        if args.eval_only:
            for i in range(batch1.shape[0]):
                for j in range(len(alphas)):
                    say("{:.4f}: [{:.4f}, {:.4f}], ".format(
                        alphas[j].data[i], outputs[j].data[i][0],
                        outputs[j].data[i][1]))
                oracle_TF = "T" if oracle_eq[i] == 1 else colored("F", 'red')
                say("gold: {}, pred: {}, oracle: {}\n".format(
                    label[i], pred[i], oracle_TF))
            say("\n")
            # print torch.cat(
            #         [
            #             torch.cat([ x.unsqueeze(1) for x in alphas ], 1),
            #             torch.cat([ x for x in outputs ], 1)
            #         ], 1
            #     )

        y_true += label.tolist()
        y_pred += pred.tolist()
        # print("output", output[:, 1].data.tolist())
        y_score += output[:, 1].data.tolist()
        # print("cur y score", y_score)

        correct += pred.eq(label).sum()
        oracle_correct += oracle_eq.sum()
        tot_cnt += output.size(0)

    # print("y_true", y_true)
    # print("y_pred", y_pred)

    if thr is not None:
        print("using threshold %.4f" % thr)
        y_score = np.array(y_score)
        y_pred = np.zeros_like(y_score)
        y_pred[y_score > thr] = 1
    else:
        # print("y_score", y_score)
        pass

    prec, rec, f1, _ = precision_recall_fscore_support(y_true,
                                                       y_pred,
                                                       average="binary")
    auc = roc_auc_score(y_true, y_score)
    print("AUC: {:.2f}, Prec: {:.2f}, Rec: {:.2f}, F1: {:.2f}".format(
        auc * 100, prec * 100, rec * 100, f1 * 100))

    best_thr = None
    metric = [auc, prec, rec, f1]

    if return_best_thrs:
        precs, recs, thrs = precision_recall_curve(y_true, y_score)
        f1s = 2 * precs * recs / (precs + recs)
        f1s = f1s[:-1]
        thrs = thrs[~np.isnan(f1s)]
        f1s = f1s[~np.isnan(f1s)]
        best_thr = thrs[np.argmax(f1s)]
        print("best threshold={:.4f}, f1={:.4f}".format(best_thr, np.max(f1s)))

    acc = float(correct) / tot_cnt
    oracle_acc = float(oracle_correct) / tot_cnt
    # return (acc, oracle_acc), confusion_matrix(y_true, y_pred)
    return best_thr, metric