示例#1
0
def main(args):
    print(args)
    print("Started experiment!")
    utils.print_args(args)
    utils.set_seed(args.seed)

    smoothing_method = {"nist": SmoothingFunction().method3}
    results = {}
    scores = {}
    for name, method in smoothing_method.items():
        scores[name] = utils.evaluate_bleu(
            args.input_file,
            os.path.join("data", "valid.txt"),
            num_real_sentences=args.num_sentences,
            num_generated_sentences=args.num_sentences,
            gram=args.gram,
            smoothing_method=method,
            chunk_size=15)
        print()

    for name in smoothing_method.keys():
        results[name] = {}
        results[name]['scores'] = scores[name]

    print("Results:", results)
    bleu = results['nist']['scores']['bleu5']
    sbleu = results['nist']['scores']['self-bleu5']
    hmean = stats.hmean([bleu, 1.0 / sbleu])
    print("Harmonic Mean:", hmean)
示例#2
0
def main():
    # ハイパーパラメータの設定
    batch_size = 32
    epochs = 100
    model_path = 'models/attention_model.h5'
    enc_arch = 'models/encoder.json'
    dec_arch = 'models/decoder.json'
    data_path = 'data/jpn.txt'
    num_words = 10000
    num_data = 20000

    # データ・セット読み込み
    en_texts, ja_texts = load_dataset(data_path)
    en_texts, ja_texts = en_texts[:num_data], ja_texts[:num_data]

    # データ・セットの前処理
    ja_texts = preprocess_ja(ja_texts)
    ja_texts = preprocess_dataset(ja_texts)
    en_texts = preprocess_dataset(en_texts)
    x_train, x_test, y_train, y_test = train_test_split(en_texts,
                                                        ja_texts,
                                                        test_size=0.2,
                                                        random_state=42)

    en_vocab = build_vocabulary(x_train, num_words)
    ja_vocab = build_vocabulary(y_train, num_words)
    x_train, y_train = create_dataset(x_train, y_train, en_vocab, ja_vocab)

    # モデルの構築
    encoder = Encoder(num_words, return_sequences=True)
    decoder = AttentionDecoder(num_words)
    seq2seq = Seq2seq(encoder, decoder)
    model = seq2seq.build()
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')

    # コールバックの用意
    callbacks = [
        EarlyStopping(patience=3),
        ModelCheckpoint(model_path,
                        save_best_only=True,
                        save_weights_only=True)
    ]

    # モデルの学習
    model.fit(x=x_train,
              y=y_train,
              batch_size=batch_size,
              epochs=epochs,
              callbacks=callbacks,
              validation_split=0.1)
    encoder.save_as_json(enc_arch)
    decoder.save_as_json(dec_arch)

    # 予測
    encoder = Encoder.load(enc_arch, model_path)
    decoder = Decoder.load(dec_arch, model_path)
    api = InferenceAPIforAttention(encoder, decoder, en_vocab, ja_vocab)
    texts = sorted(set(en_texts[:50]), key=len)
    for text in texts:
        decoded = api.predict(text=text)
        print('English : {}'.format(text))
        print('Japanese: {}'.format(decoded))

    # 性能評価
    y_test = [y.split(' ')[1:-1] for y in y_test]
    bleu_score = evaluate_bleu(x_test, y_test, api)
    print('BLEU: {}'.format(bleu_score))
def evaluation(args):
    source = pickle_load(os.path.join(args.model_path, 'source.pkl'))
    target = pickle_load(os.path.join(args.model_path, 'target.pkl'))
    target_test = pickle_load(os.path.join(args.model_path, 'target_test.pkl'))
    setting = load_setting(os.path.join(args.model_path, 'setting.yaml'))
    start_id, end_id = setting['start_id'], setting['end_id']
    type_size = setting['type_size']
    player_size = setting['player_size']
    team_size = setting['team_size']
    detail_size = setting['detail_size']
    detail_dim = setting['detail_dim']
    src_embed = setting['src_embed']
    event_size = setting['event_size']
    vocab_size = setting['vocab_size']
    trg_embed = setting['trg_embed']
    hidden = setting['hidden']
    start_id = setting['start_id']
    end_id = setting['end_id']
    class_weight = None
    mlp_layers = setting['mlp_layers']
    max_length = setting['max_length']
    dropout = setting['dropout']
    loss_weight = None
    disc_loss = setting['disc_loss']
    loss_func = setting['loss_func']
    net = setting['net']
    dataset = setting['dataset']
    numbering = setting['numbering']
    reverse_decode = setting['reverse_decode']
    home_player_tag = target.word_to_id.get(target.home_player_tag)
    away_player_tag = target.word_to_id.get(target.away_player_tag)
    home_team_tag = target.word_to_id.get(target.home_team_tag)
    away_team_tag = target.word_to_id.get(target.away_team_tag)
    test = OptaDataset(path=dataset + '.test',
                       fields={
                           'source': source,
                           'target': target_test
                       })
    test20 = OptaDataset(path=dataset + '.test',
                         fields={
                             'source': source,
                             'target': target_test
                         },
                         limit_length=20)
    test15 = OptaDataset(path=dataset + '.test',
                         fields={
                             'source': source,
                             'target': target_test
                         },
                         limit_length=15)
    test10 = OptaDataset(path=dataset + '.test',
                         fields={
                             'source': source,
                             'target': target_test
                         },
                         limit_length=10)

    if 'disc' in net:
        content_word_size = len(target.content_word_to_id)
    print('vocab size: {}'.format(vocab_size))
    if net == 'plain':
        model = MLPEncoder2AttentionDecoder(type_size,
                                            player_size,
                                            team_size,
                                            detail_size,
                                            detail_dim,
                                            src_embed,
                                            event_size,
                                            vocab_size,
                                            trg_embed,
                                            hidden,
                                            start_id,
                                            end_id,
                                            class_weight,
                                            mlp_layers,
                                            max_length,
                                            dropout,
                                            IGNORE_LABEL,
                                            reverse_decode=reverse_decode)
    elif net == 'tmpl':
        model = MLPEncoder2AttentionDecoder(type_size,
                                            player_size,
                                            team_size,
                                            detail_size,
                                            detail_dim,
                                            src_embed,
                                            event_size,
                                            vocab_size,
                                            trg_embed,
                                            hidden,
                                            start_id,
                                            end_id,
                                            class_weight,
                                            mlp_layers,
                                            max_length,
                                            dropout,
                                            IGNORE_LABEL,
                                            source.id_to_player,
                                            home_player_tag,
                                            away_player_tag,
                                            source.id_to_team,
                                            home_team_tag,
                                            away_team_tag,
                                            target.player_to_id,
                                            target.players,
                                            reverse_decode=reverse_decode)
    elif net == 'gate':
        model = MLPEncoder2GatedAttentionDecoder(type_size,
                                                 player_size,
                                                 team_size,
                                                 detail_size,
                                                 detail_dim,
                                                 src_embed,
                                                 event_size,
                                                 vocab_size,
                                                 trg_embed,
                                                 hidden,
                                                 start_id,
                                                 end_id,
                                                 class_weight,
                                                 mlp_layers,
                                                 max_length,
                                                 dropout,
                                                 IGNORE_LABEL,
                                                 reverse_decode=reverse_decode)
    elif net == 'gate-tmpl':
        model = MLPEncoder2GatedAttentionDecoder(type_size,
                                                 player_size,
                                                 team_size,
                                                 detail_size,
                                                 detail_dim,
                                                 src_embed,
                                                 event_size,
                                                 vocab_size,
                                                 trg_embed,
                                                 hidden,
                                                 start_id,
                                                 end_id,
                                                 class_weight,
                                                 mlp_layers,
                                                 max_length,
                                                 dropout,
                                                 IGNORE_LABEL,
                                                 source.id_to_player,
                                                 home_player_tag,
                                                 away_player_tag,
                                                 source.id_to_team,
                                                 home_team_tag,
                                                 away_team_tag,
                                                 target.player_to_id,
                                                 target.players,
                                                 reverse_decode=reverse_decode)
    elif net == 'disc':
        model = DiscriminativeMLPEncoder2AttentionDecoder(
            type_size,
            player_size,
            team_size,
            detail_size,
            detail_dim,
            src_embed,
            event_size,
            vocab_size,
            content_word_size,
            trg_embed,
            hidden,
            start_id,
            end_id,
            class_weight,
            loss_weight,
            disc_loss,
            loss_func,
            mlp_layers,
            max_length,
            dropout,
            IGNORE_LABEL,
            reverse_decode=reverse_decode)
    elif net == 'disc-tmpl':
        model = DiscriminativeMLPEncoder2AttentionDecoder(
            type_size,
            player_size,
            team_size,
            detail_size,
            detail_dim,
            src_embed,
            event_size,
            vocab_size,
            content_word_size,
            trg_embed,
            hidden,
            start_id,
            end_id,
            class_weight,
            loss_weight,
            disc_loss,
            loss_func,
            mlp_layers,
            max_length,
            dropout,
            IGNORE_LABEL,
            source.id_to_player,
            home_player_tag,
            away_player_tag,
            source.id_to_team,
            home_team_tag,
            away_team_tag,
            target.player_to_id,
            target.players,
            reverse_decode=reverse_decode)
    elif net == 'gate-disc':
        model = DiscriminativeMLPEncoder2GatedAttentionDecoder(
            type_size,
            player_size,
            team_size,
            detail_size,
            detail_dim,
            src_embed,
            event_size,
            vocab_size,
            content_word_size,
            trg_embed,
            hidden,
            start_id,
            end_id,
            class_weight,
            loss_weight,
            disc_loss,
            loss_func,
            mlp_layers,
            max_length,
            dropout,
            IGNORE_LABEL,
            reverse_decode=reverse_decode)
    elif net == 'gate-disc-tmpl':
        model = DiscriminativeMLPEncoder2GatedAttentionDecoder(
            type_size,
            player_size,
            team_size,
            detail_size,
            detail_dim,
            src_embed,
            event_size,
            vocab_size,
            content_word_size,
            trg_embed,
            hidden,
            start_id,
            end_id,
            class_weight,
            loss_weight,
            disc_loss,
            loss_func,
            mlp_layers,
            max_length,
            dropout,
            IGNORE_LABEL,
            source.id_to_player,
            home_player_tag,
            away_player_tag,
            source.id_to_team,
            home_team_tag,
            away_team_tag,
            target.player_to_id,
            target.players,
            reverse_decode=reverse_decode)
    if numbering:
        model.player_id = target.player_id
        model.team_id = target.team_id
    # load best model
    if args.gpu is not None:
        model.use_gpu(args.gpu)
    model.id_to_word = target.id_to_word
    model.load_model(os.path.join(args.model_path, 'best.model'))
    batch_size = args.batch
    src_test_iter = SequentialIterator(test.source,
                                       batch_size,
                                       None,
                                       event_size,
                                       source.fillvalue,
                                       gpu=args.gpu)
    src_test20_iter = SequentialIterator(test20.source,
                                         batch_size,
                                         None,
                                         event_size,
                                         source.fillvalue,
                                         gpu=args.gpu)
    src_test15_iter = SequentialIterator(test15.source,
                                         batch_size,
                                         None,
                                         event_size,
                                         source.fillvalue,
                                         gpu=args.gpu)
    src_test10_iter = SequentialIterator(test10.source,
                                         batch_size,
                                         None,
                                         event_size,
                                         source.fillvalue,
                                         gpu=args.gpu)
    trg_test_iter = Iterator(test.target,
                             batch_size,
                             wrapper=EndTokenIdRemoval(end_id),
                             gpu=None)
    trg_test20_iter = Iterator(test20.target,
                               batch_size,
                               wrapper=EndTokenIdRemoval(end_id),
                               gpu=None)
    trg_test15_iter = Iterator(test15.target,
                               batch_size,
                               wrapper=EndTokenIdRemoval(end_id),
                               gpu=None)
    trg_test10_iter = Iterator(test10.target,
                               batch_size,
                               wrapper=EndTokenIdRemoval(end_id),
                               gpu=None)

    with open('./dataset/player_list.json.new') as f:
        id_to_player = json.load(f)
    with open('./dataset/team_list.json.new') as f:
        id_to_team = json.load(f)

    def convert(ind, no_tag=False):
        if 'player' in ind:
            if no_tag:
                i = ind.replace('player', '')
                return id_to_player.get(i, ind)
            else:
                return ind
        elif 'team' in ind:
            if no_tag:
                i = ind.replace('team', '')
                return id_to_team.get(i, ind)
            else:
                return ind
        else:
            return ind

    if 'disc' in net:
        bleu_score, accuracy, hypotheses = evaluate_bleu_and_accuracy(
            model, src_test_iter, trg_test_iter)
        bleu_score20, _, hypotheses20 = evaluate_bleu_and_accuracy(
            model, src_test20_iter, trg_test20_iter)
        bleu_score15, _, hypotheses15 = evaluate_bleu_and_accuracy(
            model, src_test15_iter, trg_test15_iter)
        bleu_score10, _, hypotheses10 = evaluate_bleu_and_accuracy(
            model, src_test10_iter, trg_test10_iter)
    else:
        bleu_score, hypotheses = evaluate_bleu(model, src_test_iter,
                                               trg_test_iter)
        bleu_score20, hypotheses20 = evaluate_bleu(model, src_test20_iter,
                                                   trg_test20_iter)
        bleu_score15, hypotheses15 = evaluate_bleu(model, src_test15_iter,
                                                   trg_test15_iter)
        bleu_score10, hypotheses10 = evaluate_bleu(model, src_test10_iter,
                                                   trg_test10_iter)

    print('best score: {}'.format(bleu_score))
    print('best score20: {}'.format(bleu_score20))
    print('best score15: {}'.format(bleu_score15))
    print('best score10: {}'.format(bleu_score10))
    # save hypothesis
    hypotheses_for_save = [
        ' '.join([convert(y, True) for y in h]) for h in hypotheses
    ]
    hypotheses20_for_save = [
        ' '.join([convert(y, True) for y in h]) for h in hypotheses20
    ]
    hypotheses15_for_save = [
        ' '.join([convert(y, True) for y in h]) for h in hypotheses15
    ]
    hypotheses10_for_save = [
        ' '.join([convert(y, True) for y in h]) for h in hypotheses10
    ]
    references_for_save = [
        ' '.join(convert(y, True) for y in r[0]) for r in test.target
    ]
    references20_for_save = [
        ' '.join(convert(y, True) for y in r[0]) for r in test20.target
    ]
    references15_for_save = [
        ' '.join(convert(y, True) for y in r[0]) for r in test15.target
    ]
    references10_for_save = [
        ' '.join(convert(y, True) for y in r[0]) for r in test10.target
    ]
    TextFile(os.path.join(args.model_path, 'hypo'), hypotheses_for_save).save()
    TextFile(os.path.join(args.model_path, 'hypo_len20'),
             hypotheses20_for_save).save()
    TextFile(os.path.join(args.model_path, 'hypo_len15'),
             hypotheses15_for_save).save()
    TextFile(os.path.join(args.model_path, 'hypo_len10'),
             hypotheses10_for_save).save()
    TextFile(os.path.join('./dataset', 'ref'), references_for_save).save()
    TextFile(os.path.join('./dataset', 'ref_len20'),
             references20_for_save).save()
    TextFile(os.path.join('./dataset', 'ref_len15'),
             references15_for_save).save()
    TextFile(os.path.join('./dataset', 'ref_len10'),
             references10_for_save).save()
    # generate readable text
    result = []
    for ref, hyp in zip(test.target.data, hypotheses):
        if type(ref) == tuple:
            ref = ref[0]
        ref = ' '.join([convert(y) for y in ref]).split()
        try:
            bleu_score = sentence_bleu(
                [ref], hyp, smoothing_function=SmoothingFunction().method1)
        except:
            bleu_score = 0
        ref = ' '.join([convert(y, True) for y in ref]).split()
        hyp = ' '.join([convert(y, True) for y in hyp]).split()
        result.append((' '.join(ref), ' '.join(hyp), bleu_score))
    inputs = []
    for xs in test20.source.data:
        data = []
        for x in xs[:5]:
            event = event_type_mapper.get(x[0], x[0])
            player = id_to_player.get(str(x[1]), x[1])
            team = id_to_team.get(str(x[2]), x[2])
            detail = ','.join(
                [qualifier_type_mapper.get(i[-1], i[-1]) for i in x[-1]])
            data.append('event: {} player: {} team: {} detail: {}'.format(
                event, player, team, detail))
        inputs.append('\n'.join(data))
    result = [[x, *y] for x, y in zip(inputs, result)]
    result = sorted(result, key=lambda x: -x[-1])
    TextFile(os.path.join(args.model_path, 'test20_gate_disc_tmpl.txt'), [
        'src:\n{}\nref: {}\nhyp: {}\nbleu: {}\n##\n'.format(*x) for x in result
    ]).save()
示例#4
0
def main():
    # Set hyper-parameters.
    batch_size = 32
    epochs = 100
    model_path = 'atmodel.h5'
    enc_arch = 'encoder.json'
    dec_arch = 'decoder.json'
    data_path = '../data/w16to19hukusimaconv.txt'
    num_words = 7000
    num_data = 4367

    # Data loading.
    en_texts, ja_texts = load_dataset(data_path)
    en_texts, ja_texts = en_texts[:num_data], ja_texts[:num_data]

    # Preprocessings.
    #ja_texts = preprocess_ja(ja_texts)
    ja_texts = preprocess_dataset(ja_texts)
    en_texts = preprocess_dataset(en_texts)
    x_train, x_test, y_train, y_test = train_test_split(en_texts,
                                                        ja_texts,
                                                        test_size=0.2,
                                                        random_state=42)

    en_vocab = build_vocabulary(x_train, num_words)
    ja_vocab = build_vocabulary(y_train, num_words)
    print(x_train[:3])
    print(y_train[:3])
    x_train, y_train = create_dataset(x_train, y_train, en_vocab, ja_vocab)

    print(en_vocab.word_index)
    print(ja_vocab.word_index)

    # Build a simple model.
    encoder = Encoder(num_words)
    decoder = Decoder(num_words)
    # Build an attention model.
    #encoder = Encoder(num_words, return_sequences=True)
    #decoder = AttentionDecoder(num_words)
    seq2seq = Seq2seq(encoder, decoder)
    model = seq2seq.build()
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')

    # Train the model.
    callbacks = [
        EarlyStopping(patience=10),
        ModelCheckpoint(model_path,
                        save_best_only=True,
                        save_weights_only=True)
    ]
    """
    model.fit(x=x_train,
              y=y_train,
              batch_size=batch_size,
              epochs=epochs,
              callbacks=callbacks,
              validation_split=0.1)"""
    encoder.save_as_json(enc_arch)
    decoder.save_as_json(dec_arch)

    # Inference.
    encoder = Encoder.load(enc_arch, model_path)
    decoder = Decoder.load(dec_arch, model_path)
    api = InferenceAPI(encoder, decoder, en_vocab, ja_vocab)
    #api = InferenceAPIforAttention(encoder, decoder, en_vocab, ja_vocab)
    texts = sorted(set(en_texts[:50]), key=len)
    texts = ["お聞きしたいと思います", "さっき の 答弁 全く 納得 できません", "全く 納得 い き ません", "ありがとうございました", "おはようございます",\
            "よろしいでしょうか", "是非 よろしくお願いいたします", "もう少し 具体的に 教えて いただける と 助 か る んですけれども", "ちょっと 待 って", "質問 主 意 書 では 当然 混 同 は しておりません",\
            "正 式 な 要求 でいい んですか", "時間ですので まとめて ください", "ちょっと 静粛に お願いします", "よろしいですか", "静粛に お願いします",\
            "答弁 を まとめて ください", "時間 ですから", "驚 き の答弁 ですね", "それは いつ ごろ でしょうか", "そのとおり です"
    ]
    for text in texts:
        decoded = api.predict(text=text)
        print('入力: {}'.format(text))
        print('応答: {}'.format(decoded))

    y_test = [y.split(' ')[1:-1] for y in y_test]
    bleu_score = evaluate_bleu(x_test, y_test, api)
    print('BLEU: {}'.format(bleu_score))
def training(args):
    source = EventField(fix_length=args.event_size, embed_size=args.src_embed)
    mask_flag = 'tmpl' in args.net
    sentence_size = args.sentence_size if args.truncate else None
    reverse_decode = args.reverse_decode
    if 'disc' in args.net:
        target = TextAndContentWordField(start_token=None,
                                         fix_length=sentence_size,
                                         mask_player=mask_flag,
                                         mask_team=mask_flag,
                                         numbering=args.numbering,
                                         reverse=reverse_decode,
                                         bpc=args.bpc,
                                         multi_tag=args.multi_tag)
    else:
        target = TextField(start_token=None,
                           fix_length=sentence_size,
                           mask_player=mask_flag,
                           mask_team=mask_flag,
                           numbering=args.numbering,
                           reverse=reverse_decode,
                           bpc=args.bpc,
                           multi_tag=args.multi_tag)
    if args.truncate:
        train = OptaDataset(path=args.dataset + '.train',
                            fields={
                                'source': source,
                                'target': target
                            })
    else:
        train = OptaDataset(path=args.dataset + '.train',
                            fields={
                                'source': source,
                                'target': target
                            },
                            limit_length=args.limit)
    source.build_vocabulary(train.source)
    target.build_vocabulary(train.target, size=args.vocab_size)
    target.player_to_id = source.player_to_id
    target.players = source.id_to_player
    if mask_flag or 'disc' in args.net:
        content_word_to_id = getattr(target, 'content_word_to_id', None)
        target_test = TestTextField(source.id_to_player,
                                    source.id_to_team,
                                    target.word_to_id,
                                    content_word_to_id,
                                    target.unk_id,
                                    fix_length=None,
                                    bpc=args.bpc)
    else:
        target_test = TextField(start_token=None,
                                end_token=None,
                                fix_length=None,
                                bpc=args.bpc)
        target_test.word_to_id = target.word_to_id
        target_test.id_to_word = target.id_to_word
        target_test.unk_id = target.unk_id
    dev = OptaDataset(path=args.dataset + '.dev',
                      fields={
                          'source': source,
                          'target': target_test
                      },
                      limit_length=args.limit)
    train2 = OptaDataset(path=args.dataset + '.train',
                         fields={
                             'source': source,
                             'target': target_test
                         },
                         limit_length=args.limit)
    test = OptaDataset(path=args.dataset + '.test',
                       fields={
                           'source': source,
                           'target': target_test
                       })
    test20 = OptaDataset(path=args.dataset + '.test',
                         fields={
                             'source': source,
                             'target': target_test
                         },
                         limit_length=20)
    test15 = OptaDataset(path=args.dataset + '.test',
                         fields={
                             'source': source,
                             'target': target_test
                         },
                         limit_length=15)
    test10 = OptaDataset(path=args.dataset + '.test',
                         fields={
                             'source': source,
                             'target': target_test
                         },
                         limit_length=10)

    start_id, end_id = target.word_to_id['<s>'], target.word_to_id['</s>']
    class_weight = compute_class_weight('./dataset/player_list.txt',
                                        target.word_to_id,
                                        args.class_weight[0],
                                        args.class_weight[1],
                                        gpu=args.gpu)
    dirname = Utility.get_save_directory(
        args.net, './debug' if args.debug else args.output)
    if args.debug:
        save_path = os.path.join('./debug', dirname)
    else:
        save_path = os.path.join(args.output, dirname)
    Utility.make_directory(save_path)
    del args.vocab_size
    setting = {
        'vocab_size': len(target.word_to_id),
        'type_size': len(source.type_to_id),
        'player_size': len(source.player_to_id),
        'team_size': len(source.team_to_id),
        'detail_size': len(source.detail_to_id),
        'detail_dim': source.details_dimention,
        'start_id': start_id,
        'end_id': end_id,
        'unk_id': target.unk_id,
        'save_path': save_path,
        **vars(args)
    }
    dump_setting(setting, os.path.join(save_path, 'setting.yaml'))
    home_player_tag = target.word_to_id.get(target.home_player_tag)
    away_player_tag = target.word_to_id.get(target.away_player_tag)
    home_team_tag = target.word_to_id.get(target.home_team_tag)
    away_team_tag = target.word_to_id.get(target.away_team_tag)
    print('vocab size: {}'.format(len(target.word_to_id)))
    if args.net == 'plain':
        model = MLPEncoder2AttentionDecoder(len(source.type_to_id),
                                            len(source.player_to_id),
                                            len(source.team_to_id),
                                            len(source.detail_to_id),
                                            source.details_dimention,
                                            args.src_embed,
                                            args.event_size,
                                            len(target.word_to_id),
                                            args.trg_embed,
                                            args.hidden,
                                            start_id,
                                            end_id,
                                            class_weight,
                                            args.mlp_layers,
                                            args.max_length,
                                            args.dropout,
                                            IGNORE_LABEL,
                                            reverse_decode=reverse_decode)
    elif args.net == 'tmpl':
        model = MLPEncoder2AttentionDecoder(len(source.type_to_id),
                                            len(source.player_to_id),
                                            len(source.team_to_id),
                                            len(source.detail_to_id),
                                            source.details_dimention,
                                            args.src_embed,
                                            args.event_size,
                                            len(target.word_to_id),
                                            args.trg_embed,
                                            args.hidden,
                                            start_id,
                                            end_id,
                                            class_weight,
                                            args.mlp_layers,
                                            args.max_length,
                                            args.dropout,
                                            IGNORE_LABEL,
                                            source.id_to_player,
                                            home_player_tag,
                                            away_player_tag,
                                            source.id_to_team,
                                            home_team_tag,
                                            away_team_tag,
                                            target.player_to_id,
                                            target.players,
                                            reverse_decode=reverse_decode)
    elif args.net == 'gate':
        model = MLPEncoder2GatedAttentionDecoder(len(source.type_to_id),
                                                 len(source.player_to_id),
                                                 len(source.team_to_id),
                                                 len(source.detail_to_id),
                                                 source.details_dimention,
                                                 args.src_embed,
                                                 args.event_size,
                                                 len(target.word_to_id),
                                                 args.trg_embed,
                                                 args.hidden,
                                                 start_id,
                                                 end_id,
                                                 class_weight,
                                                 args.mlp_layers,
                                                 args.max_length,
                                                 args.dropout,
                                                 IGNORE_LABEL,
                                                 reverse_decode=reverse_decode)
    elif args.net == 'gate-tmpl':
        model = MLPEncoder2GatedAttentionDecoder(len(source.type_to_id),
                                                 len(source.player_to_id),
                                                 len(source.team_to_id),
                                                 len(source.detail_to_id),
                                                 source.details_dimention,
                                                 args.src_embed,
                                                 args.event_size,
                                                 len(target.word_to_id),
                                                 args.trg_embed,
                                                 args.hidden,
                                                 start_id,
                                                 end_id,
                                                 class_weight,
                                                 args.mlp_layers,
                                                 args.max_length,
                                                 args.dropout,
                                                 IGNORE_LABEL,
                                                 source.id_to_player,
                                                 home_player_tag,
                                                 away_player_tag,
                                                 source.id_to_team,
                                                 home_team_tag,
                                                 away_team_tag,
                                                 target.player_to_id,
                                                 target.players,
                                                 reverse_decode=reverse_decode)
    elif args.net == 'disc':
        model = DiscriminativeMLPEncoder2AttentionDecoder(
            len(source.type_to_id),
            len(source.player_to_id),
            len(source.team_to_id),
            len(source.detail_to_id),
            source.details_dimention,
            args.src_embed,
            args.event_size,
            len(target.word_to_id),
            len(target.content_word_to_id),
            args.trg_embed,
            args.hidden,
            start_id,
            end_id,
            class_weight,
            args.loss_weight,
            args.disc_loss,
            args.loss_func,
            args.mlp_layers,
            args.max_length,
            args.dropout,
            IGNORE_LABEL,
            reverse_decode=reverse_decode)
    elif args.net == 'disc-tmpl':
        model = DiscriminativeMLPEncoder2AttentionDecoder(
            len(source.type_to_id),
            len(source.player_to_id),
            len(source.team_to_id),
            len(source.detail_to_id),
            source.details_dimention,
            args.src_embed,
            args.event_size,
            len(target.word_to_id),
            len(target.content_word_to_id),
            args.trg_embed,
            args.hidden,
            start_id,
            end_id,
            class_weight,
            args.loss_weight,
            args.disc_loss,
            args.loss_func,
            args.mlp_layers,
            args.max_length,
            args.dropout,
            IGNORE_LABEL,
            source.id_to_player,
            home_player_tag,
            away_player_tag,
            source.id_to_team,
            home_team_tag,
            away_team_tag,
            target.player_to_id,
            target.players,
            reverse_decode=reverse_decode)
    elif args.net == 'gate-disc':
        model = DiscriminativeMLPEncoder2GatedAttentionDecoder(
            len(source.type_to_id),
            len(source.player_to_id),
            len(source.team_to_id),
            len(source.detail_to_id),
            source.details_dimention,
            args.src_embed,
            args.event_size,
            len(target.word_to_id),
            len(target.content_word_to_id),
            args.trg_embed,
            args.hidden,
            start_id,
            end_id,
            class_weight,
            args.loss_weight,
            args.disc_loss,
            args.loss_func,
            args.mlp_layers,
            args.max_length,
            args.dropout,
            IGNORE_LABEL,
            reverse_decode=reverse_decode)
    elif args.net == 'gate-disc-tmpl':
        model = DiscriminativeMLPEncoder2GatedAttentionDecoder(
            len(source.type_to_id),
            len(source.player_to_id),
            len(source.team_to_id),
            len(source.detail_to_id),
            source.details_dimention,
            args.src_embed,
            args.event_size,
            len(target.word_to_id),
            len(target.content_word_to_id),
            args.trg_embed,
            args.hidden,
            start_id,
            end_id,
            class_weight,
            args.loss_weight,
            args.disc_loss,
            args.loss_func,
            args.mlp_layers,
            args.max_length,
            args.dropout,
            IGNORE_LABEL,
            source.id_to_player,
            home_player_tag,
            away_player_tag,
            source.id_to_team,
            home_team_tag,
            away_team_tag,
            target.player_to_id,
            target.players,
            reverse_decode=reverse_decode)
    elif args.net == 'conv-gate-disc-tmpl':
        model = DiscriminativeGLUEncoder2GatedAttentionDecoder(
            len(source.type_to_id),
            len(source.player_to_id),
            len(source.team_to_id),
            len(source.detail_to_id),
            source.details_dimention,
            args.src_embed,
            args.event_size,
            len(target.word_to_id),
            len(target.content_word_to_id),
            args.trg_embed,
            args.hidden,
            start_id,
            end_id,
            class_weight,
            args.loss_weight,
            args.disc_loss,
            args.loss_func,
            args.mlp_layers,
            args.max_length,
            args.dropout,
            IGNORE_LABEL,
            source.id_to_player,
            home_player_tag,
            away_player_tag,
            source.id_to_team,
            home_team_tag,
            away_team_tag,
            target.player_to_id,
            target.players,
            reverse_decode=reverse_decode)

    model.keyword_ids = [
        target.word_to_id['save'], target.word_to_id['block'],
        target.word_to_id['chance'], target.word_to_id['shot'],
        target.word_to_id['clearance'], target.word_to_id['kick'],
        target.word_to_id['ball'], target.word_to_id['blocked'],
        target.word_to_id['denied']
    ]
    model.id_to_word = target.id_to_word
    if args.numbering:
        model.player_id = target.player_id
        model.team_id = target.team_id

    if args.gpu is not None:
        model.use_gpu(args.gpu)
    opt = optimizers.Adam(args.lr)
    opt.setup(model)
    if args.clipping > 0:
        opt.add_hook(GradientClipping(args.clipping))
    if args.decay > 0:
        opt.add_hook(WeightDecay(args.decay))

    N = len(train.source)
    batch_size = args.batch
    order_provider = OrderProvider(Sampling.get_random_order(N))
    src_train_iter = SequentialIterator(train.source,
                                        batch_size,
                                        order_provider,
                                        args.event_size,
                                        source.fillvalue,
                                        gpu=args.gpu)
    if 'disc' in args.net:
        trg_train_iter = TextAndLabelIterator(train.target,
                                              batch_size,
                                              order_provider,
                                              args.sentence_size,
                                              IGNORE_LABEL,
                                              gpu=args.gpu)
    else:
        trg_train_iter = SequentialIterator(train.target,
                                            batch_size,
                                            order_provider,
                                            args.sentence_size,
                                            IGNORE_LABEL,
                                            gpu=args.gpu)
    src_dev_iter = SequentialIterator(dev.source,
                                      batch_size,
                                      None,
                                      args.event_size,
                                      source.fillvalue,
                                      gpu=args.gpu)
    trg_dev_iter = Iterator(dev.target,
                            batch_size,
                            wrapper=EndTokenIdRemoval(end_id),
                            gpu=None)
    src_test_iter = SequentialIterator(test.source,
                                       batch_size,
                                       None,
                                       args.event_size,
                                       source.fillvalue,
                                       gpu=args.gpu)
    src_test20_iter = SequentialIterator(test20.source,
                                         batch_size,
                                         None,
                                         args.event_size,
                                         source.fillvalue,
                                         gpu=args.gpu)
    src_test15_iter = SequentialIterator(test15.source,
                                         batch_size,
                                         None,
                                         args.event_size,
                                         source.fillvalue,
                                         gpu=args.gpu)
    src_test10_iter = SequentialIterator(test10.source,
                                         batch_size,
                                         None,
                                         args.event_size,
                                         source.fillvalue,
                                         gpu=args.gpu)
    src_train2_iter = SequentialIterator(train2.source,
                                         batch_size,
                                         None,
                                         args.event_size,
                                         source.fillvalue,
                                         gpu=args.gpu)
    trg_train2_iter = Iterator(train2.target,
                               batch_size,
                               wrapper=EndTokenIdRemoval(end_id),
                               gpu=None)
    trg_test_iter = Iterator(test.target,
                             batch_size,
                             wrapper=EndTokenIdRemoval(end_id),
                             gpu=None)
    trg_test20_iter = Iterator(test20.target,
                               batch_size,
                               wrapper=EndTokenIdRemoval(end_id),
                               gpu=None)
    trg_test15_iter = Iterator(test15.target,
                               batch_size,
                               wrapper=EndTokenIdRemoval(end_id),
                               gpu=None)
    trg_test10_iter = Iterator(test10.target,
                               batch_size,
                               wrapper=EndTokenIdRemoval(end_id),
                               gpu=None)
    if 'disc' in args.net:
        trainer = Seq2SeqWithLabelTrainer(
            model, opt, src_train_iter, trg_train_iter, src_dev_iter,
            trg_dev_iter, order_provider, evaluate_bleu_and_accuracy,
            args.epoch, save_path, args.eval_step, src_train2_iter,
            trg_train2_iter)
    else:
        trainer = Seq2SeqTrainer(model, opt, src_train_iter, trg_train_iter,
                                 src_dev_iter, trg_dev_iter, order_provider,
                                 evaluate_bleu, args.epoch, save_path,
                                 args.eval_step, src_train2_iter,
                                 trg_train2_iter)

    trainer.run()

    # load best model
    model.load_model(os.path.join(save_path, 'best.model'))
    if 'disc' in args.net:
        bleu_score_dev, _, _ = evaluate_bleu_and_accuracy(
            model, src_dev_iter, trg_dev_iter)
        bleu_score, _, _ = evaluate_bleu_and_accuracy(model, src_test_iter,
                                                      trg_test_iter)
        bleu_score20, _, hypotheses = evaluate_bleu_and_accuracy(
            model, src_test20_iter, trg_test20_iter)
        bleu_score15, _, _ = evaluate_bleu_and_accuracy(
            model, src_test15_iter, trg_test15_iter)
        bleu_score10, _, _ = evaluate_bleu_and_accuracy(
            model, src_test10_iter, trg_test10_iter)
    else:
        bleu_score_dev, _ = evaluate_bleu(model, src_dev_iter, trg_dev_iter)
        bleu_score, _ = evaluate_bleu(model, src_test_iter, trg_test_iter)
        bleu_score20, hypotheses = evaluate_bleu(model, src_test20_iter,
                                                 trg_test20_iter)
        bleu_score15, _ = evaluate_bleu(model, src_test15_iter,
                                        trg_test15_iter)
        bleu_score10, _ = evaluate_bleu(model, src_test10_iter,
                                        trg_test10_iter)
    TextFile(os.path.join(save_path, 'hypotheses.txt'),
             [' '.join(ys) for ys in trainer.hypotheses]).save()
    print('dev score: {}'.format(bleu_score_dev))
    print('test score: {}'.format(bleu_score))
    print('test score20: {}'.format(bleu_score20))
    print('test score15: {}'.format(bleu_score15))
    print('test score10: {}'.format(bleu_score10))

    # saving fields
    pickle_dump(os.path.join(save_path, 'source.pkl'), source)
    pickle_dump(os.path.join(save_path, 'target.pkl'), target)
    pickle_dump(os.path.join(save_path, 'target_test.pkl'), target_test)
def score_gold(params): 
    utils.print_args(args)
    utils.set_seed(args.seed)

    gold_file = params['gold_file']
    gold_filename = os.path.basename(gold_file)
    with open(gold_file, "r") as f:
        print("Gold file length:", len(f.readlines()))
    reference_file = params['reference_corpus']
    with open(reference_file, "r") as f:
        print("Reference file length:", len(f.readlines()))
    chunk = params['chunk']
    ngram = params['gram']
    num_ref_sentences = params['num_ref_sentences']
    num_gold_sentences = params['num_gold_sentences']
    evaluation_method = params['eval_method']
    device = params['device']
    output_file = "results/gold_corpora/results.txt"

    results = {}
    if evaluation_method=="BLEU":
        gram = params['gram']

        smoothing_method = {"nist": SmoothingFunction().method3}
        scores = {}

        for name, method in smoothing_method.items():
            scores[name] = utils.evaluate_bleu(gold_file, reference_file, num_real_sentences=num_ref_sentences, 
                    num_generated_sentences=num_gold_sentences, gram=gram, smoothing_method=method, chunk_size=15)

        print(scores)
        scores['nist']['scores'] = {}
        scores['nist']['scores']['bleu5'] = scores['nist']['bleu5'] * -1.0
        for name in smoothing_method.keys():
            results[name] = {}
            results[name]['scores'] = scores[name]
            results['num_ref_sentences'] = num_ref_sentences 
            results['num_gold_sentences'] = num_gold_sentences 
    elif evaluation_method=="Embedding":
        knn = params['knn']

        # use Embedding calculation
        from sentence_transformers import SentenceTransformer
        sentence_model = SentenceTransformer('bert-base-nli-mean-tokens')
        # TODO: update chunks to begin more naturally.
        all_text_sentences = []
       	with open(reference_file, "r+") as f:
            reference_sentences = [i.replace("</s>", "\n") for i in f.readlines()] 
        with open(gold_file, "r+") as f:
            all_text_sentences = [i.replace("</s>", "\n") for i in f.readlines()] 	 
        encoding_batch_size = 500 
        candidate_sentences = all_text_sentences 

        random.shuffle(candidate_sentences)
        candidate_sentences = candidate_sentences[:num_sentences]

        random.shuffle(reference_sentences)
        reference_sentences = reference_sentences[:num_sentences]
        # TODO: split the embeddings up into sentences
        # and mean-pool the sentences to get better embeddings.
        print("Beginning embedding eval...")
        print("Lengths:", len(candidate_sentences), len(reference_sentences))
        encoding_batch_size = 500 
        
        reference_embeddings = embeddings.encode_sentences(sentence_model, reference_sentences, batch_size=encoding_batch_size, device=device)
        candidate_embeddings = embeddings.encode_sentences(sentence_model, candidate_sentences, batch_size=encoding_batch_size, device=device)

        bleu = embeddings.compute_scores(candidate_sentences, candidate_embeddings, 
                reference_sentences, reference_embeddings, k=knn).item() * -1.0

        sbleu = embeddings.compute_scores(candidate_sentences, candidate_embeddings, 
                candidate_sentences, candidate_embeddings, k=knn, is_self=True).item() 

        results['knn'] = knn
        results['chunk'] = chunk
        results['num_sentences'] = num_sentences
        results['nist'] = {}
        results['nist']['scores'] = {}
        results['nist']['scores']['bleu5'] = bleu
        results['nist']['scores']['self-bleu5'] = sbleu


    if os.path.exists(output_file):
        with open(output_file, "r+") as f:
            current = json.load(f)
    else:
        current = {"BLEU": defaultdict(lambda: defaultdict({})),
                "Embedding": defaultdict(lambda: defaultdict({}))} 

    if evaluation_method=="BLEU":
        num_sentences = f"{num_ref_sentences}-{num_gold_sentences}"
        if num_sentences not in current['BLEU']: 
            current['BLEU'][num_sentences] = {}
        if chunk not in current['BLEU'][num_sentences]: 
            current['BLEU'][num_sentences][chunk] = {} 
        if ngram not in current['BLEU'][num_sentences][chunk]: 
            current['BLEU'][num_sentences][chunk] = {} 

        current['BLEU'][num_sentences][chunk][ngram] = results

    else:
        if knn not in current['Embedding']: 
            current['Embedding'][knn] = {}
        if chunk not in current['Embedding'][knn]: 
            current['Embedding'][knn][chunk] = {} 
        if num_sentences not in current['Embedding'][knn][chunk]:
            current['Embedding'][knn][chunk][num_sentences] = {}

        current['Embedding'][knn][chunk][num_sentences] = results

    with open(output_file, "w+") as f:
        json.dump(current, f)
def main(args, subparsers):
    print(args)
    print("Started experiment!")
    utils.print_args(args)
    utils.set_seed(args.seed)

    ###################################################################################
    ################################# Intialization ###################################
    ###################################################################################
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = transformers.AutoModelForCausalLM.from_pretrained(
        args.pretrained_class, local_files_only=True).eval()
    # register special tokens
    # num_added_tokens = tokenizer.add_special_tokens({"bos_token": "<BOS>", "eos_token": "<EOS>",
    # "pad_token": "<PAD>"})
    model.to(device)

    tokenizer = transformers.AutoTokenizer.from_pretrained(
        args.pretrained_class, local_files_only=True)
    sampler_args = vars(subparsers[args.sampler].parse_known_args()[0])
    sampler_args_items = "-".join(
        [f"{k}:{v}" for k, v in sampler_args.items()])

    tokenizer_args = f"tokenizer:{tokenizer.__class__.__name__}"
    args.pretrained_class = args.pretrained_class.replace("/", "_")
    if args.pretrained_class == "ctrl":
        tokenizer_args += f"-ctrl_code:{args.ctrl_code}"
    elif "gpt2" in args.pretrained_class:
        tokenizer.pad_token = tokenizer.eos_token

    pretrained_class = args.pretrained_class.replace("-", "_")
    sampler_name = args.sampler
    if (args.sampler == "NegativeSampler"):
        print(sampler_args)
        sampler_name += "_Negative_" + sampler_args['negative_base']
    output_file = f"model:{model.__class__.__name__}-model_class:{pretrained_class}-{tokenizer_args}-sampler:{sampler_name}-temperature:{args.temperature}-seq_length:{args.max_seq_length}-ngram:{args.gram}-{sampler_args_items}.txt"

    results_file = os.path.join("results/", args.pretrained_class,
                                args.results_file)

    # if our results file eixsts
    if os.path.exists(results_file):
        with open(results_file, "r+") as f:
            current = json.load(f)
        key = output_file[:-4]
        # check if we have already ran this
        if key in current:
            raise Exception("We've already computed the result!" + " " +
                            results_file)

    print("Using", args.prefix_file, "as the prefix file!")
    if not args.prefix_file:
        if args.pretrained_class == "ctrl":
            input_tokens = [tokenizer.control_codes[args.ctrl_code]]
        else:
            input_tokens = [tokenizer.bos_token_id]
        input_tokens = torch.tensor(input_tokens).to(device).unsqueeze(0)
    else:
        with open(args.prefix_file, "r") as f:
            # remove lines that are empty
            lines = []
            for line in f.readlines():
                if line.strip() and line.count(" ") > args.prefix_length:
                    lines.append(line)

            # shuffle to ensure we have some diversity
            random.shuffle(lines)
            # truncate to number of the sentences that we are generating
            lines = lines[:args.num_sentences]
            input_tokens = tokenizer.batch_encode_plus(
                lines,
                add_special_tokens=False,
                truncation=True,
                max_length=args.prefix_length,
                padding="max_length",
                return_tensors="pt")
            attention_mask = input_tokens['attention_mask']
            input_tokens = input_tokens['input_ids']
            attn_token = torch.tensor([1]).unsqueeze(0).repeat(
                args.num_sentences, 1)
            attention_mask = torch.cat((attn_token, attention_mask), dim=1)
            assert tokenizer.bos_token_id not in input_tokens[0]
            bos_token = torch.tensor([tokenizer.bos_token_id
                                      ]).unsqueeze(0).repeat(
                                          args.num_sentences, 1)
            input_tokens = torch.cat((bos_token, input_tokens), dim=1)

    print("Input Tokens:", input_tokens.shape)

    all_sentences = []
    k_primes, p_primes, entropy_primes = [], [], []
    num_sentences_left = args.num_sentences
    sentences_per_batch = args.generation_batch_size
    all_logprobs = []

    with torch.no_grad():
        for idx in range(ceil(args.num_sentences / sentences_per_batch)):
            batch_size = None
            if num_sentences_left > sentences_per_batch:
                batch_size = sentences_per_batch
            else:
                batch_size = num_sentences_left

            schedule = getattr(sampler, args.sampler)(**sampler_args)
            if input_tokens.shape[0] == 1:
                num_return_sequences = 1
                input_ids = input_tokens
            else:
                input_ids = input_tokens[idx:idx + batch_size].to(device)
                num_return_sequences = 1

            num_sentences_left -= batch_size

            sentences, model_logits, transformed_logits = filtering.generate(
                model=model,
                input_ids=input_ids,
                max_length=args.max_seq_length,
                do_sample=True,
                num_beams=None,
                temperature=args.temperature,
                schedule=schedule,
                repetition_penalty=1.0,
                bos_token_id=tokenizer.bos_token_id,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
                num_return_sequences=num_return_sequences,
                dry_run=args.dry_run)

            #########################################################################
            ############################### K Prime #################################
            #########################################################################
            sz = list(transformed_logits.size())
            mask = (sentences[:, -sz[1]:] >
                    0).cuda()  #careful! make sure this mask makes sense

            distro = torch.softmax(transformed_logits.view(-1, sz[-1]).cuda(),
                                   dim=-1)
            #use .float() for Bool to avoid bug!!!
            k_prime = torch.sum((distro >
                                 (1.0 / transformed_logits.size(-1))).float(),
                                dim=-1).view(sz[0], sz[1])
            k_prime = torch.masked_select(k_prime, mask)
            assert (torch.min(k_prime).item() > 0)
            k_prime = torch.log(k_prime)

            k_primes.extend(k_prime.cpu().tolist())
            #print('k_primes:', np.mean(k_primes))
            e_distro = torch.softmax(model_logits.contiguous().view(
                -1, sz[-1]).cuda(),
                                     dim=-1)

            ori_distro = torch.softmax(
                model_logits[:, -sz[1]:, :].contiguous().view(-1,
                                                              sz[-1]).cuda(),
                dim=-1)

            distro = torch.softmax(transformed_logits.view(-1, sz[-1]).cuda(),
                                   dim=-1)

            ori_distro = ori_distro * (distro >
                                       (1.0 / transformed_logits.size(-1)))
            p_prime = torch.sum(ori_distro, dim=-1).view(sz[0], sz[1])
            p_prime = torch.log(torch.masked_select(p_prime, mask).float())
            p_primes.extend(p_prime.cpu().tolist())

            distro = torch.softmax(transformed_logits.view(-1, sz[-1]).cuda(),
                                   dim=-1)
            entropy = -torch.sum(distro * torch.log(distro + 1e-10),
                                 dim=-1).view(sz[0], sz[1])
            entropy = torch.masked_select(entropy, mask)
            entropy_primes.extend(entropy.cpu().tolist())

            ##################################################################
            ############################ K Prime Ends ##########################
            ##################################################################

            transformed_logits = transformed_logits.to(device)
            model_logits = model_logits.to(device)
            sentences = sentences.to(device)
            logprobs = utils.calculate_logprobs(
                sentences,
                transformed_logits,
                model_logits,
                args.prefix_length,
                0,
                interpolate_ratio=args.filter_weight,
                batch_size=args.generation_batch_size)
            del model_logits
            del transformed_logits
            gc.collect()

            all_logprobs.append(logprobs.cpu().detach())
            all_sentences.append(sentences.cpu().detach())

    all_sentences = torch.cat(all_sentences, dim=0)
    all_logprobs = torch.cat(all_logprobs, dim=0)
    k_prime, p_prime, entropy_prime = np.mean(k_primes), np.mean(
        p_primes), np.mean(entropy_primes)
    print('Entropy Prime:', entropy_prime, 'K Prime:', k_prime, 'P Prime:',
          p_prime)
    results = {
        'k_prime': k_prime,
        'p_prime': p_prime,
        'entropy_prime': entropy_prime
    }

    del model
    print("Final shapes:", all_sentences.shape, all_logprobs.shape)
    # all text includes the prefix
    all_text_sentences = []
    # prefixed_text_sentences excludes the prefix
    prefixed_text_sentences = []
    for idx in range(
            all_sentences.shape[0]):  # iterate over the batch dimension
        # sentence_id = sentence[0]
        idx_offset = 1 if args.pretrained_class == "ctrl" else 0
        prefixed_sentence = all_sentences[idx, idx_offset:].tolist()
        idx_offset += args.prefix_length
        sentence = all_sentences[idx, idx_offset:].tolist()

        decoded_sentence = tokenizer.decode(sentence,
                                            skip_special_tokens=True,
                                            clean_up_tokenization_spaces=True)
        prefixed_decoded_sentence = tokenizer.decode(
            prefixed_sentence,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True)
        for idx in range(len(decoded_sentence))[::-1]:
            if decoded_sentence[idx] != "!":
                break

        decoded_sentence = decoded_sentence[:idx + 1]

        for idx in range(len(prefixed_decoded_sentence))[::-1]:
            if prefixed_decoded_sentence[idx] != "!":
                break

        prefixed_decoded_sentence = prefixed_decoded_sentence[:idx + 1]

        # all_text is without the prefix, prefixed is including the prefix.
        all_text_sentences.append(decoded_sentence)
        prefixed_text_sentences.append(prefixed_decoded_sentence)

    ###################################################################################
    ############################ Score the Generated Texts ############################
    ###################################################################################
    results_file = os.path.join("results/", args.pretrained_class,
                                args.results_file)
    results_basename = os.path.basename(results_file).replace(".json", "")
    results_dir = os.path.dirname(results_file)
    if not os.path.isdir(results_dir):
        os.makedirs(results_dir)

    #results = {}  # moved to k/p/ent_prime
    scores = {}
    files = os.path.join("saved_generations/", results_basename,
                         args.pretrained_class, args.output_dir, output_file)
    files_dir = os.path.dirname(files)
    if not os.path.isdir(files_dir):
        os.makedirs(files_dir)

    print(f"Writing generated sentences to {files}.")
    utils.write_sentences(all_text_sentences, files)

    preprocessed_files = os.path.join("preprocessed_generations/",
                                      results_basename, args.pretrained_class,
                                      args.output_dir, output_file)
    preprocessed_files_dir = os.path.dirname(preprocessed_files)
    if not os.path.isdir(preprocessed_files_dir):
        os.makedirs(preprocessed_files_dir)

    print(f"Writing preprocessed sentences to {preprocessed_files}.")
    preprocessed_sentences, filtered_indicies, filtered_lengths = utils.preprocess_text(
        prefixed_text_sentences,
        tokenizer,
        lmin=args.preprocessed_min,
        lmax=args.preprocessed_max)
    utils.write_sentences(preprocessed_sentences, preprocessed_files)

    # update the reference file to be chunked to our size
    reference_file = args.eval_text
    chunked_reference_file = f"{reference_file}_seq:{args.max_seq_length}_min:{args.preprocessed_min}_max:{args.preprocessed_max}_prefix:{args.prefix_length}_model:{args.pretrained_class.replace('models/', '')}"
    if not os.path.exists(chunked_reference_file):
        utils.lock(chunked_reference_file)
        print("Reference lock acquired!")
        # begin critical section!
        utils.chunk_and_prefix_file(reference_file,
                                    tokenizer,
                                    args.preprocessed_min,
                                    args.preprocessed_max,
                                    chunked_reference_file,
                                    prefix_length=args.prefix_length)
        # end critical section!
        utils.unlock(chunked_reference_file)

    filtered_tokenizations = []
    filtered_logprobs = []
    for idx in filtered_indicies:
        filtered_tokenizations.append(all_sentences[idx])
        filtered_logprobs.append(all_logprobs[idx])
    filtered_tokenizations = torch.stack(filtered_tokenizations, dim=0)
    filtered_logprobs = torch.stack(filtered_logprobs, dim=0)

    del all_logprobs
    gc.collect()

    if args.eval_method == "BLEU":
        # use BLEU calculation
        smoothing_method = {"nist": SmoothingFunction().method3}
        for name, method in smoothing_method.items():
            scores[name] = utils.evaluate_bleu(
                files,
                chunked_reference_file,
                num_real_sentences=args.num_sentences,
                num_generated_sentences=args.num_sentences,
                gram=args.gram,
                smoothing_method=method,
                chunk_size=15)
            print()

        for name in smoothing_method.keys():
            results[name] = {}
            results[name]['scores'] = scores[name]

        results['nist']['scores'][
            'bleu5'] = results['nist']['scores']['bleu5'] * -1.0
        bleu = results['nist']['scores']['bleu5'] * -1.0
        sbleu = results['nist']['scores']['self-bleu5']
    else:
        raise Exception("We don't support other automatic metrics!")

    print("Results:", bleu, sbleu)

    ###################################################################################
    ############################# Result Reporting Section ############################
    ###################################################################################

    if not args.dry_run:
        results_file = os.path.join("results/", args.pretrained_class,
                                    args.results_file)
        results_dir = os.path.dirname(results_file)
        if not os.path.isdir(results_dir):
            os.makedirs(results_dir)
        utils.lock(results_file)
        print("Lock acquired!")

        # begin critical section!
        if os.path.exists(results_file):
            with open(results_file, "r+") as f:
                current = json.load(f)
        else:
            current = {}

        key = output_file[:-4]
        current[key] = results
        random_file = ''.join(
            random.SystemRandom().choice(string.ascii_uppercase +
                                         string.digits) for _ in range(10))
        random_file = os.path.join("results/", args.pretrained_class,
                                   random_file)
        with open(random_file, "w+") as f:
            json.dump(current, f)

        os.rename(random_file, results_file)

        # save generations
        saved_tokens_file = os.path.join("tokens/", results_basename,
                                         args.pretrained_class,
                                         args.output_dir, output_file)
        saved_tokens_dir = os.path.dirname(saved_tokens_file)
        if not os.path.isdir(saved_tokens_dir):
            os.makedirs(saved_tokens_dir)

        saved_tokens = {}
        saved_tokens['args'] = [
            vars(args),
            vars(subparsers[args.sampler].parse_known_args()[0])
        ]
        idx_offset = 1 if args.pretrained_class == "ctrl" else 0
        saved_tokens['with_prefix'] = all_sentences[:, idx_offset:].tolist()
        idx_offset += args.prefix_length
        saved_tokens['without_prefix'] = all_sentences[:, idx_offset:].tolist()

        with open("saved_tokens_file", "w+") as f:
            json.dump(saved_tokens, f)

        # save log probabilities
        preprocessed_logits = os.path.join("preprocessed_logprobs/",
                                           results_basename,
                                           args.pretrained_class,
                                           args.output_dir, output_file)
        preprocessed_logits_dir = os.path.dirname(preprocessed_logits)
        if not os.path.isdir(preprocessed_logits_dir):
            os.makedirs(preprocessed_logits_dir)

        d = {}
        print(filtered_logprobs.shape)
        for idx in range(filtered_logprobs.shape[0]):
            if preprocessed_sentences[idx] in d:
                raise Exception("Duplicate sentences found!")
            sent_id = hashlib.sha256(
                preprocessed_sentences[idx].encode()).hexdigest()
            d[sent_id] = {
                "model_score": filtered_logprobs[idx].item(),
                "lengths": filtered_lengths[idx] - args.prefix_length,
                "sentence": preprocessed_sentences[idx]
            }
        print("Avg log probabilities:",
              (filtered_logprobs /
               (torch.tensor(filtered_lengths) - args.prefix_length)).mean(
                   dim=0))

        with open(preprocessed_logits, "w") as f:
            json.dump(d, f)

        # create plot
        plots_file = os.path.join("plots/", args.pretrained_class,
                                  args.results_file)
        plots_dir = os.path.dirname(plots_file)
        if not os.path.isdir(plots_dir):
            os.makedirs(plots_dir)

        plot = plotter.Plotter(results_file)
        plot.plot_curves()
        if args.plot_gold:
            params = {
                "eval_method": args.eval_method,
                "chunk": args.max_seq_length,
                "ngram": args.gram,
                "knn": args.knn,
                "num_sentences": args.num_sentences
            }
            result = plot.plot_gold(params)
            if not result:
                # We don't have a proper score for our reference file, so let's go ahead and create it.
                params['gold_file'] = chunked_reference_file.replace(
                    "test", "valid")
                print(
                    f"Evaluating gold point on {params['gold_file']} with KNN={args.knn}"
                )
                params['num_sentences'] = args.num_sentences
                params['reference_corpus'] = chunked_reference_file
                params['chunk'] = args.max_seq_length
                params['eval_method'] = args.eval_method
                params['knn'] = args.knn
                params['gram'] = args.gram
                params['device'] = device
                score_gold(params)
                result = plot.plot_gold(params)

        plot.save(plots_file.replace(".json", ""))
        # end critical section!
        utils.unlock(results_file)