Exemple #1
0
 def draw_sequentially(self, mode, m_noise, m_image):
     upscaled_prev = self.first_img_input
     if len(self.Gs) > 0:
         if mode == 'rec':
             count = 0
             for G, padded_rec_z, cur_real, next_real, noise_amp in zip(self.Gs, self.Zs, self.reals, self.reals[1:], self.noise_amps):
                 upscaled_prev = upscaled_prev[:, :, 0:cur_real.shape[2], 0:cur_real.shape[3]]
                 padded_img = m_image(upscaled_prev)
                 padded_img_with_z = noise_amp * padded_rec_z + padded_img
                 generated_img = G(padded_img_with_z.detach(), padded_img)
                 up_scaled_img = resize_img(generated_img, 1/self.config.scale_factor, self.config)
                 upscaled_prev = up_scaled_img[:, :, 0:next_real.shape[2], 0:next_real.shape[3]]
                 count += 1
         elif mode == 'rand':
             count = 0
             pad_noise = int(((self.config.kernel_size - 1) * self.config.num_layers) / 2)
             for G, padded_rec_z, cur_real, next_real, noise_amp in zip(self.Gs, self.Zs, self.reals, self.reals[1:], self.noise_amps):
                 if count == 0:  # Generate random 1-channel noise
                     random_noise = generate_noise([1, padded_rec_z.shape[2] - 2 * pad_noise, padded_rec_z.shape[3] - 2 * pad_noise], device=self.config.device)
                     random_noise = random_noise.expand(1, 3, random_noise.shape[2], random_noise.shape[3])
                 else:           # Generate random 3-channel noise
                     random_noise = generate_noise([self.config.img_channel, padded_rec_z.shape[2] - 2 * pad_noise, padded_rec_z.shape[3] - 2 * pad_noise], device=self.config.device)
                 padded_noise = m_noise(random_noise)
                 upscaled_prev = upscaled_prev[:, :, 0:cur_real.shape[2], 0:cur_real.shape[3]]
                 padded_img = m_image(upscaled_prev)
                 padded_img_with_z = noise_amp * padded_noise + padded_img
                 generated_img = G(padded_img_with_z.detach(), padded_img)
                 up_scaled_img = resize_img(generated_img, 1/self.config.scale_factor, self.config)
                 upscaled_prev = up_scaled_img[:, :, 0:next_real.shape[2], 0:next_real.shape[3]]
                 count += 1
     return upscaled_prev
Exemple #2
0
def generate_all(cnt):
    # source : loads audio, saves as spec
    path1 = os.path.join(source_wav_dir, source_wav_list[cnt])
    path2 = os.path.join(source_spec_save_dir,
                         source_wav_list[cnt][:-4] + '.png')
    _ = spec_from_path_to_path(path1, path2)
    spec_src, ratio = get_image(path2, sz, resize_input)
    spec_src = transform_image(spec_src, sz, ic, resize_input)

    # source : loads spec, saves as audio
    path3 = os.path.join(source_wav_save_dir, source_wav_list[cnt])
    spec = cv2.imread(path2)
    stft = mel_to_stft(spectrogram_img_to_mel(spec, threshold), sample_rate,
                       n_fft, n_mels, shrink_size, power)
    wave = griffin_lim(stft, griffin_lim_iter, n_fft, win_length, hop_length,
                       pre_emphasis_rate)
    librosa.output.write_wav(path3, wave, sample_rate, norm=True)

    if (have_target):
        # target : loads audio, saves as spec
        path1 = os.path.join(target_wav_dir, source_wav_list[cnt])
        path2 = os.path.join(target_spec_save_dir,
                             source_wav_list[cnt][:-4] + '.png')
        _ = spec_from_path_to_path(path1, path2)

        # target : loads spec, saves as audio
        path3 = os.path.join(target_wav_save_dir, source_wav_list[cnt])
        spec = cv2.imread(path2)
        stft = mel_to_stft(spectrogram_img_to_mel(spec, threshold),
                           sample_rate, n_fft, n_mels, shrink_size, power)
        wave = griffin_lim(stft, griffin_lim_iter, n_fft, win_length,
                           hop_length, pre_emphasis_rate)
        librosa.output.write_wav(path3, wave, sample_rate, norm=True)

    for i in range(noise_per_image):
        # path to save generated spec
        path3 = os.path.join(out_spec_save_dir,
                             source_wav_list[cnt][:-4] + '-' + str(i) + '.png')
        noise = generate_noise(1, nz, device)
        # generate spec
        out = generate(netG, spec_src, noise, oc, device)
        # save it in the path
        cv2.imwrite(path3, out)

        # read the generated spec
        spec = cv2.imread(path3)
        # changes the size of the generated spec to the size of the spec_src (input)
        spec = cv2.resize(spec, (0, 0), fx=1 / ratio, fy=1)
        # makes it stft, then wave
        stft = mel_to_stft(spectrogram_img_to_mel(spec, threshold),
                           sample_rate, n_fft, n_mels, shrink_size, power)
        wave = griffin_lim(stft, griffin_lim_iter, n_fft, win_length,
                           hop_length, pre_emphasis_rate)
        # saves the wave
        path4 = os.path.join(out_wav_save_dir,
                             source_wav_list[cnt][:-4] + '-' + str(i) + '.wav')
        librosa.output.write_wav(path4, wave, sample_rate, norm=True)
    def __init__(self,
                 loss_type,
                 netD,
                 netG,
                 device,
                 train_dl,
                 val_dl,
                 lr_D=0.0002,
                 lr_G=0.0002,
                 rec_weight=10,
                 ds_weight=8,
                 use_rec_feature=False,
                 resample=True,
                 weight_clip=None,
                 use_gradient_penalty=False,
                 loss_interval=50,
                 image_interval=50,
                 save_img_dir='saved_images/'):
        self.netD = netD
        self.netG = netG
        self.train_dl = train_dl
        self.val_dl = val_dl
        self.lr_D = lr_D
        self.lr_G = lr_G
        self.train_iteration_per_epoch = len(self.train_dl)
        self.device = device
        self.resample = resample
        self.weight_clip = weight_clip
        self.use_gradient_penalty = use_gradient_penalty
        self.rec_weight = rec_weight
        self.use_rec_feature = use_rec_feature
        self.ds_weight = ds_weight

        self.nz = self.netG.nz
        self.fixed_noise = generate_noise(3, self.nz, self.device)

        self.loss_type = loss_type
        self.require_type = get_require_type(self.loss_type)
        self.loss = get_gan_loss(self.device, self.loss_type)
        self.ds_loss = DSGAN_Loss(self.device, self.nz)

        self.optimizerD = optim.Adam(self.netD.parameters(),
                                     lr=self.lr_D,
                                     betas=(0.5, 0.999))
        self.optimizerG = optim.Adam(self.netG.parameters(),
                                     lr=self.lr_G,
                                     betas=(0.5, 0.999))

        self.loss_interval = loss_interval
        self.image_interval = image_interval

        self.save_cnt = 0
        self.save_img_dir = save_img_dir
        if (not os.path.exists(self.save_img_dir)):
            os.makedirs(self.save_img_dir)
Exemple #4
0
    def _get_reconstruction_noise_tensor(self, index_scale):
        """
        Return the appropriate reconstruction noise tensor for the specified scale. It is zero for all scales except the
        lowest. This is required for the reconstruction term of the training loss (for further details see the original SinGAN paper)
        :param index_scale:
        :return:
        """
        if index_scale == 0:
            return generate_noise(shape=(1, *self.scaled_images[0].shape),
                                  std=1)

        return tf.zeros((1, *self.scaled_images[index_scale].shape))
    def generate_image(self, img_from_prev_scale=None):
        """
        Use the GAN generator to generate an image
        :param noise: Noise to use as input to the GAN
        :param img_from_prev_scale: Input image upsampled from the previous scale. None if the GAN is at the lowest scale in the hierarchy
        :return:
        """
        # noise = self.noise_amplitude * generate_noise((1, *self.img_shape))
        noise = generate_noise((1, *self.img_shape), self.noise_amplitude)

        if self.index_scale == 0:
            return self.generator(noise)
        else:
            return self.generator([noise, img_from_prev_scale])
Exemple #6
0
    def train(self, num_epoch):
        l1 = nn.L1Loss()
        for epoch in range(num_epoch):
            if (self.resample):
                train_dl_iter = iter(self.train_dl)
            for i, (a, b) in enumerate(tqdm(self.train_dl)):
                a = a.to(self.device)
                b = b.to(self.device)
                bs = a.size(0)
                noise = generate_noise(bs, self.nz, self.device)
                fake_a = self.netG_B2A(b, noise)
                fake_b = self.netG_A2B(a, noise)

                self.optimizerD_A.zero_grad()
                c_xr_1 = self.netD_A(a)
                c_xr_1 = c_xr_1.view(-1)
                c_xf_1 = self.netD_A(fake_a.detach())
                c_xf_1 = c_xf_1.view(-1)
                errD_A_1 = self.loss.d_loss(c_xr_1, c_xf_1)
                c_xr_2 = self.netD_A_2(a)
                c_xr_2 = c_xr_2.view(-1)
                c_xf_2 = self.netD_A_2(fake_a.detach())
                c_xf_2 = c_xf_2.view(-1)
                errD_A_2 = self.loss.d_loss(c_xr_2, c_xf_2)
                errD_A = (errD_A_1 + errD_A_2) / 2.0
                errD_A.backward()
                self.optimizerD_A.step()

                self.optimizerD_B.zero_grad()
                c_xr_1 = self.netD_B(b)
                c_xr_1 = c_xr_1.view(-1)
                c_xf_1 = self.netD_B(fake_b.detach())
                c_xf_1 = c_xf_1.view(-1)
                errD_B_1 = self.loss.d_loss(c_xr_1, c_xf_1)
                c_xr_2 = self.netD_B_2(b)
                c_xr_2 = c_xr_2.view(-1)
                c_xf_2 = self.netD_B_2(fake_b.detach())
                c_xf_2 = c_xf_2.view(-1)
                errD_B_2 = self.loss.d_loss(c_xr_2, c_xf_2)
                errD_B = (errD_B_1 + errD_B_2) / 2.0
                errD_B.backward()
                self.optimizerD_B.step()

                if (self.weight_clip != None):
                    for param in self.netD_A.parameters():
                        param.data.clamp_(-self.weight_clip, self.weight_clip)
                    for param in self.netD_A_2.parameters():
                        param.data.clamp_(-self.weight_clip, self.weight_clip)

                if (self.weight_clip != None):
                    for param in self.netD_B.parameters():
                        param.data.clamp_(-self.weight_clip, self.weight_clip)
                    for param in self.netD_B_2.parameters():
                        param.data.clamp_(-self.weight_clip, self.weight_clip)

                self.optimizerG.zero_grad()
                if (self.resample):
                    a, b = next(train_dl_iter)
                    a = a.to(self.device)
                    b = b.to(self.device)
                    bs = a.size(0)
                    noise = generate_noise(bs, self.nz, self.device)
                    fake_a = self.netG_B2A(b, noise)
                    fake_b = self.netG_A2B(a, noise)

                cycle_a = self.netG_B2A(fake_b, noise)
                cycle_b = self.netG_A2B(fake_a, noise)
                identity_a = self.netG_B2A(a, noise)
                identity_b = self.netG_A2B(b, noise)

                if (self.require_type == 0):
                    c_xr_a_1 = None
                    c_xr_a_2 = None
                    c_xr_b_1 = None
                    c_xr_b_2 = None

                    c_xf_a_1 = self.netD_A(fake_a)
                    c_xf_a_1 = c_xf_a_1.view(-1)
                    c_xf_a_2 = self.netD_A_2(fake_a)
                    c_xf_a_2 = c_xf_a_2.view(-1)
                    c_xf_b_1 = self.netD_B(fake_b)
                    c_xf_b_1 = c_xf_b_1.view(-1)
                    c_xf_b_2 = self.netD_B_2(fake_b)
                    c_xf_b_2 = c_xf_b_2.view(-1)

                    errG_a_1 = self.loss.g_loss(c_xf_a_1)
                    errG_a_2 = self.loss.g_loss(c_xf_a_2)
                    errG_b_1 = self.loss.g_loss(c_xf_b_1)
                    errG_b_2 = self.loss.g_loss(c_xf_b_2)

                    errG_a = (errG_a_1 + errG_a_2) / 2.0
                    errG_b = (errG_b_1 + errG_b_2) / 2.0

                if (self.require_type == 1 or self.require_type == 2):
                    c_xr_a_1 = self.netD_A(a)
                    c_xr_a_1 = c_xr_a_1.view(-1)
                    c_xr_a_2 = self.netD_A_2(a)
                    c_xr_a_2 = c_xr_a_2.view(-1)
                    c_xr_b_1 = self.netD_B(b)
                    c_xr_b_1 = c_xr_b_1.view(-1)
                    c_xr_b_2 = self.netD_B_2(b)
                    c_xr_b_2 = c_xr_b_2.view(-1)

                    c_xf_a_1 = self.netD_A(fake_a)
                    c_xf_a_1 = c_xf_a_1.view(-1)
                    c_xf_a_2 = self.netD_A_2(fake_a)
                    c_xf_a_2 = c_xf_a_2.view(-1)
                    c_xf_b_1 = self.netD_B(fake_b)
                    c_xf_b_1 = c_xf_b_1.view(-1)
                    c_xf_b_2 = self.netD_B_2(fake_b)
                    c_xf_b_2 = c_xf_b_2.view(-1)

                    errG_a_1 = self.loss.g_loss(c_xr_a_1, c_xf_a_1)
                    errG_a_2 = self.loss.g_loss(c_xr_a_2, c_xf_a_2)
                    errG_b_1 = self.loss.g_loss(c_xr_b_1, c_xf_b_1)
                    errG_b_2 = self.loss.g_loss(c_xr_b_2, c_xf_b_2)

                    errG_a = (errG_a_1 + errG_a_2) / 2.0
                    errG_b = (errG_b_1 + errG_b_2) / 2.0

                cycle_a_loss = l1(cycle_a, a)
                cycle_b_loss = l1(cycle_b, b)
                identity_a_loss = l1(identity_a, a)
                identity_b_loss = l1(identity_b, b)

                if (self.ds_weight == 0):
                    ds_loss = 0
                else:
                    noise1 = generate_noise(bs, self.nz, self.device)
                    noise2 = generate_noise(bs, self.nz, self.device)
                    fake_a1 = self.netG_B2A(b, noise1)
                    fake_a2 = self.netG_B2A(b, noise2)
                    fake_b1 = self.netG_A2B(a, noise1)
                    fake_b2 = self.netG_A2B(a, noise2)
                    ds_loss1 = self.ds_loss.get_loss(fake_a1, fake_a2, noise1,
                                                     noise2)
                    ds_loss2 = self.ds_loss.get_loss(fake_b1, fake_b2, noise1,
                                                     noise2)
                    ds_loss = (ds_loss1 + ds_loss2) / 2.0

                errG = errG_a + errG_b + (
                    cycle_a_loss + cycle_b_loss) * self.cycle_weight + (
                        identity_a_loss +
                        identity_b_loss) * self.identity_weight
                errG = errG + ds_loss * self.ds_weight
                errG.backward()
                #update G using the gradients calculated previously
                self.optimizerG.step()

                if (i % self.loss_interval == 0):
                    print(
                        '[%d/%d] [%d/%d] errD_A : %.4f, errD_B : %.4f, errG : %.4f'
                        %
                        (epoch + 1, num_epoch, i + 1,
                         self.train_iteration_per_epoch, errD_A, errD_B, errG))

                if (i % self.image_interval == 0):
                    if (self.nz == None):
                        sample_images_list = get_sample_images_list(
                            (self.val_dl, self.netG_A2B, self.netG_B2A,
                             self.device))
                        plot_image = get_display_samples(
                            sample_images_list, 6, 3)
                    else:
                        sample_images_list = get_sample_images_list_noise(
                            (self.val_dl, self.netG_A2B, self.netG_B2A,
                             self.fixed_noise, self.device))
                        plot_image = get_display_samples(
                            sample_images_list, 9, 6)

                    cur_file_name = os.path.join(
                        self.save_img_dir,
                        str(self.save_cnt) + ' : ' + str(epoch) + '-' +
                        str(i) + '.jpg')
                    self.save_cnt += 1
                    cv2.imwrite(cur_file_name, plot_image)
Exemple #7
0
    def __init__(self,
                 loss_type,
                 netD_A,
                 netD_B,
                 netD_A_2,
                 netD_B_2,
                 netG_A2B,
                 netG_B2A,
                 device,
                 train_dl,
                 val_dl,
                 lr_D=0.0002,
                 lr_G=0.0002,
                 cycle_weight=10,
                 identity_weight=5.0,
                 ds_weight=8,
                 resample=True,
                 weight_clip=None,
                 use_gradient_penalty=False,
                 loss_interval=50,
                 image_interval=50,
                 save_img_dir='saved_images/'):
        self.netD_A = netD_A
        self.netD_B = netD_B
        self.netD_A_2 = netD_A_2
        self.netD_B_2 = netD_B_2

        self.netG_A2B = netG_A2B
        self.netG_B2A = netG_B2A
        self.train_dl = train_dl
        self.val_dl = val_dl
        self.lr_D = lr_D
        self.lr_G = lr_G
        self.train_iteration_per_epoch = len(self.train_dl)
        self.device = device
        self.resample = resample
        self.weight_clip = weight_clip
        self.use_gradient_penalty = use_gradient_penalty
        self.cycle_weight = cycle_weight
        self.identity_weight = identity_weight
        self.ds_weight = ds_weight

        self.nz = self.netG_A2B.nz
        self.fixed_noise = generate_noise(3, self.nz, self.device)

        self.loss_type = loss_type
        self.require_type = get_require_type(self.loss_type)
        self.loss = get_gan_loss(self.device, self.loss_type)
        self.ds_loss = DSGAN_Loss(self.device, self.nz)

        self.optimizerD_A = optim.Adam(chain(self.netD_A.parameters(),
                                             self.netD_A_2.parameters()),
                                       lr=self.lr_D,
                                       betas=(0.5, 0.999))
        self.optimizerD_B = optim.Adam(chain(self.netD_B.parameters(),
                                             self.netD_B_2.parameters()),
                                       lr=self.lr_D,
                                       betas=(0.5, 0.999))
        self.optimizerG = optim.Adam(chain(self.netG_A2B.parameters(),
                                           self.netG_B2A.parameters()),
                                     lr=self.lr_G,
                                     betas=(0.5, 0.999))

        self.loss_interval = loss_interval
        self.image_interval = image_interval

        self.save_cnt = 0
        self.save_img_dir = save_img_dir
        if (not os.path.exists(self.save_img_dir)):
            os.makedirs(self.save_img_dir)
    def train(self, real_image, epochs, singan_monitor):
        # list_c_loss = []
        # list_g_loss = []
        list_c_wass_loss = []
        list_c_gp_loss = []
        list_g_adv_loss = []
        list_g_rec_loss = []
        list_lr = []

        if self.index_scale == 0:
            rec_from_prev_scale = None
        else:
            rec_from_prev_scale = self._get_upsampled_reconstructed_img_from_prev_scale(
            )

        self._calculate_noise_amplitude(real_image, rec_from_prev_scale)
        print(
            f"Noise amplitude at scale {self.index_scale}: {self.noise_amplitude}"
        )

        # Setup optimizers with learning rate decay
        self.c_optimizer = Adam(
            learning_rate=tf.keras.optimizers.schedules.PiecewiseConstantDecay(
                cfg.TRAIN.LR_SCHEDULER_STEPS, [5e-4, 5e-5, 5e-6]),
            beta_1=0.5,
            beta_2=0.999)
        self.g_optimizer = Adam(
            learning_rate=tf.keras.optimizers.schedules.PiecewiseConstantDecay(
                cfg.TRAIN.LR_SCHEDULER_STEPS, [5e-4, 5e-5, 5e-6]),
            beta_1=0.5,
            beta_2=0.999)
        t = trange(epochs, desc="Epoch ")
        for i in t:
            # Get a single noise tensor to use for all the training steps of the current epoch
            # noise = self.noise_amplitude * generate_noise((1, *self.img_shape))
            noise = generate_noise((1, *self.img_shape), self.noise_amplitude)

            # ----- Train the critic -----
            for _ in range(self.critic_steps):
                # Generate a single fake image to be evaluated by the critic using the fixed noise
                # fake_image = self._generate_image_for_training(noise)
                img_from_prev_scale = self._get_upsampled_img_from_prev_scale()
                img_from_prev_scale = img_from_prev_scale[np.newaxis, :, :, :]
                c_wass_loss, c_gp_loss = self._train_critic_step(
                    real_image, noise, img_from_prev_scale)

                list_c_wass_loss.append(c_wass_loss)
                list_c_gp_loss.append(c_gp_loss)

            # ----- Train the generator -----
            # img_from_prev_scale = self._get_upsampled_img_from_prev_scale()
            for _ in range(self.gen_steps):
                # fake_image = self._generate_image_for_training(noise, img_from_prev_scale)
                g_adv_loss, g_rec_loss = self._train_generator_step(
                    real_image, noise, img_from_prev_scale,
                    rec_from_prev_scale)

                list_g_adv_loss.append(g_adv_loss)
                list_g_rec_loss.append(g_rec_loss)

            # list_c_loss.append(c_loss)
            # list_g_loss.append(g_loss)

            list_lr.append(self.g_optimizer.lr(i).numpy())

            # Print the losses
            t.set_postfix_str(
                f"Gen. Loss: {list_g_adv_loss[-1]} - Critic Loss: {list_c_wass_loss[-1]}"
            )
            t.refresh()

            singan_monitor.save_imgs_on_epoch_end(index_scale=self.index_scale,
                                                  epoch=i)

        # Plot images on training end
        singan_monitor.save_imgs_on_epoch_end(index_scale=self.index_scale,
                                              epoch=epochs)

        return list_c_wass_loss, list_c_gp_loss, list_g_adv_loss, list_g_rec_loss, list_lr
shrink_size = 3.5
threshold = 5
griffin_lim_iter = 100

sz, ic, oc, use_bn, norm_type = 256, 1, 1, True, 'instancenorm'
netG = UNet_G(ic, oc, sz, nz, use_bn, norm_type).to(device)
# netG = ResNet_G(ic, oc, sz, nz = nz, norm_type = norm_type).to(device)
netG.load_state_dict(torch.load(model_path, map_location = 'cpu'))
netG.eval()

cnt, total_num = 0, 10
y = read_audio(os.path.join(input_wav_dir, input_wav_list[cnt]), sample_rate, pre_emphasis_rate)
mel = get_mel(get_stft(y, n_fft, win_length, hop_length), sample_rate, n_fft, n_mels, power, shrink_size)
spec = cv2.resize(cv2.cvtColor(mel_to_spectrogram(mel, threshold, None), cv2.COLOR_GRAY2RGB), (sz, sz))
spec_t = transform_image(spec, sz, ic)
noise = generate_noise(1, nz, device)
out = generate(netG, spec_t, noise, oc, sz, device)

while(1):
	cv2.imshow('Input', spec)
	cv2.imshow('Output', out)

	key = cv2.waitKey(1) & 0xFF

	if(key == ord('q')):
		break

	elif(key == ord('r')):
		noise = generate_noise(1, nz, device)
		out = generate(netG, spec_t, noise, oc, sz, device)
Exemple #10
0
    def inference(self, start_img_input):
        if self.config.save_attention_map:
            global global_att_dir
            global global_epoch
            global_att_dir = f'{self.config.infer_dir}/attention'
            os.makedirs(global_att_dir, exist_ok=True)

        if start_img_input is None:
            start_img_input = torch.full(self.reals[0].shape, 0, device=self.config.device)

        cur_images = []

        for idx, (G, D, Z_opt, noise_amp, real) in enumerate(zip(self.Gs, self.Ds, self.Zs, self.noise_amps, self.reals)):
            padding_size = ((self.config.kernel_size - 1) * self.config.num_layers) / 2
            pad = nn.ZeroPad2d(int(padding_size))
            output_h = (Z_opt.shape[2] - padding_size * 2) * self.config.scale_h
            output_w = (Z_opt.shape[3] - padding_size * 2) * self.config.scale_w

            prev_images = cur_images
            cur_images = []

            for i in tqdm(range(self.config.num_samples)):
                if idx == 0:
                    random_z = generate_noise([1, output_h, output_w], device=self.config.device)
                    random_z = random_z.expand(1, 3, random_z.shape[2], random_z.shape[3])
                    padded_random_z = pad(random_z)
                else:
                    random_z = generate_noise([self.config.img_channel, output_h, output_w], device=self.config.device)
                    padded_random_z = pad(random_z)

                if self.config.use_fixed_noise and idx < self.config.gen_start_scale:
                    padded_random_z = Z_opt

                if not prev_images:
                    padded_random_img = pad(start_img_input)
                else:
                    prev_img = prev_images[i]
                    upscaled_prev_random_img = resize_img(prev_img, 1 / self.config.scale_factor, self.config)
                    if self.config.mode == "train_SR":
                        padded_random_img = pad(upscaled_prev_random_img)
                    else:
                        upscaled_prev_random_img = upscaled_prev_random_img[:, :,
                                                   0:round(self.config.scale_h * self.reals[idx].shape[2]),
                                                   0:round(self.config.scale_w * self.reals[idx].shape[3])]
                        padded_random_img = pad(upscaled_prev_random_img)
                        padded_random_img = padded_random_img[:, :, 0:padded_random_z.shape[2], 0:padded_random_z.shape[3]]
                        padded_random_img = upsampling(padded_random_img, padded_random_z.shape[2], padded_random_z.shape[3])

                padded_random_img_with_z = noise_amp * padded_random_z + padded_random_img
                cur_image = G(padded_random_img_with_z.detach(), padded_random_img)

                np_cur_image = torch2np(cur_image.detach())
                if self.config.save_all_pyramid:
                    plt.imsave(f'{self.config.infer_dir}/{i}_{idx}.png', np_cur_image, vmin=0, vmax=1)
                    if self.config.save_attention_map:
                        np_real = torch2np(real)
                        _, _, cur_add_att_maps, cur_sub_att_maps = D(cur_image.detach())
                        cur_add_att_maps = cur_add_att_maps.detach().to(torch.device('cpu')).numpy().transpose(1, 2, 3, 0)
                        cur_sub_att_maps = cur_sub_att_maps.detach().to(torch.device('cpu')).numpy().transpose(1, 2, 3, 0)
                        global_epoch = f'{i}_{idx}thG'
                        parmap.map(save_heatmap, [[np_cur_image, cur_add_att_maps, 'infer_add'], [np_real, cur_sub_att_maps, 'infer_sub']],
                                   pm_pbar=False, pm_processes=2)

                elif idx == len(self.reals) - 1:
                    plt.imsave(f'{self.config.infer_dir}/{i}.png', np_cur_image, vmin=0, vmax=1)
                    if self.config.save_attention_map:
                        np_real = torch2np(real)
                        _, _, cur_add_att_maps, cur_sub_att_maps = D(cur_image.detach())
                        cur_add_att_maps = cur_add_att_maps.detach().to(torch.device('cpu')).numpy().transpose(1, 2, 3, 0)
                        cur_sub_att_maps = cur_sub_att_maps.detach().to(torch.device('cpu')).numpy().transpose(1, 2, 3, 0)
                        global_epoch = f'{i}_{idx}thG'
                        parmap.map(save_heatmap, [[np_cur_image, cur_add_att_maps, 'infer_add'], [np_real, cur_sub_att_maps, 'infer_sub']],
                                   pm_pbar=False, pm_processes=2)

                cur_images.append(cur_image)

        return cur_image.detach()
Exemple #11
0
    def train_single_stage(self, cur_discriminator, cur_generator):
        real = self.reals[len(self.Gs)]
        _, _, real_h, real_w = real.shape

        # Set padding layer(Initial padding) - To Do (Change this for noise padding not zero-padding)g
        self.config.receptive_field = self.config.kernel_size + ((self.config.kernel_size - 1) * (self.config.num_layers - 1)) * self.config.stride
        padding_size = int(((self.config.kernel_size - 1) * self.config.num_layers) / 2)
        noise_pad = nn.ZeroPad2d(int(padding_size))
        image_pad = nn.ZeroPad2d(int(padding_size))

        # MultiStepLR: lr *= gamma every time reaches one of the milestones
        D_optimizer = optim.Adam(cur_discriminator.parameters(), lr=self.config.d_lr, betas=(self.config.beta1, self.config.beta2))
        G_optimizer = optim.Adam(cur_generator.parameters(), lr=self.config.g_lr, betas=(self.config.beta1, self.config.beta2))
        D_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=D_optimizer, milestones=self.config.milestones, gamma=self.config.gamma)
        G_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=G_optimizer, milestones=self.config.milestones, gamma=self.config.gamma)

        # Calculate noise amp(amount of info to generate) and recover prev_rec image
        if not self.Gs:
            rec_z = generate_noise([1, real_h, real_w], device=self.config.device).expand(1, 3, real_h, real_w)
            self.first_img_input = torch.full([1, self.config.img_channel, real_h, real_w], 0, device=self.config.device)
            upscaled_prev_rec_img = self.first_img_input
            self.config.noise_amp = 1
        else:
            rec_z = torch.full([1, self.config.img_channel, real_h, real_w], 0, device=self.config.device)
            upscaled_prev_rec_img = self.draw_sequentially('rec', noise_pad, image_pad)
            criterion = nn.MSELoss()
            rmse = torch.sqrt(criterion(real, upscaled_prev_rec_img))
            self.config.noise_amp = self.config.noise_amp_init * rmse
        padded_rec_z = noise_pad(rec_z)
        padded_rec_img = image_pad(upscaled_prev_rec_img)

        for epoch in tqdm(range(self.config.num_iter), desc=f'{len(self.Gs)}th GAN'):
            random_z = generate_noise([1, real_h, real_w], device=self.config.device).expand(1, 3, real_h, real_w) \
                if not self.Gs else generate_noise([self.config.img_channel, real_h, real_w], device=self.config.device)
            padded_random_z = noise_pad(random_z)

            # Train Discriminator: Maximize D(x) - D(G(z)) -> Minimize D(G(z)) - D(X)
            for i in range(self.config.n_critic):
                # Make random image input
                upscaled_prev_random_img = self.draw_sequentially('rand', noise_pad, image_pad)
                padded_random_img = image_pad(upscaled_prev_random_img)
                padded_random_img_with_z = self.config.noise_amp * padded_random_z + padded_random_img

                # Calculate loss with real data
                cur_discriminator.zero_grad()
                real_prob_out, real_acm_oth, real_add_att_maps, real_sub_att_maps = cur_discriminator(real)
                d_real_loss = -real_prob_out.mean()                         # Maximize D(X) -> Minimize -D(X)

                # Calculate loss with fake data
                fake = cur_generator(padded_random_img_with_z.detach(), padded_random_img)
                fake_prob_out, fake_acm_oth, _, _ = cur_discriminator(fake.detach())
                d_fake_loss = fake_prob_out.mean()                          # Minimize D(G(z))

                # Gradient penalty
                gradient_penalty = calcul_gp(cur_discriminator, real, fake, self.config.device)

                # Update parameters
                d_loss = d_real_loss + d_fake_loss + (gradient_penalty * self.config.gp_weights)
                if self.config.use_acm_oth:
                    d_loss += (torch.abs(real_acm_oth.mean()) + torch.abs(fake_acm_oth.mean())) * self.config.acm_weights
                d_loss.backward()
                D_optimizer.step()

                # Log losses
                critic = -d_real_loss.item() -d_fake_loss.item()            # D(x) - D(G_z)
                self.log_losses[f'{len(self.Gs)}th_D/d'] = d_loss.item()
                self.log_losses[f'{len(self.Gs)}th_D/d_critic'] = critic
                self.log_losses[f'{len(self.Gs)}th_D/d_gp'] = gradient_penalty.item()
                if self.config.use_acm_oth:
                    self.log_losses[f'{len(self.Gs)}th_D/d_oth'] = real_acm_oth.item() + fake_acm_oth.item()

            # Train Generator : Maximize D(G(z)) -> Minimize -D(G(z))
            for i in range(self.config.generator_iter):
                cur_generator.zero_grad()

                # Make fake sample for every iteration
                upscaled_prev_random_img = self.draw_sequentially('rand', noise_pad, image_pad)
                padded_random_img = image_pad(upscaled_prev_random_img)
                padded_random_img_with_z = self.config.noise_amp * padded_random_z + padded_random_img
                fake = cur_generator(padded_random_img_with_z.detach(), padded_random_img)

                # Adversarial loss
                fake_prob_out, _, fake_add_att_maps, fake_sub_att_maps = cur_discriminator(fake)
                g_adv_loss = -fake_prob_out.mean()

                # Reconstruction loss
                mse_criterion = nn.MSELoss()
                padded_rec_img_with_z = self.config.noise_amp * padded_rec_z + padded_rec_img
                g_rec_loss = mse_criterion(cur_generator(padded_rec_img_with_z.detach(), padded_rec_img), real)

                # Update parameters
                g_loss = g_adv_loss + (g_rec_loss * self.config.rec_weights)
                g_loss.backward()
                G_optimizer.step()

                # Log losses
                self.log_losses[f'{len(self.Gs)}th_G/g'] = g_loss
                self.log_losses[f'{len(self.Gs)}th_G/g_critic'] = -g_adv_loss.item()
                self.log_losses[f'{len(self.Gs)}th_G/g_rec'] = g_rec_loss.item()

            # Log losses
            for key, value in self.log_losses.items():
                self.writer.add_scalar(key, value, epoch)
            self.log_losses = {}

            # Log image
            if epoch % self.config.img_save_iter == 0 or epoch == (self.config.num_iter - 1):
                np_real = torch2np(real)
                np_fake = torch2np(fake.detach())
                plt.imsave(f'{self.config.result_dir}/{epoch}_fake_sample.png', np_fake, vmin=0, vmax=1)
                plt.imsave(f'{self.config.result_dir}/{epoch}_fixed_noise.png', torch2np(padded_rec_img_with_z.detach() * 2 - 1), vmin=0, vmax=1)
                plt.imsave(f'{self.config.result_dir}/{epoch}_reconstruction.png', torch2np(cur_generator(padded_rec_img_with_z.detach(), padded_rec_img).detach()), vmin=0, vmax=1)
                real_add_att_maps = real_add_att_maps.detach().to(torch.device('cpu')).numpy().transpose(1, 2, 3, 0)
                real_sub_att_maps = real_sub_att_maps.detach().to(torch.device('cpu')).numpy().transpose(1, 2, 3, 0)
                fake_add_att_maps = fake_add_att_maps.detach().to(torch.device('cpu')).numpy().transpose(1, 2, 3, 0)
                fake_sub_att_maps = fake_sub_att_maps.detach().to(torch.device('cpu')).numpy().transpose(1, 2, 3, 0)

                global global_att_dir
                global global_epoch
                global_att_dir = self.config.att_dir
                global_epoch = epoch

                parmap.map(save_heatmap, [[np_real, real_add_att_maps, 'real_add'], [np_real, real_sub_att_maps, 'real_sub'],
                                          [np_fake, fake_add_att_maps, 'fake_add'], [np_real, fake_sub_att_maps, 'fake_sub']],
                           pm_pbar=False, pm_processes=4)

            D_scheduler.step()
            G_scheduler.step()

        # Save model weights
        torch.save(cur_generator.state_dict(), f'{self.config.result_dir}/generator.pth')
        torch.save(cur_discriminator.state_dict(), f'{self.config.result_dir}/ACM_discriminator.pth')

        return padded_rec_z, cur_generator, cur_discriminator
    def train(self, num_epoch):
        l1 = nn.L1Loss()

        for epoch in range(num_epoch):
            if (self.resample):
                train_dl_iter = iter(self.train_dl)

            for i, (x, y) in enumerate(tqdm(self.train_dl)):
                x = x.to(self.device)
                y = y.to(self.device)
                bs = x.size(0)
                noise = generate_noise(bs, self.nz, self.device)
                fake_y = self.netG(x, noise)

                self.netD.zero_grad()

                c_xr = self.netD(x, y)
                c_xr = c_xr.view(-1)
                c_xf = self.netD(x, fake_y.detach())
                c_xf = c_xf.view(-1)

                if (self.require_type == 0 or self.require_type == 1):
                    errD = self.loss.d_loss(c_xr, c_xf)
                elif (self.require_type == 2):
                    errD = self.loss.d_loss(c_xr, c_xf, y, fake_y)

                if (self.use_gradient_penalty != False):
                    errD += self.use_gradient_penalty * self.gradient_penalty(
                        x, y, fake_y)

                errD.backward()
                self.optimizerD.step()

                if (self.weight_clip != None):
                    for param in self.netD.parameters():
                        param.data.clamp_(-self.weight_clip, self.weight_clip)

                self.netG.zero_grad()
                if (self.resample):
                    x, y = next(train_dl_iter)
                    x = x.to(self.device)
                    y = y.to(self.device)
                    bs = x.size(0)
                    noise = generate_noise(bs, self.nz, self.device)
                    fake_y = self.netG(x, noise)

                if (self.require_type == 0):
                    c_xr = None
                    c_xf, f1 = self.netD(x, fake_y, True)  # (bs, 1, 1, 1)
                    c_xf = c_xf.view(-1)  # (bs)
                    errG_1 = self.loss.g_loss(c_xf)
                if (self.require_type == 1 or self.require_type == 2):
                    c_xr, f2 = self.netD(x, y, True)  # (bs, 1, 1, 1)
                    c_xr = c_xr.view(-1)  # (bs)
                    c_xf, f1 = self.netD(x, fake_y, True)  # (bs, 1, 1, 1)
                    c_xf = c_xf.view(-1)  # (bs)
                    errG_1 = self.loss.g_loss(c_xr, c_xf)

                if (self.ds_weight == 0):
                    ds_loss = 0
                else:
                    noise1 = generate_noise(bs, self.nz, self.device)
                    noise2 = generate_noise(bs, self.nz, self.device)
                    fake_y1 = self.netG(x, noise1)
                    fake_y2 = self.netG(x, noise2)
                    ds_loss = self.ds_loss.get_loss(fake_y1, fake_y2, noise1,
                                                    noise2)

                if (self.rec_weight == 0):
                    rec_loss = 0
                else:
                    if (self.use_rec_feature):
                        rec_loss = 0
                        if (c_xr == None):
                            c_xr, f2 = self.netD(x, y, True)  # (bs, 1, 1, 1)
                            c_xr = c_xr.view(-1)  # (bs)
                            for f1_, f2_ in zip(f1, f2):
                                rec_loss += (f1_ - f2_).abs().mean()
                            rec_loss /= len(f1)

                    else:
                        rec_loss = l1(fake_y, y)

                errG = errG_1 + rec_loss * self.rec_weight + ds_loss * self.ds_weight
                errG.backward()
                #update G using the gradients calculated previously
                self.optimizerG.step()

                if (i % self.loss_interval == 0):
                    print('[%d/%d] [%d/%d] errD : %.4f, errG : %.4f' %
                          (epoch + 1, num_epoch, i + 1,
                           self.train_iteration_per_epoch, errD, errG))

                if (i % self.image_interval == 0):
                    if (self.nz == None):
                        sample_images_list = get_sample_images_list(
                            (self.val_dl, self.netG, self.device))
                        plot_image = get_display_samples(
                            sample_images_list, 3, 3)
                    else:
                        sample_images_list = get_sample_images_list_noise(
                            (self.val_dl, self.netG, self.fixed_noise,
                             self.device))
                        plot_image = get_display_samples(
                            sample_images_list, 9, 3)

                    cur_file_name = os.path.join(
                        self.save_img_dir,
                        str(self.save_cnt) + ' : ' + str(epoch) + '-' +
                        str(i) + '.jpg')
                    self.save_cnt += 1
                    cv2.imwrite(cur_file_name, plot_image)
Exemple #13
0
    def inference(self, start_img_input):
        if start_img_input is None:
            start_img_input = torch.full(self.reals[0].shape,
                                         0,
                                         device=self.config.device)

        cur_images = []
        for idx, (G, Z_opt, noise_amp) in tqdm(
                enumerate(zip(self.Gs, self.Zs, self.noise_amps))):
            padding_size = (
                (self.config.kernel_size - 1) * self.config.num_layers) / 2
            pad = nn.ZeroPad2d(int(padding_size))
            output_h = (Z_opt.shape[2] -
                        padding_size * 2) * self.config.scale_h
            output_w = (Z_opt.shape[3] -
                        padding_size * 2) * self.config.scale_w

            prev_images = cur_images
            cur_images = []

            for i in range(self.config.num_samples):
                if idx == 0:
                    random_z = generate_noise([1, output_h, output_w],
                                              device=self.config.device)
                    random_z = random_z.expand(1, 3, random_z.shape[2],
                                               random_z.shape[3])
                    padded_random_z = pad(random_z)
                else:
                    random_z = generate_noise(
                        [self.config.img_channel, output_h, output_w],
                        device=self.config.device)
                    padded_random_z = pad(random_z)

                if self.config.use_fixed_noise and idx < self.config.gen_start_scale:
                    padded_random_z = Z_opt

                if not prev_images:
                    padded_random_img = pad(start_img_input)
                else:
                    prev_img = prev_images[i]
                    upscaled_prev_random_img = resize_img(
                        prev_img, 1 / self.config.scale_factor, self.config)
                    if self.config.mode == "train_SR":
                        padded_random_img = pad(upscaled_prev_random_img)
                    else:
                        upscaled_prev_random_img = upscaled_prev_random_img[:, :, 0:round(
                            self.config.scale_h * self.reals[idx].shape[2]
                        ), 0:round(self.config.scale_w *
                                   self.reals[idx].shape[3])]
                        padded_random_img = pad(upscaled_prev_random_img)
                        padded_random_img = padded_random_img[:, :,
                                                              0:padded_random_z
                                                              .shape[2],
                                                              0:padded_random_z
                                                              .shape[3]]
                        padded_random_img = upsampling(
                            padded_random_img, padded_random_z.shape[2],
                            padded_random_z.shape[3])

                padded_random_img_with_z = noise_amp * padded_random_z + padded_random_img
                cur_image = G(padded_random_img_with_z.detach(),
                              padded_random_img)

                if self.config.save_all_pyramid:
                    plt.imsave(f'{self.config.infer_dir}/{i}_{idx}.png',
                               torch2np(cur_image.detach()),
                               vmin=0,
                               vmax=1)
                elif idx == len(self.reals) - 1:
                    plt.imsave(f'{self.config.infer_dir}/{i}.png',
                               torch2np(cur_image.detach()),
                               vmin=0,
                               vmax=1)

                cur_images.append(cur_image)

        return cur_image.detach()
Exemple #14
0
    def train_single_stage(self, cur_discriminator, cur_generator):
        real = self.reals[len(self.Gs)]
        _, _, real_h, real_w = real.shape

        # Set padding layer(Initial padding) - To Do (Change this for noise padding not zero-padding)g
        self.config.receptive_field = self.config.kernel_size + (
            (self.config.kernel_size - 1) *
            (self.config.num_layers - 1)) * self.config.stride
        padding_size = int(
            ((self.config.kernel_size - 1) * self.config.num_layers) / 2)
        noise_pad = nn.ZeroPad2d(int(padding_size))
        image_pad = nn.ZeroPad2d(int(padding_size))

        # MultiStepLR: lr *= gamma every time reaches one of the milestones
        D_optimizer = optim.Adam(cur_discriminator.parameters(),
                                 lr=self.config.d_lr,
                                 betas=(self.config.beta1, self.config.beta2))
        G_optimizer = optim.Adam(cur_generator.parameters(),
                                 lr=self.config.g_lr,
                                 betas=(self.config.beta1, self.config.beta2))
        D_scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer=D_optimizer,
            milestones=self.config.milestones,
            gamma=self.config.gamma)
        G_scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer=G_optimizer,
            milestones=self.config.milestones,
            gamma=self.config.gamma)

        rec_z = torch.full([1, self.config.img_channel, real_h, real_w],
                           0,
                           device=self.config.device)

        for epoch in tqdm(range(self.config.num_iter),
                          desc=f'{len(self.Gs)}th GAN'):
            # Make noise input
            if not self.Gs:
                rec_z = generate_noise([1, real_h, real_w],
                                       device=self.config.device).expand(
                                           1, 3, real_h, real_w)
                random_z = generate_noise([1, real_h, real_w],
                                          device=self.config.device).expand(
                                              1, 3, real_h, real_w)
            else:
                random_z = generate_noise(
                    [self.config.img_channel, real_h, real_w],
                    device=self.config.device)
            padded_rec_z = noise_pad(rec_z)
            padded_random_z = noise_pad(random_z)

            # Train Discriminator: Maximize D(x) - D(G(z)) -> Minimize D(G(z)) - D(X)
            for i in range(self.config.n_critic):
                # Train with real data
                cur_discriminator.zero_grad()
                real_prob_out = cur_discriminator(real)
                d_real_loss = -real_prob_out.mean(
                )  # Maximize D(X) -> Minimize -D(X)
                d_real_loss.backward(retain_graph=True)
                D_x = -d_real_loss.item()

                if i == 0 and epoch == 0:
                    if not self.Gs:
                        self.first_img_input = torch.full(
                            [1, self.config.img_channel, real_h, real_w],
                            0,
                            device=self.config.device)
                        upscaled_prev_random_img = self.first_img_input
                        padded_random_img = image_pad(upscaled_prev_random_img)
                        upscaled_prev_rec_img = torch.full(
                            [1, self.config.img_channel, real_h, real_w],
                            0,
                            device=self.config.device)
                        padded_rec_img = image_pad(upscaled_prev_rec_img)
                        self.config.noise_amp = 1
                    else:
                        upscaled_prev_random_img = self.draw_sequentially(
                            'rand', noise_pad, image_pad)
                        padded_random_img = image_pad(upscaled_prev_random_img)
                        upscaled_prev_rec_img = self.draw_sequentially(
                            'rec', noise_pad, image_pad)
                        criterion = nn.MSELoss()
                        rmse = torch.sqrt(
                            criterion(real, upscaled_prev_rec_img))
                        self.config.noise_amp = self.config.noise_amp_init * rmse
                        padded_rec_img = image_pad(upscaled_prev_rec_img)
                else:
                    upscaled_prev_random_img = self.draw_sequentially(
                        'rand', noise_pad, image_pad)
                    padded_random_img = image_pad(upscaled_prev_random_img)

                # Make random image input
                if not self.Gs:
                    padded_random_img_with_z = padded_random_z
                else:
                    padded_random_img_with_z = (
                        self.config.noise_amp *
                        padded_random_z) + padded_random_img

                # Train with fake data
                fake = cur_generator(padded_random_img_with_z.detach(),
                                     padded_random_img)
                fake_prob_out = cur_discriminator(fake.detach())
                d_fake_loss = fake_prob_out.mean()  # Minimize D(G(z))
                d_fake_loss.backward(retain_graph=True)
                D_G_z = d_fake_loss.item()

                # Gradient penalty
                gradient_penalty = calcul_gp(cur_discriminator, real, fake,
                                             self.config.device,
                                             False) * self.config.gp_weights
                gradient_penalty.backward()

                D_optimizer.step()
                d_loss = d_real_loss + d_fake_loss + gradient_penalty
                critic = D_x - D_G_z
                self.log_losses[f'{len(self.Gs)}th_D/d'] = d_loss.item()
                self.log_losses[f'{len(self.Gs)}th_D/d_critic'] = critic
                self.log_losses[
                    f'{len(self.Gs)}th_D/d_gp'] = gradient_penalty.item()

            # Train Generator : Maximize D(G(z)) -> Minimize -D(G(z))
            for i in range(self.config.generator_iter):
                cur_generator.zero_grad()

                # Make fake sample for every iteration
                # upscaled_prev_random_img = self.draw_sequentially('rand', noise_pad, image_pad)
                # padded_random_img = image_pad(upscaled_prev_random_img)
                # padded_random_img_with_z = self.config.noise_amp * padded_random_z + padded_random_img
                # fake = cur_generator(padded_random_img_with_z.detach(), padded_random_img)

                # Adversarial loss
                fake_prob_out = cur_discriminator(fake)
                g_adv_loss = -fake_prob_out.mean()
                g_adv_loss.backward(retain_graph=True)
                g_adv_loss = g_adv_loss.item()

                # Reconstruction loss
                mse_criterion = nn.MSELoss()
                padded_rec_img_with_z = self.config.noise_amp * padded_rec_z + padded_rec_img
                g_rec_loss = self.config.rec_weights * mse_criterion(
                    cur_generator(padded_rec_img_with_z.detach(),
                                  padded_rec_img), real)
                g_rec_loss.backward(retain_graph=True)
                g_rec_loss = g_rec_loss.item()

                G_optimizer.step()
                g_loss = g_adv_loss + (self.config.rec_weights * g_rec_loss)
                self.log_losses[f'{len(self.Gs)}th_G/g'] = g_loss
                self.log_losses[f'{len(self.Gs)}th_G/g_critic'] = -g_adv_loss
                self.log_losses[f'{len(self.Gs)}th_G/g_rec'] = g_rec_loss

            # Log losses
            for key, value in self.log_losses.items():
                self.writer.add_scalar(key, value, epoch)
            self.log_losses = {}

            # Log image
            if epoch % self.config.img_save_iter == 0 or epoch == (
                    self.config.num_iter - 1):
                plt.imsave(f'{self.config.result_dir}/{epoch}_fake_sample.png',
                           torch2np(fake.detach()),
                           vmin=0,
                           vmax=1)
                plt.imsave(f'{self.config.result_dir}/{epoch}_fixed_noise.png',
                           torch2np(padded_rec_img_with_z.detach() * 2 - 1),
                           vmin=0,
                           vmax=1)
                plt.imsave(
                    f'{self.config.result_dir}/{epoch}_reconstruction.png',
                    torch2np(
                        cur_generator(padded_rec_img_with_z.detach(),
                                      padded_rec_img).detach()),
                    vmin=0,
                    vmax=1)

            D_scheduler.step()
            G_scheduler.step()

        # Save model weights
        torch.save(cur_generator.state_dict(),
                   f'{self.config.result_dir}/generator.pth')
        torch.save(cur_discriminator.state_dict(),
                   f'{self.config.result_dir}/discriminator.pth')

        return padded_rec_z, cur_generator