Exemple #1
0
def train_epoch(
        epoch, steps_per_epoch, model, loader_iter, optimizer, args,
        lr_scheduler=None, output_dir='', use_amp=False, scaler=None, model_ema=None):

    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    losses_m = AverageMeter()
    throughput_m = AverageMeter()

    model.train()

    end = time.time()
    last_idx = steps_per_epoch - 1
    num_updates = epoch * steps_per_epoch
    for batch_idx in range(steps_per_epoch):
        input, target = next(loader_iter)
        last_batch = batch_idx == last_idx
        data_time_m.update(time.time() - end)

        with torch.cuda.amp.autocast(enabled=use_amp):
            output = model(input, target)
            loss = output['loss']

        if not args.distributed:
            losses_m.update(loss.item(), input.size(0))

        scaler.scale(loss).backward()
        if args.clip_grad > 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad)
        scaler.step(optimizer)
        scaler.update()

        for p in model.parameters():
            p.grad = None

        torch.cuda.synchronize()
        if model_ema is not None:
            model_ema.update(model)
        num_updates += 1
        if batch_idx == 10:
            batch_time_m.reset()
            throughput_m.reset()

        batch_time_m.update(time.time() - end)
        throughput_m.update(float(input.size(0) * args.world_size / batch_time_m.val))
        if last_batch or (batch_idx+1) % args.log_interval == 0:
            lrl = [param_group['lr'] for param_group in optimizer.param_groups]
            lr = sum(lrl) / len(lrl)

            if args.distributed:
                reduced_loss = reduce_tensor(loss.data, args.world_size)
                losses_m.update(reduced_loss.item(), input.size(0))

            if args.rank == 0:
                dllogger_data = {'train_batch_time': batch_time_m.avg, 
                'train_loss': losses_m.avg,
                'throughput': throughput_m.avg,
                'lr': lr,
                'train_data_time': data_time_m.avg}
                dllogger.log(step=(epoch, steps_per_epoch, batch_idx), data=dllogger_data, verbosity=0)

        if lr_scheduler is not None:
            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)

        end = time.time()
        if args.benchmark:
            if batch_idx >= args.benchmark_steps:
                break
        # end for

    if hasattr(optimizer, 'sync_lookahead'):
        optimizer.sync_lookahead()

    metrics = {'train_loss': losses_m.avg, 'train_batch_time': batch_time_m.avg, 'train_throughout': throughput_m.avg}
    dllogger.log(step=(epoch,), data=metrics, verbosity=0)

    return metrics
def main():

    parser = argparse.ArgumentParser(
        description='TensorRT Tacotron 2 Inference')
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()

    TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
    encoder = load_engine(args.encoder, TRT_LOGGER)
    decoder_iter = load_engine(args.decoder, TRT_LOGGER)
    postnet = load_engine(args.postnet, TRT_LOGGER)
    waveglow = load_engine(args.waveglow, TRT_LOGGER)

    if args.waveglow_ckpt != "":
        # setup denoiser using WaveGlow PyTorch checkpoint
        waveglow_ckpt = load_and_setup_model('WaveGlow',
                                             parser,
                                             args.waveglow_ckpt,
                                             True,
                                             forward_is_infer=True)
        denoiser = Denoiser(waveglow_ckpt).cuda()
        # after initialization, we don't need WaveGlow PyTorch checkpoint
        # anymore - deleting
        del waveglow_ckpt
        torch.cuda.empty_cache()

    # initialize CUDA state
    torch.cuda.init()
    # create TRT contexts for each engine
    encoder_context = encoder.create_execution_context()
    decoder_context = decoder_iter.create_execution_context()
    postnet_context = postnet.create_execution_context()
    waveglow_context = waveglow.create_execution_context()

    DLLogger.init(backends=[
        JSONStreamBackend(Verbosity.DEFAULT, args.output + '/' +
                          args.log_file),
        StdOutBackend(Verbosity.VERBOSE)
    ])

    texts = []
    try:
        f = open(args.input, 'r')
        texts = f.readlines()
    except:
        print("Could not read file")
        sys.exit(1)

    measurements = {}

    sequences, sequence_lengths = prepare_input_sequence(texts)
    sequences = sequences.to(torch.int32)
    sequence_lengths = sequence_lengths.to(torch.int32)
    with MeasureTime(measurements, "latency"):
        mel, mel_lengths = infer_tacotron2_trt(encoder, decoder_iter, postnet,
                                               encoder_context,
                                               decoder_context,
                                               postnet_context, sequences,
                                               sequence_lengths, measurements)
        audios = infer_waveglow_trt(waveglow, waveglow_context, mel,
                                    measurements)

    with encoder_context, decoder_context, postnet_context, waveglow_context:
        pass

    audios = audios.float()
    if args.waveglow_ckpt != "":
        with MeasureTime(measurements, "denoiser"):
            audios = denoiser(audios,
                              strength=args.denoising_strength).squeeze(1)

    for i, audio in enumerate(audios):
        audio = audio[:mel_lengths[i] * args.stft_hop_length]
        audio = audio / torch.max(torch.abs(audio))
        audio_path = args.output + "audio_" + str(i) + "_trt.wav"
        write(audio_path, args.sampling_rate, audio.cpu().numpy())

    DLLogger.log(step=0,
                 data={
                     "tacotron2_encoder_latency":
                     measurements['tacotron2_encoder_time']
                 })
    DLLogger.log(step=0,
                 data={
                     "tacotron2_decoder_latency":
                     measurements['tacotron2_decoder_time']
                 })
    DLLogger.log(step=0,
                 data={
                     "tacotron2_postnet_latency":
                     measurements['tacotron2_postnet_time']
                 })
    DLLogger.log(step=0,
                 data={"waveglow_latency": measurements['waveglow_time']})
    DLLogger.log(step=0, data={"latency": measurements['latency']})

    if args.waveglow_ckpt != "":
        DLLogger.log(step=0, data={"denoiser": measurements['denoiser']})
    DLLogger.flush()

    prec = "fp16" if "fp16" in args.encoder else "fp32"
    latency = measurements['latency']
    throughput = audios.size(1) / latency
    log_data = "1," + str(sequence_lengths[0].item()) + "," + prec + "," + str(
        latency) + "," + str(throughput) + "," + str(
            mel_lengths[0].item()) + "\n"
    with open("log_bs1_" + prec + ".log", 'a') as f:
        f.write(log_data)
Exemple #3
0
def main():
    script_start = time.time()
    hvd_init()
    mpi_comm = MPI.COMM_WORLD
    args = parse_args()

    if hvd.rank() == 0:
        dllogger.init(backends=[dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE,
                                                           filename=args.log_path),
                                dllogger.StdOutBackend(verbosity=dllogger.Verbosity.VERBOSE)])
    else:
        dllogger.init(backends=[])

    args.world_size = hvd.size()
    dllogger.log(data=vars(args), step='PARAMETER')

    if args.seed is None:
        if hvd.rank() == 0:
            seed = int(time.time())
        else:
            seed = None

        seed = mpi_comm.bcast(seed, root=0)
    else:
        seed = args.seed

    tf.random.set_random_seed(seed)
    np.random.seed(seed)
    cp.random.seed(seed)

    if args.amp:
        os.environ["TF_ENABLE_AUTO_MIXED_PRECISION"] = "1"

    if args.checkpoint_dir is not None:
        os.makedirs(args.checkpoint_dir, exist_ok=True)
        final_checkpoint_path = os.path.join(args.checkpoint_dir, 'model.ckpt')
    else:
        final_checkpoint_path = None

    # Load converted data and get statistics
    train_df = pd.read_pickle(args.data+'/train_ratings.pickle')
    test_df = pd.read_pickle(args.data+'/test_ratings.pickle')
    nb_users, nb_items = train_df.max() + 1

    # Extract train and test feature tensors from dataframe
    pos_train_users = train_df.iloc[:, 0].values.astype(np.int32)
    pos_train_items = train_df.iloc[:, 1].values.astype(np.int32)
    pos_test_users = test_df.iloc[:, 0].values.astype(np.int32)
    pos_test_items = test_df.iloc[:, 1].values.astype(np.int32)
    # Negatives indicator for negatives generation
    neg_mat = np.ones((nb_users, nb_items), dtype=np.bool)
    neg_mat[pos_train_users, pos_train_items] = 0

    # Get the local training/test data
    train_users, train_items, train_labels = get_local_train_data(
        pos_train_users, pos_train_items, args.negative_samples
    )
    test_users, test_items = get_local_test_data(
        pos_test_users, pos_test_items
    )

    # Create and run Data Generator in a separate thread
    data_generator = DataGenerator(
        args.seed,
        hvd.rank(),
        nb_users,
        nb_items,
        neg_mat,
        train_users,
        train_items,
        train_labels,
        args.batch_size // hvd.size(),
        args.negative_samples,
        test_users,
        test_items,
        args.valid_users_per_batch,
        args.valid_negative,
        )

    # Create tensorflow session and saver
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    if args.xla:
        config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
    sess = tf.Session(config=config)

    # Input tensors
    users = tf.placeholder(tf.int32, shape=(None,))
    items = tf.placeholder(tf.int32, shape=(None,))
    labels = tf.placeholder(tf.int32, shape=(None,))
    is_dup = tf.placeholder(tf.float32, shape=(None,))
    dropout = tf.placeholder_with_default(args.dropout, shape=())
    # Model ops and saver
    hit_rate, ndcg, eval_op, train_op = ncf_model_ops(
        users,
        items,
        labels,
        is_dup,
        params={
            'val_batch_size': args.valid_negative+1,
            'top_k': args.topk,
            'learning_rate': args.learning_rate,
            'beta_1': args.beta1,
            'beta_2': args.beta2,
            'epsilon': args.eps,
            'num_users': nb_users,
            'num_items': nb_items,
            'num_factors': args.factors,
            'mf_reg': 0,
            'layer_sizes': args.layers,
            'layer_regs': [0. for i in args.layers],
            'dropout': dropout,
            'sigmoid': True,
            'loss_scale': args.loss_scale
        },
        mode='TRAIN' if args.mode == 'train' else 'EVAL'
    )
    saver = tf.train.Saver()

    # Accuracy metric tensors
    hr_sum = tf.get_default_graph().get_tensor_by_name('neumf/hit_rate/total:0')
    hr_cnt = tf.get_default_graph().get_tensor_by_name('neumf/hit_rate/count:0')
    ndcg_sum = tf.get_default_graph().get_tensor_by_name('neumf/ndcg/total:0')
    ndcg_cnt = tf.get_default_graph().get_tensor_by_name('neumf/ndcg/count:0')

    # Prepare evaluation data
    data_generator.prepare_eval_data()

    if args.load_checkpoint_path:
        saver.restore(sess, args.load_checkpoint_path)
    else:
        # Manual initialize weights
        sess.run(tf.global_variables_initializer())

    # If test mode, run one eval
    if args.mode == 'test':
        sess.run(tf.local_variables_initializer())
        eval_start = time.time()
        for user_batch, item_batch, dup_batch \
            in zip(data_generator.eval_users, data_generator.eval_items, data_generator.dup_mask):
            sess.run(
                eval_op,
                feed_dict={
                    users: user_batch,
                    items: item_batch,
                    is_dup:dup_batch, dropout: 0.0
                }
            )
        eval_duration = time.time() - eval_start

        # Report results
        hit_rate_sum = sess.run(hvd.allreduce(hr_sum, average=False))
        hit_rate_cnt = sess.run(hvd.allreduce(hr_cnt, average=False))
        ndcg_sum = sess.run(hvd.allreduce(ndcg_sum, average=False))
        ndcg_cnt = sess.run(hvd.allreduce(ndcg_cnt, average=False))

        hit_rate = hit_rate_sum / hit_rate_cnt
        ndcg = ndcg_sum / ndcg_cnt

        if hvd.rank() == 0:
            eval_throughput = pos_test_users.shape[0] * (args.valid_negative + 1) / eval_duration
            dllogger.log(step=tuple(), data={'eval_throughput': eval_throughput,
                                             'eval_time': eval_duration,
                                             'hr@10': float(hit_rate),
                                             'ndcg': float(ndcg)})
        return

    # Performance Metrics
    train_times = list()
    eval_times = list()
    # Accuracy Metrics
    first_to_target = None
    time_to_train = 0.0
    best_hr = 0
    best_epoch = 0
    # Buffers for global metrics
    global_hr_sum = np.ones(1)
    global_hr_count = np.ones(1)
    global_ndcg_sum = np.ones(1)
    global_ndcg_count = np.ones(1)
    # Buffers for local metrics
    local_hr_sum = np.ones(1)
    local_hr_count = np.ones(1)
    local_ndcg_sum = np.ones(1)
    local_ndcg_count = np.ones(1)

    # Begin training
    begin_train = time.time()
    for epoch in range(args.epochs):
        # Train for one epoch
        train_start = time.time()
        data_generator.prepare_train_data()
        for user_batch, item_batch, label_batch \
            in zip(data_generator.train_users_batches,
                   data_generator.train_items_batches,
                   data_generator.train_labels_batches):
            sess.run(
                train_op,
                feed_dict={
                    users: user_batch.get(),
                    items: item_batch.get(),
                    labels: label_batch.get()
                }
            )
        train_duration = time.time() - train_start
        # Only log "warm" epochs
        if epoch >= 1:
            train_times.append(train_duration)
        # Evaluate
        if epoch > args.eval_after:
            eval_start = time.time()
            sess.run(tf.local_variables_initializer())
            for user_batch, item_batch, dup_batch \
                in zip(data_generator.eval_users,
                       data_generator.eval_items,
                       data_generator.dup_mask):
                sess.run(
                    eval_op,
                    feed_dict={
                        users: user_batch,
                        items: item_batch,
                        is_dup: dup_batch,
                        dropout: 0.0
                    }
                )
            # Compute local metrics
            local_hr_sum[0] = sess.run(hr_sum)
            local_hr_count[0] = sess.run(hr_cnt)
            local_ndcg_sum[0] = sess.run(ndcg_sum)
            local_ndcg_count[0] = sess.run(ndcg_cnt)
            # Reduce metrics across all workers

            mpi_comm.Reduce(local_hr_count, global_hr_count)
            mpi_comm.Reduce(local_hr_sum, global_hr_sum)
            mpi_comm.Reduce(local_ndcg_count, global_ndcg_count)
            mpi_comm.Reduce(local_ndcg_sum, global_ndcg_sum)

            # Calculate metrics
            hit_rate = global_hr_sum[0] / global_hr_count[0]
            ndcg = global_ndcg_sum[0] / global_ndcg_count[0]

            eval_duration = time.time() - eval_start
            # Only log "warm" epochs
            if epoch >= 1:
                eval_times.append(eval_duration)

            if hvd.rank() == 0:
                dllogger.log(step=(epoch,), data={
                                'train_time': train_duration,
                                'eval_time': eval_duration,
                                'hr@10': hit_rate,
                                'ndcg': ndcg})

                # Update summary metrics
                if hit_rate > args.target and first_to_target is None:
                    first_to_target = epoch
                    time_to_train = time.time() - begin_train
                if hit_rate > best_hr:
                    best_hr = hit_rate
                    best_epoch = epoch
                    time_to_best =  time.time() - begin_train
                    if hit_rate > args.target and final_checkpoint_path:
                        saver.save(sess, final_checkpoint_path)

    # Final Summary
    if hvd.rank() == 0:
        train_times = np.array(train_times)
        train_throughputs = pos_train_users.shape[0]*(args.negative_samples+1) / train_times
        eval_times = np.array(eval_times)
        eval_throughputs = pos_test_users.shape[0]*(args.valid_negative+1) / eval_times

        dllogger.log(step=tuple(), data={
            'average_train_time_per_epoch': np.mean(train_times),
            'average_train_throughput': np.mean(train_throughputs),
            'average_eval_time_per_epoch': np.mean(eval_times),
            'average_eval_throughput': np.mean(eval_throughputs),
            'first_epoch_to_hit': first_to_target,
            'time_to_train': time_to_train,
            'time_to_best': time_to_best,
            'best_hr': best_hr,
            'best_epoch': best_epoch})
        dllogger.flush()

    sess.close()
    return
def main():
    """
    Launches text to speech (inference).
    Inference is executed on a single GPU.
    """
    parser = argparse.ArgumentParser(
        description='PyTorch Tacotron 2 Inference')
    parser = parse_args(parser)
    args, unknown_args = parser.parse_known_args()

    DLLogger.init(backends=[
        JSONStreamBackend(Verbosity.DEFAULT, args.log_file),
        StdOutBackend(Verbosity.VERBOSE)
    ])
    for k, v in vars(args).items():
        DLLogger.log(step="PARAMETER", data={k: v})
    DLLogger.log(step="PARAMETER", data={'model_name': 'Tacotron2_PyT'})

    measurements_all = {
        "pre_processing": [],
        "tacotron2_encoder_time": [],
        "tacotron2_decoder_time": [],
        "tacotron2_postnet_time": [],
        "tacotron2_latency": [],
        "waveglow_latency": [],
        "latency": [],
        "type_conversion": [],
        "data_transfer": [],
        "storage": [],
        "tacotron2_items_per_sec": [],
        "waveglow_items_per_sec": [],
        "num_mels_per_audio": [],
        "throughput": []
    }

    print("args:", args, unknown_args)

    torch.cuda.init()

    TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
    encoder = load_engine(args.encoder, TRT_LOGGER)
    decoder_iter = load_engine(args.decoder, TRT_LOGGER)
    postnet = load_engine(args.postnet, TRT_LOGGER)
    waveglow = load_engine(args.waveglow, TRT_LOGGER)

    if args.waveglow_ckpt != "":
        # setup denoiser using WaveGlow PyTorch checkpoint
        waveglow_ckpt = load_and_setup_model('WaveGlow',
                                             parser,
                                             args.waveglow_ckpt,
                                             True,
                                             forward_is_infer=True)
        denoiser = Denoiser(waveglow_ckpt).cuda()
        # after initialization, we don't need WaveGlow PyTorch checkpoint
        # anymore - deleting
        del waveglow_ckpt
        torch.cuda.empty_cache()

    # create TRT contexts for each engine
    encoder_context = encoder.create_execution_context()
    decoder_context = decoder_iter.create_execution_context()
    postnet_context = postnet.create_execution_context()
    waveglow_context = waveglow.create_execution_context()

    texts = [
        "The forms of printed letters should be beautiful, and that their arrangement on the page should be reasonable and a help to the shapeliness of the letters themselves. The forms of printed letters should be beautiful, and that their arrangement on the page should be reasonable and a help to the shapeliness of the letters themselves."
    ]
    texts = [texts[0][:args.input_length]]
    texts = texts * args.batch_size

    warmup_iters = 3

    for iter in range(args.num_iters):

        measurements = {}

        with MeasureTime(measurements, "pre_processing"):
            sequences_padded, input_lengths = prepare_input_sequence(texts)
            sequences_padded = sequences_padded.to(torch.int32)
            input_lengths = input_lengths.to(torch.int32)

        with torch.no_grad():
            with MeasureTime(measurements, "latency"):
                with MeasureTime(measurements, "tacotron2_latency"):
                    mel, mel_lengths = infer_tacotron2_trt(
                        encoder, decoder_iter, postnet, encoder_context,
                        decoder_context, postnet_context, sequences_padded,
                        input_lengths, measurements, args.fp16)

                with MeasureTime(measurements, "waveglow_latency"):
                    audios = infer_waveglow_trt(waveglow, waveglow_context,
                                                mel, measurements, args.fp16)

        num_mels = mel.size(0) * mel.size(2)
        num_samples = audios.size(0) * audios.size(1)

        with MeasureTime(measurements, "type_conversion"):
            audios = audios.float()

        with MeasureTime(measurements, "data_transfer"):
            audios = audios.cpu()

        with MeasureTime(measurements, "storage"):
            audios = audios.numpy()
            for i, audio in enumerate(audios):
                audio_path = "audio_" + str(i) + ".wav"
                write(audio_path, args.sampling_rate,
                      audio[:mel_lengths[i] * args.stft_hop_length])

        measurements['tacotron2_items_per_sec'] = num_mels / measurements[
            'tacotron2_latency']
        measurements['waveglow_items_per_sec'] = num_samples / measurements[
            'waveglow_latency']
        measurements['num_mels_per_audio'] = mel.size(2)
        measurements['throughput'] = num_samples / measurements['latency']

        if iter >= warmup_iters:
            for k, v in measurements.items():
                if k in measurements_all.keys():
                    measurements_all[k].append(v)
                    DLLogger.log(step=(iter - warmup_iters), data={k: v})

    DLLogger.flush()

    print_stats(measurements_all)
def evaluate(eval_iter, model, meters, log_interval, max_size=None, repeat=1):
    total_len, total_loss = 0, 0.
    eval_step = 0

    log_throughput = 0
    log_latency = 0
    log_loss = 0

    torch.cuda.synchronize()
    start_time = time.time()
    with torch.no_grad():
        mems = None
        for _ in range(repeat):
            for idx, (data, target, seq_len, warm) in enumerate(eval_iter):
                if max_size and idx >= max_size:
                    break
                eval_step += 1

                torch.cuda.synchronize()
                start_iter = time.time()
                loss, mems = model(data, target, mems)
                torch.cuda.synchronize()
                elapsed = time.time() - start_iter

                loss = loss.float().mean()
                log_loss += loss.item()
                if warm:
                    # assert all([m.size(0) == model.mem_len for m in mems])
                    total_loss += seq_len * loss.item()
                    total_len += seq_len

                meters['eval_latency'].update(elapsed)
                log_latency += elapsed

                target_tokens = target.numel()
                throughput = target_tokens / elapsed
                throughput = utils.distributed.all_reduce_item(throughput, op='sum')
                meters['eval_throughput'].update(throughput)
                log_throughput += throughput

                if eval_step % log_interval == 0:
                    log_throughput /= log_interval
                    log_latency /= log_interval
                    log_loss /= log_interval
                    log_ppl = math.exp(log_loss)

                    log_str = '| step {:>8d} | batches {:>6d} / {:d} ' \
                        '| ms/batch {:5.2f} | tok/s {:7.0f} | loss {:5.2f} | ppl {:5.2f}'.format(
                            eval_step,
                            idx+1,
                            eval_iter.n_batch,
                            log_latency * 1000,
                            log_throughput,
                            log_loss,
                            log_ppl,
                            )
                    logging.info(log_str)

                    dllogger_data = {
                        'eval_latency': log_latency * 1000,
                        'eval_throughput': log_throughput,
                        'eval_loss': log_loss,
                        'eval_perplexity': log_ppl,
                        }
                    dllogger.log(step=eval_step, data=dllogger_data)

                    log_throughput = 0
                    log_latency = 0
                    log_loss = 0

    utils.distributed.barrier()
    torch.cuda.synchronize()
    total_time = time.time() - start_time
    logging.info('Time : {:.2f}s, {:.2f}ms/segment'.format(
            total_time, 1000 * total_time / (idx+1)))

    avg_loss = total_loss / total_len
    avg_loss = utils.distributed.all_reduce_item(avg_loss, op='mean')
    return avg_loss
Exemple #6
0
def main():
    """
    Launches text to speech (inference).
    Inference is executed on a single GPU or CPU.
    """
    parser = argparse.ArgumentParser(
        description='PyTorch Tacotron 2 Inference')
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()

    DLLogger.init(backends=[
        JSONStreamBackend(Verbosity.DEFAULT, args.output + '/' +
                          args.log_file),
        StdOutBackend(Verbosity.VERBOSE)
    ])
    for k, v in vars(args).items():
        DLLogger.log(step="PARAMETER", data={k: v})
    DLLogger.log(step="PARAMETER", data={'model_name': 'Tacotron2_PyT'})

    tacotron2 = load_and_setup_model('Tacotron2',
                                     parser,
                                     args.tacotron2,
                                     args.fp16,
                                     args.cpu,
                                     forward_is_infer=True)
    waveglow = load_and_setup_model('WaveGlow',
                                    parser,
                                    args.waveglow,
                                    args.fp16,
                                    args.cpu,
                                    forward_is_infer=True)
    denoiser = Denoiser(waveglow)
    if not args.cpu:
        denoiser.cuda()

    jitted_tacotron2 = torch.jit.script(tacotron2)

    texts = []
    try:
        f = open(args.input, 'r')
        texts = f.readlines()
    except:
        print("Could not read file")
        sys.exit(1)

    if args.include_warmup:
        sequence = torch.randint(low=0, high=148, size=(1, 50)).long()
        input_lengths = torch.IntTensor([sequence.size(1)]).long()
        if not args.cpu:
            sequence = sequence.cuda()
            input_lengths = input_lengths.cuda()
        for i in range(3):
            with torch.no_grad():
                mel, mel_lengths, _ = jitted_tacotron2(sequence, input_lengths)
                _ = waveglow(mel)

    measurements = {}

    sequences_padded, input_lengths = prepare_input_sequence(texts, args.cpu)

    with torch.no_grad(), MeasureTime(measurements, "tacotron2_time",
                                      args.cpu):
        mel, mel_lengths, alignments = jitted_tacotron2(
            sequences_padded, input_lengths)

    with torch.no_grad(), MeasureTime(measurements, "waveglow_time", args.cpu):
        audios = waveglow(mel, sigma=args.sigma_infer)
        audios = audios.float()
    with torch.no_grad(), MeasureTime(measurements, "denoiser_time", args.cpu):
        audios = denoiser(audios, strength=args.denoising_strength).squeeze(1)

    print("Stopping after", mel.size(2), "decoder steps")

    tacotron2_infer_perf = mel.size(0) * mel.size(
        2) / measurements['tacotron2_time']
    waveglow_infer_perf = audios.size(0) * audios.size(
        1) / measurements['waveglow_time']

    DLLogger.log(step=0,
                 data={"tacotron2_items_per_sec": tacotron2_infer_perf})
    DLLogger.log(step=0,
                 data={"tacotron2_latency": measurements['tacotron2_time']})
    DLLogger.log(step=0, data={"waveglow_items_per_sec": waveglow_infer_perf})
    DLLogger.log(step=0,
                 data={"waveglow_latency": measurements['waveglow_time']})
    DLLogger.log(step=0,
                 data={"denoiser_latency": measurements['denoiser_time']})
    DLLogger.log(step=0,
                 data={
                     "latency": (measurements['tacotron2_time'] +
                                 measurements['waveglow_time'] +
                                 measurements['denoiser_time'])
                 })

    for i, audio in enumerate(audios):

        plt.imshow(alignments[i].float().data.cpu().numpy().T,
                   aspect="auto",
                   origin="lower")
        figure_path = args.output + "alignment_" + str(
            i) + "_" + args.suffix + ".png"
        plt.savefig(figure_path)

        audio = audio[:mel_lengths[i] * args.stft_hop_length]
        audio = audio / torch.max(torch.abs(audio))
        audio_path = args.output + "audio_" + str(
            i) + "_" + args.suffix + ".wav"
        write(audio_path, args.sampling_rate, audio.cpu().numpy())

    DLLogger.flush()
Exemple #7
0
def main():
    parser = argparse.ArgumentParser(
        description="PyTorch Object Detection Inference")
    parser.add_argument(
        "--config-file",
        default=
        "/workspace/object_detection/configs/e2e_mask_rcnn_R_50_FPN_1x.yaml",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument("--local_rank",
                        type=int,
                        default=os.getenv('LOCAL_RANK', 0))
    parser.add_argument("--json-summary",
                        help="Out file for DLLogger",
                        default="dllogger_inference.out",
                        type=str)
    parser.add_argument(
        "--skip-eval",
        dest="skip_eval",
        help="Do not eval the predictions",
        action="store_true",
    )
    parser.add_argument(
        "--fp16",
        help="Mixed precision training",
        action="store_true",
    )
    parser.add_argument(
        "--amp",
        help="Mixed precision training",
        action="store_true",
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()
    args.fp16 = args.fp16 or args.amp
    num_gpus = int(
        os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    distributed = num_gpus > 1

    if distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")
        synchronize()

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    save_dir = ""
    logger = setup_logger("maskrcnn_benchmark", save_dir, get_rank())
    if is_main_process():
        dllogger.init(backends=[
            dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE,
                                       filename=args.json_summary),
            dllogger.StdOutBackend(verbosity=dllogger.Verbosity.VERBOSE,
                                   step_format=format_step)
        ])
    else:
        dllogger.init(backends=[])

    save_dir = ""
    dllogger.log(step="PARAMETER", data={"config": cfg})
    dllogger.log(step="PARAMETER", data={"gpu_count": num_gpus})
    # dllogger.log(step="PARAMETER", data={"env_info": collect_env_info()})
    model = build_detection_model(cfg)
    model.to(cfg.MODEL.DEVICE)

    # Initialize mixed-precision
    if args.fp16:
        use_mixed_precision = True
    else:
        use_mixed_precision = cfg.DTYPE == "float16"

    output_dir = cfg.OUTPUT_DIR
    checkpointer = DetectronCheckpointer(cfg, model, save_dir=output_dir)
    _ = checkpointer.load(cfg.MODEL.WEIGHT)

    iou_types = ("bbox", )
    if cfg.MODEL.MASK_ON:
        iou_types = iou_types + ("segm", )
    output_folders = [None] * len(cfg.DATASETS.TEST)
    dataset_names = cfg.DATASETS.TEST
    if cfg.OUTPUT_DIR:
        for idx, dataset_name in enumerate(dataset_names):
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference",
                                         dataset_name)
            mkdir(output_folder)
            output_folders[idx] = output_folder
    data_loaders_val = make_data_loader(cfg,
                                        is_train=False,
                                        is_distributed=distributed)

    results = []
    for output_folder, dataset_name, data_loader_val in zip(
            output_folders, dataset_names, data_loaders_val):
        if use_mixed_precision:
            with torch.cuda.amp.autocast():
                result = inference(
                    model,
                    data_loader_val,
                    dataset_name=dataset_name,
                    iou_types=iou_types,
                    box_only=cfg.MODEL.RPN_ONLY,
                    device=cfg.MODEL.DEVICE,
                    expected_results=cfg.TEST.EXPECTED_RESULTS,
                    expected_results_sigma_tol=cfg.TEST.
                    EXPECTED_RESULTS_SIGMA_TOL,
                    output_folder=output_folder,
                    skip_eval=args.skip_eval,
                    dllogger=dllogger,
                )
        else:
            result = inference(
                model,
                data_loader_val,
                dataset_name=dataset_name,
                iou_types=iou_types,
                box_only=cfg.MODEL.RPN_ONLY,
                device=cfg.MODEL.DEVICE,
                expected_results=cfg.TEST.EXPECTED_RESULTS,
                expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
                output_folder=output_folder,
                skip_eval=args.skip_eval,
                dllogger=dllogger,
            )
        synchronize()
        results.append(result)

    if is_main_process() and not args.skip_eval:
        map_results, raw_results = results[0]
        bbox_map = map_results.results["bbox"]['AP']
        segm_map = map_results.results["segm"]['AP']
        dllogger.log(step=tuple(),
                     data={
                         "BBOX_mAP": bbox_map,
                         "MASK_mAP": segm_map
                     })
Exemple #8
0
def main():

    args = parse_arguments()

    if args.use_env and 'LOCAL_RANK' in os.environ:
        args.local_rank = int(os.environ['LOCAL_RANK'])

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    device, args = setup_training(args)
    dllogger.log(step="PARAMETER", data={"Config": [str(args)]})

    # Prepare optimizer
    model, optimizer, lr_scheduler, checkpoint, global_step = prepare_model_and_optimizer(
        args, device)

    if is_main_process():
        dllogger.log(step="PARAMETER", data={"SEED": args.seed})

    raw_train_start = time.time()
    if args.do_train:
        if is_main_process():
            dllogger.log(step="PARAMETER", data={"train_start": True})
            dllogger.log(step="PARAMETER",
                         data={"batch_size_per_gpu": args.train_batch_size})
            dllogger.log(step="PARAMETER",
                         data={"learning_rate": args.learning_rate})

        model.train()
        most_recent_ckpts_paths = []
        average_loss = 0.0  # averaged loss every args.log_freq steps
        epoch = 0
        training_steps = 0

        pool = ProcessPoolExecutor(1)

        # Note: We loop infinitely over epochs, termination is handled via iteration count
        while True:
            thread = None
            if not args.resume_from_checkpoint or epoch > 0 or (
                    args.phase2 and global_step < 1) or args.init_checkpoint:
                files = [
                    os.path.join(args.input_dir, f)
                    for f in os.listdir(args.input_dir)
                    if os.path.isfile(os.path.join(args.input_dir, f))
                    and 'training' in f
                ]
                files.sort()
                num_files = len(files)
                random.shuffle(files)
                f_start_id = 0
            else:
                f_start_id = checkpoint['files'][0]
                files = checkpoint['files'][1:]
                args.resume_from_checkpoint = False
                num_files = len(files)

            shared_file_list = {}

            if torch.distributed.is_initialized(
            ) and torch.distributed.get_world_size() > num_files:
                remainder = torch.distributed.get_world_size() % num_files
                data_file = files[
                    (f_start_id * torch.distributed.get_world_size() +
                     torch.distributed.get_rank() + remainder * f_start_id) %
                    num_files]
            else:
                data_file = files[
                    (f_start_id * torch.distributed.get_world_size() +
                     torch.distributed.get_rank()) % num_files]

            previous_file = data_file

            train_data = pretraining_dataset(data_file,
                                             args.max_predictions_per_seq)
            train_sampler = RandomSampler(train_data)
            train_dataloader = DataLoader(train_data,
                                          sampler=train_sampler,
                                          batch_size=args.train_batch_size *
                                          args.n_gpu,
                                          num_workers=4,
                                          pin_memory=True)
            # shared_file_list["0"] = (train_dataloader, data_file)

            overflow_buf = None
            if args.allreduce_post_accumulation:
                overflow_buf = torch.cuda.IntTensor([0])

            if len(files) == 1:
                f_start_id = -1
            for f_id in range(f_start_id + 1, len(files)):

                if torch.distributed.get_world_size() > num_files:
                    data_file = files[
                        (f_id * torch.distributed.get_world_size() +
                         torch.distributed.get_rank() + remainder * f_id) %
                        num_files]
                else:
                    data_file = files[
                        (f_id * torch.distributed.get_world_size() +
                         torch.distributed.get_rank()) % num_files]

                previous_file = data_file

                dataset_future = pool.submit(create_pretraining_dataset,
                                             data_file,
                                             args.max_predictions_per_seq,
                                             shared_file_list, args)

                train_iter = tqdm(train_dataloader, desc="Iteration"
                                  ) if is_main_process() else train_dataloader
                for step, batch in enumerate(train_iter):

                    training_steps += 1
                    batch = [t.to(device) for t in batch]
                    input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch
                    loss = model(
                        input_ids=input_ids,
                        token_type_ids=segment_ids,
                        attention_mask=input_mask,
                        masked_lm_labels=masked_lm_labels,
                        next_sentence_label=next_sentence_labels,
                        checkpoint_activations=args.checkpoint_activations)
                    if args.n_gpu > 1:
                        loss = loss.mean()  # mean() to average on multi-gpu.

                    divisor = args.gradient_accumulation_steps
                    if args.gradient_accumulation_steps > 1:
                        if not args.allreduce_post_accumulation:
                            # this division was merged into predivision
                            loss = loss / args.gradient_accumulation_steps
                            divisor = 1.0
                    if args.fp16:
                        with amp.scale_loss(
                                loss,
                                optimizer,
                                delay_overflow_check=args.
                                allreduce_post_accumulation) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()
                    average_loss += loss.item()

                    if training_steps % args.gradient_accumulation_steps == 0:
                        lr_scheduler.step()  # learning rate warmup
                        global_step = take_optimizer_step(
                            args, optimizer, model, overflow_buf, global_step)

                    if global_step >= args.max_steps:
                        train_time_raw = time.time() - raw_train_start
                        last_num_steps = int(
                            training_steps /
                            args.gradient_accumulation_steps) % args.log_freq
                        last_num_steps = args.log_freq if last_num_steps == 0 else last_num_steps
                        average_loss = torch.tensor(
                            average_loss, dtype=torch.float32).cuda()
                        average_loss = average_loss / (last_num_steps *
                                                       divisor)
                        if (torch.distributed.is_initialized()):
                            average_loss /= torch.distributed.get_world_size()
                            torch.distributed.all_reduce(average_loss)
                        final_loss = average_loss.item()
                        if is_main_process():
                            dllogger.log(step=(
                                epoch,
                                training_steps /
                                args.gradient_accumulation_steps,
                            ),
                                         data={"final_loss": final_loss})
                    elif training_steps % (
                            args.log_freq *
                            args.gradient_accumulation_steps) == 0:
                        if is_main_process():
                            dllogger.log(
                                step=(
                                    epoch,
                                    global_step,
                                ),
                                data={
                                    "average_loss":
                                    average_loss / (args.log_freq * divisor),
                                    "step_loss":
                                    loss.item() *
                                    args.gradient_accumulation_steps / divisor,
                                    "learning_rate":
                                    optimizer.param_groups[0]['lr']
                                })
                        average_loss = 0

                    if global_step >= args.max_steps or training_steps % (
                            args.num_steps_per_checkpoint *
                            args.gradient_accumulation_steps) == 0:
                        if is_main_process() and not args.skip_checkpoint:
                            # Save a trained model
                            dllogger.log(step="PARAMETER",
                                         data={"checkpoint_step": global_step})
                            model_to_save = model.module if hasattr(
                                model, 'module'
                            ) else model  # Only save the model it-self
                            if args.resume_step < 0 or not args.phase2:
                                output_save_file = os.path.join(
                                    args.output_dir,
                                    "ckpt_{}.pt".format(global_step))
                            else:
                                output_save_file = os.path.join(
                                    args.output_dir,
                                    "ckpt_{}.pt".format(global_step +
                                                        args.phase1_end_step))
                            if args.do_train:
                                torch.save(
                                    {
                                        'model':
                                        model_to_save.state_dict(),
                                        'optimizer':
                                        optimizer.state_dict(),
                                        'master params':
                                        list(amp.master_params(optimizer)),
                                        'files': [f_id] + files
                                    }, output_save_file)

                                most_recent_ckpts_paths.append(
                                    output_save_file)
                                if len(most_recent_ckpts_paths) > 3:
                                    ckpt_to_be_removed = most_recent_ckpts_paths.pop(
                                        0)
                                    os.remove(ckpt_to_be_removed)

                        if global_step >= args.max_steps:
                            del train_dataloader
                            # thread.join()
                            return args, final_loss, train_time_raw

                del train_dataloader
                # thread.join()
                # Make sure pool has finished and switch train_dataloader
                # NOTE: Will block until complete
                train_dataloader, data_file = dataset_future.result(
                    timeout=None)

            epoch += 1
                    if not args.skip_checkpoint and (
                            dynamic_optimizer_step %
                            args.num_steps_per_checkpoint == 0):
                        checkpoint_step(args, epoch, dynamic_optimizer_step,
                                        model, optimizer, grad_scaler,
                                        most_recent_ckpts_paths)

        epoch += 1


if __name__ == "__main__":

    now = time.time()
    args, train_time_raw, stats, skip_fwd_bwd_for_perf = main()
    gpu_count = args.n_gpu
    if torch.distributed.is_initialized():
        gpu_count = get_world_size()
    if is_main_process():
        e2e_time = time.time() - now
        training_perf = args.train_batch_size * gpu_count * (
            stats.host_stat_value('model_step') -
            skip_fwd_bwd_for_perf) / train_time_raw
        dllogger.log(step=tuple(),
                     data={
                         "e2e_train_time": e2e_time,
                         "training_sequences_per_second": training_perf,
                         "final_loss": stats.host_stat_value('average_loss'),
                         "raw_train_time": train_time_raw
                     })
    dllogger.flush()
Exemple #10
0
def main():
    parser = argparse.ArgumentParser(
        description='FastPitch Data Pre-processing')
    parser = parse_args(parser)
    args, unk_args = parser.parse_known_args()
    if len(unk_args) > 0:
        raise ValueError(f'Invalid options {unk_args}')

    DLLogger.init(backends=[
        JSONStreamBackend(Verbosity.DEFAULT,
                          Path(args.dataset_path, args.log_file)),
        StdOutBackend(Verbosity.VERBOSE)
    ])
    for k, v in vars(args).items():
        DLLogger.log(step="PARAMETER", data={k: v})
    DLLogger.flush()

    if args.extract_mels:
        Path(args.dataset_path, 'mels').mkdir(parents=False, exist_ok=True)

    if args.extract_pitch:
        Path(args.dataset_path, 'pitch').mkdir(parents=False, exist_ok=True)

    if args.save_alignment_priors:
        Path(args.dataset_path, 'alignment_priors').mkdir(parents=False,
                                                          exist_ok=True)

    for filelist in args.wav_text_filelists:

        print(f'Processing {filelist}...')

        dataset = TTSDataset(args.dataset_path,
                             filelist,
                             text_cleaners=['english_cleaners_v2'],
                             n_mel_channels=args.n_mel_channels,
                             p_arpabet=0.0,
                             n_speakers=args.n_speakers,
                             load_mel_from_disk=False,
                             load_pitch_from_disk=False,
                             pitch_mean=None,
                             pitch_std=None,
                             max_wav_value=args.max_wav_value,
                             sampling_rate=args.sampling_rate,
                             filter_length=args.filter_length,
                             hop_length=args.hop_length,
                             win_length=args.win_length,
                             mel_fmin=args.mel_fmin,
                             mel_fmax=args.mel_fmax,
                             betabinomial_online_dir=None,
                             pitch_online_dir=None,
                             pitch_online_method=args.f0_method)

        data_loader = DataLoader(dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 sampler=None,
                                 num_workers=args.n_workers,
                                 collate_fn=TTSCollate(),
                                 pin_memory=False,
                                 drop_last=False)

        all_filenames = set()
        for i, batch in enumerate(tqdm.tqdm(data_loader)):
            tik = time.time()

            _, input_lens, mels, mel_lens, _, pitch, _, _, attn_prior, fpaths = batch

            # Ensure filenames are unique
            for p in fpaths:
                fname = Path(p).name
                if fname in all_filenames:
                    raise ValueError(f'Filename is not unique: {fname}')
                all_filenames.add(fname)

            if args.extract_mels:
                for j, mel in enumerate(mels):
                    fname = Path(fpaths[j]).with_suffix('.pt').name
                    fpath = Path(args.dataset_path, 'mels', fname)
                    torch.save(mel[:, :mel_lens[j]], fpath)

            if args.extract_pitch:
                for j, p in enumerate(pitch):
                    fname = Path(fpaths[j]).with_suffix('.pt').name
                    fpath = Path(args.dataset_path, 'pitch', fname)
                    torch.save(p[:mel_lens[j]], fpath)

            if args.save_alignment_priors:
                for j, prior in enumerate(attn_prior):
                    fname = Path(fpaths[j]).with_suffix('.pt').name
                    fpath = Path(args.dataset_path, 'alignment_priors', fname)
                    torch.save(prior[:mel_lens[j], :input_lens[j]], fpath)
Exemple #11
0
                del train_dataloader
                # thread.join()
                # Make sure pool has finished and switch train_dataloader
                # NOTE: Will block until complete
                train_dataloader, data_file = dataset_future.result(
                    timeout=None)

            epoch += 1


if __name__ == "__main__":

    now = time.time()
    args, final_loss, train_time_raw = main()
    gpu_count = args.n_gpu
    args.max_steps += args.phase1_end_step if args.phase2 else 0
    if torch.distributed.is_initialized():
        gpu_count = torch.distributed.get_world_size()
    if is_main_process():
        e2e_time = time.time() - now
        training_perf = args.train_batch_size * args.gradient_accumulation_steps * gpu_count\
                        * (args.max_steps - args.resume_step + skipped_steps) / train_time_raw
        dllogger.log(step=tuple(),
                     data={
                         "e2e_train_time": e2e_time,
                         "training_sequences_per_second": training_perf,
                         "final_loss": final_loss,
                         "raw_train_time": train_time_raw
                     })
    dllogger.flush()
Exemple #12
0
def main():

    parser = get_parser()
    args = parser.parse_args()

    log_fpath = args.log_file or str(Path(args.output_dir, 'nvlog_infer.json'))
    log_fpath = unique_log_fpath(log_fpath)
    dllogger.init(backends=[JSONStreamBackend(Verbosity.DEFAULT, log_fpath),
                            StdOutBackend(Verbosity.VERBOSE,
                                          metric_format=stdout_metric_format)])

    [dllogger.log("PARAMETER", {k: v}) for k, v in vars(args).items()]

    for step in ['DNN', 'data+DNN', 'data']:
        for c in [0.99, 0.95, 0.9, 0.5]:
            cs = 'avg' if c == 0.5 else f'{int(100*c)}%'
            dllogger.metadata(f'{step.lower()}_latency_{c}',
                              {'name': f'{step} latency {cs}',
                               'format': ':>7.2f', 'unit': 'ms'})
    dllogger.metadata(
        'eval_wer', {'name': 'WER', 'format': ':>3.2f', 'unit': '%'})

    if args.cpu:
        device = torch.device('cpu')
    else:
        assert torch.cuda.is_available()
        device = torch.device('cuda')
        torch.backends.cudnn.benchmark = args.cudnn_benchmark

    if args.seed is not None:
        torch.manual_seed(args.seed + args.local_rank)
        np.random.seed(args.seed + args.local_rank)
        random.seed(args.seed + args.local_rank)

    # set up distributed training
    multi_gpu = not args.cpu and int(os.environ.get('WORLD_SIZE', 1)) > 1
    if multi_gpu:
        torch.cuda.set_device(args.local_rank)
        distrib.init_process_group(backend='nccl', init_method='env://')
        print_once(f'Inference with {distrib.get_world_size()} GPUs')

    cfg = config.load(args.model_config)
    config.apply_config_overrides(cfg, args)

    symbols = helpers.add_ctc_blank(cfg['labels'])

    use_dali = args.dali_device in ('cpu', 'gpu')
    dataset_kw, features_kw = config.input(cfg, 'val')

    measure_perf = args.steps > 0

    # dataset
    if args.transcribe_wav or args.transcribe_filelist:

        if use_dali:
            print("DALI supported only with input .json files; disabling")
            use_dali = False

        assert not (args.transcribe_wav and args.transcribe_filelist)

        if args.transcribe_wav:
            dataset = SingleAudioDataset(args.transcribe_wav)
        else:
            dataset = FilelistDataset(args.transcribe_filelist)

        data_loader = get_data_loader(dataset,
                                      batch_size=1,
                                      multi_gpu=multi_gpu,
                                      shuffle=False,
                                      num_workers=0,
                                      drop_last=(True if measure_perf else False))

        _, features_kw = config.input(cfg, 'val')
        assert not features_kw['pad_to_max_duration']
        feat_proc = FilterbankFeatures(**features_kw)

    elif use_dali:
        # pad_to_max_duration is not supported by DALI - have simple padders
        if features_kw['pad_to_max_duration']:
            feat_proc = BaseFeatures(
                pad_align=features_kw['pad_align'],
                pad_to_max_duration=True,
                max_duration=features_kw['max_duration'],
                sample_rate=features_kw['sample_rate'],
                window_size=features_kw['window_size'],
                window_stride=features_kw['window_stride'])
            features_kw['pad_to_max_duration'] = False
        else:
            feat_proc = None

        data_loader = DaliDataLoader(
            gpu_id=args.local_rank or 0,
            dataset_path=args.dataset_dir,
            config_data=dataset_kw,
            config_features=features_kw,
            json_names=args.val_manifests,
            batch_size=args.batch_size,
            pipeline_type=("train" if measure_perf else "val"),  # no drop_last
            device_type=args.dali_device,
            symbols=symbols)

    else:
        dataset = AudioDataset(args.dataset_dir,
                               args.val_manifests,
                               symbols,
                               **dataset_kw)

        data_loader = get_data_loader(dataset,
                                      args.batch_size,
                                      multi_gpu=multi_gpu,
                                      shuffle=False,
                                      num_workers=4,
                                      drop_last=False)

        feat_proc = FilterbankFeatures(**features_kw)

    model = Jasper(encoder_kw=config.encoder(cfg),
                   decoder_kw=config.decoder(cfg, n_classes=len(symbols)))

    if args.ckpt is not None:
        print(f'Loading the model from {args.ckpt} ...')
        checkpoint = torch.load(args.ckpt, map_location="cpu")
        key = 'ema_state_dict' if args.ema else 'state_dict'
        state_dict = helpers.convert_v1_state_dict(checkpoint[key])
        model.load_state_dict(state_dict, strict=True)

    model.to(device)
    model.eval()

    if feat_proc is not None:
        feat_proc.to(device)
        feat_proc.eval()

    if args.amp:
        model = model.half()

    if args.torchscript:
        greedy_decoder = GreedyCTCDecoder()

        feat_proc, model, greedy_decoder = torchscript_export(
            data_loader, feat_proc, model, greedy_decoder, args.output_dir,
            use_amp=args.amp, use_conv_masks=True, model_toml=args.model_toml,
            device=device, save=args.torchscript_export)

    if multi_gpu:
        model = DistributedDataParallel(model)

    agg = {'txts': [], 'preds': [], 'logits': []}
    dur = {'data': [], 'dnn': [], 'data+dnn': []}

    looped_loader = chain.from_iterable(repeat(data_loader))
    greedy_decoder = GreedyCTCDecoder()

    sync = lambda: torch.cuda.synchronize() if device.type == 'cuda' else None

    steps = args.steps + args.warmup_steps or len(data_loader)
    with torch.no_grad():

        for it, batch in enumerate(tqdm(looped_loader, initial=1, total=steps)):

            if use_dali:
                feats, feat_lens, txt, txt_lens = batch
                if feat_proc is not None:
                    feats, feat_lens = feat_proc(feats, feat_lens)
            else:
                batch = [t.to(device, non_blocking=True) for t in batch]
                audio, audio_lens, txt, txt_lens = batch
                feats, feat_lens = feat_proc(audio, audio_lens)

            sync()
            t1 = time.perf_counter()

            if args.amp:
                feats = feats.half()

            feats = F.pad(feats, (args.pad_leading, 0))
            feat_lens += args.pad_leading

            if model.encoder.use_conv_masks:
                log_probs, log_prob_lens = model(feats, feat_lens)
            else:
                log_probs = model(feats, feat_lens)

            preds = greedy_decoder(log_probs)

            sync()
            t2 = time.perf_counter()

            # burn-in period; wait for a new loader due to num_workers
            if it >= 1 and (args.steps == 0 or it >= args.warmup_steps):
                dur['data'].append(t1 - t0)
                dur['dnn'].append(t2 - t1)
                dur['data+dnn'].append(t2 - t0)

            if txt is not None:
                agg['txts'] += helpers.gather_transcripts([txt], [txt_lens],
                                                          symbols)
            agg['preds'] += helpers.gather_predictions([preds], symbols)
            agg['logits'].append(log_probs)

            if it + 1 == steps:
                break

            sync()
            t0 = time.perf_counter()

        # communicate the results
        if args.transcribe_wav:
            for idx, p in enumerate(agg['preds']):
                print_once(f'Prediction {idx+1: >3}: {p}')

        elif args.transcribe_filelist:
            pass

        elif not multi_gpu or distrib.get_rank() == 0:
            wer, _ = process_evaluation_epoch(agg)

            dllogger.log(step=(), data={'eval_wer': 100 * wer})

        if args.save_predictions:
            with open(args.save_predictions, 'w') as f:
                f.write('\n'.join(agg['preds']))

        if args.save_logits:
            logits = torch.cat(agg['logits'], dim=0).cpu()
            torch.save(logits, args.save_logits)

    # report timings
    if len(dur['data']) >= 20:
        ratios = [0.9, 0.95, 0.99]
        for stage in dur:
            lat = durs_to_percentiles(dur[stage], ratios)
            for k in [0.99, 0.95, 0.9, 0.5]:
                kk = str(k).replace('.', '_')
                dllogger.log(step=(), data={f'{stage.lower()}_latency_{kk}': lat[k]})

    else:
        print_once('Not enough samples to measure latencies.')
Exemple #13
0
def main():
    parser = argparse.ArgumentParser(
        description='PyTorch TTS Data Pre-processing')
    parser = parse_args(parser)
    args, unk_args = parser.parse_known_args()
    if len(unk_args) > 0:
        raise ValueError(f'Invalid options {unk_args}')

    if args.extract_pitch_char:
        assert args.extract_durations, "Durations required for pitch extraction"

    DLLogger.init(backends=[
        JSONStreamBackend(Verbosity.DEFAULT, args.log_file),
        StdOutBackend(Verbosity.VERBOSE)
    ])
    for k, v in vars(args).items():
        DLLogger.log(step="PARAMETER", data={k: v})

    model = load_and_setup_model(
        'Tacotron2',
        parser,
        args.tacotron2_checkpoint,
        amp=False,
        device=torch.device('cuda' if args.cuda else 'cpu'),
        forward_is_infer=False,
        ema=False)

    if args.train_mode:
        model.train()

    # n_mel_channels arg has been consumed by model's arg parser
    args.n_mel_channels = model.n_mel_channels

    for datum in ('mels', 'mels_teacher', 'attentions', 'durations',
                  'pitch_mel', 'pitch_char', 'pitch_trichar'):
        if getattr(args, f'extract_{datum}'):
            Path(args.dataset_path, datum).mkdir(parents=False, exist_ok=True)

    filenames = [
        Path(l.split('|')[0]).stem for l in open(args.wav_text_filelist, 'r')
    ]
    # Compatibility with Tacotron2 Data loader
    args.n_speakers = 1
    dataset = FilenamedLoader(filenames,
                              args.dataset_path,
                              args.wav_text_filelist,
                              args,
                              load_mel_from_disk=False)
    # TextMelCollate supports only n_frames_per_step=1
    data_loader = DataLoader(dataset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             sampler=None,
                             num_workers=0,
                             collate_fn=TextMelCollate(1),
                             pin_memory=False,
                             drop_last=False)
    pitch_vecs = {'mel': {}, 'char': {}, 'trichar': {}}
    for i, batch in enumerate(data_loader):
        tik = time.time()
        fnames = batch[-1]
        x, _, _ = batch_to_gpu(batch[:-1])
        _, text_lens, mels_padded, _, mel_lens = x

        for j, mel in enumerate(mels_padded):
            fpath = Path(args.dataset_path, 'mels', fnames[j] + '.pt')
            torch.save(mel[:, :mel_lens[j]].cpu(), fpath)

        with torch.no_grad():
            out_mels, out_mels_postnet, _, alignments = model.forward(x)

        if args.extract_mels_teacher:
            for j, mel in enumerate(out_mels_postnet):
                fpath = Path(args.dataset_path, 'mels_teacher',
                             fnames[j] + '.pt')
                torch.save(mel[:, :mel_lens[j]].cpu(), fpath)
        if args.extract_attentions:
            for j, ali in enumerate(alignments):
                ali = ali[:mel_lens[j], :text_lens[j]]
                fpath = Path(args.dataset_path, 'attentions',
                             fnames[j] + '.pt')
                torch.save(ali.cpu(), fpath)
        durations = []
        if args.extract_durations:
            for j, ali in enumerate(alignments):
                text_len = text_lens[j]
                ali = ali[:mel_lens[j], :text_len]
                dur = torch.histc(torch.argmax(ali, dim=1),
                                  min=0,
                                  max=text_len - 1,
                                  bins=text_len)
                durations.append(dur)
                fpath = Path(args.dataset_path, 'durations', fnames[j] + '.pt')
                torch.save(dur.cpu().int(), fpath)
        if args.extract_pitch_mel or args.extract_pitch_char or args.extract_pitch_trichar:
            for j, dur in enumerate(durations):
                fpath = Path(args.dataset_path, 'pitch_char',
                             fnames[j] + '.pt')
                wav = Path(args.dataset_path, 'wavs', fnames[j] + '.wav')
                p_mel, p_char, p_trichar = calculate_pitch(
                    str(wav),
                    dur.cpu().numpy())
                pitch_vecs['mel'][fnames[j]] = p_mel
                pitch_vecs['char'][fnames[j]] = p_char
                pitch_vecs['trichar'][fnames[j]] = p_trichar

        nseconds = time.time() - tik
        DLLogger.log(step=f'{i+1}/{len(data_loader)} ({nseconds:.2f}s)',
                     data={})

    if args.extract_pitch_mel:
        normalize_pitch_vectors(pitch_vecs['mel'])
        for fname, pitch in pitch_vecs['mel'].items():
            fpath = Path(args.dataset_path, 'pitch_mel', fname + '.pt')
            torch.save(torch.from_numpy(pitch), fpath)

    if args.extract_pitch_char:
        mean, std = normalize_pitch_vectors(pitch_vecs['char'])
        for fname, pitch in pitch_vecs['char'].items():
            fpath = Path(args.dataset_path, 'pitch_char', fname + '.pt')
            torch.save(torch.from_numpy(pitch), fpath)
        save_stats(args.dataset_path, args.wav_text_filelist, 'pitch_char',
                   mean, std)

    if args.extract_pitch_trichar:
        normalize_pitch_vectors(pitch_vecs['trichar'])
        for fname, pitch in pitch_vecs['trichar'].items():
            fpath = Path(args.dataset_path, 'pitch_trichar', fname + '.pt')
            torch.save(torch.from_numpy(pitch), fpath)

    DLLogger.flush()
Exemple #14
0
def main():
    """
    Launches text to speech (inference).
    Inference is executed on a single GPU or CPU.
    """
    parser = argparse.ArgumentParser(
        description='PyTorch Tacotron 2 Inference')
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()
    use_custom_naming = args.custom_name
    input_path = args.input
    text_cleaners = args.text_cleaners

    check_directory_and_create(args.output, exists_warning=True)

    # import pdb; pdb.set_trace()
    DLLogger.init(backends=[
        JSONStreamBackend(Verbosity.DEFAULT, args.output + '/' +
                          args.log_file),
        StdOutBackend(Verbosity.VERBOSE)
    ])
    for k, v in vars(args).items():
        DLLogger.log(step="PARAMETER", data={k: v})
    DLLogger.log(step="PARAMETER", data={'model_name': 'Tacotron2_PyT'})

    if args.use_extracted_mels:
        print(f"mel found in {args.mel_path}")
        mel = torch.load(args.mel_path)
        mel = mel.unsqueeze(0)
        print(f"The size of the mel we just loaded is {mel.shape}")
        audios = apply_griffin_lim(args, mel)
    else:
        tacotron2 = load_and_setup_model('Tacotron2',
                                         parser,
                                         args.tacotron2,
                                         args.fp16,
                                         args.cpu,
                                         forward_is_infer=True)

        if not args.use_griffin_lim:
            waveglow = \
                load_and_setup_model('WaveGlow', parser, args.waveglow,
                                    args.fp16, args.cpu, forward_is_infer=True)
            denoiser = Denoiser(waveglow)
            if not args.cpu:
                denoiser.cuda()

        jitted_tacotron2 = torch.jit.script(tacotron2)

        texts = []
        try:
            f = open(args.input, 'r')
            texts = f.readlines()
        except:
            print("Could not read file")
            sys.exit(1)

        if args.include_warmup and (not args.use_griffin_lim):
            sequence = torch.randint(low=0, high=148, size=(1, 50)).long()
            input_lengths = torch.IntTensor([sequence.size(1)]).long()
            if not args.cpu:
                sequence = sequence.cuda()
                input_lengths = input_lengths.cuda()
            for i in range(3):
                with torch.no_grad():
                    mel, mel_lengths, _ = jitted_tacotron2(
                        sequence, input_lengths)
                    _ = waveglow(mel)

        measurements = {}

        sequences_padded, input_lengths = \
            prepare_input_sequence(texts, args.cpu, text_cleaners)

        with torch.no_grad(), MeasureTime(measurements, "tacotron2_time",
                                          args.cpu):
            mel, mel_lengths, alignments = jitted_tacotron2(
                sequences_padded, input_lengths)

        if args.use_griffin_lim:
            print(f"The size of the generated mel spec is {mel.shape}")
            audios = apply_griffin_lim(args, mel)
            # import pdb; pdb.set_trace()
            # audios = audios.cpu().numpy()
            #audio = audio.astype('int16')
            # audio_path = os.path.join('samples', "{}_synthesis.wav".format(out_filename))
            # write(audio_path, hparams.sampling_rate, audio)
            # print(audio_path)
        else:
            with torch.no_grad(), MeasureTime(measurements, "waveglow_time",
                                              args.cpu):
                audios = waveglow(mel, sigma=args.sigma_infer)
                audios = audios.float()
            with torch.no_grad(), MeasureTime(measurements, "denoiser_time",
                                              args.cpu):
                audios = denoiser(audios,
                                  strength=args.denoising_strength).squeeze(1)

            print("Stopping after", mel.size(2), "decoder steps")

            tacotron2_infer_perf = mel.size(0) * mel.size(
                2) / measurements['tacotron2_time']
            waveglow_infer_perf = audios.size(0) * audios.size(
                1) / measurements['waveglow_time']

            DLLogger.log(
                step=0, data={"tacotron2_items_per_sec": tacotron2_infer_perf})
            DLLogger.log(
                step=0,
                data={"tacotron2_latency": measurements['tacotron2_time']})
            DLLogger.log(step=0,
                         data={"waveglow_items_per_sec": waveglow_infer_perf})
            DLLogger.log(
                step=0,
                data={"waveglow_latency": measurements['waveglow_time']})
            DLLogger.log(
                step=0,
                data={"denoiser_latency": measurements['denoiser_time']})
            DLLogger.log(step=0,
                         data={
                             "latency": (measurements['tacotron2_time'] +
                                         measurements['waveglow_time'] +
                                         measurements['denoiser_time'])
                         })

    for i, audio in enumerate(audios):
        if use_custom_naming:
            if args.use_extracted_mels:
                custom_name = (args.mel_path.split("/")[-1]).split(".")[0]
            else:
                custom_name = (input_path.split("/")[-1]).split(".")[0]
            custom_path = os.path.join(args.output, custom_name)
            if not args.use_extracted_mels:
                # save alignment
                import pdb
                pdb.set_trace()
                plt.imshow(alignments[i].float().data.cpu().numpy().T,
                           aspect="auto",
                           origin="lower")
                figure_path = custom_path + "_alignment.png"
                plt.savefig(figure_path)
                meltitle = "_predicted"
            else:
                meltitle = "_extracetd"
                # save predicted mel
            # import pdb; pdb.set_trace()
            plot_mel_spectrogram(
                mel,
                title=meltitle,
                dirname=custom_path,
                append_name=True,
                load_mel_path=False,
                # load_mel_path=True
            )
            # save generated audio
            # if not args.use_griffin_lim:
            if not args.use_extracted_mels:
                audio = audio[:mel_lengths[i] * args.stft_hop_length]
            audio = audio / torch.max(torch.abs(audio))
            # custom_name = (input_path.split("/")[-1]).split(".")[0]
            audio_path = custom_path + ".wav"
            write(audio_path, args.sampling_rate, audio.cpu().numpy())
        else:
            plt.imshow(alignments[i].float().data.cpu().numpy().T,
                       aspect="auto",
                       origin="lower")
            # figure_path = args.output+"alignment_"+str(i)+"_"+args.suffix+".png"
            figure_path = "alignment_" + str(i) + "_" + args.suffix + ".png"
            # import pdb; pdb.set_trace()
            figure_path = os.path.join(args.output, figure_path)
            plt.savefig(figure_path)
            audio = audio[:mel_lengths[i] * args.stft_hop_length]
            audio = audio / torch.max(torch.abs(audio))
            audio_path = \
                os.path.join(args.output, "audio_"+str(i)+"_"+args.suffix+".wav")
            write(audio_path, args.sampling_rate, audio.cpu().numpy())

    DLLogger.flush()
def main():
    args = parse_args()
    utils.gpu_affinity.set_affinity(args.local_rank)

    # Initialize device and distributed backend
    torch.cuda.set_device(args.local_rank)
    l2_promote()
    device = torch.device('cuda' if args.cuda else 'cpu')
    utils.distributed.init_distributed(args.cuda)

    args.work_dir = utils.exp_utils.build_work_dir_name(
        args.work_dir,
        args.dataset,
        args.append_dataset,
        args.append_time,
    )

    with utils.distributed.sync_workers() as rank:
        if rank == 0:
            create_exp_dir(args.work_dir,
                           scripts_to_save=['train.py', 'mem_transformer.py'],
                           debug=args.debug)

    # Setup logging
    if args.log_all_ranks:
        log_file = f'train_log_rank_{utils.distributed.get_rank()}.log'
    else:
        log_file = args.txtlog_file
    dllog_file = args.dllog_file
    log_file = os.path.join(args.work_dir, log_file)
    dllog_file = os.path.join(args.work_dir, dllog_file)

    if args.debug:
        log_file = os.devnull
        dllog_file = os.devnull

    utils.exp_utils.setup_logging(
        log_all_ranks=args.log_all_ranks,
        filename=log_file,
    )
    utils.exp_utils.setup_dllogger(enabled=True, filename=dllog_file)

    if args.local_batch_size is not None:
        world_size = utils.distributed.get_world_size()
        args.batch_size = world_size * args.local_batch_size
        logging.info(f'--local_batch_size was set, adjusting global batch size'
                     f' to {args.batch_size} (local_batch_size * world_size)')

    logging.info(args)
    dllogger.log(step='PARAMETER', data=vars(args))

    logging.info(f'world size: {utils.distributed.get_world_size()}')

    if not args.no_env:
        log_env_info()

    register_ignoring_timeout_handler()

    # Set the random seed manually for reproducibility.
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    ###########################################################################
    # Load data
    ###########################################################################
    corpus = get_lm_corpus(args.data, args.dataset, args.vocab)
    ntokens = len(corpus.vocab)
    vocab = corpus.vocab
    args.n_token = ntokens

    if args.mem_len == 0:
        eval_mem_len = 0
    else:
        eval_mem_len = args.mem_len + args.tgt_len - args.eval_tgt_len

    tr_iter = corpus.get_iterator('train',
                                  args.batch_size,
                                  args.tgt_len,
                                  device=device,
                                  ext_len=args.ext_len)
    va_iter = corpus.get_iterator('valid',
                                  args.eval_batch_size,
                                  args.eval_tgt_len,
                                  device=device,
                                  mem_len=eval_mem_len,
                                  ext_len=args.ext_len)
    te_iter = corpus.get_iterator('test',
                                  args.eval_batch_size,
                                  args.eval_tgt_len,
                                  device=device,
                                  mem_len=eval_mem_len,
                                  ext_len=args.ext_len)

    # adaptive softmax / embedding
    cutoffs, tie_projs = [], [False]
    if args.adaptive:
        assert args.dataset in ['wt103', 'lm1b']
        if args.dataset == 'wt103':
            cutoffs = [19997, 39997, 199997]
            tie_projs += [True] * len(cutoffs)
        elif args.dataset == 'lm1b':
            cutoffs = [59997, 99997, 639997]
            tie_projs += [False] * len(cutoffs)

    ###########################################################################
    # Build the model
    ###########################################################################
    model_config = {
        'n_token': ntokens,
        'n_layer': args.n_layer,
        'n_head': args.n_head,
        'd_model': args.d_model,
        'd_head': args.d_head,
        'd_inner': args.d_inner,
        'dropout': args.dropout,
        'dropatt': args.dropatt,
        'dtype': None,
        'tie_weight': args.tied,
        'd_embed': args.d_embed,
        'div_val': args.div_val,
        'tie_projs': tie_projs,
        'pre_lnorm': args.pre_lnorm,
        'tgt_len': args.tgt_len,
        'ext_len': args.ext_len,
        'mem_len': args.mem_len,
        'cutoffs': cutoffs,
        'same_length': args.same_length,
        'attn_type': args.attn_type,
        'clamp_len': args.clamp_len,
        'sample_softmax': args.sample_softmax,
    }

    model = MemTransformerLM(**model_config)

    model.apply(functools.partial(weights_init, args=args))
    # ensure embedding init is not overridden by out_layer in case of weight sharing
    model.word_emb.apply(functools.partial(weights_init, args=args))

    args.n_all_param = sum([p.nelement() for p in model.parameters()])
    args.n_nonemb_param = sum(
        [p.nelement() for p in model.layers.parameters()])

    # optimizer
    if args.optim.lower() == 'sgd':
        if args.sample_softmax > 0:
            dense_params, sparse_params = [], []
            for param in model.parameters():
                if param.size() == model.word_emb.weight.size():
                    sparse_params.append(param)
                else:
                    dense_params.append(param)
            optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2)
            optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom)
        else:
            optimizer = optim.SGD(model.parameters(),
                                  lr=args.lr,
                                  momentum=args.mom)
            optimizer_sparse = None
    elif args.optim.lower() == 'adam':
        if args.sample_softmax > 0:
            dense_params, sparse_params = [], []
            for param in model.parameters():
                if param.size() == model.word_emb.weight.size():
                    sparse_params.append(param)
                else:
                    dense_params.append(param)
            optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr)
            optimizer = optim.Adam(dense_params,
                                   lr=args.lr,
                                   weight_decay=args.weight_decay)
        else:
            optimizer = optim.Adam(model.parameters(),
                                   lr=args.lr,
                                   weight_decay=args.weight_decay)
            optimizer_sparse = None
    elif args.optim.lower() == 'adagrad':
        optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
        optimizer_sparse = None
    elif args.optim.lower() == 'lamb':
        optimizer = lamb.Lamb(model.parameters(),
                              lr=args.lr,
                              weight_decay=args.weight_decay)
        optimizer_sparse = None
    elif args.optim.lower() == 'jitlamb':
        optimizer = lamb.JITLamb(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
        optimizer_sparse = None

    model = model.to(device)

    if args.fp16:
        model, optimizer = amp.initialize(
            model,
            optimizer,
            opt_level=args.amp_mode,
        )

    if args.multi_gpu == 'ddp' and torch.distributed.is_initialized():
        para_model = DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            broadcast_buffers=False,
            find_unused_parameters=True,
        )
    elif args.multi_gpu == 'dp':
        if args.gpu0_bsz >= 0:
            para_model = BalancedDataParallel(args.gpu0_bsz //
                                              args.batch_chunk,
                                              model,
                                              dim=1).to(device)
        else:
            para_model = nn.DataParallel(model, dim=1).to(device)
    else:
        para_model = model

    # scheduler
    if args.scheduler == 'cosine':
        if args.max_step_scheduler:
            max_step = args.max_step_scheduler
        else:
            max_step = args.max_step

        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                         max_step -
                                                         args.warmup_step,
                                                         eta_min=args.eta_min)
        if args.sample_softmax > 0 and optimizer_sparse is not None:
            scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(
                optimizer_sparse,
                max_step - args.warmup_step,
                eta_min=args.eta_min)
        else:
            scheduler_sparse = None
    elif args.scheduler == 'inv_sqrt':
        # originally used for Transformer (in Attention is all you need)
        def lr_lambda(step):
            # return a multiplier instead of a learning rate
            if step == 0 and args.warmup_step == 0:
                return 1.
            else:
                return 1. / (step ** 0.5) if step > args.warmup_step \
                    else step / (args.warmup_step ** 1.5)

        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
        if args.sample_softmax > 0 and optimizer_sparse is not None:
            scheduler_sparse = optim.lr_scheduler.LambdaLR(optimizer_sparse,
                                                           lr_lambda=lr_lambda)
        else:
            scheduler_sparse = None
    elif args.scheduler == 'dev_perf':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            factor=args.decay_rate,
            patience=args.patience,
            min_lr=args.lr_min,
        )
        if args.sample_softmax > 0 and optimizer_sparse is not None:
            scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(
                optimizer_sparse,
                factor=args.decay_rate,
                patience=args.patience,
                min_lr=args.lr_min,
            )
        else:
            scheduler_sparse = None
    elif args.scheduler == 'constant':
        pass

    logging.info('=' * 100)
    for k, v in args.__dict__.items():
        logging.info('    - {} : {}'.format(k, v))
    logging.info('=' * 100)
    logging.info('#params = {}'.format(args.n_all_param))
    logging.info('#non emb params = {}'.format(args.n_nonemb_param))

    train_step = 0
    start_epoch = 1
    last_batch = 0
    last_iter = 0
    best_val_loss = None

    if args.restart:
        try:
            checkpoint = load_checkpoint(args.restart)
            model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            scheduler.load_state_dict(checkpoint['scheduler_state'])
            if args.fp16:
                amp.load_state_dict(checkpoint['amp_state'])
            train_step = checkpoint['train_step']
            start_epoch = checkpoint['epoch']
            last_batch = checkpoint['batch']
            last_iter = checkpoint['last_iter']
            best_val_loss = checkpoint['best_val_loss']

            if train_step >= args.max_step:
                logging.info(
                    f'Loaded checkpoint after {train_step} steps, but '
                    f'this run was scheduled for a total of '
                    f'{args.max_step} steps, exiting')
                sys.exit(1)

            model.apply(functools.partial(update_dropout, args=args))
            model.apply(functools.partial(update_dropatt, args=args))
        except FileNotFoundError:
            logging.info(f'Could not load checkpoint from {args.restart}, '
                         f'starting training from random init')

    meters = {}
    warmup = args.mem_len // args.tgt_len + 2
    meters['train_throughput'] = AverageMeter(warmup=warmup)
    ###########################################################################
    # Train
    ###########################################################################
    # Loop over epochs.
    # At any point you can hit Ctrl + C to break out of training early.
    start_time = time.time()
    with TimeoutHandler() as timeout_handler:
        try:
            for epoch in itertools.count(start=start_epoch):
                if args.roll:
                    tr_iter.roll(seed=args.seed + epoch)
                train_step, best_val_loss = train(
                    tr_iter, va_iter, model, para_model, model_config,
                    optimizer, optimizer_sparse, scheduler, scheduler_sparse,
                    vocab, epoch, last_batch, last_iter, train_step,
                    best_val_loss, meters, timeout_handler, args)

                last_batch = 0
                last_iter = 0

                if train_step == args.max_step:
                    logging.info('-' * 100)
                    logging.info('End of training')
                    break
        except KeyboardInterrupt:
            logging.info('-' * 100)
            logging.info('Exiting from training early')
    elapsed = time.time() - start_time

    ###########################################################################
    # Test
    ###########################################################################
    summary = {}
    test_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
    if not args.debug and not args.no_eval and os.path.exists(test_path):
        # Load the best saved model.
        checkpoint = load_checkpoint(test_path)
        model.load_state_dict(checkpoint['model_state'])

        # Run on test data.
        test_start_time = time.time()
        test_loss = evaluate(te_iter, model, args)
        test_loss = utils.distributed.all_reduce_item(test_loss, 'mean')
        test_elapsed = time.time() - test_start_time

        logging.info('=' * 100)
        if args.dataset in ['enwik8', 'text8']:
            logging.info(
                '| End of training | test time: {:5.2f}s | test loss {:5.2f} | test bpc {:9.5f}'
                .format(test_elapsed, test_loss, test_loss / math.log(2)))
        else:
            logging.info(
                '| End of training | test time: {:5.2f}s | test loss {:5.2f} | test ppl {:9.3f}'
                .format(test_elapsed, test_loss, math.exp(test_loss)))
        logging.info('=' * 100)

        summary.update({
            'test_elapsed': test_elapsed,
            'test_loss': test_loss,
        })

        if args.dataset in ['enwik8', 'text8']:
            summary['test_bits_per_character'] = test_loss / math.log(2)
        else:
            summary['test_perplexity'] = math.exp(test_loss)

    logging.info(f'Training time: {(elapsed / 60):.2f} minutes')
    logging.info(
        f'Training throughput: {meters["train_throughput"].avg:.2f} tok/s')

    if best_val_loss:
        val_perplexity = math.exp(best_val_loss)
    else:
        val_perplexity = None

    summary.update({
        'train_throughput': meters['train_throughput'].avg,
        'train_elapsed': elapsed / 60,
        'valid_loss': best_val_loss,
        'valid_perplexity': val_perplexity,
    })
    dllogger.log(step=tuple(), data=summary)

    passed = benchmark(target_perplexity=args.target_perplexity,
                       test_perplexity=val_perplexity,
                       target_throughput=args.target_throughput,
                       test_throughput=meters['train_throughput'].avg)
    if not passed:
        sys.exit(1)
def main():
    global timeout_sent

    args = parse_arguments()

    random.seed(args.seed + args.local_rank)
    np.random.seed(args.seed + args.local_rank)
    torch.manual_seed(args.seed + args.local_rank)
    torch.cuda.manual_seed(args.seed + args.local_rank)

    device, args = setup_training(args)
    dllogger.log(step="PARAMETER", data={"Config": [str(args)]})

    # Prepare optimizer
    model, optimizer, grad_scaler, lr_scheduler, checkpoint, global_step, criterion, epoch = prepare_model_and_optimizer(
        args,
        device,
        sequence_output_is_dense=not args.no_dense_sequence_output)
    # Prepare the data loader.
    if is_main_process():
        tic = time.perf_counter()
    train_dataloader = lddl.torch.get_bert_pretrain_data_loader(
        args.input_dir,
        local_rank=max(args.local_rank, 0),
        vocab_file=args.vocab_file,
        data_loader_kwargs={
            'batch_size': args.train_batch_size * args.n_gpu,
            'num_workers': args.num_workers,
            'pin_memory': True,
        },
        base_seed=args.seed,
        log_dir=None if args.output_dir is None else os.path.join(
            args.output_dir, 'lddl_log'),
        log_level=logging.WARNING,
        start_epoch=epoch,
    )
    if is_main_process():
        print('get_bert_pretrain_data_loader took {} s!'.format(
            time.perf_counter() - tic))

    if is_main_process():
        dllogger.log(step="PARAMETER", data={"SEED": args.seed})
        dllogger.log(step="PARAMETER", data={"train_start": True})
        dllogger.log(step="PARAMETER",
                     data={"batch_size_per_gpu": args.train_batch_size})
        dllogger.log(step="PARAMETER",
                     data={"learning_rate": args.learning_rate})

    model.train()
    most_recent_ckpts_paths = []

    stats = SyncFreeStats()
    # Host Only Stats
    stats.add_stat('model_step')
    # Device/Host Sync-ed Stats
    stats.add_stat('optimizer_step',
                   dtype=torch.int32,
                   device_func=(lambda: optimizer.param_groups[0]['step']))
    stats.add_stat('average_loss',
                   dtype=torch.float32,
                   device_tensor=torch.zeros(1,
                                             dtype=torch.float32,
                                             device=device))
    stats.add_stat('learning_rate',
                   dtype=torch.float32,
                   device_func=(lambda: optimizer.param_groups[0]['lr']))
    if grad_scaler.is_enabled():
        # This stat only indicates a skipped step occurred.  It does not accumulate the number of skipped steps
        stats.add_stat(
            'skip_optimizer_step',
            dtype=torch.float32,
            device_func=(
                lambda: grad_scaler._found_inf_per_device(optimizer)[device]))
        stats.add_stat(
            'skipped_optimizer_steps',
            dtype=torch.float32,
            device_tensor=torch.zeros(1, dtype=torch.float32, device=device),
            device_func=(lambda x: x.add_(
                grad_scaler._found_inf_per_device(optimizer)[device])))
    else:
        stats.add_stat('skip_optimizer_step', dtype=torch.float32)
        stats.add_stat('skipped_optimizer_steps', dtype=torch.float32)

    static_gpu_batch = None
    full_cudagraph = None
    grad_accum_cudagraph = None
    if args.cuda_graphs:
        static_gpu_batch = {
            'input_ids':
            torch.ones(args.train_batch_size,
                       args.max_seq_length,
                       dtype=torch.int64,
                       device=device),
            'token_type_ids':
            torch.ones(args.train_batch_size,
                       args.max_seq_length,
                       dtype=torch.int64,
                       device=device),
            'attention_mask':
            torch.ones(args.train_batch_size,
                       args.max_seq_length,
                       dtype=torch.int64,
                       device=device),
            'labels':
            torch.ones(args.train_batch_size,
                       args.max_seq_length,
                       dtype=torch.int64,
                       device=device),
            'next_sentence_labels':
            torch.ones(args.train_batch_size, dtype=torch.int64,
                       device=device),
        }

        side_stream = torch.cuda.Stream()

        # Warmup Steps - includes jitting fusions
        side_stream = torch.cuda.Stream()
        side_stream.wait_stream(torch.cuda.current_stream())
        with torch.cuda.stream(side_stream):
            for _ in range(11):
                take_training_step(args, grad_scaler, model, criterion,
                                   static_gpu_batch, stats)
                take_optimizer_step(args, lr_scheduler, optimizer, grad_scaler,
                                    device, stats)
        torch.cuda.current_stream().wait_stream(side_stream)

        # Capture Graph
        full_cudagraph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(full_cudagraph):
            take_training_step(args, grad_scaler, model, criterion,
                               static_gpu_batch, stats)
            take_optimizer_step(args, lr_scheduler, optimizer, grad_scaler,
                                device, stats)

        # Warmup Steps - includes jitting fusions
        side_stream.wait_stream(torch.cuda.current_stream())
        with torch.cuda.stream(side_stream):
            for _ in range(3):
                with model.no_sync():
                    take_training_step(args, grad_scaler, model, criterion,
                                       static_gpu_batch, stats)
        torch.cuda.current_stream().wait_stream(side_stream)

        # Capture Graph
        grad_accum_cudagraph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(grad_accum_cudagraph):
            with model.no_sync():
                take_training_step(args, grad_scaler, model, criterion,
                                   static_gpu_batch, stats)

    train_iter = tqdm(
        train_dataloader,
        desc="Iteration",
        disable=args.disable_progress_bar,
        total=len(train_dataloader),
    ) if is_main_process() else train_dataloader

    raw_train_start = None

    # avoid nvfuser compilation times in measuring perf with phase2 binning
    # ideally skip > 3 * num_bins fwd+bwd iterations to start measuring perf
    skip_fwd_bwd_for_perf = 4
    if args.phase2:  #we use 8 bins with phase2
        skip_fwd_bwd_for_perf = 50

    while True:
        for step, batch in enumerate(train_iter):
            # The first training step is 1 and not 0 when gradient accumulating
            # in order to avoid an optimizer step on the very first step
            stats.host_stat('model_step').add_(1)
            grad_accumulation_step = (stats.host_stat_value('model_step') %
                                      args.gradient_accumulation_steps) != 0

            if raw_train_start is None and step == skip_fwd_bwd_for_perf:
                raw_train_start = time.time()

            # Execute Model Step
            if args.cuda_graphs:
                for k in batch.keys():
                    static_gpu_batch[k].copy_(batch[k], non_blocking=True)
                if grad_accumulation_step:
                    grad_accum_cudagraph.replay()
                else:
                    full_cudagraph.replay()
            else:
                batch = {
                    k: v.to(device, non_blocking=True)
                    for k, v in batch.items()
                }

                if args.allreduce_post_accumulation and grad_accumulation_step:
                    with model.no_sync():
                        take_training_step(args, grad_scaler, model, criterion,
                                           batch, stats)
                else:
                    take_training_step(args, grad_scaler, model, criterion,
                                       batch, stats)

                if not grad_accumulation_step:
                    take_optimizer_step(args, lr_scheduler, optimizer,
                                        grad_scaler, device, stats)

            # Log Optimizer Step
            if (not grad_accumulation_step) or timeout_sent:
                static_optimizer_step = stats.host_stat_value(
                    'model_step') // args.gradient_accumulation_steps
                dynamic_optimizer_step = static_optimizer_step - int(
                    stats.host_stat_value('skipped_optimizer_steps'))
                no_log_steps = static_optimizer_step % args.log_freq

                # Log Final Step (MAYBE)
                # Since the stats are asynchronously pushed from the GPU to CPU, they are not always reliable
                # Therefore, a synchronization is required to guarantee you see the intended value.
                # Without a synchronization, it is possible for some GPUs to go through the exit conditional
                # and others to not because they accidentally see a different value for `skipped_optimizer_steps`.
                # In order to remove most device syncs, synchronizations only begin in the last few steps
                # where the skipped step count matters.
                if static_optimizer_step >= args.steps_this_run or timeout_sent:
                    torch.cuda.synchronize()
                    dynamic_optimizer_step = static_optimizer_step - int(
                        stats.host_stat_value('skipped_optimizer_steps'))
                    if dynamic_optimizer_step >= args.steps_this_run or timeout_sent:
                        train_time_raw = time.time() - raw_train_start

                        last_num_steps = args.log_freq if no_log_steps == 0 else no_log_steps
                        stats.device_stat('average_loss').div_(
                            last_num_steps * args.gradient_accumulation_steps)
                        if (torch.distributed.is_initialized()):
                            stats.device_stat('average_loss').div_(
                                get_world_size())
                            torch.distributed.all_reduce(
                                stats.device_stat('average_loss'))

                        # We block on this copy to insure the final value
                        stats.host_stat('average_loss').copy_(
                            stats.device_stat('average_loss'))
                        if is_main_process():
                            dllogger.log(
                                step=(
                                    epoch,
                                    dynamic_optimizer_step,
                                ),
                                data={
                                    "final_loss":
                                    stats.host_stat_value('average_loss')
                                })

                        checkpoint_step(args, epoch, dynamic_optimizer_step,
                                        model, optimizer, grad_scaler,
                                        most_recent_ckpts_paths)

                        return args, train_time_raw, stats, skip_fwd_bwd_for_perf

                if no_log_steps == 0:
                    if is_main_process():
                        dllogger.log(
                            step=(
                                epoch,
                                dynamic_optimizer_step,
                            ),
                            data={
                                "average_loss":
                                stats.host_stat_value('average_loss') /
                                (args.log_freq *
                                 args.gradient_accumulation_steps),
                                "learning_rate":
                                stats.host_stat_value('learning_rate'),
                                "skipped_steps":
                                int(
                                    stats.host_stat_value(
                                        'skipped_optimizer_steps'))
                            })
                        if stats.host_stat_value('skip_optimizer_step') > 0.:
                            dllogger.log(
                                step="PARAMETER",
                                data={
                                    "loss_scale":
                                    grad_scaler._get_scale_async().item()
                                })

                    stats.device_stat('average_loss').zero_()

                    if not args.skip_checkpoint and (
                            dynamic_optimizer_step %
                            args.num_steps_per_checkpoint == 0):
                        checkpoint_step(args, epoch, dynamic_optimizer_step,
                                        model, optimizer, grad_scaler,
                                        most_recent_ckpts_paths)

        epoch += 1
Exemple #17
0
def train(train_loop_func, logger, args):
    # Check that GPUs are actually available
    use_cuda = not args.no_cuda

    # Setup multi-GPU if necessary
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
        args.N_gpu = torch.distributed.get_world_size()
    else:
        args.N_gpu = 1

    if args.seed is None:
        args.seed = np.random.randint(1e4)

    if args.distributed:
        args.seed = (args.seed + torch.distributed.get_rank()) % 2**32
    print("Using seed = {}".format(args.seed))
    torch.manual_seed(args.seed)
    np.random.seed(seed=args.seed)


    # Setup data, defaults
    dboxes = dboxes300_coco()
    encoder = Encoder(dboxes)
    cocoGt = get_coco_ground_truth(args)

    train_loader = get_train_loader(args, args.seed - 2**31, len_dataset=5002)

    val_dataset = get_val_dataset(args)
    val_dataloader = get_val_dataloader(val_dataset, args)

    ssd300 = SSD300(backbone=ResNet(args.backbone, args.backbone_path), label_num=2)
    args.learning_rate = args.learning_rate * args.N_gpu * (args.batch_size / 32)
    start_epoch = 0
    iteration = 0
    loss_func = Loss(dboxes)

    if use_cuda:
        ssd300.cuda()
        loss_func.cuda()

    optimizer = torch.optim.SGD(tencent_trick(ssd300), lr=args.learning_rate,
                                    momentum=args.momentum, weight_decay=args.weight_decay)
    scheduler = MultiStepLR(optimizer=optimizer, milestones=args.multistep, gamma=0.1)
    if args.amp:
        ssd300, optimizer = amp.initialize(ssd300, optimizer, opt_level='O2')

    if args.distributed:
        ssd300 = DDP(ssd300)

    if args.checkpoint is not None:
        if os.path.isfile(args.checkpoint):
            load_checkpoint(ssd300.module if args.distributed else ssd300, args.checkpoint)
            checkpoint = torch.load(args.checkpoint,
                                    map_location=lambda storage, loc: storage.cuda(torch.cuda.current_device()))
            start_epoch = checkpoint['epoch']
            iteration = checkpoint['iteration']
            scheduler.load_state_dict(checkpoint['scheduler'])
            optimizer.load_state_dict(checkpoint['optimizer'])
        else:
            print('Provided checkpoint is not path to a file')
            return

    inv_map = {v: k for k, v in val_dataset.label_map.items()}

    total_time = 0

    if args.mode == 'evaluation':
        acc = evaluate(ssd300, val_dataloader, cocoGt, encoder, inv_map, args)
        if args.local_rank == 0:
            print('Model precision {} mAP'.format(acc))

        return
    mean, std = generate_mean_std(args)

    for epoch in range(start_epoch, args.epochs):
        start_epoch_time = time.time()
        scheduler.step()
        iteration = train_loop_func(ssd300, loss_func, epoch, optimizer, train_loader, val_dataloader, encoder, iteration,
                                    logger, args, mean, std)
        end_epoch_time = time.time() - start_epoch_time
        total_time += end_epoch_time

        if args.local_rank == 0:
            logger.update_epoch_time(epoch, end_epoch_time)

        if epoch in args.evaluation:
            acc = evaluate(ssd300, val_dataloader, cocoGt, encoder, inv_map, args)

            if args.local_rank == 0:
                logger.update_epoch(epoch, acc)

        if args.save and args.local_rank == 0:
            print("saving model...")
            obj = {'epoch': epoch + 1,
                   'iteration': iteration,
                   'optimizer': optimizer.state_dict(),
                   'scheduler': scheduler.state_dict(),
                   'label_map': val_dataset.label_info}
            if args.distributed:
                obj['model'] = ssd300.module.state_dict()
            else:
                obj['model'] = ssd300.state_dict()
            save_path = os.path.join(args.save, f'epoch_{epoch}.pt')
            torch.save(obj, save_path)
            logger.log('model path', save_path)
        train_loader.reset()
    DLLogger.log((), { 'total time': total_time })
    logger.log_summary()
Exemple #18
0
def train(model, loss_fn, optimizer, data_loader_train, data_loader_test, scaled_lr):
    """Train and evaluate the model

    Args:
        model (dlrm):
        loss_fn (torch.nn.Module): Loss function
        optimizer (torch.nn.optim):
        data_loader_train (torch.utils.data.DataLoader):
        data_loader_test (torch.utils.data.DataLoader):
    """
    model.train()
    prefetching_enabled = is_data_prefetching_enabled()
    base_device = FLAGS.base_device
    print_freq = FLAGS.print_freq
    steps_per_epoch = len(data_loader_train)

    checkpoint_writer = make_serial_checkpoint_writer(
        embedding_indices=range(len(get_categorical_feature_sizes(FLAGS))),
        config=FLAGS.flag_values_dict()
    )

    test_freq = FLAGS.test_freq if FLAGS.test_freq is not None else steps_per_epoch - 1

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
    metric_logger.add_meter('step_time', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))

    if prefetching_enabled:
        data_stream = torch.cuda.Stream()

    timer = utils.StepTimer()

    best_auc = 0
    best_epoch = 0
    start_time = time()

    timer.click()

    for epoch in range(FLAGS.epochs):
        input_pipeline = iter(data_loader_train)

        if prefetching_enabled:
            input_pipeline = prefetcher(input_pipeline, data_stream)

        for step, batch in enumerate(input_pipeline):
            global_step = steps_per_epoch * epoch + step
            numerical_features, categorical_features, click = batch

            utils.lr_step(optimizer, num_warmup_iter=FLAGS.warmup_steps, current_step=global_step + 1,
                          base_lr=scaled_lr, warmup_factor=FLAGS.warmup_factor,
                          decay_steps=FLAGS.decay_steps, decay_start_step=FLAGS.decay_start_step)

            if FLAGS.max_steps and global_step > FLAGS.max_steps:
                print(F"Reached max global steps of {FLAGS.max_steps}. Stopping.")
                break

            if prefetching_enabled:
                torch.cuda.synchronize()

            output = model(numerical_features, categorical_features).squeeze().float()

            loss = loss_fn(output, click.squeeze())

            # Setting grad to None is faster than zero_grad()
            for param_group in optimizer.param_groups:
                for param in param_group['params']:
                    param.grad = None

            if FLAGS.amp:
                loss *= FLAGS.loss_scale
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            optimizer.step()

            if step % print_freq == 0 and step > 0:
                loss_value = loss.item()

                timer.click()

                if global_step < FLAGS.benchmark_warmup_steps:
                    metric_logger.update(
                        loss=loss_value, lr=optimizer.param_groups[0]["lr"])
                else:
                    unscale_factor = FLAGS.loss_scale if FLAGS.amp else 1
                    metric_logger.update(
                        loss=loss_value / unscale_factor,
                        step_time=timer.measured / FLAGS.print_freq,
                        lr=optimizer.param_groups[0]["lr"] * unscale_factor
                    )

                if global_step < FLAGS.benchmark_warmup_steps:
                    print(F'Warming up, step [{global_step}/{FLAGS.benchmark_warmup_steps}]')
                    continue

                eta_str = datetime.timedelta(seconds=int(metric_logger.step_time.global_avg * (steps_per_epoch - step)))
                metric_logger.print(
                    header=F"Epoch:[{epoch}/{FLAGS.epochs}] [{step}/{steps_per_epoch}]  eta: {eta_str}")

            if (global_step % test_freq == 0 and global_step > 0 and
                    global_step / steps_per_epoch >= FLAGS.test_after):
                loss, auc, test_step_time = evaluate(model, loss_fn, data_loader_test)
                print(F"Epoch {epoch} step {step}. Test loss {loss:.5f}, auc {auc:.6f}")

                if auc > best_auc:
                    best_auc = auc
                    best_epoch = epoch + ((step + 1) / steps_per_epoch)
                    maybe_save_checkpoint(checkpoint_writer, model, FLAGS.save_checkpoint_path)

                if FLAGS.auc_threshold and auc >= FLAGS.auc_threshold:
                    stop_time = time()
                    run_time_s = int(stop_time - start_time)
                    print(F"Hit target accuracy AUC {FLAGS.auc_threshold} at epoch "
                          F"{global_step/steps_per_epoch:.2f} in {run_time_s}s. "
                          F"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s.")
                    return

    stop_time = time()
    run_time_s = int(stop_time - start_time)

    print(F"Finished training in {run_time_s}s. "
          F"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s.")

    avg_throughput = FLAGS.batch_size / metric_logger.step_time.avg

    results = {'best_auc' : best_auc,
               'best_epoch' : best_epoch,
               'average_train_throughput' : avg_throughput}

    if 'test_step_time' in locals():
        avg_test_throughput = FLAGS.test_batch_size / test_step_time
        results['average_test_throughput'] = avg_test_throughput

    dllogger.log(data=results, step=tuple())
Exemple #19
0
            result = inference(
                model,
                data_loader_val,
                dataset_name=dataset_name,
                iou_types=iou_types,
                box_only=cfg.MODEL.RPN_ONLY,
                device=cfg.MODEL.DEVICE,
                expected_results=cfg.TEST.EXPECTED_RESULTS,
                expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
                output_folder=output_folder,
                skip_eval=args.skip_eval,
                dllogger=dllogger,
            )
        synchronize()
        results.append(result)

    if is_main_process() and not args.skip_eval:
        map_results, raw_results = results[0]
        bbox_map = map_results.results["bbox"]['AP']
        segm_map = map_results.results["segm"]['AP']
        dllogger.log(step=tuple(),
                     data={
                         "BBOX_mAP": bbox_map,
                         "MASK_mAP": segm_map
                     })


if __name__ == "__main__":
    main()
    dllogger.log(step=tuple(), data={})
def main(argv):
    validate_flags()
    torch.manual_seed(FLAGS.seed)

    utils.init_logging(log_path=FLAGS.log_path)
    dllogger.log(data=FLAGS.flag_values_dict(), step='PARAMETER')

    data_loader_train, data_loader_test = get_dataloaders(
        train_batch_size=FLAGS.batch_size,
        test_batch_size=FLAGS.test_batch_size)

    scaled_lr = FLAGS.lr / FLAGS.loss_scale if FLAGS.fp16 else FLAGS.lr

    model = create_model()

    optimizer = torch.optim.SGD(model.parameters(), lr=scaled_lr)

    if FLAGS.fp16 and FLAGS.mode == 'train':
        (model.top_mlp, model.bottom_mlp), optimizer = amp.initialize(
            [model.top_mlp, model.bottom_mlp],
            optimizer,
            opt_level="O2",
            loss_scale=1)
    elif FLAGS.fp16:
        model = model.half()

    loss_fn = torch.nn.BCEWithLogitsLoss(reduction="mean")
    loss_fn = torch.jit.trace(loss_fn.forward, (torch.rand(
        FLAGS.batch_size, 1).cuda(), torch.rand(FLAGS.batch_size, 1).cuda()))

    if FLAGS.mode == 'test':
        loss, auc, test_step_time = evaluate(model, loss_fn, data_loader_test)

        avg_test_throughput = FLAGS.batch_size / test_step_time
        results = {
            'auc': auc,
            'avg_inference_latency': test_step_time,
            'average_test_throughput': avg_test_throughput
        }
        dllogger.log(data=results, step=tuple())

        print(F"Finished testing. Test Loss {loss:.4f}, auc {auc:.4f}")
        return

    if FLAGS.mode == 'inference_benchmark':
        results = {}

        if FLAGS.fp16:
            # can use pure FP16 for inference
            model = model.half()

        for batch_size in FLAGS.inference_benchmark_batch_sizes:
            batch_size = int(batch_size)
            _, benchmark_data_loader = get_dataloaders(
                train_batch_size=batch_size, test_batch_size=batch_size)

            latencies = inference_benchmark(
                model=model,
                data_loader=benchmark_data_loader,
                num_batches=FLAGS.inference_benchmark_steps)

            print("All inference latencies: {}".format(latencies))

            mean_latency = np.mean(latencies)
            mean_inference_throughput = batch_size / mean_latency
            subresult = {
                F'mean_inference_latency_batch_{batch_size}':
                mean_latency,
                F'mean_inference_throughput_batch_{batch_size}':
                mean_inference_throughput
            }
            results.update(subresult)
        dllogger.log(data=results, step=tuple())

        print(F"Finished inference benchmark.")
        return

    if FLAGS.mode == 'train':
        train(model, loss_fn, optimizer, data_loader_train, data_loader_test,
              scaled_lr)
Exemple #21
0
 def log_parameter(self, data, verbosity=0):
     dllogger.log(step="PARAMETER", data=data, verbosity=verbosity)
def train(model, loss_fn, optimizer, data_loader_train, data_loader_test,
          scaled_lr):
    """Train and evaluate the model

    Args:
        model (dlrm):
        loss_fn (torch.nn.Module): Loss function
        optimizer (torch.nn.optim):
        data_loader_train (torch.utils.data.DataLoader):
        data_loader_test (torch.utils.data.DataLoader):
    """
    model.train()
    base_device = FLAGS.base_device
    print_freq = FLAGS.print_freq
    steps_per_epoch = len(data_loader_train)

    test_freq = FLAGS.test_freq if FLAGS.test_freq is not None else steps_per_epoch

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'loss', utils.SmoothedValue(window_size=print_freq, fmt='{avg:.4f}'))
    metric_logger.add_meter(
        'step_time',
        utils.SmoothedValue(window_size=print_freq, fmt='{avg:.6f}'))
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))

    timer = utils.StepTimer()

    best_auc = 0
    best_epoch = 0
    start_time = time()
    for epoch in range(FLAGS.epochs):

        batch_iter = iter(data_loader_train)
        for step in range(len(data_loader_train)):
            timer.click()

            global_step = steps_per_epoch * epoch + step

            numerical_features, categorical_features, click = next(batch_iter)

            categorical_features = categorical_features.to(base_device).to(
                torch.long)
            numerical_features = numerical_features.to(base_device)
            click = click.to(base_device).to(torch.float32)

            utils.lr_step(optimizer,
                          num_warmup_iter=FLAGS.warmup_steps,
                          current_step=global_step + 1,
                          base_lr=scaled_lr,
                          warmup_factor=FLAGS.warmup_factor,
                          decay_steps=FLAGS.decay_steps,
                          decay_start_step=FLAGS.decay_start_step)

            if FLAGS.max_steps and global_step > FLAGS.max_steps:
                print(
                    F"Reached max global steps of {FLAGS.max_steps}. Stopping."
                )
                break

            output = model(numerical_features,
                           categorical_features).squeeze().float()

            loss = loss_fn(output, click.squeeze())

            optimizer.zero_grad()
            if FLAGS.fp16:
                loss *= FLAGS.loss_scale
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optimizer.step()

            loss_value = loss.item()

            if timer.measured is None:
                # first iteration, no step time etc. to print
                continue

            if global_step < FLAGS.benchmark_warmup_steps:
                metric_logger.update(loss=loss_value,
                                     lr=optimizer.param_groups[0]["lr"])
            else:
                unscale_factor = FLAGS.loss_scale if FLAGS.fp16 else 1
                metric_logger.update(loss=loss_value / unscale_factor,
                                     step_time=timer.measured,
                                     lr=optimizer.param_groups[0]["lr"] *
                                     unscale_factor)

            if step % print_freq == 0 and step > 0:
                if global_step < FLAGS.benchmark_warmup_steps:
                    print(
                        F'Warming up, step [{global_step}/{FLAGS.benchmark_warmup_steps}]'
                    )
                    continue

                eta_str = datetime.timedelta(
                    seconds=int(metric_logger.step_time.global_avg *
                                (steps_per_epoch - step)))
                metric_logger.print(
                    header=
                    F"Epoch:[{epoch}/{FLAGS.epochs}] [{step}/{steps_per_epoch}]  eta: {eta_str}"
                )

            if (
                    global_step + 1
            ) % test_freq == 0 and global_step > 0 and global_step / steps_per_epoch >= FLAGS.test_after:
                loss, auc, test_step_time = evaluate(model, loss_fn,
                                                     data_loader_test)
                print(
                    F"Epoch {epoch} step {step}. Test loss {loss:.5f}, auc {auc:.6f}"
                )

                if auc > best_auc:
                    best_auc = auc
                    best_epoch = epoch + ((step + 1) / steps_per_epoch)
                    maybe_save_checkpoint(model, FLAGS.save_checkpoint_path)

                if FLAGS.auc_threshold and auc >= FLAGS.auc_threshold:
                    stop_time = time()
                    run_time_s = int(stop_time - start_time)
                    print(
                        F"Hit target accuracy AUC {FLAGS.auc_threshold} at epoch "
                        F"{global_step/steps_per_epoch:.2f} in {run_time_s}s. "
                        F"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s."
                    )
                    return

    avg_throughput = FLAGS.batch_size / metric_logger.step_time.avg

    results = {
        'best_auc': best_auc,
        'best_epoch': best_epoch,
        'average_train_throughput': avg_throughput
    }

    if 'test_step_time' in locals():
        avg_test_throughput = FLAGS.test_batch_size / test_step_time
        results['average_test_throughput'] = avg_test_throughput

    dllogger.log(data=results, step=tuple())
Exemple #23
0
def main():

    parser = argparse.ArgumentParser(description='PyTorch Tacotron 2 Training')
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()

    if 'LOCAL_RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        local_rank = int(os.environ['LOCAL_RANK'])
        world_size = int(os.environ['WORLD_SIZE'])
    else:
        local_rank = args.rank
        world_size = args.world_size

    distributed_run = world_size > 1

    if local_rank == 0:
        DLLogger.init(backends=[JSONStreamBackend(Verbosity.DEFAULT,
                                                  args.output \
                                                  + '/' + args.log_file),
                                StdOutBackend(Verbosity.VERBOSE)])
    else:
        DLLogger.init(backends=[])

    for k, v in vars(args).items():
        DLLogger.log(step="PARAMETER", data={k: v})
    DLLogger.log(step="PARAMETER", data={'model_name': 'Tacotron2_PyT'})

    model_name = args.model_name
    parser = models.parse_model_args(model_name, parser)
    args, _ = parser.parse_known_args()

    torch.backends.cudnn.enabled = args.cudnn_enabled
    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    if distributed_run:
        init_distributed(args, world_size, local_rank, args.group_name)

    torch.cuda.synchronize()
    run_start_time = time.perf_counter()

    model_config = models.get_model_config(model_name, args)
    model = \
        models.get_model(model_name,
                         model_config,
                         cpu_run=False,
                         uniform_initialize_bn_weight=not args.disable_uniform_initialize_bn_weight)

    if not args.amp and distributed_run:
        model = DDP(model)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.learning_rate,
                                 weight_decay=args.weight_decay)

    if args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
        if distributed_run:
            model = DDP(model)

    try:
        sigma = args.sigma
    except AttributeError:
        sigma = None

    start_epoch = [0]

    if args.resume_from_last:
        args.checkpoint_path = \
            get_last_checkpoint_filename(args.output, model_name)

    if args.checkpoint_path != "":
        load_checkpoint(model, optimizer, start_epoch, model_config, args.amp,
                        args.checkpoint_path, local_rank,
                        args.resume_from_multiproc)

    start_epoch = start_epoch[0]
    # import pdb; pdb.set_trace()

    criterion = loss_functions.get_loss_function(model_name,
                                                 sigma,
                                                 modified_taco_loss=False)

    try:
        n_frames_per_step = args.n_frames_per_step
    except AttributeError:
        n_frames_per_step = None

    collate_fn = data_functions.get_collate_function(model_name,
                                                     n_frames_per_step)
    trainset = data_functions.get_data_loader(model_name, args.dataset_path,
                                              args.training_files, args)
    if distributed_run:
        train_sampler = DistributedSampler(trainset)
        shuffle = False
    else:
        train_sampler = None
        shuffle = True

    train_loader = DataLoader(trainset,
                              num_workers=1,
                              shuffle=shuffle,
                              sampler=train_sampler,
                              batch_size=args.batch_size,
                              pin_memory=False,
                              drop_last=True,
                              collate_fn=collate_fn)

    valset = data_functions.get_data_loader(model_name, args.dataset_path,
                                            args.validation_files, args)

    batch_to_gpu = data_functions.get_batch_to_gpu(model_name)

    iteration = 0
    train_epoch_items_per_sec = 0.0
    val_loss = 0.0
    num_iters = 0

    model.train()

    for epoch in range(start_epoch, args.epochs):
        torch.cuda.synchronize()
        epoch_start_time = time.perf_counter()
        # used to calculate avg items/sec over epoch
        reduced_num_items_epoch = 0

        train_epoch_items_per_sec = 0.0

        num_iters = 0
        reduced_loss = 0

        # if overflow at the last iteration then do not save checkpoint
        overflow = False

        if distributed_run:
            train_loader.sampler.set_epoch(epoch)

        for i, batch in enumerate(train_loader):
            torch.cuda.synchronize()
            iter_start_time = time.perf_counter()
            DLLogger.log(step=(epoch, i),
                         data={
                             'glob_iter/iters_per_epoch':
                             str(iteration) + "/" + str(len(train_loader))
                         })

            adjust_learning_rate(iteration, epoch, optimizer,
                                 args.learning_rate, args.anneal_steps,
                                 args.anneal_factor, local_rank)

            model.zero_grad()
            x, y, num_items = batch_to_gpu(batch)

            y_pred = model(x)
            loss = criterion(y_pred, y)

            if distributed_run:
                reduced_loss = reduce_tensor(loss.data, world_size).item()
                reduced_num_items = reduce_tensor(num_items.data, 1).item()
            else:
                reduced_loss = loss.item()
                reduced_num_items = num_items.item()
            if np.isnan(reduced_loss):
                raise Exception("loss is NaN")
            print(f"Not reduced_loss is: {loss}")
            print(f"Current reduced loss is {reduced_loss:3f}")
            DLLogger.log(step=(epoch, i), data={'train_loss': reduced_loss})

            num_iters += 1

            # accumulate number of items processed in this epoch
            reduced_num_items_epoch += reduced_num_items

            if args.amp:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    amp.master_params(optimizer), args.grad_clip_thresh)
            else:
                loss.backward()
                grad_norm = \
                    torch.nn.utils.clip_grad_norm(model.parameters(),
                                                  args.grad_clip_thresh)

            optimizer.step()

            torch.cuda.synchronize()
            iter_stop_time = time.perf_counter()
            iter_time = iter_stop_time - iter_start_time
            items_per_sec = reduced_num_items / iter_time
            train_epoch_items_per_sec += items_per_sec

            DLLogger.log(step=(epoch, i),
                         data={'train_items_per_sec': items_per_sec})
            DLLogger.log(step=(epoch, i), data={'train_iter_time': iter_time})
            iteration += 1

        torch.cuda.synchronize()
        epoch_stop_time = time.perf_counter()
        epoch_time = epoch_stop_time - epoch_start_time

        DLLogger.log(step=(epoch, ),
                     data={
                         'train_items_per_sec':
                         (train_epoch_items_per_sec /
                          num_iters if num_iters > 0 else 0.0)
                     })
        DLLogger.log(step=(epoch, ), data={'train_loss': reduced_loss})
        DLLogger.log(step=(epoch, ), data={'train_epoch_time': epoch_time})

        val_loss = validate(model, criterion, valset, epoch, iteration,
                            args.batch_size, world_size, collate_fn,
                            distributed_run, local_rank, batch_to_gpu)

        if (epoch % args.epochs_per_checkpoint
                == 0) and args.bench_class == "":
            save_checkpoint(model, optimizer, epoch, model_config, args.amp,
                            args.output, args.model_name, local_rank,
                            world_size)
        if local_rank == 0:
            DLLogger.flush()

    torch.cuda.synchronize()
    run_stop_time = time.perf_counter()
    run_time = run_stop_time - run_start_time
    DLLogger.log(step=tuple(), data={'run_time': run_time})
    DLLogger.log(step=tuple(), data={'val_loss': val_loss})
    DLLogger.log(step=tuple(),
                 data={
                     'train_items_per_sec':
                     (train_epoch_items_per_sec /
                      num_iters if num_iters > 0 else 0.0)
                 })

    if local_rank == 0:
        DLLogger.flush()
Exemple #24
0
    def train(self,
              iter_unit,
              num_iter,
              run_iter,
              batch_size,
              warmup_steps=50,
              weight_decay=1e-4,
              lr_init=0.1,
              lr_warmup_epochs=5,
              momentum=0.9,
              log_every_n_steps=1,
              loss_scale=256,
              label_smoothing=0.0,
              mixup=0.0,
              use_cosine_lr=False,
              use_static_loss_scaling=False,
              is_benchmark=False,
              quantize=False,
              symmetric=False,
              quant_delay=0,
              finetune_checkpoint=None,
              use_final_conv=False,
              use_qdq=False):

        if iter_unit not in ["epoch", "batch"]:
            raise ValueError(
                '`iter_unit` value is unknown: %s (allowed: ["epoch", "batch"])'
                % iter_unit)

        if self.run_hparams.data_dir is None and not is_benchmark:
            raise ValueError('`data_dir` must be specified for training!')

        if self.run_hparams.use_tf_amp or self.run_hparams.dtype == tf.float16:
            if use_static_loss_scaling:
                os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "0"
            else:
                os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "1"
        else:
            use_static_loss_scaling = False  # Make sure it hasn't been set to True on FP32 training

        num_gpus = hvd.size()
        global_batch_size = batch_size * num_gpus

        if self.run_hparams.data_dir is not None:
            filenames, num_samples, num_steps, num_epochs, num_decay_steps = runner_utils.parse_tfrecords_dataset(
                data_dir=self.run_hparams.data_dir,
                mode="train",
                iter_unit=iter_unit,
                num_iter=num_iter,
                global_batch_size=global_batch_size,
            )

            steps_per_epoch = num_steps / num_epochs

        else:
            num_epochs = 1
            num_steps = num_iter
            steps_per_epoch = num_steps
            num_decay_steps = num_steps
            num_samples = num_steps * batch_size

        if run_iter == -1:
            run_iter = num_steps
        else:
            run_iter = steps_per_epoch * run_iter if iter_unit == "epoch" else run_iter

        if self.run_hparams.use_dali and self.run_hparams.data_idx_dir is not None:
            idx_filenames = runner_utils.parse_dali_idx_dataset(
                data_idx_dir=self.run_hparams.data_idx_dir, mode="train")

        training_hooks = []

        if hvd.rank() == 0:
            print('Starting Model Training...')
            print("Training Epochs", num_epochs)
            print("Total Steps", num_steps)
            print("Steps per Epoch", steps_per_epoch)
            print("Decay Steps", num_decay_steps)
            print("Weight Decay Factor", weight_decay)
            print("Init Learning Rate", lr_init)
            print("Momentum", momentum)
            print("Num GPUs", num_gpus)
            print("Per-GPU Batch Size", batch_size)

            if is_benchmark:
                self.training_logging_hook = hooks.BenchmarkLoggingHook(
                    global_batch_size=global_batch_size,
                    warmup_steps=warmup_steps,
                    logging_steps=log_every_n_steps)
            else:
                self.training_logging_hook = hooks.TrainingLoggingHook(
                    global_batch_size=global_batch_size,
                    num_steps=num_steps,
                    num_samples=num_samples,
                    num_epochs=num_epochs,
                    steps_per_epoch=steps_per_epoch,
                    logging_steps=log_every_n_steps)
            training_hooks.append(self.training_logging_hook)

        if hvd.size() > 1:
            bcast_hook = hvd.hvd_global_object.BroadcastGlobalVariablesHook(0)
            training_hooks.append(bcast_hook)

        training_hooks.append(hooks.PrefillStagingAreasHook())
        training_hooks.append(hooks.TrainingPartitionHook())

        estimator_params = {
            'batch_size': batch_size,
            'steps_per_epoch': steps_per_epoch,
            'num_gpus': num_gpus,
            'momentum': momentum,
            'lr_init': lr_init,
            'lr_warmup_epochs': lr_warmup_epochs,
            'weight_decay': weight_decay,
            'loss_scale': loss_scale,
            'apply_loss_scaling': use_static_loss_scaling,
            'label_smoothing': label_smoothing,
            'mixup': mixup,
            'num_decay_steps': num_decay_steps,
            'use_cosine_lr': use_cosine_lr,
            'use_final_conv': use_final_conv,
            'quantize': quantize,
            'use_qdq': use_qdq,
            'symmetric': symmetric,
            'quant_delay': quant_delay
        }

        if finetune_checkpoint:
            estimator_params['finetune_checkpoint'] = finetune_checkpoint

        image_classifier = self._get_estimator(
            mode='train',
            run_params=estimator_params,
            use_xla=self.run_hparams.use_xla,
            use_dali=self.run_hparams.use_dali,
            gpu_memory_fraction=self.run_hparams.gpu_memory_fraction,
            gpu_id=self.run_hparams.gpu_id)

        def training_data_fn():

            if self.run_hparams.use_dali and self.run_hparams.data_idx_dir is not None:
                if hvd.rank() == 0:
                    print("Using DALI input... ")

                return data_utils.get_dali_input_fn(
                    filenames=filenames,
                    idx_filenames=idx_filenames,
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    training=True,
                    distort_color=self.run_hparams.distort_colors,
                    num_threads=self.run_hparams.num_preprocessing_threads,
                    deterministic=False
                    if self.run_hparams.seed is None else True)

            elif self.run_hparams.data_dir is not None:

                return data_utils.get_tfrecords_input_fn(
                    filenames=filenames,
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    training=True,
                    distort_color=self.run_hparams.distort_colors,
                    num_threads=self.run_hparams.num_preprocessing_threads,
                    deterministic=False
                    if self.run_hparams.seed is None else True)

            else:
                if hvd.rank() == 0:
                    print("Using Synthetic Data ...")
                return data_utils.get_synth_input_fn(
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    num_channels=self.run_hparams.n_channels,
                    data_format=self.run_hparams.input_format,
                    num_classes=self.run_hparams.n_classes,
                    dtype=self.run_hparams.dtype,
                )

        try:
            current_step = image_classifier.get_variable_value("global_step")
        except ValueError:
            current_step = 0

        run_iter = max(0, min(run_iter, num_steps - current_step))
        print("Current step:", current_step)

        if run_iter > 0:
            try:
                image_classifier.train(
                    input_fn=training_data_fn,
                    steps=run_iter,
                    hooks=training_hooks,
                )
            except KeyboardInterrupt:
                print("Keyboard interrupt")

        if hvd.rank() == 0:
            if run_iter > 0:
                print('Ending Model Training ...')
                train_throughput = self.training_logging_hook.mean_throughput.value(
                )
                dllogger.log(data={'train_throughput': train_throughput},
                             step=tuple())
            else:
                print(
                    'Model already trained required number of steps. Skipped')
def main():
    args = parse_args()

    if args.type == 'pytorch':
        from mem_transformer import MemTransformerLM
    else:
        from inference.mem_transformer_jit import MemTransformerLM

    torch.cuda.set_device(args.local_rank)
    device = torch.device('cuda' if args.cuda else 'cpu')
    utils.distributed.init_distributed(args.cuda)

    with utils.distributed.sync_workers() as rank:
        if rank == 0:
            create_exp_dir(args.work_dir, debug=args.debug)

    # Setup logging
    if args.log_all_ranks:
        log_file = f'eval_log_rank_{utils.distributed.get_rank()}.log'
    else:
        log_file = f'eval_log.log'

    dllog_file = f'eval_log.json'
    log_file = os.path.join(args.work_dir, log_file)
    dllog_file = os.path.join(args.work_dir, dllog_file)
    if args.debug:
        log_file = os.devnull
        dllog_file = os.devnull

    utils.exp_utils.setup_logging(log_all_ranks=args.log_all_ranks,
                                  filename=log_file,
                                  filemode='a',
                                  )
    utils.exp_utils.setup_dllogger(enabled=True, filename=dllog_file)

    logging.info(args)
    dllogger.log(step='PARAMETER', data=vars(args))

    if not args.no_env:
        log_env_info()

    if args.model:
        model_path = args.model
    elif args.work_dir:
        model_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
    else:
        raise RuntimeError('Specify path to checkpoint using --model or --work_dir')

    checkpoint = load_checkpoint(model_path)

    if args.manual:
        vocab = checkpoint['vocab']

        if hasattr(vocab, 'sym2idx') and not hasattr(vocab, 'unk_idx'):
            vocab.unk_idx = vocab.sym2idx['<unk>']

        text = " ".join(args.manual)
        tokenized = tokenize_raw(text)
        symbols = vocab.tokenize(tokenized, add_eos=True)
        tensor = vocab.convert_to_tensor(symbols)

        iter = data_utils.LMOrderedIterator(tensor, bsz=args.batch_size,
                                            bptt=args.tgt_len, device=device,
                                            ext_len=args.ext_len, warmup=False)
    else:
        # Load dataset
        corpus = get_lm_corpus(args.data, args.dataset, checkpoint['args'].vocab)

        if args.split == 'valid' or args.split == 'test':
            iter = corpus.get_iterator(args.split, args.batch_size, args.tgt_len,
                                       device=device, mem_len=args.mem_len,
                                       ext_len=args.ext_len)
        else:
            raise RuntimeError('Unknown split')

    if args.fp16:
        dtype = torch.float16
        math_str = 'fp16'
    else:
        dtype = torch.float32
        math_str = 'fp32'

    if args.load_torchscript:
        model = torch.jit.load(args.load_torchscript)
    else:
        checkpoint['model_config']['tgt_len'] = args.tgt_len
        checkpoint['model_config']['ext_len'] = args.ext_len
        checkpoint['model_config']['mem_len'] = args.mem_len
        checkpoint['model_config']['clamp_len'] = args.clamp_len
        checkpoint['model_config']['same_length'] = args.same_length
        checkpoint['model_config']['dtype'] = dtype

        model = MemTransformerLM(**checkpoint['model_config'])
        if args.type == 'pytorch':
            model.load_state_dict(checkpoint['model_state'])
        elif args.type == 'torchscript':
            model.load_state_dict(checkpoint['model_state'], strict=False)

    model = model.eval()
    model = model.to(device)
    model = model.to(dtype)

    if args.type == 'torchscript':
        state = checkpoint['model_state']

        tie_projs = checkpoint['model_config']['tie_projs']
        tie_weight = checkpoint['model_config']['tie_weight']
        div_val = checkpoint['model_config']['div_val']
        d_model = checkpoint['model_config']['d_model']
        d_embed = checkpoint['model_config']['d_embed']

        if div_val != 1 or d_model != d_embed:
            for i in range(len(model.word_emb.emb_projs)):
                model.word_emb.emb_projs[i] = state[f'word_emb.emb_projs.{i}'].to(dtype)

        for i in range(len(model.crit.out_projs)):
            if div_val == 1:
                src = 0
            else:
                src = i
            if model.crit.out_projs[i] is not None:
                if tie_projs[i]:
                    model.crit.out_projs[i] = state[f'word_emb.emb_projs.{src}'].to(dtype)
                else:
                    model.crit.out_projs[i] = state[f'crit.out_projs.{i}'].to(dtype)

        for i in range(len(model.crit.out_layers_biases)):
            model.crit.out_layers_biases[i] = state[f'crit.out_layers_biases.{i}'].to(dtype)

        if tie_weight:
            for i in range(len(model.crit.out_layers_weights)):
                model.crit.out_layers_weights[i] = state[f'word_emb.emb_layers.{i}.weight'].to(dtype)
        else:
            for i in range(len(model.crit.out_layers_weights)):
                model.crit.out_layers_weights[i] = state[f'crit.out_layers_weights.{i}'].to(dtype)

        model = torch.jit.script(model)

    if args.type != 'pytorch':
        compile_model(model, device, args)

    if args.type == 'torchscript' and args.save_torchscript:
        torch.jit.save(model, args.save_torchscript)

    logging.info(f'Evaluating with: math {math_str} type {args.type} '
                 f'bsz {args.batch_size} tgt_len {args.tgt_len} '
                 f'ext_len {args.ext_len} mem_len {args.mem_len} '
                 f'clamp_len {args.clamp_len}')

    meters = {}
    warmup = args.mem_len // args.tgt_len + 2
    meters['eval_throughput'] = AverageMeter(warmup=warmup, keep=args.save_data)
    meters['eval_latency'] = AverageMeter(warmup=warmup, keep=args.save_data)

    loss = evaluate(iter, model, meters, args.log_interval, args.max_size, args.repeat)
    perplexity = math.exp(loss)
    log_str = format_log(loss, args.split, args)

    summary = {
        'eval_loss': loss,
        'eval_ppl': perplexity,
        }

    logging.info('=' * 100)
    logging.info(log_str)
    logging.info('=' * 100)

    if args.save_data:
        latency_data = np.array(meters['eval_latency'].vals)
        throughput_data = np.array(meters['eval_throughput'].vals)
        precision = 'fp16' if args.fp16 else 'fp32'
        data_fname = f'eval_data_{args.batch_size}_{precision}_{args.type}'
        data_path = os.path.join(args.work_dir, data_fname)
        data = {
            'args': args,
            'throughput': throughput_data,
            'latency': latency_data,
            }
        with open(data_path, 'wb') as f:
            pickle.dump(data, f)
        logging.info(f'Throughput Avg: {throughput_data.mean():.2f} tok/s')
        logging.info(f'Latency Avg: {1000.0 * latency_data.mean():.2f} ms')
        for p in args.percentiles:
            logging.info(f'Latency {p}%: {1000.0 * np.percentile(latency_data, p):.2f} ms')

        logging.info('=' * 100)

        summary.update({
            'eval_throughput': throughput_data.mean(),
            'eval_avg_latency': 1000 * latency_data.mean(),
            })
        for p in args.percentiles:
            summary[f'eval_{p}%_latency'] = 1000 * np.percentile(latency_data, p)

    dllogger.log(step=tuple(), data=summary)

    passed = benchmark(target_perplexity=args.target_perplexity,
                       test_perplexity=perplexity,
                       target_throughput=args.target_throughput,
                       test_throughput=meters['eval_throughput'].avg,
                       )
    if not passed:
        sys.exit(1)
Exemple #26
0
    def evaluate(
        self,
        iter_unit,
        num_iter,
        batch_size,
        warmup_steps=50,
        log_every_n_steps=1,
        is_benchmark=False,
        export_dir=None,
        quantize=False,
        symmetric=False,
        use_qdq=False,
        use_final_conv=False,
    ):

        if iter_unit not in ["epoch", "batch"]:
            raise ValueError(
                '`iter_unit` value is unknown: %s (allowed: ["epoch", "batch"])'
                % iter_unit)

        if self.run_hparams.data_dir is None and not is_benchmark:
            raise ValueError('`data_dir` must be specified for evaluation!')

        if hvd.rank() != 0:
            raise RuntimeError('Multi-GPU inference is not supported')

        estimator_params = {
            'quantize': quantize,
            'symmetric': symmetric,
            'use_qdq': use_qdq,
            'use_final_conv': use_final_conv
        }

        image_classifier = self._get_estimator(
            mode='validation',
            run_params=estimator_params,
            use_xla=self.run_hparams.use_xla,
            use_dali=self.run_hparams.use_dali,
            gpu_memory_fraction=self.run_hparams.gpu_memory_fraction,
            gpu_id=self.run_hparams.gpu_id)

        if self.run_hparams.data_dir is not None:
            filenames, num_samples, num_steps, num_epochs, num_decay_steps = runner_utils.parse_tfrecords_dataset(
                data_dir=self.run_hparams.data_dir,
                mode="validation",
                iter_unit=iter_unit,
                num_iter=num_iter,
                global_batch_size=batch_size,
            )

        else:
            num_epochs = 1
            num_decay_steps = -1
            num_steps = num_iter

        if self.run_hparams.use_dali and self.run_hparams.data_idx_dir is not None:
            idx_filenames = runner_utils.parse_dali_idx_dataset(
                data_idx_dir=self.run_hparams.data_idx_dir, mode="validation")

        eval_hooks = []

        if hvd.rank() == 0:
            self.eval_logging_hook = hooks.BenchmarkLoggingHook(
                global_batch_size=batch_size,
                warmup_steps=warmup_steps,
                logging_steps=log_every_n_steps)
            eval_hooks.append(self.eval_logging_hook)

            print('Starting Model Evaluation...')
            print("Evaluation Epochs", num_epochs)
            print("Evaluation Steps", num_steps)
            print("Decay Steps", num_decay_steps)
            print("Global Batch Size", batch_size)

        def evaluation_data_fn():

            if self.run_hparams.use_dali and self.run_hparams.data_idx_dir is not None:
                if hvd.rank() == 0:
                    print("Using DALI input... ")

                return data_utils.get_dali_input_fn(
                    filenames=filenames,
                    idx_filenames=idx_filenames,
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    training=False,
                    distort_color=self.run_hparams.distort_colors,
                    num_threads=self.run_hparams.num_preprocessing_threads,
                    deterministic=False
                    if self.run_hparams.seed is None else True)

            elif self.run_hparams.data_dir is not None:
                return data_utils.get_tfrecords_input_fn(
                    filenames=filenames,
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    training=False,
                    distort_color=self.run_hparams.distort_colors,
                    num_threads=self.run_hparams.num_preprocessing_threads,
                    deterministic=False
                    if self.run_hparams.seed is None else True)

            else:
                print("Using Synthetic Data ...\n")
                return data_utils.get_synth_input_fn(
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    num_channels=self.run_hparams.n_channels,
                    data_format=self.run_hparams.input_format,
                    num_classes=self.run_hparams.n_classes,
                    dtype=self.run_hparams.dtype,
                )

        try:
            eval_results = image_classifier.evaluate(
                input_fn=evaluation_data_fn,
                steps=num_steps,
                hooks=eval_hooks,
            )

            eval_throughput = self.eval_logging_hook.mean_throughput.value()
            if len(self.eval_logging_hook.latencies) > 0:
                eval_latencies = np.array(
                    self.eval_logging_hook.latencies) * 1000
                eval_latencies_q = np.quantile(eval_latencies,
                                               q=[0.9, 0.95, 0.99])
                eval_latencies_mean = np.mean(eval_latencies)
                additional_metrics = {
                    'eval_latency_avg': eval_latencies_mean,
                    'eval_latency_p90': eval_latencies_q[0],
                    'eval_latency_p95': eval_latencies_q[1],
                    'eval_latency_p99': eval_latencies_q[2],
                }
            else:
                additional_metrics = {}

            dllogger.log(data={
                'top1_accuracy': float(eval_results['top1_accuracy']),
                'top5_accuracy': float(eval_results['top5_accuracy']),
                'eval_throughput': eval_throughput,
                **additional_metrics
            },
                         step=tuple())

            if export_dir is not None:
                dllogger.log(data={'export_dir': export_dir}, step=tuple())
                input_receiver_fn = data_utils.get_serving_input_receiver_fn(
                    batch_size=None,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    num_channels=self.run_hparams.n_channels,
                    data_format=self.run_hparams.input_format,
                    dtype=self.run_hparams.dtype)

                self.exported_path = image_classifier.export_savedmodel(
                    export_dir, input_receiver_fn)

        except KeyboardInterrupt:
            print("Keyboard interrupt")

        print('Model evaluation finished')
def main(argv):
    if FLAGS.experimental_columnwise_split and not FLAGS.data_parallel_bottom_mlp and FLAGS.num_numerical_features > 0:
        raise ValueError(
            'Currently you when using the --experimenal_columnwise_split option '
            'you must either set --data_parallel_bottom_mlp or --num_numerical_features=0'
        )

    if FLAGS.batch_size != FLAGS.valid_batch_size:
        raise ValueError(
            'For now, validation batch size must be the same as training batch size'
        )

    hvd.init()
    init_logging(log_path=FLAGS.log_path, FLAGS=FLAGS)
    init_tf(FLAGS)

    train_pipeline, validation_pipeline, dataset_metadata, multi_gpu_metadata = create_input_pipelines(
        FLAGS)

    if FLAGS.dummy_model:
        dlrm = DummyDlrm(FLAGS=FLAGS,
                         dataset_metadata=dataset_metadata,
                         multi_gpu_metadata=multi_gpu_metadata)
    else:
        dlrm = Dlrm(FLAGS=FLAGS,
                    dataset_metadata=dataset_metadata,
                    multi_gpu_metadata=multi_gpu_metadata)

    if FLAGS.optimizer == 'sgd':
        embedding_optimizer = tf.keras.optimizers.SGD(lr=FLAGS.learning_rate,
                                                      momentum=0)
        if FLAGS.amp:
            embedding_optimizer = LossScaleOptimizer(
                embedding_optimizer,
                initial_scale=FLAGS.loss_scale,
                dynamic=False)
        mlp_optimizer = embedding_optimizer
        optimizers = [mlp_optimizer]

    elif FLAGS.optimizer == 'adam':
        embedding_optimizer = tfa.optimizers.LazyAdam(lr=FLAGS.learning_rate)
        mlp_optimizer = tf.keras.optimizers.Adam(lr=FLAGS.learning_rate)
        if FLAGS.amp:
            embedding_optimizer = LossScaleOptimizer(
                embedding_optimizer,
                initial_scale=FLAGS.loss_scale,
                dynamic=False)
            mlp_optimizer = LossScaleOptimizer(mlp_optimizer,
                                               initial_scale=FLAGS.loss_scale,
                                               dynamic=False)
        optimizers = [mlp_optimizer, embedding_optimizer]

    scheduler = LearningRateScheduler(optimizers,
                                      warmup_steps=FLAGS.warmup_steps,
                                      base_lr=FLAGS.learning_rate,
                                      decay_start_step=FLAGS.decay_start_step,
                                      decay_steps=FLAGS.decay_steps)

    timer = IterTimer(train_batch_size=FLAGS.batch_size,
                      test_batch_size=FLAGS.valid_batch_size,
                      optimizer=embedding_optimizer,
                      print_freq=FLAGS.print_freq,
                      enabled=hvd.rank() == 0)

    splitter = DataParallelSplitter(batch_size=FLAGS.batch_size)

    dlrm.maybe_restore_checkpoint(FLAGS.restore_checkpoint_path)

    if FLAGS.mode == 'inference':
        inference_benchmark(validation_pipeline, dlrm, timer, splitter, FLAGS)
        return

    elif FLAGS.mode == 'eval':
        test_auc, test_loss, _ = evaluate(validation_pipeline,
                                          dlrm,
                                          timer,
                                          auc_thresholds=FLAGS.auc_thresholds,
                                          data_parallel_splitter=splitter)
        dist_print(
            f'Evaluation completed, AUC: {test_auc:.6f}, test_loss: {test_loss:.6f}'
        )
        return

    eval_points = compute_eval_points(train_batches=len(train_pipeline),
                                      evals_per_epoch=FLAGS.evals_per_epoch)

    trainer = DlrmTrainer(dlrm,
                          embedding_optimizer=embedding_optimizer,
                          mlp_optimizer=mlp_optimizer,
                          amp=FLAGS.amp,
                          lr_scheduler=scheduler)

    best_auc = 0
    train_begin = time.time()
    for epoch in range(FLAGS.epochs):
        for step, ((numerical_features, categorical_features),
                   labels) in enumerate(train_pipeline):
            if step == FLAGS.profiler_start_step and hvd.rank(
            ) == FLAGS.profiled_rank:
                tf.profiler.experimental.start('logdir')

            if FLAGS.profiler_start_step and step == FLAGS.profiler_start_step + 100 and hvd.rank(
            ) == FLAGS.profiled_rank:
                tf.profiler.experimental.stop()

            labels = splitter(labels)
            if FLAGS.data_parallel_bottom_mlp:
                numerical_features = splitter(numerical_features)

            loss = trainer.train_step(numerical_features, categorical_features,
                                      labels)

            timer.step_train(loss=loss)

            if FLAGS.max_steps != -1 and step > FLAGS.max_steps:
                dist_print(f'Max steps of {FLAGS.max_steps} reached, exiting')
                break

            if step in eval_points:
                test_auc, test_loss, _ = evaluate(
                    validation_pipeline,
                    dlrm,
                    timer,
                    FLAGS.auc_thresholds,
                    data_parallel_splitter=splitter)
                dist_print(
                    f'Evaluation completed, AUC: {test_auc:.6f}, test_loss: {test_loss:.6f}'
                )
                timer.test_idx = 0
                best_auc = max(best_auc, test_auc)

    elapsed = time.time() - train_begin
    dlrm.maybe_save_checkpoint(FLAGS.save_checkpoint_path)

    if hvd.rank() == 0:
        dist_print(f'Training run completed, elapsed: {elapsed:.0f} [s]')
        results = {
            'throughput': FLAGS.batch_size / timer.mean_train_time(),
            'mean_step_time_ms': timer.mean_train_time() * 1000,
            'auc': best_auc
        }
        dllogger.log(data=results, step=tuple())
def train(tr_iter, va_iter, model, para_model, model_config, optimizer,
          optimizer_sparse, scheduler, scheduler_sparse, vocab, epoch,
          last_batch, last_iter, train_step, best_val_loss, meters,
          timeout_handler, args):
    # Turn on training mode which enables dropout.
    model.train()

    train_loss = 0
    target_tokens = 0
    log_step = 0
    log_start_time = time.time()

    mems = [None for _ in range(args.batch_chunk)]
    if args.varlen:
        train_iter = tr_iter.get_varlen_iter(start=last_iter)
    else:
        train_iter = tr_iter.get_fixlen_iter(start=last_iter)

    for batch, (data, target, seq_len, _) in enumerate(train_iter,
                                                       start=last_batch + 1):
        log_step += 1
        target_tokens += target.numel()

        for param in model.parameters():
            param.grad = None

        data_chunks = torch.chunk(data, args.batch_chunk, 1)
        target_chunks = torch.chunk(target, args.batch_chunk, 1)

        for i in range(args.batch_chunk):
            data_i = data_chunks[i].contiguous()
            target_i = target_chunks[i].contiguous()
            loss, mems[i] = para_model(data_i, target_i, mems[i])
            loss = loss.float().mean().type_as(loss) / args.batch_chunk

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            train_loss += loss.float().item()

        if args.fp16:
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.clip)
        else:
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

        optimizer.step()
        if optimizer_sparse:
            optimizer_sparse.step()

        # step-wise learning rate annealing
        train_step += 1
        if args.scheduler in ['cosine', 'constant', 'dev_perf']:
            # linear warmup stage
            if train_step < args.warmup_step:
                curr_lr = args.lr * train_step / args.warmup_step
                optimizer.param_groups[0]['lr'] = curr_lr
                if optimizer_sparse:
                    optimizer_sparse.param_groups[0]['lr'] = curr_lr * 2
            else:
                if args.scheduler == 'cosine':
                    scheduler.step(train_step - args.warmup_step)
                    if scheduler_sparse:
                        scheduler_sparse.step(train_step - args.warmup_step)
        elif args.scheduler == 'inv_sqrt':
            scheduler.step(train_step)
            if scheduler_sparse:
                scheduler_sparse.step(train_step)

        if train_step % args.log_interval == 0:
            cur_loss = train_loss / log_step
            cur_loss = utils.distributed.all_reduce_item(cur_loss, op='mean')
            train_loss = 0

            elapsed = time.time() - log_start_time
            avg_elapsed = elapsed / log_step
            avg_elapsed = utils.distributed.all_reduce_item(avg_elapsed,
                                                            op='max')
            log_start_time = time.time()
            log_step = 0

            lr = optimizer.param_groups[0]['lr']
            throughput = target_tokens / elapsed
            throughput = utils.distributed.all_reduce_item(throughput,
                                                           op='sum')
            meters['train_throughput'].update(throughput)
            target_tokens = 0

            log_str = '| epoch {:3d} step {:>8d} | batches {:>6d} / {:d} | lr {:.3e} ' \
                '| ms/batch {:5.1f} | tok/s {:7.0f} | loss {:5.2f}'.format(
                    epoch,
                    train_step,
                    batch,
                    tr_iter.n_batch,
                    lr,
                    avg_elapsed * 1000,
                    throughput,
                    cur_loss,
                    )

            dllogger_data = {
                'epoch': epoch,
                'train_batch': batch + 1,
                'lr': lr,
                'train_time/batch': avg_elapsed * 1000,
                'train_throughput': throughput,
                'train_loss': cur_loss,
            }

            if args.dataset in ['enwik8', 'text8']:
                log_str += ' | bpc {:9.5f}'.format(cur_loss / math.log(2))
                dllogger_data[
                    'train_bits_per_character'] = cur_loss / math.log(2)
            else:
                log_str += ' | ppl {:9.2f}'.format(math.exp(cur_loss))
                dllogger_data['train_perplexity'] = math.exp(cur_loss)

            logging.info(log_str)
            dllogger.log(step=tuple([train_step]), data=dllogger_data)

        do_periodic_eval = train_step % args.eval_interval == 0
        is_final_step = train_step == args.max_step
        interrupted = timeout_handler.interrupted

        if (do_periodic_eval or is_final_step
                or interrupted) and not args.no_eval:
            eval_start_time = time.time()
            val_loss = evaluate(va_iter, model, args)
            val_loss = utils.distributed.all_reduce_item(val_loss, op='mean')

            logging.info('-' * 100)
            log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \
                      '| valid loss {:5.2f}'.format(
                          train_step // args.eval_interval,
                          train_step,
                          (time.time() - eval_start_time),
                          val_loss,
                          )

            dllogger_data = {
                'valid_elapsed': (time.time() - eval_start_time),
                'valid_loss': val_loss,
            }

            if args.dataset in ['enwik8', 'text8']:
                log_str += ' | bpc {:9.5f}'.format(val_loss / math.log(2))
                dllogger_data[
                    'valid_bits_per_character'] = val_loss / math.log(2)
            else:
                log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss))
                dllogger_data['valid_perplexity'] = math.exp(val_loss)
            logging.info(log_str)
            logging.info('-' * 100)
            dllogger.log(step=tuple([train_step]), data=dllogger_data)

            last_iter = tr_iter.last_iter

            # Check if the validation loss is the best we've seen so far.
            is_best = False
            if not best_val_loss or val_loss < best_val_loss:
                best_val_loss = val_loss
                is_best = True

            if not args.debug:
                save_checkpoint(args, model, model_config, optimizer,
                                scheduler, vocab, epoch, batch, last_iter,
                                train_step, best_val_loss, is_best,
                                args.work_dir)

            # dev-performance based learning rate annealing
            if args.scheduler == 'dev_perf':
                scheduler.step(val_loss)
                if scheduler_sparse:
                    scheduler_sparse.step(val_loss)

            # subtract eval time from timers for training
            log_start_time += time.time() - eval_start_time

        if interrupted:
            logging.info(f'Received SIGTERM, exiting')
            sys.exit(0)

        if is_final_step:
            break
    return train_step, best_val_loss
Exemple #29
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch FastPitch Training',
                                     allow_abbrev=False)
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()

    if 'LOCAL_RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        local_rank = int(os.environ['LOCAL_RANK'])
        world_size = int(os.environ['WORLD_SIZE'])
    else:
        local_rank = args.rank
        world_size = args.world_size
    distributed_run = world_size > 1

    torch.manual_seed(args.seed + local_rank)
    np.random.seed(args.seed + local_rank)

    if local_rank == 0:
        if not os.path.exists(args.output):
            os.makedirs(args.output)
        init_dllogger(args.log_file)
    else:
        init_dllogger(dummy=True)

    for k, v in vars(args).items():
        DLLogger.log(step="PARAMETER", data={k: v})

    parser = models.parse_model_args('FastPitch', parser)
    args, unk_args = parser.parse_known_args()
    if len(unk_args) > 0:
        raise ValueError(f'Invalid options {unk_args}')

    torch.backends.cudnn.enabled = args.cudnn_enabled
    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    if distributed_run:
        init_distributed(args, world_size, local_rank, args.group_name)

    device = torch.device('cuda' if args.cuda else 'cpu')
    model_config = models.get_model_config('FastPitch', args)
    model = models.get_model('FastPitch', model_config, device)

    # Store pitch mean/std as params to translate from Hz during inference
    fpath = common.utils.stats_filename(args.dataset_path, args.training_files,
                                        'pitch_char')
    with open(args.pitch_mean_std_file, 'r') as f:
        stats = json.load(f)
    model.pitch_mean[0] = stats['mean']
    model.pitch_std[0] = stats['std']

    kw = dict(lr=args.learning_rate,
              betas=(0.9, 0.98),
              eps=1e-9,
              weight_decay=args.weight_decay)
    if args.optimizer == 'adam':
        optimizer = FusedAdam(model.parameters(), **kw)
    elif args.optimizer == 'lamb':
        optimizer = FusedLAMB(model.parameters(), **kw)
    else:
        raise ValueError

    if args.amp_run:
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

    if args.ema_decay > 0:
        ema_model = copy.deepcopy(model)
    else:
        ema_model = None

    if distributed_run:
        model = DDP(model)

    start_epoch = [1]

    assert args.checkpoint_path is None or args.checkpoint_resume is False, (
        "Specify a single checkpoint source")
    if args.checkpoint_path is not None:
        ch_fpath = args.checkpoint_path
    elif args.checkpoint_resume:
        ch_fpath = last_checkpoint(args.output)
    else:
        ch_fpath = None

    if ch_fpath is not None:
        load_checkpoint(local_rank, model, ema_model, optimizer, start_epoch,
                        model_config, args.amp_run, ch_fpath, world_size)

    start_epoch = start_epoch[0]

    criterion = loss_functions.get_loss_function(
        'FastPitch',
        dur_predictor_loss_scale=args.dur_predictor_loss_scale,
        pitch_predictor_loss_scale=args.pitch_predictor_loss_scale)

    collate_fn = data_functions.get_collate_function('FastPitch')
    trainset = data_functions.get_data_loader('FastPitch', args.dataset_path,
                                              args.training_files, args)
    valset = data_functions.get_data_loader('FastPitch', args.dataset_path,
                                            args.validation_files, args)
    if distributed_run:
        train_sampler, shuffle = DistributedSampler(trainset), False
    else:
        train_sampler, shuffle = None, True

    train_loader = DataLoader(trainset,
                              num_workers=16,
                              shuffle=shuffle,
                              sampler=train_sampler,
                              batch_size=args.batch_size,
                              pin_memory=False,
                              drop_last=True,
                              collate_fn=collate_fn)

    batch_to_gpu = data_functions.get_batch_to_gpu('FastPitch')

    model.train()

    train_tblogger = TBLogger(local_rank, args.output, 'train')
    val_tblogger = TBLogger(local_rank, args.output, 'val', dummies=True)
    if args.ema_decay > 0:
        val_ema_tblogger = TBLogger(local_rank, args.output, 'val_ema')

    val_loss = 0.0
    total_iter = 0
    torch.cuda.synchronize()
    for epoch in range(start_epoch, args.epochs + 1):
        epoch_start_time = time.time()

        epoch_loss = 0.0
        epoch_mel_loss = 0.0
        epoch_num_frames = 0
        epoch_frames_per_sec = 0.0

        if distributed_run:
            train_loader.sampler.set_epoch(epoch)

        accumulated_steps = 0
        iter_loss = 0
        iter_num_frames = 0
        iter_meta = {}

        epoch_iter = 0
        num_iters = len(train_loader) // args.gradient_accumulation_steps
        for batch in train_loader:
            if accumulated_steps == 0:
                if epoch_iter == num_iters:
                    break
                total_iter += 1
                epoch_iter += 1
                iter_start_time = time.time()
                start = time.perf_counter()

                old_lr = optimizer.param_groups[0]['lr']
                adjust_learning_rate(total_iter, optimizer, args.learning_rate,
                                     args.warmup_steps)
                new_lr = optimizer.param_groups[0]['lr']

                if new_lr != old_lr:
                    dllog_lrate_change = f'{old_lr:.2E} -> {new_lr:.2E}'
                    train_tblogger.log_value(total_iter, 'lrate', new_lr)
                else:
                    dllog_lrate_change = None

                model.zero_grad()

            x, y, num_frames = batch_to_gpu(batch)
            y_pred = model(x, use_gt_durations=True)
            loss, meta = criterion(y_pred, y)

            loss /= args.gradient_accumulation_steps
            meta = {
                k: v / args.gradient_accumulation_steps
                for k, v in meta.items()
            }

            if args.amp_run:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            if distributed_run:
                reduced_loss = reduce_tensor(loss.data, world_size).item()
                reduced_num_frames = reduce_tensor(num_frames.data, 1).item()
                meta = {
                    k: reduce_tensor(v, world_size)
                    for k, v in meta.items()
                }
            else:
                reduced_loss = loss.item()
                reduced_num_frames = num_frames.item()
            if np.isnan(reduced_loss):
                raise Exception("loss is NaN")

            accumulated_steps += 1
            iter_loss += reduced_loss
            iter_num_frames += reduced_num_frames
            iter_meta = {k: iter_meta.get(k, 0) + meta.get(k, 0) for k in meta}

            if accumulated_steps % args.gradient_accumulation_steps == 0:

                train_tblogger.log_grads(total_iter, model)
                if args.amp_run:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.grad_clip_thresh)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.grad_clip_thresh)

                optimizer.step()
                apply_ema_decay(model, ema_model, args.ema_decay)

                iter_stop_time = time.time()
                iter_time = iter_stop_time - iter_start_time
                frames_per_sec = iter_num_frames / iter_time
                epoch_frames_per_sec += frames_per_sec
                epoch_loss += iter_loss
                epoch_num_frames += iter_num_frames
                iter_mel_loss = iter_meta['mel_loss'].item()
                epoch_mel_loss += iter_mel_loss

                DLLogger.log(
                    (epoch, epoch_iter, num_iters),
                    OrderedDict([('train_loss', iter_loss),
                                 ('train_mel_loss', iter_mel_loss),
                                 ('train_frames/s', frames_per_sec),
                                 ('took', iter_time),
                                 ('lrate_change', dllog_lrate_change)]))
                train_tblogger.log_meta(total_iter, iter_meta)

                accumulated_steps = 0
                iter_loss = 0
                iter_num_frames = 0
                iter_meta = {}

        # Finished epoch
        epoch_stop_time = time.time()
        epoch_time = epoch_stop_time - epoch_start_time

        DLLogger.log((epoch, ),
                     data=OrderedDict([
                         ('avg_train_loss', epoch_loss / epoch_iter),
                         ('avg_train_mel_loss', epoch_mel_loss / epoch_iter),
                         ('avg_train_frames/s', epoch_num_frames / epoch_time),
                         ('took', epoch_time)
                     ]))

        tik = time.time()
        val_loss, meta, num_frames = validate(model,
                                              criterion,
                                              valset,
                                              args.batch_size,
                                              world_size,
                                              collate_fn,
                                              distributed_run,
                                              local_rank,
                                              batch_to_gpu,
                                              use_gt_durations=True)
        tok = time.time()

        DLLogger.log((epoch, ),
                     data=OrderedDict([
                         ('val_loss', val_loss),
                         ('val_mel_loss', meta['mel_loss'].item()),
                         ('val_frames/s', num_frames / (tok - tik)),
                         ('took', tok - tik),
                     ]))
        val_tblogger.log_meta(total_iter, meta)

        if args.ema_decay > 0:
            tik_e = time.time()
            val_loss_e, meta_e, num_frames_e = validate(ema_model,
                                                        criterion,
                                                        valset,
                                                        args.batch_size,
                                                        world_size,
                                                        collate_fn,
                                                        distributed_run,
                                                        local_rank,
                                                        batch_to_gpu,
                                                        use_gt_durations=True)
            tok_e = time.time()

            DLLogger.log(
                (epoch, ),
                data=OrderedDict([
                    ('val_ema_loss', val_loss_e),
                    ('val_ema_mel_loss', meta_e['mel_loss'].item()),
                    ('val_ema_frames/s', num_frames_e / (tok_e - tik_e)),
                    ('took', tok_e - tik_e),
                ]))
            val_ema_tblogger.log_meta(total_iter, meta)

        if (epoch > 0 and args.epochs_per_checkpoint > 0
                and (epoch % args.epochs_per_checkpoint == 0)
                and local_rank == 0):

            checkpoint_path = os.path.join(args.output,
                                           f"FastPitch_checkpoint_{epoch}.pt")
            save_checkpoint(local_rank, model, ema_model, optimizer, epoch,
                            model_config, args.amp_run, checkpoint_path)
        if local_rank == 0:
            DLLogger.flush()

    # Finished training
    DLLogger.log((),
                 data=OrderedDict([
                     ('avg_train_loss', epoch_loss / epoch_iter),
                     ('avg_train_mel_loss', epoch_mel_loss / epoch_iter),
                     ('avg_train_frames/s', epoch_num_frames / epoch_time),
                 ]))
    DLLogger.log((),
                 data=OrderedDict([
                     ('val_loss', val_loss),
                     ('val_mel_loss', meta['mel_loss'].item()),
                     ('val_frames/s', num_frames / (tok - tik)),
                 ]))
    if local_rank == 0:
        DLLogger.flush()
Exemple #30
0
def main():
    setup_default_logging()  ## TODO(sugh) replace
    args, args_text = _parse_args()
    set_affinity(args.local_rank)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        torch.cuda.manual_seed_all(args.seed)
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()

        # Set device limit on the current device
        # cudaLimitMaxL2FetchGranularity = 0x05
        pValue = ctypes.cast((ctypes.c_int*1)(), ctypes.POINTER(ctypes.c_int))
        _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
        _libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
        assert pValue.contents.value == 128
    assert args.rank >= 0

    setup_dllogger(args.rank, filename=args.dllogger_file)

    if args.distributed:
        logging.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
                     % (args.rank, args.world_size))
    else:
        logging.info('Training with a single process on 1 GPU.')

    if args.waymo:
        if (args.waymo_train is not None and args.waymo_val is None) or (args.waymo_train is None and args.waymo_val is not None):
            raise Exception("waymo_train or waymo_val is not set")

    memory_format = (
        torch.channels_last if args.memory_format == "nhwc" else torch.contiguous_format
    )

    model = create_model(
        args.model,
        input_size=args.input_size,
        num_classes=args.num_classes,
        bench_task='train',
        pretrained=args.pretrained,
        pretrained_backbone_path=args.pretrained_backbone_path,
        redundant_bias=args.redundant_bias,
        checkpoint_path=args.initial_checkpoint,
        label_smoothing=args.smoothing,
        fused_focal_loss=args.fused_focal_loss,
        remove_params=args.remove_weights,
        freeze_layers=args.freeze_layers,
        strict_load=False
    )
    # FIXME decide which args to keep and overlay on config / pass to backbone
    #     num_classes=args.num_classes,
    input_size = model.config.image_size
    data_config = model.config
    print("Input size to be passed to dataloaders: {}".format(input_size))
    print("Image size used in model: {}".format(model.config.image_size))

    if args.rank == 0:
        dllogger.log(step='PARAMETER', data={'model_name':args.model, 'param_count': sum([m.numel() for m in model.parameters()])})
    model = model.cuda().to(memory_format=memory_format)

    # # optionally resume from a checkpoint

    if args.distributed:
        if args.sync_bn:
            try:
                model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
                if args.local_rank == 0:
                    logging.info(
                        'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                        'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
            except Exception as e:
                logging.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
    optimizer = create_optimizer(args, model)
    scaler = torch.cuda.amp.GradScaler(enabled=args.amp)

    resume_state = {}
    resume_epoch = None
    output_base = args.output if args.output else './output'
    resume_checkpoint_path = get_latest_checkpoint(os.path.join(output_base, 'train'))
    if args.resume and resume_checkpoint_path is not None:
        print("Trying to load checkpoint from {}".format(resume_checkpoint_path))
        resume_state, resume_epoch = resume_checkpoint(unwrap_bench(model), resume_checkpoint_path)
        if resume_epoch is not None:
            print("Resume training from {} epoch".format(resume_epoch))
    if resume_state and not args.no_resume_opt:
        if 'optimizer' in resume_state:
            if args.local_rank == 0:
                logging.info('Restoring Optimizer state from checkpoint')
            optimizer.load_state_dict(resume_state['optimizer'])
        if args.amp and 'scaler' in resume_state:
            if args.local_rank == 0:
                logging.info('Restoring NVIDIA AMP state from checkpoint')
            scaler.load_state_dict(resume_state['scaler'])
    del resume_state

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        if args.resume and resume_checkpoint_path is not None:
            resume_path = resume_checkpoint_path
        else:
            resume_path = ''
        model_ema = ModelEma(
            model,
            decay=args.model_ema_decay,
            resume=resume_path)

    if args.distributed:
        if args.local_rank == 0:
            logging.info("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.")
        model = DDP(model, device_ids=[args.device])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        dllogger.log(step="PARAMETER", data={'Scheduled_epochs': num_epochs}, verbosity=0)

    # Benchmark will always override every other setting.
    if args.benchmark:
        start_epoch = 0
        num_epochs = args.epochs

    if args.waymo:
        train_annotation_path = args.waymo_train_annotation
        train_image_dir = args.waymo_train
    else:
        train_anno_set = 'train2017'
        train_annotation_path = os.path.join(args.data, 'annotations', f'instances_{train_anno_set}.json')
        train_image_dir = train_anno_set
    dataset_train = CocoDetection(os.path.join(args.data, train_image_dir), train_annotation_path, data_config)

    loader_train = create_loader(
        dataset_train,
        input_size=input_size,
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        interpolation=args.train_interpolation,
        num_workers=args.workers,
        distributed=args.distributed,
        pin_mem=args.pin_mem,
        memory_format=memory_format
    )

    loader_train_iter = iter(loader_train)
    steps_per_epoch = int(np.ceil( len(dataset_train) / (args.world_size * args.batch_size) ))

    if args.waymo:
        val_annotation_path = args.waymo_val_annotation
        val_image_dir = args.waymo_val
    else:
        val_anno_set = 'val2017'
        val_annotation_path = os.path.join(args.data, 'annotations', f'instances_{val_anno_set}.json')
        val_image_dir = val_anno_set
    dataset_eval = CocoDetection(os.path.join(args.data, val_image_dir), val_annotation_path, data_config)

    loader_eval = create_loader(
        dataset_eval,
        input_size=input_size,
        batch_size=args.validation_batch_size_multiplier * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=args.interpolation,
        num_workers=args.workers,
        distributed=args.distributed,
        pin_mem=args.pin_mem,
        memory_format=memory_format
    )

    evaluator = COCOEvaluator(dataset_eval.coco, distributed=args.distributed, waymo=args.waymo)

    eval_metric = args.eval_metric
    eval_metrics = None
    train_metrics = {}
    best_metric = -1
    is_best = False
    best_epoch = None
    saver = None
    output_dir = ''
    if args.rank == 0:
        output_base = args.output if args.output else './output'
        output_dir = get_outdirectory(output_base, 'train')
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    try:
        for epoch in range(start_epoch, num_epochs):
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_epoch(
                epoch, steps_per_epoch, model, loader_train_iter, optimizer, args,
                lr_scheduler=lr_scheduler, output_dir=output_dir, use_amp=args.amp, scaler=scaler, model_ema=model_ema)

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    logging.info("Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            # the overhead of evaluating with coco style datasets is fairly high, so just ema or non, not both
            if model_ema is not None:
                if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                    distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
                
                if epoch >= args.eval_after:
                    eval_metrics = validate(model_ema.ema, loader_eval, args, evaluator, epoch, log_suffix=' (EMA)')
            else:
                eval_metrics = validate(model, loader_eval, args, evaluator, epoch)

            lr_scheduler.step(epoch + 1)

            if saver is not None and args.rank == 0 and epoch % args.save_checkpoint_interval == 0:
                if eval_metrics is not None:
                    # save proper checkpoint with eval metric
                    is_best = eval_metrics[eval_metric] > best_metric
                    best_metric = max(
                        eval_metrics[eval_metric],
                        best_metric
                    )
                    best_epoch = epoch
                else:
                    is_best = False
                    best_metric = 0
                saver.save_checkpoint(model, optimizer, epoch, model_ema=model_ema, metric=best_metric, is_best=is_best)


    except KeyboardInterrupt:
        dllogger.flush()
        torch.cuda.empty_cache()
    if best_metric > 0:
        train_metrics.update({'best_map': best_metric, 'best_epoch': best_epoch})
    if eval_metrics is not None:
        train_metrics.update(eval_metrics)
    dllogger.log(step=(), data=train_metrics, verbosity=0)