def main():
    """
    Launches translation (inference).
    Inference is executed on a single GPU, implementation supports beam search
    with length normalization and coverage penalty.
    """
    args = parse_args()
    if args.affinity != 'disabled':
        nproc_per_node = torch.cuda.device_count()
        affinity = gpu_affinity.set_affinity(args.local_rank, nproc_per_node,
                                             args.affinity)
        print(f'{args.local_rank}: thread affinity: {affinity}')
    device = utils.set_device(args.cuda, args.local_rank)
    utils.init_distributed(args.cuda)
    args.rank = utils.get_rank()
    os.makedirs(args.save_dir, exist_ok=True)
    utils.setup_logging()

    dllog_file = os.path.join(args.save_dir, args.dllog_file)
    utils.setup_dllogger(enabled=True, filename=dllog_file)

    if args.profile:
        try:
            pyprof.init(enable_function_stack=True)
        except NameError:
            warnings.warn('Called pyprof.init() but pyprof is not available')

    if args.env:
        utils.log_env_info()

    logging.info(f'Run arguments: {args}')
    dllogger.log(step='PARAMETER', data=vars(args))

    if not args.cuda and torch.cuda.is_available():
        warnings.warn('cuda is available but not enabled')
    if not args.cudnn:
        torch.backends.cudnn.enabled = False

    # load checkpoint and deserialize to CPU (to save GPU memory)
    if args.model:
        checkpoint = torch.load(args.model, map_location={'cuda:0': 'cpu'})

        # build GNMT model
        tokenizer = Tokenizer()
        tokenizer.set_state(checkpoint['tokenizer'])
        model_config = checkpoint['model_config']
        model_config['batch_first'] = args.batch_first
        model_config['vocab_size'] = tokenizer.vocab_size
        model = GNMT(**model_config)
        model.load_state_dict(checkpoint['state_dict'])
    elif args.synthetic:
        model = GNMT(args.synthetic_vocab, batch_first=args.batch_first)
        tokenizer = None
    else:
        raise RuntimeError(
            'Specify model either with --synthetic or with --model flag')

    # construct the dataset
    if args.input:
        data = RawTextDataset(
            raw_datafile=args.input,
            tokenizer=tokenizer,
            sort=args.sort,
        )
    elif args.input_text:
        data = RawTextDataset(
            raw_data=args.input_text,
            tokenizer=tokenizer,
            sort=args.sort,
        )
    elif args.synthetic:
        data = SyntheticDataset(args.synthetic_vocab, args.synthetic_len,
                                args.batch_size[0] * args.synthetic_batches)

    latency_table = tables.LatencyTable(args.percentiles)
    throughput_table = tables.ThroughputTable(args.percentiles)
    accuracy_table = tables.AccuracyTable('BLEU')

    dtype = {
        'fp32': torch.FloatTensor,
        'tf32': torch.FloatTensor,
        'fp16': torch.HalfTensor
    }

    for (math, batch_size, beam_size) in product(args.math, args.batch_size,
                                                 args.beam_size):
        logging.info(f'math: {math}, batch size: {batch_size}, '
                     f'beam size: {beam_size}')

        model.type(dtype[math])
        model = model.to(device)
        model.eval()

        # build the data loader
        loader = data.get_loader(
            batch_size=batch_size,
            batch_first=args.batch_first,
            pad=True,
            repeat=args.repeat[batch_size],
            num_workers=0,
        )

        # build the translator object
        translator = Translator(
            model=model,
            tokenizer=tokenizer,
            loader=loader,
            beam_size=beam_size,
            max_seq_len=args.max_seq_len,
            len_norm_factor=args.len_norm_factor,
            len_norm_const=args.len_norm_const,
            cov_penalty_factor=args.cov_penalty_factor,
            print_freq=args.print_freq,
        )

        # execute the inference
        with torch.autograd.profiler.emit_nvtx(enabled=args.profile):
            output, stats = translator.run(
                calc_bleu=args.bleu,
                eval_path=args.output,
                summary=True,
                warmup=args.warmup,
                reference_path=args.reference,
            )

        # print translated outputs
        if not args.synthetic and (not args.output and args.rank == 0):
            logging.info(f'Translated output:')
            for out in output:
                print(out)

        key = (batch_size, beam_size)
        latency_table.add(key, {math: stats['runtimes']})
        throughput_table.add(key, {math: stats['throughputs']})
        accuracy_table.add(key, {math: stats['bleu']})

    if args.tables:
        accuracy_table.write('Inference accuracy', args.math)

        if 'fp16' in args.math and 'fp32' in args.math:
            relative = 'fp32'
        elif 'fp16' in args.math and 'tf32' in args.math:
            relative = 'tf32'
        else:
            relative = None

        if 'fp32' in args.math:
            throughput_table.write('Inference throughput', 'fp32')
        if 'tf32' in args.math:
            throughput_table.write('Inference throughput', 'tf32')
        if 'fp16' in args.math:
            throughput_table.write('Inference throughput',
                                   'fp16',
                                   relative=relative)

        if 'fp32' in args.math:
            latency_table.write('Inference latency', 'fp32')
        if 'tf32' in args.math:
            latency_table.write('Inference latency', 'tf32')
        if 'fp16' in args.math:
            latency_table.write('Inference latency',
                                'fp16',
                                relative=relative,
                                reverse_speedup=True)

    avg_throughput = np.array(stats['throughputs']).mean()
    avg_latency = np.array(stats['runtimes']).mean()
    summary = {
        'eval_throughput': avg_throughput,
        'eval_bleu': stats['bleu'],
        'eval_avg_latency': avg_latency,
    }
    for p in args.percentiles:
        summary[f'eval_{p}%_latency'] = np.percentile(stats['runtimes'], p)

    dllogger.log(step=tuple(), data=summary)

    passed = utils.benchmark(stats['bleu'], args.target_bleu,
                             stats['tokens_per_sec'], args.target_perf)
    return passed
Beispiel #2
0
def main():
    """
    Launches data-parallel multi-gpu training.
    """
    training_start = time.time()
    args = parse_args()
    if args.affinity != 'disabled':
        nproc_per_node = torch.cuda.device_count()
        affinity = gpu_affinity.set_affinity(args.local_rank, nproc_per_node,
                                             args.affinity)
        print(f'{args.local_rank}: thread affinity: {affinity}')
    device = utils.set_device(args.cuda, args.local_rank)
    utils.init_distributed(args.cuda)
    args.rank = utils.get_rank()

    if not args.cudnn:
        torch.backends.cudnn.enabled = False

    # create directory for results
    os.makedirs(args.save_dir, exist_ok=True)

    # setup logging
    log_filename = f'log_rank_{utils.get_rank()}.log'
    utils.setup_logging(args.log_all_ranks,
                        os.path.join(args.save_dir, log_filename))

    dllog_file = os.path.join(args.save_dir, args.dllog_file)
    utils.setup_dllogger(enabled=True, filename=dllog_file)

    if args.env:
        utils.log_env_info()

    logging.info(f'Saving results to: {args.save_dir}')
    logging.info(f'Run arguments: {args}')
    dllogger.log(step='PARAMETER', data=vars(args))

    args.train_iter_size = set_iter_size(args.train_iter_size,
                                         args.train_global_batch_size,
                                         args.train_batch_size)

    worker_seeds, shuffling_seeds = utils.setup_seeds(args.seed, args.epochs,
                                                      device)
    worker_seed = worker_seeds[args.rank]
    logging.info(f'Worker {args.rank} is using worker seed: {worker_seed}')
    torch.manual_seed(worker_seed)

    # build tokenizer
    pad_vocab = utils.pad_vocabulary(args.math)
    tokenizer = Tokenizer(args.vocab, args.bpe_codes, args.lang, pad_vocab)

    # build datasets
    train_data = LazyParallelDataset(
        src_fname=args.train_src,
        tgt_fname=args.train_tgt,
        tokenizer=tokenizer,
        min_len=args.train_min_length,
        max_len=args.train_max_length,
        sort=False,
        max_size=args.train_max_size,
    )

    val_data = ParallelDataset(
        src_fname=args.val_src,
        tgt_fname=args.val_tgt,
        tokenizer=tokenizer,
        min_len=args.val_min_length,
        max_len=args.val_max_length,
        sort=True,
    )

    test_data = TextDataset(
        src_fname=args.test_src,
        tokenizer=tokenizer,
        min_len=args.test_min_length,
        max_len=args.test_max_length,
        sort=True,
    )

    vocab_size = tokenizer.vocab_size

    # build GNMT model
    model_config = {
        'hidden_size': args.hidden_size,
        'vocab_size': vocab_size,
        'num_layers': args.num_layers,
        'dropout': args.dropout,
        'batch_first': False,
        'share_embedding': args.share_embedding,
    }
    model = GNMT(**model_config).to(device)
    logging.info(model)

    batch_first = model.batch_first

    # define loss function (criterion) and optimizer
    criterion = build_criterion(vocab_size, config.PAD,
                                args.smoothing).to(device)

    opt_config = {'optimizer': args.optimizer, 'lr': args.lr}
    opt_config.update(literal_eval(args.optimizer_extra))
    logging.info(f'Training optimizer config: {opt_config}')

    scheduler_config = {
        'warmup_steps': args.warmup_steps,
        'remain_steps': args.remain_steps,
        'decay_interval': args.decay_interval,
        'decay_steps': args.decay_steps,
        'decay_factor': args.decay_factor
    }

    logging.info(f'Training LR schedule config: {scheduler_config}')

    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info(f'Number of parameters: {num_parameters}')

    batching_opt = {
        'shard_size': args.shard_size,
        'num_buckets': args.num_buckets
    }
    # get data loaders
    train_loader = train_data.get_loader(batch_size=args.train_batch_size,
                                         seeds=shuffling_seeds,
                                         batch_first=batch_first,
                                         shuffle=True,
                                         batching=args.batching,
                                         batching_opt=batching_opt,
                                         num_workers=args.train_loader_workers)

    val_loader = val_data.get_loader(batch_size=args.val_batch_size,
                                     batch_first=batch_first,
                                     shuffle=False,
                                     num_workers=args.val_loader_workers)

    test_loader = test_data.get_loader(batch_size=args.test_batch_size,
                                       batch_first=batch_first,
                                       shuffle=False,
                                       pad=True,
                                       num_workers=args.test_loader_workers)

    translator = Translator(
        model=model,
        tokenizer=tokenizer,
        loader=test_loader,
        beam_size=args.beam_size,
        max_seq_len=args.test_max_length,
        len_norm_factor=args.len_norm_factor,
        len_norm_const=args.len_norm_const,
        cov_penalty_factor=args.cov_penalty_factor,
        print_freq=args.print_freq,
        reference=args.test_tgt,
    )

    # create trainer
    total_train_iters = len(train_loader) // args.train_iter_size * args.epochs
    save_info = {
        'model_config': model_config,
        'config': args,
        'tokenizer': tokenizer.get_state()
    }
    loss_scaling = {
        'init_scale': args.init_scale,
        'upscale_interval': args.upscale_interval
    }
    trainer_options = dict(
        model=model,
        criterion=criterion,
        grad_clip=args.grad_clip,
        iter_size=args.train_iter_size,
        save_dir=args.save_dir,
        save_freq=args.save_freq,
        save_info=save_info,
        opt_config=opt_config,
        scheduler_config=scheduler_config,
        train_iterations=total_train_iters,
        keep_checkpoints=args.keep_checkpoints,
        math=args.math,
        loss_scaling=loss_scaling,
        print_freq=args.print_freq,
        intra_epoch_eval=args.intra_epoch_eval,
        translator=translator,
        prealloc_mode=args.prealloc_mode,
        warmup=args.warmup,
    )

    trainer = trainers.Seq2SeqTrainer(**trainer_options)

    # optionally resume from a checkpoint
    if args.resume:
        checkpoint_file = args.resume
        if os.path.isdir(checkpoint_file):
            checkpoint_file = os.path.join(checkpoint_file, 'model_best.pth')
        if os.path.isfile(checkpoint_file):
            trainer.load(checkpoint_file)
        else:
            logging.error(f'No checkpoint found at {args.resume}')

    # training loop
    best_loss = float('inf')
    training_perf = []
    break_training = False
    test_bleu = None
    for epoch in range(args.start_epoch, args.epochs):
        logging.info(f'Starting epoch {epoch}')

        train_loader.sampler.set_epoch(epoch)

        trainer.epoch = epoch
        train_loss, train_perf = trainer.optimize(train_loader)
        training_perf.append(train_perf)

        # evaluate on validation set
        if args.eval:
            logging.info(f'Running validation on dev set')
            val_loss, val_perf = trainer.evaluate(val_loader)

            # remember best prec@1 and save checkpoint
            if args.rank == 0:
                is_best = val_loss < best_loss
                best_loss = min(val_loss, best_loss)
                trainer.save(save_all=args.save_all, is_best=is_best)

        if args.eval:
            utils.barrier()
            eval_fname = f'eval_epoch_{epoch}'
            eval_path = os.path.join(args.save_dir, eval_fname)
            _, eval_stats = translator.run(
                calc_bleu=True,
                epoch=epoch,
                eval_path=eval_path,
            )
            test_bleu = eval_stats['bleu']
            if args.target_bleu and test_bleu >= args.target_bleu:
                logging.info(f'Target accuracy reached')
                break_training = True

        acc_log = []
        acc_log += [f'Summary: Epoch: {epoch}']
        acc_log += [f'Training Loss: {train_loss:.4f}']
        if args.eval:
            acc_log += [f'Validation Loss: {val_loss:.4f}']
            acc_log += [f'Test BLEU: {test_bleu:.2f}']

        perf_log = []
        perf_log += [f'Performance: Epoch: {epoch}']
        perf_log += [f'Training: {train_perf:.0f} Tok/s']
        if args.eval:
            perf_log += [f'Validation: {val_perf:.0f} Tok/s']

        if args.rank == 0:
            logging.info('\t'.join(acc_log))
            logging.info('\t'.join(perf_log))

        logging.info(f'Finished epoch {epoch}')
        if break_training:
            break

    utils.barrier()
    training_stop = time.time()
    training_time = training_stop - training_start
    logging.info(f'Total training time {training_time:.0f} s')

    table = TrainingTable()
    avg_training_perf = sum(training_perf) / len(training_perf)
    table.add(utils.get_world_size(), args.train_batch_size, test_bleu,
              avg_training_perf, training_time)
    if utils.get_rank() == 0:
        table.write('Training Summary', args.math)

    summary = {
        'train_throughput': avg_training_perf,
        'train_elapsed': training_time,
        'test_bleu': test_bleu,
    }
    dllogger.log(step=tuple(), data=summary)

    passed = utils.benchmark(test_bleu, args.target_bleu, train_perf,
                             args.target_perf)
    if not passed:
        sys.exit(1)