Beispiel #1
0
class fAnoGAN:
    @monkey_patch_fn_args_as_config
    def __init__(
        self,
        input_shape,
        lr=1e-4,
        critic_iters=1,
        gen_iters=5,
        n_epochs=10,
        gp_lambda=10,
        z_dim=512,
        print_every_iter=20,
        plot_every_epoch=1,
        log_dir=None,
        load_path=None,
        logger="visdom",
        data_dir=None,
        use_encoder=True,
        enocoder_feature_weight=1e-4,
        encoder_discr_weight=0.0,
    ):

        self.plot_every_epoch = plot_every_epoch
        self.print_every_iter = print_every_iter
        self.gp_lambda = gp_lambda
        self.n_epochs = n_epochs
        self.gen_iters = gen_iters
        self.critic_iters = critic_iters
        self.size = input_shape[2]
        self.batch_size = input_shape[0]
        self.input_shape = input_shape
        self.z_dim = z_dim
        self.logger = logger
        self.data_dir = data_dir
        self.use_encoder = use_encoder
        self.enocoder_feature_weight = enocoder_feature_weight
        self.encoder_discr_weight = encoder_discr_weight

        log_dict = {}
        if logger is not None:
            log_dict = {
                0: (logger),
            }
        self.tx = PytorchExperimentStub(
            name="fanogan",
            base_dir=log_dir,
            config=fn_args_as_config,
            loggers=log_dict,
        )

        cuda_available = torch.cuda.is_available()
        self.device = torch.device("cuda" if cuda_available else "cpu")

        self.n_image_channels = input_shape[1]

        self.gen = IWGenerator(self.size,
                               z_dim=z_dim,
                               n_image_channels=self.n_image_channels)
        self.dis = IWDiscriminator(self.size,
                                   n_image_channels=self.n_image_channels)

        self.gen.apply(weights_init)
        self.dis.apply(weights_init)

        self.optimizer_G = torch.optim.Adam(self.gen.parameters(),
                                            lr=lr,
                                            betas=(0.5, 0.999))
        self.optimizer_D = torch.optim.Adam(self.dis.parameters(),
                                            lr=lr,
                                            betas=(0.5, 0.999))

        self.gen = self.gen.to(self.device)
        self.dis = self.dis.to(self.device)

        if self.use_encoder:
            self.enc = IWEncoder(self.size,
                                 z_dim=z_dim,
                                 n_image_channels=self.n_image_channels)
            self.enc.apply(weights_init)
            self.enc = self.enc.to(self.device)
            self.optimizer_E = torch.optim.Adam(self.enc.parameters(),
                                                lr=lr,
                                                betas=(0.5, 0.999))

        self.z = torch.randn(self.batch_size, z_dim).to(self.device)

        if load_path is not None:
            PytorchExperimentLogger.load_model_static(
                self.dis, os.path.join(load_path, "dis_final.pth"))
            PytorchExperimentLogger.load_model_static(
                self.gen, os.path.join(load_path, "gen_final.pth"))
            if self.use_encoder:
                try:
                    pass
                    # PytorchExperimentLogger.load_model_static(self.enc, os.path.join(load_path, "enc_final.pth"))
                except Exception:
                    warnings.warn("Could not find an Encoder in the directory")
            time.sleep(5)

    def train(self):

        train_loader = get_numpy2d_dataset(
            base_dir=self.data_dir,
            num_processes=16,
            pin_memory=False,
            batch_size=self.batch_size,
            mode="train",
            target_size=self.size,
            slice_offset=10,
        )

        print("Training GAN...")
        for epoch in range(self.n_epochs):
            # for epoch in range(0):

            data_loader_ = tqdm(enumerate(train_loader))
            for i, batch in data_loader_:
                batch = batch * 2 - 1 + torch.randn_like(batch) * 0.01

                real_imgs = batch.to(self.device)

                # ---------------------
                #  Train Discriminator
                # ---------------------
                # disc_cost = []
                # w_dist = []
                if i % self.critic_iters == 0:
                    self.optimizer_G.zero_grad()
                    self.optimizer_D.zero_grad()

                    batch_size_curr = real_imgs.shape[0]

                    self.z.normal_()

                    fake_imgs = self.gen(self.z[:batch_size_curr])

                    real_validity = self.dis(real_imgs)
                    fake_validity = self.dis(fake_imgs)

                    gradient_penalty = self.calc_gradient_penalty(
                        self.dis,
                        real_imgs,
                        fake_imgs,
                        batch_size_curr,
                        self.size,
                        self.device,
                        self.gp_lambda,
                        n_image_channels=self.n_image_channels,
                    )

                    d_loss = -torch.mean(real_validity) + torch.mean(
                        fake_validity) + self.gp_lambda * gradient_penalty
                    d_loss.backward()
                    self.optimizer_D.step()

                    # disc_cost.append(d_loss.item())
                    w_dist = (-torch.mean(real_validity) +
                              torch.mean(fake_validity)).item()

                # -----------------
                #  Train Generator
                # -----------------
                # gen_cost = []
                if i % self.gen_iters == 0:
                    self.optimizer_G.zero_grad()
                    self.optimizer_D.zero_grad()

                    batch_size_curr = self.batch_size

                    fake_imgs = self.gen(self.z)

                    fake_validity = self.dis(fake_imgs)
                    g_loss = -torch.mean(fake_validity)

                    g_loss.backward()
                    self.optimizer_G.step()

                    # gen_cost.append(g_loss.item())

                if i % self.print_every_iter == 0:
                    status_str = (
                        f"Train Epoch: {epoch} [{i}/{len(train_loader)} "
                        f" ({100.0 * i / len(train_loader):.0f}%)] Dis: "
                        f"{d_loss.item() / batch_size_curr:.6f} vs Gen: "
                        f"{g_loss.item() / batch_size_curr:.6f} (W-Dist: {w_dist / batch_size_curr:.6f})"
                    )
                    data_loader_.set_description_str(status_str)
                    # print(f"[Epoch {epoch}/{self.n_epochs}] [Batch {i}/{len(train_loader)}]")

                    # print(d_loss.item(), g_loss.item())
                    cnt = epoch * len(train_loader) + i

                    self.tx.add_result(d_loss.item(),
                                       name="trainDisCost",
                                       tag="DisVsGen",
                                       counter=cnt)
                    self.tx.add_result(g_loss.item(),
                                       name="trainGenCost",
                                       tag="DisVsGen",
                                       counter=cnt)
                    self.tx.add_result(w_dist,
                                       "wasserstein_distance",
                                       counter=cnt)

                    self.tx.l[0].show_image_grid(
                        fake_imgs.reshape(batch_size_curr,
                                          self.n_image_channels, self.size,
                                          self.size),
                        "GeneratedImages",
                        image_args={"normalize": True},
                    )

        self.tx.save_model(self.dis, "dis_final")
        self.tx.save_model(self.gen, "gen_final")

        self.gen.train(True)
        self.dis.train(True)

        if not self.use_encoder:
            time.sleep(10)
            return

        weight_features = self.enocoder_feature_weight
        weight_disc = self.encoder_discr_weight
        print("Training Encoder...")
        for epoch in range(self.n_epochs // 2):
            data_loader_ = tqdm(enumerate(train_loader))
            for i, batch in data_loader_:
                batch = batch * 2 - 1 + torch.randn_like(batch) * 0.01
                real_img = batch.to(self.device)
                batch_size_curr = real_img.shape[0]

                self.optimizer_G.zero_grad()
                self.optimizer_D.zero_grad()
                self.optimizer_E.zero_grad()

                z = self.enc(real_img)
                recon_img = self.gen(z)

                _, img_feats = self.dis.forward_last_feature(real_img)
                disc_loss, recon_feats = self.dis.forward_last_feature(
                    recon_img)

                recon_img = recon_img.reshape(batch_size_curr,
                                              self.n_image_channels, self.size,
                                              self.size)
                loss_img = self.mse(real_img, recon_img)
                loss_feat = self.mse(img_feats, recon_feats) * weight_features
                disc_loss = -torch.mean(disc_loss) * weight_disc

                loss = loss_img + loss_feat + disc_loss

                loss.backward()
                self.optimizer_E.step()

                if i % self.print_every_iter == 0:
                    status_str = (
                        f"[Epoch {epoch}/{self.n_epochs // 2}] [Batch {i}/{len(train_loader)}] Loss:{loss:.06f}"
                    )
                    data_loader_.set_description_str(status_str)

                    cnt = epoch * len(train_loader) + i
                    self.tx.add_result(loss.item(),
                                       name="EncoderLoss",
                                       counter=cnt)

                    self.tx.l[0].show_image_grid(
                        real_img.reshape(batch_size_curr,
                                         self.n_image_channels, self.size,
                                         self.size),
                        "RealImages",
                        image_args={"normalize": True},
                    )
                    self.tx.l[0].show_image_grid(
                        recon_img.reshape(batch_size_curr,
                                          self.n_image_channels, self.size,
                                          self.size),
                        "ReconImages",
                        image_args={"normalize": True},
                    )

        self.tx.save_model(self.enc, "enc_final")
        self.enc.train(False)

        time.sleep(10)

    def score_sample(self, np_array):

        orig_shape = np_array.shape
        to_transforms = torch.nn.Upsample(
            (self.input_shape[2], self.input_shape[3]), mode="bilinear")

        data_tensor = torch.from_numpy(np_array).float()
        data_tensor = to_transforms(data_tensor[None])[0]
        slice_scores = []

        for i in range(ceil(orig_shape[0] / self.batch_size)):
            batch = data_tensor[i * self.batch_size:(i + 1) *
                                self.batch_size].unsqueeze(1)
            batch = batch * 2 - 1
            real_imgs = batch.to(self.device)
            batch_size_curr = real_imgs.shape[0]

            if self.use_encoder:
                z = self.enc(real_imgs)
            else:
                z = self.backprop_to_nearest_z(real_imgs)

            pseudo_img_recon = self.gen(z)

            pseudo_img_recon = pseudo_img_recon.reshape(
                batch_size_curr, self.n_image_channels, self.size, self.size)
            img_diff = torch.mean(torch.abs(pseudo_img_recon - real_imgs),
                                  dim=1,
                                  keepdim=True)

            loss = torch.sum(img_diff, dim=(1, 2, 3)).detach()

            slice_scores += loss.cpu().tolist()

        return np.max(slice_scores)

    def score_pixels(self, np_array):

        orig_shape = np_array.shape
        to_transforms = torch.nn.Upsample(
            (self.input_shape[2], self.input_shape[3]), mode="bilinear")
        from_transforms = torch.nn.Upsample((orig_shape[1], orig_shape[2]),
                                            mode="bilinear")

        data_tensor = torch.from_numpy(np_array).float()
        data_tensor = to_transforms(data_tensor[None])[0]
        target_tensor = torch.zeros_like(data_tensor)

        for i in range(ceil(orig_shape[0] / self.batch_size)):
            batch = data_tensor[i * self.batch_size:(i + 1) *
                                self.batch_size].unsqueeze(1)
            batch = batch * 2 - 1
            real_imgs = batch.to(self.device)
            batch_size_curr = real_imgs.shape[0]

            if self.use_encoder:
                z = self.enc(real_imgs)
            else:
                z = self.backprop_to_nearest_z(real_imgs)

            pseudo_img_recon = self.gen(z)

            pseudo_img_recon = pseudo_img_recon.reshape(
                batch_size_curr, self.n_image_channels, self.size, self.size)
            img_diff = torch.mean(torch.abs(pseudo_img_recon - real_imgs),
                                  dim=1,
                                  keepdim=True)

            loss = img_diff[:, 0, :]
            target_tensor[i * self.batch_size:(i + 1) *
                          self.batch_size] = loss.cpu()

        target_tensor = from_transforms(target_tensor[None])[0]

        return target_tensor.detach().numpy()

    def backprop_to_nearest_z(self, real_imgs):

        batch_size_curr = real_imgs.shape[0]

        z = torch.randn(batch_size_curr, self.z_dim).to(self.device).normal_()
        z.requires_grad = True
        # optimizer_z = torch.optim.LBFGS([z], lr=0.02)
        optimizer_z = torch.optim.Adam([z], lr=0.002)
        # optimizer_z = torch.optim.RMSprop([z], lr=0.05)

        for i in range(200):

            def closure():
                self.gen.zero_grad()
                optimizer_z.zero_grad()

                pseudo_img_recon = self.gen(z)

                _, img_feats = self.dis.forward_last_feature(real_imgs)
                disc_loss, recon_feats = self.dis.forward_last_feature(
                    pseudo_img_recon)

                pseudo_img_recon = pseudo_img_recon.reshape(
                    batch_size_curr, self.n_image_channels, self.size,
                    self.size)
                disc_loss = torch.mean(disc_loss)

                imgs_diff = torch.mean(torch.abs(pseudo_img_recon - real_imgs))
                feats_diff = torch.mean(torch.abs(img_feats - recon_feats))
                loss = imgs_diff - disc_loss * 0.001  # + feats_diff

                loss.backward()

                return loss

            optimizer_z.step(closure)

        return z.detach()

    def score(self, batch):
        real_imgs = batch.to(self.device).float()

        z = self.enc(real_imgs)

        batch_size_curr = real_imgs.shape[0]

        # z = torch.randn(batch_size_curr, self.z_dim).to(self.device).normal_()
        # z.requires_grad = True
        # # optimizer_z = torch.optim.LBFGS([z], lr=0.02)
        # optimizer_z = torch.optim.Adam([z], lr=0.002)
        # # optimizer_z = torch.optim.RMSprop([z], lr=0.05)
        #
        # cn = dict(tr=0)
        #
        # self.tx.vlog.show_image_grid(real_imgs, "RealImages",
        #                              image_args={"normalize": True})
        #
        # for i in range(200):
        #     def closure():
        #         self.gen.zero_grad()
        #         optimizer_z.zero_grad()
        #
        #         pseudo_img_recon = self.gen(z)
        #
        #         _, img_feats = self.dis.forward_last_feature(real_imgs)
        #         disc_loss, recon_feats = self.dis.forward_last_feature(pseudo_img_recon)
        #
        #         pseudo_img_recon = pseudo_img_recon.reshape(batch_size_curr, self.n_image_channels, self.size, self.size)
        #         disc_loss = torch.mean(disc_loss)
        #
        #         imgs_diff = torch.mean(torch.abs(pseudo_img_recon - real_imgs))
        #         feats_diff = torch.mean(torch.abs(img_feats - recon_feats))
        #         loss = imgs_diff - disc_loss * 0.001  # + feats_diff
        #
        #         loss.backward()
        #         # optimizer_z.step()
        #         #
        #         # if cn['tr'] % 20 == 0:
        #         # pseudo_img_recon = pseudo_img_recon.clamp(-1.5, 1.5)
        #         self.tx.vlog.show_image_grid(pseudo_img_recon, "PseudoImages",
        #                                      image_args={"normalize": True})
        #         self.tx.vlog.show_image_grid(torch.mean(torch.abs(pseudo_img_recon - real_imgs), dim=1, keepdim=True),
        #                                      "DiffImages", image_args={"normalize": True})
        #         #
        #         # tx.add_result(disc_loss.item() * 0.001, name="DiscLoss", tag="AnoIter")
        #         # tx.add_result(imgs_diff.item(), name="ImgsDiff", tag="AnoIter")
        #         # tx.add_result(torch.mean(torch.pow(z, 2)).item(), name="ZDevi", tag="AnoIter")
        #         #
        #         # cn['tr'] += 1
        #
        #         return loss
        #
        #     optimizer_z.step(closure)
        #
        #     # time.sleep(1)
        #
        #     print(i)
        #
        pseudo_img_recon = self.gen(z)

        pseudo_img_recon = pseudo_img_recon.reshape(batch_size_curr,
                                                    self.n_image_channels,
                                                    self.size, self.size)
        img_diff = torch.mean(torch.abs(pseudo_img_recon - real_imgs),
                              dim=1,
                              keepdim=True)

        img_scores = torch.sum(img_diff, dim=(1, 2, 3)).detach().tolist()
        pixel_scores = img_diff.flatten().detach().tolist()

        self.tx.vlog.show_image_grid(pseudo_img_recon,
                                     "PseudoImages",
                                     image_args={"normalize": True})
        self.tx.vlog.show_image_grid(
            torch.mean(torch.abs(pseudo_img_recon - real_imgs),
                       dim=1,
                       keepdim=True),
            "DiffImages",
            image_args={"normalize": True},
        )

        # print("One Down")

        return img_scores, pixel_scores

    @staticmethod
    def mse(x, y):
        return torch.mean(torch.pow(x - y, 2))

    @staticmethod
    def calc_gradient_penalty(netD,
                              real_data,
                              fake_data,
                              batch_size,
                              dim,
                              device,
                              gp_lambda,
                              n_image_channels=3):
        alpha = torch.rand(batch_size, 1)
        alpha = alpha.expand(batch_size, int(real_data.nelement() /
                                             batch_size)).contiguous()
        alpha = alpha.view(batch_size, n_image_channels, dim, dim)
        alpha = alpha.to(device)

        fake_data = fake_data.view(batch_size, n_image_channels, dim, dim)
        interpolates = alpha * real_data.detach() + (
            (1 - alpha) * fake_data.detach())

        interpolates = interpolates.to(device)
        interpolates.requires_grad_(True)

        disc_interpolates = netD(interpolates)

        gradients = torch.autograd.grad(
            outputs=disc_interpolates,
            inputs=interpolates,
            grad_outputs=torch.ones(disc_interpolates.size()).to(device),
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]

        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = (
            (gradients.norm(2, dim=1) - 1)**2).mean() * gp_lambda
        return gradient_penalty

    def print(self, *args):
        print(*args)
        self.tx.print(*args)

    def log_result(self, val, key=None):
        self.tx.print(key, val)
        self.tx.add_result_without_epoch(val, key)
Beispiel #2
0
class AE2D:
    @monkey_patch_fn_args_as_config
    def __init__(
        self,
        input_shape,
        lr=1e-4,
        n_epochs=20,
        z_dim=512,
        model_feature_map_sizes=(16, 64, 256, 1024),
        load_path=None,
        log_dir=None,
        logger="visdom",
        print_every_iter=100,
        data_dir=None,
    ):

        self.print_every_iter = print_every_iter
        self.n_epochs = n_epochs
        self.batch_size = input_shape[0]
        self.z_dim = z_dim
        self.input_shape = input_shape
        self.logger = logger
        self.data_dir = data_dir

        log_dict = {}
        if logger is not None:
            log_dict = {
                0: (logger),
            }
        self.tx = PytorchExperimentStub(
            name="ae2d",
            base_dir=log_dir,
            config=fn_args_as_config,
            loggers=log_dict,
        )

        cuda_available = torch.cuda.is_available()
        self.device = torch.device("cuda" if cuda_available else "cpu")

        self.model = AE(input_size=input_shape[1:],
                        z_dim=z_dim,
                        fmap_sizes=model_feature_map_sizes).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)

        if load_path is not None:
            PytorchExperimentLogger.load_model_static(
                self.model, os.path.join(load_path, "ae_final.pth"))
            time.sleep(5)

    def train(self):

        train_loader = get_numpy2d_dataset(
            base_dir=self.data_dir,
            num_processes=16,
            pin_memory=True,
            batch_size=self.batch_size,
            mode="train",
            target_size=self.input_shape[2],
        )
        val_loader = get_numpy2d_dataset(
            base_dir=self.data_dir,
            num_processes=8,
            pin_memory=True,
            batch_size=self.batch_size,
            mode="val",
            target_size=self.input_shape[2],
        )

        for epoch in range(self.n_epochs):

            ### Train
            self.model.train()

            train_loss = 0
            print("\nStart epoch ", epoch)
            data_loader_ = tqdm(enumerate(train_loader))
            for batch_idx, data in data_loader_:
                inpt = data.to(self.device)

                self.optimizer.zero_grad()
                inpt_rec = self.model(inpt)

                loss = torch.mean(torch.pow(inpt - inpt_rec, 2))
                loss.backward()
                self.optimizer.step()

                train_loss += loss.item()
                if batch_idx % self.print_every_iter == 0:
                    status_str = (
                        f"Train Epoch: {epoch} [{batch_idx}/{len(train_loader)} "
                        f" ({100.0 * batch_idx / len(train_loader):.0f}%)] Loss: "
                        f"{loss.item() / len(inpt):.6f}")
                    data_loader_.set_description_str(status_str)

                    cnt = epoch * len(train_loader) + batch_idx
                    self.tx.add_result(loss.item(),
                                       name="Train-Loss",
                                       tag="Losses",
                                       counter=cnt)

                    if self.logger is not None:
                        self.tx.l[0].show_image_grid(
                            inpt, name="Input", image_args={"normalize": True})
                        self.tx.l[0].show_image_grid(
                            inpt_rec,
                            name="Reconstruction",
                            image_args={"normalize": True})

            print(
                f"====> Epoch: {epoch} Average loss: {train_loss / len(train_loader):.4f}"
            )

            ### Validate
            self.model.eval()

            val_loss = 0
            with torch.no_grad():
                data_loader_ = tqdm(enumerate(val_loader))
                data_loader_.set_description_str("Validating")
                for i, data in data_loader_:
                    inpt = data.to(self.device)
                    inpt_rec = self.model(inpt)

                    loss = torch.mean(torch.pow(inpt - inpt_rec, 2))
                    val_loss += loss.item()

                self.tx.add_result(val_loss / len(val_loader),
                                   name="Val-Loss",
                                   tag="Losses",
                                   counter=(epoch + 1) * len(train_loader))

            print(
                f"====> Epoch: {epoch} Validation loss: {val_loss / len(val_loader):.4f}"
            )

        self.tx.save_model(self.model, "ae_final")

        time.sleep(10)

    def score_sample(self, np_array):

        orig_shape = np_array.shape
        to_transforms = torch.nn.Upsample(
            (self.input_shape[2], self.input_shape[3]), mode="bilinear")

        data_tensor = torch.from_numpy(np_array).float()
        data_tensor = to_transforms(data_tensor[None])[0]
        slice_scores = []

        for i in range(ceil(orig_shape[0] / self.batch_size)):
            batch = data_tensor[i * self.batch_size:(i + 1) *
                                self.batch_size].unsqueeze(1)
            batch = batch.to(self.device)

            with torch.no_grad():
                batch_rec = self.model(batch)
                loss = torch.mean(torch.pow(batch - batch_rec, 2),
                                  dim=(1, 2, 3))

            slice_scores += loss.cpu().tolist()

        return np.max(slice_scores)

    def score_pixels(self, np_array):

        orig_shape = np_array.shape
        to_transforms = torch.nn.Upsample(
            (self.input_shape[2], self.input_shape[3]), mode="bilinear")
        from_transforms = torch.nn.Upsample((orig_shape[1], orig_shape[2]),
                                            mode="bilinear")

        data_tensor = torch.from_numpy(np_array).float()
        data_tensor = to_transforms(data_tensor[None])[0]
        target_tensor = torch.zeros_like(data_tensor)

        for i in range(ceil(orig_shape[0] / self.batch_size)):
            batch = data_tensor[i * self.batch_size:(i + 1) *
                                self.batch_size].unsqueeze(1)
            batch = batch.to(self.device)

            batch_rec = self.model(batch)

            loss = torch.pow(batch - batch_rec, 2)[:, 0, :]
            target_tensor[i * self.batch_size:(i + 1) *
                          self.batch_size] = loss.cpu()

        target_tensor = from_transforms(target_tensor[None])[0]

        return target_tensor.detach().numpy()

    def print(self, *args):
        print(*args)
        self.tx.print(*args)
Beispiel #3
0
class ceVAE:
    @monkey_patch_fn_args_as_config
    def __init__(
        self,
        input_shape,
        lr=1e-4,
        n_epochs=20,
        z_dim=512,
        model_feature_map_sizes=(16, 64, 256, 1024),
        use_geco=False,
        beta=0.01,
        ce_factor=0.5,
        score_mode="combi",
        load_path=None,
        log_dir=None,
        logger="visdom",
        print_every_iter=100,
        data_dir=None,
    ):

        self.score_mode = score_mode
        self.ce_factor = ce_factor
        self.beta = beta
        self.print_every_iter = print_every_iter
        self.n_epochs = n_epochs
        self.batch_size = input_shape[0]
        self.z_dim = z_dim
        self.use_geco = use_geco
        self.input_shape = input_shape
        self.logger = logger
        self.data_dir = data_dir

        log_dict = {}
        if logger is not None:
            log_dict = {
                0: (logger),
            }
        self.tx = PytorchExperimentStub(
            name="cevae",
            base_dir=log_dir,
            config=fn_args_as_config,
            loggers=log_dict,
        )

        cuda_available = torch.cuda.is_available()
        self.device = torch.device("cuda" if cuda_available else "cpu")

        self.model = VAE(input_size=input_shape[1:],
                         z_dim=z_dim,
                         fmap_sizes=model_feature_map_sizes).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)

        self.vae_loss_ema = 1
        self.theta = 1

        if load_path is not None:
            PytorchExperimentLogger.load_model_static(
                self.model, os.path.join(load_path, "vae_final.pth"))
            time.sleep(5)

    def train(self):

        train_loader = get_numpy2d_dataset(
            base_dir=self.data_dir,
            num_processes=16,
            pin_memory=False,
            batch_size=self.batch_size,
            mode="train",
            target_size=self.input_shape[2],
        )
        val_loader = get_numpy2d_dataset(
            base_dir=self.data_dir,
            num_processes=8,
            pin_memory=False,
            batch_size=self.batch_size,
            mode="val",
            target_size=self.input_shape[2],
        )

        for epoch in range(self.n_epochs):

            self.model.train()
            train_loss = 0

            print("Start epoch")
            data_loader_ = tqdm(enumerate(train_loader))
            for batch_idx, data in data_loader_:
                data = data * 2 - 1
                self.optimizer.zero_grad()

                inpt = data.to(self.device)

                ### VAE Part
                loss_vae = 0
                if self.ce_factor < 1:
                    x_rec_vae, z_dist, = self.model(inpt)

                    kl_loss = 0
                    if self.beta > 0:
                        kl_loss = self.kl_loss_fn(z_dist) * self.beta
                    rec_loss_vae = self.rec_loss_fn(x_rec_vae, inpt)
                    loss_vae = kl_loss + rec_loss_vae * self.theta

                ### CE Part
                loss_ce = 0
                if self.ce_factor > 0:

                    ce_tensor = get_square_mask(
                        data.shape,
                        square_size=(0, np.max(self.input_shape[2:]) // 2),
                        noise_val=(torch.min(data).item(),
                                   torch.max(data).item()),
                        n_squares=(0, 3),
                    )
                    ce_tensor = torch.from_numpy(ce_tensor).float()
                    inpt_noisy = torch.where(ce_tensor != 0, ce_tensor, data)

                    inpt_noisy = inpt_noisy.to(self.device)
                    x_rec_ce, _ = self.model(inpt_noisy)
                    rec_loss_ce = self.rec_loss_fn(x_rec_ce, inpt)
                    loss_ce = rec_loss_ce

                loss = (1.0 -
                        self.ce_factor) * loss_vae + self.ce_factor * loss_ce

                if self.use_geco and self.ce_factor < 1:
                    g_goal = 0.1
                    g_lr = 1e-4
                    self.vae_loss_ema = (
                        1.0 - 0.9) * rec_loss_vae + 0.9 * self.vae_loss_ema
                    self.theta = self.geco_beta_update(self.theta,
                                                       self.vae_loss_ema,
                                                       g_goal,
                                                       g_lr,
                                                       speedup=2)

                if torch.isnan(loss):
                    print("A wild NaN occurred")
                    continue

                loss.backward()
                self.optimizer.step()
                train_loss += loss.item()

                if batch_idx % self.print_every_iter == 0:
                    status_str = (
                        f"Train Epoch: {epoch} [{batch_idx}/{len(train_loader)} "
                        f" ({100.0 * batch_idx / len(train_loader):.0f}%)] Loss: "
                        f"{loss.item() / len(inpt):.6f}")
                    data_loader_.set_description_str(status_str)

                    cnt = epoch * len(train_loader) + batch_idx

                    if self.ce_factor < 1:
                        self.tx.l[0].show_image_grid(
                            inpt,
                            name="Input-VAE",
                            image_args={"normalize": True})
                        self.tx.l[0].show_image_grid(
                            x_rec_vae,
                            name="Output-VAE",
                            image_args={"normalize": True})

                        if self.beta > 0:
                            self.tx.add_result(torch.mean(kl_loss).item(),
                                               name="Kl-loss",
                                               tag="Losses",
                                               counter=cnt)
                        self.tx.add_result(torch.mean(rec_loss_vae).item(),
                                           name="Rec-loss",
                                           tag="Losses",
                                           counter=cnt)
                        self.tx.add_result(loss_vae.item(),
                                           name="Train-loss",
                                           tag="Losses",
                                           counter=cnt)

                    if self.ce_factor > 0:
                        self.tx.l[0].show_image_grid(
                            inpt_noisy,
                            name="Input-CE",
                            image_args={"normalize": True})
                        self.tx.l[0].show_image_grid(
                            x_rec_ce,
                            name="Output-CE",
                            image_args={"normalize": True})

            print(
                f"====> Epoch: {epoch} Average loss: {train_loss / len(train_loader):.4f}"
            )

            self.model.eval()

            val_loss = 0
            with torch.no_grad():
                data_loader_ = tqdm(enumerate(val_loader))
                for i, data in data_loader_:
                    data = data * 2 - 1
                    inpt = data.to(self.device)

                    x_rec, z_dist = self.model(inpt, sample=False)

                    kl_loss = 0
                    if self.beta > 0:
                        kl_loss = self.kl_loss_fn(z_dist) * self.beta
                    rec_loss = self.rec_loss_fn(x_rec, inpt)
                    loss = kl_loss + rec_loss * self.theta

                    val_loss += loss.item()

                self.tx.add_result(val_loss / len(val_loader),
                                   name="Val-Loss",
                                   tag="Losses",
                                   counter=(epoch + 1) * len(train_loader))

            print(
                f"====> Epoch: {epoch} Validation loss: {val_loss / len(val_loader):.4f}"
            )

        self.tx.save_model(self.model, "vae_final")

        time.sleep(10)

    def score_sample(self, np_array):

        orig_shape = np_array.shape
        to_transforms = torch.nn.Upsample(
            (self.input_shape[2], self.input_shape[3]), mode="bilinear")

        data_tensor = torch.from_numpy(np_array).float()
        data_tensor = to_transforms(data_tensor[None])[0]
        slice_scores = []

        for i in range(ceil(orig_shape[0] / self.batch_size)):
            batch = data_tensor[i * self.batch_size:(i + 1) *
                                self.batch_size].unsqueeze(1)
            batch = batch * 2 - 1

            with torch.no_grad():
                inpt = batch.to(self.device).float()
                x_rec, z_dist = self.model(inpt, sample=False)
                kl_loss = self.kl_loss_fn(z_dist, sum_samples=False)
                rec_loss = self.rec_loss_fn(x_rec, inpt, sum_samples=False)
                img_scores = kl_loss * self.beta + rec_loss * self.theta

            slice_scores += img_scores.cpu().tolist()

        return np.max(slice_scores)

    def score_pixels(self, np_array):

        orig_shape = np_array.shape
        to_transforms = torch.nn.Upsample(
            (self.input_shape[2], self.input_shape[3]), mode="bilinear")
        from_transforms = torch.nn.Upsample((orig_shape[1], orig_shape[2]),
                                            mode="bilinear")

        data_tensor = torch.from_numpy(np_array).float()
        data_tensor = to_transforms(data_tensor[None])[0]
        target_tensor = torch.zeros_like(data_tensor)

        for i in range(ceil(orig_shape[0] / self.batch_size)):
            batch = data_tensor[i * self.batch_size:(i + 1) *
                                self.batch_size].unsqueeze(1)
            batch = batch * 2 - 1

            inpt = batch.to(self.device).float()
            x_rec, z_dist = self.model(inpt, sample=False)

            if self.score_mode == "combi":

                rec = torch.pow((x_rec - inpt), 2).detach().cpu()
                rec = torch.mean(rec, dim=1, keepdim=True)

                def __err_fn(x):
                    x_r, z_d = self.model(x, sample=False)
                    loss = self.kl_loss_fn(z_d)
                    return loss

                loss_grad_kl = (get_smooth_image_gradient(
                    model=self.model,
                    inpt=inpt,
                    err_fn=__err_fn,
                    grad_type="vanilla",
                    n_runs=2).detach().cpu())
                loss_grad_kl = torch.mean(loss_grad_kl, dim=1, keepdim=True)

                pixel_scores = smooth_tensor(normalize(loss_grad_kl),
                                             kernel_size=8) * rec

            elif self.score_mode == "rec":

                rec = torch.pow((x_rec - inpt), 2).detach().cpu()
                rec = torch.mean(rec, dim=1, keepdim=True)
                pixel_scores = rec

            elif self.score_mode == "grad":

                def __err_fn(x):
                    x_r, z_d = self.model(x, sample=False)
                    kl_loss_ = self.kl_loss_fn(z_d)
                    rec_loss_ = self.rec_loss_fn(x_r, x)
                    loss_ = kl_loss_ * self.beta + rec_loss_ * self.theta
                    return torch.mean(loss_)

                loss_grad_kl = (get_smooth_image_gradient(
                    model=self.model,
                    inpt=inpt,
                    err_fn=__err_fn,
                    grad_type="vanilla",
                    n_runs=2).detach().cpu())
                loss_grad_kl = torch.mean(loss_grad_kl, dim=1, keepdim=True)

                pixel_scores = smooth_tensor(normalize(loss_grad_kl),
                                             kernel_size=8)

            self.tx.elog.show_image_grid(inpt,
                                         name="Input",
                                         image_args={"normalize": True},
                                         n_iter=i)
            self.tx.elog.show_image_grid(x_rec,
                                         name="Output",
                                         image_args={"normalize": True},
                                         n_iter=i)
            self.tx.elog.show_image_grid(pixel_scores,
                                         name="Scores",
                                         image_args={"normalize": True},
                                         n_iter=i)

            target_tensor[i * self.batch_size:(i + 1) *
                          self.batch_size] = pixel_scores.detach().cpu()[:,
                                                                         0, :]

        target_tensor = from_transforms(target_tensor[None])[0]

        return target_tensor.detach().numpy()

    @staticmethod
    def load_trained_model(model, tx, path):
        tx.elog.load_model_static(model=model, model_file=path)

    @staticmethod
    def kl_loss_fn(z_post, sum_samples=True, correct=False):
        z_prior = dist.Normal(0, 1.0)
        kl_div = dist.kl_divergence(z_post, z_prior)
        if correct:
            kl_div = torch.sum(kl_div, dim=(1, 2, 3))
        else:
            kl_div = torch.mean(kl_div, dim=(1, 2, 3))
        if sum_samples:
            return torch.mean(kl_div)
        else:
            return kl_div

    @staticmethod
    def rec_loss_fn(recon_x, x, sum_samples=True, correct=False):
        if correct:
            x_dist = dist.Laplace(recon_x, 1.0)
            log_p_x_z = x_dist.log_prob(x)
            log_p_x_z = torch.sum(log_p_x_z, dim=(1, 2, 3))
        else:
            log_p_x_z = -torch.abs(recon_x - x)
            log_p_x_z = torch.mean(log_p_x_z, dim=(1, 2, 3))
        if sum_samples:
            return -torch.mean(log_p_x_z)
        else:
            return -log_p_x_z

    @staticmethod
    def get_inpt_grad(model, inpt, err_fn):
        model.zero_grad()
        inpt = inpt.detach()
        inpt.requires_grad = True

        err = err_fn(inpt)
        err.backward()

        grad = inpt.grad.detach()

        model.zero_grad()

        return torch.abs(grad.detach())

    @staticmethod
    def geco_beta_update(beta,
                         error_ema,
                         goal,
                         step_size,
                         min_clamp=1e-10,
                         max_clamp=1e4,
                         speedup=None):
        constraint = (error_ema - goal).detach()
        if speedup is not None and constraint > 0.0:
            beta = beta * torch.exp(speedup * step_size * constraint)
        else:
            beta = beta * torch.exp(step_size * constraint)
        if min_clamp is not None:
            beta = np.max((beta.item(), min_clamp))
        if max_clamp is not None:
            beta = np.min((beta.item(), max_clamp))
        return beta

    @staticmethod
    def get_ema(new, old, alpha):
        if old is None:
            return new
        return (1.0 - alpha) * new + alpha * old

    def print(self, *args):
        print(*args)
        self.tx.print(*args)

    def log_result(self, val, key=None):
        self.tx.print(key, val)
        self.tx.add_result_without_epoch(val, key)