class FairseqAgent(TorchAgent):
    """Generic wrapper around fairseq for use in ParlAI"""

    metrics = {}

    @classmethod
    def add_cmdline_args(cls, argparser):
        """Add command-line arguments specifically for this agent."""
        # first we need to add the general torch agent operations
        super(FairseqAgent, cls).add_cmdline_args(argparser)

        # let's store any defaults that were overridden
        old_defaults = argparser._defaults
        if 'clip_norm' not in old_defaults:
            # fairseq has a few awful defaults
            old_defaults['clip_norm'] = 1.0
        if 'optimizer' not in old_defaults:
            old_defaults['optimizer'] = 'adam'
            old_defaults['adam_betas'] = '(0.9,0.98)'

        agent = argparser.add_argument_group('Fairseq Arguments')
        agent.add_argument(
            '--fp16',
            default=False,
            type='bool',
            help='Use fp16 training'
        )
        agent.add_argument(
            '--fp16-init-scale',
            default=2**7,
            type=int,
            help='default FP16 loss scale'
        )
        agent.add_argument(
            '--seed',
            default=1,
            type=int,
            metavar='N',
            help='pseudo random number generator seed'
        )
        agent.add_argument(
            '--skip-generation',
            default=False,
            type='bool',
            metavar='BOOL',
            help='Skips test time beam search. Much faster if you only need PPL',
        )

        # Check subargs for generation, optimizers, criterions, archs, etc
        options.add_generation_args(argparser)
        options.add_optimization_args(argparser)
        options.add_checkpoint_args(argparser)

        # restore any user set defaults that fairseq possibly overrode
        argparser.set_defaults(**old_defaults)
        known_args = argparser.parse_known_args(nohelp=True)[0]

        if hasattr(known_args, "optimizer"):
            optimizer = known_args.optimizer
            opt_group = argparser.add_argument_group(
                '{} optimizer arguments'.format(optimizer)
            )
            optim.OPTIMIZER_REGISTRY[optimizer].add_args(opt_group)
        if hasattr(known_args, "lr_scheduler"):
            lr_scheduler = known_args.lr_scheduler
            lr_group = argparser.add_argument_group(
                '{} scheduler arguments'.format(lr_scheduler)
            )
            optim.lr_scheduler.LR_SCHEDULER_REGISTRY[lr_scheduler].add_args(lr_group)
        # We need to find out the fairseq model-specific options, so grab the
        # architecture stuff and look up its options
        arch_group = options.add_model_args(argparser)
        # Fairseq marks the arch flag as required, but it may be specified
        # by a saved model cache, so we do some weird stuff to undo that
        for a in arch_group._actions:
            if a.dest == "arch":
                a.required = False
                a.default = None
                break

        # once again restore any user-set defaults
        argparser.set_defaults(**old_defaults)
        known_args = argparser.parse_known_args(nohelp=True)[0]

        if hasattr(known_args, "arch") and known_args.arch is not None:
            arch = known_args.arch
            arch_group = argparser.add_argument_group(
                "{} architecture arguments".format(arch)
            )
            models.ARCH_MODEL_REGISTRY[arch].add_args(arch_group)

        if hasattr(known_args, "criterion"):
            crit_group = argparser.add_argument_group(
                '{} criterion arguments'.format(known_args.criterion)
            )
            criterions.CRITERION_REGISTRY[known_args.criterion].add_args(crit_group)

        # one last time, restore any user set defaults
        argparser.set_defaults(**old_defaults)

    @staticmethod
    def dictionary_class():
        # Force use of the Fairseq Dictionary
        return _FairseqDictionary

    def __init__(self, opt, shared=None):
        # In general use a basic TorchAgent wherever possible
        super().__init__(opt, shared)
        if not shared:
            # this is not a shared instance of this class, so do full initialization

            # check early if we're going to be loading the model from a checkpoint
            model_file_exists = (
                self.opt.get('model_file') and os.path.isfile(self.opt['model_file'])
            )

            # fairseq expects options to be in argparse format, instead of a dict
            # We also need to do some argument postprocessing and whatnot
            # We'll skip pretrained embeddings if we're going to override them with
            # a model checkpoint anyway
            self.args, self.opt = _fairseq_opt_wrapper(opt, model_file_exists)

            # seed the RNG
            torch.manual_seed(self.args.seed)

            # Just some identifying info
            self.id = "fairseq:{}".format(self.args.arch)

            # We need a placeholder task for fairseq
            self.task = _ParlaiTask(self.dict)

            # meters for keeping track of loss, ppl, etc.
            self.meters = defaultdict(AverageMeter)

            # actually construct the model and generator
            self.model = self.build_model()

            # Construct the generator and scorer
            self.generator = SequenceGenerator(
                [self.model],
                tgt_dict=self.dict,
                beam_size=self.args.beam,
                stop_early=(not self.args.no_early_stop),
                normalize_scores=(not self.args.unnormalized),
                len_penalty=self.args.lenpen,
                unk_penalty=self.args.unkpen,
                sampling=self.args.sampling,
                sampling_topk=self.args.sampling_topk,
                sampling_temperature=self.args.sampling_temperature,
            )
            self.scorer = SequenceScorer([self.model], self.dict)

            # set up the grader and the trainer
            self.criterion = criterions.build_criterion(self.args, self.task)

            # TODO: we might choose to add a --no-fp16 opt in the future to
            # explicitly disable fp16 instead
            if not self.args.fp16 and torch.cuda.get_device_capability(0)[0] >= 7:
                print("Heads up: using --fp16 could be a lot faster!")
            if self.use_cuda:
                self.trainer = trainer.Trainer(
                    self.args, self.task, self.model, self.criterion, None,
                )
                self.trainer._build_optimizer()
            else:
                self.trainer = None

            # if the model already existed, let's preload it and the trainer
            if model_file_exists:
                print('Loading existing model params from ' + self.opt['model_file'])
                self.load(self.opt.get('model_file'))

            # move things to the GPU if possible
            if self.use_cuda:
                self.model = self.model.cuda()
                self.generator = self.generator.cuda()
        else:
            self.model = shared['model']
            self.trainer = shared['trainer']
            self.generator = shared['generator']
            self.dict = shared['dict']
            self.args = shared['args']
            self.meters = shared['meters']

        # Start things off clean
        self.reset()

    def _check_opts_unchanged(self, saved_opts, current_opts):
        """Verify that critical options do not differ in command line vs saved model"""
        for k in NON_OVERRIDABLE_ARGS:
            if k not in saved_opts or k not in current_opts:
                # if it's not an option needed by this fairseq model, don't stress
                continue
            if saved_opts[k] != current_opts[k]:
                raise ValueError(
                    '{} cannot be overridden when --model-file is specified'.format(k)
                )

    def build_model(self):
        """
        Construct the actual Fairseq model. Default implementation is to use
        Fairseq's arch builder, but this method may be overridden to build custom
        models.
        """
        model_class = models.ARCH_MODEL_REGISTRY[self.args.arch]
        model = model_class.build_model(self.args, self.task)
        if self.args.embedding_type != 'random':
            self._copy_embeddings(
                model.encoder.embed_tokens.weight, self.args.embedding_type
            )
        return model

    def share(self):
        shared = super().share()
        shared['model'] = self.model
        shared['trainer'] = self.trainer
        shared['generator'] = self.generator
        shared['dict'] = self.dict
        shared['args'] = self.args
        shared['meters'] = self.meters
        return shared

    def save(self, path):
        """Save using fairseq's checkpointing."""
        if not path:
            return
        self.trainer.save_checkpoint(path, {'opt': self.opt, 'epoch': 0})
        # Parlai expects options to also be saved
        with open(path + '.opt', 'w') as handle:
            # overridden options shouldn't be stored, only the main ones
            if 'override' in self.opt:
                del self.opt['override']
            json.dump(self.opt, handle)

        # force save the dict
        self.dict.save(path + '.dict', sort=False)

    def load(self, path):
        """Load using fairseq's checkpointing."""
        if self.trainer:
            old_options = self.trainer.load_checkpoint(path, self.args.reset_optimizer)
            self._check_opts_unchanged(old_options, self.opt)
        else:
            load_model_state(path, self.model)

    def shutdown(self):
        if not hasattr(self, 'trainer'):
            # looks like this is a "fake" model that isn't actually used for batch_act.
            # we don't need to save this one.
            return
        super().shutdown()

    def reset(self):
        """Reset observation and episode_done."""
        super().reset()
        self.reset_metrics()

    def is_valid(self, obs):
        """Override from TorchAgent.
        Check if an observation has no tokens in it."""
        return len(obs.get('text_vec', [])) > 0

    def batchify(self, obs_batch):
        """
        Override parent batchify to set requirements for fairseq.

        Fairseq depends on sorted batch inputs for a call to rnn.pad_packed_sequence.
        Fairseq models cannot handle zero length sentences
        """
        return super().batchify(obs_batch, sort=True)

    def _update_metrics(self, metrics, sample):
        if metrics is None:
            # probably got an overflow in fp16 mode. don't count this sample
            return

        bsz = len(sample['target'])
        ntok = sample['ntokens']
        ssize = metrics['sample_size']

        for k, v in metrics.items():
            if k in {'ntokens', 'nsentences', 'sample_size'}:
                # don't need these
                continue
            elif k == "nll_loss":
                # nll loss is always normalized by ntokens
                self.meters[k].update(v, ntok)
            elif k == "loss":
                # loss is explicitly normalized by passed up sample size
                self.meters[k].update(v, ssize)
            else:
                # assume everything else it's averaged over bsz
                self.meters[k].update(v, bsz)

    def train_step(self, batch):
        """Process batch of inputs and targets and train on them.

        :param batch: parlai.core.torch_agent.Batch, contains tensorized
                      version of observations.
        """
        if batch.text_vec is None:
            return
        self.is_training = True
        sample = self._make_sample(batch)
        self.model.train()
        metrics = self.trainer.train_step([sample])
        self._update_metrics(metrics, sample)

    def eval_step(self, batch):
        """Process batch of inputs.

        If the batch includes labels, calculate validation metrics as well.
        If --skip-generation is not set, return a prediction for each input.

        :param batch: parlai.core.torch_agent.Batch, contains tensorized
                      version of observations.
        """
        if batch.text_vec is None:
            return
        self.is_training = False
        samples = self._make_sample(batch)
        self.model.eval()
        if batch.label_vec is not None and self.trainer is not None:
            # Interactive mode won't have a gold label
            metrics = self.trainer.valid_step(samples)
            self._update_metrics(metrics, samples)

        # Output placeholders
        reranked_cands = None
        generated_output = None

        # Grade each of the candidate sequences
        if batch.candidate_vecs is not None:
            bsz = len(batch.text_vec)
            reranked_cands = []
            # score the candidates for each item in the batch separately, so that
            # we can support variable number of candidates
            for i in range(bsz):
                cands = batch.candidate_vecs[i]
                if not cands:
                    reranked_cands.append(None)
                    continue
                ncand = len(cands)
                # repeat the input many times
                xs = batch.text_vec[i].unsqueeze(0).expand(ncand, -1)
                # some models crash if there's leading padding on every example
                xs = xs[:, :batch.text_lengths[i]]
                # and appropriately pack the outputs
                ys, _ = padded_tensor(cands, self.NULL_IDX, self.use_cuda)
                s = self._make_sample(xs=xs, ys=ys)
                # perform the actual grading, extract the scores
                scored = list(self.scorer.score_batched_itr([s], cuda=self.use_cuda))
                scores = [s[3][0]['score'].item() for s in scored]
                # intentional hanging comma here; argsort returns a list
                ranked, = argsort(scores, batch.candidates[i], descending=True)
                reranked_cands.append(ranked)

        # Next generate freely to create our response
        if not self.args.skip_generation:
            generated_output = self._generate(samples)
        elif reranked_cands:
            # we're skiping generation, but we're also grading candidates
            # so output the highest ranked candidate
            # In the case of zero candidates, we don't have something to rank,
            # so we may need to pass on that None
            generated_output = [
                ranked and ranked[0] or None for ranked in reranked_cands
            ]
        else:
            # no output at all
            pass

        return Output(generated_output, reranked_cands)

    def _generate(self, samples):
        no_prev_token = {
            k: v for k, v in samples['net_input'].items() if k != 'prev_output_tokens'
        }
        gens = self.generator.generate(no_prev_token, maxlen=64)
        bsz = samples['net_input']['src_tokens'].size(0)
        responses = []
        for i in range(bsz):
            beams = gens[i]
            selected = max(beams, key=lambda x: x["score"])
            tokens = selected["tokens"]
            start = 0
            end = -1
            for i, t in enumerate(tokens):
                t = t.item()
                if t == self.dict.bos_index:
                    # don't include <s> token
                    start = i + 1
                    continue
                if t == self.dict.eos_index:
                    # stop (and don't include) </s> token
                    end = i
                    break
            responses.append(self.dict.vec2txt(tokens[start:end]))
        return responses

    def report(self):
        """Return metrics calculated by the model."""
        # if we haven't initialized yet, just return a dummy object
        if not hasattr(self, "trainer"):
            return {}

        output = {k: v.avg for k, v in self.meters.items()}

        if "nll_loss" in self.meters:
            # special case, we used sentence averaging so ppl comes from nll_loss
            output["ppl"] = np.exp2(self.meters["nll_loss"].avg)
        else:
            # normal case, just use loss
            output["ppl"] = np.exp2(self.meters["loss"].avg)

        # Fairseq trainer metrics we'll pass up the way
        trainer_metrics = {"ups", "wps", "gnorm", "clip"}
        if self.is_training:
            for k in trainer_metrics:
                output[k] = self.trainer.meters[k].avg

        # for display purposes
        output = {k: round_sigfigs(v, 4) for k, v in output.items()}
        return output

    def reset_metrics(self):
        """Reset metrics calculated by the model back to zero."""
        if not hasattr(self, "trainer"):
            # We haven't set up the trainer yet, so we don't have any metrics
            return
        # We need to reset everything
        self.meters.clear()
        if self.trainer:
            for k in self.trainer.meters:
                self.trainer.meters[k].reset()

    def receive_metrics(self, metrics_dict):
        """Update lr scheduler with validation loss."""
        # TODO: this should be smarter
        self.trainer.lr_step(-1, metrics_dict["loss"])

    # Helper functions
    def _seq_length(self, xs):
        """Compute length of the sequence (non-padded size)."""
        return xs.ne(self.dict.pad_index).long().sum(dim=-1)

    def _right_shifted_ys(self, ys):
        """Replace first token with EOS and shift remaining tokens right 1."""
        result = torch.LongTensor(ys.size())
        result[:, 0] = self.dict.eos_index
        result[:, 1:] = ys[:, :-1]
        return result

    def _make_sample(self, batch=None, xs=None, ys=None):
        """Generate a sample object that Fairseq expects."""
        # add extra info to samples
        if batch is None and xs is None:
            raise ValueError("Must supply either batch or xs")
        if batch is None and ys is None:
            raise ValueError("Must supply either batch or ys")
        if xs is None:
            xs = batch.text_vec
        if ys is None:
            ys = batch.label_vec
        repadded = convert_padding_direction(xs, self.dict.pad(), right_to_left=True)
        sample = {}
        sample["id"] = torch.arange(len(xs) - 1)
        sample["net_input"] = {
            "src_tokens": repadded,
            "src_lengths": self._seq_length(xs),
        }
        if ys is not None:
            sample["target"] = ys
            sample["ntokens"] = sum(self._seq_length(ys)).item()
            sample["net_input"]["prev_output_tokens"] = self._right_shifted_ys(ys)
        return sample
    def __init__(self, opt, shared=None):
        # In general use a basic TorchAgent wherever possible
        super().__init__(opt, shared)
        if not shared:
            # this is not a shared instance of this class, so do full initialization

            # check early if we're going to be loading the model from a checkpoint
            model_file_exists = (
                self.opt.get('model_file') and os.path.isfile(self.opt['model_file'])
            )

            # fairseq expects options to be in argparse format, instead of a dict
            # We also need to do some argument postprocessing and whatnot
            # We'll skip pretrained embeddings if we're going to override them with
            # a model checkpoint anyway
            self.args, self.opt = _fairseq_opt_wrapper(opt, model_file_exists)

            # seed the RNG
            torch.manual_seed(self.args.seed)

            # Just some identifying info
            self.id = "fairseq:{}".format(self.args.arch)

            # We need a placeholder task for fairseq
            self.task = _ParlaiTask(self.dict)

            # meters for keeping track of loss, ppl, etc.
            self.meters = defaultdict(AverageMeter)

            # actually construct the model and generator
            self.model = self.build_model()

            # Construct the generator and scorer
            self.generator = SequenceGenerator(
                [self.model],
                tgt_dict=self.dict,
                beam_size=self.args.beam,
                stop_early=(not self.args.no_early_stop),
                normalize_scores=(not self.args.unnormalized),
                len_penalty=self.args.lenpen,
                unk_penalty=self.args.unkpen,
                sampling=self.args.sampling,
                sampling_topk=self.args.sampling_topk,
                sampling_temperature=self.args.sampling_temperature,
            )
            self.scorer = SequenceScorer([self.model], self.dict)

            # set up the grader and the trainer
            self.criterion = criterions.build_criterion(self.args, self.task)

            # TODO: we might choose to add a --no-fp16 opt in the future to
            # explicitly disable fp16 instead
            if not self.args.fp16 and torch.cuda.get_device_capability(0)[0] >= 7:
                print("Heads up: using --fp16 could be a lot faster!")
            if self.use_cuda:
                self.trainer = trainer.Trainer(
                    self.args, self.task, self.model, self.criterion, None,
                )
                self.trainer._build_optimizer()
            else:
                self.trainer = None

            # if the model already existed, let's preload it and the trainer
            if model_file_exists:
                print('Loading existing model params from ' + self.opt['model_file'])
                self.load(self.opt.get('model_file'))

            # move things to the GPU if possible
            if self.use_cuda:
                self.model = self.model.cuda()
                self.generator = self.generator.cuda()
        else:
            self.model = shared['model']
            self.trainer = shared['trainer']
            self.generator = shared['generator']
            self.dict = shared['dict']
            self.args = shared['args']
            self.meters = shared['meters']

        # Start things off clean
        self.reset()
예제 #3
0
def main(parsed_args):
    assert parsed_args.path is not None, '--path required for evaluation!'

    print(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load ensemble
    print('| loading model(s) from {}'.format(parsed_args.path))
    models, args = utils.load_ensemble_for_inference(
        parsed_args.path.split(':'), task)

    for arg in vars(parsed_args).keys():
        if arg not in {
                'self_target', 'future_target', 'past_target',
                'tokens_per_sample', 'output_size_dictionary'
        }:
            setattr(args, arg, getattr(parsed_args, arg))
    task = tasks.setup_task(args)

    # Load dataset splits
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(task.dataset(args.gen_subset))))

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
        if args.fp16:
            model.half()

    assert len(models) > 0

    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            *[model.max_positions() for model in models]),
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        ignore_invalid_inputs=True,
    ).next_epoch_itr(shuffle=False)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(models, task.target_dictionary)
    if use_cuda:
        scorer.cuda()

    score_sum = 0.
    count = 0

    if args.remove_bpe is not None:
        bpe_cont = args.remove_bpe.rstrip()
        bpe_toks = set(i for i in range(len(task.dictionary))
                       if task.dictionary[i].endswith(bpe_cont))
        bpe_len = len(bpe_cont)
    else:
        bpe_toks = None
        bpe_len = 0

    word_stats = dict()

    with progress_bar.build_progress_bar(args, itr) as t:
        results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
        wps_meter = TimeMeter()
        for _, src_tokens, __, hypos in results:
            for hypo in hypos:
                pos_scores = hypo['positional_scores']

                skipped_toks = 0
                if bpe_toks is not None:
                    for i in range(len(hypo['tokens']) - 1):
                        if hypo['tokens'][i].item() in bpe_toks:
                            skipped_toks += 1
                            pos_scores[i + 1] += pos_scores[i]
                            pos_scores[i] = 0

                inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(
                    float('-inf'))
                if inf_scores.any():
                    print(
                        '| Skipping tokens with inf scores:',
                        task.target_dictionary.string(
                            hypo['tokens'][inf_scores.nonzero()]))
                    pos_scores = pos_scores[(~inf_scores).nonzero()]
                score_sum += utils.item(pos_scores.sum())
                count += pos_scores.numel() - skipped_toks

                if args.output_word_probs or args.output_word_stats:
                    w = ''
                    word_prob = []
                    is_bpe = False
                    for i in range(len(hypo['tokens'])):
                        w_ind = hypo['tokens'][i].item()
                        w += task.dictionary[w_ind]
                        if bpe_toks is not None and w_ind in bpe_toks:
                            w = w[:-bpe_len]
                            is_bpe = True
                        else:
                            word_prob.append((w, pos_scores[i].item()))
                            word_stats.setdefault(w, WordStat(w, is_bpe)).add(
                                pos_scores[i].item())
                            is_bpe = False
                            w = ''
                    if args.output_word_probs:
                        print('\t'.join('{} [{:2f}]'.format(x[0], x[1])
                                        for x in word_prob))

            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(
        gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss,
                                                      np.exp(avg_nll_loss)))

    if args.output_word_stats:
        for ws in sorted(word_stats.values(),
                         key=lambda x: x.count,
                         reverse=True):
            print(ws)
예제 #4
0
def main(parsed_args):
    if parsed_args.dstore_mmap is not None:
        d = os.path.dirname(parsed_args.dstore_mmap)
        print('mmap from {}'.format(d))
        if not os.path.exists(d):
            print('making dir')
            os.system('mkdir -p {}'.format(d))

    utils.import_user_module(parsed_args)

    logger.info(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load model.
    hf_tokenizer = AutoTokenizer.from_pretrained(parsed_args.hf_model)
    if parsed_args.hf_enc_mode == 'masked':
        hf_model = AutoModelForMaskedLM.from_pretrained(parsed_args.hf_model)
    elif parsed_args.hf_enc_mode == 'causal':
        hf_model = AutoModelForCausalLM.from_pretrained(parsed_args.hf_model)

    if use_cuda:
        hf_model.cuda()

    device = next(hf_model.parameters()).device

    check_input_ids = hf_tokenizer('hello world')['input_ids']
    add_cls_token = check_input_ids[0] == hf_tokenizer.cls_token_id
    add_sep_token = check_input_ids[-1] == hf_tokenizer.sep_token_id
    print('add_cls_token = {} {} {}'.format(add_cls_token, hf_tokenizer.cls_token, hf_tokenizer.cls_token_id))
    print('add_sep_token = {} {} {}'.format(add_sep_token, hf_tokenizer.sep_token, hf_tokenizer.sep_token_id))

    args = copy.deepcopy(parsed_args)

    # reduce tokens per sample by the required context window size
    args.tokens_per_sample -= args.context_window
    task = tasks.setup_task(args)

    # Load dataset splits
    task.load_dataset(args.gen_subset)
    task_dataset = task.dataset(args.gen_subset)
    assert args.context_window > 0
    dataset = LMContextWindowDataset(
        dataset=task_dataset,
        tokens_per_sample=args.tokens_per_sample,
        context_window=args.context_window,
        pad_idx=task.source_dictionary.pad(),
    )
    logger.info('{} {} {} examples'.format(args.data, args.gen_subset, len(dataset)))

    model_max_length = min(hf_tokenizer.model_max_length, parsed_args.hf_max_position)

    itr = task.get_batch_iterator(
        dataset=dataset,
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(*[
            model_max_length
        ]),
        ignore_invalid_inputs=True,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)
    #).next_epoch_itr(shuffle=True)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(task.target_dictionary, args.softmax_batch, args=args)

    score_sum = 0.
    count = 0

    if args.remove_bpe is not None:
        if args.remove_bpe == 'sentencepiece':
            raise NotImplementedError
        else:
            bpe_cont = args.remove_bpe.rstrip()
            bpe_toks = {
                i
                for i in range(len(task.source_dictionary))
                if task.source_dictionary[i].endswith(bpe_cont)
            }
        bpe_len = len(bpe_cont)
    else:
        bpe_toks = None
        bpe_len = 0

    word_stats = dict()

    if args.knnlm and args.save_knnlm_dstore:
        raise ValueError("Cannot use knnlm while trying to build the datastore!")

    if args.knnlm:
        knn_dstore = KNN_Dstore(args)

    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()

        if args.save_knnlm_dstore:
            print('keytype being saved:', args.knn_keytype)
            dstore_keys = np.memmap(args.dstore_mmap+'_keys.npy', dtype=np.float32, mode='w+', shape=(args.dstore_size, hf_model.config.d_model))
            dstore_vals = np.memmap(args.dstore_mmap+'_vals.npy', dtype=np.int, mode='w+', shape=(args.dstore_size, 1))

        if args.save_extra:
            writer = Writer(outdir='demo-out', max_size=args.save_extra_max_size, k=args.k, vec_size=1024)

        def pad(x, pad_id=-1):
            max_len = max([len(xx) for xx in x])
            x = [xx + [pad_id] * (max_len - len(xx)) for xx in x]
            return x

        def batchify(batch):
            new_batch = {}
            new_batch['input_ids'] = torch.tensor(pad(batch['src_tokens'], hf_tokenizer.pad_token_id), dtype=torch.long, device=device)
            new_batch['context_mask'] = torch.tensor(pad(batch['mask'], -1), dtype=torch.long, device=device)
            new_batch['word_id'] = torch.tensor(pad(batch['word_id'], -1), dtype=torch.long, device=device)
            new_batch['target'] = torch.tensor(pad(batch['target'], -1), dtype=torch.long, device=device)
            return new_batch

        dstore_idx = 0
        dstore_full = False
        num_tokens = 0
        for ex_i, sample in tqdm(enumerate(t), desc='encode'):
            if 'net_input' not in sample:
                continue

            all_tokens = torch.cat([sample['net_input']['src_tokens'], sample['target'][:, -1, None]], -1)

            hf_batch = collections.defaultdict(list)
            for tok in all_tokens.tolist():
                tok = [tt for tt in tok if tt != dataset.pad_idx]
                raw_text = [task_dataset.vocab[tt] for tt in tok]
                hf_src_tokens, hf_target, hf_raw_target, hf_raw_text, hf_word_id, hf_mask = [], [], [], [], [], []
                for i_w in range(len(raw_text) - 1):

                    w = raw_text[i_w]
                    tok_ = hf_tokenizer.encode(w, add_special_tokens=False)
                    if i_w == 0 and add_cls_token:
                        if tok_[0] != hf_tokenizer.cls_token_id:
                            tok_ = [hf_tokenizer.cls_token_id] + tok_

                    if len(hf_src_tokens) + len(tok_) > model_max_length:
                        break

                    hf_src_tokens += tok_
                    hf_raw_text += hf_tokenizer.convert_ids_to_tokens(tok_)
                    hf_word_id += [i_w] * len(tok_)
                    hf_mask += [0] * (len(tok_) - 1) + [1]
                    hf_target += [tok[i_w + 1]] * len(tok_)
                    hf_raw_target += [raw_text[i_w + 1]]

                assert len(hf_src_tokens) == len(hf_target)
                assert len(hf_src_tokens) == len(hf_word_id)
                assert len(hf_src_tokens) == len(hf_mask)

                hf_batch['src_tokens'].append(hf_src_tokens)
                hf_batch['target'].append(hf_target) # This is indexed by KNN-LM tokenizer.
                hf_batch['raw_target'].append(hf_raw_target)
                hf_batch['word_id'].append(hf_word_id)
                hf_batch['mask'].append(hf_mask)

                num_tokens += len(hf_src_tokens)

            hf_batch_ = batchify(hf_batch)

            model_output = hf_model(hf_batch_['input_ids'], output_hidden_states=True)

            h = model_output['hidden_states'][-1]

            assert h.shape[:2] == hf_batch_['input_ids'].shape[:2]

            if args.save_knnlm_dstore and not dstore_full:

                flat_h = h.view(-1, hf_model.config.d_model)
                mask_ = hf_batch_['context_mask'].view(-1) == 1
                keys_ = flat_h[mask_]
                vals_ = hf_batch_['target'].view(-1, 1)[mask_]

                shape = keys_.shape
                if dstore_idx + shape[0] > args.dstore_size:
                    shape = [args.dstore_size - dstore_idx]
                    dstore_full = True

                keys_ = keys_[:shape[0]]
                vals_ = vals_[:shape[0]]

                assert keys_.shape[0] == vals_.shape[0]

                dstore_keys[dstore_idx:shape[0]+dstore_idx] = keys_.cpu().numpy().astype(np.float32)
                dstore_vals[dstore_idx:shape[0]+dstore_idx] = vals_.cpu().numpy().astype(np.int)

                dstore_idx += shape[0]
            if dstore_full:
                print('Datastore is full with {} items.'.format(args.dstore_size))

            wps_meter.update(sample['ntokens'])
            t.log({'wps': round(wps_meter.avg)})

            # Write saved values to disk.
            if args.save_extra:
                writer.update(extra)

    if args.save_knnlm_dstore:
        print("dstore_idx", dstore_idx, "final shape", shape)
        print("Keys", dstore_keys.shape, dstore_keys.dtype)
        print("Vals", dstore_vals.shape, dstore_vals.dtype)

    logger.info('done with {} tokens'.format(num_tokens))
예제 #5
0
    def test_sequence_scorer(self):
        # construct dummy dictionary
        d = test_utils.dummy_dictionary(vocab_size=2)
        self.assertEqual(d.pad(), 1)
        self.assertEqual(d.eos(), 2)
        self.assertEqual(d.unk(), 3)
        eos = d.eos()
        w1 = 4
        w2 = 5

        # construct dataloader
        data = [
            {
                "source": torch.LongTensor([w1, w2, eos]),
                "target": torch.LongTensor([w1, w2, w1, eos]),
            },
            {
                "source": torch.LongTensor([w2, eos]),
                "target": torch.LongTensor([w2, w1, eos]),
            },
            {
                "source": torch.LongTensor([w2, eos]),
                "target": torch.LongTensor([w2, eos]),
            },
        ]
        data_itr = test_utils.dummy_dataloader(data)

        # specify expected output probabilities
        args = argparse.Namespace()
        unk = 0.0
        args.beam_probs = [
            # step 0:
            torch.FloatTensor([
                # eos      w1   w2
                [0.0, unk, 0.6, 0.4],  # sentence 1
                [0.0, unk, 0.4, 0.6],  # sentence 2
                [0.0, unk, 0.7, 0.3],  # sentence 3
            ]),
            # step 1:
            torch.FloatTensor([
                # eos      w1   w2
                [0.0, unk, 0.2, 0.7],  # sentence 1
                [0.0, unk, 0.8, 0.2],  # sentence 2
                [0.7, unk, 0.1, 0.2],  # sentence 3
            ]),
            # step 2:
            torch.FloatTensor([
                # eos       w1    w2
                [0.10, unk, 0.50, 0.4],  # sentence 1
                [0.15, unk, 0.15, 0.7],  # sentence 2
                [0.00, unk, 0.00, 0.0],  # sentence 3
            ]),
            # step 3:
            torch.FloatTensor([
                # eos      w1    w2
                [0.9, unk, 0.05, 0.05],  # sentence 1
                [0.0, unk, 0.00, 0.0],  # sentence 2
                [0.0, unk, 0.00, 0.0],  # sentence 3
            ]),
        ]
        expected_scores = [
            [0.6, 0.7, 0.5, 0.9],  # sentence 1
            [0.6, 0.8, 0.15],  # sentence 2
            [0.3, 0.7],  # sentence 3
        ]

        task = test_utils.TestTranslationTask.setup_task(args, d, d)
        model = task.build_model(args)
        scorer = SequenceScorer(task.target_dictionary)
        for sample in data_itr:
            hypos = task.inference_step(scorer, [model], sample)
            for id, hypos_id in zip(sample["id"].tolist(), hypos):
                self.assertHypoTokens(hypos_id[0], data[id]["target"])
                self.assertHypoScore(hypos_id[0], expected_scores[id])
예제 #6
0
    def build_generator(self, args):
        if getattr(args, 'score_reference', False):
            from fairseq.sequence_scorer import SequenceScorer
            return SequenceScorer(
                self.target_dictionary,
                compute_alignment=getattr(args, 'print_alignment', False),
            )

        from fairseq.sequence_generator import SequenceGenerator, SequenceGeneratorWithAlignment

        # Choose search strategy. Defaults to Beam Search.
        sampling = getattr(args, 'sampling', False)
        sampling_topk = getattr(args, 'sampling_topk', -1)
        sampling_topp = getattr(args, 'sampling_topp', -1.0)
        diverse_beam_groups = getattr(args, 'diverse_beam_groups', -1)
        diverse_beam_strength = getattr(args, 'diverse_beam_strength', 0.5),
        match_source_len = getattr(args, 'match_source_len', False)
        diversity_rate = getattr(args, 'diversity_rate', -1)
        if (sum(
                int(cond) for cond in [
                    sampling,
                    diverse_beam_groups > 0,
                    match_source_len,
                    diversity_rate > 0,
                ]) > 1):
            raise ValueError(
                'Provided Search parameters are mutually exclusive.')
        assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling'
        assert sampling_topp < 0 or sampling, '--sampling-topp requires --sampling'

        if sampling:
            search_strategy = search.Sampling(self.target_dictionary,
                                              sampling_topk, sampling_topp)
        elif diverse_beam_groups > 0:
            search_strategy = search.DiverseBeamSearch(self.target_dictionary,
                                                       diverse_beam_groups,
                                                       diverse_beam_strength)
        elif match_source_len:
            # this is useful for tagging applications where the output
            # length should match the input length, so we hardcode the
            # length constraints for simplicity
            search_strategy = search.LengthConstrainedBeamSearch(
                self.target_dictionary,
                min_len_a=1,
                min_len_b=0,
                max_len_a=1,
                max_len_b=0,
            )
        elif diversity_rate > -1:
            search_strategy = search.DiverseSiblingsSearch(
                self.target_dictionary, diversity_rate)
        else:
            search_strategy = search.BeamSearch(self.target_dictionary)

        if getattr(args, 'print_alignment', False):
            seq_gen_cls = SequenceGeneratorWithAlignment
        else:
            seq_gen_cls = SequenceGenerator

        return seq_gen_cls(
            self.target_dictionary,
            beam_size=getattr(args, 'beam', 5),
            max_len_a=getattr(args, 'max_len_a', 0),
            max_len_b=getattr(args, 'max_len_b', 200),
            min_len=getattr(args, 'min_len', 1),
            normalize_scores=(not getattr(args, 'unnormalized', False)),
            len_penalty=getattr(args, 'lenpen', 1),
            unk_penalty=getattr(args, 'unkpen', 0),
            temperature=getattr(args, 'temperature', 1.),
            match_source_len=getattr(args, 'match_source_len', False),
            no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0),
            search_strategy=search_strategy,
        )
예제 #7
0
def eval_from_file(models,
                   task,
                   args,
                   use_cuda,
                   source_filename=None,
                   target_filename=None,
                   score_filename=None):
    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # I/O files
    source_filename = source_filename if source_filename is not None else args.source_file
    target_filename = target_filename if target_filename is not None else args.target_file
    score_filename = score_filename if score_filename is not None else args.score_file
    if score_filename is None:
        score_filename = target_filename + ".eval.score"
    outfile = open(score_filename, "w")

    # Get sorted input (and reversed)
    sorted_inputs, sorted_keys, sorted_targets = _get_sorted_inputs(
        source_filename, args.num_shards, args.delimiter, target_filename,
        args.shard_id, args.dup_src, args.dup_tgt)

    # Build input iterator
    src_tokens = [
        tokenizer.Tokenizer.tokenize(src_str, src_dict,
                                     add_if_not_exist=False).long()
        for src_str in sorted_inputs
    ]
    tgt_tokens = [
        tokenizer.Tokenizer.tokenize(tgt_str, tgt_dict,
                                     add_if_not_exist=False).long()
        for tgt_str in sorted_targets
    ] if sorted_targets is not None else None
    src_sizes = np.array([t.numel() for t in src_tokens])
    tgt_sizes = np.array([t.numel() for t in tgt_tokens])
    print('| loading {} examples, {} tokens'.format(len(sorted_inputs),
                                                    sum(src_sizes)))

    dataset = data.LanguagePairDataset(src_tokens,
                                       src_sizes,
                                       src_dict,
                                       tgt_tokens,
                                       tgt_sizes,
                                       tgt_dict,
                                       shuffle=False)

    itr = task.get_batch_iterator(
        dataset=dataset,
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # itr = data.EpochBatchIterator(
    #     dataset=dataset,
    #     max_tokens=args.max_tokens,
    #     max_sentences=args.max_sentences,
    #     max_positions=models[0].max_positions(),
    #     ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
    #     required_batch_size_multiple=8,
    #     num_shards=args.num_shards,
    #     shard_id=args.shard_id,
    # ).next_epoch_itr(shuffle=False)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(models, task.target_dictionary)
    if use_cuda:
        scorer.cuda()

    all_scores = dict()
    score_sum = 0.
    count, sen_count = 0, 0
    results = scorer.score_batched_itr(itr, cuda=use_cuda, timer=gen_timer)
    wps_meter = TimeMeter()
    for sample_id, src_tokens, target_tokens, hypos in results:
        for i, hypo in enumerate(hypos):
            pos_scores = hypo['positional_scores']
            inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(
                float('-inf'))
            if inf_scores.any():
                print(
                    '| Skipping tokens with inf scores:',
                    task.target_dictionary.string(
                        hypo['tokens'][inf_scores.nonzero()]))
                pos_scores = pos_scores[(~inf_scores).nonzero()]
            score_sum += pos_scores.sum()
            count += pos_scores.numel()
            sentence_score = hypo['score']
            if i == 0:
                all_scores[sample_id.tolist()] = sentence_score
        sen_count += 1
        wps_meter.update(src_tokens.size(0))

    print("| [eval] writing scores into {}".format(score_filename))
    # print(sids)
    for index in range(len(sorted_inputs)):
        outfile.write("{}{}".format(all_scores[sorted_keys[index]],
                                    args.delimiter))
    outfile.close()

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(
        gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(
        float(avg_nll_loss), np.exp(float(avg_nll_loss))))
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Set up functions for multiturn
    def train_path(lang):
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
            fname += ".{lang}".format(lang=lang)
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
        return dest_path("dict", lang) + ".txt"

    def build_dictionary(filenames, src=False, tgt=False):
        assert src ^ tgt
        return task.build_dictionary(
            filenames,
            workers=args.workers,
            threshold=args.thresholdsrc if src else args.thresholdtgt,
            nwords=args.nwordssrc if src else args.nwordstgt,
            padding_factor=args.padding_factor,
        )

    def make_binary_dataset(input_prefix, output_prefix, lang, num_workers):
        dict = task.load_dictionary(dict_path(lang))
        print("| [{}] Dictionary: {} types".format(lang, len(dict) - 1))
        n_seq_tok = [0, 0]
        replaced = Counter()

        def merge_result(worker_result):
            replaced.update(worker_result["replaced"])
            n_seq_tok[0] += worker_result["nseq"]
            n_seq_tok[1] += worker_result["ntok"]

        input_file = "{}{}".format(input_prefix,
                                   ("." + lang) if lang is not None else "")
        offsets = Tokenizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
            pool = Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(
                    binarize,
                    (
                        args,
                        input_file,
                        dict,
                        prefix,
                        lang,
                        offsets[worker_id],
                        offsets[worker_id + 1],
                    ),
                    callback=merge_result,
                )
            pool.close()

        ds = indexed_dataset.IndexedDatasetBuilder(
            dataset_dest_file(args, output_prefix, lang, "bin"))
        merge_result(
            Tokenizer.binarize(input_file,
                               dict,
                               lambda t: ds.add_item(t),
                               offset=0,
                               end=offsets[1]))
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, lang)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))

        print("| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
            lang,
            input_file,
            n_seq_tok[0],
            n_seq_tok[1],
            100 * sum(replaced.values()) / n_seq_tok[1],
            dict.unk_word,
        ))

    def make_dataset(input_prefix, output_prefix, lang, num_workers=1):
        if args.output_format == "binary":
            make_binary_dataset(input_prefix, output_prefix, lang, num_workers)
        elif args.output_format == "raw":
            # Copy original text file to destination folder
            output_text_file = dest_path(
                output_prefix +
                ".{}-{}".format(args.source_lang, args.target_lang),
                lang,
            )
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)

    def make_all(lang):
        if args.multiturnpref:
            make_dataset(args.multiturnpref,
                         "test",
                         lang,
                         num_workers=args.workers)

    # Load dataset splits
    task = tasks.setup_task(args)

    # Multiturn tracking: prompt in test set, turn in debate
    turn = 0
    prompt = 1
    first_pass = True
    while first_pass or args.multiturn:
        if args.multiturn:
            # Set up first turn
            if turn == 0:
                multiturn_file = "{}{}".format(args.multiturnpref,
                                               ("." + args.source_lang))
                test_file = "{}{}".format(args.testpref,
                                          ("." + args.source_lang))
                if args.interactive:
                    line = input('What subject would you like to debate?')
                else:
                    with open(test_file, 'r', encoding='utf-8') as f:
                        for i in range(prompt):
                            line = f.readline()
                with open(multiturn_file, 'w', encoding='utf-8') as f:
                    f.write(line)
                prompt += 1

            target = not args.only_source
            assert (args.multiturnpref), "--multiturnpref must be set"
            if args.joined_dictionary:
                assert (
                    not args.srcdict or not args.tgtdict
                ), "cannot use both --srcdict and --tgtdict with --joined-dictionary"

                if args.srcdict:
                    src_dict = task.load_dictionary(args.srcdict)
                elif args.tgtdict:
                    src_dict = task.load_dictionary(args.tgtdict)
                else:
                    assert (
                        args.trainpref
                    ), "--trainpref must be set if --srcdict is not specified"
                    src_dict = build_dictionary(
                        {
                            train_path(lang)
                            for lang in [args.source_lang, args.target_lang]
                        },
                        src=True)
                tgt_dict = src_dict
            else:
                if args.srcdict:
                    src_dict = task.load_dictionary(args.srcdict)
                else:
                    assert (
                        args.trainpref
                    ), "--trainpref must be set if --srcdict is not specified"
                    src_dict = build_dictionary([train_path(args.source_lang)],
                                                src=True)
            if target:
                if args.tgtdict:
                    tgt_dict = task.load_dictionary(args.tgtdict)
                else:
                    assert (
                        args.trainpref
                    ), "--trainpref must be set if --tgtdict is not specified"
                    tgt_dict = build_dictionary([train_path(args.target_lang)],
                                                tgt=True)
            else:
                tgt_dict = None

            src_dict.save(dict_path(args.source_lang))
            if target and tgt_dict is not None:
                tgt_dict.save(dict_path(args.target_lang))

            make_all(args.source_lang)
            if target:
                make_all(args.target_lang)
            if first_pass:
                print("| Wrote preprocessed data to {}".format(args.destdir))
                print('| Generating multiturn debate')
            task.load_dataset('test')
        else:
            task.load_dataset(args.gen_subset)
            print('| {} {} {} examples'.format(
                args.data, args.gen_subset,
                len(task.dataset(args.gen_subset))))

        if first_pass:
            # Set dictionaries
            src_dict = task.source_dictionary
            tgt_dict = task.target_dictionary

            # Load ensemble
            print('| loading model(s) from {}'.format(args.path))
            models, _model_args = utils.load_ensemble_for_inference(
                args.path.split(':'),
                task,
                model_arg_overrides=eval(args.model_overrides),
            )

            # Optimize ensemble for generation
            for model in models:
                model.make_generation_fast_(
                    beamable_mm_beam_size=None
                    if args.no_beamable_mm else args.beam,
                    need_attn=args.print_alignment,
                )
                if args.fp16:
                    model.half()

        # Load alignment dictionary for unknown word replacement
        # (None if no unknown word replacement, empty if no path to align dictionary)
        align_dict = utils.load_align_dict(args.replace_unk)

        # Load dataset (possibly sharded)
        itr = task.get_batch_iterator(
            dataset=task.dataset(args.gen_subset),
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                *[model.max_positions() for model in models]),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=8,
            num_shards=args.num_shards,
            shard_id=args.shard_id,
            num_workers=args.num_workers,
        ).next_epoch_itr(shuffle=False)

        # Initialize generator
        gen_timer = StopwatchMeter()
        if args.score_reference:
            translator = SequenceScorer(models, task.target_dictionary)
        else:
            translator = SequenceGenerator(
                models,
                task.target_dictionary,
                beam_size=args.beam,
                minlen=args.min_len,
                stop_early=(not args.no_early_stop),
                normalize_scores=(not args.unnormalized),
                len_penalty=args.lenpen,
                unk_penalty=args.unkpen,
                sampling=args.sampling,
                sampling_topk=args.sampling_topk,
                sampling_temperature=args.sampling_temperature,
                diverse_beam_groups=args.diverse_beam_groups,
                diverse_beam_strength=args.diverse_beam_strength,
                match_source_len=args.match_source_len,
                no_repeat_ngram_size=args.no_repeat_ngram_size,
            )

        if use_cuda:
            translator.cuda()

        # Generate and compute BLEU score
        if args.sacrebleu:
            scorer = bleu.SacrebleuScorer()
        else:
            scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(),
                                 tgt_dict.unk())
        num_sentences = 0
        has_target = True
        with progress_bar.build_progress_bar(args, itr) as t:
            if args.score_reference:
                translations = translator.score_batched_itr(t,
                                                            cuda=use_cuda,
                                                            timer=gen_timer)
            else:
                translations = translator.generate_batched_itr(
                    t,
                    maxlen_a=args.max_len_a,
                    maxlen_b=args.max_len_b,
                    cuda=use_cuda,
                    timer=gen_timer,
                    prefix_size=args.prefix_size,
                )

            wps_meter = TimeMeter()
            for sample_id, src_tokens, target_tokens, hypos in translations:

                # Process input and ground truth
                has_target = target_tokens is not None
                target_tokens = target_tokens.int().cpu(
                ) if has_target else None

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(
                        args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(
                        args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    src_str = src_dict.string(src_tokens, args.remove_bpe)
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

                if not args.quiet:
                    print('S-{}\t{}'.format(sample_id, src_str))
                    if has_target:
                        print('T-{}\t{}'.format(sample_id, target_str))

                # Process top predictions
                for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'].int().cpu()
                        if hypo['alignment'] is not None else None,
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    if not args.quiet:
                        print('H-{}\t{}\t{}'.format(sample_id, hypo['score'],
                                                    hypo_str))
                        print('P-{}\t{}'.format(
                            sample_id, ' '.join(
                                map(
                                    lambda x: '{:.4f}'.format(x),
                                    hypo['positional_scores'].tolist(),
                                ))))

                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id, ' '.join(
                                    map(lambda x: str(utils.item(x)),
                                        alignment))))

                    if args.multiturn:
                        multiturn_file = "{}{}".format(
                            args.multiturnpref, ("." + args.source_lang))
                        output_file = "{}{}".format(args.outputpref,
                                                    ("." + args.target_lang))
                        with open(multiturn_file, 'r', encoding='utf-8') as f:
                            line = f.readline()
                            if args.interactive:
                                interactive_response = input('Please respond:')
                                line += f' <EOA> {interactive_response}'
                        if turn < MAX_TURNS - 1:
                            with open(multiturn_file, 'w',
                                      encoding='utf-8') as f:
                                f.write(f'{line[:-1]} <EOA> {hypo_str}')
                            turn += 1
                        elif turn == MAX_TURNS - 1:
                            with open(output_file, 'a', encoding='utf-8') as f:
                                f.write(f'{line[:-1]} <EOA> {hypo_str}\n')
                            turn = 0

                    # Score only the top hypothesis
                    if has_target and i == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tokenizer.Tokenizer.tokenize(
                                target_str, tgt_dict, add_if_not_exist=True)
                        if hasattr(scorer, 'add_string'):
                            scorer.add_string(target_str, hypo_str)
                        else:
                            scorer.add(target_tokens, hypo_tokens)

                wps_meter.update(src_tokens.size(0))
                t.log({'wps': round(wps_meter.avg)})
                num_sentences += 1

        print(
            '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
            .format(num_sentences, gen_timer.n, gen_timer.sum,
                    num_sentences / gen_timer.sum, 1. / gen_timer.avg))
        if has_target:
            print('| Generate {} with beam={}: {}'.format(
                args.gen_subset, args.beam, scorer.result_string()))

        first_pass = False
예제 #9
0
class LMScorer(object):
    def __init__(self, parsed_args):
        self.args = parsed_args
        import_user_module(parsed_args)
        assert parsed_args.path is not None, '--path required for evaluation'

        print(parsed_args)

        self.use_cuda = torch.cuda.is_available() and not parsed_args.cpu

        self.task = tasks.setup_task(parsed_args)

        # Load ensemble
        print('| loading model(s) from {}'.format(parsed_args.path))
        self.models, args = utils.load_ensemble_for_inference(
            parsed_args.path.split(':'),
            self.task,
            model_arg_overrides=eval(parsed_args.model_overrides),
        )

        for model in self.models:
            model.make_generation_fast_()
            if self.use_cuda:
                model.cuda()

        for arg in vars(parsed_args).keys():
            if arg not in {
                    'self_target', 'future_target', 'past_target',
                    'tokens_per_sample', 'output_size_dictionary'
            }:
                setattr(args, arg, getattr(parsed_args, arg))
        self.task = tasks.setup_task(args)

        self.gen_timer = StopwatchMeter()
        self.scorer = SequenceScorer(self.task.target_dictionary)

    def score_sent(self, line):
        score_dict = self.score([line])
        return score_dict[0]

    def make_batches(self, lines):
        token_lst = [
            self.task.source_dictionary.encode_line(
                line, add_if_not_exist=False).long() for line in lines
        ]
        length_lst = torch.LongTensor([tokens.numel() for tokens in token_lst])

        ds = data.TokenBlockDataset(token_lst,
                                    length_lst,
                                    self.args.tokens_per_sample,
                                    pad=self.task.dictionary.pad(),
                                    eos=self.task.dictionary.eos(),
                                    break_mode='eos',
                                    include_targets=True)
        add_eos_for_other_targets = self.args.sample_break_mode is not None and self.args.sample_break_mode != 'none'
        itr = self.task.get_batch_iterator(
            dataset=data.MonolingualDataset(ds,
                                            ds.sizes,
                                            self.task.dictionary,
                                            self.task.target_dictionary,
                                            add_eos_for_other_targets,
                                            shuffle=False,
                                            targets=self.task.targets),
            max_tokens=self.args.max_tokens or 3000,
            max_sentences=self.args.max_sentences,
            max_positions=utils.resolve_max_positions(
                *[model.max_positions() for model in self.models]),
            num_shards=self.args.num_shards,
            shard_id=self.args.shard_id,
            ignore_invalid_inputs=True,
            num_workers=self.args.num_workers,
        ).next_epoch_itr(shuffle=False)

        return itr

    def score(self, lines):

        batch = self.make_batches(lines)

        sample_score_dict = {}

        # with progress_bar.build_progress_bar(self.args, itr) as t:
        for sample in batch:
            sample_id_lst = sample['id']
            sample = utils.move_to_cuda(sample) if self.use_cuda else sample
            if 'net_input' not in sample:
                continue

            hypos = self.scorer.generate(self.models, sample)

            # print(hypos)

            for sample_id, hypos_i in zip(sample_id_lst, hypos):
                hypo = hypos_i[0]
                pos_scores = hypo['positional_scores']

                inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(
                    float('-inf'))
                if inf_scores.any():
                    print(
                        '| Skipping tokens with inf scores:',
                        self.task.target_dictionary.string(
                            hypo['tokens'][inf_scores.nonzero()]))
                    pos_scores = pos_scores[(~inf_scores).nonzero()]
                sample_score = pos_scores.sum().cpu()
                count = pos_scores.numel()

                w_lst = []
                word_prob = []
                for i in range(len(hypo['tokens'])):
                    w_ind = hypo['tokens'][i].item()
                    w = self.task.dictionary[w_ind]
                    word_prob.append((w, pos_scores[i].item()))
                    w_lst.append(w)

                sample_score = -sample_score / count

                if not self.args.quiet:
                    if self.args.output_sent:
                        print('H-{}\t{}\t{}'.format(sample_id, sample_score,
                                                    ' '.join(w_lst)))
                    else:
                        print('H-{}\t{}'.format(sample_id, sample_score))
                sample_score_dict[sample_id.item()] = sample_score.item()
                # print(sample_id, sample_score.item())

        return sample_score_dict
예제 #10
0
def main(cfg: DictConfig, **unused_kwargs):
    if isinstance(cfg, Namespace):
        cfg = convert_namespace_to_omegaconf(cfg)

    utils.import_user_module(cfg.common)

    use_fp16 = cfg.common.fp16
    use_cuda = torch.cuda.is_available() and not cfg.common.cpu

    if use_cuda:
        torch.cuda.set_device(cfg.distributed_training.device_id)

    logger.info(cfg)

    # Load ensemble
    logger.info("loading model(s) from {}".format(cfg.common_eval.path))

    # reduce tokens per sample by the required context window size
    cfg.task.tokens_per_sample -= cfg.eval_lm.context_window

    # Initialize the task using the current *cfg*
    task = tasks.setup_task(cfg.task)

    # Initialize the model (but not the task) using the checkpoint's *cfg*
    models, model_args, task = checkpoint_utils.load_model_ensemble_and_task(
        [cfg.common_eval.path],
        arg_overrides=eval(cfg.common_eval.model_overrides),
        suffix=cfg.checkpoint.checkpoint_suffix,
        strict=(cfg.checkpoint.checkpoint_shard_count == 1),
        num_shards=cfg.checkpoint.checkpoint_shard_count,
        task=task,
    )

    # Load dataset splits
    gen_subset = cfg.dataset.gen_subset
    task.load_dataset(gen_subset)
    dataset = task.dataset(gen_subset)
    if cfg.eval_lm.context_window > 0:
        dataset = LMContextWindowDataset(
            dataset=dataset,
            tokens_per_sample=cfg.task.tokens_per_sample,
            context_window=cfg.eval_lm.context_window,
            pad_idx=task.source_dictionary.pad(),
        )
    logger.info("{} {} {} examples".format(cfg.task.data, gen_subset,
                                           len(dataset)))

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        if use_fp16:
            model.half()
        if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
            model.cuda()
        model.prepare_for_inference_(cfg)

    assert len(models) > 0

    logger.info("num. model params: {}".format(
        sum(p.numel() for p in models[0].parameters())))

    itr = task.get_batch_iterator(
        dataset=dataset,
        max_tokens=cfg.dataset.max_tokens or 36000,
        max_sentences=cfg.dataset.batch_size,
        max_positions=utils.resolve_max_positions(
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=True,
        num_shards=max(
            cfg.dataset.num_shards,
            cfg.distributed_training.distributed_world_size,
        ),
        shard_id=max(
            cfg.dataset.shard_id,
            cfg.distributed_training.distributed_rank,
        ),
        num_workers=cfg.dataset.num_workers,
        data_buffer_size=cfg.dataset.data_buffer_size,
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=cfg.common.log_format,
        log_interval=cfg.common.log_interval,
        default_log_format=("tqdm"
                            if not cfg.common.no_progress_bar else "simple"),
    )

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(task.target_dictionary, cfg.eval_lm.softmax_batch)

    score_sum = 0.0
    count = 0

    if cfg.common_eval.post_process is not None:
        if cfg.common_eval.post_process == "sentencepiece":
            raise NotImplementedError
        else:
            bpe_cont = cfg.common_eval.post_process.rstrip()
            bpe_toks = {
                i
                for i in range(len(task.source_dictionary))
                if task.source_dictionary[i].endswith(bpe_cont)
            }
        bpe_len = len(bpe_cont)
    else:
        bpe_toks = None
        bpe_len = 0

    word_stats = dict()

    wps_meter = TimeMeter()

    for sample in progress:
        if "net_input" not in sample:
            continue

        sample = utils.move_to_cuda(sample) if use_cuda else sample

        gen_timer.start()
        hypos = scorer.generate(models, sample)
        gen_timer.stop(sample["ntokens"])

        for i, hypos_i in enumerate(hypos):
            hypo = hypos_i[0]
            sample_id = sample["id"][i]

            tokens = hypo["tokens"]
            tgt_len = tokens.numel()
            pos_scores = hypo["positional_scores"].float()

            if getattr(cfg.task, "add_bos_token", False):
                assert hypo["tokens"][0].item() == task.target_dictionary.bos()
                tokens = tokens[1:]
                pos_scores = pos_scores[1:]

            skipped_toks = 0
            if bpe_toks is not None:
                for i in range(tgt_len - 1):
                    if tokens[i].item() in bpe_toks:
                        skipped_toks += 1
                        pos_scores[i + 1] += pos_scores[i]
                        pos_scores[i] = 0

            inf_scores = pos_scores.eq(float("inf")) | pos_scores.eq(
                float("-inf"))
            if inf_scores.any():
                logger.info(
                    "skipping tokens with inf scores:",
                    task.target_dictionary.string(
                        tokens[inf_scores.nonzero()]),
                )
                pos_scores = pos_scores[(~inf_scores).nonzero()]
            score_sum += pos_scores.sum().cpu()
            count += pos_scores.numel() - skipped_toks

            if cfg.eval_lm.output_word_probs or cfg.eval_lm.output_word_stats:
                w = ""
                word_prob = []
                is_bpe = False
                for i in range(len(tokens)):
                    w_ind = tokens[i].item()
                    w += task.source_dictionary[w_ind]
                    if bpe_toks is not None and w_ind in bpe_toks:
                        w = w[:-bpe_len]
                        is_bpe = True
                    else:
                        word_prob.append((w, pos_scores[i].item()))

                        next_prob = None
                        ind = i + 1
                        while ind < len(tokens):
                            if pos_scores[ind].item() != 0:
                                next_prob = pos_scores[ind]
                                break
                            ind += 1

                        word_stats.setdefault(w, WordStat(w, is_bpe)).add(
                            pos_scores[i].item(), next_prob)
                        is_bpe = False
                        w = ""
                if cfg.eval_lm.output_word_probs:
                    logger.info(
                        str(int(sample_id)) + " " +
                        ("\t".join("{} [{:2f}]".format(x[0], x[1])
                                   for x in word_prob)))

        wps_meter.update(sample["ntokens"])
        progress.log({"wps": round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count / math.log(
        2) if count > 0 else 0  # convert to base 2
    logger.info("Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)".format(
        gen_timer.n, gen_timer.sum,
        1.0 / gen_timer.avg if gen_timer.avg > 0 else 0))
    logger.info("Loss (base 2): {:.4f}, Perplexity: {:.2f}".format(
        avg_nll_loss, 2**avg_nll_loss))

    if cfg.eval_lm.output_word_stats:
        for ws in sorted(word_stats.values(),
                         key=lambda x: x.count,
                         reverse=True):
            logger.info(ws)
예제 #11
0
    def test_sequence_scorer(self):
        # construct dummy dictionary
        d = test_utils.dummy_dictionary(vocab_size=2)
        self.assertEqual(d.pad(), 1)
        self.assertEqual(d.eos(), 2)
        self.assertEqual(d.unk(), 3)
        eos = d.eos()
        w1 = 4
        w2 = 5

        # construct dataloader
        data = [
            {
                'source': torch.LongTensor([w1, w2, eos]),
                'target': torch.LongTensor([w1, w2, w1, eos]),
            },
            {
                'source': torch.LongTensor([w2, eos]),
                'target': torch.LongTensor([w2, w1, eos]),
            },
            {
                'source': torch.LongTensor([w2, eos]),
                'target': torch.LongTensor([w2, eos]),
            },
        ]
        data_itr = test_utils.dummy_dataloader(data)

        # specify expected output probabilities
        args = argparse.Namespace()
        unk = 0.
        args.beam_probs = [
            # step 0:
            torch.FloatTensor([
                # eos      w1   w2
                [0.0, unk, 0.6, 0.4],  # sentence 1
                [0.0, unk, 0.4, 0.6],  # sentence 2
                [0.0, unk, 0.7, 0.3],  # sentence 3
            ]),
            # step 1:
            torch.FloatTensor([
                # eos      w1   w2
                [0.0, unk, 0.2, 0.7],  # sentence 1
                [0.0, unk, 0.8, 0.2],  # sentence 2
                [0.7, unk, 0.1, 0.2],  # sentence 3
            ]),
            # step 2:
            torch.FloatTensor([
                # eos       w1    w2
                [0.10, unk, 0.50, 0.4],  # sentence 1
                [0.15, unk, 0.15, 0.7],  # sentence 2
                [0.00, unk, 0.00, 0.0],  # sentence 3
            ]),
            # step 3:
            torch.FloatTensor([
                # eos      w1    w2
                [0.9, unk, 0.05, 0.05],  # sentence 1
                [0.0, unk, 0.00, 0.0],  # sentence 2
                [0.0, unk, 0.00, 0.0],  # sentence 3
            ]),
        ]
        expected_scores = [
            [0.6, 0.7, 0.5, 0.9],  # sentence 1
            [0.6, 0.8, 0.15],  # sentence 2
            [0.3, 0.7],  # sentence 3
        ]

        task = test_utils.TestTranslationTask.setup_task(args, d, d)
        model = task.build_model(args)
        scorer = SequenceScorer([model], task.target_dictionary)
        for id, _src, _ref, hypos in scorer.score_batched_itr(data_itr):
            self.assertHypoTokens(hypos[0], data[id]['target'])
            self.assertHypoScore(hypos[0], expected_scores[id])
예제 #12
0
class FairseqAgent(TorchAgent):
    """Generic wrapper around fairseq for use in ParlAI"""

    DEFAULT_OPTIONS = {
        "adam_betas": "(0.9,0.98)",
        "optimizer": "adam",
        "clip_norm": 0.1,
    }

    metrics = {}

    @classmethod
    def add_cmdline_args(cls, argparser):
        """Add command-line arguments specifically for this agent."""
        # first we need to add the general torch agent operations
        TorchAgent.add_cmdline_args(argparser)

        agent = argparser.add_argument_group('Fairseq Arguments')
        agent.add_argument('--fp16',
                           default=False,
                           type='bool',
                           help='Use fp16 training')
        agent.add_argument('--seed',
                           default=1,
                           type=int,
                           metavar='N',
                           help='pseudo random number generator seed')
        agent.add_argument(
            '--skip-generation',
            default=False,
            type=bool,
            metavar='BOOL',
            help=
            'Skips test time beam search. Much faster if you only need PPL',
        )

        # Dictionary construction stuff. Using the subclass in case we end up
        # needing any fairseq specific things
        cls.dictionary_class().add_cmdline_args(argparser)

        # Check subargs for generation, optimizers, criterions, archs, etc
        options.add_generation_args(argparser)
        options.add_optimization_args(argparser)

        # make sure we set defaults according to the model before parsing
        argparser.set_defaults(**cls.DEFAULT_OPTIONS)
        known_args = argparser.parse_known_args(nohelp=True)[0]

        if hasattr(known_args, "optimizer"):
            optimizer = known_args.optimizer
            opt_group = argparser.add_argument_group(
                '{} optimizer arguments'.format(optimizer))
            optim.OPTIMIZER_REGISTRY[optimizer].add_args(opt_group)
        if hasattr(known_args, "lr_scheduler"):
            lr_scheduler = known_args.lr_scheduler
            lr_group = argparser.add_argument_group(
                '{} scheduler arguments'.format(lr_scheduler))
            optim.lr_scheduler.LR_SCHEDULER_REGISTRY[lr_scheduler].add_args(
                lr_group)
        # We need to find out the fairseq model-specific options, so grab the
        # architecture stuff and look up its options
        arch_group = options.add_model_args(argparser)
        # Fairseq marks the arch flag as required, but it may be specified
        # by a saved model cache, so we do some weird stuff to undo that
        for a in arch_group._actions:
            if a.dest == "arch":
                a.required = False
                a.default = None
                break

        # make sure we set defaults according to parlai model before parsing
        argparser.set_defaults(**cls.DEFAULT_OPTIONS)
        known_args = argparser.parse_known_args(nohelp=True)[0]

        if hasattr(known_args, "arch") and known_args.arch is not None:
            arch = known_args.arch
            arch_group = argparser.add_argument_group(
                "{} architecture arguments".format(arch))
            models.ARCH_MODEL_REGISTRY[arch].add_args(arch_group)

        if hasattr(known_args, "criterion"):
            crit_group = argparser.add_argument_group(
                '{} criterion arguments'.format(known_args.criterion))
            criterions.CRITERION_REGISTRY[known_args.criterion].add_args(
                crit_group)

        # As one final check, let's make sure we set defaults correctly
        argparser.set_defaults(**cls.DEFAULT_OPTIONS)

    @staticmethod
    def dictionary_class():
        # Force use of the Fairseq Dictionary
        return _FairseqDictionary

    def __init__(self, opt, shared=None):
        # In general use a basic TorchAgent wherever possible
        super().__init__(opt, shared)
        if not shared:
            # this is not a shared instance of this class, so do full initialization

            # check early if we're going to be loading the model from a checkpoint
            model_file_exists = (self.opt.get('model_file')
                                 and os.path.isfile(self.opt['model_file']))

            # fairseq expects options to be in argparse format, instead of a dict
            # We also need to do some argument postprocessing and whatnot
            # We'll skip pretrained embeddings if we're going to override them with
            # a model checkpoint anyway
            self.args, self.opt = _fairseq_opt_wrapper(opt, model_file_exists)

            # seed the RNG
            torch.manual_seed(self.args.seed)

            # Just some identifying info
            self.id = "fairseq:{}".format(self.args.arch)

            # We need a placeholder task for fairseq
            self.task = _ParlaiTask(self.dict)

            # actually construct the model and generator
            self.model = self.build_model()

            # Construct the generator and scorer
            self.generator = SequenceGenerator(
                [self.model],
                tgt_dict=self.dict,
                beam_size=self.args.beam,
                stop_early=(not self.args.no_early_stop),
                normalize_scores=(not self.args.unnormalized),
                len_penalty=self.args.lenpen,
                unk_penalty=self.args.unkpen,
                sampling=self.args.sampling,
                sampling_topk=self.args.sampling_topk,
                sampling_temperature=self.args.sampling_temperature,
            )
            self.scorer = SequenceScorer([self.model], self.dict)

            # set up the grader and the trainer
            self.criterion = criterions.build_criterion(self.args, self.task)

            if getattr(self.args, 'fp16', None):
                self.trainer = fp16_trainer.FP16Trainer(
                    self.args, self.task, self.model, self.criterion)
            else:
                # TODO: we might choose to add a --no-fp16 opt in the future to
                # explicitly disable fp16 instead
                if torch.cuda.get_device_capability(0)[0] >= 7:
                    print("Heads up: using --fp16 could be a lot faster!")
                self.trainer = trainer.Trainer(self.args, self.task,
                                               self.model, self.criterion)
            self.trainer._build_optimizer()

            # if the model already existed, let's preload it and the trainer
            if model_file_exists:
                print('Loading existing model params from ' +
                      self.opt['model_file'])
                self.load(self.opt.get('model_file'))

            # move things to the GPU if possible
            if self.use_cuda:
                self.model = self.model.cuda()
                self.generator = self.generator.cuda()
        else:
            self.model = shared['model']
            self.trainer = shared['trainer']
            self.generator = shared['generator']
            self.dict = shared['dict']
            self.args = shared['args']

        # Start things off clean
        self.reset()

    def _check_opts_unchanged(self, saved_opts, current_opts):
        """Verify that critical options do not differ in command line vs saved model"""
        for k in NON_OVERRIDABLE_ARGS:
            if k not in saved_opts or k not in current_opts:
                # if it's not an option needed by this fairseq model, don't stress
                continue
            if saved_opts[k] != current_opts[k]:
                raise ValueError(
                    '{} cannot be overridden when --model-file is specified'.
                    format(k))

    def build_model(self):
        """
        Construct the actual Fairseq model. Default implementation is to use
        Fairseq's arch builder, but this method may be overridden to build custom
        models.
        """
        model_class = models.ARCH_MODEL_REGISTRY[self.args.arch]
        return model_class.build_model(self.args, self.task)

    def share(self):
        shared = super().share()
        shared['model'] = self.model
        shared['trainer'] = self.trainer
        shared['generator'] = self.generator
        shared['dict'] = self.dict
        shared['args'] = self.args
        return shared

    def save(self, path):
        """Save using fairseq's checkpointing."""
        if not path:
            return
        self.trainer.save_checkpoint(path, {'opt': self.opt, 'epoch': 0})
        # Parlai expects options to also be saved
        with open(path + ".opt", 'wb') as handle:
            # overridden options shouldn't be stored, only the main ones
            if 'override' in self.opt:
                del self.opt['override']
            pickle.dump(self.opt, handle, protocol=pickle.HIGHEST_PROTOCOL)

    def load(self, path):
        """Load using fairseq's checkpointing."""
        old_options = self.trainer.load_checkpoint(path)
        self._check_opts_unchanged(old_options, self.opt)

    def shutdown(self):
        if not hasattr(self, 'trainer'):
            # looks like this is a "fake" model that isn't actually used for batch_act.
            # we don't need to save this one.
            return
        super().shutdown()

    def reset(self):
        """Reset observation and episode_done."""
        super().reset()
        self.reset_metrics()

    def batchify(self, obs_batch):
        """
        Override parent batchify to set requirements for fairseq.

        Fairseq depends on sorted batch inputs for a call to rnn.pad_packed_sequence.
        Fairseq models cannot handle zero length sentences
        """
        return super().batchify(obs_batch,
                                sort=True,
                                is_valid=_is_nonempty_observation)

    def train_step(self, batch):
        """Process batch of inputs and targets and train on them.

        :param batch: parlai.core.torch_agent.Batch, contains tensorized
                      version of observations.
        """
        if batch.text_vec is None:
            return
        self.is_training = True
        samples = self._make_sample(batch.text_vec, batch.label_vec)
        self.model.train()
        self.trainer.train_step(samples)

    def eval_step(self, batch):
        """Process batch of inputs.

        If the batch includes labels, calculate validation metrics as well.
        If --skip-generation is not set, return a prediction for each input.

        :param batch: parlai.core.torch_agent.Batch, contains tensorized
                      version of observations.
        """
        if batch.text_vec is None:
            return
        self.is_training = False
        samples = self._make_sample(batch.text_vec, batch.label_vec)
        self.model.eval()
        if batch.label_vec is not None:
            # Interactive mode won't have a gold label
            self.trainer.valid_step(samples)

        # Output placeholders
        reranked_cands = None
        generated_output = None

        # Grade each of the candidate sequences
        if batch.candidate_vecs is not None:
            bsz = len(batch.text_vec)
            reranked_cands = []
            # score the candidates for each item in the batch separately, so that
            # we can support variable number of candidates
            for i in range(bsz):
                cands = batch.candidate_vecs[i]
                if not cands:
                    reranked_cands.append(None)
                    continue
                ncand = len(cands)
                # repeat the input many times
                xs = batch.text_vec[i].unsqueeze(0).expand(ncand, -1)
                # some models crash if there's leading padding on every example
                xs = xs[:, :batch.text_lengths[i]]
                # and appropriately pack the outputs
                ys, _ = padded_tensor(cands, self.NULL_IDX, self.use_cuda)
                s = self._make_sample(xs, ys)
                # perform the actual grading, extract the scores
                scored = list(
                    self.scorer.score_batched_itr([s], cuda=self.use_cuda))
                scores = [s[3][0]['score'].item() for s in scored]
                # intentional hanging comma here; argsort returns a list
                ranked, = argsort(scores, batch.candidates[i], descending=True)
                reranked_cands.append(ranked)

        # Next generate freely to create our response
        if not self.args.skip_generation:
            generated_output = self._generate(samples)
        elif reranked_cands:
            # we're skiping generation, but we're also grading candidates
            # so output the highest ranked candidate
            # In the case of zero candidates, we don't have something to rank,
            # so we may need to pass on that None
            generated_output = [
                ranked and ranked[0] or None for ranked in reranked_cands
            ]
        else:
            # no output at all
            pass

        return Output(generated_output, reranked_cands)

    def _generate(self, samples):
        src_tokens = samples["net_input"]["src_tokens"]
        src_lengths = samples["net_input"]["src_lengths"]
        gens = self.generator.generate(src_tokens, src_lengths, maxlen=64)
        responses = []
        for i in range(len(src_tokens)):
            beams = gens[i]
            selected = max(beams, key=lambda x: x["score"])
            tokens = selected["tokens"]
            start = 0
            end = -1
            for i, t in enumerate(tokens):
                t = t.item()
                if t == self.dict.bos_index:
                    # don't include <s> token
                    start = i + 1
                    continue
                if t == self.dict.eos_index:
                    # stop (and don't include) </s> token
                    end = i
                    break
            responses.append(self.dict.vec2txt(tokens[start:end]))
        return responses

    def report(self):
        """Return metrics calculated by the model."""
        # if we haven't initialized yet, just return a dummy object
        if not hasattr(self, "trainer"):
            return {}

        # These are the metrics we'll pass up the way, and their new names
        train_metrics = {"train_loss", "ups", "wps", "gnorm", "clip"}
        valid_metrics = {"valid_loss"}

        metrics = train_metrics if self.is_training else valid_metrics

        m = {k: self.trainer.meters[k].avg for k in metrics}

        # additionally output perplexity. note that fairseq models use base 2
        # in cross_entropy:
        # github.com/pytorch/fairseq/blob/master/fairseq/criterions/cross_entropy.py#L55
        if "train_loss" in m:
            m["train_ppl"] = np.exp2(m["train_loss"])
        if "valid_loss" in m:
            m["ppl"] = np.exp2(m["valid_loss"])

        for k, v in m.items():
            # clean up: rounds to sigfigs and converts tensors to floats
            m[k] = round_sigfigs(v, 4)

        return m

    def reset_metrics(self):
        """Reset metrics calculated by the model back to zero."""
        if not hasattr(self, "trainer"):
            # We haven't set up the trainer yet, so we don't have any metrics
            return
        # We need to reset everything
        for k in self.trainer.meters:
            self.trainer.meters[k].reset()

    def receive_metrics(self, metrics_dict):
        """Update lr scheduler with validation loss."""
        self.trainer.lr_step(-1, metrics_dict["valid_loss"])

    # Helper functions
    def _seq_length(self, xs):
        """Compute length of the sequence (non-padded size)."""
        return xs.ne(self.dict.pad_index).long().sum(dim=-1)

    def _right_shifted_ys(self, ys):
        """Replace first token with EOS and shift remaining tokens right 1."""
        result = torch.LongTensor(ys.size())
        result[:, 0] = self.dict.eos_index
        result[:, 1:] = ys[:, :-1]
        return result

    def _make_sample(self, xs, ys):
        """Generate a sample object that Fairseq expects."""
        # add extra info to samples
        # TODO: should the right/left padding thing be in torch agent?
        sample = {}
        sample["id"] = torch.arange(len(xs) - 1)
        sample["net_input"] = {
            "src_tokens": xs,
            "src_lengths": self._seq_length(xs),
        }
        if ys is not None:
            sample["target"] = ys
            sample["ntokens"] = sum(self._seq_length(ys)).item()
            sample["net_input"]["prev_output_tokens"] = self._right_shifted_ys(
                ys)
        return sample
예제 #13
0
def main(parsed_args, **unused_kwargs):
    assert parsed_args.path is not None, '--path required for evaluation!'

    if torch.cuda.is_available() and not parsed_args.cpu:
        torch.cuda.set_device(parsed_args.device_id)

    utils.import_user_module(parsed_args)

    logger.info(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load ensemble
    logger.info('loading model(s) from {}'.format(parsed_args.path))
    models, args = checkpoint_utils.load_model_ensemble(
        parsed_args.path.split(os.pathsep),
        arg_overrides=eval(parsed_args.model_overrides),
        task=task,
        suffix=getattr(parsed_args, "checkpoint_suffix", ""),
    )

    for arg in vars(parsed_args).keys():
        if arg not in {
            'self_target', 'future_target', 'past_target', 'tokens_per_sample',
            'output_size_dictionary', 'add_bos_token',
        }:
            setattr(args, arg, getattr(parsed_args, arg))

    # reduce tokens per sample by the required context window size
    args.tokens_per_sample -= args.context_window
    task = tasks.setup_task(args)

    # Load dataset splits
    task.load_dataset(args.gen_subset)
    dataset = task.dataset(args.gen_subset)
    if args.context_window > 0:
        dataset = LMContextWindowDataset(
            dataset=dataset,
            tokens_per_sample=args.tokens_per_sample,
            context_window=args.context_window,
            pad_idx=task.source_dictionary.pad(),
        )
    logger.info('{} {} {} examples'.format(args.data, args.gen_subset, len(dataset)))

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.prepare_for_inference_(args)
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    assert len(models) > 0

    logger.info('num. model params: {}'.format(sum(p.numel() for p in models[0].parameters())))

    itr = task.get_batch_iterator(
        dataset=dataset,
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(*[
            model.max_positions() for model in models
        ]),
        ignore_invalid_inputs=True,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
        data_buffer_size=args.data_buffer_size,
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        default_log_format=('tqdm' if not args.no_progress_bar else 'none'),
    )

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(task.target_dictionary, args.softmax_batch)

    score_sum = 0.
    count = 0

    if args.remove_bpe is not None:
        if args.remove_bpe == 'sentencepiece':
            raise NotImplementedError
        else:
            bpe_cont = args.remove_bpe.rstrip()
            bpe_toks = {
                i
                for i in range(len(task.source_dictionary))
                if task.source_dictionary[i].endswith(bpe_cont)
            }
        bpe_len = len(bpe_cont)
    else:
        bpe_toks = None
        bpe_len = 0

    word_stats = dict()

    print(os.path.dirname(args.jason_test_output))
    checkpoint_utils.verify_checkpoint_directory(os.path.dirname(args.jason_test_output))
    test_loss_writer = open(args.jason_test_output, 'w')
    # test_loss_uid_writer = open(args.jason_test_uid_output, 'w')

    wps_meter = TimeMeter()

    for sample in progress:
        if 'net_input' not in sample:
            continue

        sample = utils.move_to_cuda(sample) if use_cuda else sample

        gen_timer.start()
        hypos = scorer.generate(models, sample)
        gen_timer.stop(sample['ntokens'])

        for i, hypos_i in enumerate(hypos):
            hypo = hypos_i[0]
            sample_id = sample['id'][i]

            tokens = hypo['tokens']
            tgt_len = tokens.numel()
            pos_scores = hypo['positional_scores'].float()

            if getattr(args, 'add_bos_token', False):
                assert hypo['tokens'][0].item() == task.target_dictionary.bos()
                tokens = tokens[1:]
                pos_scores = pos_scores[1:]

            skipped_toks = 0
            if bpe_toks is not None:
                for i in range(tgt_len - 1):
                    if tokens[i].item() in bpe_toks:
                        skipped_toks += 1
                        pos_scores[i + 1] += pos_scores[i]
                        pos_scores[i] = 0

            inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
            if inf_scores.any():
                logger.info(
                    'skipping tokens with inf scores:',
                    task.target_dictionary.string(tokens[inf_scores.nonzero()])
                )
                pos_scores = pos_scores[(~inf_scores).nonzero()]
            
            score_sum += pos_scores.sum().cpu()
            count += pos_scores.numel() - skipped_toks
            # print(i, pos_scores.size(), pos_scores.cpu()[-3:], pos_scores.sum().cpu(), pos_scores.numel() - skipped_toks)
            # print(parsed_args.jason_test_output_dir)
            pos_scores_cpu = pos_scores.cpu()
            output_line = ""
            for j in range(pos_scores_cpu.size()[0]):
                nll_loss_base2 = - pos_scores_cpu[j].item() / math.log(2)
                test_loss_writer.write(f"{nll_loss_base2}\n")
                output_line += f"{nll_loss_base2:.5f},"
            output_line = output_line[:-1] + "\n"
            # test_loss_uid_writer.write(output_line)

            if args.output_word_probs or args.output_word_stats:
                w = ''
                word_prob = []
                is_bpe = False
                for i in range(len(tokens)):
                    w_ind = tokens[i].item()
                    w += task.source_dictionary[w_ind]
                    if bpe_toks is not None and w_ind in bpe_toks:
                        w = w[:-bpe_len]
                        is_bpe = True
                    else:
                        word_prob.append((w, pos_scores[i].item()))

                        next_prob = None
                        ind = i + 1
                        while ind < len(tokens):
                            if pos_scores[ind].item() != 0:
                                next_prob = pos_scores[ind]
                                break
                            ind += 1

                        word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item(), next_prob)
                        is_bpe = False
                        w = ''
                if args.output_word_probs:
                    logger.info(
                        str(int(sample_id)) + " "
                        + ('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))
                    )

        wps_meter.update(sample['ntokens'])
        progress.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count / math.log(2)  # convert to base 2
    logger.info('Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(
        gen_timer.n, gen_timer.sum, 1. / gen_timer.avg
    ))
    logger.info('Loss (base 2): {:.4f}, Perplexity: {:.2f}'.format(
        avg_nll_loss, 2**avg_nll_loss
    ))

    if args.output_word_stats:
        for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
            logger.info(ws)
예제 #14
0
    def test_sequence_scorer(self):
        # construct dummy dictionary
        d = test_utils.dummy_dictionary(vocab_size=2)
        self.assertEqual(d.pad(), 1)
        self.assertEqual(d.eos(), 2)
        self.assertEqual(d.unk(), 3)
        eos = d.eos()
        w1 = 4
        w2 = 5

        # construct dataloader
        data = [
            {
                'source': torch.LongTensor([w1, w2, eos]),
                'target': torch.LongTensor([w1, w2, w1, eos]),
            },
            {
                'source': torch.LongTensor([w2, eos]),
                'target': torch.LongTensor([w2, w1, eos]),
            },
            {
                'source': torch.LongTensor([w2, eos]),
                'target': torch.LongTensor([w2, eos]),
            },
        ]
        data_itr = test_utils.dummy_dataloader(data)

        # specify expected output probabilities
        args = argparse.Namespace()
        unk = 0.
        args.beam_probs = [
            # step 0:
            torch.FloatTensor([
                # eos      w1   w2
                [0.0, unk, 0.6, 0.4],  # sentence 1
                [0.0, unk, 0.4, 0.6],  # sentence 2
                [0.0, unk, 0.7, 0.3],  # sentence 3
            ]),
            # step 1:
            torch.FloatTensor([
                # eos      w1   w2
                [0.0, unk, 0.2, 0.7],  # sentence 1
                [0.0, unk, 0.8, 0.2],  # sentence 2
                [0.7, unk, 0.1, 0.2],  # sentence 3
            ]),
            # step 2:
            torch.FloatTensor([
                # eos       w1    w2
                [0.10, unk, 0.50, 0.4],  # sentence 1
                [0.15, unk, 0.15, 0.7],  # sentence 2
                [0.00, unk, 0.00, 0.0],  # sentence 3
            ]),
            # step 3:
            torch.FloatTensor([
                # eos      w1    w2
                [0.9, unk, 0.05, 0.05],  # sentence 1
                [0.0, unk, 0.00, 0.0],  # sentence 2
                [0.0, unk, 0.00, 0.0],  # sentence 3
            ]),
        ]
        expected_scores = [
            [0.6, 0.7, 0.5, 0.9],  # sentence 1
            [0.6, 0.8, 0.15],  # sentence 2
            [0.3, 0.7],  # sentence 3
        ]

        model = test_utils.TestModel.build_model(args, d, d)
        scorer = SequenceScorer([model])
        for id, _src, _ref, hypos in scorer.score_batched_itr(data_itr):
            self.assertHypoTokens(hypos[0], data[id]['target'])
            self.assertHypoScore(hypos[0], expected_scores[id])
예제 #15
0
def eval_lm(
    models: List[fairseq.models.FairseqModel],
    source_dictionary: fairseq.data.Dictionary,
    batch_iterator: Iterable,
    post_process: Optional[str] = None,
    output_word_probs: bool = False,
    output_word_stats: bool = False,
    target_dictionary: Optional[fairseq.data.Dictionary] = None,
    softmax_batch: int = 0,
    remove_bos_token: bool = False,
    device: Optional[torch.device] = None,
):
    """
    Args:
        models (List[~fairseq.models.FairseqModel]): list of models to
            evaluate. Models are essentially `nn.Module` instances, but
            must be compatible with fairseq's `SequenceScorer`.
        source_dictionary (~fairseq.data.Dictionary): dictionary for
            applying any relevant post processing or outputing word
            probs/stats.
        batch_iterator (Iterable): yield batches of data
        post_process (Optional[str]): post-process text by removing BPE,
            letter segmentation, etc. Valid options can be found in
            fairseq.data.utils.post_process, although not all options
            are implemented here.
        output_word_probs (Optional[bool]): output words and their
            predicted log probabilities
        output_word_stats (Optional[bool]): output word statistics such
            as word count and average probability
        target_dictionary (Optional[~fairseq.data.Dictionary]): output
            dictionary (defaults to *source_dictionary*)
        softmax_batch (Optional[bool]): if BxT is more than this, will
            batch the softmax over vocab to this amount of tokens, in
            order to fit into GPU memory
        remove_bos_token (Optional[bool]): if True, confirm that the
            first token is the beginning-of-sentence symbol (according
            to the relevant dictionary) and remove it from the output
        device (Optional[torch.device]): device to use for evaluation
            (defaults to device of first model parameter)
    """
    if target_dictionary is None:
        target_dictionary = source_dictionary
    if device is None:
        device = next(models[0].parameters()).device

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(target_dictionary, softmax_batch)

    score_sum = 0.0
    count = 0

    if post_process is not None:
        if post_process in {"subword_nmt", "@@ "}:
            bpe_cont = post_process.rstrip()
            bpe_toks = {
                i
                for i in range(len(source_dictionary))
                if source_dictionary[i].endswith(bpe_cont)
            }
        else:
            raise NotImplementedError(
                "--post-process={post_process} is not implemented")
        bpe_len = len(bpe_cont)
    else:
        bpe_toks = None
        bpe_len = 0

    word_stats = dict()

    for sample in batch_iterator:
        if "net_input" not in sample:
            continue

        sample = utils.move_to_cuda(sample, device=device)

        gen_timer.start()
        hypos = scorer.generate(models, sample)
        gen_timer.stop(sample["ntokens"])

        for i, hypos_i in enumerate(hypos):
            hypo = hypos_i[0]
            sample_id = sample["id"][i]

            tokens = hypo["tokens"]
            tgt_len = tokens.numel()
            pos_scores = hypo["positional_scores"].float()

            if remove_bos_token:
                assert hypo["tokens"][0].item() == target_dictionary.bos()
                tokens = tokens[1:]
                pos_scores = pos_scores[1:]

            skipped_toks = 0
            if bpe_toks is not None:
                for i in range(tgt_len - 1):
                    if tokens[i].item() in bpe_toks:
                        skipped_toks += 1
                        pos_scores[i + 1] += pos_scores[i]
                        pos_scores[i] = 0

            inf_scores = pos_scores.eq(float("inf")) | pos_scores.eq(
                float("-inf"))
            if inf_scores.any():
                logger.info(
                    "skipping tokens with inf scores:",
                    target_dictionary.string(tokens[inf_scores.nonzero()]),
                )
                pos_scores = pos_scores[(~inf_scores).nonzero()]
            score_sum += pos_scores.sum().cpu()
            count += pos_scores.numel() - skipped_toks

            if output_word_probs or output_word_stats:
                w = ""
                word_prob = []
                is_bpe = False
                for i in range(len(tokens)):
                    w_ind = tokens[i].item()
                    w += source_dictionary[w_ind]
                    if bpe_toks is not None and w_ind in bpe_toks:
                        w = w[:-bpe_len]
                        is_bpe = True
                    else:
                        word_prob.append((w, pos_scores[i].item()))

                        next_prob = None
                        ind = i + 1
                        while ind < len(tokens):
                            if pos_scores[ind].item() != 0:
                                next_prob = pos_scores[ind]
                                break
                            ind += 1

                        word_stats.setdefault(w, WordStat(w, is_bpe)).add(
                            pos_scores[i].item(), next_prob)
                        is_bpe = False
                        w = ""
                if output_word_probs:
                    logger.info(
                        str(int(sample_id)) + " " +
                        ("\t".join("{} [{:2f}]".format(x[0], x[1])
                                   for x in word_prob)))

    avg_nll_loss = (-score_sum / count / math.log(2) if count > 0 else 0
                    )  # convert to base 2
    logger.info("Evaluated {:,} tokens in {:.1f}s ({:.2f} tokens/s)".format(
        gen_timer.n, gen_timer.sum,
        1.0 / gen_timer.avg if gen_timer.avg > 0 else 0))

    if output_word_stats:
        for ws in sorted(word_stats.values(),
                         key=lambda x: x.count,
                         reverse=True):
            logger.info(ws)

    return {
        "loss": avg_nll_loss,
        "perplexity": 2**avg_nll_loss,
    }
예제 #16
0
def main(parsed_args):
    assert parsed_args.path is not None, '--path required for evaluation!'

    utils.import_user_module(parsed_args)

    print(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load ensemble
    print('| loading model(s) from {}'.format(parsed_args.path))
    models, args = checkpoint_utils.load_model_ensemble(
        parsed_args.path.split(':'),
        arg_overrides=eval(parsed_args.model_overrides),
        task=task,
    )

    for arg in vars(parsed_args).keys():
        if arg not in {
                'self_target',
                'future_target',
                'past_target',
                'tokens_per_sample',
                'output_size_dictionary',
                'add_bos_token',
        }:
            setattr(args, arg, getattr(parsed_args, arg))

    # reduce tokens per sample by the required context window size
    args.tokens_per_sample -= args.context_window
    task = tasks.setup_task(args)

    # Load dataset splits
    task.load_dataset(args.gen_subset)
    dataset = task.dataset(args.gen_subset)
    if args.context_window > 0:
        dataset = LMContextWindowDataset(
            dataset=dataset,
            tokens_per_sample=args.tokens_per_sample,
            context_window=args.context_window,
            pad_idx=task.source_dictionary.pad(),
        )
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(dataset)))

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    assert len(models) > 0

    print('num. model params: {}'.format(
        sum(p.numel() for p in models[0].parameters())))

    itr = task.get_batch_iterator(
        dataset=dataset,
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=True,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(task.target_dictionary, args.softmax_batch)

    score_sum = 0.
    count = 0

    if args.remove_bpe is not None:
        if args.remove_bpe == 'sentencepiece':
            raise NotImplementedError
        else:
            bpe_cont = args.remove_bpe.rstrip()
            bpe_toks = set(i for i in range(len(task.source_dictionary))
                           if task.source_dictionary[i].endswith(bpe_cont))
        bpe_len = len(bpe_cont)
    else:
        bpe_toks = None
        bpe_len = 0

    word_stats = dict()

    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()

        for sample in t:
            if 'net_input' not in sample:
                continue

            sample = utils.move_to_cuda(sample) if use_cuda else sample

            gen_timer.start()
            hypos = scorer.generate(models, sample)
            gen_timer.stop(sample['ntokens'])

            for i, hypos_i in enumerate(hypos):
                hypo = hypos_i[0]
                sample_id = sample['id'][i]

                tokens = hypo['tokens']
                tgt_len = tokens.numel()
                pos_scores = hypo['positional_scores'].float()

                if args.add_bos_token:
                    assert hypo['tokens'][0].item(
                    ) == task.target_dictionary.bos()
                    tokens = tokens[1:]
                    pos_scores = pos_scores[1:]

                skipped_toks = 0
                if bpe_toks is not None:
                    for i in range(tgt_len - 1):
                        if tokens[i].item() in bpe_toks:
                            skipped_toks += 1
                            pos_scores[i + 1] += pos_scores[i]
                            pos_scores[i] = 0

                inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(
                    float('-inf'))
                if inf_scores.any():
                    print(
                        '| Skipping tokens with inf scores:',
                        task.target_dictionary.string(
                            tokens[inf_scores.nonzero()]))
                    pos_scores = pos_scores[(~inf_scores).nonzero()]
                score_sum += pos_scores.sum().cpu()
                count += pos_scores.numel() - skipped_toks

                if args.output_word_probs or args.output_word_stats:
                    w = ''
                    word_prob = []
                    is_bpe = False
                    for i in range(len(tokens)):
                        w_ind = tokens[i].item()
                        w += task.source_dictionary[w_ind]
                        if bpe_toks is not None and w_ind in bpe_toks:
                            w = w[:-bpe_len]
                            is_bpe = True
                        else:
                            word_prob.append((w, pos_scores[i].item()))

                            next_prob = None
                            ind = i + 1
                            while ind < len(tokens):
                                if pos_scores[ind].item() != 0:
                                    next_prob = pos_scores[ind]
                                    break
                                ind += 1

                            word_stats.setdefault(w, WordStat(w, is_bpe)).add(
                                pos_scores[i].item(), next_prob)
                            is_bpe = False
                            w = ''
                    if args.output_word_probs:
                        print(
                            str(int(sample_id)) + " " +
                            ('\t'.join('{} [{:2f}]'.format(x[0], x[1])
                                       for x in word_prob)))

            wps_meter.update(sample['ntokens'])
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(
        gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss,
                                                      np.exp(avg_nll_loss)))

    if args.output_word_stats:
        for ws in sorted(word_stats.values(),
                         key=lambda x: x.count,
                         reverse=True):
            print(ws)
예제 #17
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    print(args)
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset
    if args.replace_unk is None:
        dataset = data.load_dataset(
            args.data,
            [args.gen_subset],
            args.source_lang,
            args.target_lang,
        )
    else:
        dataset = data.load_raw_text_dataset(
            args.data,
            [args.gen_subset],
            args.source_lang,
            args.target_lang,
        )
    if args.source_lang is None or args.target_lang is None:
        # record inferred languages in args
        args.source_lang, args.target_lang = dataset.src, dataset.dst

    # Load ensemble
    print('| loading model(s) from {}'.format(', '.join(args.path)))
    models, _ = utils.load_ensemble_for_inference(args.path, dataset.src_dict,
                                                  dataset.dst_dict)

    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(dataset.splits[args.gen_subset])))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, )

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    max_positions = min(model.max_encoder_positions() for model in models)
    itr = dataset.eval_dataloader(
        args.gen_subset,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
        skip_invalid_size_inputs_valid_test=True,
    )
    if args.num_shards > 1:
        if args.shard_id < 0 or args.shard_id >= args.num_shards:
            raise ValueError('--shard-id must be between 0 and num_shards')
        itr = data.sharded_iterator(itr, args.num_shards, args.shard_id)

    # Initialize generator
    gen_timer = StopwatchMeter()
    if args.score_reference:
        translator = SequenceScorer(models)
    else:
        translator = SequenceGenerator(
            models,
            beam_size=args.beam,
            stop_early=(not args.no_early_stop),
            normalize_scores=(not args.unnormalized),
            len_penalty=args.lenpen,
            unk_penalty=args.unkpen,
            sampling=args.sampling)
    if use_cuda:
        translator.cuda()

    # Generate and compute BLEU score
    scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(),
                         dataset.dst_dict.unk())
    check = [
    ]  #------------------------------------------------------------------------------------------------------
    num_sentences = 0
    has_target = True
    with progress_bar.build_progress_bar(args, itr) as t:
        if args.score_reference:
            translations = translator.score_batched_itr(t,
                                                        cuda=use_cuda,
                                                        timer=gen_timer)
        else:
            translations = translator.generate_batched_itr(
                t,
                maxlen_a=args.max_len_a,
                maxlen_b=args.max_len_b,
                cuda=use_cuda,
                timer=gen_timer,
                prefix_size=args.prefix_size)
        wps_meter = TimeMeter()
        for sample_id, src_tokens, guess_tokens, target_tokens, hypos, marker in translations:
            # Process input and ground truth
            has_target = target_tokens is not None
            target_tokens = target_tokens.int().cpu() if has_target else None
            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = dataset.splits[
                    args.gen_subset].src.get_original_text(sample_id)
                guess_str = dataset.splits[
                    args.gen_subset].guess.get_original_text(sample_id)
                target_str = dataset.splits[
                    args.gen_subset].dst.get_original_text(sample_id)
            else:
                src_str = dataset.src_dict.string(src_tokens, args.remove_bpe)
                guess_str = dataset.dst_dict.string(guess_tokens,
                                                    args.remove_bpe,
                                                    escape_unk=True)
                target_str = dataset.dst_dict.string(
                    target_tokens, args.remove_bpe,
                    escape_unk=True) if has_target else ''
            if not args.quiet:
                #print('S-{}\t{}'.format(sample_id, src_str))
                if has_target:
                    y = str(sample_id.cpu().numpy()) + ' T= ' + str(
                        target_str) + '\n'
                    detailed_file.write(y)

                    print('G-{}\t{}'.format(sample_id, guess_str))
                    print('T-{}\t{}'.format(sample_id, target_str))
                else:
                    y = str(sample_id.cpu().numpy()) + 'checkcheck\n'
                    detailed_file.write(y)
            # Process top predictions
            for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu(),
                    align_dict=align_dict,
                    dst_dict=dataset.dst_dict,
                    remove_bpe=args.remove_bpe,
                )

                if not args.quiet:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'],
                                                hypo_str))

                    guess_score = get_bleu(target_str, remove_pad(guess_str))
                    hypo_score = get_bleu(target_str, hypo_str)
                    check.append(hypo_score)
                    guess_str = make_bold(guess_str, marker)

                    y = str(sample_id.cpu().numpy()) + ' ' + str(
                        guess_score) + ' G= ' + str(guess_str) + '\n'
                    detailed_file.write(y)
                    y = str(sample_id.cpu().numpy()) + ' ' + str(
                        hypo_score) + ' H= ' + str(hypo_str) + '\n'
                    detailed_file.write(y)


#                   print('P-{}\t{}'.format( sample_id, ' '.join(map(
#                            lambda x: '{:.4f}'.format(x),
#                            hypo['positional_scores'].tolist(),
#                        ))
#                    ))
#                    print('A-{}\t{}'.format(
#                        sample_id,
#                        ' '.join(map(lambda x: str(utils.item(x)), alignment))
#                           ))

# Score only the top hypothesis
                if has_target and i == 0:
                    if align_dict is not None or args.remove_bpe is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tokenizer.Tokenizer.tokenize(
                            target_str,
                            dataset.dst_dict,
                            add_if_not_exist=True)
                    scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += 1

    print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'.
          format(num_sentences, gen_timer.n, gen_timer.sum,
                 1. / gen_timer.avg))
    summ = 0
    if has_target:
        for i in check:
            summ += i
        summ = summ / len(check)
        print('| Check BLEU =', summ)
        print('| Generate {} with beam={}: {}'.format(args.gen_subset,
                                                      args.beam,
                                                      scorer.result_string()))
예제 #18
0
    def build_generator(self,
                        models,
                        args,
                        seq_gen_cls=None,
                        extra_gen_cls_kwargs=None):
        if getattr(args, "score_reference", False):
            from fairseq.sequence_scorer import SequenceScorer

            return SequenceScorer(
                self.target_dictionary,
                compute_alignment=getattr(args, "print_alignment", False),
            )

        from fairseq.sequence_generator import (
            SequenceGenerator,
            SequenceGeneratorWithAlignment,
        )

        # Choose search strategy. Defaults to Beam Search.
        sampling = getattr(args, "sampling", False)
        sampling_topk = getattr(args, "sampling_topk", -1)
        sampling_topp = getattr(args, "sampling_topp", -1.0)
        diverse_beam_groups = getattr(args, "diverse_beam_groups", -1)
        diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5)
        match_source_len = getattr(args, "match_source_len", False)
        diversity_rate = getattr(args, "diversity_rate", -1)
        constrained = getattr(args, "constraints", False)
        prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn",
                                           None)
        if (sum(
                int(cond) for cond in [
                    sampling,
                    diverse_beam_groups > 0,
                    match_source_len,
                    diversity_rate > 0,
                ]) > 1):
            raise ValueError(
                "Provided Search parameters are mutually exclusive.")
        assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling"
        assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling"

        if sampling:
            search_strategy = search.Sampling(self.target_dictionary,
                                              sampling_topk, sampling_topp)
        elif diverse_beam_groups > 0:
            search_strategy = search.DiverseBeamSearch(self.target_dictionary,
                                                       diverse_beam_groups,
                                                       diverse_beam_strength)
        elif match_source_len:
            # this is useful for tagging applications where the output
            # length should match the input length, so we hardcode the
            # length constraints for simplicity
            search_strategy = search.LengthConstrainedBeamSearch(
                self.target_dictionary,
                min_len_a=1,
                min_len_b=0,
                max_len_a=1,
                max_len_b=0,
            )
        elif diversity_rate > -1:
            search_strategy = search.DiverseSiblingsSearch(
                self.target_dictionary, diversity_rate)
        elif constrained:
            search_strategy = search.LexicallyConstrainedBeamSearch(
                self.target_dictionary, args.constraints)
        elif prefix_allowed_tokens_fn:
            search_strategy = search.PrefixConstrainedBeamSearch(
                self.target_dictionary, prefix_allowed_tokens_fn)
        else:
            search_strategy = search.BeamSearch(self.target_dictionary)

        extra_gen_cls_kwargs = extra_gen_cls_kwargs or {}
        if seq_gen_cls is None:
            if getattr(args, "print_alignment", False):
                seq_gen_cls = SequenceGeneratorWithAlignment
                extra_gen_cls_kwargs['print_alignment'] = args.print_alignment
            else:
                seq_gen_cls = SequenceGenerator

        return seq_gen_cls(
            models,
            self.target_dictionary,
            beam_size=getattr(args, "beam", 5),
            max_len_a=getattr(args, "max_len_a", 0),
            max_len_b=getattr(args, "max_len_b", 200),
            min_len=getattr(args, "min_len", 1),
            normalize_scores=(not getattr(args, "unnormalized", False)),
            len_penalty=getattr(args, "lenpen", 1),
            unk_penalty=getattr(args, "unkpen", 0),
            temperature=getattr(args, "temperature", 1.0),
            match_source_len=getattr(args, "match_source_len", False),
            no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
            search_strategy=search_strategy,
            **extra_gen_cls_kwargs,
        )
예제 #19
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset, aligned=False)
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(task.dataset(args.gen_subset))))

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _ = utils.load_ensemble_for_inference(args.path.split(':'),
                                                  task,
                                                  model_arg_overrides=eval(
                                                      args.model_overrides))
    first_model = models[0]

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Initialize generator
    gen_timer = StopwatchMeter()
    if args.score_reference:
        translator = SequenceScorer(models, task.target_dictionary)
    else:
        translator = SequenceGenerator(
            models,
            task.target_dictionary,
            beam_size=args.beam,
            stop_early=(not args.no_early_stop),
            normalize_scores=(not args.unnormalized),
            len_penalty=args.lenpen,
            unk_penalty=args.unkpen,
            sampling=args.sampling,
            sampling_topk=args.sampling_topk,
            minlen=args.min_len,
        )

    if use_cuda:
        translator.cuda()

    for data_idx in [0, 1]:

        # Load dataset (possibly sharded)
        itr = data.EpochBatchIterator(
            dataset=task.dataset(args.gen_subset)[data_idx],
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences,
            max_positions=models[0].max_positions(),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=8,
            num_shards=args.num_shards,
            shard_id=args.shard_id,
        ).next_epoch_itr(shuffle=False)

        # Generate and compute BLEU score
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
        num_sentences = 0
        has_target = True
        res = []
        out_obj = []
        with progress_bar.build_progress_bar(args, itr) as t:
            if args.score_reference:
                translations = translator.score_batched_itr(t,
                                                            cuda=use_cuda,
                                                            timer=gen_timer)
            else:
                translations = translator.generate_batched_itr(
                    t,
                    maxlen_a=args.max_len_a,
                    maxlen_b=args.max_len_b,
                    cuda=use_cuda,
                    timer=gen_timer,
                    prefix_size=args.prefix_size,
                    to_trg=(data_idx == 0),
                )

            wps_meter = TimeMeter()
            for sample_id, src_tokens, target_tokens, hypos in translations:

                # sample out dict
                sample_out_dict = {}

                # Process input and ground truth
                has_target = target_tokens is not None
                target_tokens = target_tokens.int().cpu(
                ) if has_target else None

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(
                        args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(
                        args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    src_str = src_dict.string(src_tokens, args.remove_bpe)
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

                sample_out_dict['source'] = src_str
                if has_target:
                    sample_out_dict['target'] = target_str

                if not args.quiet:
                    print('S-{}\t{}'.format(sample_id, src_str))
                    if has_target:
                        print('T-{}\t{}'.format(sample_id, target_str))

                # Process top predictions
                preds = []

                sample_out_dict['translations'] = []
                sample_out_dict['gen_scores'] = []
                sample_out_dict['class_scores'] = []
                sample_out_dict['oracle_scores'] = []

                for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'].int().cpu()
                        if hypo['alignment'] is not None else None,
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )
                    sample_out_dict['translations'].append(hypo_str)
                    sample_out_dict['gen_scores'].append(hypo['score'])

                    # res.append((sample_id.item(), hypo_str, hypo['score']))
                    preds.append([hypo['score'], hypo_str, sample_id.item()])

                    # oracle_score
                    # oracle_score = sentence_bleu([target_str.split()], hypo_str.split())
                    # sample_out_dict['oracle_scores'].append(oracle_score)
                    # if args.oracle_score:
                    #     if has_target: # score the prediction
                    #         # replace the hypo score with the testing one
                    #         preds[-1][0] = oracle_score
                    #     else:
                    #         print("# WARNING: Not target to compute oracle")

                    # disc_score
                    padded_hypo_tokens = collate_tokens(
                        [hypo['tokens']],
                        pad_idx=first_model.src_dict.pad(),
                        eos_idx=first_model.src_dict.eos(),
                        left_pad=False,
                        min_size=5,
                    )
                    # print("padded_hypo_tokens.size", padded_hypo_tokens.size())
                    # print(models[0].discriminator.pred(padded_hypo_tokens)[0].size())
                    disc_score = models[0].discriminator.pred(
                        padded_hypo_tokens)[0][0][1 - data_idx].item()
                    sample_out_dict['class_scores'].append(disc_score)
                    if args.disc_score:
                        if hasattr(first_model, 'discriminator'):

                            preds[-1][0] = -float(
                                "inf") if disc_score < 0.5 else preds[-1][0]
                            # print("{}:{}".format(hypo_str, preds[-1][0]))
                        else:
                            print("# WARNING: No discriminator to score")

                    if not args.quiet:
                        print('H-{}\t{}\t{}'.format(sample_id, hypo['score'],
                                                    hypo_str))
                        print('P-{}\t{}'.format(
                            sample_id, ' '.join(
                                map(
                                    lambda x: '{:.4f}'.format(x),
                                    hypo['positional_scores'].tolist(),
                                ))))

                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id, ' '.join(
                                    map(lambda x: str(utils.item(x)),
                                        alignment))))

                    # Score only the top hypothesis
                    if has_target and i == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tokenizer.Tokenizer.tokenize(
                                target_str, tgt_dict, add_if_not_exist=True)
                        scorer.add(target_tokens, hypo_tokens)

                preds = sorted(preds, reverse=True)
                res.append((preds[0][2], preds[0][1], preds[0][0]))

                wps_meter.update(src_tokens.size(0))
                t.log({'wps': round(wps_meter.avg)})
                num_sentences += 1

                out_obj.append(sample_out_dict)

        if args.output_path is not None:
            if data_idx == 0:
                output_suffix = '.' + args.source_lang + '-' + args.target_lang
            else:
                output_suffix = '.' + args.target_lang + '-' + args.source_lang
            out = open(args.output_path + output_suffix, 'w')
            res = sorted(res)
            for r in res:
                if args.score_reference:
                    out.write("{} ||| {:.4f}\n".format(r[1], r[2]))
                else:
                    out.write(r[1] + '\n')

            with open(args.output_path + output_suffix + '.json',
                      'w') as f_out:
                f_out.write(
                    json.dumps(out_obj,
                               ensure_ascii=False,
                               sort_keys=False,
                               indent=4))

    print(
        '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset,
                                                      args.beam,
                                                      scorer.result_string()))
예제 #20
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(task.dataset(args.gen_subset))))

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary
    src_dict_sen_piece = task.source_sen_piece_dictionary
    tgt_dict_sen_piece = task.target_sen_piece_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    prefix_path = os.path.split(args.path.split(':')[0])[0]
    models, _ = utils.load_ensemble_for_inference(args.path.split(':'),
                                                  task,
                                                  model_arg_overrides=eval(
                                                      args.model_overrides))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,  # default need_attn=False
        )
        if args.fp16:
            model.half()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    if args.score_reference:
        translator = SequenceScorer(models, task.target_dictionary)
    else:
        translator = SequenceGenerator(
            models,
            task.target_dictionary,
            task.target_sen_piece_dictionary,
            beam_size=args.beam,
            minlen=args.min_len,
            stop_early=(not args.no_early_stop),
            normalize_scores=(not args.unnormalized),
            len_penalty=args.lenpen,
            unk_penalty=args.unkpen,
            sampling=args.sampling,
            sampling_topk=args.sampling_topk,
            sampling_temperature=args.sampling_temperature,
            diverse_beam_groups=args.diverse_beam_groups,
            diverse_beam_strength=args.diverse_beam_strength,
        )

    if use_cuda:
        translator.cuda()

    # Generate and compute BLEU score
    scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    sp = spm.SentencePieceProcessor()
    # prefix = '/home/v-lijuwu'
    sp.Load(args.senpiece_model)
    with progress_bar.build_progress_bar(args, itr) as t:
        if args.score_reference:
            translations = translator.score_batched_itr(t,
                                                        cuda=use_cuda,
                                                        timer=gen_timer)
        else:
            translations = translator.generate_batched_itr(
                t,
                maxlen_a=args.max_len_a,
                maxlen_b=args.max_len_b,
                cuda=use_cuda,
                timer=gen_timer,
                prefix_size=args.prefix_size,
            )

        ftgt = open(prefix_path + '/ref_tgt.txt', 'w', encoding='utf-8')
        fbpe_src = open(prefix_path + '/bpe_src.tok', 'w', encoding='utf-8')
        fbpe_hyp = open(prefix_path + '/bpe_trans.tok', 'w', encoding='utf-8')
        fsp_src = open(prefix_path + '/sp_src.detok', 'w', encoding='utf-8')
        fsp_hyp = open(prefix_path + '/trans.txt', 'w', encoding='utf-8')
        fhyp_tok = open(prefix_path + '/hyp_trans.txt', 'w', encoding='utf-8')
        fhyp_tok_ids = open(prefix_path + '/hyp_ids.txt',
                            'w',
                            encoding='utf-8')
        wps_meter = TimeMeter()
        id = 0
        for sample_id, src_tokens, target_tokens, src_sen_piece_tokens, target_sen_piece_tokens, hypos, hypos_sen_piece in translations:
            # Process input and ground truth
            has_target = target_tokens is not None
            target_tokens = target_tokens.int().cpu() if has_target else None
            target_sen_piece_tokens = target_sen_piece_tokens.int().cpu(
            ) if has_target else None

            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = task.dataset(
                    args.gen_subset).src.get_original_text(sample_id)
                target_str = task.dataset(
                    args.gen_subset).tgt.get_original_text(sample_id)
                src_str_sen_piece = task.dataset(
                    args.gen_subset).src_sen_piece.get_original_text(sample_id)
                tgt_str_sen_piece = task.dataset(
                    args.gen_subset).tgt_sen_piece.get_original_text(sample_id)
            else:
                src_str = src_dict.string(src_tokens, args.remove_bpe)
                fbpe_src.write(src_str + '\n')  # write  bpe_token data
                if has_target:
                    target_str = tgt_dict.string(target_tokens,
                                                 args.remove_bpe,
                                                 escape_unk=True)

                src_str_sen_piece = src_dict_sen_piece.string(
                    src_sen_piece_tokens)  # return list, not string
                src_str_sen_piece_list = src_dict_sen_piece.to_list(
                    src_sen_piece_tokens)
                src_str_out = sp.DecodePieces(src_str_sen_piece_list)
                fsp_src.write(src_str_out + '\n')  # write sp_detok data
                if has_target:
                    tgt_str_sen_piece_list = tgt_dict_sen_piece.to_list(
                        target_sen_piece_tokens, escape_unk=True)
                    tgt_str_sen_piece = tgt_dict_sen_piece.string(
                        target_sen_piece_tokens, escape_unk=True)
                    tgt_str_out = sp.DecodePieces(tgt_str_sen_piece_list)
                    ftgt.write(tgt_str_out + '\n')

            if not args.quiet:
                print('S-{}\t{}'.format(sample_id, src_str))
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str))
                print('SS-{}\t{}'.format(sample_id, src_str_sen_piece))
                if has_target:
                    print('TS-{}\t{}'.format(sample_id, tgt_str_sen_piece))

            score1 = 0.
            hypo_str1 = ""
            # Process top predictions
            for i, hypo in enumerate(
                    hypos[:min(len(hypos), args.nbest)]):  # args.nbest=1
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu()
                    if hypo['alignment'] is not None else None,
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                )

                if not args.quiet:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'],
                                                hypo_str))
                    print('P-{}\t{}'.format(
                        sample_id, ' '.join(
                            map(
                                lambda x: '{:.4f}'.format(x),
                                hypo['positional_scores'].tolist(),
                            ))))

                    if args.print_alignment:
                        print('A-{}\t{}'.format(
                            sample_id, ' '.join(
                                map(lambda x: str(utils.item(x)), alignment))))

                # Score only the top hypothesis
                if has_target and i == 0:
                    score1 = hypo['score']
                    hypo_str1 = hypo_str
                    if align_dict is not None or args.remove_bpe is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tokenizer.Tokenizer.tokenize(
                            target_str, tgt_dict, add_if_not_exist=True)
                    scorer.add(target_tokens, hypo_tokens)
                # write bpe_trans to file
                fbpe_hyp.write(hypo_str + '\n')

            score2 = 0.
            # process sen_piece and save translations to file
            for i, hypo in enumerate(
                    hypos_sen_piece[:min(len(hypos_sen_piece), args.nbest)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str_sen_piece,
                    alignment=hypo['alignment'].int().cpu()
                    if hypo['alignment'] is not None else None,
                    align_dict=align_dict,
                    tgt_dict=tgt_dict_sen_piece,
                    remove_bpe=None,
                    to_list=True,
                )
                if not args.quiet:
                    print('HS-{}\t{}\t{}'.format(sample_id, hypo['score'],
                                                 hypo_str))
                hypo_str_out = sp.DecodePieces(hypo_str)
                fsp_hyp.write(hypo_str_out + '\n')  # detokenized data

                # Score only the top hypothesis
                if has_target and i == 0:
                    score2 = hypo['score']
            if score1 > score2:
                fhyp_tok.write(hypo_str1 + '\n')
                fhyp_tok_ids.write(str(id) + '\n')
            id += 1
            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += 1
    ftgt.close()
    fbpe_src.close()
    fbpe_hyp.close()
    fsp_src.close()
    fsp_hyp.close()
    fhyp_tok.close()
    fhyp_tok_ids.close()
    print(
        '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset,
                                                      args.beam,
                                                      scorer.result_string()))
예제 #21
0
def main(parsed_args):
    assert parsed_args.path is not None, '--path required for evaluation!'

    utils.import_user_module(parsed_args)

    logger.info(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load ensemble
    logger.info('loading model(s) from {}'.format(parsed_args.path))
    models, args = checkpoint_utils.load_model_ensemble(
        parsed_args.path.split(os.pathsep),
        arg_overrides=eval(parsed_args.model_overrides),
        task=task,
    )

    for arg in vars(parsed_args).keys():
        if arg not in {
                'self_target',
                'future_target',
                'past_target',
                'tokens_per_sample',
                'output_size_dictionary',
                'add_bos_token',
        }:
            setattr(args, arg, getattr(parsed_args, arg))

    # reduce tokens per sample by the required context window size
    args.tokens_per_sample -= args.context_window
    task = tasks.setup_task(args)

    # Load dataset splits
    task.load_dataset(args.gen_subset)
    dataset = task.dataset(args.gen_subset)
    if args.context_window > 0:
        dataset = LMContextWindowDataset(
            dataset=dataset,
            tokens_per_sample=args.tokens_per_sample,
            context_window=args.context_window,
            pad_idx=task.source_dictionary.pad(),
        )
    logger.info('{} {} {} examples'.format(args.data, args.gen_subset,
                                           len(dataset)))

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    assert len(models) > 0

    logger.info('num. model params: {}'.format(
        sum(p.numel() for p in models[0].parameters())))

    itr = task.get_batch_iterator(
        dataset=dataset,
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=True,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(task.target_dictionary,
                            args.softmax_batch,
                            args=args)

    score_sum = 0.
    count = 0

    if args.remove_bpe is not None:
        if args.remove_bpe == 'sentencepiece':
            raise NotImplementedError
        else:
            bpe_cont = args.remove_bpe.rstrip()
            bpe_toks = {
                i
                for i in range(len(task.source_dictionary))
                if task.source_dictionary[i].endswith(bpe_cont)
            }
        bpe_len = len(bpe_cont)
    else:
        bpe_toks = None
        bpe_len = 0

    word_stats = dict()

    if args.knnlm and args.save_knnlm_dstore:
        raise ValueError(
            "Cannot use knnlm while trying to build the datastore!")

    if args.knnlm:
        knn_dstore = KNN_Dstore(args)

    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()

        if args.save_knnlm_dstore:
            print('keytype being saved:', args.knn_keytype)
            if args.dstore_fp16:
                print('Saving fp16')
                dstore_keys = np.memmap(args.dstore_mmap + '_keys.npy',
                                        dtype=np.float16,
                                        mode='w+',
                                        shape=(args.dstore_size,
                                               args.decoder_embed_dim))
                dstore_vals = np.memmap(args.dstore_mmap + '_vals.npy',
                                        dtype=np.int16,
                                        mode='w+',
                                        shape=(args.dstore_size, 1))
            else:
                print('Saving fp32')
                dstore_keys = np.memmap(args.dstore_mmap + '_keys.npy',
                                        dtype=np.float32,
                                        mode='w+',
                                        shape=(args.dstore_size,
                                               args.decoder_embed_dim))
                dstore_vals = np.memmap(args.dstore_mmap + '_vals.npy',
                                        dtype=np.int,
                                        mode='w+',
                                        shape=(args.dstore_size, 1))

        dstore_idx = 0

        #knn_probs_file = open(args.output_log_probs_file_prefix + '.knn.txt', 'w')
        #orig_probs_file = open(args.output_log_probs_file_prefix + '.orig.txt', 'w')
        if args.knnlm:
            dists_file = open(args.output_log_probs_file_prefix + '.dists.txt',
                              'w')
            knns_file = open(
                args.output_log_probs_file_prefix + '.knn_indices.txt', 'w')
        if args.save_knnlm_dstore or args.knnlm:
            tokens_file = open(args.output_tokens_file, 'w')

        for ex_i, sample in enumerate(t):
            if 'net_input' not in sample:
                continue

            sample = utils.move_to_cuda(sample) if use_cuda else sample

            gen_timer.start()
            if args.knnlm:
                hypos = scorer.generate(models, sample, knn_dstore=knn_dstore)
            else:
                hypos = scorer.generate(models, sample)
            gen_timer.stop(sample['ntokens'])

            for i, hypos_i in enumerate(hypos):
                if i == len(hypos) - 1:
                    continue
                hypo = hypos_i[0]
                skipped = False
                if args.save_knnlm_dstore:
                    shape = hypo['dstore_keys'].shape
                    if shape[0] == args.tokens_per_sample:
                        if dstore_idx + shape[0] > args.dstore_size:
                            shape = [args.dstore_size - dstore_idx]
                            hypo['dstore_keys'] = hypo[
                                'dstore_keys'][:shape[0]]
                        if args.dstore_fp16:
                            dstore_keys[dstore_idx:shape[0] +
                                        dstore_idx] = hypo['dstore_keys'].view(
                                            -1, args.decoder_embed_dim).cpu(
                                            ).numpy().astype(np.float16)
                            dstore_vals[dstore_idx:shape[0] +
                                        dstore_idx] = hypo['tokens'].view(
                                            -1,
                                            1).cpu().numpy().astype(np.int16)
                        else:
                            dstore_keys[dstore_idx:shape[0] +
                                        dstore_idx] = hypo['dstore_keys'].view(
                                            -1, args.decoder_embed_dim).cpu(
                                            ).numpy().astype(np.float32)
                            dstore_vals[dstore_idx:shape[0] +
                                        dstore_idx] = hypo['tokens'].view(
                                            -1, 1).cpu().numpy().astype(np.int)

                        dstore_idx += shape[0]
                    else:
                        skipped = True
                        print('Skipping this one with shape', shape)

                sample_id = sample['id'][i]

                tokens = hypo['tokens']
                tgt_len = tokens.numel()
                pos_scores = hypo['positional_scores'].float()
                orig_scores = hypo['original_scores'].float()
                yhat_scores = hypo['yhat_scores'].float()
                if args.knnlm:
                    assert hypo['dists_full'] != None
                    dists_full = hypo['dists_full'].float()
                    knns_full = hypo['knns_full']

                    # knn_probs_file.write('\n'.join([str(prob) for prob in yhat_scores.tolist()]) + '\n')
                    # orig_probs_file.write('\n'.join([str(prob) for prob in orig_scores.tolist()]) + '\n')
                    dists_file.write('\n'.join([
                        str(dists_for_token)
                        for dists_for_token in dists_full.tolist()
                    ]) + '\n')
                    knns_file.write('\n'.join([
                        str(knns_for_token)
                        for knns_for_token in knns_full.tolist()
                    ]) + '\n')

                if args.save_knnlm_dstore or args.knnlm:
                    if not skipped:
                        word_tokens = [
                            task.target_dictionary[token]
                            for token in hypo['tokens']
                        ]
                        tokens_file.write('\n'.join(word_tokens) + '\n')
                        assert len(
                            hypo['yhat_scores'].float().tolist()) == len(
                                word_tokens)
                '''
                doc = spacy.tokens.doc.Doc(
                    nlp.vocab, words=word_tokens, spaces=[True for token in tokens])
                for name, proc in nlp.pipeline:
                    doc = proc(doc)
                '''

                if args.add_bos_token:
                    assert hypo['tokens'][0].item(
                    ) == task.target_dictionary.bos()
                    tokens = tokens[1:]
                    pos_scores = pos_scores[1:]

                skipped_toks = 0
                if bpe_toks is not None:
                    for i in range(tgt_len - 1):
                        if tokens[i].item() in bpe_toks:
                            skipped_toks += 1
                            pos_scores[i + 1] += pos_scores[i]
                            pos_scores[i] = 0

                #inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
                #if inf_scores.any():
                #    logger.info(
                #        'skipping tokens with inf scores:',
                #        task.target_dictionary.string(tokens[inf_scores.nonzero()])
                #    )
                #    pos_scores = pos_scores[(~inf_scores).nonzero()]
                score_sum += pos_scores.sum().cpu()
                count += pos_scores.numel() - skipped_toks

                if args.output_word_probs or args.output_word_stats:
                    w = ''
                    word_prob = []
                    is_bpe = False
                    for i in range(len(tokens)):
                        w_ind = tokens[i].item()
                        w += task.source_dictionary[w_ind]
                        if bpe_toks is not None and w_ind in bpe_toks:
                            w = w[:-bpe_len]
                            is_bpe = True
                        else:
                            word_prob.append((w, pos_scores[i].item()))

                            next_prob = None
                            ind = i + 1
                            while ind < len(tokens):
                                if pos_scores[ind].item() != 0:
                                    next_prob = pos_scores[ind]
                                    break
                                ind += 1

                            word_stats.setdefault(w, WordStat(w, is_bpe)).add(
                                pos_scores[i].item(), next_prob)
                            is_bpe = False
                            w = ''

            wps_meter.update(sample['ntokens'])
            t.log({'wps': round(wps_meter.avg)})

    if args.save_knnlm_dstore:
        print("dstore_idx", dstore_idx, "final shape", shape)
        print("Keys", dstore_keys.shape, dstore_keys.dtype)
        print("Vals", dstore_vals.shape, dstore_vals.dtype)

    # knn_probs_file.close()
    # orig_probs_file.close()
    tokens_file.close()

    # Entities
    # mask = torch.tensor([1 if token.ent_type_ else 0 for token in doc], dtype=float)
    # count_entities = mask.sum()
    # if torch.cuda.is_available() and not parsed_args.cpu:
    #     mask = mask.cuda()
    # avg_nll_loss_entities = - (pos_scores * mask).sum() / count_entities.cpu() / math.log(2)

    avg_nll_loss = -score_sum / count / math.log(2)  # convert to base 2
    logger.info('Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(
        gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    logger.info('Loss (base 2): {:.4f}, Perplexity: {:.2f}'.format(
        avg_nll_loss, 2**avg_nll_loss))

    if args.output_word_stats:
        for ws in sorted(word_stats.values(),
                         key=lambda x: x.count,
                         reverse=True):
            logger.info(ws)
예제 #22
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(task.dataset(args.gen_subset))))

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _model_args = utils.load_ensemble_for_inference(
        args.path.split(':'),
        task,
        model_arg_overrides=eval(args.model_overrides),
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    if args.score_reference:
        translator = SequenceScorer(models, task.target_dictionary)
    else:
        translator = SequenceGenerator(
            models,
            task.target_dictionary,
            beam_size=args.beam,
            minlen=args.min_len,
            stop_early=(not args.no_early_stop),
            normalize_scores=(not args.unnormalized),
            len_penalty=args.lenpen,
            unk_penalty=args.unkpen,
            sampling=args.sampling,
            sampling_topk=args.sampling_topk,
            sampling_temperature=args.sampling_temperature,
            diverse_beam_groups=args.diverse_beam_groups,
            diverse_beam_strength=args.diverse_beam_strength,
            match_source_len=args.match_source_len,
            no_repeat_ngram_size=args.no_repeat_ngram_size,
        )

    if use_cuda:
        translator.cuda()

    # Generate and compute BLEU score
    scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    # output the result
    result = [''] * 21678
    with progress_bar.build_progress_bar(args, itr) as t:
        if args.score_reference:
            translations = translator.score_batched_itr(t,
                                                        cuda=use_cuda,
                                                        timer=gen_timer)
        else:
            translations = translator.generate_batched_itr(
                t,
                maxlen_a=args.max_len_a,
                maxlen_b=args.max_len_b,
                cuda=use_cuda,
                timer=gen_timer,
                prefix_size=args.prefix_size,
            )

        wps_meter = TimeMeter()
        for sample_id, src_tokens, target_tokens, hypos in translations:
            # Process input and ground truth
            has_target = target_tokens is not None
            target_tokens = target_tokens.int().cpu() if has_target else None

            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = task.dataset(
                    args.gen_subset).src.get_original_text(sample_id)
                target_str = task.dataset(
                    args.gen_subset).tgt.get_original_text(sample_id)
            else:
                src_str = src_dict.string(src_tokens, args.remove_bpe)
                if has_target:
                    target_str = tgt_dict.string(target_tokens,
                                                 args.remove_bpe,
                                                 escape_unk=True)

            if not args.quiet:
                print('S-{}\t{}'.format(sample_id, src_str))
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str))

            # Process top predictions
            for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu()
                    if hypo['alignment'] is not None else None,
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                )

                result[sample_id] = hypo_str
                if not args.quiet:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'],
                                                hypo_str))
                    print('P-{}\t{}'.format(
                        sample_id, ' '.join(
                            map(
                                lambda x: '{:.4f}'.format(x),
                                hypo['positional_scores'].tolist(),
                            ))))

                    if args.print_alignment:
                        print('A-{}\t{}'.format(
                            sample_id, ' '.join(
                                map(lambda x: str(utils.item(x)), alignment))))

                # Score only the top hypothesis
                if has_target and i == 0:
                    if align_dict is not None or args.remove_bpe is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tokenizer.Tokenizer.tokenize(
                            target_str, tgt_dict, add_if_not_exist=True)
                    scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += 1

    print(
        '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset,
                                                      args.beam,
                                                      scorer.result_string()))
    #output the result
    return result
예제 #23
0
    def build_generator(
        self,
        models,
        args,
        seq_gen_cls=None,
        extra_gen_cls_kwargs=None,
        prefix_allowed_tokens_fn=None,
    ):
        """
        Build a :class:`~fairseq.SequenceGenerator` instance for this
        task.

        Args:
            models (List[~fairseq.models.FairseqModel]): ensemble of models
            args (fairseq.dataclass.configs.GenerationConfig):
                configuration object (dataclass) for generation
            extra_gen_cls_kwargs (Dict[str, Any]): extra options to pass
                through to SequenceGenerator
            prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]):
                If provided, this function constrains the beam search to
                allowed tokens only at each step. The provided function
                should take 2 arguments: the batch ID (`batch_id: int`)
                and a unidimensional tensor of token ids (`inputs_ids:
                torch.Tensor`). It has to return a `List[int]` with the
                allowed tokens for the next generation step conditioned
                on the previously generated tokens (`inputs_ids`) and
                the batch ID (`batch_id`). This argument is useful for
                constrained generation conditioned on the prefix, as
                described in "Autoregressive Entity Retrieval"
                (https://arxiv.org/abs/2010.00904) and
                https://github.com/facebookresearch/GENRE.
        """
        if getattr(args, "score_reference", False):
            from fairseq.sequence_scorer import SequenceScorer

            return SequenceScorer(
                self.target_dictionary,
                compute_alignment=getattr(args, "print_alignment", False),
            )

        from fairseq.sequence_generator import (
            SequenceGenerator,
            SequenceGeneratorWithAlignment,
        )

        # Choose search strategy. Defaults to Beam Search.
        sampling = getattr(args, "sampling", False)
        sampling_topk = getattr(args, "sampling_topk", -1)
        sampling_topp = getattr(args, "sampling_topp", -1.0)
        diverse_beam_groups = getattr(args, "diverse_beam_groups", -1)
        diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5)
        match_source_len = getattr(args, "match_source_len", False)
        diversity_rate = getattr(args, "diversity_rate", -1)
        constrained = getattr(args, "constraints", False)
        if prefix_allowed_tokens_fn is None:
            prefix_allowed_tokens_fn = getattr(args,
                                               "prefix_allowed_tokens_fn",
                                               None)
        if (sum(
                int(cond) for cond in [
                    sampling,
                    diverse_beam_groups > 0,
                    match_source_len,
                    diversity_rate > 0,
                ]) > 1):
            raise ValueError(
                "Provided Search parameters are mutually exclusive.")
        assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling"
        assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling"

        if sampling:
            search_strategy = search.Sampling(self.target_dictionary,
                                              sampling_topk, sampling_topp)
        elif diverse_beam_groups > 0:
            search_strategy = search.DiverseBeamSearch(self.target_dictionary,
                                                       diverse_beam_groups,
                                                       diverse_beam_strength)
        elif match_source_len:
            # this is useful for tagging applications where the output
            # length should match the input length, so we hardcode the
            # length constraints for simplicity
            search_strategy = search.LengthConstrainedBeamSearch(
                self.target_dictionary,
                min_len_a=1,
                min_len_b=0,
                max_len_a=1,
                max_len_b=0,
            )
        elif diversity_rate > -1:
            search_strategy = search.DiverseSiblingsSearch(
                self.target_dictionary, diversity_rate)
        elif constrained:
            search_strategy = search.LexicallyConstrainedBeamSearch(
                self.target_dictionary, args.constraints)
        elif prefix_allowed_tokens_fn:
            search_strategy = search.PrefixConstrainedBeamSearch(
                self.target_dictionary, prefix_allowed_tokens_fn)
        else:
            search_strategy = search.BeamSearch(self.target_dictionary)

        extra_gen_cls_kwargs = extra_gen_cls_kwargs or {}
        if seq_gen_cls is None:
            if getattr(args, "print_alignment", False):
                seq_gen_cls = SequenceGeneratorWithAlignment
                extra_gen_cls_kwargs["print_alignment"] = args.print_alignment
            else:
                seq_gen_cls = SequenceGenerator

        return seq_gen_cls(
            models,
            self.target_dictionary,
            beam_size=getattr(args, "beam", 5),
            max_len_a=getattr(args, "max_len_a", 0),
            max_len_b=getattr(args, "max_len_b", 200),
            min_len=getattr(args, "min_len", 1),
            normalize_scores=(not getattr(args, "unnormalized", False)),
            len_penalty=getattr(args, "lenpen", 1),
            unk_penalty=getattr(args, "unkpen", 0),
            temperature=getattr(args, "temperature", 1.0),
            match_source_len=getattr(args, "match_source_len", False),
            no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
            search_strategy=search_strategy,
            **extra_gen_cls_kwargs,
        )
예제 #24
0
파일: generate.py 프로젝트: fyabc/fairseq
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    if args.score_reference:
        translator = SequenceScorer(models, task.target_dictionary)
    else:
        translator = SequenceGenerator(
            models, task.target_dictionary, beam_size=args.beam, minlen=args.min_len,
            stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
            len_penalty=args.lenpen, unk_penalty=args.unkpen,
            sampling=args.sampling, sampling_topk=args.sampling_topk, sampling_temperature=args.sampling_temperature,
            diverse_beam_groups=args.diverse_beam_groups, diverse_beam_strength=args.diverse_beam_strength,
        )

    if use_cuda:
        translator.cuda()

    # Generate and compute BLEU score
    scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    with progress_bar.build_progress_bar(args, itr) as t:
        if args.score_reference:
            translations = translator.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
        else:
            translations = translator.generate_batched_itr(
                t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
                cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size,
            )

        wps_meter = TimeMeter()
        for sample_id, src_tokens, target_tokens, hypos in translations:
            # Process input and ground truth
            has_target = target_tokens is not None
            target_tokens = target_tokens.int().cpu() if has_target else None

            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
                target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
            else:
                src_str = src_dict.string(src_tokens, args.remove_bpe)
                if has_target:
                    target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)

            if not args.quiet:
                print('S-{}\t{}'.format(sample_id, src_str))
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str))

            # Process top predictions
            for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                )

                if not args.quiet:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
                    print('P-{}\t{}'.format(
                        sample_id,
                        ' '.join(map(
                            lambda x: '{:.4f}'.format(x),
                            hypo['positional_scores'].tolist(),
                        ))
                    ))

                    if args.print_alignment:
                        print('A-{}\t{}'.format(
                            sample_id,
                            ' '.join(map(lambda x: str(utils.item(x)), alignment))
                        ))

                # Score only the top hypothesis
                if has_target and i == 0:
                    if align_dict is not None or args.remove_bpe is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tokenizer.Tokenizer.tokenize(
                            target_str, tgt_dict, add_if_not_exist=True)
                    scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += 1

    print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
        num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
예제 #25
0
파일: eval_lm.py 프로젝트: fyabc/fairseq
def main(parsed_args):
    assert parsed_args.path is not None, '--path required for evaluation!'

    print(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load ensemble
    print('| loading model(s) from {}'.format(parsed_args.path))
    models, args = utils.load_ensemble_for_inference(parsed_args.path.split(':'), task)

    args.__dict__.update(parsed_args.__dict__)
    print(args)

    task.args = args

    # Load dataset splits
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
        if args.fp16:
            model.half()

    assert len(models) > 0

    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(*[
            model.max_positions() for model in models
        ]),
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        ignore_invalid_inputs=True,
    ).next_epoch_itr(shuffle=False)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(models, task.target_dictionary)
    if use_cuda:
        scorer.cuda()

    score_sum = 0.
    count = 0

    if args.remove_bpe is not None:
        bpe_cont = args.remove_bpe.rstrip()
        bpe_toks = set(i for i in range(len(task.dictionary)) if task.dictionary[i].endswith(bpe_cont))
        bpe_len = len(bpe_cont)
    else:
        bpe_toks = None
        bpe_len = 0

    word_stats = dict()

    with progress_bar.build_progress_bar(args, itr) as t:
        results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
        wps_meter = TimeMeter()
        for _, src_tokens, __, hypos in results:
            for hypo in hypos:
                pos_scores = hypo['positional_scores']

                skipped_toks = 0
                if bpe_toks is not None:
                    for i in range(len(hypo['tokens']) - 1):
                        if hypo['tokens'][i].item() in bpe_toks:
                            skipped_toks += 1
                            pos_scores[i + 1] += pos_scores[i]
                            pos_scores[i] = 0

                inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
                if inf_scores.any():
                    print('| Skipping tokens with inf scores:',
                          task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
                    pos_scores = pos_scores[(~inf_scores).nonzero()]
                score_sum += utils.item(pos_scores.sum())
                count += pos_scores.numel() - skipped_toks

                if args.output_word_probs or args.output_word_stats:
                    w = ''
                    word_prob = []
                    is_bpe = False
                    for i in range(len(hypo['tokens'])):
                        w_ind = hypo['tokens'][i].item()
                        w += task.dictionary[w_ind]
                        if bpe_toks is not None and w_ind in bpe_toks:
                            w = w[:-bpe_len]
                            is_bpe = True
                        else:
                            word_prob.append((w, pos_scores[i].item()))
                            word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item())
                            is_bpe = False
                            w = ''
                    if args.output_word_probs:
                        print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))

            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))

    if args.output_word_stats:
        for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
            print(ws)