class RepNormal(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.mu = Parameter(FloatTensor([0.0]))
        self.log_variance = Parameter(FloatTensor([0.0]))

    def __call__(self):
        z = Variable(torch.randn(1))
        return self.mu + self.log_variance.exp() * z

    def _repr_pretty_(self, p, cycle):
        p.text("mu = {}".format(self.mu))
        p.text("std = {}".format(self.log_variance.exp()))
Ejemplo n.º 2
0
class RelaxedBernoulli(Distribution):

    def __init__(self, probs=[0.5], temperature=1.0, learnable=True):
        super().__init__()
        self.n_dims = len(probs)
        self.temperature = torch.tensor(temperature)
        if not isinstance(probs, torch.Tensor):
            probs = torch.tensor(probs)
        self.logits = utils.log(probs.float())
        if learnable:
            self.logits = Parameter(self.logits)

    def log_prob(self, value):
        model = dists.RelaxedBernoulli(self.temperature, self.probs)
        return model.log_prob(value).sum(-1)

    def sample(self, batch_size):
        model = dists.RelaxedBernoulli(self.temperature, self.probs)
        return model.sample((batch_size,))

    @property
    def probs(self):
        return self.logits.exp()

    def get_parameters(self):
        if self.n_dims == 1: return {'probs':self.probs.item()}
        return {'probs':self.probs.detach().numpy()}
Ejemplo n.º 3
0
class Binomial(Distribution):
    def __init__(self, total_count=10, probs=[0.5], learnable=True):
        super().__init__()
        if not isinstance(probs, torch.Tensor):
            total_count = torch.tensor(total_count)
        if not isinstance(probs, torch.Tensor):
            probs = torch.tensor(probs).view(-1)
        self.n_dims = len(probs)
        self.total_count = total_count.float()
        self.logits = log(probs.float())
        if learnable:
            self.total_count = Parameter(self.total_count)
            self.logits = Parameter(self.logits)

    def log_prob(self, value):
        return BN(self.total_count, probs=self.probs).log_prob(value).sum(-1)

    def sample(self, batch_size):
        return BN(self.total_count, probs=self.probs).sample((batch_size, ))

    def entropy(self):
        return 0.5 * (2 * pi * e * self.total_count * self.probs *
                      (1 - self.probs)).log()

    @property
    def expectation(self):
        return self.total_count * self.probs

    @property
    def mode(self):
        return ((self.total_count + 1) * self.probs).floor()

    @property
    def median(self):
        return (self.total_count * self.probs).floor()

    @property
    def variance(self):
        return self.total_count * self.probs * (1 - self.probs)

    @property
    def skewness(self):
        return (1 - 2 * self.probs) / (self.total_count * self.probs *
                                       (1 - self.probs)).sqrt()

    @property
    def kurtosis(self):
        pq = self.probs * (1 - self.probs)
        return (1 - 6 * pq) / (self.total_count * pq)

    @property
    def probs(self):
        return self.logits.exp()

    def get_parameters(self):
        return {
            'total_count': self.total_count.detach().numpy(),
            'probs': self.probs.detach().numpy()
        }
Ejemplo n.º 4
0
class LinReg(torch.nn.Module):
    def __init__(self, p):
        super(LinReg, self).__init__()
        self.b = Param(torch.randn(p, requires_grad=True))
        self.log_sig = Param(torch.randn(1, requires_grad=True))

    def forward(self, y, X):
        return Normal(X.matmul(self.b), self.log_sig.exp()).log_prob(y).mean()
Ejemplo n.º 5
0
class GaussianLayer1(Module):
    def __init__(self, in_features, in_dim, out_features, out_dim,
                 num_components, sigma_gamma):
        super(GaussianLayer1, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_components = num_components
        self.sigma_gamma = sigma_gamma
        self.in_dim = in_dim
        self.out_dim = out_dim

        self.mus = Parameter(
            torch.rand(num_components, self.in_dim + self.out_dim))
        # self.log_vars = Parameter(-5. - torch.rand(num_components, 2) / (2 * torch.sqrt(torch.Tensor([float(num_components)]))) * sigma_gamma)
        log_var = (
            1 /
            torch.sqrt(torch.Tensor([float(num_components)]))).pow(2).log()
        self.log_vars = Parameter(log_var -
                                  torch.rand(num_components, self.in_dim +
                                             self.out_dim))
        self.weights = Parameter((torch.rand(num_components) - 0.5))
        self.bias = Parameter((torch.rand(out_features)) - 0.5)
        self.x_in_idx = torch.linspace(0, 1, self.in_features)
        self.x_out_idx = torch.linspace(0, 1, self.out_features)

    def forward(self, x):
        self.x_in_idx = self.x_in_idx.to(self.mus.device)
        self.x_out_idx = self.x_out_idx.to(self.mus.device)

        vars = self.log_vars.exp()
        sigmas = vars.sqrt()

        t0 = time.time()

        mus_x0 = self.mus[:, 0]
        deltas_x0 = (self.x_in_idx.view(-1, 1) - mus_x0).pow(2)
        log_prob_x0 = -deltas_x0 / (2 * vars[:, 0])

        mus_x1 = self.mus[:, 1]
        deltas_x1 = (self.x_out_idx.view(-1, 1) - mus_x1).pow(2)
        log_prob_x1 = -deltas_x1 / (2 * vars[:, 1])

        log_prob_x1 = log_prob_x1.unsqueeze_(1)

        log_probs = log_prob_x1 + log_prob_x0

        probs = (log_probs.exp() * self.weights).sum(dim=-1)

        x1 = torch.mm(x, probs.transpose(1, 0)) + self.bias

        t1 = time.time()
        # print(t1-t0)
        return x1
Ejemplo n.º 6
0
class NegativeBinomial(Distribution):
    def __init__(self, total_count=10, probs=[0.5], learnable=True):
        super().__init__()
        if not isinstance(probs, torch.Tensor):
            total_count = torch.tensor(total_count)
        if not isinstance(probs, torch.Tensor):
            probs = torch.tensor(probs).view(-1)
        self.n_dims = len(probs)
        self.total_count = total_count.float()
        self.logits = log(probs.float())
        if learnable:
            self.total_count = Parameter(self.total_count)
            self.logits = Parameter(self.logits)

    def log_prob(self, value):
        return NB(self.total_count, probs=self.probs).log_prob(value).sum(-1)

    def sample(self, batch_size):
        return NB(self.total_count, probs=self.probs).sample((batch_size, ))

    @property
    def expectation(self):
        return (self.probs * self.total_count) / (1 - self.probs)

    @property
    def mode(self):
        if self.total_count > 1:
            return (self.probs * (self.total_count - 1) /
                    (1 - self.probs)).floor()
        return torch.tensor(0.).float()

    @property
    def variance(self):
        return self.probs * self.total_count / (1 - self.probs).pow(2)

    @property
    def skewness(self):
        return (1 + self.probs) / (self.probs * self.total_count).sqrt()

    @property
    def kurtosis(self):
        return 6. / self.total_count + (1 - self.probs).pow(2) / (
            self.probs * self.total_count)

    @property
    def probs(self):
        return self.logits.exp()

    def get_parameters(self):
        return {
            'total_count': self.total_count.detach().numpy(),
            'probs': self.probs.detach().numpy()
        }
Ejemplo n.º 7
0
class LogNormal(Distribution):
    def __init__(self, loc=0., scale=1., learnable=True):
        super().__init__()
        if not isinstance(loc, torch.Tensor):
            loc = torch.tensor(loc).view(-1)
        self.n_dims = len(loc)
        if not isinstance(scale, torch.Tensor):
            scale = torch.tensor(scale).view(-1)
        self.loc = loc.float()
        self._scale = utils.softplus_inverse(scale.float())
        if learnable:
            self.loc = Parameter(self.loc)
            self._scale = Parameter(self._scale)

    def log_prob(self, value):
        return dists.LogNormal(self.loc, self.scale).log_prob(value).sum(-1)

    def sample(self, batch_size):
        return dists.LogNormal(self.loc, self.scale).rsample((batch_size, ))

    @property
    def expectation(self):
        return (self.loc + self.scale.pow(2) / 2).exp()

    @property
    def variance(self):
        s_square = self.scale.pow(2)
        return (s_square.exp() - 1) * (2 * self.loc + s_square).exp()

    @property
    def median(self):
        return self.loc.exp()

    @property
    def mode(self):
        return (self.loc - self.scale.pow(2)).exp()

    def entropy(self):
        return dists.LogNormal(self.loc, self.scale).entropy()

    @property
    def scale(self):
        return softplus(self._scale)

    def get_parameters(self):
        if self.n_dims == 1:
            return {'loc': self.loc.item(), 'scale': self.scale.item()}
        return {
            'loc': self.loc.detach().numpy(),
            'scale': self.scale.detach().numpy()
        }
Ejemplo n.º 8
0
class GaussianLayer2(Module):
    def __init__(self, in_features, out_features, num_components, sigma_gamma):
        super(GaussianLayer2, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_components = num_components
        self.sigma_gamma = sigma_gamma

        self.mus = Parameter(torch.rand(num_components, 2))
        # self.log_vars = Parameter(-5. - torch.rand(num_components, 2) / (2 * torch.sqrt(torch.Tensor([float(num_components)]))) * sigma_gamma)
        log_var = (
            1 /
            torch.sqrt(torch.Tensor([float(num_components)]))).pow(2).log()
        self.log_vars = Parameter(log_var - torch.rand(num_components, 2))
        self.weights = Parameter(((torch.rand(num_components)) - 0.5))

        self.x_in_idx = torch.linspace(0, 1, self.in_features)
        self.x_out_idx = torch.linspace(0, 1, self.out_features)

    def forward(self, x):
        vars = self.log_vars.exp()
        sigmas = vars.sqrt()

        x_out = torch.zeros(x.shape[0], self.out_features)

        # compute Cs
        C = torch.zeros(x.shape[0], self.num_components)
        for m in range(self.num_components):
            mu_x0 = self.mus[m, 0]
            deltas = (self.x_in_idx - mu_x0).pow(2)
            exp_terms = -deltas / (2 * vars[m, 0])
            # smth = torch.mm(x, exp_terms.exp().reshape(-1,1))
            C[:, m] = torch.mm(
                x,
                exp_terms.exp().reshape(-1, 1)).squeeze() / torch.sqrt(
                    2. * math.pi * vars[m, 0]) * self.weights[m]

        for j in range(self.out_features):
            # calc prob of x1_i along x1 axis by summing up probs of x1_i given each of m Gaussian components
            dist = Normal(self.mus[:, 1], sigmas[:, 1])  # X1 marginal
            log_probs = dist.log_prob(self.x_out_idx[j])
            probs = log_probs.exp()
            weighted_probs = probs * C

            result_prob = weighted_probs.sum(dim=-1)
            x_out[:, j] = result_prob / (self.in_features * self.weights.sum())

        return x_out
Ejemplo n.º 9
0
class LogCauchy(Distribution):
    def __init__(self, loc=0., scale=1., learnable=True):
        super().__init__()
        if not isinstance(loc, torch.Tensor):
            loc = torch.tensor(loc).view(-1)
        self.n_dims = len(loc)
        if not isinstance(scale, torch.Tensor):
            scale = torch.tensor(scale).view(-1)
        self.loc = loc.float()
        self._scale = utils.softplus_inverse(scale.float())
        if learnable:
            self.loc = Parameter(self.loc)
            self._scale = Parameter(self._scale)

    def log_prob(self, value):
        model = TransformDistribution(
            Cauchy(self.loc, self.scale, learnable=False), [Exp()])
        return model.log_prob(value)

    def sample(self, batch_size):
        model = TransformDistribution(
            Cauchy(self.loc, self.scale, learnable=False), [Exp()])
        return model.sample(batch_size)

    def cdf(self, value):
        std_term = (value.log() - self.loc) / self.scale
        return (1. / math.pi) * torch.atan(std_term) + 0.5

    def icdf(self, value):
        cauchy_icdf = torch.tan(math.pi *
                                (value - 0.5)) * self.scale + self.loc
        return cauchy_icdf.exp()

    @property
    def scale(self):
        return softplus(self._scale)

    @property
    def median(self):
        return self.loc.exp()

    def get_parameters(self):
        if self.n_dims == 1:
            return {'loc': self.loc.item(), 'scale': self.scale.item()}
        return {
            'loc': self.loc.detach().numpy(),
            'scale': self.scale.detach().numpy()
        }
Ejemplo n.º 10
0
class MaskedConv2d(nn.Module):
    """
    Conv2d with mask and weight normalization.
    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 mask_type='A',
                 order='A',
                 masked_channels=None,
                 stride=1,
                 dilation=1,
                 groups=1):
        super(MaskedConv2d, self).__init__()
        assert mask_type in {'A', 'B'}
        assert order in {'A', 'B'}
        self.mask_type = mask_type
        self.order = order
        kernel_size = _pair(kernel_size)
        for k in kernel_size:
            assert k % 2 == 1, 'kernel cannot include even number: {}'.format(
                self.kernel_size)
        padding = (kernel_size[0] // 2, kernel_size[1] // 2)
        stride = _pair(stride)
        dilation = _pair(dilation)

        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        self.in_channels = in_channels
        self.out_channels = out_channels
        # masked all input channels by default
        masked_channels = in_channels if masked_channels is None else masked_channels
        self.masked_channels = masked_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups

        self.weight_v = Parameter(
            torch.Tensor(out_channels, in_channels // groups, *kernel_size))
        self.weight_g = Parameter(torch.Tensor(out_channels, 1, 1, 1))
        self.bias = Parameter(torch.Tensor(out_channels))

        self.register_buffer('mask', torch.ones(self.weight_v.size()))
        _, _, kH, kW = self.weight_v.size()
        mask = np.ones([*self.mask.size()], dtype=np.float32)
        mask[:, :masked_channels, kH // 2, kW // 2 + (mask_type == 'B'):] = 0
        mask[:, :masked_channels, kH // 2 + 1:] = 0

        # reverse order
        if order == 'B':
            reverse_mask = mask[:, :, ::-1, :]
            reverse_mask = reverse_mask[:, :, :, ::-1]
            mask = reverse_mask.copy()
        self.mask.copy_(torch.from_numpy(mask).float())
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.normal_(self.weight_v, mean=0.0, std=0.05)
        self.weight_v.data.mul_(self.mask)
        _norm = norm(self.weight_v, 0).data + 1e-8
        self.weight_g.data.copy_(_norm.log())
        nn.init.constant_(self.bias, 0)

    def init(self, x, init_scale=1.0):
        with torch.no_grad():
            # [batch, n_channels, H, W]
            out = self(x)
            n_channels = out.size(1)
            out = out.transpose(0, 1).contiguous().view(n_channels, -1)
            # [n_channels]
            mean = out.mean(dim=1)
            std = out.std(dim=1)
            inv_stdv = init_scale / (std + 1e-6)
            self.weight_g.add_(inv_stdv.log().view(n_channels, 1, 1, 1))
            self.bias.add_(-mean).mul_(inv_stdv)
            return self(x)

    def forward(self, input):
        self.weight_v.data.mul_(self.mask)
        _norm = norm(self.weight_v, 0) + 1e-8
        weight = self.weight_v * (self.weight_g.exp() / _norm)
        return F.conv2d(input, weight, self.bias, self.stride, self.padding,
                        self.dilation, self.groups)

    def extra_repr(self):
        s = (
            '{in_channels}({masked_channels}), {out_channels}, kernel_size={kernel_size}'
            ', stride={stride}')
        if self.padding != (0, ) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1, ) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        s += ', type={mask_type}, order={order}'
        return s.format(**self.__dict__)
Ejemplo n.º 11
0
class GaussianLayerBase(Module):
    def __init__(self, in_shape, out_shape, num_components, sigma_gamma):
        super(GaussianLayerBase, self).__init__()
        self.in_shape = in_shape
        self.out_shape = out_shape
        self.num_components = num_components
        self.sigma_gamma = sigma_gamma

        self.in_dim = len(self.in_shape)
        self.out_dim = len(self.out_shape)

        self.num_in_units = np.prod(self.in_shape)
        self.num_out_units = np.prod(self.out_shape)

        self.mus = Parameter(
            torch.rand(num_components, self.in_dim + self.out_dim))
        # self.log_vars = Parameter(-5. - torch.rand(num_components, 2) / (2 * torch.sqrt(torch.Tensor([float(num_components)]))) * sigma_gamma)
        log_var = (
            1 /
            torch.sqrt(torch.Tensor([float(num_components)]))).pow(2).log()
        self.log_vars = Parameter(log_var -
                                  torch.rand(num_components, self.in_dim +
                                             self.out_dim))
        self.weights = Parameter((torch.rand(num_components) - 0.5))

        self.bias = Parameter((torch.rand(self.num_out_units)) - 0.5)

        self.x_in_idx = [
            torch.linspace(0, 1, in_shape_i) for in_shape_i in self.in_shape
        ]
        self.x_out_idx = [
            torch.linspace(0, 1, out_shape_i) for out_shape_i in self.out_shape
        ]

    def forward(self, x):
        batch_size = x.shape[0]
        vars = self.log_vars.exp()
        sigmas = vars.sqrt()

        t0 = time.time()

        log_probs = torch.zeros((self.num_components))

        for i in range(self.in_dim):
            mus_x_i = self.mus[:, i]
            deltas_x_i = (self.x_in_idx[i].view(-1, 1) - mus_x_i).pow(2)
            log_prob_x_i = -deltas_x_i / (2 * vars[:, i])
            # log_prob_x_i = log_prob_x_i.transpose()
            log_probs = log_probs.unsqueeze_(-2)
            log_probs = log_probs + log_prob_x_i
        # log_prob_x0 = log_prob_x_i

        for j in range(self.out_dim):
            mus_x_j = self.mus[:, self.in_dim + j]
            deltas_x_j = (self.x_out_idx[j].view(-1, 1) - mus_x_j).pow(2)
            log_prob_x_j = -deltas_x_j / (2 * vars[:, self.in_dim + j])
            # log_prob_x_j = log_prob_x_j.transpose(0,1)
            log_probs = log_probs.unsqueeze_(-2)
            log_probs = log_probs + log_prob_x_j
        # log_prob_x1 = log_prob_x_j.unsqueeze_(1)

        # log_probs = log_prob_x1 + log_prob_x0

        probs = (log_probs.exp() * self.weights).sum(dim=-1)

        probs = probs.view(self.num_in_units, self.num_out_units)
        x = x.view(batch_size, self.num_in_units)

        x_out = torch.mm(x, probs) + self.bias
        x_out = x_out.view(batch_size, *self.out_shape)

        t1 = time.time()
        # print(t1-t0)
        return x_out

    def to(self, *args, **kwargs):
        super().to(*args, **kwargs)
        if self.device != self.x_in_idx[0].device:
            self.x_in_idx = [
                self.x_in_idx[i].to(self.device)
                for i in range(len(self.x_in_idx))
            ]
            self.x_out_idx = [
                self.x_out_idx[i].to(self.device)
                for i in range(len(self.x_out_idx))
            ]
Ejemplo n.º 12
0
class BayesLinearMF(Module):
    r"""
    Applies Bayesian Linear

    Arguments:

    .. note:: other arguments are following linear of pytorch 1.2.0.
    https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py

    """
    __constants__ = ['bias', 'in_features', 'out_features']

    def __init__(self,
                 single_eps,
                 local_reparam,
                 in_features,
                 out_features,
                 bias=True,
                 deterministic=False):
        super(BayesLinearMF, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.single_eps = single_eps
        self.local_reparam = local_reparam
        self.deterministic = deterministic

        self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
        self.weight_log_sigma = Parameter(
            torch.Tensor(out_features, in_features))

        if bias is None or bias is False:
            self.bias = False
        else:
            self.bias = True

        if self.bias:
            self.bias_mu = Parameter(torch.Tensor(out_features))
            self.bias_log_sigma = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias_mu', None)
            self.register_parameter('bias_log_sigma', None)

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight_mu.size(1))
        self.weight_mu.data.uniform_(-stdv, stdv)
        self.weight_log_sigma.data.fill_(-5)
        if self.bias:
            self.bias_mu.data.uniform_(-stdv, stdv)
            self.bias_log_sigma.data.fill_(-5)

    def forward(self, input):
        r"""
        Overriden.
        """
        if self.single_eps or self.deterministic:
            if self.deterministic:
                weight = self.weight_mu
                bias = self.bias_mu if self.bias else None
            else:
                weight = self.weight_mu + torch.exp(
                    self.weight_log_sigma) * torch.randn(self.out_features,
                                                         self.in_features,
                                                         device=input.device,
                                                         dtype=input.dtype)
                bias = self.bias_mu + torch.exp(
                    self.bias_log_sigma) * torch.randn(
                        self.out_features,
                        device=input.device,
                        dtype=input.dtype) if self.bias else None
            out = F.linear(input, weight, bias)
        else:
            if self.local_reparam:
                act_mu = F.linear(input, self.weight_mu,
                                  self.bias_mu if self.bias else None)
                act_var = F.linear(
                    input**2,
                    self.weight_log_sigma.exp()**2,
                    self.bias_log_sigma.exp()**2 if self.bias else None)
                act_std = torch.sqrt(act_var + 1e-16)
                out = act_mu + act_std * torch.randn_like(act_mu)
            else:
                weight = self.weight_mu + torch.exp(
                    self.weight_log_sigma) * torch.randn(input.shape[0],
                                                         self.out_features,
                                                         self.in_features,
                                                         device=input.device,
                                                         dtype=input.dtype)
                out = torch.bmm(weight, input.unsqueeze(2)).squeeze()
                if self.bias:
                    bias = self.bias_mu + torch.exp(
                        self.bias_log_sigma) * torch.randn(input.shape[0],
                                                           self.out_features,
                                                           device=input.device,
                                                           dtype=input.dtype)
                    out = out + bias
        return out

    def extra_repr(self):
        r"""
        Overriden.
        """
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None)
Ejemplo n.º 13
0
class MaskedLinear(nn.Module):
    """
    masked linear module with weight normalization
    """
    def __init__(self, in_features, out_features, mask_type, total_units, max_units=None, bias=True):
        """
        Args:
            in_features: number of units in the inputs
            out_features: number of units in the outputs.
            max_units: the list containing the maximum units each input unit depends on.
            mask_type: type of the masked linear.
            total_units: the total number of units to assign.
            bias: using bias vector.
        """
        super(MaskedLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.weight_v = Parameter(torch.Tensor(out_features, in_features))
        self.weight_g = Parameter(torch.Tensor(out_features, 1))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)

        layer_type, order = mask_type
        self.layer_type = layer_type
        self.order = order
        assert layer_type in {'input-hidden', 'hidden-hidden', 'hidden-output', 'input-output'}
        assert order in {'A', 'B'}
        self.register_buffer('mask', self.weight_v.data.clone())

        # override the max_units for input layer
        if layer_type.startswith('input'):
            max_units = np.arange(in_features) + 1
        else:
            assert max_units is not None and len(max_units) == in_features

        if layer_type.endswith('output'):
            assert out_features > total_units
            self.max_units = np.arange(out_features)
            self.max_units[total_units:] = total_units
        else:
            units_per_units = float(total_units) / out_features
            self.max_units = np.zeros(out_features, dtype=np.int32)
            for i in range(out_features):
                self.max_units[i] = np.ceil((i + 1) * units_per_units)

        mask = np.zeros([out_features, in_features], dtype=np.float32)
        for i in range(out_features):
            for j in range(in_features):
                mask[i, j] = float(self.max_units[i] >= max_units[j])

        # reverse order
        if order == 'B':
            reverse_mask = mask[::-1, :]
            reverse_mask = reverse_mask[:, ::-1]
            mask = np.copy(reverse_mask)

        self.mask.copy_(torch.from_numpy(mask).float())
        self.reset_parameters()

        self._init = True

    def reset_parameters(self):
        nn.init.normal_(self.weight_v, mean=0.0, std=0.05)
        self.weight_v.data.mul_(self.mask)
        _norm = norm(self.weight_v, 0).data + 1e-8
        self.weight_g.data.copy_(_norm.log())
        if self.bias is not None:
            nn.init.constant_(self.bias, 0.)

    def initialize(self, x, init_scale=1.0):
        with torch.no_grad():
            # [batch, out_features]
            out = self(x)
            # [out_features]
            mean = out.mean(dim=0)
            std = out.std(dim=0)
            std = std + std.le(0).float()
            inv_stdv = init_scale / (std + 1e-6)

            self.weight_g.add_(inv_stdv.log().unsqueeze(1))
            if self.bias is not None:
                self.bias.add_(-mean).mul_(inv_stdv)
            return self(x)

    def forward(self, input):
        self.weight_v.data.mul_(self.mask)
        _norm = norm(self.weight_v, 0) + 1e-8
        weight = self.weight_v * (self.weight_g.exp() / _norm)
        return F.linear(input, weight, self.bias)

    @overrides
    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}, type={}, order={}'.format(
            self.in_features, self.out_features, self.bias is not None,
            self.layer_type, self.order
        )
Ejemplo n.º 14
0
class Geometric(Distribution):
    def __init__(self, probs=0.5, learnable=True):
        super().__init__()
        if not isinstance(probs, torch.Tensor):
            probs = torch.tensor(probs).view(-1)
        self.n_dims = len(probs)
        self.logits = log(probs.float())
        if learnable:
            self.logits = Parameter(self.logits)

    def log_prob(self, value):
        return (value * (-self.probs).log1p() + log(self.probs)).sum(-1)

    def sample(self, batch_size):
        u = torch.rand((batch_size, self.n_dims))
        return (u.log() / (-self.probs).log1p()).floor()

    def cdf(self, value):
        return 1 - (1 - self.probs).pow(value + 1.)

    def icdf(self, value):
        return ((-value).log1p() / (-self.probs).log1p()) - 1.

    def entropy(self):
        q = (1. - self.probs)
        return -(q * utils.log(q) +
                 self.probs * utils.log(self.probs)) / self.probs

    def kl(self, other):
        if isinstance(other, Geometric):
            return (-self.entropy() - (-other.probs).log1p() / self.probs -
                    other.logits).sum()
        return None

    @property
    def expectation(self):
        return 1. / self.probs - 1.

    @property
    def variance(self):
        return (1. / self.probs - 1.) / self.probs

    @property
    def mode(self):
        return torch.tensor(0.).float()

    @property
    def skewness(self):
        return (2 - self.probs) / (1 - self.probs).sqrt()

    @property
    def kurtosis(self):
        return 6. + (self.probs.pow(2) / (1. - self.probs))

    @property
    def median(self):
        return (-1 / (-self.probs).log1p()).ceil() - 1.

    @property
    def probs(self):
        return self.logits.exp()

    def get_parameters(self):
        return {'probs': self.probs.detach().numpy()}
Ejemplo n.º 15
0
class ActNormFlow(Flow):
    def __init__(self, in_features, inverse=False):
        super(ActNormFlow, self).__init__(inverse)
        self.in_features = in_features
        self.log_scale = Parameter(torch.Tensor(in_features))
        self.bias = Parameter(torch.Tensor(in_features))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.normal_(self.log_scale, mean=0, std=0.05)
        nn.init.constant_(self.bias, 0.)

    @overrides
    def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Args:
            input: Tensor
                input tensor [batch, N1, N2, ..., in_channels]

        Returns: out: Tensor , logdet: Tensor
            out: [batch, N1, N2, ..., in_channels], the output of the flow
            logdet: [batch], the log determinant of :math:`\partial output / \partial input`

        """
        out = input * self.log_scale.exp() + self.bias
        logdet = self.log_scale.sum(dim=0, keepdim=True)
        if input.dim() > 2:
            num = np.prod(input.size()[1:-1])
            logdet = logdet * num.astype(float)
        return out, logdet

    @overrides
    def backward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Args:
            input: input: Tensor
                input tensor [batch, N1, N2, ..., in_channels]

        Returns: out: Tensor , logdet: Tensor
            out: [batch, N1, N2, ..., in_channels], the output of the flow
            logdet: [batch], the log determinant of :math:`\partial output / \partial input`

        """
        out = input - self.bias
        out = out.div(self.log_scale.exp() + 1e-8)
        logdet = self.log_scale.sum(dim=0, keepdim=True) * -1.0
        if input.dim() > 2:
            num = np.prod(input.size()[1:-1])
            logdet = logdet * num.astype(float)
        return out, logdet

    @overrides
    def init(self, data, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Args:
            data: input: Tensor
                input tensor [batch, N1, N2, ..., in_channels]

        Returns: out: Tensor , logdet: Tensor
            out: [batch, N1, N2, ..., in_channels], the output of the flow
            logdet: [batch], the log determinant of :math:`\partial output / \partial input`

        """
        with torch.no_grad():
            out, _ = self.forward(data)
            mean = out.view(-1, self.in_features).mean(dim=0)
            std = out.view(-1, self.in_features).std(dim=0)
            inv_stdv = init_scale / (std + 1e-6)

            self.log_scale.add_(inv_stdv.log())
            self.bias.add_(-mean).mul_(inv_stdv)
            return self.forward(data)

    @overrides
    def extra_repr(self):
        return 'inverse={}, in_features={}'.format(self.inverse, self.in_features)

    @classmethod
    def from_params(cls, params: Dict) -> "ActNormFlow":
        return ActNormFlow(**params)
Ejemplo n.º 16
0
class ActNorm2dFlow(Flow):
    def __init__(self, in_channels, inverse=False):
        super(ActNorm2dFlow, self).__init__(inverse)
        self.in_channels = in_channels
        self.log_scale = Parameter(torch.Tensor(in_channels, 1, 1))
        self.bias = Parameter(torch.Tensor(in_channels, 1, 1))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.normal_(self.log_scale, mean=0, std=0.05)
        nn.init.constant_(self.bias, 0.)

    @overrides
    def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Args:
            input: Tensor
                input tensor [batch, in_channels, H, W]

        Returns: out: Tensor , logdet: Tensor
            out: [batch, in_channels, H, W], the output of the flow
            logdet: [batch], the log determinant of :math:`\partial output / \partial input`

        """
        batch, channels, H, W = input.size()
        out = input * self.log_scale.exp() + self.bias
        logdet = self.log_scale.sum(dim=0).squeeze(1).mul(H * W)
        return out, logdet

    @overrides
    def backward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Args:
            input: Tensor
                input tensor [batch, in_channels, H, W]

        Returns: out: Tensor , logdet: Tensor
            out: [batch, in_channels, H, W], the output of the flow
            logdet: [batch], the log determinant of :math:`\partial output / \partial input`

        """
        batch, channels, H, W = input.size()
        out = input - self.bias
        out = out.div(self.log_scale.exp() + 1e-8)
        logdet = self.log_scale.sum(dim=0).squeeze(1).mul(H * -W)
        return out, logdet

    @overrides
    def init(self, data, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
        with torch.no_grad():
            # [batch, n_channels, H, W]
            out, _ = self.forward(data)
            out = out.transpose(0, 1).contiguous().view(self.in_channels, -1)
            # [n_channels, 1, 1]
            mean = out.mean(dim=1).view(self.in_channels, 1, 1)
            std = out.std(dim=1).view(self.in_channels, 1, 1)
            inv_stdv = init_scale / (std + 1e-6)

            self.log_scale.add_(inv_stdv.log())
            self.bias.add_(-mean).mul_(inv_stdv)
            return self.forward(data)

    @overrides
    def extra_repr(self):
        return 'inverse={}, in_channels={}'.format(self.inverse, self.in_channels)

    @classmethod
    def from_params(cls, params: Dict) -> "ActNorm2dFlow":
        return ActNorm2dFlow(**params)
Ejemplo n.º 17
0
class ActNorm1dFlow(Flow):
    def __init__(self, in_features, inverse=False):
        super(ActNorm1dFlow, self).__init__(inverse)
        self.in_features = in_features
        self.log_scale = Parameter(torch.Tensor(in_features))
        self.bias = Parameter(torch.Tensor(in_features))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.normal_(self.log_scale, mean=0, std=0.05)
        nn.init.constant_(self.bias, 0.)

    @overrides
    def forward(self, input: torch.Tensor, mask: Union[torch.Tensor, None] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Args:
            input: Tensor
                input tensor [batch, N1, N2, ..., in_channels]
            mask: Tensor or None
                mask tensor [batch, N1, N2, ...,Nl]

        Returns: out: Tensor , logdet: Tensor
            out: [batch, N1, N2, ..., in_channels], the output of the flow
            logdet: [batch], the log determinant of :math:`\partial output / \partial input`

        """
        dim = input.dim()
        out = input * self.log_scale.exp() + self.bias
        if mask is not None:
            out = out * mask.unsqueeze(dim - 1)

        logdet = self.log_scale.sum(dim=0, keepdim=True)
        if dim > 2:
            num = np.prod(input.size()[1:-1]).astype(float) if mask is None else mask.view(out.size(0), -1).sum(dim=1)
            logdet = logdet * num
        return out, logdet

    @overrides
    def backward(self, input: torch.Tensor, mask: Union[torch.Tensor, None] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Args:
            input: input: Tensor
                input tensor [batch, N1, N2, ..., in_channels]
            mask: Tensor or None
                mask tensor [batch, N1, N2, ...,Nl]

        Returns: out: Tensor , logdet: Tensor
            out: [batch, N1, N2, ..., in_channels], the output of the flow
            logdet: [batch], the log determinant of :math:`\partial output / \partial input`

        """
        dim = input.dim()
        out = (input - self.bias).div(self.log_scale.exp() + 1e-8)
        if mask is not None:
            out = out * mask.unsqueeze(dim - 1)

        logdet = self.log_scale.sum(dim=0, keepdim=True) * -1.0
        if input.dim() > 2:
            num = np.prod(input.size()[1:-1]).astype(float) if mask is None else mask.view(out.size(0), -1).sum(dim=1)
            logdet = logdet * num
        return out, logdet

    @overrides
    def init(self, data, mask: Union[torch.Tensor, None] = None, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Args:
            data: input: Tensor
                input tensor [batch, N1, N2, ..., in_channels]
            mask: Tensor or None
                mask tensor [batch, N1, N2, ...,Nl]
            init_scale: float
                initial scale

        Returns: out: Tensor , logdet: Tensor
            out: [batch, N1, N2, ..., in_channels], the output of the flow
            logdet: [batch], the log determinant of :math:`\partial output / \partial input`

        """
        with torch.no_grad():
            # [batch * N1 * ... * Nl, in_features]
            out, _ = self.forward(data, mask=mask)
            out = out.view(-1, self.in_features)
            mean = out.mean(dim=0)
            std = out.std(dim=0)
            inv_stdv = init_scale / (std + 1e-6)

            self.log_scale.add_(inv_stdv.log())
            self.bias.add_(-mean).mul_(inv_stdv)
            return self.forward(data, mask=mask)

    @overrides
    def extra_repr(self):
        return 'inverse={}, in_features={}'.format(self.inverse, self.in_features)

    @classmethod
    def from_params(cls, params: Dict) -> "ActNorm1dFlow":
        return ActNorm1dFlow(**params)
Ejemplo n.º 18
0
class ADKL_KRR_net(MetaNetwork):
    TRAIN = 0
    DESCR = 1
    BOTH = 2
    NB_KERNEL_PARAMS = 1

    def __init__(self,
                 input_features_extractor_params,
                 target_features_extractor_params,
                 condition_on='train',
                 task_descr_extractor_params=None,
                 dataset_encoder_params=None,
                 hp_mode='fixe',
                 l2=0.1,
                 device='cuda',
                 task_encoder_reg=0.,
                 n_pseudo_inputs=0,
                 pseudo_inputs_reg=0,
                 stationary_kernel=False):
        """
        In the constructor we instantiate an lstm module
        """
        super(ADKL_KRR_net, self).__init__()

        if condition_on.lower() in ['train', 'train_samples']:
            assert dataset_encoder_params is not None, 'dataset_encoder_params must be specified'
            self.condition_on = self.TRAIN
        elif condition_on.lower() in ['descr', 'task_descr']:
            assert task_descr_extractor_params is not None, 'task_descr_extractor_params must be specified'
            self.condition_on = self.DESCR
        elif condition_on.lower() in ['both']:
            assert dataset_encoder_params is not None, 'dataset_encoder_params must be specified'
            assert task_descr_extractor_params is not None, 'task_descr_extractor_params must be specified'
            self.condition_on = self.BOTH
        else:
            raise ValueError('Invalid option for parameter condition_on')

        self.task_encoder_reg = task_encoder_reg

        if input_features_extractor_params.get('pooling_fn', 0) is None:
            pooling = GlobalAvgPool1d(dim=1)
            spectral_kernel = True
        else:
            pooling = None
            spectral_kernel = False

        task_encoder = TaskEncoderNet(
            input_features_extractor_params,
            target_features_extractor_params,
            dataset_encoder_params,
            complement_module_input_fextractor=pooling)
        self.features_extractor = task_encoder.input_fextractor
        fe_dim = self.features_extractor.output_dim

        self.task_descr_extractor = None
        tde_dim, de_dim = 0, 0
        if self.condition_on in [self.DESCR, self.BOTH]:
            self.task_descr_extractor = FeaturesExtractorFactory()(
                **task_descr_extractor_params)
            tde_dim = self.task_descr_extractor.output_dim
        if self.condition_on in [self.TRAIN, self.BOTH]:
            self.dataset_encoder = task_encoder
            de_dim = self.dataset_encoder.output_dim

        self.l2 = l2
        self.pseudo_inputs_reg = pseudo_inputs_reg
        self.hp_mode = hp_mode
        self.device = device
        if not stationary_kernel:
            self.kernel_network = NonStationaryKernel(fe_dim, de_dim + tde_dim,
                                                      fe_dim, spectral_kernel)
        else:
            self.kernel_network = StationaryKernel(fe_dim, de_dim + tde_dim,
                                                   fe_dim, spectral_kernel)

        if n_pseudo_inputs > 0:
            if spectral_kernel:
                self.pseudo_inputs = Parameter(
                    torch.Tensor(n_pseudo_inputs, fe_dim)).to(device)
            else:
                self.pseudo_inputs = Parameter(
                    torch.Tensor(n_pseudo_inputs, fe_dim)).to(device)
        else:
            self.pseudo_inputs = None
        self.phis_train_mean, self.phis_train_std = 0, 0

        if hp_mode.lower() in ['learn', 'learned', 'l']:
            self.hp_mode = 'l'
        elif hp_mode.lower() in [
                'predicted', 'predict', 'p', 't', 'task-specific', 'per-task'
        ]:
            self.hp_mode = 't'
            d = (de_dim if de_dim else 0) + (tde_dim if tde_dim else 0)
            self.hp_net = Linear(d, self.NB_KERNEL_PARAMS)
        else:
            raise Exception('hp_mode should be one of those: fixe, learn, cv')
        self._init_kernel_params(device)

    def _init_kernel_params(self, device):
        if self.pseudo_inputs is not None:
            init.kaiming_uniform_(self.pseudo_inputs, a=math.sqrt(5))
        self.l2 = torch.FloatTensor([self.l2]).to(device)

        if self.hp_mode == 'l':
            self.l2 = Parameter(self.l2)

    def compute_batch_gram_matrix(self, x, y, task_phis):
        k_ = self.kernel_network(x, y, task_phis)
        if self.pseudo_inputs is not None:
            ps = self.pseudo_inputs.unsqueeze(0).expand(
                x.shape[0], *self.pseudo_inputs.shape)
            k_g = self.kernel_network(x, ps, task_phis)
            k_ = torch.cat((k_, k_g), dim=-1)
        return k_

    def set_kernel_params(self, task_phis):
        if self.hp_mode == 't':
            self.l2 = self.hp_net(task_phis).squeeze(-1)

        l2 = hardtanh(self.l2.exp(), 1e-4, 1e1)
        return l2

    def add_pseudo_inputs_loss(self, loss):
        n = self.pseudo_inputs.shape[0]
        d = self.pseudo_inputs.shape[-1]
        p = self.pseudo_inputs.reshape(-1, d)

        # reg = torch.exp(-0.5 * (p.unsqueeze(2) - p.unsqueeze(1)).pow(2).sum(-1))
        # reg = torch.tril(reg).sum() / (n * (n - 1))

        if self.pseudo_inputs.dim() == 2:
            pi_mean = torch.mean(self.pseudo_inputs, dim=0)
            pi_std = torch.std(self.pseudo_inputs, dim=0)
        elif self.pseudo_inputs.dim() == 3:
            pi_mean = torch.mean(self.pseudo_inputs, dim=(0, 1)),
            pi_std = torch.std(self.pseudo_inputs, dim=(0, 1))
        else:
            raise Exception(
                'Pseudo inputs: the number of dimensions is incorrect')

        kl = kl_divergence(
            MultivariateNormal(pi_mean, torch.diag(pi_std + 0.1)),
            MultivariateNormal(self.phis_train_mean,
                               torch.diag(self.phis_train_std + 0.1)))

        res = self.pseudo_inputs_reg * kl
        return loss + res, res

    def add_task_encoder_loss(self, loss):
        reg = self.task_encoder_reg * self.task_encoder_loss
        return loss + reg, reg

    def get_alphas(self, phis, ys, masks, task_phis=None):
        l2 = self.set_kernel_params(task_phis)
        bsize, n_train = phis.shape[:2]

        k_ = self.compute_batch_gram_matrix(phis, phis, task_phis=task_phis)
        k = torch.bmm(k_, k_.transpose(1, 2))
        k_mask = masks[:, None, :] * masks[:, :, None]
        k = k * k_mask

        identity = torch.eye(n_train, device=k.device).unsqueeze(0).expand(
            (bsize, n_train, n_train))
        batch_K_inv = torch.inverse(k +
                                    l2.unsqueeze(1).unsqueeze(1) * identity)
        alphas = torch.bmm(batch_K_inv, ys)

        return alphas, k_

    def get_preds(self,
                  alphas,
                  K_train,
                  phis_train,
                  masks_train,
                  phis_test,
                  masks_test,
                  task_phis=None):
        k = self.compute_batch_gram_matrix(phis_test,
                                           phis_train,
                                           task_phis=task_phis)
        k = torch.bmm(k, K_train.transpose(1, 2))
        k_mask = masks_test[:, :, None] * masks_train[:, None, :]
        k = k * k_mask

        preds = torch.bmm(k, alphas)
        return preds

    def get_task_phis(self, tasks_descr, xs_train, ys_train, mask_train):
        if self.condition_on == self.DESCR:
            task_phis = self.task_descr_extractor(tasks_descr)
        elif self.condition_on == self.TRAIN:
            task_phis = self.dataset_encoder(None, xs_train, ys_train,
                                             mask_train)
        else:
            task_phis = torch.cat([
                self.task_descr_extractor(tasks_descr),
                self.dataset_encoder(None, xs_train, ys_train, mask_train)
            ],
                                  dim=1)
        return task_phis

    def get_phis(self, xs, train=False):
        phis = self.features_extractor(xs.reshape(-1, xs.shape[2]))
        if train:
            alpha = 0.8
            if phis.dim() == 2:
                self.phis_train_mean = (
                    1 - alpha) * self.phis_train_mean + alpha * torch.mean(
                        phis, dim=0).detach()
                self.phis_train_std = (
                    1 - alpha) * self.phis_train_std + alpha * torch.std(
                        phis, dim=0).detach()
            elif phis.dim() == 3:
                self.phis_train_mean = (
                    1 - alpha) * self.phis_train_mean + alpha * torch.mean(
                        phis, dim=(0, 1)).detach()
                self.phis_train_std = (
                    1 - alpha) * self.phis_train_std + alpha * torch.std(
                        phis, dim=(0, 1)).detach()

        phis = phis.reshape(*xs.shape[:2], *phis.shape[1:])
        return phis

    def forward(self, episodes):
        if self.condition_on == self.TRAIN:
            train, test = pack_episodes(episodes, return_tasks_descr=False)
            xs_train, ys_train, lens_train, mask_train = train
            xs_test, lens_test, mask_test = test
            task_phis = self.get_task_phis(None, xs_train, ys_train,
                                           mask_train)
        else:
            train, test, tasks_descr = pack_episodes(episodes,
                                                     return_tasks_descr=True)
            xs_train, ys_train, lens_train, mask_train = train
            xs_test, lens_test, mask_test = test
            task_phis = self.get_task_phis(tasks_descr, xs_train, ys_train,
                                           mask_train)

        phis_train, phis_test = self.get_phis(
            xs_train, train=True), self.get_phis(xs_test, train=False)

        # training
        alphas, K_train = self.get_alphas(phis_train, ys_train, mask_train,
                                          task_phis)

        # testing
        preds = self.get_preds(alphas, K_train, phis_train, mask_train,
                               phis_test, mask_test, task_phis)

        self.compute_task_encoder_loss_last_batch(episodes)

        if isinstance(preds, tuple):
            return [
                tuple(x[:n] for x in pred)
                for n, pred in zip(lens_test, preds)
            ]
        else:
            return [x[:n] for n, x in zip(lens_test, preds)]

    def compute_task_encoder_loss_last_batch(self, episodes):
        # train x test
        set_code = self.dataset_encoder(episodes)

        y_preds_class = torch.arange(len(set_code))
        if set_code.is_cuda:
            y_preds_class = y_preds_class.to('cuda')

        accuracy = (set_code.argmax(dim=1)
                    == y_preds_class).sum().item() / len(set_code)

        b = set_code.size(0)
        mi = set_code.diagonal().mean() \
             - torch.log((set_code * (1 - torch.eye(b))).exp().sum() / (b * (b - 1)))
        loss = -mi

        self.accuracy = accuracy
        self.task_encoder_loss = loss
Ejemplo n.º 19
0
class Bernoulli(Distribution):
    def __init__(self, probs=[0.5], learnable=True):
        super().__init__()
        if not isinstance(probs, torch.Tensor):
            probs = torch.tensor(probs).view(-1)
        self.n_dims = len(probs)
        self.logits = log(probs.float())
        if learnable:
            self.logits = Parameter(self.logits)

    def log_prob(self, value):
        q = 1. - self.probs
        return (value * (self.probs + eps).log() + (1. - value) *
                (q + eps).log()).sum(-1)

    def sample(self, batch_size):
        return torch.bernoulli(
            self.probs.unsqueeze(0).expand((batch_size, *self.probs.shape)))

    def entropy(self):
        q = 1. - self.probs
        return -q * (q + eps).log() - self.probs * (self.probs + eps).log()

    def kl(self, other):
        if isinstance(other, Bernoulli):
            t1 = self.probs * (self.probs / other.probs).log()
            t1[other.probs == 0] = inf
            t1[self.probs == 0] = 0
            t2 = (1 - self.probs) * ((1 - self.probs) /
                                     (1 - other.probs)).log()
            t2[other.probs == 1] = inf
            t2[self.probs == 1] = 0
            return (t1 + t2).sum()
        if isinstance(other, Poisson):
            return (-self.entropy() -
                    (self.probs * other.rate.log() - other.rate)).sum()
        return None

    @property
    def expectation(self):
        return self.probs

    @property
    def variance(self):
        return self.probs * (1 - self.probs)

    @property
    def skewness(self):
        return (1 - 2 * self.probs) / (self.probs * (1 - self.probs)).sqrt()

    @property
    def kurtosis(self):
        q = (1 - self.probs)
        return (1 - 6. * self.probs * q) / (self.probs * q)

    @property
    def probs(self):
        return self.logits.exp()

    def get_parameters(self):
        if self.n_dims == 1:
            return {'probs': self.probs.detach().item()}
        return {'probs': self.probs.detach().numpy()}