def prepare_dataloaders(data_config, n_gpus, batch_size): # Get data, data loaders and 1ollate function ready ignore_keys = ['training_files', 'validation_files'] trainset = Data( data_config['training_files'], **dict((k, v) for k, v in data_config.items() if k not in ignore_keys)) valset = Data(data_config['validation_files'], **dict((k, v) for k, v in data_config.items() if k not in ignore_keys), speaker_ids=trainset.speaker_ids) collate_fn = DataCollate(n_frames_per_step=1, use_attn_prior=trainset.use_attn_prior) train_sampler, shuffle = None, True if n_gpus > 1: train_sampler, shuffle = DistributedSampler(trainset), False train_loader = DataLoader(trainset, num_workers=1, shuffle=shuffle, sampler=train_sampler, batch_size=batch_size, pin_memory=False, drop_last=True, collate_fn=collate_fn) return train_loader, valset, collate_fn
def compute_validation_loss(model, criterion, valset, batch_size, n_gpus, apply_ctc): model.eval() with torch.no_grad(): collate_fn = DataCollate(n_frames_per_step=1, use_attn_prior=valset.use_attn_prior) val_sampler = DistributedSampler(valset) if n_gpus > 1 else None val_loader = DataLoader(valset, sampler=val_sampler, num_workers=1, shuffle=False, batch_size=batch_size, pin_memory=False, collate_fn=collate_fn) val_loss, val_loss_nll, val_loss_gate = 0.0, 0.0, 0.0 val_loss_ctc = 0.0 n_batches = len(val_loader) for i, batch in enumerate(val_loader): (mel, spk_ids, txt, in_lens, out_lens, gate_target, attn_prior) = batch mel, spk_ids, txt = mel.cuda(), spk_ids.cuda(), txt.cuda() in_lens, out_lens = in_lens.cuda(), out_lens.cuda() gate_target = gate_target.cuda() attn_prior = attn_prior.cuda() if attn_prior is not None else None (z, log_s_list, gate_pred, attn, attn_logprob, mean, log_var, prob) = model(mel, spk_ids, txt, in_lens, out_lens, attn_prior) loss_nll, loss_gate, loss_ctc = criterion( (z, log_s_list, gate_pred, attn, attn_logprob, mean, log_var, prob), gate_target, in_lens, out_lens, is_validation=True) loss = loss_nll + loss_gate if apply_ctc: loss += loss_ctc * criterion.ctc_loss_weight if n_gpus > 1: reduced_val_loss = reduce_tensor(loss.data, n_gpus).item() reduced_val_loss_nll = reduce_tensor(loss_nll.data, n_gpus).item() reduced_val_loss_gate = reduce_tensor(loss_gate.data, n_gpus).item() reduced_val_loss_ctc = reduce_tensor(loss_ctc.data, n_gpus).item() else: reduced_val_loss = loss.item() reduced_val_loss_nll = loss_nll.item() reduced_val_loss_gate = loss_gate.item() reduced_val_loss_ctc = loss_ctc.item() val_loss += reduced_val_loss val_loss_nll += reduced_val_loss_nll val_loss_gate += reduced_val_loss_gate val_loss_ctc += reduced_val_loss_ctc val_loss = val_loss / n_batches val_loss_nll = val_loss_nll / n_batches val_loss_gate = val_loss_gate / n_batches val_loss_ctc = val_loss_ctc / n_batches print("Mean {}\nLogVar {}\nProb {}".format(mean, log_var, prob)) model.train() return (val_loss, val_loss_nll, val_loss_gate, val_loss_ctc, attn, gate_pred, gate_target)