Beispiel #1
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(':'),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True

    system_out = open(args.system_out, 'w')
    reference = open(args.reference, 'w')

    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                # Remove padding
                src_tokens = utils.strip_pad(
                    sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(
                        sample['target'][i, :], tgt_dict.pad()).int().cpu()

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(
                        args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(
                        args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

                if not args.quiet:
                    if src_dict is not None:
                        print('S-{}\t{}'.format(sample_id, src_str))
                    if has_target:
                        reference.write(f'{sample_id} {target_str}\n')
                        print('T-{}\t{}'.format(sample_id, target_str))

                # Process top predictions
                for j, hypo in enumerate(hypos[i][:args.nbest]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'],
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    if not args.quiet:
                        system_out.write(f'{sample_id} {hypo_str}\n')
                        print('H-{}\t{}\t{}'.format(sample_id, hypo['score'],
                                                    hypo_str))
                        print('P-{}\t{}'.format(
                            sample_id, ' '.join(
                                map(
                                    lambda x: '{:.4f}'.format(x),
                                    hypo['positional_scores'].tolist(),
                                ))))

                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id, ' '.join([
                                    '{}-{}'.format(src_idx, tgt_idx)
                                    for src_idx, tgt_idx in alignment
                                ])))

                        if args.print_step:
                            print('I-{}\t{}'.format(sample_id, hypo['steps']))

                        if getattr(args, 'retain_iter_history', False):
                            print("\n".join([
                                'E-{}_{}\t{}'.format(
                                    sample_id, step,
                                    utils.post_process_prediction(
                                        h['tokens'].int().cpu(), src_str, None,
                                        None, tgt_dict, None)[1])
                                for step, h in enumerate(hypo['history'])
                            ]))

                    # Score only the top hypothesis
                    if has_target and j == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tgt_dict.encode_line(
                                target_str, add_if_not_exist=True)
                        if hasattr(scorer, 'add_string'):
                            scorer.add_string(target_str, hypo_str)
                        else:
                            scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(num_generated_tokens)
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += sample['nsentences']

    reference.close()
    system_out.close()

    print(
        '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset,
                                                      args.beam,
                                                      scorer.result_string()))

    return scorer
Beispiel #2
0
def main():
    parser = options.get_parser('Trainer')
    dataset_args = options.add_dataset_args(parser)
    dataset_args.add_argument('--max-tokens', default=6000, type=int, metavar='N',
                              help='maximum number of tokens in a batch')
    dataset_args.add_argument('--max-sentences', type=int, metavar='N',
                              help='maximum number of sentences in a batch')
    dataset_args.add_argument('--train-subset', default='train', metavar='SPLIT',
                              choices=['train', 'valid', 'test'],
                              help='data subset to use for training (train, valid, test)')
    dataset_args.add_argument('--valid-subset', default='valid', metavar='SPLIT',
                              help='comma separated list of data subsets '
                                   ' to use for validation (train, valid, valid1,test, test1)')
    options.add_optimization_args(parser)
    options.add_checkpoint_args(parser)
    options.add_model_args(parser)

    args = utils.parse_args_and_arch(parser)

    if args.no_progress_bar and args.log_format is None:
        args.log_format = 'simple'

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    torch.manual_seed(args.seed)

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        dataset = data.load_dataset(args.data, splits, args.source_lang, args.target_lang)
    else:
        dataset = data.load_raw_text_dataset(args.data, splits, args.source_lang, args.target_lang)
    if args.source_lang is None or args.target_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.source_lang, args.target_lang = dataset.src, dataset.dst

    print(args)
    print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
    for split in splits:
        print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split])))

    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    num_gpus = torch.cuda.device_count()

    print('| using {} GPUs (with max tokens per GPU = {} and max sentences per GPU = {})'.format(
        num_gpus, args.max_tokens, args.max_sentences))

    # Build model and criterion
    model = utils.build_model(args, dataset.src_dict, dataset.dst_dict)
    criterion = utils.build_criterion(args, dataset.src_dict, dataset.dst_dict)
    print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))

    # The max number of positions can be different for train and valid
    # e.g., RNNs may support more positions at test time than seen in training
    max_positions_train = (args.max_source_positions, args.max_target_positions)
    max_positions_valid = (
        min(args.max_source_positions, model.max_encoder_positions()),
        min(args.max_target_positions, model.max_decoder_positions())
    )

    # Start multiprocessing
    trainer = MultiprocessingTrainer(args, model, criterion)

    # Load the latest checkpoint if one is available
    checkpoint_path = os.path.join(args.save_dir, args.restore_file)
    extra_state = trainer.load_checkpoint(checkpoint_path)
    if extra_state is not None:
        epoch = extra_state['epoch']
        batch_offset = extra_state['batch_offset']
        print('| loaded checkpoint {} (epoch {})'.format(checkpoint_path, epoch))
        if batch_offset == 0:
            epoch += 1
    else:
        epoch, batch_offset = 1, 0

    # Train until the learning rate gets too small
    val_loss = None
    max_epoch = args.max_epoch or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    while lr > args.min_lr and epoch <= max_epoch:
        # train for one epoch
        train(args, epoch, batch_offset, trainer, dataset, max_positions_train, num_gpus)

        # evaluate on validate set
        for k, subset in enumerate(args.valid_subset.split(',')):
            val_loss = validate(args, epoch, trainer, dataset, max_positions_valid, subset, num_gpus)
            if k == 0:
                if not args.no_save:
                    # save checkpoint
                    save_checkpoint(trainer, args, epoch, 0, val_loss)
                # only use first validation loss to update the learning schedule
                lr = trainer.lr_step(val_loss, epoch)

        epoch += 1
        batch_offset = 0
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))

    # Stop multiprocessing
    trainer.stop()
def main():
    parser = options.get_parser('Generation')
    parser.add_argument('--path',
                        metavar='FILE',
                        required=True,
                        action='append',
                        help='path(s) to model file(s)')
    dataset_args = options.add_dataset_args(parser)
    dataset_args.add_argument('--batch-size',
                              default=32,
                              type=int,
                              metavar='N',
                              help='batch size')
    dataset_args.add_argument(
        '--gen-subset',
        default='test',
        metavar='SPLIT',
        help='data subset to generate (train, valid, test)')
    options.add_generation_args(parser)

    args = parser.parse_args()
    if args.no_progress_bar and args.log_format is None:
        args.log_format = 'none'
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset
    if args.replace_unk is None:
        dataset = data.load_dataset(args.data, [args.gen_subset],
                                    args.source_lang, args.target_lang)
    else:
        dataset = data.load_raw_text_dataset(args.data, [args.gen_subset],
                                             args.source_lang,
                                             args.target_lang)
    if args.source_lang is None or args.target_lang is None:
        # record inferred languages in args
        args.source_lang, args.target_lang = dataset.src, dataset.dst

    # Load ensemble
    print('| loading model(s) from {}'.format(', '.join(args.path)))
    models, _ = utils.load_ensemble_for_inference(args.path, dataset.src_dict,
                                                  dataset.dst_dict)

    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(dataset.splits[args.gen_subset])))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)

    # Initialize generator
    translator = SequenceGenerator(models,
                                   beam_size=args.beam,
                                   stop_early=(not args.no_early_stop),
                                   normalize_scores=(not args.unnormalized),
                                   len_penalty=args.lenpen,
                                   unk_penalty=args.unkpen)
    if use_cuda:
        translator.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Generate and compute BLEU score
    scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(),
                         dataset.dst_dict.unk())
    max_positions = min(model.max_encoder_positions() for model in models)
    itr = dataset.eval_dataloader(args.gen_subset,
                                  max_sentences=args.batch_size,
                                  max_positions=max_positions,
                                  skip_invalid_size_inputs_valid_test=args.
                                  skip_invalid_size_inputs_valid_test)
    num_sentences = 0
    with utils.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        gen_timer = StopwatchMeter()
        translations = translator.generate_batched_itr(
            t,
            maxlen_a=args.max_len_a,
            maxlen_b=args.max_len_b,
            cuda_device=0 if use_cuda else None,
            timer=gen_timer)
        for sample_id, src_tokens, target_tokens, hypos in translations:
            # Process input and ground truth
            target_tokens = target_tokens.int().cpu()
            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = dataset.splits[
                    args.gen_subset].src.get_original_text(sample_id)
                target_str = dataset.splits[
                    args.gen_subset].dst.get_original_text(sample_id)
            else:
                src_str = dataset.src_dict.string(src_tokens, args.remove_bpe)
                target_str = dataset.dst_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

            if not args.quiet:
                print('S-{}\t{}'.format(sample_id, src_str))
                print('T-{}\t{}'.format(sample_id, target_str))

            # Process top predictions
            for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu(),
                    align_dict=align_dict,
                    dst_dict=dataset.dst_dict,
                    remove_bpe=args.remove_bpe)

                if not args.quiet:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'],
                                                hypo_str))
                    print('A-{}\t{}'.format(sample_id,
                                            ' '.join(map(str, alignment))))

                # Score only the top hypothesis
                if i == 0:
                    if align_dict is not None or args.remove_bpe is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tokenizer.Tokenizer.tokenize(
                            target_str,
                            dataset.dst_dict,
                            add_if_not_exist=True)
                    scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += 1

    print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'.
          format(num_sentences, gen_timer.n, gen_timer.sum,
                 1. / gen_timer.avg))
    print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam,
                                                  scorer.result_string()))
Beispiel #4
0
def main(args, init_distributed=False):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Print args
    print(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=0)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print(model)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max input frames per GPU = {} and max sentences per GPU = {}'.
          format(
              args.max_tokens,
              args.max_sentences,
          ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

    if hasattr(trainer.criterion, 'set_train_tgt_dataset'):
        trainer.criterion.set_train_tgt_dataset(
            task.dataset(args.train_subset).tgt)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_subsets = args.valid_subset.split(',')
    while (lr > args.min_lr and (epoch_itr.epoch < max_epoch or
                                 (epoch_itr.epoch == max_epoch
                                  and epoch_itr._next_epoch_itr is not None))
           and trainer.get_num_updates() < max_update):
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
        else:
            valid_losses = [None]

        # only use first validation wer to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        reload_dataset = len(args.train_feat_files) > 1
        # sharded data: get train iterator for next epoch
        epoch_itr = trainer.get_train_iterator(epoch_itr.epoch,
                                               load_dataset=reload_dataset)
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Beispiel #5
0
def main(args):
    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)
    if args.distributed_world_size > 1:
        assert (torch.distributed.is_initialized())
        torch.distributed.broadcast(torch.tensor([1], device="cuda"), 0)
        torch.cuda.synchronize()
    if args.max_tokens is None:
        args.max_tokens = 6000
    print(args)
    pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
    result = torch.cuda.cudart().cudaDeviceSetLimit(ctypes.c_int(0x05),
                                                    ctypes.c_int(128))
    result = torch.cuda.cudart().cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load dataset splits
    load_dataset_splits(task, ['train', 'valid'])

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {}'.format(
        sum(p.numel() for p in model.parameters())))

    # Build trainer
    if args.fp16 and not args.amp:
        trainer = FP16Trainer(args, task, model, criterion)
    elif args.fp16 and args.amp:
        raise ValueError('Cannot use AMP and fp16 simultaneously')
    else:
        if torch.cuda.get_device_capability(0)[0] >= 7 and not args.amp:
            print(
                '| NOTICE: your device may support faster training with --fp16'
            )
        trainer = Trainer(args, task, model, criterion)
    if (args.online_eval or args.target_bleu) and not args.remove_bpe:
        args.remove_bpe = '@@ '
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))
    max_positions = trainer.get_model().max_positions()
    epoch_itr = data.EpochBatchIterator(
        dataset=task.dataset(args.train_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences_valid,
        max_positions=max_positions,
        ignore_invalid_inputs=True,
        required_batch_size_multiple=8,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
    )
    # Load the latest checkpoint if one is available
    load_checkpoint(args, trainer, epoch_itr)

    # Send a dummy batch to warm the caching allocator
    dummy_batch = task.dataset('train').get_dummy_batch(
        args.max_tokens, max_positions)
    trainer.dummy_train_step(dummy_batch)

    # Train until the learning rate gets too small or model reaches target score
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    tgt_bleu = args.target_bleu or math.inf
    current_bleu = 0.0
    best_bleu = 0.0
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')

    while lr >= args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update and current_bleu < tgt_bleu:
        # train for one epoch
        train(args, trainer, task, epoch_itr)
        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)

        # Eval BLEU score
        if args.online_eval or (not tgt_bleu is math.inf):
            current_bleu, current_sc_bleu = score(args, trainer, task,
                                                  epoch_itr, args.gen_subset)
            if current_bleu > best_bleu:
                best_bleu = current_bleu
                save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        # Only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # Save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Beispiel #6
0
def main(args):
    if args.max_tokens is None:
        args.max_tokens = 6000
    print(args)

    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load dataset splits
    load_dataset_splits(task, ['train', 'valid'])

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
    print('| num. model params: {}'.format(sum(p.numel() for p in model.parameters())))

    # Build trainer
    if args.fp16:
        if torch.cuda.get_device_capability(0)[0] < 7:
            print('| WARNING: your device does NOT support faster training with --fp16,'
                  ' please switch to FP32 which is likely to be faster')
        trainer = FP16Trainer(args, task, model, criterion)
    else:
        if torch.cuda.get_device_capability(0)[0] >= 7:
            print('| NOTICE: your device may support faster training with --fp16')
        trainer = Trainer(args, task, model, criterion)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Initialize dataloader
    from IPython.core.debugger import Pdb; Pdb().set_trace()
    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        trainer.get_model().max_positions(),
    )
    epoch_itr = task.get_batch_iterator(
        dataset=task.dataset(args.train_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
        ignore_invalid_inputs=True,
        required_batch_size_multiple=8,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
    )

    # Load the latest checkpoint if one is available
    if not load_checkpoint(args, trainer, epoch_itr):
        # Send a dummy batch to warm the caching allocator
        dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions)
        trainer.dummy_train_step(dummy_batch) # comment out for debug

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates() < max_update:
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Beispiel #7
0
def main():
    parser = options.get_parser('Trainer')
    dataset_args = options.add_dataset_args(parser)
    dataset_args.add_argument('--max-tokens',
                              default=6000,
                              type=int,
                              metavar='N',
                              help='maximum number of tokens in a batch')
    dataset_args.add_argument(
        '--train-subset',
        default='train',
        metavar='SPLIT',
        choices=['train', 'valid', 'test'],
        help='data subset to use for training (train, valid, test)')
    dataset_args.add_argument(
        '--valid-subset',
        default='valid',
        metavar='SPLIT',
        help='comma separated list ofdata subsets '
        ' to use for validation (train, valid, valid1,test, test1)')
    dataset_args.add_argument('--test-subset',
                              default='test',
                              metavar='SPLIT',
                              help='comma separated list ofdata subset '
                              'to use for testing (train, valid, test)')
    options.add_optimization_args(parser)
    options.add_checkpoint_args(parser)
    options.add_model_args(parser)
    options.add_generation_args(
        parser)  # should specify generation parameters!!

    args = utils.parse_args_and_arch(parser)
    print(args)

    # specify visible devices
    os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
    print('CUDA_VISIBLE_DEVICES: {}\n'.format(args.cuda_visible_devices))

    if args.no_progress_bar:
        progress_bar.enabled = False
        progress_bar.print_interval = args.log_interval

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    torch.manual_seed(args.seed)

    # Load dataset
    dataset = data.load_with_check(args.data, args.source_lang,
                                   args.target_lang)
    if args.source_lang is None or args.target_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.source_lang, args.target_lang = dataset.src, dataset.dst

    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))
    for split in dataset.splits:
        print('| {} {} {} examples'.format(args.data, split,
                                           len(dataset.splits[split])))

    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    num_gpus = torch.cuda.device_count()

    print('| using {} GPUs (with max tokens per GPU = {})'.format(
        num_gpus, args.max_tokens))

    # Build model
    print('| model {}'.format(args.arch))
    model = utils.build_model(args, dataset)
    criterion = utils.build_criterion(args, dataset)

    # Start multiprocessing
    trainer = MultiprocessingTrainer(args,
                                     model,
                                     src_dict=dataset.src_dict,
                                     dst_dict=dataset.dst_dict)

    # Load the latest checkpoint if one is available
    epoch, batch_offset = trainer.load_checkpoint(
        os.path.join(args.save_dir, args.restore_file))
    print("batch_offset:" + str(batch_offset))

    # Train until the learning rate gets too small
    val_loss = None
    max_epoch = args.max_epoch or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    while lr > args.min_lr and epoch <= max_epoch:
        # train for one epoch
        train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus)

        # evaluate on validate set
        for k, subset in enumerate(args.valid_subset.split(',')):
            val_loss = validate(args, epoch, trainer, criterion, dataset,
                                subset, num_gpus)
            if k == 0:
                if not args.no_save:
                    # save checkpoint
                    trainer.save_checkpoint(args, epoch, 0, val_loss)
                # only use first validation loss to update the learning schedule
                lr = trainer.lr_step(val_loss, epoch)

        epoch += 1
        batch_offset = 0
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))

    # Generate on test set and compute BLEU score
    for beam in [1, 5, 10, 20]:
        for subset in args.test_subset.split(','):
            scorer = score_test(args,
                                trainer.get_model(),
                                dataset,
                                subset,
                                beam,
                                cuda_device=(0 if num_gpus > 0 else None))
            print('| Test on {} with beam={}: {}'.format(
                subset, beam, scorer.result_string()))

    # Stop multiprocessing
    trainer.stop()
Beispiel #8
0
def train(
    args,
    extra_state: Dict[str, Any],
    trainer,
    task,
    epoch_itr,
    checkpoint_manager: Optional[checkpoint.CheckpointManager],
    output_queue: Optional[mp_queues.Queue] = None,
    **train_step_kwargs,
):
    # offset for current epoch (may be different from checkpoint offset)
    starting_offset = extra_state["batch_offset"]

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    stop_training_mid_epoch = False
    stop_training_end_of_epoch = False

    do_prune = args.pruning_percentile > 0
    if do_prune:
        prune_masks = create_prune_masks(args, trainer)
        apply_prune_masks(prune_masks, trainer)

    while lr > args.min_lr and extra_state["epoch"] <= max_epoch:
        """Train the model for one epoch."""

        itr, progress, extra_meters = setup_epoch(args=args,
                                                  epoch_itr=epoch_itr,
                                                  trainer=trainer)

        for i, samples in enumerate(progress, start=starting_offset):
            clear_per_step_extra_state(extra_state)
            extra_state["num_iterations"] = extra_state.get(
                "num_iterations", 0) + 1
            if (train_step_kwargs is not None
                    and "augment_adv" in train_step_kwargs.keys()):
                train_step_kwargs["augment_adv"] = (
                    extra_state["num_iterations"] > args.warmup_steps)
            try:
                log_output = trainer.train_step(samples, **train_step_kwargs)
            # Fairseq's fp16_trainer raises this uncommon error to indicate
            # that we should stop training.
            except FloatingPointError as e:
                print(f"Stopping training due to: {e}.")
                stop_training_mid_epoch = True
                break

            if do_prune:
                apply_prune_masks(prune_masks, trainer)

            if i == starting_offset:
                # ignore the first mini-batch in words-per-second calculation
                trainer.get_meter("wps").reset()

            # Clear any remaining metrics from previous steps. This should already
            # have been done before, but just in case - to make sure we catch
            # any case where extra_case does not get populated correctly.
            extra_state = clear_per_step_extra_state(extra_state)
            extra_state["batch_offset"] = i + 1
            (
                extra_state,
                stop_training_mid_epoch,
                translation_samples,
            ) = evals.save_and_eval(
                args=args,
                trainer=trainer,
                task=task,
                extra_state=extra_state,
                checkpoint_manager=checkpoint_manager,
            )

            # This should come after save_and_eval. Even if log_output is None,
            # meaning that there was an overflow,  We should still run
            # save_and_eval to sync all_reduce and then skip the batch.
            if log_output is None:
                # This indicates that the batch was skipped, typically
                # because of OOM or FP16 overflow.
                continue

            train_stats = evals.log_mid_epoch_stats(
                trainer=trainer,
                progress=progress,
                extra_meters=extra_meters,
                log_output=log_output,
            )
            extra_state = update_output(
                args=args,
                extra_state=extra_state,
                output_queue=output_queue,
                num_updates=trainer.get_num_updates(),
                train_ppl=train_stats["ppl"],
                # We only report wps at the end of an epoch, since
                # the meter gets reset at the start of every epoch.
                wps=None,
            )

            if (hasattr(args, "lr_shrink") and args.save_interval_updates > 0
                    and extra_state["num_iterations"] %
                    args.save_interval_updates == 0
                    and args.shrink_lr_no_best_bleu_eval > 0
                    and extra_state["tune_bleu"]["num_since_best"] >
                    args.shrink_lr_no_best_bleu_eval):
                current_lr = trainer.optimizer.get_lr()
                trainer.optimizer.set_lr(current_lr * args.lr_shrink)
                lr = trainer.optimizer.get_lr()
                print(f"Decayed lr from {current_lr} to {lr}.")

            if stop_training_mid_epoch:
                break

        # log end-of-epoch stats
        train_stats = evals.log_end_epoch_stats(trainer=trainer,
                                                progress=progress,
                                                extra_meters=extra_meters)

        # batch_offset being None denotes the end of an epoch.
        extra_state["batch_offset"] = None
        (
            extra_state,
            stop_training_end_of_epoch,
            translation_samples,
        ) = evals.save_and_eval(
            args=args,
            trainer=trainer,
            task=task,
            extra_state=extra_state,
            end_of_epoch=True,
            checkpoint_manager=checkpoint_manager,
        )
        extra_state = update_output(
            args=args,
            extra_state=extra_state,
            output_queue=output_queue,
            num_updates=trainer.get_num_updates(),
            train_ppl=train_stats["ppl"],
            wps=train_stats["wps"],
        )

        if stop_training_mid_epoch or stop_training_end_of_epoch:
            break

        lr = trainer.lr_step(extra_state["epoch"],
                             extra_state["tune_eval"]["loss"])
        extra_state["epoch"] += 1
        extra_state["batch_offset"] = 0
        starting_offset = 0

    train_meter.stop()
    print(f"| done training in {train_meter.sum:.1f} seconds")

    # the checkpoint manager may be None
    if checkpoint_manager:
        checkpoint_manager.remove_all_checkpoints()

    print(f"| Best BLEU score of {extra_state['tune_bleu']['best']} was from "
          f"epoch {extra_state['tune_bleu']['best_epoch']}")
Beispiel #9
0
def main(args, init_distributed=False):
    import_user_module(args)

    if args.max_tokens is None:
        args.max_tokens = 6000
    print(args)

    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # 加载数据集
    load_dataset_splits(task, ['train', 'valid'])

    # 初始化分布式训练
    if init_distributed:
        import socket
        args.distributed_rank = distributed_utils.distributed_init(args)
        print('| initialized host {} as rank {}'.format(
            socket.gethostname(), args.distributed_rank))

    # build模型和损失
    model = task.build_model(args)
    #损失函数, eg: CrossEntropyCriterion()
    criterion = task.build_criterion(args)
    # 打印模型架构
    print(model)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Make a dummy batch to (i) warm the caching allocator and (ii) as a
    # placeholder DistributedDataParallel when there's an uneven number of
    # batches per worker.
    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        model.max_positions(),
    )
    dummy_batch = task.dataset('train').get_dummy_batch(
        args.max_tokens, max_positions)
    oom_batch = task.dataset('train').get_dummy_batch(1, max_positions)
    #是否从预训练模型加载参数
    model.copy_pretrained_params(args)

    # 构建一个builder
    trainer = Trainer(args, task, model, criterion, dummy_batch, oom_batch)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # 初始化 dataloader
    epoch_itr = task.get_batch_iterator(
        dataset=task.dataset(args.train_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
        ignore_invalid_inputs=True,
        required_batch_size_multiple=8,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
        num_workers=args.num_workers,
    )

    # 加载最新的checkpoint(如果有), 继续训练模型
    if not load_checkpoint(args, trainer, epoch_itr):
        trainer.dummy_train_step([dummy_batch])

    # 训练直到学习率变得太小
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update:
        # 训练一个epoch
        train(args, trainer, task, epoch_itr)

        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
            # ema process
            if not args.no_ema:
                old_data = ema_restore(trainer.ema, trainer.model)
                valid_losses_ema = validate(args, trainer, task, epoch_itr,
                                            valid_subsets)
                if epoch_itr.epoch % args.save_interval == 0:
                    save_checkpoint(args,
                                    trainer,
                                    epoch_itr,
                                    valid_losses_ema[0],
                                    suffix='ema')
                ema_reverse(trainer.ema, trainer.model, old_data)

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Beispiel #10
0
def save_checkpoint(args, trainer, epoch_itr, val_loss):
    if args.no_save or not distributed_utils.is_master(args):
        return

    write_timer = StopwatchMeter()
    write_timer.start()

    epoch = epoch_itr.epoch
    end_of_epoch = epoch_itr.end_of_epoch()
    updates = trainer.get_num_updates()

    checkpoint_conds = collections.OrderedDict()
    checkpoint_conds['checkpoint{}.pt'.format(epoch)] = (
        end_of_epoch and not args.no_epoch_checkpoints
        and epoch % args.save_interval == 0)
    checkpoint_conds['checkpoint_{}_{}.pt'.format(
        epoch, updates)] = (not end_of_epoch and args.save_interval_updates > 0
                            and updates % args.save_interval_updates == 0)
    checkpoint_conds['checkpoint_best.pt'] = (
        val_loss is not None and (not hasattr(save_checkpoint, 'best')
                                  or val_loss < save_checkpoint.best))
    checkpoint_conds[
        'checkpoint_last.pt'] = True  # keep this last so that it's a symlink

    prev_best = getattr(save_checkpoint, 'best', val_loss)
    if val_loss is not None:
        save_checkpoint.best = min(val_loss, prev_best)
    extra_state = {
        'train_iterator': epoch_itr.state_dict(),
        'val_loss': val_loss,
    }
    if hasattr(save_checkpoint, 'best'):
        extra_state.update({'best': save_checkpoint.best})

    checkpoints = [
        os.path.join(args.save_dir, fn)
        for fn, cond in checkpoint_conds.items() if cond
    ]
    if len(checkpoints) > 0:
        for cp in checkpoints:
            trainer.save_checkpoint(cp, extra_state)

        write_timer.stop()
        print(
            '| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)'
            .format(checkpoints[0], epoch, updates, write_timer.sum))

    if not end_of_epoch and args.keep_interval_updates > 0:
        # remove old checkpoints; checkpoints are sorted in descending order
        checkpoints = utils.checkpoint_paths(
            args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt')
        for old_chk in checkpoints[args.keep_interval_updates:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)

    if args.keep_last_epochs > 0:
        # remove old epoch checkpoints; checkpoints are sorted in descending order
        checkpoints = utils.checkpoint_paths(args.save_dir,
                                             pattern=r'checkpoint(\d+)\.pt')
        for old_chk in checkpoints[args.keep_last_epochs:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)
def main(args):
    utils.import_user_module(args)

    if args.buffer_size < 1:
        args.buffer_size = 1
    if args.max_tokens is None and args.max_sentences is None:
        args.max_sentences = 1

    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert not args.max_sentences or args.max_sentences <= args.buffer_size, \
        '--max-sentences/--batch-size cannot be larger than --buffer-size'

    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Setup task, e.g., translation
    task = tasks.setup_task(args)

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(':'),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Initialize generator
    generator = task.build_generator(args)

    # Handle tokenization and BPE
    tokenizer = encoders.build_tokenizer(args)
    bpe = encoders.build_bpe(args)

    def encode_fn(x):
        if tokenizer is not None:
            x = tokenizer.encode(x)
        if bpe is not None:
            x = bpe.encode(x)
        return x

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        *[model.max_positions() for model in models]
    )

    if args.buffer_size > 1:
        print('| Sentence buffer size:', args.buffer_size)
    print('| Type the input sentence and press return:')
    start_id = 0
    gen_timer = StopwatchMeter()
    for inputs in buffered_read(args.input, args.buffer_size):
        results = []

        # input is sentence \t s1|||t1 \t s2|||t2 ...
        new_inputs = []
        constraints = []
        for inp in inputs:
            inp = inp.split('\t')
            new_inputs.append(inp[0])
            constraints.append([tup.split('|||')[1] for tup in inp[1:]])

        for batch in make_batches(new_inputs, args, task, max_positions, encode_fn, constraints):
            src_tokens = batch.src_tokens
            src_lengths = batch.src_lengths
            tgt_init_tokens = batch.tgt_init_tokens
            tgt_init_lengths = batch.tgt_init_lengths
            if use_cuda:
                src_tokens = src_tokens.cuda()
                src_lengths = src_lengths.cuda()
                tgt_init_tokens = tgt_init_tokens.cuda()
                tgt_init_lengths = tgt_init_lengths.cuda()

            sample = {
                'net_input': {
                    'src_tokens': src_tokens,
                    'src_lengths': src_lengths,
                    'tgt_init_tokens': tgt_init_tokens,
                    'tgt_init_lengths': tgt_init_lengths,
                },
            }
            gen_timer.start()
            translations = task.inference_step(generator, models, sample)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in translations)
            gen_timer.stop(num_generated_tokens)
            for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
                src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
                results.append((start_id + id, src_tokens_i, hypos))

        # sort output to match input order
        for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]):
            if src_dict is not None:
                src_str = src_dict.string(src_tokens, args.remove_bpe)
                print('S-{}\t{}'.format(id, src_str))

            # Process top predictions
            for hypo in hypos[:min(len(hypos), args.nbest)]:
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'],
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                )
                hypo_str = decode_fn(hypo_str)
                print('H-{}\t{}\t{}'.format(id, hypo['score'], hypo_str))
                print('P-{}\t{}'.format(
                    id,
                    ' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist()))
                ))
                if args.print_alignment:
                    alignment_str = " ".join(["{}-{}".format(src, tgt) for src, tgt in alignment])
                    print('A-{}\t{}'.format(
                        id,
                        alignment_str
                    ))

        # update running id counter
        start_id += len(inputs)
    print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
        start_id, gen_timer.n, gen_timer.sum, start_id / gen_timer.sum, 1. / gen_timer.avg))
def main(args):
    if not args.no_dllogger:
        setup_logger(args)
    else:
        dllogger.init(backends=[])

    args.interactive = sys.stdin.isatty(
    )  # Just make the code more understendable
    if args.interactive:
        args.buffer_size = 1
    if args.max_tokens is None and args.max_sentences is None:
        args.max_sentences = 1
    if args.buffer_size > 50000:
        print(
            "WARNING: To prevent memory exhaustion buffer size is set to 50000",
            file=sys.stderr)
        args.buffer_size = 50000

    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert not args.max_sentences or args.max_sentences <= args.buffer_size, \
        '--max-sentences/--batch-size cannot be larger than --buffer-size'

    print(args, file=sys.stderr)

    use_cuda = torch.cuda.is_available() and not args.cpu

    processing_start = time.time()

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path), file=sys.stderr)
    model_paths = args.path.split(':')
    models, model_args, src_dict, tgt_dict = load_ensemble_for_inference(
        model_paths, model_arg_overrides=eval(args.model_overrides))
    if args.fp16:
        for model in models:
            model.half()

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(need_attn=args.print_alignment)

    # Initialize generator
    translator = SequenceGenerator(
        models,
        tgt_dict.get_metadata(),
        maxlen=args.max_target_positions,
        beam_size=args.beam,
        stop_early=(not args.no_early_stop),
        normalize_scores=(not args.unnormalized),
        len_penalty=args.lenpen,
        unk_penalty=args.unkpen,
        sampling=args.sampling,
        sampling_topk=args.sampling_topk,
        minlen=args.min_len,
        sampling_temperature=args.sampling_temperature)

    if use_cuda:
        translator.cuda()

    # Load BPE codes file
    if args.bpe_codes:
        codes = open(args.bpe_codes, 'r')
        bpe = BPE(codes)
    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    def make_result(src_str, hypos):
        result = Translation(
            src_str=src_str,
            hypos=[],
            pos_scores=[],
            alignments=[],
        )

        # Process top predictions
        for hypo in hypos[:min(len(hypos), args.nbest)]:
            hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                hypo_tokens=hypo['tokens'].int().cpu(),
                src_str=src_str,
                alignment=hypo['alignment'].int().cpu()
                if hypo['alignment'] is not None else None,
                align_dict=align_dict,
                tgt_dict=tgt_dict,
                remove_bpe=args.remove_bpe,
            )
            hypo_str = tokenizer.Tokenizer.detokenize(hypo_str, 'de').strip()
            result.hypos.append((hypo['score'], hypo_str))
            result.pos_scores.append('P\t{}'.format(' '.join(
                map(
                    lambda x: '{:.4f}'.format(x),
                    hypo['positional_scores'].tolist(),
                ))))
            result.alignments.append('A\t{}'.format(' '.join(
                map(lambda x: str(utils.item(x)), alignment))) if args.
                                     print_alignment else None)
        return result

    gen_timer = StopwatchMeter()

    def process_batch(batch):
        tokens = batch.tokens
        lengths = batch.lengths

        if use_cuda:
            tokens = tokens.cuda()
            lengths = lengths.cuda()

        translation_start = time.time()
        gen_timer.start()
        translations = translator.generate(
            tokens,
            lengths,
            maxlen=int(args.max_len_a * tokens.size(1) + args.max_len_b),
        )
        gen_timer.stop(sum(len(h[0]['tokens']) for h in translations))
        dllogger.log(step='infer',
                     data={'latency': time.time() - translation_start})

        return [
            make_result(batch.srcs[i], t) for i, t in enumerate(translations)
        ]

    if args.interactive:
        print('| Type the input sentence and press return:')
    for inputs in buffered_read(args.buffer_size):
        indices = []
        results = []
        for batch, batch_indices in make_batches(inputs, args, src_dict,
                                                 args.max_positions, bpe):
            indices.extend(batch_indices)
            results += process_batch(batch)

        for i in np.argsort(indices):
            result = results[i]
            print(result.src_str, file=sys.stderr)
            for hypo, pos_scores, align in zip(result.hypos, result.pos_scores,
                                               result.alignments):
                print(f'Score {hypo[0]}', file=sys.stderr)
                print(hypo[1])
                print(pos_scores, file=sys.stderr)
                if align is not None:
                    print(align, file=sys.stderr)

    log_dict = {
        'throughput': 1. / gen_timer.avg,
        'latency_avg': sum(gen_timer.intervals) / len(gen_timer.intervals),
        'latency_p90': gen_timer.p(90),
        'latency_p95': gen_timer.p(95),
        'latency_p99': gen_timer.p(99),
        'total_infernece_time': gen_timer.sum,
        'total_run_time': time.time() - processing_start,
    }
    print('Translation time: {} s'.format(log_dict['total_infernece_time']),
          file=sys.stderr)
    print('Model throughput (beam {}): {} tokens/s'.format(
        args.beam, log_dict['throughput']),
          file=sys.stderr)
    print(
        'Latency:\n\tAverage {:.3f}s\n\tp90 {:.3f}s\n\tp95 {:.3f}s\n\tp99 {:.3f}s'
        .format(log_dict['latency_avg'],
                log_dict['latency_p90'],
                log_dict['latency_p95'],
                log_dict['latency_p99'],
                file=sys.stderr))
    print('End to end time: {} s'.format(log_dict['total_run_time']),
          file=sys.stderr)
    dllogger.log(step=[], data=log_dict)
Beispiel #13
0
def _generate_score(models, args, task, dataset, optimize=True):
    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load ensemble
    if not args.quiet:
        print("| loading model(s) from {}".format(", ".join(args.path.split(":"))))

    # Optimize ensemble for generation
    if optimize:
        for model in models:
            model.make_generation_fast_(
                beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
                need_attn=True,
            )

    translator = build_sequence_generator(args, task, models)
    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Keep track of translations
    # Initialize with empty translations
    # and zero probs scores
    translated_sentences = [""] * len(dataset)
    translated_scores = [0.0] * len(dataset)

    # Generate and compute BLEU score
    dst_dict = task.target_dictionary
    scorer = bleu.Scorer(dst_dict.pad(), dst_dict.eos(), dst_dict.unk())
    itr = get_eval_itr(args, models, task, dataset)

    num_sentences = 0
    translation_samples = []
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        gen_timer = StopwatchMeter()
        translations = translator.generate_batched_itr(
            t,
            maxlen_a=args.max_len_a,
            maxlen_b=args.max_len_b,
            cuda=use_cuda,
            timer=gen_timer,
            prefix_size=1 if pytorch_translate_data.is_multilingual(args) else 0,
        )
        if pytorch_translate_data.is_multilingual(args):
            first_best_translations = _iter_first_best_multilingual
        else:
            first_best_translations = _iter_first_best_bilingual
        for trans_info in first_best_translations(
            args, task, dataset, translations, align_dict
        ):
            scorer.add(trans_info.target_tokens, trans_info.hypo_tokens)
            translated_sentences[trans_info.sample_id] = trans_info.hypo_str
            translated_scores[trans_info.sample_id] = trans_info.hypo_score
            translation_samples.append(
                collections.OrderedDict(
                    {
                        "sample_id": trans_info.sample_id.item(),
                        "src_str": trans_info.src_str,
                        "target_str": trans_info.target_str,
                        "hypo_str": trans_info.hypo_str,
                    }
                )
            )
            wps_meter.update(trans_info.src_tokens.size(0))
            t.log({"wps": round(wps_meter.avg)})
            num_sentences += 1

    # If applicable, save the translations to the output file
    # For eg. external evaluation
    if getattr(args, "translation_output_file", False):
        with open(args.translation_output_file, "w") as out_file:
            for hypo_str in translated_sentences:
                print(hypo_str, file=out_file)

    if getattr(args, "translation_probs_file", False):
        with open(args.translation_probs_file, "w") as out_file:
            for hypo_score in translated_scores:
                print(np.exp(hypo_score), file=out_file)

    return scorer, num_sentences, gen_timer, translation_samples
Beispiel #14
0
def main(args, init_distributed=False):
    utils.import_user_module(args)

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Print args
    logger.info(args)

    # Setup tasks, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid datasets (we load training data below, based on the latest checkpoint)
    valid_subsets = args.valid_subset.split(',')
    for valid_sub_split in valid_subsets:
        task.load_dataset(valid_sub_split, combine=False, epoch=0)

    # Build models
    model = task.build_model(args)

    # Build criterions
    criterion = task.build_criterion(args)

    logger.info(model)
    logger.info('model {}, criterion {}'.format(args.arch,
                                                criterion.__class__.__name__))
    logger.info('num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    logger.info('training on {} GPUs'.format(args.distributed_world_size))
    logger.info(
        'max tokens per GPU = {} and max sentences per GPU = {}'.format(
            args.max_tokens,
            args.max_sentences,
        ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

    if distributed_utils.is_master(args) and not args.debug:
        initialize_neptune(trainer, extra_state, args)

    if getattr(args, 'eval_downstream',
               None) and len(args.downstream_dict) > 0:
        downstream_dict = {}
        for downstream_name, downstream_kwargs in args.downstream_dict.items():
            downstream_dict[downstream_name] = create_downstream_dict(
                args, downstream_name, downstream_kwargs, model)

    # Move model and criterion to gpu
    if torch.cuda.is_available() and not args.cpu:
        model.to('cuda:{}'.format(args.device_id))
        criterion.to('cuda:{}'.format(args.device_id))

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf

    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    if args.validate_before_training and extra_state is None:
        # We want to make sure we do validate_before_training
        # only when we start the trainig from scratch (thus, extra_state is None).
        # Here, we assert that indeed the training has just started
        # and training epoch is equal to one.
        assert epoch_itr.epoch == 1
        valid_losses = validate(args, trainer, task, 0, valid_subsets)
        if args.eval_downstream:
            run_downstream(args, downstream_dict, model, criterion, 0,
                           trainer.get_num_updates())

    while (not args.disable_training
           and ((isinstance(lr, np.ndarray) and all(lr > args.min_lr)) or
                (not isinstance(lr, np.ndarray) and lr > args.min_lr))
           and epoch_itr.next_epoch_idx <= max_epoch
           and trainer.get_num_updates() < max_update):
        train(args, trainer, task, epoch_itr)

        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:

            # validate on task validation set
            valid_losses = validate(args, trainer, task, epoch_itr.epoch,
                                    valid_subsets)

            # evaluate on downstream tasks
            if getattr(args, 'eval_downstream', None):
                run_downstream(args, downstream_dict, model, criterion,
                               epoch_itr.epoch, trainer.get_num_updates())
        else:
            valid_losses = [None]

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:

            save_extra_state = {}
            if get_experiment_id():
                save_extra_state['neptune_id'] = get_experiment_id()
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0],
                            save_extra_state)

        # early stop
        if should_stop_early(args, valid_losses[0]):
            logger.info(
                'early stop since valid performance hasn\'t improved for last {} runs'
                .format(args.patience))
            break

        reload_dataset = getattr(args, 'reload', False)
        # sharded data: get train iterator for next epoch
        epoch_itr = trainer.get_train_iterator(epoch_itr.next_epoch_idx,
                                               load_dataset=reload_dataset)

    train_meter.stop()
    logger.info('done training in {:.1f} seconds'.format(train_meter.sum))
Beispiel #15
0
def main(args):
    if args.max_tokens is None:
        args.max_tokens = 6000
    print(args)

    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load dataset splits
    load_dataset_splits(args, task, ['train', 'valid'])

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {}'.format(
        sum(p.numel() for p in model.parameters())))

    # Build trainer
    if args.fp16:
        trainer = FP16Trainer(args, task, model, criterion)
    else:
        if torch.cuda.get_device_capability(0)[0] >= 7:
            print(
                '| NOTICE: your device may support faster training with --fp16'
            )
        trainer = Trainer(
            args, task, model, criterion
        )  ##trainer seems to be generic enough to be used for wikicatsum
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    ignoredIndices = []
    if args.outindices:
        print("*    Filter examples with indices in: " + args.outindices)
        f = open(args.outindices, 'r')
        for line in f.readlines():
            ignoredIndices.append(int(line.strip()))

    # Initialize dataloader
    max_positions = trainer.get_model().max_positions()

    epoch_itr = data.EpochBatchIterator(
        dataset=task.dataset(args.train_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences_valid,
        max_positions=max_positions,
        ignore_invalid_inputs=True,
        required_batch_size_multiple=8,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
        savedir=os.path.join(args.save_dir, ""),
        ignoredIndices=ignoredIndices,
    )
    print("* Epoch batch iterator created nb. {}".format(len(epoch_itr)))

    # Load the latest checkpoint if one is available
    load_checkpoint(args, trainer, epoch_itr)

    # Send a dummy batch to warm the caching allocator
    dummy_batch = task.dataset('train').get_dummy_batch(
        args.max_tokens, max_positions)

    trainer.dummy_train_step(dummy_batch)
    print("Ok dummy step")

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update:
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Beispiel #16
0
def train(args, extra_state, trainer, dataset):
    start_time = time.time()

    # offset for current epoch (may be different from checkpoint offset)
    starting_offset = extra_state["batch_offset"]

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    while lr > args.min_lr and extra_state["epoch"] <= max_epoch:
        """Train the model for one epoch."""

        itr, progress, extra_meters = setup_epoch(
            args=args,
            epoch=extra_state["epoch"],
            batch_offset=starting_offset,
            trainer=trainer,
            dataset=dataset,
        )

        last_bleu_eval = 0
        for i, sample in enumerate(itr, start=starting_offset):
            log_output = trainer.train_step(sample)

            train_stats = log_mid_epoch_stats(
                trainer=trainer,
                progress=progress,
                extra_meters=extra_meters,
                log_output=log_output,
            )

            if (args.continuous_averaging_after_epochs >= 0
                    and extra_state["epoch"] >
                    args.continuous_averaging_after_epochs):
                model_param_dict = trainer.model.state_dict()
                if "param_totals" not in extra_state:
                    extra_state["param_totals"] = {}
                    for name, value in model_param_dict.items():
                        extra_state["param_totals"][name] = value.clone()
                    extra_state["param_accum_count"] = 1
                else:
                    for name, value in model_param_dict.items():
                        extra_state["param_totals"][name] += value
                    extra_state["param_accum_count"] += 1

            if i == starting_offset:
                # ignore the first mini-batch in words-per-second calculation
                trainer.get_meter("wps").reset()

            num_updates = trainer.get_num_updates()
            do_validate = (args.subepoch_validate_interval > 0 and
                           num_updates % args.subepoch_validate_interval == 0)
            do_save = (not args.no_save and args.save_interval > 0
                       and num_updates % args.save_interval == 0)
            do_eval_bleu = (
                # We can only do BLEU eval when we have a new checkpoint to load.
                do_save and
                args.generate_bleu_eval_interval > 0 and num_updates -
                last_bleu_eval >= args.generate_bleu_eval_interval)
            if do_eval_bleu:
                last_bleu_eval = num_updates

            extra_state["batch_offset"] = i + 1

            (_, val_ppl, val_bleu,
             stop_training_mid_epoch) = validate_save_and_evaluate_bleu(
                 args=args,
                 trainer=trainer,
                 dataset=dataset,
                 extra_state=extra_state,
                 do_validate=do_validate,
                 do_save=do_save,
                 do_eval_bleu=do_eval_bleu,
             )
            yield (
                trainer.get_num_updates(),
                {
                    "train_ppl": train_stats["ppl"],
                    "tune_ppl": val_ppl,
                    "tune_bleu": val_bleu,
                },
            )

            if stop_training_mid_epoch:
                break

        # log end-of-epoch stats
        train_stats = log_end_epoch_stats(trainer=trainer,
                                          progress=progress,
                                          extra_meters=extra_meters)

        if stop_training_mid_epoch:
            break

        # batch_offset being None denotes the end of an epoch.
        extra_state["batch_offset"] = None

        (val_loss, val_ppl, val_bleu,
         stop_training_end_of_epoch) = validate_save_and_evaluate_bleu(
             args=args,
             trainer=trainer,
             dataset=dataset,
             extra_state=extra_state,
             do_validate=True,
             do_save=not args.no_save and not args.no_end_of_epoch_checkpoints,
             do_eval_bleu=args.generate_bleu_eval_per_epoch,
         )
        extra_state["val_loss"] = val_loss
        yield (
            trainer.get_num_updates(),
            {
                "train_ppl": train_stats["ppl"],
                "tune_ppl": val_ppl,
                "tune_bleu": val_bleu,
            },
        )
        if stop_training_end_of_epoch:
            break

        lr = trainer.lr_step(extra_state["epoch"], val_loss)
        extra_state["epoch"] += 1
        starting_offset = 0

        if is_training_over_time_limit(start_time, args.stop_time_hr):
            break

    train_meter.stop()
    print(f"| done training in {train_meter.sum:.1f} seconds")
    if hasattr(evaluate_bleu, "best") and hasattr(evaluate_bleu, "best_epoch"):
        print(f"| Best BLEU score of {evaluate_bleu.best} was from "
              f"epoch {evaluate_bleu.best_epoch}")
Beispiel #17
0
def main(parsed_args):
    assert parsed_args.path is not None, '--path required for evaluation!'

    print(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load ensemble
    print('| loading model(s) from {}'.format(parsed_args.path))
    models, args = utils.load_ensemble_for_inference(parsed_args.path.split(':'), task)

    args.__dict__.update(parsed_args.__dict__)
    print(args)

    task.args = args

    # Load dataset splits
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
        if args.fp16:
            model.half()

    assert len(models) > 0

    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(*[
            model.max_positions() for model in models
        ]),
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        ignore_invalid_inputs=True,
    ).next_epoch_itr(shuffle=False)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(models, task.target_dictionary)
    if use_cuda:
        scorer.cuda()

    score_sum = 0.
    count = 0

    if args.remove_bpe is not None:
        bpe_cont = args.remove_bpe.rstrip()
        bpe_toks = set(i for i in range(len(task.dictionary)) if task.dictionary[i].endswith(bpe_cont))
        bpe_len = len(bpe_cont)
    else:
        bpe_toks = None
        bpe_len = 0

    word_stats = dict()

    with progress_bar.build_progress_bar(args, itr) as t:
        results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
        wps_meter = TimeMeter()
        for _, src_tokens, __, hypos in results:
            for hypo in hypos:
                pos_scores = hypo['positional_scores']

                skipped_toks = 0
                if bpe_toks is not None:
                    for i in range(len(hypo['tokens']) - 1):
                        if hypo['tokens'][i].item() in bpe_toks:
                            skipped_toks += 1
                            pos_scores[i + 1] += pos_scores[i]
                            pos_scores[i] = 0

                inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
                if inf_scores.any():
                    print('| Skipping tokens with inf scores:',
                          task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
                    pos_scores = pos_scores[(~inf_scores).nonzero()]
                score_sum += utils.item(pos_scores.sum())
                count += pos_scores.numel() - skipped_toks

                if args.output_word_probs or args.output_word_stats:
                    w = ''
                    word_prob = []
                    is_bpe = False
                    for i in range(len(hypo['tokens'])):
                        w_ind = hypo['tokens'][i].item()
                        w += task.dictionary[w_ind]
                        if bpe_toks is not None and w_ind in bpe_toks:
                            w = w[:-bpe_len]
                            is_bpe = True
                        else:
                            word_prob.append((w, pos_scores[i].item()))
                            word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item())
                            is_bpe = False
                            w = ''
                    if args.output_word_probs:
                        print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))

            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))

    if args.output_word_stats:
        for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
            print(ws)
Beispiel #18
0
def score(args, trainer, task, epoch_itr, subset):

    begin = time.time()

    if not subset in task.datasets.keys():
        task.load_dataset(subset)

    src_dict = deepcopy(task.source_dictionary
                        )  # This is necessary, generation of translations
    tgt_dict = deepcopy(
        task.target_dictionary
    )  # alters target dictionary messing up with the rest of training

    model = trainer.get_model()

    # Initialize data iterator
    itr = data.EpochBatchIterator(
        dataset=task.dataset(subset),
        max_tokens=None,
        max_sentences=max(
            8, min(math.ceil(1024 / args.distributed_world_size), 128)),
        max_positions=model.max_positions(),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    translator = SequenceGenerator(
        [model],
        tgt_dict,
        beam_size=args.beam,
        stop_early=(not args.no_early_stop),
        normalize_scores=(not args.unnormalized),
        len_penalty=args.lenpen,
        unk_penalty=args.unkpen,
        sampling=args.sampling,
        sampling_topk=args.sampling_topk,
        minlen=args.min_len,
    )
    # Generate and compute BLEU
    dict = dictionary.Dictionary()
    scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
    num_sentences = 0
    has_target = True
    predictions = []
    with progress_bar.build_progress_bar(args, itr) as progress:
        translations = translator.generate_batched_itr(
            progress,
            maxlen_a=args.max_len_a,
            maxlen_b=args.max_len_b,
            cuda=True,
            timer=gen_timer,
            prefix_size=args.prefix_size,
        )

        wps_meter = TimeMeter()
        for sample_id, src_tokens, target_tokens, hypos in translations:
            # Process input and grount truth
            has_target = target_tokens is not None
            target_tokens = target_tokens.int().cpu() if has_target else None

            src_str = src_dict.string(src_tokens, args.remove_bpe)
            if has_target:
                target_str = tgt_dict.string(target_tokens,
                                             args.remove_bpe,
                                             escape_unk=True)

            # Process top predictions
            for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu()
                    if hypo['alignment'] is not None else None,
                    align_dict=None,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe)

                # Score only the top hypothesis
                if has_target and i == 0:
                    if args.sentencepiece:
                        hypo_str = hypo_str.replace(' ', '').replace('▁', ' ')
                        target_str = target_str.replace(' ',
                                                        '').replace('▁', ' ')
                    sys_tok = tokenizer.Tokenizer.tokenize(
                        (hypo_str.lower() if args.ignore_case else hypo_str),
                        dict)
                    ref_tok = tokenizer.Tokenizer.tokenize(
                        (target_str.lower()
                         if args.ignore_case else target_str), dict)
                    scorer.add(ref_tok, sys_tok)
                    if not args.sentencepiece:
                        hypo_str = tokenizer.Tokenizer.detokenize(
                            hypo_str, 'de')
                    predictions.append('{}\t{}'.format(sample_id, hypo_str))

            wps_meter.update(src_tokens.size(0))
            progress.log({'wps': round(wps_meter.avg)})
            num_sentences += 1

    if args.distributed_world_size > 1:
        _all_gather_bleu_scorer(scorer)
        predictions = _all_gather_predictions(predictions)

    with open(os.path.join(args.data, 'sacrebleu_reference.de'),
              'r') as reference:
        refs = [reference.readlines()]
    #reducing indexed predictions as strings is more memory efficient than reducing tuples
    predictions = [tuple(item.split('\t')) for item in predictions]
    predictions = [(int(item[0]), item[1]) for item in predictions]
    predictions.sort(key=lambda tup: tup[0])
    predictions = [
        hypo[1] + ('\n' if hypo[1][-1] != '\n' else '') for hypo in predictions
    ]
    sacrebleu_score = sacrebleu.corpus_bleu(predictions,
                                            refs,
                                            lowercase=args.ignore_case)
    print(f'|Detokenized {sacrebleu_score}')
    if gen_timer.sum != 0:
        print(
            '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
            .format(num_sentences, gen_timer.n, gen_timer.sum,
                    num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(subset, args.beam,
                                                      scorer.result_string()))

    print('| Eval completed in: {:.2f}s'.format(time.time() - begin))

    return scorer.score(order=4), sacrebleu_score.score
Beispiel #19
0
def main(args, init_distributed=False):
    utils.import_user_module(args)
    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    #Print args
    #print(args)
    # Setup task, e.g., translation, language modeling, etc.

    #The loading of the text dictionary thing happens in def setup_task in the file sentence_preciction.py
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=0)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)

    #print(model)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    #Build trainer
    trainer = Trainer(args, task, model,
                      criterion)  #Initializing a trainer class

    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    #Load the latest checkpoint if one is available and restore the
    #corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
        args, trainer)  #Dataloader init, LR sheduler, Val loss comes to play
    #check checkpoint_utilities.py

    #Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()

    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')

    #not reset_optimizer to start from the where it left

    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update:

        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
        else:
            valid_losses = [None]

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        if ':' in getattr(args, 'data', ''):
            # sharded data: get train iterator for next epoch
            epoch_itr = trainer.get_train_iterator(epoch_itr.epoch)
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Beispiel #20
0
def main(args, init_distributed=False):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)
        if args.best_checkpoint_metric == 'bleu' and not os.path.exists(args.eval_dir):
            os.mkdir(args.eval_dir)

    # Print args
    print(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=0)

    if args.best_checkpoint_metric == 'bleu':
        for test_sub_split in args.test_subset.split(','):
            task.load_dataset(test_sub_split, combine=False, epoch=0)
        args.remove_bpe = '@@ '

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print(model)
    print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Build generator if evaluate with BLEU score
    if args.best_checkpoint_metric == 'bleu':
        generator = task.build_generator(args)
    else:
        generator = None

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_subsets = args.valid_subset.split(',')
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates() < max_update:
        # train for one epoch
        train(args, trainer, task, epoch_itr, generator)

        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
            if args.best_checkpoint_metric == 'bleu':
                valid_bleu = multi_gpu_bleu(args, trainer, task, generator, trainer._model, epoch_itr, valid_subsets,
                                              pprefix="valid", valid_bleu=-1, log=False)
                test_bleu = multi_gpu_bleu(args, trainer, task, generator, trainer._model, epoch_itr, ['test'],
                                           pprefix="test", valid_bleu=valid_bleu, log=True)
                valid_losses = [valid_bleu]
        else:
            valid_losses = [None]

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        reload_dataset = ':' in getattr(args, 'data', '')
        # sharded data: get train iterator for next epoch
        epoch_itr = trainer.get_train_iterator(epoch_itr.epoch, load_dataset=reload_dataset)
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Beispiel #21
0
def _main(args, output_file):
    logging.basicConfig(
        format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        level=logging.INFO,
        stream=output_file,
    )
    logger = logging.getLogger('espresso.speech_recognize')
    if output_file is not sys.stdout:  # also print to stdout
        logger.addHandler(logging.StreamHandler(sys.stdout))

    print_options_meaning_changes(args, logger)

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset split
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionary
    dictionary = task.target_dictionary

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(os.pathsep),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )
    for i, m in enumerate(models):
        if hasattr(m, 'is_wordlm') and m.is_wordlm:
            # assume subword LM comes before word LM
            if isinstance(models[i - 1], FairseqLanguageModel):
                models[i - 1] = MultiLevelLanguageModel(
                    m,
                    models[i - 1],
                    subwordlm_weight=args.subwordlm_weight,
                    oov_penalty=args.oov_penalty,
                    open_vocab=not args.disable_open_vocab,
                )
                del models[i]
                logger.info('LM fusion with Multi-level LM')
            else:
                models[i] = TensorizedLookaheadLanguageModel(
                    m,
                    dictionary,
                    oov_penalty=args.oov_penalty,
                    open_vocab=not args.disable_open_vocab,
                )
                logger.info('LM fusion with Look-ahead Word LM')
        # assume subword LM comes after E2E models
        elif i == len(models) - 1 and isinstance(m, FairseqLanguageModel):
            logger.info('LM fusion with Subword LM')
    if args.lm_weight != 0.0:
        logger.info('using LM fusion with lm-weight={:.2f}'.format(
            args.lm_weight))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(), *[
                model.max_positions() if hasattr(model, 'encoder') else
                (None, model.max_positions()) for model in models
            ]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    if args.match_source_len:
        logger.warning(
            'The option match_source_len is not applicable to speech recognition. Ignoring it.'
        )
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # Generate and compute WER
    scorer = wer.Scorer(dictionary, wer_output_filter=args.wer_output_filter)
    num_sentences = 0
    has_target = True
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(
                generator,
                models,
                sample,
                prefix_tokens,
                lm_weight=args.lm_weight,
            )
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            # obtain nonpad mask of encoder output to plot attentions
            if args.print_alignment:
                net_input = sample['net_input']
                src_tokens = net_input['src_tokens']
                output_lengths = models[0].encoder.output_lengths(
                    net_input['src_lengths'])
                nonpad_idxs = sequence_mask(
                    output_lengths,
                    models[0].encoder.output_lengths(src_tokens.size(1)))

            for i in range(len(sample['id'])):
                has_target = sample['target'] is not None
                utt_id = sample['utt_id'][i]

                # Retrieve the original sentences
                if has_target:
                    target_str = sample['target_raw_text'][i]
                    if not args.quiet:
                        target_sent = dictionary.tokens_to_sentence(
                            target_str,
                            use_unk_sym=False,
                            bpe_symbol=args.remove_bpe,
                        )
                        print('T-{}\t{}'.format(utt_id, target_sent),
                              file=output_file)

                # Process top predictions
                for j, hypo in enumerate(hypos[i][:args.nbest]):
                    hypo_str = dictionary.string(hypo['tokens'].int().cpu(
                    ))  # not removing bpe at this point
                    if not args.quiet or i == 0:
                        hypo_sent = dictionary.tokens_to_sentence(
                            hypo_str, bpe_symbol=args.remove_bpe)

                    if not args.quiet:
                        score = hypo['score'] / math.log(
                            2)  # convert to base 2
                        print('H-{}\t{}\t{}'.format(utt_id, hypo_sent, score),
                              file=output_file)

                    # Score and obtain attention only the top hypothesis
                    if j == 0:
                        # src_len x tgt_len
                        attention = hypo['attention'][nonpad_idxs[i]].float().cpu() \
                            if args.print_alignment and hypo['attention'] is not None else None
                        if args.print_alignment and attention is not None:
                            save_dir = os.path.join(args.results_path,
                                                    'attn_plots')
                            os.makedirs(save_dir, exist_ok=True)
                            plot_attention(attention, hypo_sent, utt_id,
                                           save_dir)
                        scorer.add_prediction(utt_id,
                                              hypo_str,
                                              bpe_symbol=args.remove_bpe)
                        if has_target:
                            scorer.add_evaluation(utt_id,
                                                  target_str,
                                                  hypo_str,
                                                  bpe_symbol=args.remove_bpe)

            wps_meter.update(num_generated_tokens)
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += sample['nsentences']

    logger.info('NOTE: hypothesis and token scores are output in base 2')
    logger.info(
        'Recognized {} utterances ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if args.print_alignment:
        logger.info('Saved attention plots in ' + save_dir)

    if has_target:
        assert args.test_text_files is not None
        scorer.add_ordered_utt_list(*args.test_text_files)

    fn = 'decoded_char_results.txt'
    with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f:
        f.write(scorer.print_char_results())
        logger.info('Decoded char results saved as ' + f.name)

    fn = 'decoded_results.txt'
    with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f:
        f.write(scorer.print_results())
        logger.info('Decoded results saved as ' + f.name)

    if has_target:
        header = 'Recognize {} with beam={}: '.format(args.gen_subset,
                                                      args.beam)
        fn = 'wer'
        with open(os.path.join(args.results_path, fn), 'w',
                  encoding='utf-8') as f:
            res = 'WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%'.format(
                *(scorer.wer()))
            logger.info(header + res)
            f.write(res + '\n')
            logger.info('WER saved in ' + f.name)

        fn = 'cer'
        with open(os.path.join(args.results_path, fn), 'w',
                  encoding='utf-8') as f:
            res = 'CER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%'.format(
                *(scorer.cer()))
            logger.info(' ' * len(header) + res)
            f.write(res + '\n')
            logger.info('CER saved in ' + f.name)

        fn = 'aligned_results.txt'
        with open(os.path.join(args.results_path, fn), 'w',
                  encoding='utf-8') as f:
            f.write(scorer.print_aligned_results())
            logger.info('Aligned results saved as ' + f.name)
    return scorer
Beispiel #22
0
def main(parsed_args):
    assert parsed_args.path is not None, '--path required for evaluation!'

    utils.import_user_module(parsed_args)

    print(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load ensemble
    print('| loading model(s) from {}'.format(parsed_args.path))
    models, args = checkpoint_utils.load_model_ensemble(
        parsed_args.path.split(':'),
        arg_overrides=eval(parsed_args.model_overrides),
        task=task)

    for arg in vars(parsed_args).keys():
        if arg not in {
            'self_target', 'future_target', 'past_target', 'tokens_per_sample',
            'output_size_dictionary', 'add_bos_token',
        }:
            setattr(args, arg, getattr(parsed_args, arg))

    # reduce tokens per sample by the required context window size
    args.tokens_per_sample -= args.context_window
    task = tasks.setup_task(args)

    # Load dataset splits
    task.load_dataset(args.gen_subset)
    dataset = task.dataset(args.gen_subset)
    if args.context_window > 0:
        dataset = LMContextWindowDataset(
            dataset=dataset,
            tokens_per_sample=args.tokens_per_sample,
            context_window=args.context_window,
            pad_idx=task.source_dictionary.pad(),
        )
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset)))

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    assert len(models) > 0

    print('num. model params: {}'.format(sum(p.numel() for p in models[0].parameters())))

    itr = task.get_batch_iterator(
        dataset=dataset,
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(*[
            model.max_positions() for model in models]),
        ignore_invalid_inputs=True,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(task.target_dictionary, args.softmax_batch)

    score_sum = 0.
    count = 0

    if args.remove_bpe is not None:
        if args.remove_bpe == 'sentencepiece':
            raise NotImplementedError
        else:
            bpe_cont = args.remove_bpe.rstrip()
            bpe_toks = set(
                i
                for i in range(len(task.source_dictionary))
                if task.source_dictionary[i].endswith(bpe_cont)
            )
        bpe_len = len(bpe_cont)
    else:
        bpe_toks = None
        bpe_len = 0

    word_stats = dict()

    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()

        for sample in t:
            if 'net_input' not in sample:
                continue

            sample = utils.move_to_cuda(sample) if use_cuda else sample

            gen_timer.start()
            hypos = scorer.generate(models, sample)
            gen_timer.stop(sample['ntokens'])

            for i, hypos_i in enumerate(hypos):
                hypo = hypos_i[0]
                sample_id = sample['id'][i]

                tokens = hypo['tokens']
                tgt_len = tokens.numel()
                pos_scores = hypo['positional_scores'].float()

                if args.add_bos_token:
                    assert hypo['tokens'][0].item() == task.target_dictionary.bos()
                    tokens = tokens[1:]
                    pos_scores = pos_scores[1:]

                skipped_toks = 0
                if bpe_toks is not None:
                    for i in range(tgt_len - 1):
                        if tokens[i].item() in bpe_toks:
                            skipped_toks += 1
                            pos_scores[i + 1] += pos_scores[i]
                            pos_scores[i] = 0

                inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
                if inf_scores.any():
                    print('| Skipping tokens with inf scores:',
                          task.target_dictionary.string(tokens[inf_scores.nonzero()]))
                    pos_scores = pos_scores[(~inf_scores).nonzero()]
                score_sum += pos_scores.sum().cpu()
                count += pos_scores.numel() - skipped_toks

                if args.output_word_probs or args.output_word_stats:
                    w = ''
                    word_prob = []
                    is_bpe = False
                    for i in range(len(tokens)):
                        w_ind = tokens[i].item()
                        w += task.source_dictionary[w_ind]
                        if bpe_toks is not None and w_ind in bpe_toks:
                            w = w[:-bpe_len]
                            is_bpe = True
                        else:
                            word_prob.append((w, pos_scores[i].item()))

                            next_prob = None
                            ind = i + 1
                            while ind < len(tokens):
                                if pos_scores[ind].item() != 0:
                                    next_prob = pos_scores[ind]
                                    break
                                ind += 1

                            word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item(), next_prob)
                            is_bpe = False
                            w = ''
                    if args.output_word_probs:
                        print(
                            str(int(sample_id)) + " "
                            + ('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))
                        )

            wps_meter.update(sample['ntokens'])
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))

    if args.output_word_stats:
        for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
            print(ws)
Beispiel #23
0
def train(args, extra_state, trainer, task, epoch_itr, **train_step_kwargs):
    # offset for current epoch (may be different from checkpoint offset)
    starting_offset = extra_state["batch_offset"]

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    stop_training_mid_epoch = False
    stop_training_end_of_epoch = False

    do_prune = args.pruning_percentile > 0
    if do_prune:
        prune_masks = create_prune_masks(args, trainer)
        apply_prune_masks(prune_masks, trainer)

    while lr > args.min_lr and extra_state["epoch"] <= max_epoch:
        """Train the model for one epoch."""

        itr, progress, extra_meters = setup_epoch(args=args,
                                                  epoch_itr=epoch_itr,
                                                  trainer=trainer)

        for i, samples in enumerate(progress, start=starting_offset):
            try:
                log_output = trainer.train_step(samples, **train_step_kwargs)
            # Fairseq's fp16_trainer raises this uncommon error to indicate
            # that we should stop training.
            except FloatingPointError as e:
                print(f"Stopping training due to: {e}.")
                stop_training_mid_epoch = True
                break

            if log_output is None:
                # This indicates that the batch was skipped, typically
                # because of OOM or FP16 overflow.
                continue

            if do_prune:
                apply_prune_masks(prune_masks, trainer)

            train_stats = log_mid_epoch_stats(
                trainer=trainer,
                progress=progress,
                extra_meters=extra_meters,
                log_output=log_output,
            )

            if i == starting_offset:
                # ignore the first mini-batch in words-per-second calculation
                trainer.get_meter("wps").reset()

            num_updates = trainer.get_num_updates()
            do_validate = (args.subepoch_validate_interval > 0 and
                           num_updates % args.subepoch_validate_interval == 0)
            do_save = (not args.no_save and args.save_interval_updates > 0
                       and num_updates % args.save_interval_updates == 0)
            do_eval_bleu = (
                # We can only do BLEU eval when we have a new checkpoint to load.
                do_save and args.generate_bleu_eval_interval > 0 and
                num_updates - extra_state["last_bleu_eval"] >=
                args.generate_bleu_eval_interval)
            if do_eval_bleu:
                extra_state["last_bleu_eval"] = num_updates

            extra_state["batch_offset"] = i + 1

            (
                _,
                val_ppl,
                val_bleu,
                stop_training_mid_epoch,
                translation_samples,
                lr,
            ) = validate_save_and_evaluate_bleu(
                args=args,
                trainer=trainer,
                task=task,
                extra_state=extra_state,
                do_validate=do_validate,
                do_save=do_save,
                do_eval_bleu=do_eval_bleu,
            )
            yield (
                trainer.get_num_updates(),
                {
                    "train_ppl": train_stats["ppl"],
                    "tune_ppl": val_ppl,
                    "tune_bleu": val_bleu,
                    "translation_samples": translation_samples,
                },
            )

            stop_training_mid_epoch = (stop_training_mid_epoch
                                       or is_training_over_time_limit(
                                           extra_state["start_time"],
                                           args.stop_time_hr))
            if stop_training_mid_epoch:
                break

        # log end-of-epoch stats
        train_stats = log_end_epoch_stats(trainer=trainer,
                                          progress=progress,
                                          extra_meters=extra_meters)

        # Run a training step if not stopping mid-epoch.
        if not stop_training_mid_epoch:
            # batch_offset being None denotes the end of an epoch.
            extra_state["batch_offset"] = None
            (
                val_loss,
                val_ppl,
                val_bleu,
                stop_training_end_of_epoch,
                translation_samples,
                lr,
            ) = validate_save_and_evaluate_bleu(
                args=args,
                trainer=trainer,
                task=task,
                extra_state=extra_state,
                do_validate=True,
                do_save=not args.no_save
                and not args.no_end_of_epoch_checkpoints,
                do_eval_bleu=args.generate_bleu_eval_per_epoch,
            )
            extra_state["val_loss"] = val_loss
            yield (
                trainer.get_num_updates(),
                {
                    "train_ppl": train_stats["ppl"],
                    "tune_ppl": val_ppl,
                    "tune_bleu": val_bleu,
                    "translation_samples": translation_samples,
                },
            )
        if stop_training_mid_epoch or stop_training_end_of_epoch:
            break

        lr = trainer.lr_step(extra_state["epoch"], val_loss)
        extra_state["epoch"] += 1
        extra_state["batch_offset"] = 0
        starting_offset = 0

    train_meter.stop()
    print(f"| done training in {train_meter.sum:.1f} seconds")
    print(
        f"| Best BLEU score of {extra_state['evaluate_bleu']['best']} was from "
        f"epoch {extra_state['evaluate_bleu']['best_epoch']}")
Beispiel #24
0
def _generate_score(models, args, task, dataset, modify_target_dict):
    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load ensemble
    if not args.quiet:
        print("| loading model(s) from {}".format(", ".join(
            args.path.split(CHECKPOINT_PATHS_DELIMITER))))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=True,
        )

    translator = build_sequence_generator(args, task, models)
    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    print("seed number is" + str(args.max_examples_to_evaluate_seed))
    if args.max_examples_to_evaluate > 0:
        pytorch_translate_data.subsample_pair_dataset(
            dataset, args.max_examples_to_evaluate,
            args.max_examples_to_evaluate_seed)

    # Keep track of translations
    # Initialize with empty translations
    # and zero probs scores
    translated_sentences = [""] * len(dataset)
    translated_scores = [0.0] * len(dataset)
    hypos_list = []

    collect_output_hypos = getattr(args, "output_hypos_binary_path", False)
    if collect_output_hypos:
        output_hypos_token_arrays = [None] * len(dataset)

    # Generate and compute BLEU score
    dst_dict = task.target_dictionary
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(dst_dict.pad(), dst_dict.eos(), dst_dict.unk())

    itr = task.get_batch_iterator(
        dataset=dataset,
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    oracle_scorer = None
    if args.report_oracle_bleu:
        oracle_scorer = bleu.Scorer(dst_dict.pad(), dst_dict.eos(),
                                    dst_dict.unk())

    rescorer = None
    num_sentences = 0
    translation_samples = []
    translation_info_list = []
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        gen_timer = StopwatchMeter()
        translations = translator.generate_batched_itr(
            t,
            maxlen_a=args.max_len_a,
            maxlen_b=args.max_len_b,
            cuda=use_cuda,
            timer=gen_timer,
            prefix_size=1
            if pytorch_translate_data.is_multilingual_many_to_one(args) else 0,
        )

        for trans_info in _iter_translations(args, task, dataset, translations,
                                             align_dict, rescorer,
                                             modify_target_dict):
            if hasattr(scorer, "add_string"):
                scorer.add_string(trans_info.target_str, trans_info.hypo_str)
            else:
                scorer.add(trans_info.target_tokens, trans_info.hypo_tokens)
            if oracle_scorer is not None:
                oracle_scorer.add(trans_info.target_tokens,
                                  trans_info.best_hypo_tokens)

            if getattr(args, "translation_output_file", False):
                translated_sentences[
                    trans_info.sample_id] = trans_info.hypo_str
            if getattr(args, "translation_probs_file", False):
                translated_scores[trans_info.sample_id] = trans_info.hypo_score
            if getattr(args, "hypotheses_export_path", False):
                hypos_list.append(trans_info.hypos)
            if collect_output_hypos:
                output_hypos_token_arrays[
                    trans_info.sample_id] = trans_info.best_hypo_tokens
            if args.translation_info_export_path is not None:
                # Strip expensive data from hypotheses before saving
                hypos = [{
                    k: v
                    for k, v in hypo.items() if k in ["tokens", "score"]
                } for hypo in trans_info.hypos]
                # Make sure everything is on cpu before exporting
                hypos = [{
                    "score": hypo["score"],
                    "tokens": hypo["tokens"].cpu()
                } for hypo in hypos]
                translation_info_list.append({
                    "src_tokens":
                    trans_info.src_tokens.cpu(),
                    "target_tokens":
                    trans_info.target_tokens,
                    "hypos":
                    hypos,
                })
            translation_samples.append(
                collections.OrderedDict({
                    "sample_id":
                    trans_info.sample_id.item(),
                    "src_str":
                    trans_info.src_str,
                    "target_str":
                    trans_info.target_str,
                    "hypo_str":
                    trans_info.hypo_str,
                }))
            wps_meter.update(trans_info.src_tokens.size(0))
            t.log({"wps": round(wps_meter.avg)})
            num_sentences += 1

    # If applicable, save collected hypothesis tokens to binary output file
    if collect_output_hypos:
        output_dataset = pytorch_translate_data.InMemoryIndexedDataset()
        output_dataset.load_from_sequences(output_hypos_token_arrays)
        output_dataset.save(args.output_hypos_binary_path)
    if args.output_source_binary_path:
        dataset.src.save(args.output_source_binary_path)
    if args.translation_info_export_path is not None:
        f = open(args.translation_info_export_path, "wb")
        pickle.dump(translation_info_list, f)
        f.close()

    # If applicable, save the translations and scores to the output files
    # These two ouputs are used in dual learning for weighted backtranslation
    if getattr(args, "translation_output_file", False) and getattr(
            args, "translation_probs_file", False):
        with open(args.translation_output_file,
                  "w") as translation_file, open(args.translation_probs_file,
                                                 "w") as score_file:
            for hypo_str, hypo_score in zip(translated_sentences,
                                            translated_scores):
                if len(hypo_str.strip()) > 0:
                    print(hypo_str, file=translation_file)
                    print(np.exp(hypo_score), file=score_file)

    # For eg. external evaluation
    if getattr(args, "hypotheses_export_path", False):
        with open(args.hypotheses_export_path, "w") as out_file:
            for hypos in hypos_list:
                for hypo in hypos:
                    print(
                        task.tgt_dict.string(hypo["tokens"],
                                             bpe_symbol=args.remove_bpe),
                        file=out_file,
                    )

    if oracle_scorer is not None:
        print(
            f"| Oracle BLEU (best hypo in beam): {oracle_scorer.result_string()}"
        )

    return scorer, num_sentences, gen_timer, translation_samples
Beispiel #25
0
def train(args, extra_state, trainer, task, epoch_itr):
    # offset for current epoch (may be different from checkpoint offset)
    starting_offset = extra_state["batch_offset"]

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    stop_training_mid_epoch = False
    stop_training_end_of_epoch = False

    do_prune = args.pruning_percentile > 0
    if do_prune:
        prune_masks = prune(args, trainer)

    while lr > args.min_lr and extra_state["epoch"] <= max_epoch:
        """Train the model for one epoch."""

        itr, progress, extra_meters = setup_epoch(
            args=args,
            epoch_itr=epoch_itr,
            trainer=trainer,
        )

        for i, sample in enumerate(itr, start=starting_offset):
            log_output = trainer.train_step(sample)

            if do_prune:
                for name, params in trainer.model.named_parameters():
                    if "weight" in name:
                        params.data[prune_masks[name]] = 0.0

            train_stats = log_mid_epoch_stats(
                trainer=trainer,
                progress=progress,
                extra_meters=extra_meters,
                log_output=log_output,
            )

            if i == starting_offset:
                # ignore the first mini-batch in words-per-second calculation
                trainer.get_meter("wps").reset()

            num_updates = trainer.get_num_updates()
            do_validate = (args.subepoch_validate_interval > 0 and
                           num_updates % args.subepoch_validate_interval == 0)
            do_save = (not args.no_save and args.save_interval_updates > 0
                       and num_updates % args.save_interval_updates == 0)
            do_eval_bleu = (
                # We can only do BLEU eval when we have a new checkpoint to load.
                do_save and args.generate_bleu_eval_interval > 0 and
                num_updates - extra_state["last_bleu_eval"] >=
                args.generate_bleu_eval_interval)
            if do_eval_bleu:
                extra_state["last_bleu_eval"] = num_updates

            extra_state["batch_offset"] = i + 1

            (
                _,
                val_ppl,
                val_bleu,
                stop_training_mid_epoch,
                translation_samples,
                lr,
            ) = validate_save_and_evaluate_bleu(
                args=args,
                trainer=trainer,
                task=task,
                extra_state=extra_state,
                do_validate=do_validate,
                do_save=do_save,
                do_eval_bleu=do_eval_bleu,
            )
            yield (
                trainer.get_num_updates(),
                {
                    "train_ppl": train_stats["ppl"],
                    "tune_ppl": val_ppl,
                    "tune_bleu": val_bleu,
                    "translation_samples": translation_samples,
                },
            )

            stop_training_mid_epoch = (stop_training_mid_epoch
                                       or is_training_over_time_limit(
                                           extra_state["start_time"],
                                           args.stop_time_hr))
            if stop_training_mid_epoch:
                break

        # log end-of-epoch stats
        train_stats = log_end_epoch_stats(trainer=trainer,
                                          progress=progress,
                                          extra_meters=extra_meters)

        # Run a training step if not stopping mid-epoch.
        if not stop_training_mid_epoch:
            # batch_offset being None denotes the end of an epoch.
            extra_state["batch_offset"] = None
            (
                val_loss,
                val_ppl,
                val_bleu,
                stop_training_end_of_epoch,
                translation_samples,
                lr,
            ) = validate_save_and_evaluate_bleu(
                args=args,
                trainer=trainer,
                task=task,
                extra_state=extra_state,
                do_validate=True,
                do_save=not args.no_save
                and not args.no_end_of_epoch_checkpoints,
                do_eval_bleu=args.generate_bleu_eval_per_epoch,
            )
            extra_state["val_loss"] = val_loss
            yield (
                trainer.get_num_updates(),
                {
                    "train_ppl": train_stats["ppl"],
                    "tune_ppl": val_ppl,
                    "tune_bleu": val_bleu,
                    "translation_samples": translation_samples,
                },
            )
        if stop_training_mid_epoch or stop_training_end_of_epoch:
            break

        lr = trainer.lr_step(extra_state["epoch"], val_loss)
        extra_state["epoch"] += 1
        extra_state["batch_offset"] = 0
        starting_offset = 0

    train_meter.stop()
    print(f"| done training in {train_meter.sum:.1f} seconds")

    if "evaluate_bleu" in extra_state:
        print(
            f"| Best BLEU score of {extra_state['evaluate_bleu']['best']} was from "
            f"epoch {extra_state['evaluate_bleu']['best_epoch']}")
Beispiel #26
0
def main(args):
    import_user_module(args)

    if args.max_tokens is None:
        args.max_tokens = 6000
    print(args)

    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load dataset splits
    load_dataset_splits(task, ['train', 'valid'])

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print(model)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Make a dummy batch to (i) warm the caching allocator and (ii) as a
    # placeholder DistributedDataParallel when there's an uneven number of
    # batches per worker.
    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        model.max_positions(),
    )
    dummy_batch = task.dataset('train').get_dummy_batch(
        args.max_tokens, max_positions)
    oom_batch = task.dataset('train').get_dummy_batch(1, max_positions)

    # Build trainer
    trainer = Trainer(args, task, model, criterion, dummy_batch, oom_batch)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Initialize dataloader
    epoch_itr = task.get_batch_iterator(
        dataset=task.dataset(args.train_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
        ignore_invalid_inputs=True,
        required_batch_size_multiple=8,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
        num_workers=args.num_workers,
    )

    # Load the latest checkpoint if one is available
    if not load_checkpoint(args, trainer, epoch_itr):
        trainer.dummy_train_step([dummy_batch])

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update:
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Beispiel #27
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    #from IPython.core.debugger import Pdb; Pdb().set_trace()
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    if args.score_reference:
        translator = SequenceScorer(models, task.target_dictionary)
    else:
        translator = SequenceGenerator(
            models, task.target_dictionary, beam_size=args.beam, minlen=args.min_len,
            stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
            len_penalty=args.lenpen, unk_penalty=args.unkpen,
            sampling=args.sampling, sampling_topk=args.sampling_topk, sampling_temperature=args.sampling_temperature,
            diverse_beam_groups=args.diverse_beam_groups, diverse_beam_strength=args.diverse_beam_strength,
        )

    if use_cuda:
        translator.cuda()

    # Generate and compute BLEU score
    #from IPython.core.debugger import Pdb; Pdb().set_trace()
    scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    with progress_bar.build_progress_bar(args, itr) as t:
        if args.score_reference:
            translations = translator.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
        else:
            translations = translator.generate_batched_itr(
                t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
                cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size,
            )

        wps_meter = TimeMeter()
        for sample_id, src_tokens, target_tokens, hypos in translations:
            # Process input and ground truth
            has_target = target_tokens is not None
            target_tokens = target_tokens.int().cpu() if has_target else None

            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
                target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
            else:
                src_str = src_dict.string(src_tokens, args.remove_bpe)
                if has_target:
                    target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)

            if not args.quiet:
                print('S-{}\t{}'.format(sample_id, src_str))
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str))

            # Process top predictions
            for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                )

                if not args.quiet:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
                    print('P-{}\t{}'.format(
                        sample_id,
                        ' '.join(map(
                            lambda x: '{:.4f}'.format(x),
                            hypo['positional_scores'].tolist(),
                        ))
                    ))

                    if args.print_alignment:
                        print('A-{}\t{}'.format(
                            sample_id,
                            ' '.join(map(lambda x: str(utils.item(x)), alignment))
                        ))

                # Score only the top hypothesis
                if has_target and i == 0:
                    if align_dict is not None or args.remove_bpe is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tokenizer.Tokenizer.tokenize(
                            target_str, tgt_dict, add_if_not_exist=True)
                    scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += 1

    print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
        num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
Beispiel #28
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # when running on CPU, use fp32 as default
    if not use_cuda:
        args.fp16 = False

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(':'),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    torch.manual_seed(args.seed)

    # Optimize ensemble for generation
    for model in models:
        if use_cuda:
            model.cuda()

        config = utils.get_subtransformer_config(args)

        model.set_sample_config(config)
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()
        print(model, file=sys.stderr)
        print(args.path, file=sys.stderr)

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    num_sentences = 0
    has_target = True
    decoder_times_all = []
    input_len_all = []
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:

            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            gen_timer.start()
            hypos, decoder_times = task.inference_step(generator, models,
                                                       sample, prefix_tokens)
            input_len_all.append(
                np.mean(sample['net_input']['src_lengths'].cpu().numpy()))

            print(decoder_times)
            decoder_times_all.append(decoder_times)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                # Remove padding
                src_tokens = utils.strip_pad(
                    sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(
                        sample['target'][i, :], tgt_dict.pad()).int().cpu()

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(
                        args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(
                        args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

                if not args.quiet:
                    if src_dict is not None:
                        print('S-{}\t{}'.format(sample_id, src_str))
                    if has_target:
                        print('T-{}\t{}'.format(sample_id, target_str))

                # Process top predictions
                for j, hypo in enumerate(hypos[i][:args.nbest]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'].int().cpu()
                        if hypo['alignment'] is not None else None,
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    if not args.quiet:
                        print('H-{}\t{}\t{}'.format(sample_id, hypo['score'],
                                                    hypo_str))
                        print('P-{}\t{}'.format(
                            sample_id, ' '.join(
                                map(
                                    lambda x: '{:.4f}'.format(x),
                                    hypo['positional_scores'].tolist(),
                                ))))

                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id, ' '.join(
                                    map(lambda x: str(utils.item(x)),
                                        alignment))))

            wps_meter.update(num_generated_tokens)
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += sample['nsentences']
Beispiel #29
0
def _main(args, output_file):
    logging.basicConfig(
        format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        level=logging.INFO,
        stream=output_file,
    )
    logger = logging.getLogger('fairseq_cli.generate')

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(os.pathsep),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Optimize ensemble for generation
    for model_idx, model in enumerate(models):
        if model_idx == 0:
            logger.info('num. model params: {:.2f} M (num. trained: {:.2f} M)'.format(
                sum(p.numel() for p in model.parameters()) / 1e6,
                sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6,
            ))

        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(generator, models, sample, prefix_tokens)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                # Remove padding
                src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(sample['target'][i, :], tgt_dict.pad()).int().cpu()

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)

                if not args.quiet:
                    if src_dict is not None:
                        print('S-{}\t{}'.format(sample_id, src_str), file=output_file)
                    if has_target:
                        print('T-{}\t{}'.format(sample_id, target_str), file=output_file)

                # Process top predictions
                for j, hypo in enumerate(hypos[i][:args.nbest]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'],
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    if not args.quiet:
                        score = hypo['score'] / math.log(2)  # convert to base 2
                        print('H-{}\t{}\t{}'.format(sample_id, score, hypo_str), file=output_file)
                        print('P-{}\t{}'.format(
                            sample_id,
                            ' '.join(map(
                                lambda x: '{:.4f}'.format(x),
                                # convert from base e to base 2
                                hypo['positional_scores'].div_(math.log(2)).tolist(),
                            ))
                        ), file=output_file)

                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id,
                                ' '.join(['{}-{}'.format(src_idx, tgt_idx) for src_idx, tgt_idx in alignment])
                            ), file=output_file)

                        if args.print_step:
                            print('I-{}\t{}'.format(sample_id, hypo['steps']), file=output_file)

                        if getattr(args, 'retain_iter_history', False):
                            for step, h in enumerate(hypo['history']):
                                _, h_str, _ = utils.post_process_prediction(
                                    hypo_tokens=h['tokens'].int().cpu(),
                                    src_str=src_str,
                                    alignment=None,
                                    align_dict=None,
                                    tgt_dict=tgt_dict,
                                    remove_bpe=None,
                                )
                                print('E-{}_{}\t{}'.format(sample_id, step, h_str), file=output_file)

                    # Score only the top hypothesis
                    if has_target and j == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tgt_dict.encode_line(target_str, add_if_not_exist=True)
                        if hasattr(scorer, 'add_string'):
                            scorer.add_string(target_str, hypo_str)
                        else:
                            scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(num_generated_tokens)
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += sample['nsentences']

    logger.info('NOTE: hypothesis and token scores are output in base 2')
    logger.info('Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
        num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        logger.info('Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))

    return scorer
def main(args):
    if args.max_tokens is None:
        args.max_tokens = 6000
    print(args)

    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load dataset splits
    load_dataset_splits(task, ['train', 'valid'])

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
    print('| num. model params: {}'.format(sum(p.numel() for p in model.parameters())))

    # Make a dummy batch to (i) warm the caching allocator and (ii) as a
    # placeholder DistributedDataParallel when there's an uneven number of
    # batches per worker.
    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        model.max_positions(),
    )
    dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions)

    # Build trainer
    trainer = Trainer(args, task, model, criterion, dummy_batch)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Initialize dataloader
    epoch_itr = task.get_batch_iterator(
        dataset=task.dataset(args.train_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
        ignore_invalid_inputs=True,
        required_batch_size_multiple=8,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
    )

    # Load the latest checkpoint if one is available
    if not load_checkpoint(args, trainer, epoch_itr):
        trainer.dummy_train_step([dummy_batch])

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates() < max_update:
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
def main(args):
    check_args(args)
    import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 30000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    logger.info("| {} {} {} examples".format(
        args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    # Set dictionary
    tgt_dict = task.target_dictionary

    logger.info("| decoding with criterion {}".format(args.criterion))

    # Load ensemble
    logger.info("| loading model(s) from {}".format(args.path))
    models, criterions, _model_args = load_models_and_criterions(
        args.path.split(os.pathsep),
        arg_overrides=eval(args.model_overrides),  # noqa
        task=task,
    )
    optimize_models(args, use_cuda, models)

    # hack to pass transitions to W2lDecoder
    if args.criterion == "asg_loss":
        trans = criterions[0].asg.trans.data
        args.asg_transitions = torch.flatten(trans).tolist()

    # Load dataset (possibly sharded)
    itr = get_dataset_itr(args, task)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    num_sentences = 0

    if not os.path.exists(args.results_path):
        os.makedirs(args.results_path)

    sp = spm.SentencePieceProcessor()
    sp.Load(os.path.join(args.data, "spm.model"))

    res_files = prepare_result_files(args)
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if "net_input" not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample["target"][:, :args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample["id"].tolist()):
                speaker = task.dataset(
                    args.gen_subset).speakers[int(sample_id)]
                id = task.dataset(args.gen_subset).ids[int(sample_id)]
                target_tokens = (utils.strip_pad(sample["target"][i, :],
                                                 tgt_dict.pad()).int().cpu())
                # Process top predictions
                process_predictions(args, hypos[i], sp, tgt_dict,
                                    target_tokens, res_files, speaker, id)

            wps_meter.update(num_generated_tokens)
            t.log({"wps": round(wps_meter.avg)})
            num_sentences += sample["nsentences"]

    logger.info("| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
                "sentences/s, {:.2f} tokens/s)".format(
                    num_sentences,
                    gen_timer.n,
                    gen_timer.sum,
                    num_sentences / gen_timer.sum,
                    1.0 / gen_timer.avg,
                ))
    logger.info("| Generate {} with beam={}".format(args.gen_subset,
                                                    args.beam))