Ejemplo n.º 1
0
 def __init__(self, model, dataloader, optimizer, hparams):
     self.model = model
     self.dataloader = dataloader
     self.hparams = hparams
     self.one_epoch_train = dataloader['train'].size // hparams.batch_size
     self.one_epoch_valid = dataloader['valid'].size // hparams.batch_size
     self.placeholder = dict()
     self.optimizer = optimizer
     self.monitor = ProgressMeter(self.one_epoch_train,
                                  hparams.output_path,
                                  quiet=hparams.comm.rank > 0)
     hparams.save(Path(hparams.output_path) / 'settings.json')
Ejemplo n.º 2
0
    def __init__(self, gen, gen_optim, dis, dis_optim, dataloader, rng, hp):
        self.gen = gen
        self.gen_optim = gen_optim
        self.dis = dis
        self.dis_optim = dis_optim

        self.dataloader = dataloader
        self.rng = rng
        self.hp = hp
        self.one_epoch_train = dataloader['train'].size // hp.batch_size

        self.placeholder = dict()
        self.monitor = ProgressMeter(self.one_epoch_train,
                                     hp.output_path,
                                     quiet=hp.comm.rank > 0)
        hp.save(os.path.join(hp.output_path, 'settings.json'))
Ejemplo n.º 3
0
def train(train_loader, model, criterion, optimizer, epoch, args, writer):
    batch_time = AverageMeter("Time", ":6.3f")
    data_time = AverageMeter("Data", ":6.3f")
    losses = AverageMeter("Loss", ":.3f")
    top1 = AverageMeter("Acc@1", ":6.2f")
    top5 = AverageMeter("Acc@5", ":6.2f")
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1, top5],
        prefix=f"Epoch: [{epoch}]",
    )

    # switch to train mode
    model.train()

    batch_size = train_loader.batch_size
    num_batches = len(train_loader)
    end = time.time()
    l1reg = SubnetL1RegLoss(temperature=1.0)

    for i, (images, target) in tqdm.tqdm(
        enumerate(train_loader), ascii=True, total=len(train_loader)
    ):
        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpu is not None:
            images = images.cuda(args.gpu, non_blocking=True)

        target = target.cuda(args.gpu, non_blocking=True)

        # compute output
        output = model(images)

        loss = criterion(output, target)
        regloss = l1reg(model) * 1e-8

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1.item(), images.size(0))
        top5.update(acc5.item(), images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        regloss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            t = (num_batches * epoch + i) * batch_size
            print("HERE", regloss)
            progress.display(i)
            progress.write_to_tensorboard(writer, prefix="train", global_step=t)

    return top1.avg, top5.avg
Ejemplo n.º 4
0
def validate(val_loader, model, criterion, args, writer, epoch):
    batch_time = AverageMeter("Time", ":6.3f", write_val=False)
    losses = AverageMeter("Loss", ":.3f", write_val=False)
    top1 = AverageMeter("Acc@1", ":6.2f", write_val=False)
    top5 = AverageMeter("Acc@5", ":6.2f", write_val=False)
    progress = ProgressMeter(
        len(val_loader), [batch_time, losses, top1, top5], prefix="Test: "
    )

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in tqdm.tqdm(
            enumerate(val_loader), ascii=True, total=len(val_loader)
        ):
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)

            target = target.cuda(args.gpu, non_blocking=True)

            # compute output
            output = model(images)

            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1.item(), images.size(0))
            top5.update(acc5.item(), images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        progress.display(len(val_loader))

        if writer is not None:
            progress.write_to_tensorboard(writer, prefix="test", global_step=epoch)

    return top1.avg, top5.avg
Ejemplo n.º 5
0
class TacotronTrainer(ABC):
    r"""Trainer for Tacotron.

    Args:
        model (model.module.Module): Tacotron model.
        dataloader (dict): A dataloader.
        optimizer (Optimizer): An optimizer used to update the parameters.
        hparams (HParams): Hyper-parameters.
    """
    def __init__(self, model, dataloader, optimizer, hparams):
        self.model = model
        self.dataloader = dataloader
        self.hparams = hparams
        self.one_epoch_train = dataloader['train'].size // hparams.batch_size
        self.one_epoch_valid = dataloader['valid'].size // hparams.batch_size
        self.placeholder = dict()
        self.optimizer = optimizer
        self.monitor = ProgressMeter(self.one_epoch_train,
                                     hparams.output_path,
                                     quiet=hparams.comm.rank > 0)
        hparams.save(Path(hparams.output_path) / 'settings.json')

    def update_graph(self, key='train'):
        r"""Builds the graph and update the placeholder.

        Args:
            key (str, optional): Type of computational graph. Defaults to 'train'.
        """
        assert key in ('train', 'valid')

        self.model.training = key != 'valid'
        hp = self.hparams

        # define input variables
        x_txt = nn.Variable([hp.batch_size, hp.text_len])
        x_mel = nn.Variable([hp.batch_size, hp.n_frames, hp.n_mels * hp.r])
        t_mag = nn.Variable(
            [hp.batch_size, hp.n_frames * hp.r, hp.n_fft // 2 + 1])

        # output variables
        o_mel, o_mag, o_att = self.model(x_txt, x_mel)
        o_mel = o_mel.apply(persistent=True)
        o_mag = o_mag.apply(persistent=True)
        o_att = o_att.apply(persistent=True)

        # loss functions
        def criteria(x, t):
            return F.mean(F.absolute_error(x, t))

        n_prior = int(3000 / (hp.sr * 0.5) * (hp.n_fft // 2 + 1))

        l_mel = criteria(o_mel, x_mel).apply(persistent=True)
        l_mag = 0.5*criteria(o_mag, t_mag) + 0.5 * \
            criteria(o_mag[..., :n_prior], t_mag[..., :n_prior])
        l_mag.persistent = True

        l_net = (l_mel + l_mag).apply(persistent=True)

        self.placeholder[key] = {
            'x_mel': x_mel,
            't_mag': t_mag,
            'x_txt': x_txt,
            'o_mel': o_mel,
            'o_mag': o_mag,
            'o_att': o_att,
            'l_mel': l_mel,
            'l_mag': l_mag,
            'l_net': l_net
        }

    def callback_on_start(self):
        self.update_graph('train')
        params = self.model.get_parameters(grad_only=True)
        self.optimizer.set_parameters(params)
        self.update_graph('valid')
        self.loss = nn.NdArray.from_numpy_array(np.zeros((1, )))
        if self.hparams.comm.n_procs > 1:
            self._grads = [x.grad for x in params.values()]

    def run(self):
        r"""Run the training process."""
        self.callback_on_start()
        for cur_epoch in range(self.hparams.epoch):
            self.monitor.reset()
            lr = self.optimizer.get_learning_rate()
            self.monitor.info(f'Running epoch={cur_epoch}\tlr={lr:.5f}\n')
            self.cur_epoch = cur_epoch
            for i in range(self.one_epoch_train):
                self.train_on_batch()
                if i % (self.hparams.print_frequency) == 0:
                    self.monitor.display(
                        i, ['train/l_mel', 'train/l_mag', 'train/l_net'])
            for i in trange(self.one_epoch_valid,
                            disable=self.hparams.comm.rank > 0):
                self.valid_on_batch()
            self.callback_on_epoch_end()
        self.callback_on_finish()
        self.monitor.close()

    def train_on_batch(self):
        r"""Updates the model parameters."""
        batch_size = self.hparams.batch_size
        p, dl = self.placeholder['train'], self.dataloader['train']
        self.optimizer.zero_grad()
        if self.hparams.comm.n_procs > 1:
            self.hparams.event.default_stream_synchronize()
        p['x_mel'].d, p['t_mag'].d, p['x_txt'].d = dl.next()
        p['l_net'].forward(clear_no_need_grad=True)
        p['l_net'].backward(clear_buffer=True)
        self.monitor.update('train/l_mel', p['l_mel'].d.copy(), batch_size)
        self.monitor.update('train/l_mag', p['l_mag'].d.copy(), batch_size)
        self.monitor.update('train/l_net', p['l_net'].d.copy(), batch_size)
        if self.hparams.comm.n_procs > 1:
            self.hparams.comm.all_reduce(self._grads,
                                         division=True,
                                         inplace=False)
            self.hparams.event.add_default_stream_event()
        self.optimizer.update()

    def valid_on_batch(self):
        r"""Performs validation."""
        batch_size = self.hparams.batch_size
        p, dl = self.placeholder['valid'], self.dataloader['valid']
        if self.hparams.comm.n_procs > 1:
            self.hparams.event.default_stream_synchronize()
        p['x_mel'].d, p['t_mag'].d, p['x_txt'].d = dl.next()
        p['l_net'].forward(clear_buffer=True)
        self.loss.data += p['l_net'].d.copy() * batch_size
        self.monitor.update('valid/l_mel', p['l_mel'].d.copy(), batch_size)
        self.monitor.update('valid/l_mag', p['l_mag'].d.copy(), batch_size)
        self.monitor.update('valid/l_net', p['l_net'].d.copy(), batch_size)

    def callback_on_epoch_end(self):
        if self.hparams.comm.n_procs > 1:
            self.hparams.comm.all_reduce([self.loss],
                                         division=True,
                                         inplace=False)
        self.loss.data /= self.dataloader['valid'].size
        if self.hparams.comm.rank == 0:
            p, hp = self.placeholder['valid'], self.hparams
            self.monitor.info(f'valid/loss={self.loss.data[0]:.5f}\n')
            if self.cur_epoch % hp.epochs_per_checkpoint == 0:
                path = Path(
                    hp.output_path) / 'output' / f'epoch_{self.cur_epoch}'
                path.mkdir(parents=True, exist_ok=True)
                # write attention and spectrogram outputs
                for k in ('o_att', 'o_mel', 'o_mag'):
                    p[k].forward(clear_buffer=True)
                    data = p[k].d[0].copy()
                    save_image(data=data.reshape(
                        (-1, hp.n_mels)).T if k == 'o_mel' else data.T,
                               path=path / (k + '.png'),
                               label=('Decoder timestep',
                                      'Encoder timestep') if k == 'o_att' else
                               ('Frame', 'Channel'),
                               title={
                                   'o_att': 'Attention',
                                   'o_mel': 'Mel spectrogram',
                                   'o_mag': 'Spectrogram'
                               }[k],
                               figsize=(6, 5) if k == 'o_att' else (6, 3))
                wave = synthesize_from_spec(p['o_mag'].d[0].copy(), hp)
                wavfile.write(path / 'sample.wav', rate=hp.sr, data=wave)
                self.model.save_parameters(
                    str(path / f'model_{self.cur_epoch}.h5'))
        self.loss.zero()

    def callback_on_finish(self):
        r"""Calls this on finishing the run method."""
        if self.hparams.comm.rank == 0:
            path = str(Path(self.hparams.output_path) / 'model.h5')
            self.model.save_parameters(path)
Ejemplo n.º 6
0
def main_worker(args):
    args.gpu = None
    train, validate, modifier = get_trainer(args)

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # create model and optimizer
    model = get_model(args)
    model = set_gpu(args, model)

    if args.pretrained:
        pretrained(args, model)

    optimizer = get_optimizer(args, model)
    data = get_dataset(args)
    lr_policy = get_policy(args.lr_policy)(optimizer, args)

    if args.label_smoothing is None:
        criterion = nn.CrossEntropyLoss().cuda()
    else:
        criterion = LabelSmoothing(smoothing=args.label_smoothing)

    # optionally resume from a checkpoint
    best_acc1 = 0.0
    best_acc5 = 0.0
    best_train_acc1 = 0.0
    best_train_acc5 = 0.0

    if args.resume:
        best_acc1 = resume(args, model, optimizer)

    # Data loading code
    if args.evaluate:
        acc1, acc5 = validate(data.val_loader,
                              model,
                              criterion,
                              args,
                              writer=None,
                              epoch=args.start_epoch)

        return

    # Set up directories
    run_base_dir, ckpt_base_dir, log_base_dir = get_directories(args)
    args.ckpt_base_dir = ckpt_base_dir

    writer = SummaryWriter(log_dir=log_base_dir)
    epoch_time = AverageMeter("epoch_time", ":.4f", write_avg=False)
    validation_time = AverageMeter("validation_time", ":.4f", write_avg=False)
    train_time = AverageMeter("train_time", ":.4f", write_avg=False)
    progress_overall = ProgressMeter(1,
                                     [epoch_time, validation_time, train_time],
                                     prefix="Overall Timing")

    end_epoch = time.time()
    args.start_epoch = args.start_epoch or 0
    acc1 = None

    # Save the initial state
    save_checkpoint(
        {
            "epoch": 0,
            "arch": args.arch,
            "state_dict": model.state_dict(),
            "best_acc1": best_acc1,
            "best_acc5": best_acc5,
            "best_train_acc1": best_train_acc1,
            "best_train_acc5": best_train_acc5,
            "optimizer": optimizer.state_dict(),
            "curr_acc1": acc1 if acc1 else "Not evaluated",
        },
        False,
        filename=ckpt_base_dir / f"initial.state",
        save=False,
    )

    # Start training
    for epoch in range(args.start_epoch, args.epochs):
        lr_policy(epoch, iteration=None)
        modifier(args, epoch, model)

        cur_lr = get_lr(optimizer)

        # train for one epoch
        start_train = time.time()
        train_acc1, train_acc5 = train(data.train_loader,
                                       model,
                                       criterion,
                                       optimizer,
                                       epoch,
                                       args,
                                       writer=writer)
        train_time.update((time.time() - start_train) / 60)

        # evaluate on validation set
        start_validation = time.time()
        acc1, acc5 = validate(data.val_loader, model, criterion, args, writer,
                              epoch)
        validation_time.update((time.time() - start_validation) / 60)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        best_acc5 = max(acc5, best_acc5)
        best_train_acc1 = max(train_acc1, best_train_acc1)
        best_train_acc5 = max(train_acc5, best_train_acc5)

        save = ((epoch % args.save_every) == 0) and args.save_every > 0
        if is_best or save or epoch == args.epochs - 1:
            if is_best:
                print(
                    f"==> New best, saving at {ckpt_base_dir / 'model_best.pth'}"
                )

            save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "arch": args.arch,
                    "state_dict": model.state_dict(),
                    "best_acc1": best_acc1,
                    "best_acc5": best_acc5,
                    "best_train_acc1": best_train_acc1,
                    "best_train_acc5": best_train_acc5,
                    "optimizer": optimizer.state_dict(),
                    "curr_acc1": acc1,
                    "curr_acc5": acc5,
                },
                is_best,
                filename=ckpt_base_dir / f"epoch_{epoch}.state",
                save=save,
            )

        epoch_time.update((time.time() - end_epoch) / 60)
        progress_overall.display(epoch)
        progress_overall.write_to_tensorboard(writer,
                                              prefix="diagnostics",
                                              global_step=epoch)

        if args.conv_type == "SampleSubnetConv":
            count = 0
            sum_pr = 0.0
            for n, m in model.named_modules():
                if isinstance(m, SampleSubnetConv):
                    # avg pr across 10 samples
                    pr = 0.0
                    for _ in range(10):
                        pr += ((torch.rand_like(m.clamped_scores) >=
                                m.clamped_scores).float().mean().item())
                    pr /= 10.0
                    writer.add_scalar("pr/{}".format(n), pr, epoch)
                    sum_pr += pr
                    count += 1

            args.prune_rate = sum_pr / count
            writer.add_scalar("pr/average", args.prune_rate, epoch)

        writer.add_scalar("test/lr", cur_lr, epoch)
        end_epoch = time.time()

    write_result_to_csv(
        best_acc1=best_acc1,
        best_acc5=best_acc5,
        best_train_acc1=best_train_acc1,
        best_train_acc5=best_train_acc5,
        prune_rate=args.prune_rate,
        curr_acc1=acc1,
        curr_acc5=acc5,
        base_config=args.config,
        name=args.name,
    )
Ejemplo n.º 7
0
class Trainer:
    r"""Trainer is a basic class for training a model."""
    def __init__(self, gen, gen_optim, dis, dis_optim, dataloader, rng, hp):
        self.gen = gen
        self.gen_optim = gen_optim
        self.dis = dis
        self.dis_optim = dis_optim

        self.dataloader = dataloader
        self.rng = rng
        self.hp = hp
        self.one_epoch_train = dataloader['train'].size // hp.batch_size

        self.placeholder = dict()
        self.monitor = ProgressMeter(self.one_epoch_train,
                                     hp.output_path,
                                     quiet=hp.comm.rank > 0)
        hp.save(os.path.join(hp.output_path, 'settings.json'))

    def update_graph(self, key='train'):
        r"""Builds the graph and update the placeholder.

        Args:
            training (bool, optional): Type of the graph. Defaults to `train`.
        """
        assert key in ('train', 'valid')

        self.gen.training = key == 'train'
        self.dis.training = key == 'train'
        hp = self.hp

        def data_aug(v):
            v = random_flip(v)
            v = random_scaling(v, hp.scale_low, hp.scale_high)
            return v

        # define input variables
        input_x = nn.Variable((hp.batch_size, 1, hp.segment_length))
        input_y = nn.Variable((hp.batch_size, 1, hp.segment_length))
        label_x = nn.Variable((hp.batch_size, 1))
        label_y = nn.Variable((hp.batch_size, 1))

        x_aug = data_aug(input_x)
        r_jitter_x = random_jitter(x_aug, hp.max_jitter_steps)

        x_real_con = self.gen.encode(x_aug)
        s_real, s_mu, s_logvar = self.gen.embed(data_aug(input_x))
        x_real = self.gen.decode(x_real_con, s_real)

        r_fake = self.gen.embed(data_aug(input_y))[0]
        x_fake = self.gen.decode(x_real_con, r_fake)
        x_fake_con = self.gen.encode(random_flip(x_fake))

        dis_real_x = self.dis(data_aug(input_x), label_x)
        dis_fake_x = self.dis(data_aug(x_fake), label_y)

        # ------------------------------ Discriminator -----------------------
        d_loss = (self.dis.adversarial_loss(dis_real_x, 1.0) +
                  self.dis.adversarial_loss(dis_fake_x, 0.0))
        # --------------------------------------------------------------------

        # -------------------------------- Generator -------------------------
        g_loss_avd = self.dis.adversarial_loss(self.dis(x_fake, label_y), 1.0)
        g_loss_con = self.dis.preservation_loss(x_fake_con, x_real_con)
        g_loss_kld = self.gen.kl_loss(s_mu, s_logvar)
        g_loss_rec = (self.dis.perceptual_loss(x_real, r_jitter_x) +
                      self.dis.spectral_loss(x_real, r_jitter_x))
        g_loss = (g_loss_avd + hp.lambda_con * g_loss_con +
                  hp.lambda_rec * g_loss_rec + hp.lambda_kld * g_loss_kld)

        # -------------------------------------------------------------------
        set_persistent_all(g_loss_con, g_loss_avd, g_loss, d_loss, x_fake,
                           g_loss_kld, g_loss_rec)

        self.placeholder[key] = dict(
            input_x=input_x,
            label_x=label_x,
            input_y=input_y,
            label_y=label_y,
            x_fake=x_fake,
            d_loss=d_loss,
            g_loss_avd=g_loss_avd,
            g_loss_con=g_loss_con,
            g_loss_rec=g_loss_rec,
            g_loss_kld=g_loss_kld,
            g_loss=g_loss,
        )

    def callback_on_start(self):
        self.cur_epoch = 0
        checkpoint = Path(self.hp.output_path) / 'checkpoint.json'
        if checkpoint.is_file():
            self.load_checkpoint_model(str(checkpoint))

        self.update_graph('train')
        params = self.gen.get_parameters(grad_only=True)
        self.gen_optim.set_parameters(params)

        dis_params = self.dis.get_parameters(grad_only=True)
        self.dis_optim.set_parameters(dis_params)

        if checkpoint.is_file():
            self.load_checkpoint_optim(str(checkpoint))

        self._grads = [x.grad for x in params.values()]
        self._discs = [x.grad for x in dis_params.values()]

        self.log_variables = [
            'g_loss_avd',
            'g_loss_con',
            'g_loss_rec',
            'g_loss_kld',
            'd_loss',
        ]

    def run(self):
        r"""Run the training process."""
        self.callback_on_start()

        for cur_epoch in range(self.cur_epoch + 1, self.hp.epoch + 1):
            self.monitor.reset()
            lr = self.gen_optim.get_learning_rate()
            self.monitor.info(f'Running epoch={cur_epoch}\tlr={lr:.5f}\n')
            self.cur_epoch = cur_epoch

            for i in range(self.one_epoch_train):
                self.train_on_batch(i)
                if i % (self.hp.print_frequency) == 0:
                    self.monitor.display(i, self.log_variables)

            self.callback_on_epoch_end()

        self.callback_on_finish()
        self.monitor.close()

    def _zero_grads(self):
        self.gen_optim.zero_grad()
        self.dis_optim.zero_grad()

    def getdata(self, key='train'):
        data, label = self.dataloader[key].next()
        idx = self.rng.permutation(self.hp.batch_size)
        return data, label, data[idx], label[idx]

    def train_on_batch(self, i):
        r"""Updates the model parameters."""
        hp = self.hp
        bs, p = hp.batch_size, self.placeholder['train']
        p['input_x'].d, p['label_x'].d, p['input_y'].d, p['label_y'].d = \
            self.getdata('train')

        # ----------------------------- train discriminator ------------------
        if i % hp.n_D_updates == 0:
            self._zero_grads()
            p['x_fake'].need_grad = False
            p['d_loss'].forward()
            p['d_loss'].backward(clear_buffer=True)
            self.monitor.update('d_loss', p['d_loss'].d.copy(), bs)
            hp.comm.all_reduce(self._discs, division=True, inplace=False)
            self.dis_optim.update()
            p['x_fake'].need_grad = True
        # ---------------------------------------------------------------------

        # ------------------------------ train generator ----------------------
        self._zero_grads()
        p['g_loss'].forward()
        p['g_loss'].backward(clear_buffer=True)
        self.monitor.update('g_loss', p['g_loss'].d.copy(), bs)
        self.monitor.update('g_loss_avd', p['g_loss_avd'].d.copy(), bs)
        self.monitor.update('g_loss_con', p['g_loss_con'].d.copy(), bs)
        self.monitor.update('g_loss_rec', p['g_loss_rec'].d.copy(), bs)
        self.monitor.update('g_loss_kld', p['g_loss_kld'].d.copy(), bs)
        hp.comm.all_reduce(self._grads, division=True, inplace=False)
        self.gen_optim.update()
        # -------------------------------------------------------------------------

    def callback_on_epoch_end(self):
        hp = self.hp
        if (hp.comm.rank == 0):
            path = Path(hp.output_path) / 'artifacts'
            path.joinpath('states').mkdir(parents=True, exist_ok=True)
            path.joinpath('samples').mkdir(parents=True, exist_ok=True)
            self.save_checkpoint(path / 'states')
            self.write_samples(path / 'samples')
            if self.cur_epoch % hp.epochs_per_checkpoint == 0:
                path = path / f"epoch_{self.cur_epoch}"
                path.mkdir(parents=True, exist_ok=True)
                self.gen.save_parameters(str(path / 'model.h5'))
                self.dis.save_parameters(str(path / 'cls.h5'))

    def callback_on_finish(self):
        if self.hp.comm.rank == 0:
            path = Path(self.hp.output_path)
            self.gen.save_parameters(str(path / 'model.h5'))
            self.dis.save_parameters(str(path / 'cls.h5'))

    def save_checkpoint(self, path):
        r"""Save the current states of the trainer."""
        if self.hp.comm.rank == 0:
            path = Path(path)
            self.gen.save_parameters(str(path / 'model.h5'))
            self.dis.save_parameters(str(path / 'cls.h5'))
            self.gen_optim.save_states(str(path / 'gen_optim.h5'))
            self.dis_optim.save_states(str(path / 'dis_optim.h5'))
            with open(Path(self.hp.output_path) / 'checkpoint.json', 'w') as f:
                json.dump(
                    dict(
                        cur_epoch=self.cur_epoch,
                        params_path=str(path),
                        gen_optim_n_iters=self.gen_optim._iter,
                        dis_optim_n_iters=self.dis_optim._iter,
                    ), f)
            self.monitor.info(f"Checkpoint saved: {str(path)}\n")

    def load_checkpoint_model(self, checkpoint):
        r"""Load the last states of the trainer."""
        with open(checkpoint, 'r') as file:
            info = json.load(file)
            path = Path(info['params_path'])
        self.gen.load_parameters(str(path / 'model.h5'), raise_if_missing=True)
        self.dis.load_parameters(str(path / 'cls.h5'), raise_if_missing=True)

    def load_checkpoint_optim(self, checkpoint):
        r"""Load the last states of the trainer."""
        with open(checkpoint, 'r') as file:
            info = json.load(file)
            path = Path(info['params_path'])
            self.gen_optim._iter = info['gen_optim_n_iters']
            self.dis_optim._iter = info['dis_optim_n_iters']
            self.cur_epoch = info['cur_epoch']
        self.gen_optim.load_states(str(path / 'gen_optim.h5'))
        self.dis_optim.load_states(str(path / 'dis_optim.h5'))

    def write_samples(self, path):
        r"""write a few samples."""
        hp = self.hp
        p = self.placeholder['train']
        X, Z = p['input_x'].d.copy(), p['x_fake'].d.copy()
        for i, (x, z) in enumerate(zip(X, Z)):
            sf.write(str(path / f'input_{i}.wav'), x[0], hp.sr)
            sf.write(str(path / f'convert_{i}.wav'), z[0], hp.sr)