Ejemplo n.º 1
0
    def __init__(self, in_features, out_channels, num_repetitions, cardinality, dropout=0.0):
        """Creat a gaussian layer.

        Args:
            out_channels: Number of parallel representations for each input feature.
            cardinality: Number of features per gaussian.
            in_features: Number of input features.

        """
        super().__init__(in_features, out_channels, num_repetitions, dropout)
        self.cardinality = cardinality

        # Number of different distributions: total number of features
        # divided by the number of features in each gaussian

        self._pad_value = in_features % cardinality
        self._out_features = np.ceil(in_features / cardinality).astype(int)

        self._n_dists = np.ceil(in_features / cardinality).astype(int)

        # Create gaussian means and stds
        self.means = nn.Parameter(torch.randn(out_channels, self._n_dists, cardinality, num_repetitions))
        self.stds = nn.Parameter(torch.rand(out_channels, self._n_dists, cardinality, num_repetitions))
        self.cov_factors = nn.Parameter(
            torch.zeros(out_channels, self._n_dists, cardinality, num_repetitions), requires_grad=False
        )
        self.gauss = dist.LowRankMultivariateNormal(loc=self.means, cov_factor=self.cov_factors, cov_diag=self.stds)
    def forward(self, image, **kwargs):
        logits = F.relu(super().forward(image, **kwargs)[0])
        batch_size = logits.shape[0]
        event_shape = (self.num_classes, ) + logits.shape[2:]

        mean = self.mean_l(logits)
        cov_diag = self.log_cov_diag_l(logits).exp() + self.epsilon
        mean = mean.view((batch_size, -1))
        cov_diag = cov_diag.view((batch_size, -1))

        cov_factor = self.cov_factor_l(logits)
        cov_factor = cov_factor.view(
            (batch_size, self.rank, self.num_classes, -1))
        cov_factor = cov_factor.flatten(2, 3)
        cov_factor = cov_factor.transpose(1, 2)

        # covariance in the background tens to blow up to infinity, hence set to 0 outside the ROI
        mask = kwargs['sampling_mask']
        mask = mask.unsqueeze(1).expand((batch_size, self.num_classes) +
                                        mask.shape[1:]).reshape(
                                            batch_size, -1)
        cov_factor = cov_factor * mask.unsqueeze(-1)
        cov_diag = cov_diag * mask + self.epsilon

        if self.diagonal:
            base_distribution = td.Independent(
                td.Normal(loc=mean, scale=torch.sqrt(cov_diag)), 1)
        else:
            try:
                base_distribution = td.LowRankMultivariateNormal(
                    loc=mean, cov_factor=cov_factor, cov_diag=cov_diag)
            except:
                print(
                    'Covariance became not invertible using independent normals for this batch!'
                )
                base_distribution = td.Independent(
                    td.Normal(loc=mean, scale=torch.sqrt(cov_diag)), 1)

        distribution = ReshapedDistribution(base_distribution, event_shape)

        shape = (batch_size, ) + event_shape
        logit_mean = mean.view(shape)
        cov_diag_view = cov_diag.view(shape).detach()
        cov_factor_view = cov_factor.transpose(
            2, 1).view((batch_size, self.num_classes * self.rank) +
                       event_shape[1:]).detach()

        output_dict = {
            'logit_mean': logit_mean.detach(),
            'cov_diag': cov_diag_view,
            'cov_factor': cov_factor_view,
            'distribution': distribution
        }

        return logit_mean, output_dict
Ejemplo n.º 3
0
    opt=opt,
    temperature_schedule=lambda t: max(0.01, np.exp(-1e-4 * t)),
    clip_grad=1e5,
    verbose=True,
    writer=writer)

# %%
mixture.load_state_dict(best_params)

# %%
colors = ["red", "green", "blue"]

with torch.no_grad():
    for i, component in enumerate(mixture.components):
        X_k = distrib.LowRankMultivariateNormal(
            component.loc, component.sqrt_cov_factor**2 + 0.1,
            component.sqrt_cov_diag**2 + 0.1).sample((500, ))

        plt.scatter(X_k[:, 0].numpy(), X_k[:, 1].numpy(), c=colors[i], s=5)

    plt.show()

# %%
x = np.linspace(-20, 20, 1000)
z = np.array(np.meshgrid(x, x)).transpose(1, 2, 0)
z = np.reshape(z, [z.shape[0] * z.shape[1], -1])

with torch.no_grad():
    densities = mixture.forward(torch.Tensor(z)).numpy()

mesh = z.reshape([1000, 1000, 2]).transpose(2, 0, 1)