class ProbabilisticUnet(nn.Module):
    """
    A probabilistic UNet (https://arxiv.org/abs/1806.05034) implementation.
    input_channels: the number of channels in the image (1 for greyscale and 3 for RGB)
    num_classes: the number of classes to predict
    num_filters: is a list consisint of the amount of filters layer
    latent_dim: dimension of the latent space
    no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior
    """
    def __init__(self,
                 input_channels=1,
                 num_classes=1,
                 num_filters=None,
                 latent_levels=1,
                 latent_dim=2,
                 initializers=None,
                 no_convs_fcomb=4,
                 image_size=(1, 128, 128),
                 beta=10.0,
                 reversible=False):
        super(ProbabilisticUnet, self).__init__()
        self.input_channels = input_channels
        self.num_classes = num_classes
        self.num_filters = num_filters
        self.latent_dim = latent_dim
        self.no_convs_per_block = 3
        self.no_convs_fcomb = no_convs_fcomb
        self.initializers = {'w': 'he_normal', 'b': 'normal'}
        self.z_prior_sample = 0

        self.unet = Unet(self.input_channels,
                         self.num_classes,
                         self.num_filters,
                         initializers=self.initializers,
                         apply_last_layer=False,
                         padding=True,
                         reversible=reversible).to(device)
        self.prior = AxisAlignedConvGaussian(
            self.input_channels,
            self.num_filters,
            self.no_convs_per_block,
            self.latent_dim,
            initializers=self.initializers).to(device)
        self.posterior = AxisAlignedConvGaussian(
            self.input_channels,
            self.num_filters,
            self.no_convs_per_block,
            self.latent_dim,
            initializers=self.initializers,
            posterior=True).to(device)
        self.fcomb = Fcomb(self.num_filters,
                           self.latent_dim,
                           self.input_channels,
                           self.num_classes,
                           self.no_convs_fcomb,
                           initializers={
                               'w': 'orthogonal',
                               'b': 'normal'
                           },
                           use_tile=True).to(device)

        self.last_conv = Conv2D(32,
                                num_classes,
                                kernel_size=1,
                                activation=torch.nn.Identity,
                                norm=torch.nn.Identity)

    def forward(self, patch, segm=None, training=True):
        """
        Construct prior latent space for patch and run patch through UNet,
        in case training is True also construct posterior latent space
        """
        if segm is not None:  # construct posterior latent space aswell e.g. during validation
            self.posterior_latent_space = self.posterior.forward(patch, segm)
        self.prior_latent_space = self.prior.forward(patch)
        self.unet_features = self.unet.forward(patch, False)
        return self.last_conv(self.unet_features)  # added for summary writer

    def sample(self, testing=False):
        """
        Sample a segmentation by reconstructing from a prior sample
        and combining this with UNet features
        """
        if testing == False:
            z_prior = self.prior_latent_space.rsample()
            self.z_prior_sample = z_prior
        else:
            # You can choose whether you mean a sample or the mean here. For the GED it is important to take a sample.
            # z_prior = self.prior_latent_space.base_dist.loc
            z_prior = self.prior_latent_space.sample()
            self.z_prior_sample = z_prior
        return self.fcomb.forward(self.unet_features, z_prior)

    def reconstruct(self,
                    use_posterior_mean=False,
                    calculate_posterior=False,
                    z_posterior=None):
        """
        Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map
        use_posterior_mean: use posterior_mean instead of sampling z_q
        calculate_posterior: use a provided sample or sample from posterior latent space
        """
        if use_posterior_mean:
            z_posterior = self.posterior_latent_space.loc
        else:
            if calculate_posterior:
                z_posterior = self.posterior_latent_space.rsample()
        return self.fcomb.forward(self.unet_features, z_posterior)

    def accumulate_output(self, output_list, use_softmax=False):
        """Adapted to ProbUnet, which does not have an output list"""
        s_accum = output_list
        if use_softmax:
            return torch.nn.functional.softmax(s_accum, dim=1)
        return s_accum

    def KL_two_gauss_with_diag_cov(self, mu0, sigma0, mu1, sigma1):
        sigma0_fs = torch.mul(torch.flatten(sigma0, start_dim=1),
                              torch.flatten(sigma0, start_dim=1))
        sigma1_fs = torch.mul(torch.flatten(sigma1, start_dim=1),
                              torch.flatten(sigma0, start_dim=1))

        logsigma0_fs = torch.log(sigma0_fs + 1e-10)
        logsigma1_fs = torch.log(sigma1_fs + 1e-10)

        mu0_f = torch.flatten(mu0, start_dim=1)
        mu1_f = torch.flatten(mu1, start_dim=1)

        return torch.mean(0.5 * torch.sum(
            torch.div(sigma0_fs + torch.mul(
                (mu1_f - mu0_f), (mu1_f - mu0_f)), sigma1_fs + 1e-10) +
            logsigma1_fs - logsigma0_fs - 1,
            dim=1))

    def kl_divergence(self,
                      analytic=True,
                      calculate_posterior=False,
                      z_posterior=None):
        """
        Calculate the KL divergence between the posterior and prior KL(Q||P)
        analytic: calculate KL analytically or via sampling from the posterior
        calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample
        """
        # if analytic:
        #     # Neeed to add this to torch source code, see: https://github.com/pytorch/pytorch/issues/13545
        #     kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space)
        # else:
        #     if calculate_posterior:
        #         z_posterior = self.posterior_latent_space.rsample()
        #     log_posterior_prob = self.posterior_latent_space.log_prob(z_posterior)
        #     log_prior_prob = self.prior_latent_space.log_prob(z_posterior)
        #     kl_div = log_posterior_prob - log_prior_prob
        mu0 = self.posterior_latent_space.mean
        sigma0 = self.posterior_latent_space.stddev
        mu1 = self.prior_latent_space.mean
        sigma1 = self.prior_latent_space.stddev
        kl_div = self.KL_two_gauss_with_diag_cov(mu0, sigma0, mu1, sigma1)
        return kl_div

    def multinoulli_loss(self, reconstruction, target):
        criterion = torch.nn.CrossEntropyLoss(reduction='none')

        batch_size = reconstruction.shape[0]

        recon_flat = reconstruction.view(batch_size, self.num_classes, -1)
        target_flat = target.view(batch_size, -1).long()
        return torch.mean(
            torch.sum(criterion(target=target_flat, input=recon_flat), dim=1))

    def elbo(self, segm, analytic_kl=False, reconstruct_posterior_mean=False):
        """
        Calculate the evidence lower bound of the log-likelihood of P(Y|X)
        """

        criterion = self.multinoulli_loss

        z_posterior = self.posterior_latent_space.rsample()

        self.kl_divergence_loss = torch.mean(
            self.kl_divergence(analytic=analytic_kl,
                               calculate_posterior=False,
                               z_posterior=z_posterior))

        # Here we use the posterior sample sampled above
        self.reconstruction = self.reconstruct(
            use_posterior_mean=reconstruct_posterior_mean,
            calculate_posterior=False,
            z_posterior=z_posterior)

        reconstruction_loss = criterion(reconstruction=self.reconstruction,
                                        target=segm)
        self.reconstruction_loss = torch.sum(reconstruction_loss)
        self.mean_reconstruction_loss = torch.mean(reconstruction_loss)

        return -(self.reconstruction_loss + 1.0 * self.kl_divergence_loss)

    def loss(self, mask):
        elbo = self.elbo(mask)
        reg_loss = l2_regularisation(self.posterior) + l2_regularisation(
            self.prior) + l2_regularisation(self.fcomb.layers)
        loss = -elbo + 1e-5 * reg_loss
        return loss
Esempio n. 2
0
    recEpoch = [0]
    kl = torch.zeros(1)
    recLoss = torch.zeros(1)
    dGED = 0
    t = time.time()
    for step, (patch, masks) in enumerate(train_loader):
        patch = patch.to(device)
        masks = masks.to(device)
        if args.singleRater or args.unet:
            rater = 0
        else:
            # Choose a random mask
            rater = torch.randperm(4)[0]
        mask = masks[:, [rater]]
        if not args.unet:
            net.forward(patch, mask, training=True)
            _, _, _, elbo = net.elbo(mask, use_mask=False, analytic_kl=False)
            reg_loss = l2_regularisation(net.posterior) + l2_regularisation(
                net.prior)
            loss = -elbo + 1e-5 * reg_loss
        else:
            pred = torch.sigmoid(net.forward(patch, False))
            loss = criterion(target=mask, input=pred)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        trLoss.append(loss.item())

        if (step + 1) % 5 == 0:
            with torch.no_grad():
                for idx, (patch, masks) in enumerate(valid_loader):
Esempio n. 3
0
class cFlowNet(nn.Module):
    """
    input_channels: the number of channels in the image (1 for greyscale and 3 for RGB)
    num_classes: the number of classes to predict
    num_filters: is a list consisint of the amount of filters layer
    latent_dim: dimension of the latent space
    no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior
    """
    def __init__(self,
                 input_channels=1,
                 num_classes=1,
                 num_filters=[32, 64, 128, 256],
                 latent_dim=6,
                 no_convs_fcomb=4,
                 beta=1.0,
                 num_flows=4,
                 norm=False,
                 flow=False,
                 glow=False):

        super(cFlowNet, self).__init__()
        self.input_channels = input_channels
        self.num_classes = num_classes
        self.num_filters = num_filters
        self.latent_dim = latent_dim
        self.no_convs_per_block = 3
        self.no_convs_fcomb = no_convs_fcomb
        self.initializers = {'w': 'he_normal', 'b': 'normal'}
        self.beta = beta
        self.z_prior_sample = 0
        self.flow = flow
        self.flow_steps = num_flows

        self.unet = Unet(self.input_channels,
                         self.num_classes,
                         self.num_filters,
                         self.initializers,
                         apply_last_layer=False,
                         padding=True,
                         norm=norm).to(device)
        self.prior = AxisAlignedConvGaussian(self.input_channels,
                                             self.num_filters,
                                             self.no_convs_per_block,
                                             self.latent_dim,
                                             self.initializers,
                                             norm=norm).to(device)

        if flow:
            if glow:
                self.posterior = glowDensity(self.flow_steps,
                                             self.input_channels,
                                             self.num_filters,
                                             self.no_convs_per_block,
                                             self.latent_dim,
                                             self.initializers,
                                             posterior=True,
                                             norm=norm).to(device)
            else:
                self.posterior = planarFlowDensity(self.flow_steps,
                                                   self.input_channels,
                                                   self.num_filters,
                                                   self.no_convs_per_block,
                                                   self.latent_dim,
                                                   self.initializers,
                                                   posterior=True,
                                                   norm=norm).to(device)
        else:
            self.posterior = AxisAlignedConvGaussian(self.input_channels,
                                                     self.num_filters,
                                                     self.no_convs_per_block,
                                                     self.latent_dim,
                                                     self.initializers,
                                                     posterior=True,
                                                     norm=norm).to(device)

        self.fcomb = Fcomb(self.num_filters,
                           self.latent_dim,
                           self.input_channels,
                           self.num_classes,
                           self.no_convs_fcomb, {
                               'w': 'orthogonal',
                               'b': 'normal'
                           },
                           use_tile=True,
                           norm=norm).to(device)

    def forward(self, patch, segm, training=True):
        """
        Construct prior latent space for patch and run patch through UNet,
        in case training is True also construct posterior latent space
        """
        #        pdb.set_trace()
        if training:
            if self.flow:
                self.log_det_j, self.z0, self.z, self.posterior_latent_space = self.posterior.forward(
                    patch, segm)
            else:
                _, self.posterior_latent_space = self.posterior.forward(
                    patch, segm)
                self.z = self.posterior_latent_space.rsample()
                self.z0 = self.z.clone()
        _, self.prior_latent_space = self.prior.forward(patch)
        self.unet_features = self.unet.forward(patch, False)

    def sample(self, testing=False):
        """
        Sample a segmentation by reconstructing from a prior sample
        and combining this with UNet features
        """
        if testing == False:
            z_prior = self.prior_latent_space.rsample()
            self.z_prior_sample = z_prior
        else:
            #You can choose whether you mean a sample or the mean here. For the GED it is important to take a sample.
            z_prior = self.prior_latent_space.sample()
            self.z_prior_sample = z_prior
        log_pz = self.prior_latent_space.log_prob(z_prior)
        log_qz = self.posterior_latent_space.log_prob(z_prior)
        return self.fcomb.forward(self.unet_features, z_prior), log_pz, log_qz

    def reconstruct(self,
                    use_posterior_mean=False,
                    calculate_posterior=False,
                    z_posterior=None):
        """
        Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map
        use_posterior_mean: use posterior_mean instead of sampling z_q
        calculate_posterior: use a provided sample or sample from posterior latent space
        """
        if use_posterior_mean:
            z_posterior = self.posterior_latent_space.loc
        else:
            if calculate_posterior:
                z_posterior = self.posterior_latent_space.rsample()
        return self.fcomb.forward(self.unet_features, z_posterior)

    def kl_divergence(self,
                      analytic=True,
                      calculate_posterior=False,
                      z_posterior=None):
        """
        Calculate the KL divergence between the posterior and prior KL(Q||P)
        analytic: calculate KL analytically or via sampling from the posterior
        calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample
        """
        if analytic:
            #Neeed to add this to torch source code, see: https://github.com/pytorch/pytorch/issues/13545
            kl_div = kl.kl_divergence(self.posterior_latent_space,
                                      self.prior_latent_space).sum()

        else:
            log_posterior_prob = self.posterior_latent_space.log_prob(self.z)
            log_prior_prob = self.prior_latent_space.log_prob(self.z)
            kl_div = (log_posterior_prob - log_prior_prob).sum()
        if self.flow:
            kl_div = kl_div - self.log_det_j.sum()
        return kl_div

    def elbo(self,
             segm,
             mask=None,
             use_mask=True,
             analytic_kl=True,
             reconstruct_posterior_mean=False):
        """
        Calculate the evidence lower bound of the log-likelihood of P(Y|X)
        """
        batch_size = segm.shape[0]
        self.kl = (self.kl_divergence(analytic=analytic_kl,
                                      calculate_posterior=False))

        #Here we use the posterior sample sampled above
        self.reconstruction = self.reconstruct(
            use_posterior_mean=reconstruct_posterior_mean,
            calculate_posterior=False,
            z_posterior=self.z)
        if use_mask:

            self.reconstruction = self.reconstruction * mask
        criterion = nn.BCEWithLogitsLoss(reduction='none')
        reconstruction_loss = criterion(input=self.reconstruction, target=segm)
        self.reconstruction_loss = torch.sum(reconstruction_loss)
        self.mean_reconstruction_loss = torch.mean(reconstruction_loss)

        return self.reconstruction, self.reconstruction_loss/batch_size, self.kl/batch_size,\
                -(self.reconstruction_loss + self.beta * self.kl)/batch_size