Beispiel #1
0
def main(parsed_args):
    assert parsed_args.path is not None, '--path required for evaluation!'

    utils.import_user_module(parsed_args)

    logger.info(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load ensemble
    logger.info('loading model(s) from {}'.format(parsed_args.path))
    models, args = checkpoint_utils.load_model_ensemble(
        parsed_args.path.split(os.pathsep),
        arg_overrides=eval(parsed_args.model_overrides),
        task=task,
    )

    for arg in vars(parsed_args).keys():
        if arg not in {
                'self_target',
                'future_target',
                'past_target',
                'tokens_per_sample',
                'output_size_dictionary',
                'add_bos_token',
        }:
            setattr(args, arg, getattr(parsed_args, arg))

    # reduce tokens per sample by the required context window size
    args.tokens_per_sample -= args.context_window
    task = tasks.setup_task(args)

    # Load dataset splits
    task.load_dataset(args.gen_subset)
    dataset = task.dataset(args.gen_subset)
    if args.context_window > 0:
        dataset = LMContextWindowDataset(
            dataset=dataset,
            tokens_per_sample=args.tokens_per_sample,
            context_window=args.context_window,
            pad_idx=task.source_dictionary.pad(),
        )
    logger.info('{} {} {} examples'.format(args.data, args.gen_subset,
                                           len(dataset)))

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    assert len(models) > 0

    logger.info('num. model params: {}'.format(
        sum(p.numel() for p in models[0].parameters())))

    itr = task.get_batch_iterator(
        dataset=dataset,
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=True,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(task.target_dictionary,
                            args.softmax_batch,
                            args=args)

    score_sum = 0.
    count = 0

    if args.remove_bpe is not None:
        if args.remove_bpe == 'sentencepiece':
            raise NotImplementedError
        else:
            bpe_cont = args.remove_bpe.rstrip()
            bpe_toks = {
                i
                for i in range(len(task.source_dictionary))
                if task.source_dictionary[i].endswith(bpe_cont)
            }
        bpe_len = len(bpe_cont)
    else:
        bpe_toks = None
        bpe_len = 0

    word_stats = dict()

    if args.knnlm and args.save_knnlm_dstore:
        raise ValueError(
            "Cannot use knnlm while trying to build the datastore!")

    if args.knnlm:
        knn_dstore = KNN_Dstore(args)

    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()

        if args.save_knnlm_dstore:
            print('keytype being saved:', args.knn_keytype)
            if args.dstore_fp16:
                print('Saving fp16')
                dstore_keys = np.memmap(args.dstore_mmap + '_keys.npy',
                                        dtype=np.float16,
                                        mode='w+',
                                        shape=(args.dstore_size,
                                               args.decoder_embed_dim))
                dstore_vals = np.memmap(args.dstore_mmap + '_vals.npy',
                                        dtype=np.int16,
                                        mode='w+',
                                        shape=(args.dstore_size, 1))
            else:
                print('Saving fp32')
                dstore_keys = np.memmap(args.dstore_mmap + '_keys.npy',
                                        dtype=np.float32,
                                        mode='w+',
                                        shape=(args.dstore_size,
                                               args.decoder_embed_dim))
                dstore_vals = np.memmap(args.dstore_mmap + '_vals.npy',
                                        dtype=np.int,
                                        mode='w+',
                                        shape=(args.dstore_size, 1))

        dstore_idx = 0
        for ex_i, sample in enumerate(t):
            if 'net_input' not in sample:
                continue

            sample = utils.move_to_cuda(sample) if use_cuda else sample

            gen_timer.start()
            if args.knnlm:
                hypos = scorer.generate(models, sample, knn_dstore=knn_dstore)
            else:
                hypos = scorer.generate(models, sample)
            gen_timer.stop(sample['ntokens'])

            for i, hypos_i in enumerate(hypos):
                hypo = hypos_i[0]
                if args.save_knnlm_dstore:
                    shape = hypo['dstore_keys'].shape
                    if shape[0] == args.tokens_per_sample:
                        if dstore_idx + shape[0] > args.dstore_size:
                            shape = [args.dstore_size - dstore_idx]
                            hypo['dstore_keys'] = hypo[
                                'dstore_keys'][:shape[0]]
                        if args.dstore_fp16:
                            dstore_keys[dstore_idx:shape[0] +
                                        dstore_idx] = hypo['dstore_keys'].view(
                                            -1, args.decoder_embed_dim).cpu(
                                            ).numpy().astype(np.float16)
                            dstore_vals[dstore_idx:shape[0] +
                                        dstore_idx] = hypo['tokens'].view(
                                            -1,
                                            1).cpu().numpy().astype(np.int16)
                        else:
                            dstore_keys[dstore_idx:shape[0] +
                                        dstore_idx] = hypo['dstore_keys'].view(
                                            -1, args.decoder_embed_dim).cpu(
                                            ).numpy().astype(np.float32)
                            dstore_vals[dstore_idx:shape[0] +
                                        dstore_idx] = hypo['tokens'].view(
                                            -1, 1).cpu().numpy().astype(np.int)

                        dstore_idx += shape[0]
                    else:
                        print('Skipping this one with shape', shape)

                sample_id = sample['id'][i]

                tokens = hypo['tokens']
                tgt_len = tokens.numel()
                pos_scores = hypo['positional_scores'].float()

                if args.add_bos_token:
                    assert hypo['tokens'][0].item(
                    ) == task.target_dictionary.bos()
                    tokens = tokens[1:]
                    pos_scores = pos_scores[1:]

                skipped_toks = 0
                if bpe_toks is not None:
                    for i in range(tgt_len - 1):
                        if tokens[i].item() in bpe_toks:
                            skipped_toks += 1
                            pos_scores[i + 1] += pos_scores[i]
                            pos_scores[i] = 0

                #inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
                #if inf_scores.any():
                #    logger.info(
                #        'skipping tokens with inf scores:',
                #        task.target_dictionary.string(tokens[inf_scores.nonzero()])
                #    )
                #    pos_scores = pos_scores[(~inf_scores).nonzero()]
                score_sum += pos_scores.sum().cpu()
                count += pos_scores.numel() - skipped_toks

                if args.output_word_probs or args.output_word_stats:
                    w = ''
                    word_prob = []
                    is_bpe = False
                    for i in range(len(tokens)):
                        w_ind = tokens[i].item()
                        w += task.source_dictionary[w_ind]
                        if bpe_toks is not None and w_ind in bpe_toks:
                            w = w[:-bpe_len]
                            is_bpe = True
                        else:
                            word_prob.append((w, pos_scores[i].item()))

                            next_prob = None
                            ind = i + 1
                            while ind < len(tokens):
                                if pos_scores[ind].item() != 0:
                                    next_prob = pos_scores[ind]
                                    break
                                ind += 1

                            word_stats.setdefault(w, WordStat(w, is_bpe)).add(
                                pos_scores[i].item(), next_prob)
                            is_bpe = False
                            w = ''
                    if args.output_word_probs:
                        logger.info(
                            str(int(sample_id)) + " " +
                            ('\t'.join('{} [{:2f}]'.format(x[0], x[1])
                                       for x in word_prob)))

            wps_meter.update(sample['ntokens'])
            t.log({'wps': round(wps_meter.avg)})

    if args.save_knnlm_dstore:
        print("dstore_idx", dstore_idx, "final shape", shape)
        print("Keys", dstore_keys.shape, dstore_keys.dtype)
        print("Vals", dstore_vals.shape, dstore_vals.dtype)

    avg_nll_loss = -score_sum / count / math.log(2)  # convert to base 2
    logger.info('Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(
        gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    logger.info('Loss (base 2): {:.4f}, Perplexity: {:.2f}'.format(
        avg_nll_loss, 2**avg_nll_loss))

    if args.output_word_stats:
        for ws in sorted(word_stats.values(),
                         key=lambda x: x.count,
                         reverse=True):
            logger.info(ws)
Beispiel #2
0
def _main(args, output_file):
    logging.basicConfig(
        format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        level=logging.INFO,
        stream=output_file,
    )
    logger = logging.getLogger('fairseq_cli.generate')

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(os.pathsep),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    args.vocab_size = len(tgt_dict)
    for arg in vars(_model_args).keys():
        if arg in {'decoder_embed_dim', 'vocab_size'}:
            setattr(args, arg, getattr(_model_args, arg))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    if args.knnlm and args.save_knnlm_dstore:
        raise ValueError(
            "Cannot use knnlm while trying to build the datastore!")
    if args.knnlm:
        knn_dstore = KNN_Dstore(args)

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())

    if args.save_knnlm_dstore:
        print('keytype being saved:', args.knn_keytype)
        if args.dstore_fp16:
            print('Saving fp16')
            dstore_keys = np.memmap(args.dstore_mmap + '_keys.npy',
                                    dtype=np.float16,
                                    mode='w+',
                                    shape=(args.dstore_size,
                                           args.decoder_embed_dim))
            dstore_vals = np.memmap(args.dstore_mmap + '_vals.npy',
                                    dtype=np.int16,
                                    mode='w+',
                                    shape=(args.dstore_size, 1))
        else:
            print('Saving fp32')
            dstore_keys = np.memmap(args.dstore_mmap + '_keys.npy',
                                    dtype=np.float32,
                                    mode='w+',
                                    shape=(args.dstore_size,
                                           args.decoder_embed_dim))
            dstore_vals = np.memmap(args.dstore_mmap + '_vals.npy',
                                    dtype=np.int,
                                    mode='w+',
                                    shape=(args.dstore_size, 1))
        dstore_idx = 0
    if args.save_knnlm_dstore or args.knnlm:
        # source_tokens_file = open(args.output_tokens_file_prefix + '.src' , 'w')
        target_tokens_file = open(args.output_tokens_file_prefix + '.tgt', 'w')

        # This is only for MT right now, use interactive.py for language modeling
        assert task != 'language_modeling'

    num_sentences = 0
    has_target = True
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            gen_timer.start()
            if args.knnlm:
                hypos = task.inference_step(generator,
                                            models,
                                            sample,
                                            prefix_tokens,
                                            knn_dstore=knn_dstore)
            else:
                hypos = task.inference_step(generator, models, sample,
                                            prefix_tokens)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            if args.save_knnlm_dstore:
                for i, hypos_i in enumerate(hypos):
                    hypo = hypos_i[0]
                    shape = hypo['dstore_keys'].shape
                    if dstore_idx + shape[0] > args.dstore_size:
                        shape = [args.dstore_size - dstore_idx]
                        hypo['dstore_keys'] = hypo['dstore_keys'][:shape[0]]
                    # import pdb; pdb.set_trace()
                    # print(hypo)
                    if args.dstore_fp16:
                        dstore_keys[dstore_idx:shape[0] +
                                    dstore_idx] = hypo['dstore_keys'].view(
                                        -1, args.decoder_embed_dim).cpu(
                                        ).numpy().astype(np.float16)
                        dstore_vals[dstore_idx:shape[0] +
                                    dstore_idx] = hypo['tokens'].view(
                                        -1, 1).cpu().numpy().astype(np.int16)
                    else:
                        dstore_keys[dstore_idx:shape[0] +
                                    dstore_idx] = hypo['dstore_keys'].view(
                                        -1, args.decoder_embed_dim).cpu(
                                        ).numpy().astype(np.float32)
                        dstore_vals[dstore_idx:shape[0] +
                                    dstore_idx] = hypo['tokens'].view(
                                        -1, 1).cpu().numpy().astype(np.int)
                    dstore_idx += shape[0]

            if args.save_knnlm_dstore or args.knnlm:
                for i, hypos_i in enumerate(hypos):
                    hypo = hypos_i[0]

                    # dump the tokens to a file, used for analysis and interactive printing
                    # source_tokens = [task.source_dictionary[token] for token in hypo['source_tokens']]
                    # source_tokens_file.write('\n'.join(source_tokens) + '\n')

                    target_tokens = [
                        task.target_dictionary[token]
                        for token in hypo['tokens']
                    ]
                    target_tokens_file.write('\n'.join(target_tokens) + '\n')

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                # Remove padding
                src_tokens = utils.strip_pad(
                    sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(
                        sample['target'][i, :], tgt_dict.pad()).int().cpu()

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(
                        args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(
                        args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

                if not args.quiet:
                    if src_dict is not None:
                        print('S-{}\t{}'.format(sample_id, src_str),
                              file=output_file)
                    if has_target:
                        print('T-{}\t{}'.format(sample_id, target_str),
                              file=output_file)

                # Process top predictions
                for j, hypo in enumerate(hypos[i][:args.nbest]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'],
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    if not args.quiet:
                        score = hypo['score'] / math.log(
                            2)  # convert to base 2
                        print('H-{}\t{}\t{}'.format(sample_id, score,
                                                    hypo_str),
                              file=output_file)
                        print(
                            'P-{}\t{}'.format(
                                sample_id,
                                ' '.join(
                                    map(
                                        lambda x: '{:.4f}'.format(x),
                                        # convert from base e to base 2
                                        hypo['positional_scores'].div_(
                                            math.log(2)).tolist(),
                                    ))),
                            file=output_file)

                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id, ' '.join([
                                    '{}-{}'.format(src_idx, tgt_idx)
                                    for src_idx, tgt_idx in alignment
                                ])),
                                  file=output_file)

                        if args.print_step:
                            print('I-{}\t{}'.format(sample_id, hypo['steps']),
                                  file=output_file)

                        if getattr(args, 'retain_iter_history', False):
                            for step, h in enumerate(hypo['history']):
                                _, h_str, _ = utils.post_process_prediction(
                                    hypo_tokens=h['tokens'].int().cpu(),
                                    src_str=src_str,
                                    alignment=None,
                                    align_dict=None,
                                    tgt_dict=tgt_dict,
                                    remove_bpe=None,
                                )
                                print('E-{}_{}\t{}'.format(
                                    sample_id, step, h_str),
                                      file=output_file)
                    # Score only the top hypothesis
                    if has_target and j == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tgt_dict.encode_line(
                                target_str, add_if_not_exist=True)
                        if hasattr(scorer, 'add_string'):
                            scorer.add_string(target_str, hypo_str)
                        else:
                            scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(num_generated_tokens)
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += sample['nsentences']

    logger.info('NOTE: hypothesis and token scores are output in base 2')
    logger.info(
        'Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        logger.info('Generate {} with beam={}: {}'.format(
            args.gen_subset, args.beam, scorer.result_string()))

    if args.save_knnlm_dstore:
        print("dstore_idx", dstore_idx, "final shape", shape)
        print("Keys", dstore_keys.shape, dstore_keys.dtype)
        print("Vals", dstore_vals.shape, dstore_vals.dtype)
        target_tokens_file.seek(0)
        num_lines = len(target_tokens_file.readlines())
        if dstore_idx != num_lines:
            print(
                'Warning: size of KNN datastore is {}, does not match number of lines in train tokens file which is {}'
                .format(dstore_idx, num_lines))

    if args.save_knnlm_dstore or args.knnlm:
        # source_tokens_file.close()
        target_tokens_file.close()

    return scorer
def main(args):
    utils.import_user_module(args)

    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert not args.max_sentences or args.max_sentences <= args.buffer_size, \
        '--max-sentences/--batch-size cannot be larger than --buffer-size'

    logger.info(args)
    use_cuda = torch.cuda.is_available() and not args.cpu

    # Setup task, e.g., translation
    task = tasks.setup_task(args)

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(os.pathsep),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    for arg in vars(_model_args).keys():
        '''if arg not in {
            'self_target', 'future_target', 'past_target', 'tokens_per_sample',
            'output_size_dictionary', 'add_bos_token',
        }:'''
        if arg in {'decoder_embed_dim', 'vocab_size'}:
            setattr(args, arg, getattr(_model_args, arg))

    logger.info(args)

    if args.knnlm:
        knn_dstore = KNN_Dstore(args)
        print("Loading training tokens...")
        with open(args.input_tokens_file) as infile:
            train_tokens = infile.read().split()

        print("TODO, REMOVE ME\n\n\n !!!!Skipping first training tokens...")
        train_tokens = train_tokens[3072:]  # TODO, remove this

    if args.buffer_size < 1:
        args.buffer_size = 1
    if args.max_tokens is None and args.max_sentences is None:
        args.max_sentences = 1

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Initialize generator
    generator = task.build_generator(args)

    # Handle tokenization and BPE
    tokenizer = encoders.build_tokenizer(args)
    bpe = encoders.build_bpe(args)

    def encode_fn(x):
        if tokenizer is not None:
            x = tokenizer.encode(x)
        if bpe is not None:
            x = bpe.encode(x)
        return x

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    max_positions = utils.resolve_max_positions(
        task.max_positions(), *[model.max_positions() for model in models])

    if args.buffer_size > 1:
        logger.info('Sentence buffer size: %s', args.buffer_size)
    logger.info('NOTE: hypothesis and token scores are output in base 2')
    logger.info('Type the input sentence and press return:')
    start_id = 0
    for inputs in buffered_read(args.input, args.buffer_size):
        results = []
        for batch in make_batches(inputs, args, task, max_positions,
                                  encode_fn):
            src_tokens = batch.src_tokens
            src_lengths = batch.src_lengths
            if use_cuda:
                src_tokens = src_tokens.cuda()
                src_lengths = src_lengths.cuda()

            sample = {
                'net_input': {
                    'src_tokens': src_tokens,
                    'src_lengths': src_lengths,
                },
            }
            if args.knnlm:
                translations = task.inference_step(generator,
                                                   models,
                                                   sample,
                                                   None,
                                                   knn_dstore=knn_dstore)
            else:
                translations = task.inference_step(generator, models, sample,
                                                   None)
            for i, (id,
                    hypos) in enumerate(zip(batch.ids.tolist(), translations)):
                src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
                results.append((start_id + id, src_tokens_i, hypos))

        # sort output to match input order
        for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]):
            if src_dict is not None:
                src_str = src_dict.string(src_tokens, args.remove_bpe)
                print('S-{}\t{}'.format(id, src_str))

            # Process top predictions
            for hypo in hypos[:min(len(hypos), args.nbest)]:
                assert hypo['dists_full'] != None
                dists_full = hypo['dists_full'].float()
                knns_full = hypo['knns_full']
                word_tokens = [
                    task.target_dictionary[token] for token in hypo['tokens']
                ]
                #assert len(yhat_scores.tolist()) == len(word_tokens)

                # TODO, trim off padding when its batched

                context_size = 20
                num_neighbors = 10
                print("Example:", " ".join(word_tokens))
                print(dists_full)
                best_dist_indices = np.argsort(
                    dists_full)[-num_neighbors:][::-1]
                for j, neighbor_index in enumerate(best_dist_indices):
                    distance = dists_full[neighbor_index]
                    knn_index = knns_full[neighbor_index]
                    print(
                        "Best neighbor {} (distance {:.2f}):".format(
                            j, distance),
                        " ".join(train_tokens[knn_index -
                                              context_size:knn_index]), "[[",
                        train_tokens[knn_index], "]]")

                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'],
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                )
                hypo_str = decode_fn(hypo_str)
                score = hypo['score'] / math.log(2)  # convert to base 2
                print('H-{}\t{}\t{}'.format(id, score, hypo_str))
                print('P-{}\t{}'.format(
                    id,
                    ' '.join(
                        map(
                            lambda x: '{:.4f}'.format(x),
                            # convert from base e to base 2
                            hypo['positional_scores'].div_(math.log(2)
                                                           ).tolist(),
                        ))))
                if args.print_alignment:
                    alignment_str = " ".join(
                        ["{}-{}".format(src, tgt) for src, tgt in alignment])
                    print('A-{}\t{}'.format(id, alignment_str))

        # update running id counter
        start_id += len(inputs)
Beispiel #4
0
def main(parsed_args):
    if parsed_args.dstore_mmap is not None:
        d = os.path.dirname(parsed_args.dstore_mmap)
        print('mmap from {}'.format(d))
        if not os.path.exists(d):
            print('making dir')
            os.system('mkdir -p {}'.format(d))

    utils.import_user_module(parsed_args)

    logger.info(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load model.
    hf_tokenizer = AutoTokenizer.from_pretrained(parsed_args.hf_model)
    if parsed_args.hf_enc_mode == 'masked':
        hf_model = AutoModelForMaskedLM.from_pretrained(parsed_args.hf_model)
    elif parsed_args.hf_enc_mode == 'causal':
        hf_model = AutoModelForCausalLM.from_pretrained(parsed_args.hf_model)

    if use_cuda:
        hf_model.cuda()

    device = next(hf_model.parameters()).device

    check_input_ids = hf_tokenizer('hello world')['input_ids']
    add_cls_token = check_input_ids[0] == hf_tokenizer.cls_token_id
    add_sep_token = check_input_ids[-1] == hf_tokenizer.sep_token_id
    print('add_cls_token = {} {} {}'.format(add_cls_token, hf_tokenizer.cls_token, hf_tokenizer.cls_token_id))
    print('add_sep_token = {} {} {}'.format(add_sep_token, hf_tokenizer.sep_token, hf_tokenizer.sep_token_id))

    args = copy.deepcopy(parsed_args)

    # reduce tokens per sample by the required context window size
    args.tokens_per_sample -= args.context_window
    task = tasks.setup_task(args)

    # Load dataset splits
    task.load_dataset(args.gen_subset)
    task_dataset = task.dataset(args.gen_subset)
    assert args.context_window > 0
    dataset = LMContextWindowDataset(
        dataset=task_dataset,
        tokens_per_sample=args.tokens_per_sample,
        context_window=args.context_window,
        pad_idx=task.source_dictionary.pad(),
    )
    logger.info('{} {} {} examples'.format(args.data, args.gen_subset, len(dataset)))

    model_max_length = min(hf_tokenizer.model_max_length, parsed_args.hf_max_position)

    itr = task.get_batch_iterator(
        dataset=dataset,
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(*[
            model_max_length
        ]),
        ignore_invalid_inputs=True,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)
    #).next_epoch_itr(shuffle=True)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(task.target_dictionary, args.softmax_batch, args=args)

    score_sum = 0.
    count = 0

    if args.remove_bpe is not None:
        if args.remove_bpe == 'sentencepiece':
            raise NotImplementedError
        else:
            bpe_cont = args.remove_bpe.rstrip()
            bpe_toks = {
                i
                for i in range(len(task.source_dictionary))
                if task.source_dictionary[i].endswith(bpe_cont)
            }
        bpe_len = len(bpe_cont)
    else:
        bpe_toks = None
        bpe_len = 0

    word_stats = dict()

    if args.knnlm and args.save_knnlm_dstore:
        raise ValueError("Cannot use knnlm while trying to build the datastore!")

    if args.knnlm:
        knn_dstore = KNN_Dstore(args)

    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()

        if args.save_knnlm_dstore:
            print('keytype being saved:', args.knn_keytype)
            dstore_keys = np.memmap(args.dstore_mmap+'_keys.npy', dtype=np.float32, mode='w+', shape=(args.dstore_size, hf_model.config.d_model))
            dstore_vals = np.memmap(args.dstore_mmap+'_vals.npy', dtype=np.int, mode='w+', shape=(args.dstore_size, 1))

        if args.save_extra:
            writer = Writer(outdir='demo-out', max_size=args.save_extra_max_size, k=args.k, vec_size=1024)

        def pad(x, pad_id=-1):
            max_len = max([len(xx) for xx in x])
            x = [xx + [pad_id] * (max_len - len(xx)) for xx in x]
            return x

        def batchify(batch):
            new_batch = {}
            new_batch['input_ids'] = torch.tensor(pad(batch['src_tokens'], hf_tokenizer.pad_token_id), dtype=torch.long, device=device)
            new_batch['context_mask'] = torch.tensor(pad(batch['mask'], -1), dtype=torch.long, device=device)
            new_batch['word_id'] = torch.tensor(pad(batch['word_id'], -1), dtype=torch.long, device=device)
            new_batch['target'] = torch.tensor(pad(batch['target'], -1), dtype=torch.long, device=device)
            return new_batch

        dstore_idx = 0
        dstore_full = False
        num_tokens = 0
        for ex_i, sample in tqdm(enumerate(t), desc='encode'):
            if 'net_input' not in sample:
                continue

            all_tokens = torch.cat([sample['net_input']['src_tokens'], sample['target'][:, -1, None]], -1)

            hf_batch = collections.defaultdict(list)
            for tok in all_tokens.tolist():
                tok = [tt for tt in tok if tt != dataset.pad_idx]
                raw_text = [task_dataset.vocab[tt] for tt in tok]
                hf_src_tokens, hf_target, hf_raw_target, hf_raw_text, hf_word_id, hf_mask = [], [], [], [], [], []
                for i_w in range(len(raw_text) - 1):

                    w = raw_text[i_w]
                    tok_ = hf_tokenizer.encode(w, add_special_tokens=False)
                    if i_w == 0 and add_cls_token:
                        if tok_[0] != hf_tokenizer.cls_token_id:
                            tok_ = [hf_tokenizer.cls_token_id] + tok_

                    if len(hf_src_tokens) + len(tok_) > model_max_length:
                        break

                    hf_src_tokens += tok_
                    hf_raw_text += hf_tokenizer.convert_ids_to_tokens(tok_)
                    hf_word_id += [i_w] * len(tok_)
                    hf_mask += [0] * (len(tok_) - 1) + [1]
                    hf_target += [tok[i_w + 1]] * len(tok_)
                    hf_raw_target += [raw_text[i_w + 1]]

                assert len(hf_src_tokens) == len(hf_target)
                assert len(hf_src_tokens) == len(hf_word_id)
                assert len(hf_src_tokens) == len(hf_mask)

                hf_batch['src_tokens'].append(hf_src_tokens)
                hf_batch['target'].append(hf_target) # This is indexed by KNN-LM tokenizer.
                hf_batch['raw_target'].append(hf_raw_target)
                hf_batch['word_id'].append(hf_word_id)
                hf_batch['mask'].append(hf_mask)

                num_tokens += len(hf_src_tokens)

            hf_batch_ = batchify(hf_batch)

            model_output = hf_model(hf_batch_['input_ids'], output_hidden_states=True)

            h = model_output['hidden_states'][-1]

            assert h.shape[:2] == hf_batch_['input_ids'].shape[:2]

            if args.save_knnlm_dstore and not dstore_full:

                flat_h = h.view(-1, hf_model.config.d_model)
                mask_ = hf_batch_['context_mask'].view(-1) == 1
                keys_ = flat_h[mask_]
                vals_ = hf_batch_['target'].view(-1, 1)[mask_]

                shape = keys_.shape
                if dstore_idx + shape[0] > args.dstore_size:
                    shape = [args.dstore_size - dstore_idx]
                    dstore_full = True

                keys_ = keys_[:shape[0]]
                vals_ = vals_[:shape[0]]

                assert keys_.shape[0] == vals_.shape[0]

                dstore_keys[dstore_idx:shape[0]+dstore_idx] = keys_.cpu().numpy().astype(np.float32)
                dstore_vals[dstore_idx:shape[0]+dstore_idx] = vals_.cpu().numpy().astype(np.int)

                dstore_idx += shape[0]
            if dstore_full:
                print('Datastore is full with {} items.'.format(args.dstore_size))

            wps_meter.update(sample['ntokens'])
            t.log({'wps': round(wps_meter.avg)})

            # Write saved values to disk.
            if args.save_extra:
                writer.update(extra)

    if args.save_knnlm_dstore:
        print("dstore_idx", dstore_idx, "final shape", shape)
        print("Keys", dstore_keys.shape, dstore_keys.dtype)
        print("Vals", dstore_vals.shape, dstore_vals.dtype)

    logger.info('done with {} tokens'.format(num_tokens))