예제 #1
0
 def eval_lm_dataloader(
     self,
     dataset,
     max_tokens: Optional[int] = 36000,
     batch_size: Optional[int] = None,
     max_positions: Optional[int] = None,
     num_shards: int = 1,
     shard_id: int = 0,
     num_workers: int = 1,
     data_buffer_size: int = 10,
     # ensures that every evaluated token has access to a context of at least
     # this size, if possible
     context_window: int = 0,
 ):
     if context_window > 0:
         dataset = LMContextWindowDataset(
             dataset=dataset,
             tokens_per_sample=self.args.tokens_per_sample,
             context_window=context_window,
             pad_idx=self.source_dictionary.pad(),
         )
     return self.get_batch_iterator(
         dataset=dataset,
         max_tokens=max_tokens,
         max_sentences=batch_size,
         max_positions=max_positions,
         ignore_invalid_inputs=True,
         num_shards=num_shards,
         shard_id=shard_id,
         num_workers=num_workers,
         data_buffer_size=data_buffer_size,
     ).next_epoch_itr(shuffle=False)
예제 #2
0
def main(cfg: DictConfig, override_args=None, **unused_kwargs):
    if isinstance(cfg, Namespace):
        cfg = convert_namespace_to_omegaconf(cfg)

    utils.import_user_module(cfg.common)

    use_fp16 = cfg.common.fp16
    use_cuda = torch.cuda.is_available() and not cfg.common.cpu

    if use_cuda:
        torch.cuda.set_device(cfg.distributed_training.device_id)

    if override_args is not None:
        overrides = vars(override_args)
        overrides.update(eval(getattr(override_args, "model_overrides", "{}")))
    else:
        overrides = None

    logger.info(cfg)

    # Load ensemble
    logger.info("loading model(s) from {}".format(cfg.common_eval.path))

    # reduce tokens per sample by the required context window size
    cfg.task.tokens_per_sample -= cfg.eval_lm.context_window

    models, model_args, task = checkpoint_utils.load_model_ensemble_and_task(
        [cfg.common_eval.path],
        arg_overrides=overrides,
        suffix=cfg.checkpoint.checkpoint_suffix,
        strict=(cfg.checkpoint.checkpoint_shard_count == 1),
        num_shards=cfg.checkpoint.checkpoint_shard_count,
    )

    # Load dataset splits
    gen_subset = cfg.dataset.gen_subset
    task.load_dataset(gen_subset)
    dataset = task.dataset(gen_subset)
    if cfg.eval_lm.context_window > 0:
        dataset = LMContextWindowDataset(
            dataset=dataset,
            tokens_per_sample=cfg.task.tokens_per_sample,
            context_window=cfg.eval_lm.context_window,
            pad_idx=task.source_dictionary.pad(),
        )
    logger.info("{} {} {} examples".format(cfg.task.data, gen_subset, len(dataset)))

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        if use_fp16:
            model.half()
        if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
            model.cuda()
        model.prepare_for_inference_(cfg)

    assert len(models) > 0

    logger.info(
        "num. model params: {}".format(sum(p.numel() for p in models[0].parameters()))
    )

    itr = task.get_batch_iterator(
        dataset=dataset,
        max_tokens=cfg.dataset.max_tokens or 36000,
        max_sentences=cfg.dataset.batch_size,
        max_positions=utils.resolve_max_positions(
            *[model.max_positions() for model in models]
        ),
        ignore_invalid_inputs=True,
        num_shards=max(
            cfg.dataset.num_shards,
            cfg.distributed_training.distributed_world_size,
        ),
        shard_id=max(
            cfg.dataset.shard_id,
            cfg.distributed_training.distributed_rank,
        ),
        num_workers=cfg.dataset.num_workers,
        data_buffer_size=cfg.dataset.data_buffer_size,
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=cfg.common.log_format,
        log_interval=cfg.common.log_interval,
        default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
    )

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(task.target_dictionary, cfg.eval_lm.softmax_batch)

    score_sum = 0.0
    count = 0

    if cfg.common_eval.remove_bpe is not None:
        if cfg.common_eval.remove_bpe == "sentencepiece":
            raise NotImplementedError
        else:
            bpe_cont = cfg.common_eval.remove_bpe.rstrip()
            bpe_toks = {
                i
                for i in range(len(task.source_dictionary))
                if task.source_dictionary[i].endswith(bpe_cont)
            }
        bpe_len = len(bpe_cont)
    else:
        bpe_toks = None
        bpe_len = 0

    word_stats = dict()

    wps_meter = TimeMeter()

    for sample in progress:
        if "net_input" not in sample:
            continue

        sample = utils.move_to_cuda(sample) if use_cuda else sample

        gen_timer.start()
        hypos = scorer.generate(models, sample)
        gen_timer.stop(sample["ntokens"])

        for i, hypos_i in enumerate(hypos):
            hypo = hypos_i[0]
            sample_id = sample["id"][i]

            tokens = hypo["tokens"]
            tgt_len = tokens.numel()
            pos_scores = hypo["positional_scores"].float()

            if cfg.task.add_bos_token:
                assert hypo["tokens"][0].item() == task.target_dictionary.bos()
                tokens = tokens[1:]
                pos_scores = pos_scores[1:]

            skipped_toks = 0
            if bpe_toks is not None:
                for i in range(tgt_len - 1):
                    if tokens[i].item() in bpe_toks:
                        skipped_toks += 1
                        pos_scores[i + 1] += pos_scores[i]
                        pos_scores[i] = 0

            inf_scores = pos_scores.eq(float("inf")) | pos_scores.eq(float("-inf"))
            if inf_scores.any():
                logger.info(
                    "skipping tokens with inf scores:",
                    task.target_dictionary.string(tokens[inf_scores.nonzero()]),
                )
                pos_scores = pos_scores[(~inf_scores).nonzero()]
            score_sum += pos_scores.sum().cpu()
            count += pos_scores.numel() - skipped_toks

            if cfg.eval_lm.output_word_probs or cfg.eval_lm.output_word_stats:
                w = ""
                word_prob = []
                is_bpe = False
                for i in range(len(tokens)):
                    w_ind = tokens[i].item()
                    w += task.source_dictionary[w_ind]
                    if bpe_toks is not None and w_ind in bpe_toks:
                        w = w[:-bpe_len]
                        is_bpe = True
                    else:
                        word_prob.append((w, pos_scores[i].item()))

                        next_prob = None
                        ind = i + 1
                        while ind < len(tokens):
                            if pos_scores[ind].item() != 0:
                                next_prob = pos_scores[ind]
                                break
                            ind += 1

                        word_stats.setdefault(w, WordStat(w, is_bpe)).add(
                            pos_scores[i].item(), next_prob
                        )
                        is_bpe = False
                        w = ""
                if cfg.eval_lm.output_word_probs:
                    logger.info(
                        str(int(sample_id))
                        + " "
                        + (
                            "\t".join(
                                "{} [{:2f}]".format(x[0], x[1]) for x in word_prob
                            )
                        )
                    )

        wps_meter.update(sample["ntokens"])
        progress.log({"wps": round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count / math.log(2)  # convert to base 2
    logger.info(
        "Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)".format(
            gen_timer.n, gen_timer.sum, 1.0 / gen_timer.avg
        )
    )
    logger.info(
        "Loss (base 2): {:.4f}, Perplexity: {:.2f}".format(
            avg_nll_loss, 2 ** avg_nll_loss
        )
    )

    if cfg.eval_lm.output_word_stats:
        for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
            logger.info(ws)
예제 #3
0
def main(parsed_args, **unused_kwargs):
    assert parsed_args.path is not None, '--path required for evaluation!'

    if torch.cuda.is_available() and not parsed_args.cpu:
        torch.cuda.set_device(parsed_args.device_id)

    utils.import_user_module(parsed_args)

    logger.info(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load ensemble
    logger.info('loading model(s) from {}'.format(parsed_args.path))
    models, args = checkpoint_utils.load_model_ensemble(
        parsed_args.path.split(os.pathsep),
        arg_overrides=eval(parsed_args.model_overrides),
        task=task,
        suffix=getattr(parsed_args, "checkpoint_suffix", ""),
    )

    for arg in vars(parsed_args).keys():
        if arg not in {
                'self_target',
                'future_target',
                'past_target',
                'tokens_per_sample',
                'output_size_dictionary',
                'add_bos_token',
        }:
            setattr(args, arg, getattr(parsed_args, arg))

    # reduce tokens per sample by the required context window size
    args.tokens_per_sample -= args.context_window
    task = tasks.setup_task(args)

    # Load dataset splits
    task.load_dataset(args.gen_subset)
    dataset = task.dataset(args.gen_subset)
    if args.context_window > 0:
        dataset = LMContextWindowDataset(
            dataset=dataset,
            tokens_per_sample=args.tokens_per_sample,
            context_window=args.context_window,
            pad_idx=task.source_dictionary.pad(),
        )
    logger.info('{} {} {} examples'.format(args.data, args.gen_subset,
                                           len(dataset)))

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.prepare_for_inference_(args)
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    assert len(models) > 0

    logger.info('num. model params: {}'.format(
        sum(p.numel() for p in models[0].parameters())))

    itr = task.get_batch_iterator(
        dataset=dataset,
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=True,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        default_log_format=('tqdm' if not args.no_progress_bar else 'none'),
    )

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(task.target_dictionary, args.softmax_batch)

    score_sum = 0.
    count = 0

    if args.remove_bpe is not None:
        if args.remove_bpe == 'sentencepiece':
            raise NotImplementedError
        else:
            bpe_cont = args.remove_bpe.rstrip()
            bpe_toks = {
                i
                for i in range(len(task.source_dictionary))
                if task.source_dictionary[i].endswith(bpe_cont)
            }
        bpe_len = len(bpe_cont)
    else:
        bpe_toks = None
        bpe_len = 0

    word_stats = dict()

    wps_meter = TimeMeter()

    for sample in progress:
        if 'net_input' not in sample:
            continue

        sample = utils.move_to_cuda(sample) if use_cuda else sample

        gen_timer.start()
        hypos = scorer.generate(models, sample)
        gen_timer.stop(sample['ntokens'])

        for i, hypos_i in enumerate(hypos):
            hypo = hypos_i[0]
            sample_id = sample['id'][i]

            tokens = hypo['tokens']
            tgt_len = tokens.numel()
            pos_scores = hypo['positional_scores'].float()

            if getattr(args, 'add_bos_token', False):
                assert hypo['tokens'][0].item() == task.target_dictionary.bos()
                tokens = tokens[1:]
                pos_scores = pos_scores[1:]

            skipped_toks = 0
            if bpe_toks is not None:
                for i in range(tgt_len - 1):
                    if tokens[i].item() in bpe_toks:
                        skipped_toks += 1
                        pos_scores[i + 1] += pos_scores[i]
                        pos_scores[i] = 0

            inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(
                float('-inf'))
            if inf_scores.any():
                logger.info(
                    'skipping tokens with inf scores:',
                    task.target_dictionary.string(
                        tokens[inf_scores.nonzero()]))
                pos_scores = pos_scores[(~inf_scores).nonzero()]
            score_sum += pos_scores.sum().cpu()
            count += pos_scores.numel() - skipped_toks

            if args.output_word_probs or args.output_word_stats:
                w = ''
                word_prob = []
                is_bpe = False
                for i in range(len(tokens)):
                    w_ind = tokens[i].item()
                    w += task.source_dictionary[w_ind]
                    if bpe_toks is not None and w_ind in bpe_toks:
                        w = w[:-bpe_len]
                        is_bpe = True
                    else:
                        word_prob.append((w, pos_scores[i].item()))

                        next_prob = None
                        ind = i + 1
                        while ind < len(tokens):
                            if pos_scores[ind].item() != 0:
                                next_prob = pos_scores[ind]
                                break
                            ind += 1

                        word_stats.setdefault(w, WordStat(w, is_bpe)).add(
                            pos_scores[i].item(), next_prob)
                        is_bpe = False
                        w = ''
                if args.output_word_probs:
                    logger.info(
                        str(int(sample_id)) + " " +
                        ('\t'.join('{} [{:2f}]'.format(x[0], x[1])
                                   for x in word_prob)))

        wps_meter.update(sample['ntokens'])
        progress.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count / math.log(2)  # convert to base 2
    logger.info('Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(
        gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    logger.info('Loss (base 2): {:.4f}, Perplexity: {:.2f}'.format(
        avg_nll_loss, 2**avg_nll_loss))

    if args.output_word_stats:
        for ws in sorted(word_stats.values(),
                         key=lambda x: x.count,
                         reverse=True):
            logger.info(ws)
예제 #4
0
def main(parsed_args):
    assert parsed_args.path is not None, '--path required for evaluation!'

    import_user_module(parsed_args)

    print(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load ensemble
    print('| loading model(s) from {}'.format(parsed_args.path))
    models, args = utils.load_ensemble_for_inference(
        parsed_args.path.split(':'),
        task,
        model_arg_overrides=eval(parsed_args.model_overrides),
    )

    for arg in vars(parsed_args).keys():
        if arg not in {
                'self_target', 'future_target', 'past_target',
                'tokens_per_sample', 'output_size_dictionary'
        }:
            setattr(args, arg, getattr(parsed_args, arg))

    # reduce tokens per sample by the required context window size
    args.tokens_per_sample -= args.context_window
    task = tasks.setup_task(args)

    # Load dataset splits
    task.load_dataset(args.gen_subset)
    dataset = task.dataset(args.gen_subset)
    if args.context_window > 0:
        dataset = LMContextWindowDataset(
            dataset=dataset,
            tokens_per_sample=args.tokens_per_sample,
            context_window=args.context_window,
            pad_idx=task.source_dictionary.pad(),
        )
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(dataset)))

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    assert len(models) > 0

    print('num. model params: {}'.format(
        sum(p.numel() for p in models[0].parameters())))

    itr = task.get_batch_iterator(
        dataset=dataset,
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=True,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(task.target_dictionary, args.softmax_batch)

    score_sum = 0.
    count = 0

    if args.remove_bpe is not None:
        if args.remove_bpe == 'sentencepiece':
            raise NotImplementedError
        else:
            bpe_cont = args.remove_bpe.rstrip()
            bpe_toks = set(i for i in range(len(task.source_dictionary))
                           if task.source_dictionary[i].endswith(bpe_cont))
        bpe_len = len(bpe_cont)
    else:
        bpe_toks = None
        bpe_len = 0

    word_stats = dict()

    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()

        for sample in t:
            if 'net_input' not in sample:
                continue

            sample = utils.move_to_cuda(sample) if use_cuda else sample

            gen_timer.start()
            hypos = scorer.generate(models, sample)
            gen_timer.stop(sample['ntokens'])

            for hypos_i in hypos:
                hypo = hypos_i[0]

                tokens = hypo['tokens']
                tgt_len = tokens.numel()
                pos_scores = hypo['positional_scores'].float()

                skipped_toks = 0
                if bpe_toks is not None:
                    for i in range(tgt_len - 1):
                        if tokens[i].item() in bpe_toks:
                            skipped_toks += 1
                            pos_scores[i + 1] += pos_scores[i]
                            pos_scores[i] = 0

                inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(
                    float('-inf'))
                if inf_scores.any():
                    print(
                        '| Skipping tokens with inf scores:',
                        task.target_dictionary.string(
                            tokens[inf_scores.nonzero()]))
                    pos_scores = pos_scores[(~inf_scores).nonzero()]
                score_sum += pos_scores.sum().cpu()
                count += pos_scores.numel() - skipped_toks

                if args.output_word_probs or args.output_word_stats:
                    w = ''
                    word_prob = []
                    is_bpe = False
                    for i in range(len(tokens)):
                        w_ind = tokens[i].item()
                        w += task.source_dictionary[w_ind]
                        if bpe_toks is not None and w_ind in bpe_toks:
                            w = w[:-bpe_len]
                            is_bpe = True
                        else:
                            word_prob.append((w, pos_scores[i].item()))

                            next_prob = None
                            ind = i + 1
                            while ind < len(tokens):
                                if pos_scores[ind].item() != 0:
                                    next_prob = pos_scores[ind]
                                    break
                                ind += 1

                            word_stats.setdefault(w, WordStat(w, is_bpe)).add(
                                pos_scores[i].item(), next_prob)
                            is_bpe = False
                            w = ''
                    if args.output_word_probs:
                        print('\t'.join('{} [{:2f}]'.format(x[0], x[1])
                                        for x in word_prob))

            wps_meter.update(sample['ntokens'])
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(
        gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    if args.log2:
        avg_nll_loss /= 0.69314718056
        print(
            '| Loss: {:.4f}, Perplexity: {:.2f}, Cross-entropy: {:.4f}'.format(
                avg_nll_loss, np.power(2, avg_nll_loss),
                avg_nll_loss * (count - len(dataset)) / len(dataset)))
    else:
        print(
            '| Loss: {:.4f}, Perplexity: {:.2f}, Cross-entropy: {:.4f}'.format(
                avg_nll_loss, np.exp(avg_nll_loss),
                avg_nll_loss * (count - len(dataset)) / len(dataset)))

    if args.output_word_stats:
        for ws in sorted(word_stats.values(),
                         key=lambda x: x.count,
                         reverse=True):
            print(ws)
예제 #5
0
def main(parsed_args):
    assert parsed_args.path is not None, '--path required for evaluation!'

    utils.import_user_module(parsed_args)

    logger.info(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load ensemble
    logger.info('loading model(s) from {}'.format(parsed_args.path))
    models, args = checkpoint_utils.load_model_ensemble(
        parsed_args.path.split(os.pathsep),
        arg_overrides=eval(parsed_args.model_overrides),
        task=task,
    )

    for arg in vars(parsed_args).keys():
        if arg not in {
                'self_target',
                'future_target',
                'past_target',
                'tokens_per_sample',
                'output_size_dictionary',
                'add_bos_token',
        }:
            setattr(args, arg, getattr(parsed_args, arg))

    # reduce tokens per sample by the required context window size
    args.tokens_per_sample -= args.context_window
    task = tasks.setup_task(args)

    # Load dataset splits
    task.load_dataset(args.gen_subset)
    dataset = task.dataset(args.gen_subset)
    if args.context_window > 0:
        dataset = LMContextWindowDataset(
            dataset=dataset,
            tokens_per_sample=args.tokens_per_sample,
            context_window=args.context_window,
            pad_idx=task.source_dictionary.pad(),
        )
    logger.info('{} {} {} examples'.format(args.data, args.gen_subset,
                                           len(dataset)))

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    assert len(models) > 0

    logger.info('num. model params: {}'.format(
        sum(p.numel() for p in models[0].parameters())))

    itr = task.get_batch_iterator(
        dataset=dataset,
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=True,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(task.target_dictionary,
                            args.softmax_batch,
                            args=args)

    score_sum = 0.
    count = 0

    if args.remove_bpe is not None:
        if args.remove_bpe == 'sentencepiece':
            raise NotImplementedError
        else:
            bpe_cont = args.remove_bpe.rstrip()
            bpe_toks = {
                i
                for i in range(len(task.source_dictionary))
                if task.source_dictionary[i].endswith(bpe_cont)
            }
        bpe_len = len(bpe_cont)
    else:
        bpe_toks = None
        bpe_len = 0

    word_stats = dict()

    if args.knnlm and args.save_knnlm_dstore:
        raise ValueError(
            "Cannot use knnlm while trying to build the datastore!")

    if args.knnlm:
        knn_dstore = KNN_Dstore(args)

    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()

        if args.save_knnlm_dstore:
            print('keytype being saved:', args.knn_keytype)
            if args.dstore_fp16:
                print('Saving fp16')
                dstore_keys = np.memmap(args.dstore_mmap + '_keys.npy',
                                        dtype=np.float16,
                                        mode='w+',
                                        shape=(args.dstore_size,
                                               args.decoder_embed_dim))
                dstore_vals = np.memmap(args.dstore_mmap + '_vals.npy',
                                        dtype=np.int16,
                                        mode='w+',
                                        shape=(args.dstore_size, 1))
            else:
                print('Saving fp32')
                dstore_keys = np.memmap(args.dstore_mmap + '_keys.npy',
                                        dtype=np.float32,
                                        mode='w+',
                                        shape=(args.dstore_size,
                                               args.decoder_embed_dim))
                dstore_vals = np.memmap(args.dstore_mmap + '_vals.npy',
                                        dtype=np.int,
                                        mode='w+',
                                        shape=(args.dstore_size, 1))

        dstore_idx = 0
        for ex_i, sample in enumerate(t):
            if 'net_input' not in sample:
                continue

            sample = utils.move_to_cuda(sample) if use_cuda else sample

            gen_timer.start()
            if args.knnlm:
                hypos = scorer.generate(models, sample, knn_dstore=knn_dstore)
            else:
                hypos = scorer.generate(models, sample)
            gen_timer.stop(sample['ntokens'])

            for i, hypos_i in enumerate(hypos):
                hypo = hypos_i[0]
                if args.save_knnlm_dstore:
                    shape = hypo['dstore_keys'].shape
                    if shape[0] == args.tokens_per_sample:
                        if dstore_idx + shape[0] > args.dstore_size:
                            shape = [args.dstore_size - dstore_idx]
                            hypo['dstore_keys'] = hypo[
                                'dstore_keys'][:shape[0]]
                        if args.dstore_fp16:
                            dstore_keys[dstore_idx:shape[0] +
                                        dstore_idx] = hypo['dstore_keys'].view(
                                            -1, args.decoder_embed_dim).cpu(
                                            ).numpy().astype(np.float16)
                            dstore_vals[dstore_idx:shape[0] +
                                        dstore_idx] = hypo['tokens'].view(
                                            -1,
                                            1).cpu().numpy().astype(np.int16)
                        else:
                            dstore_keys[dstore_idx:shape[0] +
                                        dstore_idx] = hypo['dstore_keys'].view(
                                            -1, args.decoder_embed_dim).cpu(
                                            ).numpy().astype(np.float32)
                            dstore_vals[dstore_idx:shape[0] +
                                        dstore_idx] = hypo['tokens'].view(
                                            -1, 1).cpu().numpy().astype(np.int)

                        dstore_idx += shape[0]
                    else:
                        print('Skipping this one with shape', shape)

                sample_id = sample['id'][i]

                tokens = hypo['tokens']
                tgt_len = tokens.numel()
                pos_scores = hypo['positional_scores'].float()

                if args.add_bos_token:
                    assert hypo['tokens'][0].item(
                    ) == task.target_dictionary.bos()
                    tokens = tokens[1:]
                    pos_scores = pos_scores[1:]

                skipped_toks = 0
                if bpe_toks is not None:
                    for i in range(tgt_len - 1):
                        if tokens[i].item() in bpe_toks:
                            skipped_toks += 1
                            pos_scores[i + 1] += pos_scores[i]
                            pos_scores[i] = 0

                #inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
                #if inf_scores.any():
                #    logger.info(
                #        'skipping tokens with inf scores:',
                #        task.target_dictionary.string(tokens[inf_scores.nonzero()])
                #    )
                #    pos_scores = pos_scores[(~inf_scores).nonzero()]
                score_sum += pos_scores.sum().cpu()
                count += pos_scores.numel() - skipped_toks

                if args.output_word_probs or args.output_word_stats:
                    w = ''
                    word_prob = []
                    is_bpe = False
                    for i in range(len(tokens)):
                        w_ind = tokens[i].item()
                        w += task.source_dictionary[w_ind]
                        if bpe_toks is not None and w_ind in bpe_toks:
                            w = w[:-bpe_len]
                            is_bpe = True
                        else:
                            word_prob.append((w, pos_scores[i].item()))

                            next_prob = None
                            ind = i + 1
                            while ind < len(tokens):
                                if pos_scores[ind].item() != 0:
                                    next_prob = pos_scores[ind]
                                    break
                                ind += 1

                            word_stats.setdefault(w, WordStat(w, is_bpe)).add(
                                pos_scores[i].item(), next_prob)
                            is_bpe = False
                            w = ''
                    if args.output_word_probs:
                        logger.info(
                            str(int(sample_id)) + " " +
                            ('\t'.join('{} [{:2f}]'.format(x[0], x[1])
                                       for x in word_prob)))

            wps_meter.update(sample['ntokens'])
            t.log({'wps': round(wps_meter.avg)})

    if args.save_knnlm_dstore:
        print("dstore_idx", dstore_idx, "final shape", shape)
        print("Keys", dstore_keys.shape, dstore_keys.dtype)
        print("Vals", dstore_vals.shape, dstore_vals.dtype)

    avg_nll_loss = -score_sum / count / math.log(2)  # convert to base 2
    logger.info('Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(
        gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    logger.info('Loss (base 2): {:.4f}, Perplexity: {:.2f}'.format(
        avg_nll_loss, 2**avg_nll_loss))

    if args.output_word_stats:
        for ws in sorted(word_stats.values(),
                         key=lambda x: x.count,
                         reverse=True):
            logger.info(ws)
예제 #6
0
def test(parsed_args):
    # Make sure we didn't screw up the params
    assert parsed_args.path is not None, '--path required for evaluation!'
    assert parsed_args.sample_break_mode == 'eos', 'Sample break mode must be eos!'

    # Print the args
    import_user_module(parsed_args)
    print(parsed_args)

    # Do we use CUDA
    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    # Get the task (Language Modeling)
    task = tasks.setup_task(parsed_args)

    # Load ensemble
    print('| loading model(s) from {}'.format(parsed_args.path))
    models, args = utils.load_ensemble_for_inference(
        parsed_args.path.split(':'),
        task,
        model_arg_overrides=eval(parsed_args.model_overrides),
    )

    for arg in vars(parsed_args).keys():
        if arg not in {
                'self_target', 'future_target', 'past_target',
                'tokens_per_sample', 'output_size_dictionary'
        }:
            setattr(args, arg, getattr(parsed_args, arg))

    # Load dataset splits
    task.load_dataset(args.gen_subset)
    dataset = task.dataset(args.gen_subset)
    if args.context_window > 0:
        dataset = LMContextWindowDataset(
            dataset=dataset,
            tokens_per_sample=args.tokens_per_sample,
            context_window=args.context_window,
            pad_idx=task.source_dictionary.pad(),
        )
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(dataset)))

    # Optimize model for generation
    assert len(models) > 0
    model = models[0]
    model.make_generation_fast_()
    if args.fp16:
        model.half()
    if use_cuda:
        model.cuda()

    # Make data iterator
    itr = task.get_batch_iterator(
        dataset=dataset,
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=True,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Iterate over batches of sentences
    # Get the sentence logps for the batch
    all_s_logps = []
    all_n_tokens = 0
    all_n_sentences = 0
    for sample in itr:
        if 'net_input' not in sample:
            continue

        # Move sample to GPU if possible
        sample = utils.move_to_cuda(sample) if use_cuda else sample

        # Number of sentences in this batch
        bsz = sample['nsentences']
        all_n_sentences += bsz

        # Get the softmax outputs for the batch
        # The resultant tensor has shape: BATCH_SZ x N_TOKENS x VOCAB_SZ
        probs = []
        net_input = sample['net_input']
        with torch.no_grad():
            model.eval()
            decoder_out = model.forward(**net_input)
            probs = model.get_normalized_probs(decoder_out,
                                               log_probs=True,
                                               sample=sample).data

        # Make sure we have a softmax-sequence for each sentence in the batch
        assert len(probs) == bsz

        # Assert that the softmax output is correct
        assert torch.allclose(torch.sum(torch.exp(probs), dim=2),
                              torch.ones(probs.shape[:2]))

        # Get the token logps for each sentence from the softmax outputs
        target = sample['target']
        logps = probs.gather(
            dim=2,
            index=target.unsqueeze(-1),
        ).squeeze(2)

        # Iterate over each sentence in the batch
        # Get the sum of logps for each sentence
        start_idxs = [0] * bsz
        for i in range(bsz):

            # Get the token indices / strings
            tokens = utils.strip_pad(target[i, start_idxs[i]:],
                                     task.source_dictionary.pad())
            token_idxs = [tokens[i].item() for i in range(len(tokens))]
            token_strings = [task.source_dictionary[idx] for idx in token_idxs]

            # Maintain total number of tokens
            all_n_tokens += len(tokens)

            # This is the original sentence
            sentence = ' '.join(token_strings)

            # Get the token logps for this sentence
            s_len = len(tokens)
            s_logps = logps[i][:s_len]
            all_s_logps.append(torch.sum(s_logps))

    # Get the average sentence logp over all sentences in the test set
    avg_s_logp = sum(all_s_logps) / all_n_sentences
    print(-1 * avg_s_logp.item())
예제 #7
0
def main(parsed_args):
    if parsed_args.dstore_mmap is not None:
        d = os.path.dirname(parsed_args.dstore_mmap)
        print('mmap from {}'.format(d))
        if not os.path.exists(d):
            print('making dir')
            os.system('mkdir -p {}'.format(d))

    utils.import_user_module(parsed_args)

    logger.info(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load model.
    hf_tokenizer = AutoTokenizer.from_pretrained(parsed_args.hf_model)
    if parsed_args.hf_enc_mode == 'masked':
        hf_model = AutoModelForMaskedLM.from_pretrained(parsed_args.hf_model)
    elif parsed_args.hf_enc_mode == 'causal':
        hf_model = AutoModelForCausalLM.from_pretrained(parsed_args.hf_model)

    if use_cuda:
        hf_model.cuda()

    device = next(hf_model.parameters()).device

    check_input_ids = hf_tokenizer('hello world')['input_ids']
    add_cls_token = check_input_ids[0] == hf_tokenizer.cls_token_id
    add_sep_token = check_input_ids[-1] == hf_tokenizer.sep_token_id
    print('add_cls_token = {} {} {}'.format(add_cls_token, hf_tokenizer.cls_token, hf_tokenizer.cls_token_id))
    print('add_sep_token = {} {} {}'.format(add_sep_token, hf_tokenizer.sep_token, hf_tokenizer.sep_token_id))

    args = copy.deepcopy(parsed_args)

    # reduce tokens per sample by the required context window size
    args.tokens_per_sample -= args.context_window
    task = tasks.setup_task(args)

    # Load dataset splits
    task.load_dataset(args.gen_subset)
    task_dataset = task.dataset(args.gen_subset)
    assert args.context_window > 0
    dataset = LMContextWindowDataset(
        dataset=task_dataset,
        tokens_per_sample=args.tokens_per_sample,
        context_window=args.context_window,
        pad_idx=task.source_dictionary.pad(),
    )
    logger.info('{} {} {} examples'.format(args.data, args.gen_subset, len(dataset)))

    model_max_length = min(hf_tokenizer.model_max_length, parsed_args.hf_max_position)

    itr = task.get_batch_iterator(
        dataset=dataset,
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(*[
            model_max_length
        ]),
        ignore_invalid_inputs=True,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)
    #).next_epoch_itr(shuffle=True)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(task.target_dictionary, args.softmax_batch, args=args)

    score_sum = 0.
    count = 0

    if args.remove_bpe is not None:
        if args.remove_bpe == 'sentencepiece':
            raise NotImplementedError
        else:
            bpe_cont = args.remove_bpe.rstrip()
            bpe_toks = {
                i
                for i in range(len(task.source_dictionary))
                if task.source_dictionary[i].endswith(bpe_cont)
            }
        bpe_len = len(bpe_cont)
    else:
        bpe_toks = None
        bpe_len = 0

    word_stats = dict()

    if args.knnlm and args.save_knnlm_dstore:
        raise ValueError("Cannot use knnlm while trying to build the datastore!")

    if args.knnlm:
        knn_dstore = KNN_Dstore(args)

    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()

        if args.save_knnlm_dstore:
            print('keytype being saved:', args.knn_keytype)
            dstore_keys = np.memmap(args.dstore_mmap+'_keys.npy', dtype=np.float32, mode='w+', shape=(args.dstore_size, hf_model.config.d_model))
            dstore_vals = np.memmap(args.dstore_mmap+'_vals.npy', dtype=np.int, mode='w+', shape=(args.dstore_size, 1))

        if args.save_extra:
            writer = Writer(outdir='demo-out', max_size=args.save_extra_max_size, k=args.k, vec_size=1024)

        def pad(x, pad_id=-1):
            max_len = max([len(xx) for xx in x])
            x = [xx + [pad_id] * (max_len - len(xx)) for xx in x]
            return x

        def batchify(batch):
            new_batch = {}
            new_batch['input_ids'] = torch.tensor(pad(batch['src_tokens'], hf_tokenizer.pad_token_id), dtype=torch.long, device=device)
            new_batch['context_mask'] = torch.tensor(pad(batch['mask'], -1), dtype=torch.long, device=device)
            new_batch['word_id'] = torch.tensor(pad(batch['word_id'], -1), dtype=torch.long, device=device)
            new_batch['target'] = torch.tensor(pad(batch['target'], -1), dtype=torch.long, device=device)
            return new_batch

        dstore_idx = 0
        dstore_full = False
        num_tokens = 0
        for ex_i, sample in tqdm(enumerate(t), desc='encode'):
            if 'net_input' not in sample:
                continue

            all_tokens = torch.cat([sample['net_input']['src_tokens'], sample['target'][:, -1, None]], -1)

            hf_batch = collections.defaultdict(list)
            for tok in all_tokens.tolist():
                tok = [tt for tt in tok if tt != dataset.pad_idx]
                raw_text = [task_dataset.vocab[tt] for tt in tok]
                hf_src_tokens, hf_target, hf_raw_target, hf_raw_text, hf_word_id, hf_mask = [], [], [], [], [], []
                for i_w in range(len(raw_text) - 1):

                    w = raw_text[i_w]
                    tok_ = hf_tokenizer.encode(w, add_special_tokens=False)
                    if i_w == 0 and add_cls_token:
                        if tok_[0] != hf_tokenizer.cls_token_id:
                            tok_ = [hf_tokenizer.cls_token_id] + tok_

                    if len(hf_src_tokens) + len(tok_) > model_max_length:
                        break

                    hf_src_tokens += tok_
                    hf_raw_text += hf_tokenizer.convert_ids_to_tokens(tok_)
                    hf_word_id += [i_w] * len(tok_)
                    hf_mask += [0] * (len(tok_) - 1) + [1]
                    hf_target += [tok[i_w + 1]] * len(tok_)
                    hf_raw_target += [raw_text[i_w + 1]]

                assert len(hf_src_tokens) == len(hf_target)
                assert len(hf_src_tokens) == len(hf_word_id)
                assert len(hf_src_tokens) == len(hf_mask)

                hf_batch['src_tokens'].append(hf_src_tokens)
                hf_batch['target'].append(hf_target) # This is indexed by KNN-LM tokenizer.
                hf_batch['raw_target'].append(hf_raw_target)
                hf_batch['word_id'].append(hf_word_id)
                hf_batch['mask'].append(hf_mask)

                num_tokens += len(hf_src_tokens)

            hf_batch_ = batchify(hf_batch)

            model_output = hf_model(hf_batch_['input_ids'], output_hidden_states=True)

            h = model_output['hidden_states'][-1]

            assert h.shape[:2] == hf_batch_['input_ids'].shape[:2]

            if args.save_knnlm_dstore and not dstore_full:

                flat_h = h.view(-1, hf_model.config.d_model)
                mask_ = hf_batch_['context_mask'].view(-1) == 1
                keys_ = flat_h[mask_]
                vals_ = hf_batch_['target'].view(-1, 1)[mask_]

                shape = keys_.shape
                if dstore_idx + shape[0] > args.dstore_size:
                    shape = [args.dstore_size - dstore_idx]
                    dstore_full = True

                keys_ = keys_[:shape[0]]
                vals_ = vals_[:shape[0]]

                assert keys_.shape[0] == vals_.shape[0]

                dstore_keys[dstore_idx:shape[0]+dstore_idx] = keys_.cpu().numpy().astype(np.float32)
                dstore_vals[dstore_idx:shape[0]+dstore_idx] = vals_.cpu().numpy().astype(np.int)

                dstore_idx += shape[0]
            if dstore_full:
                print('Datastore is full with {} items.'.format(args.dstore_size))

            wps_meter.update(sample['ntokens'])
            t.log({'wps': round(wps_meter.avg)})

            # Write saved values to disk.
            if args.save_extra:
                writer.update(extra)

    if args.save_knnlm_dstore:
        print("dstore_idx", dstore_idx, "final shape", shape)
        print("Keys", dstore_keys.shape, dstore_keys.dtype)
        print("Vals", dstore_vals.shape, dstore_vals.dtype)

    logger.info('done with {} tokens'.format(num_tokens))