Esempio n. 1
0
 def test_mdn_loss(self):
     # wrap up the inverse data as Variables
     x = torch.from_numpy(
         self.x_train_inv.reshape(self.batch_size, -1, self.d_in)).to(
             self.device)  # (B, max(T), D_in)
     y = torch.from_numpy(
         self.y_train_inv.reshape(self.batch_size, -1, self.d_out)).to(
             self.device)  # (B, max(T), D_out)
     for e in range(1000):
         self.model.zero_grad()
         pi, sigma, mu = self.model(x)
         loss = mdn.mdn_loss(pi, sigma, mu, y).mean()
         if e % 100 == 0:
             print(f"loss: {loss.data.item()}")
         loss.backward()
         self.opt.step()
Esempio n. 2
0
def train_step(
    model,
    optimizer,
    grad_scaler,
    train,
    in_feats,
    out_feats,
    lengths,
    out_scaler,
    feats_criterion="mse",
    stream_wise_loss=False,
    stream_weights=None,
    stream_sizes=None,
):
    model.train() if train else model.eval()
    optimizer.zero_grad()

    if feats_criterion in ["l2", "mse"]:
        criterion = nn.MSELoss(reduction="none")
    elif feats_criterion in ["l1", "mae"]:
        criterion = nn.L1Loss(reduction="none")
    else:
        raise RuntimeError("not supported criterion")

    prediction_type = (model.module.prediction_type() if isinstance(
        model, nn.DataParallel) else model.prediction_type())

    # Apply preprocess if required (e.g., FIR filter for shallow AR)
    # defaults to no-op
    if isinstance(model, nn.DataParallel):
        out_feats = model.module.preprocess_target(out_feats)
    else:
        out_feats = model.preprocess_target(out_feats)

    # Run forward
    with autocast(enabled=grad_scaler is not None):
        pred_out_feats = model(in_feats, lengths)

    # Mask (B, T, 1)
    mask = make_non_pad_mask(lengths).unsqueeze(-1).to(in_feats.device)

    # Compute loss
    if prediction_type == PredictionType.PROBABILISTIC:
        pi, sigma, mu = pred_out_feats
        # (B, max(T)) or (B, max(T), D_out)
        mask_ = mask if len(pi.shape) == 4 else mask.squeeze(-1)
        # Compute loss and apply mask
        with autocast(enabled=grad_scaler is not None):
            loss = mdn_loss(pi, sigma, mu, out_feats, reduce=False)
        loss = loss.masked_select(mask_).mean()
    else:
        if stream_wise_loss:
            w = get_stream_weight(stream_weights,
                                  stream_sizes).to(in_feats.device)
            streams = split_streams(out_feats, stream_sizes)
            pred_streams = split_streams(pred_out_feats, stream_sizes)
            loss = 0
            for pred_stream, stream, sw in zip(pred_streams, streams, w):
                with autocast(enabled=grad_scaler is not None):
                    loss += (sw * criterion(pred_stream.masked_select(mask),
                                            stream.masked_select(mask)).mean())
        else:
            with autocast(enabled=grad_scaler is not None):
                loss = criterion(pred_out_feats.masked_select(mask),
                                 out_feats.masked_select(mask)).mean()

    if prediction_type == PredictionType.PROBABILISTIC:
        with torch.no_grad():
            pred_out_feats_ = mdn_get_most_probable_sigma_and_mu(
                pi, sigma, mu)[1]
    else:
        pred_out_feats_ = pred_out_feats
    distortions = compute_distortions(pred_out_feats_, out_feats, lengths,
                                      out_scaler)

    if train:
        if grad_scaler is not None:
            grad_scaler.scale(loss).backward()
            grad_scaler.step(optimizer)
            grad_scaler.update()
        else:
            loss.backward()
            optimizer.step()

    return loss, distortions
Esempio n. 3
0
def train_step(
    model,
    model_config,
    optimizer,
    grad_scaler,
    train,
    in_feats,
    out_feats,
    lengths,
    out_scaler,
    feats_criterion="mse",
    pitch_reg_dyn_ws=1.0,
    pitch_reg_weight=1.0,
):
    model.train() if train else model.eval()
    optimizer.zero_grad()
    log_metrics = {}

    if feats_criterion in ["l2", "mse"]:
        criterion = nn.MSELoss(reduction="none")
    elif feats_criterion in ["l1", "mae"]:
        criterion = nn.L1Loss(reduction="none")
    else:
        raise RuntimeError("not supported criterion")

    prediction_type = (
        model.module.prediction_type()
        if isinstance(model, nn.DataParallel)
        else model.prediction_type()
    )

    # Apply preprocess if required (e.g., FIR filter for shallow AR)
    # defaults to no-op
    if isinstance(model, nn.DataParallel):
        out_feats = model.module.preprocess_target(out_feats)
    else:
        out_feats = model.preprocess_target(out_feats)

    # Run forward
    with autocast(enabled=grad_scaler is not None):
        outs = model(in_feats, lengths, out_feats)
        if isinstance(outs, tuple) and len(outs) == 2:
            pred_out_feats, lf0_residual = outs
        else:
            pred_out_feats, lf0_residual = outs, None

    # Mask (B, T, 1)
    mask = make_non_pad_mask(lengths).unsqueeze(-1).to(in_feats.device)

    # Compute loss
    if prediction_type == PredictionType.PROBABILISTIC:
        pi, sigma, mu = pred_out_feats

        # (B, max(T)) or (B, max(T), D_out)
        mask_ = mask if len(pi.shape) == 4 else mask.squeeze(-1)
        # Compute loss and apply mask
        with autocast(enabled=grad_scaler is not None):
            loss_feats = mdn_loss(pi, sigma, mu, out_feats, reduce=False)
            loss_feats = loss_feats.masked_select(mask_).mean()
    else:
        with autocast(enabled=grad_scaler is not None):
            # NOTE: multiple predictions
            if isinstance(pred_out_feats, list):
                loss_feats = 0
                for pred_out_feats_ in pred_out_feats:
                    loss_feats += criterion(
                        pred_out_feats_.masked_select(mask),
                        out_feats.masked_select(mask),
                    ).mean()
            else:
                loss_feats = criterion(
                    pred_out_feats.masked_select(mask), out_feats.masked_select(mask)
                ).mean()

    # Pitch regularization
    # NOTE: l1 loss seems to be better than mse loss in my experiments
    # we could use l2 loss as suggested in the sinsy's paper
    if lf0_residual is not None:
        with autocast(enabled=grad_scaler is not None):
            if isinstance(lf0_residual, list):
                loss_pitch = 0
                for lf0_residual_ in lf0_residual:
                    loss_pitch += (
                        (pitch_reg_dyn_ws * lf0_residual_.abs())
                        .masked_select(mask)
                        .mean()
                    )
            else:
                loss_pitch = (
                    (pitch_reg_dyn_ws * lf0_residual.abs()).masked_select(mask).mean()
                )
    else:
        loss_pitch = torch.tensor(0.0).to(in_feats.device)

    loss = loss_feats + pitch_reg_weight * loss_pitch

    if prediction_type == PredictionType.PROBABILISTIC:
        with torch.no_grad():
            pred_out_feats_ = mdn_get_most_probable_sigma_and_mu(pi, sigma, mu)[1]
    else:
        if isinstance(pred_out_feats, list):
            pred_out_feats_ = pred_out_feats[-1]
        else:
            pred_out_feats_ = pred_out_feats
    distortions = compute_distortions(
        pred_out_feats_, out_feats, lengths, out_scaler, model_config
    )

    if train:
        if grad_scaler is not None:
            grad_scaler.scale(loss).backward()
            grad_scaler.step(optimizer)
            grad_scaler.update()
        else:
            loss.backward()
            optimizer.step()

    log_metrics.update(distortions)
    log_metrics.update(
        {
            "Loss": loss.item(),
            "Loss_Feats": loss_feats.item(),
            "Loss_Pitch": loss_pitch.item(),
        }
    )

    return loss, log_metrics
Esempio n. 4
0
def train_loop(config, device, model, optimizer, lr_scheduler, data_loaders):
    criterion = nn.MSELoss(reduction="none")

    logger.info("Start utterance-wise training...")

    stream_weights = get_stream_weight(
        config.model.stream_weights, config.model.stream_sizes).to(device)

    best_loss = 10000000
    for epoch in tqdm(range(1, config.train.nepochs + 1)):
        for phase in data_loaders.keys():
            train = phase.startswith("train")
            model.train() if train else model.eval()
            running_loss = 0
            for x, y, lengths in data_loaders[phase]:
                # Sort by lengths . This is needed for pytorch's PackedSequence
                sorted_lengths, indices = torch.sort(lengths, dim=0, descending=True)
                x, y = x[indices].to(device), y[indices].to(device)

                optimizer.zero_grad()

                # Apply preprocess if required (e.g., FIR filter for shallow AR)
                # defaults to no-op
                y = model.preprocess_target(y)

                # Run forwaard
                if model.prediction_type() == PredictionType.PROBABILISTIC:
                    pi, sigma, mu = model(x, sorted_lengths)

                    # (B, max(T))
                    mask = make_non_pad_mask(sorted_lengths).to(device)
                    # Compute loss and apply mask
                    loss = mdn_loss(pi, sigma, mu, y, reduce=False).masked_select(mask).mean()

                else:
                    y_hat = model(x, sorted_lengths)

                    # Compute loss
                    mask = make_non_pad_mask(sorted_lengths).unsqueeze(-1).to(device)

                    if config.train.stream_wise_loss:
                        # Strean-wise loss
                        streams = split_streams(y, config.model.stream_sizes)
                        streams_hat = split_streams(y_hat, config.model.stream_sizes)
                        loss = 0
                        for s_hat, s, sw in zip(streams_hat, streams, stream_weights):
                            s_hat_mask = s_hat.masked_select(mask)
                            s_mask = s.masked_select(mask)
                            loss += sw * criterion(s_hat_mask, s_mask).mean()
                    else:
                        # Joint modeling
                        y_hat = y_hat.masked_select(mask)
                        y = y.masked_select(mask)
                        loss = criterion(y_hat, y).mean()

                if train:
                    loss.backward()
                    optimizer.step()

                running_loss += loss.item()
            ave_loss = running_loss / len(data_loaders[phase])
            logger.info(f"[{phase}] [Epoch {epoch}]: loss {ave_loss}")
            if not train and ave_loss < best_loss:
                best_loss = ave_loss
                save_best_checkpoint(config, model, optimizer, best_loss)

        # step per each epoch (may consider updating per iter.)
        lr_scheduler.step()

        if epoch % config.train.checkpoint_epoch_interval == 0:
            save_checkpoint(config, model, optimizer, lr_scheduler, epoch)

    # save at last epoch
    save_checkpoint(config, model, optimizer, lr_scheduler, config.train.nepochs)
    logger.info(f"The best loss was {best_loss}")

    return model