def test_vae(args, model, test_loader, device, epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        if args.dataset == 'ljspeech':
            for step, (x, y, c, g,
                       input_lengths) in tqdm(enumerate(test_loader)):
                # Prepare data
                x, y = x.to(device), y.to(device)
                input_lengths = input_lengths.to(device)
                c = c.to(device) if c is not None else None
                g = g.to(device) if g is not None else None
                c = c.unsqueeze(1)
                x_tilde, kl_d = model(c)
                target = torch.zeros(c.size(0), c.size(1), c.size(2),
                                     c.size(3))
                target[:, :, :, :x_tilde.size(3)] = x_tilde
                target = target.to(device)
                loss = mse_loss(target, c, kl_d)
                test_loss += loss.item()
                # if batch_idx == 0:
                # n = min(data.size(0), 8)
                # comparison = torch.cat([data[:n],
                # recon_batch.view(args.batch_size, 1, 28, 28)[:n]])
                # save_image(.cpu(),
                # './results/reconstruction_' + str(epoch) + '.png', nrow=n)
        else:
            for i, (data, _) in enumerate(test_loader):
                data = data.to(device)
                recon_batch, kl_d = model(data)
                loss = mse_loss(recon_batch, data, kl_d)
                test_loss += loss.item()
                # if i == 0:
                # n = min(data.size(0), 8)
                # comparison = torch.cat([data[:n],
                # recon_batch.view(args.batch_size, 1, 28, 28)[:n]])
                # save_image(comparison.cpu(),
                # './results/reconstruction_' + str(args.model)  + str(epoch) + '.png', nrow=n)

        test_loss /= len(test_loader.dataset)
        print('====> Test set loss: {:.4f}'.format(test_loss))
示例#2
0
def train():
    # Check NNabla version
    if utils.get_nnabla_version_integer() < 11900:
        raise ValueError(
            'Please update the nnabla version to v1.19.0 or latest version since memory efficiency of core engine is improved in v1.19.0'
        )

    parser, args = get_train_args()

    # Get context.
    ctx = get_extension_context(args.context, device_id=args.device_id)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)
    ext = import_extension_module(args.context)

    # Monitors
    # setting up monitors for logging
    monitor_path = args.output
    monitor = Monitor(monitor_path)

    monitor_best_epoch = MonitorSeries('Best epoch', monitor, interval=1)
    monitor_traing_loss = MonitorSeries('Training loss', monitor, interval=1)
    monitor_validation_loss = MonitorSeries('Validation loss',
                                            monitor,
                                            interval=1)
    monitor_lr = MonitorSeries('learning rate', monitor, interval=1)
    monitor_time = MonitorTimeElapsed("training time per iteration",
                                      monitor,
                                      interval=1)

    if comm.rank == 0:
        print("Mixing coef. is {}, i.e., MDL = {}*TD-Loss + FD-Loss".format(
            args.mcoef, args.mcoef))
        if not os.path.isdir(args.output):
            os.makedirs(args.output)

    # Initialize DataIterator for MUSDB.
    train_source, valid_source, args = load_datasources(parser, args)

    train_iter = data_iterator(train_source,
                               args.batch_size,
                               RandomState(args.seed),
                               with_memory_cache=False,
                               with_file_cache=False)

    valid_iter = data_iterator(valid_source,
                               1,
                               RandomState(args.seed),
                               with_memory_cache=False,
                               with_file_cache=False)

    if comm.n_procs > 1:
        train_iter = train_iter.slice(rng=None,
                                      num_of_slices=comm.n_procs,
                                      slice_pos=comm.rank)

        valid_iter = valid_iter.slice(rng=None,
                                      num_of_slices=comm.n_procs,
                                      slice_pos=comm.rank)

    # Calculate maxiter per GPU device.
    max_iter = int((train_source._size // args.batch_size) // comm.n_procs)
    weight_decay = args.weight_decay * comm.n_procs

    print("max_iter", max_iter)

    # Calculate the statistics (mean and variance) of the dataset
    scaler_mean, scaler_std = utils.get_statistics(args, train_source)

    max_bin = utils.bandwidth_to_max_bin(train_source.sample_rate, args.nfft,
                                         args.bandwidth)

    unmix = OpenUnmix_CrossNet(input_mean=scaler_mean,
                               input_scale=scaler_std,
                               nb_channels=args.nb_channels,
                               hidden_size=args.hidden_size,
                               n_fft=args.nfft,
                               n_hop=args.nhop,
                               max_bin=max_bin)

    # Create input variables.
    mixture_audio = nn.Variable([args.batch_size] +
                                list(train_source._get_data(0)[0].shape))
    target_audio = nn.Variable([args.batch_size] +
                               list(train_source._get_data(0)[1].shape))

    vmixture_audio = nn.Variable(
        [1] + [2, valid_source.sample_rate * args.valid_dur])
    vtarget_audio = nn.Variable([1] +
                                [8, valid_source.sample_rate * args.valid_dur])

    # create training graph
    mix_spec, M_hat, pred = unmix(mixture_audio)
    Y = Spectrogram(*STFT(target_audio, n_fft=unmix.n_fft, n_hop=unmix.n_hop),
                    mono=(unmix.nb_channels == 1))
    loss_f = mse_loss(mix_spec, M_hat, Y)
    loss_t = sdr_loss(mixture_audio, pred, target_audio)
    loss = args.mcoef * loss_t + loss_f
    loss.persistent = True

    # Create Solver and set parameters.
    solver = S.Adam(args.lr)
    solver.set_parameters(nn.get_parameters())

    # create validation graph
    vmix_spec, vM_hat, vpred = unmix(vmixture_audio, test=True)
    vY = Spectrogram(*STFT(vtarget_audio, n_fft=unmix.n_fft,
                           n_hop=unmix.n_hop),
                     mono=(unmix.nb_channels == 1))
    vloss_f = mse_loss(vmix_spec, vM_hat, vY)
    vloss_t = sdr_loss(vmixture_audio, vpred, vtarget_audio)
    vloss = args.mcoef * vloss_t + vloss_f
    vloss.persistent = True

    # Initialize Early Stopping
    es = utils.EarlyStopping(patience=args.patience)

    # Initialize LR Scheduler (ReduceLROnPlateau)
    lr_scheduler = ReduceLROnPlateau(lr=args.lr,
                                     factor=args.lr_decay_gamma,
                                     patience=args.lr_decay_patience)
    best_epoch = 0

    # Training loop.
    for epoch in trange(args.epochs):
        # TRAINING
        losses = utils.AverageMeter()
        for batch in range(max_iter):
            mixture_audio.d, target_audio.d = train_iter.next()
            solver.zero_grad()
            loss.forward(clear_no_need_grad=True)
            if comm.n_procs > 1:
                all_reduce_callback = comm.get_all_reduce_callback()
                loss.backward(clear_buffer=True,
                              communicator_callbacks=all_reduce_callback)
            else:
                loss.backward(clear_buffer=True)
            solver.weight_decay(weight_decay)
            solver.update()
            losses.update(loss.d.copy(), args.batch_size)
        training_loss = losses.avg

        # clear cache memory
        ext.clear_memory_cache()

        # VALIDATION
        vlosses = utils.AverageMeter()
        for batch in range(int(valid_source._size // comm.n_procs)):
            x, y = valid_iter.next()
            dur = int(valid_source.sample_rate * args.valid_dur)
            sp, cnt = 0, 0
            loss_tmp = nn.NdArray()
            loss_tmp.zero()
            while 1:
                vmixture_audio.d = x[Ellipsis, sp:sp + dur]
                vtarget_audio.d = y[Ellipsis, sp:sp + dur]
                vloss.forward(clear_no_need_grad=True)
                cnt += 1
                sp += dur
                loss_tmp += vloss.data
                if x[Ellipsis,
                     sp:sp + dur].shape[-1] < dur or x.shape[-1] == cnt * dur:
                    break
            loss_tmp = loss_tmp / cnt
            if comm.n_procs > 1:
                comm.all_reduce(loss_tmp, division=True, inplace=True)
            vlosses.update(loss_tmp.data.copy(), 1)
        validation_loss = vlosses.avg

        # clear cache memory
        ext.clear_memory_cache()

        lr = lr_scheduler.update_lr(validation_loss, epoch=epoch)
        solver.set_learning_rate(lr)
        stop = es.step(validation_loss)

        if comm.rank == 0:
            monitor_best_epoch.add(epoch, best_epoch)
            monitor_traing_loss.add(epoch, training_loss)
            monitor_validation_loss.add(epoch, validation_loss)
            monitor_lr.add(epoch, lr)
            monitor_time.add(epoch)

            if validation_loss == es.best:
                # save best model
                nn.save_parameters(os.path.join(args.output, 'best_xumx.h5'))
                best_epoch = epoch

        if stop:
            print("Apply Early Stopping")
            break
示例#3
0
def get_model(args, input_mean, input_scale, max_bin=None):
    '''
    Create computation graph and variables for X-UMX/UMX.
    '''
    # target channels (2 for UMX and 8 for X-UMX (2 * 4-target sources
    target_channels = args.nb_channels if args.umx_train else 4 * args.nb_channels

    # Create input variables.
    mixture_audio = nn.Variable(
        (args.batch_size, args.nb_channels, args.sample_rate * args.seq_dur))
    target_audio = nn.Variable(
        (args.batch_size, target_channels, args.sample_rate * args.seq_dur))
    vmixture_audio = nn.Variable(
        (1, args.nb_channels, args.sample_rate * args.valid_dur))
    vtarget_audio = nn.Variable(
        (1, target_channels, args.sample_rate * args.valid_dur))

    if args.umx_train:
        # create training graph for UMX
        unmix = OpenUnmix(input_mean=input_mean,
                          input_scale=input_scale,
                          max_bin=max_bin)
        pred_spec = unmix(mixture_audio)
        target_spec = get_spectogram(*get_stft(target_audio,
                                               n_fft=unmix.n_fft,
                                               n_hop=unmix.n_hop,
                                               center=False),
                                     mono=(unmix.nb_channels == 1))
        loss = F.mean(F.squared_error(pred_spec, target_spec))

        # create validation graph for UMX
        vpred_spec = unmix(vmixture_audio, test=True)
        vtarget_spec = get_spectogram(*get_stft(vtarget_audio,
                                                n_fft=unmix.n_fft,
                                                n_hop=unmix.n_hop),
                                      mono=(unmix.nb_channels == 1))
        vloss = F.mean(F.squared_error(vpred_spec, vtarget_spec))
    else:
        # create training graph for X-UMX
        unmix = OpenUnmix_CrossNet(input_mean=input_mean,
                                   input_scale=input_scale,
                                   max_bin=max_bin)
        mix_spec, m_hat, pred = unmix(mixture_audio)
        target_spec = get_spectogram(*get_stft(target_audio,
                                               n_fft=unmix.n_fft,
                                               n_hop=unmix.n_hop),
                                     mono=(unmix.nb_channels == 1))
        loss_f = mse_loss(mix_spec, m_hat, target_spec)
        loss_t = sdr_loss(mixture_audio, pred, target_audio)
        loss = args.mcoef * loss_t + loss_f

        # create validation graph for X-UMX
        vmix_spec, vm_hat, vpred = unmix(vmixture_audio, test=True)
        vtarget_spec = get_spectogram(*get_stft(vtarget_audio,
                                                n_fft=unmix.n_fft,
                                                n_hop=unmix.n_hop),
                                      mono=(unmix.nb_channels == 1))
        vloss_f = mse_loss(vmix_spec, vm_hat, vtarget_spec)
        vloss_t = sdr_loss(vmixture_audio, vpred, vtarget_audio)
        vloss = args.mcoef * vloss_t + vloss_f

    loss.persistent = True
    vloss.persistent = True

    Network = collections.namedtuple(
        'Network',
        'loss, vloss, mixture_audio, target_audio, vmixture_audio, vtarget_audio'
    )
    return Network(loss=loss,
                   vloss=vloss,
                   mixture_audio=mixture_audio,
                   target_audio=target_audio,
                   vmixture_audio=vmixture_audio,
                   vtarget_audio=vtarget_audio)
示例#4
0
def main(unused_argv):
    root = Path(FLAGS.dataset)
    clean_root = str(root / "clean")
    noisy_root = str(root / "noisy")
    test_root = str(root / "test")

    resnet = ResNet(3, 8, 64).cuda()
    flow = Flow(32, 3, 3).cuda()
    optimizer = optim.Adam(resnet.parameters(), lr=FLAGS.learning_rate)

    ckpt = torch.load(FLAGS.saved_flow_model)
    flow.load_state_dict(ckpt["flow_state_dict"])
    flow.eval()

    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
    log_folder = Path(FLAGS.logs) / timestamp
    writer = SummaryWriter(str(log_folder))

    for epoch in range(1, FLAGS.num_epochs + 1):
        trainloader = get_dataloader(noisy_root,
                                     False,
                                     FLAGS.batch_size,
                                     crop_size=FLAGS.crop_size)

        pbar = tqdm(trainloader, desc="epoch %d" % epoch, leave=False)
        resnet.train()
        for n, noisy_images in enumerate(pbar):

            noisy_images = noisy_images.to(cuda_device)

            optimizer.zero_grad()
            denoised_images = resnet(noisy_images)
            zs, log_det = flow(denoised_images)

            mse = mse_loss(denoised_images, noisy_images)
            batch_size = noisy_images.shape[0]
            nll = noisy_images.numel() * torch.log(256 * one) / batch_size
            nll = nll + nll_loss(zs, log_det)
            loss = mse + FLAGS.alpha * nll

            loss.backward()
            optimizer.step()

            pbar.set_postfix(loss=loss.item())

            step = (epoch - 1) * len(trainloader) + n
            writer.add_scalar("loss", loss, step)
            writer.add_scalar("mse", mse, step)
            writer.add_scalar("nll", nll, step)

        if epoch % FLAGS.test_every == 0:
            resnet.eval()
            noisy_testloader = get_dataloader(test_root, False, 1, train=False)
            clean_testloader = get_dataloader(test_root, False, 1, train=False)
            with torch.no_grad():
                psnr = []
                sample_images = None
                for n, (clean_images, noisy_images) in enumerate(
                        zip(clean_testloader, noisy_testloader)):
                    # for n, noisy_images in enumerate(noisy_testloader):
                    clean_images = clean_images.to(cuda_device)
                    noisy_images = noisy_images.to(cuda_device)

                    denoised_images = resnet(noisy_images)
                    mse = F.mse_loss(clean_images, denoised_images)
                    psnr.append(10 * torch.log10(1 / mse))

                    if n < 5:
                        sample_images = torch.cat(
                            [clean_images, noisy_images, denoised_images])
                        # sample_images = torch.cat([noisy_images, denoised_images])
                        writer.add_images("sample_image_%d" % n, sample_images,
                                          step)
                test_psnr = torch.as_tensor(psnr).mean()

            writer.add_scalar("test_psnr", test_psnr, step)
            # writer.add_images("sample_images_%d" % n, sample_images, step)

            torch.save(
                {
                    "epoch": epoch,
                    "resnet_state_dict": resnet.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                }, str(log_folder / ("ckpt-%d.pth" % step)))

    writer.close()
示例#5
0
    def trainstep(real_human, real_anime, big_anime):
        with tf.GradientTape(persistent=True) as tape:
            latent_anime = encode_share(encode_anime(real_anime))
            latent_human = encode_share(encode_human(real_human))

            recon_anime = decode_anime(decode_share(latent_anime))
            recon_human = decode_human(decode_share(latent_human))

            fake_anime = decode_anime(decode_share(latent_human))
            latent_human_cycled = encode_share(encode_anime(fake_anime))

            fake_human = decode_anime(decode_share(latent_anime))
            latent_anime_cycled = encode_share(encode_anime(fake_human))

            def kl_loss(mean, log_var):
                loss = 1 + log_var - tf.math.square(mean) + tf.math.exp(
                    log_var)
                loss = tf.reduce_sum(loss, axis=-1) * -0.5
                return loss

            disc_fake = D(fake_anime)
            disc_real = D(real_anime)

            c_dann_anime = c_dann(latent_anime)
            c_dann_human = c_dann(latent_human)

            loss_anime_encode = mse_loss(real_anime, recon_anime) * 4
            loss_human_encode = mse_loss(real_human, recon_human) * 4

            loss_domain_adversarial = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=tf.zeros_like(c_dann_anime),
                    logits=c_dann_anime)) + tf.reduce_mean(
                        tf.nn.sigmoid_cross_entropy_with_logits(
                            labels=tf.ones_like(c_dann_human),
                            logits=c_dann_human))
            loss_domain_adversarial = tf.math.minimum(loss_domain_adversarial,
                                                      100)
            loss_domain_adversarial = loss_domain_adversarial * 0.2
            tf.print(loss_domain_adversarial)

            loss_semantic_consistency = identity_loss(
                latent_anime, latent_anime_cycled) * 2 + identity_loss(
                    latent_human, latent_human_cycled) * 2

            loss_gan = mse_loss(tf.zeros_like(disc_fake), disc_fake) * 5

            anime_encode_total_loss = (loss_anime_encode +
                                       loss_domain_adversarial +
                                       loss_semantic_consistency + loss_gan)
            human_encode_total_loss = (loss_human_encode +
                                       loss_domain_adversarial +
                                       loss_semantic_consistency)
            share_encode_total_loss = (loss_anime_encode +
                                       loss_domain_adversarial +
                                       loss_semantic_consistency + loss_gan +
                                       loss_human_encode)

            share_decode_total_loss = loss_anime_encode + loss_human_encode + loss_gan
            anime_decode_total_loss = loss_anime_encode + loss_gan
            human_decode_total_loss = loss_human_encode

            loss_disc = (mse_loss(tf.ones_like(disc_fake), disc_fake) +
                         mse_loss(tf.zeros_like(disc_real), disc_real)) * 10

            losses = [
                anime_encode_total_loss, human_encode_total_loss,
                share_encode_total_loss, loss_domain_adversarial,
                share_decode_total_loss, anime_decode_total_loss,
                human_decode_total_loss, loss_disc
            ]

            scaled_losses = [
                optim.get_scaled_loss(loss)
                for optim, loss in zip(optims, losses)
            ]

        list_variables = [
            encode_anime.trainable_variables, encode_human.trainable_variables,
            encode_share.trainable_variables, c_dann.trainable_variables,
            decode_share.trainable_variables, decode_anime.trainable_variables,
            decode_human.trainable_variables, D.trainable_variables
        ]
        gan_grad = [
            tape.gradient(scaled_loss, train_variable) for scaled_loss,
            train_variable in zip(scaled_losses, list_variables)
        ]
        gan_grad = [
            optim.get_unscaled_gradients(x)
            for optim, x in zip(optims, gan_grad)
        ]
        for optim, grad, trainable in zip(optims, gan_grad, list_variables):
            optim.apply_gradients(zip(grad, trainable))
        # dis_grad = dis_optim.get_unscaled_gradients(
        #     tape.gradient(scaled_loss_disc, D.trainable_variables)
        # )
        # dis_optim.apply_gradients(zip(dis_grad, D.trainable_variables))

        return (
            real_human,
            real_anime,
            recon_anime,
            recon_human,
            fake_anime,
            fake_human,
            loss_anime_encode,
            loss_human_encode,
            loss_domain_adversarial,
            loss_semantic_consistency,
            loss_gan,
            loss_disc,
        )
示例#6
0
    def train(self):
        epochs = 1000
        self.genA2B.train(), self.genB2A.train(), self.disGA.train(
        ), self.disGB.train(), self.disLA.train(), self.disLB.train()
        print('training start !')
        start_time = time.time()
        '''加载预训练模型'''
        if self.pretrain:
            str_genA2B = "Parameters/genA2B%03d.pdparams" % (self.start - 1)
            str_genB2A = "Parameters/genB2A%03d.pdparams" % (self.start - 1)
            str_disGA = "Parameters/disGA%03d.pdparams" % (self.start - 1)
            str_disGB = "Parameters/disGB%03d.pdparams" % (self.start - 1)
            str_disLA = "Parameters/disLA%03d.pdparams" % (self.start - 1)
            str_disLB = "Parameters/disLB%03d.pdparams" % (self.start - 1)
            genA2B_para, gen_A2B_opt = fluid.load_dygraph(str_genA2B)
            genB2A_para, gen_B2A_opt = fluid.load_dygraph(str_genB2A)
            disGA_para, disGA_opt = fluid.load_dygraph(str_disGA)
            disGB_para, disGB_opt = fluid.load_dygraph(str_disGB)
            disLA_para, disLA_opt = fluid.load_dygraph(str_disLA)
            disLB_para, disLB_opt = fluid.load_dygraph(str_disLB)
            self.genA2B.load_dict(genA2B_para)
            self.genB2A.load_dict(genB2A_para)
            self.disGA.load_dict(disGA_para)
            self.disGB.load_dict(disGB_para)
            self.disLA.load_dict(disLA_para)
            self.disLB.load_dict(disLB_para)
        for epoch in range(self.start, epochs):
            for block_id, data in enumerate(self.train_reader()):
                real_A = np.array([x[0] for x in data], np.float32)
                real_B = np.array([x[1] for x in data], np.float32)
                real_A = totensor(real_A, block_id, 'train')
                real_B = totensor(real_B, block_id, 'train')

                # Update D

                fake_A2B, _, _ = self.genA2B(real_A)
                fake_B2A, _, _ = self.genB2A(real_B)

                real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A)
                real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A)
                real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B)
                real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B)

                fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
                fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
                fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
                fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

                D_ad_loss_GA = mse_loss(1, real_GA_logit) + mse_loss(
                    0, fake_GA_logit)
                D_ad_cam_loss_GA = mse_loss(1, real_GA_cam_logit) + mse_loss(
                    0, fake_GA_cam_logit)

                D_ad_loss_LA = mse_loss(1, real_LA_logit) + mse_loss(
                    0, fake_LA_logit)
                D_ad_cam_loss_LA = mse_loss(1, real_LA_cam_logit) + mse_loss(
                    0, fake_LA_cam_logit)

                D_ad_loss_GB = mse_loss(1, real_GB_logit) + mse_loss(
                    0, fake_GB_logit)
                D_ad_cam_loss_GB = mse_loss(1, real_GB_cam_logit) + mse_loss(
                    0, fake_GB_cam_logit)

                D_ad_loss_LB = mse_loss(1, real_LB_logit) + mse_loss(
                    0, fake_LB_logit)
                D_ad_cam_loss_LB = mse_loss(1, real_LB_cam_logit) + mse_loss(
                    0, fake_LB_cam_logit)

                D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA +
                                              D_ad_loss_LA + D_ad_cam_loss_LA)
                D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB +
                                              D_ad_loss_LB + D_ad_cam_loss_LB)

                Discriminator_loss = D_loss_A + D_loss_B
                Discriminator_loss.backward()
                self.D_opt.minimize(Discriminator_loss)
                self.disGA.clear_gradients(), self.disGB.clear_gradients(
                ), self.disLA.clear_gradients(), self.disLB.clear_gradients()

                # Update G

                fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A)
                fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B)
                print("fake_A2B.shape:", fake_A2B.shape)
                fake_A2B2A, _, _ = self.genB2A(fake_A2B)
                fake_B2A2B, _, _ = self.genA2B(fake_B2A)

                fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A)
                fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B)

                fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
                fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
                fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
                fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

                G_ad_loss_GA = mse_loss(1, fake_GA_logit)
                G_ad_cam_loss_GA = mse_loss(1, fake_GA_cam_logit)

                G_ad_loss_LA = mse_loss(1, fake_LA_logit)
                G_ad_cam_loss_LA = mse_loss(1, fake_LA_cam_logit)

                G_ad_loss_GB = mse_loss(1, fake_GB_logit)
                G_ad_cam_loss_GB = mse_loss(1, fake_GB_cam_logit)

                G_ad_loss_LB = mse_loss(1, fake_LB_logit)
                G_ad_cam_loss_LB = mse_loss(1, fake_LB_cam_logit)

                G_recon_loss_A = self.L1loss(fake_A2B2A, real_A)
                G_recon_loss_B = self.L1loss(fake_B2A2B, real_B)

                G_identity_loss_A = self.L1loss(fake_A2A, real_A)
                G_identity_loss_B = self.L1loss(fake_B2B, real_B)

                G_cam_loss_A = bce_loss(1, fake_B2A_cam_logit) + bce_loss(
                    0, fake_A2A_cam_logit)
                G_cam_loss_B = bce_loss(1, fake_A2B_cam_logit) + bce_loss(
                    0, fake_B2B_cam_logit)

                G_loss_A = self.adv_weight * (
                    G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA +
                    G_ad_cam_loss_LA
                ) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A
                G_loss_B = self.adv_weight * (
                    G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB +
                    G_ad_cam_loss_LB
                ) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B

                Generator_loss = G_loss_A + G_loss_B
                Generator_loss.backward()
                self.G_opt.minimize(Generator_loss)
                self.genA2B.clear_gradients(), self.genB2A.clear_gradients()

                print("[%5d/%5d] time: %4.4f d_loss: %.5f, g_loss: %.5f" %
                      (epoch, block_id, time.time() - start_time,
                       Discriminator_loss.numpy(), Generator_loss.numpy()))
                print("G_loss_A: %.5f G_loss_B: %.5f" %
                      (G_loss_A.numpy(), G_loss_B.numpy()))
                print("G_ad_loss_GA: %.5f   G_ad_loss_GB: %.5f" %
                      (G_ad_loss_GA.numpy(), G_ad_loss_GB.numpy()))
                print("G_ad_loss_LA: %.5f   G_ad_loss_LB: %.5f" %
                      (G_ad_loss_LA.numpy(), G_ad_loss_LB.numpy()))
                print("G_cam_loss_A:%.5f  G_cam_loss_B:%.5f" %
                      (G_cam_loss_A.numpy(), G_cam_loss_B.numpy()))
                print("G_recon_loss_A:%.5f  G_recon_loss_B:%.5f" %
                      (G_recon_loss_A.numpy(), G_recon_loss_B.numpy()))
                print("G_identity_loss_A:%.5f  G_identity_loss_B:%.5f" %
                      (G_identity_loss_B.numpy(), G_identity_loss_B.numpy()))

                if epoch % 2 == 1 and block_id % self.print_freq == 0:

                    A2B = np.zeros((self.img_size * 7, 0, 3))
                    # B2A = np.zeros((self.img_size * 7, 0, 3))
                    for eval_id, eval_data in enumerate(self.test_reader()):
                        if eval_id == 10:
                            break
                        real_A = np.array([x[0] for x in eval_data],
                                          np.float32)
                        real_B = np.array([x[1] for x in eval_data],
                                          np.float32)
                        real_A = totensor(real_A, eval_id, 'eval')
                        real_B = totensor(real_B, eval_id, 'eval')

                        fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
                        fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

                        fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(
                            fake_A2B)
                        fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(
                            fake_B2A)

                        fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
                        fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

                        a = tensor2numpy(denorm(real_A[0]))
                        b = cam(tensor2numpy(fake_A2A_heatmap[0]),
                                self.img_size)
                        c = tensor2numpy(denorm(fake_A2A[0]))
                        d = cam(tensor2numpy(fake_A2B_heatmap[0]),
                                self.img_size)
                        e = tensor2numpy(denorm(fake_A2B[0]))
                        f = cam(tensor2numpy(fake_A2B2A_heatmap[0]),
                                self.img_size)
                        g = tensor2numpy(denorm(fake_A2B2A[0]))
                        A2B = np.concatenate((A2B, (np.concatenate(
                            (a, b, c, d, e, f, g)) * 255).astype(np.uint8)),
                                             1).astype(np.uint8)
                    A2B = Image.fromarray(A2B)
                    A2B.save('Images/%d_%d.png' % (epoch, block_id))
                    self.genA2B.train(), self.genB2A.train(), self.disGA.train(
                    ), self.disGB.train(), self.disLA.train(
                    ), self.disLB.train()
            if epoch % 4 == 0:
                fluid.save_dygraph(self.genA2B.state_dict(),
                                   "Parameters/genA2B%03d" % (epoch))
                fluid.save_dygraph(self.genB2A.state_dict(),
                                   "Parameters/genB2A%03d" % (epoch))
                fluid.save_dygraph(self.disGA.state_dict(),
                                   "Parameters/disGA%03d" % (epoch))
                fluid.save_dygraph(self.disGB.state_dict(),
                                   "Parameters/disGB%03d" % (epoch))
                fluid.save_dygraph(self.disLA.state_dict(),
                                   "Parameters/disLA%03d" % (epoch))
                fluid.save_dygraph(self.disLB.state_dict(),
                                   "Parameters/disLB%03d" % (epoch))
示例#7
0
    def trainstep(real_human, real_anime, big_anime):

        with tf.GradientTape(persistent=True) as tape:

            fake_anime = generator_to_anime(real_human, training=True)
            cycled_human = generator_to_human(fake_anime, training=True)

            fake_human = generator_to_human(real_anime, training=True)
            cycled_anime = generator_to_anime(fake_human, training=True)

            # same_human and same_anime are used for identity loss.
            same_human = generator_to_human(real_human, training=True)
            same_anime = generator_to_anime(real_anime, training=True)

            disc_real_human = discriminator_human(real_human, training=True)
            disc_real_anime = discriminator_anime(real_anime, training=True)

            disc_fake_human = discriminator_human(fake_human, training=True)
            disc_fake_anime = discriminator_anime(fake_anime, training=True)

            # assert()
            # calculate the loss
            gen_anime_loss = mse_loss(disc_fake_anime,
                                      tf.zeros_like(disc_fake_anime))
            gen_human_loss = mse_loss(disc_fake_human,
                                      tf.zeros_like(disc_fake_human))

            total_cycle_loss = cycle_loss(real_human,
                                          cycled_human) + cycle_loss(
                                              real_anime, cycled_anime)
            total_gen_anime_loss = (gen_anime_loss + total_cycle_loss +
                                    identity_loss(real_anime, same_anime) +
                                    mse_loss(real_anime, fake_anime) * 0.1)

            total_gen_human_loss = (gen_human_loss + total_cycle_loss +
                                    identity_loss(real_human, same_human) +
                                    mse_loss(real_anime, fake_anime))
            disc_human_loss = mse_loss(
                disc_real_human, tf.ones_like(disc_real_human)) + mse_loss(
                    disc_fake_human, -1 * tf.ones_like(disc_fake_human))
            disc_anime_loss = mse_loss(
                disc_real_anime, tf.ones_like(disc_real_human)) + mse_loss(
                    disc_fake_anime, -1 * tf.ones_like(disc_fake_anime))

            fake_anime_upscale = generator_anime_upscale(fake_anime,
                                                         training=True)
            same_anime_upscale = generator_anime_upscale(same_anime,
                                                         training=True)
            disc_fake_upscale = discriminator_anime_upscale(fake_anime_upscale,
                                                            training=True)
            disc_same_upscale = discriminator_anime_upscale(same_anime_upscale,
                                                            training=True)
            disc_real_big = discriminator_anime_upscale(big_anime,
                                                        training=True)

            gen_upscale_loss = (
                mse_loss(disc_fake_upscale, tf.zeros_like(disc_fake_upscale)) +
                mse_loss(disc_same_upscale, tf.zeros_like(disc_same_upscale)) *
                0.1)
            # tf.print("gen_upscale_loss", gen_upscale_loss)

            print("generator_to_anime.count_params()",
                  generator_to_anime.count_params())
            print("discriminator_anime.count_params()",
                  discriminator_human.count_params())
            print("generator_anime_upscale.count_params()",
                  generator_anime_upscale.count_params())
            print(
                "discriminator_anime_upscale.count_params()",
                discriminator_anime_upscale.count_params(),
            )

            disc_upscale_loss = mse_loss(
                disc_fake_upscale,
                -1 * tf.ones_like(disc_fake_upscale)) + mse_loss(
                    disc_real_big, tf.ones_like(disc_real_big))

            scaled_total_gen_anime_loss = generator_to_anime_optimizer.get_scaled_loss(
                total_gen_anime_loss)
            scaled_total_gen_human_loss = generator_to_human_optimizer.get_scaled_loss(
                total_gen_human_loss)
            scaled_gen_upscale_loss = generator_anime_upscale_optimizer.get_scaled_loss(
                gen_upscale_loss)
            scaled_disc_human_loss = discriminator_human_optimizer.get_scaled_loss(
                disc_human_loss)
            scaled_disc_anime_loss = discriminator_anime_optimizer.get_scaled_loss(
                disc_anime_loss)
            scaled_disc_upscale_loss = discriminator_anime_upscale_optimizer.get_scaled_loss(
                disc_upscale_loss)

        generator_to_anime_gradients = generator_to_anime_optimizer.get_unscaled_gradients(
            tape.gradient(scaled_total_gen_anime_loss,
                          generator_to_anime.trainable_variables))

        generator_to_human_gradients = generator_to_human_optimizer.get_unscaled_gradients(
            tape.gradient(scaled_total_gen_human_loss,
                          generator_to_human.trainable_variables))

        generator_upscale_gradients = generator_anime_upscale_optimizer.get_unscaled_gradients(
            tape.gradient(scaled_gen_upscale_loss,
                          generator_anime_upscale.trainable_variables))

        discriminator_human_gradients = discriminator_human_optimizer.get_unscaled_gradients(
            tape.gradient(scaled_disc_human_loss,
                          discriminator_human.trainable_variables))
        discriminator_anime_gradients = discriminator_anime_optimizer.get_unscaled_gradients(
            tape.gradient(scaled_disc_anime_loss,
                          discriminator_anime.trainable_variables))

        discriminator_upscale_gradients = discriminator_anime_upscale_optimizer.get_unscaled_gradients(
            tape.gradient(
                scaled_disc_upscale_loss,
                discriminator_anime_upscale.trainable_variables,
            ))

        generator_to_anime_optimizer.apply_gradients(
            zip(generator_to_anime_gradients,
                generator_to_anime.trainable_variables))

        generator_to_human_optimizer.apply_gradients(
            zip(generator_to_human_gradients,
                generator_to_human.trainable_variables))

        generator_anime_upscale_optimizer.apply_gradients(
            zip(generator_upscale_gradients,
                generator_anime_upscale.trainable_variables))

        discriminator_human_optimizer.apply_gradients(
            zip(discriminator_human_gradients,
                discriminator_human.trainable_variables))

        discriminator_anime_optimizer.apply_gradients(
            zip(discriminator_anime_gradients,
                discriminator_anime.trainable_variables))

        discriminator_anime_upscale_optimizer.apply_gradients(
            zip(
                discriminator_upscale_gradients,
                discriminator_anime_upscale.trainable_variables,
            ))

        return [
            real_human,
            real_anime,
            fake_anime,
            cycled_human,
            fake_human,
            cycled_anime,
            same_human,
            same_anime,
            fake_anime_upscale,
            same_anime_upscale,
            # real_anime_upscale,
            gen_anime_loss,
            gen_human_loss,
            disc_human_loss,
            disc_anime_loss,
            total_gen_anime_loss,
            total_gen_human_loss,
            gen_upscale_loss,
            disc_upscale_loss,
        ]