Exemplo n.º 1
0
def test_store_restore(decay, use_num_updates, explicit_params):
    model = torch.nn.Linear(10, 2)
    ema = ExponentialMovingAverage(
        model.parameters(),
        decay=decay,
        use_num_updates=use_num_updates
    )
    orig_weight = model.weight.clone().detach()
    if explicit_params:
        ema.store(model.parameters())
    else:
        ema.store()
    with torch.no_grad():
        model.weight.uniform_(0.0, 1.0)
    if explicit_params:
        ema.restore(model.parameters())
    else:
        ema.restore()
    assert torch.all(model.weight == orig_weight)
Exemplo n.º 2
0
    def __init__(
            self,
            steps,
            epochs,
            data_loader,
            sampler,
            model,
            criterion,
            optimizer,
            scheduler,
            config,
            device=torch.device("cpu"),
    ):
        """Initialize trainer.

        Args:
            steps (int): Initial global steps.
            epochs (int): Initial global epochs.
            data_loader (dict): Dict of data loaders. It must contrain "train" and "dev" loaders.
            model (dict): Dict of models. It must contrain "generator" and "discriminator" models.
            criterion (dict): Dict of criterions. It must contrain "stft" and "mse" criterions.
            optimizer (dict): Dict of optimizers. It must contrain "generator" and "discriminator" optimizers.
            scheduler (dict): Dict of schedulers. It must contrain "generator" and "discriminator" schedulers.
            config (dict): Config dict loaded from yaml format configuration file.
            device (torch.deive): Pytorch device instance.

        """
        self.steps = steps
        self.epochs = epochs
        self.data_loader = data_loader
        self.sampler = sampler
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.config = config
        self.device = device
        self.writer = SummaryWriter(config["outdir"])
        self.finish_train = False
        self.total_train_loss = defaultdict(float)
        self.total_eval_loss = defaultdict(float)
        self.ema = ExponentialMovingAverage(model["generator"].parameters(),
                                            decay=0.995)
Exemplo n.º 3
0
def test_update(decay, explicit_params):
    model = torch.nn.Linear(10, 2, bias=False)
    with torch.no_grad():
        model.weight.fill_(0.0)
    ema = ExponentialMovingAverage(
        model.parameters(),
        decay=decay,
        use_num_updates=False
    )
    with torch.no_grad():
        model.weight.fill_(1.0)
    if explicit_params:
        ema.update(model.parameters())
    else:
        ema.update()
    assert torch.all(model.weight == 1.0), "ema.update changed model weights"
    if explicit_params:
        ema.copy_to(model.parameters())
    else:
        ema.copy_to()
    assert torch.allclose(
        model.weight,
        torch.full(size=(1,), fill_value=(1.0 - decay))
    ), "average was wrong"
Exemplo n.º 4
0
def test_explicit_params():
    model = torch.nn.Linear(10, 2)
    with torch.no_grad():
        model.weight.fill_(0.0)
    ema = ExponentialMovingAverage(model.parameters(), decay=0.9)
    model2 = torch.nn.Linear(10, 2)
    with torch.no_grad():
        model2.weight.fill_(1.0)
    ema.update(model2.parameters())
    ema.copy_to()
    assert not torch.all(model.weight == 0.0)
Exemplo n.º 5
0
    }
    num_workers = hyperparams["dataloaders"]["num_workers"]

    if not amp: scaler = None

    utils.seed_everything()

    # create model, loaders, optimizer, etc
    transforms = utils.build_transforms(second_stage=(stage == 'second'))
    loaders = utils.build_loaders(data_dir, transforms, batch_sizes, num_workers, second_stage=(stage == 'second'))
    model = utils.build_model(backbone, second_stage=(stage == 'second'), num_classes=num_classes, ckpt_pretrained=ckpt_pretrained).cuda()

    if ema:
        iters = len(loaders['train_features_loader'])
        ema_decay = ema_decay_per_epoch**(1/iters)
        ema = ExponentialMovingAverage(model.parameters(), decay=ema_decay)

    optim = utils.build_optim(model, optimizer_params, scheduler_params, criterion_params)
    criterion, optimizer, scheduler = (
        optim["criterion"],
        optim["optimizer"],
        optim["scheduler"],
    )

    # handle logging (regular logs, tensorboard, and weights)
    if logging_name is None:
        logging_name = "stage_{}_model_{}_dataset_{}".format(stage, backbone, data_dir.split("/")[-1])

    shutil.rmtree("weights/{}".format(logging_name), ignore_errors=True)
    shutil.rmtree(
        "runs/{}".format(logging_name),
Exemplo n.º 6
0
def run_single_nn(cfg,
                  train,
                  test,
                  folds,
                  num_features,
                  cat_features,
                  target,
                  device,
                  logger,
                  fold_num=0,
                  seed=7):

    # Set seed
    logger.info(f'Set seed {seed}')
    seed_everything(seed=seed)

    # loader
    trn_idx = folds[folds['fold'] != fold_num].index
    val_idx = folds[folds['fold'] == fold_num].index
    train_folds = train.loc[trn_idx].reset_index(drop=True)
    valid_folds = train.loc[val_idx].reset_index(drop=True)
    train_target = target[trn_idx]
    valid_target = target[val_idx]
    train_dataset = TrainDataset(train_folds, num_features, cat_features,
                                 train_target)
    valid_dataset = TrainDataset(valid_folds, num_features, cat_features,
                                 valid_target)
    train_loader = DataLoader(train_dataset,
                              batch_size=cfg.batch_size,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True,
                              drop_last=True)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=cfg.batch_size,
                              shuffle=False,
                              num_workers=4,
                              pin_memory=True,
                              drop_last=False)

    # model
    if cfg.ex_name == "baseline":
        model = TabularNN(cfg)
    if "add_cate_x" in cfg.ex_name:
        model = TabularNNV2(cfg)
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=cfg.lr,
                                 weight_decay=cfg.weight_decay)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer=optimizer,
        pct_start=0.1,
        div_factor=1e3,
        max_lr=1e-2,
        epochs=cfg.epochs,
        steps_per_epoch=len(train_loader))
    if "ema" in cfg.ex_name:
        ema = ExponentialMovingAverage(model.parameters(), decay=0.995)
    else:
        ema = None

    # log
    log_df = pd.DataFrame(columns=(['EPOCH'] + ['TRAIN_LOSS'] +
                                   ['VALID_LOSS']))

    # train & validate
    best_loss = np.inf
    for epoch in range(cfg.epochs):
        train_loss = train_fn(train_loader, model, optimizer, scheduler,
                              device, ema)
        valid_loss, val_preds = validate_fn(valid_loader, model, device)
        log_row = {
            'EPOCH': epoch,
            'TRAIN_LOSS': train_loss,
            'VALID_LOSS': valid_loss,
        }
        log_df = log_df.append(pd.DataFrame(log_row, index=[0]), sort=False)
        # logger.info(log_df.tail(1))
        if valid_loss < best_loss:
            logger.info(f'epoch{epoch} save best model... {valid_loss}')
            best_loss = valid_loss
            oof = np.zeros((len(train), len(cfg.target_cols)))
            oof[val_idx] = val_preds
            if ema is not None:
                ema.copy_to(model.parameters())
            torch.save(
                model.state_dict(),
                os.path.join(cfg.ex_name, f"fold{fold_num}_seed{seed}.pth"))

    # predictions
    test_dataset = TestDataset(test, num_features, cat_features)
    test_loader = DataLoader(test_dataset,
                             batch_size=cfg.batch_size,
                             shuffle=False,
                             num_workers=4,
                             pin_memory=True)
    if cfg.ex_name == "baseline":
        model = TabularNN(cfg)
    if "add_cate_x" in cfg.ex_name:
        model = TabularNNV2(cfg)
    model.load_state_dict(
        torch.load(os.path.join(cfg.ex_name,
                                f"fold{fold_num}_seed{seed}.pth")))
    model.to(device)
    predictions = inference_fn(test_loader, model, device)

    # del
    torch.cuda.empty_cache()

    return oof, predictions
        'drop_last': False
    }
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   **dataloader_extras)
    valid_dataloader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=max(1, int(args.batch_size / 3)),
        **dataloader_extras)

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.RMSprop(model.parameters(),
                                    lr=args.learning_rate,
                                    weight_decay=1e-5,
                                    momentum=0.9)
    ema = ExponentialMovingAverage(model.parameters(), decay=0.9999)

    total_epochs = args.epochs
    run_name = args.run_name if args.run_name is not None else 'MobileNetV3-Large-LR%.6f' % (
        args.learning_rate, )
    run_name = run_name + datetime.datetime.now().strftime(
        '-%Y-%m-%d-%H-%M-%S')

    print("학습 시작 (run_name: %s)" % run_name)

    summary_writer = torch.utils.tensorboard.SummaryWriter(
        log_dir=os.path.join('logs', run_name))

    if args.continue_weight:
        if not os.path.isfile(args.continue_weight):
            print("Weight file %s not found!" % args.continue_weight)
Exemplo n.º 8
0
class Trainer(object):
    """Customized trainer module for Parallel WaveGAN training."""
    def __init__(
            self,
            steps,
            epochs,
            data_loader,
            sampler,
            model,
            criterion,
            optimizer,
            scheduler,
            config,
            device=torch.device("cpu"),
    ):
        """Initialize trainer.

        Args:
            steps (int): Initial global steps.
            epochs (int): Initial global epochs.
            data_loader (dict): Dict of data loaders. It must contrain "train" and "dev" loaders.
            model (dict): Dict of models. It must contrain "generator" and "discriminator" models.
            criterion (dict): Dict of criterions. It must contrain "stft" and "mse" criterions.
            optimizer (dict): Dict of optimizers. It must contrain "generator" and "discriminator" optimizers.
            scheduler (dict): Dict of schedulers. It must contrain "generator" and "discriminator" schedulers.
            config (dict): Config dict loaded from yaml format configuration file.
            device (torch.deive): Pytorch device instance.

        """
        self.steps = steps
        self.epochs = epochs
        self.data_loader = data_loader
        self.sampler = sampler
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.config = config
        self.device = device
        self.writer = SummaryWriter(config["outdir"])
        self.finish_train = False
        self.total_train_loss = defaultdict(float)
        self.total_eval_loss = defaultdict(float)
        self.ema = ExponentialMovingAverage(model["generator"].parameters(),
                                            decay=0.995)

    def run(self):
        """Run training."""
        self.tqdm = tqdm(initial=self.steps,
                         total=self.config["train_max_steps"],
                         desc="[train]")
        while True:
            # train one epoch
            self._train_epoch()

            # check whether training is finished
            if self.finish_train:
                break

        self.tqdm.close()
        logging.info("Finished training.")

    def save_checkpoint(self, checkpoint_path):
        """Save checkpoint.

        Args:
            checkpoint_path (str): Checkpoint path to be saved.

        """
        state_dict = {
            "optimizer": {
                "generator": self.optimizer["generator"].state_dict(),
                "discriminator": self.optimizer["discriminator"].state_dict(),
            },
            "scheduler": {
                "generator": self.scheduler["generator"].state_dict(),
                "discriminator": self.scheduler["discriminator"].state_dict(),
            },
            "steps": self.steps,
            "epochs": self.epochs,
        }
        if self.config["distributed"]:
            state_dict["model"] = {
                "generator": self.model["generator"].module.state_dict(),
                "discriminator":
                self.model["discriminator"].module.state_dict(),
            }
        else:
            state_dict["model"] = {
                "generator": self.model["generator"].state_dict(),
                "discriminator": self.model["discriminator"].state_dict(),
            }

        with self.ema.average_parameters():
            state_dict["model"]["generator_ema"] = self.optimizer[
                "generator"].state_dict()

        if not os.path.exists(os.path.dirname(checkpoint_path)):
            os.makedirs(os.path.dirname(checkpoint_path))
        torch.save(state_dict, checkpoint_path)

    def load_checkpoint(self, checkpoint_path, load_only_params=False):
        """Load checkpoint.

        Args:
            checkpoint_path (str): Checkpoint path to be loaded.
            load_only_params (bool): Whether to load only model parameters.

        """
        state_dict = torch.load(checkpoint_path, map_location="cpu")
        if self.config["distributed"]:
            self.model["generator"].module.load_state_dict(
                state_dict["model"]["generator"])
            self.model["discriminator"].module.load_state_dict(
                state_dict["model"]["discriminator"])
        else:
            self.model["generator"].load_state_dict(
                state_dict["model"]["generator"])
            self.model["discriminator"].load_state_dict(
                state_dict["model"]["discriminator"])
        if not load_only_params:
            self.steps = state_dict["steps"]
            self.epochs = state_dict["epochs"]
            self.optimizer["generator"].load_state_dict(
                state_dict["optimizer"]["generator"])
            self.optimizer["discriminator"].load_state_dict(
                state_dict["optimizer"]["discriminator"])
            self.scheduler["generator"].load_state_dict(
                state_dict["scheduler"]["generator"])
            self.scheduler["discriminator"].load_state_dict(
                state_dict["scheduler"]["discriminator"])

    def _train_step(self, batch):
        """Train model one step."""
        # parse batch
        x, y = batch
        x = tuple([x_.to(self.device) for x_ in x])
        y = y.to(self.device)

        #######################
        #      Generator      #
        #######################
        y_ = self.model["generator"](*x)

        # reconstruct the signal from multi-band signal
        if self.config["generator_params"]["out_channels"] > 1:
            y_mb_ = y_
            y_ = self.criterion["pqmf"].synthesis(y_mb_)

        gen_loss = 0.0

        # multi-resolution sfft loss
        if self.config.get("use_stft_loss", True):
            sc_loss, mag_loss = self.criterion["stft"](y_.squeeze(1),
                                                       y.squeeze(1))
            self.total_train_loss[
                "train/spectral_convergence_loss"] += sc_loss.item()
            self.total_train_loss[
                "train/log_stft_magnitude_loss"] += mag_loss.item()
            gen_loss = gen_loss + sc_loss + mag_loss

        # subband multi-resolution stft loss
        if self.config.get("use_subband_stft_loss", False):
            gen_loss *= 0.5  # for balancing with subband stft loss
            y_mb = self.criterion["pqmf"].analysis(y)
            y_mb = y_mb.view(-1, y_mb.size(2))  # (B, C, T) -> (B x C, T)
            y_mb_ = y_mb_.view(-1, y_mb_.size(2))  # (B, C, T) -> (B x C, T)
            sub_sc_loss, sub_mag_loss = self.criterion["sub_stft"](y_mb_, y_mb)
            self.total_train_loss[
                "train/sub_spectral_convergence_loss"] += sub_sc_loss.item()
            self.total_train_loss[
                "train/sub_log_stft_magnitude_loss"] += sub_mag_loss.item()
            gen_loss = gen_loss + 0.5 * (sub_sc_loss + sub_mag_loss)

        if self.config.get("use_feat_match_loss_wav2vec", False):
            wav2vec_loss = self.criterion["l1"](self.criterion["wav2vec"](
                y_.squeeze(1)), self.criterion["wav2vec"](y.squeeze(1)))
            gen_loss = gen_loss + wav2vec_loss
            self.total_train_loss["train/wav2vec_loss"] += wav2vec_loss.item()

        # adversarial loss
        if self.steps > self.config["discriminator_train_start_steps"]:
            p = self.model["discriminator"](y)
            p_ = self.model["discriminator"](y_)
            if not isinstance(p_, list):
                # for standard discriminator
                adv_loss = relativistic_loss(p.detach(), p_, -1)
                self.total_train_loss[
                    "train/adversarial_loss"] += adv_loss.item()
            else:
                # for multi-scale discriminator
                adv_loss = 0.0
                for i in range(len(p_)):
                    adv_loss = adv_loss + relativistic_loss(
                        p[i][-1].detach(), p_[i][-1], -1)
                adv_loss /= (i + 1)
                self.total_train_loss[
                    "train/adversarial_loss"] += adv_loss.item()

                # feature matching loss
                if self.config["use_feat_match_loss"]:
                    # no need to track gradients
                    # with torch.no_grad():
                    #     p = self.model["discriminator"](y)
                    fm_loss = 0.0
                    for i in range(len(p_)):
                        for j in range(len(p_[i]) - 1):
                            fm_loss = fm_loss + self.criterion["l1"](
                                p_[i][j], p[i][j].detach())
                    fm_loss /= (i + 1) * (j + 1)
                    self.total_train_loss[
                        "train/feature_matching_loss"] += fm_loss.item()
                    adv_loss = adv_loss + self.config[
                        "lambda_feat_match"] * fm_loss

            # add adversarial loss to generator loss
            gen_loss = gen_loss + self.config["lambda_adv"] * adv_loss

        self.total_train_loss["train/generator_loss"] += gen_loss.item()

        # update generator
        self.optimizer["generator"].zero_grad()
        gen_loss.backward()
        if self.config["generator_grad_norm"] > 0:
            torch.nn.utils.clip_grad_norm_(
                self.model["generator"].parameters(),
                self.config["generator_grad_norm"])
        self.optimizer["generator"].step()
        self.scheduler["generator"].step()
        self.ema.update()

        #######################
        #    Discriminator    #
        #######################
        if self.steps > self.config["discriminator_train_start_steps"]:
            # re-compute y_ which leads better quality
            # with torch.no_grad():
            #     y_ = self.model["generator"](*x)
            # if self.config["generator_params"]["out_channels"] > 1:
            #     y_ = self.criterion["pqmf"].synthesis(y_)

            # discriminator loss
            # p = self.model["discriminator"](y)
            p_ = self.model["discriminator"](y_.detach())
            if not isinstance(p, list):
                # for standard discriminator
                dis_loss = relativistic_loss(p, p_, 1)
            else:
                # for multi-scale discriminator
                dis_loss = 0.0
                for i in range(len(p_)):
                    dis_loss = dis_loss + relativistic_loss(
                        p[i][-1], p_[i][-1], 1)
                dis_loss /= (i + 1)

            self.total_train_loss["train/discriminator_loss"] += dis_loss.item(
            )

            # update discriminator
            self.optimizer["discriminator"].zero_grad()
            dis_loss.backward()
            if self.config["discriminator_grad_norm"] > 0:
                torch.nn.utils.clip_grad_norm_(
                    self.model["discriminator"].parameters(),
                    self.config["discriminator_grad_norm"])
            self.optimizer["discriminator"].step()
            self.scheduler["discriminator"].step()

        # update counts
        self.steps += 1
        self.tqdm.update(1)
        self._check_train_finish()

    def _train_epoch(self):
        """Train model one epoch."""
        for train_steps_per_epoch, batch in enumerate(
                self.data_loader["train"], 1):
            # train one step
            self._train_step(batch)

            # check interval
            if self.config["rank"] == 0:
                self._check_log_interval()
                self._check_eval_interval()
                self._check_save_interval()

            # check whether training is finished
            if self.finish_train:
                return

        # update
        self.epochs += 1
        self.train_steps_per_epoch = train_steps_per_epoch
        logging.info(
            f"(Steps: {self.steps}) Finished {self.epochs} epoch training "
            f"({self.train_steps_per_epoch} steps per epoch).")

        # needed for shuffle in distributed training
        if self.config["distributed"]:
            self.sampler["train"].set_epoch(self.epochs)

    @torch.no_grad()
    def _eval_step(self, batch):
        """Evaluate model one step."""
        # parse batch
        x, y = batch
        x = tuple([x_.to(self.device) for x_ in x])
        y = y.to(self.device)

        #######################
        #      Generator      #
        #######################
        y_ = self.model["generator"](*x)
        if self.config["generator_params"]["out_channels"] > 1:
            y_mb_ = y_
            y_ = self.criterion["pqmf"].synthesis(y_mb_)

        aux_loss = 0.0

        # multi-resolution sfft loss
        if self.config.get("use_stft_loss", True):
            sc_loss, mag_loss = self.criterion["stft"](y_.squeeze(1),
                                                       y.squeeze(1))
            self.total_eval_loss[
                "eval/spectral_convergence_loss"] += sc_loss.item()
            self.total_eval_loss[
                "eval/log_stft_magnitude_loss"] += mag_loss.item()
            aux_loss = aux_loss + sc_loss + mag_loss

        # subband multi-resolution stft loss
        if self.config.get("use_subband_stft_loss", False):
            aux_loss *= 0.5  # for balancing with subband stft loss
            y_mb = self.criterion["pqmf"].analysis(y)
            y_mb = y_mb.view(-1, y_mb.size(2))  # (B, C, T) -> (B x C, T)
            y_mb_ = y_mb_.view(-1, y_mb_.size(2))  # (B, C, T) -> (B x C, T)
            sub_sc_loss, sub_mag_loss = self.criterion["sub_stft"](y_mb_, y_mb)
            self.total_eval_loss[
                "eval/sub_spectral_convergence_loss"] += sub_sc_loss.item()
            self.total_eval_loss[
                "eval/sub_log_stft_magnitude_loss"] += sub_mag_loss.item()
            aux_loss += 0.5 * (sub_sc_loss + sub_mag_loss)

        if self.config.get("use_feat_match_loss_wav2vec", False):
            wav2vec_loss = self.criterion["l1"](self.criterion["wav2vec"](
                y_.squeeze(1)), self.criterion["wav2vec"](y.squeeze(1)))
            aux_loss = aux_loss + wav2vec_loss
            self.total_eval_loss["eval/wav2vec_loss"] += wav2vec_loss.item()

        # adversarial loss
        p = self.model["discriminator"](y)
        p_ = self.model["discriminator"](y_)
        if not isinstance(p_, list):
            # for standard discriminator
            adv_loss = relativistic_loss(p, p_, -1)
            gen_loss = aux_loss + self.config["lambda_adv"] * adv_loss
        else:
            # for multi-scale discriminator
            adv_loss = 0.0
            for i in range(len(p_)):
                adv_loss += relativistic_loss(p[i][-1], p_[i][-1], -1)
            adv_loss /= (i + 1)
            gen_loss = aux_loss + self.config["lambda_adv"] * adv_loss

            # feature matching loss
            if self.config["use_feat_match_loss"]:
                p = self.model["discriminator"](y)
                fm_loss = 0.0
                for i in range(len(p_)):
                    for j in range(len(p_[i]) - 1):
                        fm_loss += self.criterion["l1"](p_[i][j], p[i][j])
                fm_loss /= (i + 1) * (j + 1)
                self.total_eval_loss[
                    "eval/feature_matching_loss"] += fm_loss.item()
                gen_loss += self.config["lambda_adv"] * self.config[
                    "lambda_feat_match"] * fm_loss

        #######################
        #    Discriminator    #
        #######################
        # discriminator loss
        if not isinstance(p_, list):
            # for standard discriminator
            dis_loss = relativistic_loss(p, p_, 1)
        else:
            # for multi-scale discriminator
            dis_loss = 0.0
            for i in range(len(p_)):
                dis_loss = dis_loss + relativistic_loss(p[i][-1], p_[i][-1], 1)
            dis_loss /= (i + 1)

        # add to total eval loss
        self.total_eval_loss["eval/adversarial_loss"] += adv_loss.item()
        self.total_eval_loss["eval/generator_loss"] += gen_loss.item()
        self.total_eval_loss["eval/discriminator_loss"] += dis_loss.item()

    def _eval_epoch(self):
        """Evaluate model one epoch."""
        logging.info(f"(Steps: {self.steps}) Start evaluation.")
        # change mode
        for key in self.model.keys():
            self.model[key].eval()

        # calculate loss for each batch
        for eval_steps_per_epoch, batch in enumerate(
                tqdm(self.data_loader["dev"], desc="[eval]"), 1):
            # eval one step
            self._eval_step(batch)

            # save intermediate result
            if eval_steps_per_epoch == 1:
                self._genearete_and_save_intermediate_result(batch)

        # average loss
        for key in self.total_eval_loss.keys():
            self.total_eval_loss[key] /= eval_steps_per_epoch
            logging.info(
                f"(Steps: {self.steps}) {key} = {self.total_eval_loss[key]:.4f}."
            )

        # record
        self._write_to_tensorboard(self.total_eval_loss)

        # reset
        self.total_eval_loss = defaultdict(float)

        with self.ema.average_parameters():
            # calculate loss for each batch
            for eval_steps_per_epoch, batch in enumerate(
                    tqdm(self.data_loader["dev"], desc="[eval]"), 1):
                # eval one step
                self._eval_step(batch)

        for key in self.total_eval_loss.keys():
            self.total_eval_loss[key + '_ema'] = self.total_eval_loss[key]
            del self.total_eval_loss[key]

        # average loss
        for key in self.total_eval_loss.keys():
            self.total_eval_loss[key] /= eval_steps_per_epoch
            logging.info(
                f"(Steps: {self.steps}) {key} = {self.total_eval_loss[key]:.4f}."
            )

        # record
        self._write_to_tensorboard(self.total_eval_loss)

        # reset
        self.total_eval_loss = defaultdict(float)

        # restore mode
        for key in self.model.keys():
            self.model[key].train()

    @torch.no_grad()
    def _genearete_and_save_intermediate_result(self, batch):
        """Generate and save intermediate result."""
        # delayed import to avoid error related backend error
        import matplotlib.pyplot as plt

        # generate
        x_batch, y_batch = batch
        x_batch = tuple([x.to(self.device) for x in x_batch])
        y_batch = y_batch.to(self.device)
        y_batch_ = self.model["generator"](*x_batch)
        if self.config["generator_params"]["out_channels"] > 1:
            y_batch_ = self.criterion["pqmf"].synthesis(y_batch_)

        with self.ema.average_parameters():
            y_batch_ema = self.model["generator"](*x_batch)
            if self.config["generator_params"]["out_channels"] > 1:
                y_batch_ema = self.criterion["pqmf"].synthesis(y_batch_ema)

        # check directory
        dirname = os.path.join(self.config["outdir"],
                               f"predictions/{self.steps}steps")
        if not os.path.exists(dirname):
            os.makedirs(dirname)

        for idx, (y, y_,
                  y_ema) in enumerate(zip(y_batch, y_batch_, y_batch_ema), 1):
            # convert to ndarray
            y, y_, y_ema = y.view(-1).cpu().numpy(), y_.view(
                -1).cpu().numpy(), y_ema.view(-1).cpu().numpy()

            # plot figure and save it
            figname = os.path.join(dirname, f"{idx}.png")
            plt.subplot(3, 1, 1)
            plt.plot(y)
            plt.title("groundtruth speech")
            plt.subplot(3, 1, 2)
            plt.plot(y_)
            plt.title(f"generated speech @ {self.steps} steps")
            plt.subplot(3, 1, 3)
            plt.plot(y_ema)
            plt.title(f"generated speech ema @ {self.steps} steps")
            plt.tight_layout()
            plt.savefig(figname)
            plt.close()

            # save as wavfile
            y = np.clip(y, -1, 1)
            y_ = np.clip(y_, -1, 1)
            y_ema = np.clip(y_, -1, 1)
            sf.write(figname.replace(".png", "_ref.wav"), y,
                     self.config["sampling_rate"], "PCM_16")
            sf.write(figname.replace(".png", "_gen.wav"), y_,
                     self.config["sampling_rate"], "PCM_16")
            sf.write(figname.replace(".png", "_gen_ema.wav"), y_ema,
                     self.config["sampling_rate"], "PCM_16")

            if idx >= self.config["num_save_intermediate_results"]:
                break

    def _write_to_tensorboard(self, loss):
        """Write to tensorboard."""
        for key, value in loss.items():
            self.writer.add_scalar(key, value, self.steps)

    def _check_save_interval(self):
        if self.steps % self.config["save_interval_steps"] == 0:
            self.save_checkpoint(
                os.path.join(self.config["outdir"],
                             f"checkpoint-{self.steps}steps.pkl"))
            logging.info(
                f"Successfully saved checkpoint @ {self.steps} steps.")

    def _check_eval_interval(self):
        if self.steps % self.config["eval_interval_steps"] == 0:
            self._eval_epoch()

    def _check_log_interval(self):
        if self.steps % self.config["log_interval_steps"] == 0:
            for key in self.total_train_loss.keys():
                self.total_train_loss[key] /= self.config["log_interval_steps"]
                logging.info(
                    f"(Steps: {self.steps}) {key} = {self.total_train_loss[key]:.4f}."
                )
            self._write_to_tensorboard(self.total_train_loss)

            # reset
            self.total_train_loss = defaultdict(float)

    def _check_train_finish(self):
        if self.steps >= self.config["train_max_steps"]:
            self.finish_train = True
Exemplo n.º 9
0
def test_state_dict(decay, use_num_updates, explicit_params):
    model = torch.nn.Linear(10, 2, bias=False)
    with torch.no_grad():
        model.weight.fill_(0.0)
    ema = ExponentialMovingAverage(
        model.parameters(),
        decay=decay,
        use_num_updates=False
    )
    state_dict = copy.deepcopy(ema.state_dict())

    model2 = torch.nn.Linear(10, 2, bias=False)
    ema2 = ExponentialMovingAverage(model2.parameters(), decay=0.0)
    ema2.load_state_dict(state_dict)
    assert ema2.decay == decay
    assert torch.allclose(ema2.shadow_params[0], ema.shadow_params[0])

    with torch.no_grad():
        model2.weight.fill_(1.0)
    if explicit_params:
        ema2.update(model2.parameters())
    else:
        ema2.update()
    assert torch.all(model2.weight == 1.0), "ema.update changed model weights"

    ema.load_state_dict(ema2.state_dict())

    if explicit_params:
        ema.copy_to(model.parameters())
    else:
        ema.copy_to()
    assert torch.allclose(
        model.weight,
        torch.full(size=(1,), fill_value=(1.0 - decay))
    ), "average was wrong"
Exemplo n.º 10
0
def test_val_error(decay, use_num_updates, explicit_params):
    """Confirm that EMA validation error is lower than raw validation error."""
    torch.manual_seed(0)
    x_train = torch.rand((100, 10))
    y_train = torch.rand(100).round().long()
    x_val = torch.rand((100, 10))
    y_val = torch.rand(100).round().long()
    model = torch.nn.Linear(10, 2)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
    ema = ExponentialMovingAverage(
        model.parameters(),
        decay=decay,
        use_num_updates=use_num_updates
    )

    # Train for a few epochs
    model.train()
    for _ in range(20):
        logits = model(x_train)
        loss = torch.nn.functional.cross_entropy(logits, y_train)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if explicit_params:
            ema.update(model.parameters())
        else:
            ema.update()

    # Validation: original
    model.eval()
    logits = model(x_val)
    loss_orig = torch.nn.functional.cross_entropy(logits, y_val)
    print(f"Original loss: {loss_orig}")

    # Validation: with EMA
    # First save original parameters before replacing with EMA version
    if explicit_params:
        ema.store(model.parameters())
    else:
        ema.store()
    # Copy EMA parameters to model
    if explicit_params:
        ema.copy_to(model.parameters())
    else:
        ema.copy_to()
    logits = model(x_val)
    loss_ema = torch.nn.functional.cross_entropy(logits, y_val)

    print(f"EMA loss: {loss_ema}")
    assert loss_ema < loss_orig, "EMA loss wasn't lower"

    # Test restore
    if explicit_params:
        ema.restore(model.parameters())
    else:
        ema.restore()
    model.eval()
    logits = model(x_val)
    loss_orig2 = torch.nn.functional.cross_entropy(logits, y_val)
    assert torch.allclose(loss_orig, loss_orig2), \
        "Restored model wasn't the same as stored model"