示例#1
0
    def valid_step(self, sample, model, criterion):
        loss, sample_size, logging_output = super().valid_step(sample, model, criterion)

        def decode(toks, escape_unk=False, trunc_eos=True):
            s = self.tgt_dict.string(
                toks.int().cpu(),
                self.args['task']['eval_bleu_remove_bpe'],
                escape_unk=escape_unk,
                trunc_eos=trunc_eos,
            )
            if len(s) == 0:
                s = '0'
            if self.tokenizer:
                s = self.tokenizer.decode(s)
            return s

        gen_out = self.inference_step(self.sequence_generator, [model], sample)
        ids = sample['id'].tolist()
        hyps, refs = [], []
        for i in range(len(gen_out)):
            hyps.append(decode(gen_out[i][0]['tokens']))
            refs.append(decode(
                utils.strip_pad(sample['target'][i], self.tgt_dict.pad()),
                escape_unk=True,
            ))

        bleu, rouge_l, meteor = self._inference_score(hyps, refs, ids)
        logging_output['bleu'] = bleu
        logging_output['rouge_l'] = rouge_l
        logging_output['meteor'] = meteor

        return loss, sample_size, logging_output
示例#2
0
    def step_out(self, sample, model):
        def decode(toks, escape_unk=False, trunc_eos=True):
            s = self.tgt_dict.string(
                toks.int().cpu(),
                self.args['task']['eval_bleu_remove_bpe'],
                escape_unk=escape_unk,
                trunc_eos=trunc_eos,
            )
            if len(s) == 0:
                s = '0'  # if predict sentence is null, use '0'
            if self.tokenizer:
                s = self.tokenizer.decode(s)
            return s

        gen_out = self.inference_step(self.sequence_generator, [model],
                                      sample,
                                      bos_token=self.target_dictionary.bos())
        src_ids = sample['src_ids']
        tgt_ids = sample['tgt_ids']
        hyps, refs = [], []
        for i in range(len(gen_out)):
            hyps.append(decode(gen_out[i][0]['tokens']))
            refs.append(
                decode(
                    utils.strip_pad(sample['target'][i], self.tgt_dict.pad()),
                    escape_unk=True,  # don't count <unk> as matches to the hypo
                ))
        return hyps, refs, src_ids, tgt_ids
示例#3
0
    def _inference_with_bleu(self, generator, sample, model):
        import sacrebleu

        def decode(toks, escape_unk=False):
            s = self.tgt_dict.string(
                toks.int().cpu(),
                self.args['task']['eval_bleu_remove_bpe'],
                escape_unk=escape_unk,
            )
            if self.tokenizer:
                s = self.tokenizer.decode(s)
            return s

        gen_out = self.inference_step(generator, [model], sample, None)
        hyps, refs = [], []
        for i in range(len(gen_out)):
            hyps.append(decode(gen_out[i][0]['tokens']))
            refs.append(
                decode(
                    utils.strip_pad(sample['target'][i], self.tgt_dict.pad()),
                    escape_unk=True,  # don't count <unk> as matches to the hypo
                ))
        if self.args['task']['eval_bleu_print_samples']:
            LOGGER.info('example hypothesis: ' + hyps[0])
            LOGGER.info('example reference: ' + refs[0])
        # tokenize = sacrebleu.DEFAULT_TOKENIZER if not self.args['task']['eval_tokenized_bleu'] else 'none'
        # return sacrebleu.corpus_bleu(hyps, [refs], tokenize=tokenize)
        if self.args['task']['eval_tokenized_bleu']:
            return sacrebleu.corpus_bleu(hyps, [refs], tokenize='none')
        else:
            return sacrebleu.corpus_bleu(hyps, [refs])
示例#4
0
    def valid_step(self, sample, model, criterion):
        loss, sample_size, logging_output = super().valid_step(
            sample, model, criterion)

        def decode(toks, escape_unk=False, trunc_eos=True):
            s = self.tgt_dict.string(
                toks.int().cpu(),
                self.args['task']['eval_bleu_remove_bpe'],
                escape_unk=escape_unk,
                trunc_eos=trunc_eos,
            )
            if self.tokenizer:
                s = self.tokenizer.decode(s)
            if len(s) == 0:
                s = '0'  # if predict sentence is null, use '0'
            return s

        if self.args['task']['eval_bleu']:
            gen_out = self.inference_step(self.sequence_generator, [model],
                                          sample)
            ids = sample['id'].tolist()
            hyps, refs = [], []
            for i in range(len(gen_out)):
                hyps.append(decode(gen_out[i][0]['tokens']))
                refs.append(
                    decode(
                        utils.strip_pad(sample['target'][i],
                                        self.tgt_dict.pad()),
                        escape_unk=
                        True,  # don't count <unk> as matches to the hypo
                    ))
            if self.args['task']['eval_with_sacrebleu']:
                import sacrebleu
                tokenize = sacrebleu.DEFAULT_TOKENIZER if not self.args[
                    'task']['eval_tokenized_bleu'] else 'none'
                bleu = sacrebleu.corpus_bleu(hyps, [refs], tokenize=tokenize)
                logging_output['_bleu_sys_len'] = bleu.sys_len
                logging_output['_bleu_ref_len'] = bleu.ref_len
                # we split counts into separate entries so that they can be
                # summed efficiently across workers using fast-stat-sync
                assert len(bleu.counts) == EVAL_BLEU_ORDER
                for i in range(EVAL_BLEU_ORDER):
                    logging_output['_bleu_counts_' + str(i)] = bleu.counts[i]
                    logging_output['_bleu_totals_' + str(i)] = bleu.totals[i]
            else:
                bleu, rouge_l, meteor = self._inference_score(hyps, refs, ids)
                logging_output['bleu'] = round(bleu, 4)
                logging_output['rouge_l'] = round(rouge_l, 4)
                logging_output['meteor'] = round(meteor, 4)
        return loss, sample_size, logging_output
示例#5
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """

        assert hasattr(
            model.decoder,
            'adaptive_softmax') and model.decoder.adaptive_softmax is not None
        adaptive_softmax = model.decoder.adaptive_softmax

        net_output = model(**sample['net_input'])
        orig_target = model.get_targets(sample, net_output)

        nsentences = orig_target.size(0)
        orig_target = orig_target.view(-1)

        bsz = orig_target.size(0)

        logits, target = adaptive_softmax(net_output[0], orig_target)
        assert len(target) == len(logits)

        loss = net_output[0].new(1 if reduce else bsz).zero_()

        for i in range(len(target)):
            if target[i] is not None:
                assert (target[i].min() >= 0
                        and target[i].max() <= logits[i].size(1))
                loss += F.cross_entropy(
                    logits[i],
                    target[i],
                    ignore_index=self.padding_idx,
                    reduction='sum' if reduce else 'none',
                )

        orig = utils.strip_pad(orig_target, self.padding_idx)
        ntokens = orig.numel()
        sample_size = sample['target'].size(
            0) if self.sentence_avg else ntokens
        logging_output = {
            'loss': loss.data,
            'ntokens': ntokens,
            'nsentences': nsentences,
            'sample_size': sample_size,
        }
        return loss, sample_size, logging_output
示例#6
0
    def valid_step(self, sample, model, criterion):
        loss, sample_size, logging_output = super().valid_step(
            sample, model, criterion)

        def decode(toks, escape_unk=False, trunc_eos=True):
            s = self.tgt_dict.string(
                toks.int().cpu(),
                self.args['task']['eval_bleu_remove_bpe'],
                escape_unk=escape_unk,
                trunc_eos=trunc_eos,
            )
            if len(s) == 0:
                s = '0'  # if predict sentence is null, use '0'
            if self.tokenizer:
                s = self.tokenizer.decode(s)
            return s

        gen_out = self.inference_step(
            self.sequence_generator,
            [model],
            sample,
            bos_token=self.target_dictionary.bos(),
            # bos_token=self.target_dictionary.pad(),
        )
        ids = sample['id'].tolist()
        hyps, refs = [], []
        for i in range(len(gen_out)):
            hyps.append(decode(gen_out[i][0]['tokens']))
            refs.append(
                decode(
                    utils.strip_pad(sample['target'][i], self.tgt_dict.pad()),
                    escape_unk=True,  # don't count <unk> as matches to the hypo
                ))

        bleu, rouge_l, meteor = self._inference_score(hyps, refs, ids)
        logging_output['bleu'] = round(bleu, 4)
        logging_output['rouge_l'] = round(rouge_l, 4)
        logging_output['meteor'] = round(meteor, 4)

        return loss, sample_size, logging_output
示例#7
0
def main(args, out_file=None):
    use_cuda = torch.cuda.is_available() and not args['common']['cpu']

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args['dataset']['gen_subset'])

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Load ensemble
    LOGGER.info('loading model(s) from {}'.format(args['eval']['path']))
    models, _ = checkpoint_utils.load_model_ensemble(
        utils.split_paths(args['eval']['path']),
        arg_overrides=eval(args['eval']['model_overrides']),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None
            if args['eval']['no_beamable_mm'] else args['eval']['beam'],
            need_attn=args['eval']['print_alignment'],
        )

        if use_cuda:
            device = os.environ.get('CUDA_VISIBALE_DEVICES',
                                    [0])[0]  # get first device as default
            torch.cuda.set_device(f'cuda:{device}')
            model = model.cuda()
        if args['common']['fp16'] and use_cuda:
            model.half()

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args['dataset']['gen_subset']),
        max_tokens=args['dataset']['max_tokens'],
        max_sentences=args['eval']['max_sentences'],
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args['dataset']
        ['skip_invalid_size_inputs_valid_test'],
        required_batch_size_multiple=args['dataset']
        ['required_batch_size_multiple'],
        num_shards=args['dataset']['num_shards'],
        shard_id=args['dataset']['shard_id'],
        num_workers=args['dataset']['num_workers'],
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args['common']['log_format'],
        log_interval=args['common']['log_interval'],
        default_log_format=('tqdm' if not args['common']['no_progress_bar']
                            else 'none'),
    )

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(models, args)

    sources, hypotheses, references = dict(), dict(), dict()

    for sample in progress:
        torch.cuda.empty_cache()

        sample = move_to_cuda(sample) if use_cuda else sample
        if 'net_input' not in sample:
            continue

        gen_timer.start()
        hypos = task.inference_step(generator,
                                    models,
                                    sample,
                                    bos_token=tgt_dict.bos())
        num_generated_tokens = sum(len(h[0]['tokens'])
                                   for h in hypos)  # TODO: warning
        gen_timer.stop(num_generated_tokens)

        for i, sample_id in enumerate(sample['id'].tolist()):
            has_target = sample['target'] is not None

            # Remove padding
            src_tokens = utils.strip_pad(
                sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
            target_tokens = None
            if has_target:
                target_tokens = utils.strip_pad(sample['target'][i, :],
                                                tgt_dict.pad()).int().cpu()

            hypos_tokens = utils.strip_eos(hypos[i][0]['tokens'],
                                           tgt_dict.eos()).int().cpu()
            # Either retrieve the original sentences or regenerate them from tokens.
            if src_dict is not None:
                src_str = src_dict.string(src_tokens,
                                          args['eval']['remove_bpe'])
            else:
                src_str = "0"
            if has_target:
                target_str = tgt_dict.string(target_tokens,
                                             args['eval']['remove_bpe'],
                                             escape_unk=True)

            hypo_str = tgt_dict.string(hypos_tokens,
                                       args['eval']['remove_bpe'])

            sources[sample_id] = [src_str]
            hypotheses[sample_id] = [hypo_str]
            references[sample_id] = [target_str]

    bleu, rouge_l, meteor = \
        summarization_metrics.eval_accuracies(hypotheses, references, filename=out_file, mode='test')
    LOGGER.info('BLEU: {:.2f}\t ROUGE-L: {:.2f}\t METEOR: {:.2f}'.format(
        bleu, rouge_l, meteor))
示例#8
0
def _main(args, output_file):
    if args['dataset']['max_tokens'] is None and args['dataset'][
            'max_sentences'] is None:
        args['dataset']['max_tokens'] = 12000
    LOGGER.info(args)

    use_cuda = torch.cuda.is_available() and not args['common']['cpu']

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args['dataset']['gen_subset'])

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    LOGGER.info('loading model(s) from {}'.format(args['eval']['path']))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        utils.split_paths(args['eval']['path']),
        arg_overrides=eval(args['eval']['model_overrides']),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None
            if args['eval']['no_beamable_mm'] else args['eval']['beam'],
            need_attn=args['eval']['print_alignment'],
        )
        if _model_args['common']['fp16']:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args['eval']['replace_unk'])

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args['dataset']['gen_subset']),
        max_tokens=args['dataset']['max_tokens'],
        max_sentences=args['eval']['max_sentences'],
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=_model_args['dataset']
        ['skip_invalid_size_inputs_valid_test'],
        required_batch_size_multiple=_model_args['dataset']
        ['required_batch_size_multiple'],
        num_shards=_model_args['dataset']['num_shards'],
        shard_id=_model_args['dataset']['shard_id'],
        num_workers=_model_args['dataset']['num_workers'],
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=_model_args['common']['log_format'],
        log_interval=_model_args['common']['log_interval'],
        default_log_format=('tqdm'
                            if not _model_args['common']['no_progress_bar']
                            else 'none'),
    )

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    num_sentences = 0
    has_target = True
    wps_meter = TimeMeter()
    # for sample in tqdm(progress, total=len(progress)):
    sources, hypotheses, references = dict(), dict(), dict()

    for sample in progress:
        torch.cuda.empty_cache()
        sample = utils.move_to_cuda(sample) if use_cuda else sample
        if 'net_input' not in sample:
            continue

        # prefix_tokens = None
        # if args['eval']['prefix_size'] > 0:
        #     prefix_tokens = sample['target'][:, :args['eval']['prefix_size']]

        gen_timer.start()
        hypos = task.inference_step(generator, models, sample)
        # gen_out = task.sequence_generator.generate(model, sample)
        num_generated_tokens = sum(len(h[0]['tokens'])
                                   for h in hypos)  # TODO: warning
        gen_timer.stop(num_generated_tokens)

        for i, sample_id in enumerate(sample['id'].tolist()):
            has_target = sample['target'] is not None

            # Remove padding
            src_tokens = utils.strip_pad(
                sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
            target_tokens = None
            if has_target:
                target_tokens = utils.strip_pad(sample['target'][i, :],
                                                tgt_dict.pad()).int().cpu()

            hypos_tokens = utils.strip_eos(hypos[i][0]['tokens'],
                                           tgt_dict.eos()).int().cpu()
            # Either retrieve the original sentences or regenerate them from tokens.
            # if align_dict is not None:
            #     src_str = task.dataset(args['dataset']['gen_subset']).src.get_original_text(sample_id)
            #     target_str = task.dataset(args['dataset']['gen_subset']).tgt.get_original_text(sample_id)
            # else:
            if src_dict is not None:
                src_str = src_dict.string(src_tokens,
                                          args['eval']['remove_bpe'])
            else:
                src_str = ""
            if has_target:
                target_str = tgt_dict.string(target_tokens,
                                             args['eval']['remove_bpe'],
                                             escape_unk=True)

            # hypo_tokens = tgt_dict.encode_line(hypo_str, add_if_not_exist=True)
            hypo_str = tgt_dict.string(hypos_tokens,
                                       args['eval']['remove_bpe'])

            sources[sample_id] = [src_str]
            hypotheses[sample_id] = [hypo_str]
            references[sample_id] = [target_str]

            if not args['eval']['quiet']:
                if src_dict is not None:
                    print('S-{}\t{}'.format(sample_id, src_str),
                          file=output_file)
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str),
                          file=output_file)

                print('H-{}\t{}'.format(sample_id, hypo_str), file=output_file)

    filename = os.path.join(os.path.dirname(__file__), 'config',
                            'predict.json')
    LOGGER.info('write predicted file at {}'.format(filename))
    bleu, rouge_l, meteor = eval_utils.eval_accuracies(hypotheses,
                                                       references,
                                                       filename=filename,
                                                       mode='test')
    LOGGER.info('BLEU: {:.2f}\t ROUGE-L: {:.2f}\t METEOR: {:.2f}'.format(
        bleu, rouge_l, meteor))
示例#9
0
    def generate(self, models, sample, **kwargs):
        """Score a batch of translations."""
        net_input = sample['net_input']

        def batch_for_softmax(dec_out, target):
            # assumes decoder_out[0] is the only thing needed (may not be correct for future models!)
            first, rest = dec_out[0], dec_out[1:]
            bsz, tsz, dim = first.shape
            if bsz * tsz < self.softmax_batch:
                yield dec_out, target, True
            else:
                flat = first.contiguous().view(1, -1, dim)
                flat_tgt = target.contiguous().view(flat.shape[:-1])
                s = 0
                while s < flat.size(1):
                    e = s + self.softmax_batch
                    yield (flat[:, s:e], ) + rest, flat_tgt[:, s:e], False
                    s = e

        def gather_target_probs(probs, target):
            probs = probs.gather(
                dim=2,
                index=target.unsqueeze(-1),
            )
            return probs

        orig_target = sample['target']

        # compute scores for each model in the ensemble
        avg_probs = None
        avg_attn = None
        for model in models:
            model.eval()
            decoder_out = model(**net_input)
            attn = decoder_out[1] if len(decoder_out) > 1 else None
            if type(attn) is dict:
                attn = attn.get('attn', None)

            batched = batch_for_softmax(decoder_out, orig_target)
            probs, idx = None, 0
            for bd, tgt, is_single in batched:
                sample['target'] = tgt
                curr_prob = model.get_normalized_probs(
                    bd, log_probs=len(models) == 1, sample=sample).data
                if is_single:
                    probs = gather_target_probs(curr_prob, orig_target)
                else:
                    if probs is None:
                        probs = curr_prob.new(orig_target.numel())
                    step = curr_prob.size(0) * curr_prob.size(1)
                    end = step + idx
                    tgt_probs = gather_target_probs(
                        curr_prob.view(tgt.shape + (curr_prob.size(-1), )),
                        tgt)
                    probs[idx:end] = tgt_probs.view(-1)
                    idx = end
                sample['target'] = orig_target

            probs = probs.view(sample['target'].shape)

            if avg_probs is None:
                avg_probs = probs
            else:
                avg_probs.add_(probs)
            if attn is not None and torch.is_tensor(attn):
                attn = attn.data
                if avg_attn is None:
                    avg_attn = attn
                else:
                    avg_attn.add_(attn)
        if len(models) > 1:
            avg_probs.div_(len(models))
            avg_probs.log_()
            if avg_attn is not None:
                avg_attn.div_(len(models))

        bsz = avg_probs.size(0)
        hypos = []
        start_idxs = sample[
            'start_indices'] if 'start_indices' in sample else [0] * bsz
        for i in range(bsz):
            # remove padding from ref
            ref = utils.strip_pad(sample['target'][i, start_idxs[i]:], self.pad) \
                if sample['target'] is not None else None
            tgt_len = ref.numel()
            avg_probs_i = avg_probs[i][start_idxs[i]:start_idxs[i] + tgt_len]
            score_i = avg_probs_i.sum() / tgt_len
            if avg_attn is not None:
                avg_attn_i = avg_attn[i]
                if self.compute_alignment:
                    alignment = utils.extract_hard_alignment(
                        avg_attn_i,
                        sample['net_input']['src_tokens'][i],
                        sample['target'][i],
                        self.pad,
                        self.eos,
                    )
                else:
                    alignment = None
            else:
                avg_attn_i = alignment = None
            hypos.append([{
                'tokens': ref,
                'score': score_i,
                'attention': avg_attn_i,
                'alignment': alignment,
                'positional_scores': avg_probs_i,
            }])
        return hypos
示例#10
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.
        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        lm_logits, output_metadata = model(**sample["net_input"])

        # reshape lm_logits from (N,T,C) to (N*T,C)
        lm_logits = lm_logits.view(-1, lm_logits.size(-1))
        lm_targets = sample['lm_target'].view(-1)
        lm_loss = compute_cross_entropy_loss(lm_logits, lm_targets,
                                             self.padding_idx)

        # compute the number of tokens for which loss is computed. This is used
        # to normalize the loss
        ntokens = utils.strip_pad(lm_targets, self.padding_idx).numel()
        loss = lm_loss / ntokens
        nsentences = sample['nsentences']
        # nsentences = 0

        # Compute sentence loss if masked_lm_only is False
        sentence_loss = None
        if not self.masked_lm_only:
            sentence_logits = output_metadata['sentence_logits']
            sentence_targets = sample['sentence_target'].view(-1)
            # This needs to be recomputed due to some differences between
            # TokenBlock and BlockPair dataset. This can be resolved with a
            # refactor of BERTModel which we will do in the future.
            # TODO: Remove this after refactor of BERTModel
            nsentences = sentence_targets.size(0)

            # Check for logits being none which can happen when remove_heads
            # is set to true in the BERT model. Ideally we should set
            # masked_lm_only to true in this case, but that requires some
            # refactor in the BERT model.
            if sentence_logits is not None:
                sentence_loss = compute_cross_entropy_loss(
                    sentence_logits, sentence_targets)

                loss += self.nsp_loss_weight * (sentence_loss / nsentences)

        # NOTE: as we are summing up per token mlm loss and per sentence nsp loss
        # we don't need to use sample_size as denominator for the gradient
        # here sample_size is just used for logging
        sample_size = 1
        logging_output = {
            'loss':
            utils.item(loss.data) if reduce else loss.data,
            'lm_loss':
            utils.item(lm_loss.data) if reduce else lm_loss.data,
            # sentence loss is not always computed
            'sentence_loss':
            ((utils.item(sentence_loss.data) if reduce else sentence_loss.data)
             if sentence_loss is not None else 0.0),
            'ntokens':
            ntokens,
            'nsentences':
            nsentences,
            'sample_size':
            sample_size,
        }
        return loss, sample_size, logging_output
示例#11
0
    def complete(self, models, sample, predict_type, **kwargs):
        """Score a batch of translations."""
        net_input = sample['net_input']
        # node_id = sample['node_ids']
        def gather_target_probs(probs, target):
            probs = probs.gather(
                dim=2,
                index=target.unsqueeze(-1),
            )
            return probs

        avg_probs = None
        avg_curr_probs = None

        for model in models:
            model.eval()
            decoder_out = model(**net_input)
            curr_prob = model.get_normalized_probs(decoder_out, log_probs=len(models) == 1, sample=sample).data

            probs = gather_target_probs(curr_prob, sample['target'])
            probs = probs.view(sample['target'].shape)

            if avg_probs is None:
                avg_probs = probs
            else:
                avg_probs.add_(probs)

            if avg_curr_probs is None:
                avg_curr_probs = curr_prob
            else:
                avg_curr_probs.add_(curr_prob)

        if len(models) > 1:
            avg_probs.div_(len(models))
            avg_probs.log_()
            avg_curr_probs.div_(len(models))
            avg_curr_probs.log_()

        bsz = avg_probs.size(0)
        hypos = []
        start_idxs = sample['start_indices'] if 'start_indices' in sample else [0] * bsz
        # mask = sample['target'] != self.pad
        selected = sample['node_ids'][predict_type]

        for i in range(bsz):
            # remove padding from ref
            ref = utils.strip_pad(sample['target'][i, start_idxs[i]:], self.pad) \
                if sample['target'] is not None else None
            tgt_len = ref.numel()
            avg_probs_i = avg_probs[i][start_idxs[i]:start_idxs[i] + tgt_len]
            score_i = avg_probs_i.sum() / tgt_len

            lprob = avg_curr_probs[i]

            if selected[i]:
                selected_prob = lprob[selected[i]].contiguous()
                rank = torch.argmax(selected_prob, 1)
                mrr = np.mean([1. / (r.item() + 1) for r in rank.view(-1)])

                ncorrect = torch.sum(rank == sample['target'][i][selected[i]].contiguous())
                accuracy = ncorrect / sum(selected[i])

                hypos.append([{
                    'tokens': ref,
                    'score': score_i,
                    'positional_scores': avg_probs_i,
                    'accuracy': accuracy,
                    'mrr': mrr,
                }])
            else:
                hypos.append([{
                    'tokens': ref,
                    'score': score_i,
                    'positional_scores': avg_probs_i,
                    'accuracy': 0.0,
                    'mrr': 0.0,
                }])
        return hypos
示例#12
0
def _main(args, output_file):
    if args['dataset']['max_tokens'] is None and args['dataset'][
            'max_sentences'] is None:
        args['dataset']['max_tokens'] = 12000
    LOGGER.info(args)

    use_cuda = torch.cuda.is_available() and not args['common']['cpu']

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args['dataset']['gen_subset'])

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    LOGGER.info('loading model(s) from {}'.format(args['eval']['path']))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        utils.split_paths(args['eval']['path']),
        arg_overrides=eval(args['eval']['model_overrides']),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None
            if args['eval']['no_beamable_mm'] else args['eval']['beam'],
            need_attn=args['eval']['print_alignment'],
        )
        if _model_args['common']['fp16']:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args['eval']['replace_unk'])

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args['dataset']['gen_subset']),
        max_tokens=args['dataset']['max_tokens'],
        max_sentences=args['eval']['max_sentences'],
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=_model_args['dataset']
        ['skip_invalid_size_inputs_valid_test'],
        required_batch_size_multiple=_model_args['dataset']
        ['required_batch_size_multiple'],
        num_shards=_model_args['dataset']['num_shards'],
        shard_id=_model_args['dataset']['shard_id'],
        num_workers=_model_args['dataset']['num_workers'],
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=_model_args['common']['log_format'],
        log_interval=_model_args['common']['log_interval'],
        default_log_format=('tqdm'
                            if not _model_args['common']['no_progress_bar']
                            else 'none'),
    )

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # Generate and compute BLEU score
    scorer = OrderedDict()
    if args['eval']['sacrebleu']:
        scorer['bleu'] = bleu_scorer.SacrebleuScorer()
    elif args['eval']['nltk_bleu']:
        scorer['bleu'] = bleu_scorer.NLTKBleuScorer()
    else:
        scorer['bleu'] = bleu_scorer.Scorer(tgt_dict.pad(), tgt_dict.eos(),
                                            tgt_dict.unk())
    # Generate and compute BLEU score
    if args['eval']['rouge']:
        scorer['rouge'] = rouge_scorer.RougeScorer()
    num_sentences = 0
    has_target = True
    wps_meter = TimeMeter()
    # for sample in tqdm(progress, total=len(progress)):
    for sample in progress:
        torch.cuda.empty_cache()
        sample = utils.move_to_cuda(sample) if use_cuda else sample
        if 'net_input' not in sample:
            continue

        prefix_tokens = None
        if args['eval']['prefix_size'] > 0:
            prefix_tokens = sample['target'][:, :args['eval']['prefix_size']]

        gen_timer.start()
        hypos = task.inference_step(generator, models, sample, prefix_tokens)
        num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
        gen_timer.stop(num_generated_tokens)

        for i, sample_id in enumerate(sample['id'].tolist()):
            has_target = sample['target'] is not None

            # Remove padding
            src_tokens = utils.strip_pad(
                sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
            target_tokens = None
            if has_target:
                target_tokens = utils.strip_pad(sample['target'][i, :],
                                                tgt_dict.pad()).int().cpu()

            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = task.dataset(
                    args['dataset']['gen_subset']).src.get_original_text(
                        sample_id)
                target_str = task.dataset(
                    args['dataset']['gen_subset']).tgt.get_original_text(
                        sample_id)
            else:
                if src_dict is not None:
                    src_str = src_dict.string(src_tokens,
                                              args['eval']['remove_bpe'])
                else:
                    src_str = ""
                if has_target:
                    target_str = tgt_dict.string(target_tokens,
                                                 args['eval']['remove_bpe'],
                                                 escape_unk=True)

            if not args['eval']['quiet']:
                if src_dict is not None:
                    print('S-{}\t{}'.format(sample_id, src_str),
                          file=output_file)
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str),
                          file=output_file)

            # Process top predictions
            for j, hypo in enumerate(hypos[i][:args['eval']['nbest']]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'],
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args['eval']['remove_bpe'],
                )

                if hypo_str == '.':
                    # rouge cannot handle hypo'.'
                    continue

                if not args['eval']['quiet']:
                    score = hypo['score'] / math.log(2)  # convert to base 2
                    print('H-{}\t{}\t{}'.format(sample_id, score, hypo_str),
                          file=output_file)
                    print(
                        'P-{}\t{}'.format(
                            sample_id,
                            ' '.join(
                                map(
                                    lambda x: '{:.4f}'.format(x),
                                    # convert from base e to base 2
                                    hypo['positional_scores'].div_(math.log(2)
                                                                   ).tolist(),
                                ))),
                        file=output_file)

                    if args['eval']['print_alignment']:
                        print('A-{}\t{}'.format(
                            sample_id, ' '.join([
                                '{}-{}'.format(src_idx, tgt_idx)
                                for src_idx, tgt_idx in alignment
                            ])),
                              file=output_file)

                    if args['eval']['print_step']:
                        print('I-{}\t{}'.format(sample_id, hypo['steps']),
                              file=output_file)

                    # if getattr(args, 'retain_iter_history', False):
                    if args['eval']['retain_iter_history']:
                        for step, h in enumerate(hypo['history']):
                            _, h_str, _ = utils.post_process_prediction(
                                hypo_tokens=h['tokens'].int().cpu(),
                                src_str=src_str,
                                alignment=None,
                                align_dict=None,
                                tgt_dict=tgt_dict,
                                remove_bpe=None,
                            )
                            print('E-{}_{}\t{}'.format(sample_id, step, h_str),
                                  file=output_file)

                # Score only the top hypothesis
                if has_target and j == 0:
                    # print('Ref>> {}'.format(target_str), file=output_file)
                    # print('Hyp>> {}'.format(hypo_str), file=output_file)
                    if align_dict is not None or args['eval'][
                            'remove_bpe'] is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tgt_dict.encode_line(
                            target_str, add_if_not_exist=True)
                    for metric in scorer:
                        if hasattr(scorer[metric], 'add_string'):
                            scorer[metric].add_string(target_str, hypo_str)
                        else:
                            scorer[metric].add(target_tokens, hypo_tokens)

        wps_meter.update(num_generated_tokens)
        progress.log({'wps': round(wps_meter.avg)})
        num_sentences += sample['nsentences']

    LOGGER.info('NOTE: hypothesis and token scores are output in base 2')
    LOGGER.info(
        'Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        LOGGER.info('Generate {} with beam={}: {}'.format(
            args['dataset']['gen_subset'], args['eval']['beam'], {
                '\n{}:\n{}'.format(str.upper(metric), value.score())
                for metric, value in scorer.items()
            }))

    return scorer