Beispiel #1
0
class ERGB2Depth(BaseERGB2Depth):
    def __init__(self, config):
        super(ERGB2Depth, self).__init__(config)

        self.unet = UNet(num_input_channels=self.num_bins_rgb,
                         num_output_channels=1,
                         skip_type=self.skip_type,
                         activation='sigmoid',
                         num_encoders=self.num_encoders,
                         base_num_channels=self.base_num_channels,
                         num_residual_blocks=self.num_residual_blocks,
                         norm=self.norm,
                         use_upsample_conv=self.use_upsample_conv)

    def forward(self, item, prev_super_states, prev_states_lstm):
        #def forward(self, event_tensor, prev_states=None):
        """
        :param event_tensor: N x num_bins x H x W
        :return: a predicted image of size N x 1 x H x W, taking values in [0,1].
        """
        predictions_dict = {}
        '''for key in item.keys():
            if "depth" not in key:
                event_tensor = item[key].to(self.gpu)

                prediction = self.unet.forward(event_tensor)
                predictions_dict[key] = prediction'''

        event_tensor = item["image"].to(self.gpu)
        prediction = self.unet.forward(event_tensor)
        predictions_dict["image"] = prediction

        return predictions_dict, {'image': None}, prev_states_lstm
class E2VID(BaseE2VID):
    def __init__(self, config):
        super(E2VID, self).__init__(config)

        self.unet = UNet(num_input_channels=self.num_bins,
                         num_output_channels=1,
                         skip_type=self.skip_type,
                         activation='sigmoid',
                         num_encoders=self.num_encoders,
                         base_num_channels=self.base_num_channels,
                         num_residual_blocks=self.num_residual_blocks,
                         norm=self.norm,
                         use_upsample_conv=self.use_upsample_conv)

    def forward(self, event_tensor, prev_states=None):
        """
        :param event_tensor: N x num_bins x H x W
        :return: a predicted image of size N x 1 x H x W, taking values in [0,1].
        """
        return self.unet.forward(event_tensor), None
class ProbabilisticUnet(nn.Module):
    """
    A probabilistic UNet (https://arxiv.org/abs/1806.05034) implementation.
    input_channels: the number of channels in the image (1 for greyscale and 3 for RGB)
    num_classes: the number of classes to predict
    num_filters: is a list consisint of the amount of filters layer
    latent_dim: dimension of the latent space
    no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior
    """
    def __init__(self,
                 input_channels=1,
                 num_classes=1,
                 num_filters=[32, 64, 128, 192],
                 latent_dim=6,
                 no_convs_fcomb=3,
                 beta=1.0):
        super(ProbabilisticUnet, self).__init__()
        self.n_channels = input_channels
        self.n_classes = num_classes
        self.num_filters = num_filters
        self.latent_dim = latent_dim
        self.no_convs_per_block = 2
        self.no_convs_fcomb = no_convs_fcomb
        self.initializers = {'w': 'he_normal', 'b': 'normal'}
        self.beta = beta
        self.z_prior_sample = 0

        self.unet = UNet(n_channels=self.n_channels,
                         n_classes=self.n_classes,
                         num_filters=self.num_filters,
                         apply_last_layer=False).to(device)
        self.prior = AxisAlignedConvGaussian(
            self.n_channels,
            self.num_filters,
            self.no_convs_per_block,
            self.latent_dim,
            self.initializers,
        ).to(device)
        self.posterior = AxisAlignedConvGaussian(self.n_channels,
                                                 self.num_filters,
                                                 self.no_convs_per_block,
                                                 self.latent_dim,
                                                 self.initializers,
                                                 posterior=True).to(device)
        self.fcomb = Fcomb(self.num_filters,
                           self.latent_dim,
                           self.n_channels,
                           self.n_classes,
                           self.no_convs_fcomb, {
                               'w': 'orthogonal',
                               'b': 'normal'
                           },
                           use_tile=True).to(device)

        self.posterior_latent_space = None
        self.prior_latent_space = None
        self.unet_features = None

    def forward(self, patch, segm, training=True):
        """
        Construct prior latent space for patch and run patch through UNet,
        in case training is True also construct posterior latent space
        """
        if training:
            self.posterior_latent_space = self.posterior.forward(patch, segm)
        self.prior_latent_space = self.prior.forward(patch)
        self.unet_features = self.unet.forward(patch)

    def sample(self, testing=False):
        """
        Sample a segmentation by reconstructing from a prior sample
        and combining this with UNet features
        
        
        """
        if testing == False:
            z_prior = self.prior_latent_space.rsample()
            self.z_prior_sample = z_prior
        else:
            #You can choose whether you mean a sample or the mean here. For the GED it is important to take a sample.
            #z_prior = self.prior_latent_space.base_dist.loc
            z_prior = self.prior_latent_space.sample()
            self.z_prior_sample = z_prior
        return self.fcomb.forward(self.unet_features, z_prior)

    def sample_at(self, z):
        """
        get probability at z location
        prob = torch.exp(self.prior_latent_space.log_prob(z))
        """
        return self.fcomb.forward(self.unet_features,
                                  z.to(device).unsqueeze(0))

    def reconstruct(self,
                    use_posterior_mean=False,
                    calculate_posterior=False,
                    z_posterior=None):
        """
        Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map
        use_posterior_mean: use posterior_mean instead of sampling z_q
        calculate_posterior: use a provided sample or sample from posterior latent space
        """
        if use_posterior_mean:
            z_posterior = self.posterior_latent_space.loc
        else:
            if calculate_posterior:
                z_posterior = self.posterior_latent_space.rsample()
        return self.fcomb.forward(self.unet_features, z_posterior)

    def kl_divergence(self,
                      analytic=True,
                      calculate_posterior=False,
                      z_posterior=None):
        """
        Calculate the KL divergence between the posterior and prior KL(Q||P)
        analytic: calculate KL analytically or via sampling from the posterior
        calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample
        """
        if analytic:
            #Need to add this to torch source code, see: https://github.com/pytorch/pytorch/issues/13545
            kl_div = kl.kl_divergence(self.posterior_latent_space,
                                      self.prior_latent_space)
        else:
            if calculate_posterior:
                z_posterior = self.posterior_latent_space.rsample()
            log_posterior_prob = self.posterior_latent_space.log_prob(
                z_posterior)
            log_prior_prob = self.prior_latent_space.log_prob(z_posterior)
            kl_div = log_posterior_prob - log_prior_prob
        return kl_div

    def elbo(self, segm, analytic_kl=True, reconstruct_posterior_mean=False):
        """
        Calculate the evidence lower bound of the log-likelihood of P(Y|X)
        """
        if self.n_classes == 1:
            criterion = nn.BCEWithLogitsLoss(size_average=False,
                                             reduce=False,
                                             reduction=None)
        else:
            criterion = nn.CrossEntropyLoss(size_average=False,
                                            reduce=False,
                                            reduction=None)

        z_posterior = self.posterior_latent_space.rsample()
        #print(z_posterior)

        self.kl = torch.mean(
            self.kl_divergence(analytic=analytic_kl,
                               calculate_posterior=False,
                               z_posterior=z_posterior))

        #Here we use the posterior sample sampled above
        self.reconstruction = self.reconstruct(
            use_posterior_mean=reconstruct_posterior_mean,
            calculate_posterior=False,
            z_posterior=z_posterior)
        self.reconstruction = self.reconstruction.to(device=device,
                                                     dtype=torch.float32)

        segm = segm.to(device=device, dtype=torch.long).squeeze(1)

        reconstruction_loss = criterion(input=self.reconstruction, target=segm)
        self.reconstruction_loss = torch.sum(reconstruction_loss)
        #self.mean_reconstruction_loss = torch.mean(self.reconstruction_loss)

        #print(f"loss: kl={self.kl}, ce={self.reconstruction_loss}")
        return -(self.reconstruction_loss + self.beta * self.kl)