Esempio n. 1
0
def linear_warmup_reduce_on_palteau(optimizer,batch_size,epochs,burnin,batches_per_epoch,load_epoch=-1,mode="min"):

    # grab lr from re-loaded optim
    init_lr = [group['lr'] for group in optimizer.param_groups]    

    # correct learning rate with batch size
    batch_correction = math.sqrt(batch_size)
    
    # compute number of total batches
    burnin_steps = int(burnin * batches_per_epoch)
    total_steps = int(epochs * batches_per_epoch)

    # convert epoch number to batches
    nbatches = load_epoch * batches_per_epoch if load_epoch > -1 else -1
    loading_lin = nbatches > -1

    # compute the loaded "epoch" cosine scheduler
    cos_batch = nbatches - burnin_steps if nbatches > -1 else -1
    cos_batch = cos_batch - 2 if cos_batch > -1 else -1 

    # init cosine annealing scheduler
    T_max = total_steps - burnin_steps
    eta_min = 0
    after_scheduler = ReduceLROnPlateau(optimizer,
                                        patience = 10,
                                        factor=1./np.sqrt(10),
                                        mode=mode)
    cos_loading = cos_batch > -1

    # init the linear warmup
    nbatches = nbatches - 1 if nbatches > 0 else -1
    scheduler = LinearWarmup(optimizer, burnin_steps, batch_correction,
                             after_scheduler, nbatches)

    # handle setting lr after reloading
    if loading_lin:
        for idx,group in enumerate(optimizer.param_groups):
            group['lr'] = init_lr[idx]

    if cos_loading:
        after_scheduler._last_lr = init_lr
        scheduler._last_lr = init_lr

    return scheduler
Esempio n. 2
0
def train(num_gpus, rank, group_name, output_directory, epochs, learning_rate,
          sigma, loss_empthasis, iters_per_checkpoint, batch_size, seed, fp16_run,
          checkpoint_path, with_tensorboard, logdirname, datedlogdir, warm_start=False, optimizer='ADAM', start_zero=False):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        init_distributed(rank, num_gpus, group_name, **dist_config)
    #=====END:   ADDED FOR DISTRIBUTED======
    
    global WaveGlow
    global WaveGlowLoss
    
    ax = True # this is **really** bad coding practice :D
    if ax:
        from efficient_model_ax import WaveGlow
        from efficient_loss import WaveGlowLoss
    else:
        if waveglow_config["yoyo"]: # efficient_mode # TODO: Add to Config File
            from efficient_model import WaveGlow
            from efficient_loss import WaveGlowLoss
        else:
            from glow import WaveGlow, WaveGlowLoss
    
    criterion = WaveGlowLoss(sigma, loss_empthasis)
    model = WaveGlow(**waveglow_config).cuda()
    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        model = apply_gradient_allreduce(model)
    #=====END:   ADDED FOR DISTRIBUTED======
    STFTs = [STFT.TacotronSTFT(filter_length=window,
                                 hop_length=data_config['hop_length'],
                                 win_length=window,
                                 sampling_rate=data_config['sampling_rate'],
                                 n_mel_channels=160,
                                 mel_fmin=0, mel_fmax=16000) for window in data_config['validation_windows']]
    
    loader_STFT = STFT.TacotronSTFT(filter_length=data_config['filter_length'],
                                 hop_length=data_config['hop_length'],
                                 win_length=data_config['win_length'],
                                 sampling_rate=data_config['sampling_rate'],
                                 n_mel_channels=data_config['n_mel_channels'] if 'n_mel_channels' in data_config.keys() else 160,
                                 mel_fmin=data_config['mel_fmin'], mel_fmax=data_config['mel_fmax'])
    
    #optimizer = "Adam"
    optimizer = optimizer.lower()
    optimizer_fused = bool( 0 ) # use Apex fused optimizer, should be identical to normal but slightly faster and only works on RTX cards
    if optimizer_fused:
        from apex import optimizers as apexopt
        if optimizer == "adam":
            optimizer = apexopt.FusedAdam(model.parameters(), lr=learning_rate)
        elif optimizer == "lamb":
            optimizer = apexopt.FusedLAMB(model.parameters(), lr=learning_rate, max_grad_norm=200)
    else:
        if optimizer == "adam":
            optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        elif optimizer == "lamb":
            from lamb import Lamb as optLAMB
            optimizer = optLAMB(model.parameters(), lr=learning_rate)
            #import torch_optimizer as optim
            #optimizer = optim.Lamb(model.parameters(), lr=learning_rate)
            #raise# PyTorch doesn't currently include LAMB optimizer.
    
    if fp16_run:
        global amp
        from apex import amp
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
    else:
        amp = None
    
    ## LEARNING RATE SCHEDULER
    if True:
        from torch.optim.lr_scheduler import ReduceLROnPlateau
        min_lr = 1e-8
        factor = 0.1**(1/5) # amount to scale the LR by on Validation Loss plateau
        scheduler = ReduceLROnPlateau(optimizer, 'min', factor=factor, patience=20, cooldown=2, min_lr=min_lr, verbose=True, threshold=0.0001, threshold_mode='abs')
        print("ReduceLROnPlateau used as Learning Rate Scheduler.")
    else: scheduler=False
    
    # Load checkpoint if one exists
    iteration = 0
    if checkpoint_path != "":
        model, optimizer, iteration, scheduler = load_checkpoint(checkpoint_path, model,
                                                      optimizer, scheduler, fp16_run, warm_start=warm_start)
        iteration += 1  # next iteration is iteration + 1
    if start_zero:
        iteration = 0
    
    trainset = Mel2Samp(**data_config, check_files=True)
    speaker_lookup = trainset.speaker_ids
    # =====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        train_sampler = DistributedSampler(trainset, shuffle=True)
        shuffle = False
    else:
        train_sampler = None
        shuffle = True
    # =====END:   ADDED FOR DISTRIBUTED======
    train_loader = DataLoader(trainset, num_workers=3, shuffle=shuffle,
                              sampler=train_sampler,
                              batch_size=batch_size,
                              pin_memory=False,
                              drop_last=True)
    
    # Get shared output_directory ready
    if rank == 0:
        if not os.path.isdir(output_directory):
            os.makedirs(output_directory)
            os.chmod(output_directory, 0o775)
        print("output directory", output_directory)
    
    if with_tensorboard and rank == 0:
        from tensorboardX import SummaryWriter
        if datedlogdir:
            timestr = time.strftime("%Y_%m_%d-%H_%M_%S")
            log_directory = os.path.join(output_directory, logdirname, timestr)
        else:
            log_directory = os.path.join(output_directory, logdirname)
        logger = SummaryWriter(log_directory)
    
    moving_average = int(min(len(train_loader), 100)) # average loss over entire Epoch
    rolling_sum = StreamingMovingAverage(moving_average)
    start_time = time.time()
    start_time_iter = time.time()
    start_time_dekaiter = time.time()
    model.train()
    
    # best (averaged) training loss
    if os.path.exists(os.path.join(output_directory, "best_model")+".txt"):
        best_model_loss = float(str(open(os.path.join(output_directory, "best_model")+".txt", "r", encoding="utf-8").read()).split("\n")[0])
    else:
        best_model_loss = -6.20
    
    # best (validation) MSE on inferred spectrogram.
    if os.path.exists(os.path.join(output_directory, "best_val_model")+".txt"):
        best_MSE = float(str(open(os.path.join(output_directory, "best_val_model")+".txt", "r", encoding="utf-8").read()).split("\n")[0])
    else:
        best_MSE = 9e9
    
    epoch_offset = max(0, int(iteration / len(train_loader)))
    
    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print("{:,} total parameters in model".format(pytorch_total_params))
    pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("{:,} trainable parameters.".format(pytorch_total_params))
    
    print(f"Segment Length: {data_config['segment_length']:,}\nBatch Size: {batch_size:,}\nNumber of GPUs: {num_gpus:,}\nSamples/Iter: {data_config['segment_length']*batch_size*num_gpus:,}")
    
    training = True
    while training:
        try:
            if rank == 0:
                epochs_iterator = tqdm(range(epoch_offset, epochs), initial=epoch_offset, total=epochs, smoothing=0.01, desc="Epoch", position=1, unit="epoch")
            else:
                epochs_iterator = range(epoch_offset, epochs)
            # ================ MAIN TRAINING LOOP! ===================
            for epoch in epochs_iterator:
                print(f"Epoch: {epoch}")
                if num_gpus > 1:
                    train_sampler.set_epoch(epoch)
                
                if rank == 0:
                    iters_iterator = tqdm(enumerate(train_loader), desc=" Iter", smoothing=0, total=len(train_loader), position=0, unit="iter", leave=True)
                else:
                    iters_iterator = enumerate(train_loader)
                for i, batch in iters_iterator:
                    # run external code every iter, allows the run to be adjusted without restarts
                    if (i==0 or iteration % param_interval == 0):
                        try:
                            with open("run_every_epoch.py") as f:
                                internal_text = str(f.read())
                                if len(internal_text) > 0:
                                    #code = compile(internal_text, "run_every_epoch.py", 'exec')
                                    ldict = {'iteration': iteration, 'seconds_elapsed': time.time()-start_time}
                                    exec(internal_text, globals(), ldict)
                                else:
                                    print("No Custom code found, continuing without changes.")
                        except Exception as ex:
                            print(f"Custom code FAILED to run!\n{ex}")
                        globals().update(ldict)
                        locals().update(ldict)
                        if show_live_params:
                            print(internal_text)
                    if not iteration % 50: # check actual learning rate every 20 iters (because I sometimes see learning_rate variable go out-of-sync with real LR)
                        learning_rate = optimizer.param_groups[0]['lr']
                    # Learning Rate Schedule
                    if custom_lr:
                        old_lr = learning_rate
                        if iteration < warmup_start:
                            learning_rate = warmup_start_lr
                        elif iteration < warmup_end:
                            learning_rate = (iteration-warmup_start)*((A_+C_)-warmup_start_lr)/(warmup_end-warmup_start) + warmup_start_lr # learning rate increases from warmup_start_lr to A_ linearly over (warmup_end-warmup_start) iterations.
                        else:
                            if iteration < decay_start:
                                learning_rate = A_ + C_
                            else:
                                iteration_adjusted = iteration - decay_start
                                learning_rate = (A_*(e**(-iteration_adjusted/B_))) + C_
                        assert learning_rate > -1e-8, "Negative Learning Rate."
                        if old_lr != learning_rate:
                            for param_group in optimizer.param_groups:
                                param_group['lr'] = learning_rate
                    else:
                        scheduler.patience = scheduler_patience
                        scheduler.cooldown = scheduler_cooldown
                        if override_scheduler_last_lr:
                            scheduler._last_lr = override_scheduler_last_lr
                        if override_scheduler_best:
                            scheduler.best = override_scheduler_best
                        if override_scheduler_last_lr or override_scheduler_best:
                            print("scheduler._last_lr =", scheduler._last_lr, "scheduler.best =", scheduler.best, "  |", end='')
                    model.zero_grad()
                    mel, audio, speaker_ids = batch
                    mel = torch.autograd.Variable(mel.cuda(non_blocking=True))
                    audio = torch.autograd.Variable(audio.cuda(non_blocking=True))
                    speaker_ids = speaker_ids.cuda(non_blocking=True).long().squeeze(1)
                    outputs = model(mel, audio, speaker_ids)
                    
                    loss = criterion(outputs)
                    if num_gpus > 1:
                        reduced_loss = reduce_tensor(loss.data, num_gpus).item()
                    else:
                        reduced_loss = loss.item()
                    
                    if fp16_run:
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()
                    
                    if (reduced_loss > LossExplosionThreshold) or (math.isnan(reduced_loss)):
                        model.zero_grad()
                        raise LossExplosion(f"\nLOSS EXPLOSION EXCEPTION ON RANK {rank}: Loss reached {reduced_loss} during iteration {iteration}.\n\n\n")
                    
                    if use_grad_clip:
                        if fp16_run:
                            grad_norm = torch.nn.utils.clip_grad_norm_(
                                amp.master_params(optimizer), grad_clip_thresh)
                        else:
                            grad_norm = torch.nn.utils.clip_grad_norm_(
                                model.parameters(), grad_clip_thresh)
                        if type(grad_norm) == torch.Tensor:
                            grad_norm = grad_norm.item()
                        is_overflow = math.isinf(grad_norm) or math.isnan(grad_norm)
                    else: is_overflow = False; grad_norm=0.00001
                    
                    optimizer.step()
                    if not is_overflow and rank == 0:
                        # get current Loss Scale of first optimizer
                        loss_scale = amp._amp_state.loss_scalers[0]._loss_scale if fp16_run else 32768
                        
                        if with_tensorboard:
                            if (iteration % 100000 == 0):
                                # plot distribution of parameters
                                for tag, value in model.named_parameters():
                                    tag = tag.replace('.', '/')
                                    logger.add_histogram(tag, value.data.cpu().numpy(), iteration)
                            logger.add_scalar('training_loss', reduced_loss, iteration)
                            logger.add_scalar('training_loss_samples', reduced_loss, iteration*batch_size)
                            if (iteration % 20 == 0):
                                logger.add_scalar('learning.rate', learning_rate, iteration)
                            if (iteration % 10 == 0):
                                logger.add_scalar('duration', ((time.time() - start_time_dekaiter)/10), iteration)
                        
                        average_loss = rolling_sum.process(reduced_loss)
                        if (iteration % 10 == 0):
                            tqdm.write("{} {}:  {:.3f} {:.3f}  {:.3f} {:08.3F} {:.8f}LR ({:.8f} Effective)  {:.2f}s/iter {:.4f}s/item".format(time.strftime("%H:%M:%S"), iteration, reduced_loss, average_loss, best_MSE, round(grad_norm,3), learning_rate, min((grad_clip_thresh/grad_norm)*learning_rate,learning_rate), (time.time() - start_time_dekaiter)/10, ((time.time() - start_time_dekaiter)/10)/(batch_size*num_gpus)))
                            start_time_dekaiter = time.time()
                        else:
                            tqdm.write("{} {}:  {:.3f} {:.3f}  {:.3f} {:08.3F} {:.8f}LR ({:.8f} Effective) {}LS".format(time.strftime("%H:%M:%S"), iteration, reduced_loss, average_loss, best_MSE, round(grad_norm,3), learning_rate, min((grad_clip_thresh/grad_norm)*learning_rate,learning_rate), loss_scale))
                        start_time_iter = time.time()
                    
                    if rank == 0 and (len(rolling_sum.values) > moving_average-2):
                        if (average_loss+best_model_margin) < best_model_loss:
                            checkpoint_path = os.path.join(output_directory, "best_model")
                            try:
                                save_checkpoint(model, optimizer, learning_rate, iteration, amp, scheduler, speaker_lookup,
                                            checkpoint_path)
                            except KeyboardInterrupt: # Avoid corrupting the model.
                                save_checkpoint(model, optimizer, learning_rate, iteration, amp, scheduler, speaker_lookup,
                                            checkpoint_path)
                            text_file = open((f"{checkpoint_path}.txt"), "w", encoding="utf-8")
                            text_file.write(str(average_loss)+"\n"+str(iteration))
                            text_file.close()
                            best_model_loss = average_loss #Only save the model if X better than the current loss.
                    if rank == 0 and iteration > 0 and ((iteration % iters_per_checkpoint == 0) or (os.path.exists(save_file_check_path))):
                        checkpoint_path = f"{output_directory}/waveglow_{iteration}"
                        save_checkpoint(model, optimizer, learning_rate, iteration, amp, scheduler, speaker_lookup,
                                        checkpoint_path)
                        if (os.path.exists(save_file_check_path)):
                            os.remove(save_file_check_path)
                    
                    if (iteration % validation_interval == 0):
                        if rank == 0:
                            MSE, MAE = validate(model, loader_STFT, STFTs, logger, iteration, data_config['validation_files'], speaker_lookup, sigma, output_directory, data_config)
                            if scheduler:
                                MSE = torch.tensor(MSE, device='cuda')
                                if num_gpus > 1:
                                    broadcast(MSE, 0)
                                scheduler.step(MSE.item())
                                if MSE < best_MSE:
                                    checkpoint_path = os.path.join(output_directory, "best_val_model")
                                    try:
                                        save_checkpoint(model, optimizer, learning_rate, iteration, amp, scheduler, speaker_lookup,
                                                    checkpoint_path)
                                    except KeyboardInterrupt: # Avoid corrupting the model.
                                        save_checkpoint(model, optimizer, learning_rate, iteration, amp, scheduler, speaker_lookup,
                                                    checkpoint_path)
                                    text_file = open((f"{checkpoint_path}.txt"), "w", encoding="utf-8")
                                    text_file.write(str(MSE.item())+"\n"+str(iteration))
                                    text_file.close()
                                    best_MSE = MSE.item() #Only save the model if X better than the current loss.
                        else:
                            if scheduler:
                                MSE = torch.zeros(1, device='cuda')
                                broadcast(MSE, 0)
                                scheduler.step(MSE.item())
                        learning_rate = optimizer.param_groups[0]['lr'] #check actual learning rate (because I sometimes see learning_rate variable go out-of-sync with real LR)
                    iteration += 1
            training = False # exit the While loop
        
        except LossExplosion as ex: # print Exception and continue from checkpoint. (turns out it takes < 4 seconds to restart like this, f*****g awesome)
            print(ex) # print Loss
            checkpoint_path = os.path.join(output_directory, "best_model")
            assert os.path.exists(checkpoint_path), "best_val_model must exist for automatic restarts"
            
            # clearing VRAM for load checkpoint
            audio = mel = speaker_ids = loss = None
            torch.cuda.empty_cache()
            
            model.eval()
            model, optimizer, iteration, scheduler = load_checkpoint(checkpoint_path, model, optimizer, scheduler, fp16_run)
            learning_rate = optimizer.param_groups[0]['lr']
            epoch_offset = max(0, int(iteration / len(train_loader)))
            model.train()
            iteration += 1
            pass # and continue training.
Esempio n. 3
0
def load_checkpoint(checkpoint_path, model, optimizer):
    assert os.path.isfile(checkpoint_path)
    print("Loading checkpoint '{}'".format(checkpoint_path))
    checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
    
    state_dict = {k.replace("encoder_speaker_embedding.weight","encoder.encoder_speaker_embedding.weight"): v for k,v in torch.load(checkpoint_path)['state_dict'].items()}
    model.load_state_dict(state_dict) # tmp for updating old models
    
    #model.load_state_dict(checkpoint_dict['state_dict']) # original
    
    if 'optimizer' in checkpoint_dict.keys(): optimizer.load_state_dict(checkpoint_dict['optimizer'])
    if 'amp' in checkpoint_dict.keys(): amp.load_state_dict(checkpoint_dict['amp'])
    if 'learning_rate' in checkpoint_dict.keys(): learning_rate = checkpoint_dict['learning_rate']
    #if 'hparams' in checkpoint_dict.keys(): hparams = checkpoint_dict['hparams']
    if 'best_validation_loss' in checkpoint_dict.keys(): best_validation_loss = checkpoint_dict['best_validation_loss']
    if 'average_loss' in checkpoint_dict.keys(): average_loss = checkpoint_dict['average_loss']
    if (start_from_checkpoints_from_zero):
        iteration = 0
    else:
    iteration = checkpoint_dict['iteration']
    print("Loaded checkpoint '{}' from iteration {}" .format(
        checkpoint_path, iteration))
    return model, optimizer, learning_rate, iteration, best_validation_loss


def save_checkpoint(model, optimizer, learning_rate, iteration, hparams, best_validation_loss, average_loss, speaker_id_lookup, filepath):
    from utils import load_filepaths_and_text
    tqdm.write("Saving model and optimizer state at iteration {} to {}".format(
        iteration, filepath))
    
    # get speaker names to ID
    speakerlist = load_filepaths_and_text(hparams.speakerlist)
    speaker_name_lookup = {x[2]: speaker_id_lookup[x[3]] for x in speakerlist}
    
    torch.save({'iteration': iteration,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'learning_rate': learning_rate,
                #'amp': amp.state_dict(),
                'hparams': hparams,
                'speaker_id_lookup': speaker_id_lookup,
                'speaker_name_lookup': speaker_name_lookup,
                'best_validation_loss': best_validation_loss,
                'average_loss': average_loss}, filepath)
    tqdm.write("Saving Complete")


def validate(model, criterion, valset, iteration, batch_size, n_gpus,
             collate_fn, logger, distributed_run, rank, val_teacher_force_till, val_p_teacher_forcing, teacher_force=1):
    """Handles all the validation scoring and printing"""
    model.eval()
    with torch.no_grad():
        val_sampler = DistributedSampler(valset) if distributed_run else None
        val_loader = DataLoader(valset, sampler=val_sampler, num_workers=1,
                                shuffle=False, batch_size=batch_size,
                                pin_memory=False, drop_last=True, collate_fn=collate_fn)
        if teacher_force == 1:
            val_teacher_force_till = 0
            val_p_teacher_forcing = 1.0
        elif teacher_force == 2:
            val_teacher_force_till = 0
            val_p_teacher_forcing = 0.0
        val_loss = 0.0
        diagonality = torch.zeros(1)
        avg_prob = torch.zeros(1)
        for i, batch in tqdm(enumerate(val_loader), desc="Validation", total=len(val_loader), smoothing=0): # i = index, batch = stuff in array[i]
            x, y = model.parse_batch(batch)
            y_pred = model(x, teacher_force_till=val_teacher_force_till, p_teacher_forcing=val_p_teacher_forcing)
            rate, prob = alignment_metric(x, y_pred)
            diagonality += rate
            avg_prob += prob
            loss, gate_loss = criterion(y_pred, y)
            if distributed_run:
                reduced_val_loss = reduce_tensor(loss.data, n_gpus).item()
            else:
                reduced_val_loss = loss.item()
            val_loss += reduced_val_loss
            # end forloop
        val_loss = val_loss / (i + 1)
        diagonality = (diagonality / (i + 1)).item()
        avg_prob = (avg_prob / (i + 1)).item()
        # end torch.no_grad()
    model.train()
    if rank == 0:
        tqdm.write("Validation loss {}: {:9f}  Average Max Attention: {:9f}".format(iteration, val_loss, avg_prob))
        #logger.log_validation(val_loss, model, y, y_pred, iteration)
        if True:#iteration != 0:
            if teacher_force == 1:
                logger.log_teacher_forced_validation(val_loss, model, y, y_pred, iteration, val_teacher_force_till, val_p_teacher_forcing, diagonality, avg_prob)
            elif teacher_force == 2:
                logger.log_infer(val_loss, model, y, y_pred, iteration, val_teacher_force_till, val_p_teacher_forcing, diagonality, avg_prob)
            else:
                logger.log_validation(val_loss, model, y, y_pred, iteration, val_teacher_force_till, val_p_teacher_forcing, diagonality, avg_prob)
    return val_loss


def calculate_global_mean(data_loader, global_mean_npy, hparams):
    if global_mean_npy and os.path.exists(global_mean_npy):
        global_mean = np.load(global_mean_npy)
        return to_gpu(torch.tensor(global_mean).half()) if hparams.fp16_run else to_gpu(torch.tensor(global_mean).float())
    sums = []
    frames = []
    print('calculating global mean...')
    for i, batch in tqdm(enumerate(data_loader), total=len(data_loader), smoothing=0.001):
        text_padded, input_lengths, mel_padded, gate_padded,\
            output_lengths, speaker_ids, torchmoji_hidden, preserve_decoder_states = batch
        # padded values are 0.
        sums.append(mel_padded.double().sum(dim=(0, 2)))
        frames.append(output_lengths.double().sum())
    global_mean = sum(sums) / sum(frames)
    global_mean = to_gpu(global_mean.half()) if hparams.fp16_run else to_gpu(global_mean.float())
    if global_mean_npy:
        np.save(global_mean_npy, global_mean.cpu().numpy())
    return global_mean


def train(output_directory, log_directory, checkpoint_path, warm_start, warm_start_force, n_gpus,
          rank, group_name, hparams):
    """Training and validation logging results to tensorboard and stdout

    Params
    ------
    output_directory (string): directory to save checkpoints
    log_directory (string) directory to save tensorboard logs
    checkpoint_path(string): checkpoint path
    n_gpus (int): number of gpus
    rank (int): rank of current gpu
    hparams (object): comma separated list of "name=value" pairs.
    """
    hparams.n_gpus = n_gpus
    hparams.rank = rank
    if hparams.distributed_run:
        init_distributed(hparams, n_gpus, rank, group_name)

    torch.manual_seed(hparams.seed)
    torch.cuda.manual_seed(hparams.seed)

    train_loader, valset, collate_fn, train_sampler, trainset = prepare_dataloaders(hparams)
    speaker_lookup = trainset.speaker_ids
    
    if hparams.drop_frame_rate > 0.:
        if rank != 0: # if global_mean not yet calcuated, wait for main thread to do it
            while not os.path.exists(hparams.global_mean_npy): time.sleep(1)
        global_mean = calculate_global_mean(train_loader, hparams.global_mean_npy, hparams)
        hparams.global_mean = global_mean
    
    model = load_model(hparams)

    model.eval() # test if this is needed anymore

    learning_rate = hparams.learning_rate
	if hparams.Apex_optimizer: # apex optimizer is slightly faster with slightly more vram usage in my testing. Helps in both fp32 and fp16.
    	optimizer = apexopt.FusedAdam(model.parameters(), lr=learning_rate, weight_decay=hparams.weight_decay)
	else:
	    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=hparams.weight_decay)

    if hparams.fp16_run:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O2')

    if hparams.distributed_run:
        model = apply_gradient_allreduce(model)

    criterion = Tacotron2Loss(hparams)

    logger = prepare_directories_and_logger(
        output_directory, log_directory, rank)

    # Load checkpoint if one exists
    best_validation_loss = 0.8 # used to see when "best_model" should be saved, default = 0.8, load_checkpoint will update to last best value.
    iteration = 0
    epoch_offset = 0
    if checkpoint_path is not None:
        if warm_start:
            model, iteration = warm_start_model(
                checkpoint_path, model, hparams.ignore_layers)
        elif warm_start_force:
            model, iteration = warm_start_force_model(
                checkpoint_path, model)
        else:
            model, optimizer, _learning_rate, iteration, best_validation_loss = load_checkpoint(
                checkpoint_path, model, optimizer)
            if hparams.use_saved_learning_rate:
                learning_rate = _learning_rate
            iteration += 1  # next iteration is iteration + 1
            epoch_offset = max(0, int(iteration / len(train_loader)))
        print('Model Loaded')
	
    ## LEARNING RATE SCHEDULER
    if True:
        from torch.optim.lr_scheduler import ReduceLROnPlateau
        min_lr = 1e-5
        factor = 0.1**(1/5) # amount to scale the LR by on Validation Loss plateau
        scheduler = ReduceLROnPlateau(optimizer, 'min', factor=factor, patience=20, cooldown=2, min_lr=min_lr, verbose=True)
        print("ReduceLROnPlateau used as (optional) Learning Rate Scheduler.")
    else: scheduler=False

    model.train()
    is_overflow = False
	
    validate_then_terminate = 0 # I use this for testing old models with new metrics
    if validate_then_terminate:
        val_loss = validate(model, criterion, valset, iteration,
            hparams.batch_size, n_gpus, collate_fn, logger,
            hparams.distributed_run, rank)
        raise Exception("Finished Validation")
    
    for param_group in optimizer.param_groups:
        param_group['lr'] = learning_rate
    
    rolling_loss = StreamingMovingAverage(min(int(len(train_loader)), 200))
    # ================ MAIN TRAINNIG LOOP! ===================
    for epoch in tqdm(range(epoch_offset, hparams.epochs), initial=epoch_offset, total=hparams.epochs, desc="Epoch:", position=1, unit="epoch"):
        tqdm.write("Epoch:{}".format(epoch))
        
        if hparams.distributed_run: # shuffles the train_loader when doing multi-gpu training
            train_sampler.set_epoch(epoch)
        start_time = time.time()
        # start iterating through the epoch
        for i, batch in tqdm(enumerate(train_loader), desc="Iter:  ", smoothing=0, total=len(train_loader), position=0, unit="iter"):
                    # run external code every iter, allows the run to be adjusted without restarts
                    if (i==0 or iteration % param_interval == 0):
                        try:
                            with open("run_every_epoch.py") as f:
                                internal_text = str(f.read())
                                if len(internal_text) > 0:
                                    #code = compile(internal_text, "run_every_epoch.py", 'exec')
                                    ldict = {'iteration': iteration}
                                    exec(internal_text, globals(), ldict)
                                else:
                                    print("No Custom code found, continuing without changes.")
                        except Exception as ex:
                            print(f"Custom code FAILED to run!\n{ex}")
                        globals().update(ldict)
                        locals().update(ldict)
                        if show_live_params:
                            print(internal_text)
                    if not iteration % 50: # check actual learning rate every 20 iters (because I sometimes see learning_rate variable go out-of-sync with real LR)
                        learning_rate = optimizer.param_groups[0]['lr']
                    # Learning Rate Schedule
                    if custom_lr:
                        old_lr = learning_rate
                        if iteration < warmup_start:
                            learning_rate = warmup_start_lr
                        elif iteration < warmup_end:
                            learning_rate = (iteration-warmup_start)*((A_+C_)-warmup_start_lr)/(warmup_end-warmup_start) + warmup_start_lr # learning rate increases from warmup_start_lr to A_ linearly over (warmup_end-warmup_start) iterations.
                        else:
                            if iteration < decay_start:
                                learning_rate = A_ + C_
                            else:
                                iteration_adjusted = iteration - decay_start
                                learning_rate = (A_*(e**(-iteration_adjusted/B_))) + C_
                        assert learning_rate > -1e-8, "Negative Learning Rate."
                        if old_lr != learning_rate:
            for param_group in optimizer.param_groups:
                param_group['lr'] = learning_rate
                    else:
                        scheduler.patience = scheduler_patience
                        scheduler.cooldown = scheduler_cooldown
                        if override_scheduler_last_lr:
                            scheduler._last_lr = override_scheduler_last_lr
                            print("Scheduler last_lr overriden. scheduler._last_lr =", scheduler._last_lr)
                        if override_scheduler_best:
                            scheduler.best = override_scheduler_best
                            print("Scheduler best metric overriden. scheduler.best =", override_scheduler_best)
Esempio n. 4
0
def train(num_gpus, rank, group_name, stage, output_directory, epochs, learning_rate, sigma, iters_per_checkpoint, batch_size, seed, fp16_run, checkpoint_path, with_tensorboard, logdirname, datedlogdir, warm_start=False, optimizer='ADAM', start_zero=False):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        init_distributed(rank, num_gpus, group_name, **dist_config)
    #=====END:   ADDED FOR DISTRIBUTED======
    
    from model import HiFiGAN, HiFiGANLoss
    criterion = HiFiGANLoss(**hifigan_config).cuda()
    model = HiFiGAN(**hifigan_config).cuda()
    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        model = apply_gradient_allreduce(model)
        if stage >= 2:
            criterion = apply_gradient_allreduce(criterion)
    #=====END:   ADDED FOR DISTRIBUTED======
    
    criterion, optimizer_d = get_optimizer(criterion, optimizer, fp16_run, optimizer_fused=True) if stage >= 2 else (criterion, None)
    model, optimizer = get_optimizer(model, optimizer, fp16_run, optimizer_fused=True)
    
    ## LEARNING RATE SCHEDULER
    if True:
        from torch.optim.lr_scheduler import ReduceLROnPlateau
        min_lr = 1e-8
        factor = 0.1**(1/5) # amount to scale the LR by on Validation Loss plateau
        scheduler = ReduceLROnPlateau(optimizer, 'min', factor=factor, patience=20, cooldown=2, min_lr=min_lr, verbose=True, threshold=0.0001, threshold_mode='abs')
        print("ReduceLROnPlateau used as Learning Rate Scheduler.")
    else: scheduler=False
    
    # Load checkpoint if one exists
    iteration = 0
    if checkpoint_path != "":
        model, optimizer, criterion, optimizer_d, iteration, scheduler = load_checkpoint(checkpoint_path, model,
                                                      optimizer, criterion, optimizer_d, scheduler, fp16_run, stage, warm_start=warm_start)
        iteration += 1  # next iteration is iteration + 1
    if start_zero:
        iteration = 0
    
    trainset = Mel2Samp(**data_config, check_files=True)
    speaker_lookup = trainset.speaker_ids
    # =====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        train_sampler = DistributedSampler(trainset, shuffle=True)
        shuffle = False
    else:
        train_sampler = None
        shuffle = True
    # =====END:   ADDED FOR DISTRIBUTED======
    train_loader = DataLoader(trainset, num_workers=3, shuffle=shuffle,
                              sampler=train_sampler,
                              batch_size=batch_size,
                              pin_memory=False,
                              drop_last=True)
    
    # Get shared output_directory ready
    if rank == 0:
        if not os.path.isdir(output_directory):
            os.makedirs(output_directory)
            os.chmod(output_directory, 0o775)
        print("output directory", output_directory)
    
    if with_tensorboard and rank == 0:
        from tensorboardX import SummaryWriter
        if datedlogdir:
            timestr = time.strftime("%Y_%m_%d-%H_%M_%S")
            log_directory = os.path.join(output_directory, logdirname, timestr)
        else:
            log_directory = os.path.join(output_directory, logdirname)
        logger = SummaryWriter(log_directory)
    
    moving_average = int(min(len(train_loader), 200)) # average loss over entire Epoch
    rolling_sum = StreamingMovingAverage(moving_average)
    start_time = time.time()
    start_time_iter = time.time()
    start_time_dekaiter = time.time()
    model.train()
    
    # best (averaged) training loss
    if os.path.exists(os.path.join(output_directory, "best_model")+".txt"):
        best_model_loss = float(str(open(os.path.join(output_directory, "best_model")+".txt", "r", encoding="utf-8").read()).split("\n")[0])
    else:
        best_model_loss = 9e9
    
    # best (validation) MSE on inferred spectrogram.
    if os.path.exists(os.path.join(output_directory, "best_val_model")+".txt"):
        best_MSE = float(str(open(os.path.join(output_directory, "best_val_model")+".txt", "r", encoding="utf-8").read()).split("\n")[0])
    else:
        best_MSE = 9e9
    
    epoch_offset = max(0, int(iteration / len(train_loader)))
    
    print_params(model, name='generator')
    
    print(f"Segment Length: {data_config['segment_length']:,}\nBatch Size: {batch_size:,}\nNumber of GPUs: {num_gpus:,}\nSamples/Iter: {data_config['segment_length']*batch_size*num_gpus:,}")
    
    training = True
    while training:
        try:
            if rank == 0:
                epochs_iterator = tqdm(range(epoch_offset, epochs), initial=epoch_offset, total=epochs, smoothing=0.01, desc="Epoch", position=1, unit="epoch")
            else:
                epochs_iterator = range(epoch_offset, epochs)
            # ================ MAIN TRAINING LOOP! ===================
            for epoch in epochs_iterator:
                print(f"Epoch: {epoch}")
                if num_gpus > 1:
                    train_sampler.set_epoch(epoch)
                
                if rank == 0:
                    iters_iterator = tqdm(enumerate(train_loader), desc=" Iter", smoothing=0, total=len(train_loader), position=0, unit="iter", leave=True)
                else:
                    iters_iterator = enumerate(train_loader)
                for i, batch in iters_iterator:
                    # run external code every iter, allows the run to be adjusted without restarts
                    if (i==0 or iteration % param_interval == 0):
                        try:
                            with open("run_every_epoch.py") as f:
                                internal_text = str(f.read())
                                if len(internal_text) > 0:
                                    #code = compile(internal_text, "run_every_epoch.py", 'exec')
                                    ldict = {'iteration': iteration, 'seconds_elapsed': time.time()-start_time}
                                    exec(internal_text, globals(), ldict)
                                else:
                                    print("No Custom code found, continuing without changes.")
                        except Exception as ex:
                            print(f"Custom code FAILED to run!\n{ex}")
                        globals().update(ldict)
                        locals().update(ldict)
                        if show_live_params:
                            print(internal_text)
                    # Learning Rate Schedule
                    if custom_lr:
                        old_lr = learning_rate
                        if iteration < warmup_start:
                            learning_rate = warmup_start_lr
                        elif iteration < warmup_end:
                            learning_rate = (iteration-warmup_start)*((A_+C_)-warmup_start_lr)/(warmup_end-warmup_start) + warmup_start_lr # learning rate increases from warmup_start_lr to A_ linearly over (warmup_end-warmup_start) iterations.
                        else:
                            if iteration < decay_start:
                                learning_rate = A_ + C_
                            else:
                                iteration_adjusted = iteration - decay_start
                                learning_rate = (A_*(e**(-iteration_adjusted/B_))) + C_
                        assert learning_rate > -1e-8, "Negative Learning Rate."
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = learning_rate
                        if optimizer_d is not None:
                            for param_group in optimizer_d.param_groups:
                                param_group['lr'] = learning_rate*d_lr_scale
                    else:
                        scheduler.patience = scheduler_patience
                        scheduler.cooldown = scheduler_cooldown
                        if override_scheduler_last_lr:
                            scheduler._last_lr = override_scheduler_last_lr
                        if override_scheduler_best:
                            scheduler.best = override_scheduler_best
                        if override_scheduler_last_lr or override_scheduler_best:
                            print(f"scheduler._last_lr = {scheduler._last_lr} scheduler.best = {scheduler.best}  |", end='')
                    model.zero_grad()
                    noisy_audio, gt_audio, speaker_ids = batch
                    noisy_audio = torch.autograd.Variable(noisy_audio.cuda(non_blocking=True))
                    gt_audio = torch.autograd.Variable(gt_audio.cuda(non_blocking=True))
                    speaker_ids = speaker_ids.cuda(non_blocking=True).long().squeeze(1)
                    pred_audio = model(noisy_audio)#, speaker_ids)
                    
                    metrics = criterion(pred_audio, gt_audio, amp, model, optimizer, optimizer_d, num_gpus, use_grad_clip, grad_clip_thresh)
                    
                    if not metrics['is_overflow'] and rank == 0:
                        # get current Loss Scale of first optimizer
                        loss_scale = amp._amp_state.loss_scalers[0]._loss_scale if fp16_run else 32768
                        
                        if with_tensorboard:
                            if (iteration % 100000 == 0):
                                # plot distribution of parameters
                                for tag, value in model.named_parameters():
                                    tag = tag.replace('.', '/')
                                    logger.add_histogram(tag, value.data.cpu().numpy(), iteration)
                            for key, value in metrics.items():
                                if key not in ['is_overflow',]:
                                    logger.add_scalar(key, value, iteration)
                            if (iteration % 20 == 0):
                                logger.add_scalar('learning.rate', learning_rate, iteration)
                            if (iteration % 10 == 0):
                                logger.add_scalar('duration', ((time.time() - start_time_dekaiter)/10), iteration)
                        
                        logged_loss = metrics['g_train_loss'] if stage >= 2 else metrics['train_loss']
                        grad_norm = metrics['grad_norm']
                        average_loss = rolling_sum.process(logged_loss)
                        if (iteration % 10 == 0):
                            tqdm.write("{} {}:  {:.3f} {:.3f}  {:.3f} {:08.3F} {:.8f}LR ({:.8f} Effective)  {:.2f}s/iter {:.4f}s/item".format(time.strftime("%H:%M:%S"), iteration, logged_loss, average_loss, best_MSE, round(grad_norm,3), learning_rate, min((grad_clip_thresh/grad_norm)*learning_rate,learning_rate), (time.time() - start_time_dekaiter)/10, ((time.time() - start_time_dekaiter)/10)/(batch_size*num_gpus)))
                            start_time_dekaiter = time.time()
                        else:
                            tqdm.write("{} {}:  {:.3f} {:.3f}  {:.3f} {:08.3F} {:.8f}LR ({:.8f} Effective) {}LS".format(time.strftime("%H:%M:%S"), iteration, logged_loss, average_loss, best_MSE, round(grad_norm,3), learning_rate, min((grad_clip_thresh/grad_norm)*learning_rate,learning_rate), loss_scale))
                        start_time_iter = time.time()
                    
                    if rank == 0 and (len(rolling_sum.values) > moving_average-2):
                        if (average_loss+best_model_margin) < best_model_loss:
                            checkpoint_path = os.path.join(output_directory, "best_model")
                            try:
                                save_checkpoint(model, optimizer, criterion, optimizer_d, learning_rate, iteration, amp, scheduler, speaker_lookup, checkpoint_path)
                            except KeyboardInterrupt: # Avoid corrupting the model.
                                save_checkpoint(model, optimizer, criterion, optimizer_d, learning_rate, iteration, amp, scheduler, speaker_lookup, checkpoint_path)
                            text_file = open((f"{checkpoint_path}.txt"), "w", encoding="utf-8")
                            text_file.write(str(average_loss)+"\n"+str(iteration))
                            text_file.close()
                            best_model_loss = average_loss #Only save the model if X better than the current loss.
                    if rank == 0 and iteration > 0 and ((iteration % iters_per_checkpoint == 0) or (os.path.exists(save_file_check_path))):
                        checkpoint_path = f"{output_directory}/waveglow_{iteration}"
                        save_checkpoint(model, optimizer, criterion, optimizer_d, learning_rate, iteration, amp, scheduler, speaker_lookup, checkpoint_path)
                        if (os.path.exists(save_file_check_path)):
                            os.remove(save_file_check_path)
                    
                    if iteration%validation_interval == 0:
                        if rank == 0:
                            MSE, MAE = validate(model, trainset, logger, iteration, data_config['validation_files'], speaker_lookup, output_directory, data_config)
                            if scheduler:
                                MSE = torch.tensor(MSE, device='cuda')
                                if num_gpus > 1:
                                    broadcast(MSE, 0)
                                scheduler.step(MSE.item())
                                if MSE < best_MSE:
                                    checkpoint_path = os.path.join(output_directory, "best_val_model")
                                    try:
                                        save_checkpoint(model, optimizer, criterion, optimizer_d, learning_rate, iteration, amp, scheduler, speaker_lookup, checkpoint_path)
                                    except KeyboardInterrupt: # Avoid corrupting the model.
                                        save_checkpoint(model, optimizer, learning_rate, iteration, amp, scheduler, speaker_lookup, checkpoint_path)
                                    text_file = open((f"{checkpoint_path}.txt"), "w", encoding="utf-8")
                                    text_file.write(str(MSE.item())+"\n"+str(iteration))
                                    text_file.close()
                                    best_MSE = MSE.item()
                        else:
                            if scheduler:
                                MSE = torch.zeros(1, device='cuda')
                                broadcast(MSE, 0)
                                scheduler.step(MSE.item())
                    iteration += 1
            training = False # exit the training While loop
        
        except LossExplosion as ex: # print Exception and continue from checkpoint. (turns out it takes < 4 seconds to restart like this, f*****g awesome)
            print(ex) # print Loss
            checkpoint_path = os.path.join(output_directory, "best_model")
            assert os.path.exists(checkpoint_path), "best_model must exist for automatic restarts"
            
            # clearing VRAM for load checkpoint
            audio = mel = speaker_ids = loss = None
            torch.cuda.empty_cache()
            
            model.eval()
            model, optimizer, iteration, scheduler = load_checkpoint(checkpoint_path, model, optimizer, scheduler, fp16_run)
            learning_rate = optimizer.param_groups[0]['lr']
            epoch_offset = max(0, int(iteration / len(train_loader)))
            model.train()
            iteration += 1
            pass # and continue training.
Esempio n. 5
0
                 learning_rate = (iteration-warmup_start)*((A_+C_)-warmup_start_lr)/(warmup_end-warmup_start) + warmup_start_lr # learning rate increases from warmup_start_lr to A_ linearly over (warmup_end-warmup_start) iterations.
             else:
                 if iteration < decay_start:
                     learning_rate = A_ + C_
                 else:
                     iteration_adjusted = iteration - decay_start
                     learning_rate = (A_*(e**(-iteration_adjusted/B_))) + C_
             assert learning_rate > -1e-8, "Negative Learning Rate."
             if old_lr != learning_rate:
                 for param_group in optimizer.param_groups:
                     param_group['lr'] = learning_rate
         else:
             scheduler.patience = scheduler_patience
             scheduler.cooldown = scheduler_cooldown
             if override_scheduler_last_lr:
                 scheduler._last_lr = override_scheduler_last_lr
                 print("Scheduler last_lr overriden. scheduler._last_lr =", scheduler._last_lr)
             if override_scheduler_best:
                 scheduler.best = override_scheduler_best
                 print("Scheduler best metric overriden. scheduler.best =", override_scheduler_best)
 
 model.zero_grad()
 x, y = model.parse_batch(batch) # move batch to GPU
 y_pred = model(x, teacher_force_till=teacher_force_till, p_teacher_forcing=p_teacher_forcing, drop_frame_rate=drop_frame_rate)
 
 loss, gate_loss, attention_loss_item = criterion(y_pred, y)
 
 if hparams.distributed_run:
     reduced_loss = reduce_tensor(loss.data, n_gpus).item()
     reduced_gate_loss = reduce_tensor(gate_loss.data, n_gpus).item()
 else:
Esempio n. 6
0
def train(output_directory, log_directory, checkpoint_path, warm_start,
          warm_start_force, n_gpus, rank, group_name, hparams):
    torch.manual_seed(hparams.seed)
    torch.cuda.manual_seed(hparams.seed)
    #=====START: ADDED FOR DISTRIBUTED======
    if n_gpus > 1:
        init_distributed(rank, n_gpus, group_name, **dist_config)
    #=====END:   ADDED FOR DISTRIBUTED======

    model, criterion = getCore(hparams)

    #=====START: ADDED FOR DISTRIBUTED======
    if n_gpus > 1:
        model = apply_gradient_allreduce(model)
    #=====END:   ADDED FOR DISTRIBUTED======

    STFT = [
        TacotronSTFT(filter_length=window,
                     hop_length=hparams.hop_length,
                     win_length=window,
                     sampling_rate=hparams.sampling_rate,
                     n_mel_channels=160,
                     mel_fmin=hparams.mel_fmin,
                     mel_fmax=hparams.mel_fmax)
        for window in hparams.validation_windows
    ]

    optimizer = getOptimizer(model, hparams)

    if hparams.fp16_run:
        global amp
        from apex import amp
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=hparams.fp16_opt_level,
                                          min_loss_scale=2.0)
    else:
        amp = None

    # LEARNING RATE SCHEDULER
    if hparams.LRScheduler.lower() == "ReduceLROnPlateau".lower():
        from torch.optim.lr_scheduler import ReduceLROnPlateau
        min_lr = 1e-5
        factor = 0.1**(
            1 / 5)  # amount to scale the LR by on Validation Loss plateau
        scheduler = ReduceLROnPlateau(optimizer,
                                      'min',
                                      factor=factor,
                                      patience=20,
                                      cooldown=2,
                                      min_lr=min_lr,
                                      verbose=True)
        print("ReduceLROnPlateau used as Learning Rate Scheduler.")
    else:
        scheduler = None

    # Load checkpoint if one exists
    iteration = 0
    if checkpoint_path:
        model, optimizer, iteration, scheduler = load_checkpoint(
            warm_start, warm_start_force, checkpoint_path, model, optimizer,
            scheduler, hparams.fp16_run)
    iteration += 1  # next iteration is iteration + 1

    trainset = Mel2Samp(hparams)
    speaker_lookup = trainset.speaker_ids
    # =====START: ADDED FOR DISTRIBUTED======
    if n_gpus > 1:
        train_sampler = DistributedSampler(trainset, shuffle=True)
        shuffle = False
    else:
        train_sampler = None
        shuffle = True
    # =====END:   ADDED FOR DISTRIBUTED======
    train_loader = DataLoader(trainset,
                              num_workers=hparams.n_dataloader_workers,
                              shuffle=shuffle,
                              sampler=train_sampler,
                              batch_size=hparams.batch_size,
                              pin_memory=False,
                              drop_last=True)

    # Get shared output_directory ready
    if rank == 0:
        if not os.path.isdir(output_directory):
            os.makedirs(output_directory)
            os.chmod(output_directory, 0o775)
        print("output directory", output_directory)

    if rank == 0:
        from tensorboardX import SummaryWriter
        if False:  # dated and seperated log dirs for each run
            timestr = time.strftime("%Y_%m_%d-%H_%M_%S")
            log_directory = os.path.join(output_directory, log_directory,
                                         timestr)
        else:
            log_directory = os.path.join(output_directory, log_directory)
        logger = SummaryWriter(log_directory)

    moving_average = int(min(len(train_loader),
                             100))  # average loss over 100 iters
    rolling_sum = StreamingMovingAverage(moving_average)
    start_time = time.time()
    start_time_single_batch = time.time()

    model.train()

    if os.path.exists(os.path.join(output_directory, "best_train_model")):
        best_model_loss = float(
            str(
                open(os.path.join(output_directory, "best_train_model") +
                     ".txt",
                     "r",
                     encoding="utf-8").read()).split("\n")[0])
    else:
        best_model_loss = -4.20
    if os.path.exists(os.path.join(output_directory, "best_val_model")):
        best_MSE = float(
            str(
                open(os.path.join(output_directory, "best_val_model") + ".txt",
                     "r",
                     encoding="utf-8").read()).split("\n")[0])
    else:
        best_MSE = 9e9
    epoch_offset = max(0, int(iteration / len(train_loader)))

    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print("{:,} total parameters.".format(pytorch_total_params))
    pytorch_total_params = sum(p.numel() for p in model.parameters()
                               if p.requires_grad)
    print("{:,} trainable parameters.".format(pytorch_total_params))

    learning_rate = hparams.learning_rate
    # ================ MAIN TRAINING LOOP! ===================
    for epoch in get_progress_bar(range(epoch_offset, hparams.epochs),
                                  dict(initial=epoch_offset,
                                       total=hparams.epochs,
                                       smoothing=0.01,
                                       desc="Epoch",
                                       position=1,
                                       unit="epoch"),
                                  hparams,
                                  rank=rank):
        cprint(f"Epoch: {epoch}", b_tqdm=hparams.tqdm)
        if n_gpus > 1: train_sampler.set_epoch(epoch)

        for i, batch in get_progress_bar(enumerate(train_loader),
                                         dict(desc=" Iter",
                                              smoothing=0,
                                              total=len(train_loader),
                                              position=0,
                                              unit="iter",
                                              leave=True),
                                         hparams,
                                         rank=rank):
            # run external code every iter, allows the run to be adjusted without restarts
            if (i == 0 or iteration % param_interval == 0):
                try:
                    with open("hparams_realtime.py") as f:
                        internal_text = str(f.read())
                        ldict = {'iteration': iteration}
                        exec(internal_text, globals(), ldict)
                except Exception as ex:
                    cprint(f"Custom code FAILED to run!\n{ex}",
                           b_tqdm=hparams.tqdm)
                globals().update(ldict)
                locals().update(ldict)
                if show_live_params:
                    cprint(internal_text, b_tqdm=hparams.tqdm)
            assert warmup_start <= iteration, "Current iteration less than warmup_start."
            # Learning Rate Schedule
            if custom_lr:
                old_lr = learning_rate
                if iteration < warmup_end:
                    learning_rate = (iteration - warmup_start) * (
                        (A_ + C_) - warmup_start_lr
                    ) / (
                        warmup_end - warmup_start
                    ) + warmup_start_lr  # learning rate increases from warmup_start_lr to A_ linearly over (warmup_end-warmup_start) iterations.
                else:
                    if iteration < decay_start:
                        learning_rate = A_ + C_
                    else:
                        iteration_adjusted = iteration - decay_start
                        learning_rate = (A_ *
                                         (e**(-iteration_adjusted / B_))) + C_
                assert learning_rate > -1e-8, "Negative Learning Rate."
                if old_lr != learning_rate:
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = learning_rate
            else:
                scheduler.patience = scheduler_patience
                scheduler.cooldown = scheduler_cooldown
                if override_scheduler_last_lr:
                    scheduler._last_lr = override_scheduler_last_lr
                    cprint("Scheduler last_lr overriden. scheduler._last_lr =",
                           scheduler._last_lr,
                           b_tqdm=hparams.tqdm)
                if not iteration % 20:  # check actual learning rate every 20 iters (because I sometimes see learning_rate variable go out-of-sync with real LR)
                    learning_rate = optimizer.param_groups[0]['lr']
                if override_scheduler_best:
                    scheduler.best = override_scheduler_best
                    cprint("Scheduler best metric overriden. scheduler.best =",
                           override_scheduler_best,
                           b_tqdm=hparams.tqdm)

            model.zero_grad()
            mel, audio, speaker_ids = batch
            mel = torch.autograd.Variable(mel.cuda(non_blocking=True))
            audio = torch.autograd.Variable(audio.cuda(non_blocking=True))
            if model.multispeaker:
                speaker_ids = torch.autograd.Variable(
                    speaker_ids.cuda(non_blocking=True)).long().squeeze(1)
                outputs = model(mel, audio, speaker_ids)
            else:
                outputs = model(mel, audio)

            loss = criterion(outputs)
            if n_gpus > 1:
                reduced_loss = reduce_tensor(loss.data, n_gpus).item()
            else:
                reduced_loss = loss.item()

            assert reduced_loss < 1e5, "Model Diverged. Loss > 1e5"
            if hparams.fp16_run:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            if hparams.b_grad_clip:
                if hparams.fp16_run:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), hparams.grad_clip_thresh)
                else:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        model.parameters(), hparams.grad_clip_thresh)
                is_overflow = math.isinf(grad_norm) or math.isnan(grad_norm)
            else:
                is_overflow = False
                grad_norm = 0.00001

            optimizer.step()
            if not is_overflow and rank == 0:
                if (iteration % 100000 == 0):
                    # plot distribution of parameters
                    for tag, value in model.named_parameters():
                        tag = tag.replace('.', '/')
                        logger.add_histogram(tag,
                                             value.data.cpu().numpy(),
                                             iteration)
                logger.add_scalar('training_loss', reduced_loss, iteration)
                if (iteration % 20 == 0):
                    logger.add_scalar('learning.rate', learning_rate,
                                      iteration)
                if (iteration % 10 == 0):
                    logger.add_scalar('duration',
                                      ((time.time() - start_time) / 10),
                                      iteration)
                start_time_single_batch = time.time()

            average_loss = rolling_sum.process(reduced_loss)
            if rank == 0:
                if (iteration % 10 == 0):
                    cprint(
                        "{} {}:  {:.3f}  {:.3f} {:08.3F} {:.8f}LR ({:.8f} Effective)  {:.2f}s/iter {:.4f}s/item"
                        .format(
                            time.strftime("%H:%M:%S"), iteration, reduced_loss,
                            average_loss, round(grad_norm, 3), learning_rate,
                            min((hparams.grad_clip_thresh / grad_norm) *
                                learning_rate, learning_rate),
                            (time.time() - start_time) / 10,
                            ((time.time() - start_time) / 10) /
                            (hparams.batch_size * n_gpus)),
                        b_tqdm=hparams.tqdm)
                    start_time = time.time()
                else:
                    cprint(
                        "{} {}:  {:.3f}  {:.3f} {:08.3F} {:.8f}LR ({:.8f} Effective)"
                        .format(
                            time.strftime("%H:%M:%S"), iteration, reduced_loss,
                            average_loss, round(grad_norm, 3), learning_rate,
                            min((hparams.grad_clip_thresh / grad_norm) *
                                learning_rate, learning_rate)),
                        b_tqdm=hparams.tqdm)

            if rank == 0 and (len(rolling_sum.values) > moving_average - 2):
                if (average_loss + best_model_margin) < best_model_loss:
                    checkpoint_path = os.path.join(output_directory,
                                                   "best_train_model")
                    try:
                        save_checkpoint(model, optimizer, hparams,
                                        learning_rate, iteration, amp,
                                        scheduler, speaker_lookup,
                                        checkpoint_path)
                    except KeyboardInterrupt:  # Avoid corrupting the model.
                        save_checkpoint(model, optimizer, hparams,
                                        learning_rate, iteration, amp,
                                        scheduler, speaker_lookup,
                                        checkpoint_path)
                    text_file = open((f"{checkpoint_path}.txt"),
                                     "w",
                                     encoding="utf-8")
                    text_file.write(str(average_loss) + "\n" + str(iteration))
                    text_file.close()
                    best_model_loss = average_loss  #Only save the model if X better than the current loss.
            if rank == 0 and ((iteration % hparams.iters_per_checkpoint == 0)
                              or (os.path.exists(save_file_check_path))):
                checkpoint_path = f"{output_directory}/waveglow_{iteration}"
                save_checkpoint(model, optimizer, hparams, learning_rate,
                                iteration, amp, scheduler, speaker_lookup,
                                checkpoint_path)
                start_time_single_batch = time.time()
                if (os.path.exists(save_file_check_path)):
                    os.remove(save_file_check_path)

            if (iteration % validation_interval == 0):
                if rank == 0:
                    MSE, MAE = validate(model, STFT, logger, iteration,
                                        speaker_lookup, hparams,
                                        output_directory)
                    if scheduler and n_gpus > 1:
                        MSE = torch.tensor(MSE, device='cuda')
                        broadcast(MSE, 0)
                        scheduler.step(MSE.item())
                        if MSE < best_MSE:
                            checkpoint_path = os.path.join(
                                output_directory, "best_val_model")
                            try:
                                save_checkpoint(model, optimizer, hparams,
                                                learning_rate, iteration, amp,
                                                scheduler, speaker_lookup,
                                                checkpoint_path)
                            except KeyboardInterrupt:  # Avoid corrupting the model.
                                save_checkpoint(model, optimizer, hparams,
                                                learning_rate, iteration, amp,
                                                scheduler, speaker_lookup,
                                                checkpoint_path)
                            text_file = open((f"{checkpoint_path}.txt"),
                                             "w",
                                             encoding="utf-8")
                            text_file.write(
                                str(MSE.item()) + "\n" + str(iteration))
                            text_file.close()
                            best_MSE = MSE.item(
                            )  #Only save the model if X better than the current loss.
                else:
                    if scheduler:
                        MSE = torch.zeros(1, device='cuda')
                        broadcast(MSE, 0)
                        scheduler.step(MSE.item())
            iteration += 1