コード例 #1
0
def bert2gram_decoder(args, data_loader, dataset, model, test_input_refactor,
                      pred_arranger, mode):
    logging.info(' Bert2Gram : Start Generating Keyphrases for %s  ...' % mode)
    test_time = utils.Timer()

    tot_examples = 0
    tot_predictions = []

    for step, batch in enumerate(tqdm(data_loader)):

        inputs, indices, lengths = test_input_refactor(batch,
                                                       model.args.device)
        try:
            logit_lists = model.test_bert2gram(inputs, lengths,
                                               args.max_phrase_words)
        except:
            logging.error(str(traceback.format_exc()))
            continue

        # decode logits to phrase per batch
        params = {
            'examples': dataset.examples,
            'logit_lists': logit_lists,
            'indices': indices,
            'max_phrase_words': args.max_phrase_words,
            'return_num': Decode_Candidate_Number[args.dataset_class],
            'stem_flag': False
        }

        batch_predictions = generator.gram2phrase(**params)
        tot_predictions.extend(batch_predictions)

    candidate = pred_arranger(tot_predictions)
    return candidate
コード例 #2
0
def bert2tag_decoder(
    args,
    data_loader,
    dataset,
    model,
    test_input_refactor,
    pred_arranger,
    mode,
    stem_flag=False,
):
    logging.info("Start Generating Keyphrases for %s ... \n" % mode)
    test_time = utils.Timer()
    if args.dataset_class == "kp20k":
        stem_flag = True

    tot_examples = 0
    tot_predictions = []
    for step, batch in enumerate(tqdm(data_loader)):
        inputs, indices, lengths = test_input_refactor(batch, model.args.device)
        try:
            logit_lists = model.test_bert2tag(inputs, lengths)
        except:
            logging.error(str(traceback.format_exc()))
            continue

        # decode logits to phrase per batch
        params = {
            "examples": dataset.examples,
            "logit_lists": logit_lists,
            "indices": indices,
            "max_phrase_words": args.max_phrase_words,
            "pooling": args.tag_pooling,
            "return_num": Decode_Candidate_Number[args.dataset_class],
            "stem_flag": stem_flag,
        }

        batch_predictions = generator.tag2phrase(**params)
        tot_predictions.extend(batch_predictions)

    candidate = pred_arranger(tot_predictions)
    return candidate