def test_mdn_loss(self): # wrap up the inverse data as Variables x = torch.from_numpy( self.x_train_inv.reshape(self.batch_size, -1, self.d_in)).to( self.device) # (B, max(T), D_in) y = torch.from_numpy( self.y_train_inv.reshape(self.batch_size, -1, self.d_out)).to( self.device) # (B, max(T), D_out) for e in range(1000): self.model.zero_grad() pi, sigma, mu = self.model(x) loss = mdn.mdn_loss(pi, sigma, mu, y).mean() if e % 100 == 0: print(f"loss: {loss.data.item()}") loss.backward() self.opt.step()
def train_step( model, optimizer, grad_scaler, train, in_feats, out_feats, lengths, out_scaler, feats_criterion="mse", stream_wise_loss=False, stream_weights=None, stream_sizes=None, ): model.train() if train else model.eval() optimizer.zero_grad() if feats_criterion in ["l2", "mse"]: criterion = nn.MSELoss(reduction="none") elif feats_criterion in ["l1", "mae"]: criterion = nn.L1Loss(reduction="none") else: raise RuntimeError("not supported criterion") prediction_type = (model.module.prediction_type() if isinstance( model, nn.DataParallel) else model.prediction_type()) # Apply preprocess if required (e.g., FIR filter for shallow AR) # defaults to no-op if isinstance(model, nn.DataParallel): out_feats = model.module.preprocess_target(out_feats) else: out_feats = model.preprocess_target(out_feats) # Run forward with autocast(enabled=grad_scaler is not None): pred_out_feats = model(in_feats, lengths) # Mask (B, T, 1) mask = make_non_pad_mask(lengths).unsqueeze(-1).to(in_feats.device) # Compute loss if prediction_type == PredictionType.PROBABILISTIC: pi, sigma, mu = pred_out_feats # (B, max(T)) or (B, max(T), D_out) mask_ = mask if len(pi.shape) == 4 else mask.squeeze(-1) # Compute loss and apply mask with autocast(enabled=grad_scaler is not None): loss = mdn_loss(pi, sigma, mu, out_feats, reduce=False) loss = loss.masked_select(mask_).mean() else: if stream_wise_loss: w = get_stream_weight(stream_weights, stream_sizes).to(in_feats.device) streams = split_streams(out_feats, stream_sizes) pred_streams = split_streams(pred_out_feats, stream_sizes) loss = 0 for pred_stream, stream, sw in zip(pred_streams, streams, w): with autocast(enabled=grad_scaler is not None): loss += (sw * criterion(pred_stream.masked_select(mask), stream.masked_select(mask)).mean()) else: with autocast(enabled=grad_scaler is not None): loss = criterion(pred_out_feats.masked_select(mask), out_feats.masked_select(mask)).mean() if prediction_type == PredictionType.PROBABILISTIC: with torch.no_grad(): pred_out_feats_ = mdn_get_most_probable_sigma_and_mu( pi, sigma, mu)[1] else: pred_out_feats_ = pred_out_feats distortions = compute_distortions(pred_out_feats_, out_feats, lengths, out_scaler) if train: if grad_scaler is not None: grad_scaler.scale(loss).backward() grad_scaler.step(optimizer) grad_scaler.update() else: loss.backward() optimizer.step() return loss, distortions
def train_step( model, model_config, optimizer, grad_scaler, train, in_feats, out_feats, lengths, out_scaler, feats_criterion="mse", pitch_reg_dyn_ws=1.0, pitch_reg_weight=1.0, ): model.train() if train else model.eval() optimizer.zero_grad() log_metrics = {} if feats_criterion in ["l2", "mse"]: criterion = nn.MSELoss(reduction="none") elif feats_criterion in ["l1", "mae"]: criterion = nn.L1Loss(reduction="none") else: raise RuntimeError("not supported criterion") prediction_type = ( model.module.prediction_type() if isinstance(model, nn.DataParallel) else model.prediction_type() ) # Apply preprocess if required (e.g., FIR filter for shallow AR) # defaults to no-op if isinstance(model, nn.DataParallel): out_feats = model.module.preprocess_target(out_feats) else: out_feats = model.preprocess_target(out_feats) # Run forward with autocast(enabled=grad_scaler is not None): outs = model(in_feats, lengths, out_feats) if isinstance(outs, tuple) and len(outs) == 2: pred_out_feats, lf0_residual = outs else: pred_out_feats, lf0_residual = outs, None # Mask (B, T, 1) mask = make_non_pad_mask(lengths).unsqueeze(-1).to(in_feats.device) # Compute loss if prediction_type == PredictionType.PROBABILISTIC: pi, sigma, mu = pred_out_feats # (B, max(T)) or (B, max(T), D_out) mask_ = mask if len(pi.shape) == 4 else mask.squeeze(-1) # Compute loss and apply mask with autocast(enabled=grad_scaler is not None): loss_feats = mdn_loss(pi, sigma, mu, out_feats, reduce=False) loss_feats = loss_feats.masked_select(mask_).mean() else: with autocast(enabled=grad_scaler is not None): # NOTE: multiple predictions if isinstance(pred_out_feats, list): loss_feats = 0 for pred_out_feats_ in pred_out_feats: loss_feats += criterion( pred_out_feats_.masked_select(mask), out_feats.masked_select(mask), ).mean() else: loss_feats = criterion( pred_out_feats.masked_select(mask), out_feats.masked_select(mask) ).mean() # Pitch regularization # NOTE: l1 loss seems to be better than mse loss in my experiments # we could use l2 loss as suggested in the sinsy's paper if lf0_residual is not None: with autocast(enabled=grad_scaler is not None): if isinstance(lf0_residual, list): loss_pitch = 0 for lf0_residual_ in lf0_residual: loss_pitch += ( (pitch_reg_dyn_ws * lf0_residual_.abs()) .masked_select(mask) .mean() ) else: loss_pitch = ( (pitch_reg_dyn_ws * lf0_residual.abs()).masked_select(mask).mean() ) else: loss_pitch = torch.tensor(0.0).to(in_feats.device) loss = loss_feats + pitch_reg_weight * loss_pitch if prediction_type == PredictionType.PROBABILISTIC: with torch.no_grad(): pred_out_feats_ = mdn_get_most_probable_sigma_and_mu(pi, sigma, mu)[1] else: if isinstance(pred_out_feats, list): pred_out_feats_ = pred_out_feats[-1] else: pred_out_feats_ = pred_out_feats distortions = compute_distortions( pred_out_feats_, out_feats, lengths, out_scaler, model_config ) if train: if grad_scaler is not None: grad_scaler.scale(loss).backward() grad_scaler.step(optimizer) grad_scaler.update() else: loss.backward() optimizer.step() log_metrics.update(distortions) log_metrics.update( { "Loss": loss.item(), "Loss_Feats": loss_feats.item(), "Loss_Pitch": loss_pitch.item(), } ) return loss, log_metrics
def train_loop(config, device, model, optimizer, lr_scheduler, data_loaders): criterion = nn.MSELoss(reduction="none") logger.info("Start utterance-wise training...") stream_weights = get_stream_weight( config.model.stream_weights, config.model.stream_sizes).to(device) best_loss = 10000000 for epoch in tqdm(range(1, config.train.nepochs + 1)): for phase in data_loaders.keys(): train = phase.startswith("train") model.train() if train else model.eval() running_loss = 0 for x, y, lengths in data_loaders[phase]: # Sort by lengths . This is needed for pytorch's PackedSequence sorted_lengths, indices = torch.sort(lengths, dim=0, descending=True) x, y = x[indices].to(device), y[indices].to(device) optimizer.zero_grad() # Apply preprocess if required (e.g., FIR filter for shallow AR) # defaults to no-op y = model.preprocess_target(y) # Run forwaard if model.prediction_type() == PredictionType.PROBABILISTIC: pi, sigma, mu = model(x, sorted_lengths) # (B, max(T)) mask = make_non_pad_mask(sorted_lengths).to(device) # Compute loss and apply mask loss = mdn_loss(pi, sigma, mu, y, reduce=False).masked_select(mask).mean() else: y_hat = model(x, sorted_lengths) # Compute loss mask = make_non_pad_mask(sorted_lengths).unsqueeze(-1).to(device) if config.train.stream_wise_loss: # Strean-wise loss streams = split_streams(y, config.model.stream_sizes) streams_hat = split_streams(y_hat, config.model.stream_sizes) loss = 0 for s_hat, s, sw in zip(streams_hat, streams, stream_weights): s_hat_mask = s_hat.masked_select(mask) s_mask = s.masked_select(mask) loss += sw * criterion(s_hat_mask, s_mask).mean() else: # Joint modeling y_hat = y_hat.masked_select(mask) y = y.masked_select(mask) loss = criterion(y_hat, y).mean() if train: loss.backward() optimizer.step() running_loss += loss.item() ave_loss = running_loss / len(data_loaders[phase]) logger.info(f"[{phase}] [Epoch {epoch}]: loss {ave_loss}") if not train and ave_loss < best_loss: best_loss = ave_loss save_best_checkpoint(config, model, optimizer, best_loss) # step per each epoch (may consider updating per iter.) lr_scheduler.step() if epoch % config.train.checkpoint_epoch_interval == 0: save_checkpoint(config, model, optimizer, lr_scheduler, epoch) # save at last epoch save_checkpoint(config, model, optimizer, lr_scheduler, config.train.nepochs) logger.info(f"The best loss was {best_loss}") return model