예제 #1
0
def main():
    """
    Launches text to speech (inference).
    Inference is executed on a single GPU.
    """
    parser = argparse.ArgumentParser(description='PyTorch Tacotron 2 Inference')
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()

    LOGGER.set_model_name("Tacotron2_PyT")
    LOGGER.set_backends([
        dllg.StdOutBackend(log_file=None, logging_scope=dllg.TRAIN_ITER_SCOPE, iteration_interval=1),
        dllg.JsonBackend(log_file=args.log_file, logging_scope=dllg.TRAIN_ITER_SCOPE, iteration_interval=1)
    ])
    LOGGER.register_metric("tacotron2_frames_per_sec", metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("tacotron2_latency", metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("latency", metric_scope=dllg.TRAIN_ITER_SCOPE)

    model = load_and_setup_model(parser, args)

    log_hardware()
    log_args(args)

    if args.include_warmup:
        sequences = torch.randint(low=0, high=148, size=(1,50), dtype=torch.long).cuda()
        text_lengths = torch.IntTensor([sequence.size(1)]).cuda().long()
        for i in range(3):
            with torch.no_grad():
                _, mels, _, _, mel_lengths = model.infer(sequences, text_lengths)

    os.makedirs(args.output, exist_ok=True)

    LOGGER.iteration_start()

    measurements = {}

    anchor_dirs = [os.path.join(args.dataset_path, anchor) for anchor in args.anchor_dirs]
    metadatas = [load_metadata(anchor) for anchor in anchor_dirs]
    with torch.no_grad(), MeasureTime(measurements, "tacotron2_time"):
        for speaker_id in range(len(anchor_dirs)):
            metadata = metadatas[speaker_id]
            for mel_path, text in tqdm(metadata):
               seq = text_to_sequence(text, speaker_id, ['basic_cleaners'])
               seqs = torch.from_numpy(np.stack(seq)).unsqueeze(0)
               seq_lens = torch.IntTensor([len(text)])
               melspec = torch.from_numpy(np.load(mel_path))
               target = melspec[:, ::args.reduction_factor]
               targets = torch.from_numpy(np.stack(target)).unsqueeze(0)
               target_lengths = torch.IntTensor([target.shape[1]])
               inputs = (to_gpu(seqs).long(), to_gpu(seq_lens).int(), to_gpu(targets).float(), to_gpu(target_lengths).int())
               _, mel_outs, _, _ = model(inputs)
               fname = os.path.basename(mel_path)
               np.save(os.path.join(args.output, fname), mel_outs[0, :, :melspec.shape[1]], allow_pickle=False)

    LOGGER.log(key="tacotron2_latency", value=measurements['tacotron2_time'])
    LOGGER.log(key="latency", value=(measurements['tacotron2_time']))
    LOGGER.iteration_stop()
    LOGGER.finish()
예제 #2
0
def main():
    """
    Launches text to speech (inference).
    Inference is executed on a single GPU.
    """
    parser = argparse.ArgumentParser(
        description='PyTorch Tacotron 2 Inference')
    parser = parse_training_args(parser)
    args, _ = parser.parse_known_args()

    LOGGER.set_model_name("Tacotron2_PyT")
    LOGGER.set_backends([
        dllg.StdOutBackend(log_file=None,
                           logging_scope=dllg.TRAIN_ITER_SCOPE,
                           iteration_interval=1),
        dllg.JsonBackend(log_file=args.log_file,
                         logging_scope=dllg.TRAIN_ITER_SCOPE,
                         iteration_interval=1)
    ])
    LOGGER.register_metric("tacotron2_frames_per_sec",
                           metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("tacotron2_latency",
                           metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("latency", metric_scope=dllg.TRAIN_ITER_SCOPE)

    model, args = load_and_setup_model(parser, args)

    log_hardware()
    log_args(args)

    os.makedirs(args.output_dir, exist_ok=True)

    LOGGER.iteration_start()

    measurements = {}

    anchor_dirs = [
        os.path.join(args.dataset_path, anchor)
        for anchor in args.training_anchor_dirs
    ]
    metadatas = [load_metadata(anchor) for anchor in anchor_dirs]
    stft = TacotronSTFT(args.filter_length, args.hop_length, args.win_length,
                        args.n_mel_channels, args.sampling_rate, args.mel_fmin,
                        args.mel_fmax)
    with torch.no_grad(), MeasureTime(measurements, "tacotron2_time"):
        for speaker_id in range(len(anchor_dirs)):
            metadata = metadatas[speaker_id]
            for npy_path, text in tqdm(metadata):
                seq = text_to_sequence(text, speaker_id, ['basic_cleaners'])
                seqs = torch.from_numpy(np.stack(seq)).unsqueeze(0)
                seq_lens = torch.IntTensor([len(text)])
                wav = load_wav_to_torch(npy_path)
                mel = stft.mel_spectrogram(wav.unsqueeze(0))
                mel = mel.squeeze()
                max_target_len = mel.size(1) - 1
                max_target_len += args.n_frames_per_step - max_target_len % args.n_frames_per_step
                padded_mel = np.pad(mel, [(0, 0),
                                          (0, max_target_len - mel.size(1))],
                                    mode='constant',
                                    constant_values=args.mel_pad_val)
                target = padded_mel[:, ::args.n_frames_per_step]
                targets = torch.from_numpy(np.stack(target)).unsqueeze(0)
                target_lengths = torch.IntTensor([target.shape[1]])
                outputs = model.infer(
                    to_gpu(seqs).long(),
                    to_gpu(seq_lens).int(),
                    to_gpu(targets).half(),
                    to_gpu(target_lengths).int())
                _, mel_out, _, _ = [
                    output.cpu() for output in outputs if output is not None
                ]
                mel_out = mel_out.squeeze()[:, :mel.size(-1) - 1]
                assert (mel_out.shape[-1] == wav.shape[-1] // args.hop_length)
                fname = os.path.basename(npy_path)
                np.save(os.path.join(args.output_dir, fname),
                        mel_out,
                        allow_pickle=False)
                # GTA synthesis
                # magnitudes = stft.inv_mel_spectrogram(mel_out.squeeze())
                # wav = griffin_lim(magnitudes, stft.stft_fn, 60)
                # save_wav(wav, os.path.join(args.output_dir, 'eval.wav'))

    LOGGER.log(key="tacotron2_latency", value=measurements['tacotron2_time'])
    LOGGER.log(key="latency", value=(measurements['tacotron2_time']))
    LOGGER.iteration_stop()
    LOGGER.finish()
예제 #3
0
def main():
    """
    Launches text to speech (inference).
    Inference is executed on a single GPU.
    """
    parser = argparse.ArgumentParser(
        description='PyTorch Tacotron 2 Inference')
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()

    LOGGER.set_model_name("Tacotron2_PyT")
    LOGGER.set_backends([
        dllg.StdOutBackend(log_file=None,
                           logging_scope=dllg.TRAIN_ITER_SCOPE, iteration_interval=1),
        dllg.JsonBackend(log_file=args.log_file,
                         logging_scope=dllg.TRAIN_ITER_SCOPE, iteration_interval=1)
    ])
    LOGGER.register_metric("tacotron2_items_per_sec", metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("waveglow_items_per_sec", metric_scope=dllg.TRAIN_ITER_SCOPE)

    log_hardware()
    log_args(args)

    # tacotron2 model filepath was specified
    if args.tacotron2:
        # Setup Tacotron2
        tacotron2 = load_and_setup_model('Tacotron2', parser, args.tacotron2, args.fp16_run)
    # file with mel spectrogram was specified
    elif args.mel_file:
        mel = torch.load(args.mel_file)
        mel = torch.autograd.Variable(mel.cuda())
        mel = torch.unsqueeze(mel, 0)

    # Setup WaveGlow
    if args.old_waveglow:
        waveglow = torch.load(args.waveglow)['model']
        waveglow = waveglow.remove_weightnorm(waveglow)
        waveglow = waveglow.cuda()
        waveglow.eval()
    else:
        waveglow = load_and_setup_model('WaveGlow', parser, args.waveglow, args.fp16_run)

    texts = []
    try:
        f = open(args.input, 'r')
        texts = f.readlines()
    except:
        print("Could not read file. Using default text.")
        texts = ["The forms of printed letters should be beautiful, and\
        that their arrangement on the page should be reasonable and\
        a help to the shapeliness of the letters themselves."]

    for i, text in enumerate(texts):

        LOGGER.iteration_start()

        sequence = np.array(text_to_sequence(text, ['english_cleaners']))[None, :]
        sequence = torch.autograd.Variable(
            torch.from_numpy(sequence)).cuda().long()

        if args.tacotron2:
            tacotron2_t0 = time.time()
            with torch.no_grad():
                _, mel, _, _ = tacotron2.inference(sequence)
            tacotron2_t1 = time.time()
            tacotron2_infer_perf = sequence.size(1)/(tacotron2_t1-tacotron2_t0)
            LOGGER.log(key="tacotron2_items_per_sec", value=tacotron2_infer_perf)

        waveglow_t0 = time.time()
        with torch.no_grad():
            audio = waveglow.infer(mel, sigma=args.sigma_infer)
            audio = audio.float()
        waveglow_t1 = time.time()
        waveglow_infer_perf = audio[0].size(0)/(waveglow_t1-waveglow_t0)

        audio_path = args.output + "audio_"+str(i)+".wav"
        write(audio_path, args.sampling_rate, audio[0].data.cpu().numpy())

        LOGGER.log(key="waveglow_items_per_sec", value=waveglow_infer_perf)
        LOGGER.iteration_stop()

    LOGGER.finish()
예제 #4
0
def main():
    """
    Launches text to speech (inference).
    Inference is executed on a single GPU.
    """
    parser = argparse.ArgumentParser(
        description='PyTorch Tacotron 2 Inference')
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()

    LOGGER.set_model_name("Tacotron2_PyT")
    LOGGER.set_backends([
        dllg.StdOutBackend(log_file=None,
                           logging_scope=dllg.TRAIN_ITER_SCOPE,
                           iteration_interval=1),
        dllg.JsonBackend(log_file=args.log_file,
                         logging_scope=dllg.TRAIN_ITER_SCOPE,
                         iteration_interval=1)
    ])
    LOGGER.register_metric("tacotron2_items_per_sec",
                           metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("tacotron2_latency",
                           metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("waveglow_items_per_sec",
                           metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("waveglow_latency",
                           metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("latency", metric_scope=dllg.TRAIN_ITER_SCOPE)

    log_hardware()
    log_args(args)

    tacotron2 = load_and_setup_model('Tacotron2', parser, args.tacotron2,
                                     args.amp_run)
    waveglow = load_and_setup_model('WaveGlow', parser, args.waveglow,
                                    args.amp_run)
    denoiser = Denoiser(waveglow).cuda()

    tacotron2.forward = tacotron2.infer
    type(tacotron2).forward = type(tacotron2).infer
    jitted_tacotron2 = torch.jit.script(tacotron2)

    texts = []
    try:
        f = open(args.input, 'r')
        texts = f.readlines()
    except:
        print("Could not read file")
        sys.exit(1)

    if args.include_warmup:
        sequence = torch.randint(low=0,
                                 high=148,
                                 size=(1, 50),
                                 dtype=torch.long).cuda()
        input_lengths = torch.IntTensor([sequence.size(1)]).cuda().long()
        for i in range(3):
            with torch.no_grad():
                mel, mel_lengths = jitted_tacotron2(sequence, input_lengths)
                _ = waveglow.infer(mel)

    LOGGER.iteration_start()

    measurements = {}

    sequences_padded, input_lengths = prepare_input_sequence(texts)

    with torch.no_grad(), MeasureTime(measurements, "tacotron2_time"):
        mel, mel_lengths = jitted_tacotron2(sequences_padded, input_lengths)

    with torch.no_grad(), MeasureTime(measurements, "waveglow_time"):
        audios = waveglow.infer(mel, sigma=args.sigma_infer)
        audios = audios.float()
        audios = denoiser(audios, strength=args.denoising_strength).squeeze(1)

    tacotron2_infer_perf = mel.size(0) * mel.size(
        2) / measurements['tacotron2_time']
    waveglow_infer_perf = audios.size(0) * audios.size(
        1) / measurements['waveglow_time']

    LOGGER.log(key="tacotron2_items_per_sec", value=tacotron2_infer_perf)
    LOGGER.log(key="tacotron2_latency", value=measurements['tacotron2_time'])
    LOGGER.log(key="waveglow_items_per_sec", value=waveglow_infer_perf)
    LOGGER.log(key="waveglow_latency", value=measurements['waveglow_time'])
    LOGGER.log(key="latency",
               value=(measurements['tacotron2_time'] +
                      measurements['waveglow_time']))

    for i, audio in enumerate(audios):
        audio = audio[:mel_lengths[i] * args.stft_hop_length]
        audio = audio / torch.max(torch.abs(audio))
        audio_path = args.output + "audio_" + str(i) + ".wav"
        write(audio_path, args.sampling_rate, audio.cpu().numpy())

    LOGGER.iteration_stop()
    LOGGER.finish()
 def begin(self):
     log_hardware(LOGGER)
     LOGGER.log(tags.RUN_INIT)
예제 #6
0
def main():

    parser = argparse.ArgumentParser(description='PyTorch Tacotron 2 Training')
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()

    LOGGER.set_model_name("Tacotron2_PyT")
    LOGGER.set_backends([
        dllg.StdOutBackend(log_file=None,
                           logging_scope=dllg.TRAIN_ITER_SCOPE,
                           iteration_interval=1),
        dllg.JsonBackend(log_file=args.log_file if args.rank == 0 else None,
                         logging_scope=dllg.TRAIN_ITER_SCOPE,
                         iteration_interval=1)
    ])

    LOGGER.timed_block_start("run")
    LOGGER.register_metric(tags.TRAIN_ITERATION_LOSS,
                           metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("iter_time", metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("epoch_time", metric_scope=dllg.EPOCH_SCOPE)
    LOGGER.register_metric("run_time", metric_scope=dllg.RUN_SCOPE)
    LOGGER.register_metric("val_iter_loss", metric_scope=dllg.EPOCH_SCOPE)
    LOGGER.register_metric("train_epoch_items/sec",
                           metric_scope=dllg.EPOCH_SCOPE)
    LOGGER.register_metric("train_epoch_avg_loss",
                           metric_scope=dllg.EPOCH_SCOPE)

    log_hardware()

    model_name = args.model_name
    parser = models.parse_model_args(model_name, parser)
    parser.parse_args()

    args = parser.parse_args()

    log_args(args)

    torch.backends.cudnn.enabled = args.cudnn_enabled
    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    distributed_run = args.world_size > 1
    if distributed_run:
        init_distributed(args, args.world_size, args.rank, args.group_name)

    LOGGER.log(key=tags.RUN_START)
    run_start_time = time.time()

    model_config = models.get_model_config(model_name, args)
    model = models.get_model(model_name,
                             model_config,
                             to_fp16=args.fp16_run,
                             to_cuda=True)

    epoch_start = 0
    if args.resume:
        resume_model_path = args.resume_tacotron2_path if args.model_name == "Tacotron2" else args.resume_waveglow_path
        checkpoint = torch.load(resume_model_path, map_location='cpu')
        epoch_start = checkpoint["epoch"]
        state_dict = checkpoint['state_dict']
        if checkpoint_from_distributed(state_dict):
            state_dict = unwrap_distributed(state_dict)

        model.load_state_dict(state_dict)
        print("restore model %s" % resume_model_path)

    if distributed_run:
        model = DDP(model)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.learning_rate,
                                 weight_decay=args.weight_decay)

    if args.fp16_run:
        optimizer = FP16_Optimizer(
            optimizer, dynamic_loss_scale=args.dynamic_loss_scaling)

    try:
        sigma = args.sigma
    except AttributeError:
        sigma = None

    criterion = loss_functions.get_loss_function(model_name, sigma)

    try:
        n_frames_per_step = args.n_frames_per_step
    except AttributeError:
        n_frames_per_step = None

    collate_fn = data_functions.get_collate_function(model_name,
                                                     n_frames_per_step)
    trainset = data_functions.get_data_loader(model_name, args.dataset_path,
                                              args.training_files, args)
    train_sampler = DistributedSampler(trainset) if distributed_run else None
    train_loader = DataLoader(trainset,
                              num_workers=1,
                              shuffle=False,
                              sampler=train_sampler,
                              batch_size=args.batch_size,
                              pin_memory=False,
                              drop_last=True,
                              collate_fn=collate_fn)

    valset = data_functions.get_data_loader(model_name, args.dataset_path,
                                            args.validation_files, args)

    batch_to_gpu = data_functions.get_batch_to_gpu(model_name)

    iteration = 0
    model.train()

    LOGGER.log(key=tags.TRAIN_LOOP)

    for epoch in range(epoch_start, args.epochs):
        LOGGER.epoch_start()
        epoch_start_time = time.time()
        LOGGER.log(key=tags.TRAIN_EPOCH_START, value=epoch)

        # used to calculate avg items/sec over epoch
        reduced_num_items_epoch = 0

        # used to calculate avg loss over epoch
        train_epoch_avg_loss = 0.0
        num_iters = 0

        # if overflow at the last iteration then do not save checkpoint
        overflow = False

        for i, batch in enumerate(train_loader):
            LOGGER.iteration_start()
            iter_start_time = time.time()
            LOGGER.log(key=tags.TRAIN_ITER_START, value=i)
            print("Batch: {}/{} epoch {}".format(i, len(train_loader), epoch))

            start = time.perf_counter()
            adjust_learning_rate(epoch, optimizer, args.learning_rate,
                                 args.anneal_steps, args.anneal_factor)

            model.zero_grad()
            x, y, num_items = batch_to_gpu(batch)

            if args.fp16_run:
                y_pred = model(fp32_to_fp16(x))
                loss = criterion(fp16_to_fp32(y_pred), y)
            else:
                y_pred = model(x)
                loss = criterion(y_pred, y)

            if distributed_run:
                reduced_loss = reduce_tensor(loss.data, args.world_size).item()
                reduced_num_items = reduce_tensor(num_items.data, 1).item()
            else:
                reduced_loss = loss.item()
                reduced_num_items = num_items.item()
            if np.isnan(reduced_loss):
                raise Exception("loss is NaN")

            LOGGER.log(key=tags.TRAIN_ITERATION_LOSS, value=reduced_loss)

            train_epoch_avg_loss += reduced_loss
            num_iters += 1

            # accumulate number of items processed in this epoch
            reduced_num_items_epoch += reduced_num_items

            if args.fp16_run:
                optimizer.backward(loss)
                grad_norm = optimizer.clip_master_grads(args.grad_clip_thresh)
            else:
                loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), args.grad_clip_thresh)

            optimizer.step()

            overflow = optimizer.overflow if args.fp16_run else False
            iteration += 1

            LOGGER.log(key=tags.TRAIN_ITER_STOP, value=i)

            iter_stop_time = time.time()
            iter_time = iter_stop_time - iter_start_time
            LOGGER.log(key="train_iter_items/sec",
                       value=(reduced_num_items / iter_time))
            LOGGER.log(key="iter_time", value=iter_time)
            LOGGER.iteration_stop()

        LOGGER.log(key=tags.TRAIN_EPOCH_STOP, value=epoch)
        epoch_stop_time = time.time()
        epoch_time = epoch_stop_time - epoch_start_time

        LOGGER.log(key="train_epoch_items/sec",
                   value=(reduced_num_items_epoch / epoch_time))
        LOGGER.log(key="train_epoch_avg_loss",
                   value=(train_epoch_avg_loss /
                          num_iters if num_iters > 0 else 0.0))
        LOGGER.log(key="epoch_time", value=epoch_time)

        LOGGER.log(key=tags.EVAL_START, value=epoch)

        validate(model, criterion, valset, iteration, args.batch_size,
                 args.world_size, collate_fn, distributed_run, args.rank,
                 batch_to_gpu, args.fp16_run)

        LOGGER.log(key=tags.EVAL_STOP, value=epoch)

        if not overflow and (epoch % args.epochs_per_checkpoint
                             == 0) and args.rank == 0:
            checkpoint_path = os.path.join(
                args.output_directory,
                "checkpoint_{}_{}".format(model_name, epoch))
            save_checkpoint(model, epoch, model_config, checkpoint_path)
            save_sample(
                model_name, model, args.waveglow_checkpoint,
                args.tacotron2_checkpoint, args.phrase_path,
                os.path.join(args.output_directory,
                             "sample_{}_{}.wav".format(model_name, iteration)),
                args.sampling_rate, args.fp16_run)

        LOGGER.epoch_stop()

    run_stop_time = time.time()
    run_time = run_stop_time - run_start_time
    LOGGER.log(key="run_time", value=run_time)
    LOGGER.log(key=tags.RUN_FINAL)

    print("training time", run_stop_time - run_start_time)

    LOGGER.timed_block_stop("run")

    if args.rank == 0:
        LOGGER.finish()
예제 #7
0
def main():
    """
    Launches text to speech (inference).
    Inference is executed on a single GPU.
    """
    parser = argparse.ArgumentParser(
        description='PyTorch Tacotron 2 Inference')
    parser = parse_training_args(parser)
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()

    LOGGER.set_model_name("Tacotron2_PyT")
    LOGGER.set_backends([
        dllg.StdOutBackend(log_file=None,
                           logging_scope=dllg.TRAIN_ITER_SCOPE,
                           iteration_interval=1),
        dllg.JsonBackend(log_file=args.log_file,
                         logging_scope=dllg.TRAIN_ITER_SCOPE,
                         iteration_interval=1)
    ])
    LOGGER.register_metric("tacotron2_frames_per_sec",
                           metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("tacotron2_latency",
                           metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("latency", metric_scope=dllg.TRAIN_ITER_SCOPE)

    model, args = load_and_setup_model(parser, args)

    log_hardware()
    log_args(args)

    try:
        f = open(args.input_file)
        sentences = list(map(lambda s: s.strip(), f.readlines()))
    except UnicodeDecodeError:
        f = open(args.input_file, encoding='gbk')
        sentences = list(map(lambda s: s.strip(), f.readlines()))

    os.makedirs(args.output_dir, exist_ok=True)

    LOGGER.iteration_start()

    measurements = {}

    sequences, text_lengths, ids_sorted_decreasing = prepare_input_sequence(
        sentences, args.speaker_id)

    with torch.no_grad(), MeasureTime(measurements, "tacotron2_time"):
        outputs = model.infer(sequences, text_lengths)
        _, mels, _, _, mel_lengths = [output.cpu() for output in outputs]

    tacotron2_infer_perf = mels.size(0) * mels.size(
        2) / measurements['tacotron2_time']

    LOGGER.log(key="tacotron2_frames_per_sec", value=tacotron2_infer_perf)
    LOGGER.log(key="tacotron2_latency", value=measurements['tacotron2_time'])
    LOGGER.log(key="latency", value=(measurements['tacotron2_time']))
    LOGGER.iteration_stop()
    LOGGER.finish()

    # recover to the original order and concatenate
    stft = TacotronSTFT(args.filter_length, args.hop_length, args.win_length,
                        args.n_mel_channels, args.sampling_rate, args.mel_fmin,
                        args.mel_fmax)
    ids_sorted_decreasing = ids_sorted_decreasing.numpy().tolist()
    mels = [mel[:, :length] for mel, length in zip(mels, mel_lengths)]
    mels = [
        mels[ids_sorted_decreasing.index(i)]
        for i in range(len(ids_sorted_decreasing))
    ]
    magnitudes = stft.inv_mel_spectrogram(torch.cat(mels, axis=-1))
    wav = griffin_lim(magnitudes, stft.stft_fn)
    save_wav(wav, os.path.join(args.output_dir, 'eval.wav'))
    np.save(os.path.join(args.output_dir, 'eval.npy'),
            np.concatenate(mels, axis=-1),
            allow_pickle=False)
예제 #8
0
파일: train.py 프로젝트: herenje/tacotron2
def main():

    parser = argparse.ArgumentParser(description='PyTorch Tacotron 2 Training')
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()

    LOGGER.set_model_name("Tacotron2_PyT")
    LOGGER.set_backends([
        dllg.StdOutBackend(log_file=None,
                           logging_scope=dllg.TRAIN_ITER_SCOPE,
                           iteration_interval=1),
        dllg.JsonBackend(log_file=os.path.join(
            args.output_directory, args.log_file) if args.rank == 0 else None,
                         logging_scope=dllg.TRAIN_ITER_SCOPE,
                         iteration_interval=1)
    ])

    LOGGER.timed_block_start("run")
    LOGGER.register_metric(tags.TRAIN_ITERATION_LOSS,
                           metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("iter_time", metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("epoch_time", metric_scope=dllg.EPOCH_SCOPE)
    LOGGER.register_metric("run_time", metric_scope=dllg.RUN_SCOPE)
    LOGGER.register_metric("val_iter_loss", metric_scope=dllg.EPOCH_SCOPE)
    LOGGER.register_metric("train_epoch_frames/sec",
                           metric_scope=dllg.EPOCH_SCOPE)
    LOGGER.register_metric("train_epoch_avg_frames/sec",
                           metric_scope=dllg.EPOCH_SCOPE)
    LOGGER.register_metric("train_epoch_avg_loss",
                           metric_scope=dllg.EPOCH_SCOPE)

    log_hardware()

    parser = parse_tacotron2_args(parser)
    args = parser.parse_args()

    log_args(args)

    torch.backends.cudnn.enabled = args.cudnn_enabled
    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    distributed_run = args.world_size > 1
    if distributed_run:
        init_distributed(args, args.world_size, args.rank, args.group_name)

    os.makedirs(args.output_directory, exist_ok=True)

    LOGGER.log(key=tags.RUN_START)
    run_start_time = time.time()

    model = get_tacotron2_model(args,
                                len(args.training_anchor_dirs),
                                is_training=True)

    if not args.amp_run and distributed_run:
        model = DDP(model)

    model.restore_checkpoint(
        os.path.join(args.output_directory, args.latest_checkpoint_file))

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.init_lr,
                                 weight_decay=args.weight_decay)

    if args.amp_run:
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
        if distributed_run:
            model = DDP(model)

    criterion = Tacotron2Loss()

    collate_fn = TextMelCollate(args)
    train_dataset = TextMelDataset(args, args.training_anchor_dirs)
    train_loader = DataLoader(train_dataset,
                              num_workers=2,
                              shuffle=False,
                              batch_size=args.batch_size //
                              len(args.training_anchor_dirs),
                              pin_memory=False,
                              drop_last=True,
                              collate_fn=collate_fn)
    # valate_dataset = TextMelDataset(args, args.validation_anchor_dirs)

    model.train()

    elapsed_epochs = model.get_elapsed_epochs()
    epochs = args.epochs - elapsed_epochs
    iteration = elapsed_epochs * len(train_loader)

    LOGGER.log(key=tags.TRAIN_LOOP)

    for epoch in range(1, epochs + 1):
        LOGGER.epoch_start()
        epoch_start_time = time.time()
        epoch += elapsed_epochs
        LOGGER.log(key=tags.TRAIN_EPOCH_START, value=epoch)

        # used to calculate avg frames/sec over epoch
        reduced_num_frames_epoch = 0

        # used to calculate avg loss over epoch
        train_epoch_avg_loss = 0.0
        train_epoch_avg_frames_per_sec = 0.0
        num_iters = 0

        adjust_learning_rate(optimizer, epoch, args)

        for i, batch in enumerate(train_loader):
            print(f"Batch: {i}/{len(train_loader)} epoch {epoch}")
            LOGGER.iteration_start()
            iter_start_time = time.time()
            LOGGER.log(key=tags.TRAIN_ITER_START, value=i)

            # start = time.perf_counter()

            optimizer.zero_grad()
            x, y, num_frames = batch_to_gpu(batch)

            y_pred = model(x)

            loss = criterion(y_pred, y)

            if distributed_run:
                reduced_loss = reduce_tensor(loss.data, args.world_size).item()
                reduced_num_frames = reduce_tensor(num_frames.data, 1).item()
            else:
                reduced_loss = loss.item()
                reduced_num_frames = num_frames.item()
            if np.isnan(reduced_loss):
                raise Exception("loss is NaN")

            LOGGER.log(key=tags.TRAIN_ITERATION_LOSS, value=reduced_loss)

            train_epoch_avg_loss += reduced_loss
            num_iters += 1

            # accumulate number of frames processed in this epoch
            reduced_num_frames_epoch += reduced_num_frames

            if args.amp_run:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    amp.master_params(optimizer), args.grad_clip_thresh)
            else:
                loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), args.grad_clip_thresh)

            optimizer.step()

            iteration += 1

            LOGGER.log(key=tags.TRAIN_ITER_STOP, value=i)

            iter_stop_time = time.time()
            iter_time = iter_stop_time - iter_start_time
            frames_per_sec = reduced_num_frames / iter_time
            train_epoch_avg_frames_per_sec += frames_per_sec

            LOGGER.log(key="train_iter_frames/sec", value=frames_per_sec)
            LOGGER.log(key="iter_time", value=iter_time)
            LOGGER.iteration_stop()

        LOGGER.log(key=tags.TRAIN_EPOCH_STOP, value=epoch)
        epoch_stop_time = time.time()
        epoch_time = epoch_stop_time - epoch_start_time

        LOGGER.log(key="train_epoch_frames/sec",
                   value=(reduced_num_frames_epoch / epoch_time))
        LOGGER.log(key="train_epoch_avg_frames/sec",
                   value=(train_epoch_avg_frames_per_sec /
                          num_iters if num_iters > 0 else 0.0))
        LOGGER.log(key="train_epoch_avg_loss",
                   value=(train_epoch_avg_loss /
                          num_iters if num_iters > 0 else 0.0))
        LOGGER.log(key="epoch_time", value=epoch_time)

        LOGGER.log(key=tags.EVAL_START, value=epoch)

        # validate(model, criterion, valate_dataset, iteration, collate_fn, distributed_run, args)

        LOGGER.log(key=tags.EVAL_STOP, value=epoch)

        # Store latest checkpoint in each epoch
        model.elapse_epoch()
        checkpoint_path = os.path.join(args.output_directory,
                                       args.latest_checkpoint_file)
        torch.save(model.state_dict(), checkpoint_path)

        # Plot alignemnt
        if epoch % args.epochs_per_alignment == 0 and args.rank == 0:
            alignments = y_pred[3].data.numpy()
            index = np.random.randint(len(alignments))
            plot_alignment(
                alignments[index].transpose(0, 1),  # [enc_step, dec_step]
                os.path.join(args.output_directory,
                             f"align_{epoch:04d}_{iteration}.png"),
                info=
                f"{datetime.now().strftime('%Y-%m-%d %H:%M')} Epoch={epoch:04d} Iteration={iteration} Average loss={train_epoch_avg_loss/num_iters:.5f}"
            )

        # Save checkpoint
        if epoch % args.epochs_per_checkpoint == 0 and args.rank == 0:
            checkpoint_path = os.path.join(args.output_directory,
                                           f"checkpoint_{epoch:04d}.pt")
            print(
                f"Saving model and optimizer state at epoch {epoch:04d} to {checkpoint_path}"
            )
            torch.save(model.state_dict(), checkpoint_path)

            # Save evaluation
            # save_sample(model, args.tacotron2_checkpoint, args.phrase_path,
            #             os.path.join(args.output_directory, f"sample_{epoch:04d}_{iteration}.wav"), args.sampling_rate)

        LOGGER.epoch_stop()

    run_stop_time = time.time()
    run_time = run_stop_time - run_start_time
    LOGGER.log(key="run_time", value=run_time)
    LOGGER.log(key=tags.RUN_FINAL)

    print("training time", run_stop_time - run_start_time)

    LOGGER.timed_block_stop("run")

    if args.rank == 0:
        LOGGER.finish()
예제 #9
0
def main():
    """
    Launches text to speech (inference).
    Inference is executed on a single GPU.
    """
    parser = argparse.ArgumentParser(
        description='PyTorch Tacotron 2 Inference')
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()

    LOGGER.set_model_name("Tacotron2_PyT")
    LOGGER.set_backends([
        dllg.StdOutBackend(log_file=None,
                           logging_scope=dllg.TRAIN_ITER_SCOPE,
                           iteration_interval=1),
        dllg.JsonBackend(log_file=args.log_file,
                         logging_scope=dllg.TRAIN_ITER_SCOPE,
                         iteration_interval=1)
    ])
    LOGGER.register_metric("tacotron2_frames_per_sec",
                           metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("tacotron2_latency",
                           metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("latency", metric_scope=dllg.TRAIN_ITER_SCOPE)

    model = load_and_setup_model(parser, args)

    log_hardware()
    log_args(args)

    if args.include_warmup:
        sequences = torch.randint(low=0,
                                  high=148,
                                  size=(1, 50),
                                  dtype=torch.long).cuda()
        text_lengths = torch.IntTensor([sequence.size(1)]).cuda().long()
        for i in range(3):
            with torch.no_grad():
                _, mels, _, _, mel_lengths = model.infer(
                    sequences, text_lengths)

    try:
        f = open(args.input_file)
        sentences = list(map(lambda s: s.strip(), f.readlines()))
    except UnicodeDecodeError:
        f = open(args.input_file, encoding='gbk')
        sentences = list(map(lambda s: s.strip(), f.readlines()))

    os.makedirs(args.output, exist_ok=True)

    texts = sentences[1::2]

    LOGGER.iteration_start()

    measurements = {}

    sequences, text_lengths, ids_sorted_decreasing = prepare_input_sequence(
        texts, args.speaker_id)

    with torch.no_grad(), MeasureTime(measurements, "tacotron2_time"):
        _, mels, _, _, mel_lengths = model.infer(sequences, text_lengths)

    tacotron2_infer_perf = mels.size(0) * mels.size(
        2) / measurements['tacotron2_time']

    LOGGER.log(key="tacotron2_frames_per_sec", value=tacotron2_infer_perf)
    LOGGER.log(key="tacotron2_latency", value=measurements['tacotron2_time'])
    LOGGER.log(key="latency", value=(measurements['tacotron2_time']))
    LOGGER.iteration_stop()
    LOGGER.finish()

    # recover to the original order and concatenate
    ids_sorted_decreasing = ids_sorted_decreasing.numpy().tolist()
    mels = [mel[:, :length] for mel, length in zip(mels, mel_lengths)]
    mels = [
        mels[ids_sorted_decreasing.index(i)]
        for i in range(len(ids_sorted_decreasing))
    ]
    wav = audio.inv_mel_spectrogram(np.concatenate(mels, axis=-1))
    audio.save_wav(wav, os.path.join(args.output, 'eval.wav'))
예제 #10
0
def main():
    """
    Launches inference benchmark.
    Inference is executed on a single GPU.
    """
    parser = argparse.ArgumentParser(
        description='PyTorch Tacotron 2 Inference')
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()

    log_file = args.log_file

    LOGGER.set_model_name("Tacotron2_PyT")
    LOGGER.set_backends([
        dllg.StdOutBackend(log_file=None,
                           logging_scope=dllg.TRAIN_ITER_SCOPE,
                           iteration_interval=1),
        dllg.JsonBackend(log_file,
                         logging_scope=dllg.TRAIN_ITER_SCOPE,
                         iteration_interval=1)
    ])
    LOGGER.register_metric("items_per_sec", metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("latency", metric_scope=dllg.TRAIN_ITER_SCOPE)

    log_hardware()
    log_args(args)

    model = load_and_setup_model(args.model_name, parser, None, args.amp_run)

    warmup_iters = 3
    num_iters = 1 + warmup_iters

    for i in range(num_iters):
        if i >= warmup_iters:
            LOGGER.iteration_start()

        measurements = {}

        if args.model_name == 'Tacotron2':
            text_padded = torch.randint(low=0,
                                        high=148,
                                        size=(args.batch_size, 140),
                                        dtype=torch.long).cuda()
            input_lengths = torch.IntTensor([text_padded.size(1)] *
                                            args.batch_size).cuda().long()
            with torch.no_grad(), MeasureTime(measurements, "inference_time"):
                mels, _ = model.infer(text_padded, input_lengths)
            num_items = mels.size(0) * mels.size(2)

        if args.model_name == 'WaveGlow':
            n_mel_channels = model.upsample.in_channels
            num_mels = 895
            mel_padded = torch.zeros(args.batch_size, n_mel_channels,
                                     num_mels).normal_(-5.62, 1.98).cuda()
            if args.amp_run:
                mel_padded = mel_padded.half()

            with torch.no_grad(), MeasureTime(measurements, "inference_time"):
                audios = model.infer(mel_padded)
                audios = audios.float()
            num_items = audios.size(0) * audios.size(1)

        if i >= warmup_iters:
            LOGGER.log(key="items_per_sec",
                       value=(num_items / measurements['inference_time']))
            LOGGER.log(key="latency", value=measurements['inference_time'])
            LOGGER.iteration_stop()

    LOGGER.finish()
def main():
    """
    Launches text to speech (inference).
    Inference is executed on a single GPU.
    """
    parser = argparse.ArgumentParser(
        description='PyTorch Tacotron 2 Inference')
    parser = parse_args(parser)
    args, unknown_args = parser.parse_known_args()

    LOGGER.set_model_name("Tacotron2_PyT")
    LOGGER.set_backends([
        dllg.JsonBackend(log_file=args.log_file,
                         logging_scope=dllg.TRAIN_ITER_SCOPE, iteration_interval=1)
    ])
    LOGGER.register_metric("pre_processing", metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("tacotron2_items_per_sec", metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("tacotron2_latency", metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("waveglow_items_per_sec", metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("waveglow_latency", metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("latency", metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("type_conversion", metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("storage", metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("data_transfer", metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("num_mels_per_audio", metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("throughput", metric_scope=dllg.TRAIN_ITER_SCOPE)

    measurements_all = {"pre_processing": [],
                        "tacotron2_latency": [],
                        "waveglow_latency": [],
                        "latency": [],
                        "type_conversion": [],
                        "data_transfer": [],
                        "storage": [],
                        "tacotron2_items_per_sec": [],
                        "waveglow_items_per_sec": [],
                        "num_mels_per_audio": [],
                        "throughput": []}

    log_hardware()
    log_args(args)

    print("args:", args, unknown_args)

    tacotron2 = load_and_setup_model('Tacotron2', parser, args.tacotron2, args.amp_run)
    waveglow = load_and_setup_model('WaveGlow', parser, args.waveglow, args.amp_run)

    texts = ["The forms of printed letters should be beautiful, and that their arrangement on the page should be reasonable and a help to the shapeliness of the letters themselves. The forms of printed letters should be beautiful, and that their arrangement on the page should be reasonable and a help to the shapeliness of the letters themselves."]
    texts = [texts[0][:args.input_length]]
    texts = texts*args.batch_size

    warmup_iters = 3

    for iter in range(args.num_iters):

        if iter >= warmup_iters:
            LOGGER.iteration_start()

        measurements = {}

        with MeasureTime(measurements, "pre_processing"):
            sequences_padded, input_lengths = prepare_input_sequence(texts)

        with torch.no_grad():
            with MeasureTime(measurements, "latency"):
                with MeasureTime(measurements, "tacotron2_latency"):
                    _, mel, _, _, mel_lengths = tacotron2.infer(sequences_padded, input_lengths)

                with MeasureTime(measurements, "waveglow_latency"):
                    audios = waveglow.infer(mel, sigma=args.sigma_infer)

        num_mels = mel.size(0)*mel.size(2)
        num_samples = audios.size(0)*audios.size(1)

        with MeasureTime(measurements, "type_conversion"):
            audios = audios.float()

        with MeasureTime(measurements, "data_transfer"):
            audios = audios.cpu()

        with MeasureTime(measurements, "storage"):
            audios = audios.numpy()
            for i, audio in enumerate(audios):
                audio_path = "audio_"+str(i)+".wav"
                write(audio_path, args.sampling_rate,
                      audio[:mel_lengths[i]*args.stft_hop_length])

        measurements['tacotron2_items_per_sec'] = num_mels/measurements['tacotron2_latency']
        measurements['waveglow_items_per_sec'] = num_samples/measurements['waveglow_latency']
        measurements['num_mels_per_audio'] = mel.size(2)
        measurements['throughput'] = num_samples/measurements['latency']

        if iter >= warmup_iters:
            for k,v in measurements.items():
                measurements_all[k].append(v)
                LOGGER.log(key=k, value=v)

            LOGGER.iteration_stop()

    LOGGER.finish()

    print(np.mean(measurements_all['latency'][1:]),
          np.mean(measurements_all['throughput'][1:]),
          np.mean(measurements_all['pre_processing'][1:]),
          np.mean(measurements_all['type_conversion'][1:])+
          np.mean(measurements_all['storage'][1:])+
          np.mean(measurements_all['data_transfer'][1:]),
          np.mean(measurements_all['num_mels_per_audio'][1:]))

    throughput = measurements_all['throughput']
    preprocessing = measurements_all['pre_processing']
    type_conversion = measurements_all['type_conversion']
    storage = measurements_all['storage']
    data_transfer = measurements_all['data_transfer']
    postprocessing = [sum(p) for p in zip(type_conversion,storage,data_transfer)]
    latency = measurements_all['latency']
    num_mels_per_audio = measurements_all['num_mels_per_audio']

    latency.sort()

    cf_50 = max(latency[:int(len(latency)*0.50)])
    cf_90 = max(latency[:int(len(latency)*0.90)])
    cf_95 = max(latency[:int(len(latency)*0.95)])
    cf_99 = max(latency[:int(len(latency)*0.99)])
    cf_100 = max(latency[:int(len(latency)*1.0)])

    print("Throughput average (samples/sec) = {:.4f}".format(np.mean(throughput)))
    print("Preprocessing average (seconds) = {:.4f}".format(np.mean(preprocessing)))
    print("Postprocessing average (seconds) = {:.4f}".format(np.mean(postprocessing)))
    print("Number of mels per audio average = {}".format(np.mean(num_mels_per_audio)))
    print("Latency average (seconds) = {:.4f}".format(np.mean(latency)))
    print("Latency std (seconds) = {:.4f}".format(np.std(latency)))
    print("Latency cl 50 (seconds) = {:.4f}".format(cf_50))
    print("Latency cl 90 (seconds) = {:.4f}".format(cf_90))
    print("Latency cl 95 (seconds) = {:.4f}".format(cf_95))
    print("Latency cl 99 (seconds) = {:.4f}".format(cf_99))
    print("Latency cl 100 (seconds) = {:.4f}".format(cf_100))
예제 #12
0
def main():

    parser = argparse.ArgumentParser(description='PyTorch Tacotron 2 Training')
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()

    LOGGER.set_model_name("Tacotron2_PyT")
    LOGGER.set_backends([
        dllg.StdOutBackend(log_file=None,
                           logging_scope=dllg.TRAIN_ITER_SCOPE,
                           iteration_interval=1),
        dllg.JsonBackend(log_file=args.log_file if args.rank == 0 else None,
                         logging_scope=dllg.TRAIN_ITER_SCOPE,
                         iteration_interval=1)
    ])

    LOGGER.timed_block_start("run")
    LOGGER.register_metric(tags.TRAIN_ITERATION_LOSS,
                           metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("iter_time", metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("epoch_time", metric_scope=dllg.EPOCH_SCOPE)
    LOGGER.register_metric("run_time", metric_scope=dllg.RUN_SCOPE)
    LOGGER.register_metric("val_iter_loss", metric_scope=dllg.EPOCH_SCOPE)
    LOGGER.register_metric("train_epoch_items/sec",
                           metric_scope=dllg.EPOCH_SCOPE)
    LOGGER.register_metric("train_epoch_avg_items/sec",
                           metric_scope=dllg.EPOCH_SCOPE)
    LOGGER.register_metric("train_epoch_avg_loss",
                           metric_scope=dllg.EPOCH_SCOPE)

    log_hardware()

    # Restore training from checkpoint logic
    checkpoint = None
    start_epoch = 0

    model_name = args.model_name
    parser = models.parse_model_args(model_name, parser)
    parser.parse_args()

    args = parser.parse_args()

    log_args(args)

    torch.backends.cudnn.enabled = args.cudnn_enabled
    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    num_gpus = torch.cuda.device_count()
    print("gpus", num_gpus)
    distributed_run = num_gpus > 1
    if distributed_run:
        init_distributed(args, args.world_size, args.rank, args.group_name)

    LOGGER.log(key=tags.RUN_START)
    run_start_time = time.time()

    # Restore training from checkpoint logic
    if args.restore_from:
        print('Restoring from {} checkpoint'.format(args.restore_from))
        checkpoint = torch.load(args.restore_from, map_location='cpu')
        start_epoch = checkpoint['epoch'] + 1
        model_config = checkpoint['config']
        model = models.get_model(model_name, model_config, to_cuda=True)

        new_state_dict = {}
        for key, value in checkpoint['state_dict'].items():
            new_key = key.replace('module.', '')
            new_state_dict[new_key] = value

        model_dict = new_state_dict
        if args.warm_start:
            ignore_layers = ['embedding.weight']
            print('Warm start')

            if len(ignore_layers) > 0:
                model_dict = {
                    k: v
                    for k, v in model_dict.items() if k not in ignore_layers
                }
                dummy_dict = model.state_dict()
                dummy_dict.update(model_dict)
                model_dict = dummy_dict

        model.load_state_dict(model_dict)
    else:
        model_config = models.get_model_config(model_name, args)
        model = models.get_model(model_name, model_config, to_cuda=True)
        print("model configured")
        #model.cuda(4)
    model.cuda()
    # if not args.amp_run and distributed_run:
    #     model = DDP(model ,delay_allreduce=True)
    #
    #

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.learning_rate,
                                 weight_decay=args.weight_decay)

    # Restore training from checkpoint logic
    if checkpoint and 'optimizer_state_dict' in checkpoint and not args.warm_start:  # TODO: think about this more
        print('Restoring optimizer state')
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    if args.amp_run:
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
        print("amp initialized")

        model = DDP(model, delay_allreduce=True)
        print("ddpmodel")

    try:
        sigma = args.sigma
    except AttributeError:
        sigma = None

    print("train starting")
    criterion = loss_functions.get_loss_function(model_name, sigma)

    try:
        n_frames_per_step = args.n_frames_per_step
    except AttributeError:
        n_frames_per_step = None

    print("data loading start")
    collate_fn = data_functions.get_collate_function(model_name,
                                                     n_frames_per_step)
    trainset = data_functions.get_data_loader(model_name, args.training_files,
                                              args)

    train_sampler = DistributedSampler(trainset) if distributed_run else None
    print("train loader started")

    train_loader = DataLoader(trainset,
                              num_workers=1,
                              shuffle=False,
                              sampler=train_sampler,
                              batch_size=args.batch_size,
                              pin_memory=False,
                              drop_last=True,
                              collate_fn=collate_fn)

    valset = data_functions.get_data_loader(model_name, args.validation_files,
                                            args)

    batch_to_gpu = data_functions.get_batch_to_gpu(model_name)

    iteration = 0
    model.train()

    LOGGER.log(key=tags.TRAIN_LOOP)

    # Restore training from checkpoint logic
    if start_epoch >= args.epochs:
        print('Checkpoint epoch {} >= total epochs {}'.format(
            start_epoch, args.epochs))
    else:
        for epoch in range(start_epoch, args.epochs):
            LOGGER.epoch_start()
            epoch_start_time = time.time()
            LOGGER.log(key=tags.TRAIN_EPOCH_START, value=epoch)

            # used to calculate avg items/sec over epoch
            reduced_num_items_epoch = 0

            # used to calculate avg loss over epoch
            train_epoch_avg_loss = 0.0
            train_epoch_avg_items_per_sec = 0.0
            num_iters = 0

            # if overflow at the last iteration then do not save checkpoint
            overflow = False

            for i, batch in enumerate(train_loader):

                print("Batch: {}/{} epoch {}".format(i, len(train_loader),
                                                     epoch))
                LOGGER.iteration_start()
                iter_start_time = time.time()
                LOGGER.log(key=tags.TRAIN_ITER_START, value=i)

                start = time.perf_counter()
                adjust_learning_rate(epoch, optimizer, args.learning_rate,
                                     args.anneal_steps, args.anneal_factor)

                model.zero_grad()

                x, y, num_items = batch_to_gpu(batch)

                y_pred = model(x)

                loss = criterion(y_pred, y)

                if distributed_run:
                    reduced_loss = reduce_tensor(loss.data,
                                                 args.world_size).item()
                    reduced_num_items = reduce_tensor(num_items.data, 1).item()
                else:
                    reduced_loss = loss.item()
                    reduced_num_items = num_items.item()
                if np.isnan(reduced_loss):
                    raise Exception("loss is NaN")

                LOGGER.log(key=tags.TRAIN_ITERATION_LOSS, value=reduced_loss)

                train_epoch_avg_loss += reduced_loss
                num_iters += 1

                # accumulate number of items processed in this epoch
                reduced_num_items_epoch += reduced_num_items

                if args.amp_run:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.grad_clip_thresh)
                else:
                    loss.backward()
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        model.parameters(), args.grad_clip_thresh)

                optimizer.step()

                iteration += 1

                LOGGER.log(key=tags.TRAIN_ITER_STOP, value=i)

                iter_stop_time = time.time()
                iter_time = iter_stop_time - iter_start_time
                items_per_sec = reduced_num_items / iter_time
                train_epoch_avg_items_per_sec += items_per_sec

                LOGGER.log(key="train_iter_items/sec", value=items_per_sec)
                LOGGER.log(key="iter_time", value=iter_time)
                LOGGER.iteration_stop()

            LOGGER.log(key=tags.TRAIN_EPOCH_STOP, value=epoch)
            epoch_stop_time = time.time()
            epoch_time = epoch_stop_time - epoch_start_time

            LOGGER.log(key="train_epoch_items/sec",
                       value=(reduced_num_items_epoch / epoch_time))
            LOGGER.log(key="train_epoch_avg_items/sec",
                       value=(train_epoch_avg_items_per_sec /
                              num_iters if num_iters > 0 else 0.0))
            LOGGER.log(key="train_epoch_avg_loss",
                       value=(train_epoch_avg_loss /
                              num_iters if num_iters > 0 else 0.0))
            LOGGER.log(key="epoch_time", value=epoch_time)

            LOGGER.log(key=tags.EVAL_START, value=epoch)

            validate(model, criterion, valset, iteration, args.batch_size,
                     args.world_size, collate_fn, distributed_run, args.rank,
                     batch_to_gpu)

            LOGGER.log(key=tags.EVAL_STOP, value=epoch)

            if (epoch % args.epochs_per_checkpoint == 0) and args.rank == 0:
                checkpoint_path = os.path.join(
                    args.output_directory,
                    "checkpoint_{}_{}".format(model_name, epoch))
                save_checkpoint(model, epoch, model_config, optimizer,
                                checkpoint_path)
                save_sample(
                    model_name, model, args.waveglow_checkpoint,
                    args.tacotron2_checkpoint, args.phrase_path,
                    os.path.join(
                        args.output_directory,
                        "sample_{}_{}.wav".format(model_name, iteration)),
                    args.sampling_rate)

            LOGGER.epoch_stop()

    run_stop_time = time.time()
    run_time = run_stop_time - run_start_time
    LOGGER.log(key="run_time", value=run_time)
    LOGGER.log(key=tags.RUN_FINAL)

    print("training time", run_stop_time - run_start_time)

    LOGGER.timed_block_stop("run")

    if args.rank == 0:
        LOGGER.finish()
def main():
    """
    Launches inference benchmark.
    Inference is executed on a single GPU.
    """
    parser = argparse.ArgumentParser(
        description='PyTorch Tacotron 2 Inference')
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()

    log_file = ("qa/baselines/" + args.model_name + "_inferbench_BS" + str(args.batch_size) + "_FP" + ("16" if args.fp16_run else "32") +
                "_DGX1_16GB_1GPU_single" + ".json") \
                if args.create_benchmark else \
                   (args.model_name + "_infer_BS" + str(args.batch_size) + "_FP" + ("16" if args.fp16_run else "32") + \
                   "_DGX1_16GB_1GPU_single" + ".json")

    LOGGER.set_model_name("Tacotron2_PyT")
    LOGGER.set_backends([
        dllg.StdOutBackend(log_file=None,
                           logging_scope=dllg.TRAIN_ITER_SCOPE,
                           iteration_interval=1),
        dllg.JsonBackend(log_file,
                         logging_scope=dllg.TRAIN_ITER_SCOPE,
                         iteration_interval=1)
    ])
    LOGGER.register_metric("items_per_sec", metric_scope=dllg.TRAIN_ITER_SCOPE)

    log_hardware()
    log_args(args)

    # ## uncomment to generate new padded text
    # texts = []
    # f = open('qa/ljs_text_train_subset_2500.txt', 'r')
    # texts = f.readlines()
    # sequence = []
    # for i, text in enumerate(texts):
    #     sequence.append(torch.IntTensor(text_to_sequence(text, ['english_cleaners'])))

    # text_padded, input_lengths = collate_text(sequence)
    # text_padded = torch.autograd.Variable(text_padded).cuda().long()
    # torch.save(text_padded, "qa/text_padded.pt")
    # torch.save(input_lengths, "qa/input_lengths.pt")

    model = load_and_setup_model(args.model_name, parser, None, args.fp16_run)

    dry_runs = 3
    num_iters = (16 + dry_runs) if args.create_benchmark else (1 + dry_runs)

    for i in range(num_iters):
        ## skipping the first inference which is slower
        if i >= dry_runs:
            LOGGER.iteration_start()

        if args.model_name == 'Tacotron2':
            text_padded = torch.load(args.input_text)
            text_padded = text_padded[:args.batch_size]
            text_padded = torch.autograd.Variable(text_padded).cuda().long()

            t0 = time.time()
            with torch.no_grad():
                _, mels, _, _ = model.infer(text_padded)
            t1 = time.time()
            inference_time = t1 - t0
            num_items = text_padded.size(0) * text_padded.size(1)

            # # ## uncomment to generate new padded mels
            # torch.save(mels, "qa/mel_padded.pt")

        if args.model_name == 'WaveGlow':
            mel_padded = torch.load(args.input_mels)
            mel_padded = torch.cat(
                (mel_padded, mel_padded, mel_padded, mel_padded))
            mel_padded = mel_padded[:args.batch_size]
            mel_padded = mel_padded.cuda()

            if args.fp16_run:
                mel_padded = mel_padded.half()

            t0 = time.time()
            with torch.no_grad():
                audios = model.infer(mel_padded)
                audios = audios.float()
            t1 = time.time()
            inference_time = t1 - t0
            num_items = audios.size(0) * audios.size(1)

        if i >= dry_runs:
            LOGGER.log(key="items_per_sec", value=(num_items / inference_time))
            LOGGER.iteration_stop()

    LOGGER.finish()