Esempio n. 1
0
    def training_step(self, batch, batch_idx, optimizer_idx):
        images = batch

        # gp
        apply_gp = self.iterations % self.gp_every == 0

        # train discriminator
        if optimizer_idx == 0:
            images = images.requires_grad_()

            # increase resolution
            if self.iterations % self.add_layers_iters == 0:
                if self.iterations != 0:
                    self.D.increase_resolution_()

                image_size = self.D.resolution.item()
                self.G.set_image_size(image_size)
                self.image_dataset.set_transforms(image_size)

            real_out = self.D(images)

            fake_images = self.generate_samples(self.batch_size)
            fake_out = self.D(fake_images.clone().detach())

            divergence = (F.relu(1 + real_out) + F.relu(1 - fake_out)).mean()
            loss_D = divergence

            if apply_gp:
                gp = gradient_penalty(images, real_out)
                self.last_loss_gp = get_item(gp)
                loss = loss_D + gp
            else:
                loss = loss_D

            self.last_loss_D = loss_D

            tqdm_dict = {'loss_D': loss_D}
            output = OrderedDict({
                'loss': loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            })
            return output

        # train generator
        if optimizer_idx == 1:
            fake_images = self.generate_samples(self.batch_size)
            loss_G = self.D(fake_images).mean()

            self.last_loss_G = loss_G

            tqdm_dict = {'loss_G': loss_G}
            output = OrderedDict({
                'loss': loss_G,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            })

            return output
Esempio n. 2
0
    def train_step(self, x, y, update_generator=True):

        y_fake = self.generator_g(x)
        x_fake = self.generator_f(y_fake)

        for p in self.discriminator.parameters():
            p.requires_grad = False

        content_loss = self.content_criterion(x, x_fake)
        tv_loss = self.tv_criterion(y_fake)

        batch_size = x.size(0)
        pos_labels = torch.ones(batch_size, dtype=torch.uint8, device=x.device)
        neg_labels = torch.zeros(batch_size,
                                 dtype=torch.uint8,
                                 device=x.device)
        realism_generation_loss = self.realism_criterion(
            self.discriminator(y_fake), pos_labels)

        if update_generator:

            generator_loss = content_loss + 100.0 * tv_loss
            generator_loss += 5e-2 * realism_generation_loss

            self.g_optimizer.zero_grad()
            self.f_optimizer.zero_grad()
            generator_loss.backward()
            self.g_optimizer.step()
            self.f_optimizer.step()

        for p in self.discriminator.parameters():
            p.requires_grad = True

        targets = torch.cat([pos_labels, neg_labels], dim=0)
        is_real_real = self.discriminator(y)
        is_fake_real = self.discriminator(y_fake.detach())
        scores = torch.cat([is_real_real, is_fake_real], dim=0)
        discriminator_loss = self.realism_criterion(scores, targets)

        lambda_constant = 5.0
        gp = gradient_penalty(y, y_fake.detach(), self.discriminator)

        self.d_optimizer.zero_grad()
        (discriminator_loss + lambda_constant * gp).backward(retain_graph=True)
        self.d_optimizer.step()

        loss_dict = {
            'content': content_loss.item(),
            'tv': tv_loss.item(),
            'realism_generation': realism_generation_loss.item(),
            'discriminator': discriminator_loss.item(),
            'gradient_penalty': gp.item()
        }
        return loss_dict
Esempio n. 3
0
def train_one_epoch(gen, critic, opt_gen, opt_crt, scaler_gen, scaler_crt,
                    dataloader, metric_logger, dataset, step, alpha, device,
                    fixed_noise, epoch, stage):

    loop = tqdm(dataloader, leave=True)
    for batch_idx, real in enumerate(loop):
        real = real.to(device)
        cur_batch_size = real.shape[0]

        # Train critic: maximize(E[critic(real)] - E[critic(fake)]) or (-1 * maximize(E[critic(real)]) + E[critic(fake)]
        noise = torch.randn(cur_batch_size, cfg.Z_DIMENSION, 1, 1).to(device)
        with torch.cuda.amp.autocast():
            fake = gen(noise, alpha, step)
            critic_real = critic(real, alpha, step)
            critic_fake = critic(fake.detach(), alpha, step)
            gp = gradient_penalty(critic, real, fake, alpha, step, device=device)
            loss_critic = (
                    -(torch.mean(critic_real) - torch.mean(critic_fake))
                    + cfg.LAMBDA_GP * gp + (0.001 * torch.mean(critic_real ** 2))
            )
        opt_crt.zero_grad()
        scaler_crt.scale(loss_critic).backward()
        scaler_crt.step(opt_crt)
        scaler_crt.update()

        # Train generator maximize(E[critic(gen_fake)] <-> min -E[critic(gen_fake)])
        with torch.cuda.amp.autocast():
            gen_fake = critic(fake, alpha, step)
            loss_gen = -torch.mean(gen_fake)
        opt_gen.zero_grad()
        scaler_gen.scale(loss_gen).backward()
        scaler_gen.step(opt_gen)
        scaler_gen.update()

        # update alpha
        alpha += cur_batch_size / (cfg.PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
        alpha = min(alpha, 1)

        if batch_idx % cfg.FREQ == 0:
            with torch.no_grad():
                fixed_fakes = gen(fixed_noise, alpha, step) * 0.5 + 0.5
                metric_logger.log(loss_critic, loss_gen)
                metric_logger.log_image(fixed_fakes, cfg.NUM_SAMPLES, epoch, stage)

        loop.set_postfix(
            gp=gp.item(),
            loss_critic=loss_critic.item(),
        )

    return alpha
Esempio n. 4
0
    def train_discriminator(self,x_real):
        with tf.GradientTape() as t:
            z = tf.random.normal(shape=(self.spec['batch_size'], self.latent_dim))
            x_fake = self.generative_net(z, training=True)

            x_real_d_logit = self.inference_net(x_real, training=True)
            x_fake_d_logit = self.inference_net(x_fake, training=True)

            x_real_d_loss, x_fake_d_loss =  self.d_loss_fn(x_real_d_logit, x_fake_d_logit)
            gp = utils.gradient_penalty(partial(self.inference_net,
                                        training=True), x_real, x_fake)

            D_loss = (x_real_d_loss + x_fake_d_loss) + gp * self.spec['gradient_penalty']

        D_grad = t.gradient(D_loss, self.inference_net.trainable_variables)
        self.D_optimizer.apply_gradients(zip(D_grad, self.inference_net.trainable_variables))

        return {'d_loss': x_real_d_loss + x_fake_d_loss, 'gp': gp}
    def fit_single_scale(self, img: np.ndarray, target_size: Tuple[int, int],
                         steps: int) -> None:
        real = torch.nn.functional.interpolate(img,
                                               target_size,
                                               mode='bilinear',
                                               align_corners=True).float().to(
                                                   self.device)

        # paragraph below equation (5) in paper
        if self.last_rec is not None:
            # compute root mean squared error between upsampled version of rec from last scale and the real target of the current scale
            rmse = torch.sqrt(
                torch.nn.functional.mse_loss(
                    real,
                    torch.nn.functional.interpolate(self.last_rec,
                                                    target_size,
                                                    mode='bilinear',
                                                    align_corners=True)))
            self.rmses.insert(0, rmse.detach().cpu().numpy().item())
        print(self.rmses)

        self.logger.set_mode('training')
        for step in range(1, steps + 1):
            self.logger.new_step()

            # ====== train discriminator =======================================
            self.d_pyramid[0].zero_grad()

            # generate a fake image
            fake = self.forward_g_pyramid(target_size=target_size)
            # let the discriminator judge the fake image patches (without any gradient flow through the generator )
            d_fake = self.d_pyramid[0](fake.detach())

            # loss for fake images
            adv_d_fake_loss = torch.mean(d_fake)
            adv_d_fake_loss.backward()

            # let the discriminator judge the real image patches
            d_real = self.d_pyramid[0](real)

            # loss for real images
            adv_d_real_loss = (-1) * torch.mean(d_real)
            adv_d_real_loss.backward()

            # gradient penalty loss
            grad_penalty = gradient_penalty(
                self.d_pyramid[0], real, fake,
                self.device) * self.hypers['grad_penalty_weight']
            grad_penalty.backward()

            # make a step against the gradient
            self.d_optimizer.step()

            # ====== train generator ===========================================
            self.g_pyramid[0].zero_grad()

            # let the discriminator judge the fake image patches
            d_fake = self.d_pyramid[0](fake)

            # loss for fake images
            adv_g_loss = (-1) * torch.mean(d_fake)
            adv_g_loss.backward()

            # reconstruct original image with fixed z_init and else no noise
            rec = self.forward_g_pyramid(target_size=target_size,
                                         Z=[self.z_init] + [0] *
                                         (len(self.g_pyramid) - 1))

            # reconstruction loss
            rec_g_loss = torch.nn.functional.mse_loss(
                rec, real) * self.hypers['rec_loss_weight']
            rec_g_loss.backward()

            # make a step against the gradient
            self.g_optimizer.step()

            loss_dict = {
                'adv_d_fake_loss': adv_d_fake_loss,
                'adv_d_real_loss': adv_d_real_loss,
                'adv_g_loss': adv_g_loss,
                'rec_g_loss': rec_g_loss,
            }

            self.logger.log_losses(loss_dict)
            print(f'[{self.N - len(self.g_pyramid) + 1}: {step}|{steps}] -',
                  loss_dict)

            # increment counter for task progress tracking
            self.done_steps += 1
            self.update_celery_state()

        # save current reconstruction of this scale temporally
        self.last_rec = rec

        self.logger.set_mode('sampling')
        for step in range(8):
            # sample from learned distribution
            fake = self.test(target_size=target_size)

            # log image
            self.logger.log_images(fake, name=str(step), dataformats='HWC')

            # save image traditionally to file
            run_dir = os.path.join('samples', self.logger.run_name)
            if not os.path.isdir(run_dir):
                os.makedirs(run_dir)
            imwrite(
                os.path.join(run_dir,
                             f'{self.N - len(self.g_pyramid) + 1}_{step}.jpg'),
                fake)

        # save reconstruction of real image and downscaled copy of real image
        imwrite(
            os.path.join(run_dir,
                         f'{self.N - len(self.g_pyramid) + 1}_rec.jpg'),
            self.transform_output(rec[0].detach().cpu().numpy().transpose(
                1, 2, 0)).astype('uint8'))
        imwrite(
            os.path.join(run_dir,
                         f'{self.N - len(self.g_pyramid) + 1}_real.jpg'),
            self.transform_output(real[0].detach().cpu().numpy().transpose(
                1, 2, 0)).astype('uint8'))
if AE_type == "VAE":
    kld_loss_a = -tf.reduce_mean(0.5 * (1 + c_log_sigma_sq_a - c_mu_a**2 - tf.exp(c_log_sigma_sq_a)))
    kld_loss_b = -tf.reduce_mean(0.5 * (1 + c_log_sigma_sq_b - c_mu_b**2 - tf.exp(c_log_sigma_sq_b)))
    kld_loss = kld_loss_a + kld_loss_b
else:
    kld_loss_a = tf.constant(0.)
    kld_loss_b = tf.constant(0.)
    kld_loss = tf.constant(0.)

adv_loss =  tf.losses.mean_squared_error(tf.reduce_mean(Disc_a), tf.reduce_mean(Disc_b))

G_loss = rec_loss + kld_loss * beta + adv_loss * gamma

# D loss components
wd_loss = tf.reduce_mean(Disc_b) - tf.reduce_mean(Disc_a)
gp_loss = utils.gradient_penalty(c_a, c_b, Disc)


D_loss = wd_loss + gp_loss * delta


# otpimizers
d_vars = utils.trainable_variables('discriminator')
g_vars = utils.trainable_variables(['Encoder', 'Decoder_a', 'Decoder_b'])

# LR decay policy
iters_per_epoch = int(min_n/batch_size)
decay_steps = iters_per_epoch * decay_epochs


Esempio n. 7
0
def train_one_epoch_with_gp(gen, opt_gen, crt, opt_crt, dataloader,
                            metric_logger, device, fixed_noise, epoch,
                            fid_model, fid_score):
    """
    Train one epoch with wasserstein distance with gradient penalty
    :param gen: ``Instance of models.model.Generator``, generator model
    :param opt_gen: ``Instance of torch.optim``, optimizer for generator
    :param crt: ``Instance of models.model.Discriminator``, discriminator model
    :param opt_dis: ``Instance of torch.optim``, optimizer for discriminator
    :param dataloader: ``Instance of torch.utils.data.DataLoader``, train dataloader
    :param metric_logger: ``Instance of metrics.MetricLogger``, logger
    :param device: ``Instance of torch.device``, cuda device
    :param fixed_noise: ``Tensor([N, 1, Z, 1, 1])``, fixed noise for display result
    :param epoch: ``int``, current epoch
    :param fid_model: model for calculate fid score
    :param fid_score: ``bool``, if True - calculate fid score otherwise no
    """
    loop = tqdm(dataloader, leave=True)
    for batch_idx, real in enumerate(loop):
        cur_batch_size = real.shape[0]
        real = real.to(device)
        real.requires_grad_()
        real_cropped_128 = center_crop(real, size=(128, 128))
        real_128 = resize(real, size=(128, 128))
        # Train critic
        for _ in range(cfg.CRITIC_ITERATIONS):
            noise = torch.randn(cur_batch_size, cfg.Z_DIMENSION, 1,
                                1).to(device)
            fake = gen(noise)
            real_fake_logits_real_images, decoded_real_img_part, decoded_real_img = crt(
                DiffAugment(real, policy=cfg.DIFF_AUGMENT_POLICY))
            real_fake_logits_fake_images, _, _ = crt(
                DiffAugment(fake, policy=cfg.DIFF_AUGMENT_POLICY))

            divergence = hinge_adv_loss(real_fake_logits_real_images,
                                        real_fake_logits_fake_images, device)
            i_recon_loss = reconstruction_loss_mse(real_128, decoded_real_img)
            i_part_recon_loss = reconstruction_loss_mse(
                real_cropped_128, decoded_real_img_part)
            gp = gradient_penalty(crt, real, fake, device)
            crt_loss = (divergence + i_recon_loss +
                        i_part_recon_loss) + cfg.LAMBDA_GP * gp
            crt.zero_grad()
            crt_loss.backward(retain_graph=True)
            opt_crt.step()

        # Train generator
        fake_logits, _, _ = crt(fake)
        g_loss = -torch.mean(fake_logits)
        gen.zero_grad()
        g_loss.backward()
        opt_gen.step()

        # Eval and metrics
        if fid_score:
            if epoch % cfg.FID_FREQ == 0:
                # TODO: test fid on GPU
                fid = evaluate(gen, fid_model, device)
                gen.train()
                metric_logger.log_fid(fid)
        if batch_idx % cfg.LOG_FREQ == 0:
            metric_logger.log(g_loss, crt_loss, divergence, i_recon_loss,
                              i_part_recon_loss)
        if batch_idx % cfg.LOG_IMAGE_FREQ == 0:
            with torch.no_grad():
                fixed_fakes = gen(fixed_noise)
                metric_logger.log_image(fixed_fakes,
                                        cfg.NUM_SAMPLES_IMAGES,
                                        epoch,
                                        batch_idx,
                                        normalize=True)

        loop.set_postfix(d_loss=crt_loss.item(), g_loss=g_loss.item())
Esempio n. 8
0
critic.train()

for epoch in range(NUM_EPOCHS):
    # Target labels not needed! <3 unsupervised
    for batch_idx, (real, _) in enumerate(loader):
        real = real.to(device)
        cur_batch_size = real.shape[0]

        # Train Critic: max E[critic(real)] - E[critic(fake)]
        # equivalent to minimizing the negative of that
        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
            fake = gen(noise)
            critic_real = critic(real).reshape(-1)
            critic_fake = critic(fake).reshape(-1)
            gp = gradient_penalty(critic, real, fake, device=device)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake)) +
                LAMBDA_GP * gp)
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

        # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
        gen_fake = critic(fake).reshape(-1)
        loss_gen = -torch.mean(gen_fake)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Print losses occasionally and print to tensorboard
Esempio n. 9
0
        out_src, out_cls = discriminator(real.float().cuda(),
                                         t2.float().cuda())
        d_loss_real = -torch.mean(out_src.sum([1, 2, 3]))
        d_loss_cls = classification_loss(out_cls, label)

        ############################################## discriminator
        fake = generator(t2.float().cuda(), label)
        out_src, out_cls = discriminator(fake.detach(), t2.float().cuda())
        d_loss_fake = torch.mean(out_src.sum([1, 2, 3]))

        # Compute loss for gradient penalty.
        alpha = torch.rand(real.size(0), 1, 1, 1).cuda()
        x_hat = (alpha * real.cuda().data +
                 (1 - alpha) * fake.data).requires_grad_(True)
        out_src, _ = discriminator(x_hat, t2.float().cuda())
        d_loss_gp = gradient_penalty(out_src, x_hat)
        #         d_loss_gp.backward(retain_graph=True)
        d_loss = d_loss_real + d_loss_fake + LAMBDA_CLS * d_loss_cls + LAMBDA_GP * d_loss_gp
        optimizer_d.zero_grad()
        d_loss.backward()
        optimizer_d.step()
        ############################################## generator
        fake = generator(t2.float().cuda(), label)
        out_src, out_cls = discriminator(fake, t2.float().cuda())
        g_loss_fake = -torch.mean(out_src.sum([1, 2, 3]))
        g_loss_cls = classification_loss(out_cls, label)
        g_loss_rec = torch.mean(
            torch.abs(real.float().cuda() - fake).sum([1, 2, 3]))
        pred = unet(fake, label)
        g_loss_seg = dice_loss(pred, seg.float().cuda())
        g_loss = g_loss_fake + LAMBDA_CLS * g_loss_cls + g_loss_rec * LAMBDA_REC + g_loss_seg * LAMBDA_SEG
Esempio n. 10
0
writer_real = SummaryWriter(f'runs/GAN_MNIST/real')
step = 0

# Training
for epoch in range(num_epochs):
    for batch_idx, (real, labels) in enumerate(loader):
        real = real.to(device)
        labels = labels.to(device)

        # For Critic
        for _ in range(critic_n):
            noise = torch.randn((batch_size, z_dim, 1, 1)).to(device)
            fake = gen(noise, labels)
            critic_real = critic(real, labels)
            critic_fake = critic(fake, labels)
            gp = gradient_penalty(critic, real, labels, fake, device)
            loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake) +
                            lambda_gp * gp)

            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

        # For Generator
        output = critic(fake, labels).view(-1)
        loss_gen = -torch.mean(output)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        if batch_idx == 0:
def train_fn(
    critic,
    gen,
    loader,
    dataset,
    step,
    alpha,
    opt_critic,
    opt_gen,
    tensorboard_step,
    writer,
):
    start = time.time()
    total_time = 0
    training = tqdm(loader, leave=True)
    for batch_idx, (real, _) in enumerate(training):
        real = real.to(device)
        cur_batch_size = real.shape[0]
        model_start = time.time()

        # Train Critic: max E[critic(real)] - E[critic(fake)]
        # which is equivalent to minimizing the negative of the expression
        for _ in range(CRITIC_ITERATIONS):
            critic.zero_grad()
            noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
            fake = gen(noise, alpha, step)
            critic_real = critic(real, alpha, step).reshape(-1)
            critic_fake = critic(fake, alpha, step).reshape(-1)
            gp = gradient_penalty(critic,
                                  real,
                                  fake,
                                  alpha,
                                  step,
                                  device=device)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake)) +
                LAMBDA_GP * gp)
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

        # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
        gen.zero_grad()
        fake = gen(noise, alpha, step)
        gen_fake = critic(fake, alpha, step).reshape(-1)
        loss_gen = -torch.mean(gen_fake)
        loss_gen.backward()
        opt_gen.step()

        # Update alpha and ensure less than 1
        alpha += cur_batch_size / (
            (PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)  # - step
        )
        alpha = min(alpha, 1)
        total_time += time.time() - model_start

        if batch_idx % 300 == 0:
            with torch.no_grad():
                fixed_fakes = gen(fixed_noise, alpha, step)
            plot_to_tensorboard(writer, loss_critic, loss_gen, real,
                                fixed_fakes, tensorboard_step)
            tensorboard_step += 1

    print(
        f'Fraction spent on model training: {total_time/(time.time()-start)}')
    return tensorboard_step, alpha
Esempio n. 12
0
def train_fn(
    critic,
    gen,
    loader,
    dataset,
    step,
    alpha,
    opt_critic,
    opt_gen,
    tensorboard_step,
    writer,
    scaler_gen,
    scaler_critic,
):
    start = time.time()
    total_time = 0
    loop = tqdm(loader, leave=True)
    losses_critic = []

    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(config.DEVICE)
        cur_batch_size = real.shape[0]
        model_start = time.time()

        for _ in range(4):
            # Train Critic: max E[critic(real)] - E[critic(fake)]
            # which is equivalent to minimizing the negative of the expression
            for _ in range(config.CRITIC_ITERATIONS):
                noise = torch.randn(cur_batch_size, config.Z_DIM, 1, 1).to(config.DEVICE)

                with torch.cuda.amp.autocast():
                    fake = gen(noise, alpha, step)
                    critic_real = critic(real, alpha, step).reshape(-1)
                    critic_fake = critic(fake, alpha, step).reshape(-1)
                    gp = gradient_penalty(critic, real, fake, alpha, step, device=config.DEVICE)
                    loss_critic = (
                        -(torch.mean(critic_real) - torch.mean(critic_fake))
                        + config.LAMBDA_GP * gp
                    )

                losses_critic.append(loss_critic.item())
                opt_critic.zero_grad()
                scaler_critic.scale(loss_critic).backward()
                scaler_critic.step(opt_critic)
                scaler_critic.update()
                #loss_critic.backward(retain_graph=True)
                #opt_critic.step()

            # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
            with torch.cuda.amp.autocast():
                fake = gen(noise, alpha, step)
                gen_fake = critic(fake, alpha, step).reshape(-1)
                loss_gen = -torch.mean(gen_fake)

            opt_gen.zero_grad()
            scaler_gen.scale(loss_gen).backward()
            scaler_gen.step(opt_gen)
            scaler_gen.update()
            #gen.zero_grad()
            #loss_gen.backward()
            #opt_gen.step()

        # Update alpha and ensure less than 1
        alpha += cur_batch_size / (
            (config.PROGRESSIVE_EPOCHS[step]*0.5) * len(dataset) # - step
        )
        alpha = min(alpha, 1)
        total_time += time.time()-model_start

        if batch_idx % 10 == 0:
            print(alpha)
            with torch.no_grad():
                fixed_fakes = gen(config.FIXED_NOISE, alpha, step)
            plot_to_tensorboard(
                writer, loss_critic, loss_gen, real, fixed_fakes, tensorboard_step
            )
            tensorboard_step += 1

        mean_loss = sum(losses_critic) / len(losses_critic)
        loop.set_postfix(loss=mean_loss)

    print(f'Fraction spent on model training: {total_time/(time.time()-start)}')
    return tensorboard_step, alpha
Esempio n. 13
0
    def build_graph_with_losses(self,
                                batch_data,
                                cfg,
                                summary=True,
                                reuse=None):
        # batch_pos = batch_data / 127.5 - 1
        batch_pos = batch_data
        bbox = random_bbox(cfg)
        mask = bbox2mask(bbox, cfg)

        batch_incomplete = batch_pos * (1. - mask)
        ones_x = tf.ones_like(batch_incomplete)[:, :, :, 0:1]
        coarse_network_input = tf.concat(
            [batch_incomplete, ones_x, ones_x * mask], axis=3)

        coarse_output = self.coarse_network(coarse_network_input, reuse)
        batch_complete_coarse = coarse_output * mask + batch_incomplete * (
            1. - mask)

        refine_network_input = tf.concat(
            [batch_complete_coarse, ones_x, ones_x * mask], axis=3)
        refine_output = self.refine_network(refine_network_input, reuse)
        batch_complete_refine = refine_output * mask + batch_incomplete * (
            1. - mask)

        losses = {}

        # local patches
        local_patch_pos = local_patch(batch_pos, bbox)
        local_patch_coarse = local_patch(coarse_output, bbox)
        local_patch_refine = local_patch(refine_output, bbox)
        local_patch_mask = local_patch(mask, bbox)

        l1_alpha = cfg['coarse_l1_alpha']
        losses['coarse_l1_loss'] = l1_alpha * tf.reduce_mean(
            tf.abs(local_patch_pos - local_patch_coarse) *
            spatial_discounting_mask(cfg))
        losses['coarse_ae_loss'] = l1_alpha * tf.reduce_mean(
            tf.abs(batch_pos - coarse_output) * (1. - mask))

        losses['refine_l1_loss'] = losses['coarse_l1_loss'] + \
            tf.reduce_mean(tf.abs(local_patch_pos - local_patch_refine) *
                           spatial_discounting_mask(cfg))
        losses['refine_ae_loss'] = losses['coarse_ae_loss'] + \
            tf.reduce_mean(tf.abs(batch_pos - refine_output) * (1. - mask))

        losses['coarse_ae_loss'] /= tf.reduce_mean(1. - mask)
        losses['refine_ae_loss'] /= tf.reduce_mean(1. - mask)

        # wgan
        # global discriminator patch
        batch_pos_neg = tf.concat([batch_pos, batch_complete_refine], axis=0)

        # local discriminator patch
        local_patch_pos_neg = tf.concat([local_patch_pos, local_patch_refine],
                                        axis=0)

        # wgan with gradient penalty
        pos_neg_global, pos_neg_local = self.build_wgan_discriminator(
            batch_pos_neg, local_patch_pos_neg, reuse)

        pos_global, neg_global = tf.split(pos_neg_global, 2)
        pos_local, neg_local = tf.split(pos_neg_local, 2)

        # wgan loss
        g_loss_global, d_loss_global = gan_wgan_loss(pos_global, neg_global)
        g_loss_local, d_loss_local = gan_wgan_loss(pos_local, neg_local)

        losses['refine_g_loss'] = cfg[
            'global_wgan_loss_alpha'] * g_loss_global + g_loss_local
        losses['refine_d_loss'] = d_loss_global + d_loss_local

        # gradient penalty
        interpolates_global = random_interpolates(batch_pos,
                                                  batch_complete_refine)
        interpolates_local = random_interpolates(local_patch_pos,
                                                 local_patch_refine)
        dout_global, dout_local = self.build_wgan_discriminator(
            interpolates_global, interpolates_local, reuse=True)

        # apply penalty
        penalty_global = gradient_penalty(interpolates_global,
                                          dout_global,
                                          mask=mask,
                                          norm=750.)
        penalty_local = gradient_penalty(interpolates_local,
                                         dout_local,
                                         mask=local_patch_mask,
                                         norm=750.)

        losses['gp_loss'] = cfg['wgan_gp_lambda'] * (penalty_global +
                                                     penalty_local)
        losses['refine_d_loss'] += losses['gp_loss']

        g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'coarse') +\
            tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'refine')
        d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'wgan_discriminator')

        if summary:
            # stage1
            tf.summary.scalar(
                'rec_loss/coarse_rec_loss',
                losses['coarse_l1_loss'] + losses['coarse_ae_loss'])
            # tf.summary.scalar('rec_loss/coarse_l1_loss', losses['coarse_l1_loss'])
            # tf.summary.scalar('rec_loss/coarse_ae_loss', losses['coarse_ae_loss'])
            tf.summary.scalar(
                'rec_loss/refine_rec_loss',
                losses['refine_l1_loss'] + losses['refine_ae_loss'])
            # tf.summary.scalar('rec_loss/refine_l1_loss', losses['refine_l1_loss'])
            # tf.summary.scalar('rec_loss/refine_ae_loss', losses['refine_ae_loss'])

            visual_img = [
                batch_pos, batch_incomplete, batch_complete_coarse,
                batch_complete_refine
            ]
            visual_img = tf.concat(visual_img, axis=2)
            images_summary(visual_img, 'raw_incomplete_coarse_refine', 4)

            # stage2
            gradients_summary(g_loss_global,
                              refine_output,
                              name='g_loss_global')
            gradients_summary(g_loss_local, refine_output, name='g_loss_local')

            tf.summary.scalar('convergence/refine_d_loss',
                              losses['refine_d_loss'])
            # tf.summary.scalar('convergence/refine_g_loss', losses['refine_g_loss'])
            tf.summary.scalar('convergence/local_d_loss', d_loss_local)
            tf.summary.scalar('convergence/global_d_loss', d_loss_global)

            tf.summary.scalar('gradient_penalty/gp_loss', losses['gp_loss'])
            tf.summary.scalar('gradient_penalty/gp_penalty_local',
                              penalty_local)
            tf.summary.scalar('gradient_penalty/gp_penalty_global',
                              penalty_global)

            # summary the magnitude of gradients from different losses w.r.t. predicted image
            # gradients_summary(losses['g_loss'], refine_output, name='g_loss')
            gradients_summary(losses['coarse_l1_loss'] +
                              losses['coarse_ae_loss'],
                              coarse_output,
                              name='rec_loss_grad_to_coarse')
            gradients_summary(losses['refine_l1_loss'] +
                              losses['refine_ae_loss'] +
                              losses['refine_g_loss'],
                              refine_output,
                              name='rec_loss_grad_to_refine')
            gradients_summary(losses['coarse_l1_loss'],
                              coarse_output,
                              name='l1_loss_grad_to_coarse')
            gradients_summary(losses['refine_l1_loss'],
                              refine_output,
                              name='l1_loss_grad_to_refine')
            gradients_summary(losses['coarse_ae_loss'],
                              coarse_output,
                              name='ae_loss_grad_to_coarse')
            gradients_summary(losses['refine_ae_loss'],
                              refine_output,
                              name='ae_loss_grad_to_refine')

        return g_vars, d_vars, losses
Esempio n. 14
0
def train_fn(
    loader,
    disc,
    gen,
    opt_gen,
    opt_disc,
    l1,
    vgg_loss,
    g_scaler,
    d_scaler,
    writer,
    tb_step,
):
    loop = tqdm(loader, leave=True)

    for idx, (low_res, high_res) in enumerate(loop):
        high_res = high_res.to(config.DEVICE)
        low_res = low_res.to(config.DEVICE)

        with torch.cuda.amp.autocast():
            fake = gen(low_res)
            critic_real = disc(high_res)
            critic_fake = disc(fake.detach())
            gp = gradient_penalty(disc, high_res, fake, device=config.DEVICE)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake)) +
                config.LAMBDA_GP * gp)

        opt_disc.zero_grad()
        d_scaler.scale(loss_critic).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        with torch.cuda.amp.autocast():
            l1_loss = 1e-2 * l1(fake, high_res)
            adversarial_loss = 5e-3 * -torch.mean(disc(fake))
            loss_for_vgg = vgg_loss(fake, high_res)
            gen_loss = l1_loss + loss_for_vgg + adversarial_loss

        opt_gen.zero_grad()
        g_scaler.scale(gen_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        writer.add_scalar("Critic loss",
                          loss_critic.item(),
                          global_step=tb_step)
        tb_step += 1

        if idx % 100 == 0 and idx > 0:
            plot_examples("test_images/", gen)

        loop.set_postfix(
            gp=gp.item(),
            critic=loss_critic.item(),
            l1=l1_loss.item(),
            vgg=loss_for_vgg.item(),
            adversarial=adversarial_loss.item(),
        )

    return tb_step
Esempio n. 15
0
    def train(self,
              trainloader,
              testloader,
              inception,
              N,
              batch_size,
              epoch,
              c_iter=5,
              topk=5,
              acc=True,
              test_size=1000,
              test_step=5,
              img_step=5):

        fid_total = []
        real_acc_total = []
        fake_acc_total = []
        gen_loss_total = []
        disc_loss_total = []

        loop_per_epoch = N // (batch_size * c_iter)
        for e in range(epoch):
            gen_avg_loss = 0.0
            disc_avg_loss = 0.0
            start = time.time()
            iterator = iter(trainloader)

            for i in tqdm(range(loop_per_epoch)):
                for c in range(c_iter):
                    self.D_optim.zero_grad()
                    # train D with real data
                    x_real, _ = iterator.next()
                    x_real = x_real.to(self.device)
                    d_real_loss = -self.D(x_real).mean()

                    # train D with fake data
                    z = torch.randn(batch_size, self.z_dim, device=self.device)
                    x_fake = self.G(z)
                    d_fake_loss = self.D(x_fake).mean()

                    # gradient penalty loss to satisfy Lipschitz condition
                    d_grad_loss = gradient_penalty(self.D, x_real, x_fake, 1.0,
                                                   self.device)

                    d_loss = d_real_loss + d_fake_loss + d_grad_loss
                    d_loss.backward()
                    self.D_optim.step()
                    disc_avg_loss += d_loss.item()

                # train G
                for p in self.D.parameters():
                    p.requires_grad = False
                self.G_optim.zero_grad()

                z = torch.randn(batch_size, self.z_dim, device=self.device)
                x_fake = self.G(z)
                g_loss = -self.D(x_fake).mean()

                g_loss.backward()
                self.G_optim.step()
                gen_avg_loss += g_loss.item()

                for p in self.D.parameters():
                    p.requires_grad = True

            finish = time.time()
            time_elapsed = finish - start

            gen_loss_total.append(gen_avg_loss / loop_per_epoch)
            disc_loss_total.append(disc_avg_loss / (loop_per_epoch * c_iter))
            print(
                "Epoch: %d\tdisc loss: %.5f\tgen loss: %.5f\ttime elapsed: %.3f"
                %
                (e + 1, gen_loss_total[-1], disc_loss_total[-1], time_elapsed))

            if (e + 1) == 1:
                eta = time_elapsed * epoch
                finish = time.asctime(time.localtime(time.time() + eta))
                print("### set your alarm at:", finish, "###")

            # save sample images
            if (e + 1) % img_step == 0:
                self.save_images(n=100,
                                 filename=os.path.join(
                                     self.out_dir, "{0}.png".format(e + 1)))

            # test
            if (e + 1) % test_step == 0:
                fid, nn_real, nn_fake = self.test(testloader, test_size,
                                                  inception)
                fid_total.append(fid)
                real_acc_total.append(nn_real)
                fake_acc_total.append(nn_fake)
                print("FID: %.5f\tReal acc: %.5f\tFake acc: %.5f" %
                      (fid, nn_real, nn_fake))

            # save statistics
            if not os.path.exists(self.out_dir):
                os.makedirs(self.out_dir)

            self.G.eval()
            self.D.eval()
            np.save(os.path.join(self.out_dir, "fid.npy"), fid_total)
            np.save(os.path.join(self.out_dir, "ra.npy"), real_acc_total)
            np.save(os.path.join(self.out_dir, "fa.npy"), fake_acc_total)
            np.save(os.path.join(self.out_dir, "gloss.npy"), gen_loss_total)
            np.save(os.path.join(self.out_dir, "dloss.npy"), disc_loss_total)
            torch.save(self.G.cpu().state_dict(),
                       os.path.join(self.out_dir, "g.pth"))
            torch.save(self.D.cpu().state_dict(),
                       os.path.join(self.out_dir, "d.pth"))
            self.G.to(self.device)
            self.D.to(self.device)
            self.G.train()
            self.D.train()

        plt.plot(fake_acc_total)
        plt.plot(real_acc_total)
        plt.plot((np.array(fake_acc_total) + np.array(real_acc_total)) * 0.5,
                 "--")
        plt.legend(["fake acc.", "real acc.", "total acc."])
        pp = PdfPages(os.path.join(self.out_dir, "accuracy.pdf"))
        pp.savefig()
        pp.close()
        plt.close()

        plt.plot(disc_loss_total)
        plt.plot(gen_loss_total)
        plt.legend(["disc. loss", "gen. loss"])
        pp = PdfPages(os.path.join(self.out_dir, "loss.pdf"))
        pp.savefig()
        pp.close()
        plt.close()

        plt.plot(fid_total)
        pp = PdfPages(os.path.join(self.out_dir, "fid.pdf"))
        pp.savefig()
        pp.close()
        plt.close()
Esempio n. 16
0
    # auto-encoders and its regularization
    R_loss = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(
            logits=_Xr_logit,
            labels=X))  #X needed to normalized to the range [0,1]
    #R_loss = tf.reduce_mean(tf.square(_Xr_prob - X)) # another option

    f = tf.reduce_mean(_Xr_prob - _X_prob)
    g = tf.reduce_mean(_ze - z)
    R_reg = tf.square(f - g)

    # Gradient penalty
    beta = tf.random_uniform(shape=[tf.shape(_ze)[0], 1], minval=0, maxval=1)
    _X_i = (1. - beta) * X + beta * _X_prob
    gp = gradient_penalty(_X_i, theta_D)

    # Discriminator loss on data
    d_loss_real = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real_logit,
                                                labels=tf.ones_like(D_real)))
    d_loss_recon = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(logits=D_recon_logit,
                                                labels=tf.zeros_like(D_recon)))
    d_loss_fake = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logit,
                                                labels=tf.zeros_like(D_fake)))
    d_loss_inter = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(logits=D_inter_logit,
                                                labels=tf.zeros_like(D_inter)))
Esempio n. 17
0
    def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int,
                      optimizer_idx: int) -> Tensor:
        """A single training step on a batch from the dataloader."""
        opt_g, opt_d = self.optimizers()

        real_images, real_labels = batch
        batch_size = real_images.size(0)
        # Assign each image a label from another random image
        fake_labels = real_labels[torch.randperm(batch_size)]

        # Train discriminator
        pred_real_error, pred_real_labels = self.discriminator(real_images)

        # Train discriminator to classify image labels
        class_loss = self.classification_loss(pred_real_labels, real_labels)

        fake_images = self.generator(real_images, fake_labels)
        pred_fake_error, pred_fake_labels = self.discriminator(fake_images)

        # Train discriminator to recognize real patches
        real_loss = -pred_real_error.mean()
        # Train discriminator to recognize fake patches
        fake_loss = pred_fake_error.mean()

        # Create random noise image sample
        alpha = torch.rand(batch_size, 1, 1, 1)
        alpha = alpha.expand_as(real_images)

        # Create images distanced between the real set and the fake set
        # selecting random points between a real image and a fake image
        interpolated = (alpha * real_images.detach() +
                        (1 - alpha) * fake_images.detach())
        interpolated.requires_grad_(True)

        prob_interpolated, _ = self.discriminator(interpolated)
        # Push the gradients towards one to prevent vanishing gradients
        grad_loss = gradient_penalty(prob_interpolated, interpolated)

        # Scale and sum up losses
        class_loss *= self.config.class_weight
        grad_loss *= self.config.grad_weight
        loss = class_loss + real_loss + fake_loss + grad_loss

        self.log('discriminator loss', loss)
        self.log('discriminator classification loss', class_loss)
        self.log('discriminator real images loss', real_loss)
        self.log('discriminator fake images loss', fake_loss)
        self.log('discriminator gradient loss', grad_loss)

        # Update discriminator weights
        opt_d.zero_grad()
        self.manual_backward(loss, opt_d)
        opt_d.step()

        if (batch_idx + 1) % self.config.n_critic == 0:
            # Train generator
            fake_images = self.generator(real_images, fake_labels)
            pred_fake_error, pred_fake_labels = self.discriminator(fake_images)

            # Train generator to make images that match target labels
            class_loss = self.classification_loss(pred_fake_labels,
                                                  real_labels)

            # Train generator to make realistic images
            fake_loss = -pred_fake_error.mean()

            # Reconstruct original image with generated image and original labels
            recon_images = self.generator(fake_images, real_labels)
            # Train generator to make resulting images similar to original
            recon_loss = F.l1_loss(recon_images, real_images)

            # Scale and sum up losses
            class_loss *= self.config.class_weight
            recon_loss *= self.config.recon_weight
            loss = fake_loss + class_loss + recon_loss

            self.log('generator loss', loss)
            self.log('generator realism loss', fake_loss)
            self.log('generator classification loss', class_loss)
            self.log('generator reconstruction loss', recon_loss)

            # Update generator weights
            opt_g.zero_grad()
            self.manual_backward(loss, opt_g)
            opt_g.step()
def train():
    epoch = 200
    batch_size = 1
    lr = 0.0002
    crop_size = 128
    load_size = 128
    tar_db_a = "cameron_images.tgz"
    tar_db_b = "teresa_images.tgz"
    db_a_i = importer_tar.Importer(tar_db_a)
    db_b_i = importer_tar.Importer(tar_db_b)
    image_a_names = db_a_i.get_sorted_image_name()
    image_b_names = db_b_i.get_sorted_image_name()
    train_a_size = int(len(image_a_names) * 0.8)
    train_b_size = int(len(image_b_names) * 0.8)

    image_a_train_names = image_a_names[0:train_a_size]
    image_b_train_names = image_b_names[0:train_b_size]

    image_a_test_names = image_a_names[train_a_size:]
    image_b_test_names = image_b_names[train_b_size:]

    print("A train size:{},test size:{}".format(len(image_a_train_names),
                                                len(image_a_test_names)))
    print("B train size:{},test size:{}".format(len(image_b_train_names),
                                                len(image_b_test_names)))
    """ graph """
    # models
    generator_a2b = partial(models.generator, scope='a2b')
    generator_b2a = partial(models.generator, scope='b2a')
    discriminator_a = partial(models.discriminator, scope='a')
    discriminator_b = partial(models.discriminator, scope='b')

    # operations
    a_real_in = tf.placeholder(tf.float32,
                               shape=[None, load_size, load_size, 3],
                               name="a_real")
    b_real_in = tf.placeholder(tf.float32,
                               shape=[None, load_size, load_size, 3],
                               name="b_real")
    a_real = utils.preprocess_image(a_real_in, crop_size=crop_size)
    b_real = utils.preprocess_image(b_real_in, crop_size=crop_size)

    a2b = generator_a2b(a_real)
    b2a = generator_b2a(b_real)
    b2a2b = generator_a2b(b2a)
    a2b2a = generator_b2a(a2b)

    a_logit = discriminator_a(a_real)
    b2a_logit = discriminator_a(b2a)
    b_logit = discriminator_b(b_real)
    a2b_logit = discriminator_b(a2b)

    # losses
    g_loss_a2b = -tf.reduce_mean(a2b_logit)
    g_loss_b2a = -tf.reduce_mean(b2a_logit)

    cyc_loss_a = tf.losses.absolute_difference(a_real, a2b2a)
    cyc_loss_b = tf.losses.absolute_difference(b_real, b2a2b)
    g_loss = g_loss_a2b + g_loss_b2a + cyc_loss_a * 10.0 + cyc_loss_b * 10.0

    wd_a = tf.reduce_mean(a_logit) - tf.reduce_mean(b2a_logit)
    wd_b = tf.reduce_mean(b_logit) - tf.reduce_mean(a2b_logit)

    gp_a = gradient_penalty(a_real, b2a, discriminator_a)
    gp_b = gradient_penalty(b_real, a2b, discriminator_b)

    d_loss_a = -wd_a + 10.0 * gp_a
    d_loss_b = -wd_b + 10.0 * gp_b

    # summaries
    utils.summary({
        g_loss_a2b: 'g_loss_a2b',
        g_loss_b2a: 'g_loss_b2a',
        cyc_loss_a: 'cyc_loss_a',
        cyc_loss_b: 'cyc_loss_b'
    })
    utils.summary({d_loss_a: 'd_loss_a'})
    utils.summary({d_loss_b: 'd_loss_b'})
    for variable in slim.get_model_variables():
        tf.summary.histogram(variable.op.name, variable)
    merged = tf.summary.merge_all()

    im1_op = tf.summary.image("real_a", a_real_in)
    im2_op = tf.summary.image("a2b", a2b)
    im3_op = tf.summary.image("b2a2b", b2a2b)
    im4_op = tf.summary.image("real_b", b_real_in)
    im5_op = tf.summary.image("b2a", b2a)
    im6_op = tf.summary.image("b2a2b", b2a2b)

    # optim
    t_var = tf.trainable_variables()
    d_a_var = [var for var in t_var if 'a_discriminator' in var.name]
    d_b_var = [var for var in t_var if 'b_discriminator' in var.name]
    g_var = [
        var for var in t_var
        if 'a2b_generator' in var.name or 'b2a_generator' in var.name
    ]

    d_a_train_op = tf.train.AdamOptimizer(lr,
                                          beta1=0.5).minimize(d_loss_a,
                                                              var_list=d_a_var)
    d_b_train_op = tf.train.AdamOptimizer(lr,
                                          beta1=0.5).minimize(d_loss_b,
                                                              var_list=d_b_var)
    g_train_op = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(g_loss,
                                                                var_list=g_var)
    """ train """
    ''' init '''
    # session
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    # counter
    it_cnt, update_cnt = utils.counter()
    ''' summary '''
    summary_writer = tf.summary.FileWriter('./outputs/summaries/', sess.graph)
    ''' saver '''
    saver = tf.train.Saver(max_to_keep=5)
    ''' restore '''
    ckpt_dir = './outputs/checkpoints/'
    utils.mkdir(ckpt_dir)
    try:
        utils.load_checkpoint(ckpt_dir, sess)
    except:
        sess.run(tf.global_variables_initializer())
    '''train'''
    try:
        batch_epoch = min(train_a_size, train_b_size) // batch_size
        max_it = epoch * batch_epoch
        for it in range(sess.run(it_cnt), max_it):
            sess.run(update_cnt)
            epoch = it // batch_epoch
            it_epoch = it % batch_epoch + 1

            # read data
            a_real_np = cv2.resize(
                db_a_i.get_image(image_a_train_names[it_epoch % batch_epoch]),
                (load_size, load_size))
            b_real_np = cv2.resize(
                db_b_i.get_image(image_b_train_names[it_epoch % batch_epoch]),
                (load_size, load_size))

            # train G
            sess.run(g_train_op,
                     feed_dict={
                         a_real_in: [a_real_np],
                         b_real_in: [b_real_np]
                     })

            # train discriminator
            sess.run([d_a_train_op, d_b_train_op],
                     feed_dict={
                         a_real_in: [a_real_np],
                         b_real_in: [b_real_np]
                     })

            # display
            if it % 100 == 0:
                # make summary
                summary = sess.run(merged,
                                   feed_dict={
                                       a_real_in: [a_real_np],
                                       b_real_in: [b_real_np]
                                   })
                summary_writer.add_summary(summary, it)
                print("Epoch: (%3d) (%5d/%5d)" %
                      (epoch, it_epoch, batch_epoch))

            # save
            if (it + 1) % 1000 == 0:
                save_path = saver.save(
                    sess, '{}/epoch_{}_{}.ckpt'.format(ckpt_dir, epoch,
                                                       it_epoch))
                print('###Model saved in file: {}'.format(save_path))

            # sample
            if (it + 1) % 1000 == 0:
                a_test_index = int(
                    np.random.uniform(high=len(image_a_test_names)))
                b_test_index = int(
                    np.random.uniform(high=len(image_b_test_names)))
                a_real_np = cv2.resize(
                    db_a_i.get_image(image_a_test_names[a_test_index]),
                    (load_size, load_size))
                b_real_np = cv2.resize(
                    db_b_i.get_image(image_b_test_names[b_test_index]),
                    (load_size, load_size))

                [a_opt, a2b_opt, a2b2a_opt, b_opt, b2a_opt, b2a2b_opt
                 ] = sess.run([a_real, a2b, a2b2a, b_real, b2a, b2a2b],
                              feed_dict={
                                  a_real_in: [a_real_np],
                                  b_real_in: [b_real_np]
                              })

                sample_opt = np.concatenate(
                    (a_opt, a2b_opt, a2b2a_opt, b_opt, b2a_opt, b2a2b_opt),
                    axis=0)
                [im1_sum,im2_sum,im3_sum,im4_sum,im5_sum,im6_sum] = \
                    sess.run([im1_op,im2_op,im3_op,im4_op,im5_op,im6_op],
                             feed_dict={a_real_in: [a_real_np], b_real_in: [b_real_np]})

                summary_writer.add_summary(im1_sum, it)
                summary_writer.add_summary(im2_sum, it)
                summary_writer.add_summary(im3_sum, it)
                summary_writer.add_summary(im4_sum, it)
                summary_writer.add_summary(im5_sum, it)
                summary_writer.add_summary(im6_sum, it)

                save_dir = './outputs/sample_images_while_training/'
                utils.mkdir(save_dir)
                im.imwrite(
                    im.immerge(sample_opt, 2, 3),
                    '{}/epoch_{}_it_{}.jpg'.format(save_dir, epoch, it_epoch))
    except:
        raise

    finally:
        save_path = saver.save(
            sess, '{}/epoch_{}_{}.ckpt'.format(ckpt_dir, epoch, it_epoch))
        print('###Model saved in file: {}'.format(save_path))
        sess.close()
Esempio n. 19
0
def train_fn(
    loader,
    disc,
    gen,
    opt_gen,
    opt_disc,
    l1,
    vgg_loss,
    g_scaler,
    d_scaler,
    writer,
    tb_step,
):
    loop = tqdm(loader, leave=True)

    for idx, (lr, hr) in enumerate(loop):
        hr = hr.to(config.DEVICE)
        lr = lr.to(config.DEVICE)

        # Train Discriminator
        with torch.cuda.amp.autocast():
            fake = gen(lr)
            disc_real = disc(hr)
            disc_fake = disc(fake.detach())
            gp = gradient_penalty(disc, hr, fake,
                                  device=config.DEVICE)  # ?? 相对loss
            loss_disc = config.LAMBDA_GP * gp - (torch.mean(disc_real) -
                                                 torch.mean(disc_fake))

        opt_disc.zero_grad()
        d_scaler.scale(loss_disc).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train Generator
        with torch.cuda.amp.autocast():
            l1_loss = 1e-2 * l1(fake, hr)
            adversarial_loss = 5e-3 * -torch.mean(disc(fake))
            vgg_for_loss = vgg_loss(fake, hr)
            loss_gen = l1_loss + adversarial_loss + vgg_for_loss

        opt_gen.zero_grad()
        g_scaler.scale(loss_gen).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        writer.add_scalar("Disc loss", loss_disc.item(), global_step=tb_step)
        tb_step += 1

        if idx % 1 == 0:
            plot_examples("data/val", gen)

        loop.set_postfix(
            gp=gp.item(),
            disc=loss_disc.item(),
            l1=l1_loss.item(),
            vgg=vgg_for_loss.item(),
            adversarial=adversarial_loss.item(),
        )

    return tb_step
Esempio n. 20
0
def train_one_epoch(epoch,
                    dataloader,
                    gen,
                    critic,
                    opt_gen,
                    opt_critic,
                    fixed_noise,
                    device,
                    metric_logger,
                    num_samples,
                    freq=100):
    """
    Train one epoch
    :param epoch: ``int`` current epoch
    :param dataloader: object of dataloader
    :param gen: Generator model
    :param critic: Discriminator model
    :param opt_gen: Optimizer for generator
    :param opt_critic: Optimizer for discriminator
    :param fixed_noise: ``tensor[[cfg.BATCH_SIZE, latent_space_dimension, 1, 1]]`` fixed noise (latent space) for image metrics
    :param device: cuda device or cpu
    :param metric_logger: object of MetricLogger
    :param num_samples: ``int`` well retrievable sqrt() (for example: 4, 16, 64) for good result,
    number of samples for grid image metric
    :param freq: ``int``, freq < len(dataloader)`` freq for display results
    """
    for batch_idx, img in enumerate(dataloader):
        real = img.to(device)
        # Train critic
        acc_real = 0
        acc_fake = 0
        for _ in range(cfg.CRITIC_ITERATIONS):
            noise = get_random_noise(cfg.BATCH_SIZE, cfg.Z_DIMENSION, device)
            fake = gen(noise)
            critic_real = critic(real).reshape(-1)
            acc_real += critic_real
            critic_fake = critic(fake).reshape(-1)
            acc_fake += critic_fake
            gp = gradient_penalty(critic, real, fake, device=device)
            loss_critic = -1 * (torch.mean(critic_real) -
                                torch.mean(critic_fake)) + cfg.LAMBDA_GP * gp
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

        acc_real = acc_real / cfg.CRITIC_ITERATIONS
        acc_fake = acc_fake / cfg.CRITIC_ITERATIONS

        # Train generator: minimize -E[critic(gen_fake)]
        output = critic(fake).reshape(-1)
        loss_gen = -1 * torch.mean(output)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # logs metrics
        if batch_idx % freq == 0:
            with torch.no_grad():
                metric_logger.log(loss_critic, loss_gen, acc_real, acc_fake)
                fake = gen(fixed_noise)
                metric_logger.log_image(fake, num_samples, epoch, batch_idx,
                                        len(dataloader))
                metric_logger.display_status(epoch, cfg.NUM_EPOCHS, batch_idx,
                                             len(dataloader), loss_critic,
                                             loss_gen, acc_real, acc_fake)
Esempio n. 21
0
def train(num_epoch=500,
          dsize=2048,
          batch_size=64,
          seq_length=128,
          conv_dim=8,
          print_interval=10,
          variables=None):

    num_dim = sum(len(v) for v in variables)
    sustains, attacks, releases = variables

    G = models.Generator(num_dim, conv_dim)
    D = models.Discriminator(num_dim, seq_length, conv_dim)

    D_optimizer = torch.optim.Adam(D.parameters())
    G_optimizer = torch.optim.Adam(G.parameters())

    loss_function = F.binary_cross_entropy_with_logits

    dataloader = utils.create_dataset(seq_length,
                                      dsize,
                                      batch_size,
                                      sus=sustains,
                                      att=attacks,
                                      rel=releases)

    for epoch in range(num_epoch):
        for sample, label in dataloader:

            input_labels = label.view(batch_size, num_dim,
                                      1).repeat(1, 1, seq_length)
            main_input = torch.cat([sample, input_labels], dim=1)
            real_sample = sample[:, 0, :].clone()

            ############
            # D - real #
            ############
            d_real_src_pred, d_real_cls_pred = D(main_input[:, :3, :])
            d_real_loss = -torch.mean(d_real_src_pred)
            d_real_cls_loss = loss_function(d_real_cls_pred, label)

            D_optimizer.zero_grad()
            d_real_loss = d_real_loss + d_real_cls_loss
            d_real_loss.backward()
            D_optimizer.step()

            ############
            # G - fake #
            ############
            gen_sample = G(main_input[:, 1:, :])
            main_input = torch.cat([gen_sample, main_input[:, 1:, :]], dim=1)
            d_fake_src_pred, _ = D(main_input[:, :3, :].detach())
            d_fake_loss = torch.mean(d_fake_src_pred)

            D_optimizer.zero_grad()
            d_fake_loss.backward()
            D_optimizer.step()

            ############
            #    GP    #
            ############
            alpha = torch.rand(batch_size, 1)
            interpolates = alpha * real_sample.squeeze().data + (
                1 - alpha) * gen_sample.squeeze().data
            interpolates = interpolates.view(batch_size, 1, seq_length)
            interpolates = interpolates.requires_grad_(True)
            interp_input = torch.cat([interpolates, main_input[:, 1:, :]],
                                     dim=1)

            gp_pred, _ = D(interp_input[:, :3, :])
            gp_loss = utils.gradient_penalty(x=interp_input, y=gp_pred)

            D_optimizer.zero_grad()
            gp_loss.backward(retain_graph=True)
            D_optimizer.step()

            ############
            #   Gen    #
            ############
            gen_sample = G(main_input[:, 1:, :])
            main_input = torch.cat([gen_sample, main_input[:, 1:, :]], dim=1)
            g_src_pred, g_cls_pred = D(main_input[:, :3, :])
            g_src_loss = -torch.mean(g_src_pred)
            g_cls_loss = loss_function(g_cls_pred, label)

            G_optimizer.zero_grad()
            D_optimizer.zero_grad()
            g_loss = g_src_loss + g_cls_loss
            g_loss.backward()
            G_optimizer.step()

        if epoch % print_interval == 0:

            i = np.random.randint(0, batch_size - 1)
            gen_sample = gen_sample[i, :].squeeze().data.numpy()
            real_sample = real_sample[i, :].data.numpy()
            label = label[i].data.numpy()

            # console update
            print('\n----- ', epoch, '-----')
            print('D real: ',
                  d_real_src_pred.data.mean().numpy(), '--- class: ',
                  d_real_cls_loss.data.numpy(), '--- gp: ',
                  gp_loss.data.numpy())
            print('G loss: ',
                  g_src_pred.data.mean().numpy(), '--- class: ',
                  g_cls_loss.data.numpy())
            print('label: ', label)
            print('real pred : ', d_real_cls_pred[i].data.numpy())
            print('fake pred : ', g_cls_pred[i].data.numpy())

            # output graphs
            plt.plot(gen_sample, 'r')
            plt.plot(real_sample, 'g')
            plt.title('label: {}'.format(label))
            plt.savefig('output_data/gen_sample_epoch_{}.png'.format(epoch))
            plt.clf()

            # save models
            filepath = 'models/gen/G_{}.pth.tar'.format(epoch)
            torch.save(G.state_dict(), filepath)

            filepath = 'models/disc/D_{}.pth.tar'.format(epoch)
            torch.save(D.state_dict(), filepath)
Esempio n. 22
0
def train_fn(
    critic,
    gen,
    loader,
    dataset,
    step,
    alpha,
    opt_critic,
    opt_gen,
    tensorboard_step,
    writer,
    scaler_gen,
    scaler_critic,
):
    loop = tqdm(loader, leave=True)
    # critic_losses = []
    reals = 0
    fakes = 0
    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(config.DEVICE)
        cur_batch_size = real.shape[0]

        # Train Critic: max E[critic(real)] - E[critic(fake)] <-> min -E[critic(real)] + E[critic(fake)]
        # which is equivalent to minimizing the negative of the expression
        noise = torch.randn(cur_batch_size, config.Z_DIM, 1,
                            1).to(config.DEVICE)

        with torch.cuda.amp.autocast():
            fake = gen(noise, alpha, step)
            critic_real = critic(real, alpha, step)
            critic_fake = critic(fake.detach(), alpha, step)
            reals += critic_real.mean().item()
            fakes += critic_fake.mean().item()
            gp = gradient_penalty(critic,
                                  real,
                                  fake,
                                  alpha,
                                  step,
                                  device=config.DEVICE)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake)) +
                config.LAMBDA_GP * gp + (0.001 * torch.mean(critic_real**2)))

        opt_critic.zero_grad()
        scaler_critic.scale(loss_critic).backward()
        scaler_critic.step(opt_critic)
        scaler_critic.update()

        # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
        with torch.cuda.amp.autocast():
            gen_fake = critic(fake, alpha, step)
            loss_gen = -torch.mean(gen_fake)

        opt_gen.zero_grad()
        scaler_gen.scale(loss_gen).backward()
        scaler_gen.step(opt_gen)
        scaler_gen.update()

        # Update alpha and ensure less than 1
        alpha += cur_batch_size / (
            (config.PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset))
        alpha = min(alpha, 1)

        if batch_idx % 500 == 0:
            with torch.no_grad():
                fixed_fakes = gen(config.FIXED_NOISE, alpha, step) * 0.5 + 0.5
            plot_to_tensorboard(
                writer,
                loss_critic.item(),
                loss_gen.item(),
                real.detach(),
                fixed_fakes.detach(),
                tensorboard_step,
            )
            tensorboard_step += 1

        loop.set_postfix(
            reals=reals / (batch_idx + 1),
            fakes=fakes / (batch_idx + 1),
            gp=gp.item(),
            loss_critic=loss_critic.item(),
        )

    return tensorboard_step, alpha
Esempio n. 23
0
        for _ in range(CRITIC_ITERS):
            # calculate discriminator loss
            noise = gen_noise(cur_batch_size, NOISE_DIM, device=device)
            fake = gen(noise)
            disc_fake = critic(fake.detach())
            disc_real = critic(real)

            # calculate gradient penalty
            epsilon = torch.rand(cur_batch_size,
                                 1,
                                 1,
                                 1,
                                 device=device,
                                 requires_grad=True)
            grads = get_grad(critic, real, fake.detach(), epsilon)
            gp = gradient_penalty(grads)

            # calculate discriminator loss
            disc_loss = -(torch.mean(disc_real) -
                          torch.mean(disc_fake)) + C_LAMBDA * gp

            # update discriminator
            opt_disc.zero_grad()
            disc_loss.backward(retain_graph=True)
            opt_disc.step()
            crit_losses.append(disc_loss.item())

        # monitor running loss
        lossD.update(np.mean(crit_losses), cur_batch_size)

        # calculate generator loss