class ERGB2Depth(BaseERGB2Depth): def __init__(self, config): super(ERGB2Depth, self).__init__(config) self.unet = UNet(num_input_channels=self.num_bins_rgb, num_output_channels=1, skip_type=self.skip_type, activation='sigmoid', num_encoders=self.num_encoders, base_num_channels=self.base_num_channels, num_residual_blocks=self.num_residual_blocks, norm=self.norm, use_upsample_conv=self.use_upsample_conv) def forward(self, item, prev_super_states, prev_states_lstm): #def forward(self, event_tensor, prev_states=None): """ :param event_tensor: N x num_bins x H x W :return: a predicted image of size N x 1 x H x W, taking values in [0,1]. """ predictions_dict = {} '''for key in item.keys(): if "depth" not in key: event_tensor = item[key].to(self.gpu) prediction = self.unet.forward(event_tensor) predictions_dict[key] = prediction''' event_tensor = item["image"].to(self.gpu) prediction = self.unet.forward(event_tensor) predictions_dict["image"] = prediction return predictions_dict, {'image': None}, prev_states_lstm
class E2VID(BaseE2VID): def __init__(self, config): super(E2VID, self).__init__(config) self.unet = UNet(num_input_channels=self.num_bins, num_output_channels=1, skip_type=self.skip_type, activation='sigmoid', num_encoders=self.num_encoders, base_num_channels=self.base_num_channels, num_residual_blocks=self.num_residual_blocks, norm=self.norm, use_upsample_conv=self.use_upsample_conv) def forward(self, event_tensor, prev_states=None): """ :param event_tensor: N x num_bins x H x W :return: a predicted image of size N x 1 x H x W, taking values in [0,1]. """ return self.unet.forward(event_tensor), None
class ProbabilisticUnet(nn.Module): """ A probabilistic UNet (https://arxiv.org/abs/1806.05034) implementation. input_channels: the number of channels in the image (1 for greyscale and 3 for RGB) num_classes: the number of classes to predict num_filters: is a list consisint of the amount of filters layer latent_dim: dimension of the latent space no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior """ def __init__(self, input_channels=1, num_classes=1, num_filters=[32, 64, 128, 192], latent_dim=6, no_convs_fcomb=3, beta=1.0): super(ProbabilisticUnet, self).__init__() self.n_channels = input_channels self.n_classes = num_classes self.num_filters = num_filters self.latent_dim = latent_dim self.no_convs_per_block = 2 self.no_convs_fcomb = no_convs_fcomb self.initializers = {'w': 'he_normal', 'b': 'normal'} self.beta = beta self.z_prior_sample = 0 self.unet = UNet(n_channels=self.n_channels, n_classes=self.n_classes, num_filters=self.num_filters, apply_last_layer=False).to(device) self.prior = AxisAlignedConvGaussian( self.n_channels, self.num_filters, self.no_convs_per_block, self.latent_dim, self.initializers, ).to(device) self.posterior = AxisAlignedConvGaussian(self.n_channels, self.num_filters, self.no_convs_per_block, self.latent_dim, self.initializers, posterior=True).to(device) self.fcomb = Fcomb(self.num_filters, self.latent_dim, self.n_channels, self.n_classes, self.no_convs_fcomb, { 'w': 'orthogonal', 'b': 'normal' }, use_tile=True).to(device) self.posterior_latent_space = None self.prior_latent_space = None self.unet_features = None def forward(self, patch, segm, training=True): """ Construct prior latent space for patch and run patch through UNet, in case training is True also construct posterior latent space """ if training: self.posterior_latent_space = self.posterior.forward(patch, segm) self.prior_latent_space = self.prior.forward(patch) self.unet_features = self.unet.forward(patch) def sample(self, testing=False): """ Sample a segmentation by reconstructing from a prior sample and combining this with UNet features """ if testing == False: z_prior = self.prior_latent_space.rsample() self.z_prior_sample = z_prior else: #You can choose whether you mean a sample or the mean here. For the GED it is important to take a sample. #z_prior = self.prior_latent_space.base_dist.loc z_prior = self.prior_latent_space.sample() self.z_prior_sample = z_prior return self.fcomb.forward(self.unet_features, z_prior) def sample_at(self, z): """ get probability at z location prob = torch.exp(self.prior_latent_space.log_prob(z)) """ return self.fcomb.forward(self.unet_features, z.to(device).unsqueeze(0)) def reconstruct(self, use_posterior_mean=False, calculate_posterior=False, z_posterior=None): """ Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map use_posterior_mean: use posterior_mean instead of sampling z_q calculate_posterior: use a provided sample or sample from posterior latent space """ if use_posterior_mean: z_posterior = self.posterior_latent_space.loc else: if calculate_posterior: z_posterior = self.posterior_latent_space.rsample() return self.fcomb.forward(self.unet_features, z_posterior) def kl_divergence(self, analytic=True, calculate_posterior=False, z_posterior=None): """ Calculate the KL divergence between the posterior and prior KL(Q||P) analytic: calculate KL analytically or via sampling from the posterior calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample """ if analytic: #Need to add this to torch source code, see: https://github.com/pytorch/pytorch/issues/13545 kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space) else: if calculate_posterior: z_posterior = self.posterior_latent_space.rsample() log_posterior_prob = self.posterior_latent_space.log_prob( z_posterior) log_prior_prob = self.prior_latent_space.log_prob(z_posterior) kl_div = log_posterior_prob - log_prior_prob return kl_div def elbo(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): """ Calculate the evidence lower bound of the log-likelihood of P(Y|X) """ if self.n_classes == 1: criterion = nn.BCEWithLogitsLoss(size_average=False, reduce=False, reduction=None) else: criterion = nn.CrossEntropyLoss(size_average=False, reduce=False, reduction=None) z_posterior = self.posterior_latent_space.rsample() #print(z_posterior) self.kl = torch.mean( self.kl_divergence(analytic=analytic_kl, calculate_posterior=False, z_posterior=z_posterior)) #Here we use the posterior sample sampled above self.reconstruction = self.reconstruct( use_posterior_mean=reconstruct_posterior_mean, calculate_posterior=False, z_posterior=z_posterior) self.reconstruction = self.reconstruction.to(device=device, dtype=torch.float32) segm = segm.to(device=device, dtype=torch.long).squeeze(1) reconstruction_loss = criterion(input=self.reconstruction, target=segm) self.reconstruction_loss = torch.sum(reconstruction_loss) #self.mean_reconstruction_loss = torch.mean(self.reconstruction_loss) #print(f"loss: kl={self.kl}, ce={self.reconstruction_loss}") return -(self.reconstruction_loss + self.beta * self.kl)