예제 #1
0
def make_batches(inputs_buffer, args, src_dict, ctx_dict, max_positions ):
    ctx_tokens = [
    tokenizer.Tokenizer.tokenize(inputs[1], ctx_dict, add_if_not_exist=False).long()
    for inputs in inputs_buffer
    ]

    tokens = [
        tokenizer.Tokenizer.tokenize(inputs[0], src_dict, add_if_not_exist=False).long()
        for inputs in inputs_buffer
    ]

    src_sizes = np.array([t.numel() for t in tokens])
    ctx_sizes = np.array([t.numel() for t in ctx_tokens])
    #!debug
    if len(max_positions) < 3:
        max_positions += (max_positions[0],)
    itr = data.EpochBatchIterator(
        dataset=data.LanguageTripleDataset(
            src=tokens, src_sizes=src_sizes, src_dict=src_dict,
            ctx=ctx_tokens, ctx_sizes=ctx_sizes, ctx_dict=ctx_dict
            ),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
    ).next_epoch_itr(shuffle=False)

    for batch in itr:
        yield Batch(
            srcs=[inputs_buffer[i][0] for i in batch['id']],
            tokens=batch['net_input']['src_tokens'],
            lengths=batch['net_input']['src_lengths'],
            ctxs=[inputs_buffer[i][1] for i in batch['id']],
            ctx_tokens=batch['net_input']['ctx_tokens'],
            ctx_lengths=batch['net_input']['ctx_lengths']
        ), batch['id']
예제 #2
0
def get_dataloader(args, encoder=None):
    ''' return dataloader for inference '''
    assert not (args.part == 'decoder' and encoder is None
                ), "Cannot export decoder without providing encoder"
    src_dict, tgt_dict = data_utils.load_dictionaries(args)
    datasets = load_dataset_splits(args, ['valid'], src_dict, tgt_dict)
    itr = data.EpochBatchIterator(
        dataset=datasets['valid'],
        max_tokens=args.max_tokens,
        max_positions=args.max_positions,
    ).next_epoch_itr(shuffle=False)

    def input_itr():
        for batch in itr:
            if itr.count > args.num_batches:
                break
            ni = batch['net_input']
            if args.part == 'decoder':  #this part works only on GPU
                with torch.no_grad():
                    encoder_out = encoder(ni['src_tokens'].cuda(),
                                          ni['src_lengths'].cuda())
                yield ni['prev_output_tokens'], encoder_out[0], encoder_out[1]
            elif args.part == 'encoder':
                yield ni['src_tokens'], ni['src_lengths']
            else:
                yield ni['src_tokens'], ni['src_lengths'], ni[
                    'prev_output_tokens']

    return input_itr()
예제 #3
0
def make_batches(lines, args, src_dict, max_positions):
    pairs = [
        tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False, reverse_order=args.reverse_order)
        for src_str in lines
    ]
    tokens = [p[0].long() for p in pairs]
    words = [p[1] for p in pairs]
    lengths = np.array([t.numel() for t in tokens])

    trg_tokens = None
    trg_lengths = None
    itr = data.EpochBatchIterator(
        dataset=data.LanguagePairDataset(tokens, lengths, src_dict, trg_tokens, trg_lengths, src_dict, use_copy=args.use_copy),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
    ).next_epoch_itr(shuffle=False)
    for batch in itr:
        yield Batch(
            srcs=[lines[i] for i in batch['id']],
            words=words,
            tokens=batch['net_input']['src_tokens'],
            lengths=batch['net_input']['src_lengths'],
            prev_output_tokens=batch['net_input']['prev_output_tokens'],
            target=batch['target'],
        ), batch['id'], batch
예제 #4
0
def main(args):
    assert args.path is not None, '--path required for evaluation!'

    args.tokens_per_sample = getattr(args, 'tokens_per_sample', 1024)
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task)

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
        if args.fp16:
            model.half()

    itr = data.EpochBatchIterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences or 4,
        max_positions=model.max_positions(),
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        ignore_invalid_inputs=True,
    ).next_epoch_itr(shuffle=False)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(models, task.target_dictionary)
    if use_cuda:
        scorer.cuda()

    score_sum = 0.
    count = 0
    with progress_bar.build_progress_bar(args, itr) as t:
        results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
        wps_meter = TimeMeter()
        for _, src_tokens, __, hypos in results:
            for hypo in hypos:
                pos_scores = hypo['positional_scores']
                inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
                if inf_scores.any():
                    print('| Skipping tokens with inf scores:',
                          task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
                    pos_scores = pos_scores[(~inf_scores).nonzero()]
                score_sum += pos_scores.sum()
                count += pos_scores.numel()
            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))
예제 #5
0
def get_eval_itr(args, models, task, dataset_split):
    return data.EpochBatchIterator(
        dataset=task.dataset(dataset_split),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=models[0].max_positions(),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
    ).next_epoch_itr(shuffle=False)
예제 #6
0
def validate(args,
             trainer,
             task,
             epoch_itr,
             subsets,
             ignoredIndicesValid=None):
    """Evaluate the model on the validation set(s) and return the losses."""
    valid_losses = []
    for subset in subsets:
        # Initialize data iterator
        itr = data.EpochBatchIterator(
            dataset=task.dataset(subset),
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences_valid,
            max_positions=trainer.get_model().max_positions(),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=8,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            ignoredIndices=ignoredIndicesValid,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.build_progress_bar(
            args,
            itr,
            epoch_itr.epoch,
            prefix='valid on \'{}\' subset'.format(subset),
            no_progress_bar='simple')

        # reset validation loss meters
        for k in ['valid_loss', 'valid_nll_loss']:
            meter = trainer.get_meter(k)
            if meter is not None:
                meter.reset()
        extra_meters = collections.defaultdict(lambda: AverageMeter())

        for sample in progress:
            log_output = trainer.valid_step(sample)

            for k, v in log_output.items():
                if k in ['loss', 'nll_loss', 'sample_size']:
                    continue
                extra_meters[k].update(v)

        # log validation stats
        stats = get_valid_stats(trainer)
        for k, meter in extra_meters.items():
            stats[k] = meter.avg
        progress.print(stats)

        valid_losses.append(stats['valid_loss'])
    return valid_losses
예제 #7
0
def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
    tokens = torch.LongTensor(list(range(epoch_size))).view(1, -1)
    tokens_ds = data.TokenBlockDataset(
        tokens, sizes=[tokens.size(-1)], block_size=1, pad=0, eos=1, include_targets=False,
    )
    trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
    dataset = data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False)
    epoch_itr = data.EpochBatchIterator(
        dataset=dataset,
        collate_fn=dataset.collater,
        batch_sampler=[[i] for i in range(epoch_size)],
    )
    return trainer, epoch_itr
예제 #8
0
def make_batches(lines, src_dict, max_positions):
    tokens = [
        tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long()
        for src_str in lines
    ]
    idx_to_words = {v: k for k, v in src_dict.indices.items()}
    lengths = np.array([t.numel() for t in tokens])
    itr = data.EpochBatchIterator(
        dataset=data.MonolingualDataset([(s[:-1], s[1:]) for s in tokens], lengths, src_dict, False),
        max_tokens=100,
        max_sentences=5,
        max_positions=max_positions,
    ).next_epoch_itr(shuffle=False)
    return itr
예제 #9
0
def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates,
                              iterations_in_epoch):
    tokens = torch.LongTensor(list(range(epoch_size)))
    tokens_ds = data.TokenBlockDataset(tokens, [len(tokens)],
                                       1,
                                       include_targets=False)
    trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
    epoch_itr = data.EpochBatchIterator(
        dataset=data.LanguagePairDataset(tokens_ds,
                                         tokens_ds.sizes,
                                         mock_dict(),
                                         shuffle=False),
        max_tokens=1,
    )
    return trainer, epoch_itr
def make_batches(lines, args, src_dict, max_positions):
    tokens = [
        tokenizer.Tokenizer.tokenize(src_str, src_dict,
                                     add_if_not_exist=False).long()
        for src_str in lines
    ]
    lengths = np.array([t.numel() for t in tokens])
    itr = data.EpochBatchIterator(
        dataset=data.LanguagePairDataset(tokens, lengths, src_dict),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
    ).next_epoch_itr(shuffle=False)
    for batch in itr:
        yield Batch(
            srcs=[lines[i] for i in batch['id']],
            tokens=batch['net_input']['src_tokens'],
            lengths=batch['net_input']['src_lengths'],
        ), batch['id']
예제 #11
0
def create_iterator(args, trainer, task, adv_split):
    """Sets up data and progress meters for one pass of adversarial attack."""
    # Set seed based on args.seed
    torch.manual_seed(args.seed)

    # reset training meters
    for k in ["wps", "ups", "wpb", "bsz"]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()

    return data.EpochBatchIterator(
        dataset=task.dataset(adv_split),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=trainer.get_model().max_positions(),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
    ).next_epoch_itr(shuffle=False)
예제 #12
0
def validate(args, trainer, datasets, subsets):
    """Evaluate the model on the validation set(s) and return the losses."""
    # Reset value iterations counter
    trainer._num_val_iterations = 0

    valid_losses = []
    for subset in subsets:

        if len(subsets) > 1:
            print('Validating on \'{}\' subset'.format(subset))

        # Initialize data iterator
        itr = data.EpochBatchIterator(
            dataset=datasets[subset],
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences_valid,
            max_positions=args.max_positions,
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=8,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
        ).next_epoch_itr(shuffle=False)

        # reset validation loss meters
        DLLogger.flush()

        subset_losses = []
        for sample in itr:
            loss = trainer.valid_step(sample)
            subset_losses.append(loss)
        subset_loss = sum(subset_losses) / len(subset_losses)

        DLLogger.flush()

        valid_losses.append(subset_loss)
        print(f'Validation loss on subset {subset}: {subset_loss}')

    return valid_losses
예제 #13
0
def score(args, trainer, dataset, src_dict, tgt_dict, ref_file):

    begin = time.time()

    src_dict = deepcopy(
        src_dict)  # This is necessary, generation of translations
    tgt_dict = deepcopy(
        tgt_dict
    )  # alters target dictionary messing up with the rest of training

    model = trainer.get_model()

    # Initialize data iterator
    itr = data.EpochBatchIterator(
        dataset=dataset,
        max_tokens=None,
        max_sentences=max(
            8, min(math.ceil(1024 / args.distributed_world_size), 128)),
        max_positions=args.max_positions,
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    translator = SequenceGenerator(
        [model],
        tgt_dict.get_metadata(),
        maxlen=args.max_target_positions - 1,  #do not include EOS token
        beam_size=args.beam,
        stop_early=(not args.no_early_stop),
        normalize_scores=(not args.unnormalized),
        len_penalty=args.lenpen,
        unk_penalty=args.unkpen,
        sampling=args.sampling,
        sampling_topk=args.sampling_topk,
        minlen=args.min_len,
    )
    # Generate and compute BLEU
    dict = dictionary.Dictionary()
    num_sentences = 0
    predictions = []
    translations = translator.generate_batched_itr(
        itr,
        maxlen_a=args.max_len_a,
        maxlen_b=args.max_len_b,
        cuda=True,
        timer=gen_timer,
        prefix_size=args.prefix_size,
    )

    for sample_id, src_tokens, target_tokens, hypos in translations:
        # Process input and grount truth
        target_tokens = target_tokens.int().cpu()

        src_str = src_dict.string(src_tokens, args.remove_bpe)
        target_str = tgt_dict.string(target_tokens,
                                     args.remove_bpe,
                                     escape_unk=True)

        # Process top predictions
        for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
            hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                hypo_tokens=hypo['tokens'].int().cpu(),
                src_str=src_str,
                alignment=hypo['alignment'].int().cpu()
                if hypo['alignment'] is not None else None,
                align_dict=None,
                tgt_dict=tgt_dict,
                remove_bpe=args.remove_bpe)

            # Score only the top hypothesis
            if i == 0:
                if args.sentencepiece:
                    hypo_str = hypo_str.replace(' ', '').replace('▁', ' ')
                    target_str = target_str.replace(' ', '').replace('▁', ' ')
                sys_tok = tokenizer.Tokenizer.tokenize(
                    (hypo_str.lower()
                     if not args.test_cased_bleu else hypo_str), dict)
                ref_tok = tokenizer.Tokenizer.tokenize(
                    (target_str.lower()
                     if not args.test_cased_bleu else target_str), dict)
                if not args.sentencepiece:
                    hypo_str = tokenizer.Tokenizer.detokenize(hypo_str, 'de')
                predictions.append('{}\t{}'.format(sample_id, hypo_str))

        num_sentences += 1

    if args.distributed_world_size > 1:
        predictions = _all_gather_predictions(predictions)

    with open(os.path.join(args.data, ref_file), 'r') as reference:
        refs = [reference.readlines()]
    #reducing indexed predictions as strings is more memory efficient than reducing tuples
    predictions = [tuple(item.split('\t')) for item in predictions]
    predictions = [(int(item[0]), item[1]) for item in predictions]
    predictions.sort(key=lambda tup: tup[0])
    predictions = [
        hypo[1] + ('\n' if hypo[1][-1] != '\n' else '') for hypo in predictions
    ]
    sacrebleu_score = sacrebleu.corpus_bleu(
        predictions, refs, lowercase=not args.test_cased_bleu).score
    if args.save_predictions:
        os.makedirs(os.path.join(args.save_dir, 'predictions'), exist_ok=True)
        with open(
                os.path.join(
                    args.save_dir, 'predictions',
                    ref_file + '.pred.update_{}'.format(trainer._num_updates)),
                'w') as f:
            f.write(''.join(predictions))

    DLLogger.log(step=trainer.get_num_updates(),
                 data={
                     'inference tokens/s':
                     float(args.distributed_world_size) / gen_timer.avg
                 },
                 verbosity=0)
    DLLogger.flush()
    if gen_timer.sum != 0:
        print(
            '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
            .format(len(predictions), gen_timer.n, gen_timer.sum,
                    len(predictions) / gen_timer.sum,
                    float(args.distributed_world_size) / gen_timer.avg))

    print('| Eval completed in: {:.2f}s | {}CASED BLEU {:.2f}'.format(
        time.time() - begin, '' if args.test_cased_bleu else 'UN',
        sacrebleu_score))

    return sacrebleu_score
예제 #14
0
def main(args):
    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)
    from mlperf_compliance.mlperf_log import transformer_print
    transformer_print(
        key=mlperf_log.RUN_CLEAR_CACHES
    )  #before this tag we should run clearing caches on the host
    # mlperf compliance synchronization
    if args.distributed_world_size > 1:
        assert (torch.distributed.is_initialized())
        torch.distributed.broadcast(torch.tensor([1], device="cuda"), 0)
        torch.cuda.synchronize()
    transformer_print(key=mlperf_log.RUN_START)
    if args.max_tokens is None:
        args.max_tokens = 6000
    print(args)
    transformer_print(key=mlperf_log.OPT_NAME, value=args.optimizer)
    transformer_print(key=mlperf_log.OPT_LR, value=args.lr)
    transformer_print(key=mlperf_log.OPT_HP_ADAM_BETA1,
                      value=eval(args.adam_betas)[0])
    transformer_print(key=mlperf_log.OPT_HP_ADAM_BETA2,
                      value=eval(args.adam_betas)[1])
    transformer_print(key=mlperf_log.OPT_HP_ADAM_EPSILON, value=args.adam_eps)
    pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
    result = torch.cuda.cudart().cudaDeviceSetLimit(ctypes.c_int(0x05),
                                                    ctypes.c_int(128))
    result = torch.cuda.cudart().cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
    torch.manual_seed(args.seed)
    transformer_print(key=mlperf_log.RUN_SET_RANDOM_SEED, value=args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)
    transformer_print(key=mlperf_log.MODEL_HP_SEQ_BEAM_SEARCH,
                      value={
                          'alpha': args.lenpen,
                          'beam_size': args.beam,
                          'extra_decode_length': args.max_len_b,
                          'vocab_size': task.target_dictionary.__len__()
                      })

    # Load dataset splits
    load_dataset_splits(task, ['train', 'valid'])

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {}'.format(
        sum(p.numel() for p in model.parameters())))

    # Build trainer
    if args.fp16:
        trainer = FP16Trainer(args, task, model, criterion)
    else:
        if torch.cuda.get_device_capability(0)[0] >= 7:
            print(
                '| NOTICE: your device may support faster training with --fp16'
            )
        trainer = Trainer(args, task, model, criterion)
    if (args.online_eval or args.target_bleu) and not args.remove_bpe:
        args.remove_bpe = '@@ '
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))
    transformer_print(key=mlperf_log.INPUT_BATCH_SIZE, value=args.max_tokens)
    transformer_print(key=mlperf_log.INPUT_ORDER)
    # Initialize dataloader
    max_positions = trainer.get_model().max_positions()

    # Send a dummy batch to warm the caching allocator
    dummy_batch = task.dataset('train').get_dummy_batch(
        args.max_tokens, max_positions)
    trainer.dummy_train_step(dummy_batch)

    # Train until the learning rate gets too small or model reaches target score
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    tgt_bleu = args.target_bleu or math.inf
    current_bleu = 0.0
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')

    ctr = 0

    class DummyEpochBatchIterator:
        def __init__(self, epoch=0):
            self.epoch = epoch

    epoch_itr = DummyEpochBatchIterator(0)
    transformer_print(key=mlperf_log.TRAIN_LOOP)
    while lr >= args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update and current_bleu < tgt_bleu:
        transformer_print(key=mlperf_log.TRAIN_EPOCH, value=epoch_itr.epoch)
        import time
        start = time.time()
        epoch_itr = data.EpochBatchIterator(
            dataset=task.dataset(args.train_subset),
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences_valid,
            max_positions=max_positions,
            ignore_invalid_inputs=True,
            required_batch_size_multiple=8,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            epoch=epoch_itr.epoch if ctr is not 0 else 0)
        print("got epoch iterator", time.time() - start)

        # Load the latest checkpoint if one is available
        if ctr is 0:
            load_checkpoint(args, trainer, epoch_itr)

        # train for one epoch
        start = time.time()
        train(args, trainer, task, epoch_itr)
        print("epoch time ", time.time() - start)

        start = time.time()

        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)

        # Eval BLEU score
        transformer_print(key=mlperf_log.EVAL_START, value=epoch_itr.epoch)
        if args.online_eval or (not tgt_bleu is math.inf):
            current_bleu = score(args, trainer, task, epoch_itr,
                                 args.gen_subset)
            transformer_print(key=mlperf_log.EVAL_ACCURACY,
                              value={
                                  'epoch': epoch_itr.epoch,
                                  'value': current_bleu
                              })
            transformer_print(key=mlperf_log.EVAL_TARGET, value=tgt_bleu)
        transformer_print(key=mlperf_log.EVAL_STOP, value=epoch_itr.epoch)

        # Only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # Save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        ctr = ctr + 1
        print("validation and scoring ", time.time() - start)

    train_meter.stop()
    transformer_print(key=mlperf_log.RUN_STOP)
    transformer_print(key=mlperf_log.RUN_FINAL)
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
예제 #15
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    if not args.quiet:
        print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(task.dataset(args.gen_subset))))

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task)

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)
        if args.fp16:
            model.half()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = data.EpochBatchIterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=models[0].max_positions(),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    if args.score_reference:
        translator = SequenceScorer(models, task.target_dictionary)
    else:
        translator = SequenceGenerator(
            models,
            task.target_dictionary,
            beam_size=args.beam,
            stop_early=(not args.no_early_stop),
            normalize_scores=(not args.unnormalized),
            len_penalty=args.lenpen,
            unk_penalty=args.unkpen,
            sampling=args.sampling,
            sampling_topk=args.sampling_topk,
            minlen=args.min_len,
        )

    if use_cuda:
        translator.cuda()

    # Generate and compute BLEU score
    scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True

    if args.score_reference:
        translations = translator.score_batched_itr(itr,
                                                    cuda=use_cuda,
                                                    timer=gen_timer)
    else:
        translations = translator.generate_batched_itr(
            itr,
            maxlen_a=args.max_len_a,
            maxlen_b=args.max_len_b,
            cuda=use_cuda,
            timer=gen_timer,
            prefix_size=args.prefix_size,
        )

    wps_meter = TimeMeter()
    decoded = dict()
    for sample_id, src_tokens, target_tokens, hypos in translations:
        sample_index = sample_id.tolist()
        # Process input and ground truth
        has_target = target_tokens is not None
        target_tokens = target_tokens.int().cpu() if has_target else None

        # Either retrieve the original sentences or regenerate them from tokens.
        if align_dict is not None:
            src_str = task.dataset(
                args.gen_subset).src.get_original_text(sample_id)
            target_str = task.dataset(
                args.gen_subset).tgt.get_original_text(sample_id)
        else:
            src_str = src_dict.string(src_tokens, args.remove_bpe)
            if has_target:
                target_str = tgt_dict.string(target_tokens,
                                             args.remove_bpe,
                                             escape_unk=True)

        if not args.quiet:
            print('S-{}\t{}'.format(sample_id, src_str))
            if has_target:
                print('T-{}\t{}'.format(sample_id, target_str))
        decoded[sample_index] = ["S\t{}".format(src_str)]
        if has_target:
            decoded[sample_index] = ["T\t{}".format(target_str)]

        # Process top predictions
        for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
            hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                hypo_tokens=hypo['tokens'].int().cpu(),
                src_str=src_str,
                alignment=hypo['alignment'].int().cpu(),
                align_dict=align_dict,
                tgt_dict=tgt_dict,
                remove_bpe=args.remove_bpe,
            )

            if not args.quiet:
                print('H-{}\t{}\t{}'.format(sample_id, hypo['score'],
                                            hypo_str))
                print('P-{}\t{}'.format(
                    sample_id, ' '.join(
                        map(
                            lambda x: '{:.4f}'.format(x),
                            hypo['positional_scores'].tolist(),
                        ))))
                print('A-{}\t{}'.format(
                    sample_id,
                    ' '.join(map(lambda x: str(utils.item(x)), alignment))))
            decoded[sample_index].append("H\t{}\t{}".format(
                hypo['score'], hypo_str))
            decoded[sample_index].append("P\t{}".format(' '.join(
                map(
                    lambda x: '{:.4f}'.format(x),
                    hypo['positional_scores'].tolist(),
                ))))
            decoded[sample_index].append("A\t{}".format(' '.join(
                map(lambda x: str(utils.item(x)), alignment))))

            # Score only the top hypothesis
            if has_target and i == 0:
                if align_dict is not None or args.remove_bpe is not None:
                    # Convert back to tokens for evaluation with unk replacement and/or without BPE
                    target_tokens = tokenizer.Tokenizer.tokenize(
                        target_str, tgt_dict, add_if_not_exist=True)
                scorer.add(target_tokens, hypo_tokens)

        decoded[sample_index] = "\n".join(decoded[sample_index])
        wps_meter.update(src_tokens.size(0))

        num_sentences += 1

    for i in range(num_sentences):
        print(decoded[i])

    print(
        '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset,
                                                      args.beam,
                                                      scorer.result_string()))
예제 #16
0
파일: train.py 프로젝트: lhu17/translate
def validate(args, trainer, task, subset, extra_state):
    """Evaluate the model on the validation set and return the average loss."""
    epoch = extra_state["epoch"]

    # Initialize dataloader
    itr = data.EpochBatchIterator(
        dataset=task.dataset(subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences_valid,
        max_positions=trainer.get_model().max_positions(),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch,
        prefix=f"valid on '{subset}' subset",
        no_progress_bar="simple")

    # reset validation loss meters
    for k in ["valid_loss", "valid_nll_loss"]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    for sample in progress:
        log_output = trainer.valid_step(sample)

        # log mid-validation stats
        stats = get_valid_stats(trainer)
        for k, v in log_output.items():
            if k in ["loss", "nll_loss"]:
                continue
            if "loss" in k:
                extra_meters[k].update(v, log_output["sample_size"])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats)

    # log validation stats
    stats = get_valid_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)

    val_loss = stats["valid_loss"]
    val_ppl = stats["valid_ppl"]

    if (extra_state["validate"]["lowest_loss"] is None
            or val_loss < extra_state["validate"]["lowest_loss"]):
        extra_state["validate"] = {
            "lowest_loss": val_loss,
            "num_since_best": 0
        }
    else:
        extra_state["validate"]["num_since_best"] += 1

    stop_due_to_val_loss = False
    if (args.stop_no_best_validate_loss >= 0
            and extra_state["validate"]["num_since_best"] >
            args.stop_no_best_validate_loss):
        stop_due_to_val_loss = True
        print(
            f"Stopping training due to validation score stagnation - last best "
            f"validation loss of {extra_state['validate']['lowest_loss']} (current loss: {val_loss}) "
            f"was {extra_state['validate']['num_since_best']} validations ago."
        )
    return val_loss, val_ppl, stop_due_to_val_loss
예제 #17
0
파일: train.py 프로젝트: lhu17/translate
def setup_training(args):
    """Parse args, load dataset, and load model trainer."""
    if not torch.cuda.is_available():
        raise NotImplementedError("Training on CPU is not supported")
    torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Setup task and load dataset
    task = tasks.setup_task(args)
    task.load_dataset(
        args.train_subset,
        args.train_source_binary_path,
        args.train_target_binary_path,
        weights_file=getattr(args, "train_weights_path", None),
    )
    task.load_dataset(args.valid_subset, args.eval_source_binary_path,
                      args.eval_target_binary_path)

    # Build model and criterion
    model = task.build_model(args)
    print("| building criterion")
    criterion = task.build_criterion(args)
    print(f"| model {args.arch}, criterion {criterion.__class__.__name__}")
    print(f"| num. model params: \
        {sum(p.numel() for p in model.parameters())}")

    # Build trainer
    if args.fp16:
        trainer = FP16Trainer(args, task, model, criterion)
    else:
        if torch.cuda.get_device_capability(0)[0] >= 7:
            print(
                "| NOTICE: your device may support faster training with --fp16"
            )
        trainer = Trainer(args, task, model, criterion)
    print(f"| training on {args.distributed_world_size} GPUs")
    print(
        f"| max tokens per GPU = {args.max_tokens} and \
        max sentences per GPU = {args.max_sentences}",
        flush=True,
    )

    os.makedirs(args.save_dir, exist_ok=True)

    # If --restore-file is already present under --save-dir, use that one
    # instead of --pretrained-checkpoint-file. The idea is that
    # --pretrained-checkpoint-file allows the user to specify restoring from a
    # different run's checkpoint (possibly with different training params),
    # while not polluting the previous run's checkpoint directory
    # with new checkpoints. However, if training gets interrupted
    # and the user restarts training, we want to resume from
    # the checkpoints under --save-dir, instead of
    # restarting again from the old run's checkpoint at
    # --pretrained-checkpoint-file.
    #
    # Note that if args.restore_file is an absolute path, os.path.join() will
    # ignore previous directory args and just use the absolute path as is.
    checkpoint_path = os.path.join(args.save_dir, args.restore_file)
    restore_state = True
    if os.path.exists(checkpoint_path):
        print(
            f"| Using --save-dir={args.save_dir}, --restore-file={args.restore_file}."
        )
    elif args.pretrained_checkpoint_file and os.path.exists(
            args.pretrained_checkpoint_file):
        checkpoint_path = args.pretrained_checkpoint_file
        restore_state = args.load_pretrained_checkpoint_state
        print(
            f"| Using --pretrained-checkpoint-file={args.pretrained_checkpoint_file}, "
            f"--load-pretrained-checkpoint-state={args.load_pretrained_checkpoint_state}."
        )

    extra_state = default_extra_state(args)
    if not os.path.isfile(checkpoint_path) and args.multi_model_restore_files:
        print(
            f"| Restoring individual models from {args.multi_model_restore_files}"
        )
        multi_model.import_individual_models(args.multi_model_restore_files,
                                             trainer)
    else:
        loaded, loaded_extra_state = load_existing_checkpoint(
            checkpoint_path=checkpoint_path,
            trainer=trainer,
            restore_state=restore_state,
        )
        if loaded_extra_state:
            extra_state.update(loaded_extra_state)
        if loaded:
            args.path = [checkpoint_path]
            calculate_bleu_on_subset(
                args=args,
                task=task,
                epoch_str="initial loaded checkpoint",
                offset=None,
                dataset_split=args.valid_subset,
            )
    print(f"| extra_state: {extra_state}")

    epoch_itr = data.EpochBatchIterator(
        dataset=task.dataset(args.train_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=trainer.get_model().max_positions(),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
    )
    epoch = extra_state["epoch"]
    if extra_state["batch_offset"] == 0:
        epoch -= 1  # this will be incremented when we call epoch_itr.next_epoch_itr()
    epoch_itr.load_state_dict({
        "epoch":
        epoch,
        "iterations_in_epoch":
        extra_state["batch_offset"]
    })

    return extra_state, trainer, task, epoch_itr
예제 #18
0
def main(args):
    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)
    if args.distributed_world_size > 1:
        assert (torch.distributed.is_initialized())
        torch.distributed.broadcast(torch.tensor([1], device="cuda"), 0)
        torch.cuda.synchronize()
    if args.max_tokens is None:
        args.max_tokens = 6000
    print(args)
    pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
    result = torch.cuda.cudart().cudaDeviceSetLimit(ctypes.c_int(0x05),
                                                    ctypes.c_int(128))
    result = torch.cuda.cudart().cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load dataset splits
    load_dataset_splits(task, ['train', 'valid'])

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {}'.format(
        sum(p.numel() for p in model.parameters())))

    # Build trainer
    if args.fp16:
        trainer = FP16Trainer(args, task, model, criterion)
    else:
        if torch.cuda.get_device_capability(0)[0] >= 7:
            print(
                '| NOTICE: your device may support faster training with --fp16'
            )
        trainer = Trainer(args, task, model, criterion)
    if (args.online_eval or args.target_bleu) and not args.remove_bpe:
        args.remove_bpe = '@@ '
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))
    max_positions = trainer.get_model().max_positions()
    epoch_itr = data.EpochBatchIterator(
        dataset=task.dataset(args.train_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences_valid,
        max_positions=max_positions,
        ignore_invalid_inputs=True,
        required_batch_size_multiple=8,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
    )
    # Load the latest checkpoint if one is available
    load_checkpoint(args, trainer, epoch_itr)

    # Send a dummy batch to warm the caching allocator
    dummy_batch = task.dataset('train').get_dummy_batch(
        args.max_tokens, max_positions)
    trainer.dummy_train_step(dummy_batch)

    # Train until the learning rate gets too small or model reaches target score
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    tgt_bleu = args.target_bleu or math.inf
    current_bleu = 0.0
    best_bleu = 0.0
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')

    while lr >= args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update and current_bleu < tgt_bleu:
        # train for one epoch
        train(args, trainer, task, epoch_itr)
        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)

        # Eval BLEU score
        if args.online_eval or (not tgt_bleu is math.inf):
            current_bleu, current_sc_bleu = score(args, trainer, task,
                                                  epoch_itr, args.gen_subset)
            if current_bleu > best_bleu:
                best_bleu = current_bleu
                save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        # Only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # Save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
예제 #19
0
def main(args):
    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)

    mlperf_compliance.mlperf_log.LOGGER.propagate = False

    # framework = f'Pytorch NGC {os.environ["NVIDIA_PYTORCH_VERSION"]}'
    # mlperf_submission_log(
    #     benchmark=mlperf_compliance.constants.TRANSFORMER,
    #     framework=framework)

    mlperf_compliance.mlperf_log.setdefault(
        root_dir=os.path.dirname(os.path.abspath(__file__)),
        benchmark=mlperf_compliance.constants.TRANSFORMER,
        stack_offset=1,
        extra_print=False)

    mlperf_print(key=mlperf_compliance.constants.INIT_START,
                 log_all_ranks=True)

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # preinit and warmup streams/groups for allreduce communicators
    allreduce_communicators = None
    if args.distributed_world_size > 1 and args.enable_parallel_backward_allred_opt:
        allreduce_groups = [
            torch.distributed.new_group()
            for _ in range(args.parallel_backward_allred_cuda_nstreams)
        ]
        allreduce_streams = [
            torch.cuda.Stream()
            for _ in range(args.parallel_backward_allred_cuda_nstreams)
        ]
        for group, stream in zip(allreduce_groups, allreduce_streams):
            with torch.cuda.stream(stream):
                torch.distributed.all_reduce(torch.cuda.FloatTensor(1),
                                             group=group)
        allreduce_communicators = (allreduce_groups, allreduce_streams)

    if args.max_tokens is None:
        args.max_tokens = 6000

    print(args)

    mlperf_print(key=mlperf_compliance.constants.GLOBAL_BATCH_SIZE,
                 value=args.max_tokens * args.distributed_world_size)
    mlperf_print(key=mlperf_compliance.constants.OPT_NAME,
                 value=args.optimizer)
    assert (len(args.lr) == 1)
    mlperf_print(key=mlperf_compliance.constants.OPT_BASE_LR,
                 value=args.lr[0] if len(args.lr) == 1 else args.lr)
    mlperf_print(key=mlperf_compliance.constants.OPT_LR_WARMUP_STEPS,
                 value=args.warmup_updates)
    assert (args.max_source_positions == args.max_target_positions)
    mlperf_print(key=mlperf_compliance.constants.MAX_SEQUENCE_LENGTH,
                 value=args.max_target_positions)
    mlperf_print(key=mlperf_compliance.constants.OPT_ADAM_BETA_1,
                 value=eval(args.adam_betas)[0])
    mlperf_print(key=mlperf_compliance.constants.OPT_ADAM_BETA_2,
                 value=eval(args.adam_betas)[1])
    mlperf_print(key=mlperf_compliance.constants.OPT_ADAM_EPSILON,
                 value=args.adam_eps)

    pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
    result = torch.cuda.cudart().cudaDeviceSetLimit(ctypes.c_int(0x05),
                                                    ctypes.c_int(128))
    result = torch.cuda.cudart().cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))

    #    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)

    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {}'.format(
        sum(p.numel() for p in model.parameters())))

    # Build trainer
    if args.fp16:
        trainer = FP16Trainer(args,
                              task,
                              model,
                              criterion,
                              allreduce_communicators=allreduce_communicators)
    else:
        if torch.cuda.get_device_capability(0)[0] >= 7:
            print(
                '| NOTICE: your device may support faster training with --fp16'
            )

        trainer = Trainer(args,
                          task,
                          model,
                          criterion,
                          allreduce_communicators=None)

    #if (args.online_eval or args.target_bleu) and not args.remove_bpe:
    #    args.remove_bpe='@@ '

    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Initialize dataloader
    max_positions = trainer.get_model().max_positions()

    # Send a dummy batch to warm the caching allocator
    dummy_batch = language_pair_dataset.get_dummy_batch_isolated(
        args.max_tokens, max_positions, 8)
    trainer.dummy_train_step(dummy_batch)

    # Train until the learning rate gets too small or model reaches target score
    max_epoch = args.max_epoch if args.max_epoch >= 0 else math.inf
    max_update = args.max_update or math.inf
    tgt_bleu = args.target_bleu or math.inf
    current_bleu = 0.0
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')

    # mlperf compliance synchronization
    if args.distributed_world_size > 1:
        assert (torch.distributed.is_initialized())
        torch.distributed.all_reduce(torch.cuda.FloatTensor(1))
        torch.cuda.synchronize()

    mlperf_print(key=mlperf_compliance.constants.INIT_STOP, sync=True)

    mlperf_print(key=mlperf_compliance.constants.RUN_START, sync=True)
    # second sync after RUN_START tag is printed.
    # this ensures no rank touches data until after RUN_START tag is printed.
    barrier()

    # Load dataset splits
    load_dataset_splits(task, ['train', 'test'])

    ctr = 0

    class DummyEpochBatchIterator:
        def __init__(self, epoch=0):
            self.epoch = epoch

    epoch_itr = DummyEpochBatchIterator(0)

    # Main training loop
    while lr >= args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update and current_bleu < tgt_bleu:
        first_epoch = epoch_itr.epoch + 1
        mlperf_print(key=mlperf_compliance.constants.BLOCK_START,
                     metadata={
                         'first_epoch_num': first_epoch,
                         'epoch_count': 1
                     },
                     sync=True)
        mlperf_print(key=mlperf_compliance.constants.EPOCH_START,
                     metadata={'epoch_num': first_epoch},
                     sync=True)
        start = time.time()

        gc.disable()

        epoch_itr = data.EpochBatchIterator(
            dataset=task.dataset(args.train_subset),
            dataloader_num_workers=args.dataloader_num_workers,
            dataloader_pin_memory=args.enable_dataloader_pin_memory,
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences_valid,
            max_positions=max_positions,
            ignore_invalid_inputs=True,
            required_batch_size_multiple=8,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            epoch=epoch_itr.epoch if ctr is not 0 else 0,
            bucket_growth_factor=args.bucket_growth_factor,
            seq_len_multiple=args.seq_len_multiple,
            batching_scheme=args.batching_scheme,
            batch_multiple_strategy=args.batch_multiple_strategy,
        )

        print("got epoch iterator", time.time() - start)

        # Load the latest checkpoint if one is available
        if ctr is 0:
            load_checkpoint(args, trainer, epoch_itr)

        # train for one epoch
        start = time.time()
        #exit(1)
        train(args, trainer, task, epoch_itr)
        print("epoch time ", time.time() - start)

        start = time.time()
        mlperf_print(key=mlperf_compliance.constants.EPOCH_STOP,
                     metadata={'epoch_num': first_epoch},
                     sync=True)

        #if epoch_itr.epoch % args.validate_interval == 0:
        #    valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)

        # Eval BLEU score
        if args.online_eval or (not tgt_bleu is math.inf):
            current_bleu = score(args, trainer, task, epoch_itr,
                                 args.gen_subset)
            mlperf_print(key=mlperf_compliance.tags.EVAL_ACCURACY,
                         value=str(current_bleu),
                         metadata={'epoch_num': first_epoch})

        gc.enable()

        # Only use first validation loss to update the learning rate
        #lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # Save checkpoint
        #if epoch_itr.epoch % args.save_interval == 0:
        #    save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        ctr = ctr + 1
        print("validation and scoring ", time.time() - start)
        mlperf_print(key=mlperf_compliance.constants.BLOCK_STOP,
                     metadata={'first_epoch_num': first_epoch},
                     sync=True)

    train_meter.stop()
    status = 'success' if current_bleu >= tgt_bleu else 'aborted'
    mlperf_print(key=mlperf_compliance.constants.RUN_STOP,
                 metadata={'status': status})
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
예제 #20
0
def main(args):
    if args.max_tokens is None:
        args.max_tokens = 6000
    print(args)

    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load dataset splits
    load_dataset_splits(task, ['train', 'valid'])

    # print("<AFTER>load_dataset_splits")
    # Build model and criterion
    model = task.build_model(args)
    print('| num. model params: {}'.format(sum(p.numel() for p in model.parameters())))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))
    # print("<AFTER>build_model")

    # Validation
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')
    val_criterion = task.build_criterion(args, 'label_smoothed_cross_entropy')
    val_trainer = Trainer(args, task, model, val_criterion)

    class_pretrain_flag = False
    mt_pretrain_flag = False

    # Pre-training on CNN discriminator and Seq2Seq recontruction
    if args.task == 'style_transfer':
        # classification pretrain
        criterion = task.build_criterion(args, 'classification')
        trainer = Trainer(args, task, model, criterion)
        print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
        max_positions = trainer.get_model().max_positions()
        epoch_itr = data.EpochBatchIterator(
            dataset=task.dataset('train'),
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences_valid,
            max_positions=max_positions,
            ignore_invalid_inputs=True,
            required_batch_size_multiple=8,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
        )

        # Load the latest checkpoint if one is available
        load_checkpoint(args, trainer, epoch_itr, load_optim=True, find_best=args.restore_best)

        max_epoch = args.pre_train_max_epoch
        while epoch_itr.epoch < max_epoch:
            class_pretrain_flag = True

            # train for one epoch
            train(args, trainer, task, epoch_itr)
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)

            # save to checkpoint
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        print("Done classification pretrain")

        # MT pretrain
        criterion = task.build_criterion(args, 'style_transfer_pretrain')
        trainer = Trainer(args, task, model, criterion)
        print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))

        # Load the latest checkpoint if one is available
        if epoch_itr.epoch <= args.pre_train_max_epoch:
            load_checkpoint(args, trainer, epoch_itr, load_optim=False, find_best=True)
            epoch_itr.epoch = args.pre_train_max_epoch
            save_checkpoint.best = float("inf")
        else:
            load_checkpoint(args, trainer, epoch_itr, load_optim=True, find_best=args.restore_best)

        # Send a dummy batch to warm the caching allocator
        dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions)
        trainer.dummy_train_step(dummy_batch)

        max_epoch = 2 * args.pre_train_max_epoch
        while epoch_itr.epoch < max_epoch:
            mt_pretrain_flag = True

            # train for one epoch
            train(args, trainer, task, epoch_itr)
            valid_losses = validate(args, val_trainer, task, epoch_itr, valid_subsets)

            # save to checkpoint
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        print("Done MT pretrain")

    # Training
    if args.task == 'style_transfer':
         criterion_name = "style_transfer_train"
         print("Loading plain data")
         load_dataset_splits(task, ['plain'])
    else:
         criterion_name = None
    criterion = task.build_criterion(args, criterion_name)
    print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))

    trainer = Trainer(args, task, model, criterion)

    # Load the latest checkpoint if one is available
    if epoch_itr.epoch <= 2*args.pre_train_max_epoch:
        load_checkpoint(args, trainer, epoch_itr, load_optim=False,
                            fix_discriminator=True, find_best=True)
    else:
        load_checkpoint(args, trainer, epoch_itr, load_optim=True,
                            fix_discriminator=True, find_best=args.restore_best)
        print("# WARNING:  Loading checkpoint with optimizer")

    # Initialize dataloader
    max_positions = trainer.get_model().max_positions()
    if args.task == 'style_transfer':
        src_plain_epoch_iter = data.EpochBatchIterator(
            dataset=task.dataset('plain')[0],
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences_valid,
            max_positions=max_positions,
            ignore_invalid_inputs=True,
            required_batch_size_multiple=8,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
        )
        trg_plain_epoch_iter = data.EpochBatchIterator(
            dataset=task.dataset('plain')[1],
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences_valid,
            max_positions=max_positions,
            ignore_invalid_inputs=True,
            required_batch_size_multiple=8,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
        )
        pre_train_max_epoch = 2 * args.pre_train_max_epoch

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()

    while lr > args.min_lr and epoch_itr.epoch < (max_epoch + pre_train_max_epoch) and trainer.get_num_updates() < max_update:
        # train for one epoch
        train(args, trainer, task, epoch_itr,
                use_plain=(args.task=='style_transfer'),
                src_plain_epoch_iter=src_plain_epoch_iter,
                trg_plain_epoch_iter=trg_plain_epoch_iter,
            )

        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(
                    args, val_trainer, task, epoch_itr, valid_subsets)

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
예제 #21
0
파일: train.py 프로젝트: sk210892/fairseq
def main(args):
    if args.max_tokens is None:
        args.max_tokens = 6000
    print(args)

    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load dataset splits
    load_dataset_splits(args, task, ['train', 'valid'])

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {}'.format(
        sum(p.numel() for p in model.parameters())))

    # Build trainer
    if args.fp16:
        trainer = FP16Trainer(args, task, model, criterion)
    else:
        if torch.cuda.get_device_capability(0)[0] >= 7:
            print(
                '| NOTICE: your device may support faster training with --fp16'
            )
        trainer = Trainer(args, task, model, criterion)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Initialize dataloader
    max_positions = trainer.get_model().max_positions()
    epoch_itr = data.EpochBatchIterator(
        dataset=task.dataset(args.train_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences_valid,
        max_positions=max_positions,
        ignore_invalid_inputs=True,
        required_batch_size_multiple=8,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
    )

    # Load the latest checkpoint if one is available
    load_checkpoint(args, trainer, epoch_itr)

    # Send a dummy batch to warm the caching allocator
    dummy_batch = task.dataset('train').get_dummy_batch(
        args.max_tokens, max_positions)
    trainer.dummy_train_step(dummy_batch)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')
    while lr > args.min_lr and epoch_itr.epoch <= max_epoch and trainer.get_num_updates(
    ) < max_update:
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
예제 #22
0
def score(args, trainer, task, epoch_itr, subset):

    begin = time.time()

    if not subset in task.datasets.keys():
        task.load_dataset(subset)

    src_dict = deepcopy(task.source_dictionary
                        )  # This is necessary, generation of translations
    tgt_dict = deepcopy(
        task.target_dictionary
    )  # alters target dictionary messing up with the rest of training

    model = trainer.get_model()

    #mlperf_log.transformer_print(key=mlperf_log.EVAL_SIZE, value=task.dataset(subset).__len__())
    # Initialize data iterator
    itr = data.EpochBatchIterator(
        dataset=task.dataset(subset),
        max_tokens=None,
        max_sentences=max(
            8, min(math.ceil(1024 / args.distributed_world_size), 128)),
        max_positions=model.max_positions(),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    translator = SequenceGenerator(
        [model],
        tgt_dict,
        beam_size=args.beam,
        stop_early=(not args.no_early_stop),
        normalize_scores=(not args.unnormalized),
        len_penalty=args.lenpen,
        unk_penalty=args.unkpen,
        sampling=args.sampling,
        sampling_topk=args.sampling_topk,
        minlen=args.min_len,
    )
    # Generate and compute BLEU
    dict = dictionary.Dictionary()
    scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
    num_sentences = 0
    has_target = True
    if args.log_translations:
        log = open(
            os.path.join(
                args.save_dir,
                'translations_epoch{}_{}'.format(epoch_itr.epoch,
                                                 args.distributed_rank)), 'w+')
    with progress_bar.build_progress_bar(args, itr) as progress:
        translations = translator.generate_batched_itr(
            progress,
            maxlen_a=args.max_len_a,
            maxlen_b=args.max_len_b,
            cuda=True,
            timer=gen_timer,
            prefix_size=args.prefix_size,
        )

        wps_meter = TimeMeter()
        for sample_id, src_tokens, target_tokens, hypos in translations:
            # Process input and grount truth
            has_target = target_tokens is not None
            target_tokens = target_tokens.int().cpu() if has_target else None

            src_str = src_dict.string(src_tokens, args.remove_bpe)
            if has_target:
                target_str = tgt_dict.string(target_tokens,
                                             args.remove_bpe,
                                             escape_unk=True)

            if args.log_translations:
                log.write('S-{}\t{}\n'.format(sample_id, src_str))
                if has_target:
                    log.write('T-{}\t{}\n'.format(sample_id, target_str))

            # Process top predictions
            for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu()
                    if hypo['alignment'] is not None else None,
                    align_dict=None,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe)
                if args.log_translations:
                    log.write('H-{}\t{}\t{}\n'.format(sample_id, hypo['score'],
                                                      hypo_str))
                    # log.write(str(hypo_tokens))
                    log.write('P-{}\t{}\n'.format(
                        sample_id, ' '.join(
                            map(
                                lambda x: '{:.4f}'.format(x),
                                hypo['positional_scores'].tolist(),
                            ))))

                # Score only the top hypothesis
                if has_target and i == 0:
                    sys_tok = tokenizer.Tokenizer.tokenize(
                        (hypo_str.lower() if args.ignore_case else hypo_str),
                        dict)
                    ref_tok = tokenizer.Tokenizer.tokenize(
                        (target_str.lower()
                         if args.ignore_case else target_str), dict)
                    scorer.add(ref_tok, sys_tok)

            wps_meter.update(src_tokens.size(0))
            progress.log({'wps': round(wps_meter.avg)})
            num_sentences += 1

    if args.distributed_world_size > 1:
        _all_gather_bleu_scorer(scorer)
    if args.log_translations:
        log.close()
    if gen_timer.sum != 0:
        print(
            '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
            .format(num_sentences, gen_timer.n, gen_timer.sum,
                    num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(subset, args.beam,
                                                      scorer.result_string()))

    print('| Eval completed in: {:.2f}s'.format(time.time() - begin))

    return scorer.score(order=4)
예제 #23
0
def main(args):
    assert args.path is not None, '--path required for evaluation!'

    args.tokens_per_sample = getattr(args, 'tokens_per_sample', 1024)
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(task.dataset(args.gen_subset))))

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task)

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
        if args.fp16:
            model.half()

    assert len(models) > 0

    itr = data.EpochBatchIterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
        max_positions=models[0].max_positions(),
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        ignore_invalid_inputs=True,
    ).next_epoch_itr(shuffle=False)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(models, task.target_dictionary)
    if use_cuda:
        scorer.cuda()

    score_sum = 0.
    count = 0

    if args.remove_bpe is not None:
        bpe_cont = args.remove_bpe.rstrip()
        bpe_toks = set(i for i in range(len(task.dictionary))
                       if task.dictionary[i].endswith(bpe_cont))
        bpe_len = len(bpe_cont)
    else:
        bpe_toks = None
        bpe_len = 0

    with progress_bar.build_progress_bar(args, itr) as t:
        results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
        wps_meter = TimeMeter()
        for _, src_tokens, __, hypos in results:
            for hypo in hypos:
                pos_scores = hypo['positional_scores']

                skipped_toks = 0
                if bpe_toks is not None:
                    for i in range(len(hypo['tokens']) - 1):
                        if hypo['tokens'][i].item() in bpe_toks:
                            skipped_toks += 1
                            pos_scores[i + 1] += pos_scores[i]
                            pos_scores[i] = 0

                inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(
                    float('-inf'))
                if inf_scores.any():
                    print(
                        '| Skipping tokens with inf scores:',
                        task.target_dictionary.string(
                            hypo['tokens'][inf_scores.nonzero()]))
                    pos_scores = pos_scores[(~inf_scores).nonzero()]
                score_sum += pos_scores.sum()
                count += pos_scores.numel() - skipped_toks

                if args.output_word_probs:
                    w = ''
                    word_prob = []
                    for i in range(len(hypo['tokens'])):
                        w_ind = hypo['tokens'][i].item()
                        w += task.dictionary[w_ind]
                        if bpe_toks is not None and w_ind in bpe_toks:
                            w = w[:-bpe_len]
                        else:
                            word_prob.append((w, pos_scores[i].item()))
                            w = ''
                    print('\t'.join('{} [{:2f}]'.format(x[0], x[1])
                                    for x in word_prob))

            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(
        gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss,
                                                      np.exp(avg_nll_loss)))
예제 #24
0
def eval_from_file(models, task, args, use_cuda, source_filename=None,
                   target_filename=None, score_filename=None):
    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # I/O files
    source_filename = source_filename if source_filename is not None else args.source_file
    target_filename = target_filename if target_filename is not None else args.target_file
    score_filename = score_filename if score_filename is not None else args.score_file
    if score_filename is None:
        score_filename = target_filename + ".eval.score"
    outfile = open(score_filename, "w")

    # Get sorted input (and reversed)
    sorted_inputs, sorted_keys, sorted_targets = _get_sorted_inputs(
        source_filename, args.num_shards, args.delimiter, target_filename, args.shard_id,
        args.dup_src, args.dup_tgt)

    # Build input iterator
    src_tokens = [
        tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long()
        for src_str in sorted_inputs]
    tgt_tokens = [
        tokenizer.Tokenizer.tokenize(tgt_str, tgt_dict, add_if_not_exist=False).long()
        for tgt_str in sorted_targets] if sorted_targets is not None else None
    src_sizes = np.array([t.numel() for t in src_tokens])
    tgt_sizes = np.array([t.numel() for t in tgt_tokens])
    print('| loading {} examples, {} tokens'.format(len(sorted_inputs), sum(src_sizes)))

    dataset = data.LanguagePairDataset(
        src_tokens, src_sizes, src_dict, tgt_tokens, tgt_sizes, tgt_dict, shuffle=False)
    itr = data.EpochBatchIterator(
        dataset=dataset,
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=models[0].max_positions(),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
    ).next_epoch_itr(shuffle=False)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(models, task.target_dictionary)
    if use_cuda:
        scorer.cuda()

    all_scores = dict()
    score_sum = 0.
    count, sen_count = 0, 0
    results = scorer.score_batched_itr(itr, cuda=use_cuda, timer=gen_timer)
    wps_meter = TimeMeter()
    for sample_id, src_tokens, target_tokens, hypos in results:
        for i, hypo in enumerate(hypos):
            pos_scores = hypo['positional_scores']
            inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
            if inf_scores.any():
                print('| Skipping tokens with inf scores:',
                      task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
                pos_scores = pos_scores[(~inf_scores).nonzero()]
            score_sum += pos_scores.sum()
            count += pos_scores.numel()
            sentence_score = hypo['score']
            if i == 0:
                all_scores[sample_id.tolist()] = sentence_score
        sen_count += 1
        wps_meter.update(src_tokens.size(0))

    print("| [eval] writing scores into {}".format(score_filename))
    # print(sids)
    for index in range(len(sorted_inputs)):
        outfile.write("{}{}".format(all_scores[sorted_keys[index]], args.delimiter))
    outfile.close()

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))
예제 #25
0
def main(args):

    print(args)
    setup_logger(args)

    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)
    if args.distributed_world_size > 1:
        assert (torch.distributed.is_initialized())
        torch.distributed.broadcast(torch.tensor([1], device="cuda"), 0)
        torch.cuda.synchronize()
    if args.max_tokens is None:
        args.max_tokens = 6000
    pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
    ctypes.CDLL('libcudart.so').cudaDeviceSetLimit(ctypes.c_int(0x05),
                                                   ctypes.c_int(128))
    ctypes.CDLL('libcudart.so').cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
    torch.manual_seed(args.seed)

    src_dict, tgt_dict = data_utils.load_dictionaries(args)
    add_extra_items_to_checkpoint({'src_dict': src_dict, 'tgt_dict': tgt_dict})
    datasets = load_dataset_splits(args, ['train', 'valid', 'test'], src_dict,
                                   tgt_dict)

    model = build_model(args)
    print('| num. model params: {}'.format(
        sum(p.numel() for p in model.parameters())))

    # Build trainer
    if torch.cuda.get_device_capability(0)[0] >= 7 and not args.amp:
        print('| NOTICE: your device may support faster training with --amp')
    trainer = DDPTrainer(args, model)
    print('| model {}, criterion {}'.format(
        args.arch, trainer.criterion.__class__.__name__))

    if (args.online_eval or args.target_bleu) and not args.remove_bpe:
        args.remove_bpe = '@@ '
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    epoch_itr = data.EpochBatchIterator(
        dataset=datasets[args.train_subset],
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences_valid,
        max_positions=args.max_positions,
        ignore_invalid_inputs=True,
        required_batch_size_multiple=8,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
    )
    # Load the latest checkpoint if one is available
    load_checkpoint(args, trainer, epoch_itr)

    # Send a dummy batch to warm the caching allocator
    dummy_batch = data_utils.get_dummy_batch(args.max_tokens, src_dict,
                                             tgt_dict)
    trainer.dummy_train_step(dummy_batch)

    # Sanity check
    if args.do_sanity_check:
        print('Performing sanity check...')
        sanity_score = score(args, trainer, datasets['test'], src_dict,
                             tgt_dict, 'test.raw.de')
        DLLogger.log(step='SANITY_CHECK',
                     data={'sanity_check_score': sanity_score},
                     verbosity=1)

    # Train until the learning rate gets too small or model reaches target score
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    tgt_bleu = args.target_bleu or math.inf
    current_bleu = 0.0
    best_bleu = -1.0
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')
    run_summary = {
        'loss': float('inf'),
        'val_loss': float('inf'),
        'speed': 0,
        'accuracy': 0
    }

    while lr >= args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update and current_bleu < tgt_bleu:
        DLLogger.log(step=trainer.get_num_updates(),
                     data={'epoch': epoch_itr.epoch},
                     verbosity=0)
        # train for one epoch
        with torch.autograd.profiler.profile(enabled=args.profile,
                                             use_cuda=True) as prof:
            train(args, trainer, datasets, epoch_itr)
        if args.profile:
            print(prof.key_averages().table(sort_by="cuda_time_total"))
            if args.profiler_file:
                with open(os.path.join(args.save_dir, args.profiler_file),
                          'w') as f:
                    f.write(
                        prof.key_averages().table(sort_by="cuda_time_total"))
            exit(0)

        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, datasets, valid_subsets)
            valid_bleu = score(args, trainer, datasets[valid_subsets[0]],
                               src_dict, tgt_dict, 'valid.raw.de')
            DLLogger.log(step=trainer.get_num_updates(),
                         data={
                             'val_loss': valid_losses[0],
                             'val_bleu': valid_bleu
                         },
                         verbosity=1)

        # Eval BLEU score
        if args.online_eval or (not tgt_bleu is math.inf):
            current_bleu = score(args, trainer, datasets[args.gen_subset],
                                 src_dict, tgt_dict, 'test.raw.de')
            DLLogger.log(step=trainer.get_num_updates(),
                         data={'test_bleu': current_bleu},
                         verbosity=1)
            if current_bleu > best_bleu:
                best_bleu = current_bleu
                DLLogger.log(step='RUN', data={'BLEU': best_bleu}, verbosity=0)
                save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        if valid_losses[0] < run_summary['val_loss']:
            run_summary['val_loss'] = valid_losses[0]
            if best_bleu < 0:
                run_summary['accuracy'] = valid_bleu
            else:
                run_summary['accuracy'] = best_bleu
        run_summary['loss'] = valid_losses[0]
        run_summary['speed'] = trainer.throughput_meter.u_avg

        # Only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # Save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

    train_meter.stop()
    DLLogger.log(step=[], data=run_summary, verbosity=0)
    DLLogger.log(step='RUN', data={'walltime': train_meter.sum}, verbosity=0)
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
예제 #26
0
def score(args, trainer, task, epoch_itr, subset):

    mlperf_print(key=mlperf_compliance.constants.EVAL_START,
                 metadata={'epoch_num': epoch_itr.epoch},
                 sync=True)
    begin = time.time()

    if not subset in task.datasets.keys():
        task.load_dataset(subset)

    src_dict = deepcopy(task.source_dictionary
                        )  # This is necessary, generation of translations
    tgt_dict = deepcopy(
        task.target_dictionary
    )  # alters target dictionary messing up with the rest of training

    model = trainer.get_model()

    # Initialize data iterator
    itr = data.EpochBatchIterator(
        dataset=task.dataset(subset),
        max_tokens=min(2560, args.max_tokens),
        max_sentences=max(
            8, min(math.ceil(1024 / args.distributed_world_size), 128)),
        max_positions=(256, 256),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
        seq_len_multiple=args.seq_len_multiple,
        # Use a large growth factor to get fewer buckets.
        # Fewer buckets yield faster eval since batches are filled from single bucket
        # and eval dataset is small.
        bucket_growth_factor=1.2,
        batching_scheme=args.batching_scheme,
        batch_multiple_strategy=args.batch_multiple_strategy,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    translator = SequenceGenerator(
        [model],
        tgt_dict,
        beam_size=args.beam,
        stop_early=(not args.no_early_stop),
        normalize_scores=(not args.unnormalized),
        len_penalty=args.lenpen,
        sampling=args.sampling,
        sampling_topk=args.sampling_topk,
        minlen=args.min_len,
    )
    # Generate and compute BLEU
    ref_toks = []
    sys_toks = []
    num_sentences = 0
    has_target = True
    if args.log_translations:
        log = open(
            os.path.join(
                args.save_dir,
                'translations_epoch{}_{}'.format(epoch_itr.epoch,
                                                 args.distributed_rank)), 'w+')
    with progress_bar.build_progress_bar(args, itr) as progress:
        translations = translator.generate_batched_itr(
            progress,
            maxlen_a=args.max_len_a,
            maxlen_b=args.max_len_b,
            cuda=True,
            timer=gen_timer,
            prefix_size=args.prefix_size,
        )

        wps_meter = TimeMeter()
        for sample_id, src_tokens, target_tokens, hypos in translations:
            # Process input and grount truth
            has_target = target_tokens is not None
            target_tokens = target_tokens.int().cpu() if has_target else None

            src_str = src_dict.string(src_tokens, args.remove_bpe)
            if has_target:
                target_str = tgt_dict.string(target_tokens, args.remove_bpe)

            if args.log_translations:
                log.write('S-{}\t{}\n'.format(sample_id, src_str))
                if has_target:
                    log.write('T-{}\t{}\n'.format(sample_id, target_str))

            # Process top predictions
            for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu()
                    if hypo['alignment'] is not None else None,
                    align_dict=None,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe)
                if args.log_translations:
                    log.write('H-{}\t{}\t{}\n'.format(sample_id, hypo['score'],
                                                      hypo_str))
                    # log.write(str(hypo_tokens))
                    log.write('P-{}\t{}\n'.format(
                        sample_id, ' '.join(
                            map(
                                lambda x: '{:.4f}'.format(x),
                                hypo['positional_scores'].tolist(),
                            ))))

                # Score only the top hypothesis
                if has_target and i == 0:
                    src_str = detokenize_subtokenized_sentence(src_str)
                    target_str = detokenize_subtokenized_sentence(target_str)
                    hypo_str = detokenize_subtokenized_sentence(hypo_str)
                    sys_tok = bleu_tokenize(
                        (hypo_str.lower() if args.ignore_case else hypo_str))
                    ref_tok = bleu_tokenize((target_str.lower() if
                                             args.ignore_case else target_str))
                    sys_toks.append(sys_tok)
                    ref_toks.append(ref_tok)

            wps_meter.update(src_tokens.size(0))
            progress.log({'wps': round(wps_meter.avg)})
            num_sentences += 1

    bleu_score_reference = compute_bleu(ref_toks, sys_toks, args)
    bleu_score_reference_str = '{:.4f}'.format(bleu_score_reference)
    if args.log_translations:
        log.close()
    if gen_timer.sum != 0:
        print(
            '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
            .format(num_sentences, gen_timer.n, gen_timer.sum,
                    num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: bleu_score={}'.format(
            subset, args.beam, bleu_score_reference_str))
    print('| Eval completed in: {:.2f}s'.format(time.time() - begin))
    mlperf_print(key=mlperf_compliance.constants.EVAL_STOP,
                 metadata={'epoch_num': epoch_itr.epoch},
                 sync=True)

    return bleu_score_reference
예제 #27
0
def decode_from_file(models, task, args, use_cuda, source_filename=None,
                     target_filename=None, output_filename=None):
    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # I/O files
    source_filename = source_filename if source_filename is not None else args.decode_source_file
    target_filename = target_filename if target_filename is not None else args.decode_target_file
    output_filename = output_filename if output_filename is not None else args.decode_output_file
    if output_filename is not None:
        base_filename = output_filename
    else:
        base_filename = source_filename
        if args.num_shards:
            base_filename += "%.2d" % args.shard_id
    decode_filename = _decode_filename(base_filename, args)
    outfile = open(decode_filename, "w")
    if args.decode_to_file:
        print("| [decode] writing decodes into {}".format(decode_filename))

    # Get sorted input (and reversed)
    sorted_inputs, sorted_keys, sorted_targets = _get_sorted_inputs(
        source_filename, args.num_shards, args.delimiter, target_filename, args.shard_id)

    # Build input iterator
    src_tokens = [
        tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long()
        for src_str in sorted_inputs]
    src_sizes = np.array([t.numel() for t in src_tokens])
    tgt_tokens = [
        tokenizer.Tokenizer.tokenize(tgt_str, tgt_dict, add_if_not_exist=False).long()
        for tgt_str in sorted_targets] if sorted_targets is not None else None
    tgt_sizes = np.array([t.numel() for t in tgt_tokens]) if tgt_tokens is not None else None
    print('| loading {} examples, {} tokens'.format(len(sorted_inputs), sum(src_sizes)))

    dataset = data.LanguagePairDataset(
        src_tokens, src_sizes, src_dict, tgt_tokens, tgt_sizes, tgt_dict, shuffle=False)
    itr = data.EpochBatchIterator(
        dataset=dataset,
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=models[0].max_positions(),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    if args.score_reference:
        translator = SequenceScorer(models, task.target_dictionary)
    else:
        translator = SequenceGenerator(
            models, task.target_dictionary, beam_size=args.beam,
            stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
            len_penalty=args.lenpen, unk_penalty=args.unkpen,
            sampling=args.sampling, sampling_topk=args.sampling_topk, minlen=args.min_len,
        )

    if use_cuda:
        translator.cuda()

    # Generate and compute BLEU score
    scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True

    if args.score_reference:
        translations = translator.score_batched_itr(itr, cuda=use_cuda, timer=gen_timer)
    else:
        translations = translator.generate_batched_itr(
            itr, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
            cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size,
        )

    decodes = dict()
    sids = []
    wps_meter = TimeMeter()
    start = time.perf_counter()
    for sample_id, src_tokens, target_tokens, hypos in translations:
        # Process input and ground truth
        has_target = target_tokens is not None
        target_tokens = target_tokens.int().cpu() if has_target else None

        # Either retrieve the original sentences or regenerate them from tokens.
        if align_dict is not None:
            src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
            target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
        else:
            src_str = src_dict.string(src_tokens, args.remove_bpe)
            if has_target:
                target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)

        if not args.quiet:
            try:
                print('S-{}\t{}'.format(sample_id, src_str))
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str))
            except:
                print('S-{}\t{}'.format(sample_id, src_str.encode('utf-8')))
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str.encode('utf-8')))

        # Process top predictions
        for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
            hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                hypo_tokens=hypo['tokens'].int().cpu(),
                src_str=src_str,
                alignment=hypo['alignment'].int().cpu(),
                align_dict=align_dict,
                tgt_dict=tgt_dict,
                remove_bpe=args.remove_bpe,
            )
            if i == 0:
                decodes[sample_id.tolist()] = hypo_str
                # sids.append(sample_id.tolist())

            if not args.quiet:
                try:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
                except:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str.encode('utf-8')))
                print('P-{}\t{}'.format(
                    sample_id,
                    ' '.join(map(
                        lambda x: '{:.4f}'.format(x),
                        hypo['positional_scores'].tolist(),
                    ))
                ))
                print('A-{}\t{}'.format(
                    sample_id,
                    ' '.join(map(lambda x: str(utils.item(x)), alignment))
                ))

            # Score only the top hypothesis
            if has_target and i == 0:
                if align_dict is not None or args.remove_bpe is not None:
                    # Convert back to tokens for evaluation with unk replacement and/or without BPE
                    target_tokens = tokenizer.Tokenizer.tokenize(
                        target_str, tgt_dict, add_if_not_exist=True)
                scorer.add(target_tokens, hypo_tokens)

        wps_meter.update(src_tokens.size(0))

        num_sentences += 1
        if args.quiet and num_sentences % 100 == 0:
            print("| {} / {} sentences decoded ({})".format(num_sentences, len(sorted_inputs), len(decodes)))

    used_time = time.perf_counter() - start
    print("| Used time:" + repr(used_time))
    print("| Average time:" + repr(used_time / len(sorted_inputs)))

    if args.decode_to_file:
        print("| [decode] writing decodes into {}".format(decode_filename))
        # print(sids)
        for index in range(len(sorted_inputs)):
            try:
                outfile.write("{}{}".format(decodes[sorted_keys[index]], args.delimiter))
            except:
                outfile.write("{}{}".format(decodes[sorted_keys[index]].encode('utf-8'), args.delimiter))
        outfile.close()

    print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
        num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
예제 #28
0
def score(args, trainer, task, epoch_itr, subset):

    begin = time.time()

    if not subset in task.datasets.keys():
        task.load_dataset(subset)

    src_dict = deepcopy(task.source_dictionary
                        )  # This is necessary, generation of translations
    tgt_dict = deepcopy(
        task.target_dictionary
    )  # alters target dictionary messing up with the rest of training

    model = trainer.get_model()

    # Initialize data iterator
    itr = data.EpochBatchIterator(
        dataset=task.dataset(subset),
        max_tokens=None,
        max_sentences=max(
            8, min(math.ceil(1024 / args.distributed_world_size), 128)),
        max_positions=model.max_positions(),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    translator = SequenceGenerator(
        [model],
        tgt_dict,
        beam_size=args.beam,
        stop_early=(not args.no_early_stop),
        normalize_scores=(not args.unnormalized),
        len_penalty=args.lenpen,
        unk_penalty=args.unkpen,
        sampling=args.sampling,
        sampling_topk=args.sampling_topk,
        minlen=args.min_len,
    )
    # Generate and compute BLEU
    dict = dictionary.Dictionary()
    scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
    num_sentences = 0
    has_target = True
    predictions = []
    with progress_bar.build_progress_bar(args, itr) as progress:
        translations = translator.generate_batched_itr(
            progress,
            maxlen_a=args.max_len_a,
            maxlen_b=args.max_len_b,
            cuda=True,
            timer=gen_timer,
            prefix_size=args.prefix_size,
        )

        wps_meter = TimeMeter()
        for sample_id, src_tokens, target_tokens, hypos in translations:
            # Process input and grount truth
            has_target = target_tokens is not None
            target_tokens = target_tokens.int().cpu() if has_target else None

            src_str = src_dict.string(src_tokens, args.remove_bpe)
            if has_target:
                target_str = tgt_dict.string(target_tokens,
                                             args.remove_bpe,
                                             escape_unk=True)

            # Process top predictions
            for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu()
                    if hypo['alignment'] is not None else None,
                    align_dict=None,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe)

                # Score only the top hypothesis
                if has_target and i == 0:
                    if args.sentencepiece:
                        hypo_str = hypo_str.replace(' ', '').replace('▁', ' ')
                        target_str = target_str.replace(' ',
                                                        '').replace('▁', ' ')
                    sys_tok = tokenizer.Tokenizer.tokenize(
                        (hypo_str.lower() if args.ignore_case else hypo_str),
                        dict)
                    ref_tok = tokenizer.Tokenizer.tokenize(
                        (target_str.lower()
                         if args.ignore_case else target_str), dict)
                    scorer.add(ref_tok, sys_tok)
                    if not args.sentencepiece:
                        hypo_str = tokenizer.Tokenizer.detokenize(
                            hypo_str, 'de')
                    predictions.append('{}\t{}'.format(sample_id, hypo_str))

            wps_meter.update(src_tokens.size(0))
            progress.log({'wps': round(wps_meter.avg)})
            num_sentences += 1

    if args.distributed_world_size > 1:
        _all_gather_bleu_scorer(scorer)
        predictions = _all_gather_predictions(predictions)

    with open(os.path.join(args.data, 'sacrebleu_reference.de'),
              'r') as reference:
        refs = [reference.readlines()]
    #reducing indexed predictions as strings is more memory efficient than reducing tuples
    predictions = [tuple(item.split('\t')) for item in predictions]
    predictions = [(int(item[0]), item[1]) for item in predictions]
    predictions.sort(key=lambda tup: tup[0])
    predictions = [
        hypo[1] + ('\n' if hypo[-1] != '\n' else '') for hypo in predictions
    ]
    sacrebleu_score = sacrebleu.corpus_bleu(predictions,
                                            refs,
                                            lowercase=args.ignore_case)
    print(f'|Detokenized {sacrebleu_score}')
    if gen_timer.sum != 0:
        print(
            '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
            .format(num_sentences, gen_timer.n, gen_timer.sum,
                    num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(subset, args.beam,
                                                      scorer.result_string()))

    print('| Eval completed in: {:.2f}s'.format(time.time() - begin))

    return scorer.score(order=4), sacrebleu_score.score
예제 #29
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(task.dataset(args.gen_subset))))

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task)

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)
        if args.fp16:
            model.half()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    ignoredIndices = []
    if args.outindices:
        f = open(args.outindices, 'r')
        for line in f.readlines():
            ignoredIndices.append(int(line.strip()))
    print("{} indices to be ignored from validation set.".format(
        len(ignoredIndices)))

    # Load dataset (possibly sharded)
    itr = data.EpochBatchIterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=models[0].max_positions(),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        savedir=os.path.join(args.decode_dir, "valid_"),
        ignoredIndices=ignoredIndices,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    if args.score_reference:
        translator = SequenceScorer(models, task.target_dictionary)
    elif args.sepahypo:
        translator = SequenceGeneratorWCSSepahypo(
            models,
            task.target_dictionary,
            beam_size=args.beam,
            stop_early=(not args.no_early_stop),
            normalize_scores=(not args.unnormalized),
            len_penalty=args.lenpen,
            unk_penalty=args.unkpen,
            sampling=args.sampling,
            sampling_topk=args.sampling_topk,
            minlen=args.min_len,
            maxlen=None,
            context=args.context,
            ngram=args.ngram,
            naive=args.naive,
            num_topics=args.num_topics,
            flatenc=args.flatenc,
            flatten_source=args.flatten_source,
            cov_penalty=args.covpen,
            keystop=args.keystop,
        )
    elif args.flatdec:
        translator = SequenceGenerator(
            models,
            task.target_dictionary,
            beam_size=args.beam,
            stop_early=(not args.no_early_stop),
            normalize_scores=(not args.unnormalized),
            len_penalty=args.lenpen,
            unk_penalty=args.unkpen,
            sampling=args.sampling,
            sampling_topk=args.sampling_topk,
            minlen=args.min_len,
            flatdec=True,
        )
    else:
        translator = SequenceGeneratorWCS(
            models,
            task.target_dictionary,
            beam_size=args.beam,
            stop_early=(not args.no_early_stop),
            normalize_scores=(not args.unnormalized),
            len_penalty=args.lenpen,
            unk_penalty=args.unkpen,
            sampling=args.sampling,
            sampling_topk=args.sampling_topk,
            minlen=args.min_len,
            maxlen=None,
            context=args.context,
            ngram=args.ngram,
            num_topics=args.num_topics,
            flatenc=args.flatenc,
            dechatt=args.dechatt,
            flatten_source=args.flatten_source,
        )

    if use_cuda:
        translator.cuda()

    # Generate and compute BLEU score
    scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    outlog = open(args.decode_dir + '/out.log', 'w', encoding='utf8')
    print(
        "* Generating target texts of max length proportional to b: {} (ax+b)".
        format(args.max_len_b))
    with progress_bar.build_progress_bar(args, itr) as t:
        if args.score_reference:
            translations = translator.score_batched_itr(t,
                                                        cuda=use_cuda,
                                                        timer=gen_timer)
        else:
            translations = translator.generate_batched_itr(
                t,
                maxlen_a=args.max_len_a,
                maxlen_b=args.max_len_b,
                cuda=use_cuda,
                timer=gen_timer,
                prefix_size=args.prefix_size,
            )

        wps_meter = TimeMeter()
        for sample_id, src_tokens, target_tokens, hypos in translations:  # for each batch
            # Process input and ground truth
            has_target = target_tokens is not None
            target_tokens = target_tokens.int().cpu() if has_target else None

            # Either retrieve the original sentences or regenerate them from tokens.
            target_str = None
            if align_dict is not None and args.raw_text:
                src_str = task.dataset(
                    args.gen_subset).src.get_original_text(sample_id)
                target_str = task.dataset(
                    args.gen_subset).tgt.get_original_text(sample_id)
            else:
                src_str = src_dict.string(src_tokens, args.remove_bpe)
                if has_target and args.target_raw_text:
                    target_str_tok = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)
                    target_str = task.dataset(
                        args.gen_subset).get_target_original_text(sample_id)

            # Process top predictions
            if args.flatdec:
                processFlatHypo(sample_id, src_tokens, target_tokens, hypos,
                                src_str, align_dict, tgt_dict, args.remove_bpe,
                                has_target, target_str)
            else:
                for j in range(min(len(hypos), args.nbest)):  # for each beam
                    doc_hypo_tokens = []
                    doc_hypo_str = []
                    doc_target_str = []

                    for i in range(
                            len(hypos[j]
                                ['beam'])):  # for each sentence of the beam
                        hypo = hypos[j]['beam'][i]
                        hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                            hypo_tokens=hypo['tokens'].int().cpu(),
                            src_str=src_str,
                            alignment=hypo['alignment'].int().cpu(),
                            align_dict=align_dict,
                            tgt_dict=tgt_dict,
                            remove_bpe=args.remove_bpe,
                        )

                        if not args.quiet:
                            print('H({})-{}\t{}\t{}'.format(
                                j, sample_id, hypo['score'], hypo_str))
                            print('P({})-{}\t{}'.format(
                                j, sample_id, ' '.join(
                                    map(
                                        lambda x: '{:.4f}'.format(x),
                                        hypo['positional_scores'].tolist(),
                                    ))))
                            print('A({})-{}\t{}'.format(
                                j, sample_id, ' '.join(
                                    map(lambda x: str(utils.item(x)),
                                        alignment))))

                        subhypo = False
                        tokens_curhypo = set(hypo_str.split())
                        for hyp in doc_hypo_str:
                            tokens_hyp = set(hyp.split())

                            # if its contained in previous sentence hypothesis
                            if hypo_str.strip()[0:-1] in hyp:
                                subhypo = True
                                break

                            shorter = len(tokens_curhypo)

                            # if it overlaps on more than 80% of its tokens
                            shorter = round(shorter * 0.8)
                            if len(tokens_curhypo.intersection(
                                    tokens_hyp)) >= shorter:
                                subhypo = True

                        if not (hypo_str in doc_hypo_str or subhypo):
                            doc_hypo_str.append(hypo_str)
                        else:
                            print("repeated on {} / {}".format(sample_id, i))
                            print(hypo_str)

                        if has_target and i == 0:
                            doc_hypo_tokens.append(hypo_tokens)

                #write files for ROUGE
                with open(
                        os.path.join(args.decode_dir,
                                     "{}.dec".format(sample_id)), 'w') as f:
                    f.write(
                        make_html_safe(" ".join(doc_hypo_str).replace(
                            tgt_dict.eod_word, "").strip()))
                    f.close()

                #TODO: call scorer for BLEU

                if target_str:
                    doc_target_str.append(target_str)
                    with open(
                            os.path.join(args.reference_dir,
                                         "{}.ref".format(sample_id)),
                            'w') as f:
                        f.write(make_html_safe(" ".join(doc_target_str)))
                        f.close()
                    with open(
                            os.path.join(args.reference_dir + "_fromdict",
                                         "{}.ref".format(sample_id)),
                            'w') as f:
                        f.write(make_html_safe(target_str_tok))
                        f.close()
                outlog.write("[{}] ".format(sample_id) +
                             " ".join(doc_hypo_str).replace(
                                 tgt_dict.eod_word, "").strip() + "\n")

            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += 1

    outlog.close()

    print(
        '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
예제 #30
0
def decode_from_dataset(models, task, args, use_cuda, output_filename=None):
    # Load dataset splits
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    output_filename = output_filename if output_filename is not None else args.decode_output_file
    if output_filename is not None:
        base_filename = output_filename
    else:
        base_filename = args.gen_subset
        if args.num_shards:
            base_filename += "%.2d" % args.shard_id
    decode_filename = _decode_filename(base_filename, args)
    outfile = open(decode_filename, "w")

    # Load dataset (possibly sharded)
    itr = data.EpochBatchIterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=models[0].max_positions(),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    if args.score_reference:
        translator = SequenceScorer(models, task.target_dictionary)
    else:
        translator = SequenceGenerator(
            models, task.target_dictionary, beam_size=args.beam,
            stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
            len_penalty=args.lenpen, unk_penalty=args.unkpen,
            sampling=args.sampling, sampling_topk=args.sampling_topk, minlen=args.min_len,
        )

    if use_cuda:
        translator.cuda()

    # Generate and compute BLEU score
    scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True

    if args.score_reference:
        translations = translator.score_batched_itr(itr, cuda=use_cuda, timer=gen_timer)
    else:
        translations = translator.generate_batched_itr(
            itr, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
            cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size,
        )

    wps_meter = TimeMeter()
    for sample_id, src_tokens, target_tokens, hypos in translations:
        # Process input and ground truth
        has_target = target_tokens is not None
        target_tokens = target_tokens.int().cpu() if has_target else None

        # Either retrieve the original sentences or regenerate them from tokens.
        if align_dict is not None:
            src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
            target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
        else:
            src_str = src_dict.string(src_tokens, args.remove_bpe)
            if has_target:
                target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)

        if not args.quiet:
            try:
                print('S-{}\t{}'.format(sample_id, src_str))
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str))
            except:
                print('S-{}\t{}'.format(sample_id, src_str.encode('utf-8')))
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str.encode('utf-8')))

        # Process top predictions
        for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
            hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                hypo_tokens=hypo['tokens'].int().cpu(),
                src_str=src_str,
                alignment=hypo['alignment'].int().cpu(),
                align_dict=align_dict,
                tgt_dict=tgt_dict,
                remove_bpe=args.remove_bpe,
            )

            if not args.quiet:
                try:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
                except:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str.encode('utf-8')))
                print('P-{}\t{}'.format(
                    sample_id,
                    ' '.join(map(
                        lambda x: '{:.4f}'.format(x),
                        hypo['positional_scores'].tolist(),
                    ))
                ))
                print('A-{}\t{}'.format(
                    sample_id,
                    ' '.join(map(lambda x: str(utils.item(x)), alignment))
                ))

            # Score only the top hypothesis
            if has_target and i == 0:
                if align_dict is not None or args.remove_bpe is not None:
                    # Convert back to tokens for evaluation with unk replacement and/or without BPE
                    target_tokens = tokenizer.Tokenizer.tokenize(
                        target_str, tgt_dict, add_if_not_exist=True)
                scorer.add(target_tokens, hypo_tokens)

        wps_meter.update(src_tokens.size(0))

        num_sentences += 1

    print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
        num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))