Beispiel #1
0
def train_moe_deep_stack(args):
    save_model_dir = os.path.join(settings.OUT_DIR, args.test)
    classifiers, attn_mats = torch.load(
        os.path.join(
            save_model_dir,
            "{}_{}_moe_best_now.mdl".format(args.test, args.base_model)))
    print("base model", args.base_model)
    print("classifier", classifiers[0])

    source_train_sets = args.train.split(',')
    pretrain_emb = torch.load(
        os.path.join(settings.OUT_DIR, "rnn_init_word_emb.emb"))

    encoders_src = []
    for src_i in range(len(source_train_sets)):
        cur_model_dir = os.path.join(settings.OUT_DIR,
                                     source_train_sets[src_i])

        if args.base_model == "cnn":
            encoder_class = CNNMatchModel(
                input_matrix_size1=args.matrix_size1,
                input_matrix_size2=args.matrix_size2,
                mat1_channel1=args.mat1_channel1,
                mat1_kernel_size1=args.mat1_kernel_size1,
                mat1_channel2=args.mat1_channel2,
                mat1_kernel_size2=args.mat1_kernel_size2,
                mat1_hidden=args.mat1_hidden,
                mat2_channel1=args.mat2_channel1,
                mat2_kernel_size1=args.mat2_kernel_size1,
                mat2_hidden=args.mat2_hidden)
        elif args.base_model == "rnn":
            encoder_class = BiLSTM(pretrain_emb=pretrain_emb,
                                   vocab_size=args.max_vocab_size,
                                   embedding_size=args.embedding_size,
                                   hidden_size=args.hidden_size,
                                   dropout=args.dropout)
        else:
            raise NotImplementedError
        if args.cuda:
            encoder_class.load_state_dict(
                torch.load(
                    os.path.join(
                        cur_model_dir,
                        "{}-match-best-now.mdl".format(args.base_model))))
        else:
            encoder_class.load_state_dict(
                torch.load(os.path.join(
                    cur_model_dir,
                    "{}-match-best-now.mdl".format(args.base_model)),
                           map_location=torch.device('cpu')))

        encoders_src.append(encoder_class)

    map(lambda m: m.eval(), encoders_src + classifiers + attn_mats)

    if args.cuda:
        map(lambda m: m.cuda(), classifiers + encoders_src + attn_mats)

    if args.base_model == "cnn":
        train_dataset_dst = ProcessedCNNInputDataset(args.test, "train")
        valid_dataset = ProcessedCNNInputDataset(args.test, "valid")
        test_dataset = ProcessedCNNInputDataset(args.test, "test")
    elif args.base_model == "rnn":
        train_dataset_dst = ProcessedRNNInputDataset(args.test, "train")
        valid_dataset = ProcessedRNNInputDataset(args.test, "valid")
        test_dataset = ProcessedRNNInputDataset(args.test, "test")
    else:
        raise NotImplementedError

    train_loader_dst = data.DataLoader(train_dataset_dst,
                                       batch_size=args.batch_size,
                                       shuffle=False,
                                       num_workers=0)

    valid_loader = data.DataLoader(valid_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   num_workers=0)

    test_loader = data.DataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=0)
    say("Corpus loaded.\n")

    meta_features = np.empty(shape=(0, 192 + 2 * 8))
    meta_labels = []
    n_sources = len(encoders_src)
    encoders = encoders_src

    if args.base_model == "cnn":
        for batch1, batch2, label in train_loader_dst:
            if args.cuda:
                batch1 = batch1.cuda()
                batch2 = batch2.cuda()
                label = label.cuda()

            outputs_dst_transfer = []
            hidden_from_src_enc = []
            for src_i in range(n_sources):
                _, cur_hidden = encoders[src_i](batch1, batch2)
                hidden_from_src_enc.append(cur_hidden)
                cur_output = classifiers[src_i](cur_hidden)
                outputs_dst_transfer.append(cur_output)

            source_ids = range(n_sources)
            support_ids = [x for x in source_ids]  # experts

            source_alphas = [
                attn_mats[j](hidden_from_src_enc[j]).squeeze()
                for j in source_ids
            ]

            support_alphas = [source_alphas[x] for x in support_ids]
            support_alphas = softmax(support_alphas)
            source_alphas = softmax(source_alphas)  # [ 32, 32, 32 ]
            alphas = source_alphas
Beispiel #2
0
def main(args=args):
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    logger.info('cuda is available %s', args.cuda)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed + args.seed_delta)
    if args.cuda:
        torch.cuda.manual_seed(args.seed + args.seed_delta)

    dataset = ProcessedCNNInputDataset(args.entity_type, "train")
    dataset_valid = ProcessedCNNInputDataset(args.entity_type, "valid")
    dataset_test = ProcessedCNNInputDataset(args.entity_type, "test")
    N = len(dataset)
    N_valid = len(dataset_valid)
    N_test = len(dataset_test)
    train_loader = DataLoader(dataset,
                              batch_size=args.batch,
                              sampler=ChunkSampler(N, 0))
    valid_loader = DataLoader(dataset_valid,
                              batch_size=args.batch,
                              sampler=ChunkSampler(N_valid, 0))
    test_loader = DataLoader(dataset_test,
                             batch_size=args.batch,
                             sampler=ChunkSampler(N_test, 0))
    model = CNNMatchModel(input_matrix_size1=args.matrix_size1,
                          input_matrix_size2=args.matrix_size2,
                          mat1_channel1=args.mat1_channel1,
                          mat1_kernel_size1=args.mat1_kernel_size1,
                          mat1_channel2=args.mat1_channel2,
                          mat1_kernel_size2=args.mat1_kernel_size2,
                          mat1_hidden=args.mat1_hidden,
                          mat2_channel1=args.mat2_channel1,
                          mat2_kernel_size1=args.mat2_kernel_size1,
                          mat2_hidden=args.mat2_hidden)
    model = model.float()

    if args.cuda:
        model.cuda()

    optimizer = optim.Adagrad(
        model.parameters(),
        lr=args.lr,
        # initial_accumulator_value=args.initial_accumulator_value,
        weight_decay=args.weight_decay)
    t_total = time.time()
    logger.info("training...")

    model_dir = join(settings.OUT_DIR, args.entity_type)
    os.makedirs(model_dir, exist_ok=True)

    # model.load_state_dict(torch.load(join(settings.OUT_VENUE_DIR, "venue-matching-cnn.mdl")))
    evaluate(0, test_loader, model, thr=None, args=args)
    min_loss_val = None
    best_test_metrics = None
    for epoch in range(args.epochs):
        print("training epoch", epoch)
        metrics_val, metrics_test = train(epoch,
                                          train_loader,
                                          valid_loader,
                                          test_loader,
                                          model,
                                          optimizer,
                                          args=args)
        if metrics_val is not None:
            if min_loss_val is None or min_loss_val > metrics_val[0]:
                min_loss_val = metrics_val[0]
                best_test_metrics = metrics_test
                torch.save(model.state_dict(),
                           join(model_dir, "cnn-match-best-now.mdl"))

    logger.info("optimization Finished!")
    logger.info("total time elapsed: {:.4f}s".format(time.time() - t_total))

    # torch.save(model.state_dict(), join(model_dir, 'paper-matching-cnn.mdl'))
    # logger.info('paper matching CNN model saved')

    print(
        "min valid loss {:.4f}, best test metrics: AUC: {:.2f}, Prec: {:.4f}, Rec: {:.4f}, F1: {:.4f}"
        .format(min_loss_val, best_test_metrics[1], best_test_metrics[2],
                best_test_metrics[3], best_test_metrics[4]))

    # evaluate(args.epochs, test_loader, model, thr=best_thr, args=args)
    writer.close()
Beispiel #3
0
def train(args):
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    say('cuda is available %s\n' % args.cuda)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed + args.seed_delta)
    if args.cuda:
        torch.cuda.manual_seed(args.seed + args.seed_delta)

    source_train_sets = args.train.split(',')
    print("sources", source_train_sets)

    pretrain_emb = torch.load(
        os.path.join(settings.OUT_DIR, "rnn_init_word_emb.emb"))

    encoders_src = []
    for src_i in range(len(source_train_sets)):
        cur_model_dir = os.path.join(settings.OUT_DIR,
                                     source_train_sets[src_i])

        if args.base_model == "cnn":
            encoder_class = CNNMatchModel(
                input_matrix_size1=args.matrix_size1,
                input_matrix_size2=args.matrix_size2,
                mat1_channel1=args.mat1_channel1,
                mat1_kernel_size1=args.mat1_kernel_size1,
                mat1_channel2=args.mat1_channel2,
                mat1_kernel_size2=args.mat1_kernel_size2,
                mat1_hidden=args.mat1_hidden,
                mat2_channel1=args.mat2_channel1,
                mat2_kernel_size1=args.mat2_kernel_size1,
                mat2_hidden=args.mat2_hidden)
        elif args.base_model == "rnn":
            encoder_class = BiLSTM(pretrain_emb=pretrain_emb,
                                   vocab_size=args.max_vocab_size,
                                   embedding_size=args.embedding_size,
                                   hidden_size=args.hidden_size,
                                   dropout=args.dropout)
        else:
            raise NotImplementedError
        if args.cuda:
            encoder_class.load_state_dict(
                torch.load(
                    os.path.join(
                        cur_model_dir,
                        "{}-match-best-now.mdl".format(args.base_model))))
        else:
            encoder_class.load_state_dict(
                torch.load(os.path.join(
                    cur_model_dir,
                    "{}-match-best-now.mdl".format(args.base_model)),
                           map_location=torch.device('cpu')))

        encoders_src.append(encoder_class)

    dst_pretrain_dir = os.path.join(settings.OUT_DIR, args.test)
    if args.base_model == "cnn":
        encoder_dst_pretrain = CNNMatchModel(
            input_matrix_size1=args.matrix_size1,
            input_matrix_size2=args.matrix_size2,
            mat1_channel1=args.mat1_channel1,
            mat1_kernel_size1=args.mat1_kernel_size1,
            mat1_channel2=args.mat1_channel2,
            mat1_kernel_size2=args.mat1_kernel_size2,
            mat1_hidden=args.mat1_hidden,
            mat2_channel1=args.mat2_channel1,
            mat2_kernel_size1=args.mat2_kernel_size1,
            mat2_hidden=args.mat2_hidden)
    elif args.base_model == "rnn":
        encoder_dst_pretrain = BiLSTM(pretrain_emb=pretrain_emb,
                                      vocab_size=args.max_vocab_size,
                                      embedding_size=args.embedding_size,
                                      hidden_size=args.hidden_size,
                                      dropout=args.dropout)
    else:
        raise NotImplementedError

    args = argparser.parse_args()
    say(args)
    print()

    say("Transferring from %s to %s\n" % (args.train, args.test))

    if args.base_model == "cnn":
        train_dataset_dst = ProcessedCNNInputDataset(args.test, "train")
        valid_dataset = ProcessedCNNInputDataset(args.test, "valid")
        test_dataset = ProcessedCNNInputDataset(args.test, "test")

    elif args.base_model == "rnn":
        train_dataset_dst = ProcessedRNNInputDataset(args.test, "train")
        valid_dataset = ProcessedRNNInputDataset(args.test, "valid")
        test_dataset = ProcessedRNNInputDataset(args.test, "test")
    else:
        raise NotImplementedError

    train_loader_dst = data.DataLoader(train_dataset_dst,
                                       batch_size=args.batch_size,
                                       shuffle=False,
                                       num_workers=0)

    valid_loader = data.DataLoader(valid_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   num_workers=0)

    test_loader = data.DataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=0)

    say("Corpus loaded.\n")

    classifiers = []
    attn_mats = []
    for source in source_train_sets:

        classifier = nn.Sequential(
            nn.Linear(encoders_src[0].n_out, 64),
            nn.ReLU(),
            nn.Linear(64, 16),
            nn.ReLU(),
            nn.Linear(16, 2),
        )
        # cur_att_weight = nn.Linear(len(encoders_src), 1, bias=True)
        cur_att_weight = nn.Linear(len(encoders_src), 1, bias=False)
        # nn.init.uniform_(cur_att_weight.weight)
        # print(cur_att_weight)
        cur_att_weight.weight = nn.Parameter(
            torch.ones(size=(1, len(encoders_src))), requires_grad=True)
        print("init cur att weight", cur_att_weight.weight)
        if args.attn_type == "onehot":
            attn_mats.append(
                # nn.Linear(encoders_src[0].n_out, 1)
                cur_att_weight
                # nn.Linear(encoders_src[0].n_out, encoders_src[0].n_out)
                # MulInteractAttention(encoders_src[0].n_out, 16)
            )
        elif args.attn_type == "cor":
            attn_mats.append(MulInteractAttention(encoders_src[0].n_out, 16))
        else:
            raise NotImplementedError
        classifiers.append(classifier)
    print("classifier build", classifiers[0])

    if args.cuda:
        map(lambda m: m.cuda(), classifiers + encoders_src + attn_mats)

    for i, classifier in enumerate(classifiers):
        say("Classifier-{}: {}\n".format(i, classifier))

    requires_grad = lambda x: x.requires_grad
    task_params = []
    for src_i in range(len(classifiers)):
        task_params += list(classifiers[src_i].parameters())
        task_params += list(attn_mats[src_i].parameters())

    if args.base_model == "cnn":
        optim_model = optim.Adagrad(
            filter(requires_grad, task_params),
            lr=args.lr,
            weight_decay=1e-4  #TODO
        )
    elif args.base_model == "rnn":
        optim_model = optim.Adam(filter(requires_grad, task_params),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    else:
        raise NotImplementedError

    say("Training will begin from scratch\n")

    iter_cnt = 0
    min_loss_val = None
    best_test_results = None
    model_dir = os.path.join(settings.OUT_DIR, args.test)

    for epoch in range(args.max_epoch):
        print("training epoch", epoch)

        iter_cnt = train_epoch(iter_cnt, [encoders_src, encoder_dst_pretrain],
                               classifiers, attn_mats, train_loader_dst, args,
                               optim_model, epoch)

        thr, metrics_val = evaluate(epoch,
                                    [encoders_src, encoder_dst_pretrain],
                                    classifiers, attn_mats, valid_loader, True,
                                    args)

        _, metrics_test = evaluate(epoch, [encoders_src, encoder_dst_pretrain],
                                   classifiers,
                                   attn_mats,
                                   test_loader,
                                   False,
                                   args,
                                   thr=thr)

        if min_loss_val is None or min_loss_val > metrics_val[0]:
            print("change val loss from {} to {}".format(
                min_loss_val, metrics_val[0]))
            min_loss_val = metrics_val[0]
            best_test_results = metrics_test
            torch.save([classifiers, attn_mats],
                       os.path.join(
                           model_dir,
                           "{}_{}_moe_simple_attn_best_now.mdl".format(
                               args.test, args.base_model)))
        say("\n")
        writer.flush()

    say(
        colored("Min valid loss: {:.4f}, best test results, "
                "AUC: {:.2f}, Prec: {:.2f}, Rec: {:.2f}, F1: {:.2f}\n".format(
                    min_loss_val, best_test_results[1] * 100,
                    best_test_results[2] * 100, best_test_results[3] * 100,
                    best_test_results[4] * 100)))
Beispiel #4
0
def train_one_time(args, wf, repeat_seed=0):
    tb_dir = 'runs/{}_sup_base_{}_attn_{}_moe_sources_{}_train_num_{}_repeat_{}'.format(
        args.test, args.base_model, args.attn_type, n_sources, args.train_num,
        repeat_seed)
    if os.path.exists(tb_dir) and os.path.isdir(tb_dir):
        shutil.rmtree(tb_dir)
    writer = SummaryWriter(tb_dir)

    args.cuda = not args.no_cuda and torch.cuda.is_available()
    say('cuda is available %s\n' % args.cuda)

    np.random.seed(args.seed + repeat_seed)
    torch.manual_seed(args.seed + repeat_seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed + repeat_seed)

    source_train_sets = args.train.split(',')
    print("sources", source_train_sets)
    sources_idx = [source_to_idx[s] for s in source_train_sets]

    pretrain_emb = torch.load(
        os.path.join(settings.OUT_DIR, "rnn_init_word_emb.emb"))

    encoders_src = []
    # for src_i in range(len(source_train_sets)):
    for src_i in sources_idx:
        # cur_model_dir = os.path.join(settings.OUT_DIR, source_train_sets[src_i])
        cur_model_dir = os.path.join(settings.OUT_DIR, sources_all[src_i])

        if args.base_model == "cnn":
            encoder_class = CNNMatchModel(
                input_matrix_size1=args.matrix_size1,
                input_matrix_size2=args.matrix_size2,
                mat1_channel1=args.mat1_channel1,
                mat1_kernel_size1=args.mat1_kernel_size1,
                mat1_channel2=args.mat1_channel2,
                mat1_kernel_size2=args.mat1_kernel_size2,
                mat1_hidden=args.mat1_hidden,
                mat2_channel1=args.mat2_channel1,
                mat2_kernel_size1=args.mat2_kernel_size1,
                mat2_hidden=args.mat2_hidden)
        elif args.base_model == "rnn":
            encoder_class = BiLSTM(pretrain_emb=pretrain_emb,
                                   vocab_size=args.max_vocab_size,
                                   embedding_size=args.embedding_size,
                                   hidden_size=args.hidden_size,
                                   dropout=args.dropout)
        else:
            raise NotImplementedError
        if args.cuda:
            encoder_class.load_state_dict(
                torch.load(
                    os.path.join(
                        cur_model_dir,
                        "{}-match-best-now-train-num-{}-try-{}.mdl".format(
                            args.base_model, args.train_num, repeat_seed))))
        else:
            encoder_class.load_state_dict(
                torch.load(os.path.join(
                    cur_model_dir,
                    "{}-match-best-now-train-num-{}-try-{}.mdl".format(
                        args.base_model, args.train_num, repeat_seed)),
                           map_location=torch.device('cpu')))

        encoders_src.append(encoder_class)

    dst_model_dir = os.path.join(settings.OUT_DIR, args.test)
    if args.base_model == "cnn":
        encoder_dst_pretrain = CNNMatchModel(
            input_matrix_size1=args.matrix_size1,
            input_matrix_size2=args.matrix_size2,
            mat1_channel1=args.mat1_channel1,
            mat1_kernel_size1=args.mat1_kernel_size1,
            mat1_channel2=args.mat1_channel2,
            mat1_kernel_size2=args.mat1_kernel_size2,
            mat1_hidden=args.mat1_hidden,
            mat2_channel1=args.mat2_channel1,
            mat2_kernel_size1=args.mat2_kernel_size1,
            mat2_hidden=args.mat2_hidden)
    elif args.base_model == "rnn":
        encoder_dst_pretrain = BiLSTM(pretrain_emb=pretrain_emb,
                                      vocab_size=args.max_vocab_size,
                                      embedding_size=args.embedding_size,
                                      hidden_size=args.hidden_size,
                                      dropout=args.dropout)
    else:
        raise NotImplementedError

    encoder_dst_pretrain.load_state_dict(
        torch.load(
            os.path.join(
                dst_model_dir,
                "{}-match-best-now-train-num-{}-try-{}.mdl".format(
                    args.base_model, args.train_num, repeat_seed))))

    # args = argparser.parse_args()
    say(args)
    print()

    say("Transferring from %s to %s\n" % (args.train, args.test))

    if args.base_model == "cnn":
        train_dataset_dst = ProcessedCNNInputDataset(args.test, "train",
                                                     args.train_num,
                                                     repeat_seed)
        valid_dataset = ProcessedCNNInputDataset(args.test, "valid")
        test_dataset = ProcessedCNNInputDataset(args.test, "test")

    elif args.base_model == "rnn":
        train_dataset_dst = ProcessedRNNInputDataset(args.test, "train",
                                                     args.train_num,
                                                     repeat_seed)
        valid_dataset = ProcessedRNNInputDataset(args.test, "valid")
        test_dataset = ProcessedRNNInputDataset(args.test, "test")
    else:
        raise NotImplementedError

    print("train num", len(train_dataset_dst))

    train_loader_dst = data.DataLoader(train_dataset_dst,
                                       batch_size=args.batch_size,
                                       shuffle=False,
                                       num_workers=0)

    valid_loader = data.DataLoader(valid_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   num_workers=0)

    test_loader = data.DataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=0)

    say("Corpus loaded.\n")

    classifiers = []
    attn_mats = []

    # for src_i in sources_idx:
    #     classifier = torch.load(os.path.join(dst_model_dir, "{}_{}_classifier_from_src_{}_train_num_{}_try_{}.mdl".format(
    #             args.test, args.base_model, src_i, args.train_num, repeat_seed
    #         )))
    #     classifiers.append(classifier)

    for source in source_train_sets:
        classifier = nn.Sequential(
            nn.Linear(encoders_src[0].n_out, 64),
            nn.ReLU(),
            nn.Linear(64, 16),
            nn.ReLU(),
            nn.Linear(16, 2),
        )
        classifiers.append(classifier)

        if args.attn_type == "onehot":
            # cur_att_weight = nn.Linear(len(encoders_src), 1, bias=False)
            # cur_att_weight.weight = nn.Parameter(torch.ones(size=(1, len(encoders_src))), requires_grad=True)
            attn_mats.append(OneHotAttention(n_sources=len(encoders_src)))
        elif args.attn_type == "cor":
            attn_mats.append(MulInteractAttention(encoders_src[0].n_out, 16))
        elif args.attn_type == "mlp":
            attn_mats.append(MLP(encoders_src[0].n_out))
        else:
            raise NotImplementedError

    if args.cuda:
        map(lambda m: m.cuda(), classifiers + encoders_src + attn_mats)
        encoder_dst_pretrain.cuda()
        [e.cuda() for e in encoders_src]
        [c.cuda() for c in classifiers]
        [a.cuda() for a in attn_mats]
        # print("device", next(attn_mats[2].parameters()))
        # print("here")

    for i, classifier in enumerate(classifiers):
        say("Classifier-{}: {}\n".format(i, classifier))

    requires_grad = lambda x: x.requires_grad
    task_params = []
    for src_i in range(len(classifiers)):
        task_params += list(classifiers[src_i].parameters())
        task_params += list(attn_mats[src_i].parameters())

        for para in classifiers[src_i].parameters():
            para.require_grad = True
        for para in encoders_src[src_i].parameters():
            para.require_grad = False
        for para in attn_mats[src_i].parameters():
            para.require_grad = True

    if args.base_model == "cnn":
        optim_model = optim.Adagrad(filter(requires_grad, task_params),
                                    lr=args.lr,
                                    weight_decay=args.weight_decay)
    elif args.base_model == "rnn":
        optim_model = optim.Adam(filter(requires_grad, task_params),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    else:
        raise NotImplementedError

    say("Training will begin from scratch\n")

    iter_cnt = 0
    min_loss_val = None
    max_auc_val = None
    best_test_results = None
    weights_sources = None
    model_dir = os.path.join(settings.OUT_DIR, args.test)

    for epoch in range(args.max_epoch):
        print("training epoch", epoch)

        iter_cnt = train_epoch(iter_cnt, [encoders_src, encoder_dst_pretrain],
                               classifiers, attn_mats, train_loader_dst, args,
                               optim_model, epoch, writer)

        thr, metrics_val, alpha_weights_val = evaluate(
            epoch, [encoders_src, encoder_dst_pretrain], classifiers,
            attn_mats, valid_loader, True, args, writer)

        _, metrics_test, alpha_weights_test = evaluate(
            epoch, [encoders_src, encoder_dst_pretrain],
            classifiers,
            attn_mats,
            test_loader,
            False,
            args,
            writer,
            thr=thr)

        if min_loss_val is None or min_loss_val > metrics_val[0]:
            # if max_auc_val is None or max_auc_val < metrics_val[1]:
            print("change val loss from {} to {}".format(
                min_loss_val, metrics_val[0]))
            # print("change val auc from {} to {}".format(max_auc_val, metrics_val[1]))
            min_loss_val = metrics_val[0]
            max_auc_val = metrics_val[1]
            best_test_results = metrics_test
            weights_sources = [alpha_weights_val, alpha_weights_test]
            torch.save(
                [classifiers, attn_mats],
                os.path.join(
                    model_dir,
                    "{}_{}_moe_attn_{}_sources_{}_train_num_{}_seed_{}_best_now.mdl"
                    .format(args.test, args.base_model, args.attn_type,
                            n_sources, args.train_num, repeat_seed)))

    print()
    print(
        "min valid loss {:.4f}, best test metrics: AUC: {:.2f}, Prec: {:.2f}, Rec: {:.2f}, F1: {:.2f}\n"
        .format(min_loss_val, best_test_results[1] * 100,
                best_test_results[2] * 100, best_test_results[3] * 100,
                best_test_results[4] * 100))
    # print("max valid auc {:.4f}, best test metrics: AUC: {:.2f}, Prec: {:.2f}, Rec: {:.2f}, F1: {:.2f}\n".format(
    #             max_auc_val, best_test_results[1] * 100, best_test_results[2] * 100, best_test_results[3] * 100,
    #                           best_test_results[4] * 100
    #         ))

    # with open(os.path.join(model_dir, "{}_{}_moe_attn_{}_sources_{}_train_num_{}_seed_{}_results.txt".format(
    #         args.test, args.base_model, args.attn_type, n_sources, args.train_num, args.seed_delta)), "w") as wf:
    wf.write(
        "min valid loss {:.4f}, best test metrics: AUC: {:.2f}, Prec: {:.2f}, Rec: {:.2f}, F1: {:.2f}\n"
        .format(min_loss_val, best_test_results[1] * 100,
                best_test_results[2] * 100, best_test_results[3] * 100,
                best_test_results[4] * 100))
    # wf.write(
    #     "max valid auc {:.4f}, best test metrics: AUC: {:.2f}, Prec: {:.2f}, Rec: {:.2f}, F1: {:.2f}\n".format(
    #         max_auc_val, best_test_results[1] * 100, best_test_results[2] * 100, best_test_results[3] * 100,
    #                       best_test_results[4] * 100
    #     ))
    wf.write("val weights: ")
    for w in weights_sources[0]:
        wf.write("{:.4f}, ".format(w))
    wf.write("\n")
    wf.write("test weights: ")
    for w in weights_sources[1]:
        wf.write("{:.4f}, ".format(w))
    wf.write("\n\n")
    # wf.write(json.dumps(vars(args)) + "\n")
    writer.close()
def train(args):
    ''' Training Strategy

    Input: source = {S1, S2, ..., Sk}, target = {T}

    Train:
        Approach 1: fix metric and learn encoder only
        Approach 2: learn metric and encoder alternatively
    '''

    # test_mahalanobis_metric() and return

    args.cuda = not args.no_cuda and torch.cuda.is_available()
    say('cuda is available %s\n' % args.cuda)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed + args.seed_delta)
    if args.cuda:
        torch.cuda.manual_seed(args.seed + args.seed_delta)

    source_train_sets = args.train.split(',')
    print("sources", source_train_sets)

    encoders = []
    for _ in range(len(source_train_sets)):
        # encoder_class = get_model_class("mlp")
        encoder_class = CNNMatchModel(input_matrix_size1=args.matrix_size1,
                                      input_matrix_size2=args.matrix_size2,
                                      mat1_channel1=args.mat1_channel1,
                                      mat1_kernel_size1=args.mat1_kernel_size1,
                                      mat1_channel2=args.mat1_channel2,
                                      mat1_kernel_size2=args.mat1_kernel_size2,
                                      mat1_hidden=args.mat1_hidden,
                                      mat2_channel1=args.mat2_channel1,
                                      mat2_kernel_size1=args.mat2_kernel_size1,
                                      mat2_hidden=args.mat2_hidden)

        # encoder_class.add_config(argparser)
        encoders.append(encoder_class)

    encoder_dst = CNNMatchModel(input_matrix_size1=args.matrix_size1,
                                input_matrix_size2=args.matrix_size2,
                                mat1_channel1=args.mat1_channel1,
                                mat1_kernel_size1=args.mat1_kernel_size1,
                                mat1_channel2=args.mat1_channel2,
                                mat1_kernel_size2=args.mat1_kernel_size2,
                                mat1_hidden=args.mat1_hidden,
                                mat2_channel1=args.mat2_channel1,
                                mat2_kernel_size1=args.mat2_kernel_size1,
                                mat2_hidden=args.mat2_hidden)

    critic_class = get_critic_class(args.critic)
    critic_class.add_config(argparser)

    args = argparser.parse_args()
    say(args)

    # encoder is shared across domains
    # encoder = encoder_class(args)
    # encoder = encoder_class

    print()
    print("encoder", encoders[0])

    say("Transferring from %s to %s\n" % (args.train, args.test))
    train_loaders = []
    # valid_loaders_src = []
    # test_loaders_src = []
    Us = []
    Ps = []
    Ns = []
    Ws = []
    Vs = []
    # Ms = []

    for source in source_train_sets:
        # filepath = os.path.join(DATA_DIR, "%s_train.svmlight" % (source))
        filepath = os.path.join(settings.DOM_ADAPT_DIR,
                                "{}_train.pkl".format(source))
        assert (os.path.exists(filepath))
        # train_dataset = AmazonDataset(filepath)
        train_dataset = ProcessedCNNInputDataset(source, "train")
        train_loader = data.DataLoader(train_dataset,
                                       batch_size=args.batch_size,
                                       shuffle=False,
                                       num_workers=0)
        train_loaders.append(train_loader)

        # cur_valid_dataset = ProcessedCNNInputDataset(source, "valid")
        # cur_valid_loader = data.DataLoader(
        #     cur_valid_dataset,
        #     batch_size=args.batch_size,
        #     shuffle=False,
        #     num_workers=0
        # )
        # valid_loaders_src.append(cur_valid_loader)
        #
        # cur_test_dataset = ProcessedCNNInputDataset(source, "test")
        # cur_test_loader = data.DataLoader(
        #     cur_test_dataset,
        #     batch_size=args.batch_size,
        #     shuffle=False,
        #     num_workers=0
        # )
        # test_loaders_src.append(cur_test_loader)

        if args.metric == "biaffine":
            U = torch.FloatTensor(encoders[0].n_d, encoders[0].n_d)
            W = torch.FloatTensor(encoders[0].n_d, 1)
            nn.init.xavier_uniform(W)
            Ws.append(W)
            V = torch.FloatTensor(encoders[0].n_d, 1)
            nn.init.xavier_uniform(V)
            Vs.append(V)
        else:
            U = torch.FloatTensor(encoders[0].n_d, args.m_rank)

        nn.init.xavier_uniform_(U)
        Us.append(U)
        P = torch.FloatTensor(encoders[0].n_d, args.m_rank)
        nn.init.xavier_uniform_(P)
        Ps.append(P)
        N = torch.FloatTensor(encoders[0].n_d, args.m_rank)
        nn.init.xavier_uniform_(N)
        Ns.append(N)
        # Ms.append(U.mm(U.t()))

    # unl_filepath = os.path.join(DATA_DIR, "%s_train.svmlight" % (args.test))
    unl_filepath = os.path.join(settings.DOM_ADAPT_DIR,
                                "{}_train.pkl".format(args.test))
    print("****************", unl_filepath)
    assert (os.path.exists(unl_filepath))
    # unl_dataset = AmazonDomainDataset(unl_filepath)  # using domain as labels
    unl_dataset = OAGDomainDataset(args.test, "train")
    unl_loader = data.DataLoader(unl_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=0)

    train_dataset_dst = ProcessedCNNInputDataset(args.test, "train")
    train_loader_dst = data.DataLoader(train_dataset_dst,
                                       batch_size=args.batch_size,
                                       shuffle=False,
                                       num_workers=0)

    # valid_filepath = os.path.join(DATA_DIR, "%s_test.svmlight" % (args.test))  # No dev files
    # valid_dataset = AmazonDataset(valid_filepath)
    valid_dataset = ProcessedCNNInputDataset(args.test, "valid")
    print("valid y", len(valid_dataset), valid_dataset.y)
    valid_loader = data.DataLoader(valid_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   num_workers=0)

    # test_filepath = os.path.join(DATA_DIR, "%s_test.svmlight" % (args.test))
    # assert (os.path.exists(test_filepath))
    # test_dataset = AmazonDataset(test_filepath)
    test_dataset = ProcessedCNNInputDataset(args.test, "test")
    test_loader = data.DataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=0)
    say("Corpus loaded.\n")

    classifiers = []
    for source in source_train_sets:  # only one layer
        classifier = nn.Linear(encoders[0].n_out, 2)  # binary classification
        # classifier = encoder.fc_out
        # nn.init.xavier_normal(classifier.weight)
        # nn.init.constant(classifier.bias, 0.1)
        classifiers.append(classifier)

    classifier_dst = nn.Linear(encoder_dst.n_out, 2)
    # classifier_mix = nn.Linear(2, 2)
    classifier_mix = WeightScaler()

    critic = critic_class(encoders[0], args)

    # if args.save_model:
    #     say(colored("Save model to {}\n".format(args.save_model + ".init"), 'red'))
    #     torch.save([encoder, classifiers, Us, Ps, Ns], args.save_model + ".init")

    if args.cuda:
        map(lambda m: m.cuda(),
            [encoder_dst, critic, classifier_dst, classifier_mix] + encoders +
            classifiers)
        Us = [Variable(U.cuda(), requires_grad=True) for U in Us]
        Ps = [Variable(P.cuda(), requires_grad=True) for P in Ps]
        Ns = [Variable(N.cuda(), requires_grad=True) for N in Ns]
        if args.metric == "biaffine":
            Ws = [Variable(W.cuda(), requires_grad=True) for W in Ws]
            Vs = [Variable(V.cuda(), requires_grad=True) for V in Vs]

    # Ms = [ U.mm(U.t()) for U in Us ]

    # say("\nEncoder: {}\n".format(encoder))
    for i, classifier in enumerate(classifiers):
        say("Classifier-{}: {}\n".format(i, classifier))
    say("Critic: {}\n".format(critic))

    requires_grad = lambda x: x.requires_grad
    # task_params = list(encoder.parameters())
    task_params = []
    for encoder in encoders:
        task_params += encoder.parameters()
    task_params += encoder_dst.parameters()
    for classifier in classifiers:
        task_params += list(classifier.parameters())
    task_params += classifier_dst.parameters()
    task_params += classifier_mix.parameters()
    # task_params += [classifier_mix.data]
    task_params += list(critic.parameters())
    task_params += Us
    task_params += Ps
    task_params += Ns
    if args.metric == "biaffine":
        task_params += Ws
        task_params += Vs

    optim_model = optim.Adagrad(  # use adagrad instead of adam
        filter(requires_grad, task_params),
        lr=args.lr,
        weight_decay=1e-4)

    say("Training will begin from scratch\n")

    best_dev = 0
    best_test = 0
    iter_cnt = 0

    # encoder.load_state_dict(torch.load(os.path.join(settings.OUT_VENUE_DIR, "venue-matching-cnn.mdl")))

    for epoch in range(args.max_epoch):
        say("epoch: {}\n".format(epoch))
        if args.metric == "biaffine":
            mats = [Us, Ws, Vs]
        else:
            mats = [Us, Ps, Ns]

        iter_cnt = train_epoch(
            iter_cnt, [encoders, encoder_dst],
            [classifiers, classifier_dst, classifier_mix], critic, mats,
            [train_loaders, train_loader_dst, unl_loader, valid_loader], args,
            optim_model, epoch)

        # thrs, metrics_val, src_weights_val = evaluate_cross(
        #     encoder, classifiers,
        #     mats,
        #     [train_loaders, valid_loaders_src],
        #     return_best_thrs=True,
        #     args=args
        # )
        #
        # _, metrics_test, src_weights_test = evaluate_cross(
        #     encoder, classifiers,
        #     mats,
        #     [train_loaders, test_loaders_src],
        #     return_best_thrs=False,
        #     args=args,
        #     thr=thrs
        # )

        thr, metrics_val = evaluate(
            epoch, [encoders, encoder_dst],
            [classifiers, classifier_dst, classifier_mix], mats,
            [train_loaders, valid_loader], True, args)
        # say("Dev accuracy/oracle: {:.4f}/{:.4f}\n".format(curr_dev, oracle_curr_dev))
        _, metrics_test = evaluate(
            epoch, [encoders, encoder_dst],
            [classifiers, classifier_dst, classifier_mix],
            mats, [train_loaders, test_loader],
            False,
            args,
            thr=thr)
        # say("Test accuracy/oracle: {:.4f}/{:.4f}\n".format(curr_test, oracle_curr_test))

        # if curr_dev >= best_dev:
        #     best_dev = curr_dev
        #     best_test = curr_test
        #     print(confusion_mat)
        #     if args.save_model:
        #         say(colored("Save model to {}\n".format(args.save_model + ".best"), 'red'))
        #         torch.save([encoder, classifiers, Us, Ps, Ns], args.save_model + ".best")
        say("\n")

    say(colored("Best test accuracy {:.4f}\n".format(best_test), 'red'))
def train_single_domain_transfer(args, wf, src, repeat_seed):
    tb_dir = 'runs/{}_sup_base_{}_source_{}_train_num_{}_tune_{}_{}'.format(
        args.test, args.base_model, src, args.train_num, args.n_tune,
        repeat_seed)
    if os.path.exists(tb_dir) and os.path.isdir(tb_dir):
        shutil.rmtree(tb_dir)
    writer = SummaryWriter(tb_dir)

    args.cuda = not args.no_cuda and torch.cuda.is_available()
    say('cuda is available %s\n' % args.cuda)

    np.random.seed(args.seed + repeat_seed)
    torch.manual_seed(args.seed + repeat_seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed + repeat_seed)

    pretrain_emb = torch.load(
        os.path.join(settings.OUT_DIR, "rnn_init_word_emb.emb"))

    src_model_dir = os.path.join(settings.OUT_DIR, src)

    if args.base_model == "cnn":
        encoder_src = CNNMatchModel(input_matrix_size1=args.matrix_size1,
                                    input_matrix_size2=args.matrix_size2,
                                    mat1_channel1=args.mat1_channel1,
                                    mat1_kernel_size1=args.mat1_kernel_size1,
                                    mat1_channel2=args.mat1_channel2,
                                    mat1_kernel_size2=args.mat1_kernel_size2,
                                    mat1_hidden=args.mat1_hidden,
                                    mat2_channel1=args.mat2_channel1,
                                    mat2_kernel_size1=args.mat2_kernel_size1,
                                    mat2_hidden=args.mat2_hidden)
    elif args.base_model == "rnn":
        encoder_src = BiLSTM(pretrain_emb=pretrain_emb,
                             vocab_size=args.max_vocab_size,
                             embedding_size=args.embedding_size,
                             hidden_size=args.hidden_size,
                             dropout=args.dropout)
    else:
        raise NotImplementedError
    if args.cuda:
        encoder_src.load_state_dict(
            torch.load(
                os.path.join(
                    src_model_dir,
                    "{}-match-best-now-train-num-{}-try-{}.mdl".format(
                        args.base_model, args.train_num, repeat_seed))))
    else:
        encoder_src.load_state_dict(
            torch.load(os.path.join(
                src_model_dir,
                "{}-match-best-now-train-num-{}-try-{}.mdl".format(
                    args.base_model, args.train_num, repeat_seed)),
                       map_location=torch.device('cpu')))

    dst_model_dir = os.path.join(settings.OUT_DIR, args.test)
    if args.base_model == "cnn":
        encoder_dst_pretrain = CNNMatchModel(
            input_matrix_size1=args.matrix_size1,
            input_matrix_size2=args.matrix_size2,
            mat1_channel1=args.mat1_channel1,
            mat1_kernel_size1=args.mat1_kernel_size1,
            mat1_channel2=args.mat1_channel2,
            mat1_kernel_size2=args.mat1_kernel_size2,
            mat1_hidden=args.mat1_hidden,
            mat2_channel1=args.mat2_channel1,
            mat2_kernel_size1=args.mat2_kernel_size1,
            mat2_hidden=args.mat2_hidden)
    elif args.base_model == "rnn":
        encoder_dst_pretrain = BiLSTM(pretrain_emb=pretrain_emb,
                                      vocab_size=args.max_vocab_size,
                                      embedding_size=args.embedding_size,
                                      hidden_size=args.hidden_size,
                                      dropout=args.dropout)
    else:
        raise NotImplementedError

    encoder_dst_pretrain.load_state_dict(
        torch.load(
            os.path.join(
                dst_model_dir,
                "{}-match-best-now-train-num-{}-try-{}.mdl".format(
                    args.base_model, args.train_num, repeat_seed))))

    # args = argparser.parse_args()
    say(args)
    print()

    say("Transferring from %s to %s\n" % (src, args.test))

    if args.base_model == "cnn":
        train_dataset_dst = ProcessedCNNInputDataset(args.test, "train",
                                                     args.train_num)
        valid_dataset = ProcessedCNNInputDataset(args.test, "valid")
        test_dataset = ProcessedCNNInputDataset(args.test, "test")

    elif args.base_model == "rnn":
        train_dataset_dst = ProcessedRNNInputDataset(args.test, "train",
                                                     args.train_num)
        valid_dataset = ProcessedRNNInputDataset(args.test, "valid")
        test_dataset = ProcessedRNNInputDataset(args.test, "test")
    else:
        raise NotImplementedError

    print("train num", len(train_dataset_dst))

    train_loader_dst = data.DataLoader(train_dataset_dst,
                                       batch_size=args.batch_size,
                                       shuffle=False,
                                       num_workers=0)

    valid_loader = data.DataLoader(valid_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   num_workers=0)

    test_loader = data.DataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=0)

    say("Corpus loaded.\n")

    if args.n_tune == 3:
        classifier = nn.Sequential(
            nn.Linear(encoder_src.n_out, 64),
            nn.ReLU(),
            nn.Linear(64, 16),
            nn.ReLU(),
            nn.Linear(16, 2),
        )
    elif args.n_tune == 1:
        classifier = nn.Sequential(nn.Linear(16, 2))
    # elif args.n_tune == 0:
    #     classifier = None
    else:
        raise NotImplementedError

    if args.cuda:
        encoder_dst_pretrain.cuda()
        encoder_src.cuda()
        classifier.cuda()

    requires_grad = lambda x: x.requires_grad
    task_params = list(classifier.parameters())

    if args.base_model == "cnn":
        optim_model = optim.Adagrad(filter(requires_grad, task_params),
                                    lr=args.lr,
                                    weight_decay=args.weight_decay)
    elif args.base_model == "rnn":
        optim_model = optim.Adam(filter(requires_grad, task_params),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    else:
        raise NotImplementedError

    say("Training will begin from scratch\n")

    iter_cnt = 0
    min_loss_val = None
    max_auc_val = None
    best_test_results = None
    weights_sources = None
    model_dir = os.path.join(settings.OUT_DIR, args.test)

    cur_src_idx = source_to_idx[src]

    for epoch in range(args.max_epoch):
        print("training epoch", epoch)

        iter_cnt = train_epoch(iter_cnt, [encoder_src, encoder_dst_pretrain],
                               classifier, train_loader_dst, args, optim_model,
                               epoch, writer)

        thr, metrics_val = evaluate(epoch, [encoder_src, encoder_dst_pretrain],
                                    classifier, valid_loader, True, args,
                                    writer)

        _, metrics_test = evaluate(epoch, [encoder_src, encoder_dst_pretrain],
                                   classifier,
                                   test_loader,
                                   False,
                                   args,
                                   writer,
                                   thr=thr)

        if min_loss_val is None or min_loss_val > metrics_val[0]:
            print("change val loss from {} to {}".format(
                min_loss_val, metrics_val[0]))
            min_loss_val = metrics_val[0]
            best_test_results = metrics_test
            torch.save(
                classifier,
                os.path.join(
                    model_dir,
                    "{}_{}_classifier_from_src_{}_train_num_{}_try_{}.mdl".
                    format(args.test, args.base_model, cur_src_idx,
                           args.train_num, repeat_seed)))

    print()
    print(
        "src: {}, min valid loss {:.4f}, best test metrics: AUC: {:.2f}, Prec: {:.2f}, Rec: {:.2f}, F1: {:.2f}\n"
        .format(src, min_loss_val, best_test_results[1] * 100,
                best_test_results[2] * 100, best_test_results[3] * 100,
                best_test_results[4] * 100))

    wf.write(
        "from src {}, min valid loss {:.4f}, best test metrics: AUC: {:.2f}, Prec: {:.2f}, Rec: {:.2f}, F1: {:.2f}\n"
        .format(src, min_loss_val, best_test_results[1] * 100,
                best_test_results[2] * 100, best_test_results[3] * 100,
                best_test_results[4] * 100))
    writer.close()

    return min_loss_val, best_test_results[1:]
Beispiel #7
0
def train_one_time(args, wf, repeat_seed):
    tb_dir = 'runs/{}_cnn_train_num_{}_{}'.format(args.entity_type,
                                                  args.train_num, repeat_seed)
    if os.path.exists(tb_dir) and os.path.isdir(tb_dir):
        shutil.rmtree(tb_dir)
    writer = SummaryWriter(tb_dir)

    args.cuda = not args.no_cuda and torch.cuda.is_available()
    logger.info('cuda is available %s', args.cuda)

    np.random.seed(args.seed + repeat_seed)
    torch.manual_seed(args.seed + repeat_seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed + repeat_seed)

    dataset = ProcessedCNNInputDataset(args.entity_type, "train",
                                       args.train_num, repeat_seed)
    dataset_valid = ProcessedCNNInputDataset(args.entity_type, "valid")
    dataset_test = ProcessedCNNInputDataset(args.entity_type, "test")
    N = len(dataset)
    N_valid = len(dataset_valid)
    N_test = len(dataset_test)
    print("n_train", N)
    train_loader = DataLoader(dataset,
                              batch_size=args.batch,
                              sampler=ChunkSampler(N, 0))
    valid_loader = DataLoader(dataset_valid,
                              batch_size=args.batch,
                              sampler=ChunkSampler(N_valid, 0))
    test_loader = DataLoader(dataset_test,
                             batch_size=args.batch,
                             sampler=ChunkSampler(N_test, 0))

    model = CNNMatchModel(input_matrix_size1=args.matrix_size1,
                          input_matrix_size2=args.matrix_size2,
                          mat1_channel1=args.mat1_channel1,
                          mat1_kernel_size1=args.mat1_kernel_size1,
                          mat1_channel2=args.mat1_channel2,
                          mat1_kernel_size2=args.mat1_kernel_size2,
                          mat1_hidden=args.mat1_hidden,
                          mat2_channel1=args.mat2_channel1,
                          mat2_kernel_size1=args.mat2_kernel_size1,
                          mat2_hidden=args.mat2_hidden)
    model = model.float()

    if args.cuda:
        model.cuda()

    optimizer = optim.Adagrad(
        model.parameters(),
        lr=args.lr,
        # initial_accumulator_value=args.initial_accumulator_value,
        weight_decay=args.weight_decay)
    t_total = time.time()
    logger.info("training...")

    model_dir = join(settings.OUT_DIR, args.entity_type)
    os.makedirs(model_dir, exist_ok=True)
    n_paras = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("number of paras:", n_paras)

    evaluate(0, test_loader, model, writer, thr=None, args=args)

    min_loss_val = None
    best_test_metrics = None

    for epoch in range(args.epochs):
        print("training epoch", epoch)
        metrics_val, metrics_test = train(epoch,
                                          train_loader,
                                          valid_loader,
                                          test_loader,
                                          model,
                                          optimizer,
                                          writer,
                                          args=args)
        if metrics_val is not None:
            if min_loss_val is None or min_loss_val > metrics_val[0]:
                min_loss_val = metrics_val[0]
                best_test_metrics = metrics_test
                torch.save(
                    model.state_dict(),
                    join(
                        model_dir,
                        "cnn-match-best-now-train-num-{}-try-{}.mdl".format(
                            args.train_num, repeat_seed)))

    logger.info("optimization Finished!")
    logger.info("total time elapsed: {:.4f}s".format(time.time() - t_total))

    print(
        "min valid loss {:.4f}, best test metrics: AUC: {:.2f}, Prec: {:.4f}, Rec: {:.4f}, F1: {:.4f}"
        .format(min_loss_val, best_test_metrics[1], best_test_metrics[2],
                best_test_metrics[3], best_test_metrics[4]))

    # with open(join(model_dir, "{}_cnn_train_num_{}_results.txt".format(args.entity_type, args.train_num)), "w") as wf:
    wf.write(
        "min valid loss {:.4f}, best test metrics: AUC: {:.2f}, Prec: {:.2f}, Rec: {:.2f}, F1: {:.2f}\n\n"
        .format(min_loss_val, best_test_metrics[1] * 100,
                best_test_metrics[2] * 100, best_test_metrics[3] * 100,
                best_test_metrics[4] * 100))
    # wf.write(json.dumps(vars(args)) + "\n")

    writer.close()