class ProbabilisticUnet(nn.Module): """ A probabilistic UNet (https://arxiv.org/abs/1806.05034) implementation. input_channels: the number of channels in the image (1 for greyscale and 3 for RGB) num_classes: the number of classes to predict num_filters: is a list consisint of the amount of filters layer latent_dim: dimension of the latent space no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior """ def __init__(self, input_channels=1, num_classes=1, num_filters=None, latent_levels=1, latent_dim=2, initializers=None, no_convs_fcomb=4, image_size=(1, 128, 128), beta=10.0, reversible=False): super(ProbabilisticUnet, self).__init__() self.input_channels = input_channels self.num_classes = num_classes self.num_filters = num_filters self.latent_dim = latent_dim self.no_convs_per_block = 3 self.no_convs_fcomb = no_convs_fcomb self.initializers = {'w': 'he_normal', 'b': 'normal'} self.z_prior_sample = 0 self.unet = Unet(self.input_channels, self.num_classes, self.num_filters, initializers=self.initializers, apply_last_layer=False, padding=True, reversible=reversible).to(device) self.prior = AxisAlignedConvGaussian( self.input_channels, self.num_filters, self.no_convs_per_block, self.latent_dim, initializers=self.initializers).to(device) self.posterior = AxisAlignedConvGaussian( self.input_channels, self.num_filters, self.no_convs_per_block, self.latent_dim, initializers=self.initializers, posterior=True).to(device) self.fcomb = Fcomb(self.num_filters, self.latent_dim, self.input_channels, self.num_classes, self.no_convs_fcomb, initializers={ 'w': 'orthogonal', 'b': 'normal' }, use_tile=True).to(device) self.last_conv = Conv2D(32, num_classes, kernel_size=1, activation=torch.nn.Identity, norm=torch.nn.Identity) def forward(self, patch, segm=None, training=True): """ Construct prior latent space for patch and run patch through UNet, in case training is True also construct posterior latent space """ if segm is not None: # construct posterior latent space aswell e.g. during validation self.posterior_latent_space = self.posterior.forward(patch, segm) self.prior_latent_space = self.prior.forward(patch) self.unet_features = self.unet.forward(patch, False) return self.last_conv(self.unet_features) # added for summary writer def sample(self, testing=False): """ Sample a segmentation by reconstructing from a prior sample and combining this with UNet features """ if testing == False: z_prior = self.prior_latent_space.rsample() self.z_prior_sample = z_prior else: # You can choose whether you mean a sample or the mean here. For the GED it is important to take a sample. # z_prior = self.prior_latent_space.base_dist.loc z_prior = self.prior_latent_space.sample() self.z_prior_sample = z_prior return self.fcomb.forward(self.unet_features, z_prior) def reconstruct(self, use_posterior_mean=False, calculate_posterior=False, z_posterior=None): """ Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map use_posterior_mean: use posterior_mean instead of sampling z_q calculate_posterior: use a provided sample or sample from posterior latent space """ if use_posterior_mean: z_posterior = self.posterior_latent_space.loc else: if calculate_posterior: z_posterior = self.posterior_latent_space.rsample() return self.fcomb.forward(self.unet_features, z_posterior) def accumulate_output(self, output_list, use_softmax=False): """Adapted to ProbUnet, which does not have an output list""" s_accum = output_list if use_softmax: return torch.nn.functional.softmax(s_accum, dim=1) return s_accum def KL_two_gauss_with_diag_cov(self, mu0, sigma0, mu1, sigma1): sigma0_fs = torch.mul(torch.flatten(sigma0, start_dim=1), torch.flatten(sigma0, start_dim=1)) sigma1_fs = torch.mul(torch.flatten(sigma1, start_dim=1), torch.flatten(sigma0, start_dim=1)) logsigma0_fs = torch.log(sigma0_fs + 1e-10) logsigma1_fs = torch.log(sigma1_fs + 1e-10) mu0_f = torch.flatten(mu0, start_dim=1) mu1_f = torch.flatten(mu1, start_dim=1) return torch.mean(0.5 * torch.sum( torch.div(sigma0_fs + torch.mul( (mu1_f - mu0_f), (mu1_f - mu0_f)), sigma1_fs + 1e-10) + logsigma1_fs - logsigma0_fs - 1, dim=1)) def kl_divergence(self, analytic=True, calculate_posterior=False, z_posterior=None): """ Calculate the KL divergence between the posterior and prior KL(Q||P) analytic: calculate KL analytically or via sampling from the posterior calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample """ # if analytic: # # Neeed to add this to torch source code, see: https://github.com/pytorch/pytorch/issues/13545 # kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space) # else: # if calculate_posterior: # z_posterior = self.posterior_latent_space.rsample() # log_posterior_prob = self.posterior_latent_space.log_prob(z_posterior) # log_prior_prob = self.prior_latent_space.log_prob(z_posterior) # kl_div = log_posterior_prob - log_prior_prob mu0 = self.posterior_latent_space.mean sigma0 = self.posterior_latent_space.stddev mu1 = self.prior_latent_space.mean sigma1 = self.prior_latent_space.stddev kl_div = self.KL_two_gauss_with_diag_cov(mu0, sigma0, mu1, sigma1) return kl_div def multinoulli_loss(self, reconstruction, target): criterion = torch.nn.CrossEntropyLoss(reduction='none') batch_size = reconstruction.shape[0] recon_flat = reconstruction.view(batch_size, self.num_classes, -1) target_flat = target.view(batch_size, -1).long() return torch.mean( torch.sum(criterion(target=target_flat, input=recon_flat), dim=1)) def elbo(self, segm, analytic_kl=False, reconstruct_posterior_mean=False): """ Calculate the evidence lower bound of the log-likelihood of P(Y|X) """ criterion = self.multinoulli_loss z_posterior = self.posterior_latent_space.rsample() self.kl_divergence_loss = torch.mean( self.kl_divergence(analytic=analytic_kl, calculate_posterior=False, z_posterior=z_posterior)) # Here we use the posterior sample sampled above self.reconstruction = self.reconstruct( use_posterior_mean=reconstruct_posterior_mean, calculate_posterior=False, z_posterior=z_posterior) reconstruction_loss = criterion(reconstruction=self.reconstruction, target=segm) self.reconstruction_loss = torch.sum(reconstruction_loss) self.mean_reconstruction_loss = torch.mean(reconstruction_loss) return -(self.reconstruction_loss + 1.0 * self.kl_divergence_loss) def loss(self, mask): elbo = self.elbo(mask) reg_loss = l2_regularisation(self.posterior) + l2_regularisation( self.prior) + l2_regularisation(self.fcomb.layers) loss = -elbo + 1e-5 * reg_loss return loss
recEpoch = [0] kl = torch.zeros(1) recLoss = torch.zeros(1) dGED = 0 t = time.time() for step, (patch, masks) in enumerate(train_loader): patch = patch.to(device) masks = masks.to(device) if args.singleRater or args.unet: rater = 0 else: # Choose a random mask rater = torch.randperm(4)[0] mask = masks[:, [rater]] if not args.unet: net.forward(patch, mask, training=True) _, _, _, elbo = net.elbo(mask, use_mask=False, analytic_kl=False) reg_loss = l2_regularisation(net.posterior) + l2_regularisation( net.prior) loss = -elbo + 1e-5 * reg_loss else: pred = torch.sigmoid(net.forward(patch, False)) loss = criterion(target=mask, input=pred) optimizer.zero_grad() loss.backward() optimizer.step() trLoss.append(loss.item()) if (step + 1) % 5 == 0: with torch.no_grad(): for idx, (patch, masks) in enumerate(valid_loader):
class cFlowNet(nn.Module): """ input_channels: the number of channels in the image (1 for greyscale and 3 for RGB) num_classes: the number of classes to predict num_filters: is a list consisint of the amount of filters layer latent_dim: dimension of the latent space no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior """ def __init__(self, input_channels=1, num_classes=1, num_filters=[32, 64, 128, 256], latent_dim=6, no_convs_fcomb=4, beta=1.0, num_flows=4, norm=False, flow=False, glow=False): super(cFlowNet, self).__init__() self.input_channels = input_channels self.num_classes = num_classes self.num_filters = num_filters self.latent_dim = latent_dim self.no_convs_per_block = 3 self.no_convs_fcomb = no_convs_fcomb self.initializers = {'w': 'he_normal', 'b': 'normal'} self.beta = beta self.z_prior_sample = 0 self.flow = flow self.flow_steps = num_flows self.unet = Unet(self.input_channels, self.num_classes, self.num_filters, self.initializers, apply_last_layer=False, padding=True, norm=norm).to(device) self.prior = AxisAlignedConvGaussian(self.input_channels, self.num_filters, self.no_convs_per_block, self.latent_dim, self.initializers, norm=norm).to(device) if flow: if glow: self.posterior = glowDensity(self.flow_steps, self.input_channels, self.num_filters, self.no_convs_per_block, self.latent_dim, self.initializers, posterior=True, norm=norm).to(device) else: self.posterior = planarFlowDensity(self.flow_steps, self.input_channels, self.num_filters, self.no_convs_per_block, self.latent_dim, self.initializers, posterior=True, norm=norm).to(device) else: self.posterior = AxisAlignedConvGaussian(self.input_channels, self.num_filters, self.no_convs_per_block, self.latent_dim, self.initializers, posterior=True, norm=norm).to(device) self.fcomb = Fcomb(self.num_filters, self.latent_dim, self.input_channels, self.num_classes, self.no_convs_fcomb, { 'w': 'orthogonal', 'b': 'normal' }, use_tile=True, norm=norm).to(device) def forward(self, patch, segm, training=True): """ Construct prior latent space for patch and run patch through UNet, in case training is True also construct posterior latent space """ # pdb.set_trace() if training: if self.flow: self.log_det_j, self.z0, self.z, self.posterior_latent_space = self.posterior.forward( patch, segm) else: _, self.posterior_latent_space = self.posterior.forward( patch, segm) self.z = self.posterior_latent_space.rsample() self.z0 = self.z.clone() _, self.prior_latent_space = self.prior.forward(patch) self.unet_features = self.unet.forward(patch, False) def sample(self, testing=False): """ Sample a segmentation by reconstructing from a prior sample and combining this with UNet features """ if testing == False: z_prior = self.prior_latent_space.rsample() self.z_prior_sample = z_prior else: #You can choose whether you mean a sample or the mean here. For the GED it is important to take a sample. z_prior = self.prior_latent_space.sample() self.z_prior_sample = z_prior log_pz = self.prior_latent_space.log_prob(z_prior) log_qz = self.posterior_latent_space.log_prob(z_prior) return self.fcomb.forward(self.unet_features, z_prior), log_pz, log_qz def reconstruct(self, use_posterior_mean=False, calculate_posterior=False, z_posterior=None): """ Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map use_posterior_mean: use posterior_mean instead of sampling z_q calculate_posterior: use a provided sample or sample from posterior latent space """ if use_posterior_mean: z_posterior = self.posterior_latent_space.loc else: if calculate_posterior: z_posterior = self.posterior_latent_space.rsample() return self.fcomb.forward(self.unet_features, z_posterior) def kl_divergence(self, analytic=True, calculate_posterior=False, z_posterior=None): """ Calculate the KL divergence between the posterior and prior KL(Q||P) analytic: calculate KL analytically or via sampling from the posterior calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample """ if analytic: #Neeed to add this to torch source code, see: https://github.com/pytorch/pytorch/issues/13545 kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space).sum() else: log_posterior_prob = self.posterior_latent_space.log_prob(self.z) log_prior_prob = self.prior_latent_space.log_prob(self.z) kl_div = (log_posterior_prob - log_prior_prob).sum() if self.flow: kl_div = kl_div - self.log_det_j.sum() return kl_div def elbo(self, segm, mask=None, use_mask=True, analytic_kl=True, reconstruct_posterior_mean=False): """ Calculate the evidence lower bound of the log-likelihood of P(Y|X) """ batch_size = segm.shape[0] self.kl = (self.kl_divergence(analytic=analytic_kl, calculate_posterior=False)) #Here we use the posterior sample sampled above self.reconstruction = self.reconstruct( use_posterior_mean=reconstruct_posterior_mean, calculate_posterior=False, z_posterior=self.z) if use_mask: self.reconstruction = self.reconstruction * mask criterion = nn.BCEWithLogitsLoss(reduction='none') reconstruction_loss = criterion(input=self.reconstruction, target=segm) self.reconstruction_loss = torch.sum(reconstruction_loss) self.mean_reconstruction_loss = torch.mean(reconstruction_loss) return self.reconstruction, self.reconstruction_loss/batch_size, self.kl/batch_size,\ -(self.reconstruction_loss + self.beta * self.kl)/batch_size