コード例 #1
0
def pair(command):
    pair = Pair(command[1].upper(), command[2].upper())
    back_days = core.safe_execute(gcnv.BACK_DAYS, ValueError,
                                  lambda x: int(x) * 30, command[3])

    print(f"  Correlation: {format(pair.correlation(back_days), '.2f')}")
    print(f"  Beta:        {format(pair.beta(back_days), '.2f')}")
    print(f"  Volat ratio: {format(pair.stdev_ratio(back_days), '.2f')}")
コード例 #2
0
    def __init__(self, chpt_path, max_length=128):
        """
            Initializer
        """
        self.checkpoint_path = chpt_path
        self.bert_pretrained_checkpoint = 'bert-base-uncased'
        self.max_length = max_length
        self.sentence_encoder = BERTPAIRSentenceEncoder(
            self.bert_pretrained_checkpoint, self.max_length)

        self.model = Pair(self.sentence_encoder, hidden_size=768)
        if torch.cuda.is_available():
            self.model = self.model.cuda()
        self.model.eval()

        #         self.nlp_coref = spacy.load("en_core_web_sm")
        #         neuralcoref.add_to_pipe(self.nlp_coref)
        self.nlp_no_coref = spacy.load("en_core_web_sm")
        self.load_model()
コード例 #3
0
def table(command):
    back_days = core.safe_execute(gcnv.BACK_DAYS, ValueError,
                                  lambda x: int(x) * 30, command[2])
    header = [
        "", "SPY", "TLT", "IEF", "GLD", "USO", "UNG", "FXE", "FXY", "FXB",
        "IYR", "XLU", "EFA", "EEM", "VXX"
    ]
    rows = []
    for symbol in util.read_symbol_list(
            f"{gcnv.APP_PATH}/input/{command[1]}.txt"):
        row = [symbol]
        for head_symbol in header[1:]:
            if symbol == head_symbol:
                row.append("-")
            else:
                try:
                    pair = Pair(head_symbol, symbol)
                    row.append(pair.correlation(back_days))
                except GettingInfoError:
                    row.append("-")
        rows.append(row)
    return header, rows
コード例 #4
0
def get_row(pair, command):
    ps = process_pair_string(pair)
    ticker1 = ps.ticker1
    ticker2 = ps.ticker2
    fixed_stdev_ratio = ps.stdev_ratio
    back_days = core.safe_execute(gcnv.PAIR_BACK_DAYS, ValueError,
                                  lambda x: int(x) * 30, command[2])
    bring_if_connected(ticker1)
    bring_if_connected(ticker2)
    try:
        pair = Pair(ticker1, ticker2, fixed_stdev_ratio)
        max_stored_date = gcnv.data_handler.get_max_stored_date(
            "stock", ticker1)
        date = '-' if max_stored_date is None \
                    else util.date_in_string(max_stored_date) # Need to change this
        row = [
            ticker1 + '-' + ticker2,
            date,
            '-',
            pair.get_last_close(
                back_days),  # GettingInfoError raised here if not stored data
            pair.min(back_days),
            pair.max(back_days),
            pair.current_rank(back_days),
            pair.ma(back_days),
            '-',
            pair.stdev_ratio(back_days),
            pair.correlation(back_days),
            pair.hv_to_10_ratio(back_days),
            '-'
        ]
        closes = pair.closes(back_days)[-gcnv.PAIR_PAST_RESULTS:]
        closes.reverse()
        row += closes
        return row
    except (GettingInfoError, ZeroDivisionError,
            statistics.StatisticsError) as e:
        print(e)
        return []
コード例 #5
0
def main():
    opt = {'train': 'train_wiki',
           'val': 'val_wiki',
           'test': 'val_semeval',
           'adv': None,
           'trainN': 10,
           'N': 5,
           'K': 1,
           'Q': 1,
           'batch_size': 1,
           'train_iter': 5,
           'val_iter': 1000,
           'test_iter': 10,
           'val_step': 2000,
           'model': 'pair',
           'encoder': 'bert',
           'max_length': 64,
           'lr': -1,
           'weight_decay': 1e-5,
           'dropout': 0.0,
           'na_rate':0,
           'optim': 'adam',
           'load_ckpt': './src/FewRel-master/Checkpoints/ckpt_5-Way-1-Shot_FewRel.pth',
           'save_ckpt': './src/FewRel-master/Checkpoints/post_ckpt_5-Way-1-Shot_FewRel.pth',
           'fp16':False,
           'only_test': False,
           'ckpt_name': 'Checkpoints/ckpt_5-Way-1-Shot_FewRel.pth',
           'pair': True,
           'pretrain_ckpt': '',
           'cat_entity_rep': False,
           'dot': False,
           'no_dropout': False,
           'mask_entity': False,
           'use_sgd_for_bert': False
           }
    opt = DotMap(opt)
    trainN = opt.trainN
    N = opt.N
    K = opt.K
    Q = opt.Q
    batch_size = opt.batch_size
    model_name = opt.model
    encoder_name = opt.encoder
    max_length = opt.max_length
    
    print("{}-way-{}-shot Few-Shot Relation Classification".format(N, K))
    print("model: {}".format(model_name))
    print("encoder: {}".format(encoder_name))
    print("max_length: {}".format(max_length))

    pretrain_ckpt = opt.pretrain_ckpt or 'bert-base-uncased'
    sentence_encoder = BERTPAIRSentenceEncoder(
            pretrain_ckpt,
            max_length)
    
    train_data_loader = get_loader_pair(opt.train, sentence_encoder,
            N=trainN, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size, encoder_name=encoder_name)
    val_data_loader = get_loader_pair(opt.val, sentence_encoder,
            N=N, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size, encoder_name=encoder_name)
    test_data_loader = get_loader_pair(opt.test, sentence_encoder,
            N=N, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size, encoder_name=encoder_name)
   
    if opt.optim == 'sgd':
        pytorch_optim = optim.SGD
    elif opt.optim == 'adam':
        pytorch_optim = optim.Adam
    elif opt.optim == 'adamw':
        pytorch_optim = AdamW
    else:
        raise NotImplementedError
    framework = FewShotREFramework(train_data_loader, val_data_loader, test_data_loader)
        
    prefix = '-'.join([model_name, encoder_name, opt.train, opt.val, str(N), str(K)])
    if opt.na_rate != 0:
        prefix += '-na{}'.format(opt.na_rate)
    if opt.dot:
        prefix += '-dot'
    if opt.cat_entity_rep:
        prefix += '-catentity'
    if len(opt.ckpt_name) > 0:
        prefix += '-' + opt.ckpt_name

    model = Pair(sentence_encoder, hidden_size=opt.hidden_size)
    
    if not os.path.exists('checkpoint'):
        os.mkdir('checkpoint')
    ckpt = 'checkpoint/{}.pth.tar'.format(prefix)
    if opt.save_ckpt:
        ckpt = opt.save_ckpt

    if torch.cuda.is_available():
        model.cuda()

    if not opt.only_test:
        if encoder_name in ['bert', 'roberta']:
            bert_optim = True
        else:
            bert_optim = False

        if opt.lr == -1:
            if bert_optim:
                opt.lr = 2e-5
            else:
                opt.lr = 1e-1

        framework.train(model, prefix, batch_size, trainN, N, K, Q,
                pytorch_optim=pytorch_optim, load_ckpt=opt.load_ckpt, save_ckpt=ckpt,
                na_rate=opt.na_rate, val_step=opt.val_step, fp16=opt.fp16, pair=opt.pair, 
                train_iter=opt.train_iter, val_iter=opt.val_iter, bert_optim=bert_optim, 
                learning_rate=opt.lr, use_sgd_for_bert=opt.use_sgd_for_bert)
    else:
        ckpt = opt.load_ckpt
        if ckpt is None:
            print("Warning: --load_ckpt is not specified. Will load Huggingface pre-trained checkpoint.")
            ckpt = 'none'

    acc = framework.eval(model, batch_size, N, K, Q, opt.test_iter, na_rate=opt.na_rate, ckpt=ckpt, pair=opt.pair)
    print("RESULT: %.2f" % (acc * 100))
コード例 #6
0
ファイル: train_demo.py プロジェクト: zhangpiepie/FewRel
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train', default='train_wiki', help='train file')
    parser.add_argument('--val', default='val_wiki', help='val file')
    parser.add_argument('--test', default='test_wiki', help='test file')
    parser.add_argument('--adv', default=None, help='adv file')
    parser.add_argument('--trainN', default=10, type=int, help='N in train')
    parser.add_argument('--N', default=5, type=int, help='N way')
    parser.add_argument('--K', default=5, type=int, help='K shot')
    parser.add_argument('--Q',
                        default=5,
                        type=int,
                        help='Num of query per class')
    parser.add_argument('--batch_size', default=4, type=int, help='batch size')
    parser.add_argument('--train_iter',
                        default=30000,
                        type=int,
                        help='num of iters in training')
    parser.add_argument('--val_iter',
                        default=1000,
                        type=int,
                        help='num of iters in validation')
    parser.add_argument('--test_iter',
                        default=10000,
                        type=int,
                        help='num of iters in testing')
    parser.add_argument('--val_step',
                        default=2000,
                        type=int,
                        help='val after training how many iters')
    parser.add_argument('--model', default='proto', help='model name')
    parser.add_argument('--encoder',
                        default='cnn',
                        help='encoder: cnn or bert or roberta')
    parser.add_argument('--max_length',
                        default=128,
                        type=int,
                        help='max length')
    parser.add_argument('--lr', default=1e-1, type=float, help='learning rate')
    parser.add_argument('--weight_decay',
                        default=1e-5,
                        type=float,
                        help='weight decay')
    parser.add_argument('--dropout',
                        default=0.0,
                        type=float,
                        help='dropout rate')
    parser.add_argument('--na_rate',
                        default=0,
                        type=int,
                        help='NA rate (NA = Q * na_rate)')
    parser.add_argument('--grad_iter',
                        default=1,
                        type=int,
                        help='accumulate gradient every x iterations')
    parser.add_argument('--optim', default='sgd', help='sgd / adam / adamw')
    parser.add_argument('--hidden_size',
                        default=230,
                        type=int,
                        help='hidden size')
    parser.add_argument('--load_ckpt', default=None, help='load ckpt')
    parser.add_argument('--save_ckpt', default=None, help='save ckpt')
    parser.add_argument('--fp16',
                        action='store_true',
                        help='use nvidia apex fp16')
    parser.add_argument('--only_test', action='store_true', help='only test')

    # only for bert / roberta
    parser.add_argument('--pair', action='store_true', help='use pair model')
    parser.add_argument('--pretrain_ckpt',
                        default=None,
                        help='bert / roberta pre-trained checkpoint')
    parser.add_argument(
        '--cat_entity_rep',
        action='store_true',
        help='concatenate entity representation as sentence rep')

    # only for prototypical networks
    parser.add_argument('--dot',
                        action='store_true',
                        help='use dot instead of L2 distance for proto')

    opt = parser.parse_args()
    trainN = opt.trainN
    N = opt.N
    K = opt.K
    Q = opt.Q
    batch_size = opt.batch_size
    model_name = opt.model
    encoder_name = opt.encoder
    max_length = opt.max_length

    print("{}-way-{}-shot Few-Shot Relation Classification".format(N, K))
    print("model: {}".format(model_name))
    print("encoder: {}".format(encoder_name))
    print("max_length: {}".format(max_length))

    if encoder_name == 'cnn':
        try:
            glove_mat = np.load('./pretrain/glove/glove_mat.npy')
            glove_word2id = json.load(
                open('./pretrain/glove/glove_word2id.json'))
        except:
            raise Exception(
                "Cannot find glove files. Run glove/download_glove.sh to download glove files."
            )
        sentence_encoder = CNNSentenceEncoder(glove_mat, glove_word2id,
                                              max_length)
    elif encoder_name == 'bert':
        pretrain_ckpt = opt.pretrain_ckpt or 'bert-base-uncased'
        if opt.pair:
            sentence_encoder = BERTPAIRSentenceEncoder(pretrain_ckpt,
                                                       max_length)
        else:
            sentence_encoder = BERTSentenceEncoder(
                pretrain_ckpt, max_length, cat_entity_rep=opt.cat_entity_rep)
    elif encoder_name == 'roberta':
        pretrain_ckpt = opt.pretrain_ckpt or 'roberta-base'
        if opt.pair:
            sentence_encoder = RobertaPAIRSentenceEncoder(
                pretrain_ckpt, max_length)
        else:
            sentence_encoder = RobertaSentenceEncoder(
                pretrain_ckpt, max_length, cat_entity_rep=opt.cat_entity_rep)
    else:
        raise NotImplementedError

    if opt.pair:
        train_data_loader = get_loader_pair(opt.train,
                                            sentence_encoder,
                                            N=trainN,
                                            K=K,
                                            Q=Q,
                                            na_rate=opt.na_rate,
                                            batch_size=batch_size,
                                            encoder_name=encoder_name)
        val_data_loader = get_loader_pair(opt.val,
                                          sentence_encoder,
                                          N=N,
                                          K=K,
                                          Q=Q,
                                          na_rate=opt.na_rate,
                                          batch_size=batch_size,
                                          encoder_name=encoder_name)
        test_data_loader = get_loader_pair(opt.test,
                                           sentence_encoder,
                                           N=N,
                                           K=K,
                                           Q=Q,
                                           na_rate=opt.na_rate,
                                           batch_size=batch_size,
                                           encoder_name=encoder_name)
    else:
        train_data_loader = get_loader(opt.train,
                                       sentence_encoder,
                                       N=trainN,
                                       K=K,
                                       Q=Q,
                                       na_rate=opt.na_rate,
                                       batch_size=batch_size)
        val_data_loader = get_loader(opt.val,
                                     sentence_encoder,
                                     N=N,
                                     K=K,
                                     Q=Q,
                                     na_rate=opt.na_rate,
                                     batch_size=batch_size)
        test_data_loader = get_loader(opt.test,
                                      sentence_encoder,
                                      N=N,
                                      K=K,
                                      Q=Q,
                                      na_rate=opt.na_rate,
                                      batch_size=batch_size)
        if opt.adv:
            adv_data_loader = get_loader_unsupervised(opt.adv,
                                                      sentence_encoder,
                                                      N=trainN,
                                                      K=K,
                                                      Q=Q,
                                                      na_rate=opt.na_rate,
                                                      batch_size=batch_size)

    if opt.optim == 'sgd':
        pytorch_optim = optim.SGD
    elif opt.optim == 'adam':
        pytorch_optim = optim.Adam
    elif opt.optim == 'adamw':
        from transformers import AdamW
        pytorch_optim = AdamW
    else:
        raise NotImplementedError
    if opt.adv:
        d = Discriminator(opt.hidden_size)
        framework = FewShotREFramework(train_data_loader,
                                       val_data_loader,
                                       test_data_loader,
                                       adv_data_loader,
                                       adv=opt.adv,
                                       d=d)
    else:
        framework = FewShotREFramework(train_data_loader, val_data_loader,
                                       test_data_loader)

    prefix = '-'.join(
        [model_name, encoder_name, opt.train, opt.val,
         str(N), str(K)])
    if opt.adv is not None:
        prefix += '-adv_' + opt.adv
    if opt.na_rate != 0:
        prefix += '-na{}'.format(opt.na_rate)
    if opt.dot:
        prefix += '-dot'
    if opt.cat_entity_rep:
        prefix += '-catentity'

    if model_name == 'proto':
        model = Proto(sentence_encoder, dot=opt.dot)
    elif model_name == 'gnn':
        model = GNN(sentence_encoder, N, hidden_size=opt.hidden_size)
    elif model_name == 'snail':
        model = SNAIL(sentence_encoder, N, K, hidden_size=opt.hidden_size)
    elif model_name == 'metanet':
        model = MetaNet(N, K, sentence_encoder.embedding, max_length)
    elif model_name == 'siamese':
        model = Siamese(sentence_encoder,
                        hidden_size=opt.hidden_size,
                        dropout=opt.dropout)
    elif model_name == 'pair':
        model = Pair(sentence_encoder, hidden_size=opt.hidden_size)
    else:
        raise NotImplementedError

    if not os.path.exists('checkpoint'):
        os.mkdir('checkpoint')
    ckpt = 'checkpoint/{}.pth.tar'.format(prefix)
    if opt.save_ckpt:
        ckpt = opt.save_ckpt

    if torch.cuda.is_available():
        model.cuda()

    if not opt.only_test:
        if encoder_name in ['bert', 'roberta']:
            bert_optim = True
        else:
            bert_optim = False

        framework.train(model,
                        prefix,
                        batch_size,
                        trainN,
                        N,
                        K,
                        Q,
                        pytorch_optim=pytorch_optim,
                        load_ckpt=opt.load_ckpt,
                        save_ckpt=ckpt,
                        na_rate=opt.na_rate,
                        val_step=opt.val_step,
                        fp16=opt.fp16,
                        pair=opt.pair,
                        train_iter=opt.train_iter,
                        val_iter=opt.val_iter,
                        bert_optim=bert_optim)
    else:
        ckpt = opt.load_ckpt

    acc = framework.eval(model,
                         batch_size,
                         N,
                         K,
                         Q,
                         opt.test_iter,
                         na_rate=opt.na_rate,
                         ckpt=ckpt,
                         pair=opt.pair)
    print("RESULT: %.2f" % (acc * 100))
コード例 #7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train', default='train_wiki', help='train file')
    parser.add_argument('--val', default='val_wiki', help='val file')
    parser.add_argument('--test', default='test_wiki', help='test file')
    parser.add_argument('--adv', default=None, help='adv file')
    parser.add_argument('--trainN', default=10, type=int, help='N in train')
    parser.add_argument('--N', default=5, type=int, help='N way')
    parser.add_argument('--K', default=5, type=int, help='K shot')
    parser.add_argument('--Q',
                        default=5,
                        type=int,
                        help='Num of query per class')
    parser.add_argument('--batch_size', default=4, type=int, help='batch size')
    parser.add_argument('--train_iter',
                        default=20000,
                        type=int,
                        help='num of iters in training')
    parser.add_argument('--val_iter',
                        default=1000,
                        type=int,
                        help='num of iters in validation')
    parser.add_argument('--test_iter',
                        default=2000,
                        type=int,
                        help='num of iters in testing')
    parser.add_argument('--val_step',
                        default=2000,
                        type=int,
                        help='val after training how many iters')
    parser.add_argument('--model', default='proto', help='model name')
    parser.add_argument('--encoder',
                        default='cnn',
                        help='encoder: cnn or bert')
    parser.add_argument('--max_length',
                        default=128,
                        type=int,
                        help='max length')
    parser.add_argument('--lr', default=1e-1, type=float, help='learning rate')
    parser.add_argument('--weight_decay',
                        default=1e-5,
                        type=float,
                        help='weight decay')
    parser.add_argument('--dropout',
                        default=0.0,
                        type=float,
                        help='dropout rate')
    parser.add_argument('--na_rate',
                        default=0,
                        type=int,
                        help='NA rate (NA = Q * na_rate)')
    parser.add_argument('--grad_iter',
                        default=1,
                        type=int,
                        help='accumulate gradient every x iterations')
    parser.add_argument('--optim',
                        default='sgd',
                        help='sgd / adam / bert_adam')
    parser.add_argument('--hidden_size',
                        default=230,
                        type=int,
                        help='hidden size')
    parser.add_argument('--load_ckpt', default=None, help='load ckpt')
    parser.add_argument('--save_ckpt', default=None, help='save ckpt')
    parser.add_argument('--fp16',
                        action='store_true',
                        help='use nvidia apex fp16')
    parser.add_argument('--only_test', action='store_true', help='only test')
    parser.add_argument('--pair', action='store_true', help='use pair model')
    parser.add_argument('--language', type=str, default='eng', help='language')
    parser.add_argument('--sup_cost',
                        type=int,
                        default=0,
                        help='use sup classifier')

    opt = parser.parse_args()
    trainN = opt.trainN
    N = opt.N
    K = opt.K
    Q = opt.Q
    batch_size = opt.batch_size
    model_name = opt.model
    encoder_name = opt.encoder
    max_length = opt.max_length
    sup_cost = bool(opt.sup_cost)
    print(sup_cost)

    print("{}-way-{}-shot Few-Shot Relation Classification".format(N, K))
    print("model: {}".format(model_name))
    print("encoder: {}".format(encoder_name))
    print("max_length: {}".format(max_length))

    embsize = 50
    if opt.language == 'chn':
        embsize = 100

    if encoder_name == 'cnn':
        try:
            if opt.language == 'chn':
                glove_mat = np.load('./pretrain/chinese_emb/emb.npy')
                glove_word2id = json.load(
                    open('./pretrain/chinese_emb/word2id.json'))
            else:
                glove_mat = np.load('./pretrain/glove/glove_mat.npy')
                glove_word2id = json.load(
                    open('./pretrain/glove/glove_word2id.json'))
        except:
            raise Exception(
                "Cannot find glove files. Run glove/download_glove.sh to download glove files."
            )
        sentence_encoder = CNNSentenceEncoder(glove_mat,
                                              glove_word2id,
                                              max_length,
                                              word_embedding_dim=embsize)
    elif encoder_name == 'bert':
        if opt.pair:
            if opt.language == 'chn':
                sentence_encoder = BERTPAIRSentenceEncoder(
                    'bert-base-chinese',  #'./pretrain/bert-base-uncased',
                    max_length)
            else:
                sentence_encoder = BERTPAIRSentenceEncoder(
                    'bert-base-uncased', max_length)
        else:
            if opt.language == 'chn':
                sentence_encoder = BERTSentenceEncoder(
                    'bert-base-chinese',  #'./pretrain/bert-base-uncased',
                    max_length)
            else:
                sentence_encoder = BERTSentenceEncoder('bert-base-uncased',
                                                       max_length)
    else:
        raise NotImplementedError

    if opt.pair:
        train_data_loader = get_loader_pair(opt.train,
                                            sentence_encoder,
                                            N=trainN,
                                            K=K,
                                            Q=Q,
                                            na_rate=opt.na_rate,
                                            batch_size=batch_size)
        val_data_loader = get_loader_pair(opt.val,
                                          sentence_encoder,
                                          N=N,
                                          K=K,
                                          Q=Q,
                                          na_rate=opt.na_rate,
                                          batch_size=batch_size)
        test_data_loader = get_loader_pair(opt.test,
                                           sentence_encoder,
                                           N=N,
                                           K=K,
                                           Q=Q,
                                           na_rate=opt.na_rate,
                                           batch_size=batch_size)
    else:
        train_data_loader = get_loader(opt.train,
                                       sentence_encoder,
                                       N=trainN,
                                       K=K,
                                       Q=Q,
                                       na_rate=opt.na_rate,
                                       batch_size=batch_size)
        val_data_loader = get_loader(opt.val,
                                     sentence_encoder,
                                     N=N,
                                     K=K,
                                     Q=Q,
                                     na_rate=opt.na_rate,
                                     batch_size=batch_size)
        test_data_loader = get_loader(opt.test,
                                      sentence_encoder,
                                      N=N,
                                      K=K,
                                      Q=Q,
                                      na_rate=opt.na_rate,
                                      batch_size=batch_size)
        if opt.adv:
            adv_data_loader = get_loader_unsupervised(opt.adv,
                                                      sentence_encoder,
                                                      N=trainN,
                                                      K=K,
                                                      Q=Q,
                                                      na_rate=opt.na_rate,
                                                      batch_size=batch_size)

    if opt.optim == 'sgd':
        pytorch_optim = optim.SGD
    elif opt.optim == 'adam':
        pytorch_optim = optim.Adam
    elif opt.optim == 'bert_adam':
        from transformers import AdamW
        pytorch_optim = AdamW
    else:
        raise NotImplementedError
    if opt.adv:
        d = Discriminator(opt.hidden_size)
        framework = FewShotREFramework(train_data_loader,
                                       val_data_loader,
                                       test_data_loader,
                                       adv_data_loader,
                                       adv=opt.adv,
                                       d=d)
    else:
        framework = FewShotREFramework(train_data_loader, val_data_loader,
                                       test_data_loader)

    prefix = '-'.join(
        [model_name, encoder_name, opt.train, opt.val,
         str(N), str(K)])
    if opt.adv is not None:
        prefix += '-adv_' + opt.adv
    if opt.na_rate != 0:
        prefix += '-na{}'.format(opt.na_rate)

    if model_name == 'proto':
        model = Proto(sentence_encoder, hidden_size=opt.hidden_size)
    elif model_name == 'gnn':
        model = GNN(sentence_encoder, N, use_sup_cost=sup_cost)
    elif model_name == 'snail':
        print("HINT: SNAIL works only in PyTorch 0.3.1")
        model = SNAIL(sentence_encoder, N, K)
    elif model_name == 'metanet':
        model = MetaNet(N,
                        K,
                        sentence_encoder.embedding,
                        max_length,
                        use_sup_cost=sup_cost)
    elif model_name == 'siamese':
        model = Siamese(sentence_encoder,
                        hidden_size=opt.hidden_size,
                        dropout=opt.dropout)
    elif model_name == 'pair':
        model = Pair(sentence_encoder, hidden_size=opt.hidden_size)
    else:
        raise NotImplementedError

    if not os.path.exists('checkpoint'):
        os.mkdir('checkpoint')
    ckpt = 'checkpoint/{}.pth.tar'.format(prefix)
    if opt.save_ckpt:
        ckpt = opt.save_ckpt

    if torch.cuda.is_available():
        model.cuda()

    if not opt.only_test:
        if encoder_name == 'bert':
            bert_optim = True
        else:
            bert_optim = False
        framework.train(model,
                        prefix,
                        batch_size,
                        trainN,
                        N,
                        K,
                        Q,
                        pytorch_optim=pytorch_optim,
                        load_ckpt=opt.load_ckpt,
                        save_ckpt=ckpt,
                        na_rate=opt.na_rate,
                        val_step=opt.val_step,
                        fp16=opt.fp16,
                        pair=opt.pair,
                        train_iter=opt.train_iter,
                        val_iter=opt.val_iter,
                        bert_optim=bert_optim,
                        sup_cls=sup_cost)
    else:
        ckpt = opt.load_ckpt

    acc = framework.eval(model,
                         batch_size,
                         N,
                         K,
                         Q,
                         opt.test_iter,
                         na_rate=opt.na_rate,
                         ckpt=ckpt,
                         pair=opt.pair)
    wfile = open('logs/' + ckpt.replace('checkpoint/', '') + '.txt', 'a')
    wfile.write(str(N) + '\t' + str(K) + '\t' + str(acc * 100) + '\n')
    wfile.close()
    print("RESULT: %.2f" % (acc * 100))
コード例 #8
0
def get_row(ticker, command):
    try:
        bring_if_connected(ticker)
        date = gcnv.data_handler.get_max_stored_date("stock", ticker)
        if date is None:
            return []
        date = util.date_in_string(date)
        back_days = core.safe_execute(gcnv.BACK_DAYS, ValueError,
                                      lambda x: int(x) * 30, command[2])

        iv = IV(ticker)
        hv = HV(ticker)
        mixed_vs = MixedVs(iv, hv)
        stock = Stock(ticker)
        spy_pair = Pair(ticker, "SPY")
        spy_iv = IV("SPY")
        earnings_data = load_earnings()
        row = [ticker, date]
        # Price related data
        row += [
            stock.get_close_at(
                date),  # GettingInfoError raised here if not stored data
            f"{stock.min(back_days)} - {stock.max(back_days)}",
            round(stock.min_max_rank(date, back_days)),
            stock.range(28) / spy_pair.stdev_ratio(back_days),
            stock.range(14) / spy_pair.stdev_ratio(back_days),
            stock.move(7) / spy_pair.stdev_ratio(back_days),
            stock.get_last_percentage_change(),
            up_down_closes_str(stock, 14),
            core.safe_execute('-', GettingInfoError, spy_pair.correlation,
                              back_days),
            spy_pair.stdev_ratio(back_days),
            stock.hv_to_10_ratio(back_days),
            round(
                notional.directional_stock_number(
                    stock.get_close_at(date),
                    spy_pair.stdev_ratio(back_days))),
            round(
                notional.neutral_options_number(
                    stock.get_close_at(date), spy_pair.stdev_ratio(back_days)),
                1),
            round(
                notional.directional_options_number(
                    stock.get_close_at(date), spy_pair.stdev_ratio(back_days)),
                1),
            earnings_data[ticker][0],
            earnings_data[ticker][1],
            chart_link(ticker)
        ]
        # Volatility related data
        try:
            row += [
                iv.current_to_average_ratio(date, back_days),
                mixed_vs.iv_current_to_hv_average(date, back_days),
                mixed_vs.positive_difference_ratio(back_days),
                mixed_vs.difference_average(back_days),
                iv.current_percentile_iv_rank(back_days)
            ]
            row += iv.period_iv_ranks(back_days, max_results=gcnv.IVR_RESULTS)
        except (GettingInfoError, ZeroDivisionError,
                statistics.StatisticsError) as e:
            result_row_len = 5  # Number of rows above
            row += ['-'] * (result_row_len + gcnv.IVR_RESULTS)
        return row
    except (GettingInfoError, InputError, ZeroDivisionError,
            statistics.StatisticsError) as e:
        print(e)
        return []
コード例 #9
0
def validate(req_id, ena_dir):
    """
    Validates fastq files for a seq submission. This method is called by the Django API.
    It create a new DB record for the validation job that can be retrieved by the calling the check endpoint.
    
    Jobs statuses are:
        - P => Pending (used when the job is still running, or execution errors appeared)
        - F ==> Failed
        - V ==> Valid
     
    :param req_id: The request ID used by the client as a unique identifier for their job.
    :type req_id: str
    :param ena_dir: The directory on ENA machine that containing the datafiles and the SDRF.
    :type ena_dir: str

    """
    report = {
        'file_errors': {},
        'pairs_errors': [],
        'valid_files': [],
        'execution_errors': [],
        'integrity_errors': []
    }
    v = Validate.objects.filter(job_id=req_id)
    if not v:
        v = Validate(job_id=str(req_id), data_dir=ena_dir)
        v.save()
    else:
        v = v[0]
    dir_name = ena_dir.split('/')[-1]
    if ena_dir.endswith('/'):
        dir_name = ena_dir.split('/')[-2]
    print ena_dir
    print dir_name
    try:
        local_dir = os.path.join(TEMP_FOLDER, str(req_id) + dir_name)
        if os.path.exists(local_dir):
            shutil.rmtree(local_dir)
        if not ena_dir.startswith(ENA_DIR):
            ena_dir = os.path.join(ENA_DIR, ena_dir)
        out, err = copy_files(ena_dir, local_dir)
        print out
        print err
        if err:
            report['execution_errors'].append(err)
        sdrf_file = ''
        data_files = []
        pairs = []

        for f in os.listdir(local_dir):
            if f.endswith('.sdrf.txt'):
                sdrf_file = os.path.join(local_dir, f)
                break
        try:
            sdrf = SdrfCollection(sdrf_file)
        except Exception, e:
            report['integrity_errors'].append(str(e))
            v.status = 'F'
            v.validation_report = json.dumps(report)
            v.save()
            return

        for i in range(len(sdrf.rows)):
            r = sdrf.rows[i]
            if r.is_paired:
                continue
            print colored.yellow(
                str(
                    dict(out_file=os.path.join(local_dir, str(i + 1)),
                         name=str(i + 1),
                         file_name=r.data_file,
                         base_dir=local_dir,
                         ena_dir=ena_dir)))
            data_file = FileObject(out_file=os.path.join(
                local_dir, str(i + 1)),
                                   name=str(i + 1),
                                   file_name=r.data_file,
                                   base_dir=local_dir,
                                   ena_dir=ena_dir)
            data_file.start()
            data_files.append(data_file)

        for p1, p2 in sdrf.pairs:
            p = Pair(p1.data_file, p2.data_file, local_dir, ena_dir)
            p.run()

            pairs.append(p)

        live = True
        while live:
            time.sleep(10)
            p_live = False
            f_live = False
            for p in pairs:
                if p.is_alive():
                    p_live = True
                    break
            for f in data_files:
                if f.is_alive():
                    f_live = True
                    break
            live = f_live or p_live
        for p in pairs:
            if not p.errors:
                if p.file_1.errors:
                    report['file_errors'][p.file_1.file_name] = p.file_1.errors
                else:
                    report['valid_files'].append(p.file_1.file_name)
                if p.file_2.errors:
                    report['file_errors'][p.file_2.file_name] = p.file_2.errors

                else:
                    report['valid_files'].append(p.file_2.file_name)

                if p.file_1.execution_error:
                    report['execution_errors'].append(p.file_1.execution_error)
                if p.file_2.execution_error:
                    report['execution_errors'].append(p.file_2.execution_error)

            report['pairs_errors'] += p.errors

        for data_file in data_files:
            if data_file.errors:
                report['file_errors'][data_file.file_name] = data_file.errors
            else:
                report['valid_files'].append(data_file.file_name)
            if data_file.execution_error:
                report['execution_errors'].append(data_file.execution_error)
                shutil.rmtree(local_dir)
コード例 #10
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train', default='train_wiki', help='train file')
    parser.add_argument('--val', default='val_wiki', help='val file')
    parser.add_argument('--test', default='test_wiki', help='test file')
    parser.add_argument('--adv', default=None, help='adv file')
    parser.add_argument('--trainN', default=10, type=int, help='N in train')
    parser.add_argument('--N', default=5, type=int, help='N way')
    parser.add_argument('--K', default=5, type=int, help='K shot')
    parser.add_argument('--Q',
                        default=5,
                        type=int,
                        help='Num of query per class')
    parser.add_argument('--batch_size', default=4, type=int, help='batch size')
    parser.add_argument('--train_iter',
                        default=30000,
                        type=int,
                        help='num of iters in training')
    parser.add_argument('--val_iter',
                        default=1000,
                        type=int,
                        help='num of iters in validation')
    parser.add_argument('--test_iter',
                        default=3000,
                        type=int,
                        help='num of iters in testing')
    parser.add_argument('--val_step',
                        default=2000,
                        type=int,
                        help='val after training how many iters')
    parser.add_argument('--model', default='proto', help='model name')
    parser.add_argument('--encoder',
                        default='cnn',
                        help='encoder: cnn or bert')
    parser.add_argument('--max_length',
                        default=128,
                        type=int,
                        help='max length')
    parser.add_argument('--lr', default=1e-1, type=float, help='learning rate')
    parser.add_argument('--weight_decay',
                        default=1e-5,
                        type=float,
                        help='weight decay')
    parser.add_argument('--dropout',
                        default=0.0,
                        type=float,
                        help='dropout rate')
    parser.add_argument('--na_rate',
                        default=0,
                        type=int,
                        help='NA rate (NA = Q * na_rate)')
    parser.add_argument('--grad_iter',
                        default=1,
                        type=int,
                        help='accumulate gradient every x iterations')
    parser.add_argument('--optim',
                        default='sgd',
                        help='sgd / adam / bert_adam')
    parser.add_argument('--hidden_size',
                        default=230,
                        type=int,
                        help='hidden size')
    parser.add_argument('--load_ckpt', default=None, help='load ckpt')
    parser.add_argument('--save_ckpt', default=None, help='save ckpt')
    parser.add_argument('--fp16',
                        action='store_true',
                        help='use nvidia apex fp16')
    parser.add_argument('--only_test', action='store_true', help='only test')
    parser.add_argument('--pair', action='store_true', help='use pair model')

    opt = parser.parse_args()
    trainN = opt.trainN
    N = opt.N
    K = opt.K
    Q = opt.Q
    batch_size = opt.batch_size
    model_name = opt.model
    encoder_name = opt.encoder
    max_length = opt.max_length

    print("{}-way-{}-shot Few-Shot Relation Classification".format(N, K))
    print("model: {}".format(model_name))
    print("encoder: {}".format(encoder_name))
    print("max_length: {}".format(max_length))

    if encoder_name == 'cnn':
        try:
            glove_mat = np.load('./pretrain/glove/glove_mat.npy')
            glove_word2id = json.load(
                open('./pretrain/glove/glove_word2id.json'))
        except:
            raise Exception(
                "Cannot find glove files. Run glove/download_glove.sh to download glove files."
            )
        sentence_encoder = CNNSentenceEncoder(glove_mat, glove_word2id,
                                              max_length)
    elif encoder_name == 'bert':
        if opt.pair:
            sentence_encoder = BERTPAIRSentenceEncoder(
                './pretrain/bert-base-uncased', max_length)
        else:
            sentence_encoder = BERTSentenceEncoder(
                './pretrain/bert-base-uncased', max_length)
    else:
        raise NotImplementedError

    if opt.pair:
        train_data_loader = get_loader_pair(opt.train,
                                            sentence_encoder,
                                            N=trainN,
                                            K=K,
                                            Q=Q,
                                            na_rate=opt.na_rate,
                                            batch_size=batch_size)
        val_data_loader = get_loader_pair(opt.val,
                                          sentence_encoder,
                                          N=N,
                                          K=K,
                                          Q=Q,
                                          na_rate=opt.na_rate,
                                          batch_size=batch_size)
        test_data_loader = get_loader_pair(opt.test,
                                           sentence_encoder,
                                           N=N,
                                           K=K,
                                           Q=Q,
                                           na_rate=opt.na_rate,
                                           batch_size=batch_size)
    else:
        train_data_loader = get_loader(opt.train,
                                       sentence_encoder,
                                       N=trainN,
                                       K=K,
                                       Q=Q,
                                       na_rate=opt.na_rate,
                                       batch_size=batch_size)
        val_data_loader = get_loader(opt.val,
                                     sentence_encoder,
                                     N=N,
                                     K=K,
                                     Q=Q,
                                     na_rate=opt.na_rate,
                                     batch_size=batch_size)
        test_data_loader = get_loader(opt.test,
                                      sentence_encoder,
                                      N=N,
                                      K=K,
                                      Q=Q,
                                      na_rate=opt.na_rate,
                                      batch_size=batch_size)
        if opt.adv:
            adv_data_loader = get_loader_unsupervised(opt.adv,
                                                      sentence_encoder,
                                                      N=trainN,
                                                      K=K,
                                                      Q=Q,
                                                      na_rate=opt.na_rate,
                                                      batch_size=batch_size)

    if opt.optim == 'sgd':
        pytorch_optim = optim.SGD
    elif opt.optim == 'adam':
        pytorch_optim = optim.Adam
    elif opt.optim == 'bert_adam':
        from pytorch_transformers import AdamW
        pytorch_optim = AdamW
    else:
        raise NotImplementedError
    if opt.adv:
        d = Discriminator(opt.hidden_size)
        framework = FewShotREFramework(train_data_loader,
                                       val_data_loader,
                                       test_data_loader,
                                       adv_data_loader,
                                       adv=opt.adv,
                                       d=d)
    else:
        framework = FewShotREFramework(train_data_loader, val_data_loader,
                                       test_data_loader)

    prefix = '-'.join(
        [model_name, encoder_name, opt.train, opt.val,
         str(N), str(K)])
    if opt.adv is not None:
        prefix += '-adv_' + opt.adv
    if opt.na_rate != 0:
        prefix += '-na{}'.format(opt.na_rate)

    if model_name == 'proto':
        model = Proto(sentence_encoder, hidden_size=opt.hidden_size)
    elif model_name == 'gnn':
        model = GNN(sentence_encoder, N)
    elif model_name == 'snail':
        print("HINT: SNAIL works only in PyTorch 0.3.1")
        model = SNAIL(sentence_encoder, N, K)
    elif model_name == 'metanet':
        model = MetaNet(N, K, sentence_encoder.embedding, max_length)
    elif model_name == 'siamese':
        model = Siamese(sentence_encoder,
                        hidden_size=opt.hidden_size,
                        dropout=opt.dropout)
    elif model_name == 'pair':
        model = Pair(sentence_encoder, hidden_size=opt.hidden_size)
    else:
        raise NotImplementedError

    if not os.path.exists('checkpoint'):
        os.mkdir('checkpoint')
    ckpt = 'checkpoint/{}.pth.tar'.format(prefix)
    if opt.save_ckpt:
        ckpt = opt.save_ckpt

    if torch.cuda.is_available():
        model.cuda()

    if not opt.only_test:
        if encoder_name == 'bert':
            bert_optim = True
        else:
            bert_optim = False

        framework.train(model,
                        prefix,
                        batch_size,
                        trainN,
                        N,
                        K,
                        Q,
                        pytorch_optim=pytorch_optim,
                        load_ckpt=opt.load_ckpt,
                        save_ckpt=ckpt,
                        na_rate=opt.na_rate,
                        val_step=opt.val_step,
                        fp16=opt.fp16,
                        pair=opt.pair,
                        train_iter=opt.train_iter,
                        val_iter=opt.val_iter,
                        bert_optim=bert_optim)
    else:
        ckpt = opt.load_ckpt

    acc = 0
    his_acc = []
    total_test_round = 5
    for i in range(total_test_round):
        cur_acc = framework.eval(model,
                                 batch_size,
                                 N,
                                 K,
                                 Q,
                                 opt.test_iter,
                                 na_rate=opt.na_rate,
                                 ckpt=ckpt,
                                 pair=opt.pair)
        his_acc.append(cur_acc)
        acc += cur_acc
    acc /= total_test_round
    nhis_acc = np.array(his_acc)
    error = nhis_acc.std() * 1.96 / (nhis_acc.shape[0]**0.5)
    print("RESULT: %.2f\\pm%.2f" % (acc * 100, error * 100))

    result_file = open('./result.txt', 'a+')
    result_file.write(
        "test data: %12s | model: %45s | acc: %.6f\n | error: %.6f\n" %
        (opt.test, prefix, acc, error))

    result_file = open('./result_detail.txt', 'a+')
    result_detail = {
        'test': opt.test,
        'model': prefix,
        'acc': acc,
        'his': his_acc
    }
    result_file.write("%s\n" % (json.dumps(result_detail)))
コード例 #11
0
class Detector:
    #RUNS USING GPU if available and pytoch has CUDA support

    def __init__(self, chpt_path, max_length=128):
        """
            Initializer
        """
        self.checkpoint_path = chpt_path
        self.bert_pretrained_checkpoint = 'bert-base-uncased'
        self.max_length = max_length
        self.sentence_encoder = BERTPAIRSentenceEncoder(
            self.bert_pretrained_checkpoint, self.max_length)

        self.model = Pair(self.sentence_encoder, hidden_size=768)
        if torch.cuda.is_available():
            self.model = self.model.cuda()
        self.model.eval()

        #         self.nlp_coref = spacy.load("en_core_web_sm")
        #         neuralcoref.add_to_pipe(self.nlp_coref)
        self.nlp_no_coref = spacy.load("en_core_web_sm")
        self.load_model()

    def __load_model_from_checkpoint__(self, ckpt):
        '''
        ckpt: Path of the checkpoint
        return: Checkpoint dict
        '''
        if os.path.isfile(ckpt):
            checkpoint = torch.load(ckpt)
            print("Successfully loaded checkpoint '%s'" % ckpt)
            return checkpoint
        else:
            raise Exception("No checkpoint found at '%s'" % ckpt)

    def bert_tokenize(self, tokens, head_indices, tail_indices):
        word = self.sentence_encoder.tokenize(tokens, head_indices,
                                              tail_indices)
        return word

    def load_model(self):
        """
            Loads the model from the checkpoint
        """
        state_dict = self.__load_model_from_checkpoint__(
            self.checkpoint_path)['state_dict']
        own_state = self.model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            own_state[name].copy_(param)

#     def spacy_tokenize_coref(self,sentence):
#         """
#             Tokenizes the sentence using spacy
#         """
#         return list(map(str, self.nlp_coref(sentence)))

    def spacy_tokenize_no_coref(self, sentence):
        """
            Tokenizes the sentence using spacy
        """
        try:
            return list(map(str, self.nlp_no_coref(sentence)))
        except TypeError as e:
            print("problem sentence: '{}'".format(sentence))
            raise e

#     def get_head_tail_pairs(self,sentence):
#         """
#             Gets pairs of heads and tails of named entities so that relation identification can be done on these.
#         """
#         acceptable_entity_types = ['PERSON', 'NORP', 'ORG', 'GPE', 'PRODUCT', 'EVENT', 'LAW', 'LOC', 'FAC']
#         doc = self.nlp_coref(sentence)
#         entity_info = [(X.text, X.label_) for X in doc.ents]
#         entity_info = set(map(lambda x:x[0], filter(lambda x:x[1] in acceptable_entity_types, entity_info)))

#         return combinations(entity_info, 2)

    def _get_indices_alt(self, tokens, tokenized_head, tokenized_tail):
        """
            Alternative implemention for getting the indices of the head and tail if exact matches cannot be done.
        """
        head_indices = None
        tail_indices = None
        for i in range(len(tokens)):
            if tokens[i] in tokenized_head[0] or tokenized_head[0] in tokens[i]:
                broke = False
                for k, j in zip(tokens[i:i + len(tokenized_head)],
                                tokenized_head):
                    if k not in j and j not in k:
                        broke = True
                        break
                if not broke:
                    head_indices = list(range(i, i + len(tokenized_head)))
                    break
        for i in range(len(tokens)):
            if tokens[i] in tokenized_tail[0] or tokenized_tail[0] in tokens[i]:
                broke = False
                for k, j in zip(tokens[i:i + len(tokenized_tail)],
                                tokenized_tail):
                    if k not in j and j not in k:
                        broke = True
                        break
                if not broke:
                    tail_indices = list(range(i, i + len(tokenized_tail)))
                    break
        return head_indices, tail_indices

    def _calculate_conf(self, logits, order, pred):
        exp = list(float(i) for i in logits[0][0])
        exp = [math.exp(i) for i in exp]
        if pred == 'NA':
            return exp[-1] * 100 / sum(exp)
        return exp[order.index(pred)] * 100 / sum(exp)

    def run_detection_algorithm(self, query, relation_data):
        """
            Runs the algorithm/model on the given query using the given support data.
        """
        N = len(relation_data)
        K = len(relation_data[0]['examples'])
        Q = 1
        head = query['head']
        tail = query['tail']
        fusion_set = {'word': [], 'mask': [], 'seg': []}
        tokens = self.spacy_tokenize_no_coref(query['sentence'])

        print("head: '{}' tail: '{}' sentence: '{}'".format(
            head, tail, query['sentence']))

        tokenized_head = self.spacy_tokenize_no_coref(head)
        tokenized_tail = self.spacy_tokenize_no_coref(tail)

        head_indices = None
        tail_indices = None
        for i in range(len(tokens)):
            if tokens[i] == tokenized_head[0] and tokens[
                    i:i + len(tokenized_head)] == tokenized_head:
                head_indices = list(range(i, i + len(tokenized_head)))
                break
        for i in range(len(tokens)):
            if tokens[i] == tokenized_tail[0] and tokens[
                    i:i + len(tokenized_tail)] == tokenized_tail:
                tail_indices = list(range(i, i + len(tokenized_tail)))
                break

        if head_indices is None or tail_indices is None:
            head_indices, tail_indices = self._get_indices_alt(
                tokens, tokenized_head, tokenized_tail)

        if head_indices is None or tail_indices is None:
            print(tokenized_head)
            print(tokenized_tail)
            print(tokens)
            raise ValueError(
                "Head/Tail indices error: head: {} \n tail: {} \n sentence: {}"
                .format(head, tail, query['sentence']))

        bert_query_tokens = self.bert_tokenize(tokens, head_indices,
                                               tail_indices)
        for relation in relation_data:
            for ex in relation['examples']:
                tokens = self.spacy_tokenize_no_coref(ex['sentence'])
                tokenized_head = self.spacy_tokenize_no_coref(
                    ex['head']
                )  #head and tail spelling and punctuation should match the corefered output exactly
                tokenized_tail = self.spacy_tokenize_no_coref(ex['tail'])

                head_indices = None
                tail_indices = None
                for i in range(len(tokens)):
                    if tokens[i] == tokenized_head[0] and tokens[
                            i:i + len(tokenized_head)] == tokenized_head:
                        head_indices = list(range(i, i + len(tokenized_head)))
                        break
                for i in range(len(tokens)):
                    if tokens[i] == tokenized_tail[0] and tokens[
                            i:i + len(tokenized_tail)] == tokenized_tail:
                        tail_indices = list(range(i, i + len(tokenized_tail)))
                        break
                if head_indices is None or tail_indices is None:
                    raise ValueError(
                        "Head/Tail indices error: head: {} \n tail: {} \n sentence: {}"
                        .format(ex['head'], ex['tail'], ex['sentence']))

                bert_relation_example_tokens = self.bert_tokenize(
                    tokens, head_indices, tail_indices)

                SEP = self.sentence_encoder.tokenizer.convert_tokens_to_ids(
                    ['[SEP]'])
                CLS = self.sentence_encoder.tokenizer.convert_tokens_to_ids(
                    ['[CLS]'])
                word_tensor = torch.zeros((self.max_length)).long()

                new_word = CLS + bert_relation_example_tokens + SEP + bert_query_tokens + SEP
                for i in range(min(self.max_length, len(new_word))):
                    word_tensor[i] = new_word[i]
                mask_tensor = torch.zeros((self.max_length)).long()
                mask_tensor[:min(self.max_length, len(new_word))] = 1
                seg_tensor = torch.ones((self.max_length)).long()
                seg_tensor[:min(self.max_length,
                                len(bert_relation_example_tokens) + 1)] = 0
                fusion_set['word'].append(word_tensor)
                fusion_set['mask'].append(mask_tensor)
                fusion_set['seg'].append(seg_tensor)

        fusion_set['word'] = torch.stack(fusion_set['word'])
        fusion_set['seg'] = torch.stack(fusion_set['seg'])
        fusion_set['mask'] = torch.stack(fusion_set['mask'])

        if torch.cuda.is_available():
            fusion_set['word'] = fusion_set['word'].cuda()
            fusion_set['seg'] = fusion_set['seg'].cuda()
            fusion_set['mask'] = fusion_set['mask'].cuda()

        logits, pred = self.model(fusion_set, N, K, Q)
        gc.collect()
        order = list(r['name'] for r in relation_data)
        pred_relation = relation_data[
            pred.item()]['name'] if pred.item() < len(relation_data) else 'NA'
        return {
            'sentence': query['sentence'],
            'head': head,
            'tail': tail,
            'pred_relation': pred_relation,
            'conf': int(self._calculate_conf(logits, order, pred_relation))
        }  #returns (sentence, head, tail, prediction relation name)

    def print_result(self, sentence, head, tail, prediction):
        """
            Helper function to print the results to the stdout.
        """
        print('Sentence: \"{}\", head: \"{}\", tail: \"{}\", prediction: {}'.
              format(sentence, head, tail, prediction))