Exemplo n.º 1
0
def data_generator(data_base,chunk_length_in_sec,label_order_list,batch_size):    
    stft = TacotronSTFT(**SFT_CONFIG)
    keys = data_base.get_db_keys()    
    times = data_base.get_db_wv_times(keys[0])
    batch_x = []
    batch_y = []
    while True:        
            for k in keys:                
                label = np.array([label_order_list.index(k)])                 
                sampling_rate, speech = data_base.get_wav(keys[0],*times[0])
                chunks = int(len(speech)/sampling_rate/chunk_length_in_sec)
                audio_length = sampling_rate*chunk_length_in_sec
                for chunk in range(chunks):                
                    audio = speech[chunk*audio_length:(chunk+1)*audio_length]
                    audio_norm = audio / MAX_WAV_VALUE    
                    audio_norm = torch.from_numpy(audio_norm).float()
                    audio_norm = audio_norm.unsqueeze(0)
                    audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
                    melspec = stft.mel_spectrogram(audio_norm)
                    mel_np = melspec.detach().numpy()
                    for i in range(mel_np.shape[1]):
                        channel_mean = np.mean(mel_np[0,i,:])  
                        mel_np[0,i,:] = mel_np[0,i,:] - channel_mean
                    
                    #normalized_mel = torch.from_numpy(mel_np)
                    batch_x.append(mel_np)
                    batch_y.append(label)
                    #yield normalized_mel.unsqueeze(1), Variable(y_tensor)
                    if len(batch_x) >= batch_size:
                        x = torch.from_numpy(np.array(batch_x))
                        y = Variable(torch.from_numpy(np.concatenate(batch_y)).long())
                        batch_x = []
                        batch_y = []
                        yield x,y
Exemplo n.º 2
0
 def __init__(self,
              training_files,
              segment_length,
              filter_length,
              hop_length,
              win_length,
              sampling_rate,
              mel_fmin,
              mel_fmax,
              load_mel_from_disk=False):
     self.load_mel_from_disk = load_mel_from_disk
     self.hop_length = hop_length
     self.audio_files = audiopaths_and_melpaths(
         training_files) if self.load_mel_from_disk else files_to_list(
             training_files)
     random.seed(1234)
     random.shuffle(self.audio_files)
     self.stft = TacotronSTFT(filter_length=filter_length,
                              hop_length=hop_length,
                              win_length=win_length,
                              sampling_rate=sampling_rate,
                              mel_fmin=mel_fmin,
                              mel_fmax=mel_fmax)
     self.segment_length = segment_length
     self.sampling_rate = sampling_rate
Exemplo n.º 3
0
 def __init__(self,
              training_files,
              validation_files,
              segment_length,
              filter_length,
              hop_length,
              win_length,
              sampling_rate,
              mel_fmin,
              mel_fmax,
              train=True):
     if train:
         self.audio_files = files_to_list(training_files)
     else:
         self.audio_files = files_to_list(validation_files)
     random.seed(1234)
     random.shuffle(self.audio_files)
     self.stft = TacotronSTFT(filter_length=filter_length,
                              hop_length=hop_length,
                              win_length=win_length,
                              sampling_rate=sampling_rate,
                              mel_fmin=mel_fmin,
                              mel_fmax=mel_fmax)
     self.segment_length = segment_length
     self.sampling_rate = sampling_rate
Exemplo n.º 4
0
 def __init__(self,
              training_files,
              segment_length,
              filter_length,
              hop_length,
              win_length,
              sampling_rate,
              data_folder,
              audio_format,
              return_stft=False):
     self.audio_files = files_to_list(training_files)
     random.seed(1234)
     random.shuffle(self.audio_files)
     self.return_stft = return_stft
     if self.return_stft:
         self.stft = STFT(filter_length=filter_length,
                          hop_length=hop_length,
                          win_length=win_length)
     else:
         self.stft = TacotronSTFT(filter_length=filter_length,
                                  hop_length=hop_length,
                                  win_length=win_length,
                                  sampling_rate=sampling_rate,
                                  mel_fmin=0.0,
                                  mel_fmax=8000.0)
     self.segment_length = segment_length
     self.sampling_rate = sampling_rate
     self.data_folder = data_folder
     self.audio_format = audio_format
Exemplo n.º 5
0
    def __init__(self,
                 training_files,
                 segment_length,
                 filter_length,
                 hop_length,
                 win_length,
                 sampling_rate,
                 mel_fmin,
                 mel_fmax,
                 debug=False):
        self.stft = TacotronSTFT(filter_length=filter_length,
                                 hop_length=hop_length,
                                 win_length=win_length,
                                 sampling_rate=sampling_rate,
                                 mel_fmin=mel_fmin,
                                 mel_fmax=mel_fmax)
        self.segment_length = segment_length
        self.sampling_rate = sampling_rate
        self.debug = debug

        valid_files = []
        paths = files_to_list(training_files)
        for path in paths:
            dur = duration(path)
            if dur >= self.segment_length:
                valid_files.append(path)
        self.audio_files = valid_files
Exemplo n.º 6
0
 def __init__(self, filter_length, hop_length, win_length,
              sampling_rate, mel_fmin, mel_fmax):
     self.stft = TacotronSTFT(filter_length=filter_length,
                              hop_length=hop_length,
                              win_length=win_length,
                              sampling_rate=sampling_rate,
                              mel_fmin=mel_fmin, mel_fmax=mel_fmax)
Exemplo n.º 7
0
 def __init__(self, training_files, segment_length, filter_length,
              hop_length, win_length, sampling_rate):
     self.audio_files = files_to_list(training_files)
     random.seed(1234)
     random.shuffle(self.audio_files)
     self.stft = TacotronSTFT(filter_length=filter_length,
                              hop_length=hop_length,
                              win_length=win_length,
                              sampling_rate=sampling_rate)
     self.segment_length = segment_length
     self.sampling_rate = sampling_rate
Exemplo n.º 8
0
 def __init__(self, training_files, segment_length, filter_length,
              hop_length, win_length, sampling_rate, mel_fmin, mel_fmax):
     self.audio_files = files_to_list(training_files)
     random.seed(1234)
     self.stft = TacotronSTFT(filter_length=filter_length,
                              hop_length=hop_length,
                              win_length=win_length,
                              sampling_rate=sampling_rate,
                              mel_fmin=mel_fmin, mel_fmax=mel_fmax)
     self.segment_length = segment_length
     self.sampling_rate = sampling_rate
     self.dataset = self.pack()
Exemplo n.º 9
0
 def __init__(self, training_files, segment_length, filter_length,
              hop_length, win_length, sampling_rate, mel_fmin, mel_fmax, checkpoint_store_command):
     self.audio_files = files_to_list(training_files)
     random.seed(1234)
     random.shuffle(self.audio_files)
     self.stft = TacotronSTFT(filter_length=filter_length,
                              hop_length=hop_length,
                              win_length=win_length,
                              sampling_rate=sampling_rate,
                              mel_fmin=mel_fmin, mel_fmax=mel_fmax)
     self.segment_length = segment_length
     self.sampling_rate = sampling_rate
Exemplo n.º 10
0
 def __init__(self, data_path, valid, segment_length, filter_length,
              hop_length, win_length, sampling_rate, mel_fmin, mel_fmax):
     self.audio_files = files_to_list(data_path)
     self.valid = valid
     random.seed(1234)
     random.shuffle(self.audio_files)
     self.stft = TacotronSTFT(filter_length=filter_length,
                              hop_length=hop_length,
                              win_length=win_length,
                              sampling_rate=sampling_rate,
                              mel_fmin=mel_fmin, mel_fmax=mel_fmax)
     self.segment_length = segment_length
     self.sampling_rate = sampling_rate
Exemplo n.º 11
0
 def __init__(self, training_files, segment_length, filter_length,
              hop_length, win_length, sampling_rate, mel_fmin, mel_fmax):
     self.audio_files = files_to_list(training_files)
     self.audio_files_segment_pos = [-1 for _ in self.audio_files]
     random.seed(1234)
     random.shuffle(self.audio_files)
     self.stft = TacotronSTFT(filter_length=filter_length,
                              hop_length=hop_length,
                              win_length=win_length,
                              sampling_rate=sampling_rate,
                              mel_fmin=mel_fmin,
                              mel_fmax=mel_fmax)
     self.segment_length = segment_length
     self.sampling_rate = sampling_rate
Exemplo n.º 12
0
    def __init__(self, training_files, segment_length, mu_quantization,
                 filter_length, hop_length, win_length, sampling_rate):
        audio_files = utils.files_to_list(training_files)
        self.audio_files = audio_files
        random.seed(1234)
        random.shuffle(self.audio_files)

        self.stft = TacotronSTFT(filter_length=filter_length,
                                 hop_length=hop_length,
                                 win_length=win_length,
                                 sampling_rate=sampling_rate)

        self.segment_length = segment_length
        self.mu_quantization = mu_quantization
        self.sampling_rate = sampling_rate
Exemplo n.º 13
0
    def __init__(self, training_files, validation_files, validation_windows,
                 segment_length, filter_length, hop_length, win_length,
                 sampling_rate, mel_fmin, mel_fmax, load_mel_from_disk,
                 preempthasis):
        self.audio_files = load_filepaths_and_text(training_files)

        print("Files before checking: ", len(self.audio_files))

        i = 0
        i_offset = 0
        for i_ in range(len(self.audio_files)):
            i = i_ + i_offset
            if i == len(self.audio_files): break
            file = self.audio_files[i]
            if not os.path.exists(file[0]):
                print(file[0], "does not exist")
                self.audio_files.remove(file)
                i_offset -= 1
                continue

            audio_data, sample_r = load_wav_to_torch(file[0])
            if audio_data.size(0) <= segment_length:
                print(file[0], "is too short")
                self.audio_files.remove(file)
                i_offset -= 1
                continue

        print("Files after checking: ", len(self.audio_files))

        self.load_mel_from_disk = load_mel_from_disk
        self.speaker_ids = self.create_speaker_lookup_table(self.audio_files)
        random.seed(1234)
        random.shuffle(self.audio_files)
        self.stft = TacotronSTFT(filter_length=filter_length,
                                 hop_length=hop_length,
                                 win_length=win_length,
                                 sampling_rate=sampling_rate,
                                 n_mel_channels=160,
                                 mel_fmin=mel_fmin,
                                 mel_fmax=mel_fmax)
        if preempthasis:
            self.preempthasise = PreEmphasis(preempthasis)
        self.segment_length = segment_length
        self.sampling_rate = sampling_rate
        self.hop_length = hop_length
        self.win_length = win_length
Exemplo n.º 14
0
 def __init__(
     self,
     sampling_rate,
     n_mel_channels,
     filter_length=1024,
     hop_length=256,
     win_length=1024,
     mel_fmin=0.0,
     mel_fmax=8000.0,
 ):
     super(Tacotron, self).__init__(sampling_rate, n_mel_channels)
     self.taco_stft = TacotronSTFT(
         filter_length=filter_length,
         hop_length=hop_length,
         win_length=win_length,
         sampling_rate=sampling_rate,
         n_mel_channels=n_mel_channels,
         mel_fmin=mel_fmin,
         mel_fmax=mel_fmax,
     )
Exemplo n.º 15
0
    def __init__(self, training_files, segment_length, filter_length,
                 hop_length, win_length, sampling_rate, mel_fmin, mel_fmax, num_workers,
                 use_multi_speaker, speaker_embedding_path, use_speaker_embedding_model):
        self.audio_files = files_to_list(training_files)

        random.seed(1234)
        random.shuffle(self.audio_files)
        self.stft = TacotronSTFT(filter_length=filter_length,
                                 hop_length=hop_length,
                                 win_length=win_length,
                                 sampling_rate=sampling_rate,
                                 mel_fmin=mel_fmin, mel_fmax=mel_fmax)
        self.segment_length = segment_length
        self.sampling_rate = sampling_rate
        self.num_workers = num_workers
        self.use_multi_speaker = use_multi_speaker
        self.speaker_embedding_path = speaker_embedding_path
        self.use_speaker_embedding_model = use_speaker_embedding_model
        if not self.use_speaker_embedding_model:
            self.spk_id_map = pickle.load(open(self.speaker_embedding_path, "rb"))
def data_generator(data_base, chunk_length_in_sec, label_order_list):
    stft = TacotronSTFT(**SFT_CONFIG)
    keys = data_base.get_db_keys()
    batch_train_x = []
    batch_train_y = []

    while True:
        for k in keys:
            label = np.array([label_order_list.index(k)])
            for t in data_base.get_db_wv_times(k):
                sampling_rate, speech = data_base.get_wav(k, *t)
                chunks = int(len(speech) / sampling_rate / chunk_length_in_sec)
                audio_length = sampling_rate * chunk_length_in_sec
                for chunk in range(chunks):
                    audio = speech[chunk * audio_length:(chunk + 1) *
                                   audio_length]
                    audio_norm = audio / MAX_WAV_VALUE
                    audio_norm = torch.from_numpy(audio_norm).float()
                    audio_norm = audio_norm.unsqueeze(0)
                    audio_norm = torch.autograd.Variable(audio_norm,
                                                         requires_grad=False)
                    melspec = stft.mel_spectrogram(audio_norm)
                    mel_np = melspec.detach().numpy()
                    for i in range(mel_np.shape[1]):
                        channel_mean = np.mean(mel_np[0, i, :])
                        mel_np[0, i, :] = mel_np[0, i, :] - channel_mean

                    batch_train_x.append(mel_np)
                    batch_train_y.append(label)
            if len(batch_train_x) > 0 and len(batch_train_y) > 0:
                yield np.array(batch_train_x), np.concatenate(
                    np.array(batch_train_y))
            else:
                yield np.array([]), np.array([])

            batch_train_x = []
            batch_train_y = []

        yield None, None
Exemplo n.º 17
0
    def __init__(self, training_files, segment_length, filter_length,
                 hop_length, win_length, sampling_rate, mel_fmin, mel_fmax):
        self.audio_files = files_to_list(training_files)
        random.seed(1234)
        random.shuffle(self.audio_files)

        self.stft = TacotronSTFT(filter_length=filter_length,
                                 hop_length=hop_length,
                                 win_length=win_length,
                                 sampling_rate=sampling_rate,
                                 mel_fmin=mel_fmin, mel_fmax=mel_fmax)
        self.segment_length = segment_length
        self.sampling_rate = sampling_rate
        self.everything = self.pack()


        self.max_time = self.segment_length

        best = -1
        score = 0.0
        if self.max_time == 0: ##auto configuration for maximum efficiency
            for x in range(250000, 1000000,10000):
                self.max_time = x
                self.do_binpacking()

                utilized =   np.asarray(self.volumes).mean()/self.max_time
                if utilized > score:
                    score = utilized
                    best= x

        self.max_time = best
        self.do_binpacking()

        ##import pdb; pdb.set_trace()
        perm = list(range(len(self.balancer)))
        random.shuffle(perm)
        self.volumes = [self.volumes[p] for p in perm  ]
        self.balancer = [self.balancer[p] for p in perm  ]
Exemplo n.º 18
0
    def __init__(self, training_files, segment_length, filter_length,
                 hop_length, win_length, sampling_rate, mel_fmin, mel_fmax):
        self.audio_files = files_to_list(training_files)
        #过滤短音频
        # i = 0
        # for file in files_to_list(training_files):
        #     audio_data, sample_r = load_wav_to_torch(file)

        #     if audio_data.size(0) < segment_length:
        #         i += 1
        #         print(file)
        #         self.audio_files.remove(file)
        # print("{} files shorter than segment_len".format(i))

        random.seed(1234)
        random.shuffle(self.audio_files)
        self.stft = TacotronSTFT(filter_length=filter_length,
                                 hop_length=hop_length,
                                 win_length=win_length,
                                 sampling_rate=sampling_rate,
                                 mel_fmin=mel_fmin,
                                 mel_fmax=mel_fmax)
        self.segment_length = segment_length
        self.sampling_rate = sampling_rate
Exemplo n.º 19
0
    def __init__(self,
                 training_files,
                 validation_files,
                 validation_windows,
                 segment_length,
                 filter_length,
                 hop_length,
                 win_length,
                 sampling_rate,
                 mel_fmin,
                 mel_fmax,
                 load_mel_from_disk,
                 preempthasis,
                 check_files=False):
        self.audio_files = load_filepaths_and_text(training_files)

        if check_files:
            print("Files before checking: ", len(self.audio_files))
            if True:  # list comp non-verbose
                # filter audio files that don't exist
                self.audio_files = [
                    x for x in self.audio_files if os.path.exists(x[0])
                ]
                assert len(self.audio_files), "self.audio_files is empty"

                # filter spectrograms that don't exist
                if load_mel_from_disk > 0.0:
                    self.audio_files = [
                        x for x in self.audio_files if os.path.exists(x[1])
                    ]
                    assert len(self.audio_files), "self.audio_files is empty"

                # filter audio files that are too short
                self.audio_files = [
                    x for x in self.audio_files
                    if (os.stat(x[0]).st_size // 2) >= segment_length
                ]
                assert len(self.audio_files), "self.audio_files is empty"
            else:  # forloop with verbose support
                i = 0
                i_offset = 0
                for i_ in range(len(self.audio_files)):
                    i = i_ + i_offset
                    if i == len(self.audio_files): break
                    file = self.audio_files[i]

                    if not os.path.exists(
                            file[0]):  # check if audio file exists
                        print(f"'{file[0]}' does not exist")
                        self.audio_files.remove(file)
                        i_offset -= 1
                        continue

                    if load_mel_from_disk > 0.0 and not os.path.exists(
                            file[1]):  # check if mel exists
                        print(f"'{file[1]}' does not exist")
                        self.audio_files.remove(file)
                        i_offset -= 1
                        continue

                    if 1:  # performant mode if bitdepth is already known
                        bitdepth = 2
                        size = os.stat(file[0]).st_size
                        duration = size // bitdepth  #duration in samples
                        if duration <= segment_length:  # check if audio file is shorter than segment_length
                            #print(f"'{file[0]}' is too short")
                            self.audio_files.remove(file)
                            i_offset -= 1
                            continue
                    else:
                        audio_data, sample_r, *_ = load_wav_to_torch(file[0])
                        if audio_data.size(
                                0
                        ) <= segment_length:  # check if audio file is shorter than segment_length
                            print(f"'{file[0]}' is too short")
                            self.audio_files.remove(file)
                            i_offset -= 1
                            continue
            print("Files after checking: ", len(self.audio_files))

        self.load_mel_from_disk = load_mel_from_disk
        self.speaker_ids = self.create_speaker_lookup_table(self.audio_files)

        # Apply weighting to MLP Datasets
        duplicated_audiopaths = [
            x for x in self.audio_files if "SlicedDialogue" in x[0]
        ]
        for i in range(3):
            self.audio_files.extend(duplicated_audiopaths)

        random.seed(1234)
        random.shuffle(self.audio_files)
        self.stft = TacotronSTFT(filter_length=filter_length,
                                 hop_length=hop_length,
                                 win_length=win_length,
                                 sampling_rate=sampling_rate,
                                 n_mel_channels=160,
                                 mel_fmin=mel_fmin,
                                 mel_fmax=mel_fmax)
        if preempthasis:
            self.preempthasise = PreEmphasis(preempthasis)
        self.segment_length = segment_length
        self.sampling_rate = sampling_rate
        self.hop_length = hop_length
        self.win_length = win_length
Exemplo n.º 20
0
def train(num_gpus,
          rank,
          group_name,
          output_directory,
          epochs,
          learning_rate,
          sigma,
          loss_empthasis,
          iters_per_checkpoint,
          batch_size,
          seed,
          fp16_run,
          checkpoint_path,
          with_tensorboard,
          logdirname,
          datedlogdir,
          warm_start=False):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        init_distributed(rank, num_gpus, group_name, **dist_config)
    #=====END:   ADDED FOR DISTRIBUTED======

    global WaveGlow
    global WaveGlowLoss

    ax = True  # this is **really** bad coding practice :D
    if ax:
        from efficient_model_ax import WaveGlow
        from efficient_loss import WaveGlowLoss
    else:
        if waveglow_config[
                "yoyo"]:  # efficient_mode # TODO: Add to Config File
            from efficient_model import WaveGlow
            from efficient_loss import WaveGlowLoss
        else:
            from glow import WaveGlow, WaveGlowLoss

    criterion = WaveGlowLoss(sigma, loss_empthasis)
    model = WaveGlow(**waveglow_config).cuda()
    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        model = apply_gradient_allreduce(model)
    #=====END:   ADDED FOR DISTRIBUTED======
    STFT = [
        TacotronSTFT(filter_length=window,
                     hop_length=data_config['hop_length'],
                     win_length=window,
                     sampling_rate=data_config['sampling_rate'],
                     n_mel_channels=160,
                     mel_fmin=0,
                     mel_fmax=16000)
        for window in data_config['validation_windows']
    ]

    loader_STFT = TacotronSTFT(filter_length=data_config['filter_length'],
                               hop_length=data_config['hop_length'],
                               win_length=data_config['win_length'],
                               sampling_rate=data_config['sampling_rate'],
                               n_mel_channels=160,
                               mel_fmin=data_config['mel_fmin'],
                               mel_fmax=data_config['mel_fmax'])

    optimizer = "LAMB"
    optimizer_fused = True  # use Apex fused optimizer, should be identical to normal but slightly faster
    if optimizer_fused:
        from apex import optimizers as apexopt
        if optimizer == "Adam":
            optimizer = apexopt.FusedAdam(model.parameters(), lr=learning_rate)
        elif optimizer == "LAMB":
            optimizer = apexopt.FusedLAMB(model.parameters(),
                                          lr=learning_rate,
                                          max_grad_norm=1000)
    else:
        if optimizer == "Adam":
            optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        elif optimizer == "LAMB":
            from lamb import Lamb as optLAMB
            optimizer = optLAMB(model.parameters(), lr=learning_rate)
            #import torch_optimizer as optim
            #optimizer = optim.Lamb(model.parameters(), lr=learning_rate)
            #raise# PyTorch doesn't currently include LAMB optimizer.

    if fp16_run:
        global amp
        from apex import amp
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level='O1',
                                          min_loss_scale=1.0)
    else:
        amp = None

    ## LEARNING RATE SCHEDULER
    if True:
        from torch.optim.lr_scheduler import ReduceLROnPlateau
        min_lr = 1e-5
        factor = 0.1**(
            1 / 5)  # amount to scale the LR by on Validation Loss plateau
        scheduler = ReduceLROnPlateau(optimizer,
                                      'min',
                                      factor=factor,
                                      patience=20,
                                      cooldown=2,
                                      min_lr=min_lr,
                                      verbose=True)
        print("ReduceLROnPlateau used as Learning Rate Scheduler.")
    else:
        scheduler = False

    # Load checkpoint if one exists
    iteration = 0
    if checkpoint_path != "":

        #warm_start = 0 # WARM START THE MODEL AND RESET ANY INVALID LAYERS

        model, optimizer, iteration, scheduler = load_checkpoint(
            checkpoint_path,
            model,
            optimizer,
            scheduler,
            fp16_run,
            warm_start=warm_start)
        iteration += 1  # next iteration is iteration + 1

    trainset = Mel2Samp(**data_config, check_files=True)
    speaker_lookup = trainset.speaker_ids
    # =====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        train_sampler = DistributedSampler(trainset, shuffle=True)
        shuffle = False
    else:
        train_sampler = None
        shuffle = True
    # =====END:   ADDED FOR DISTRIBUTED======
    train_loader = DataLoader(trainset,
                              num_workers=2,
                              shuffle=shuffle,
                              sampler=train_sampler,
                              batch_size=batch_size,
                              pin_memory=False,
                              drop_last=True)

    # Get shared output_directory ready
    if rank == 0:
        if not os.path.isdir(output_directory):
            os.makedirs(output_directory)
            os.chmod(output_directory, 0o775)
        print("output directory", output_directory)

    if with_tensorboard and rank == 0:
        from tensorboardX import SummaryWriter
        if datedlogdir:
            timestr = time.strftime("%Y_%m_%d-%H_%M_%S")
            log_directory = os.path.join(output_directory, logdirname, timestr)
        else:
            log_directory = os.path.join(output_directory, logdirname)
        logger = SummaryWriter(log_directory)

    moving_average = int(min(len(train_loader),
                             100))  # average loss over entire Epoch
    rolling_sum = StreamingMovingAverage(moving_average)
    start_time = time.time()
    start_time_single_batch = time.time()
    model.train()

    # best (averaged) training loss
    if os.path.exists(os.path.join(output_directory, "best_model") + ".txt"):
        best_model_loss = float(
            str(
                open(os.path.join(output_directory, "best_model") + ".txt",
                     "r",
                     encoding="utf-8").read()).split("\n")[0])
    else:
        best_model_loss = -6.20

    # best (validation) MSE on inferred spectrogram.
    if os.path.exists(
            os.path.join(output_directory, "best_val_model") + ".txt"):
        best_MSE = float(
            str(
                open(os.path.join(output_directory, "best_val_model") + ".txt",
                     "r",
                     encoding="utf-8").read()).split("\n")[0])
    else:
        best_MSE = 9e9

    epoch_offset = max(0, int(iteration / len(train_loader)))

    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print("{:,} total parameters in model".format(pytorch_total_params))
    pytorch_total_params = sum(p.numel() for p in model.parameters()
                               if p.requires_grad)
    print("{:,} trainable parameters.".format(pytorch_total_params))

    training = True
    while training:
        try:
            if rank == 0:
                epochs_iterator = tqdm(range(epoch_offset, epochs),
                                       initial=epoch_offset,
                                       total=epochs,
                                       smoothing=0.01,
                                       desc="Epoch",
                                       position=1,
                                       unit="epoch")
            else:
                epochs_iterator = range(epoch_offset, epochs)
            # ================ MAIN TRAINING LOOP! ===================
            for epoch in epochs_iterator:
                print(f"Epoch: {epoch}")
                if num_gpus > 1: train_sampler.set_epoch(epoch)

                if rank == 0:
                    iters_iterator = tqdm(enumerate(train_loader),
                                          desc=" Iter",
                                          smoothing=0,
                                          total=len(train_loader),
                                          position=0,
                                          unit="iter",
                                          leave=True)
                else:
                    iters_iterator = enumerate(train_loader)
                for i, batch in iters_iterator:
                    # run external code every iter, allows the run to be adjusted without restarts
                    if (i == 0 or iteration % param_interval == 0):
                        try:
                            with open("run_every_epoch.py") as f:
                                internal_text = str(f.read())
                                if len(internal_text) > 0:
                                    #code = compile(internal_text, "run_every_epoch.py", 'exec')
                                    ldict = {'iteration': iteration}
                                    exec(internal_text, globals(), ldict)
                                else:
                                    print(
                                        "No Custom code found, continuing without changes."
                                    )
                        except Exception as ex:
                            print(f"Custom code FAILED to run!\n{ex}")
                        globals().update(ldict)
                        locals().update(ldict)
                        if show_live_params:
                            print(internal_text)
                    if not iteration % 50:  # check actual learning rate every 20 iters (because I sometimes see learning_rate variable go out-of-sync with real LR)
                        learning_rate = optimizer.param_groups[0]['lr']
                    # Learning Rate Schedule
                    if custom_lr:
                        old_lr = learning_rate
                        if iteration < warmup_start:
                            learning_rate = warmup_start_lr
                        elif iteration < warmup_end:
                            learning_rate = (iteration - warmup_start) * (
                                (A_ + C_) - warmup_start_lr
                            ) / (
                                warmup_end - warmup_start
                            ) + warmup_start_lr  # learning rate increases from warmup_start_lr to A_ linearly over (warmup_end-warmup_start) iterations.
                        else:
                            if iteration < decay_start:
                                learning_rate = A_ + C_
                            else:
                                iteration_adjusted = iteration - decay_start
                                learning_rate = (
                                    A_ * (e**(-iteration_adjusted / B_))) + C_
                        assert learning_rate > -1e-8, "Negative Learning Rate."
                        if old_lr != learning_rate:
                            for param_group in optimizer.param_groups:
                                param_group['lr'] = learning_rate
                    else:
                        scheduler.patience = scheduler_patience
                        scheduler.cooldown = scheduler_cooldown
                        if override_scheduler_last_lr:
                            scheduler._last_lr = override_scheduler_last_lr
                            print(
                                "Scheduler last_lr overriden. scheduler._last_lr =",
                                scheduler._last_lr)
                        if override_scheduler_best:
                            scheduler.best = override_scheduler_best
                            print(
                                "Scheduler best metric overriden. scheduler.best =",
                                override_scheduler_best)

                    model.zero_grad()
                    mel, audio, speaker_ids = batch
                    mel = torch.autograd.Variable(mel.cuda(non_blocking=True))
                    audio = torch.autograd.Variable(
                        audio.cuda(non_blocking=True))
                    if waveglow_config['WN_config']['speaker_embed_dim'] > 0:
                        speaker_ids = speaker_ids.cuda(
                            non_blocking=True).long().squeeze(1)
                        outputs = model(mel, audio, speaker_ids)
                    else:
                        outputs = model(mel, audio, None)

                    loss = criterion(outputs)
                    if num_gpus > 1:
                        reduced_loss = reduce_tensor(loss.data,
                                                     num_gpus).item()
                    else:
                        reduced_loss = loss.item()

                    if iteration > 1e3 and (
                        (reduced_loss > LossExplosionThreshold) or
                        (math.isnan(reduced_loss))):
                        raise LossExplosion(
                            f"\n\n\nLOSS EXPLOSION EXCEPTION: Loss reached {reduced_loss} during iteration {iteration}.\n\n\n"
                        )

                    if fp16_run:
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()

                    grad_clip = False
                    grad_clip_thresh = 10000
                    if grad_clip:
                        if fp16_run:
                            grad_norm = torch.nn.utils.clip_grad_norm_(
                                amp.master_params(optimizer), grad_clip_thresh)
                        else:
                            grad_norm = torch.nn.utils.clip_grad_norm_(
                                model.parameters(), grad_clip_thresh)
                        is_overflow = math.isinf(grad_norm) or math.isnan(
                            grad_norm)
                    else:
                        is_overflow = False
                        grad_norm = 0.00001

                    optimizer.step()
                    if not is_overflow and with_tensorboard and rank == 0:
                        if (iteration % 100000 == 0):
                            # plot distribution of parameters
                            for tag, value in model.named_parameters():
                                tag = tag.replace('.', '/')
                                logger.add_histogram(tag,
                                                     value.data.cpu().numpy(),
                                                     iteration)
                        logger.add_scalar('training_loss', reduced_loss,
                                          iteration)
                        #logger.add_scalar('training_loss_exp', 500*(exp(reduced_loss)), iteration)
                        logger.add_scalar('training_loss_samples',
                                          reduced_loss, iteration * batch_size)
                        if (iteration % 20 == 0):
                            logger.add_scalar('learning.rate', learning_rate,
                                              iteration)
                        if (iteration % 10 == 0):
                            logger.add_scalar(
                                'duration', ((time.time() - start_time) / 10),
                                iteration)
                        start_time_single_batch = time.time()

                    average_loss = rolling_sum.process(reduced_loss)
                    if rank == 0:
                        if (iteration % 10 == 0):
                            tqdm.write(
                                "{} {}:  {:.3f}  {:.3f} {:08.3F} {:.8f}LR ({:.8f} Effective)  {:.2f}s/iter {:.4f}s/item"
                                .format(
                                    time.strftime("%H:%M:%S"), iteration,
                                    reduced_loss, average_loss,
                                    round(grad_norm, 3), learning_rate,
                                    min((grad_clip_thresh / grad_norm) *
                                        learning_rate, learning_rate),
                                    (time.time() - start_time) / 10,
                                    ((time.time() - start_time) / 10) /
                                    (batch_size * num_gpus)))
                            start_time = time.time()
                        else:
                            tqdm.write(
                                "{} {}:  {:.3f}  {:.3f} {:08.3F} {:.8f}LR ({:.8f} Effective)"
                                .format(
                                    time.strftime("%H:%M:%S"), iteration,
                                    reduced_loss, average_loss,
                                    round(grad_norm, 3), learning_rate,
                                    min((grad_clip_thresh / grad_norm) *
                                        learning_rate, learning_rate)))

                    if rank == 0 and (len(rolling_sum.values) >
                                      moving_average - 2):
                        if (average_loss +
                                best_model_margin) < best_model_loss:
                            checkpoint_path = os.path.join(
                                output_directory, "best_model")
                            try:
                                save_checkpoint(model, optimizer,
                                                learning_rate, iteration, amp,
                                                scheduler, speaker_lookup,
                                                checkpoint_path)
                            except KeyboardInterrupt:  # Avoid corrupting the model.
                                save_checkpoint(model, optimizer,
                                                learning_rate, iteration, amp,
                                                scheduler, speaker_lookup,
                                                checkpoint_path)
                            text_file = open((f"{checkpoint_path}.txt"),
                                             "w",
                                             encoding="utf-8")
                            text_file.write(
                                str(average_loss) + "\n" + str(iteration))
                            text_file.close()
                            best_model_loss = average_loss  #Only save the model if X better than the current loss.
                    if rank == 0 and ((iteration % iters_per_checkpoint == 0)
                                      or
                                      (os.path.exists(save_file_check_path))):
                        checkpoint_path = f"{output_directory}/waveglow_{iteration}"
                        save_checkpoint(model, optimizer, learning_rate,
                                        iteration, amp, scheduler,
                                        speaker_lookup, checkpoint_path)
                        start_time_single_batch = time.time()
                        if (os.path.exists(save_file_check_path)):
                            os.remove(save_file_check_path)

                    if (iteration % validation_interval == 0):
                        if rank == 0:
                            MSE, MAE = validate(
                                model, loader_STFT, STFT, logger, iteration,
                                data_config['validation_files'],
                                speaker_lookup, sigma, output_directory,
                                data_config)
                            if scheduler:
                                MSE = torch.tensor(MSE, device='cuda')
                                if num_gpus > 1:
                                    broadcast(MSE, 0)
                                scheduler.step(MSE.item())
                                if MSE < best_MSE:
                                    checkpoint_path = os.path.join(
                                        output_directory, "best_val_model")
                                    try:
                                        save_checkpoint(
                                            model, optimizer, learning_rate,
                                            iteration, amp, scheduler,
                                            speaker_lookup, checkpoint_path)
                                    except KeyboardInterrupt:  # Avoid corrupting the model.
                                        save_checkpoint(
                                            model, optimizer, learning_rate,
                                            iteration, amp, scheduler,
                                            speaker_lookup, checkpoint_path)
                                    text_file = open(
                                        (f"{checkpoint_path}.txt"),
                                        "w",
                                        encoding="utf-8")
                                    text_file.write(
                                        str(MSE.item()) + "\n" +
                                        str(iteration))
                                    text_file.close()
                                    best_MSE = MSE.item(
                                    )  #Only save the model if X better than the current loss.
                        else:
                            if scheduler:
                                MSE = torch.zeros(1, device='cuda')
                                broadcast(MSE, 0)
                                scheduler.step(MSE.item())
                        learning_rate = optimizer.param_groups[0][
                            'lr']  #check actual learning rate (because I sometimes see learning_rate variable go out-of-sync with real LR)
                    iteration += 1
            training = False  # exit the While loop

        except LossExplosion as ex:  # print Exception and continue from checkpoint. (turns out it takes < 4 seconds to restart like this, f*****g awesome)
            print(ex)  # print Loss
            if checkpoint_path == '':
                checkpoint_path = os.path.join(output_directory,
                                               "best_val_model")
            assert 'best_val_model' in checkpoint_path, "Automatic restarts require checkpoint set to best_val_model"
            model.eval()
            model, optimizer, iteration, scheduler = load_checkpoint(
                checkpoint_path, model, optimizer, scheduler, fp16_run)
            learning_rate = optimizer.param_groups[0]['lr']
            epoch_offset = max(0, int(iteration / len(train_loader)))
            model.train()
            iteration += 1
            pass  # and continue training.