Beispiel #1
0
class ProbabilisticUnet(nn.Module):
    """
    概率UNet(https://arxiv.org/abs/1806.05034)实现。
     input_channels:图像中的通道数(灰度为1,RGB为3)
     num_classes:要预测的类数
     num_filters:是过滤器层数的列表一致性
     latent_dim:隐空间的维度
     no_cons_per_block:先验和后验(卷积)编码器中的每个块卷积编号

    A probabilistic UNet (https://arxiv.org/abs/1806.05034) implementation.
    input_channels: the number of channels in the image (1 for greyscale and 3 for RGB)
    num_classes: the number of classes to predict
    num_filters: is a list consisint of the amount of filters layer
    latent_dim: dimension of the latent space
    no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior
    """
    def __init__(self,
                 input_channels=1,
                 num_classes=1,
                 num_filters=[32, 64, 128, 192],
                 latent_dim=6,
                 no_convs_fcomb=4,
                 beta=10.0):
        super(ProbabilisticUnet, self).__init__()
        self.input_channels = input_channels  # 输入图像通道数
        self.num_classes = num_classes  # 分割类别数
        self.num_filters = num_filters  # filter数
        self.latent_dim = latent_dim  # 隐空间维度
        self.no_convs_per_block = 3
        self.no_convs_fcomb = no_convs_fcomb
        self.initializers = {'w': 'he_normal', 'b': 'normal'}  # 初始化
        self.beta = beta
        self.z_prior_sample = 0

        self.unet = Unet(self.input_channels,
                         self.num_classes,
                         self.num_filters,
                         self.initializers,
                         apply_last_layer=False,
                         padding=True).to(device)
        self.prior = AxisAlignedConvGaussian(
            self.input_channels,
            self.num_filters,
            self.no_convs_per_block,
            self.latent_dim,
            self.initializers,
        ).to(device)
        self.posterior = AxisAlignedConvGaussian(self.input_channels,
                                                 self.num_filters,
                                                 self.no_convs_per_block,
                                                 self.latent_dim,
                                                 self.initializers,
                                                 posterior=True).to(device)
        self.fcomb = Fcomb(self.num_filters,
                           self.latent_dim,
                           self.input_channels,
                           self.num_classes,
                           self.no_convs_fcomb, {
                               'w': 'orthogonal',
                               'b': 'normal'
                           },
                           use_tile=True).to(device)

    def forward(self, patch, segm, training=True):
        """
        为patch构建先验隐空间,并通过UNet运行patch,
        如果training=True,则还可以构造后方潜在空间

        Construct prior latent space for patch and run patch through UNet,
        in case training is True also construct posterior latent space
        """
        if training:
            self.posterior_latent_space = self.posterior.forward(patch, segm)
        self.prior_latent_space = self.prior.forward(patch)
        self.unet_features = self.unet.forward(patch, False)

    def sample(self, testing=False):
        """
        通过根据先验样本进行重构来对切割进行采样
        并将其与UNet特征相结合

        Sample a segmentation by reconstructing from a prior sample
        and combining this with UNet features
        """
        if testing == False:
            z_prior = self.prior_latent_space.rsample()
            self.z_prior_sample = z_prior
        else:
            #你可以选择是指样本还是平均值。 对于GED,取样非常重要。
            #You can choose whether you mean a sample or the mean here. For the GED it is important to take a sample.
            #z_prior = self.prior_latent_space.base_dist.loc
            z_prior = self.prior_latent_space.sample()
            self.z_prior_sample = z_prior
        return self.fcomb.forward(self.unet_features, z_prior)

    def reconstruct(self,
                    use_posterior_mean=False,
                    calculate_posterior=False,
                    z_posterior=None):
        """
        从后验样本(解码后验样本)和UNet特征图重建分割
        use_posterior_mean:使用posterior_mean代替对z_q的采样
        compute_posterior:使用提供的样本或来自后潜在空间的样本

        Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map
        use_posterior_mean: use posterior_mean instead of sampling z_q
        calculate_posterior: use a provided sample or sample from posterior latent space
        """
        if use_posterior_mean:
            z_posterior = self.posterior_latent_space.loc
        else:
            if calculate_posterior:
                z_posterior = self.posterior_latent_space.rsample()
        return self.fcomb.forward(self.unet_features, z_posterior)

    def kl_divergence(self,
                      analytic=True,
                      calculate_posterior=False,
                      z_posterior=None):
        """
        计算后验KL(Q||P)和先验KL(Q||P)之间的KL散度
        分析:通过分析或通过后验采样来计算KL
        compute_posterior:如果我们使用samapling来近似KL,则可以在此处采样或提供样本

        Calculate the KL divergence between the posterior and prior KL(Q||P)
        analytic: calculate KL analytically or via sampling from the posterior
        calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample
        """
        if analytic:
            #Neeed to add this to torch source code, see: https://github.com/pytorch/pytorch/issues/13545
            kl_div = kl.kl_divergence(self.posterior_latent_space,
                                      self.prior_latent_space)
        else:
            if calculate_posterior:
                z_posterior = self.posterior_latent_space.rsample()
            log_posterior_prob = self.posterior_latent_space.log_prob(
                z_posterior)
            log_prior_prob = self.prior_latent_space.log_prob(z_posterior)
            kl_div = log_posterior_prob - log_prior_prob
        return kl_div

    def elbo(self, segm, analytic_kl=True, reconstruct_posterior_mean=False):
        """
        计算P(Y|X)的边际似然函数下界

        Calculate the evidence lower bound of the log-likelihood of P(Y|X)
        """

        criterion = nn.BCEWithLogitsLoss(size_average=False,
                                         reduce=False,
                                         reduction=None)
        z_posterior = self.posterior_latent_space.rsample()

        self.kl = torch.mean(
            self.kl_divergence(analytic=analytic_kl,
                               calculate_posterior=False,
                               z_posterior=z_posterior))

        #Here we use the posterior sample sampled above
        self.reconstruction = self.reconstruct(
            use_posterior_mean=reconstruct_posterior_mean,
            calculate_posterior=False,
            z_posterior=z_posterior)

        reconstruction_loss = criterion(input=self.reconstruction, target=segm)
        self.reconstruction_loss = torch.sum(reconstruction_loss)
        self.mean_reconstruction_loss = torch.mean(reconstruction_loss)

        return -(self.reconstruction_loss + self.beta * self.kl)
class ProbabilisticUnet(nn.Module):
    """
    A probabilistic UNet (https://arxiv.org/abs/1806.05034) implementation.
    input_channels: the number of channels in the image (1 for greyscale and 3 for RGB)
    num_classes: the number of classes to predict
    num_filters: is a list consisint of the amount of filters layer
    latent_dim: dimension of the latent space
    no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior
    """
    def __init__(
        self,
        input_channels=1,
        num_classes=1,
        num_filters=[32, 64, 128, 192],
        latent_dim=6,
        no_convs_fcomb=4,
        beta=10.0,
    ):
        super(ProbabilisticUnet, self).__init__()
        self.input_channels = input_channels
        self.num_classes = num_classes
        self.num_filters = num_filters
        self.latent_dim = latent_dim
        self.no_convs_per_block = 3
        self.no_convs_fcomb = no_convs_fcomb
        self.initializers = {"w": "he_normal", "b": "normal"}
        self.beta = beta
        self.z_prior_sample = 0

        self.unet = Unet(
            self.input_channels,
            self.num_classes,
            self.num_filters,
            self.initializers,
            apply_last_layer=False,
            padding=True,
        ).to(device)
        self.prior = AxisAlignedConvGaussian(
            self.input_channels,
            self.num_filters,
            self.no_convs_per_block,
            self.latent_dim,
            self.initializers,
        ).to(device)
        self.posterior = AxisAlignedConvGaussian(
            self.input_channels,
            self.num_filters,
            self.no_convs_per_block,
            self.latent_dim,
            self.initializers,
            posterior=True,
        ).to(device)
        self.fcomb = Fcomb(
            self.num_filters,
            self.latent_dim,
            self.input_channels,
            self.num_classes,
            self.no_convs_fcomb,
            {
                "w": "orthogonal",
                "b": "normal"
            },
            use_tile=True,
        ).to(device)

    def forward(self, patch, segm, training=True):
        """
        Construct prior latent space for patch and run patch through UNet,
        in case training is True also construct posterior latent space
        """
        if training:
            self.posterior_latent_space = self.posterior.forward(patch, segm)
        self.prior_latent_space = self.prior.forward(patch)
        self.unet_features = self.unet.forward(patch, False)

    def sample(self, testing=False):
        """
        Sample a segmentation by reconstructing from a prior sample
        and combining this with UNet features
        """
        if testing == False:
            z_prior = self.prior_latent_space.rsample()
            self.z_prior_sample = z_prior
        else:
            # You can choose whether you mean a sample or the mean here. For the GED it is important to take a sample.
            # z_prior = self.prior_latent_space.base_dist.loc
            z_prior = self.prior_latent_space.sample()
            self.z_prior_sample = z_prior
        return self.fcomb.forward(self.unet_features, z_prior)

    def reconstruct(self,
                    use_posterior_mean=False,
                    calculate_posterior=False,
                    z_posterior=None):
        """
        Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map
        use_posterior_mean: use posterior_mean instead of sampling z_q
        calculate_posterior: use a provided sample or sample from posterior latent space
        """
        if use_posterior_mean:
            z_posterior = self.posterior_latent_space.loc
        else:
            if calculate_posterior:
                z_posterior = self.posterior_latent_space.rsample()
        return self.fcomb.forward(self.unet_features, z_posterior)

    def kl_divergence(self,
                      analytic=True,
                      calculate_posterior=False,
                      z_posterior=None):
        """
        Calculate the KL divergence between the posterior and prior KL(Q||P)
        analytic: calculate KL analytically or via sampling from the posterior
        calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample
        """
        if analytic:
            # Neeed to add this to torch source code, see: https://github.com/pytorch/pytorch/issues/13545
            kl_div = kl.kl_divergence(self.posterior_latent_space,
                                      self.prior_latent_space)
        else:
            if calculate_posterior:
                z_posterior = self.posterior_latent_space.rsample()
            log_posterior_prob = self.posterior_latent_space.log_prob(
                z_posterior)
            log_prior_prob = self.prior_latent_space.log_prob(z_posterior)
            kl_div = log_posterior_prob - log_prior_prob
        return kl_div

    def elbo(self, segm, analytic_kl=True, reconstruct_posterior_mean=False):
        """
        Calculate the evidence lower bound of the log-likelihood of P(Y|X)
        """

        criterion = nn.BCEWithLogitsLoss(size_average=False,
                                         reduce=False,
                                         reduction=None)
        z_posterior = self.posterior_latent_space.rsample()

        self.kl = torch.mean(
            self.kl_divergence(analytic=analytic_kl,
                               calculate_posterior=False,
                               z_posterior=z_posterior))

        # Here we use the posterior sample sampled above
        self.reconstruction = self.reconstruct(
            use_posterior_mean=reconstruct_posterior_mean,
            calculate_posterior=False,
            z_posterior=z_posterior,
        )

        reconstruction_loss = criterion(input=self.reconstruction, target=segm)
        self.reconstruction_loss = torch.sum(reconstruction_loss)
        self.mean_reconstruction_loss = torch.mean(reconstruction_loss)

        return -(self.reconstruction_loss + self.beta * self.kl)
Beispiel #3
0
def main(argv):
    """

    IMAGES VALID:
    * 005-TS_13C08351_2-2014-02-12 12.22.44.ndpi | id : 77150767
    * 024-12C07162_2A-2012-08-14-17.21.05.jp2 | id : 77150761
    * 019-CP_12C04234_2-2012-08-10-12.49.26.jp2 | id : 77150809

    IMAGES TEST:
    * 004-PF_08C11886_1-2012-08-09-19.05.53.jp2 | id : 77150623
    * 011-TS_13C10153_3-2014-02-13 15.22.21.ndpi | id : 77150611
    * 018-PF_07C18435_1-2012-08-17-00.55.09.jp2 | id : 77150755

    """
    with Cytomine.connect_from_cli(argv):
        parser = ArgumentParser()
        parser.add_argument("-b",
                            "--batch_size",
                            dest="batch_size",
                            default=4,
                            type=int)
        parser.add_argument("-j",
                            "--n_jobs",
                            dest="n_jobs",
                            default=1,
                            type=int)
        parser.add_argument("-e",
                            "--epochs",
                            dest="epochs",
                            default=1,
                            type=int)
        parser.add_argument("-d", "--device", dest="device", default="cpu")
        parser.add_argument("-o",
                            "--overlap",
                            dest="overlap",
                            default=0,
                            type=int)
        parser.add_argument("-t",
                            "--tile_size",
                            dest="tile_size",
                            default=256,
                            type=int)
        parser.add_argument("-z",
                            "--zoom_level",
                            dest="zoom_level",
                            default=0,
                            type=int)
        parser.add_argument("--lr", dest="lr", default=0.01, type=float)
        parser.add_argument("--init_fmaps",
                            dest="init_fmaps",
                            default=16,
                            type=int)
        parser.add_argument("--data_path",
                            "--dpath",
                            dest="data_path",
                            default=os.path.join(str(Path.home()), "tmp"))
        parser.add_argument("-w",
                            "--working_path",
                            "--wpath",
                            dest="working_path",
                            default=os.path.join(str(Path.home()), "tmp"))
        parser.add_argument("-s",
                            "--save_path",
                            dest="save_path",
                            default=os.path.join(str(Path.home()), "tmp"))
        args, _ = parser.parse_known_args(argv)

        os.makedirs(args.save_path, exist_ok=True)
        os.makedirs(args.data_path, exist_ok=True)
        os.makedirs(args.working_path, exist_ok=True)

        # fetch annotations (filter val/test sets + other annotations)
        all_annotations = AnnotationCollection(project=77150529,
                                               showWKT=True,
                                               showMeta=True,
                                               showTerm=True).fetch()
        val_ids = {77150767, 77150761, 77150809}
        test_ids = {77150623, 77150611, 77150755}
        val_test_ids = val_ids.union(test_ids)
        train_collection = all_annotations.filter(lambda a: (
            a.user in {55502856} and len(a.term) > 0 and a.term[0] in
            {35777351, 35777321, 35777459} and a.image not in val_test_ids))
        val_rois = all_annotations.filter(
            lambda a: (a.user in {142954314} and a.image in val_ids and len(
                a.term) > 0 and a.term[0] in {154890363}))
        val_foreground = all_annotations.filter(
            lambda a: (a.user in {142954314} and a.image in val_ids and len(
                a.term) > 0 and a.term[0] in {154005477}))

        train_wsi_ids = list({an.image
                              for an in all_annotations
                              }.difference(val_test_ids))
        val_wsi_ids = list(val_ids)

        download_path = os.path.join(args.data_path,
                                     "crops-{}".format(args.tile_size))
        images = {
            _id: ImageInstance().fetch(_id)
            for _id in (train_wsi_ids + val_wsi_ids)
        }

        train_crops = [
            AnnotationCrop(images[annot.image],
                           annot,
                           download_path,
                           args.tile_size,
                           zoom_level=args.zoom_level)
            for annot in train_collection
        ]
        val_crops = [
            AnnotationCrop(images[annot.image],
                           annot,
                           download_path,
                           args.tile_size,
                           zoom_level=args.zoom_level) for annot in val_rois
        ]

        for crop in train_crops + val_crops:
            crop.download()

        np.random.seed(42)
        dataset = RemoteAnnotationTrainDataset(
            train_crops, seg_trans=segmentation_transform)
        loader = DataLoader(dataset,
                            shuffle=True,
                            batch_size=args.batch_size,
                            num_workers=args.n_jobs,
                            worker_init_fn=worker_init)

        # network
        device = torch.device(args.device)
        unet = Unet(args.init_fmaps, n_classes=1)
        unet.train()
        unet.to(device)

        optimizer = Adam(unet.parameters(), lr=args.lr)
        loss_fn = BCEWithLogitsLoss(reduction="mean")

        results = {
            "train_losses": [],
            "val_losses": [],
            "val_metrics": [],
            "save_path": []
        }

        for e in range(args.epochs):
            print("########################")
            print("        Epoch {}".format(e))
            print("########################")

            epoch_losses = list()
            unet.train()
            for i, (x, y) in enumerate(loader):
                x, y = (t.to(device) for t in [x, y])
                y_pred = unet.forward(x)
                loss = loss_fn(y_pred, y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                epoch_losses = [loss.detach().cpu().item()] + epoch_losses[:5]
                print("{} - {:1.5f}".format(i, np.mean(epoch_losses)))
                results["train_losses"].append(epoch_losses[0])

            unet.eval()
            # validation
            val_losses = np.zeros(len(val_rois), dtype=np.float)
            val_roc_auc = np.zeros(len(val_rois), dtype=np.float)
            val_cm = np.zeros([len(val_rois), 2, 2], dtype=np.int)

            for i, roi in enumerate(val_crops):
                foregrounds = find_intersecting_annotations(
                    roi.annotation, val_foreground)
                with torch.no_grad():
                    y_pred, y_true = predict_roi(
                        roi,
                        foregrounds,
                        unet,
                        device,
                        in_trans=transforms.ToTensor(),
                        batch_size=args.batch_size,
                        tile_size=args.tile_size,
                        overlap=args.overlap,
                        n_jobs=args.n_jobs,
                        zoom_level=args.zoom_level)

                val_losses[i] = metrics.log_loss(y_true.flatten(),
                                                 y_pred.flatten())
                val_roc_auc[i] = metrics.roc_auc_score(y_true.flatten(),
                                                       y_pred.flatten())
                val_cm[i] = metrics.confusion_matrix(
                    y_true.flatten().astype(np.uint8),
                    (y_pred.flatten() > 0.5).astype(np.uint8))

            print("------------------------------")
            print("Epoch {}:".format(e))
            val_loss = np.mean(val_losses)
            roc_auc = np.mean(val_roc_auc)
            print("> val_loss: {:1.5f}".format(val_loss))
            print("> roc_auc : {:1.5f}".format(roc_auc))
            cm = np.sum(val_cm, axis=0)
            cnt = np.sum(val_cm)
            print("CM at 0.5 threshold")
            print("> {:3.2f}%  {:3.2f}%".format(100 * cm[0, 0] / cnt,
                                                100 * cm[0, 1] / cnt))
            print("> {:3.2f}%  {:3.2f}%".format(100 * cm[1, 0] / cnt,
                                                100 * cm[1, 1] / cnt))
            print("------------------------------")

            filename = "{}_e_{}_val_{:0.4f}_roc_{:0.4f}_z{}_s{}.pth".format(
                datetime.now().timestamp(), e, val_loss, roc_auc,
                args.zoom_level, args.tile_size)
            torch.save(unet.state_dict(), os.path.join(args.save_path,
                                                       filename))

            results["val_losses"].append(val_loss)
            results["val_metrics"].append(roc_auc)
            results["save_path"].append(filename)

        return results