Ejemplo n.º 1
0
    def __init__(
        self,
        input_shape,
        lr=1e-4,
        n_epochs=20,
        z_dim=512,
        model_feature_map_sizes=(16, 64, 256, 1024),
        use_geco=False,
        beta=0.01,
        ce_factor=0.5,
        score_mode="combi",
        load_path=None,
        log_dir=None,
        logger="visdom",
        print_every_iter=100,
        data_dir=None,
    ):

        self.score_mode = score_mode
        self.ce_factor = ce_factor
        self.beta = beta
        self.print_every_iter = print_every_iter
        self.n_epochs = n_epochs
        self.batch_size = input_shape[0]
        self.z_dim = z_dim
        self.use_geco = use_geco
        self.input_shape = input_shape
        self.logger = logger
        self.data_dir = data_dir

        log_dict = {}
        if logger is not None:
            log_dict = {
                0: (logger),
            }
        self.tx = PytorchExperimentStub(
            name="cevae",
            base_dir=log_dir,
            config=fn_args_as_config,
            loggers=log_dict,
        )

        cuda_available = torch.cuda.is_available()
        self.device = torch.device("cuda" if cuda_available else "cpu")

        self.model = VAE(input_size=input_shape[1:],
                         z_dim=z_dim,
                         fmap_sizes=model_feature_map_sizes).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)

        self.vae_loss_ema = 1
        self.theta = 1

        if load_path is not None:
            PytorchExperimentLogger.load_model_static(
                self.model, os.path.join(load_path, "vae_final.pth"))
            time.sleep(5)
Ejemplo n.º 2
0
    def __init__(self, basic_kws, train_kws):
        self.__dict__.update(basic_kws)
        self.__dict__.update(train_kws)

        log_dict = {}
        if self.logger is not None: log_dict = {0: self.logger}
        self.tx = PytorchExperimentStub(
            name=self.model_type,
            base_dir=self.log_dir,
            config=None,
            loggers=log_dict,
        )
Ejemplo n.º 3
0
    def __init__(
        self,
        input_shape,
        lr=1e-4,
        n_epochs=20,
        z_dim=512,
        model_feature_map_sizes=(16, 64, 256, 1024),
        load_path=None,
        log_dir=None,
        logger="visdom",
        print_every_iter=100,
        data_dir=None,
    ):

        self.print_every_iter = print_every_iter
        self.n_epochs = n_epochs
        self.batch_size = input_shape[0]
        self.z_dim = z_dim
        self.input_shape = input_shape
        self.logger = logger
        self.data_dir = data_dir

        log_dict = {}
        if logger is not None:
            log_dict = {
                0: (logger),
            }
        self.tx = PytorchExperimentStub(
            name="ae3d",
            base_dir=log_dir,
            config=fn_args_as_config,
            loggers=log_dict,
        )

        cuda_available = torch.cuda.is_available()
        self.device = torch.device("cuda" if cuda_available else "cpu")

        self.model = AE(
            input_size=input_shape[1:],
            z_dim=z_dim,
            fmap_sizes=model_feature_map_sizes,
            conv_op=torch.nn.Conv3d,
            tconv_op=torch.nn.ConvTranspose3d,
        ).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)

        if load_path is not None:
            PytorchExperimentLogger.load_model_static(
                self.model, os.path.join(load_path, "ae_final.pth"))
            time.sleep(5)
Ejemplo n.º 4
0
class Algorithm:
    def __init__(self, basic_kws, train_kws):
        self.__dict__.update(basic_kws)
        self.__dict__.update(train_kws)

        log_dict = {}
        if self.logger is not None: log_dict = {0: self.logger}
        self.tx = PytorchExperimentStub(
            name=self.model_type,
            base_dir=self.log_dir,
            config=None,
            loggers=log_dict,
        )

    def train(self):
        print('train')

        n_items = None
        train_loader = get_numpy2d_dataset(
            base_dir=self.train_data_dir,
            num_processes=self.batch_size,
            pin_memory=True,
            batch_size=self.batch_size,
            mode="train",
            target_size=self.target_size,
            drop_last=True,
            n_items=n_items,
            functions_dict=self.dataset_functions,
        )
        val_loader = get_numpy2d_dataset(
            base_dir=self.train_data_dir,
            num_processes=self.batch_size // 2,
            pin_memory=True,
            batch_size=self.batch_size,
            mode="val",
            target_size=self.target_size,
            drop_last=True,
            n_items=n_items,
            functions_dict=self.dataset_functions,
        )

        for epoch in range(self.n_epochs):
            self.model.train()
            train_loss = 0

            data_loader_ = tqdm(enumerate(train_loader))
            for batch_idx, data in data_loader_:
                loss = self.train_model(data)

                train_loss += loss.item()
                if batch_idx % self.print_every_iter == 0:
                    status_str = (
                        f"Train Epoch: {epoch} [{batch_idx}/{len(train_loader)} "
                        f" ({100.0 * batch_idx / len(train_loader):.0f}%)] Loss: "
                        f"{loss.item():.6f}")
                    data_loader_.set_description_str(status_str)

                    cnt = epoch * len(train_loader) + batch_idx
                    self.tx.add_result(loss.item(),
                                       name="Train-Loss",
                                       tag="Losses",
                                       counter=cnt)

                    # if self.logger is not None:
                    # self.tx.l[0].show_image_grid(inpt, name="Input", image_args={"normalize": True})
                    # self.tx.l[0].show_image_grid(inpt_rec, name="Reconstruction", image_args={"normalize": True})

            print(
                f"====> Epoch: {epoch} Average loss: {train_loss / len(train_loader):.6f}"
            )

            # validate
            self.model.eval()

            val_loss = 0
            data_loader_ = tqdm(enumerate(val_loader))
            data_loader_.set_description_str("Validating")
            for _, data in data_loader_:
                loss = self.eval_model(data)
                val_loss += loss.item()

            self.tx.add_result(val_loss / len(val_loader),
                               name="Val-Loss",
                               tag="Losses",
                               counter=(epoch + 1) * len(train_loader))
            print(
                f"====> Epoch: {epoch} Validation loss: {val_loss / len(val_loader):.6f}"
            )

        self.tx.save_model(self.model, "model")
        time.sleep(2)

    def train_model(self, data):
        input, label = self.get_input_label(data)
        loss = self.calculate_loss(self.model, input, label)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss

    def eval_model(self, data):
        with torch.no_grad():
            input, label = self.get_input_label(data)
            loss = self.calculate_loss(self.model, input, label)
        return loss

    def predict(self):
        print('predict')
        from .nifti_io import ni_save, ni_load
        from .function import save_images, init_validation_dir

        test_dir = self.test_dir
        _, pred_pixel_dir, pred_sample_dir = init_validation_dir(
            algo_name=self.name, dataset_dir=test_dir)
        test_dir = os.path.join(test_dir, 'data')

        test_dir_list = os.listdir(test_dir)
        length = len(test_dir_list)
        handle = tqdm(enumerate(test_dir_list))

        self.model.eval()
        for i, f_name in handle:
            ni_file_path = os.path.join(test_dir, f_name)
            ni_data, ni_aff = ni_load(ni_file_path)

            # pixel
            result = self.score_pixel_2d(ni_data, return_rec=REC)
            save_images(pred_pixel_dir,
                        f_name,
                        ni_aff,
                        score=result['score'],
                        ori=result['ori'],
                        rec=result['rec'])

            # sample
            sample_score = self.score_sample_2d(ni_data)
            with open(os.path.join(pred_sample_dir, f_name + ".txt"),
                      "w") as target_file:
                target_file.write(str(sample_score))

            handle.set_description_str(f'predict: {i+1}/{length}')

    def validate(self):
        print('validate')
        from .function import init_validation_dir
        from scripts.evalresults import eval_dir

        test_dir = self.test_dir
        score_dir, pred_pixel_dir, pred_sample_dir = init_validation_dir(
            algo_name=self.name, dataset_dir=test_dir)

        # pixel
        pred_pixel_dir = os.path.join(pred_pixel_dir, 'score')
        eval_dir(pred_dir=pred_pixel_dir,
                 label_dir=os.path.join(test_dir, 'label', 'pixel'),
                 mode='pixel',
                 save_file=os.path.join(score_dir, 'pixel'))

        # sample
        eval_dir(pred_dir=pred_sample_dir,
                 label_dir=os.path.join(test_dir, 'label', 'sample'),
                 mode='sample',
                 save_file=os.path.join(score_dir, 'sample'))

    def statistics(self):
        print('statistics')
        from .nifti_io import ni_load, ni_save
        import matplotlib.pyplot as plt
        import numpy as np

        test_dir = self.test_dir
        predict_dir = os.path.join(test_dir, 'eval', self.name, 'predict')
        assert os.path.exists(predict_dir), '先预测,再统计'

        statistics_dir = os.path.join(predict_dir, 'statistics')
        if not os.path.exists(statistics_dir):
            os.mkdir(statistics_dir)

        for file_name in tqdm(os.listdir(os.path.join(test_dir, 'data'))):
            prefix = file_name.split('.')[0]
            each_statistics_dir = os.path.join(statistics_dir, prefix)
            if not os.path.exists(each_statistics_dir):
                os.mkdir(each_statistics_dir)

            score, ni_aff = ni_load(
                os.path.join(predict_dir, 'pixel', 'score', file_name))
            flatten_score = score.flatten()

            # 整体打分直方图
            plt.hist(flatten_score, bins=50, log=True)
            plt.savefig(
                os.path.join(each_statistics_dir, 'whole_score_histogram'))
            plt.cla()

            with open(
                    os.path.join(test_dir, 'label', 'sample',
                                 file_name + '.txt'), "r") as f:
                sample_label = f.readline()

            if sample_label == '1':
                # 异常区域打分直方图
                label, _ = ni_load(
                    os.path.join(test_dir, 'label', 'pixel', file_name))
                abnormal_area_score = score[label == 1]
                plt.hist(abnormal_area_score, bins=50, log=True)
                plt.savefig(
                    os.path.join(each_statistics_dir,
                                 'abnormal_area_score_histogram'))
                plt.cla()

                abnormal_number = len(abnormal_area_score)
                print(f'abnormal_number: {abnormal_number}')
            elif sample_label == '0':
                abnormal_number = 10000
            else:
                raise Exception(f'sample_label有问题: {sample_label}')

            # 打分最高区域
            img = score.copy()
            ordered_flatten_score = np.sort(flatten_score)[::-1]
            threshold = ordered_flatten_score[abnormal_number]
            img[img > threshold] = 1
            img[img <= threshold] = 0
            ni_save(os.path.join(each_statistics_dir, 'highest_score'), img,
                    ni_aff)

    def score_pixel_2d(self,
                       np_array,
                       return_score=True,
                       return_ori=False,
                       return_rec=False):
        score = None
        ori = None
        rec = None
        np_array = self.transpose(np_array)

        ori_shape = np_array.shape
        to_transforms = torch.nn.Upsample((self.target_size, self.target_size),
                                          mode="bilinear")
        from_transforms = torch.nn.Upsample((ori_shape[1], ori_shape[2]),
                                            mode="bilinear")

        data_tensor = torch.from_numpy(np_array).float()
        data_tensor = to_transforms(data_tensor[None])[0]

        score_tensor, rec_tensor = self.get_pixel_score(
            self.model, data_tensor)

        if return_score:
            score_tensor = from_transforms(score_tensor[None])[0]
            score = score_tensor.detach().numpy()
            score = self.revert_transpose(score)
        if return_ori:
            data_tensor = from_transforms(data_tensor[None])[0]
            ori = data_tensor.detach().numpy()
            ori = self.revert_transpose(ori)
        if return_rec:
            rec_tensor = from_transforms(rec_tensor[None])[0]
            rec = rec_tensor.detach().numpy()
            rec = self.revert_transpose(rec)

        return {'score': score, 'ori': ori, 'rec': rec}

    def score_sample_2d(self, np_array):
        to_transforms = torch.nn.Upsample((self.target_size, self.target_size),
                                          mode="bilinear")

        data_tensor = torch.from_numpy(np_array).float()
        data_tensor = to_transforms(data_tensor[None])[0]

        sample_score = self.get_sample_score(self.model, data_tensor)
        return sample_score
Ejemplo n.º 5
0
class AE2D:
    @monkey_patch_fn_args_as_config
    def __init__(
        self,
        input_shape,
        lr=1e-4,
        n_epochs=20,
        z_dim=512,
        model_feature_map_sizes=(16, 64, 256, 1024),
        load_path=None,
        log_dir=None,
        logger="visdom",
        print_every_iter=100,
        data_dir=None,
    ):

        self.print_every_iter = print_every_iter
        self.n_epochs = n_epochs
        self.batch_size = input_shape[0]
        self.z_dim = z_dim
        self.input_shape = input_shape
        self.logger = logger
        self.data_dir = data_dir

        log_dict = {}
        if logger is not None:
            log_dict = {
                0: (logger),
            }
        self.tx = PytorchExperimentStub(
            name="ae2d",
            base_dir=log_dir,
            config=fn_args_as_config,
            loggers=log_dict,
        )

        cuda_available = torch.cuda.is_available()
        self.device = torch.device("cuda" if cuda_available else "cpu")

        self.model = AE(input_size=input_shape[1:],
                        z_dim=z_dim,
                        fmap_sizes=model_feature_map_sizes).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)

        if load_path is not None:
            PytorchExperimentLogger.load_model_static(
                self.model, os.path.join(load_path, "ae_final.pth"))
            time.sleep(5)

    def train(self):

        train_loader = get_numpy2d_dataset(
            base_dir=self.data_dir,
            num_processes=16,
            pin_memory=True,
            batch_size=self.batch_size,
            mode="train",
            target_size=self.input_shape[2],
        )
        val_loader = get_numpy2d_dataset(
            base_dir=self.data_dir,
            num_processes=8,
            pin_memory=True,
            batch_size=self.batch_size,
            mode="val",
            target_size=self.input_shape[2],
        )

        for epoch in range(self.n_epochs):

            ### Train
            self.model.train()

            train_loss = 0
            print("\nStart epoch ", epoch)
            data_loader_ = tqdm(enumerate(train_loader))
            for batch_idx, data in data_loader_:
                inpt = data.to(self.device)

                self.optimizer.zero_grad()
                inpt_rec = self.model(inpt)

                loss = torch.mean(torch.pow(inpt - inpt_rec, 2))
                loss.backward()
                self.optimizer.step()

                train_loss += loss.item()
                if batch_idx % self.print_every_iter == 0:
                    status_str = (
                        f"Train Epoch: {epoch} [{batch_idx}/{len(train_loader)} "
                        f" ({100.0 * batch_idx / len(train_loader):.0f}%)] Loss: "
                        f"{loss.item() / len(inpt):.6f}")
                    data_loader_.set_description_str(status_str)

                    cnt = epoch * len(train_loader) + batch_idx
                    self.tx.add_result(loss.item(),
                                       name="Train-Loss",
                                       tag="Losses",
                                       counter=cnt)

                    if self.logger is not None:
                        self.tx.l[0].show_image_grid(
                            inpt, name="Input", image_args={"normalize": True})
                        self.tx.l[0].show_image_grid(
                            inpt_rec,
                            name="Reconstruction",
                            image_args={"normalize": True})

            print(
                f"====> Epoch: {epoch} Average loss: {train_loss / len(train_loader):.4f}"
            )

            ### Validate
            self.model.eval()

            val_loss = 0
            with torch.no_grad():
                data_loader_ = tqdm(enumerate(val_loader))
                data_loader_.set_description_str("Validating")
                for i, data in data_loader_:
                    inpt = data.to(self.device)
                    inpt_rec = self.model(inpt)

                    loss = torch.mean(torch.pow(inpt - inpt_rec, 2))
                    val_loss += loss.item()

                self.tx.add_result(val_loss / len(val_loader),
                                   name="Val-Loss",
                                   tag="Losses",
                                   counter=(epoch + 1) * len(train_loader))

            print(
                f"====> Epoch: {epoch} Validation loss: {val_loss / len(val_loader):.4f}"
            )

        self.tx.save_model(self.model, "ae_final")

        time.sleep(10)

    def score_sample(self, np_array):

        orig_shape = np_array.shape
        to_transforms = torch.nn.Upsample(
            (self.input_shape[2], self.input_shape[3]), mode="bilinear")

        data_tensor = torch.from_numpy(np_array).float()
        data_tensor = to_transforms(data_tensor[None])[0]
        slice_scores = []

        for i in range(ceil(orig_shape[0] / self.batch_size)):
            batch = data_tensor[i * self.batch_size:(i + 1) *
                                self.batch_size].unsqueeze(1)
            batch = batch.to(self.device)

            with torch.no_grad():
                batch_rec = self.model(batch)
                loss = torch.mean(torch.pow(batch - batch_rec, 2),
                                  dim=(1, 2, 3))

            slice_scores += loss.cpu().tolist()

        return np.max(slice_scores)

    def score_pixels(self, np_array):

        orig_shape = np_array.shape
        to_transforms = torch.nn.Upsample(
            (self.input_shape[2], self.input_shape[3]), mode="bilinear")
        from_transforms = torch.nn.Upsample((orig_shape[1], orig_shape[2]),
                                            mode="bilinear")

        data_tensor = torch.from_numpy(np_array).float()
        data_tensor = to_transforms(data_tensor[None])[0]
        target_tensor = torch.zeros_like(data_tensor)

        for i in range(ceil(orig_shape[0] / self.batch_size)):
            batch = data_tensor[i * self.batch_size:(i + 1) *
                                self.batch_size].unsqueeze(1)
            batch = batch.to(self.device)

            batch_rec = self.model(batch)

            loss = torch.pow(batch - batch_rec, 2)[:, 0, :]
            target_tensor[i * self.batch_size:(i + 1) *
                          self.batch_size] = loss.cpu()

        target_tensor = from_transforms(target_tensor[None])[0]

        return target_tensor.detach().numpy()

    def print(self, *args):
        print(*args)
        self.tx.print(*args)
Ejemplo n.º 6
0
    def __init__(
        self,
        input_shape,
        lr=1e-4,
        critic_iters=1,
        gen_iters=5,
        n_epochs=10,
        gp_lambda=10,
        z_dim=512,
        print_every_iter=20,
        plot_every_epoch=1,
        log_dir=None,
        load_path=None,
        logger="visdom",
        data_dir=None,
        use_encoder=True,
        enocoder_feature_weight=1e-4,
        encoder_discr_weight=0.0,
    ):

        self.plot_every_epoch = plot_every_epoch
        self.print_every_iter = print_every_iter
        self.gp_lambda = gp_lambda
        self.n_epochs = n_epochs
        self.gen_iters = gen_iters
        self.critic_iters = critic_iters
        self.size = input_shape[2]
        self.batch_size = input_shape[0]
        self.input_shape = input_shape
        self.z_dim = z_dim
        self.logger = logger
        self.data_dir = data_dir
        self.use_encoder = use_encoder
        self.enocoder_feature_weight = enocoder_feature_weight
        self.encoder_discr_weight = encoder_discr_weight

        log_dict = {}
        if logger is not None:
            log_dict = {
                0: (logger),
            }
        self.tx = PytorchExperimentStub(
            name="fanogan",
            base_dir=log_dir,
            config=fn_args_as_config,
            loggers=log_dict,
        )

        cuda_available = torch.cuda.is_available()
        self.device = torch.device("cuda" if cuda_available else "cpu")

        self.n_image_channels = input_shape[1]

        self.gen = IWGenerator(self.size,
                               z_dim=z_dim,
                               n_image_channels=self.n_image_channels)
        self.dis = IWDiscriminator(self.size,
                                   n_image_channels=self.n_image_channels)

        self.gen.apply(weights_init)
        self.dis.apply(weights_init)

        self.optimizer_G = torch.optim.Adam(self.gen.parameters(),
                                            lr=lr,
                                            betas=(0.5, 0.999))
        self.optimizer_D = torch.optim.Adam(self.dis.parameters(),
                                            lr=lr,
                                            betas=(0.5, 0.999))

        self.gen = self.gen.to(self.device)
        self.dis = self.dis.to(self.device)

        if self.use_encoder:
            self.enc = IWEncoder(self.size,
                                 z_dim=z_dim,
                                 n_image_channels=self.n_image_channels)
            self.enc.apply(weights_init)
            self.enc = self.enc.to(self.device)
            self.optimizer_E = torch.optim.Adam(self.enc.parameters(),
                                                lr=lr,
                                                betas=(0.5, 0.999))

        self.z = torch.randn(self.batch_size, z_dim).to(self.device)

        if load_path is not None:
            PytorchExperimentLogger.load_model_static(
                self.dis, os.path.join(load_path, "dis_final.pth"))
            PytorchExperimentLogger.load_model_static(
                self.gen, os.path.join(load_path, "gen_final.pth"))
            if self.use_encoder:
                try:
                    pass
                    # PytorchExperimentLogger.load_model_static(self.enc, os.path.join(load_path, "enc_final.pth"))
                except Exception:
                    warnings.warn("Could not find an Encoder in the directory")
            time.sleep(5)
Ejemplo n.º 7
0
class fAnoGAN:
    @monkey_patch_fn_args_as_config
    def __init__(
        self,
        input_shape,
        lr=1e-4,
        critic_iters=1,
        gen_iters=5,
        n_epochs=10,
        gp_lambda=10,
        z_dim=512,
        print_every_iter=20,
        plot_every_epoch=1,
        log_dir=None,
        load_path=None,
        logger="visdom",
        data_dir=None,
        use_encoder=True,
        enocoder_feature_weight=1e-4,
        encoder_discr_weight=0.0,
    ):

        self.plot_every_epoch = plot_every_epoch
        self.print_every_iter = print_every_iter
        self.gp_lambda = gp_lambda
        self.n_epochs = n_epochs
        self.gen_iters = gen_iters
        self.critic_iters = critic_iters
        self.size = input_shape[2]
        self.batch_size = input_shape[0]
        self.input_shape = input_shape
        self.z_dim = z_dim
        self.logger = logger
        self.data_dir = data_dir
        self.use_encoder = use_encoder
        self.enocoder_feature_weight = enocoder_feature_weight
        self.encoder_discr_weight = encoder_discr_weight

        log_dict = {}
        if logger is not None:
            log_dict = {
                0: (logger),
            }
        self.tx = PytorchExperimentStub(
            name="fanogan",
            base_dir=log_dir,
            config=fn_args_as_config,
            loggers=log_dict,
        )

        cuda_available = torch.cuda.is_available()
        self.device = torch.device("cuda" if cuda_available else "cpu")

        self.n_image_channels = input_shape[1]

        self.gen = IWGenerator(self.size,
                               z_dim=z_dim,
                               n_image_channels=self.n_image_channels)
        self.dis = IWDiscriminator(self.size,
                                   n_image_channels=self.n_image_channels)

        self.gen.apply(weights_init)
        self.dis.apply(weights_init)

        self.optimizer_G = torch.optim.Adam(self.gen.parameters(),
                                            lr=lr,
                                            betas=(0.5, 0.999))
        self.optimizer_D = torch.optim.Adam(self.dis.parameters(),
                                            lr=lr,
                                            betas=(0.5, 0.999))

        self.gen = self.gen.to(self.device)
        self.dis = self.dis.to(self.device)

        if self.use_encoder:
            self.enc = IWEncoder(self.size,
                                 z_dim=z_dim,
                                 n_image_channels=self.n_image_channels)
            self.enc.apply(weights_init)
            self.enc = self.enc.to(self.device)
            self.optimizer_E = torch.optim.Adam(self.enc.parameters(),
                                                lr=lr,
                                                betas=(0.5, 0.999))

        self.z = torch.randn(self.batch_size, z_dim).to(self.device)

        if load_path is not None:
            PytorchExperimentLogger.load_model_static(
                self.dis, os.path.join(load_path, "dis_final.pth"))
            PytorchExperimentLogger.load_model_static(
                self.gen, os.path.join(load_path, "gen_final.pth"))
            if self.use_encoder:
                try:
                    pass
                    # PytorchExperimentLogger.load_model_static(self.enc, os.path.join(load_path, "enc_final.pth"))
                except Exception:
                    warnings.warn("Could not find an Encoder in the directory")
            time.sleep(5)

    def train(self):

        train_loader = get_numpy2d_dataset(
            base_dir=self.data_dir,
            num_processes=16,
            pin_memory=False,
            batch_size=self.batch_size,
            mode="train",
            target_size=self.size,
            slice_offset=10,
        )

        print("Training GAN...")
        for epoch in range(self.n_epochs):
            # for epoch in range(0):

            data_loader_ = tqdm(enumerate(train_loader))
            for i, batch in data_loader_:
                batch = batch * 2 - 1 + torch.randn_like(batch) * 0.01

                real_imgs = batch.to(self.device)

                # ---------------------
                #  Train Discriminator
                # ---------------------
                # disc_cost = []
                # w_dist = []
                if i % self.critic_iters == 0:
                    self.optimizer_G.zero_grad()
                    self.optimizer_D.zero_grad()

                    batch_size_curr = real_imgs.shape[0]

                    self.z.normal_()

                    fake_imgs = self.gen(self.z[:batch_size_curr])

                    real_validity = self.dis(real_imgs)
                    fake_validity = self.dis(fake_imgs)

                    gradient_penalty = self.calc_gradient_penalty(
                        self.dis,
                        real_imgs,
                        fake_imgs,
                        batch_size_curr,
                        self.size,
                        self.device,
                        self.gp_lambda,
                        n_image_channels=self.n_image_channels,
                    )

                    d_loss = -torch.mean(real_validity) + torch.mean(
                        fake_validity) + self.gp_lambda * gradient_penalty
                    d_loss.backward()
                    self.optimizer_D.step()

                    # disc_cost.append(d_loss.item())
                    w_dist = (-torch.mean(real_validity) +
                              torch.mean(fake_validity)).item()

                # -----------------
                #  Train Generator
                # -----------------
                # gen_cost = []
                if i % self.gen_iters == 0:
                    self.optimizer_G.zero_grad()
                    self.optimizer_D.zero_grad()

                    batch_size_curr = self.batch_size

                    fake_imgs = self.gen(self.z)

                    fake_validity = self.dis(fake_imgs)
                    g_loss = -torch.mean(fake_validity)

                    g_loss.backward()
                    self.optimizer_G.step()

                    # gen_cost.append(g_loss.item())

                if i % self.print_every_iter == 0:
                    status_str = (
                        f"Train Epoch: {epoch} [{i}/{len(train_loader)} "
                        f" ({100.0 * i / len(train_loader):.0f}%)] Dis: "
                        f"{d_loss.item() / batch_size_curr:.6f} vs Gen: "
                        f"{g_loss.item() / batch_size_curr:.6f} (W-Dist: {w_dist / batch_size_curr:.6f})"
                    )
                    data_loader_.set_description_str(status_str)
                    # print(f"[Epoch {epoch}/{self.n_epochs}] [Batch {i}/{len(train_loader)}]")

                    # print(d_loss.item(), g_loss.item())
                    cnt = epoch * len(train_loader) + i

                    self.tx.add_result(d_loss.item(),
                                       name="trainDisCost",
                                       tag="DisVsGen",
                                       counter=cnt)
                    self.tx.add_result(g_loss.item(),
                                       name="trainGenCost",
                                       tag="DisVsGen",
                                       counter=cnt)
                    self.tx.add_result(w_dist,
                                       "wasserstein_distance",
                                       counter=cnt)

                    self.tx.l[0].show_image_grid(
                        fake_imgs.reshape(batch_size_curr,
                                          self.n_image_channels, self.size,
                                          self.size),
                        "GeneratedImages",
                        image_args={"normalize": True},
                    )

        self.tx.save_model(self.dis, "dis_final")
        self.tx.save_model(self.gen, "gen_final")

        self.gen.train(True)
        self.dis.train(True)

        if not self.use_encoder:
            time.sleep(10)
            return

        weight_features = self.enocoder_feature_weight
        weight_disc = self.encoder_discr_weight
        print("Training Encoder...")
        for epoch in range(self.n_epochs // 2):
            data_loader_ = tqdm(enumerate(train_loader))
            for i, batch in data_loader_:
                batch = batch * 2 - 1 + torch.randn_like(batch) * 0.01
                real_img = batch.to(self.device)
                batch_size_curr = real_img.shape[0]

                self.optimizer_G.zero_grad()
                self.optimizer_D.zero_grad()
                self.optimizer_E.zero_grad()

                z = self.enc(real_img)
                recon_img = self.gen(z)

                _, img_feats = self.dis.forward_last_feature(real_img)
                disc_loss, recon_feats = self.dis.forward_last_feature(
                    recon_img)

                recon_img = recon_img.reshape(batch_size_curr,
                                              self.n_image_channels, self.size,
                                              self.size)
                loss_img = self.mse(real_img, recon_img)
                loss_feat = self.mse(img_feats, recon_feats) * weight_features
                disc_loss = -torch.mean(disc_loss) * weight_disc

                loss = loss_img + loss_feat + disc_loss

                loss.backward()
                self.optimizer_E.step()

                if i % self.print_every_iter == 0:
                    status_str = (
                        f"[Epoch {epoch}/{self.n_epochs // 2}] [Batch {i}/{len(train_loader)}] Loss:{loss:.06f}"
                    )
                    data_loader_.set_description_str(status_str)

                    cnt = epoch * len(train_loader) + i
                    self.tx.add_result(loss.item(),
                                       name="EncoderLoss",
                                       counter=cnt)

                    self.tx.l[0].show_image_grid(
                        real_img.reshape(batch_size_curr,
                                         self.n_image_channels, self.size,
                                         self.size),
                        "RealImages",
                        image_args={"normalize": True},
                    )
                    self.tx.l[0].show_image_grid(
                        recon_img.reshape(batch_size_curr,
                                          self.n_image_channels, self.size,
                                          self.size),
                        "ReconImages",
                        image_args={"normalize": True},
                    )

        self.tx.save_model(self.enc, "enc_final")
        self.enc.train(False)

        time.sleep(10)

    def score_sample(self, np_array):

        orig_shape = np_array.shape
        to_transforms = torch.nn.Upsample(
            (self.input_shape[2], self.input_shape[3]), mode="bilinear")

        data_tensor = torch.from_numpy(np_array).float()
        data_tensor = to_transforms(data_tensor[None])[0]
        slice_scores = []

        for i in range(ceil(orig_shape[0] / self.batch_size)):
            batch = data_tensor[i * self.batch_size:(i + 1) *
                                self.batch_size].unsqueeze(1)
            batch = batch * 2 - 1
            real_imgs = batch.to(self.device)
            batch_size_curr = real_imgs.shape[0]

            if self.use_encoder:
                z = self.enc(real_imgs)
            else:
                z = self.backprop_to_nearest_z(real_imgs)

            pseudo_img_recon = self.gen(z)

            pseudo_img_recon = pseudo_img_recon.reshape(
                batch_size_curr, self.n_image_channels, self.size, self.size)
            img_diff = torch.mean(torch.abs(pseudo_img_recon - real_imgs),
                                  dim=1,
                                  keepdim=True)

            loss = torch.sum(img_diff, dim=(1, 2, 3)).detach()

            slice_scores += loss.cpu().tolist()

        return np.max(slice_scores)

    def score_pixels(self, np_array):

        orig_shape = np_array.shape
        to_transforms = torch.nn.Upsample(
            (self.input_shape[2], self.input_shape[3]), mode="bilinear")
        from_transforms = torch.nn.Upsample((orig_shape[1], orig_shape[2]),
                                            mode="bilinear")

        data_tensor = torch.from_numpy(np_array).float()
        data_tensor = to_transforms(data_tensor[None])[0]
        target_tensor = torch.zeros_like(data_tensor)

        for i in range(ceil(orig_shape[0] / self.batch_size)):
            batch = data_tensor[i * self.batch_size:(i + 1) *
                                self.batch_size].unsqueeze(1)
            batch = batch * 2 - 1
            real_imgs = batch.to(self.device)
            batch_size_curr = real_imgs.shape[0]

            if self.use_encoder:
                z = self.enc(real_imgs)
            else:
                z = self.backprop_to_nearest_z(real_imgs)

            pseudo_img_recon = self.gen(z)

            pseudo_img_recon = pseudo_img_recon.reshape(
                batch_size_curr, self.n_image_channels, self.size, self.size)
            img_diff = torch.mean(torch.abs(pseudo_img_recon - real_imgs),
                                  dim=1,
                                  keepdim=True)

            loss = img_diff[:, 0, :]
            target_tensor[i * self.batch_size:(i + 1) *
                          self.batch_size] = loss.cpu()

        target_tensor = from_transforms(target_tensor[None])[0]

        return target_tensor.detach().numpy()

    def backprop_to_nearest_z(self, real_imgs):

        batch_size_curr = real_imgs.shape[0]

        z = torch.randn(batch_size_curr, self.z_dim).to(self.device).normal_()
        z.requires_grad = True
        # optimizer_z = torch.optim.LBFGS([z], lr=0.02)
        optimizer_z = torch.optim.Adam([z], lr=0.002)
        # optimizer_z = torch.optim.RMSprop([z], lr=0.05)

        for i in range(200):

            def closure():
                self.gen.zero_grad()
                optimizer_z.zero_grad()

                pseudo_img_recon = self.gen(z)

                _, img_feats = self.dis.forward_last_feature(real_imgs)
                disc_loss, recon_feats = self.dis.forward_last_feature(
                    pseudo_img_recon)

                pseudo_img_recon = pseudo_img_recon.reshape(
                    batch_size_curr, self.n_image_channels, self.size,
                    self.size)
                disc_loss = torch.mean(disc_loss)

                imgs_diff = torch.mean(torch.abs(pseudo_img_recon - real_imgs))
                feats_diff = torch.mean(torch.abs(img_feats - recon_feats))
                loss = imgs_diff - disc_loss * 0.001  # + feats_diff

                loss.backward()

                return loss

            optimizer_z.step(closure)

        return z.detach()

    def score(self, batch):
        real_imgs = batch.to(self.device).float()

        z = self.enc(real_imgs)

        batch_size_curr = real_imgs.shape[0]

        # z = torch.randn(batch_size_curr, self.z_dim).to(self.device).normal_()
        # z.requires_grad = True
        # # optimizer_z = torch.optim.LBFGS([z], lr=0.02)
        # optimizer_z = torch.optim.Adam([z], lr=0.002)
        # # optimizer_z = torch.optim.RMSprop([z], lr=0.05)
        #
        # cn = dict(tr=0)
        #
        # self.tx.vlog.show_image_grid(real_imgs, "RealImages",
        #                              image_args={"normalize": True})
        #
        # for i in range(200):
        #     def closure():
        #         self.gen.zero_grad()
        #         optimizer_z.zero_grad()
        #
        #         pseudo_img_recon = self.gen(z)
        #
        #         _, img_feats = self.dis.forward_last_feature(real_imgs)
        #         disc_loss, recon_feats = self.dis.forward_last_feature(pseudo_img_recon)
        #
        #         pseudo_img_recon = pseudo_img_recon.reshape(batch_size_curr, self.n_image_channels, self.size, self.size)
        #         disc_loss = torch.mean(disc_loss)
        #
        #         imgs_diff = torch.mean(torch.abs(pseudo_img_recon - real_imgs))
        #         feats_diff = torch.mean(torch.abs(img_feats - recon_feats))
        #         loss = imgs_diff - disc_loss * 0.001  # + feats_diff
        #
        #         loss.backward()
        #         # optimizer_z.step()
        #         #
        #         # if cn['tr'] % 20 == 0:
        #         # pseudo_img_recon = pseudo_img_recon.clamp(-1.5, 1.5)
        #         self.tx.vlog.show_image_grid(pseudo_img_recon, "PseudoImages",
        #                                      image_args={"normalize": True})
        #         self.tx.vlog.show_image_grid(torch.mean(torch.abs(pseudo_img_recon - real_imgs), dim=1, keepdim=True),
        #                                      "DiffImages", image_args={"normalize": True})
        #         #
        #         # tx.add_result(disc_loss.item() * 0.001, name="DiscLoss", tag="AnoIter")
        #         # tx.add_result(imgs_diff.item(), name="ImgsDiff", tag="AnoIter")
        #         # tx.add_result(torch.mean(torch.pow(z, 2)).item(), name="ZDevi", tag="AnoIter")
        #         #
        #         # cn['tr'] += 1
        #
        #         return loss
        #
        #     optimizer_z.step(closure)
        #
        #     # time.sleep(1)
        #
        #     print(i)
        #
        pseudo_img_recon = self.gen(z)

        pseudo_img_recon = pseudo_img_recon.reshape(batch_size_curr,
                                                    self.n_image_channels,
                                                    self.size, self.size)
        img_diff = torch.mean(torch.abs(pseudo_img_recon - real_imgs),
                              dim=1,
                              keepdim=True)

        img_scores = torch.sum(img_diff, dim=(1, 2, 3)).detach().tolist()
        pixel_scores = img_diff.flatten().detach().tolist()

        self.tx.vlog.show_image_grid(pseudo_img_recon,
                                     "PseudoImages",
                                     image_args={"normalize": True})
        self.tx.vlog.show_image_grid(
            torch.mean(torch.abs(pseudo_img_recon - real_imgs),
                       dim=1,
                       keepdim=True),
            "DiffImages",
            image_args={"normalize": True},
        )

        # print("One Down")

        return img_scores, pixel_scores

    @staticmethod
    def mse(x, y):
        return torch.mean(torch.pow(x - y, 2))

    @staticmethod
    def calc_gradient_penalty(netD,
                              real_data,
                              fake_data,
                              batch_size,
                              dim,
                              device,
                              gp_lambda,
                              n_image_channels=3):
        alpha = torch.rand(batch_size, 1)
        alpha = alpha.expand(batch_size, int(real_data.nelement() /
                                             batch_size)).contiguous()
        alpha = alpha.view(batch_size, n_image_channels, dim, dim)
        alpha = alpha.to(device)

        fake_data = fake_data.view(batch_size, n_image_channels, dim, dim)
        interpolates = alpha * real_data.detach() + (
            (1 - alpha) * fake_data.detach())

        interpolates = interpolates.to(device)
        interpolates.requires_grad_(True)

        disc_interpolates = netD(interpolates)

        gradients = torch.autograd.grad(
            outputs=disc_interpolates,
            inputs=interpolates,
            grad_outputs=torch.ones(disc_interpolates.size()).to(device),
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]

        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = (
            (gradients.norm(2, dim=1) - 1)**2).mean() * gp_lambda
        return gradient_penalty

    def print(self, *args):
        print(*args)
        self.tx.print(*args)

    def log_result(self, val, key=None):
        self.tx.print(key, val)
        self.tx.add_result_without_epoch(val, key)
Ejemplo n.º 8
0
class Algorithm:
    def __init__(self, basic_kws, train_kws):
        self.__dict__.update(basic_kws)
        self.__dict__.update(train_kws)

        log_dict = {}
        if self.logger is not None: log_dict = {0: self.logger}
        self.tx = PytorchExperimentStub(
            name=self.name,
            base_dir=self.log_dir,
            config=None,
            loggers=log_dict,
        )

    def train(self):
        n_items = None
        train_loader = get_numpy2d_dataset(
            base_dir=self.train_data_dir,
            num_processes=self.batch_size,
            pin_memory=True,
            batch_size=self.batch_size,
            mode="all",
            # target_size=self.target_size,
            drop_last=False,
            n_items=n_items,
            functions_dict=self.dataset_functions,
        )
        # val_loader = get_numpy2d_dataset(
        #     base_dir=self.test_data_dir,
        #     num_processes=self.batch_size // 2,
        #     pin_memory=True,
        #     batch_size=self.batch_size,
        #     mode="all",
        #     # target_size=self.target_size,
        #     drop_last=False,
        #     n_items=n_items,
        #     functions_dict=self.dataset_functions,
        # )
        train_loader = DataPreFetcher(train_loader)
        # val_loader = DataPreFetcher(val_loader)

        for epoch in range(self.n_epochs):
            print('train')
            self.model.train()
            train_loss = 0

            data_loader_ = tqdm(enumerate(train_loader))
            # data_loader_ = enumerate(train_loader)
            for batch_idx, data in data_loader_:
                # data = data.cuda()
                loss, input, out = self.train_model(data)

                train_loss += loss.item()
                if batch_idx % self.print_every_iter == 0:
                    status_str = (
                        f"Train Epoch: {epoch + 1} [{batch_idx}/{len(train_loader)} "
                        f" ({100.0 * batch_idx / len(train_loader):.0f}%)] Loss: "
                        f"{loss.item():.6f}")
                    data_loader_.set_description_str(status_str)

                    cnt = epoch * len(train_loader) + batch_idx

                    # tensorboard记录
                    self.tx.add_result(loss.item(),
                                       name="Train-Loss",
                                       tag="Losses",
                                       counter=cnt)

                    # if self.logger is not None:
                    #     self.tx.l[0].show_image_grid(input, name="Input", image_args={"normalize": True})
                    #     self.tx.l[0].show_image_grid(out, name="Reconstruction", image_args={"normalize": True})

            print(
                f"====> Epoch: {epoch + 1} Average loss: {train_loss / len(train_loader):.6f}"
            )

            # print('validate')
            # self.model.eval()
            # val_loss = 0

            # data_loader_ = tqdm(enumerate(val_loader))
            # data_loader_.set_description_str("Validating")
            # for _, data in data_loader_:
            #     loss = self.eval_model(data)
            #     val_loss += loss.item()

            # self.tx.add_result(
            #     val_loss / len(val_loader), name="Val-Loss", tag="Losses", counter=(epoch + 1) * len(train_loader))
            # print(f"====> Epoch: {epoch + 1} Validation loss: {val_loss / len(val_loader):.6f}")

            # if (epoch + 1) % self.save_per_epoch == 0:
            if (epoch + 1) > self.n_epochs - 5:
                self.save_model(epoch + 1)

        time.sleep(2)

    def save_model(self, new_training_epoch):
        save_epoch = self.total_epoch + new_training_epoch
        path = os.path.join(self.tx.elog.work_dir, 'checkpoint',
                            f'{save_epoch}')
        save_dict = {
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'total_epoch': save_epoch,
            # 'score': score,
            # 'best_score': self.best_score,
        }
        torch.save(save_dict, path)

        # if score > best_score:
        #     self.best_score = score
        #     save_dict['best_score'] = score
        #     best_path = os.path.join(self.tx.elog.work_dir, 'checkpoint', 'best')
        #     torch.save(save_dict, best_path)

        #     log_path = os.path.join(self.tx.elog.work_dir, 'checkpoint', 'best_score_epoch.txt')
        #     with open(log_path, 'w') as target_file:
        #         target_file.write(f'{str(save_epoch)}')

    def load_model(self, path):
        load_dict = torch.load(path)
        self.model.load_state_dict(load_dict['model'])
        self.optimizer.load_state_dict(load_dict['optimizer'])
        self.total_epoch = load_dict['total_epoch']
        # self.best_score = load_dict['best_score']

    def train_model(self, data):
        input, label = self.get_input_label(data)
        loss, out = self.calculate_loss(self.model, input, label)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss, input, out

    def eval_model(self, data):
        with torch.no_grad():
            input, label = self.get_input_label(data)
            loss, _ = self.calculate_loss(self.model, input, label)
        return loss

    def predict(self, **kwargs):
        print('predict')
        from .nifti_io import ni_save, ni_load
        from .function import save_images, init_validation_dir

        test_dir = self.test_dir
        _, pred_pixel_dir, pred_sample_dir = init_validation_dir(
            algo_name=self.name, dataset_dir=test_dir)
        test_dir = os.path.join(test_dir, 'data')

        test_dir_list = os.listdir(test_dir)
        handle = tqdm(enumerate(test_dir_list))

        if 'num' in kwargs.keys():
            num = kwargs['num']
            kwargs.pop('num')
        else:
            num = None
        length = num if num is not None else len(test_dir_list)

        # return_rec = True

        self.model.eval()
        for i, f_name in handle:
            handle.set_description_str(f'predict: {i}/{length}')

            # if not f_name.startswith('n2'): continue
            if num is not None:
                if i == num: break

            ni_file_path = os.path.join(test_dir, f_name)
            ni_data, ni_aff = ni_load(ni_file_path)

            # pixel
            result = self.score_pixel_2d(ni_data, **kwargs)
            save_images(pred_pixel_dir, f_name, ni_aff, result)

            # sample
            if 'sp' in result.keys():
                # sample_score = self.score_sample_2d(ni_data)
                sample_score = result['sp']
                with open(os.path.join(pred_sample_dir, f_name + ".txt"),
                          "w") as target_file:
                    target_file.write(str(sample_score))

    def validate(self):
        print('validate')
        from .function import init_validation_dir
        from scripts.evalresults import eval_dir

        test_dir = self.test_dir
        score_dir, pred_pixel_dir, pred_sample_dir = init_validation_dir(
            algo_name=self.name, dataset_dir=test_dir)

        # pixel
        pred_pixel_dir = os.path.join(pred_pixel_dir, 'score')
        eval_dir(pred_dir=pred_pixel_dir,
                 label_dir=os.path.join(test_dir, 'label', 'pixel'),
                 mode='pixel',
                 save_file=os.path.join(score_dir, 'pixel'))

        # sample
        eval_dir(pred_dir=pred_sample_dir,
                 label_dir=os.path.join(test_dir, 'label', 'sample'),
                 mode='sample',
                 save_file=os.path.join(score_dir, 'sample'))

    def statistics(self):
        print('statistics')
        from .nifti_io import ni_load, ni_save
        import matplotlib.pyplot as plt
        import numpy as np

        test_dir = self.test_dir
        predict_dir = os.path.join(test_dir, 'eval', self.name, 'predict')
        assert os.path.exists(predict_dir), '先预测,再统计'

        statistics_dir = os.path.join(predict_dir, 'statistics')
        if not os.path.exists(statistics_dir):
            os.mkdir(statistics_dir)

        file_names = os.listdir(os.path.join(predict_dir, 'pixel', 'score'))
        length = len(file_names)
        handle = tqdm(enumerate(file_names))
        for i, file_name in handle:
            handle.set_description_str(f'{i}/{length}')

            prefix = file_name.split('.')[0]
            each_statistics_dir = os.path.join(statistics_dir, prefix)
            if not os.path.exists(each_statistics_dir):
                os.mkdir(each_statistics_dir)

            score, ni_aff = ni_load(
                os.path.join(predict_dir, 'pixel', 'score', file_name))
            flatten_score = score.flatten()

            # 整体打分直方图
            plt.hist(flatten_score, bins=50, log=False)
            plt.savefig(
                os.path.join(each_statistics_dir, 'whole_score_histogram'))
            plt.cla()

            with open(
                    os.path.join(test_dir, 'label', 'sample',
                                 file_name + '.txt'), "r") as f:
                sample_label = f.readline()
            sample_label = int(sample_label)

            if sample_label == 1:
                # 异常区域打分直方图
                label, _ = ni_load(
                    os.path.join(test_dir, 'label', 'pixel', file_name))
                abnormal_area_score = score[label == 1]
                plt.hist(abnormal_area_score, bins=50, log=False)
                plt.savefig(
                    os.path.join(each_statistics_dir,
                                 'abnormal_area_score_histogram'))
                plt.cla()

                abnormal_number = len(abnormal_area_score)
                # print(f'abnormal_number: {abnormal_number}')
            elif sample_label == 0:
                abnormal_number = 10000
            else:
                raise Exception(f'sample_label有问题: {sample_label}')

            # 高分区域打分直方图
            ordered_flatten_score = np.sort(flatten_score)[::-1]
            large_score = ordered_flatten_score[0:abnormal_number]
            plt.hist(large_score, bins=50, log=False)
            plt.savefig(
                os.path.join(each_statistics_dir,
                             'max_score_area_score_histogram'))
            plt.cla()

            max_score = large_score[0]
            img = score / max_score
            ni_save(os.path.join(each_statistics_dir, 'normalized'), img,
                    ni_aff)

            img = score
            threshold = ordered_flatten_score[abnormal_number]
            img[img >= threshold] = 1
            img[img < threshold] = 0
            ni_save(os.path.join(each_statistics_dir, 'binary'), img, ni_aff)

    def score_pixel_2d(self, np_array, **kwargs):
        from monai.transforms import Resize

        origin_size = np_array.shape[-1]
        from_transforms = Resize((origin_size, origin_size))
        to_transforms = self.to_transforms

        np_array = self.transpose(np_array)
        np_array = to_transforms(np_array)
        data_tensor = torch.from_numpy(np_array).float().cuda()

        result = self.get_pixel_score(self.model, data_tensor, **kwargs)

        for key in result.keys():
            if key == 'sp': continue
            tensor = result[key]
            array = tensor.detach().cpu().numpy()
            array = from_transforms(array)
            array = self.revert_transpose(array)
            result[key] = array

        return result

    def score_sample_2d(self, np_array):
        data_tensor = torch.from_numpy(np_array).float().cuda()

        sample_score = self.get_sample_score(self.model, data_tensor)
        return sample_score
Ejemplo n.º 9
0
class ceVAE:
    @monkey_patch_fn_args_as_config
    def __init__(
        self,
        input_shape,
        lr=1e-4,
        n_epochs=20,
        z_dim=512,
        model_feature_map_sizes=(16, 64, 256, 1024),
        use_geco=False,
        beta=0.01,
        ce_factor=0.5,
        score_mode="combi",
        load_path=None,
        log_dir=None,
        logger="visdom",
        print_every_iter=100,
        data_dir=None,
    ):

        self.score_mode = score_mode
        self.ce_factor = ce_factor
        self.beta = beta
        self.print_every_iter = print_every_iter
        self.n_epochs = n_epochs
        self.batch_size = input_shape[0]
        self.z_dim = z_dim
        self.use_geco = use_geco
        self.input_shape = input_shape
        self.logger = logger
        self.data_dir = data_dir

        log_dict = {}
        if logger is not None:
            log_dict = {
                0: (logger),
            }
        self.tx = PytorchExperimentStub(
            name="cevae",
            base_dir=log_dir,
            config=fn_args_as_config,
            loggers=log_dict,
        )

        cuda_available = torch.cuda.is_available()
        self.device = torch.device("cuda" if cuda_available else "cpu")

        self.model = VAE(input_size=input_shape[1:],
                         z_dim=z_dim,
                         fmap_sizes=model_feature_map_sizes).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)

        self.vae_loss_ema = 1
        self.theta = 1

        if load_path is not None:
            PytorchExperimentLogger.load_model_static(
                self.model, os.path.join(load_path, "vae_final.pth"))
            time.sleep(5)

    def train(self):

        train_loader = get_numpy2d_dataset(
            base_dir=self.data_dir,
            num_processes=16,
            pin_memory=False,
            batch_size=self.batch_size,
            mode="train",
            target_size=self.input_shape[2],
        )
        val_loader = get_numpy2d_dataset(
            base_dir=self.data_dir,
            num_processes=8,
            pin_memory=False,
            batch_size=self.batch_size,
            mode="val",
            target_size=self.input_shape[2],
        )

        for epoch in range(self.n_epochs):

            self.model.train()
            train_loss = 0

            print("Start epoch")
            data_loader_ = tqdm(enumerate(train_loader))
            for batch_idx, data in data_loader_:
                data = data * 2 - 1
                self.optimizer.zero_grad()

                inpt = data.to(self.device)

                ### VAE Part
                loss_vae = 0
                if self.ce_factor < 1:
                    x_rec_vae, z_dist, = self.model(inpt)

                    kl_loss = 0
                    if self.beta > 0:
                        kl_loss = self.kl_loss_fn(z_dist) * self.beta
                    rec_loss_vae = self.rec_loss_fn(x_rec_vae, inpt)
                    loss_vae = kl_loss + rec_loss_vae * self.theta

                ### CE Part
                loss_ce = 0
                if self.ce_factor > 0:

                    ce_tensor = get_square_mask(
                        data.shape,
                        square_size=(0, np.max(self.input_shape[2:]) // 2),
                        noise_val=(torch.min(data).item(),
                                   torch.max(data).item()),
                        n_squares=(0, 3),
                    )
                    ce_tensor = torch.from_numpy(ce_tensor).float()
                    inpt_noisy = torch.where(ce_tensor != 0, ce_tensor, data)

                    inpt_noisy = inpt_noisy.to(self.device)
                    x_rec_ce, _ = self.model(inpt_noisy)
                    rec_loss_ce = self.rec_loss_fn(x_rec_ce, inpt)
                    loss_ce = rec_loss_ce

                loss = (1.0 -
                        self.ce_factor) * loss_vae + self.ce_factor * loss_ce

                if self.use_geco and self.ce_factor < 1:
                    g_goal = 0.1
                    g_lr = 1e-4
                    self.vae_loss_ema = (
                        1.0 - 0.9) * rec_loss_vae + 0.9 * self.vae_loss_ema
                    self.theta = self.geco_beta_update(self.theta,
                                                       self.vae_loss_ema,
                                                       g_goal,
                                                       g_lr,
                                                       speedup=2)

                if torch.isnan(loss):
                    print("A wild NaN occurred")
                    continue

                loss.backward()
                self.optimizer.step()
                train_loss += loss.item()

                if batch_idx % self.print_every_iter == 0:
                    status_str = (
                        f"Train Epoch: {epoch} [{batch_idx}/{len(train_loader)} "
                        f" ({100.0 * batch_idx / len(train_loader):.0f}%)] Loss: "
                        f"{loss.item() / len(inpt):.6f}")
                    data_loader_.set_description_str(status_str)

                    cnt = epoch * len(train_loader) + batch_idx

                    if self.ce_factor < 1:
                        self.tx.l[0].show_image_grid(
                            inpt,
                            name="Input-VAE",
                            image_args={"normalize": True})
                        self.tx.l[0].show_image_grid(
                            x_rec_vae,
                            name="Output-VAE",
                            image_args={"normalize": True})

                        if self.beta > 0:
                            self.tx.add_result(torch.mean(kl_loss).item(),
                                               name="Kl-loss",
                                               tag="Losses",
                                               counter=cnt)
                        self.tx.add_result(torch.mean(rec_loss_vae).item(),
                                           name="Rec-loss",
                                           tag="Losses",
                                           counter=cnt)
                        self.tx.add_result(loss_vae.item(),
                                           name="Train-loss",
                                           tag="Losses",
                                           counter=cnt)

                    if self.ce_factor > 0:
                        self.tx.l[0].show_image_grid(
                            inpt_noisy,
                            name="Input-CE",
                            image_args={"normalize": True})
                        self.tx.l[0].show_image_grid(
                            x_rec_ce,
                            name="Output-CE",
                            image_args={"normalize": True})

            print(
                f"====> Epoch: {epoch} Average loss: {train_loss / len(train_loader):.4f}"
            )

            self.model.eval()

            val_loss = 0
            with torch.no_grad():
                data_loader_ = tqdm(enumerate(val_loader))
                for i, data in data_loader_:
                    data = data * 2 - 1
                    inpt = data.to(self.device)

                    x_rec, z_dist = self.model(inpt, sample=False)

                    kl_loss = 0
                    if self.beta > 0:
                        kl_loss = self.kl_loss_fn(z_dist) * self.beta
                    rec_loss = self.rec_loss_fn(x_rec, inpt)
                    loss = kl_loss + rec_loss * self.theta

                    val_loss += loss.item()

                self.tx.add_result(val_loss / len(val_loader),
                                   name="Val-Loss",
                                   tag="Losses",
                                   counter=(epoch + 1) * len(train_loader))

            print(
                f"====> Epoch: {epoch} Validation loss: {val_loss / len(val_loader):.4f}"
            )

        self.tx.save_model(self.model, "vae_final")

        time.sleep(10)

    def score_sample(self, np_array):

        orig_shape = np_array.shape
        to_transforms = torch.nn.Upsample(
            (self.input_shape[2], self.input_shape[3]), mode="bilinear")

        data_tensor = torch.from_numpy(np_array).float()
        data_tensor = to_transforms(data_tensor[None])[0]
        slice_scores = []

        for i in range(ceil(orig_shape[0] / self.batch_size)):
            batch = data_tensor[i * self.batch_size:(i + 1) *
                                self.batch_size].unsqueeze(1)
            batch = batch * 2 - 1

            with torch.no_grad():
                inpt = batch.to(self.device).float()
                x_rec, z_dist = self.model(inpt, sample=False)
                kl_loss = self.kl_loss_fn(z_dist, sum_samples=False)
                rec_loss = self.rec_loss_fn(x_rec, inpt, sum_samples=False)
                img_scores = kl_loss * self.beta + rec_loss * self.theta

            slice_scores += img_scores.cpu().tolist()

        return np.max(slice_scores)

    def score_pixels(self, np_array):

        orig_shape = np_array.shape
        to_transforms = torch.nn.Upsample(
            (self.input_shape[2], self.input_shape[3]), mode="bilinear")
        from_transforms = torch.nn.Upsample((orig_shape[1], orig_shape[2]),
                                            mode="bilinear")

        data_tensor = torch.from_numpy(np_array).float()
        data_tensor = to_transforms(data_tensor[None])[0]
        target_tensor = torch.zeros_like(data_tensor)

        for i in range(ceil(orig_shape[0] / self.batch_size)):
            batch = data_tensor[i * self.batch_size:(i + 1) *
                                self.batch_size].unsqueeze(1)
            batch = batch * 2 - 1

            inpt = batch.to(self.device).float()
            x_rec, z_dist = self.model(inpt, sample=False)

            if self.score_mode == "combi":

                rec = torch.pow((x_rec - inpt), 2).detach().cpu()
                rec = torch.mean(rec, dim=1, keepdim=True)

                def __err_fn(x):
                    x_r, z_d = self.model(x, sample=False)
                    loss = self.kl_loss_fn(z_d)
                    return loss

                loss_grad_kl = (get_smooth_image_gradient(
                    model=self.model,
                    inpt=inpt,
                    err_fn=__err_fn,
                    grad_type="vanilla",
                    n_runs=2).detach().cpu())
                loss_grad_kl = torch.mean(loss_grad_kl, dim=1, keepdim=True)

                pixel_scores = smooth_tensor(normalize(loss_grad_kl),
                                             kernel_size=8) * rec

            elif self.score_mode == "rec":

                rec = torch.pow((x_rec - inpt), 2).detach().cpu()
                rec = torch.mean(rec, dim=1, keepdim=True)
                pixel_scores = rec

            elif self.score_mode == "grad":

                def __err_fn(x):
                    x_r, z_d = self.model(x, sample=False)
                    kl_loss_ = self.kl_loss_fn(z_d)
                    rec_loss_ = self.rec_loss_fn(x_r, x)
                    loss_ = kl_loss_ * self.beta + rec_loss_ * self.theta
                    return torch.mean(loss_)

                loss_grad_kl = (get_smooth_image_gradient(
                    model=self.model,
                    inpt=inpt,
                    err_fn=__err_fn,
                    grad_type="vanilla",
                    n_runs=2).detach().cpu())
                loss_grad_kl = torch.mean(loss_grad_kl, dim=1, keepdim=True)

                pixel_scores = smooth_tensor(normalize(loss_grad_kl),
                                             kernel_size=8)

            self.tx.elog.show_image_grid(inpt,
                                         name="Input",
                                         image_args={"normalize": True},
                                         n_iter=i)
            self.tx.elog.show_image_grid(x_rec,
                                         name="Output",
                                         image_args={"normalize": True},
                                         n_iter=i)
            self.tx.elog.show_image_grid(pixel_scores,
                                         name="Scores",
                                         image_args={"normalize": True},
                                         n_iter=i)

            target_tensor[i * self.batch_size:(i + 1) *
                          self.batch_size] = pixel_scores.detach().cpu()[:,
                                                                         0, :]

        target_tensor = from_transforms(target_tensor[None])[0]

        return target_tensor.detach().numpy()

    @staticmethod
    def load_trained_model(model, tx, path):
        tx.elog.load_model_static(model=model, model_file=path)

    @staticmethod
    def kl_loss_fn(z_post, sum_samples=True, correct=False):
        z_prior = dist.Normal(0, 1.0)
        kl_div = dist.kl_divergence(z_post, z_prior)
        if correct:
            kl_div = torch.sum(kl_div, dim=(1, 2, 3))
        else:
            kl_div = torch.mean(kl_div, dim=(1, 2, 3))
        if sum_samples:
            return torch.mean(kl_div)
        else:
            return kl_div

    @staticmethod
    def rec_loss_fn(recon_x, x, sum_samples=True, correct=False):
        if correct:
            x_dist = dist.Laplace(recon_x, 1.0)
            log_p_x_z = x_dist.log_prob(x)
            log_p_x_z = torch.sum(log_p_x_z, dim=(1, 2, 3))
        else:
            log_p_x_z = -torch.abs(recon_x - x)
            log_p_x_z = torch.mean(log_p_x_z, dim=(1, 2, 3))
        if sum_samples:
            return -torch.mean(log_p_x_z)
        else:
            return -log_p_x_z

    @staticmethod
    def get_inpt_grad(model, inpt, err_fn):
        model.zero_grad()
        inpt = inpt.detach()
        inpt.requires_grad = True

        err = err_fn(inpt)
        err.backward()

        grad = inpt.grad.detach()

        model.zero_grad()

        return torch.abs(grad.detach())

    @staticmethod
    def geco_beta_update(beta,
                         error_ema,
                         goal,
                         step_size,
                         min_clamp=1e-10,
                         max_clamp=1e4,
                         speedup=None):
        constraint = (error_ema - goal).detach()
        if speedup is not None and constraint > 0.0:
            beta = beta * torch.exp(speedup * step_size * constraint)
        else:
            beta = beta * torch.exp(step_size * constraint)
        if min_clamp is not None:
            beta = np.max((beta.item(), min_clamp))
        if max_clamp is not None:
            beta = np.min((beta.item(), max_clamp))
        return beta

    @staticmethod
    def get_ema(new, old, alpha):
        if old is None:
            return new
        return (1.0 - alpha) * new + alpha * old

    def print(self, *args):
        print(*args)
        self.tx.print(*args)

    def log_result(self, val, key=None):
        self.tx.print(key, val)
        self.tx.add_result_without_epoch(val, key)