Ejemplo n.º 1
0
def train_step(
    model_config,
    optim_config,
    netG,
    optG,
    netD,
    optD,
    grad_scaler,
    train,
    in_feats,
    out_feats,
    lengths,
    out_scaler,
    feats_criterion="mse",
    pitch_reg_dyn_ws=1.0,
    pitch_reg_weight=1.0,
    adv_weight=1.0,
    adv_streams=None,
    fm_weight=0.0,
    adv_use_static_feats_only=True,
    mask_nth_mgc_for_adv_loss=0,
    gan_type="lsgan",
):
    netG.train() if train else netG.eval()
    netD.train() if train else netD.eval()
    log_metrics = {}

    if feats_criterion in ["l2", "mse"]:
        criterion = nn.MSELoss(reduction="none")
    elif feats_criterion in ["l1", "mae"]:
        criterion = nn.L1Loss(reduction="none")
    else:
        raise RuntimeError("not supported criterion")

    prediction_type = (
        netG.module.prediction_type()
        if isinstance(netG, nn.DataParallel)
        else netG.prediction_type()
    )
    # NOTE: it is not trivial to adapt GAN for probabilistic models
    assert prediction_type != PredictionType.PROBABILISTIC

    # Apply preprocess if required (e.g., FIR filter for shallow AR)
    # defaults to no-op
    if isinstance(netG, nn.DataParallel):
        out_feats = netG.module.preprocess_target(out_feats)
    else:
        out_feats = netG.preprocess_target(out_feats)

    # Run forward
    with autocast(enabled=grad_scaler is not None):
        pred_out_feats, lf0_residual = netG(in_feats, lengths)

        # Select streams for computing adversarial loss
        if adv_use_static_feats_only:
            real_netD_in_feats = torch.cat(
                get_static_features(
                    out_feats,
                    model_config.num_windows,
                    model_config.stream_sizes,
                    model_config.has_dynamic_features,
                    adv_streams,
                ),
                dim=-1,
            )
            fake_netD_in_feats = torch.cat(
                get_static_features(
                    pred_out_feats,
                    model_config.num_windows,
                    model_config.stream_sizes,
                    model_config.has_dynamic_features,
                    adv_streams,
                ),
                dim=-1,
            )
        else:
            real_netD_in_feats = select_streams(
                out_feats, model_config.stream_sizes, adv_streams
            )
            fake_netD_in_feats = select_streams(
                pred_out_feats,
                model_config.stream_sizes,
                adv_streams,
            )

    # Ref: http://sython.org/papers/ASJ/saito2017asja.pdf
    # 0-th mgc with adversarial trainging affects speech quality
    # NOTE: assuming that the first stream contains mgc
    if mask_nth_mgc_for_adv_loss > 0:
        real_netD_in_feats = real_netD_in_feats[:, :, mask_nth_mgc_for_adv_loss:]
        fake_netD_in_feats = fake_netD_in_feats[:, :, mask_nth_mgc_for_adv_loss:]

    # Real
    with autocast(enabled=grad_scaler is not None):
        D_real = netD(real_netD_in_feats, in_feats, lengths)
        # NOTE: must be list of list to support multi-scale discriminators
        assert isinstance(D_real, list) and isinstance(D_real[-1], list)
        # Fake
        D_fake_det = netD(fake_netD_in_feats.detach(), in_feats, lengths)

    # Mask (B, T, 1)
    mask = make_non_pad_mask(lengths).unsqueeze(-1).to(in_feats.device)

    # Update discriminator
    eps = 1e-14
    loss_real = 0
    loss_fake = 0

    with autocast(enabled=grad_scaler is not None):
        for idx, (D_real_, D_fake_det_) in enumerate(zip(D_real, D_fake_det)):
            if gan_type == "lsgan":
                loss_real_ = (D_real_[-1] - 1) ** 2
                loss_fake_ = D_fake_det_[-1] ** 2
            elif gan_type == "vanilla-gan":
                loss_real_ = -torch.log(D_real_[-1] + eps)
                loss_fake_ = -torch.log(1 - D_fake_det_[-1] + eps)
            else:
                raise ValueError(f"Unknown gan type: {gan_type}")

            # mask for D
            if (
                hasattr(netD, "downsample_scale")
                and mask.shape[1] // netD.downsample_scale == D_real_[-1].shape[1]
            ):
                D_mask = mask[:, :: netD.downsample_scale, :]
            else:
                if D_real_[-1].shape[1] == out_feats.shape[1]:
                    D_mask = mask
                else:
                    D_mask = None

            if D_mask is not None:
                loss_real_ = loss_real_.masked_select(D_mask).mean()
                loss_fake_ = loss_fake_.masked_select(D_mask).mean()
            else:
                loss_real_ = loss_real_.mean()
                loss_fake_ = loss_fake_.mean()

            log_metrics[f"Loss_Real_Scale{idx}"] = loss_real_.item()
            log_metrics[f"Loss_Fake_Scale{idx}"] = loss_fake_.item()

            loss_real += loss_real_
            loss_fake += loss_fake_

        loss_d = loss_real + loss_fake

    if train:
        optD.zero_grad()
        if grad_scaler is not None:
            grad_scaler.scale(loss_d).backward()
            grad_scaler.unscale_(optD)
            grad_norm_d = torch.nn.utils.clip_grad_norm_(
                netD.parameters(), optim_config.netD.clip_norm
            )
            log_metrics["GradNorm_D"] = grad_norm_d
            grad_scaler.step(optD)
        else:
            loss_d.backward()
            grad_norm_d = torch.nn.utils.clip_grad_norm_(
                netD.parameters(), optim_config.netD.clip_norm
            )
            log_metrics["GradNorm_D"] = grad_norm_d
            optD.step()

    # Update generator
    with autocast(enabled=grad_scaler is not None):
        loss_feats = criterion(
            pred_out_feats.masked_select(mask), out_feats.masked_select(mask)
        ).mean()

        # adversarial loss
        D_fake = netD(fake_netD_in_feats, in_feats, lengths)
        loss_adv = 0
        for idx, D_fake_ in enumerate(D_fake):
            if gan_type == "lsgan":
                loss_adv_ = (1 - D_fake_[-1]) ** 2
            elif gan_type == "vanilla-gan":
                loss_adv_ = -torch.log(D_fake_[-1] + eps)
            else:
                raise ValueError(f"Unknown gan type: {gan_type}")

            if (
                hasattr(netD, "downsample_scale")
                and mask.shape[1] // netD.downsample_scale == D_fake_[-1].shape[1]
            ):
                D_mask = mask[:, :: netD.downsample_scale, :]
            else:
                if D_real_[-1].shape[1] == out_feats.shape[1]:
                    D_mask = mask
                else:
                    D_mask = None

            if D_mask is not None:
                loss_adv_ = loss_adv_.masked_select(D_mask).mean()
            else:
                loss_adv_ = loss_adv_.mean()

            log_metrics[f"Loss_Adv_Scale{idx}"] = loss_adv_.item()

            loss_adv += loss_adv_

        # Feature matching loss
        loss_fm = torch.tensor(0.0).to(in_feats.device)
        if fm_weight > 0:
            for D_fake_, D_real_ in zip(D_fake, D_real):
                for fake_fmap, real_fmap in zip(D_fake_[:-1], D_real_[:-1]):
                    loss_fm += F.l1_loss(fake_fmap, real_fmap.detach())

        # Pitch regularization
        # NOTE: l1 loss seems to be better than mse loss in my experiments
        # we could use l2 loss as suggested in the sinsy's paper
        loss_pitch = (pitch_reg_dyn_ws * lf0_residual.abs()).masked_select(mask).mean()

        loss = (
            loss_feats
            + adv_weight * loss_adv
            + pitch_reg_weight * loss_pitch
            + fm_weight * loss_fm
        )

    if train:
        optG.zero_grad()
        if grad_scaler is not None:
            grad_scaler.scale(loss).backward()
            grad_scaler.unscale_(optG)
            grad_norm_g = torch.nn.utils.clip_grad_norm_(
                netG.parameters(), optim_config.netG.clip_norm
            )
            log_metrics["GradNorm_G"] = grad_norm_g
            grad_scaler.step(optG)
        else:
            loss.backward()
            grad_norm_g = torch.nn.utils.clip_grad_norm_(
                netG.parameters(), optim_config.netG.clip_norm
            )
            log_metrics["GradNorm_G"] = grad_norm_g
            optG.step()

    # NOTE: this shouldn't be called multiple times in a training step
    if train and grad_scaler is not None:
        grad_scaler.update()

    # Metrics
    distortions = compute_distortions(
        pred_out_feats, out_feats, lengths, out_scaler, model_config
    )
    log_metrics.update(distortions)
    log_metrics.update(
        {
            "Loss": loss.item(),
            "Loss_Feats": loss_feats.item(),
            "Loss_Adv_Total": loss_adv.item(),
            "Loss_Feature_Matching": loss_fm.item(),
            "Loss_Pitch": loss_pitch.item(),
            "Loss_Real_Total": loss_real.item(),
            "Loss_Fake_Total": loss_fake.item(),
            "Loss_D": loss_d.item(),
        }
    )

    return loss, log_metrics
Ejemplo n.º 2
0
def train_loop(config, device, model, optimizer, lr_scheduler, data_loaders):
    criterion = nn.MSELoss(reduction="none")
    logger.info("Start utterance-wise training...")

    stream_weights = get_stream_weight(config.model.stream_weights,
                                       config.model.stream_sizes).to(device)

    best_loss = 10000000
    for epoch in tqdm(range(1, config.train.nepochs + 1)):
        for phase in data_loaders.keys():
            train = phase.startswith("train")
            model.train() if train else model.eval()
            running_loss = 0
            for x, y, lengths in data_loaders[phase]:
                # Sort by lengths . This is needed for pytorch's PackedSequence
                sorted_lengths, indices = torch.sort(lengths,
                                                     dim=0,
                                                     descending=True)
                x, y = x[indices].to(device), y[indices].to(device)

                optimizer.zero_grad()

                # Run forwaard
                y_hat = model(x, sorted_lengths)

                # Compute loss
                mask = make_non_pad_mask(sorted_lengths).unsqueeze(-1).to(
                    device)

                if config.train.stream_wise_loss:
                    # Strean-wise loss
                    streams = split_streams(y, config.model.stream_sizes)
                    streams_hat = split_streams(y_hat,
                                                config.model.stream_sizes)
                    loss = 0
                    for s_hat, s, sw in zip(streams_hat, streams,
                                            stream_weights):
                        s_hat_mask = s_hat.masked_select(mask)
                        s_mask = s.masked_select(mask)
                        loss += sw * criterion(s_hat_mask, s_mask).mean()
                else:
                    # Joint modeling
                    y_hat = y_hat.masked_select(mask)
                    y = y.masked_select(mask)
                    loss = criterion(y_hat, y).mean()

                if train:
                    loss.backward()
                    optimizer.step()

                running_loss += loss.item()
            ave_loss = running_loss / len(data_loaders[phase])
            logger.info(f"[{phase}] [Epoch {epoch}]: loss {ave_loss}")
            if not train and ave_loss < best_loss:
                best_loss = ave_loss
                save_best_checkpoint(config, model, optimizer, best_loss)

        # step per each epoch (may consider updating per iter.)
        lr_scheduler.step()

        if epoch % config.train.checkpoint_epoch_interval == 0:
            save_checkpoint(config, model, optimizer, lr_scheduler, epoch)

    # save at last epoch
    save_checkpoint(config, model, optimizer, lr_scheduler,
                    config.train.nepochs)
    logger.info(f"The best loss was {best_loss}")

    return model
Ejemplo n.º 3
0
def train_step(
    model,
    model_config,
    optimizer,
    grad_scaler,
    train,
    in_feats,
    out_feats,
    lengths,
    out_scaler,
    feats_criterion="mse",
    pitch_reg_dyn_ws=1.0,
    pitch_reg_weight=1.0,
):
    model.train() if train else model.eval()
    optimizer.zero_grad()
    log_metrics = {}

    if feats_criterion in ["l2", "mse"]:
        criterion = nn.MSELoss(reduction="none")
    elif feats_criterion in ["l1", "mae"]:
        criterion = nn.L1Loss(reduction="none")
    else:
        raise RuntimeError("not supported criterion")

    prediction_type = (
        model.module.prediction_type()
        if isinstance(model, nn.DataParallel)
        else model.prediction_type()
    )

    # Apply preprocess if required (e.g., FIR filter for shallow AR)
    # defaults to no-op
    if isinstance(model, nn.DataParallel):
        out_feats = model.module.preprocess_target(out_feats)
    else:
        out_feats = model.preprocess_target(out_feats)

    # Run forward
    with autocast(enabled=grad_scaler is not None):
        outs = model(in_feats, lengths, out_feats)
        if isinstance(outs, tuple) and len(outs) == 2:
            pred_out_feats, lf0_residual = outs
        else:
            pred_out_feats, lf0_residual = outs, None

    # Mask (B, T, 1)
    mask = make_non_pad_mask(lengths).unsqueeze(-1).to(in_feats.device)

    # Compute loss
    if prediction_type == PredictionType.PROBABILISTIC:
        pi, sigma, mu = pred_out_feats

        # (B, max(T)) or (B, max(T), D_out)
        mask_ = mask if len(pi.shape) == 4 else mask.squeeze(-1)
        # Compute loss and apply mask
        with autocast(enabled=grad_scaler is not None):
            loss_feats = mdn_loss(pi, sigma, mu, out_feats, reduce=False)
            loss_feats = loss_feats.masked_select(mask_).mean()
    else:
        with autocast(enabled=grad_scaler is not None):
            # NOTE: multiple predictions
            if isinstance(pred_out_feats, list):
                loss_feats = 0
                for pred_out_feats_ in pred_out_feats:
                    loss_feats += criterion(
                        pred_out_feats_.masked_select(mask),
                        out_feats.masked_select(mask),
                    ).mean()
            else:
                loss_feats = criterion(
                    pred_out_feats.masked_select(mask), out_feats.masked_select(mask)
                ).mean()

    # Pitch regularization
    # NOTE: l1 loss seems to be better than mse loss in my experiments
    # we could use l2 loss as suggested in the sinsy's paper
    if lf0_residual is not None:
        with autocast(enabled=grad_scaler is not None):
            if isinstance(lf0_residual, list):
                loss_pitch = 0
                for lf0_residual_ in lf0_residual:
                    loss_pitch += (
                        (pitch_reg_dyn_ws * lf0_residual_.abs())
                        .masked_select(mask)
                        .mean()
                    )
            else:
                loss_pitch = (
                    (pitch_reg_dyn_ws * lf0_residual.abs()).masked_select(mask).mean()
                )
    else:
        loss_pitch = torch.tensor(0.0).to(in_feats.device)

    loss = loss_feats + pitch_reg_weight * loss_pitch

    if prediction_type == PredictionType.PROBABILISTIC:
        with torch.no_grad():
            pred_out_feats_ = mdn_get_most_probable_sigma_and_mu(pi, sigma, mu)[1]
    else:
        if isinstance(pred_out_feats, list):
            pred_out_feats_ = pred_out_feats[-1]
        else:
            pred_out_feats_ = pred_out_feats
    distortions = compute_distortions(
        pred_out_feats_, out_feats, lengths, out_scaler, model_config
    )

    if train:
        if grad_scaler is not None:
            grad_scaler.scale(loss).backward()
            grad_scaler.step(optimizer)
            grad_scaler.update()
        else:
            loss.backward()
            optimizer.step()

    log_metrics.update(distortions)
    log_metrics.update(
        {
            "Loss": loss.item(),
            "Loss_Feats": loss_feats.item(),
            "Loss_Pitch": loss_pitch.item(),
        }
    )

    return loss, log_metrics
Ejemplo n.º 4
0
def train_step(
    model_config,
    optim_config,
    netG_A2B,
    netG_B2A,
    optG,
    netD_A,
    netD_B,
    optD,
    grad_scaler,
    train,
    in_feats,
    out_feats,
    lengths,
    out_scaler,
    adv_weight=1.0,
    adv_streams=None,
    fm_weight=0.0,
    mask_nth_mgc_for_adv_loss=0,
    gan_type="lsgan",
    vuv_mask=False,
    cycle_weight=10.0,
    id_weight=5.0,
    use_id_loss=True,
):
    netG_A2B.train() if train else netG_A2B.eval()
    netG_B2A.train() if train else netG_B2A.eval()
    netD_A.train() if train else netD_A.eval()
    netD_B.train() if train else netD_B.eval()

    log_metrics = {}

    if vuv_mask:
        # NOTE: Assuming 3rd stream is the V/UV
        vuv_idx = np.sum(model_config.stream_sizes[:2])
        is_v = torch.logical_and(
            out_feats[:, :, vuv_idx : vuv_idx + 1] > 0,
            in_feats[:, :, vuv_idx : vuv_idx + 1] > 0,
        )
        vuv = is_v
    else:
        vuv = 1.0

    # Run forward A2B and B2A
    with autocast(enabled=grad_scaler is not None):
        pred_out_feats_A = netG_B2A(out_feats, lengths)
        pred_out_feats_B = netG_A2B(in_feats, lengths)

        # Cycle consistency loss
        loss_cycle = F.l1_loss(
            netG_A2B(pred_out_feats_A, lengths) * vuv, out_feats * vuv
        ) + F.l1_loss(netG_B2A(pred_out_feats_B, lengths) * vuv, in_feats * vuv)

        # Identity mapping loss
        if use_id_loss and id_weight > 0:
            loss_id = F.l1_loss(
                netG_A2B(out_feats, lengths) * vuv, out_feats * vuv
            ) + F.l1_loss(netG_B2A(in_feats, lengths) * vuv, in_feats * vuv)
        else:
            loss_id = torch.tensor(0.0).to(in_feats.device)

    real_netD_in_feats_A = select_streams(
        in_feats, model_config.stream_sizes, adv_streams
    )
    real_netD_in_feats_B = select_streams(
        out_feats, model_config.stream_sizes, adv_streams
    )
    fake_netD_in_feats_A = select_streams(
        pred_out_feats_A,
        model_config.stream_sizes,
        adv_streams,
    )
    fake_netD_in_feats_B = select_streams(
        pred_out_feats_B,
        model_config.stream_sizes,
        adv_streams,
    )

    # Ref: http://sython.org/papers/ASJ/saito2017asja.pdf
    # 0-th mgc with adversarial trainging affects speech quality
    # NOTE: assuming that the first stream contains mgc
    if mask_nth_mgc_for_adv_loss > 0:
        real_netD_in_feats_A = real_netD_in_feats_A[:, :, mask_nth_mgc_for_adv_loss:]
        real_netD_in_feats_B = real_netD_in_feats_B[:, :, mask_nth_mgc_for_adv_loss:]
        fake_netD_in_feats_A = fake_netD_in_feats_A[:, :, mask_nth_mgc_for_adv_loss:]
        fake_netD_in_feats_B = fake_netD_in_feats_B[:, :, mask_nth_mgc_for_adv_loss:]

    with autocast(enabled=grad_scaler is not None):
        # Real
        D_real_A = netD_A(real_netD_in_feats_A * vuv, in_feats, lengths)
        D_real_B = netD_B(real_netD_in_feats_B * vuv, in_feats, lengths)
        # Fake
        D_fake_det_A = netD_A(fake_netD_in_feats_A.detach() * vuv, in_feats, lengths)
        D_fake_det_B = netD_B(fake_netD_in_feats_B.detach() * vuv, in_feats, lengths)

    # Mask (B, T, 1)
    mask = make_non_pad_mask(lengths).unsqueeze(-1).to(in_feats.device)

    # Update discriminator
    eps = 1e-14
    loss_real = 0
    loss_fake = 0

    # A
    with autocast(enabled=grad_scaler is not None):
        for idx, (D_real_, D_fake_det_) in enumerate(zip(D_real_A, D_fake_det_A)):
            if gan_type == "lsgan":
                loss_real_ = (D_real_[-1] - 1) ** 2
                loss_fake_ = D_fake_det_[-1] ** 2
            elif gan_type == "vanilla-gan":
                loss_real_ = -torch.log(D_real_[-1] + eps)
                loss_fake_ = -torch.log(1 - D_fake_det_[-1] + eps)
            elif gan_type == "hinge":
                loss_real_ = F.relu(1 - D_real_[-1])
                loss_fake_ = F.relu(1 + D_fake_det_[-1])
            else:
                raise ValueError(f"Unknown gan type: {gan_type}")

            # mask for D
            if (
                hasattr(netD_A, "downsample_scale")
                and mask.shape[1] // netD_A.downsample_scale == D_real_[-1].shape[1]
            ):
                D_mask = mask[:, :: netD_A.downsample_scale, :]
            else:
                if D_real_[-1].shape[1] == out_feats.shape[1]:
                    D_mask = mask
                else:
                    D_mask = None

            if D_mask is not None:
                loss_real_ = loss_real_.masked_select(D_mask).mean()
                loss_fake_ = loss_fake_.masked_select(D_mask).mean()
            else:
                loss_real_ = loss_real_.mean()
                loss_fake_ = loss_fake_.mean()

            log_metrics[f"Loss_Real_Scale{idx}_A"] = loss_real_.item()
            log_metrics[f"Loss_Fake_Scale{idx}_A"] = loss_fake_.item()

            loss_real += loss_real_
            loss_fake += loss_fake_

        # B
        for idx, (D_real_, D_fake_det_) in enumerate(zip(D_real_B, D_fake_det_B)):
            if gan_type == "lsgan":
                loss_real_ = (D_real_[-1] - 1) ** 2
                loss_fake_ = D_fake_det_[-1] ** 2
            elif gan_type == "vanilla-gan":
                loss_real_ = -torch.log(D_real_[-1] + eps)
                loss_fake_ = -torch.log(1 - D_fake_det_[-1] + eps)
            elif gan_type == "hinge":
                loss_real_ = F.relu(1 - D_real_[-1])
                loss_fake_ = F.relu(1 + D_fake_det_[-1])
            else:
                raise ValueError(f"Unknown gan type: {gan_type}")

            # mask for D
            if (
                hasattr(netD_B, "downsample_scale")
                and mask.shape[1] // netD_B.downsample_scale == D_real_[-1].shape[1]
            ):
                D_mask = mask[:, :: netD_B.downsample_scale, :]
            else:
                if D_real_[-1].shape[1] == out_feats.shape[1]:
                    D_mask = mask
                else:
                    D_mask = None

            if D_mask is not None:
                loss_real_ = loss_real_.masked_select(D_mask).mean()
                loss_fake_ = loss_fake_.masked_select(D_mask).mean()
            else:
                loss_real_ = loss_real_.mean()
                loss_fake_ = loss_fake_.mean()

            log_metrics[f"Loss_Real_Scale{idx}_B"] = loss_real_.item()
            log_metrics[f"Loss_Fake_Scale{idx}_B"] = loss_fake_.item()

            loss_real += loss_real_
            loss_fake += loss_fake_

        loss_d = loss_real + loss_fake

    if train:
        optD.zero_grad()
        if grad_scaler is not None:
            grad_scaler.scale(loss_d).backward()
            grad_scaler.unscale_(optD)
            grad_norm_d = torch.nn.utils.clip_grad_norm_(
                netD_A.parameters(), optim_config.netD.clip_norm
            )
            log_metrics["GradNorm_D/netG_A2B"] = grad_norm_d
            grad_norm_d = torch.nn.utils.clip_grad_norm_(
                netD_B.parameters(), optim_config.netD.clip_norm
            )
            log_metrics["GradNorm_D/netG_B2A"] = grad_norm_d
            grad_scaler.step(optD)
        else:
            loss_d.backward()
            grad_norm_d = torch.nn.utils.clip_grad_norm_(
                netD_A.parameters(), optim_config.netD.clip_norm
            )
            log_metrics["GradNorm_D/netG_A2B"] = grad_norm_d
            grad_norm_d = torch.nn.utils.clip_grad_norm_(
                netD_B.parameters(), optim_config.netD.clip_norm
            )
            log_metrics["GradNorm_D/netG_B2A"] = grad_norm_d
            optD.step()

    # adversarial loss
    loss_adv = 0

    with autocast(enabled=grad_scaler is not None):
        # A
        D_fake_A = netD_A(fake_netD_in_feats_A * vuv, in_feats, lengths)
        for idx, D_fake_ in enumerate(D_fake_A):
            if gan_type == "lsgan":
                loss_adv_ = (1 - D_fake_[-1]) ** 2
            elif gan_type == "vanilla-gan":
                loss_adv_ = -torch.log(D_fake_[-1] + eps)
            elif gan_type == "hinge":
                loss_adv_ = -D_fake_[-1]
            else:
                raise ValueError(f"Unknown gan type: {gan_type}")

            if (
                hasattr(netD_A, "downsample_scale")
                and mask.shape[1] // netD_A.downsample_scale == D_fake_[-1].shape[1]
            ):
                D_mask = mask[:, :: netD_A.downsample_scale, :]
            else:
                if D_real_[-1].shape[1] == out_feats.shape[1]:
                    D_mask = mask
                else:
                    D_mask = None

            if D_mask is not None:
                loss_adv_ = loss_adv_.masked_select(D_mask).mean()
            else:
                loss_adv_ = loss_adv_.mean()

            log_metrics[f"Loss_Adv_Scale{idx}_A"] = loss_adv_.item()

            loss_adv += loss_adv_

        # B
        D_fake_B = netD_B(fake_netD_in_feats_B * vuv, in_feats, lengths)
        for idx, D_fake_ in enumerate(D_fake_B):
            if gan_type == "lsgan":
                loss_adv_ = (1 - D_fake_[-1]) ** 2
            elif gan_type == "vanilla-gan":
                loss_adv_ = -torch.log(D_fake_[-1] + eps)
            elif gan_type == "hinge":
                loss_adv_ = -D_fake_[-1]
            else:
                raise ValueError(f"Unknown gan type: {gan_type}")

            if (
                hasattr(netD_B, "downsample_scale")
                and mask.shape[1] // netD_B.downsample_scale == D_fake_[-1].shape[1]
            ):
                D_mask = mask[:, :: netD_B.downsample_scale, :]
            else:
                if D_real_[-1].shape[1] == out_feats.shape[1]:
                    D_mask = mask
                else:
                    D_mask = None

            if D_mask is not None:
                loss_adv_ = loss_adv_.masked_select(D_mask).mean()
            else:
                loss_adv_ = loss_adv_.mean()

            log_metrics[f"Loss_Adv_Scale{idx}_B"] = loss_adv_.item()

            loss_adv += loss_adv_

        # Feature matching loss
        loss_fm = torch.tensor(0.0).to(in_feats.device)
        if fm_weight > 0:
            for D_fake_, D_real_ in zip(D_fake_A, D_real_A):
                for fake_fmap, real_fmap in zip(D_fake_[:-1], D_real_[:-1]):
                    loss_fm += F.l1_loss(fake_fmap, real_fmap.detach())
            for D_fake_, D_real_ in zip(D_fake_B, D_real_B):
                for fake_fmap, real_fmap in zip(D_fake_[:-1], D_real_[:-1]):
                    loss_fm += F.l1_loss(fake_fmap, real_fmap.detach())

    loss = (
        adv_weight * loss_adv
        + cycle_weight * loss_cycle
        + id_weight * loss_id
        + fm_weight * loss_fm
    )

    if train:
        optG.zero_grad()
        if grad_scaler is not None:
            grad_scaler.scale(loss).backward()
            grad_scaler.unscale_(optG)
            grad_norm_g = torch.nn.utils.clip_grad_norm_(
                netG_A2B.parameters(), optim_config.netG.clip_norm
            )
            log_metrics["GradNorm_G/netG_A2B"] = grad_norm_g
            grad_norm_g = torch.nn.utils.clip_grad_norm_(
                netG_B2A.parameters(), optim_config.netG.clip_norm
            )
            log_metrics["GradNorm_G/netG_B2A"] = grad_norm_g
            grad_scaler.step(optG)
        else:
            loss.backward()
            grad_norm_g = torch.nn.utils.clip_grad_norm_(
                netG_A2B.parameters(), optim_config.netG.clip_norm
            )
            log_metrics["GradNorm_G/netG_A2B"] = grad_norm_g
            grad_norm_g = torch.nn.utils.clip_grad_norm_(
                netG_B2A.parameters(), optim_config.netG.clip_norm
            )
            log_metrics["GradNorm_G/netG_B2A"] = grad_norm_g
            optG.step()

    # NOTE: this shouldn't be called multiple times in a training step
    if train and grad_scaler is not None:
        grad_scaler.update()

    # Metrics
    distortions = compute_distortions(
        pred_out_feats_B, out_feats, lengths, out_scaler, model_config
    )
    log_metrics.update(distortions)
    log_metrics.update(
        {
            "Loss": loss.item(),
            "Loss_Adv_Total": loss_adv.item(),
            "Loss_Cycle": loss_cycle.item(),
            "Loss_Identity": loss_id.item(),
            "Loss_Feature_Matching": loss_fm.item(),
            "Loss_Real_Total": loss_real.item(),
            "Loss_Fake_Total": loss_fake.item(),
            "Loss_D": loss_d.item(),
        }
    )

    return loss, log_metrics
Ejemplo n.º 5
0
def train_step(
    model,
    optimizer,
    grad_scaler,
    train,
    in_feats,
    out_feats,
    lengths,
    out_scaler,
    feats_criterion="mse",
    stream_wise_loss=False,
    stream_weights=None,
    stream_sizes=None,
):
    model.train() if train else model.eval()
    optimizer.zero_grad()

    if feats_criterion in ["l2", "mse"]:
        criterion = nn.MSELoss(reduction="none")
    elif feats_criterion in ["l1", "mae"]:
        criterion = nn.L1Loss(reduction="none")
    else:
        raise RuntimeError("not supported criterion")

    prediction_type = (model.module.prediction_type() if isinstance(
        model, nn.DataParallel) else model.prediction_type())

    # Apply preprocess if required (e.g., FIR filter for shallow AR)
    # defaults to no-op
    if isinstance(model, nn.DataParallel):
        out_feats = model.module.preprocess_target(out_feats)
    else:
        out_feats = model.preprocess_target(out_feats)

    # Run forward
    with autocast(enabled=grad_scaler is not None):
        pred_out_feats = model(in_feats, lengths)

    # Mask (B, T, 1)
    mask = make_non_pad_mask(lengths).unsqueeze(-1).to(in_feats.device)

    # Compute loss
    if prediction_type == PredictionType.PROBABILISTIC:
        pi, sigma, mu = pred_out_feats
        # (B, max(T)) or (B, max(T), D_out)
        mask_ = mask if len(pi.shape) == 4 else mask.squeeze(-1)
        # Compute loss and apply mask
        with autocast(enabled=grad_scaler is not None):
            loss = mdn_loss(pi, sigma, mu, out_feats, reduce=False)
        loss = loss.masked_select(mask_).mean()
    else:
        if stream_wise_loss:
            w = get_stream_weight(stream_weights,
                                  stream_sizes).to(in_feats.device)
            streams = split_streams(out_feats, stream_sizes)
            pred_streams = split_streams(pred_out_feats, stream_sizes)
            loss = 0
            for pred_stream, stream, sw in zip(pred_streams, streams, w):
                with autocast(enabled=grad_scaler is not None):
                    loss += (sw * criterion(pred_stream.masked_select(mask),
                                            stream.masked_select(mask)).mean())
        else:
            with autocast(enabled=grad_scaler is not None):
                loss = criterion(pred_out_feats.masked_select(mask),
                                 out_feats.masked_select(mask)).mean()

    if prediction_type == PredictionType.PROBABILISTIC:
        with torch.no_grad():
            pred_out_feats_ = mdn_get_most_probable_sigma_and_mu(
                pi, sigma, mu)[1]
    else:
        pred_out_feats_ = pred_out_feats
    distortions = compute_distortions(pred_out_feats_, out_feats, lengths,
                                      out_scaler)

    if train:
        if grad_scaler is not None:
            grad_scaler.scale(loss).backward()
            grad_scaler.step(optimizer)
            grad_scaler.update()
        else:
            loss.backward()
            optimizer.step()

    return loss, distortions
Ejemplo n.º 6
0
def train_loop(config, device, model, optimizer, lr_scheduler, data_loaders):
    criterion = nn.MSELoss(reduction="none")

    logger.info("Start utterance-wise training...")

    stream_weights = get_stream_weight(
        config.model.stream_weights, config.model.stream_sizes).to(device)

    best_loss = 10000000
    for epoch in tqdm(range(1, config.train.nepochs + 1)):
        for phase in data_loaders.keys():
            train = phase.startswith("train")
            model.train() if train else model.eval()
            running_loss = 0
            for x, y, lengths in data_loaders[phase]:
                # Sort by lengths . This is needed for pytorch's PackedSequence
                sorted_lengths, indices = torch.sort(lengths, dim=0, descending=True)
                x, y = x[indices].to(device), y[indices].to(device)

                optimizer.zero_grad()

                # Apply preprocess if required (e.g., FIR filter for shallow AR)
                # defaults to no-op
                y = model.preprocess_target(y)

                # Run forwaard
                if model.prediction_type() == PredictionType.PROBABILISTIC:
                    pi, sigma, mu = model(x, sorted_lengths)

                    # (B, max(T)) or (B, max(T), D_out)
                    mask = make_non_pad_mask(sorted_lengths).to(device)
                    mask = mask.unsqueeze(-1) if len(pi.shape) == 4 else mask
                    # Compute loss and apply mask
                    loss = mdn_loss(pi, sigma, mu, y, reduce=False)
                    loss = loss.masked_select(mask).mean()
                else:
                    y_hat = model(x, sorted_lengths)

                    # Compute loss
                    mask = make_non_pad_mask(sorted_lengths).unsqueeze(-1).to(device)

                    if config.train.stream_wise_loss:
                        # Strean-wise loss
                        streams = split_streams(y, config.model.stream_sizes)
                        streams_hat = split_streams(y_hat, config.model.stream_sizes)
                        loss = 0
                        for s_hat, s, sw in zip(streams_hat, streams, stream_weights):
                            s_hat_mask = s_hat.masked_select(mask)
                            s_mask = s.masked_select(mask)
                            loss += sw * criterion(s_hat_mask, s_mask).mean()
                    else:
                        # Joint modeling
                        y_hat = y_hat.masked_select(mask)
                        y = y.masked_select(mask)
                        loss = criterion(y_hat, y).mean()

                if train:
                    loss.backward()
                    optimizer.step()

                running_loss += loss.item()
            ave_loss = running_loss / len(data_loaders[phase])
            logger.info("[%s] [Epoch %s]: loss %s", phase, epoch, ave_loss)
            if not train and ave_loss < best_loss:
                best_loss = ave_loss
                save_best_checkpoint(config, model, optimizer, best_loss)

        # step per each epoch (may consider updating per iter.)
        lr_scheduler.step()

        if epoch % config.train.checkpoint_epoch_interval == 0:
            save_checkpoint(config, model, optimizer, lr_scheduler, epoch)

    # save at last epoch
    save_checkpoint(config, model, optimizer, lr_scheduler, config.train.nepochs)
    logger.info("The best loss was {%s}", best_loss)

    return model
Ejemplo n.º 7
0
def train_step(
    model_config,
    optim_config,
    netG,
    optG,
    netD,
    optD,
    grad_scaler,
    train,
    in_feats,
    out_feats,
    lengths,
    out_scaler,
    mse_weight=1.0,
    adv_weight=1.0,
    adv_streams=None,
    fm_weight=0.0,
    mask_nth_mgc_for_adv_loss=0,
    gan_type="lsgan",
    vuv_mask=False,
):
    netG.train() if train else netG.eval()
    netD.train() if train else netD.eval()

    log_metrics = {}

    if vuv_mask:
        # NOTE: Assuming 3rd stream is the V/UV
        vuv_idx = np.sum(model_config.stream_sizes[:2])
        is_v = torch.logical_and(
            out_feats[:, :, vuv_idx:vuv_idx + 1] > 0,
            in_feats[:, :, vuv_idx:vuv_idx + 1] > 0,
        )
        vuv = is_v
    else:
        vuv = 1.0

    # Run forward
    with autocast(enabled=grad_scaler is not None):
        pred_out_feats = netG(in_feats, lengths)

    real_netD_in_feats = select_streams(out_feats, model_config.stream_sizes,
                                        adv_streams)
    fake_netD_in_feats = select_streams(
        pred_out_feats,
        model_config.stream_sizes,
        adv_streams,
    )

    # Ref: http://sython.org/papers/ASJ/saito2017asja.pdf
    # 0-th mgc with adversarial trainging affects speech quality
    # NOTE: assuming that the first stream contains mgc
    if mask_nth_mgc_for_adv_loss > 0:
        real_netD_in_feats = real_netD_in_feats[:, :,
                                                mask_nth_mgc_for_adv_loss:]
        fake_netD_in_feats = fake_netD_in_feats[:, :,
                                                mask_nth_mgc_for_adv_loss:]

    # Real
    with autocast(enabled=grad_scaler is not None):
        D_real = netD(real_netD_in_feats * vuv, in_feats, lengths)
        # NOTE: must be list of list to support multi-scale discriminators
        assert isinstance(D_real, list) and isinstance(D_real[-1], list)
        # Fake
        D_fake_det = netD(fake_netD_in_feats.detach() * vuv, in_feats, lengths)

    # Mask (B, T, 1)
    mask = make_non_pad_mask(lengths).unsqueeze(-1).to(in_feats.device)

    # Update discriminator
    eps = 1e-14
    loss_real = 0
    loss_fake = 0

    with autocast(enabled=grad_scaler is not None):
        for idx, (D_real_, D_fake_det_) in enumerate(zip(D_real, D_fake_det)):
            if gan_type == "lsgan":
                loss_real_ = (D_real_[-1] - 1)**2
                loss_fake_ = D_fake_det_[-1]**2
            elif gan_type == "vanilla-gan":
                loss_real_ = -torch.log(D_real_[-1] + eps)
                loss_fake_ = -torch.log(1 - D_fake_det_[-1] + eps)
            elif gan_type == "hinge":
                loss_real_ = F.relu(1 - D_real_[-1])
                loss_fake_ = F.relu(1 + D_fake_det_[-1])
            else:
                raise ValueError(f"Unknown gan type: {gan_type}")

            # mask for D
            if (hasattr(netD, "downsample_scale")
                    and mask.shape[1] // netD.downsample_scale
                    == D_real_[-1].shape[1]):
                D_mask = mask[:, ::netD.downsample_scale, :]
            else:
                if D_real_[-1].shape[1] == out_feats.shape[1]:
                    D_mask = mask
                else:
                    D_mask = None

            if D_mask is not None:
                loss_real_ = loss_real_.masked_select(D_mask).mean()
                loss_fake_ = loss_fake_.masked_select(D_mask).mean()
            else:
                loss_real_ = loss_real_.mean()
                loss_fake_ = loss_fake_.mean()

            log_metrics[f"Loss_Real_Scale{idx}"] = loss_real_.item()
            log_metrics[f"Loss_Fake_Scale{idx}"] = loss_fake_.item()

            loss_real += loss_real_
            loss_fake += loss_fake_

        loss_d = loss_real + loss_fake

    if train:
        optD.zero_grad()
        if grad_scaler is not None:
            grad_scaler.scale(loss_d).backward()
            grad_scaler.unscale_(optD)
            grad_norm_d = torch.nn.utils.clip_grad_norm_(
                netD.parameters(), optim_config.netD.clip_norm)
            log_metrics["GradNorm_D"] = grad_norm_d
            grad_scaler.step(optD)
        else:
            loss_d.backward()
            grad_norm_d = torch.nn.utils.clip_grad_norm_(
                netD.parameters(), optim_config.netD.clip_norm)
            log_metrics["GradNorm_D"] = grad_norm_d
            optD.step()

    # MSE loss
    with autocast(enabled=grad_scaler is not None):
        loss_feats = nn.MSELoss(reduction="none")(
            pred_out_feats.masked_select(mask),
            out_feats.masked_select(mask)).mean()

    # adversarial loss
    with autocast(enabled=grad_scaler is not None):
        D_fake = netD(fake_netD_in_feats * vuv, in_feats, lengths)

        loss_adv = 0
        for idx, D_fake_ in enumerate(D_fake):
            if gan_type == "lsgan":
                loss_adv_ = (1 - D_fake_[-1])**2
            elif gan_type == "vanilla-gan":
                loss_adv_ = -torch.log(D_fake_[-1] + eps)
            elif gan_type == "hinge":
                loss_adv_ = -D_fake_[-1]
            else:
                raise ValueError(f"Unknown gan type: {gan_type}")

            if (hasattr(netD, "downsample_scale")
                    and mask.shape[1] // netD.downsample_scale
                    == D_fake_[-1].shape[1]):
                D_mask = mask[:, ::netD.downsample_scale, :]
            else:
                if D_real_[-1].shape[1] == out_feats.shape[1]:
                    D_mask = mask
                else:
                    D_mask = None

            if D_mask is not None:
                loss_adv_ = loss_adv_.masked_select(D_mask).mean()
            else:
                loss_adv_ = loss_adv_.mean()

            log_metrics[f"Loss_Adv_Scale{idx}"] = loss_adv_.item()

            loss_adv += loss_adv_

        # Feature matching loss
        loss_fm = torch.tensor(0.0).to(in_feats.device)
        if fm_weight > 0:
            for D_fake_, D_real_ in zip(D_fake, D_real):
                for fake_fmap, real_fmap in zip(D_fake_[:-1], D_real_[:-1]):
                    loss_fm += F.l1_loss(fake_fmap, real_fmap.detach())

        loss = mse_weight * loss_feats + adv_weight * loss_adv + fm_weight * loss_fm

    if train:
        optG.zero_grad()
        if grad_scaler is not None:
            grad_scaler.scale(loss).backward()
            grad_scaler.unscale_(optG)
            grad_norm_g = torch.nn.utils.clip_grad_norm_(
                netG.parameters(), optim_config.netG.clip_norm)
            log_metrics["GradNorm_G"] = grad_norm_g
            grad_scaler.step(optG)
        else:
            loss.backward()
            grad_norm_g = torch.nn.utils.clip_grad_norm_(
                netG.parameters(), optim_config.netG.clip_norm)
            log_metrics["GradNorm_G"] = grad_norm_g
            optG.step()

    # NOTE: this shouldn't be called multiple times in a training step
    if train and grad_scaler is not None:
        grad_scaler.update()

    # Metrics
    distortions = compute_distortions(pred_out_feats, out_feats, lengths,
                                      out_scaler, model_config)
    log_metrics.update(distortions)
    log_metrics.update({
        "Loss": loss.item(),
        "Loss_Feats": loss_feats.item(),
        "Loss_Adv_Total": loss_adv.item(),
        "Loss_Feature_Matching": loss_fm.item(),
        "Loss_Real_Total": loss_real.item(),
        "Loss_Fake_Total": loss_fake.item(),
        "Loss_D": loss_d.item(),
    })

    return loss, log_metrics