Exemple #1
0
def train_gan(generator,
              discriminator,
              image_loader,
              num_epochs,
              batch_size,
              cuda=True,
              g_lr=1e-3,
              d_lr=1e-3,
              filename_prefix="results",
              save_gen_images=False):
    if cuda:
        dtype = torch.cuda.FloatTensor
        generator.cuda()
        discriminator.cuda()
    else:
        dtype = torch.FloatTensor

    iters = 0
    d_optimizer = create_optimizer(discriminator, lr=d_lr, betas=(.5, .999))
    g_optimizer = create_optimizer(generator, lr=g_lr, betas=(.5, .999))
    BCELoss = nn.BCELoss()
    for epoch in range(num_epochs):
        for x, _ in image_loader:
            if x.shape[0] != batch_size:
                continue

            real_data = x.type(dtype)

            z = generate_noise(batch_size).type(dtype)
            fake_images = generator(z)
            g_result = discriminator(fake_images).squeeze()
            # g_cost = BCELoss(g_result, torch.ones(batch_size).type(dtype))
            g_cost = torch.mean(g_result)
            g_cost.backward()
            g_optimizer.step()
            g_optimizer.zero_grad()

            d_optimizer.zero_grad()
            z = generate_noise(batch_size).type(dtype)
            fake_images = generator(z)
            d_spred_fake = discriminator(fake_images).squeeze()
            d_cost_fake = BCELoss(d_spred_fake,
                                  torch.zeros(batch_size).type(dtype))
            d_spred_real = discriminator(real_data).squeeze()
            d_cost_real = BCELoss(d_spred_real,
                                  torch.ones(batch_size).type(dtype))
            # d_cost = d_cost_real + d_cost_fake
            d_cost = 0 - torch.mean(d_spred_real - d_spred_fake)
            d_cost.backward()
            d_optimizer.step()
            iters += 1
        if save_images:
            save_images(generator, epoch, iters, filename_prefix)
        print("Epoch", epoch, "Iter", iters)
        print("d_cost", d_cost)
        print("g_cost", g_cost)
        print("Inception Score", get_inception_score(generator))

    return discriminator, generator
Exemple #2
0
def calc_losses_sagan(batch_real, generator, discriminator, optim_generator,
                      optim_discr, device, noise_size):

    # Calculate loss of according to https://arxiv.org/abs/1805.08318
    #Hinge loss for the dicrimnator
    discriminator.train()
    generator.train()
    optim_discr.zero_grad()
    noise_for_discr = utils.generate_noise(batch_real.size()[0],
                                           noise_size).to(device)
    fake_for_discr = generator(noise_for_discr).to(device)

    discr_fake_pred = discriminator(fake_for_discr)

    #Dicrimnator loss is the Hinge loss of real image and fake images
    disc_fake_loss = relu(1.0 + discr_fake_pred).mean()
    disc_real_pred = discriminator(batch_real)
    disc_real_loss = relu(1.0 - disc_real_pred).mean()

    # total loss is the loss on fake + loss on real
    total_discr_loss = disc_fake_loss + disc_real_loss
    total_discr_loss.backward()
    optim_discr.step()

    # Generator loss
    generator.train()
    #discriminator.eval()
    optim_generator.zero_grad()
    noise_gener = utils.generate_noise(batch_real.size()[0],
                                       noise_size).to(device)
    fake = generator(noise_gener).to(device)
    discr_fake_pred_for_gener = discriminator(fake)
    # We minimise the dicrimnator mean value  on the generator weights
    # multiply by -1 for gradient ascent
    gener_loss = -discr_fake_pred_for_gener.mean()
    # gener_loss = criterion(discr_fake_pred_for_gener, torch.ones_like(discr_fake_pred_for_gener))
    gener_loss.backward()
    optim_generator.step()

    # Calculate accuracies of the discriminator
    batch_correct_discr_real = (disc_real_pred.detach().cpu() >
                                0).float().numpy().sum()
    batch_correct_discr_fake = (discr_fake_pred.detach().cpu() <
                                0).float().numpy().sum()

    batch_accuracy_real = batch_correct_discr_real * 100 / batch_real.size(0)
    batch_accuracy_fake = batch_correct_discr_fake * 100 / batch_real.size(0)

    return gener_loss.item(), total_discr_loss.item(
    ), batch_accuracy_real, batch_accuracy_fake
Exemple #3
0
    def train_one_iter_D(self, images):
        self.optimizerD.zero_grad()
        # get noise (z)
        noise = generate_noise(self.batch_size, self.z_dim, type='normal')
        noise = noise.to(self.device)
        # get fake and real images
        real_label = torch.ones(self.batch_size).to(self.device)
        fake_label = torch.zeros(self.batch_size).to(self.device)

        # forward and backward
        fake_images = self.generator(noise)
        fake_pred, _ = self.discriminator(fake_images)
        fake_loss = self.adv_loss(fake_pred, fake_label)

        real_pred, _ = self.discriminator(images)
        real_loss = self.adv_loss(real_pred, real_label)

        # accumulate gradient and update
        D_loss = fake_loss + real_loss
        D_loss.backward()
        self.optimizerD.step()

        # logging
        D_x = real_pred.mean(0).item()
        D_G_z = fake_pred.mean(0).item()
        self.D_x.update(D_x, self.batch_size)
        self.D_G_z.update(D_G_z, self.batch_size)
        self.D_loss.update(D_loss, self.batch_size)
Exemple #4
0
	def __init__(self, loss_type, netD, netG, device, train_ds, val_ds, lr_D = 0.0002, lr_G = 0.0002, rec_weight = 10, ds_weight = 8, use_rec_feature = False, resample = True, weight_clip = None, use_gradient_penalty = False, loss_interval = 50, image_interval = 50, save_img_dir = 'saved_images/'):
		self.netD = netD
		self.netG = netG
		self.train_ds = train_ds
		self.val_ds = val_ds
		self.lr_D = lr_D
		self.lr_G = lr_G
		self.device = device
		self.resample = resample
		self.weight_clip = weight_clip
		self.use_gradient_penalty = use_gradient_penalty
		self.rec_weight = rec_weight
		self.use_rec_feature = use_rec_feature
		self.ds_weight = ds_weight

		self.nz = self.netG.nz
		self.fixed_noise = generate_noise(3, self.nz, self.device)

		self.loss_type = loss_type
		self.require_type = get_require_type(self.loss_type)
		self.loss = get_gan_loss(self.device, self.loss_type)
		self.ds_loss = DSGAN_Loss(self.device, self.nz)

		self.optimizerD = optim.Adam(self.netD.parameters(), lr = self.lr_D, betas = (0, 0.9))
		self.optimizerG = optim.Adam(self.netG.parameters(), lr = self.lr_G, betas = (0, 0.9))

		self.loss_interval = loss_interval
		self.image_interval = image_interval

		self.save_cnt = 0
		self.save_img_dir = save_img_dir
		if(not os.path.exists(self.save_img_dir)):
			os.makedirs(self.save_img_dir)
Exemple #5
0
    def forward(self, noise_init, noise_amp, mode='rand'):

        x_prev_out = self.body[0](F.pad(noise_init, self.p3d))

        for idx, block in enumerate(self.body[1:], 1):
            x_prev_out = torch.tanh(x_prev_out)

            # Upscale
            x_prev_out_up = utils.upscale(x_prev_out, idx, self.opt)

            # Add noise if "random" sampling, else, add no noise is "reconstruction" mode
            if mode == 'rand':
                x_prev_out_up_2 = utils.interpolate_3D(
                    x_prev_out,
                    size=[
                        x_prev_out_up.shape[-3] + (self.opt.num_layer + 2) * 2,
                        x_prev_out_up.shape[-2] + (self.opt.num_layer + 2) * 2,
                        x_prev_out_up.shape[-1] + (self.opt.num_layer + 2) * 2
                    ])
                noise = utils.generate_noise(ref=x_prev_out_up_2)
                x_prev = block(x_prev_out_up_2 + noise * noise_amp[idx])
            else:
                x_prev = block(F.pad(x_prev_out_up, self.p3d))
            x_prev_out = x_prev + x_prev_out_up

        out = torch.tanh(x_prev_out)
        return out
Exemple #6
0
    def __init__(self, opts, data, weights):

        # Create a new session with session.graph = default graph
        self._session = tf.Session()
        self._trained = False
        self._data = data
        self._data_weights = np.copy(weights)
        # Latent noise sampled ones to apply decoder while training
        self._noise_for_plots = utils.generate_noise(opts, 500)
        # Placeholders
        self._real_points_ph = None
        self._noise_ph = None

        # Main operations
        # FIX
        self._loss = None
        self._loss_reconstruct = None
        self._loss_kl = None
        self._generated = None
        self._reconstruct_x = None

        # Optimizers
        self.optim = None

        with self._session.as_default(), self._session.graph.as_default():
            logging.error('Building the graph...')
            self._build_model_internal(opts)

        # Make sure AdamOptimizer, if used in the Graph, is defined before
        # calling global_variables_initializer().
        init = tf.global_variables_initializer()
        self._session.run(init)
	def __init__(self, netD, netG, device, train_dl, lr_D = 0.0002, lr_G = 0.0002, beta1 = 0.5, loss_interval = 50, image_interval = 50, snapshot_interval = None, save_img_dir = 'saved_images/', save_snapshot_dir = 'saved_snapshots', resample = False):
		self.netD = netD
		self.netG = netG
		self.train_dl = train_dl
		self.lr_D = lr_D
		self.lr_G = lr_G
		self.train_iteration_per_epoch = len(self.train_dl)
		self.device = device
		self.resample = resample
		self.special = None

		self.optimizerD = optim.Adam(self.netD.parameters(), lr = self.lr_D, betas = (beta1, 0.999))
		self.optimizerG = optim.Adam(self.netG.parameters(), lr = self.lr_G, betas = (beta1, 0.999))

		self.real_label = 1
		self.fake_label = 0
		self.nz = self.netG.nz

		self.fixed_noise = generate_noise(16, self.nz, self.device)
		self.loss_interval = loss_interval
		self.image_interval = image_interval
		self.snapshot_interval = snapshot_interval

		self.errD_records = []
		self.errG_records = []

		self.save_cnt = 0
		self.save_img_dir = save_img_dir
		self.save_snapshot_dir = save_snapshot_dir
		if(not os.path.exists(self.save_img_dir)):
			os.makedirs(self.save_img_dir)
		if(not os.path.exists(self.save_snapshot_dir)):
			os.makedirs(self.save_snapshot_dir)
Exemple #8
0
    def train_one_iter_G(self, images=None):
        self.optimizerG.zero_grad()
        # get noise (z)
        noise = generate_noise(self.batch_size, self.z_dim, type='normal')
        noise = noise.to(self.device)
        # get real labels ( modified loss, minimize -log(D(G(z)) )
        real_label = torch.ones(self.batch_size).to(self.device)
        fake_images = self.generator(noise)

        if self.loss_type == 'adversarial':
            # vanilla adversarial loss
            fake_pred, _ = self.discriminator(fake_images)
            G_loss = self.adv_loss(fake_pred, real_label)
            G_loss.backward()

        elif self.loss_type == 'feature_matching':
            _, fmaps_fake = self.discriminator(fake_images)
            _, fmaps_real = self.discriminator(images)
            # get feature map statistics of last layer
            fmap_fake = fmaps_fake[-1].mean(0)
            fmap_real = fmaps_real[-1].mean(0).detach()  # treat as constant
            G_loss = self.fm_loss(fmap_fake, fmap_real)
            G_loss.backward()

        elif self.loss_type == 'both':
            raise ValueError("los_type = both is not yet implemented")

        else:
            raise ValueError("invalid loss_type.")

        self.optimizerG.step()
        self.G_loss.update(G_loss, self.batch_size)
Exemple #9
0
    def _sample_internal(self, opts, num):
        """Sample from the trained GAN model.

        """
        noise = utils.generate_noise(opts, num)
        sample = self._run_batch(opts, self._generated, self._noise_ph, noise,
                                 self._is_training_ph, False)
        return sample
Exemple #10
0
    def __init__(self,
                 netD,
                 netG,
                 device,
                 train_dl,
                 lr_D=0.0002,
                 lr_G=0.0002,
                 n_critic=5,
                 lambd=10,
                 loss_interval=50,
                 image_interval=50,
                 snapshot_interval=None,
                 save_img_dir='saved_images/',
                 save_snapshot_dir='saved_snapshots',
                 resample=None):
        self.netD = netD
        self.netG = netG
        self.train_dl = train_dl
        self.lr_D = lr_D
        self.lr_G = lr_G
        self.n_critic = n_critic
        self.lambd = lambd
        self.train_iteration_per_epoch = len(self.train_dl)
        self.device = device
        self.resample = resample
        self.special = None

        self.optimizerD = optim.Adam(self.netD.parameters(),
                                     lr=self.lr_D,
                                     betas=(0, 0.9))
        self.optimizerG = optim.Adam(self.netG.parameters(),
                                     lr=self.lr_G,
                                     betas=(0, 0.9))

        self.real_label = 1
        self.fake_label = 0
        self.nz = self.netG.nz

        self.fixed_noise = generate_noise(16, self.nz, self.device)
        self.loss_interval = loss_interval
        self.image_interval = image_interval
        self.snapshot_interval = snapshot_interval

        self.errD_records = []
        self.errG_records = []
        self.w_dist_records = []

        self.save_cnt = 0
        self.save_img_dir = save_img_dir
        self.save_snapshot_dir = save_snapshot_dir
        if (not os.path.exists(self.save_img_dir)):
            os.makedirs(self.save_img_dir)
        if (not os.path.exists(self.save_snapshot_dir)):
            os.makedirs(self.save_snapshot_dir)

        assert (
            self.resample is not None
        ), "Resample parameter is unnecessary for wgan_gp because it already resamples by default."
Exemple #11
0
def draw_concat(Gs, Zs, reals, NoiseAmp, in_s, mode, m_noise, m_image, opt):
    G_z = in_s
    if len(Gs) > 0:
        if mode == 'rand':
            count = 0
            pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
            if opt.mode == 'animation_train':
                pad_noise = 0
            for G, Z_opt, real_curr, real_next, noise_amp in zip(
                    Gs, Zs, reals, reals[1:], NoiseAmp):
                if count == 0:
                    z = functions.generate_noise([
                        1, Z_opt.shape[2] - 2 * pad_noise,
                        Z_opt.shape[3] - 2 * pad_noise
                    ],
                                                 device=opt.device)
                    z = z.expand(1, 3, z.shape[2], z.shape[3])
                else:
                    z = functions.generate_noise([
                        opt.nc_z, Z_opt.shape[2] - 2 * pad_noise,
                        Z_opt.shape[3] - 2 * pad_noise
                    ],
                                                 device=opt.device)
                z = m_noise(z)
                G_z = G_z[:, :, 0:real_curr.shape[2], 0:real_curr.shape[3]]
                G_z = m_image(G_z)
                z_in = noise_amp * z + G_z
                G_z = G(z_in.detach(), G_z)
                G_z = imresize(G_z, 1 / opt.scale_factor, opt)
                G_z = G_z[:, :, 0:real_next.shape[2], 0:real_next.shape[3]]
                count += 1
        if mode == 'rec':
            count = 0
            for G, Z_opt, real_curr, real_next, noise_amp in zip(
                    Gs, Zs, reals, reals[1:], NoiseAmp):
                G_z = G_z[:, :, 0:real_curr.shape[2], 0:real_curr.shape[3]]
                G_z = m_image(G_z)
                z_in = noise_amp * Z_opt + G_z
                G_z = G(z_in.detach(), G_z)
                G_z = imresize(G_z, 1 / opt.scale_factor, opt)
                G_z = G_z[:, :, 0:real_next.shape[2], 0:real_next.shape[3]]
                #if count != (len(Gs)-1):
                #    G_z = m_image(G_z)
                count += 1
    return G_z
Exemple #12
0
def get_shapes_with_noise(initial_shape: np.array):
    noised_shapes = []

    for i in range(5, 55, 5):
        noise_level = i / 100

        noised_shapes.append((generate_noise(initial_shape,
                                             noise_level), noise_level))

    return noised_shapes
Exemple #13
0
    def before_train(self):
        # set logging level
        logging.basicConfig(filename='train.log', level=logging.INFO)

        # initialize WandbLogger
        self.wandb_logger = WandbLogger(project=self.cfg.project)

        # initialize a set of noise vectors which will be used to
        # visualize generator's progress
        self.noise = generate_noise(16, self.z_dim).to(self.device)
        os.makedirs(self.cfg.save_folder, exist_ok=True)
Exemple #14
0
def calc_inception_score(device,
                         noise_size,
                         generator,
                         splits=10,
                         eval_size=50000):
    #Calculate Inception score

    #Based on the classification output of the Inception model trained on Imagenet
    batch_size = 32
    inception = inception_v3(pretrained=True)
    inception.eval()
    inception = inception.to(device)

    generator.eval()

    num_batches = eval_size // batch_size
    #claculate incpetion predictions in generated images
    all_fake_predictions = None
    for batch in range(num_batches):
        noise = utils.generate_noise(batch_size, noise_size).to(device)
        fake_batch = generator(noise).to(device).detach()
        prepared_fake_batch = F.interpolate(fake_batch, INCEPTION_SIZE)
        inception_logits = inception(prepared_fake_batch)
        predictions = F.softmax(inception_logits,
                                dim=-1).detach().cpu().numpy()
        if all_fake_predictions is None:
            all_fake_predictions = predictions
        else:
            all_fake_predictions = np.concatenate(
                (all_fake_predictions, predictions))

    eval_size_equal = (
        len(all_fake_predictions) // splits
    ) * splits  #Need to have an equal number of splits so cut off end if not

    #Calculate Inception score . Based on KL Divergence.
    all_scores = []
    for split in np.split(all_fake_predictions[:eval_size_equal], splits):
        split_scores = []
        prob_y = np.repeat(np.mean(split, axis=0, keepdims=True),
                           len(split),
                           axis=0)
        kl_div = stats.entropy(split, prob_y, axis=1)
        split_scores = np.exp(np.mean(kl_div))
        all_scores.append(split_scores)

    inception_score = np.mean(all_scores)

    generator.train()

    return inception_score
Exemple #15
0
    def __init__(self,
                 netD,
                 netG,
                 n_classes,
                 device,
                 train_dl,
                 lr_D=0.0002,
                 lr_G=0.0002,
                 loss_interval=50,
                 image_interval=50,
                 snapshot_interval=None,
                 save_img_dir='saved_images/',
                 save_snapshot_dir='saved_snapshots'):
        self.netD = netD
        self.netG = netG
        self.n_classes = n_classes
        self.train_dl = train_dl
        self.lr_D = lr_D
        self.lr_G = lr_G
        self.train_iteration_per_epoch = len(self.train_dl)
        self.device = device

        self.optimizerD = optim.RMSprop(self.netD.parameters(), lr=self.lr_D)
        self.optimizerG = optim.RMSprop(self.netG.parameters(), lr=self.lr_G)

        self.real_label = 1
        self.fake_label = 0
        self.nz = self.netG.nz

        self.fixed_noise = generate_noise(self.n_classes, self.nz, self.device)
        self.fixed_one_hot_labels = torch.diagflat(torch.ones(
            self.n_classes)).to(self.device)
        self.loss_interval = loss_interval
        self.image_interval = image_interval
        self.snapshot_interval = snapshot_interval

        self.errD_records = []
        self.errG_records = []

        self.save_cnt = 0
        self.save_img_dir = save_img_dir
        self.save_snapshot_dir = save_snapshot_dir
        if (not os.path.exists(self.save_img_dir)):
            os.makedirs(self.save_img_dir)
        if (not os.path.exists(self.save_snapshot_dir)):
            os.makedirs(self.save_snapshot_dir)
	def __init__(self, loss_type, netD, netG, device, train_dl, lr_D = 0.0002, lr_G = 0.0002, resample = True, weight_clip = None, use_gradient_penalty = False, loss_interval = 50, image_interval = 50, save_img_dir = 'saved_images/'):
		self.loss_type = loss_type
		self.loss_dict = {'SGAN':SGAN, 'LSGAN':LSGAN, 'HINGEGAN':HINGEGAN, 'WGAN':WGAN, 'RASGAN':RASGAN, 'RALSGAN':RALSGAN, 'RAHINGEGAN':RAHINGEGAN, 'QPGAN':QPGAN}
		if(loss_type == 'SGAN' or loss_type == 'LSGAN' or loss_type == 'HINGEGAN' or loss_type == 'WGAN'):
			self.require_type = 0
			self.loss = self.loss_dict[self.loss_type](device)
		elif(loss_type == 'RASGAN' or loss_type == 'RALSGAN' or loss_type == 'RAHINGEGAN'):
			self.require_type = 1
			self.loss = self.loss_dict[self.loss_type](device)
		elif(loss_type == 'QPGAN'):
			self.require_type = 2
			self.loss = self.loss_dict[self.loss_type](device, 'L1')
		else:
			self.require_type = -1

		self.netD = netD
		self.netG = netG
		self.train_dl = train_dl
		self.lr_D = lr_D
		self.lr_G = lr_G
		self.train_iteration_per_epoch = len(self.train_dl)
		self.device = device
		self.resample = resample
		self.weight_clip = weight_clip
		self.use_gradient_penalty = use_gradient_penalty
		self.special = None

		self.optimizerD = optim.Adam(self.netD.parameters(), lr = self.lr_D, betas = (0, 0.9))
		self.optimizerG = optim.Adam(self.netG.parameters(), lr = self.lr_G, betas = (0, 0.9))

		self.real_label = 1
		self.fake_label = 0
		self.nz = self.netG.nz

		self.fixed_noise = generate_noise(49, self.nz, self.device)
		self.loss_interval = loss_interval
		self.image_interval = image_interval

		self.errD_records = []
		self.errG_records = []

		self.save_cnt = 0
		self.save_img_dir = save_img_dir
		if(not os.path.exists(self.save_img_dir)):
			os.makedirs(self.save_img_dir)
Exemple #17
0
    def refinement_layers(self, start_idx, x_prev_out, noise_amp, mode):
        for idx, block in enumerate(self.body[start_idx:], start_idx):
            if self.opt.vae_levels == idx + 1:
                x_prev_out.detach_()

            # Upscale
            x_prev_out_up = utils.upscale(x_prev_out, idx + 1, self.opt)

            # Add noise if "random" sampling, else, add no noise is "reconstruction" mode
            if mode == 'rand':
                noise = utils.generate_noise(ref=x_prev_out_up)
                x_prev = block(x_prev_out_up + noise * noise_amp[idx + 1])
            else:
                x_prev = block(x_prev_out_up)

            x_prev_out = torch.tanh(x_prev + x_prev_out_up)

        return x_prev_out
Exemple #18
0
def inference(inp, filename):
    lm, gen = init_inference()
    checkpoint = torch.load(f"{config.OUT_DIR}/checkpoint.pt")

    ctoi_file = open(f"{config.BASE_DIR}/src/ctoi.txt", "rb")
    encoding_dict = pickle.load(ctoi_file)
    ctoi_file.close()
    # print(
    #     f'Checkpoint Details:\n Trained for: {checkpoint["epoch"]} epochs, Final Generator loss: {checkpoint["gen_loss"]}, Log File: {checkpoint["log_file"]}'
    # )
    lm.load_state_dict(checkpoint["lm"])
    gen.load_state_dict(checkpoint["gen"])

    test = preprocess_labels([inp] * config.BATCH_SIZE, encoding_dict)
    with torch.no_grad():
        zin = generate_noise(config.Z_LEN, config.BATCH_SIZE, device)
        gin = lm(test.to(device))
        gout = gen(zin, gin)
        tgrid = torchvision.utils.make_grid(gout.detach().cpu(), nrow=4)
        imshow(tgrid, f"{config.OUT_DIR}/inference/{filename}.png")

    print(f'Inference Finished. Check "out" directory for {filename}.png')
Exemple #19
0
def inference_tb(inp, writer):
    lm, gen = init_inference()
    checkpoint = torch.load(f"{config.OUT_DIR}/checkpoint.pt")

    ctoi_file = open(f"{config.BASE_DIR}/src/ctoi.txt", "rb")
    encoding_dict = pickle.load(ctoi_file)
    ctoi_file.close()

    # print(
    #     f'Checkpoint Details:\n Trained for: {checkpoint["epoch"]} epochs, Final Generator loss: {checkpoint["gen_loss"]}, Log File: {checkpoint["log_file"]}'
    # )
    lm.load_state_dict(checkpoint["lm"])
    gen.load_state_dict(checkpoint["gen"])

    test = preprocess_labels([inp] * config.BATCH_SIZE, encoding_dict)
    with torch.no_grad():
        # lm.eval()
        # gen.eval()
        zin = generate_noise(config.Z_LEN, config.BATCH_SIZE, device)
        gin = lm(test.to(device))
        gout = gen(zin, gin)
        tgrid = torchvision.utils.make_grid(gout.detach().cpu(), nrow=4)
        writer.add_image(str(checkpoint["epoch"]), tgrid)
	def train(self, num_epoch):
		for epoch in range(num_epoch):
			for i, data in enumerate(tqdm(self.train_dl)):
				self.netD.zero_grad()
				real_images = data[0].to(self.device)
				bs = real_images.size(0)

				noise = generate_noise(bs, self.nz, self.device)
				fake_images = self.netG(noise)

				c_xr = self.netD(real_images)
				c_xr = c_xr.view(-1)
				c_xf = self.netD(fake_images.detach())
				c_xf = c_xf.view(-1)

				if(self.require_type == 0 or self.require_type == 1):
					errD = self.loss.d_loss(c_xr, c_xf)
				elif(self.require_type == 2):
					errD = self.loss.d_loss(c_xr, c_xf, real_images, fake_images)
				
				if(self.use_gradient_penalty != False):
					errD += self.use_gradient_penalty * self.gradient_penalty(real_images, fake_images)

				errD.backward()
				self.optimizerD.step()

				if(self.weight_clip != None):
					for param in self.netD.parameters():
						param.data.clamp_(-self.weight_clip, self.weight_clip)

			
				self.netG.zero_grad()
				if(self.resample):
					noise = generate_noise(bs, self.nz, self.device)
					fake_images = self.netG(noise)

				if(self.require_type == 0):
					c_xf = self.netD(fake_images)
					c_xf = c_xf.view(-1)
					errG = self.loss.g_loss(c_xf)
				if(self.require_type == 1 or self.require_type == 2):
					c_xr = self.netD(real_images)				# (bs, 1, 1, 1)
					c_xr = c_xr.view(-1)						# (bs)
					c_xf = self.netD(fake_images)		# (bs, 1, 1, 1)
					c_xf = c_xf.view(-1)
					errG = self.loss.g_loss(c_xr, c_xf)
				errG.backward()
				self.optimizerG.step()

				self.errD_records.append(float(errD))
				self.errG_records.append(float(errG))

				if(i % self.loss_interval == 0):
					print('[%d/%d] [%d/%d] errD : %.4f, errG : %.4f'
						  %(epoch+1, num_epoch, i+1, self.train_iteration_per_epoch, errD, errG))

				if(i % self.image_interval == 0):
					if(self.special == None):
						sample_images_list = get_sample_images_list('Unsupervised', (self.fixed_noise, self.netG))
						plot_img = get_display_samples(sample_images_list, 7, 7)
						cur_file_name = os.path.join(self.save_img_dir, str(self.save_cnt)+' : '+str(epoch)+'-'+str(i)+'.jpg')
						self.save_cnt += 1
						cv2.imwrite(cur_file_name, plot_img)

					elif(self.special == 'Wave'):
						sample_audios_list = get_sample_images_list('Unsupervised_Audio', (self.fixed_noise, self.netG))
						plot_fig = plot_multiple_spectrograms(sample_audios_list, 7, 7, freq = 16000)
						cur_file_name = os.path.join(self.save_img_dir, str(self.save_cnt)+' : '+str(epoch)+'-'+str(i)+'.jpg')
						self.save_cnt += 1
						save_fig(cur_file_name, plot_fig)
						plot_fig.clf()
						
Exemple #21
0
    def _train_internal(self, opts):
        """Train a VAE model.

        """

        batches_num = self._data.num_points / opts['batch_size']
        train_size = self._data.num_points
        num_plot = 320
        sample_prev = np.zeros([num_plot] + list(self._data.data_shape))
        l2s = []

        counter = 0
        decay = 1.
        logging.error('Training VAE')
        for _epoch in xrange(opts["gan_epoch_num"]):

            if opts['decay_schedule'] == "manual":
                if _epoch == 30:
                    decay = decay / 2.
                if _epoch == 50:
                    decay = decay / 5.
                if _epoch == 100:
                    decay = decay / 10.

            if _epoch > 0 and _epoch % opts['save_every_epoch'] == 0:
                os.path.join(opts['work_dir'], opts['ckpt_dir'])
                self._saver.save(self._session,
                                 os.path.join(opts['work_dir'],
                                              opts['ckpt_dir'],
                                              'trained-pot'),
                                 global_step=counter)

            for _idx in xrange(batches_num):
                # logging.error('Step %d of %d' % (_idx, batches_num ) )
                data_ids = np.random.choice(train_size, opts['batch_size'],
                                            replace=False, p=self._data_weights)
                batch_images = self._data.data[data_ids].astype(np.float)
                batch_noise = utils.generate_noise(opts, opts['batch_size'])
                _, loss, loss_kl, loss_reconstruct = self._session.run(
                    [self._optim, self._loss, self._loss_kl,
                     self._loss_reconstruct],
                    feed_dict={self._real_points_ph: batch_images,
                               self._noise_ph: batch_noise,
                               self._lr_decay_ph: decay,
                               self._is_training_ph: True})
                counter += 1

                if opts['verbose'] and counter % opts['plot_every'] == 0:
                    debug_str = 'Epoch: %d/%d, batch:%d/%d' % (
                        _epoch+1, opts['gan_epoch_num'], _idx+1, batches_num)
                    debug_str += '  [L=%.2g, Recon=%.2g, KLQ=%.2g]' % (
                        loss, loss_reconstruct, loss_kl)
                    logging.error(debug_str)

                if opts['verbose'] and counter % opts['plot_every'] == 0:
                    metrics = Metrics()
                    points_to_plot = self._run_batch(
                        opts, self._generated, self._noise_ph,
                        self._noise_for_plots[0:num_plot],
                        self._is_training_ph, False)
                    l2s.append(np.sum((points_to_plot - sample_prev)**2))
                    metrics.l2s = l2s[:]
                    metrics.make_plots(
                        opts,
                        counter,
                        None,
                        points_to_plot,
                        prefix='sample_e%04d_mb%05d_' % (_epoch, _idx))
                    reconstructed = self._session.run(
                        self._reconstruct_x,
                        feed_dict={self._real_points_ph: batch_images,
                                   self._is_training_ph: False})
                    metrics.l2s = None
                    metrics.make_plots(
                        opts,
                        counter,
                        None,
                        reconstructed,
                        prefix='reconstr_e%04d_mb%05d_' % (_epoch, _idx))
                if opts['early_stop'] > 0 and counter > opts['early_stop']:
                    break
        if _epoch > 0:
            os.path.join(opts['work_dir'], opts['ckpt_dir'])
            self._saver.save(self._session,
                             os.path.join(opts['work_dir'],
                                          opts['ckpt_dir'],
                                          'trained-pot-final'),
                             global_step=counter)
    def train(self, num_epoch):
        criterion = nn.BCELoss()
        for epoch in range(num_epoch):
            for i, data in enumerate(tqdm(self.train_dl)):
                # (1) : maximize log(D(x)) + log(1 - D(G(z)))
                # 		also means minimize (-log(D(x))) + (-log(1 - D(G(z))))
                self.netD.zero_grad()

                # first, calculate -log(D(x)) and its gradients using real images
                # real images (bs, nc, 64, 64)
                real_images = data[0].to(self.device)
                bs = real_images.size(0)
                # real labels (bs)
                label = torch.full((bs, ), self.real_label, device=self.device)
                output = self.netD(real_images)  # (bs, 1, 1, 1)
                output = output.view(-1)  # (bs)
                # BCELoss of output(bs), and real label(bs)
                errD_real = criterion(output, label)  # -log(D(x))
                # calculate the gradients
                errD_real.backward()

                # second, calculate -log(1 - D(G(z))) and its gradients using fake images
                # noise (bs, nz, 1, 1), fake images (bs, nc, 64, 64)
                noise = generate_noise(bs, self.nz, self.device)
                fake_images = self.netG(noise)
                # fake labels (bs)
                label.fill_(self.fake_label)
                output = self.netD(fake_images.detach())  # (bs, 1, 1, 1)
                output = output.view(-1)  # (bs)
                # BCELoss of output(bs), and fake labels(bs)
                errD_fake = criterion(output, label)  # -log(1 - D(G(z)))
                # calculate the gradients
                errD_fake.backward()

                # calculate the final loss value, (-log(D(x))) + (-log(1 - D(G(z))))
                errD = errD_real + errD_fake
                # update D using the gradients calculated previously
                self.optimizerD.step()

                # (2) : maximize log(D(G(z)))
                #		also means minimize -log(D(G(z)))
                self.netG.zero_grad()
                if (self.resample):
                    noise = generate_noise(bs, self.nz, self.device)
                    fake_images = self.netG(noise)

                # first, calculate -log(D(G(z))) and its gradients using fake images
                # real labels (bs)
                label.fill_(self.real_label)
                output = self.netD(fake_images)  # (bs, 1, 1, 1)
                output = output.view(-1)  # (bs)
                # BCELoss of output(bs), and real labels(bs)
                errG = criterion(output, label)  # -log(D(G(z)))
                #calculate the gradients
                errG.backward()

                #update G using the gradients calculated previously
                self.optimizerG.step()

                self.errD_records.append(float(errD))
                self.errG_records.append(float(errG))

                if (i % self.loss_interval == 0):
                    print('[%d/%d] [%d/%d] errD : %.4f, errG : %.4f' %
                          (epoch + 1, num_epoch, i + 1,
                           self.train_iteration_per_epoch, errD, errG))

                if (i % self.image_interval == 0):
                    if (self.special == None):
                        sample_images_list = get_sample_images_list(
                            'Unsupervised', (self.fixed_noise, self.netG))
                        plot_fig = plot_multiple_images(
                            sample_images_list, 4, 4)
                        cur_file_name = os.path.join(
                            self.save_img_dir,
                            str(self.save_cnt) + ' : ' + str(epoch) + '-' +
                            str(i) + '.jpg')
                        self.save_cnt += 1
                        save_fig(cur_file_name, plot_fig)
                        plot_fig.clf()

                    elif (self.special == 'Wave'):
                        sample_audios_list = get_sample_images_list(
                            'Unsupervised_Audio',
                            (self.fixed_noise, self.netG))
                        plot_fig = plot_multiple_spectrograms(
                            sample_audios_list, 4, 4, freq=16000)
                        cur_file_name = os.path.join(
                            self.save_img_dir,
                            str(self.save_cnt) + ' : ' + str(epoch) + '-' +
                            str(i) + '.jpg')
                        self.save_cnt += 1
                        save_fig(cur_file_name, plot_fig)
                        plot_fig.clf()

                if (self.snapshot_interval is not None):
                    if (i % self.snapshot_interval == 0):
                        save(
                            os.path.join(
                                self.save_snapshot_dir, 'Epoch' + str(epoch) +
                                '_' + str(i) + '.state'), self.netD, self.netG,
                            self.optimizerD, self.optimizerG)
Exemple #23
0
def eval(opt, netG):
    # Re-generate dataset frames

    fps, td, fps_index = utils.get_fps_td_by_index(opt.scale_idx, opt)
    opt.fps = fps
    opt.td = td
    opt.fps_index = fps_index
    # opt.tds.append(opt.td)
    opt.dataset.generate_frames(opt.scale_idx)

    torch.save(opt.dataset.frames,
               os.path.join(opt.saver.eval_dir, "real_full_scale.pth"))

    if not hasattr(opt, 'Z_init_size'):
        initial_size = utils.get_scales_by_index(0, opt.scale_factor,
                                                 opt.stop_scale, opt.img_size)
        initial_size = [int(initial_size * opt.ar), initial_size]
        opt.Z_init_size = [
            opt.batch_size, opt.latent_dim, opt.td, *initial_size
        ]

    # Parallel
    if opt.device == 'cuda':
        G_curr = torch.nn.DataParallel(netG)
    else:
        G_curr = netG

    progressbar_args = {
        "iterable":
        range(opt.niter),
        "desc":
        "Generation scale [{}/{}]".format(opt.scale_idx + 1,
                                          opt.stop_scale + 1),
        "train":
        True,
        "offset":
        0,
        "logging_on_update":
        False,
        "logging_on_close":
        True,
        "postfix":
        True
    }
    epoch_iterator = tools.create_progressbar(**progressbar_args)

    iterator = iter(data_loader)

    random_samples = []

    for iteration in epoch_iterator:
        try:
            data = next(iterator)
        except StopIteration:
            iterator = iter(opt.data_loader)
            data = next(iterator)

        if opt.scale_idx > 0:
            real, real_zero = data
            real = real.to(opt.device)
        else:
            real = data.to(opt.device)

        noise_init = utils.generate_noise(size=opt.Z_init_size,
                                          device=opt.device)

        # Update progress bar
        epoch_iterator.set_description(
            'Scale [{}/{}], Iteration [{}/{}]'.format(
                opt.scale_idx + 1,
                opt.stop_scale + 1,
                iteration + 1,
                opt.niter,
            ))

        with torch.no_grad():
            fake_var = []
            fake_vae_var = []
            for _ in range(opt.num_samples):
                noise_init = utils.generate_noise(ref=noise_init)
                fake, fake_vae = G_curr(noise_init,
                                        opt.Noise_Amps,
                                        noise_init=noise_init,
                                        mode="rand")
                fake_var.append(fake)
                fake_vae_var.append(fake_vae)
            fake_var = torch.cat(fake_var, dim=0)
            fake_vae_var = torch.cat(fake_vae_var, dim=0)

        opt.summary.visualize_video(opt, iteration, real, 'Real')
        opt.summary.visualize_video(opt, iteration, fake_var, 'Fake var')
        opt.summary.visualize_video(opt, iteration, fake_vae_var,
                                    'Fake VAE var')

        random_samples.append(fake_var)

    random_samples = torch.cat(random_samples, dim=0)
    torch.save(random_samples,
               os.path.join(opt.saver.eval_dir, "random_samples.pth"))
    epoch_iterator.close()
Exemple #24
0
def eval(opt, netG):
    # Re-generate dataset frames

    if not hasattr(opt, 'Z_init_size'):
        initial_size = utils.get_scales_by_index(0, opt.scale_factor,
                                                 opt.stop_scale, opt.img_size)
        initial_size = [int(initial_size * opt.ar), initial_size]
        opt.Z_init_size = [opt.batch_size, opt.latent_dim, *initial_size]

    # Parallel
    if opt.device == 'cuda':
        G_curr = torch.nn.DataParallel(netG)
    else:
        G_curr = netG

    progressbar_args = {
        "iterable":
        range(opt.niter),
        "desc":
        "Training scale [{}/{}]".format(opt.scale_idx + 1, opt.stop_scale + 1),
        "train":
        True,
        "offset":
        0,
        "logging_on_update":
        False,
        "logging_on_close":
        True,
        "postfix":
        True
    }
    epoch_iterator = tools.create_progressbar(**progressbar_args)

    iterator = iter(data_loader)

    random_samples = []

    for iteration in epoch_iterator:
        try:
            data = next(iterator)
        except StopIteration:
            iterator = iter(opt.data_loader)
            data = next(iterator)

        if opt.scale_idx > 0:
            real, real_zero = data
            real = real.to(opt.device)
        else:
            real = data.to(opt.device)

        noise_init = utils.generate_noise(size=opt.Z_init_size,
                                          device=opt.device)

        # Update progress bar
        epoch_iterator.set_description(
            'Scale [{}/{}], Iteration [{}/{}]'.format(
                opt.scale_idx + 1,
                opt.stop_scale + 1,
                iteration + 1,
                opt.niter,
            ))
        G_curr.eval()
        import numpy as np
        import sys
        with torch.no_grad():
            fake_var = []
            fake_vae_var = []
            for _ in range(opt.num_samples):
                noise_init = utils.generate_noise(ref=noise_init)
                channel_idxs = np.random.choice(np.arange(0, 128),
                                                127,
                                                replace=False)
                # U = torch.zeros(1, 128, 5).normal_(0, 1).to(noise_init.device)
                U = torch.zeros(1, 128, 1).to(noise_init.device)
                U[:, _] = 4
                # U[:, :120] =
                V = torch.zeros(1, 1, 22, 33).to(noise_init.device)
                # V.bernoulli_(p=0.01)
                V[:, :, 1:4, 20:32] = 1
                # V[:, :, 4:10, 8:10] = 1
                V = V.flatten(2)
                UV = torch.bmm(U, V).view(1, 128, 22, 33)
                UV = (UV - UV.mean()) / UV.std()
                # noise_init[:] = 0
                # noise_init[:, :, 5:11, 16:18] = _
                # noise_init[:, 108, 0:4, 0:4] = 100
                # noise_init[:, 21, _:_ + 1, 16:19] = 0.01
                # noise_init[:, :, 3:11, 16:18] = -10 / opt.num_samples

                # normed_z_vae = z_vae / ((z_vae ** 2).sum() + sys.float_info.epsilon)
                # noise_init = noise_init / ((noise_init ** 2).sum() + sys.float_info.epsilon)
                noise_init = UV
                fake, fake_vae = G_curr(noise_init,
                                        opt.Noise_Amps,
                                        noise_init=noise_init,
                                        mode="rand")
                fake_var.append(fake)
                fake_vae_var.append(fake_vae)
            fake_var = torch.cat(fake_var, dim=0)
            fake_vae_var = torch.cat(fake_vae_var, dim=0)

        opt.summary.visualize_image(opt, iteration, real, 'Real')
        opt.summary.visualize_image(opt, iteration, fake_var, 'Fake var')
        opt.summary.visualize_image(opt, iteration, fake_vae_var,
                                    'Fake VAE var')

        random_samples.append(fake_var)

    random_samples = torch.cat(random_samples, dim=0)
    from torchvision.utils import save_image
    save_image(random_samples, 'test.png', normalize=True)
    torch.save(random_samples,
               os.path.join(opt.saver.eval_dir, "random_samples.pth"))
    epoch_iterator.close()
Exemple #25
0
def train_single_scale(netD,
                       netG,
                       reals,
                       Gs,
                       Zs,
                       in_s,
                       NoiseAmp,
                       opt,
                       centers=None):

    real = reals[len(Gs)]
    opt.nzx = real.shape[2]  #+(opt.ker_size-1)*(opt.num_layer)
    opt.nzy = real.shape[3]  #+(opt.ker_size-1)*(opt.num_layer)
    opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) *
                                          (opt.num_layer - 1)) * opt.stride
    pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    if opt.mode == 'animation_train':
        opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer)
        opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer)
        pad_noise = 0
    m_noise = nn.ZeroPad2d(int(pad_noise))
    m_image = nn.ZeroPad2d(int(pad_image))

    alpha = opt.alpha

    fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy],
                                           device=opt.device)
    z_opt = torch.full(fixed_noise.shape, 0, device=opt.device)
    z_opt = m_noise(z_opt)

    # setup optimizer
    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lr_d,
                            betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lr_g,
                            betas=(opt.beta1, 0.999))
    schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)

    errD2plot = []
    errG2plot = []
    D_real2plot = []
    D_fake2plot = []
    z_opt2plot = []

    for epoch in range(opt.niter):
        if (Gs == []) & (opt.mode != 'SR_train'):
            z_opt = functions.generate_noise([1, opt.nzx, opt.nzy],
                                             device=opt.device)
            z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy))
            noise_ = functions.generate_noise([1, opt.nzx, opt.nzy],
                                              device=opt.device)
            noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy))
        else:
            noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy],
                                              device=opt.device)
            noise_ = m_noise(noise_)

        ############################
        # (1) Update D network: maximize D(x) + D(G(z))
        ###########################
        for j in range(opt.Dsteps):
            # train with real
            netD.zero_grad()

            output = netD(real).to(opt.device)
            #D_real_map = output.detach()
            errD_real = -output.mean()  #-a
            errD_real.backward(retain_graph=True)
            D_x = -errD_real.item()

            # train with fake
            if (j == 0) & (epoch == 0):
                if (Gs == []) & (opt.mode != 'SR_train'):
                    prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                      0,
                                      device=opt.device)
                    in_s = prev
                    prev = m_image(prev)
                    z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                        0,
                                        device=opt.device)
                    z_prev = m_noise(z_prev)
                    opt.noise_amp = 1
                elif opt.mode == 'SR_train':
                    z_prev = in_s
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
                    prev = z_prev
                else:
                    prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                       m_noise, m_image, opt)
                    prev = m_image(prev)
                    z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec',
                                         m_noise, m_image, opt)
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
            else:
                prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                   m_noise, m_image, opt)
                prev = m_image(prev)

            if opt.mode == 'paint_train':
                prev = functions.quant2centers(prev, centers)
                plt.imsave('%s/prev.png' % (opt.outf),
                           functions.convert_image_np(prev),
                           vmin=0,
                           vmax=1)

            if (Gs == []) & (opt.mode != 'SR_train'):
                noise = noise_
            else:
                noise = opt.noise_amp * noise_ + prev

            fake = netG(noise.detach(), prev)
            output = netD(fake.detach())
            errD_fake = output.mean()
            errD_fake.backward(retain_graph=True)
            D_G_z = output.mean().item()

            gradient_penalty = functions.calc_gradient_penalty(
                netD, real, fake, opt.lambda_grad, opt.device)
            gradient_penalty.backward()

            errD = errD_real + errD_fake + gradient_penalty
            optimizerD.step()

        errD2plot.append(errD.detach())

        ############################
        # (2) Update G network: maximize D(G(z))
        ###########################

        for j in range(opt.Gsteps):
            netG.zero_grad()
            output = netD(fake)
            #D_fake_map = output.detach()
            errG = -output.mean()
            errG.backward(retain_graph=True)
            if alpha != 0:
                loss = nn.MSELoss()
                if opt.mode == 'paint_train':
                    z_prev = functions.quant2centers(z_prev, centers)
                    plt.imsave('%s/z_prev.png' % (opt.outf),
                               functions.convert_image_np(z_prev),
                               vmin=0,
                               vmax=1)
                Z_opt = opt.noise_amp * z_opt + z_prev
                rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real)
                rec_loss.backward(retain_graph=True)
                rec_loss = rec_loss.detach()
            else:
                Z_opt = z_opt
                rec_loss = 0

            optimizerG.step()

        errG2plot.append(errG.detach() + rec_loss)
        D_real2plot.append(D_x)
        D_fake2plot.append(D_G_z)
        z_opt2plot.append(rec_loss)

        if epoch % 25 == 0 or epoch == (opt.niter - 1):
            print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter))

        if epoch % 500 == 0 or epoch == (opt.niter - 1):
            plt.imsave('%s/fake_sample.png' % (opt.outf),
                       functions.convert_image_np(fake.detach()),
                       vmin=0,
                       vmax=1)
            plt.imsave('%s/G(z_opt).png' % (opt.outf),
                       functions.convert_image_np(
                           netG(Z_opt.detach(), z_prev).detach()),
                       vmin=0,
                       vmax=1)
            #plt.imsave('%s/D_fake.png'   % (opt.outf), functions.convert_image_np(D_fake_map))
            #plt.imsave('%s/D_real.png'   % (opt.outf), functions.convert_image_np(D_real_map))
            #plt.imsave('%s/z_opt.png'    % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1)
            #plt.imsave('%s/prev.png'     %  (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1)
            #plt.imsave('%s/noise.png'    %  (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1)
            #plt.imsave('%s/z_prev.png'   % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1)

            torch.save(z_opt, '%s/z_opt.pth' % (opt.outf))

        schedulerD.step()
        schedulerG.step()

    functions.save_networks(netG, netD, z_opt, opt)
    return z_opt, in_s, netG
Exemple #26
0
	def train(self, res_num_epochs, res_percentage, bs):
		l1 = nn.L1Loss()
		
		p = 0
		res_percentage = [None] + res_percentage

		for i, (num_epoch, percentage, cur_bs) in enumerate(zip(res_num_epochs, res_percentage, bs)):
			train_dl = self.train_ds.get_loader(32 * (2**i), cur_bs)
			val_dl = list(self.val_ds.get_loader(32 * (2**i), 3))[0]
			train_dl_len = len(train_dl)
				
			if(percentage is None):
				num_epoch_transition = 0
			else:
				num_epoch_transition = int(num_epoch * percentage)

			cnt = 1
			for epoch in range(num_epoch):
				p = i
				if(self.resample):
					train_dl_iter = iter(train_dl)

				for j, (x, y) in enumerate(tqdm(train_dl)):
					if(epoch < num_epoch_transition):
						p = i + cnt / (train_dl_len * num_epoch_transition) - 1
						cnt+=1

					x = x.to(self.device)
					y = y.to(self.device)
					bs = x.size(0)
					noise = generate_noise(bs, self.nz, self.device)
					fake_y = self.netG(x, p, noise)

					self.netD.zero_grad()

					c_xr = self.netD(x, y)
					c_xr = c_xr.view(-1)
					c_xf = self.netD(x, fake_y.detach())
					c_xf = c_xf.view(-1)

					if(self.require_type == 0 or self.require_type == 1):
						errD = self.loss.d_loss(c_xr, c_xf)
					elif(self.require_type == 2):
						errD = self.loss.d_loss(c_xr, c_xf, y, fake_y)
					
					if(self.use_gradient_penalty != False):
						errD += self.use_gradient_penalty * self.gradient_penalty(x, y, fake_y)

					errD.backward()
					self.optimizerD.step()

					if(self.weight_clip != None):
						for param in self.netD.parameters():
							param.data.clamp_(-self.weight_clip, self.weight_clip)


					self.netG.zero_grad()
					if(self.resample):
						x, y = next(train_dl_iter)
						x = x.to(self.device)
						y = y.to(self.device)
						bs = x.size(0)
						noise = generate_noise(bs, self.nz, self.device)
						fake_y = self.netG(x, p, noise)

					if(self.require_type == 0):
						c_xr = None
						c_xf, f1 = self.netD(x, fake_y, True)		# (bs, 1, 1, 1)
						c_xf = c_xf.view(-1)						# (bs)	
						errG_1 = self.loss.g_loss(c_xf)
					if(self.require_type == 1 or self.require_type == 2):
						c_xr, f2 = self.netD(x, y, True)				# (bs, 1, 1, 1)
						c_xr = c_xr.view(-1)						# (bs)
						c_xf, f1 = self.netD(x, fake_y, True)		# (bs, 1, 1, 1)
						c_xf = c_xf.view(-1)						# (bs)		
						errG_1 = self.loss.g_loss(c_xr, c_xf)

					if(self.ds_weight == 0):
						ds_loss = 0
					else:
						noise1 = generate_noise(bs, self.nz, self.device)
						noise2 = generate_noise(bs, self.nz, self.device)
						fake_y1 = self.netG(x, noise1)
						fake_y2 = self.netG(x, noise2)
						ds_loss = self.ds_loss.get_loss(fake_y1, fake_y2, noise1, noise2)
					
					if(self.rec_weight == 0):
						rec_loss = 0
					else:
						if(self.use_rec_feature):
							rec_loss = 0
							if(c_xr == None):
								c_xr, f2 = self.netD(x, y, True)				# (bs, 1, 1, 1)
								c_xr = c_xr.view(-1)						# (bs)
								for f1_, f2_ in zip(f1, f2):
									rec_loss += (f1_ - f2_).abs().mean()
								rec_loss /= len(f1)

						else:
							rec_loss = l1(fake_y, y)

					errG = errG_1 + rec_loss * self.rec_weight + ds_loss * self.ds_weight
					errG.backward()
					# update G using the gradients calculated previously
					self.optimizerG.step()

					if(j % self.loss_interval == 0):
						print('[%d/%d] [%d/%d] errD : %.4f, errG : %.4f'
							  %(epoch+1, num_epoch, i+1, train_dl_len, errD, errG))

					if(j % self.image_interval == 0):
						if(self.nz == None):
							sample_images_list = get_sample_images_list((val_dl, self.netG, p, self.device))
							plot_image = get_display_samples(sample_images_list, 3, 3)
						else:
							sample_images_list = get_sample_images_list_noise((val_dl, self.netG, p, self.fixed_noise, self.device))
							plot_image = get_display_samples(sample_images_list, 9, 3)

						cur_file_name = os.path.join(self.save_img_dir, str(self.save_cnt)+' : '+str(epoch)+'-'+str(j)+'.jpg')
						self.save_cnt += 1
						cv2.imwrite(cur_file_name, plot_image)
    def train(self, num_epoch):
        criterion = nn.BCELoss()
        for epoch in range(num_epoch):
            for i, data in enumerate(tqdm(self.train_dl)):
                # (1) : minimize 0.5 * mean((D(x, y) - 1)^2) + 0.5 * mean((D(G(z, y), y) - 0)^2)
                self.netD.zero_grad()
                real_images = data[0].to(self.device)
                real_class = data[1].to(self.device)

                bs = real_images.size(0)
                # real labels (bs)
                real_label = torch.full((bs, ),
                                        self.real_label,
                                        device=self.device)
                # fake labels (bs)
                fake_label = torch.full((bs, ),
                                        self.fake_label,
                                        device=self.device)

                # one hot labels (bs, n_classes)
                one_hot_labels = torch.FloatTensor(bs, self.n_classes).to(
                    self.device)
                one_hot_labels.zero_()
                one_hot_labels.scatter_(1, real_class.view(bs, 1), 1.0)

                # noise (bs, nz, 1, 1), fake images (bs, nc, 64, 64)
                noise = generate_noise(bs, self.nz, self.device)

                fake_class = torch.randint(0, self.n_classes,
                                           size=(bs,
                                                 1)).view(bs,
                                                          1).to(self.device)
                # one hot labels (bs, n_classes)
                one_hot_labels_fake = torch.FloatTensor(bs, self.n_classes).to(
                    self.device)
                one_hot_labels_fake.zero_()
                one_hot_labels_fake.scatter_(1,
                                             fake_class.view(bs, 1).long(),
                                             1.0)

                fake_images = self.netG(noise, one_hot_labels_fake)

                # calculate the discriminator results for both real & fake
                c_xr = self.netD(real_images, one_hot_labels)  # (bs, 1, 1, 1)
                c_xr = c_xr.view(-1)  # (bs)
                c_xf = self.netD(fake_images.detach(),
                                 one_hot_labels_fake)  # (bs, 1, 1, 1)
                c_xf = c_xf.view(-1)  # (bs)
                # calculate the discriminator loss
                errD = criterion(c_xr, real_label) + criterion(
                    c_xf, fake_label)
                errD.backward()
                # update D using the gradients calculated previously
                self.optimizerD.step()

                # (2) : minimize 0.5 * mean((D(G(z)) - 1)^2)
                self.netG.zero_grad()
                if (self.resample):
                    noise = generate_noise(bs, self.nz, self.device)
                    one_hot_labels_fake = torch.FloatTensor(
                        bs, self.n_classes).to(self.device)
                    one_hot_labels_fake.zero_()
                    one_hot_labels_fake.scatter_(1,
                                                 fake_class.view(bs, 1).long(),
                                                 1.0)
                    fake_images = self.netG(noise, one_hot_labels_fake)
                # we updated the discriminator once, therefore recalculate c_xf
                c_xf = self.netD(fake_images,
                                 one_hot_labels_fake)  # (bs, 1, 1, 1)
                c_xf = c_xf.view(-1)  # (bs)
                # calculate the Generator loss
                errG = criterion(c_xf,
                                 real_label)  # 0.5 * mean((D(G(z)) - 1)^2)
                errG.backward()
                #update G using the gradients calculated previously
                self.optimizerG.step()

                self.errD_records.append(float(errD))
                self.errG_records.append(float(errG))

                if (i % self.loss_interval == 0):
                    print('[%d/%d] [%d/%d] errD : %.4f, errG : %.4f' %
                          (epoch + 1, num_epoch, i + 1,
                           self.train_iteration_per_epoch, errD, errG))

                if (i % self.image_interval == 0):
                    sample_images_list = get_sample_images_list(
                        'Conditional',
                        (self.fixed_noise, self.fixed_one_hot_labels,
                         self.n_classes, self.netG))
                    plot_fig = plot_multiple_images(sample_images_list,
                                                    self.n_classes, 1)
                    cur_file_name = os.path.join(
                        self.save_img_dir,
                        str(self.save_cnt) + ' : ' + str(epoch) + '-' +
                        str(i) + '.jpg')
                    self.save_cnt += 1
                    save_fig(cur_file_name, plot_fig)
                    plot_fig.clf()

                if (self.snapshot_interval is not None):
                    if (i % self.snapshot_interval == 0):
                        save(
                            os.path.join(
                                self.save_snapshot_dir, 'Epoch' + str(epoch) +
                                '_' + str(i) + '.state'), self.netD, self.netG,
                            self.optimizerD, self.optimizerG)
Exemple #28
0
def train(opt, netG):
    # Re-generate dataset frames
    fps, td, fps_index = utils.get_fps_td_by_index(opt.scale_idx, opt)
    opt.fps = fps
    opt.td = td
    opt.fps_index = fps_index

    with logger.LoggingBlock("Updating dataset", emph=True):
        logging.info("{}FPS :{} {}{}".format(green, clear, opt.fps, clear))
        logging.info("{}Time-Depth :{} {}{}".format(green, clear, opt.td,
                                                    clear))
        logging.info("{}Sampling-Ratio :{} {}{}".format(
            green, clear, opt.sampling_rates[opt.fps_index], clear))
        opt.dataset.generate_frames(opt.scale_idx)

    # Initialize noise
    if not hasattr(opt, 'Z_init_size'):
        initial_size = utils.get_scales_by_index(0, opt.scale_factor,
                                                 opt.stop_scale, opt.img_size)
        initial_size = [int(initial_size * opt.ar), initial_size]
        opt.Z_init_size = [
            opt.batch_size, opt.latent_dim, opt.td, *initial_size
        ]

    if opt.vae_levels < opt.scale_idx + 1:
        D_curr = getattr(networks_3d, opt.discriminator)(opt).to(opt.device)

        if (opt.netG != '') and (opt.resumed_idx == opt.scale_idx):
            D_curr.load_state_dict(
                torch.load('{}/netD_{}.pth'.format(
                    opt.resume_dir, opt.scale_idx - 1))['state_dict'])
        elif opt.vae_levels < opt.scale_idx:
            D_curr.load_state_dict(
                torch.load(
                    '{}/netD_{}.pth'.format(opt.saver.experiment_dir,
                                            opt.scale_idx - 1))['state_dict'])

        # Current optimizers
        optimizerD = optim.Adam(D_curr.parameters(),
                                lr=opt.lr_d,
                                betas=(opt.beta1, 0.999))

    parameter_list = []
    # Generator Adversary
    if not opt.train_all:
        if opt.vae_levels < opt.scale_idx + 1:
            train_depth = min(opt.train_depth,
                              len(netG.body) - opt.vae_levels + 1)
            parameter_list += [{
                "params":
                block.parameters(),
                "lr":
                opt.lr_g *
                (opt.lr_scale**(len(netG.body[-train_depth:]) - 1 - idx))
            } for idx, block in enumerate(netG.body[-train_depth:])]
        else:
            # VAE
            parameter_list += [{
                "params": netG.encode.parameters(),
                "lr": opt.lr_g * (opt.lr_scale**opt.scale_idx)
            }, {
                "params": netG.decoder.parameters(),
                "lr": opt.lr_g * (opt.lr_scale**opt.scale_idx)
            }]
            parameter_list += [{
                "params":
                block.parameters(),
                "lr":
                opt.lr_g *
                (opt.lr_scale**(len(netG.body[-opt.train_depth:]) - 1 - idx))
            } for idx, block in enumerate(netG.body[-opt.train_depth:])]
    else:
        if len(netG.body) < opt.train_depth:
            parameter_list += [{
                "params": netG.encode.parameters(),
                "lr": opt.lr_g * (opt.lr_scale**opt.scale_idx)
            }, {
                "params": netG.decoder.parameters(),
                "lr": opt.lr_g * (opt.lr_scale**opt.scale_idx)
            }]
            parameter_list += [{
                "params":
                block.parameters(),
                "lr":
                opt.lr_g * (opt.lr_scale**(len(netG.body) - 1 - idx))
            } for idx, block in enumerate(netG.body)]
        else:
            parameter_list += [{
                "params":
                block.parameters(),
                "lr":
                opt.lr_g *
                (opt.lr_scale**(len(netG.body[-opt.train_depth:]) - 1 - idx))
            } for idx, block in enumerate(netG.body[-opt.train_depth:])]

    optimizerG = optim.Adam(parameter_list,
                            lr=opt.lr_g,
                            betas=(opt.beta1, 0.999))

    # Parallel
    if opt.device == 'cuda':
        G_curr = torch.nn.DataParallel(netG)
        if opt.vae_levels < opt.scale_idx + 1:
            D_curr = torch.nn.DataParallel(D_curr)
    else:
        G_curr = netG

    progressbar_args = {
        "iterable":
        range(opt.niter),
        "desc":
        "Training scale [{}/{}]".format(opt.scale_idx + 1, opt.stop_scale + 1),
        "train":
        True,
        "offset":
        0,
        "logging_on_update":
        False,
        "logging_on_close":
        True,
        "postfix":
        True
    }
    epoch_iterator = tools.create_progressbar(**progressbar_args)

    iterator = iter(data_loader)

    for iteration in epoch_iterator:
        try:
            data = next(iterator)
        except StopIteration:
            iterator = iter(opt.data_loader)
            data = next(iterator)

        if opt.scale_idx > 0:
            real, real_zero = data
            real = real.to(opt.device)
            real_zero = real_zero.to(opt.device)
        else:
            real = data.to(opt.device)
            real_zero = real

        noise_init = utils.generate_noise(size=opt.Z_init_size,
                                          device=opt.device)

        ############################
        # calculate noise_amp
        ###########################
        if iteration == 0:
            if opt.const_amp:
                opt.Noise_Amps.append(1)
            else:
                with torch.no_grad():
                    if opt.scale_idx == 0:
                        opt.noise_amp = 1
                        opt.Noise_Amps.append(opt.noise_amp)
                    else:
                        opt.Noise_Amps.append(0)
                        z_reconstruction, _, _ = G_curr(real_zero,
                                                        opt.Noise_Amps,
                                                        mode="rec")

                        RMSE = torch.sqrt(F.mse_loss(real, z_reconstruction))
                        opt.noise_amp = opt.noise_amp_init * RMSE.item(
                        ) / opt.batch_size
                        opt.Noise_Amps[-1] = opt.noise_amp

        ############################
        # (1) Update VAE network
        ###########################
        total_loss = 0

        generated, generated_vae, (mu, logvar) = G_curr(real_zero,
                                                        opt.Noise_Amps,
                                                        mode="rec")

        if opt.vae_levels >= opt.scale_idx + 1:
            rec_vae_loss = opt.rec_loss(generated, real) + opt.rec_loss(
                generated_vae, real_zero)
            kl_loss = kl_criterion(mu, logvar)
            vae_loss = opt.rec_weight * rec_vae_loss + opt.kl_weight * kl_loss

            total_loss += vae_loss
        else:
            ############################
            # (2) Update D network: maximize D(x) + D(G(z))
            ###########################
            # train with real
            #################

            # Train 3D Discriminator
            D_curr.zero_grad()
            output = D_curr(real)
            errD_real = -output.mean()

            # train with fake
            #################
            fake, _ = G_curr(noise_init,
                             opt.Noise_Amps,
                             noise_init=noise_init,
                             mode="rand")

            # Train 3D Discriminator
            output = D_curr(fake.detach())
            errD_fake = output.mean()

            gradient_penalty = calc_gradient_penalty(D_curr, real, fake,
                                                     opt.lambda_grad,
                                                     opt.device)
            errD_total = errD_real + errD_fake + gradient_penalty
            errD_total.backward()
            optimizerD.step()

            ############################
            # (3) Update G network: maximize D(G(z))
            ###########################
            errG_total = 0
            rec_loss = opt.rec_loss(generated, real)
            errG_total += opt.rec_weight * rec_loss

            # Train with 3D Discriminator
            output = D_curr(fake)
            errG = -output.mean() * opt.disc_loss_weight
            errG_total += errG

            total_loss += errG_total

        G_curr.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(G_curr.parameters(), opt.grad_clip)
        optimizerG.step()

        # Update progress bar
        epoch_iterator.set_description(
            'Scale [{}/{}], Iteration [{}/{}]'.format(
                opt.scale_idx + 1,
                opt.stop_scale + 1,
                iteration + 1,
                opt.niter,
            ))

        if opt.visualize:
            # Tensorboard
            opt.summary.add_scalar(
                'Video/Scale {}/noise_amp'.format(opt.scale_idx),
                opt.noise_amp, iteration)
            if opt.vae_levels >= opt.scale_idx + 1:
                opt.summary.add_scalar(
                    'Video/Scale {}/KLD'.format(opt.scale_idx), kl_loss.item(),
                    iteration)
            else:
                opt.summary.add_scalar(
                    'Video/Scale {}/rec loss'.format(opt.scale_idx),
                    rec_loss.item(), iteration)
            opt.summary.add_scalar(
                'Video/Scale {}/noise_amp'.format(opt.scale_idx),
                opt.noise_amp, iteration)
            if opt.vae_levels < opt.scale_idx + 1:
                opt.summary.add_scalar(
                    'Video/Scale {}/errG'.format(opt.scale_idx), errG.item(),
                    iteration)
                opt.summary.add_scalar(
                    'Video/Scale {}/errD_fake'.format(opt.scale_idx),
                    errD_fake.item(), iteration)
                opt.summary.add_scalar(
                    'Video/Scale {}/errD_real'.format(opt.scale_idx),
                    errD_real.item(), iteration)
            else:
                opt.summary.add_scalar(
                    'Video/Scale {}/Rec VAE'.format(opt.scale_idx),
                    rec_vae_loss.item(), iteration)

            if iteration % opt.print_interval == 0:
                with torch.no_grad():
                    fake_var = []
                    fake_vae_var = []
                    for _ in range(3):
                        noise_init = utils.generate_noise(ref=noise_init)
                        fake, fake_vae = G_curr(noise_init,
                                                opt.Noise_Amps,
                                                noise_init=noise_init,
                                                mode="rand")
                        fake_var.append(fake)
                        fake_vae_var.append(fake_vae)
                    fake_var = torch.cat(fake_var, dim=0)
                    fake_vae_var = torch.cat(fake_vae_var, dim=0)

                opt.summary.visualize_video(opt, iteration, real, 'Real')
                opt.summary.visualize_video(opt, iteration, generated,
                                            'Generated')
                opt.summary.visualize_video(opt, iteration, generated_vae,
                                            'Generated VAE')
                opt.summary.visualize_video(opt, iteration, fake_var,
                                            'Fake var')
                opt.summary.visualize_video(opt, iteration, fake_vae_var,
                                            'Fake VAE var')

    epoch_iterator.close()

    # Save data
    opt.saver.save_checkpoint({'data': opt.Noise_Amps}, 'Noise_Amps.pth')
    opt.saver.save_checkpoint(
        {
            'scale': opt.scale_idx,
            'state_dict': netG.state_dict(),
            'optimizer': optimizerG.state_dict(),
            'noise_amps': opt.Noise_Amps,
        }, 'netG.pth')
    if opt.vae_levels < opt.scale_idx + 1:
        opt.saver.save_checkpoint(
            {
                'scale':
                opt.scale_idx,
                'state_dict':
                D_curr.module.state_dict()
                if opt.device == 'cuda' else D_curr.state_dict(),
                'optimizer':
                optimizerD.state_dict(),
            }, 'netD_{}.pth'.format(opt.scale_idx))
    def train(self, num_epoch):
        for epoch in range(num_epoch):
            if (self.resample):
                train_dl_iter = iter(self.train_dl)
            for i, data in enumerate(tqdm(self.train_dl)):
                # (1) : minimizes mean((D(x) - mean(D(G(z))) - 1)**2) + mean((D(G(z)) - mean(D(x)) + 1)**2)
                self.netD.zero_grad()
                real_images = data[0].to(self.device)
                bs = real_images.size(0)
                # real labels (bs)
                real_label = torch.full((bs, ),
                                        self.real_label,
                                        device=self.device)
                # fake labels (bs)
                fake_label = torch.full((bs, ),
                                        self.fake_label,
                                        device=self.device)
                # noise (bs, nz, 1, 1), fake images (bs, cn, 64, 64)
                noise = generate_noise(bs, self.nz, self.device)
                fake_images = self.netG(noise)
                # calculate the discriminator results for both real & fake
                c_xr = self.netD(real_images)  # (bs, 1, 1, 1)
                c_xr = c_xr.view(-1)  # (bs)
                c_xf = self.netD(fake_images.detach())  # (bs, 1, 1, 1)
                c_xf = c_xf.view(-1)  # (bs)
                # calculate the Discriminator loss
                errD = (torch.mean(
                    (c_xr - torch.mean(c_xf) - real_label)**2) + torch.mean(
                        (c_xf - torch.mean(c_xr) + real_label)**2)) / 2.0
                errD.backward()
                # update D using the gradients calculated previously
                self.optimizerD.step()

                # (2) : minimizes mean((D(G(z)) - mean(D(x)) - 1)**2) + mean((D(x) - mean(D(G(z))) + 1)**2)
                self.netG.zero_grad()
                if (self.resample):
                    real_images = next(train_dl_iter)[0].to(self.device)
                    noise = generate_noise(bs, self.nz, self.device)
                    fake_images = self.netG(noise)
                # we updated the discriminator once, therefore recalculate c_xr, c_xf
                c_xr = self.netD(real_images)  # (bs, 1, 1, 1)
                c_xr = c_xr.view(-1)  # (bs)
                c_xf = self.netD(fake_images)  # (bs, 1, 1, 1)
                c_xf = c_xf.view(-1)  # (bs)
                # calculate the Generator loss
                errG = (torch.mean(
                    (c_xf - torch.mean(c_xr) - real_label)**2) + torch.mean(
                        (c_xr - torch.mean(c_xf) + real_label)**2)) / 2.0
                errG.backward()
                # update G using the gradients calculated previously
                self.optimizerG.step()

                self.errD_records.append(float(errD))
                self.errG_records.append(float(errG))

                if (i % self.loss_interval == 0):
                    print('[%d/%d] [%d/%d] errD : %.4f, errG : %.4f' %
                          (epoch + 1, num_epoch, i + 1,
                           self.train_iteration_per_epoch, errD, errG))

                if (i % self.image_interval == 0):
                    if (self.special == None):
                        sample_images_list = get_sample_images_list(
                            'Unsupervised', (self.fixed_noise, self.netG))
                        plot_fig = plot_multiple_images(
                            sample_images_list, 4, 4)
                        cur_file_name = os.path.join(
                            self.save_img_dir,
                            str(self.save_cnt) + ' : ' + str(epoch) + '-' +
                            str(i) + '.jpg')
                        self.save_cnt += 1
                        save_fig(cur_file_name, plot_fig)
                        plot_fig.clf()

                    elif (self.special == 'Wave'):
                        sample_audios_list = get_sample_images_list(
                            'Unsupervised_Audio',
                            (self.fixed_noise, self.netG))
                        plot_fig = plot_multiple_spectrograms(
                            sample_audios_list, 4, 4, freq=16000)
                        cur_file_name = os.path.join(
                            self.save_img_dir,
                            str(self.save_cnt) + ' : ' + str(epoch) + '-' +
                            str(i) + '.jpg')
                        self.save_cnt += 1
                        save_fig(cur_file_name, plot_fig)
                        plot_fig.clf()

                if (self.snapshot_interval is not None):
                    if (i % self.snapshot_interval == 0):
                        save(
                            os.path.join(
                                self.save_snapshot_dir, 'Epoch' + str(epoch) +
                                '_' + str(i) + '.state'), self.netD, self.netG,
                            self.optimizerD, self.optimizerG)
Exemple #30
0
def train(opt):
    images = creat_reals_pyramid(opt)
    place = fluid.CUDAPlace(0) if opt.use_gpu else fluid.CPUPlace()
    priors = []
    prior_recons = []
    netD_arrs = []
    netG_arrs = []
    noiseamp_arrs = []
    opt.padd_size = 1
    for idx in range(0, len(images)):
        outdir = "%s/%d/" % (opt.out, idx)
        if not os.path.isdir(outdir):
            os.mkdir(outdir)
        with fluid.dygraph.guard():
            real = images[idx]
            in_s = np.zeros(shape=real.shape, dtype=np.float32)
            zero = fluid.layers.zeros(shape=[1], dtype='float32')
            #zero.stop_gradient = True
            one = fluid.layers.ones(shape=[1], dtype='float32')
            #one.stop_gradient = True
            alpha = to_variable(np.array([opt.alpha]).astype('float32'))
            optimizerG = fluid.optimizer.Adam(learning_rate=opt.lr_d,
                                              beta1=opt.beta1,
                                              beta2=0.999,
                                              name='net_GA')
            optimizerD = fluid.optimizer.Adam(learning_rate=opt.lr_d,
                                              beta1=opt.beta1,
                                              beta2=0.999,
                                              name='net_DA')
            backward_strategy = fluid.dygraph.BackwardStrategy()
            backward_strategy.sort_sum_gradient = True
            #optimizerD = fluid.optimizer.RMSPropOptimizer(learning_rate=opt.lr_d, name="opD")
            #optimizerG = fluid.optimizer.RMSPropOptimizer(learning_rate=opt.lr_g, name="opG")
            #fluid.clip.set_gradient_clip(fluid.clip.GradientClipByValue(0,1))
            netD = Discriminator("DA", opt)
            netG = Generator("GA", opt)
            # fluid.clip.set_gradient_clip(fluid.clip.GradientClipByValue(min=-0.01, max=0.01),param_list=[netD.parameters(),netG.parameters()])
            vreal = to_variable(real)
            for epoch in range(opt.niter):
                noise_epoch = generate_noise(real.shape, opt)
                prev = in_s
                prev_rec = in_s
                opt.noise_amp = 1
                for idx in range(len(netG_arrs)):
                    prev = priors[idx]
                    prev_rec = prior_recons[idx]
                    opt.noise_amp = noiseamp_arrs[idx]
                prev = resize(prev, (real.shape[3], real.shape[2]))
                prev_rec = resize(prev_rec, (real.shape[3], real.shape[2]))
                vprev = to_variable(prev)
                vprev_rec = to_variable(prev_rec)
                for j in range(opt.Dsteps):
                    netD.clear_gradients()
                    outD_real = netD(vreal)
                    errD_real = fluid.layers.mean(outD_real)
                    errD_real = 0.0 - errD_real
                    errD_real.backward(backward_strategy)
                    #errD_real = fluid.layers.elementwise_sub(zero, errD_real)
                    #                    errD_real.backward(backward_strategy)
                    noise = opt.noise_amp * noise_epoch + prev
                    vnoise = to_variable(noise)
                    outG_fake = netG(vnoise.detach(), vprev)
                    outD_fake = netD(outG_fake.detach())
                    errD_fake = fluid.layers.mean(outD_fake)
                    #        errD_fake = fluid.layers.elementwise_sub(zero, errD_fake)
                    errD_fake.backward(backward_strategy)
                    #gradient_penalty = calc_gradient_penalty(netD, vreal, outG_fake, opt, backward_strategy)
                    #gradient_penalty.backward()
                    #errD = errD_real + errD_fake + gradient_penalty
                    errD = errD_real + errD_fake
                    params_d = optimizerD.backward(
                        errD, parameter_list=netD.parameters())
                    optimizerD.apply_gradients(params_d)
                for j in range(opt.Gsteps):
                    netG.clear_gradients()
                    outD_fakeG = netD(outG_fake)
                    errD_fakeG = 0.0 - fluid.layers.mean(outD_fakeG)
                    #             errD_fakeG = fluid.layers.elementwise_add(zero, errD_fakeG)
                    errD_fakeG.backward(backward_strategy)
                    noise_fake = opt.noise_amp * noise_epoch + prev_rec
                    noise_fake = to_variable(noise_fake)
                    outG_fake_rec = netG(noise_fake.detach(), vprev_rec)
                    rec_loss = fluid.layers.mse_loss(vreal, outG_fake_rec)
                    RMSE = fluid.layers.sqrt(rec_loss).numpy()
                    rec_loss = fluid.layers.elementwise_mul(alpha, rec_loss)
                    rec_loss.backward(backward_strategy)
                    errG = rec_loss + errD_fakeG
                    #errG =  errD_fakeG
                    params_g = optimizerG.backward(
                        errG, parameter_list=netG.parameters())
                    optimizerG.apply_gradients(params_g)
                #netD.clear_gradients()
                #netG.clear_gradients()
                if epoch % 25 == 0 or epoch == (opt.niter - 1):
                    print(
                        'shape %s [epoch:%d/%d][errD:%.5f][errG:%.5f][rec_loss:%.5f][noise_amp:%.5f][errD_real:%.5f][errD_fake:%.5f][outD_fakeG:%.5f]'
                        % (real.shape, epoch, opt.niter, errD.numpy(),
                           errG.numpy(), rec_loss.numpy(), opt.noise_amp,
                           errD_real.numpy(), errD_fake.numpy(),
                           errD_fakeG.numpy()))
                if epoch % 500 == 0 or epoch == (opt.niter - 1):
                    dump_img(
                        outG_fake.numpy(),
                        os.path.join(
                            outdir,
                            "fake_sample_%d_%s" % (epoch, opt.input_name)))
                    dump_img(
                        outG_fake_rec.numpy(),
                        os.path.join(
                            outdir,
                            "G(z_opt)_%d_%s" % (epoch, opt.input_name)))
            fluid.dygraph.save_dygraph(netD.state_dict(),
                                       os.path.join(outdir, "DA"))
            fluid.dygraph.save_dygraph(netG.state_dict(),
                                       os.path.join(outdir, "GA"))
            fluid.dygraph.save_dygraph(optimizerD.state_dict(),
                                       os.path.join(outdir, "DA"))
            fluid.dygraph.save_dygraph(optimizerG.state_dict(),
                                       os.path.join(outdir, "GA"))
            opt.noise_amp = opt.noise_amp_init * RMSE
            netD.eval()
            netD_arrs.append(netD)
            netG.eval()
            netG_arrs.append(netG)
            priors.append(outG_fake.numpy())
            prior_recons.append(outG_fake_rec.numpy())
            noiseamp_arrs.append(opt.noise_amp)