Example #1
0
def main():
    try:
        args = get_arguments()

        lc = read_binary_lc(args.lc, hparams.num_mels)

        if hparams.lc_encode or hparams.transposed_upsampling:
            lc = np.reshape(lc, [1, -1, hparams.num_mels])
        else:
            # upsampling local condition
            lc = np.tile(lc, [1, 1, hparams.upsampling_rate])
            lc = np.reshape(lc, [1, -1, hparams.num_mels])

        print(lc.shape)

        glow = WaveGlow(lc_dim=hparams.num_mels,
                        n_flows=hparams.n_flows,
                        n_group=hparams.n_group,
                        n_early_every=hparams.n_early_every,
                        n_early_size=hparams.n_early_size)

        lc_placeholder = tf.placeholder(tf.float32,
                                        shape=[1, None, hparams.num_mels],
                                        name='lc')
        audio = glow.infer(lc_placeholder, sigma=args.sigma)

        sess = tf.Session(config=tf.ConfigProto(log_device_placement=False,
                                                allow_soft_placement=True))
        print("restore model")
        saver = tf.train.Saver(var_list=tf.trainable_variables())
        saver.restore(sess, args.restore_from)
        print('restore model successfully!')

        audio_output = sess.run(audio, feed_dict={lc_placeholder: lc})
        audio_output = audio_output.flatten()
        print(audio_output)
        write_wav(audio_output, hparams.sample_rate, args.wave_name)
    except Exception:
        raise
Example #2
0
def waveglow_infer(mel, config):
    print(
        colored('Running WaveGlow with ', 'blue', attrs=['bold']) +
        config.vocoder_path)

    waveglow = WaveGlow(config)
    waveglow, _, _ = load_checkpoint(config.vocoder_path, waveglow)

    #waveglow = torch.hub.load('nvidia/DeepLearningExamples:torchhub', 'nvidia_waveglow')
    waveglow = waveglow.remove_weightnorm(waveglow)
    waveglow = set_device(waveglow, config.device)
    waveglow.eval()

    denoiser = Denoiser(waveglow, config)
    denoiser = set_device(denoiser, config.device)

    with torch.no_grad():
        wave = waveglow.infer(mel, config.sigma).float()
        wave = denoiser(wave, strength=config.denoising_strength)

    wave = wave / torch.max(torch.abs(wave))

    return wave.cpu()
Example #3
0
def train(num_gpus,
          rank,
          group_name,
          output_directory,
          epochs,
          learning_rate,
          sigma,
          iters_per_checkpoint,
          batch_size,
          seed,
          fp16_run,
          checkpoint_path,
          with_tensorboard,
          num_workers=2):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        init_distributed(rank, num_gpus, group_name, **dist_config)
    #=====END:   ADDED FOR DISTRIBUTED======

    criterion = WaveGlowLoss(sigma)
    model = WaveGlow(**waveglow_config).cuda()

    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        model = apply_gradient_allreduce(model)
    #=====END:   ADDED FOR DISTRIBUTED======

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    if fp16_run:
        from apex import amp
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

    # Load checkpoint if one exists
    iteration = 0
    if checkpoint_path != "":
        model, optimizer, iteration = load_checkpoint(checkpoint_path, model,
                                                      optimizer)
        iteration += 1  # next iteration is iteration + 1

    # HACK: setup separate training and eval sets
    training_files = data_config['training_files']
    eval_files = data_config['eval_files']
    del data_config['training_files']
    del data_config['eval_files']
    data_config['audio_files'] = training_files
    trainset = Mel2Samp(**data_config)
    data_config['audio_files'] = eval_files
    evalset = Mel2Samp(**data_config)

    # =====START: ADDED FOR DISTRIBUTED======
    train_sampler = DistributedSampler(trainset) if num_gpus > 1 else None
    eval_sampler = DistributedSampler(evalset) if num_gpus > 1 else None
    # =====END:   ADDED FOR DISTRIBUTED======

    print("Creating dataloaders with " + str(num_workers) + " workers")
    train_loader = DataLoader(trainset,
                              num_workers=num_workers,
                              shuffle=True,
                              sampler=train_sampler,
                              batch_size=batch_size,
                              pin_memory=False,
                              drop_last=True)
    eval_loader = DataLoader(evalset,
                             num_workers=num_workers,
                             shuffle=True,
                             sampler=eval_sampler,
                             batch_size=batch_size,
                             pin_memory=False,
                             drop_last=True)

    # Get shared output_directory ready
    if rank == 0:
        if not os.path.isdir(output_directory):
            os.makedirs(output_directory)
            os.chmod(output_directory, 0o775)
        print("output directory", output_directory)

    if with_tensorboard and rank == 0:
        from tensorboardX import SummaryWriter
        logger_train = SummaryWriter(
            os.path.join(output_directory, 'logs', 'train'))
        logger_eval = SummaryWriter(
            os.path.join(output_directory, 'logs', 'eval'))

    epoch_offset = max(0, int(iteration / len(train_loader)))
    # ================ MAIN TRAINNIG LOOP! ===================
    for epoch in range(epoch_offset, epochs):
        model.train()
        with tqdm(total=len(train_loader)) as train_pbar:
            for i, batch in enumerate(train_loader):
                model.zero_grad()

                mel, audio = batch
                mel = torch.autograd.Variable(mel.cuda())
                audio = torch.autograd.Variable(audio.cuda())
                outputs = model((mel, audio))

                loss = criterion(outputs)
                if num_gpus > 1:
                    reduced_loss = reduce_tensor(loss.data, num_gpus).item()
                else:
                    reduced_loss = loss.item()

                if fp16_run:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                optimizer.step()

                train_pbar.set_description(
                    "Epoch {} Iter {} Loss {:.3f}".format(
                        epoch, iteration, reduced_loss))
                if with_tensorboard and rank == 0 and iteration % 10 == 0:
                    logger_train.add_scalar('loss', reduced_loss,
                                            i + len(train_loader) * epoch)
                    # adding logging for GPU utilization and memory usage
                    gpu_memory_used, gpu_utilization = get_gpu_stats()
                    k = 'gpu' + str(0)
                    logger_train.add_scalar(k + '/memory', gpu_memory_used,
                                            iteration)
                    logger_train.add_scalar(k + '/load', gpu_utilization,
                                            iteration)
                    logger_train.flush()

                if (iteration % iters_per_checkpoint == 0):
                    if rank == 0:
                        checkpoint_path = "{}/waveglow_{}".format(
                            output_directory, iteration)
                        save_checkpoint(model, optimizer, learning_rate,
                                        iteration, checkpoint_path)

                iteration += 1
                train_pbar.update(1)

        # Eval
        model.eval()
        torch.cuda.empty_cache()

        with torch.no_grad():
            tensorboard_mel, tensorboard_audio = None, None
            loss_accum = []
            with tqdm(total=len(eval_loader)) as eval_pbar:
                for i, batch in enumerate(eval_loader):
                    model.zero_grad()
                    mel, audio = batch
                    mel = torch.autograd.Variable(mel.cuda())
                    audio = torch.autograd.Variable(audio.cuda())
                    outputs = model((mel, audio))
                    loss = criterion(outputs).item()
                    loss_accum.append(loss)
                    eval_pbar.set_description("Epoch {} Eval {:.3f}".format(
                        epoch, loss))
                    outputs = None

                    # use the first batch for tensorboard audio samples
                    if i == 0:
                        tensorboard_mel = mel
                        tensorboard_audio = audio
                    eval_pbar.update(1)

            if with_tensorboard and rank == 0:
                loss_avg = statistics.mean(loss_accum)
                tqdm.write("Epoch {} Eval AVG {}".format(epoch, loss_avg))
                logger_eval.add_scalar('loss', loss_avg, iteration)

            # log audio samples to tensorboard
            tensorboard_audio_generated = model.infer(tensorboard_mel)
            for i in range(0, 5):
                ta = tensorboard_audio[i].cpu().numpy()
                tag = tensorboard_audio_generated[i].cpu().numpy()
                logger_eval.add_audio("sample " + str(i) + "/orig",
                                      ta,
                                      epoch,
                                      sample_rate=data_config['sampling_rate'])
                logger_eval.add_audio("sample " + str(i) + "/gen",
                                      tag,
                                      epoch,
                                      sample_rate=data_config['sampling_rate'])
            logger_eval.flush()
Example #4
0
class TTSModel(object):
    """docstring for TTSModel."""
    def __init__(self, tacotron2_path, waveglow_path, **kwargs):
        super(TTSModel, self).__init__()
        hparams = HParams(**kwargs)
        self.hparams = hparams
        self.model = Tacotron2(hparams)
        if torch.cuda.is_available():
            self.model.load_state_dict(
                torch.load(tacotron2_path)["state_dict"])
            self.model.cuda().eval()
        else:
            self.model.load_state_dict(
                torch.load(tacotron2_path, map_location="cpu")["state_dict"])
            self.model.eval()
        self.k_cache = klepto.archives.file_archive(cached=False)
        if waveglow_path:
            if torch.cuda.is_available():
                wave_params = torch.load(waveglow_path)
            else:
                wave_params = torch.load(waveglow_path, map_location="cpu")
            try:
                self.waveglow = WaveGlow(**WAVEGLOW_CONFIG)
                self.waveglow.load_state_dict(wave_params)
            except:
                self.waveglow = wave_params["model"]
                self.waveglow = self.waveglow.remove_weightnorm(self.waveglow)
            if torch.cuda.is_available():
                self.waveglow.cuda().eval()
            else:
                self.waveglow.eval()
            # workaround from
            # https://github.com/NVIDIA/waveglow/issues/127
            for m in self.waveglow.modules():
                if "Conv" in str(type(m)):
                    setattr(m, "padding_mode", "zeros")
            for k in self.waveglow.convinv:
                k.float().half()
            self.denoiser = Denoiser(self.waveglow,
                                     n_mel_channels=hparams.n_mel_channels)
            self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)(
                self._synth_speech)
        else:
            self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)(
                self._synth_speech_fast)
        self.taco_stft = TacotronSTFT(
            hparams.filter_length,
            hparams.hop_length,
            hparams.win_length,
            n_mel_channels=hparams.n_mel_channels,
            sampling_rate=hparams.sampling_rate,
            mel_fmax=4000,
        )

    def _generate_mel_postnet(self, text):
        sequence = np.array(text_to_sequence(text,
                                             ["english_cleaners"]))[None, :]
        if torch.cuda.is_available():
            sequence = torch.autograd.Variable(
                torch.from_numpy(sequence)).cuda().long()
        else:
            sequence = torch.autograd.Variable(
                torch.from_numpy(sequence)).long()
        with torch.no_grad():
            mel_outputs, mel_outputs_postnet, _, alignments = self.model.inference(
                sequence)
        return mel_outputs_postnet

    def synth_speech_array(self, text, vocoder):
        mel_outputs_postnet = self._generate_mel_postnet(text)

        if vocoder == VOCODER_WAVEGLOW:
            with torch.no_grad():
                audio_t = self.waveglow.infer(mel_outputs_postnet, sigma=0.666)
                audio_t = self.denoiser(audio_t, 0.1)[0]
            audio = audio_t[0].data
        elif vocoder == VOCODER_GL:
            mel_decompress = self.taco_stft.spectral_de_normalize(
                mel_outputs_postnet)
            mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
            spec_from_mel_scaling = 1000
            spec_from_mel = torch.mm(mel_decompress[0],
                                     self.taco_stft.mel_basis)
            spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
            spec_from_mel = spec_from_mel * spec_from_mel_scaling
            spec_from_mel = (spec_from_mel.cuda()
                             if torch.cuda.is_available() else spec_from_mel)
            audio = griffin_lim(
                torch.autograd.Variable(spec_from_mel[:, :, :-1]),
                self.taco_stft.stft_fn,
                GL_ITERS,
            )
            audio = audio.squeeze()
        else:
            raise ValueError("vocoder arg should be one of [wavglow|gl]")
        audio = audio.cpu().numpy()
        return audio

    def _synth_speech(self,
                      text,
                      speed: float = 1.0,
                      sample_rate: int = OUTPUT_SAMPLE_RATE):
        audio = self.synth_speech_array(text, VOCODER_WAVEGLOW)

        return postprocess_audio(
            audio,
            src_rate=self.hparams.sampling_rate,
            dst_rate=sample_rate,
            tempo=speed,
        )

    def _synth_speech_fast(self,
                           text,
                           speed: float = 1.0,
                           sample_rate: int = OUTPUT_SAMPLE_RATE):
        audio = self.synth_speech_array(text, VOCODER_GL)

        return postprocess_audio(
            audio,
            tempo=speed,
            src_rate=self.hparams.sampling_rate,
            dst_rate=sample_rate,
        )
Example #5
0
def train(num_gpus, rank, group_name, output_directory, epochs, learning_rate,
          sigma, iters_per_checkpoint, iters_per_validation, batch_size, seed,
          fp16_run, checkpoint_path, with_tensorboard):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        init_distributed(rank, num_gpus, group_name, **dist_config)
    #=====END:   ADDED FOR DISTRIBUTED======

    criterion = WaveGlowLoss(sigma)
    model = WaveGlow(**waveglow_config).cuda()

    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        model = apply_gradient_allreduce(model)
    #=====END:   ADDED FOR DISTRIBUTED======

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    if fp16_run:
        from apex import amp
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

    # Load checkpoint if one exists
    iteration = 0
    if checkpoint_path != "":
        model, optimizer, iteration = load_checkpoint(checkpoint_path, model,
                                                      optimizer)
        iteration += 1  # next iteration is iteration + 1

    trainset = Mel2Samp(**data_config)
    valset = Mel2Samp(**data_config, val=True)
    # =====START: ADDED FOR DISTRIBUTED======
    train_sampler = DistributedSampler(trainset) if num_gpus > 1 else None
    # =====END:   ADDED FOR DISTRIBUTED======
    train_loader = DataLoader(trainset,
                              num_workers=1,
                              shuffle=False,
                              sampler=train_sampler,
                              batch_size=batch_size,
                              pin_memory=False,
                              drop_last=True)

    # Get shared output_directory ready
    if rank == 0:
        if not os.path.isdir(output_directory):
            os.makedirs(output_directory)
            os.chmod(output_directory, 0o775)
        print("output directory", output_directory)

    if with_tensorboard and rank == 0:
        from tensorboardX import SummaryWriter
        logger = SummaryWriter(os.path.join(output_directory, 'logs'))

    model.train()
    epoch_offset = max(0, int(iteration / len(train_loader)))
    # ================ MAIN TRAINNIG LOOP! ===================
    for epoch in range(epoch_offset, epochs):
        print("Epoch: {}".format(epoch))
        for i, batch in enumerate(train_loader):
            model.zero_grad()

            mel, audio = batch
            mel = torch.autograd.Variable(mel.cuda())
            audio = torch.autograd.Variable(audio.cuda())
            outputs = model((mel, audio))

            loss = criterion(outputs)
            if num_gpus > 1:
                reduced_loss = reduce_tensor(loss.data, num_gpus).item()
            else:
                reduced_loss = loss.item()

            if fp16_run:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            optimizer.step()

            print("{}:\t{:.9f}".format(iteration, reduced_loss))
            if with_tensorboard and rank == 0:
                logger.add_scalar('training_loss', reduced_loss,
                                  i + len(train_loader) * epoch)
            if with_tensorboard and iteration % iters_per_validation == 0:
                logger.add_audio('real_audio',
                                 audio[0, :].cpu().detach().numpy(),
                                 iteration,
                                 sample_rate=22050)
                generated_audio = model.infer(mel[0].unsqueeze(0))
                logger.add_audio('generated_audio',
                                 generated_audio.cpu().detach().numpy(),
                                 iteration,
                                 sample_rate=22050)
                validate(model, criterion, valset, iteration, batch_size,
                         logger)

            if (iteration % iters_per_checkpoint == 0):
                if rank == 0:
                    checkpoint_path = "{}/waveglow_{}".format(
                        output_directory, iteration)
                    save_checkpoint(model, optimizer, learning_rate, iteration,
                                    checkpoint_path)

            iteration += 1
Example #6
0
def main():
    args = get_arguments()
    args.logdir = os.path.join(hparams.logdir_root, args.run_name)
    if not os.path.exists(args.logdir):
        os.makedirs(args.logdir)

    args.gen_wave_dir = os.path.join(args.logdir, 'wave')
    os.makedirs(args.gen_wave_dir, exist_ok=True)

    assert hparams.upsampling_rate == hparams.hop_length, 'upsamling rate should be same as hop_length'

    # Create coordinator.
    coord = tf.train.Coordinator()
    global_step = tf.get_variable("global_step", [],
                                  initializer=tf.constant_initializer(0),
                                  trainable=False)
    learning_rate = tf.train.exponential_decay(hparams.lr,
                                               global_step,
                                               hparams.decay_steps,
                                               0.95,
                                               staircase=True)
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)

    with tf.device('/cpu:0'):
        with tf.name_scope('inputs'):
            reader = DataReader(coord, args.filelist, args.wave_dir,
                                args.lc_dir)

    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False,
                                            allow_soft_placement=True))
    reader.start_threads()

    audio_placeholder = tf.placeholder(tf.float32,
                                       shape=[None, None, 1],
                                       name='audio')
    lc_placeholder = tf.placeholder(tf.float32,
                                    shape=[None, None, hparams.num_mels],
                                    name='lc')

    tower_losses = []
    tower_grads = []
    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
        for i in range(args.ngpu):
            with tf.device('/gpu:%d' % i), tf.name_scope('tower_%d' % i):
                glow = WaveGlow(lc_dim=hparams.num_mels,
                                n_flows=hparams.n_flows,
                                n_group=hparams.n_group,
                                n_early_every=hparams.n_early_every,
                                n_early_size=hparams.n_early_size)
                print('create network %i' % i)

                local_audio_placeholder = audio_placeholder[
                    i * hparams.batch_size:(i + 1) * hparams.batch_size, :, :]
                local_lc_placeholder = lc_placeholder[
                    i * hparams.batch_size:(i + 1) * hparams.batch_size, :, :]

                output_audio, log_s_list, log_det_W_list = glow.create_forward_network(
                    local_audio_placeholder, local_lc_placeholder)
                loss = compute_waveglow_loss(output_audio,
                                             log_s_list,
                                             log_det_W_list,
                                             sigma=hparams.sigma)
                grads = optimizer.compute_gradients(
                    loss, var_list=tf.trainable_variables())

                tower_losses.append(loss)
                tower_grads.append(grads)

                tf.summary.scalar('loss_tower_%d' % i, loss)

    # # gradient clipping
    # gradients = [grad for grad, var in averaged_gradients]
    # params = [var for grad, var in averaged_gradients]
    # clipped_gradients, norm = tf.clip_by_global_norm(gradients, 1.0)
    #
    # with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
    #     train_ops = optimizer.apply_gradients(zip(clipped_gradients, params), global_step=global_step)

    print("create network finished")
    loss = tf.reduce_mean(tower_losses)
    averaged_gradients = average_gradients(tower_grads)

    train_ops = optimizer.apply_gradients(averaged_gradients,
                                          global_step=global_step)

    tf.summary.scalar('loss', loss)

    # Set up logging for TensorBoard.
    writer = tf.summary.FileWriter(args.logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.summary.merge_all()

    # inference for audio
    with tf.variable_scope(tf.get_variable_scope(), reuse=True):
        lc_placeholder_infer = tf.placeholder(
            tf.float32, shape=[1, None, hparams.num_mels], name='lc_infer')
        audio_infer_ops = glow.infer(lc_placeholder_infer, sigma=hparams.sigma)

    # Set up session
    init = tf.global_variables_initializer()
    sess.run(init)
    print('parameters initialization finished')

    # stats_graph(tf.get_default_graph())
    total_parameters = count()
    print("######################################################")
    print("### Total Trainable Params is {} ###".format(total_parameters))
    print("######################################################")

    saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=30)

    saved_global_step = 0
    if args.restore_from is not None:
        try:
            saved_global_step = load(saver, sess, args.restore_from)
            if saved_global_step is None:
                # The first training step will be saved_global_step + 1,
                # therefore we put -1 here for new or overwritten trainings.
                saved_global_step = 0
        except Exception:
            print(
                "Something went wrong while restoring checkpoint. "
                "We will terminate training to avoid accidentally overwriting "
                "the previous model.")
            raise

        print("restore model successfully!")

    print('start training.')
    last_saved_step = saved_global_step
    try:
        for step in range(saved_global_step + 1, hparams.train_steps):
            audio, lc = reader.dequeue(num_elements=hparams.batch_size *
                                       args.ngpu)

            if hparams.lc_conv1d or hparams.lc_encode or hparams.transposed_upsampling:
                # if using local condition bi-lstm encoding or tranposed conv upsampling, no need to upsample
                # bi-lstm, upsamle will be done in the tf code
                lc = np.reshape(
                    lc, [hparams.batch_size * args.ngpu, -1, hparams.num_mels])
            else:
                # upsampling by directly repeat
                lc = np.tile(lc, [1, 1, hparams.upsampling_rate])
                lc = np.reshape(
                    lc, [hparams.batch_size * args.ngpu, -1, hparams.num_mels])

            start_time = time.time()
            if step % 100 == 0 and args.store_metadata:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary, loss_value, _, lr = sess.run(
                    [summaries, loss, train_ops, learning_rate],
                    feed_dict={
                        audio_placeholder: audio,
                        lc_placeholder: lc
                    },
                    options=run_options,
                    run_metadata=run_metadata)
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(args.logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                summary, loss_value, _, lr = sess.run(
                    [summaries, loss, train_ops, learning_rate],
                    feed_dict={
                        audio_placeholder: audio,
                        lc_placeholder: lc
                    })
                writer.add_summary(summary, step)

            duration = time.time() - start_time
            step_log = 'step {:d} - loss = {:.3f}, lr={:.8f}, time cost={:4f}' \
                .format(step, loss_value, lr, duration)
            print(step_log)

            if step % hparams.save_model_every == 0:
                save(saver, sess, args.logdir, step)
                last_saved_step = step

            if step % hparams.gen_test_wave_every == 0:
                generate_wave(lc_placeholder_infer, audio_infer_ops, sess,
                              step, args.gen_wave_dir)

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step > last_saved_step:
            save(saver, sess, args.logdir, step)
        coord.request_stop()
        coord.join()
Example #7
0
def train(
    num_gpus,
    rank,
    group_name,
    output_directory,
    epochs,
    learning_rate,
    sigma,
    iters_per_checkpoint,
    batch_size,
    seed,
    fp16_run,
    checkpoint_path,
    with_tensorboard,
):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # =====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        init_distributed(rank, num_gpus, group_name, **dist_config)
    # =====END:   ADDED FOR DISTRIBUTED======

    criterion = WaveGlowLoss(sigma)
    model = WaveGlow(**waveglow_config).cuda()

    # =====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        model = apply_gradient_allreduce(model)
    # =====END:   ADDED FOR DISTRIBUTED======

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    if fp16_run:
        from apex import amp

        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

    # Load checkpoint if one exists
    iteration = 0
    if checkpoint_path != "":
        model, optimizer, iteration = load_checkpoint(checkpoint_path, model,
                                                      optimizer)
        iteration += 1  # next iteration is iteration + 1

    trainset = Mel2Samp(**data_config)
    # =====START: ADDED FOR DISTRIBUTED======
    train_sampler = DistributedSampler(trainset) if num_gpus > 1 else None
    # =====END:   ADDED FOR DISTRIBUTED======
    train_loader = DataLoader(
        trainset,
        num_workers=1,
        shuffle=False,
        sampler=train_sampler,
        batch_size=batch_size,
        pin_memory=False,
        drop_last=True,
    )

    # Get shared output_directory ready
    if rank == 0:
        if not os.path.isdir(output_directory):
            os.makedirs(output_directory)
            os.chmod(output_directory, 0o775)
        print("output directory", output_directory)

    if with_tensorboard and rank == 0:
        from tensorboardX import SummaryWriter

        logger = SummaryWriter(os.path.join(output_directory, "logs"))

    # fixed for visualization
    real_mels, real_audios = zip(*[trainset[i] for i in range(8)])
    real_mel = torch.cat(real_mels, dim=-1)
    real_audio = torch.cat(real_audios, dim=0)

    model.train()
    epoch_offset = max(0, int(iteration / len(train_loader)))
    # ================ MAIN TRAINNIG LOOP! ===================
    for epoch in range(epoch_offset, epochs):
        print("Epoch: {}".format(epoch))
        for i, batch in enumerate(train_loader):
            model.zero_grad()

            mel, audio = batch
            mel = torch.autograd.Variable(mel.cuda())
            audio = torch.autograd.Variable(audio.cuda())
            outputs = model((mel, audio))

            loss = criterion(outputs)
            if num_gpus > 1:
                reduced_loss = reduce_tensor(loss.data, num_gpus).item()
            else:
                reduced_loss = loss.item()

            if fp16_run:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            optimizer.step()

            print("{}:\t{:.9f}".format(iteration, reduced_loss))
            if with_tensorboard and rank == 0:
                step = i + len(train_loader) * epoch
                logger.add_scalar("training_loss", reduced_loss, step)
                if step % 500 == 0:
                    # select the first eight data sample

                    model.eval()
                    with torch.no_grad():
                        device = mel.device
                        fake_audio = (model.infer(
                            torch.stack(real_mels).to(device)).flatten(
                                0, 1).cpu())
                    model.train()
                    fake_mel = trainset.get_mel(fake_audio)

                    logger.add_image(
                        "training_mel_real",
                        plot_spectrogram_to_numpy(real_mel),
                        step,
                        dataformats="HWC",
                    )
                    logger.add_audio(
                        "training_audio_real",
                        real_audio,
                        step,
                        22050,
                    )
                    logger.add_image(
                        "training_mel_fake",
                        plot_spectrogram_to_numpy(fake_mel),
                        step,
                        dataformats="HWC",
                    )
                    logger.add_audio(
                        "training_audio_fake",
                        fake_audio,
                        step,
                        22050,
                    )
                    logger.flush()

            if iteration % iters_per_checkpoint == 0:
                if rank == 0:
                    checkpoint_path = "{}/waveglow_{}".format(
                        output_directory, iteration)
                    save_checkpoint(model, optimizer, learning_rate, iteration,
                                    checkpoint_path)

            iteration += 1