Ejemplo n.º 1
0
def generate_walks():
    print("Start Generating Graph...")
    transitions_per_walk = 4 - 1
    transitions_per_iter = 20e4
    eval_transitions = 80e7
    sample_many_count = int(
        np.round(transitions_per_iter / transitions_per_walk))
    n_eval_walks = eval_transitions / transitions_per_walk
    n_eval_iters = int(np.round(n_eval_walks / sample_many_count))

    smpls_type, smpls_node = [], []
    for i in range(n_eval_iters):
        initial_noise = utils.make_noise((args.batch_size, args.noise_dim),
                                         args.noise_type).to(device)
        synthetic_type, synthetic_node = generator(initial_noise)
        synthetic_type = torch.argmax(synthetic_type.cpu(),
                                      dim=2).numpy().astype(np.int32)
        synthetic_node = torch.argmax(synthetic_node.cpu(),
                                      dim=2).numpy().astype(np.int32)
        smpls_type += utils.delete_from_tail(synthetic_type, 3)
        smpls_node += utils.delete_from_tail(synthetic_node, _N)
        if i % 100 == 0:
            print("Done, generating {} of {} batches meta-paths...".format(
                i, n_eval_iters))

    return smpls_type, smpls_node
Ejemplo n.º 2
0
def validate_classifier(G,
                        deformator,
                        shift_predictor,
                        params_dict=None,
                        trainer=None):
    n_steps = 100
    if trainer is None:
        trainer = Trainer(params=Params(**params_dict), verbose=False)

    percents = torch.empty([n_steps])
    for step in range(n_steps):
        z = make_noise(trainer.p.batch_size, G.dim_z).cuda()
        target_indices, shifts, z_shift = trainer.make_shifts(G.dim_z)

        if trainer.p.global_deformation:
            z_shifted = deformator(z + z_shift)
            z = deformator(z)
        else:
            z_shifted = z + deformator(z_shift)
        imgs = G(z)
        imgs_shifted = G(z_shifted)

        logits, _ = shift_predictor(imgs, imgs_shifted)
        percents[step] = (torch.argmax(logits, dim=1) == target_indices).to(
            torch.float32).mean()

    return percents.mean()
Ejemplo n.º 3
0
def validate_classifier(G,
                        deformator,
                        shift_predictor,
                        params_dict=None,
                        trainer=None):
    n_steps = 10
    if trainer is None:
        trainer = Trainer(params=Params(**params_dict), verbose=False)

    percents = torch.empty([n_steps])
    for step in range(n_steps):
        z = make_noise(trainer.p.batch_size, G.dim_z, trainer.p.z_std,
                       trainer.p.truncation).cuda()
        target_indices, shifts, basis_shift = trainer.make_shifts(
            deformator.input_dim)

        shift = deformator(basis_shift)

        with torch.set_grad_enabled(trainer.p.torch_grad):
            imgs = G(z=z, reverse=True).clamp_(0, 1).to('cuda:0')
        shift = shift.view([-1] + G.dim_z)
        imgs_shifted = G.nvp_shifted(z, shift,
                                     reverse=True).clamp_(0, 1).to('cuda:0')

        logits, _ = shift_predictor(imgs, imgs_shifted)
        percents[step] = (torch.argmax(logits, dim=1) == target_indices).to(
            torch.float32).mean()

    return percents.mean()
Ejemplo n.º 4
0
    def train(self, G, deformator, shift_predictor, multi_gpu=False):
        G.cuda().eval()
        deformator.cuda().train()
        shift_predictor.cuda().train()

        should_gen_classes = is_conditional(G)
        if multi_gpu:
            G = DataParallelPassthrough(G)

        deformator_opt = torch.optim.Adam(deformator.parameters(), lr=self.p.deformator_lr) \
            if deformator.type not in [DeformatorType.ID, DeformatorType.RANDOM] else None
        shift_predictor_opt = torch.optim.Adam(
            shift_predictor.parameters(), lr=self.p.shift_predictor_lr)

        avgs = MeanTracker('percent'), MeanTracker('loss'), MeanTracker('direction_loss'),\
               MeanTracker('shift_loss')
        avg_correct_percent, avg_loss, avg_label_loss, avg_shift_loss = avgs

        recovered_step = self.start_from_checkpoint(deformator, shift_predictor)
        for step in range(recovered_step, self.p.n_steps, 1):
            G.zero_grad()
            deformator.zero_grad()
            shift_predictor.zero_grad()

            z = make_noise(self.p.batch_size, G.dim_z, self.p.truncation).cuda()
            target_indices, shifts, basis_shift = self.make_shifts(deformator.input_dim)

            if should_gen_classes:
                classes = G.mixed_classes(z.shape[0])

            # Deformation
            shift = deformator(basis_shift)
            if should_gen_classes:
                imgs = G(z, classes)
                imgs_shifted = G.gen_shifted(z, shift, classes)
            else:
                imgs = G(z)
                imgs_shifted = G.gen_shifted(z, shift)

            logits, shift_prediction = shift_predictor(imgs, imgs_shifted)
            logit_loss = self.p.label_weight * self.cross_entropy(logits, target_indices)
            shift_loss = self.p.shift_weight * torch.mean(torch.abs(shift_prediction - shifts))

            # total loss
            loss = logit_loss + shift_loss
            loss.backward()

            if deformator_opt is not None:
                deformator_opt.step()
            shift_predictor_opt.step()

            # update statistics trackers
            avg_correct_percent.add(torch.mean(
                    (torch.argmax(logits, dim=1) == target_indices).to(torch.float32)).detach())
            avg_loss.add(loss.item())
            avg_label_loss.add(logit_loss.item())
            avg_shift_loss.add(shift_loss)

            self.log(G, deformator, shift_predictor, step, avgs)
Ejemplo n.º 5
0
 def log_generated_images(self, G, step):
     generated_imgs = []
     for _ in range(2):
         z = make_noise(8, G.dim_z, self.p.z_std, self.p.truncation).cuda()
         imgs = G(z=z, reverse=True).clamp_(0, 1).cuda()
         generated_imgs.append(imgs)
     image_grid = torchvision.utils.make_grid(generated_imgs)
     self.writer.add_image("Generated Images", image_grid, step)
Ejemplo n.º 6
0
 def make_noise(self, batch_size, device):
     if self.zs is None:
         return make_noise(batch_size, self.G.dim_z).to(device)
     else:
         indices = torch.randint(0,
                                 len(self.zs), [batch_size],
                                 dtype=torch.long)
         z = self.zs[indices].to(device)
         return z
Ejemplo n.º 7
0
    def eval(self, G, deformator, shift_predictor, inception, target_id):
        G.cuda().eval()
        deformator.cuda().eval()
        shift_predictor.cuda().eval()

        z = make_noise(self.p.batch_size, G.dim_z).cuda()
        target_indices = torch.full([self.p.batch_size],
                                    target_id,
                                    device='cuda').type(torch.long)
        _, shifts, z_shift = self.make_shifts(G.dim_z,
                                              target_indices=target_indices)

        # Deformation

        if self.p.global_deformation:
            z_shifted = deformator(z + z_shift)
            z = deformator(z)
        else:
            z_shifted = z + deformator(z_shift)

        ##########################
        img_feats_list = []
        for _z in z.split(128):
            imgs = G(_z)
            img_feats = inception(imgs)
            if isinstance(img_feats, list):
                img_feats = img_feats[0]
            img_feats_list.append(img_feats.view(self.p.batch_size, -1))
        img_feats = torch.cat(img_feats_list)

        img_shifted_feats_list = []
        for _z_shifted in z_shifted.split(128):
            imgs_shifted = G(_z_shifted)
            img_shifted_feats = inception(imgs_shifted)
            if isinstance(img_shifted_feats, list):
                img_shifted_feats = img_shifted_feats[0]
            img_shifted_feats_list.append(
                img_shifted_feats.view(self.p.batch_size, -1))
        img_shifted_feats = torch.cat(img_shifted_feats_list)

        mean_img_feats = img_feats.mean(0)
        std_img_feats = img_feats.std(0)
        img_feats_distr = torch.distributions.Normal(loc=mean_img_feats,
                                                     scale=std_img_feats)

        mean_img_shifted_feats = img_shifted_feats.mean(0)
        std_img_shifted_feats = img_shifted_feats.std(0)
        img_shifted_feats_distr = torch.distributions.Normal(
            loc=mean_img_shifted_feats, scale=std_img_shifted_feats)

        kl = torch.distributions.kl.kl_divergence(
            img_shifted_feats_distr, img_feats_distr).mean().item()
        l2 = ((img_feats - img_shifted_feats)**2).mean().item()

        print(f"Target id {target_id} | KL {kl:.3} | L2 {l2:.3}")
        return kl, l2
Ejemplo n.º 8
0
def save_results_charts(G, deformator, params, out_dir):
    deformator.eval()
    G.eval()
    z = make_noise(3, G.dim_z, params.truncation).cuda()
    inspect_all_directions(
        G, deformator, os.path.join(out_dir, 'charts_s{}'.format(int(params.shift_scale))),
        zs=z, shifts_r=params.shift_scale)
    inspect_all_directions(
        G, deformator, os.path.join(out_dir, 'charts_s{}'.format(int(3 * params.shift_scale))),
        zs=z, shifts_r=3 * params.shift_scale)
Ejemplo n.º 9
0
    def log_interpolation(self, G, deformator, step):
        noise = make_noise(1, G.dim_z, self.p.truncation).cuda()
        if self.fixed_test_noise is None:
            self.fixed_test_noise = noise.clone()
        for z, prefix in zip([noise, self.fixed_test_noise], ['rand', 'fixed']):
            fig = make_interpolation_chart(
                G, deformator, z=z, shifts_r=3 * self.p.shift_scale, shifts_count=3, dims_count=15,
                dpi=500)

            self.writer.add_figure('{}_deformed_interpolation'.format(prefix), fig, step)
            fig_to_image(fig).convert("RGB").save(
                os.path.join(self.images_dir, '{}_{}.jpg'.format(prefix, step)))
Ejemplo n.º 10
0
def save_results_charts(G, deformator, params, out_dir):
    print("[Charts]: Starting creating visualization charts:")
    deformator.eval()
    G.eval()
    z = make_noise(3, G.dim_z, params.z_std, params.truncation).cuda()
    inspect_all_directions_per_direction(
        G, deformator, os.path.join(out_dir, 'charts_s{}'.format(int(2 * params.shift_scale))),
        directions_count=params.directions_count, zs=z, std=params.z_std, shifts_r=2 * params.shift_scale)
    inspect_all_directions_per_direction(
        G, deformator, os.path.join(out_dir, 'charts_s{}'.format(int(params.shift_scale))),
        directions_count=params.directions_count, zs=z, std=params.z_std, shifts_r=params.shift_scale)
    inspect_all_directions_per_direction(
        G, deformator, os.path.join(out_dir, 'charts_s{}'.format(int(3 * params.shift_scale))),
        directions_count=params.directions_count, zs=z, std=params.z_std, shifts_r=3 * params.shift_scale)
Ejemplo n.º 11
0
def train():
    batches_done = 0
    print("Start Training...")
    for epoch in range(args.n_epochs):
        for i, (real_node_type, real_node_seq) in enumerate(dataloader):
            # Configure input
            real_node_type = Variable(real_node_type.float()).to(device)
            real_node_seq = Variable(real_node_seq.float()).to(device)
            initial_noise = utils.make_noise((args.batch_size, args.noise_dim),
                                             args.noise_type).to(device)
            # ---------------------
            #  Train Discriminator
            # ---------------------
            discriminator.train()
            optimizer_D.zero_grad()
            fake_type = generator(initial_noise)[0].detach()
            fake_node = generator(initial_noise)[1].detach()
            loss_D = torch.mean(discriminator((fake_type, fake_node))) -\
                torch.mean(discriminator((real_node_type, real_node_seq)))

            loss_D.backward()
            optimizer_D.step()

            # Clip weights of discriminator
            for p in discriminator.parameters():
                p.data.clamp_(-args.clip_value, args.clip_value)

            # Train the generator every n_critic iterations
            if i % args.n_critic == 0:
                # -----------------
                #  Train Generator
                # -----------------
                optimizer_G.zero_grad()

                # Generate a batch of random walks
                syn_types, syn_node_seq = generator(initial_noise)
                # Adversarial loss
                loss_G = -torch.mean(discriminator((syn_types, syn_node_seq)))
                # Loss back-propagation
                loss_G.backward()
                optimizer_G.step()

                print(
                    "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" %
                    (epoch + 1, args.n_epochs, batches_done % len(dataloader),
                     len(dataloader), loss_D.item(), loss_G.item()))
            batches_done += 1
Ejemplo n.º 12
0
def validate_classifier(G, deformator, shift_predictor, params_dict=None, trainer=None):
    n_steps = 100
    if trainer is None:
        trainer = Trainer(params=Params(**params_dict), verbose=False)

    percents = torch.empty([n_steps])
    for step in range(n_steps):
        z = make_noise(trainer.p.batch_size, G.dim_z, trainer.p.truncation).cuda()
        target_indices, shifts, basis_shift = trainer.make_shifts(deformator.input_dim)

        imgs = G(z)
        imgs_shifted = G.gen_shifted(z, deformator(basis_shift))

        logits, _ = shift_predictor(imgs, imgs_shifted)
        percents[step] = (torch.argmax(logits, dim=1) == target_indices).to(torch.float32).mean()

    return percents.mean()
Ejemplo n.º 13
0
def make_interpolation_chart(G,
                             deformator=None,
                             z=None,
                             shifts_r=10.0,
                             shifts_count=5,
                             dims=None,
                             dims_count=10,
                             texts=None,
                             **kwargs):
    with_deformation = deformator is not None
    if with_deformation:
        deformator_is_training = deformator.training
        deformator.eval()
    z = z if z is not None else make_noise(1, G.dim_z).cuda()

    if with_deformation:
        original_img = G(z).cpu()
    else:
        original_img = G(z).cpu()
    imgs = []
    if dims is None:
        dims = range(dims_count)
    for i in dims:
        imgs.append(interpolate(G, z, shifts_r, shifts_count, i, deformator))

    rows_count = len(imgs) + 1
    fig, axs = plt.subplots(rows_count, **kwargs)

    axs[0].axis('off')
    axs[0].imshow(to_image(original_img, True))

    if texts is None:
        texts = dims
    for ax, shifts_imgs, text in zip(axs[1:], imgs, texts):
        ax.axis('off')
        plt.subplots_adjust(left=0.5)
        ax.imshow(
            to_image(
                make_grid(shifts_imgs, nrow=(2 * shifts_count + 1), padding=1),
                True))
        ax.text(-20, 21, str(text), fontsize=10)

    if deformator is not None and deformator_is_training:
        deformator.train()

    return fig
Ejemplo n.º 14
0
def inspect_all_directions(G,
                           deformator,
                           out_dir,
                           zs=None,
                           num_z=3,
                           shifts_r=8.0):
    os.makedirs(out_dir, exist_ok=True)

    step = 20
    # max_dim = G.dim_shift
    max_dim = deformator.input_dim
    zs = zs if zs is not None else make_noise(num_z, G.dim_z).cuda()
    shifts_count = zs.shape[0]

    for start in range(0, max_dim - 1, step):
        imgs = []
        dims = range(start, min(start + step, max_dim))
        for z in zs:
            z = z.unsqueeze(0)
            fig = make_interpolation_chart(G,
                                           deformator=deformator,
                                           z=z,
                                           shifts_count=shifts_count,
                                           dims=dims,
                                           shifts_r=shifts_r,
                                           dpi=250,
                                           figsize=(int(shifts_count * 4.0),
                                                    int(0.5 * step) + 2))
            fig.canvas.draw()
            plt.close(fig)
            img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
            img = img.reshape(fig.canvas.get_width_height()[::-1] + (3, ))

            # crop borders
            nonzero_columns = np.count_nonzero(img != 255, axis=0)[:, 0] > 0
            img = img.transpose(1, 0, 2)[nonzero_columns].transpose(1, 0, 2)
            imgs.append(img)

        out_file = os.path.join(out_dir, '{}_{}.jpg'.format(dims[0], dims[-1]))
        print('saving chart to {}'.format(out_file))
        Image.fromarray(np.hstack(imgs)).save(out_file)
Ejemplo n.º 15
0
    def train(self, G, deformator, shift_predictor, trial, multi_gpu=False):
        # torch.autograd.set_detect_anomaly(True)
        G.cuda().eval()
        deformator.cuda().train()
        shift_predictor.cuda().train()

        should_gen_classes = is_conditional(G)
        if multi_gpu:
            G = DataParallelPassthrough(G, device_ids=[0]).to('cuda:0')

        # Optimizers
        deformator_opt = torch.optim.Adam(deformator.parameters(), lr=self.p.deformator_lr) \
            if deformator.type not in [DeformatorType.ID, DeformatorType.RANDOM] else None
        shift_predictor_opt = torch.optim.Adam(shift_predictor.parameters(),
                                               lr=self.p.shift_predictor_lr)

        # Optimization Scheduler
        scheduler_def = lr_scheduler.ReduceLROnPlateau(
            deformator_opt,
            'min',
            min_lr=0.00001,
            factor=0.5,
            patience=1000,
            verbose=True,
            threshold=0.001,
            cooldown=500) if deformator.type not in [
                DeformatorType.ID, DeformatorType.RANDOM
            ] else None
        scheduler_pred = lr_scheduler.ReduceLROnPlateau(shift_predictor_opt,
                                                        'min',
                                                        min_lr=0.00001,
                                                        patience=1000,
                                                        verbose=True,
                                                        threshold=0.001,
                                                        cooldown=500)

        # Measures Trackers
        avgs = MeanTracker('class_correct_percent'), MeanTracker(
            'loss'), MeanTracker('direction_loss'), MeanTracker(
                'shift_loss'), MeanTracker('learning_rate')
        avg_correct_percent, avg_loss, avg_label_loss, avg_shift_loss, avg_lr = avgs

        # Load the checkpoint (deformator's weights, shift_predictor's weight, step)

        recovered_step = self.start_from_checkpoint(
            deformator, shift_predictor, deformator_opt, shift_predictor_opt,
            scheduler_def, scheduler_pred)
        # recovered_step = self.start_from_saved_model(deformator, shift_predictor,deformator_opt, shift_predictor_opt,
        #                                              scheduler_def, scheduler_pred,
        #                                              './anime_results_dir/BS=256 sft-lr=0.0100000 def-lr=0.0100000 dir_n=100 def-type=proj/models',
        #                                              5000)
        shift_predictor_opt.param_groups[0]['lr'] = 0.0001
        if recovered_step == self.p.n_steps - 1:
            print("[Trainer]: the model has been trained before.",
                  "Trying next hyperparameters")
            return

        for step in range(recovered_step, self.p.n_steps, 1):
            if step < 11 or False: begin = time.time()

            G.zero_grad()
            deformator.zero_grad()
            shift_predictor.zero_grad()

            if deformator.type == DeformatorType.ID:
                target_indices, shifts, basis_shift = self.make_shifts(48 * 8 *
                                                                       8)
            else:
                target_indices, shifts, basis_shift = self.make_shifts(
                    deformator.input_dim)

            # Deformation
            shift = deformator(basis_shift)
            shift = shift.view([-1] + G.dim_z)

            # Image Generation
            # z = make_noise(self.p.batch_size, G.dim_z, self.p.z_std, self.p.truncation).cuda()
            # with torch.set_grad_enabled(self.p.torch_grad):
            #     imgs = G(z=z, reverse=True)  # .clamp_(0, 1)
            #     imgs_shifted = G.nvp_shifted(z, shift, reverse=True)  # .clamp_(0, 1)

            # Sometimes the generator make infinite values in output images and cause gradient explosion. This part checks
            # if there is inf values in images and so substitue the invalid images vith valid images (have no inf).

            # if (torch.sum(torch.isinf(imgs)) + torch.sum(torch.isinf(imgs_shifted))) > 0:
            #     print('| Invalid Images')
            #     valid = False
            #     while not valid:
            #         for i in range(len(imgs)):
            #             if torch.sum(torch.isinf(imgs[i])) + torch.sum(torch.isinf(imgs_shifted[i])) > 0:
            #                 print(f'| image #{i} has Inf values.')
            #                 zz = make_noise(1, G.dim_z, self.p.z_std, self.p.truncation).cuda()
            #                 imgs[i] = G(z=zz, reverse=True)[0]
            #                 imgs_shifted[i] = G.nvp_shifted(zz, shift[i].unsqueeze(0), reverse=True)[0]
            #         torch.cuda.empty_cache()
            #         valid = (torch.sum(torch.isinf(imgs)) + torch.sum(torch.isinf(imgs_shifted))) == 0

            while True:
                z = make_noise(self.p.batch_size, G.dim_z, self.p.z_std,
                               self.p.truncation).cuda()
                with torch.set_grad_enabled(self.p.torch_grad):
                    imgs = G(z=z, reverse=True).to('cuda:0')
                    imgs_shifted = G(z=z + shift, reverse=True).to('cuda:0')

                if (torch.sum(torch.isinf(imgs)) +
                        torch.sum(torch.isinf(imgs_shifted))) == 0:
                    imgs = torch.clamp(imgs, min=0, max=1)
                    imgs_shifted = torch.clamp(imgs_shifted, min=0, max=1)
                    break
                print(
                    "Inf values in generated images!   Repeat image generation"
                )
                del imgs
                del imgs_shifted
                torch.cuda.empty_cache()

            if step % (self.p.steps_per_img_log / 2) == 0:
                image_grid = torchvision.utils.make_grid(
                    list(imgs[0:8]) + list(imgs_shifted[0:8]))
                self.writer.add_image("Training Shifted Images", image_grid,
                                      step)

            logits, shift_prediction = shift_predictor(imgs, imgs_shifted)
            logit_loss = self.p.label_weight * self.cross_entropy(
                logits, target_indices)
            shift_loss = self.p.shift_weight * torch.mean(
                torch.abs(shift_prediction - shifts))

            # total loss
            loss = logit_loss + shift_loss
            self.loss = loss.detach().item()
            loss.backward()

            if deformator_opt is not None:
                deformator_opt.step()
            shift_predictor_opt.step()

            if deformator.type not in [
                    DeformatorType.ID, DeformatorType.RANDOM
            ]:
                scheduler_def.step(loss)
            # scheduler_pred.step(loss)

            # update statistics trackers
            avg_correct_percent.add(
                torch.mean((torch.argmax(logits, dim=1) == target_indices).to(
                    torch.float32)).detach())
            avg_loss.add(loss.item())
            avg_label_loss.add(logit_loss.item())
            avg_shift_loss.add(shift_loss)
            avg_lr.add(shift_predictor_opt.param_groups[0]['lr'])

            # Optuna, report intermediate objective value
            trial.report(avg_loss.mean(), step)

            self.log(G, deformator, shift_predictor, deformator_opt,
                     shift_predictor_opt, scheduler_def, scheduler_pred, step,
                     avgs)

            # Handle pruning based on the intermediate value.
            if trial.should_prune():
                self.writer.add_hparams(
                    {
                        "bs": self.p.batch_size,
                        "sft-lr": self.p.shift_predictor_lr,
                        "dfr-lr": self.p.deformator_lr
                    }, {"loss": self.loss})
                print("%%%% The trial was prouned %%%%")
                raise optuna.TrialPruned()

            if step < 11 or False:
                end = time.time()
                print(f"Epoch {step} = {(end - begin):.2f} seconds.")

        return self.loss
Ejemplo n.º 16
0
                            })
# you can use weight clip in the training scope
trainer_dis = gluon.Trainer(
    discriminator.collect_params(),
    optimizer='rmsprop',
    optimizer_params={
        'learning_rate': lr,
        'epsilon': 1e-11,
        # 'clip_weights': 0.01
    })

fix_noise_name = os.path.join(fix_noise_dir, '{}_{}'.format(nz, batch_size))
if os.path.exists(fix_noise_name):
    fix_noise = mx.nd.load(fix_noise_name)[0]
else:
    fix_noise = make_noise()
    mx.nd.save(fix_noise_name, fix_noise)

# %% begin training
logger.info("Begin training")
dis_update_time = 0
gen_update_time = 0
iter4G = 100

g_train_loss = 0.0
d_train_loss = 0.0
for ep in tqdm.tqdm(range(epoch_start, epoch + 1),
                    total=epoch,
                    desc="Total Progress",
                    leave=False,
                    initial=epoch_start,
Ejemplo n.º 17
0
    def train(self, G, deformator, shift_predictor, inception):
        G.cuda().eval()
        deformator.cuda().train()
        shift_predictor.cuda().train()

        deformator_opt = torch.optim.Adam(deformator.parameters(), lr=self.p.deformator_lr) \
            if deformator.type not in [DeformatorType.ID, DeformatorType.RANDOM] else None
        shift_predictor_opt = torch.optim.Adam(shift_predictor.parameters(),
                                               lr=self.p.shift_predictor_lr)

        avgs = MeanTracker('percent'), MeanTracker('loss'), MeanTracker('direction_loss'),\
               MeanTracker('shift_loss'), MeanTracker('deformator_loss'), MeanTracker('inception_loss')
        avg_correct_percent, avg_loss, avg_label_loss, avg_shift_loss, avg_deformator_loss, avg_inception_loss = avgs

        recovered_step = self.start_from_checkpoint(deformator,
                                                    shift_predictor)
        for step in range(recovered_step, self.p.n_steps, 1):
            G.zero_grad()
            deformator.zero_grad()
            shift_predictor.zero_grad()

            z = make_noise(self.p.batch_size, G.dim_z).cuda()
            z_orig = torch.clone(z)
            target_indices, shifts, z_shift = self.make_shifts(G.dim_z)

            # Deformation

            if self.p.global_deformation:
                z_shifted = deformator(z + z_shift)
                z = deformator(z)
            else:
                z_shifted = z + deformator(z_shift)

            imgs = G(z)
            imgs_shifted = G(z_shifted)

            ##########################
            img_feats = inception(((imgs + 1.) / 2.).clamp(0, 1))
            if isinstance(img_feats, list):
                img_feats = img_feats[0]

            img_shifted_feats = inception(
                ((imgs_shifted + 1.) / 2.).clamp(0, 1))
            if isinstance(img_shifted_feats, list):
                img_shifted_feats = img_shifted_feats[0]

            # mean_img_feats = img_feats.mean(0)
            # std_img_feats = img_feats.std(0)
            # img_feats_distr = torch.distributions.Normal(loc=mean_img_feats, scale=std_img_feats)
            #
            # mean_img_shifted_feats = img_shifted_feats.mean(0)
            # std_img_shifted_feats = img_shifted_feats.std(0)
            # img_shifted_feats_distr = torch.distributions.Normal(loc=mean_img_shifted_feats,
            #                                              scale=std_img_shifted_feats)
            #
            # kl = torch.distributions.kl.kl_divergence(img_shifted_feats_distr, img_feats_distr)
            # inception_loss = self.p.inception_loss_weight * kl.mean()
            l2 = ((img_feats - img_shifted_feats)**2).mean()
            inception_loss = self.p.inception_loss_weight * l2
            ##########################

            logits, shift_prediction = shift_predictor(imgs, imgs_shifted)
            logit_loss = self.p.label_weight * self.cross_entropy(
                logits, target_indices)
            shift_loss = self.p.shift_weight * torch.mean(
                torch.abs(shift_prediction - shifts))

            # Loss

            # deformator penalty
            if self.p.deformation_loss == DeformatorLoss.STAT:
                z_std, z_mean = normal_projection_stat(z)
                z_loss = self.p.z_mean_weight * torch.abs(z_mean) + \
                    self.p.z_std_weight * torch.abs(1.0 - z_std)

            elif self.p.deformation_loss == DeformatorLoss.L2:
                z_loss = self.p.deformation_loss_weight * torch.mean(
                    torch.norm(z, dim=1))
                if z_loss < self.p.z_norm_loss_low_bound * torch.mean(
                        torch.norm(z_orig, dim=1)):
                    z_loss = torch.tensor([0.0], device='cuda')

            elif self.p.deformation_loss == DeformatorLoss.RELATIVE:
                deformation_norm = torch.norm(z - z_shifted, dim=1)
                z_loss = self.p.deformation_loss_weight * torch.mean(
                    torch.abs(deformation_norm - shifts))

            else:
                z_loss = torch.tensor([0.0], device='cuda')

            # total loss
            loss = logit_loss + shift_loss + z_loss + inception_loss
            loss.backward()

            if deformator_opt is not None:
                deformator_opt.step()
            shift_predictor_opt.step()

            # update statistics trackers
            avg_correct_percent.add(
                torch.mean((torch.argmax(logits, dim=1) == target_indices).to(
                    torch.float32)).detach())
            avg_loss.add(loss.item())
            avg_label_loss.add(logit_loss.item())
            avg_shift_loss.add(shift_loss)
            avg_deformator_loss.add(z_loss.item())
            avg_inception_loss.add(inception_loss.item())

            self.log(G, deformator, shift_predictor, step, avgs)