Beispiel #1
0
def train_loop(
    config,
    logger,
    device,
    netG_A2B,
    netG_B2A,
    optG,
    schedulerG,
    netD_A,
    netD_B,
    optD,
    schedulerD,
    grad_scaler,
    data_loaders,
    writer,
    out_scaler,
    use_mlflow,
):
    out_dir = Path(to_absolute_path(config.train.out_dir))
    best_dev_loss = torch.finfo(torch.float32).max
    last_dev_loss = torch.finfo(torch.float32).max

    adv_streams = config.train.adv_streams
    if len(adv_streams) != len(config.model.stream_sizes):
        raise ValueError("adv_streams must be specified for all streams")

    for epoch in tqdm(range(1, config.train.nepochs + 1)):
        for phase in data_loaders.keys():
            train = phase.startswith("train")
            running_loss = 0
            running_metrics = {}
            evaluated = False
            for in_feats, out_feats, lengths in data_loaders[phase]:
                # NOTE: This is needed for pytorch's PackedSequence
                lengths, indices = torch.sort(lengths, dim=0, descending=True)
                in_feats, out_feats = (
                    in_feats[indices].to(device),
                    out_feats[indices].to(device),
                )

                if (not train) and (not evaluated):
                    eval_spss_model(
                        epoch,
                        netG_A2B,
                        in_feats,
                        out_feats,
                        lengths,
                        config.model,
                        out_scaler,
                        writer,
                        sr=config.data.sample_rate,
                    )
                    evaluated = True

                if (
                    config.train.id_loss_until > 0
                    and epoch > config.train.id_loss_until
                ):
                    use_id_loss = False
                else:
                    use_id_loss = True

                loss, log_metrics = train_step(
                    model_config=config.model,
                    optim_config=config.train.optim,
                    netG_A2B=netG_A2B,
                    netG_B2A=netG_B2A,
                    optG=optG,
                    netD_A=netD_A,
                    netD_B=netD_B,
                    optD=optD,
                    grad_scaler=grad_scaler,
                    train=train,
                    in_feats=in_feats,
                    out_feats=out_feats,
                    lengths=lengths,
                    out_scaler=out_scaler,
                    adv_weight=config.train.adv_weight,
                    adv_streams=adv_streams,
                    fm_weight=config.train.fm_weight,
                    mask_nth_mgc_for_adv_loss=config.train.mask_nth_mgc_for_adv_loss,
                    gan_type=config.train.gan_type,
                    vuv_mask=config.train.vuv_mask,
                    cycle_weight=config.train.cycle_weight,
                    id_weight=config.train.id_weight,
                    use_id_loss=use_id_loss,
                )
                running_loss += loss.item()
                for k, v in log_metrics.items():
                    try:
                        running_metrics[k] += float(v)
                    except KeyError:
                        running_metrics[k] = float(v)

            ave_loss = running_loss / len(data_loaders[phase])
            logger.info("[%s] [Epoch %s]: loss %s", phase, epoch, ave_loss)

            for k, v in running_metrics.items():
                ave_v = v / len(data_loaders[phase])
                if writer is not None:
                    writer.add_scalar(f"{k}/{phase}", ave_v, epoch)
                if use_mlflow:
                    mlflow.log_metric(f"{phase}_{k}", ave_v, step=epoch)

            if not train:
                last_dev_loss = ave_loss
            if not train and ave_loss < best_dev_loss:
                best_dev_loss = ave_loss
                for model, opt, scheduler, postfix in [
                    (netG_A2B, optG, schedulerG, "_A2B"),
                    (netG_B2A, optG, schedulerG, "_B2A"),
                    (netD_A, optD, schedulerD, "_D_A"),
                    (netD_B, optD, schedulerD, "_D_B"),
                ]:
                    save_checkpoint(
                        logger,
                        out_dir,
                        model,
                        opt,
                        scheduler,
                        epoch,
                        is_best=True,
                        postfix=postfix,
                    )

        schedulerG.step()
        schedulerD.step()

        if epoch % config.train.checkpoint_epoch_interval == 0:
            for model, opt, scheduler, postfix in [
                (netG_A2B, optG, schedulerG, "_A2B"),
                (netG_B2A, optG, schedulerG, "_B2A"),
                (netD_A, optD, schedulerD, "_D_A"),
                (netD_B, optD, schedulerD, "_D_B"),
            ]:
                save_checkpoint(
                    logger,
                    out_dir,
                    model,
                    opt,
                    scheduler,
                    epoch,
                    is_best=False,
                    postfix=postfix,
                )

    for model, opt, scheduler, postfix in [
        (netG_A2B, optG, schedulerG, "_A2B"),
        (netG_B2A, optG, schedulerG, "_B2A"),
        (netD_A, optD, schedulerD, "_D_A"),
        (netD_B, optD, schedulerD, "_D_B"),
    ]:
        save_checkpoint(
            logger,
            out_dir,
            model,
            opt,
            scheduler,
            config.train.nepochs,
            postfix=postfix,
        )
    logger.info("The best loss was %s", best_dev_loss)
    if use_mlflow:
        mlflow.log_metric("best_dev_loss", best_dev_loss, step=epoch)
        mlflow.log_artifacts(out_dir)

    return last_dev_loss
Beispiel #2
0
def train_loop(
    config,
    logger,
    device,
    netG,
    optG,
    schedulerG,
    netD,
    optD,
    schedulerD,
    grad_scaler,
    data_loaders,
    writer,
    in_scaler,
    out_scaler,
    use_mlflow,
):
    out_dir = Path(to_absolute_path(config.train.out_dir))
    best_dev_loss = torch.finfo(torch.float32).max
    last_dev_loss = torch.finfo(torch.float32).max

    in_lf0_idx = config.data.in_lf0_idx
    in_rest_idx = config.data.in_rest_idx
    if in_lf0_idx is None or in_rest_idx is None:
        raise ValueError("in_lf0_idx and in_rest_idx must be specified")
    pitch_reg_weight = config.train.pitch_reg_weight
    fm_weight = config.train.fm_weight
    adv_streams = config.train.adv_streams
    if len(adv_streams) != len(config.model.stream_sizes):
        raise ValueError("adv_streams must be specified for all streams")

    if "sample_rate" not in config.data:
        logger.warning(
            "sample_rate is not found in the data config. Fallback to 48000."
        )
        sr = 48000
    else:
        sr = config.data.sample_rate

    if "feats_criterion" not in config.train:
        logger.warning(
            "feats_criterion is not found in the data config. Fallback to MSE."
        )
        feats_criterion = "mse"
    else:
        feats_criterion = config.train.feats_criterion

    for epoch in tqdm(range(1, config.train.nepochs + 1)):
        for phase in data_loaders.keys():
            train = phase.startswith("train")
            running_loss = 0
            running_metrics = {}
            evaluated = False
            for in_feats, out_feats, lengths in data_loaders[phase]:
                # NOTE: This is needed for pytorch's PackedSequence
                lengths, indices = torch.sort(lengths, dim=0, descending=True)
                in_feats, out_feats = (
                    in_feats[indices].to(device),
                    out_feats[indices].to(device),
                )
                if (not train) and (not evaluated):
                    eval_spss_model(
                        epoch,
                        netG,
                        in_feats,
                        out_feats,
                        lengths,
                        config.model,
                        out_scaler,
                        writer,
                        sr=sr,
                    )
                    evaluated = True

                # Compute denormalized log-F0 in the musical scores
                lf0_score_denorm = (
                    in_feats[:, :, in_lf0_idx] - in_scaler.min_[in_lf0_idx]
                ) / in_scaler.scale_[in_lf0_idx]
                # Fill zeros for rest and padded frames
                lf0_score_denorm *= (in_feats[:, :, in_rest_idx] <= 0).float()
                for idx, length in enumerate(lengths):
                    lf0_score_denorm[idx, length:] = 0
                # Compute time-variant pitch regularization weight vector
                pitch_reg_dyn_ws = compute_batch_pitch_regularization_weight(
                    lf0_score_denorm
                )

                loss, log_metrics = train_step(
                    model_config=config.model,
                    optim_config=config.train.optim,
                    netG=netG,
                    optG=optG,
                    netD=netD,
                    optD=optD,
                    grad_scaler=grad_scaler,
                    train=train,
                    in_feats=in_feats,
                    out_feats=out_feats,
                    lengths=lengths,
                    out_scaler=out_scaler,
                    feats_criterion=feats_criterion,
                    pitch_reg_dyn_ws=pitch_reg_dyn_ws,
                    pitch_reg_weight=pitch_reg_weight,
                    adv_weight=config.train.adv_weight,
                    adv_streams=adv_streams,
                    fm_weight=fm_weight,
                    adv_use_static_feats_only=config.train.adv_use_static_feats_only,
                    mask_nth_mgc_for_adv_loss=config.train.mask_nth_mgc_for_adv_loss,
                    gan_type=config.train.gan_type,
                )
                running_loss += loss.item()
                for k, v in log_metrics.items():
                    try:
                        running_metrics[k] += float(v)
                    except KeyError:
                        running_metrics[k] = float(v)

            ave_loss = running_loss / len(data_loaders[phase])
            logger.info("[%s] [Epoch %s]: loss %s", phase, epoch, ave_loss)

            for k, v in running_metrics.items():
                ave_v = v / len(data_loaders[phase])
                if writer is not None:
                    writer.add_scalar(f"{k}/{phase}", ave_v, epoch)
                if use_mlflow:
                    mlflow.log_metric(f"{phase}_{k}", ave_v, step=epoch)

            if not train:
                last_dev_loss = ave_loss
            if not train and ave_loss < best_dev_loss:
                best_dev_loss = ave_loss
                for model, opt, scheduler, postfix in [
                    (netG, optG, schedulerG, ""),
                    (netD, optD, schedulerD, "_D"),
                ]:
                    save_checkpoint(
                        logger,
                        out_dir,
                        model,
                        opt,
                        scheduler,
                        epoch,
                        is_best=True,
                        postfix=postfix,
                    )

        schedulerG.step()
        schedulerD.step()

        if epoch % config.train.checkpoint_epoch_interval == 0:
            for model, opt, scheduler, postfix in [
                (netG, optG, schedulerG, ""),
                (netD, optD, schedulerD, "_D"),
            ]:
                save_checkpoint(
                    logger,
                    out_dir,
                    model,
                    opt,
                    scheduler,
                    epoch,
                    is_best=False,
                    postfix=postfix,
                )

    for model, opt, scheduler, postfix in [
        (netG, optG, schedulerG, ""),
        (netD, optD, schedulerD, "_D"),
    ]:
        save_checkpoint(
            logger,
            out_dir,
            model,
            opt,
            scheduler,
            config.train.nepochs,
            postfix=postfix,
        )
    logger.info("The best loss was %s", best_dev_loss)
    if use_mlflow:
        mlflow.log_metric("best_dev_loss", best_dev_loss, step=epoch)
        mlflow.log_artifacts(out_dir)

    return last_dev_loss
Beispiel #3
0
def train_loop(
    config,
    logger,
    device,
    model,
    optimizer,
    lr_scheduler,
    grad_scaler,
    data_loaders,
    writer,
    out_scaler,
    use_mlflow,
):
    out_dir = Path(to_absolute_path(config.train.out_dir))
    best_dev_loss = torch.finfo(torch.float32).max
    last_dev_loss = torch.finfo(torch.float32).max

    if "feats_criterion" not in config.train:
        logger.warning(
            "feats_criterion is not found in the train config. Fallback to MSE."
        )
        feats_criterion = "mse"
    else:
        feats_criterion = config.train.feats_criterion

    for epoch in tqdm(range(1, config.train.nepochs + 1)):
        for phase in data_loaders.keys():
            train = phase.startswith("train")
            running_loss = 0
            running_metrics = {}
            for in_feats, out_feats, lengths in data_loaders[phase]:
                # NOTE: This is needed for pytorch's PackedSequence
                lengths, indices = torch.sort(lengths, dim=0, descending=True)
                in_feats, out_feats = (
                    in_feats[indices].to(device),
                    out_feats[indices].to(device),
                )
                loss, distortions = train_step(
                    model=model,
                    optimizer=optimizer,
                    grad_scaler=grad_scaler,
                    train=train,
                    in_feats=in_feats,
                    out_feats=out_feats,
                    lengths=lengths,
                    out_scaler=out_scaler,
                    feats_criterion=feats_criterion,
                    stream_wise_loss=config.train.stream_wise_loss,
                    stream_weights=config.model.stream_weights,
                    stream_sizes=config.model.stream_sizes,
                )
                running_loss += loss.item()
                for k, v in distortions.items():
                    try:
                        running_metrics[k] += float(v)
                    except KeyError:
                        running_metrics[k] = float(v)

            ave_loss = running_loss / len(data_loaders[phase])
            if writer is not None:
                writer.add_scalar(f"Loss/{phase}", ave_loss, epoch)
            if use_mlflow:
                mlflow.log_metric(f"{phase}_loss", ave_loss, step=epoch)

            ave_loss = running_loss / len(data_loaders[phase])
            logger.info("[%s] [Epoch %s]: loss %s", phase, epoch, ave_loss)

            for k, v in running_metrics.items():
                ave_v = v / len(data_loaders[phase])
                if writer is not None:
                    writer.add_scalar(f"{k}/{phase}", ave_v, epoch)
                if use_mlflow:
                    mlflow.log_metric(f"{phase}_{k}", ave_v, step=epoch)

            if not train:
                last_dev_loss = ave_loss
            if not train and ave_loss < best_dev_loss:
                best_dev_loss = ave_loss
                save_checkpoint(logger,
                                out_dir,
                                model,
                                optimizer,
                                lr_scheduler,
                                epoch,
                                is_best=True)

        lr_scheduler.step()
        if epoch % config.train.checkpoint_epoch_interval == 0:
            save_checkpoint(logger,
                            out_dir,
                            model,
                            optimizer,
                            lr_scheduler,
                            epoch,
                            is_best=False)

    save_checkpoint(logger, out_dir, model, optimizer, lr_scheduler,
                    config.train.nepochs)
    logger.info("The best loss was %s", best_dev_loss)
    if use_mlflow:
        mlflow.log_metric("best_dev_loss", best_dev_loss, step=epoch)
        mlflow.log_artifacts(out_dir)

    return last_dev_loss