def validate(args, trainer, task, epoch_itr, subsets):
    """Evaluate the model on the validation set(s) and return the losses."""
    valid_losses = []
    for subset in subsets:
        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=task.dataset(subset),
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences_valid,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=8,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.build_progress_bar(
            args, itr, epoch_itr.epoch,
            prefix='valid on \'{}\' subset'.format(subset),
            no_progress_bar='simple'
        )

        # reset validation loss meters
        for k in ['valid_loss', 'valid_nll_loss']:
            meter = trainer.get_meter(k)
            if meter is not None:
                meter.reset()
        extra_meters = collections.defaultdict(lambda: AverageMeter())

        for sample in progress:
            log_output = trainer.valid_step(sample)

            for k, v in log_output.items():
                if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
                    continue
                extra_meters[k].update(v)

        # log validation stats
        stats = get_valid_stats(trainer)
        for k, meter in extra_meters.items():
            stats[k] = meter.avg
        progress.print(stats)

        valid_losses.append(stats['valid_loss'])
    return valid_losses
Пример #2
0
def train(args, trainer, task, epoch_itr, summary_writer=None):
    """Train the model for one epoch."""

    # Update parameters every N batches
    if epoch_itr.epoch <= len(args.update_freq):
        update_freq = args.update_freq[epoch_itr.epoch - 1]
    else:
        update_freq = args.update_freq[-1]

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(fix_batches_to_gpus=args.fix_batches_to_gpus)
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args, itr, epoch_itr.epoch, no_progress_bar='simple',
    )

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    first_valid = args.valid_subset.split(',')[0]
    max_update = args.max_update or math.inf
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        log_output = trainer.train_step(samples)
        if log_output is None:
            continue

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
                continue  # these are already logged above
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats)

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

        num_updates = trainer.get_num_updates()
        if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0:
            valid_losses = validate(args, trainer, task, epoch_itr, [first_valid])
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        if num_updates % args.log_interval == 0:
            summary_writer.log_stats('train', stats, num_updates)

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)

    # reset training meters
    for k in [
        'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip',
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
Пример #3
0
def main(args, task=None, model_state=None):
    check_args(args)

    if args.max_tokens is None and args.batch_size is None:
        args.max_tokens = 4000000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    if task is None:
        # Load dataset splits
        task = tasks.setup_task(args)
        task.load_dataset(args.gen_subset)

        logger.info(
            "| {} {} {} examples".format(
                args.data, args.gen_subset, len(task.dataset(args.gen_subset))
            )
        )

    # Set dictionary
    tgt_dict = task.target_dictionary

    logger.info("| decoding with criterion {}".format(args.criterion))

    # Load ensemble

    if args.load_emissions:
        models, criterions = [], []
    else:
        logger.info("| loading model(s) from {}".format(args.path))
        models, criterions, _ = load_models_and_criterions(
            args.path,
            data_path=args.data,
            arg_overrides=eval(args.model_overrides),  # noqa
            task=task,
            model_state=model_state,
        )
        optimize_models(args, use_cuda, models)

    # hack to pass transitions to W2lDecoder
    if args.criterion == "asg_loss":
        trans = criterions[0].asg.trans.data
        args.asg_transitions = torch.flatten(trans).tolist()

    # Load dataset (possibly sharded)
    itr = get_dataset_itr(args, task, models)

    # Initialize generator
    gen_timer = StopwatchMeter()

    generator = task.build_generator(models, args)

    if args.load_emissions:
        generator = ExistingEmissionsDecoder(
            generator, np.load(args.load_emissions, allow_pickle=True)
        )
        logger.info("loaded emissions from " + args.load_emissions)

    num_sentences = 0

    if args.results_path is not None and not os.path.exists(args.results_path):
        os.makedirs(args.results_path)

    max_source_pos = (
        utils.resolve_max_positions(
            task.max_positions(), *[model.max_positions() for model in models]
        ),
    )

    if max_source_pos is not None:
        max_source_pos = max_source_pos[0]
        if max_source_pos is not None:
            max_source_pos = max_source_pos[0] - 1

    if args.dump_emissions:
        emissions = {}
    if args.dump_features:
        features = {}
        models[0].bert.proj = None
    else:
        res_files = prepare_result_files(args)
    errs_t = 0
    lengths_t = 0
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if "net_input" not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample["target"][:, : args.prefix_size]

            gen_timer.start()
            if args.dump_emissions:
                with torch.no_grad():
                    encoder_out = models[0](**sample["net_input"])
                    emm = models[0].get_normalized_probs(encoder_out, log_probs=True)
                    emm = emm.transpose(0, 1).cpu().numpy()
                    for i, id in enumerate(sample["id"]):
                        emissions[id.item()] = emm[i]
                    continue
            elif args.dump_features:
                with torch.no_grad():
                    encoder_out = models[0](**sample["net_input"])
                    feat = encoder_out["encoder_out"].transpose(0, 1).cpu().numpy()
                    for i, id in enumerate(sample["id"]):
                        padding = encoder_out["encoder_padding_mask"][i].cpu().numpy() if encoder_out["encoder_padding_mask"] is not None else None
                        features[id.item()] = (feat[i], padding)
                    continue
            hypos = task.inference_step(generator, models, sample, prefix_tokens)
            num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample["id"].tolist()):
                speaker = None
                # id = task.dataset(args.gen_subset).ids[int(sample_id)]
                id = sample_id
                toks = sample["target"][i, :] if 'target_label' not in sample else sample["target_label"][i, :]
                target_tokens = (
                    utils.strip_pad(toks, tgt_dict.pad()).int().cpu()
                )
                # Process top predictions
                errs, length = process_predictions(
                    args, hypos[i], None, tgt_dict, target_tokens, res_files, speaker, id
                )
                errs_t += errs
                lengths_t += length

            wps_meter.update(num_generated_tokens)
            t.log({"wps": round(wps_meter.avg)})
            num_sentences += sample["nsentences"] if "nsentences" in sample else sample["id"].numel()

    wer = None
    if args.dump_emissions:
        emm_arr = []
        for i in range(len(emissions)):
            emm_arr.append(emissions[i])
        np.save(args.dump_emissions, emm_arr)
        logger.info(f"saved {len(emissions)} emissions to {args.dump_emissions}")
    elif args.dump_features:
        feat_arr = []
        for i in range(len(features)):
            feat_arr.append(features[i])
        np.save(args.dump_features, feat_arr)
        logger.info(f"saved {len(features)} emissions to {args.dump_features}")
    else:
        if lengths_t > 0:
            wer = errs_t * 100.0 / lengths_t
            logger.info(f"WER: {wer}")

        logger.info(
            "| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
            "sentences/s, {:.2f} tokens/s)".format(
                num_sentences,
                gen_timer.n,
                gen_timer.sum,
                num_sentences / gen_timer.sum,
                1.0 / gen_timer.avg,
                )
        )
        logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam))
    return task, wer
Пример #4
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(':'),
        arg_overrides=eval(args.model_overrides),
        task=task,
        bert_ratio=args.bert_ratio if args.change_ratio else None,
        encoder_ratio=args.encoder_ratio if args.change_ratio else None,
        geargs=args,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(
                generator, models, sample,
                prefix_tokens)  # batchsize, beamsize, dict
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                # Remove padding
                src_tokens = utils.strip_pad(
                    sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(
                        sample['target'][i, :], tgt_dict.pad()).int().cpu()

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(
                        args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(
                        args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

                if not args.quiet:
                    if src_dict is not None:
                        print('S-{}\t{}'.format(sample_id, src_str))
                    if has_target:
                        print('T-{}\t{}'.format(sample_id, target_str))

                # Process top predictions
                for i, hypo in enumerate(
                        hypos[i][:min(len(hypos), args.nbest)]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'].int().cpu()
                        if hypo['alignment'] is not None else None,
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    if not args.quiet:
                        print('H-{}\t{}\t{}'.format(sample_id, hypo['score'],
                                                    hypo_str))
                        print('P-{}\t{}'.format(
                            sample_id, ' '.join(
                                map(
                                    lambda x: '{:.4f}'.format(x),
                                    hypo['positional_scores'].tolist(),
                                ))))

                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id, ' '.join(
                                    map(lambda x: str(utils.item(x)),
                                        alignment))))

                    # Score only the top hypothesis
                    if has_target and i == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tgt_dict.encode_line(
                                target_str, add_if_not_exist=True)
                        if hasattr(scorer, 'add_string'):
                            scorer.add_string(target_str, hypo_str)
                        else:
                            scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(num_generated_tokens)
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += sample['nsentences']

    print(
        '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset,
                                                      args.beam,
                                                      scorer.result_string()))
    return scorer
Пример #5
0
def main(parsed_args):
    assert parsed_args.path is not None, '--path required for evaluation!'

    utils.import_user_module(parsed_args)

    print(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load ensemble
    print('| loading model(s) from {}'.format(parsed_args.path))
    models, args = checkpoint_utils.load_model_ensemble(
        parsed_args.path.split(':'),
        arg_overrides=eval(parsed_args.model_overrides),
        task=task,
    )

    for arg in vars(parsed_args).keys():
        if arg not in {
                'self_target',
                'future_target',
                'past_target',
                'tokens_per_sample',
                'output_size_dictionary',
                'add_bos_token',
        }:
            setattr(args, arg, getattr(parsed_args, arg))

    # reduce tokens per sample by the required context window size
    args.tokens_per_sample -= args.context_window
    task = tasks.setup_task(args)

    # Load dataset splits
    task.load_dataset(args.gen_subset)
    dataset = task.dataset(args.gen_subset)
    if args.context_window > 0:
        dataset = LMContextWindowDataset(
            dataset=dataset,
            tokens_per_sample=args.tokens_per_sample,
            context_window=args.context_window,
            pad_idx=task.source_dictionary.pad(),
        )
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(dataset)))

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
        if args.fp16:
            model.half()
            for name, _module in model.named_modules():
                if 'layer_norm' in name:
                    _module.float()
        if use_cuda:
            model.cuda()

    assert len(models) > 0

    print('num. model params: {}'.format(
        sum(p.numel() for p in models[0].parameters())))

    itr = task.get_batch_iterator(
        dataset=dataset,
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=True,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(task.target_dictionary, args.softmax_batch)

    score_sum = 0.
    count = 0

    if args.remove_bpe is not None:
        if args.remove_bpe == 'sentencepiece':
            raise NotImplementedError
        else:
            bpe_cont = args.remove_bpe.rstrip()
            bpe_toks = set(i for i in range(len(task.source_dictionary))
                           if task.source_dictionary[i].endswith(bpe_cont))
        bpe_len = len(bpe_cont)
    else:
        bpe_toks = None
        bpe_len = 0

    word_stats = dict()

    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()

        for sample in t:
            if 'net_input' not in sample:
                continue

            sample = utils.move_to_cuda(sample) if use_cuda else sample

            gen_timer.start()
            hypos = scorer.generate(models, sample)
            gen_timer.stop(sample['ntokens'])

            for i, hypos_i in enumerate(hypos):
                hypo = hypos_i[0]
                sample_id = sample['id'][i]

                tokens = hypo['tokens']
                tgt_len = tokens.numel()
                pos_scores = hypo['positional_scores'].float()

                if args.add_bos_token:
                    assert hypo['tokens'][0].item(
                    ) == task.target_dictionary.bos()
                    tokens = tokens[1:]
                    pos_scores = pos_scores[1:]

                skipped_toks = 0
                if bpe_toks is not None:
                    for i in range(tgt_len - 1):
                        if tokens[i].item() in bpe_toks:
                            skipped_toks += 1
                            pos_scores[i + 1] += pos_scores[i]
                            pos_scores[i] = 0

                inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(
                    float('-inf'))
                if inf_scores.any():
                    print(
                        '| Skipping tokens with inf scores:',
                        task.target_dictionary.string(
                            tokens[inf_scores.nonzero()]))
                    pos_scores = pos_scores[(~inf_scores).nonzero()]
                score_sum += pos_scores.sum().cpu()
                count += pos_scores.numel() - skipped_toks

                if args.output_word_probs or args.output_word_stats:
                    w = ''
                    word_prob = []
                    is_bpe = False
                    for i in range(len(tokens)):
                        w_ind = tokens[i].item()
                        w += task.source_dictionary[w_ind]
                        if bpe_toks is not None and w_ind in bpe_toks:
                            w = w[:-bpe_len]
                            is_bpe = True
                        else:
                            word_prob.append((w, pos_scores[i].item()))

                            next_prob = None
                            ind = i + 1
                            while ind < len(tokens):
                                if pos_scores[ind].item() != 0:
                                    next_prob = pos_scores[ind]
                                    break
                                ind += 1

                            word_stats.setdefault(w, WordStat(w, is_bpe)).add(
                                pos_scores[i].item(), next_prob)
                            is_bpe = False
                            w = ''
                    if args.output_word_probs:
                        print(
                            str(int(sample_id)) + " " +
                            ('\t'.join('{} [{:2f}]'.format(x[0], x[1])
                                       for x in word_prob)))

            wps_meter.update(sample['ntokens'])
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(
        gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss,
                                                      np.exp(avg_nll_loss)))

    if args.output_word_stats:
        for ws in sorted(word_stats.values(),
                         key=lambda x: x.count,
                         reverse=True):
            print(ws)
Пример #6
0
def eval_tune_loss(args, trainer, task, subset, extra_state):
    """Evaluate the model on the validation set and return the average loss."""
    # Initialize dataloader
    itr = task.get_batch_iterator(
        dataset=task.dataset(subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences_valid,
        max_positions=utils.resolve_max_positions(
            task.max_positions(), trainer.get_model().max_positions()
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.build_progress_bar(
        args=args,
        iterator=itr,
        epoch=extra_state["epoch"],
        prefix=f"valid on '{subset}' subset",
        no_progress_bar="simple",
    )

    # reset validation loss meters
    for k in ["valid_loss", "valid_nll_loss"]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()

    extra_meters = defaultdict(lambda: AverageMeter())
    for sample in progress:
        log_output = trainer.valid_step(sample)

        # log mid-validation stats
        stats = get_valid_stats(trainer)
        for k, v in log_output.items():
            if k in ["loss", "nll_loss", "ntokens", "nsentences", "sample_size"]:
                continue
            if "loss" in k:
                extra_meters[k].update(v, log_output["sample_size"])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats)

    # log validation stats
    stats = get_valid_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)

    extra_state["tune_eval"]["loss"] = stats["valid_loss"]
    extra_state["tune_eval"]["perplexity"] = stats["valid_ppl"]

    if (
        extra_state["tune_eval"]["lowest_loss"] is None
        or extra_state["tune_eval"]["loss"] < extra_state["tune_eval"]["lowest_loss"]
    ):
        extra_state["tune_eval"]["lowest_loss"] = extra_state["tune_eval"]["loss"]
        extra_state["tune_eval"]["num_since_best"] = 0
    else:
        extra_state["tune_eval"]["num_since_best"] += 1

    stop_due_to_tune_loss = False
    if (
        args.stop_no_best_validate_loss >= 0
        and extra_state["tune_eval"]["num_since_best"] > args.stop_no_best_validate_loss
    ):
        stop_due_to_tune_loss = True
        print(
            f"Stopping training due to eval tune loss stagnation - last best "
            f"eval tune loss of {extra_state['tune_eval']['lowest_loss']} "
            f"(current loss: {extra_state['tune_eval']['loss']}) "
            f"was {extra_state['tune_eval']['num_since_best']} validations ago."
        )
    return extra_state, stop_due_to_tune_loss
Пример #7
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(task.dataset(args.gen_subset))))

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _model_args = utils.load_ensemble_for_inference(
        args.path.split(':'),
        task,
        model_arg_overrides=eval(args.model_overrides),
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    if args.score_reference:
        translator = SequenceScorer(models, task.target_dictionary)
    else:
        translator = SequenceGenerator(
            models,
            task.target_dictionary,
            beam_size=args.beam,
            minlen=args.min_len,
            stop_early=(not args.no_early_stop),
            normalize_scores=(not args.unnormalized),
            len_penalty=args.lenpen,
            unk_penalty=args.unkpen,
            sampling=args.sampling,
            sampling_topk=args.sampling_topk,
            sampling_temperature=args.sampling_temperature,
            diverse_beam_groups=args.diverse_beam_groups,
            diverse_beam_strength=args.diverse_beam_strength,
            match_source_len=args.match_source_len,
            no_repeat_ngram_size=args.no_repeat_ngram_size,
        )

    if use_cuda:
        translator.cuda()

    # Generate and compute BLEU score
    scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    with progress_bar.build_progress_bar(args, itr) as t:
        if args.score_reference:
            translations = translator.score_batched_itr(t,
                                                        cuda=use_cuda,
                                                        timer=gen_timer)
        else:
            translations = translator.generate_batched_itr(
                t,
                maxlen_a=args.max_len_a,
                maxlen_b=args.max_len_b,
                cuda=use_cuda,
                timer=gen_timer,
                prefix_size=args.prefix_size,
            )

        wps_meter = TimeMeter()
        for sample_id, src_tokens, target_tokens, hypos in translations:
            # Process input and ground truth
            has_target = target_tokens is not None
            target_tokens = target_tokens.int().cpu() if has_target else None

            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = task.dataset(
                    args.gen_subset).src.get_original_text(sample_id)
                target_str = task.dataset(
                    args.gen_subset).tgt.get_original_text(sample_id)
            else:
                src_str = src_dict.string(src_tokens, args.remove_bpe)
                if has_target:
                    target_str = tgt_dict.string(target_tokens,
                                                 args.remove_bpe,
                                                 escape_unk=True)

            if not args.quiet:
                print('S-{}\t{}'.format(sample_id, src_str))
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str))

            # Process top predictions
            for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu()
                    if hypo['alignment'] is not None else None,
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                )

                if not args.quiet:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'],
                                                hypo_str))
                    print('P-{}\t{}'.format(
                        sample_id, ' '.join(
                            map(
                                lambda x: '{:.4f}'.format(x),
                                hypo['positional_scores'].tolist(),
                            ))))

                    if args.print_alignment:
                        print('A-{}\t{}'.format(
                            sample_id, ' '.join(
                                map(lambda x: str(utils.item(x)), alignment))))

                # Score only the top hypothesis
                if has_target and i == 0:
                    if align_dict is not None or args.remove_bpe is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tokenizer.Tokenizer.tokenize(
                            target_str, tgt_dict, add_if_not_exist=True)
                    scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += 1

    print(
        '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset,
                                                      args.beam,
                                                      scorer.result_string()))
Пример #8
0
def validate(args, trainer, dataset, subset, extra_state):
    """Evaluate the model on the validation set and return the average loss."""
    epoch = extra_state["epoch"]
    # Initialize dataloader
    max_positions_valid = (
        trainer.get_model().max_encoder_positions(),
        trainer.get_model().max_decoder_positions(),
    )
    itr = dataset.eval_dataloader(
        subset,
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences_valid,
        max_positions=max_positions_valid,
        skip_invalid_size_inputs_valid_test=args.
        skip_invalid_size_inputs_valid_test,
        descending=True,  # largest batch first to warm the caching allocator
        shard_id=args.distributed_rank,
        num_shards=args.distributed_world_size,
    )
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch,
        prefix=f"valid on '{subset}' subset",
        no_progress_bar="simple")

    # reset validation loss meters
    for k in ["valid_loss", "valid_nll_loss"]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    for sample in progress:
        log_output = trainer.valid_step(sample)

        # log mid-validation stats
        stats = get_valid_stats(trainer)
        for k, v in log_output.items():
            if k in ["loss", "nll_loss"]:
                continue
            if "loss" in k:
                extra_meters[k].update(v, log_output["sample_size"])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats)

    # log validation stats
    stats = get_valid_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)

    val_loss = stats["valid_loss"]
    val_ppl = stats["valid_ppl"]

    if ("validate" not in extra_state
            or val_loss < extra_state["validate"]["lowest_loss"]):
        extra_state["validate"] = {
            "lowest_loss": val_loss,
            "num_since_best": 0
        }
    else:
        extra_state["validate"]["num_since_best"] += 1

    stop_due_to_val_loss = False
    if (args.stop_no_best_validate_loss >= 0
            and extra_state["validate"]["num_since_best"] >
            args.stop_no_best_validate_loss):
        stop_due_to_val_loss = True
        print(
            f"Stopping training due to validation score stagnation - last best "
            f"validation loss of {extra_state['validate']['lowest_loss']} (current loss: {val_loss})"
            f"was {extra_state['validate']['num_since_best']} validations ago."
        )
    return val_loss, val_ppl, stop_due_to_val_loss
Пример #9
0
def validate(args, trainer, dataset, subset, epoch):
    """Evaluate the model on the validation set and return the average loss."""
    # Initialize dataloader
    max_positions_valid = (
        trainer.get_model().max_encoder_positions(),
        trainer.get_model().max_decoder_positions(),
    )
    itr = dataset.eval_dataloader(
        subset,
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences_valid,
        max_positions=max_positions_valid,
        skip_invalid_size_inputs_valid_test=args.
        skip_invalid_size_inputs_valid_test,
        descending=True,  # largest batch first to warm the caching allocator
        shard_id=args.distributed_rank,
        num_shards=args.distributed_world_size,
    )
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch,
        prefix=f'valid on \'{subset}\' subset',
        no_progress_bar='simple')

    # reset validation loss meters
    for k in ['valid_loss', 'valid_nll_loss']:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    for sample in progress:
        log_output = trainer.valid_step(sample)

        # log mid-validation stats
        stats = get_valid_stats(trainer)
        for k, v in log_output.items():
            if k in ['loss', 'nll_loss']:
                continue
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats)

    # log validation stats
    stats = get_valid_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)

    val_loss = stats['valid_loss']
    val_ppl = stats['valid_ppl']
    if not hasattr(validate, 'lowest_loss') or val_loss < validate.lowest_loss:
        validate.lowest_loss = val_loss
        validate.num_since_best = 0
    elif not hasattr(validate, 'num_since_best'):
        validate.num_since_best = 1
    else:
        validate.num_since_best += 1

    stop_due_to_val_loss = False
    if (args.stop_no_best_validate_loss >= 0
            and validate.num_since_best > args.stop_no_best_validate_loss):
        stop_due_to_val_loss = True
        print(
            f'Stopping training due to validation score stagnation - last best '
            f'validation loss of {validate.lowest_loss} (current loss: {val_loss})'
            f'was {validate.num_since_best} validations ago.')
    return val_loss, val_ppl, stop_due_to_val_loss
Пример #10
0
def setup_epoch(
    args,
    epoch,
    batch_offset,
    trainer,
    dataset,
):
    """Sets up data and progress meters for one epoch."""
    # Set seed based on args.seed and the epoch number so that we get
    # reproducible results when resuming from checkpoints
    seed = args.seed + epoch
    torch.manual_seed(seed)

    # The max number of positions can be different for train and valid
    # e.g., RNNs may support more positions at test time than seen in training
    max_positions_train = (min(
        args.max_source_positions,
        trainer.get_model().max_encoder_positions(),
    ),
                           min(
                               args.max_target_positions,
                               trainer.get_model().max_decoder_positions(),
                           ))

    # Initialize dataloader, starting at batch_offset
    itr = dataset.train_dataloader(
        args.train_subset,
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions_train,
        seed=seed,
        epoch=epoch,
        sample_without_replacement=args.sample_without_replacement,
        sort_by_source_size=(epoch <= args.curriculum),
        shard_id=args.distributed_rank,
        num_shards=args.distributed_world_size,
    )
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch,
        no_progress_bar='simple',
    )
    itr = itertools.islice(progress, batch_offset, None)

    # reset training meters
    for k in [
            'train_loss',
            'train_nll_loss',
            'wps',
            'ups',
            'wpb',
            'bsz',
            'clip',
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    return itr, progress, extra_meters
Пример #11
0
def _generate_score(models, args, dataset, dataset_split):
    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load ensemble
    if not args.quiet:
        print("| loading model(s) from {}".format(", ".join(args.path)))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam
        )

    # Initialize generator
    model_weights = None
    if args.model_weights:
        model_weights = [float(w.strip()) for w in args.model_weights.split(",")]
    use_char_source = isinstance(models[0], char_source_model.CharSourceModel)
    translator = beam_decode.SequenceGenerator(
        models,
        beam_size=args.beam,
        stop_early=(not args.no_early_stop),
        normalize_scores=(not args.unnormalized),
        len_penalty=args.lenpen,
        unk_penalty=args.unkpen,
        word_reward=args.word_reward,
        model_weights=model_weights,
        use_char_source=use_char_source,
    )
    if use_cuda:
        translator.cuda()
    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Keep track of translations
    # Initialize with empty translations
    translated_sentences = [""] * len(dataset.splits[dataset_split])

    # Generate and compute BLEU score
    scorer = bleu.Scorer(
        dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk()
    )
    max_positions = min(model.max_encoder_positions() for model in models)
    itr = dataset.eval_dataloader(
        dataset_split,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
        skip_invalid_size_inputs_valid_test=(args.skip_invalid_size_inputs_valid_test),
    )
    if args.num_shards > 1:
        if args.shard_id < 0 or args.shard_id >= args.num_shards:
            raise ValueError("--shard-id must be between 0 and num_shards")
        itr = data.sharded_iterator(itr, args.num_shards, args.shard_id)

    num_sentences = 0
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        # Keep more detailed timing when invoked from benchmark
        if "keep_detailed_timing" in args:
            gen_timer = pytorch_translate_utils.BucketStopwatchMeter(
                args.increment,
                args.max_length,
                args.samples_per_length,
            )
        else:
            gen_timer = StopwatchMeter()
        translations = translator.generate_batched_itr(
            t,
            maxlen_a=args.max_len_a,
            maxlen_b=args.max_len_b,
            cuda=use_cuda,
            timer=gen_timer,
        )
        for sample_id, src_tokens, target_tokens, hypos in translations:
            # Process input and ground truth
            target_tokens = target_tokens.int().cpu()
            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = dataset.splits[dataset_split].src.get_original_text(sample_id)
                target_str = dataset.splits[dataset_split].dst.get_original_text(
                    sample_id
                )
            else:
                src_str = dataset.src_dict.string(src_tokens, args.remove_bpe)
                target_str = dataset.dst_dict.string(
                    target_tokens, args.remove_bpe, escape_unk=True
                )

            if not args.quiet:
                print(f"S-{sample_id}\t{src_str}")
                print(f"T-{sample_id}\t{target_str}")

            # Process top predictions
            for i, hypo in enumerate(hypos[: min(len(hypos), args.nbest)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo["tokens"].int().cpu(),
                    src_str=src_str,
                    alignment=hypo["alignment"].int().cpu(),
                    align_dict=align_dict,
                    dst_dict=dataset.dst_dict,
                    remove_bpe=args.remove_bpe,
                )

                if not args.quiet:
                    print(f"H-{sample_id}\t{hypo['score']}\t{hypo_str}")
                    print(
                        "A-{}\t{}".format(
                            sample_id,
                            " ".join(map(lambda x: str(utils.item(x)), alignment)),
                        )
                    )

                # Score only the top hypothesis
                if i == 0:
                    if align_dict is not None or args.remove_bpe is not None:
                        # Convert back to tokens for evaluation with unk replacement
                        # and/or without BPE
                        target_tokens = tokenizer.Tokenizer.tokenize(
                            target_str, dataset.dst_dict, add_if_not_exist=True
                        )
                    scorer.add(target_tokens, hypo_tokens)
                    translated_sentences[sample_id] = hypo_str

            wps_meter.update(src_tokens.size(0))
            t.log({"wps": round(wps_meter.avg)})
            num_sentences += 1

    # If applicable, save the translations to the output file
    # For eg. external evaluation
    if getattr(args, "translation_output_file", False):
        with open(args.translation_output_file, 'w') as out_file:
            for hypo_str in translated_sentences:
                print(hypo_str, file=out_file)

    return scorer, num_sentences, gen_timer
Пример #12
0
def validate_translation(args, trainer, task, epoch_itr, generator):
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary
    models = [trainer.get_model()]
    if hasattr(task, 'eval_lang_pairs'):
        bleu_dict = {key: None for key in task.eval_lang_pairs}

        # Generate and compute BLEU score
        if args.sacrebleu:
            scorer_dict = {
                key: bleu.SacrebleuScorer()
                for key in task.eval_lang_pairs
            }
        else:
            scorer_dict = {
                key: bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(),
                                 tgt_dict.unk())
                for key in task.eval_lang_pairs
            }

        itr = task.get_batch_iterator(
            dataset=task.dataset('valid'),
            max_tokens=args.max_tokens_valid,
            max_sentences=args.max_sentences_valid,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=args.required_batch_size_multiple,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            num_workers=args.num_workers,
            noskip=True,
        )[0].next_epoch_itr(shuffle=False)
        progress = progress_bar.build_progress_bar(args,
                                                   itr,
                                                   epoch_itr.epoch,
                                                   prefix='translate subset',
                                                   no_progress_bar='simple')

        num_sentences = 0
        has_target = True
        #with progress_bar.build_progress_bar(args, itr) as t:
        for samples in progress:
            if torch.cuda.is_available() and not args.cpu:
                samples = utils.move_to_cuda(samples)
            #if 'net_input' not in samples:
            #    continue

            prefix_tokens = None
            for key, sample in samples.items():
                hypos = task.inference_step(generator, models, sample,
                                            prefix_tokens)
                num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)

                for i, sample_id in enumerate(sample['id'].tolist()):
                    has_target = sample['target'] is not None

                    target_tokens = None
                    if has_target:
                        target_tokens = utils.strip_pad(
                            sample['target'][i, :],
                            tgt_dict.pad()).int().cpu()
                    # Remove padding
                    if args.sde:
                        src_tokens = target_tokens
                    else:
                        src_tokens = utils.strip_pad(
                            sample['net_input']['src_tokens'][i, :],
                            tgt_dict.pad())

                    # Either retrieve the original sentences or regenerate them from tokens.
                    #if src_dict is not None:
                    #    src_str = src_dict.string(src_tokens, args.remove_bpe)
                    #else:
                    #    src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

                    #if not args.quiet:
                    #    if src_dict is not None:
                    #        print('S-{}\t{}'.format(sample_id, src_str))
                    #    if has_target:
                    #        print('T-{}\t{}'.format(sample_id, target_str))

                    # Process top predictions
                    for j, hypo in enumerate(hypos[i][:args.nbest]):
                        hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                            hypo_tokens=hypo['tokens'].int().cpu(),
                            src_str="",
                            alignment=hypo['alignment'].int().cpu()
                            if hypo['alignment'] is not None else None,
                            align_dict=None,
                            tgt_dict=tgt_dict,
                            remove_bpe=args.remove_bpe,
                        )

                        #if not args.quiet:
                        #    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
                        #    print('P-{}\t{}'.format(
                        #        sample_id,
                        #        ' '.join(map(
                        #            lambda x: '{:.4f}'.format(x),
                        #            hypo['positional_scores'].tolist(),
                        #        ))
                        #    ))

                        #    if args.print_alignment:
                        #        print('A-{}\t{}'.format(
                        #            sample_id,
                        #            ' '.join(map(lambda x: str(utils.item(x)), alignment))
                        #        ))
                        #print(has_target, j, hypo_str)
                        # Score only the top hypothesis
                        if has_target and j == 0:
                            if args.remove_bpe is not None:
                                # Convert back to tokens for evaluation with unk replacement and/or without BPE
                                target_tokens = tgt_dict.encode_line(
                                    target_str, add_if_not_exist=True)
                            if hasattr(scorer_dict[key], 'add_string'):
                                scorer_dict[key].add_string(
                                    target_str, hypo_str)
                            else:
                                scorer_dict[key].add(target_tokens,
                                                     hypo_tokens)

                num_sentences += sample['nsentences']
        print("|valid tranlsated {} sentences".format(num_sentences))
        for key, scorer in scorer_dict.items():
            bleu_dict[key] = scorer.score()
    else:
        bleu_dict = {0: None}

        # Generate and compute BLEU score
        if args.sacrebleu:
            scorer_dict = {0: bleu.SacrebleuScorer()}
        else:
            scorer_dict = {
                0: bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
            }

        itr = task.get_batch_iterator(
            dataset=task.dataset('valid'),
            max_tokens=args.max_tokens_valid,
            max_sentences=args.max_sentences_valid,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=args.required_batch_size_multiple,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            num_workers=args.num_workers,
            noskip=True,
        )[0].next_epoch_itr(shuffle=False)
        progress = progress_bar.build_progress_bar(args,
                                                   itr,
                                                   epoch_itr.epoch,
                                                   prefix='translate subset',
                                                   no_progress_bar='simple')

        num_sentences = 0
        has_target = True
        #with progress_bar.build_progress_bar(args, itr) as t:
        for samples in progress:
            if torch.cuda.is_available() and not args.cpu:
                samples = utils.move_to_cuda(samples)
            #if 'net_input' not in samples:
            #    continue
            prefix_tokens = None
            sample = samples
            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(
                        sample['target'][i, :], tgt_dict.pad()).int().cpu()
                # Remove padding
                if args.sde:
                    src_tokens = target_tokens
                else:
                    src_tokens = utils.strip_pad(
                        sample['net_input']['src_tokens'][i, :],
                        tgt_dict.pad())

                # Either retrieve the original sentences or regenerate them from tokens.
                #if src_dict is not None:
                #    src_str = src_dict.string(src_tokens, args.remove_bpe)
                #else:
                #    src_str = ""
                if has_target:
                    target_str = tgt_dict.string(target_tokens,
                                                 args.remove_bpe,
                                                 escape_unk=True)

                #if not args.quiet:
                #    if src_dict is not None:
                #        print('S-{}\t{}'.format(sample_id, src_str))
                #    if has_target:
                #        print('T-{}\t{}'.format(sample_id, target_str))

                # Process top predictions
                for j, hypo in enumerate(hypos[i][:args.nbest]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str="",
                        alignment=hypo['alignment'].int().cpu()
                        if hypo['alignment'] is not None else None,
                        align_dict=None,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    #if not args.quiet:
                    #    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
                    #    print('P-{}\t{}'.format(
                    #        sample_id,
                    #        ' '.join(map(
                    #            lambda x: '{:.4f}'.format(x),
                    #            hypo['positional_scores'].tolist(),
                    #        ))
                    #    ))

                    #    if args.print_alignment:
                    #        print('A-{}\t{}'.format(
                    #            sample_id,
                    #            ' '.join(map(lambda x: str(utils.item(x)), alignment))
                    #        ))
                    #print(has_target, j, hypo_str)
                    # Score only the top hypothesis
                    if has_target and j == 0:
                        if args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tgt_dict.encode_line(
                                target_str, add_if_not_exist=True)
                        if hasattr(scorer_dict[0], 'add_string'):
                            scorer_dict[0].add_string(target_str, hypo_str)
                        else:
                            scorer_dict[0].add(target_tokens, hypo_tokens)

            num_sentences += sample['nsentences']
        print("|valid tranlsated {} sentences".format(num_sentences))
        for key, scorer in scorer_dict.items():
            bleu_dict[key] = scorer.score()

    return bleu_dict
Пример #13
0
def validate(args, trainer, task, epoch_itr, subsets, generator=None):
    """Evaluate the model on the validation set(s) and return the losses."""
    valid_losses = []
    if args.eval_bleu:
        bleus = validate_translation(args, trainer, task, epoch_itr, generator)
    for subset in subsets:
        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=task.dataset(subset),
            max_tokens=args.max_tokens_valid,
            max_sentences=args.max_sentences_valid,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=args.required_batch_size_multiple,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            num_workers=args.num_workers,
            noskip=True,
        )[0].next_epoch_itr(shuffle=False)
        progress = progress_bar.build_progress_bar(
            args,
            itr,
            epoch_itr.epoch,
            prefix='valid on \'{}\' subset'.format(subset),
            no_progress_bar='simple')

        # reset validation loss meters
        for k in ['valid_loss', 'valid_nll_loss']:
            meter = trainer.get_meter(k)
            if meter is not None:
                meter.reset()
        extra_meters = collections.defaultdict(lambda: AverageMeter())
        if args.eval_bleu:
            for k, v in bleus.items():
                extra_meters[str(k) + ":bleu"].update(v)

        for sample in progress:
            log_output = trainer.valid_step(sample)

            for k, v in log_output.items():
                if k in [
                        'loss', 'nll_loss', 'ntokens', 'nsentences',
                        'sample_size'
                ]:
                    continue
                extra_meters[k].update(v)

        # log validation stats
        stats = get_valid_stats(trainer, args, extra_meters)
        if epoch_itr.epoch > args.switch_obj_epoch:
            for k, v in extra_meters.items():
                #print(k, v.avg)
                if k.endswith(":loss"):
                    k = k.split(":")[0]
                    trainer.valid_losses[k] = v.avg
        for k, meter in extra_meters.items():
            stats[k] = meter.avg
        progress.print(stats, tag=subset, step=trainer.get_num_updates())

        valid_losses.append(stats[args.best_checkpoint_metric].avg if args.
                            best_checkpoint_metric ==
                            'loss' else stats[args.best_checkpoint_metric])

    if args.eval_bleu:
        return [sum(bleus.values())]
    else:
        return valid_losses
Пример #14
0
def train(args,
          trainer,
          task,
          epoch_itr,
          generator=None,
          filtered_maxpos_indices=None):
    """Train the model for one epoch."""
    # Update parameters every N batches
    update_freq = args.update_freq[epoch_itr.epoch - 1] \
        if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]
    extra_meters = collections.defaultdict(lambda: AverageMeter())
    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf
    # data selection: reset epoch iter to filter out unselected data
    if epoch_itr.epoch == args.select_by_dds_epoch and args.select_by_dds_epoch > 0:
        epoch_itr, _ = trainer.get_filtered_train_iterator(
            epoch_itr.epoch, filtered_maxpos_indices=filtered_maxpos_indices)

    if args.update_language_sampling > 0 and args.select_by_dds_epoch < 0 and (
            not args.data_actor_step_update):
        num_reset = len(epoch_itr.frozen_batches) // (
            args.update_language_sampling * args.update_freq[0] + 1)
        datasize = args.update_language_sampling * args.update_freq[0] + 1
        if num_reset * datasize < len(epoch_itr.frozen_batches):
            num_reset += 1
    else:
        num_reset = 1
        datasize = -1

    for reset_idx in range(num_reset):
        print("resetting at step", reset_idx)
        # Initialize data iterator
        itr = epoch_itr.next_epoch_itr(
            fix_batches_to_gpus=args.fix_batches_to_gpus,
            shuffle=(epoch_itr.epoch >= args.curriculum),
            offset=reset_idx *
            (args.update_language_sampling * args.update_freq[0] + 1),
            datasize=datasize,
        )
        itr = iterators.GroupedIterator(itr, update_freq)
        progress = progress_bar.build_progress_bar(
            args,
            itr,
            epoch_itr.epoch,
            no_progress_bar='simple',
        )

        for i, samples in enumerate(progress,
                                    start=epoch_itr.iterations_in_epoch):
            if args.extra_data_actor == 'ave_emb':
                update_actor = (i % args.extra_update_language_sampling == 0)
            elif args.data_actor_step_update:
                update_actor = (i % args.update_language_sampling == 0)
            elif args.data_actor == 'lan' and args.data_actor_step_update:
                update_actor = (i % args.update_language_sampling == 0)
            else:
                update_actor = False
            if (epoch_itr.epoch > args.select_by_dds_epoch
                    and args.select_by_dds_epoch > 0):
                update_actor = False
            log_output = trainer.train_step(samples, update_actor=update_actor)
            if log_output is None:
                continue

            # update sampling distribution
            if args.update_language_sampling > 0 and i % args.update_language_sampling == 0 and args.data_actor != 'ave_emb' and not args.data_actor_step_update:
                if args.data_actor_multilin:
                    trainer.update_language_sampler_multilin(
                        args, epoch=epoch_itr.epoch)
                else:
                    trainer.update_language_sampler(args)
            # log mid-epoch stats
            stats = get_training_stats(trainer)
            for k, v in log_output.items():
                if k in [
                        'loss', 'nll_loss', 'ntokens', 'nsentences',
                        'sample_size'
                ]:
                    continue  # these are already logged above
                if 'loss' in k or k == 'accuracy':
                    extra_meters[k].update(v, log_output['sample_size'])
                else:
                    extra_meters[k].update(v)
                stats[k] = extra_meters[k].avg
            progress.log(stats, tag='train', step=stats['num_updates'])

            # ignore the first mini-batch in words-per-second calculation
            if i == 0:
                trainer.get_meter('wps').reset()

            num_updates = trainer.get_num_updates()
            if (not args.disable_validation and args.save_interval_updates > 0
                    and num_updates % args.save_interval_updates == 0
                    and num_updates > 0):
                valid_losses = validate(args, trainer, task, epoch_itr,
                                        valid_subsets, generator)
                checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                                 valid_losses[0])

            if num_updates >= max_update:
                break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats, tag='train', step=stats['num_updates'])

    # reset training meters
    for k in [
            'train_loss',
            'train_nll_loss',
            'wps',
            'ups',
            'wpb',
            'bsz',
            'gnorm',
            'clip',
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
    return epoch_itr
Пример #15
0
def score(args, trainer, task, epoch_itr, subset):

    mlperf_print(key=mlperf_compliance.constants.EVAL_START,
                 metadata={'epoch_num': epoch_itr.epoch},
                 sync=True)
    begin = time.time()

    if not subset in task.datasets.keys():
        task.load_dataset(subset)

    src_dict = deepcopy(task.source_dictionary
                        )  # This is necessary, generation of translations
    tgt_dict = deepcopy(
        task.target_dictionary
    )  # alters target dictionary messing up with the rest of training

    model = trainer.get_model()

    # Initialize data iterator
    itr = data.EpochBatchIterator(
        dataset=task.dataset(subset),
        max_tokens=min(2560, args.max_tokens),
        max_sentences=max(
            8, min(math.ceil(1024 / args.distributed_world_size), 128)),
        max_positions=(256, 256),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
        seq_len_multiple=args.seq_len_multiple,
        # Use a large growth factor to get fewer buckets.
        # Fewer buckets yield faster eval since batches are filled from single bucket
        # and eval dataset is small.
        bucket_growth_factor=1.2,
        batching_scheme=args.batching_scheme,
        batch_multiple_strategy=args.batch_multiple_strategy,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    translator = SequenceGenerator(
        [model],
        tgt_dict,
        beam_size=args.beam,
        stop_early=(not args.no_early_stop),
        normalize_scores=(not args.unnormalized),
        len_penalty=args.lenpen,
        sampling=args.sampling,
        sampling_topk=args.sampling_topk,
        minlen=args.min_len,
    )
    # Generate and compute BLEU
    ref_toks = []
    sys_toks = []
    num_sentences = 0
    has_target = True
    if args.log_translations:
        log = open(
            os.path.join(
                args.save_dir,
                'translations_epoch{}_{}'.format(epoch_itr.epoch,
                                                 args.distributed_rank)), 'w+')
    with progress_bar.build_progress_bar(args, itr) as progress:
        translations = translator.generate_batched_itr(
            progress,
            maxlen_a=args.max_len_a,
            maxlen_b=args.max_len_b,
            cuda=True,
            timer=gen_timer,
            prefix_size=args.prefix_size,
        )

        wps_meter = TimeMeter()
        for sample_id, src_tokens, target_tokens, hypos in translations:
            # Process input and grount truth
            has_target = target_tokens is not None
            target_tokens = target_tokens.int().cpu() if has_target else None

            src_str = src_dict.string(src_tokens, args.remove_bpe)
            if has_target:
                target_str = tgt_dict.string(target_tokens, args.remove_bpe)

            if args.log_translations:
                log.write('S-{}\t{}\n'.format(sample_id, src_str))
                if has_target:
                    log.write('T-{}\t{}\n'.format(sample_id, target_str))

            # Process top predictions
            for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu()
                    if hypo['alignment'] is not None else None,
                    align_dict=None,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe)
                if args.log_translations:
                    log.write('H-{}\t{}\t{}\n'.format(sample_id, hypo['score'],
                                                      hypo_str))
                    # log.write(str(hypo_tokens))
                    log.write('P-{}\t{}\n'.format(
                        sample_id, ' '.join(
                            map(
                                lambda x: '{:.4f}'.format(x),
                                hypo['positional_scores'].tolist(),
                            ))))

                # Score only the top hypothesis
                if has_target and i == 0:
                    src_str = detokenize_subtokenized_sentence(src_str)
                    target_str = detokenize_subtokenized_sentence(target_str)
                    hypo_str = detokenize_subtokenized_sentence(hypo_str)
                    sys_tok = bleu_tokenize(
                        (hypo_str.lower() if args.ignore_case else hypo_str))
                    ref_tok = bleu_tokenize((target_str.lower() if
                                             args.ignore_case else target_str))
                    sys_toks.append(sys_tok)
                    ref_toks.append(ref_tok)

            wps_meter.update(src_tokens.size(0))
            progress.log({'wps': round(wps_meter.avg)})
            num_sentences += 1

    bleu_score_reference = compute_bleu(ref_toks, sys_toks, args)
    bleu_score_reference_str = '{:.4f}'.format(bleu_score_reference)
    if args.log_translations:
        log.close()
    if gen_timer.sum != 0:
        print(
            '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
            .format(num_sentences, gen_timer.n, gen_timer.sum,
                    num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: bleu_score={}'.format(
            subset, args.beam, bleu_score_reference_str))
    print('| Eval completed in: {:.2f}s'.format(time.time() - begin))
    mlperf_print(key=mlperf_compliance.constants.EVAL_STOP,
                 metadata={'epoch_num': epoch_itr.epoch},
                 sync=True)

    return bleu_score_reference
Пример #16
0
def main(args):
    utils.import_user_module(args)

    if args.buffer_size < 1:
        args.buffer_size = 1
    if args.max_tokens is None and args.max_sentences is None:
        args.max_sentences = 1

    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert not args.max_sentences or args.max_sentences <= args.buffer_size, \
        '--max-sentences/--batch-size cannot be larger than --buffer-size'

    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu
    use_ctc_loss = True if args.criterion == 'ctc_loss' else False

    # Setup task, e.g., image captioning
    task = tasks.setup_task(args)
    # Load dataset split
    task.load_dataset(args.gen_subset, combine=True, epoch=0)

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    model_paths = args.path.split(':')
    models, _model_args = checkpoint_utils.load_model_ensemble(
        model_paths,
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Set dictionaries
    tgt_dict = task.target_dictionary

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())

    stats = collections.OrderedDict()
    num_sentences = 0
    num_correct = 0
    has_target = True

    with progress_bar.build_progress_bar(
            args,
            itr,
            prefix='inference on \'{}\' subset'.format(args.gen_subset),
            no_progress_bar='simple',
    ) as progress:
        wps_meter = TimeMeter()
        for sample in progress:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            gen_timer.start()
            hypos = task.inference_step(generator, models, sample)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None
                target_tokens = None
                if has_target:
                    if use_ctc_loss:
                        target_tokens = sample['target'][i]
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)
                    else:
                        # Remove padding
                        target_tokens = utils.strip_pad(
                            sample['target'][i, :],
                            tgt_dict.pad()).int().cpu()
                        # Regenerate original sentences from tokens.
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

                if not args.quiet:
                    if has_target:
                        print('\nT-{}\t{}'.format(sample_id, target_str))

                # Process top predictions
                hypo = hypos[i][0]
                hypo_tokens = hypo['tokens'] if use_ctc_loss else hypo[
                    'tokens'].int().cpu()
                hypo_str = tgt_dict.string(hypo_tokens,
                                           args.remove_bpe,
                                           escape_unk=True)
                alignment = hypo['alignment'].int().cpu(
                ) if hypo['alignment'] is not None else None

                if hypo_str == target_str:
                    num_correct += 1

                if not args.quiet:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo_str,
                                                hypo['score']))
                    print('P-{}\t{}'.format(
                        sample_id, ' '.join(
                            map(
                                lambda x: '{:.4f}'.format(x),
                                hypo['positional_scores'].tolist(),
                            )) if not use_ctc_loss else None))

                    if args.print_alignment:
                        print('A-{}\t{}'.format(
                            sample_id, ' '.join(
                                map(lambda x: str(utils.item(x)), alignment))))

                # Score only the top hypothesis
                if has_target:
                    if hasattr(scorer, 'add_string'):
                        scorer.add_string(target_str, hypo_str)
                    else:
                        scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(num_generated_tokens)
            num_sentences += sample['nsentences']
            stats['wps'] = round(wps_meter.avg)
            stats['acc'] = num_correct / num_sentences
            progress.log(stats, tag='accuracy')
        progress.print(stats, tag='accuracy')

    print(
        '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset,
                                                      args.beam,
                                                      scorer.result_string()))
    return scorer
Пример #17
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr()
    progress = progress_bar.build_progress_bar(args,
                                               itr,
                                               epoch_itr.epoch,
                                               no_progress_bar='simple')

    # update parameters every N batches
    if epoch_itr.epoch <= len(args.update_freq):
        update_freq = args.update_freq[epoch_itr.epoch - 1]
    else:
        update_freq = args.update_freq[-1]

    if args.enable_parallel_backward_allred_opt and update_freq > 1:
        raise RuntimeError(
            '--enable-parallel-backward-allred-opt is incompatible with --update-freq > 1'
        )

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    first_valid = args.valid_subset.split(',')[0]
    max_update = args.max_update or math.inf
    num_batches = len(epoch_itr)
    #begin = time.time()
    #inside = 0

    for i, sample in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        #newbegin = time.time()
        #print("iter time", newbegin - begin, inside, (newbegin - begin - inside)*1000)
        #begin = newbegin
        if i < num_batches - 1 and (i + 1) % update_freq > 0:
            # buffer updates according to --update-freq
            trainer.train_step(sample,
                               update_params=False,
                               last_step=(i == len(itr) - 1))
            continue
        else:
            log_output = trainer.train_step(sample,
                                            update_params=True,
                                            last_step=(i == len(itr) - 1))

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in ['loss', 'nll_loss', 'sample_size']:
                continue  # these are already logged above
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats)

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

        if args.profile is not None and i == args.profile:
            import sys
            sys.exit()

        num_updates = trainer.get_num_updates()
        if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    [first_valid])
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        if num_updates >= max_update:
            break
        #end = time.time()
        #inside = end - begin

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)

    # reset training meters
    for k in [
            'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip'
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(':'),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()
        model.decoder.alignment_layer = args.alignment_layer

    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())

    if args.print_vanilla_alignment:
        import string
        punc = string.punctuation
        src_punc_tokens = [
            w for w in range(len(src_dict)) if src_dict[w] in punc
        ]
        tgt_punc_tokens = [
            w for w in range(len(tgt_dict)) if tgt_dict[w] in punc
        ]
    else:
        src_punc_tokens = None

    import time
    print('start time is :', time.strftime("%Y-%m-%d %X"))
    with progress_bar.build_progress_bar(args, itr) as t:
        if args.decoding_path is not None:
            align_sents = [[] for _ in range(4000000)]

        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue
            if args.print_vanilla_alignment:
                if args.set_shift:
                    alignments = utils.extract_soft_alignment(
                        sample,
                        models[0],
                        src_punc_tokens,
                        tgt_punc_tokens,
                        alignment_task=args.alignment_task)
                else:
                    alignments = utils.extract_soft_alignment_noshift(
                        sample,
                        models[0],
                        src_punc_tokens,
                        tgt_punc_tokens,
                        alignment_task=args.alignment_task)
            else:
                alignments = None

            for sample_id in sample['id'].tolist():
                if args.print_vanilla_alignment and args.decoding_path is not None:
                    align_sents[int(sample_id)].append(
                        alignments[int(sample_id)])

    print('end time is :', time.strftime("%Y-%m-%d %X"))
    if args.decoding_path is not None and args.print_vanilla_alignment:
        with open(
                os.path.join(
                    args.decoding_path,
                    f'{args.gen_subset}.{args.source_lang}2{args.target_lang}.align'
                ), 'w') as f:
            for sents in align_sents:
                if len(sents) == 0:
                    continue
                for sent in sents:
                    f.write(str(sent) + '\n')
        print("finished ...")
Пример #19
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr()
    progress = progress_bar.build_progress_bar(args,
                                               itr,
                                               epoch_itr.epoch,
                                               no_progress_bar='simple')

    # update parameters every N batches
    if epoch_itr.epoch <= len(args.update_freq):
        update_freq = args.update_freq[epoch_itr.epoch - 1]
    else:
        update_freq = args.update_freq[-1]

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    first_valid = args.valid_subset.split(',')[0]
    max_update = args.max_update or math.inf
    num_batches = len(epoch_itr)
    for i, sample in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        if i < num_batches - 1 and (i + 1) % update_freq > 0:
            # buffer updates according to --update-freq
            trainer.train_step(sample, update_params=False)
            continue
        else:
            log_output = trainer.train_step(sample, update_params=True)

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in ['loss', 'nll_loss', 'sample_size']:
                continue  # these are already logged above
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats)

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

        num_updates = trainer.get_num_updates()
        if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    [first_valid])
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)

    # reset training meters
    for k in [
            'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip'
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
Пример #20
0
def main(args):
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset
    if args.replace_unk is None:
        dataset = data.load_dataset(
            args.data,
            [args.gen_subset],
            args.source_lang,
            args.target_lang,
        )
    else:
        dataset = data.load_raw_text_dataset(
            args.data,
            [args.gen_subset],
            args.source_lang,
            args.target_lang,
            args.doctopics,
            args.encoder_embed_dim,
        )
    if args.source_lang is None or args.target_lang is None:
        # record inferred languages in args
        args.source_lang, args.target_lang = dataset.src, dataset.dst

    # Load ensemble
    print('| loading model(s) from {}'.format(', '.join(args.path)))
    models, _ = utils.load_ensemble_for_inference(args.path, dataset.src_dict,
                                                  dataset.dst_dict)

    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(dataset.splits[args.gen_subset])))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, )

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    max_positions = min(model.max_encoder_positions() for model in models)
    itr = dataset.eval_dataloader(
        args.gen_subset,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
        skip_invalid_size_inputs_valid_test=args.
        skip_invalid_size_inputs_valid_test,
    )
    if args.num_shards > 1:
        if args.shard_id < 0 or args.shard_id >= args.num_shards:
            raise ValueError('--shard-id must be between 0 and num_shards')
        itr = data.sharded_iterator(itr, args.num_shards, args.shard_id)

    # print("SHASHI: I AM HERE")

    # Initialize generator
    gen_timer = StopwatchMeter()
    if args.score_reference:
        translator = SequenceScorer(models)
    else:
        translator = SequenceGenerator(
            models,
            beam_size=args.beam,
            stop_early=(not args.no_early_stop),
            normalize_scores=(not args.unnormalized),
            len_penalty=args.lenpen,
            unk_penalty=args.unkpen)
    if use_cuda:
        translator.cuda()

    # Generate and compute BLEU score
    scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(),
                         dataset.dst_dict.unk())
    num_sentences = 0
    has_target = True
    with progress_bar.build_progress_bar(args, itr) as t:
        if args.score_reference:
            translations = translator.score_batched_itr(t,
                                                        cuda=use_cuda,
                                                        timer=gen_timer)
        else:
            translations = translator.generate_batched_itr(
                t,
                maxlen_a=args.max_len_a,
                maxlen_b=args.max_len_b,
                cuda=use_cuda,
                timer=gen_timer,
                prefix_size=args.prefix_size)
        wps_meter = TimeMeter()
        for sample_id, src_tokens, target_tokens, hypos in translations:
            # Process input and ground truth
            has_target = target_tokens is not None
            target_tokens = target_tokens.int().cpu() if has_target else None
            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = dataset.splits[
                    args.gen_subset].src.get_original_text(sample_id)
                target_str = dataset.splits[
                    args.gen_subset].dst.get_original_text(sample_id)
            else:
                src_str = dataset.src_dict.string(src_tokens, args.remove_bpe)
                target_str = dataset.dst_dict.string(
                    target_tokens, args.remove_bpe,
                    escape_unk=True) if has_target else ''

            if not args.quiet:
                print('S-{}\t{}'.format(sample_id, src_str))
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str))

            # Process top predictions
            for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu(),
                    align_dict=align_dict,
                    dst_dict=dataset.dst_dict,
                    remove_bpe=args.remove_bpe,
                )

                if not args.quiet:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'],
                                                hypo_str))
                    print('P-{}\t{}'.format(
                        sample_id, ' '.join(
                            map(
                                lambda x: '{:.4f}'.format(x),
                                hypo['positional_scores'].tolist(),
                            ))))
                    print('A-{}\t{}'.format(
                        sample_id,
                        ' '.join(map(lambda x: str(utils.item(x)),
                                     alignment))))

                # Score only the top hypothesis
                if has_target and i == 0:
                    if align_dict is not None or args.remove_bpe is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tokenizer.Tokenizer.tokenize(
                            target_str,
                            dataset.dst_dict,
                            add_if_not_exist=True)
                    scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += 1

    print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'.
          format(num_sentences, gen_timer.n, gen_timer.sum,
                 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset,
                                                      args.beam,
                                                      scorer.result_string()))
Пример #21
0
def main(args):
    check_args(args)
    import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 30000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    logger.info("| {} {} {} examples".format(
        args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    # Set dictionary
    tgt_dict = task.target_dictionary

    logger.info("| decoding with criterion {}".format(args.criterion))

    # Load ensemble
    logger.info("| loading model(s) from {}".format(args.path))
    models, criterions, _model_args = load_models_and_criterions(
        args.path.split(":"),
        arg_overrides=eval(args.model_overrides),  # noqa
        task=task,
    )
    optimize_models(args, use_cuda, models)

    # hack to pass transitions to W2lDecoder
    if args.criterion == "asg_loss":
        trans = criterions[0].asg.trans.data
        args.asg_transitions = torch.flatten(trans).tolist()

    # Load dataset (possibly sharded)
    itr = get_dataset_itr(args, task)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    num_sentences = 0

    if not os.path.exists(args.results_path):
        os.makedirs(args.results_path)

    sp = spm.SentencePieceProcessor()
    sp.Load(os.path.join(args.data, "spm.model"))

    res_files = prepare_result_files(args)
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if "net_input" not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample["target"][:, :args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample["id"].tolist()):
                speaker = task.dataset(
                    args.gen_subset).speakers[int(sample_id)]
                id = task.dataset(args.gen_subset).ids[int(sample_id)]
                target_tokens = (utils.strip_pad(sample["target"][i, :],
                                                 tgt_dict.pad()).int().cpu())
                # Process top predictions
                process_predictions(args, hypos[i], sp, tgt_dict,
                                    target_tokens, res_files, speaker, id)

            wps_meter.update(num_generated_tokens)
            t.log({"wps": round(wps_meter.avg)})
            num_sentences += sample["nsentences"]

    logger.info("| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
                "sentences/s, {:.2f} tokens/s)".format(
                    num_sentences,
                    gen_timer.n,
                    gen_timer.sum,
                    num_sentences / gen_timer.sum,
                    1.0 / gen_timer.avg,
                ))
    logger.info("| Generate {} with beam={}".format(args.gen_subset,
                                                    args.beam))
Пример #22
0
def main(args):
    assert args.path is not None, '--path required for recognition!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset split
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionary
    dictionary = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(':'),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )
    for i, m in enumerate(models):
        if hasattr(m, 'is_wordlm') and m.is_wordlm:
            # assume subword LM comes before word LM
            if isinstance(models[i - 1], FairseqLanguageModel):
                models[i-1] = MultiLevelLanguageModel(
                    m, models[i-1],
                    subwordlm_weight=args.subwordlm_weight,
                    oov_penalty=args.oov_penalty,
                    open_vocab=not args.disable_open_vocab,
                )
                del models[i]
                print('| LM fusion with Multi-level LM')
            else:
                models[i] = TensorizedLookaheadLanguageModel(
                    m, dictionary,
                    oov_penalty=args.oov_penalty,
                    open_vocab=not args.disable_open_vocab,
                )
                print('| LM fusion with Look-ahead Word LM')
        # assume subword LM comes after E2E models
        elif i == len(models) - 1 and isinstance(m, FairseqLanguageModel):
            print('| LM fusion with Subword LM')
    if args.lm_weight != 0.0:
        print('| using LM fusion with lm-weight={:.2f}'.format(args.lm_weight))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() if hasattr(model, 'encoder')
              else (None, model.max_positions()) for model in models]
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    if args.match_source_len:
        print('| The option match_source_len is not applicable to '
              'speech recognition. Ignoring it.')
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # Generate and compute WER
    scorer = wer.Scorer(dictionary, wer_output_filter=args.wer_output_filter)
    num_sentences = 0
    has_target = True
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(
                generator, models, sample, prefix_tokens, lm_weight=args.lm_weight,
            )
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            # obtain nonpad mask of encoder output to plot attentions
            if args.print_alignment:
                net_input = sample['net_input']
                src_tokens = net_input['src_tokens']
                output_lengths = models[0].encoder.output_lengths(net_input['src_lengths'])
                nonpad_idxs = sequence_mask(output_lengths, models[0].encoder.output_lengths(src_tokens.size(1)))

            for i in range(len(sample['id'])):
                has_target = sample['target'] is not None
                utt_id = sample['utt_id'][i]

                # Retrieve the original sentences
                if has_target:
                    target_str = sample['target_raw_text'][i]
                    if not args.quiet:
                        target_sent = dictionary.tokens_to_sentence(
                            target_str, use_unk_sym=False, bpe_symbol=args.remove_bpe,
                        )
                        print('T-{}\t{}'.format(utt_id, target_sent))

                # Process top predictions
                for j, hypo in enumerate(hypos[i][:args.nbest]):
                    hypo_str = dictionary.string(hypo['tokens'].int().cpu())  # not removing bpe at this point
                    if not args.quiet or i == 0:
                        hypo_sent = dictionary.tokens_to_sentence(hypo_str, bpe_symbol=args.remove_bpe)

                    if not args.quiet:
                        print('H-{}\t{}\t{}'.format(utt_id, hypo_sent, hypo['score']))

                    # Score and obtain attention only the top hypothesis
                    if j == 0:
                        # src_len x tgt_len
                        attention = hypo['attention'][nonpad_idxs[i]].float().cpu() \
                            if args.print_alignment and hypo['attention'] is not None else None
                        if args.print_alignment and attention is not None:
                            save_dir = os.path.join(args.results_path, 'attn_plots')
                            os.makedirs(save_dir, exist_ok=True)
                            plot_attention(attention, hypo_sent, utt_id, save_dir)
                        scorer.add_prediction(utt_id, hypo_str, bpe_symbol=args.remove_bpe)
                        if has_target:
                            scorer.add_evaluation(utt_id, target_str, hypo_str, bpe_symbol=args.remove_bpe)

            wps_meter.update(num_generated_tokens)
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += sample['nsentences']

    print('| Recognized {} utterances ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
        num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if args.print_alignment:
        print('| Saved attention plots in ' + save_dir)

    if has_target:
        assert args.test_text_files is not None
        scorer.add_ordered_utt_list(*args.test_text_files)

    os.makedirs(args.results_path, exist_ok=True)

    fn = 'decoded_char_results.txt'
    with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f:
        f.write(scorer.print_char_results())
        print('| Decoded char results saved as ' + f.name)

    fn = 'decoded_results.txt'
    with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f:
        f.write(scorer.print_results())
        print('| Decoded results saved as ' + f.name)

    if has_target:
        header = ' Recognize {} with beam={}: '.format(args.gen_subset, args.beam)
        fn = 'wer'
        with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f:
            res = 'WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%'.format(
                *(scorer.wer()))
            print('|' + header + res)
            f.write(res + '\n')
            print('| WER saved in ' + f.name)

        fn = 'cer'
        with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f:
            res = 'CER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%'.format(
                *(scorer.cer()))
            print('|' + ' ' * len(header) + res)
            f.write(res + '\n')
            print('| CER saved in ' + f.name)

        fn = 'aligned_results.txt'
        with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f:
            f.write(scorer.print_aligned_results())
            print('| Aligned results saved as ' + f.name)
    return scorer
Пример #23
0
def _generate_score(models, args, task, dataset):
    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load ensemble
    if not args.quiet:
        print("| loading model(s) from {}".format(", ".join(
            args.path.split(":"))))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=True,
        )

    translator = build_sequence_generator(args, task, models)
    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    if args.max_examples_to_evaluate > 0:
        pytorch_translate_data.subsample_pair_dataset(
            dataset, args.max_examples_to_evaluate)

    # Keep track of translations
    # Initialize with empty translations
    # and zero probs scores
    translated_sentences = [""] * len(dataset)
    translated_scores = [0.0] * len(dataset)
    hypos_list = []

    collect_output_hypos = getattr(args, "output_hypos_binary_path", False)
    if collect_output_hypos:
        output_hypos_token_arrays = [None] * len(dataset)

    # Generate and compute BLEU score
    dst_dict = task.target_dictionary
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(dst_dict.pad(), dst_dict.eos(), dst_dict.unk())

    itr = task.get_batch_iterator(
        dataset=dataset,
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    oracle_scorer = None
    if args.report_oracle_bleu:
        oracle_scorer = bleu.Scorer(dst_dict.pad(), dst_dict.eos(),
                                    dst_dict.unk())

    rescorer = None
    num_sentences = 0
    translation_samples = []
    translation_info_list = []
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        gen_timer = StopwatchMeter()
        translations = translator.generate_batched_itr(
            t,
            maxlen_a=args.max_len_a,
            maxlen_b=args.max_len_b,
            cuda=use_cuda,
            timer=gen_timer,
            prefix_size=1
            if pytorch_translate_data.is_multilingual(args) else 0,
        )

        for trans_info in _iter_translations(args, task, dataset, translations,
                                             align_dict, rescorer):
            if hasattr(scorer, "add_string"):
                scorer.add_string(trans_info.target_str, trans_info.hypo_str)
            else:
                scorer.add(trans_info.target_tokens, trans_info.hypo_tokens)
            if oracle_scorer is not None:
                oracle_scorer.add(trans_info.target_tokens,
                                  trans_info.best_hypo_tokens)

            if getattr(args, "translation_output_file", False):
                translated_sentences[
                    trans_info.sample_id] = trans_info.hypo_str
            if getattr(args, "hypotheses_export_path", False):
                hypos_list.append(trans_info.hypos)
            if collect_output_hypos:
                output_hypos_token_arrays[
                    trans_info.sample_id] = trans_info.best_hypo_tokens
            if args.translation_info_export_path is not None:
                # Strip expensive data from hypotheses before saving
                hypos = [{
                    k: v
                    for k, v in hypo.items() if k in ["tokens", "score"]
                } for hypo in trans_info.hypos]
                # Make sure everything is on cpu before exporting
                hypos = [{
                    "score": hypo["score"],
                    "tokens": hypo["tokens"].cpu()
                } for hypo in hypos]
                translation_info_list.append({
                    "src_tokens":
                    trans_info.src_tokens.cpu(),
                    "target_tokens":
                    trans_info.target_tokens,
                    "hypos":
                    hypos,
                })
            translation_samples.append(
                collections.OrderedDict({
                    "sample_id":
                    trans_info.sample_id.item(),
                    "src_str":
                    trans_info.src_str,
                    "target_str":
                    trans_info.target_str,
                    "hypo_str":
                    trans_info.hypo_str,
                }))
            wps_meter.update(trans_info.src_tokens.size(0))
            t.log({"wps": round(wps_meter.avg)})
            num_sentences += 1

    # If applicable, save collected hypothesis tokens to binary output file
    if collect_output_hypos:
        output_dataset = pytorch_translate_data.InMemoryNumpyDataset()
        output_dataset.load_from_sequences(output_hypos_token_arrays)
        output_dataset.save(args.output_hypos_binary_path)
    if args.translation_info_export_path is not None:
        f = open(args.translation_info_export_path, "wb")
        pickle.dump(translation_info_list, f)
        f.close()

    # If applicable, save the translations and hypos to the output file
    # For eg. external evaluation
    if getattr(args, "translation_output_file", False):
        with open(args.translation_output_file, "w") as out_file:
            for hypo_str in translated_sentences:
                print(hypo_str, file=out_file)

    if getattr(args, "hypotheses_export_path", False):
        with open(args.hypotheses_export_path, "w") as out_file:
            for hypos in hypos_list:
                for hypo in hypos:
                    print(
                        task.tgt_dict.string(hypo["tokens"],
                                             bpe_symbol=args.remove_bpe),
                        file=out_file,
                    )

    if getattr(args, "translation_probs_file", False):
        with open(args.translation_probs_file, "w") as out_file:
            for hypo_score in translated_scores:
                print(np.exp(hypo_score), file=out_file)

    if oracle_scorer is not None:
        print(
            f"| Oracle BLEU (best hypo in beam): {oracle_scorer.result_string()}"
        )

    return scorer, num_sentences, gen_timer, translation_samples
Пример #24
0
def validate(args, trainer, task, epoch_itr, subsets):
    """Evaluate the model on the validation set(s) and return the losses."""

    if args.fixed_validation_seed is not None:
        # set fixed seed for every validation
        utils.set_torch_seed(args.fixed_validation_seed)

    valid_losses = []
    for subset in subsets:
        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=task.dataset(subset),
            max_tokens=args.max_tokens_valid,
            max_sentences=args.max_sentences_valid,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=8,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            num_workers=args.num_workers,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.build_progress_bar(
            args,
            itr,
            epoch_itr.epoch,
            prefix='valid on \'{}\' subset'.format(subset),
            no_progress_bar='simple')

        # reset validation loss meters
        for k in ['valid_loss', 'valid_nll_loss']:
            meter = trainer.get_meter(k)
            if meter is not None:
                meter.reset()
        extra_meters = collections.defaultdict(lambda: AverageMeter())

        if callable(getattr(trainer.criterion, 'set_valid_tgt_dataset', None)):
            trainer.criterion.set_valid_tgt_dataset(task.dataset(subset).tgt)

        for sample in progress:
            log_output = trainer.valid_step(sample)

            for k, v in log_output.items():
                if k in [
                        'loss', 'nll_loss', 'ntokens', 'nsentences',
                        'sample_size', 'word_count', 'char_count'
                ]:
                    continue
                if k == 'word_error':
                    extra_meters['wer'].update(
                        float(v) / log_output['word_count'] * 100,
                        log_output['word_count'])
                elif k == 'char_error':
                    extra_meters['cer'].update(
                        float(v) / log_output['char_count'] * 100,
                        log_output['char_count'])
                else:
                    extra_meters[k].update(v)

        # log validation stats
        stats = get_valid_stats(trainer, args, extra_meters)
        for k, meter in extra_meters.items():
            stats[k] = meter.avg
        progress.print(stats, tag=subset, step=trainer.get_num_updates())

        valid_losses.append(stats[args.best_checkpoint_metric].avg if args.
                            best_checkpoint_metric ==
                            'loss' else stats[args.best_checkpoint_metric])
    return valid_losses
Пример #25
0
def main(args, task=None, model_state=None):
    check_args(args)

    if args.max_tokens is None and args.batch_size is None:
        args.max_tokens = 4000000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu


    logger.info("| decoding with criterion {}".format(args.criterion))

    task = tasks.setup_task(args)

    # Load ensemble
    if args.load_emissions:
        models, criterions = [], []
        task.load_dataset(args.gen_subset)
    else:
        logger.info("| loading model(s) from {}".format(args.path))
        models, saved_cfg = checkpoint_utils.load_model_ensemble(
            utils.split_paths(args.path),
            arg_overrides=ast.literal_eval(args.model_overrides),
            task=task,
            suffix=args.checkpoint_suffix,
            strict=(args.checkpoint_shard_count == 1),
            num_shards=args.checkpoint_shard_count,
            state=model_state,
        )
        optimize_models(args, use_cuda, models)
        task.load_dataset(args.gen_subset, task_cfg=saved_cfg.task)


    # Set dictionary
    tgt_dict = task.target_dictionary

    logger.info(
        "| {} {} {} examples".format(
            args.data, args.gen_subset, len(task.dataset(args.gen_subset))
        )
    )

    # hack to pass transitions to W2lDecoder
    if args.criterion == "asg_loss":
        raise NotImplementedError("asg_loss is currently not supported")
        # trans = criterions[0].asg.trans.data
        # args.asg_transitions = torch.flatten(trans).tolist()

    # Load dataset (possibly sharded)
    itr = get_dataset_itr(args, task, models)

    # Initialize generator
    gen_timer = StopwatchMeter()

    def build_generator(args):
        w2l_decoder = getattr(args, "w2l_decoder", None)
        if w2l_decoder == "viterbi":
            from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder

            return W2lViterbiDecoder(args, task.target_dictionary)
        elif w2l_decoder == "kenlm":
            from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder

            return W2lKenLMDecoder(args, task.target_dictionary)
        elif w2l_decoder == "fairseqlm":
            from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder

            return W2lFairseqLMDecoder(args, task.target_dictionary)
        else:
            print(
                "only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment"
            )

    # please do not touch this unless you test both generate.py and infer.py with audio_pretraining task
    generator = build_generator(args)

    if args.load_emissions:
        generator = ExistingEmissionsDecoder(
            generator, np.load(args.load_emissions, allow_pickle=True)
        )
        logger.info("loaded emissions from " + args.load_emissions)

    num_sentences = 0

    if args.results_path is not None and not os.path.exists(args.results_path):
        os.makedirs(args.results_path)

    max_source_pos = (
        utils.resolve_max_positions(
            task.max_positions(), *[model.max_positions() for model in models]
        ),
    )

    if max_source_pos is not None:
        max_source_pos = max_source_pos[0]
        if max_source_pos is not None:
            max_source_pos = max_source_pos[0] - 1

    if args.dump_emissions:
        emissions = {}
    if args.dump_features:
        features = {}
        models[0].bert.proj = None
    else:
        res_files = prepare_result_files(args)
    errs_t = 0
    lengths_t = 0
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if "net_input" not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample["target"][:, : args.prefix_size]

            gen_timer.start()
            if args.dump_emissions:
                with torch.no_grad():
                    encoder_out = models[0](**sample["net_input"])
                    emm = models[0].get_normalized_probs(encoder_out, log_probs=True)
                    emm = emm.transpose(0, 1).cpu().numpy()
                    for i, id in enumerate(sample["id"]):
                        emissions[id.item()] = emm[i]
                    continue
            elif args.dump_features:
                with torch.no_grad():
                    encoder_out = models[0](**sample["net_input"])
                    feat = encoder_out["encoder_out"].transpose(0, 1).cpu().numpy()
                    for i, id in enumerate(sample["id"]):
                        padding = (
                            encoder_out["encoder_padding_mask"][i].cpu().numpy()
                            if encoder_out["encoder_padding_mask"] is not None
                            else None
                        )
                        features[id.item()] = (feat[i], padding)
                    continue
            hypos = task.inference_step(generator, models, sample, prefix_tokens)
            num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample["id"].tolist()):
                speaker = None
                # id = task.dataset(args.gen_subset).ids[int(sample_id)]
                id = sample_id
                toks = (
                    sample["target"][i, :]
                    if "target_label" not in sample
                    else sample["target_label"][i, :]
                )
                target_tokens = utils.strip_pad(toks, tgt_dict.pad()).int().cpu()
                # Process top predictions
                errs, length = process_predictions(
                    args,
                    hypos[i],
                    None,
                    tgt_dict,
                    target_tokens,
                    res_files,
                    speaker,
                    id,
                )
                errs_t += errs
                lengths_t += length

            wps_meter.update(num_generated_tokens)
            t.log({"wps": round(wps_meter.avg)})
            num_sentences += (
                sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
            )

    wer = None
    if args.dump_emissions:
        emm_arr = []
        for i in range(len(emissions)):
            emm_arr.append(emissions[i])
        np.save(args.dump_emissions, emm_arr)
        logger.info(f"saved {len(emissions)} emissions to {args.dump_emissions}")
    elif args.dump_features:
        feat_arr = []
        for i in range(len(features)):
            feat_arr.append(features[i])
        np.save(args.dump_features, feat_arr)
        logger.info(f"saved {len(features)} emissions to {args.dump_features}")
    else:
        if lengths_t > 0:
            wer = errs_t * 100.0 / lengths_t
            logger.info(f"WER: {wer}")

        logger.info(
            "| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
            "sentences/s, {:.2f} tokens/s)".format(
                num_sentences,
                gen_timer.n,
                gen_timer.sum,
                num_sentences / gen_timer.sum,
                1.0 / gen_timer.avg,
            )
        )
        logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam))
    return task, wer
Пример #26
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
    update_freq = (args.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(args.update_freq) else
                   args.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    # task specific setup per epoch
    task.begin_epoch(epoch_itr.epoch, trainer.get_model())

    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf

    start_time = time.time()
    step = 0
    for samples in progress:
        log_output = trainer.train_step(samples)
        num_updates = trainer.get_num_updates()
        step += 1
        """
        if step % 10 == 0:
            print(step)

        if step >= 200:
            pr.disable()
            #pr.dump_stats( "torch_profile")

            sys.exit()

        step += 1
        """

        if log_output is None:
            continue

        # log mid-epoch stats
        stats = get_training_stats(metrics.get_smoothed_values('train'))
        progress.log(stats, tag='train', step=num_updates)

        if (not args.disable_validation and args.save_interval_updates > 0
                and num_updates % args.save_interval_updates == 0
                and num_updates > 0):
            print("validate and save_checkpoint")
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        if num_updates >= max_update:
            break

    train_epoch_cost = time.time() - start_time

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values('train'))
    progress.print(stats, tag='train', step=num_updates)
    print("epoch_cost: %.5f s, avg_speed: %.5f steps/s" %
          (train_epoch_cost, float(step) / train_epoch_cost))

    # reset epoch-level meters
    metrics.reset_meters('train')
Пример #27
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    args.beam = args.nbest = 1
    args.max_tokens = int(1e4)

    utils.import_user_module(args)

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    src_dict = getattr(task, 'source_dictionary', None)
    tgt_dict = task.target_dictionary

    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(':'),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(beamable_mm_beam_size=args.beam,
                                    need_attn=False)
        model.cuda()

    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    generator = task.build_generator(args)

    with progress_bar.build_progress_bar(args, itr) as t:
        for sample in t:
            sample = utils.move_to_cuda(sample)
            if 'net_input' not in sample:
                continue

            prefix_tokens = None

            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)

            for i, sample_id in enumerate(sample['id'].tolist()):
                # Remove padding
                src_tokens = utils.strip_pad(
                    sample['net_input']['src_tokens'][i, :], tgt_dict.pad())

                if src_dict is not None:
                    src_str = src_dict.string(src_tokens, args.remove_bpe)
                else:
                    src_str = ""

                # Process top predictions
                hypo = hypos[i][0]
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'],
                    align_dict=None,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                )

                result = dict(src=src_str,
                              pred=hypo_str,
                              src_len=len(src_str.split()),
                              pred_len=len(hypo_str.split()))
                result_line = json.dumps(result)
                print(result_line)
Пример #28
0
def compute_top_k(
    task,
    models,
    dataset,
    k,
    use_cuda,
    max_tokens=None,
    max_sentences=None,
    progress_bar_args=None,
):
    """
    This function runs forward computation on an ensemble of trained models
    using binarized parallel training data and returns the top-k probabilities
    and their corresponding token indices for each output step.

    Returns: (top_k_scores, top_k_indices)
        Each a NumPy array of size (total_target_tokens, k)
    """
    top_k_scores_list = [None for _ in range(len(dataset))]
    top_k_indices_list = [None for _ in range(len(dataset))]

    itr = task.get_batch_iterator(
        dataset=dataset, max_tokens=max_tokens,
        max_sentences=max_sentences).next_epoch_itr(shuffle=False)
    if progress_bar_args is not None:
        itr = progress_bar.build_progress_bar(
            args=progress_bar_args,
            iterator=itr,
            prefix=f"top-k probs eval",
            no_progress_bar="simple",
        )

    for sample in itr:
        sentence_ids = sample["id"]
        target_lengths = ((sample["net_input"]["prev_output_tokens"] !=
                           dataset.tgt_dict.pad()).sum(axis=1).numpy())
        if use_cuda:
            sample = utils.move_to_cuda(sample)
        avg_probs = None
        for model in models:
            with torch.no_grad():
                net_output = model(**sample["net_input"])
                probs = model.get_normalized_probs(net_output, log_probs=False)
            if avg_probs is None:
                avg_probs = probs
            else:
                avg_probs.add_(probs)
        avg_probs.div_(len(models))

        top_k_avg_probs, indices = torch.topk(avg_probs, k=k)

        top_k_probs_normalized = F.normalize(top_k_avg_probs, p=1, dim=2).cpu()
        indices = indices.cpu()

        for i, sentence_id in enumerate(sentence_ids):
            length = target_lengths[i]
            top_k_scores_list[sentence_id] = top_k_probs_normalized[
                i][:length].numpy()
            top_k_indices_list[sentence_id] = indices[i][:length].numpy()

    assert all(top_k_scores is not None for top_k_scores in
               top_k_scores_list), "scores not calculated for all examples!"
    assert all(top_k_indices is not None for top_k_indices in
               top_k_indices_list), "indices not calculated for all examples!"

    top_k_scores = np.concatenate(top_k_scores_list, axis=0)
    top_k_indices = np.concatenate(top_k_indices_list, axis=0)

    return top_k_scores, top_k_indices
Пример #29
0
def generate(cfg: UnsupGenerateConfig, models, saved_cfg, use_cuda):
    task = tasks.setup_task(cfg.fairseq.task)
    saved_cfg.task.labels = cfg.fairseq.task.labels
    task.load_dataset(cfg.fairseq.dataset.gen_subset, task_cfg=saved_cfg.task)
    # Set dictionary
    tgt_dict = task.target_dictionary
    logger.info("| {} {} {} examples".format(
        cfg.fairseq.task.data,
        cfg.fairseq.dataset.gen_subset,
        len(task.dataset(cfg.fairseq.dataset.gen_subset)),
    ))
    # Load dataset (possibly sharded)
    itr = get_dataset_itr(cfg, task)
    # Initialize generator
    gen_timer = StopwatchMeter()

    def build_generator(cfg: UnsupGenerateConfig):
        w2l_decoder = cfg.w2l_decoder
        if w2l_decoder == DecoderType.VITERBI:
            from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder

            return W2lViterbiDecoder(cfg, task.target_dictionary)
        elif w2l_decoder == DecoderType.KENLM:
            from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder

            return W2lKenLMDecoder(cfg, task.target_dictionary)
        elif w2l_decoder == DecoderType.FAIRSEQ:
            from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder

            return W2lFairseqLMDecoder(cfg, task.target_dictionary)
        elif w2l_decoder == DecoderType.KALDI:
            from examples.speech_recognition.kaldi.kaldi_decoder import KaldiDecoder

            assert cfg.kaldi_decoder_config is not None

            return KaldiDecoder(
                cfg.kaldi_decoder_config,
                cfg.beam,
            )
        else:
            raise NotImplementedError(
                "only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment but found "
                + str(w2l_decoder))

    generator = build_generator(cfg)

    kenlm = None
    fairseq_lm = None
    if cfg.lm_model is not None:
        import kenlm

        kenlm = kenlm.Model(cfg.lm_model)

    num_sentences = 0
    if cfg.results_path is not None and not os.path.exists(cfg.results_path):
        os.makedirs(cfg.results_path)

    res_files = prepare_result_files(cfg)
    errs_t = 0
    lengths_hyp_t = 0
    lengths_hyp_unit_t = 0
    lengths_t = 0
    count = 0
    num_feats = 0
    all_hyp_pieces = []
    all_hyp_words = []

    num_symbols = (
        len([s for s in tgt_dict.symbols if not s.startswith("madeup")]) -
        tgt_dict.nspecial)
    targets = None
    if cfg.targets is not None:
        tgt_path = os.path.join(
            cfg.fairseq.task.data,
            cfg.fairseq.dataset.gen_subset + "." + cfg.targets)
        if os.path.exists(tgt_path):
            with open(tgt_path, "r") as f:
                targets = f.read().splitlines()
    viterbi_transcript = None
    if cfg.viterbi_transcript is not None and len(cfg.viterbi_transcript) > 0:
        logger.info(
            f"loading viterbi transcript from {cfg.viterbi_transcript}")
        with open(cfg.viterbi_transcript, "r") as vf:
            viterbi_transcript = vf.readlines()
            viterbi_transcript = [
                v.rstrip().split() for v in viterbi_transcript
            ]

    gen_timer.start()

    start = 0
    end = len(itr)

    hypo_futures = None
    if cfg.w2l_decoder == DecoderType.KALDI:
        logger.info("Extracting features")
        hypo_futures = []
        samples = []
        with progress_bar.build_progress_bar(cfg.fairseq.common, itr) as t:
            for i, sample in enumerate(t):
                if "net_input" not in sample or i < start or i >= end:
                    continue
                if "padding_mask" not in sample["net_input"]:
                    sample["net_input"]["padding_mask"] = None

                hypos, num_feats = gen_hypos(generator, models, num_feats,
                                             sample, task, use_cuda)
                hypo_futures.append(hypos)
                samples.append(sample)
                if cfg.debug:
                    break
        itr = list(zip(hypo_futures, samples))
        start = 0
        end = len(itr)
        logger.info("Finished extracting features")

    with progress_bar.build_progress_bar(cfg.fairseq.common, itr) as t:
        for i, sample in enumerate(t):
            if i < start or i >= end:
                continue

            if hypo_futures is not None:
                hypos, sample = sample
                hypos = [h.result() for h in hypos]
            else:
                if "net_input" not in sample:
                    continue

                hypos, num_feats = gen_hypos(generator, models, num_feats,
                                             sample, task, use_cuda)

            for i, sample_id in enumerate(sample["id"].tolist()):
                if targets is not None:
                    target_tokens = targets[sample_id]
                elif "target" in sample or "target_label" in sample:
                    toks = (sample["target"][i, :] if "target_label"
                            not in sample else sample["target_label"][i, :])

                    target_tokens = utils.strip_pad(
                        toks, tgt_dict.pad()).int().cpu()
                else:
                    target_tokens = None

                # Process top predictions
                (
                    errs,
                    length_hyp,
                    length,
                    hyp_pieces,
                    hyp_words,
                ) = process_predictions(
                    cfg,
                    hypos[i],
                    tgt_dict,
                    target_tokens,
                    res_files,
                )
                errs_t += errs
                lengths_hyp_t += length_hyp
                lengths_hyp_unit_t += (len(hyp_pieces) if len(hyp_pieces) > 0
                                       else len(hyp_words))
                lengths_t += length
                count += 1
                all_hyp_pieces.append(hyp_pieces)
                all_hyp_words.append(hyp_words)

            num_sentences += (sample["nsentences"] if "nsentences" in sample
                              else sample["id"].numel())

    lm_score_sum = 0
    if kenlm is not None:

        if cfg.unit_lm:
            lm_score_sum = sum(kenlm.score(w) for w in all_hyp_pieces)
        else:
            lm_score_sum = sum(kenlm.score(w) for w in all_hyp_words)
    elif fairseq_lm is not None:
        lm_score_sum = sum(
            fairseq_lm.score([h.split() for h in all_hyp_words])[0])

    vt_err_t = 0
    vt_length_t = 0
    if viterbi_transcript is not None:
        unit_hyps = []
        if cfg.targets is not None and cfg.lexicon is not None:
            lex = {}
            with open(cfg.lexicon, "r") as lf:
                for line in lf:
                    items = line.rstrip().split()
                    lex[items[0]] = items[1:]
            for h in all_hyp_pieces:
                hyp_ws = []
                for w in h.split():
                    assert w in lex, w
                    hyp_ws.extend(lex[w])
                unit_hyps.append(hyp_ws)

        else:
            unit_hyps.extend([h.split() for h in all_hyp_words])

        vt_err_t = sum(
            editdistance.eval(vt, h)
            for vt, h in zip(viterbi_transcript, unit_hyps))

        vt_length_t = sum(len(h) for h in viterbi_transcript)

    if res_files is not None:
        for r in res_files.values():
            r.close()

    gen_timer.stop(lengths_hyp_t)

    return GenResult(
        count,
        errs_t,
        gen_timer,
        lengths_hyp_unit_t,
        lengths_hyp_t,
        lengths_t,
        lm_score_sum,
        num_feats,
        num_sentences,
        num_symbols,
        vt_err_t,
        vt_length_t,
    )
Пример #30
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    # Update parameters every N batches
    update_freq = args.update_freq[epoch_itr.epoch - 1] \
        if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        log_output = trainer.train_step(samples)
        if log_output is None:
            continue

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        wandb_stats = {}
        from numbers import Number
        from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
        for k in stats.keys():
            stat = stats[k]
            if isinstance(stat, Number):
                wandb_stats[k] = stat
            elif isinstance(stat, AverageMeter):
                wandb_stats[k] = stat.avg
            elif isinstance(stat, TimeMeter):
                wandb_stats[k] = stat.avg
            elif isinstance(stat, StopwatchMeter):
                wandb_stats[k] = stat.sum
        wandb.log(wandb_stats)
        for k, v in log_output.items():
            if k in [
                    'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size'
            ]:
                continue  # these are already logged above
            if 'loss' in k or k == 'accuracy':
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats, tag='train', step=stats['num_updates'])

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

        num_updates = trainer.get_num_updates()
        if (not args.disable_validation and args.save_interval_updates > 0
                and num_updates % args.save_interval_updates == 0
                and num_updates > 0):
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats, tag='train', step=stats['num_updates'])

    # reset training meters
    for k in [
            'train_loss',
            'train_nll_loss',
            'wps',
            'ups',
            'wpb',
            'bsz',
            'gnorm',
            'clip',
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""

    # Update parameters every N batches
    if epoch_itr.epoch <= len(args.update_freq):
        update_freq = args.update_freq[epoch_itr.epoch - 1]
    else:
        update_freq = args.update_freq[-1]

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(fix_batches_to_gpus=args.fix_batches_to_gpus)
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args, itr, epoch_itr.epoch, no_progress_bar='simple',
    )

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    first_valid = args.valid_subset.split(',')[0]
    max_update = args.max_update or math.inf
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        log_output = trainer.train_step(samples)
        if log_output is None:
            continue

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
                continue  # these are already logged above
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats)

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

        num_updates = trainer.get_num_updates()
        if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0:
            valid_losses = validate(args, trainer, task, epoch_itr, [first_valid])
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)

    # reset training meters
    for k in [
        'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip',
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
Пример #32
0
def main(parsed_args):
    assert parsed_args.path is not None, '--path required for evaluation!'

    print(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load ensemble
    print('| loading model(s) from {}'.format(parsed_args.path))
    models, args = utils.load_ensemble_for_inference(
        parsed_args.path.split(':'), task)

    for arg in vars(parsed_args).keys():
        if arg not in {
                'self_target', 'future_target', 'past_target',
                'tokens_per_sample', 'output_size_dictionary'
        }:
            setattr(args, arg, getattr(parsed_args, arg))
    task = tasks.setup_task(args)

    # Load dataset splits
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(task.dataset(args.gen_subset))))

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
        if args.fp16:
            model.half()
        if hasattr(model, 'set_targets'):
            model.set_targets(args, task)

    assert len(models) > 0

    print('num. model params: {}'.format(
        sum(p.numel() for p in models[0].parameters())))

    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            *[model.max_positions() for model in models]),
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        ignore_invalid_inputs=True,
    ).next_epoch_itr(shuffle=False)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(models,
                            task.target_dictionary,
                            target_idx=args.target_idx)
    if use_cuda:
        scorer.cuda()

    score_sum = 0.
    count = 0

    if args.remove_bpe is not None:
        bpe_cont = args.remove_bpe.rstrip()
        bpe_toks = set(i for i in range(len(task.dictionary))
                       if task.dictionary[i].endswith(bpe_cont))
        bpe_len = len(bpe_cont)
    else:
        bpe_toks = None
        bpe_len = 0

    word_stats = dict()

    with progress_bar.build_progress_bar(args, itr) as t:
        results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
        wps_meter = TimeMeter()
        for _, src_tokens, __, hypos in results:
            for hypo in hypos:
                pos_scores = hypo['positional_scores']

                skipped_toks = 0
                if bpe_toks is not None:
                    for i in range(len(hypo['tokens']) - 1):
                        if hypo['tokens'][i].item() in bpe_toks:
                            skipped_toks += 1
                            pos_scores[i + 1] += pos_scores[i]
                            pos_scores[i] = 0

                inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(
                    float('-inf'))
                if inf_scores.any():
                    print(
                        '| Skipping tokens with inf scores:',
                        task.target_dictionary.string(
                            hypo['tokens'][inf_scores.nonzero()]))
                    pos_scores = pos_scores[(~inf_scores).nonzero()]
                score_sum += pos_scores.sum().cpu()
                count += pos_scores.numel() - skipped_toks

                if args.output_word_probs or args.output_word_stats:
                    w = ''
                    word_prob = []
                    is_bpe = False
                    for i in range(len(hypo['tokens'])):
                        w_ind = hypo['tokens'][i].item()
                        w += task.dictionary[w_ind]
                        if bpe_toks is not None and w_ind in bpe_toks:
                            w = w[:-bpe_len]
                            is_bpe = True
                        else:
                            word_prob.append((w, pos_scores[i].item()))

                            next_prob = None
                            ind = i + 1
                            while ind < len(hypo['tokens']):
                                if pos_scores[ind].item() != 0:
                                    next_prob = pos_scores[ind]
                                    break
                                ind += 1

                            word_stats.setdefault(w, WordStat(w, is_bpe)).add(
                                pos_scores[i].item(), next_prob)
                            is_bpe = False
                            w = ''
                    if args.output_word_probs:
                        print('\t'.join('{} [{:2f}]'.format(x[0], x[1])
                                        for x in word_prob))

            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(
        gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss,
                                                      np.exp(avg_nll_loss)))

    if args.output_word_stats:
        for ws in sorted(word_stats.values(),
                         key=lambda x: x.count,
                         reverse=True):
            print(ws)
Пример #33
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    if args.score_reference:
        translator = SequenceScorer(models, task.target_dictionary)
    else:
        translator = SequenceGenerator(
            models, task.target_dictionary, beam_size=args.beam, minlen=args.min_len,
            stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
            len_penalty=args.lenpen, unk_penalty=args.unkpen,
            sampling=args.sampling, sampling_topk=args.sampling_topk, sampling_temperature=args.sampling_temperature,
            diverse_beam_groups=args.diverse_beam_groups, diverse_beam_strength=args.diverse_beam_strength,
        )

    if use_cuda:
        translator.cuda()

    # Generate and compute BLEU score
    scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    with progress_bar.build_progress_bar(args, itr) as t:
        if args.score_reference:
            translations = translator.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
        else:
            translations = translator.generate_batched_itr(
                t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
                cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size,
            )

        wps_meter = TimeMeter()
        for sample_id, src_tokens, target_tokens, hypos in translations:
            # Process input and ground truth
            has_target = target_tokens is not None
            target_tokens = target_tokens.int().cpu() if has_target else None

            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
                target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
            else:
                src_str = src_dict.string(src_tokens, args.remove_bpe)
                if has_target:
                    target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)

            if not args.quiet:
                print('S-{}\t{}'.format(sample_id, src_str))
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str))

            # Process top predictions
            for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                )

                if not args.quiet:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
                    print('P-{}\t{}'.format(
                        sample_id,
                        ' '.join(map(
                            lambda x: '{:.4f}'.format(x),
                            hypo['positional_scores'].tolist(),
                        ))
                    ))

                    if args.print_alignment:
                        print('A-{}\t{}'.format(
                            sample_id,
                            ' '.join(map(lambda x: str(utils.item(x)), alignment))
                        ))

                # Score only the top hypothesis
                if has_target and i == 0:
                    if align_dict is not None or args.remove_bpe is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tokenizer.Tokenizer.tokenize(
                            target_str, tgt_dict, add_if_not_exist=True)
                    scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += 1

    print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
        num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
Пример #34
0
def main(parsed_args):
    assert parsed_args.path is not None, '--path required for evaluation!'

    print(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load ensemble
    print('| loading model(s) from {}'.format(parsed_args.path))
    models, args = utils.load_ensemble_for_inference(parsed_args.path.split(':'), task)

    args.__dict__.update(parsed_args.__dict__)
    print(args)

    task.args = args

    # Load dataset splits
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
        if args.fp16:
            model.half()

    assert len(models) > 0

    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(*[
            model.max_positions() for model in models
        ]),
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        ignore_invalid_inputs=True,
    ).next_epoch_itr(shuffle=False)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(models, task.target_dictionary)
    if use_cuda:
        scorer.cuda()

    score_sum = 0.
    count = 0

    if args.remove_bpe is not None:
        bpe_cont = args.remove_bpe.rstrip()
        bpe_toks = set(i for i in range(len(task.dictionary)) if task.dictionary[i].endswith(bpe_cont))
        bpe_len = len(bpe_cont)
    else:
        bpe_toks = None
        bpe_len = 0

    word_stats = dict()

    with progress_bar.build_progress_bar(args, itr) as t:
        results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
        wps_meter = TimeMeter()
        for _, src_tokens, __, hypos in results:
            for hypo in hypos:
                pos_scores = hypo['positional_scores']

                skipped_toks = 0
                if bpe_toks is not None:
                    for i in range(len(hypo['tokens']) - 1):
                        if hypo['tokens'][i].item() in bpe_toks:
                            skipped_toks += 1
                            pos_scores[i + 1] += pos_scores[i]
                            pos_scores[i] = 0

                inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
                if inf_scores.any():
                    print('| Skipping tokens with inf scores:',
                          task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
                    pos_scores = pos_scores[(~inf_scores).nonzero()]
                score_sum += utils.item(pos_scores.sum())
                count += pos_scores.numel() - skipped_toks

                if args.output_word_probs or args.output_word_stats:
                    w = ''
                    word_prob = []
                    is_bpe = False
                    for i in range(len(hypo['tokens'])):
                        w_ind = hypo['tokens'][i].item()
                        w += task.dictionary[w_ind]
                        if bpe_toks is not None and w_ind in bpe_toks:
                            w = w[:-bpe_len]
                            is_bpe = True
                        else:
                            word_prob.append((w, pos_scores[i].item()))
                            word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item())
                            is_bpe = False
                            w = ''
                    if args.output_word_probs:
                        print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))

            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))

    if args.output_word_stats:
        for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
            print(ws)