Ejemplo n.º 1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train', default='train_wiki', help='train file')
    parser.add_argument('--val', default='val_wiki', help='val file')
    parser.add_argument('--test', default='test_wiki', help='test file')
    parser.add_argument('--adv', default=None, help='adv file')
    parser.add_argument('--trainN', default=10, type=int, help='N in train')
    parser.add_argument('--N', default=5, type=int, help='N way')
    parser.add_argument('--K', default=5, type=int, help='K shot')
    parser.add_argument('--Q',
                        default=5,
                        type=int,
                        help='Num of query per class')
    parser.add_argument('--batch_size', default=4, type=int, help='batch size')
    parser.add_argument('--train_iter',
                        default=30000,
                        type=int,
                        help='num of iters in training')
    parser.add_argument('--val_iter',
                        default=1000,
                        type=int,
                        help='num of iters in validation')
    parser.add_argument('--test_iter',
                        default=10000,
                        type=int,
                        help='num of iters in testing')
    parser.add_argument('--val_step',
                        default=2000,
                        type=int,
                        help='val after training how many iters')
    parser.add_argument('--model', default='proto', help='model name')
    parser.add_argument('--encoder',
                        default='cnn',
                        help='encoder: cnn or bert or roberta')
    parser.add_argument('--max_length',
                        default=128,
                        type=int,
                        help='max length')
    parser.add_argument('--lr', default=1e-1, type=float, help='learning rate')
    parser.add_argument('--weight_decay',
                        default=1e-5,
                        type=float,
                        help='weight decay')
    parser.add_argument('--dropout',
                        default=0.0,
                        type=float,
                        help='dropout rate')
    parser.add_argument('--na_rate',
                        default=0,
                        type=int,
                        help='NA rate (NA = Q * na_rate)')
    parser.add_argument('--grad_iter',
                        default=1,
                        type=int,
                        help='accumulate gradient every x iterations')
    parser.add_argument('--optim', default='sgd', help='sgd / adam / adamw')
    parser.add_argument('--hidden_size',
                        default=230,
                        type=int,
                        help='hidden size')
    parser.add_argument('--load_ckpt', default=None, help='load ckpt')
    parser.add_argument('--save_ckpt', default=None, help='save ckpt')
    parser.add_argument('--fp16',
                        action='store_true',
                        help='use nvidia apex fp16')
    parser.add_argument('--only_test', action='store_true', help='only test')

    # only for bert / roberta
    parser.add_argument('--pair', action='store_true', help='use pair model')
    parser.add_argument('--pretrain_ckpt',
                        default=None,
                        help='bert / roberta pre-trained checkpoint')
    parser.add_argument(
        '--cat_entity_rep',
        action='store_true',
        help='concatenate entity representation as sentence rep')

    # only for prototypical networks
    parser.add_argument('--dot',
                        action='store_true',
                        help='use dot instead of L2 distance for proto')

    opt = parser.parse_args()
    trainN = opt.trainN
    N = opt.N
    K = opt.K
    Q = opt.Q
    batch_size = opt.batch_size
    model_name = opt.model
    encoder_name = opt.encoder
    max_length = opt.max_length

    print("{}-way-{}-shot Few-Shot Relation Classification".format(N, K))
    print("model: {}".format(model_name))
    print("encoder: {}".format(encoder_name))
    print("max_length: {}".format(max_length))

    if encoder_name == 'cnn':
        try:
            glove_mat = np.load('./pretrain/glove/glove_mat.npy')
            glove_word2id = json.load(
                open('./pretrain/glove/glove_word2id.json'))
        except:
            raise Exception(
                "Cannot find glove files. Run glove/download_glove.sh to download glove files."
            )
        sentence_encoder = CNNSentenceEncoder(glove_mat, glove_word2id,
                                              max_length)
    elif encoder_name == 'bert':
        pretrain_ckpt = opt.pretrain_ckpt or 'bert-base-uncased'
        if opt.pair:
            sentence_encoder = BERTPAIRSentenceEncoder(pretrain_ckpt,
                                                       max_length)
        else:
            sentence_encoder = BERTSentenceEncoder(
                pretrain_ckpt, max_length, cat_entity_rep=opt.cat_entity_rep)
    elif encoder_name == 'roberta':
        pretrain_ckpt = opt.pretrain_ckpt or 'roberta-base'
        if opt.pair:
            sentence_encoder = RobertaPAIRSentenceEncoder(
                pretrain_ckpt, max_length)
        else:
            sentence_encoder = RobertaSentenceEncoder(
                pretrain_ckpt, max_length, cat_entity_rep=opt.cat_entity_rep)
    else:
        raise NotImplementedError

    if opt.pair:
        train_data_loader = get_loader_pair(opt.train,
                                            sentence_encoder,
                                            N=trainN,
                                            K=K,
                                            Q=Q,
                                            na_rate=opt.na_rate,
                                            batch_size=batch_size,
                                            encoder_name=encoder_name)
        val_data_loader = get_loader_pair(opt.val,
                                          sentence_encoder,
                                          N=N,
                                          K=K,
                                          Q=Q,
                                          na_rate=opt.na_rate,
                                          batch_size=batch_size,
                                          encoder_name=encoder_name)
        test_data_loader = get_loader_pair(opt.test,
                                           sentence_encoder,
                                           N=N,
                                           K=K,
                                           Q=Q,
                                           na_rate=opt.na_rate,
                                           batch_size=batch_size,
                                           encoder_name=encoder_name)
    else:
        train_data_loader = get_loader(opt.train,
                                       sentence_encoder,
                                       N=trainN,
                                       K=K,
                                       Q=Q,
                                       na_rate=opt.na_rate,
                                       batch_size=batch_size)
        val_data_loader = get_loader(opt.val,
                                     sentence_encoder,
                                     N=N,
                                     K=K,
                                     Q=Q,
                                     na_rate=opt.na_rate,
                                     batch_size=batch_size)
        test_data_loader = get_loader(opt.test,
                                      sentence_encoder,
                                      N=N,
                                      K=K,
                                      Q=Q,
                                      na_rate=opt.na_rate,
                                      batch_size=batch_size)
        if opt.adv:
            adv_data_loader = get_loader_unsupervised(opt.adv,
                                                      sentence_encoder,
                                                      N=trainN,
                                                      K=K,
                                                      Q=Q,
                                                      na_rate=opt.na_rate,
                                                      batch_size=batch_size)

    if opt.optim == 'sgd':
        pytorch_optim = optim.SGD
    elif opt.optim == 'adam':
        pytorch_optim = optim.Adam
    elif opt.optim == 'adamw':
        from transformers import AdamW
        pytorch_optim = AdamW
    else:
        raise NotImplementedError
    if opt.adv:
        d = Discriminator(opt.hidden_size)
        framework = FewShotREFramework(train_data_loader,
                                       val_data_loader,
                                       test_data_loader,
                                       adv_data_loader,
                                       adv=opt.adv,
                                       d=d)
    else:
        framework = FewShotREFramework(train_data_loader, val_data_loader,
                                       test_data_loader)

    prefix = '-'.join(
        [model_name, encoder_name, opt.train, opt.val,
         str(N), str(K)])
    if opt.adv is not None:
        prefix += '-adv_' + opt.adv
    if opt.na_rate != 0:
        prefix += '-na{}'.format(opt.na_rate)
    if opt.dot:
        prefix += '-dot'
    if opt.cat_entity_rep:
        prefix += '-catentity'

    if model_name == 'proto':
        model = Proto(sentence_encoder, dot=opt.dot)
    elif model_name == 'gnn':
        model = GNN(sentence_encoder, N, hidden_size=opt.hidden_size)
    elif model_name == 'snail':
        model = SNAIL(sentence_encoder, N, K, hidden_size=opt.hidden_size)
    elif model_name == 'metanet':
        model = MetaNet(N, K, sentence_encoder.embedding, max_length)
    elif model_name == 'siamese':
        model = Siamese(sentence_encoder,
                        hidden_size=opt.hidden_size,
                        dropout=opt.dropout)
    elif model_name == 'pair':
        model = Pair(sentence_encoder, hidden_size=opt.hidden_size)
    else:
        raise NotImplementedError

    if not os.path.exists('checkpoint'):
        os.mkdir('checkpoint')
    ckpt = 'checkpoint/{}.pth.tar'.format(prefix)
    if opt.save_ckpt:
        ckpt = opt.save_ckpt

    if torch.cuda.is_available():
        model.cuda()

    if not opt.only_test:
        if encoder_name in ['bert', 'roberta']:
            bert_optim = True
        else:
            bert_optim = False

        framework.train(model,
                        prefix,
                        batch_size,
                        trainN,
                        N,
                        K,
                        Q,
                        pytorch_optim=pytorch_optim,
                        load_ckpt=opt.load_ckpt,
                        save_ckpt=ckpt,
                        na_rate=opt.na_rate,
                        val_step=opt.val_step,
                        fp16=opt.fp16,
                        pair=opt.pair,
                        train_iter=opt.train_iter,
                        val_iter=opt.val_iter,
                        bert_optim=bert_optim)
    else:
        ckpt = opt.load_ckpt

    acc = framework.eval(model,
                         batch_size,
                         N,
                         K,
                         Q,
                         opt.test_iter,
                         na_rate=opt.na_rate,
                         ckpt=ckpt,
                         pair=opt.pair)
    print("RESULT: %.2f" % (acc * 100))
Ejemplo n.º 2
0
def main():
    parser = argparse.ArgumentParser()

    # 模型相关
    parser.add_argument('--do_train', default=True, type=bool, help='do train')
    parser.add_argument('--do_eval', default=True, type=bool, help='do eval')
    parser.add_argument('--do_predict',
                        default=False,
                        type=bool,
                        help='do predict')
    parser.add_argument('--do_cn_eval',
                        default=False,
                        type=bool,
                        help='do CN eval')

    parser.add_argument(
        '--proto_emb',
        default=False,
        help='Get root cause proto emb or sentence emb. Require do_predict=True'
    )
    parser.add_argument('--train_file',
                        default='./data/source_add_CN_V2.xlsx',
                        help='source file')

    parser.add_argument('--trainN', default=5, type=int, help='N in train')
    parser.add_argument('--N', default=5, type=int, help='N way')
    parser.add_argument('--K', default=3, type=int, help='K shot')
    parser.add_argument('--Q',
                        default=2,
                        type=int,
                        help='Num of query per class')
    parser.add_argument('--batch_size', default=8, type=int, help='batch size')
    parser.add_argument('--train_iter',
                        default=10000,
                        type=int,
                        help='num of iters in training')
    parser.add_argument('--warmup_rate', default=0.1, type=float)
    parser.add_argument('--max_length',
                        default=128,
                        type=int,
                        help='max length')
    parser.add_argument('--lr', default=1e-5, type=float, help='learning rate')
    parser.add_argument('--dropout',
                        default=0.0,
                        type=float,
                        help='dropout rate')
    parser.add_argument('--seed', default=100, type=int)  # 100

    # 保存与加载
    parser.add_argument('--load_ckpt',
                        default='./check_points/model_54000.bin',
                        help='load ckpt')
    parser.add_argument('--save_ckpt',
                        default='./check_points/',
                        help='save ckpt')
    parser.add_argument('--save_emb',
                        default='./data/emb.json',
                        help='save embedding')
    parser.add_argument('--save_root_emb',
                        default='./data/root_emb.json',
                        help='save embedding')

    parser.add_argument('--use_cuda', default=True, help='whether to use cuda')
    parser.add_argument('--eval_step', default=100)
    parser.add_argument('--save_step', default=500)
    parser.add_argument('--threshold', default=5)

    # bert pretrain
    parser.add_argument("--vocab_file",
                        default="./pretrain/vocab.txt",
                        type=str,
                        help="Init vocab to resume training from.")
    parser.add_argument("--config_path",
                        default="./pretrain/bert_config.json",
                        type=str,
                        help="Init config to resume training from.")
    parser.add_argument("--init_checkpoint",
                        default="./pretrain/pytorch_model.bin",
                        type=str,
                        help="Init checkpoint to resume training from.")

    opt = parser.parse_args()
    trainN = opt.trainN
    K = opt.K
    Q = opt.Q
    batch_size = opt.batch_size
    max_length = opt.max_length

    logger.info("{}-way-{}-shot Few-Shot Dignose".format(trainN, K))
    logger.info("max_length: {}".format(max_length))

    random.seed(opt.seed)
    np.random.seed(opt.seed)
    torch.manual_seed(opt.seed)

    if not os.path.exists(opt.save_ckpt):
        os.mkdir(opt.save_ckpt)

    bert_tokenizer = BertTokenizer.from_pretrained(opt.vocab_file)
    bert_config = BertConfig.from_pretrained(opt.config_path)
    bert_model = BertModel.from_pretrained(opt.init_checkpoint,
                                           config=bert_config)
    model = Proto(bert_model, opt)
    if opt.use_cuda:
        model.cuda()

    if opt.do_train:
        train_data, eval_data = read_data(opt.train_file, opt.threshold)
        train_data_loader = get_loader(train_data,
                                       bert_tokenizer,
                                       max_length,
                                       N=trainN,
                                       K=K,
                                       Q=Q,
                                       batch_size=batch_size)

        framework = FewShotREFramework(tokenizer=bert_tokenizer,
                                       train_data_loader=train_data_loader,
                                       train_data=train_data,
                                       eval_data=eval_data)
        framework.train(model, batch_size, trainN, K, Q, opt)

    if opt.do_eval:
        train_data, eval_data = read_data(opt.train_file, opt.threshold)
        state_dict = torch.load(opt.load_ckpt)
        own_state = bert_model.state_dict()
        for name, param in state_dict.items():
            name = name.replace('sentence_encoder.module.', '')
            if name not in own_state:
                continue
            own_state[name].copy_(param)
        step = opt.load_ckpt.split('/')[-1].replace('model_', '').split('.')[0]
        bert_model.eval()
        train_data_emb, train_rc_emb = util.get_emb(bert_model, bert_tokenizer,
                                                    train_data, opt)
        eval_data_emb, _ = util.get_emb(bert_model, bert_tokenizer, eval_data,
                                        opt)
        acc1 = util.single_acc(train_data_emb, eval_data_emb)
        acc2 = util.proto_acc(train_rc_emb, eval_data_emb)
        acc3 = util.policy_acc(train_data_emb, eval_data_emb)
        logger.info(
            "single eval accuracy: [top1: %.4f] [top3: %.4f] [top5: %.4f]" %
            (acc1[0], acc1[1], acc1[2]))
        logger.info(
            "proto eval accuracy: [top1: %.4f] [top3: %.4f] [top5: %.4f]" %
            (acc2[0], acc2[1], acc2[2]))
        logger.info(
            "policy eval accuracy: [top1: %.4f] [top3: %.4f] [top5: %.4f]" %
            (acc3[0], acc3[1], acc3[2]))

        with open('./data/train_emb_%s.json' % step, 'w',
                  encoding='utf8') as f:
            json.dump(train_data_emb, f, ensure_ascii=False)

        with open('./data/train_rc_emb_%s.json' % step, 'w',
                  encoding='utf8') as f:
            json.dump(train_rc_emb, f, ensure_ascii=False)

        with open('./data/eval_emb_%s.json' % step, 'w', encoding='utf8') as f:
            json.dump(eval_data_emb, f, ensure_ascii=False)

    if opt.do_predict:
        test_data = read_data(opt.train_file, opt.threshold, False)
        # predict proto emb or sentence emb
        state_dict = torch.load(opt.load_ckpt)
        own_state = bert_model.state_dict()
        for name, param in state_dict.items():
            name = name.replace('sentence_encoder.module.', '')
            if name not in own_state:
                continue
            own_state[name].copy_(param)
        bert_model.eval()
        id_to_emd, root_cause_emb = util.get_emb(bert_model, bert_tokenizer,
                                                 test_data, opt)

        if opt.save_emb and opt.save_root_emb:
            with open(opt.save_emb, 'w', encoding='utf8') as f:
                json.dump(id_to_emd, f, ensure_ascii=False)

            with open(opt.save_root_emb, 'w', encoding='utf8') as f:
                json.dump(root_cause_emb, f, ensure_ascii=False)
if len(sys.argv) > 3:
    K = int(sys.argv[3])

print("{}-way-{}-shot Few-Shot Relation Classification".format(N, K))
print("Model: {}".format(model_name))

max_length = 40
train_data_loader = JSONFileDataLoader('./data/train.json', './data/glove.6B.50d.json', max_length=max_length)
val_data_loader = JSONFileDataLoader('./data/val.json', './data/glove.6B.50d.json', max_length=max_length)
test_data_loader = JSONFileDataLoader('./data/test.json', './data/glove.6B.50d.json', max_length=max_length)

framework = FewShotREFramework(train_data_loader, val_data_loader, test_data_loader)
sentence_encoder = CNNSentenceEncoder(train_data_loader.word_vec_mat, max_length)

if model_name == 'proto':
    model = Proto(sentence_encoder)
    framework.train(model, model_name, 4, 20, N, K, 5)
elif model_name == 'gnn':
    model = GNN(sentence_encoder, N)
    framework.train(model, model_name, 2, N, N, K, 1, learning_rate=1e-3, weight_decay=0, optimizer=optim.Adam)
elif model_name == 'snail':
    print("HINT: SNAIL works only in PyTorch 0.3.1")
    model = SNAIL(sentence_encoder, N, K)
    framework.train(model, model_name, 25, N, N, K, 1, learning_rate=1e-2, weight_decay=0, optimizer=optim.SGD)
elif model_name == 'metanet':
    model = MetaNet(N, K, train_data_loader.word_vec_mat, max_length)
    framework.train(model, model_name, 1, N, N, K, 1, learning_rate=5e-3, weight_decay=0, optimizer=optim.Adam, train_iter=300000)
else:
    raise NotImplementedError

Ejemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train', default='train_wiki', help='train file')
    parser.add_argument('--val', default='val_wiki', help='val file')
    parser.add_argument('--test', default='test_wiki', help='test file')
    parser.add_argument('--adv', default=None, help='adv file')
    parser.add_argument('--trainN', default=10, type=int, help='N in train')
    parser.add_argument('--N', default=5, type=int, help='N way')
    parser.add_argument('--K', default=5, type=int, help='K shot')
    parser.add_argument('--Q',
                        default=5,
                        type=int,
                        help='Num of query per class')
    parser.add_argument('--batch_size', default=4, type=int, help='batch size')
    parser.add_argument('--train_iter',
                        default=20000,
                        type=int,
                        help='num of iters in training')
    parser.add_argument('--val_iter',
                        default=1000,
                        type=int,
                        help='num of iters in validation')
    parser.add_argument('--test_iter',
                        default=2000,
                        type=int,
                        help='num of iters in testing')
    parser.add_argument('--val_step',
                        default=2000,
                        type=int,
                        help='val after training how many iters')
    parser.add_argument('--model', default='proto', help='model name')
    parser.add_argument('--encoder',
                        default='cnn',
                        help='encoder: cnn or bert')
    parser.add_argument('--max_length',
                        default=128,
                        type=int,
                        help='max length')
    parser.add_argument('--lr', default=1e-1, type=float, help='learning rate')
    parser.add_argument('--weight_decay',
                        default=1e-5,
                        type=float,
                        help='weight decay')
    parser.add_argument('--dropout',
                        default=0.0,
                        type=float,
                        help='dropout rate')
    parser.add_argument('--na_rate',
                        default=0,
                        type=int,
                        help='NA rate (NA = Q * na_rate)')
    parser.add_argument('--grad_iter',
                        default=1,
                        type=int,
                        help='accumulate gradient every x iterations')
    parser.add_argument('--optim',
                        default='sgd',
                        help='sgd / adam / bert_adam')
    parser.add_argument('--hidden_size',
                        default=230,
                        type=int,
                        help='hidden size')
    parser.add_argument('--load_ckpt', default=None, help='load ckpt')
    parser.add_argument('--save_ckpt', default=None, help='save ckpt')
    parser.add_argument('--fp16',
                        action='store_true',
                        help='use nvidia apex fp16')
    parser.add_argument('--only_test', action='store_true', help='only test')
    parser.add_argument('--pair', action='store_true', help='use pair model')
    parser.add_argument('--language', type=str, default='eng', help='language')
    parser.add_argument('--sup_cost',
                        type=int,
                        default=0,
                        help='use sup classifier')

    opt = parser.parse_args()
    trainN = opt.trainN
    N = opt.N
    K = opt.K
    Q = opt.Q
    batch_size = opt.batch_size
    model_name = opt.model
    encoder_name = opt.encoder
    max_length = opt.max_length
    sup_cost = bool(opt.sup_cost)
    print(sup_cost)

    print("{}-way-{}-shot Few-Shot Relation Classification".format(N, K))
    print("model: {}".format(model_name))
    print("encoder: {}".format(encoder_name))
    print("max_length: {}".format(max_length))

    embsize = 50
    if opt.language == 'chn':
        embsize = 100

    if encoder_name == 'cnn':
        try:
            if opt.language == 'chn':
                glove_mat = np.load('./pretrain/chinese_emb/emb.npy')
                glove_word2id = json.load(
                    open('./pretrain/chinese_emb/word2id.json'))
            else:
                glove_mat = np.load('./pretrain/glove/glove_mat.npy')
                glove_word2id = json.load(
                    open('./pretrain/glove/glove_word2id.json'))
        except:
            raise Exception(
                "Cannot find glove files. Run glove/download_glove.sh to download glove files."
            )
        sentence_encoder = CNNSentenceEncoder(glove_mat,
                                              glove_word2id,
                                              max_length,
                                              word_embedding_dim=embsize)
    elif encoder_name == 'bert':
        if opt.pair:
            if opt.language == 'chn':
                sentence_encoder = BERTPAIRSentenceEncoder(
                    'bert-base-chinese',  #'./pretrain/bert-base-uncased',
                    max_length)
            else:
                sentence_encoder = BERTPAIRSentenceEncoder(
                    'bert-base-uncased', max_length)
        else:
            if opt.language == 'chn':
                sentence_encoder = BERTSentenceEncoder(
                    'bert-base-chinese',  #'./pretrain/bert-base-uncased',
                    max_length)
            else:
                sentence_encoder = BERTSentenceEncoder('bert-base-uncased',
                                                       max_length)
    else:
        raise NotImplementedError

    if opt.pair:
        train_data_loader = get_loader_pair(opt.train,
                                            sentence_encoder,
                                            N=trainN,
                                            K=K,
                                            Q=Q,
                                            na_rate=opt.na_rate,
                                            batch_size=batch_size)
        val_data_loader = get_loader_pair(opt.val,
                                          sentence_encoder,
                                          N=N,
                                          K=K,
                                          Q=Q,
                                          na_rate=opt.na_rate,
                                          batch_size=batch_size)
        test_data_loader = get_loader_pair(opt.test,
                                           sentence_encoder,
                                           N=N,
                                           K=K,
                                           Q=Q,
                                           na_rate=opt.na_rate,
                                           batch_size=batch_size)
    else:
        train_data_loader = get_loader(opt.train,
                                       sentence_encoder,
                                       N=trainN,
                                       K=K,
                                       Q=Q,
                                       na_rate=opt.na_rate,
                                       batch_size=batch_size)
        val_data_loader = get_loader(opt.val,
                                     sentence_encoder,
                                     N=N,
                                     K=K,
                                     Q=Q,
                                     na_rate=opt.na_rate,
                                     batch_size=batch_size)
        test_data_loader = get_loader(opt.test,
                                      sentence_encoder,
                                      N=N,
                                      K=K,
                                      Q=Q,
                                      na_rate=opt.na_rate,
                                      batch_size=batch_size)
        if opt.adv:
            adv_data_loader = get_loader_unsupervised(opt.adv,
                                                      sentence_encoder,
                                                      N=trainN,
                                                      K=K,
                                                      Q=Q,
                                                      na_rate=opt.na_rate,
                                                      batch_size=batch_size)

    if opt.optim == 'sgd':
        pytorch_optim = optim.SGD
    elif opt.optim == 'adam':
        pytorch_optim = optim.Adam
    elif opt.optim == 'bert_adam':
        from transformers import AdamW
        pytorch_optim = AdamW
    else:
        raise NotImplementedError
    if opt.adv:
        d = Discriminator(opt.hidden_size)
        framework = FewShotREFramework(train_data_loader,
                                       val_data_loader,
                                       test_data_loader,
                                       adv_data_loader,
                                       adv=opt.adv,
                                       d=d)
    else:
        framework = FewShotREFramework(train_data_loader, val_data_loader,
                                       test_data_loader)

    prefix = '-'.join(
        [model_name, encoder_name, opt.train, opt.val,
         str(N), str(K)])
    if opt.adv is not None:
        prefix += '-adv_' + opt.adv
    if opt.na_rate != 0:
        prefix += '-na{}'.format(opt.na_rate)

    if model_name == 'proto':
        model = Proto(sentence_encoder, hidden_size=opt.hidden_size)
    elif model_name == 'gnn':
        model = GNN(sentence_encoder, N, use_sup_cost=sup_cost)
    elif model_name == 'snail':
        print("HINT: SNAIL works only in PyTorch 0.3.1")
        model = SNAIL(sentence_encoder, N, K)
    elif model_name == 'metanet':
        model = MetaNet(N,
                        K,
                        sentence_encoder.embedding,
                        max_length,
                        use_sup_cost=sup_cost)
    elif model_name == 'siamese':
        model = Siamese(sentence_encoder,
                        hidden_size=opt.hidden_size,
                        dropout=opt.dropout)
    elif model_name == 'pair':
        model = Pair(sentence_encoder, hidden_size=opt.hidden_size)
    else:
        raise NotImplementedError

    if not os.path.exists('checkpoint'):
        os.mkdir('checkpoint')
    ckpt = 'checkpoint/{}.pth.tar'.format(prefix)
    if opt.save_ckpt:
        ckpt = opt.save_ckpt

    if torch.cuda.is_available():
        model.cuda()

    if not opt.only_test:
        if encoder_name == 'bert':
            bert_optim = True
        else:
            bert_optim = False
        framework.train(model,
                        prefix,
                        batch_size,
                        trainN,
                        N,
                        K,
                        Q,
                        pytorch_optim=pytorch_optim,
                        load_ckpt=opt.load_ckpt,
                        save_ckpt=ckpt,
                        na_rate=opt.na_rate,
                        val_step=opt.val_step,
                        fp16=opt.fp16,
                        pair=opt.pair,
                        train_iter=opt.train_iter,
                        val_iter=opt.val_iter,
                        bert_optim=bert_optim,
                        sup_cls=sup_cost)
    else:
        ckpt = opt.load_ckpt

    acc = framework.eval(model,
                         batch_size,
                         N,
                         K,
                         Q,
                         opt.test_iter,
                         na_rate=opt.na_rate,
                         ckpt=ckpt,
                         pair=opt.pair)
    wfile = open('logs/' + ckpt.replace('checkpoint/', '') + '.txt', 'a')
    wfile.write(str(N) + '\t' + str(K) + '\t' + str(acc * 100) + '\n')
    wfile.close()
    print("RESULT: %.2f" % (acc * 100))
Ejemplo n.º 5
0
                                       './data/glove.6B.50d.json',
                                       max_length=max_length)
val_data_loader = JSONFileDataLoader('./data/val.json',
                                     './data/glove.6B.50d.json',
                                     max_length=max_length)
test_data_loader = JSONFileDataLoader('./data/test.json',
                                      './data/glove.6B.50d.json',
                                      max_length=max_length)

framework = FewShotREFramework(train_data_loader, val_data_loader,
                               test_data_loader)
sentence_encoder = CNNSentenceEncoder(train_data_loader.word_vec_mat,
                                      max_length)

if model_name == 'proto':
    model = Proto(sentence_encoder).cuda()
    ckpt = 'checkpoint/proto.pth.tar'
elif model_name == 'proto_hatt':
    model = ProtoHATT(sentence_encoder, K).cuda()
    ckpt = 'checkpoint/proto_hatt.pth.tar'
else:
    raise NotImplementedError

acc = 0
for i in range(5):
    acc += framework.eval(model,
                          4,
                          N,
                          K,
                          100,
                          3000,
Ejemplo n.º 6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train', default='train', help='train file')
    parser.add_argument('--val', default='dev', help='val file')
    parser.add_argument('--test', default='test', help='test file')
    parser.add_argument('--N',
                        default=2,
                        type=int,
                        help='Number of example concencate')
    parser.add_argument('--batch_size',
                        default=32,
                        type=int,
                        help='batch size')
    parser.add_argument('--train_iter',
                        default=30000,
                        type=int,
                        help='num of iters in training')
    parser.add_argument('--val_iter',
                        default=1000,
                        type=int,
                        help='num of iters in validation')
    parser.add_argument('--test_iter',
                        default=10000,
                        type=int,
                        help='num of iters in testing')
    parser.add_argument('--val_step',
                        default=2000,
                        type=int,
                        help='val after training how many iters')
    parser.add_argument('--model', default='proto', help='model name')
    parser.add_argument('--encoder', default='cnn', help='cnn, lstm')
    parser.add_argument('--max_length',
                        default=32,
                        type=int,
                        help='max length')
    parser.add_argument('--lr', default=1e-5, type=float, help='learning rate')
    parser.add_argument('--dropout',
                        default=0.0,
                        type=float,
                        help='dropout rate')
    parser.add_argument('--grad_iter',
                        default=1,
                        type=int,
                        help='accumulate gradient every x iterations')
    parser.add_argument('--optim', default='sgd', help='sgd / adam')
    parser.add_argument('--hidden_size',
                        default=64,
                        type=int,
                        help='hidden size')
    parser.add_argument('--load_ckpt', default=None, help='load ckpt')
    parser.add_argument('--save_ckpt', default=None, help='save ckpt')
    parser.add_argument('--only_test', action='store_true', help='only test')

    opt = parser.parse_args()
    N = opt.N

    batch_size = opt.batch_size
    model_name = opt.model
    encoder_name = opt.encoder
    max_length = opt.max_length

    print("model: {}".format(model_name))
    print("encoder: {}".format(encoder_name))
    print("max_length: {}".format(max_length))

    print('Preparing vocab.')

    vocab = Vocab(opt.train, './data')
    char2id = vocab.prepare_vocab()
    print('Prepare done!')
    sentence_encoder = SentenceEncoder(char2id,
                                       max_length,
                                       opt.hidden_size,
                                       opt.hidden_size,
                                       encoder=opt.encoder)

    train_data_loader = get_loader(opt.train,
                                   sentence_encoder,
                                   N=N,
                                   batch_size=batch_size)
    val_data_loader = get_loader(opt.val,
                                 sentence_encoder,
                                 N=N,
                                 batch_size=batch_size)
    test_data_loader = get_loader(opt.test,
                                  sentence_encoder,
                                  N=N,
                                  batch_size=batch_size)

    if opt.optim == 'sgd':
        pytorch_optim = optim.SGD
    elif opt.optim == 'adam':
        pytorch_optim = optim.Adam
    else:
        raise NotImplementedError

    framework = Framework(train_data_loader, val_data_loader, test_data_loader)
    prefix = '-'.join([model_name, encoder_name, opt.train, opt.val, str(N)])

    model = Proto(sentence_encoder, hidden_size=opt.hidden_size)

    if not os.path.exists('checkpoint'):
        os.mkdir('checkpoint')
    ckpt = 'checkpoint/{}.pth.tar'.format(prefix)
    if opt.save_ckpt:
        ckpt = opt.save_ckpt

    if torch.cuda.is_available():
        model.cuda()

    framework.train(model,
                    prefix,
                    batch_size,
                    pytorch_optim=pytorch_optim,
                    load_ckpt=opt.load_ckpt,
                    save_ckpt=ckpt,
                    val_step=opt.val_step,
                    train_iter=opt.train_iter,
                    val_iter=opt.val_iter)

    acc = framework.eval(model, batch_size, opt.test_iter, ckpt=ckpt)
    print("RESULT: %.2f" % (acc * 100))
Ejemplo n.º 7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train', default='train_data',
            help='train file')
    parser.add_argument('--val', default='val_data',
            help='val file')
    parser.add_argument('--test', default='test_data',
            help='test file')
    parser.add_argument('--adv', default=None,
            help='adv file')
    parser.add_argument('--trainN', default=10, type=int,
            help='N in train')
    parser.add_argument('--N', default=5, type=int,
            help='N way')
    parser.add_argument('--K', default=5, type=int,
            help='K shot')
    parser.add_argument('--Q', default=5, type=int,
            help='Num of query per class')
    parser.add_argument('--batch_size', default=4, type=int,
            help='batch size')
    parser.add_argument('--train_iter', default=30000, type=int,
            help='num of iters in training')
    parser.add_argument('--val_iter', default=1000, type=int,
            help='num of iters in validation')
    parser.add_argument('--test_iter', default=1000, type=int,
            help='num of iters in testing')
    parser.add_argument('--val_step', default=2000, type=int,
           help='val after training how many iters')
    parser.add_argument('--repeat_test', default=8, type=int,
                        help='repeat test stage')
    parser.add_argument('--model', default='proto',
            help='model name')
    parser.add_argument('--encoder', default='bert',
            help='encoder: cnn or bert or roberta')
    parser.add_argument('--cross_modality', default='concate',
                        help='cross-modality module')
    parser.add_argument('--max_length', default=128, type=int,
           help='max length')
    parser.add_argument('--lr', default=1e-1, type=float,
           help='learning rate')
    parser.add_argument('--weight_decay', default=1e-5, type=float,
           help='weight decay')
    parser.add_argument('--dropout', default=0.0, type=float,
           help='dropout rate')
    parser.add_argument('--na_rate', default=0, type=int,
           help='NA rate (NA = Q * na_rate)')
    parser.add_argument('--grad_iter', default=1, type=int,
           help='accumulate gradient every x iterations')
    parser.add_argument('--optim', default='sgd',
           help='sgd / adam / adamw')
    parser.add_argument('--hidden_size', default=230, type=int,
           help='hidden size')
    parser.add_argument('--multi_choose', default=1, type=int,
           help='choose multi random faces')
    parser.add_argument('--load_ckpt', default=None,
           help='load ckpt')
    parser.add_argument('--root_data', default='./data',
                        help='the root path stores data')
    parser.add_argument('--save_ckpt', default=None,
           help='save ckpt')
    parser.add_argument('--fp16', action='store_true',
           help='use nvidia apex fp16')
    parser.add_argument('--only_test', action='store_true',
           help='only test')

    # only for bert / roberta
    parser.add_argument('--pair', action='store_true',
           help='use pair model')
    parser.add_argument('--pretrain_ckpt', default=None,
           help='bert / roberta pre-trained checkpoint')
    parser.add_argument('--cat_entity_rep', action='store_true',
           help='concatenate entity representation as sentence rep')
    parser.add_argument('--use_img', action='store_true',
                        help='use img info')

    # only for prototypical networks
    parser.add_argument('--dot', action='store_true', 
           help='use dot instead of L2 distance for proto')

    parser.add_argument('--differ_scene', action='store_true',
                        help='use face image in different scenes')


    opt = parser.parse_args()
    trainN = opt.trainN
    N = opt.N
    K = opt.K
    Q = opt.Q
    batch_size = opt.batch_size
    model_name = opt.model
    encoder_name = opt.encoder
    max_length = opt.max_length
    
    print("{}-way-{}-shot Multimodal Social Relation Classification".format(N, K))
    print("model: {}".format(model_name))
    print("encoder: {}".format(encoder_name))
    print("max_length: {}".format(max_length))
    
    if encoder_name == 'cnn':
        try:
            glove_mat = np.load('./pretrain/glove/glove_mat.npy')
            glove_word2id = json.load(open('./pretrain/glove/glove_word2id.json'))
        except:
            raise Exception("Cannot find glove files. Run glove/download_glove.sh to download glove files.")
        sentence_encoder = CNNSentenceEncoder(
                glove_mat,
                glove_word2id,
                max_length)
    elif encoder_name == 'bert':
        pretrain_ckpt = opt.pretrain_ckpt or 'bert-base-chinese'
        sentence_encoder = BERTSentenceEncoder(pretrain_ckpt, max_length)
    elif encoder_name == 'roberta':
        pretrain_ckpt = opt.pretrain_ckpt or 'roberta-base'
        sentence_encoder = RobertaSentenceEncoder(
                pretrain_ckpt,
                max_length,
                cat_entity_rep=opt.cat_entity_rep)
    else:
        raise NotImplementedError

    train_data_loader = get_loader(opt.train, sentence_encoder,
            N=trainN, K=K, Q=Q, root=opt.root_data, batch_size=batch_size, use_img=opt.use_img, differ_scene=opt.differ_scene, multi_choose=opt.multi_choose)
    val_data_loader = get_loader(opt.val, sentence_encoder,
            N=N, K=K, Q=Q, root=opt.root_data, batch_size=batch_size, use_img=opt.use_img, differ_scene=opt.differ_scene, multi_choose=opt.multi_choose)
    test_data_loader = get_loader(opt.test, sentence_encoder,
            N=N, K=K, Q=Q, root=opt.root_data, batch_size=batch_size, use_img=opt.use_img, differ_scene=opt.differ_scene, multi_choose=opt.multi_choose)

   
    if opt.optim == 'sgd':
        pytorch_optim = optim.SGD
    elif opt.optim == 'adam':
        pytorch_optim = optim.Adam
    elif opt.optim == 'adamw':
        from transformers import AdamW
        pytorch_optim = AdamW
    else:
        raise NotImplementedError

    framework = FewShotREFramework(train_data_loader, val_data_loader, test_data_loader)

    if opt.differ_scene:
        prefix = '-'.join(['differ', model_name, encoder_name, opt.train, opt.val, str(N), str(K)])
    else:
        prefix = '-'.join(['same', model_name, encoder_name, opt.train, opt.val, str(N), str(K)])

    if opt.use_img:
        prefix = 'img_' + prefix
    if opt.cross_modality == 'transformer':
        prefix = 'transformer' + '-' + prefix

    if opt.multi_choose > 1:
        prefix = 'choose-' + str(opt.multi_choose) + '-' + prefix

    prefix = '0227' + prefix
    
    if model_name == 'proto':
        if opt.use_img:
            face_encoder = FacenetEncoder()
            multimodal_encoder = MultimodalEncoder(sentence_encoder=sentence_encoder, face_encoder=face_encoder, cross_modality=opt.cross_modality)
            model = Proto(multimodal_encoder, dot=opt.dot)
        else:
            model = Proto(sentence_encoder, dot=opt.dot)
    else:
        raise NotImplementedError

    # if not os.path.exists('checkpoint'):
    #     os.mkdir('checkpoint')
    # ckpt = 'checkpoint/{}.pth.tar'.format(prefix)

    checkpoint_dir = os.path.join(opt.root_data, 'checkpoint')
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)
    ckpt = os.path.join(checkpoint_dir, '{}.pth.tar'.format(prefix))
    

    if opt.save_ckpt:
        ckpt = opt.save_ckpt

    if torch.cuda.is_available():
        model.cuda()

    if not opt.only_test:
        if encoder_name in ['bert', 'roberta']:
            bert_optim = True
        else:
            bert_optim = False

        framework.train(model, prefix, batch_size, trainN, N, K, Q,
                pytorch_optim=pytorch_optim, load_ckpt=opt.load_ckpt, save_ckpt=ckpt,
                na_rate=opt.na_rate, val_step=opt.val_step, grad_iter=opt.grad_iter, fp16=opt.fp16,
                pair=opt.pair, train_iter=opt.train_iter, val_iter=opt.val_iter, bert_optim=bert_optim,
                multi_choose=opt.multi_choose)
    else:
        ckpt = opt.load_ckpt

    result_txt = os.path.join(checkpoint_dir, '{}.txt'.format(opt.root_data.split('/')[1] + '_' + prefix))
    with open(result_txt, 'a', encoding='utf-8') as f:
        for _ in range (opt.repeat_test):
            acc = framework.eval(model, batch_size, N, K, Q, opt.test_iter, na_rate=opt.na_rate, ckpt=ckpt, pair=opt.pair, multi_choose=opt.multi_choose)
            print("RESULT: %.2f" % (acc * 100))
            f.write(str(acc * 100))
            f.write('\n')
Ejemplo n.º 8
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train', default='train_wiki', help='train file')
    parser.add_argument('--val', default='val_wiki', help='val file')
    parser.add_argument('--test', default='test_wiki', help='test file')
    parser.add_argument('--adv', default=None, help='adv file')
    parser.add_argument('--trainN', default=10, type=int, help='N in train')
    parser.add_argument('--N', default=5, type=int, help='N way')
    parser.add_argument('--K', default=5, type=int, help='K shot')
    parser.add_argument('--Q',
                        default=5,
                        type=int,
                        help='Num of query per class')
    parser.add_argument('--batch_size', default=4, type=int, help='batch size')
    parser.add_argument('--train_iter',
                        default=30000,
                        type=int,
                        help='num of iters in training')
    parser.add_argument('--val_iter',
                        default=1000,
                        type=int,
                        help='num of iters in validation')
    parser.add_argument('--test_iter',
                        default=3000,
                        type=int,
                        help='num of iters in testing')
    parser.add_argument('--val_step',
                        default=2000,
                        type=int,
                        help='val after training how many iters')
    parser.add_argument('--model', default='proto', help='model name')
    parser.add_argument('--encoder',
                        default='cnn',
                        help='encoder: cnn or bert')
    parser.add_argument('--max_length',
                        default=128,
                        type=int,
                        help='max length')
    parser.add_argument('--lr', default=1e-1, type=float, help='learning rate')
    parser.add_argument('--weight_decay',
                        default=1e-5,
                        type=float,
                        help='weight decay')
    parser.add_argument('--dropout',
                        default=0.0,
                        type=float,
                        help='dropout rate')
    parser.add_argument('--na_rate',
                        default=0,
                        type=int,
                        help='NA rate (NA = Q * na_rate)')
    parser.add_argument('--grad_iter',
                        default=1,
                        type=int,
                        help='accumulate gradient every x iterations')
    parser.add_argument('--optim',
                        default='sgd',
                        help='sgd / adam / bert_adam')
    parser.add_argument('--hidden_size',
                        default=230,
                        type=int,
                        help='hidden size')
    parser.add_argument('--load_ckpt', default=None, help='load ckpt')
    parser.add_argument('--save_ckpt', default=None, help='save ckpt')
    parser.add_argument('--fp16',
                        action='store_true',
                        help='use nvidia apex fp16')
    parser.add_argument('--only_test', action='store_true', help='only test')
    parser.add_argument('--pair', action='store_true', help='use pair model')

    opt = parser.parse_args()
    trainN = opt.trainN
    N = opt.N
    K = opt.K
    Q = opt.Q
    batch_size = opt.batch_size
    model_name = opt.model
    encoder_name = opt.encoder
    max_length = opt.max_length

    print("{}-way-{}-shot Few-Shot Relation Classification".format(N, K))
    print("model: {}".format(model_name))
    print("encoder: {}".format(encoder_name))
    print("max_length: {}".format(max_length))

    if encoder_name == 'cnn':
        try:
            glove_mat = np.load('./pretrain/glove/glove_mat.npy')
            glove_word2id = json.load(
                open('./pretrain/glove/glove_word2id.json'))
        except:
            raise Exception(
                "Cannot find glove files. Run glove/download_glove.sh to download glove files."
            )
        sentence_encoder = CNNSentenceEncoder(glove_mat, glove_word2id,
                                              max_length)
    elif encoder_name == 'bert':
        if opt.pair:
            sentence_encoder = BERTPAIRSentenceEncoder(
                './pretrain/bert-base-uncased', max_length)
        else:
            sentence_encoder = BERTSentenceEncoder(
                './pretrain/bert-base-uncased', max_length)
    else:
        raise NotImplementedError

    if opt.pair:
        train_data_loader = get_loader_pair(opt.train,
                                            sentence_encoder,
                                            N=trainN,
                                            K=K,
                                            Q=Q,
                                            na_rate=opt.na_rate,
                                            batch_size=batch_size)
        val_data_loader = get_loader_pair(opt.val,
                                          sentence_encoder,
                                          N=N,
                                          K=K,
                                          Q=Q,
                                          na_rate=opt.na_rate,
                                          batch_size=batch_size)
        test_data_loader = get_loader_pair(opt.test,
                                           sentence_encoder,
                                           N=N,
                                           K=K,
                                           Q=Q,
                                           na_rate=opt.na_rate,
                                           batch_size=batch_size)
    else:
        train_data_loader = get_loader(opt.train,
                                       sentence_encoder,
                                       N=trainN,
                                       K=K,
                                       Q=Q,
                                       na_rate=opt.na_rate,
                                       batch_size=batch_size)
        val_data_loader = get_loader(opt.val,
                                     sentence_encoder,
                                     N=N,
                                     K=K,
                                     Q=Q,
                                     na_rate=opt.na_rate,
                                     batch_size=batch_size)
        test_data_loader = get_loader(opt.test,
                                      sentence_encoder,
                                      N=N,
                                      K=K,
                                      Q=Q,
                                      na_rate=opt.na_rate,
                                      batch_size=batch_size)
        if opt.adv:
            adv_data_loader = get_loader_unsupervised(opt.adv,
                                                      sentence_encoder,
                                                      N=trainN,
                                                      K=K,
                                                      Q=Q,
                                                      na_rate=opt.na_rate,
                                                      batch_size=batch_size)

    if opt.optim == 'sgd':
        pytorch_optim = optim.SGD
    elif opt.optim == 'adam':
        pytorch_optim = optim.Adam
    elif opt.optim == 'bert_adam':
        from pytorch_transformers import AdamW
        pytorch_optim = AdamW
    else:
        raise NotImplementedError
    if opt.adv:
        d = Discriminator(opt.hidden_size)
        framework = FewShotREFramework(train_data_loader,
                                       val_data_loader,
                                       test_data_loader,
                                       adv_data_loader,
                                       adv=opt.adv,
                                       d=d)
    else:
        framework = FewShotREFramework(train_data_loader, val_data_loader,
                                       test_data_loader)

    prefix = '-'.join(
        [model_name, encoder_name, opt.train, opt.val,
         str(N), str(K)])
    if opt.adv is not None:
        prefix += '-adv_' + opt.adv
    if opt.na_rate != 0:
        prefix += '-na{}'.format(opt.na_rate)

    if model_name == 'proto':
        model = Proto(sentence_encoder, hidden_size=opt.hidden_size)
    elif model_name == 'gnn':
        model = GNN(sentence_encoder, N)
    elif model_name == 'snail':
        print("HINT: SNAIL works only in PyTorch 0.3.1")
        model = SNAIL(sentence_encoder, N, K)
    elif model_name == 'metanet':
        model = MetaNet(N, K, sentence_encoder.embedding, max_length)
    elif model_name == 'siamese':
        model = Siamese(sentence_encoder,
                        hidden_size=opt.hidden_size,
                        dropout=opt.dropout)
    elif model_name == 'pair':
        model = Pair(sentence_encoder, hidden_size=opt.hidden_size)
    else:
        raise NotImplementedError

    if not os.path.exists('checkpoint'):
        os.mkdir('checkpoint')
    ckpt = 'checkpoint/{}.pth.tar'.format(prefix)
    if opt.save_ckpt:
        ckpt = opt.save_ckpt

    if torch.cuda.is_available():
        model.cuda()

    if not opt.only_test:
        if encoder_name == 'bert':
            bert_optim = True
        else:
            bert_optim = False

        framework.train(model,
                        prefix,
                        batch_size,
                        trainN,
                        N,
                        K,
                        Q,
                        pytorch_optim=pytorch_optim,
                        load_ckpt=opt.load_ckpt,
                        save_ckpt=ckpt,
                        na_rate=opt.na_rate,
                        val_step=opt.val_step,
                        fp16=opt.fp16,
                        pair=opt.pair,
                        train_iter=opt.train_iter,
                        val_iter=opt.val_iter,
                        bert_optim=bert_optim)
    else:
        ckpt = opt.load_ckpt

    acc = 0
    his_acc = []
    total_test_round = 5
    for i in range(total_test_round):
        cur_acc = framework.eval(model,
                                 batch_size,
                                 N,
                                 K,
                                 Q,
                                 opt.test_iter,
                                 na_rate=opt.na_rate,
                                 ckpt=ckpt,
                                 pair=opt.pair)
        his_acc.append(cur_acc)
        acc += cur_acc
    acc /= total_test_round
    nhis_acc = np.array(his_acc)
    error = nhis_acc.std() * 1.96 / (nhis_acc.shape[0]**0.5)
    print("RESULT: %.2f\\pm%.2f" % (acc * 100, error * 100))

    result_file = open('./result.txt', 'a+')
    result_file.write(
        "test data: %12s | model: %45s | acc: %.6f\n | error: %.6f\n" %
        (opt.test, prefix, acc, error))

    result_file = open('./result_detail.txt', 'a+')
    result_detail = {
        'test': opt.test,
        'model': prefix,
        'acc': acc,
        'his': his_acc
    }
    result_file.write("%s\n" % (json.dumps(result_detail)))
Ejemplo n.º 9
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train', default='train', help='train file')
    parser.add_argument('--test', default='test_sim', help='test file')
    parser.add_argument('--N',
                        default=2,
                        type=int,
                        help='Number of example concencate')
    parser.add_argument('--batch_size', default=1, type=int, help='batch size')
    parser.add_argument('--model', default='proto', help='model name')
    parser.add_argument('--encoder', default='cnn', help='cnn, lstm')
    parser.add_argument('--max_length',
                        default=32,
                        type=int,
                        help='max length')
    parser.add_argument('--hidden_size',
                        default=64,
                        type=int,
                        help='hidden size')

    opt = parser.parse_args()
    N = opt.N

    model_name = opt.model
    encoder_name = opt.encoder
    max_length = opt.max_length

    print("model: {}".format(model_name))
    print("encoder: {}".format(encoder_name))
    print("max_length: {}".format(max_length))

    print('Preparing vocab.')

    vocab = Vocab(opt.train, './data')
    char2id = vocab.prepare_vocab()
    print('Prepare done!')

    sentence_encoder = SentenceEncoder(char2id,
                                       max_length,
                                       opt.hidden_size,
                                       opt.hidden_size,
                                       encoder=opt.encoder)

    prefix = '-'.join([model_name, encoder_name, opt.train, 'dev', str(N)])
    model = Proto(sentence_encoder, hidden_size=opt.hidden_size)
    ckpt = 'checkpoint/{}.pth.tar'.format(prefix)

    if torch.cuda.is_available():
        model.cuda()
    model.eval()

    state_dict = __load_model__(ckpt)['state_dict']
    own_state = model.state_dict()
    for name, param in state_dict.items():
        if name not in own_state:
            continue
        own_state[name].copy_(param)

    abbs, fulls = __load_data__(opt.test, './data')
    ground_truth = {}
    for i in range(len(abbs)):
        if abbs[i] in ground_truth:
            ground_truth[abbs[i]].append(fulls[i])
        else:
            ground_truth[abbs[i]] = []

    abbs_set = list(set(abbs))
    fulls_set = list(set(fulls))

    abbs_data = []
    abbs_mask = []
    fulls_data = []
    fulls_mask = []
    for i in range(len(abbs_set)):
        abbs_data_ = []
        abbs_mask_ = []
        fulls_data_ = []
        fulls_mask_ = []
        for j in range(len(fulls_set)):
            indexed_abbs, mask_abb, indexed_target, mask_target = sentence_encoder.tokenize_test(
                abbs_set[i], fulls_set[j])
            indexed_abbs = torch.tensor(indexed_abbs).long().unsqueeze(0)
            mask_abb = torch.tensor(mask_abb).long().unsqueeze(0)
            indexed_target = torch.tensor(indexed_target).long().unsqueeze(0)
            mask_target = torch.tensor(mask_target).long().unsqueeze(0)

            abbs_data_.append(indexed_abbs)
            abbs_mask_.append(mask_abb)
            fulls_data_.append(indexed_target)
            fulls_mask_.append(mask_target)
        abbs_data.append(abbs_data_)
        abbs_mask.append(abbs_mask_)
        fulls_data.append(fulls_data_)
        fulls_mask.append(fulls_mask_)

    print('Start Evaluating...')
    with torch.no_grad():
        scores = []
        for i in tqdm(range(len(abbs_set))):
            scores_ = []
            for j in range(len(fulls_set)):
                input1, mask1, input2, mask2 = abbs_data[i][j], abbs_mask[i][
                    j], fulls_data[i][j], fulls_mask[i][j]
                if torch.cuda.is_available():
                    input1 = input1.cuda()
                    input2 = input2.cuda()
                    mask1 = mask1.cuda()
                    mask2 = mask2.cuda()

                logits = model(input1, mask1, input2, mask2)
                scores_.append(logits.cpu().numpy()[0][0])
            scores.append(scores_)
        scores = np.array(scores)
        print(scores.shape)
    score_arg = np.argsort(scores, axis=1)
    score_sorted = np.sort(scores, axis=1)
    score_arg = score_arg[:, :10]
    score_sorted = score_sorted[:, :10]

    result = []
    for i in range(score_arg.shape[0]):
        result.append({
            'abbreviation':
            abbs_set[i],
            'predicted':
            list(map(lambda x: fulls_set[x], score_arg[i])),
            'ground truth':
            ground_truth[abbs_set[i]],
            'score':
            str(score_sorted[i])
        })

    f = open('result.json', 'w', encoding='utf8')
    json.dump(result, f)
    f.close()