예제 #1
0
 def decode_output(self, output):
     output = output[0][0]['tokens']
     output = utils.strip_eos(output, self.tgt_dict.eos())
     output = self.tgt_dict.string(output)
     if not str.endswith(output, "."):
         output += "."
     return output
    def valid_step(self, sample, model, criterion):
        loss, sample_size, logging_output = super().valid_step(
            sample, model, criterion)

        if self.args['task']['eval_bleu'] or self.args['task']['eval_rouge']:

            def decode(toks, escape_unk=False, trunc_eos=True):
                s = self.tgt_dict.string(
                    toks.int().cpu(),
                    self.args['task']['eval_bleu_remove_bpe'],
                    escape_unk=escape_unk,
                    trunc_eos=trunc_eos,
                )
                if len(s) == 0:
                    s = '0'  # if predict sentence is null, use '0'
                if self.tokenizer:
                    s = self.tokenizer.decode(s)
                return s

            # gen_out = self.inference_step(generator, [model], sample, None)
            gen_out = self.sequence_generator.generate([model], sample)
            ids = sample['id'].tolist()
            hyps, refs = [], []
            for i in range(len(gen_out)):
                # hyps.append(decode(gen_out[i][0]['tokens']))
                hyps.append(
                    decode(
                        utils.strip_eos(gen_out[i][0]['tokens'],
                                        self.tgt_dict.eos())))
                refs.append(
                    decode(
                        utils.strip_pad(sample['target'][i],
                                        self.tgt_dict.pad()),
                        escape_unk=
                        True,  # don't count <unk> as matches to the hypo
                    ))

            bleu, rouge_l, meteor = self._inference_score(hyps, refs, ids)
            logging_output['bleu'] = bleu
            logging_output['rouge_l'] = rouge_l
            logging_output['meteor'] = meteor

        return loss, sample_size, logging_output
예제 #3
0
def main(args, out_file=None):
    use_cuda = torch.cuda.is_available() and not args['common']['cpu']

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args['dataset']['gen_subset'])

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Load ensemble
    LOGGER.info('loading model(s) from {}'.format(args['eval']['path']))
    models, _ = checkpoint_utils.load_model_ensemble(
        utils.split_paths(args['eval']['path']),
        arg_overrides=eval(args['eval']['model_overrides']),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None
            if args['eval']['no_beamable_mm'] else args['eval']['beam'],
            need_attn=args['eval']['print_alignment'],
        )

        if use_cuda:
            device = os.environ.get('CUDA_VISIBALE_DEVICES',
                                    [0])[0]  # get first device as default
            torch.cuda.set_device(f'cuda:{device}')
            model = model.cuda()
        if args['common']['fp16'] and use_cuda:
            model.half()

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args['dataset']['gen_subset']),
        max_tokens=args['dataset']['max_tokens'],
        max_sentences=args['eval']['max_sentences'],
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args['dataset']
        ['skip_invalid_size_inputs_valid_test'],
        required_batch_size_multiple=args['dataset']
        ['required_batch_size_multiple'],
        num_shards=args['dataset']['num_shards'],
        shard_id=args['dataset']['shard_id'],
        num_workers=args['dataset']['num_workers'],
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args['common']['log_format'],
        log_interval=args['common']['log_interval'],
        default_log_format=('tqdm' if not args['common']['no_progress_bar']
                            else 'none'),
    )

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(models, args)

    sources, hypotheses, references = dict(), dict(), dict()

    for sample in progress:
        torch.cuda.empty_cache()

        sample = move_to_cuda(sample) if use_cuda else sample
        if 'net_input' not in sample:
            continue

        gen_timer.start()
        hypos = task.inference_step(generator,
                                    models,
                                    sample,
                                    bos_token=tgt_dict.bos())
        num_generated_tokens = sum(len(h[0]['tokens'])
                                   for h in hypos)  # TODO: warning
        gen_timer.stop(num_generated_tokens)

        for i, sample_id in enumerate(sample['id'].tolist()):
            has_target = sample['target'] is not None

            # Remove padding
            src_tokens = utils.strip_pad(
                sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
            target_tokens = None
            if has_target:
                target_tokens = utils.strip_pad(sample['target'][i, :],
                                                tgt_dict.pad()).int().cpu()

            hypos_tokens = utils.strip_eos(hypos[i][0]['tokens'],
                                           tgt_dict.eos()).int().cpu()
            # Either retrieve the original sentences or regenerate them from tokens.
            if src_dict is not None:
                src_str = src_dict.string(src_tokens,
                                          args['eval']['remove_bpe'])
            else:
                src_str = "0"
            if has_target:
                target_str = tgt_dict.string(target_tokens,
                                             args['eval']['remove_bpe'],
                                             escape_unk=True)

            hypo_str = tgt_dict.string(hypos_tokens,
                                       args['eval']['remove_bpe'])

            sources[sample_id] = [src_str]
            hypotheses[sample_id] = [hypo_str]
            references[sample_id] = [target_str]

    bleu, rouge_l, meteor = \
        summarization_metrics.eval_accuracies(hypotheses, references, filename=out_file, mode='test')
    LOGGER.info('BLEU: {:.2f}\t ROUGE-L: {:.2f}\t METEOR: {:.2f}'.format(
        bleu, rouge_l, meteor))
예제 #4
0
def _main(args, output_file):
    if args['dataset']['max_tokens'] is None and args['dataset'][
            'max_sentences'] is None:
        args['dataset']['max_tokens'] = 12000
    LOGGER.info(args)

    use_cuda = torch.cuda.is_available() and not args['common']['cpu']

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args['dataset']['gen_subset'])

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    LOGGER.info('loading model(s) from {}'.format(args['eval']['path']))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        utils.split_paths(args['eval']['path']),
        arg_overrides=eval(args['eval']['model_overrides']),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None
            if args['eval']['no_beamable_mm'] else args['eval']['beam'],
            need_attn=args['eval']['print_alignment'],
        )
        if _model_args['common']['fp16']:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args['eval']['replace_unk'])

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args['dataset']['gen_subset']),
        max_tokens=args['dataset']['max_tokens'],
        max_sentences=args['eval']['max_sentences'],
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=_model_args['dataset']
        ['skip_invalid_size_inputs_valid_test'],
        required_batch_size_multiple=_model_args['dataset']
        ['required_batch_size_multiple'],
        num_shards=_model_args['dataset']['num_shards'],
        shard_id=_model_args['dataset']['shard_id'],
        num_workers=_model_args['dataset']['num_workers'],
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=_model_args['common']['log_format'],
        log_interval=_model_args['common']['log_interval'],
        default_log_format=('tqdm'
                            if not _model_args['common']['no_progress_bar']
                            else 'none'),
    )

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    num_sentences = 0
    has_target = True
    wps_meter = TimeMeter()
    # for sample in tqdm(progress, total=len(progress)):
    sources, hypotheses, references = dict(), dict(), dict()

    for sample in progress:
        torch.cuda.empty_cache()
        sample = utils.move_to_cuda(sample) if use_cuda else sample
        if 'net_input' not in sample:
            continue

        # prefix_tokens = None
        # if args['eval']['prefix_size'] > 0:
        #     prefix_tokens = sample['target'][:, :args['eval']['prefix_size']]

        gen_timer.start()
        hypos = task.inference_step(generator, models, sample)
        # gen_out = task.sequence_generator.generate(model, sample)
        num_generated_tokens = sum(len(h[0]['tokens'])
                                   for h in hypos)  # TODO: warning
        gen_timer.stop(num_generated_tokens)

        for i, sample_id in enumerate(sample['id'].tolist()):
            has_target = sample['target'] is not None

            # Remove padding
            src_tokens = utils.strip_pad(
                sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
            target_tokens = None
            if has_target:
                target_tokens = utils.strip_pad(sample['target'][i, :],
                                                tgt_dict.pad()).int().cpu()

            hypos_tokens = utils.strip_eos(hypos[i][0]['tokens'],
                                           tgt_dict.eos()).int().cpu()
            # Either retrieve the original sentences or regenerate them from tokens.
            # if align_dict is not None:
            #     src_str = task.dataset(args['dataset']['gen_subset']).src.get_original_text(sample_id)
            #     target_str = task.dataset(args['dataset']['gen_subset']).tgt.get_original_text(sample_id)
            # else:
            if src_dict is not None:
                src_str = src_dict.string(src_tokens,
                                          args['eval']['remove_bpe'])
            else:
                src_str = ""
            if has_target:
                target_str = tgt_dict.string(target_tokens,
                                             args['eval']['remove_bpe'],
                                             escape_unk=True)

            # hypo_tokens = tgt_dict.encode_line(hypo_str, add_if_not_exist=True)
            hypo_str = tgt_dict.string(hypos_tokens,
                                       args['eval']['remove_bpe'])

            sources[sample_id] = [src_str]
            hypotheses[sample_id] = [hypo_str]
            references[sample_id] = [target_str]

            if not args['eval']['quiet']:
                if src_dict is not None:
                    print('S-{}\t{}'.format(sample_id, src_str),
                          file=output_file)
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str),
                          file=output_file)

                print('H-{}\t{}'.format(sample_id, hypo_str), file=output_file)

    filename = os.path.join(os.path.dirname(__file__), 'config',
                            'predict.json')
    LOGGER.info('write predicted file at {}'.format(filename))
    bleu, rouge_l, meteor = eval_utils.eval_accuracies(hypotheses,
                                                       references,
                                                       filename=filename,
                                                       mode='test')
    LOGGER.info('BLEU: {:.2f}\t ROUGE-L: {:.2f}\t METEOR: {:.2f}'.format(
        bleu, rouge_l, meteor))