Пример #1
0
def validate(hparams, args, file_losses, loss_scalars, model, criterion, valset, best_val_loss_dict, iteration,
             collate_fn, logger,):
    """Handles all the validation scoring and printing"""
    model.eval()
    with torch.no_grad():
        if not len(valset.filelist) >= hparams.n_gpus*hparams.batch_size:
            print(f'too few files in validation set! Found {len(valset.filelist)}, expected {hparams.batch_size} or more. If your dataset has single speaker, you can change "inference_equally_sample_speakers" to False in hparams.py which *may* fix the issue.\nIf you have a small amount of data, increase `dataset_p_val` or decrease `val_batch_size`')
        val_sampler = DistributedSampler(valset) if hparams.distributed_run else None
        val_loader = DataLoader(valset, sampler=val_sampler,
                                num_workers=hparams.val_num_workers,# prefetch_factor=hparams.prefetch_factor,
                                shuffle=False, batch_size=hparams.batch_size,
                                pin_memory=False, drop_last=False, collate_fn=collate_fn)
        
        loss_dict_total = None
        for i, batch in tqdm(enumerate(val_loader), desc="Validation", total=len(val_loader), smoothing=0): # i = index, batch = stuff in array[i]
            y = model.parse_batch(batch)
            with torch.random.fork_rng(devices=[0,]):
                torch.random.manual_seed(i)# use repeatable seeds during validation so results are more consistent and comparable.
                y['speaker_embeds'] = (y['parallel_speaker_embed'] +y['non_parallel_speaker_embed'])/2# [B, embed]
                y_pred = force(model, valid_kwargs=model_args, **{**y,
                                    'c_org': y['speaker_embeds'],
                                    'c_trg': y['speaker_embeds'],})
            
            loss_dict, file_losses_batch = criterion(model, y_pred, y, loss_scalars)
            file_losses = update_smoothed_dict(file_losses, file_losses_batch, file_losses_smoothness)
            if loss_dict_total is None:
                loss_dict_total = {k: 0. for k, v in loss_dict.items()}
            
            if hparams.distributed_run:
                reduced_loss_dict = {k: reduce_tensor(v.data, args.n_gpus).item() if v is not None else 0. for k, v in loss_dict.items()}
            else:
                reduced_loss_dict = {k: v.item() if v is not None else 0. for k, v in loss_dict.items()}
            reduced_loss = reduced_loss_dict['loss']
            
            for k in loss_dict_total.keys():
                loss_dict_total[k] = loss_dict_total[k] + reduced_loss_dict[k]
            # end forloop
        loss_dict_total = {k: v/(i+1) for k, v in loss_dict_total.items()}
        # end torch.no_grad()
    
    model.train()
    
    # update best losses
    if best_val_loss_dict is None:
        best_val_loss_dict = loss_dict_total
    else:
        best_val_loss_dict = {k: min(best_val_loss_dict[k], loss_dict_total[k]) for k in best_val_loss_dict.keys()}
    
    # print, log data and return.
    if args.rank == 0:
        tqdm.write(f"Validation loss {iteration}: {loss_dict_total['loss']:9f}")
        if iteration > 1:
            log_terms = (loss_dict_total, best_val_loss_dict, model, y, y_pred, iteration)
            logger.log_validation(*log_terms)
    
    return loss_dict_total['loss'], best_val_loss_dict, file_losses
Пример #2
0
def train(args, rank, group_name, hparams):
    """Training and validation logging results to tensorboard and stdout

    Params
    ------
    args.output_directory (string): directory to save checkpoints
    args.log_directory (string) directory to save tensorboard logs
    args.checkpoint_path(string): checkpoint path
    args.n_gpus (int): number of gpus
    rank (int): rank of current gpu
    hparams (object): comma separated list of "name=value" pairs.
    """
    # setup distributed
    hparams.n_gpus = args.n_gpus
    hparams.rank = rank
    if hparams.distributed_run:
        init_distributed(hparams, args.n_gpus, rank, group_name)
    
    # reproducablilty stuffs
    torch.manual_seed(hparams.seed)
    torch.cuda.manual_seed(hparams.seed)
    
    # initialize blank model
    print('Initializing Tacotron2...')
    model = load_model(hparams)
    print('Done')
    global model_args
    model_args = get_args(model.forward)
    model.eval()
    learning_rate = hparams.learning_rate
    
    # (optional) show the names of each layer in model, mainly makes it easier to copy/paste what you want to adjust
    if hparams.print_layer_names_during_startup:
        print(*[f"Layer{i} = "+str(x[0])+" "+str(x[1].shape) for i,x in enumerate(list(model.named_parameters()))], sep="\n")
    
    # (optional) Freeze layers by disabling grads
    if len(hparams.frozen_modules):
        for layer, params in list(model.named_parameters()):
            if any(layer.startswith(module) for module in hparams.frozen_modules):
                params.requires_grad = False
                print(f"Layer: {layer} has been frozen")
    
    if len(hparams.unfrozen_modules):
        for layer, params in list(model.named_parameters()):
            if any(layer.startswith(module) for module in hparams.frozen_modules):
                params.requires_grad = True
                print(f"Layer: {layer} has been unfrozen")
    
    # define optimizer (any params without requires_grad are ignored)
    #optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate, weight_decay=hparams.weight_decay)
    optimizer = apexopt.FusedAdam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate, weight_decay=hparams.weight_decay)
    
    if True and rank == 0:
        pytorch_total_params = sum(p.numel() for p in model.parameters())
        print("{:,} total parameters in model".format(pytorch_total_params))
        pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print("{:,} trainable parameters.".format(pytorch_total_params))
    
    print("Initializing AMP Model / Optimzier")
    if hparams.fp16_run:
        model, optimizer = amp.initialize(model, optimizer, opt_level=f'O{hparams.fp16_run_optlvl}')
    
    print("Initializing Gradient AllReduce model wrapper.")
    if hparams.distributed_run:
        model = apply_gradient_allreduce(model)
    
    print("Initializing Tacotron2 Loss func.")
    criterion = Tacotron2Loss(hparams)
    
    print("Initializing Tacotron2 Logger.")
    logger = prepare_directories_and_logger(hparams, args)
    
    # Load checkpoint if one exists
    best_validation_loss = 1e3# used to see when "best_val_model" should be saved
    best_inf_attsc       = -99# used to see when "best_inf_attsc" should be saved
    
    n_restarts = 0
    checkpoint_iter = 0
    iteration = 0
    epoch_offset = 0
    _learning_rate = 1e-3
    saved_lookup = None
    original_filelist = None
    
    global file_losses
    file_losses = {}
    global file_losses_smoothness
    file_losses_smoothness = 0.6
    
    global best_val_loss_dict
    best_val_loss_dict = None
    global best_loss_dict
    best_loss_dict = None
    global expavg_loss_dict
    expavg_loss_dict = None
    expavg_loss_dict_iters = 0# initial iters expavg_loss_dict has been fitted
    loss_dict_smoothness = 0.95 # smoothing factor
    
    if args.checkpoint_path is not None:
        if args.warm_start:
            model, iteration, saved_lookup = warm_start_model(
                args.checkpoint_path, model, hparams.ignore_layers)
        elif args.warm_start_force:
            model, iteration, saved_lookup = warm_start_force_model(
                args.checkpoint_path, model)
        else:
            _ = load_checkpoint(args.checkpoint_path, model, optimizer, best_val_loss_dict, best_loss_dict)
            model, optimizer, _learning_rate, iteration, best_validation_loss, best_inf_attsc, saved_lookup, best_val_loss_dict, best_loss_dict = _
            if hparams.use_saved_learning_rate:
                learning_rate = _learning_rate
        checkpoint_iter = iteration
        iteration += 1  # next iteration is iteration + 1
        print('Model Loaded')
    
    # define datasets/dataloaders
    dataloader_args = [*get_args(criterion.forward), *model_args]
    if rank == 0:
        dataloader_args.extend(get_args(logger.log_training))
    train_loader, valset, collate_fn, train_sampler, trainset = prepare_dataloaders(hparams, dataloader_args, args, saved_lookup)
    epoch_offset = max(0, int(iteration / len(train_loader)))
    speaker_lookup = trainset.speaker_ids
    
    # load and/or generate global_mean
    if hparams.drop_frame_rate > 0.:
        if rank != 0: # if global_mean not yet calcuated, wait for main thread to do it
            while not os.path.exists(hparams.global_mean_npy): time.sleep(1)
        global_mean = calculate_global_mean(train_loader, hparams.global_mean_npy, hparams)
        hparams.global_mean = global_mean
        model.global_mean = global_mean
    
    # define scheduler
    use_scheduler = 0
    if use_scheduler:
        scheduler = ReduceLROnPlateau(optimizer, factor=0.1**(1/5), patience=10)
    
    model.train()
    is_overflow = False
    validate_then_terminate = 0
    if validate_then_terminate:
        val_loss = validate(model, criterion, valset, iteration,
            hparams.batch_size, args.n_gpus, collate_fn, logger,
            hparams.distributed_run, rank)
        raise Exception("Finished Validation")
    
    for param_group in optimizer.param_groups:
        param_group['lr'] = learning_rate
    
    just_did_val = True
    rolling_loss = StreamingMovingAverage(min(int(len(train_loader)), 200))
    # ================ MAIN TRAINNIG LOOP! ===================
    training = True
    while training:
        try:
            for epoch in tqdm(range(epoch_offset, hparams.epochs), initial=epoch_offset, total=hparams.epochs, desc="Epoch:", position=1, unit="epoch"):
                tqdm.write("Epoch:{}".format(epoch))
                
                train_loader.dataset.shuffle_dataset()# Shuffle Dataset
                dataset_len = len(train_loader)
                
                start_time = time.time()
                # start iterating through the epoch
                for i, batch in tqdm(enumerate(train_loader), desc="Iter:  ", smoothing=0, total=len(train_loader), position=0, unit="iter"):
                    # run external code every epoch or 1000 iters, allows the run to be adjusted without restarts
                    if (i==0 or iteration % param_interval == 0):
                        try:
                            with open("run_every_epoch.py", encoding='utf-8') as f:
                                internal_text = str(f.read())
                                if len(internal_text) > 0:
                                    #code = compile(internal_text, "run_every_epoch.py", 'exec')
                                    ldict = {'iteration': iteration, 'checkpoint_iter': checkpoint_iter, 'n_restarts': n_restarts}
                                    exec(internal_text, globals(), ldict)
                                else:
                                    print("[info] tried to execute 'run_every_epoch.py' but it is empty")
                        except Exception as ex:
                            print(f"[warning] 'run_every_epoch.py' FAILED to execute!\nException:\n{ex}")
                        globals().update(ldict)
                        locals().update(ldict)
                        if show_live_params:
                            print(internal_text)
                    n_restarts = n_restarts_override if (n_restarts_override is not None) else n_restarts or 0
                    # Learning Rate Schedule
                    if custom_lr:
                        if iteration < warmup_start:
                            learning_rate = warmup_start_lr
                        elif iteration < warmup_end:
                            learning_rate = (iteration-warmup_start)*((A_+C_)-warmup_start_lr)/(warmup_end-warmup_start) + warmup_start_lr # learning rate increases from warmup_start_lr to A_ linearly over (warmup_end-warmup_start) iterations.
                        else:
                            if iteration < decay_start:
                                learning_rate = A_ + C_
                            else:
                                iteration_adjusted = iteration - decay_start
                                learning_rate = (A_*(e**(-iteration_adjusted/B_))) + C_
                        assert learning_rate > -1e-8, "Negative Learning Rate."
                        if decrease_lr_on_restart:
                            learning_rate = learning_rate/(2**(n_restarts/3))
                        if just_did_val:
                            learning_rate = 0.0
                            just_did_val=False
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = learning_rate
                    
                    # /run external code every epoch, allows the run to be adjusting without restarts/
                    model.zero_grad()
                    y = model.parse_batch(batch) # move batch to GPU (async)
                    y_pred = force(model, valid_kwargs=model_args, **{**y, "teacher_force_till": teacher_force_till, "p_teacher_forcing": p_teacher_forcing, "drop_frame_rate": drop_frame_rate})
                    
                    loss_scalars = {
                         "spec_MSE_weight": spec_MSE_weight,
                        "spec_MFSE_weight": spec_MFSE_weight,
                      "postnet_MSE_weight": postnet_MSE_weight,
                     "postnet_MFSE_weight": postnet_MFSE_weight,
                        "gate_loss_weight": gate_loss_weight,
                        "sylps_kld_weight": sylps_kld_weight,
                        "sylps_MSE_weight": sylps_MSE_weight,
                        "sylps_MAE_weight": sylps_MAE_weight,
                         "diag_att_weight": diag_att_weight,
                    }
                    loss_dict, file_losses_batch = criterion(y_pred, y, loss_scalars)
                    
                    file_losses = update_smoothed_dict(file_losses, file_losses_batch, file_losses_smoothness)
                    loss = loss_dict['loss']
                    
                    if hparams.distributed_run:
                        reduced_loss_dict = {k: reduce_tensor(v.data, args.n_gpus).item() if v is not None else 0. for k, v in loss_dict.items()}
                    else:
                        reduced_loss_dict = {k: v.item() if v is not None else 0. for k, v in loss_dict.items()}
                    
                    reduced_loss = reduced_loss_dict['loss']
                    
                    if hparams.fp16_run:
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()
                    
                    if grad_clip_thresh:
                        if hparams.fp16_run:
                            grad_norm = torch.nn.utils.clip_grad_norm_(
                                amp.master_params(optimizer), grad_clip_thresh)
                            is_overflow = math.isinf(grad_norm) or math.isnan(grad_norm)
                        else:
                            grad_norm = torch.nn.utils.clip_grad_norm_(
                                model.parameters(), grad_clip_thresh)
                    else:
                        grad_norm = 0.0
                    
                    optimizer.step()
                    
                    # get current Loss Scale of first optimizer
                    loss_scale = amp._amp_state.loss_scalers[0]._loss_scale if hparams.fp16_run else 32768.
                    
                    # restart if training/model has collapsed
                    if (iteration > 1e3 and (reduced_loss > LossExplosionThreshold)) or (math.isnan(reduced_loss)) or (loss_scale < 1/4):
                        raise LossExplosion(f"\nLOSS EXPLOSION EXCEPTION ON RANK {rank}: Loss reached {reduced_loss} during iteration {iteration}.\n\n\n")
                    
                    if expavg_loss_dict is None:
                        expavg_loss_dict = reduced_loss_dict
                    else:
                        expavg_loss_dict = {k: (reduced_loss_dict[k]*(1-loss_dict_smoothness))+(expavg_loss_dict[k]*loss_dict_smoothness) for k in expavg_loss_dict.keys()}
                        expavg_loss_dict_iters += 1
                    
                    if expavg_loss_dict_iters > 100:
                        if best_loss_dict is None:
                            best_loss_dict = expavg_loss_dict
                        else:
                            best_loss_dict = {k: min(best_loss_dict[k], expavg_loss_dict[k]) for k in best_loss_dict.keys()}
                    
                    if rank == 0:
                        duration = time.time() - start_time
                        if not is_overflow:
                            average_loss = rolling_loss.process(reduced_loss)
                            tqdm.write(
                                f"{iteration} [Train_loss:{reduced_loss:.4f} Avg:{average_loss:.4f}] "
                                f"[Grad Norm {grad_norm:.4f}] [{duration:.2f}s/it] "
                                f"[{(duration/(hparams.batch_size*args.n_gpus)):.3f}s/file] "
                                f"[{learning_rate:.7f} LR] [{loss_scale:.0f} LS]")
                            logger.log_training(reduced_loss_dict, expavg_loss_dict, best_loss_dict, grad_norm, learning_rate, duration, iteration, teacher_force_till, p_teacher_forcing, drop_frame_rate)
                        else:
                            tqdm.write("Gradient Overflow, Skipping Step")
                        start_time = time.time()
                    
                    if iteration%checkpoint_interval==0 or os.path.exists(save_file_check_path):
                        # save model checkpoint like normal
                        if rank == 0:
                            checkpoint_path = os.path.join(args.output_directory, "checkpoint_{}".format(iteration))
                            save_checkpoint(model, optimizer, learning_rate, iteration, hparams, best_validation_loss, best_inf_attsc, average_loss, best_val_loss_dict, best_loss_dict, speaker_lookup, checkpoint_path)
                    
                    if iteration%dump_filelosses_interval==0:
                        print("Updating File_losses dict!")
                        file_losses = write_dict_to_file(file_losses, os.path.join(args.output_directory, 'file_losses.csv'), args.n_gpus, rank)
                    
                    if (iteration % int(validation_interval) == 0) or (os.path.exists(save_file_check_path)) or (iteration < 1000 and (iteration % 250 == 0)):
                        if rank == 0 and os.path.exists(save_file_check_path):
                            os.remove(save_file_check_path)
                        # perform validation and save "best_val_model" depending on validation loss
                        val_loss, best_val_loss_dict, file_losses = validate(hparams, args, file_losses, model, criterion, valset, best_val_loss_dict, iteration, collate_fn, logger, val_teacher_force_till, val_p_teacher_forcing, teacher_force=0)# validate/teacher_force
                        file_losses = write_dict_to_file(file_losses, os.path.join(args.output_directory, 'file_losses.csv'), args.n_gpus, rank)
                        valatt_loss, *_ = validate(hparams, args, file_losses, model, criterion, valset, best_val_loss_dict, iteration, collate_fn, logger, 0, 0.0, teacher_force=2)# infer
                        if use_scheduler:
                            scheduler.step(val_loss)
                        if (val_loss < best_validation_loss):
                            best_validation_loss = val_loss
                            if rank == 0 and hparams.save_best_val_model:
                                checkpoint_path = os.path.join(args.output_directory, "best_val_model")
                                save_checkpoint(
                                    model, optimizer, learning_rate, iteration, hparams, best_validation_loss, max(best_inf_attsc, val_loss),
                                    average_loss, best_val_loss_dict, best_loss_dict, speaker_lookup, checkpoint_path)
                        if (valatt_loss > best_inf_attsc):
                            best_inf_attsc = valatt_loss
                            if rank == 0 and hparams.save_best_inf_attsc:
                                checkpoint_path = os.path.join(args.output_directory, "best_inf_attsc")
                                save_checkpoint(
                                    model, optimizer, learning_rate, iteration, hparams, best_validation_loss, best_inf_attsc,
                                    average_loss, best_val_loss_dict, best_loss_dict, speaker_lookup, checkpoint_path)
                        just_did_val = True
                    
                    iteration += 1
                    # end of iteration loop
                
                # update filelist of training dataloader
                if (iteration > hparams.min_avg_max_att_start) and (iteration-checkpoint_iter >= dataset_len):
                    print("Updating File_losses dict!")
                    file_losses = write_dict_to_file(file_losses, os.path.join(args.output_directory, 'file_losses.csv'), args.n_gpus, rank)
                    print("Done!")
                    
                    print("Updating dataloader filtered paths!")
                    bad_file_paths = [k for k in list(file_losses.keys()) if
                        file_losses[k]['avg_max_attention'] < hparams.min_avg_max_att or# if attention stength if too weak
                        file_losses[k]['att_diagonality']   > hparams.max_diagonality or# or diagonality is too high
                        file_losses[k]['spec_MSE']          > hparams.max_spec_mse]     # or audio quality is too low
                                                                                        # then add to bad files list
                    bad_file_paths = set(bad_file_paths)                                # and remove from dataset
                    filted_filelist = [x for x in train_loader.dataset.filelist if not (x[0] in bad_file_paths)]
                    train_loader.dataset.update_filelist(filted_filelist)
                    print(f"Done! {len(bad_file_paths)} Files removed from dataset. {len(filted_filelist)} Files remain.")
                    del filted_filelist, bad_file_paths
                    if iteration > hparams.speaker_mse_sampling_start:
                        print("Updating dataset with speaker MSE Sampler!")
                        if original_filelist is None:
                            original_filelist = train_loader.dataset.filelist
                        train_loader.dataset.update_filelist(get_mse_sampled_filelist(
                                                             original_filelist, file_losses, hparams.speaker_mse_exponent, seed=iteration))
                        print("Done!")
                
                # end of epoch loop
            training = False # exit the While loop
        
        #except Exception as ex: # print Exception and continue from checkpoint. (turns out it takes < 4 seconds to restart like this, f*****g awesome)
        except LossExplosion as ex: # print Exception and continue from checkpoint. (turns out it takes < 4 seconds to restart like this, f*****g awesome)
            print(ex) # print Loss
            checkpoint_path = os.path.join(args.output_directory, "best_val_model")
            assert os.path.exists(checkpoint_path), "best_val_model checkpoint must exist for automatic restarts"
            
            if hparams.fp16_run:
                amp._amp_state.loss_scalers[0]._loss_scale = 32768
            
            # clearing VRAM for load checkpoint
            model.zero_grad()
            x=y=y_pred=loss=len_loss=loss_z=loss_w=loss_s=loss_att=dur_loss_z=dur_loss_w=dur_loss_s=None
            torch.cuda.empty_cache()
            
            model.eval()
            model, optimizer, _learning_rate, iteration, best_validation_loss, saved_lookup = load_checkpoint(checkpoint_path, model, optimizer)
            learning_rate = optimizer.param_groups[0]['lr']
            epoch_offset = max(0, int(iteration / len(train_loader)))
            model.train()
            checkpoint_iter = iteration
            iteration += 1
            n_restarts += 1
        except KeyboardInterrupt as ex:
            print(ex)
Пример #3
0
def validate(hparams, args, file_losses, model, criterion, valset, best_val_loss_dict, iteration,
             collate_fn, logger, val_teacher_force_till, val_p_teacher_forcing, teacher_force=-1):
    """Handles all the validation scoring and printing"""
    assert teacher_force >= 0, 'teacher_force not specified.'
    model.eval()
    with torch.no_grad():
        if teacher_force == 2:# if inference, sample from each speaker equally. So speakers with smaller datasets get the same weighting onto the val loss.
            orig_filelist = valset.filelist
            valset.update_filelist(get_mse_sampled_filelist(orig_filelist, file_losses, 0.0, seed=1234))
        val_sampler = DistributedSampler(valset) if hparams.distributed_run else None
        val_loader = DataLoader(valset, sampler=val_sampler, num_workers=hparams.num_workers,
                                shuffle=False, batch_size=hparams.batch_size,
                                pin_memory=False, drop_last=True, collate_fn=collate_fn)
        
        loss_dict_total = None
        for i, batch in tqdm(enumerate(val_loader), desc="Validation", total=len(val_loader), smoothing=0): # i = index, batch = stuff in array[i]
            y = model.parse_batch(batch)
            with torch.random.fork_rng(devices=[0,]):
                torch.random.manual_seed(i)# use repeatable seeds during validation so results are more consistent and comparable.
                y_pred = force(model, valid_kwargs=model_args, **{**y, "teacher_force_till": val_teacher_force_till, "p_teacher_forcing": val_p_teacher_forcing})
            
            val_loss_scalars = {
                 "spec_MSE_weight": 0.00,
                "spec_MFSE_weight": 1.00,
              "postnet_MSE_weight": 0.00,
             "postnet_MFSE_weight": 1.00,
                "gate_loss_weight": 1.00,
                "sylps_kld_weight": 0.00,
                "sylps_MSE_weight": 0.00,
                "sylps_MAE_weight": 0.05,
                 "diag_att_weight": 0.00,
            }
            loss_dict, file_losses_batch = criterion(y_pred, y, val_loss_scalars)
            file_losses = update_smoothed_dict(file_losses, file_losses_batch, file_losses_smoothness)
            if loss_dict_total is None:
                loss_dict_total = {k: 0. for k, v in loss_dict.items()}
            
            if hparams.distributed_run:
                reduced_loss_dict = {k: reduce_tensor(v.data, args.n_gpus).item() if v is not None else 0. for k, v in loss_dict.items()}
            else:
                reduced_loss_dict = {k: v.item() if v is not None else 0. for k, v in loss_dict.items()}
            reduced_loss = reduced_loss_dict['loss']
            
            for k in loss_dict_total.keys():
                loss_dict_total[k] = loss_dict_total[k] + reduced_loss_dict[k]
            # end forloop
        loss_dict_total = {k: v/(i+1) for k, v in loss_dict_total.items()}
        # end torch.no_grad()
        
    # reverse changes to valset and model
    if teacher_force == 2:# if inference, sample from each speaker equally. So speakers with smaller datasets get the same weighting onto the val loss.
        valset.filelist = orig_filelist
    model.train()
    
    # update best losses
    if best_val_loss_dict is None:
        best_val_loss_dict = loss_dict_total
    else:
        best_val_loss_dict = {k: min(best_val_loss_dict[k], loss_dict_total[k]) for k in best_val_loss_dict.keys()}
    
    # print, log data and return.
    if args.rank == 0:
        tqdm.write(f"Validation loss {iteration}: {loss_dict_total['loss']:9f}  Average Max Attention: {loss_dict_total['avg_max_attention']:9f}")
        if iteration > 1:
            log_terms = (loss_dict_total, best_val_loss_dict, model, y, y_pred, iteration, val_teacher_force_till, val_p_teacher_forcing)
            if teacher_force == 2:
                logger.log_infer(*log_terms)
            else:
                logger.log_validation(*log_terms)
    
    if teacher_force == 2:
        return loss_dict_total['weighted_score'], best_val_loss_dict, file_losses
    else:
        return loss_dict_total['loss'], best_val_loss_dict, file_losses
Пример #4
0
def GTA_Synthesis(hparams, args, extra_info='', audio_offset=0):
    """Generate Ground-Truth-Aligned Spectrograms for Training WaveGlow."""
    rank   = args.rank
    n_gpus = args.n_gpus
    torch.manual_seed(hparams.seed)
    torch.cuda.manual_seed(hparams.seed)
    
    if args.use_validation_files:
        filelisttype = "val"
        hparams.training_files = hparams.validation_files
    else:
        filelisttype = "train"
    
    # initialize blank model
    print('Initializing Tacotron2...')
    model = load_model(hparams)
    print('Done')
    global model_args
    model_args = get_args(model.forward)
    model.eval()
    
    # Load checkpoint
    assert args.checkpoint_path is not None
    print('Loading Tacotron2 Checkpoint...')
    model = warm_start_model(args.checkpoint_path, model)
    print('Done')
    
    _ = model.train() if args.use_training_mode else model.eval()# set model to either train() or eval() mode. (controls dropout + DFR)
    
    print("Initializing AMP Model")
    if hparams.fp16_run:
        model = amp.initialize(model, opt_level='O2')
    print('Done')
    
    # define datasets/dataloaders
    train_loader, valset, collate_fn, train_sampler, trainset = prepare_dataloaders(hparams, model_args, args, None, audio_offset=audio_offset)
    
    # load and/or generate global_mean
    if args.use_training_mode and hparams.drop_frame_rate > 0.:
        if rank != 0: # if global_mean not yet calcuated, wait for main thread to do it
            while not os.path.exists(hparams.global_mean_npy): time.sleep(1)
        hparams.global_mean = get_global_mean(train_loader, hparams.global_mean_npy, hparams)
    
    # ================ MAIN TRAINNIG LOOP! ===================
    os.makedirs(os.path.join(args.output_directory), exist_ok=True)
    f = open(os.path.join(args.output_directory, f'map_{filelisttype}_gpu{rank}.txt'),'w', encoding='utf-8')
    
    processed_files = 0
    failed_files = 0
    duration = time.time()
    total = len(train_loader)
    rolling_sum = StreamingMovingAverage(100)
    for i, y in enumerate(train_loader):
        y_gpu = model.parse_batch(y) # move batch to GPU
        
        y_pred_gpu = force(model, valid_kwargs=model_args, **{**y_gpu, "teacher_force_till": 0, "p_teacher_forcing": 1.0, "drop_frame_rate": 0.0})
        y_pred = {k: v.cpu() for k,v in y_pred_gpu.items() if v is not None}# move model outputs to CPU
        if args.fp16_save:
            y_pred = {k: v.half() for k,v in y_pred.items()}# convert model outputs to fp16
        
        if args.save_letter_alignments or args.save_phone_alignments:
            alignments = get_alignments(y_pred['alignments'], y['mel_lengths'], y['text_lengths'])# [B, mel_T, txt_T] -> [[B, mel_T, txt_T], [B, mel_T, txt_T], ...]
        
        offset_append = '' if audio_offset == 0 else str(audio_offset)
        for j in range(len(y['gt_mel'])):
            gt_mel   = y['gt_mel'  ][j, :, :y['mel_lengths'][j]]
            pred_mel = y_pred['pred_mel_postnet'][j, :, :y['mel_lengths'][j]]
            
            audiopath      = y['audiopath'][j]
            speaker_id_ext = y['speaker_id_ext'][j]
            
            if True or (args.max_mse or args.max_mae):
                MAE = F. l1_loss(pred_mel, gt_mel).item()
                MSE = F.mse_loss(pred_mel, gt_mel).item()
                if args.max_mse and MSE > args.max_mse:
                    print(f"MSE ({MSE}) is greater than max MSE ({args.max_mse}).\nFilepath: '{audiopath}'\n")
                    failed_files+=1; continue
                if args.max_mae and MAE > args.max_mae:
                    print(f"MAE ({MAE}) is greater than max MAE ({args.max_mae}).\nFilepath: '{audiopath}'\n")
                    failed_files+=1; continue
            else:
                MAE = MSE = 'N/A'
            
            print(f"PATH: '{audiopath}'\nMel Shape:{list(gt_mel.shape)}\nSpeaker_ID: {speaker_id_ext}\nMSE: {MSE}\nMAE: {MAE}")
            if not args.do_not_save_mel:
                pred_mel_path = os.path.splitext(audiopath)[0]+'.pred_mel.pt'
                torch.save(pred_mel.clone(), pred_mel_path)
                pm_audio_path = os.path.splitext(audiopath)[0]+'.pm_audio.pt'# predicted mel audio
                torch.save(y['gt_audio'][j, :y['audio_lengths'][j]].clone(), pm_audio_path)
            if args.save_letter_alignments and hparams.p_arpabet == 0.:
                save_path_align_out = os.path.splitext(audiopath)[0]+'_galign.pt'
                np.save(alignments[j].clone(), save_path_align_out)
            if args.save_phone_alignments and hparams.p_arpabet == 1.:
                save_path_align_out = os.path.splitext(audiopath)[0]+'_palign.pt'
                np.save(alignments[j].clone(), save_path_align_out)
            map = f"{audiopath}|{y['gtext_str'][j]}|{speaker_id_ext}|\n"
            
            f.write(map)# write paths to text file
            processed_files+=1
            print("")
        
        duration = time.time() - duration
        avg_duration = rolling_sum.process(duration)
        time_left = round(((total-i) * avg_duration)/3600, 2)
        print(f'{extra_info}{i}/{total} compute and save GTA melspectrograms in {i}th batch, {duration}s, {time_left}hrs left. {processed_files} processed, {failed_files} failed.')
        duration = time.time()
    f.close()
    
    if n_gpus > 1:
        torch.distributed.barrier()# wait till all graphics cards reach this point.
    
    # merge all generated filelists from every GPU
    filenames = [f'map_{filelisttype}_gpu{j}.txt' for j in range(n_gpus)]
    if rank == 0:
        with open(os.path.join(args.output_directory, f'map_{filelisttype}.txt'), 'w') as outfile:
            for fname in filenames:
                with open(os.path.join(args.output_directory, fname)) as infile:
                    for line in infile:
                        if len(line.strip()):
                            outfile.write(line)