Ejemplo n.º 1
0
class SkewNormalKernel(Module):
    def __init__(self, shape=None, loc=None, scale=None):
        super().__init__()
        self.__relu = ReLU()
        self.__shape = Parameter(torch.rand(1)*0.1 if shape is None else torch.Tensor([shape]), requires_grad=True)
        self.__loc = Parameter(torch.rand(1)+6 if loc is None else torch.Tensor([loc]), requires_grad=True)
        self.__scale = Parameter(torch.rand(1)+1 if scale is None else torch.Tensor([scale]), requires_grad=True)

    @property
    def shape(self) -> float:
        return self.__shape.item()

    @property
    def loc(self) -> float:
        return self.__loc.item()

    @property
    def scale(self) -> float:
        with torch.no_grad():
            return self.__relu(self.__scale).item()

    @property
    def params(self):
        return self.shape, self.loc, self.scale

    def forward(self, classes_φ):
        interval = torch.abs((classes_φ[0]-classes_φ[-1]) / (classes_φ.shape[0]-1)).item()
        shape = self.__shape
        loc = self.__loc
        scale = self.__relu(self.__scale)
        x = classes_φ
        pdf = skew_normal_pdf(x, shape, loc, scale)
        # scale pdf to frequency
        frequency = pdf * interval
        return frequency
Ejemplo n.º 2
0
    def init_embeddings(self, num_entities: int, num_relations: int,
                        embedding_range: nn.Parameter) -> None:
        """
        Initialise the embeddings (to be done by subclass)

        :param num_entities: int, > 0
            The number of unique entities
        :param num_relations: int, > 0
            The number of unique relations

        :return: None
        """
        self.num_entities = num_entities
        self.num_relations = num_relations

        self.entity_embedding = nn.Parameter(torch.zeros(num_entities,
                                             self.embedding_dim))
        nn.init.uniform_(
            tensor=self.entity_embedding,
            a=-embedding_range.item(),
            b=embedding_range.item()
        )

        self.relation_embedding = nn.Parameter(torch.zeros(num_relations,
                                                           self.embedding_dim))
        nn.init.uniform_(
            tensor=self.relation_embedding,
            a=-embedding_range.item(),
            b=embedding_range.item()
        )
Ejemplo n.º 3
0
def optimize(extrapolate=True, optimizer='sgd'):
    x = Parameter(torch.ones(1))
    y = Parameter(torch.ones(1))
    g = Game()

    def make_optimizer(params, optimizer):
        if optimizer == 'sgd':
            return SGD(
                params,
                lr=1e-1,
            )
        else:
            return Adam(params, lr=1e-2, betas=(0., 0.9), amsgrad=False)

    opt_x = ExtraOptimizer(make_optimizer([x], optimizer))
    opt_y = ExtraOptimizer(make_optimizer([y], optimizer))
    trace = []
    for i in range(10000):
        distance = (x**2 + y**2).item()
        trace.append(dict(x=x.item(), y=y.item(), c=i * 2, distance=distance))
        opt_x.zero_grad()
        opt_y.zero_grad()
        lx, ly = g(x, y)
        lx.backward()
        ly.backward()
        if extrapolate:
            opt_x.step(extrapolate=i % 2)
            opt_y.step(extrapolate=i % 2)
        else:
            opt_x.step()
            opt_y.step()
    return pd.DataFrame(trace)
Ejemplo n.º 4
0
class Power(Transform):

    def __init__(self, power=1.0, learnable=True):
        super().__init__()
        if not isinstance(power, torch.Tensor):
            power = torch.tensor(power).view(1, -1)
        self.power = power
        if learnable:
            self.power = Parameter(self.power)

    def forward(self, x):
        if self.power == 0.:
            return x.exp()
        return (1. + x * self.power) ** (1. / self.power)

    def inverse(self, y):
        if self.power == 0.:
            return y.log()
        return (y**self.power - 1.) / self.power

    def log_abs_det_jacobian(self, x, y):
        if self.power == 0.:
            return x.sum(-1)
        return ((1. / self.power - 1.) * (x * self.power).log1p()).sum(-1)

    def get_parameters(self):
        return {'type':'power', 'power':self.power.item()}
Ejemplo n.º 5
0
class LogLaplace(Distribution):

    def __init__(self, loc=0., scale=1., learnable=True):
        super().__init__()
        if not isinstance(loc, torch.Tensor):
            loc = torch.tensor(loc).view(-1)
        self.n_dims = len(loc)
        if not isinstance(scale, torch.Tensor):
            scale = torch.tensor(scale).view(-1)
        self.loc = loc.float()
        self._scale = utils.softplus_inverse(scale.float())
        if learnable:
            self.loc = Parameter(self.loc)
            self._scale = Parameter(self._scale)

    def log_prob(self, value):
        model = TransformDistribution(Laplace(self.loc, self.scale, learnable=False),
                                      [Exp()])
        return model.log_prob(value)

    def sample(self, batch_size):
        model = TransformDistribution(Laplace(self.loc, self.scale, learnable=False),
                                      [Exp()])
        return model.sample(batch_size)

    @property
    def scale(self):
        return softplus(self._scale)

    def get_parameters(self):
        if self.n_dims == 1:
            return {'loc':self.loc.item(), 'scale':self.scale.item()}
        return {'loc':self.loc.detach().numpy(),
                'scale':self.scale.detach().numpy()}
Ejemplo n.º 6
0
class DiffBoundary:
    def __init__(self, bit_width=4):
        # TODO: add channel-wise option?
        self.bit_width = bit_width
        self.register_boundaries()

    def register_boundaries(self):
        assert hasattr(self, "weight")
        self.lb = Parameter(self.weight.data.min())
        self.ub = Parameter(self.weight.data.max())

    def reset_boundaries(self):
        assert hasattr(self, "weight")
        self.lb.data = self.weight.data.min()
        self.ub.data = self.weight.data.max()

    def get_quant_weight(self, align_zero=True):
        # TODO: set `align_zero`?
        if align_zero:
            return self._get_quant_weight_align_zero()
        else:
            return self._get_quant_weight()

    def _get_quant_weight(self):
        round_ = RoundSTE.apply
        w = self.weight.detach()
        delta = (self.ub - self.lb) / (2**self.bit_width - 1)
        w = torch.clamp(w, self.lb.item(), self.ub.item())
        idx = round_((w - self.lb).div(delta))  # TODO: do we need STE here?
        qw = (idx * delta) + self.lb
        return qw

    def _get_quant_weight_align_zero(self):
        # TODO: WTF?
        round_ = RoundSTE.apply
        n = 2**self.bit_width - 1
        w = self.weight.detach()
        delta = (self.ub - self.lb) / n
        z = round_(self.lb.abs() / delta)
        lb = -z * delta
        ub = (n - z) * delta
        w = torch.clamp(w, lb.item(), ub.item())
        idx = round_((w - self.lb).div(delta))  # TODO: do we need STE here?
        qw = (idx - z) * delta
        return qw
Ejemplo n.º 7
0
class LogNormal(Distribution):
    def __init__(self, loc=0., scale=1., learnable=True):
        super().__init__()
        if not isinstance(loc, torch.Tensor):
            loc = torch.tensor(loc).view(-1)
        self.n_dims = len(loc)
        if not isinstance(scale, torch.Tensor):
            scale = torch.tensor(scale).view(-1)
        self.loc = loc.float()
        self._scale = utils.softplus_inverse(scale.float())
        if learnable:
            self.loc = Parameter(self.loc)
            self._scale = Parameter(self._scale)

    def log_prob(self, value):
        return dists.LogNormal(self.loc, self.scale).log_prob(value).sum(-1)

    def sample(self, batch_size):
        return dists.LogNormal(self.loc, self.scale).rsample((batch_size, ))

    @property
    def expectation(self):
        return (self.loc + self.scale.pow(2) / 2).exp()

    @property
    def variance(self):
        s_square = self.scale.pow(2)
        return (s_square.exp() - 1) * (2 * self.loc + s_square).exp()

    @property
    def median(self):
        return self.loc.exp()

    @property
    def mode(self):
        return (self.loc - self.scale.pow(2)).exp()

    def entropy(self):
        return dists.LogNormal(self.loc, self.scale).entropy()

    @property
    def scale(self):
        return softplus(self._scale)

    def get_parameters(self):
        if self.n_dims == 1:
            return {'loc': self.loc.item(), 'scale': self.scale.item()}
        return {
            'loc': self.loc.detach().numpy(),
            'scale': self.scale.detach().numpy()
        }
Ejemplo n.º 8
0
class LogCauchy(Distribution):
    def __init__(self, loc=0., scale=1., learnable=True):
        super().__init__()
        if not isinstance(loc, torch.Tensor):
            loc = torch.tensor(loc).view(-1)
        self.n_dims = len(loc)
        if not isinstance(scale, torch.Tensor):
            scale = torch.tensor(scale).view(-1)
        self.loc = loc.float()
        self._scale = utils.softplus_inverse(scale.float())
        if learnable:
            self.loc = Parameter(self.loc)
            self._scale = Parameter(self._scale)

    def log_prob(self, value):
        model = TransformDistribution(
            Cauchy(self.loc, self.scale, learnable=False), [Exp()])
        return model.log_prob(value)

    def sample(self, batch_size):
        model = TransformDistribution(
            Cauchy(self.loc, self.scale, learnable=False), [Exp()])
        return model.sample(batch_size)

    def cdf(self, value):
        std_term = (value.log() - self.loc) / self.scale
        return (1. / math.pi) * torch.atan(std_term) + 0.5

    def icdf(self, value):
        cauchy_icdf = torch.tan(math.pi *
                                (value - 0.5)) * self.scale + self.loc
        return cauchy_icdf.exp()

    @property
    def scale(self):
        return softplus(self._scale)

    @property
    def median(self):
        return self.loc.exp()

    def get_parameters(self):
        if self.n_dims == 1:
            return {'loc': self.loc.item(), 'scale': self.scale.item()}
        return {
            'loc': self.loc.detach().numpy(),
            'scale': self.scale.detach().numpy()
        }
Ejemplo n.º 9
0
class BiLinearLSR(torch.nn.Linear):
    def __init__(self, in_features, out_features, bias=False, binary_act=True):
        super(BiLinearLSR, self).__init__(in_features, out_features, bias=bias)
        self.binary_act = binary_act

        # must register a nn.Parameter placeholder for model loading
        # self.register_parameter('scale', None) doesn't register None into state_dict
        # so it leads to unexpected key error when loading saved model
        # hence, init scale with Parameter
        # however, Parameter(None) actually has size [0], not [] as a scalar
        # hence, init it using the following trick
        self.register_parameter('scale',
                                Parameter(torch.Tensor([0.0]).squeeze()))

    def reset_scale(self, input):
        bw = self.weight
        ba = input
        bw = bw - bw.mean()
        self.scale = Parameter(
            (F.linear(ba, bw).std() /
             F.linear(torch.sign(ba), torch.sign(bw)).std()).float().to(
                 ba.device))
        # corner case when ba is all 0.0
        if torch.isnan(self.scale):
            self.scale = Parameter(
                (bw.std() / torch.sign(bw).std()).float().to(ba.device))

    def forward(self, input):
        bw = self.weight
        ba = input
        bw = bw - bw.mean()

        if self.scale.item() == 0.0:
            self.reset_scale(input)

        bw = BinaryQuantize().apply(bw)
        bw = bw * self.scale
        if self.binary_act:
            ba = BinaryQuantize().apply(ba)
        output = F.linear(ba, bw)
        return output
Ejemplo n.º 10
0
class tensorizedlinear(nn.Module):
    def __init__(self,
                 in_size,
                 out_size,
                 in_rank,
                 out_rank,
                 alpha=1,
                 beta=0.1,
                 c=1e-3,
                 **kwargs):
        super(tensorizedlinear, self).__init__()
        self.in_size = list(in_size)
        self.out_size = list(out_size)
        self.in_rank = list(in_rank)
        self.out_rank = list(out_rank)
        self.factors_in = ParameterList([
            Parameter(torch.Tensor(r, s)) for (r, s) in zip(in_rank, in_size)
        ])
        self.factors_out = ParameterList([
            Parameter(torch.Tensor(s, r))
            for (r, s) in zip(out_rank, out_size)
        ])
        self.core = Parameter(torch.Tensor(np.prod(out_rank),
                                           np.prod(in_rank)))
        self.bias = Parameter(torch.Tensor(np.prod(out_size)))
        self.lamb_in = ParameterList(
            [Parameter(torch.ones(r)) for r in in_rank])
        self.lamb_out = ParameterList(
            [Parameter(torch.ones(r)) for r in out_rank])
        self.alpha = Parameter(torch.tensor(alpha), requires_grad=False)
        self.beta = Parameter(torch.tensor(beta), requires_grad=False)
        self.c = Parameter(torch.tensor(c), requires_grad=False)
        self._initialize_weights()

    def forward(self, x):
        x = x.reshape((x.shape[0], *self.in_size))
        for i in range(len(self.factors_in)):
            x = tl.tenalg.mode_dot(x, self.factors_in[i], i + 1)
        x = x.reshape((x.shape[0], -1))
        x = torch.nn.functional.linear(x, self.core)
        x = x.reshape((x.shape[0], *self.out_rank))
        for i in range(len(self.factors_out)):
            x = tl.tenalg.mode_dot(x, self.factors_out[i], i + 1)
        x = x.reshape((x.shape[0], -1))
        x = x + self.bias
        x /= np.prod(self.out_rank)**0.5
        return x

    def _initialize_weights(self):
        for f in self.factors_in:
            nn.init.kaiming_uniform_(f)
        for f in self.factors_out:
            nn.init.kaiming_uniform_(f)
        nn.init.kaiming_uniform_(self.core)
        #        self.core.data /= np.prod(self.out_rank) **0.5
        nn.init.constant_(self.bias, 0)

    def regularizer(self, exp=True):
        ret = 0
        if exp:
            for l, f, s in zip(self.lamb_in, self.factors_in, self.in_size):
                ret += torch.sum(torch.sum(f**2, dim=1) * torch.exp(l) / 2)
                ret -= s * torch.sum(l) / 2
                ret -= torch.sum(self.alpha * l)
                ret += torch.sum(torch.exp(l)) / self.beta
            for l, f, s in zip(self.lamb_out, self.factors_out, self.out_size):
                ret += torch.sum(torch.sum(f**2, dim=0) * torch.exp(l) / 2)
                ret -= s * torch.sum(l) / 2
                ret -= torch.sum(self.alpha * l)
                ret += torch.sum(torch.exp(l)) / self.beta
            ret += torch.sum(self.core**2 / 2)
            core_shape = list(self.out_rank) + list(self.in_rank)
            core = self.core.reshape(core_shape)
            core2 = core**2
            for d, l in enumerate(list(self.lamb_out) + list(self.lamb_in)):
                s = [1] * len(core_shape)
                s[d] = -1
                l = l.reshape(s)
                core2 = core2 * torch.exp(l)
                ret -= core2.numel() / l.numel() * torch.sum(l) / 2


#            core2 = self.core ** 2
            ret += torch.sum(core2) * self.c / 2
        else:
            for l, f, s in zip(self.lamb_in, self.factors_in, self.in_size):
                l.data.clamp_min_(1e-6)
                ret += torch.sum(torch.sum(f**2, dim=1) / l / 2)
                ret += s * torch.sum(torch.log(l)) / 2
                ret += torch.sum(self.beta / l)
                ret += (self.alpha + 1) * torch.sum(torch.log(l))
            for l, f, s in zip(self.lamb_out, self.factors_out, self.out_size):
                l.data.clamp_min_(1e-6)
                ret += torch.sum(torch.sum(f**2, dim=0) / l / 2)
                ret += s * torch.sum(torch.log(l)) / 2
                ret += torch.sum(self.beta / l)
                ret += (self.alpha + 1) * torch.sum(torch.log(l))

    #        ret += torch.sum(self.core ** 2 / 2)
            core_shape = list(self.out_rank) + list(self.in_rank)
            core = self.core.reshape(core_shape)
            core2 = core**2
            for d, l in enumerate(list(self.lamb_out) + list(self.lamb_in)):
                s = [1] * len(core_shape)
                s[d] = -1
                l = l.reshape(s)
                core2 = core2 / l
            ret += torch.sum(core2) / 2
        return ret

    def get_lamb_ths(self, exp=True):
        if exp:
            ths_in = [
                np.log(((s + self.core.numel() / self.in_rank[ind]) / 2 +
                        self.alpha.item()) * self.beta.item())
                for (ind, s) in enumerate(self.in_size)
            ]
            ths_out = [
                np.log(((s + self.core.numel() / self.out_rank[ind]) / 2 +
                        self.alpha.item()) * self.beta.item())
                for (ind, s) in enumerate(self.out_size)
            ]
        else:
            ths_in = [
                self.beta.item() / (s / 2 + self.alpha.item() + 1)
                for s in self.in_size
            ]
            ths_out = [
                self.beta.item() / (s / 2 + self.alpha.item() + 1)
                for s in self.out_size
            ]
        return (ths_in, ths_out)
Ejemplo n.º 11
0
class Normal(Distribution):
    def __init__(self, loc=0., scale=1., learnable=True):
        super().__init__()
        if not isinstance(loc, torch.Tensor):
            loc = torch.tensor(loc).float()
        if not isinstance(scale, torch.Tensor):
            scale = torch.tensor(scale).float()

        if len(loc.shape) == 0:
            loc = loc.view(-1)
            scale = scale.view(-1)
            self.n_dims = 1
            self._scale = softplus_inverse(scale)
            self._diag_type = 'diag'

        if len(loc.shape) == 1:
            self.n_dims = len(loc)
            scale = scale.view(-1)
            if scale.numel() == 1:
                scale = scale.expand_as(loc)

            if scale.shape == loc.shape:
                self._scale = softplus_inverse(scale)
                self._diag_type = 'diag'
            else:
                self._scale = scale.view(self.n_dims, self.n_dims).cholesky()
                self._diag_type = 'cholesky'

            self.loc = loc

        if len(loc.shape) > 1:
            assert len(loc.shape) == len(scale.shape)
            self.loc = loc

            scale = scale.expand_as(loc)
            self._diag_type = 'diag'
            self._scale = softplus_inverse(scale)
            self.n_dims = loc.shape

        if learnable:
            self.loc = Parameter(self.loc)
            self._scale = Parameter(self._scale)

    def log_prob(self, value):
        if self._diag_type == "cholesky":
            return dists.MultivariateNormal(self.loc,
                                            self.scale).log_prob(value)
        elif self._diag_type == 'diag':
            return dists.Normal(self.loc, self.std).log_prob(value).sum(dim=-1)
        else:
            raise NotImplementedError(
                "_diag_type can only be cholesky or diag")

    def sample(self, batch_size):
        if self._diag_type == "cholesky":
            return dists.MultivariateNormal(self.loc, self.scale).rsample(
                (batch_size, ))
        elif self._diag_type == 'diag':
            return dists.Normal(self.loc, self.std).rsample((batch_size, ))
        else:
            raise NotImplementedError(
                "_diag_type can only be cholesky or diag")

    def entropy(self, batch_size=None):
        if self._diag_type == "cholesky":
            return 0.5 * torch.det(2 * math.pi * math.e * self.scale).log()
        elif self._diag_type == 'diag':
            return 0.5 + 0.5 * math.log(2 * math.pi) + torch.log(self.scale)
        else:
            raise NotImplementedError(
                "_diag_type can only be cholesky or diag")

    def cdf(self, value):
        if self._diag_type == 'diag':
            return dists.Normal(self.loc, self.std).cdf(value)
        else:
            raise NotImplementedError(
                "CDF only implemented for _diag_type diag")

    def icdf(self, value):
        if self._diag_type == 'diag':
            return dists.Normal(self.loc, self.std).icdf(value)
        else:
            raise NotImplementedError(
                "CDF only implemented for _diag_type diag")

    def kl(self, other):
        if isinstance(other, Normal):
            if other._diag_type == 'diag':  # regular normal
                var_ratio = (self.scale / other.scale).pow(2)
                t1 = ((self.loc - other.loc) / other.scale).pow(2)
                return (0.5 * (var_ratio + t1 - 1. - var_ratio.log())).sum()
        return None

    @property
    def expectation(self):
        return self.loc

    @property
    def variance(self):
        return self.scale

    @property
    def std(self):
        return torch.diagonal(self.scale, dim1=-2, dim2=-1).sqrt()

    @property
    def scale(self):
        if self._diag_type == 'cholesky':
            return torch.mm(self._scale, self._scale.t())
        elif self._diag_type == 'diag':
            return torch.diag_embed(softplus(self._scale))
        else:
            raise NotImplementedError(
                "_diag_type can only be cholesky or diag")

    def get_parameters(self):
        if self.n_dims == 1:
            return {'loc': self.loc.item(), 'scale': self.scale.item()}
        return {
            'loc': self.loc.detach().numpy(),
            'scale': self.scale.detach().numpy()
        }
Ejemplo n.º 12
0
class Logistic(Distribution):
    def __init__(self, loc=0., scale=1., learnable=True):
        super().__init__()
        if not isinstance(loc, torch.Tensor):
            loc = torch.tensor(loc).view(-1)
        self.n_dims = len(loc)
        if not isinstance(scale, torch.Tensor):
            scale = torch.tensor(scale).view(-1)
        self.loc = loc.float()
        self._scale = utils.softplus_inverse(scale.float())
        if learnable:
            self.loc = Parameter(self.loc)
            self._scale = Parameter(self._scale)

    def create_dist(self):
        zero = torch.zeros_like(self.loc)
        one = torch.ones_like(self.loc)
        model = TransformDistribution(
            Uniform(zero, one, learnable=False),
            [Logit(), Affine(self.loc, self.scale, learnable=False)])
        return model

    def log_prob(self, value):
        model = self.create_dist()
        return model.log_prob(value)

    def sample(self, batch_size):
        model = self.create_dist()
        return model.sample(batch_size)

    def cdf(self, value):
        model = self.create_dist()
        return model.cdf(value)

    def icdf(self, value):
        model = self.create_dist()
        return model.icdf(value)

    @property
    def scale(self):
        return softplus(self._scale)

    def entropy(self):
        return self.scale.log() + 2.

    @property
    def expectation(self):
        return self.loc

    @property
    def mode(self):
        return self.loc

    @property
    def variance(self):
        return self.scale.pow(2) * (math.pi**2) / 3

    @property
    def median(self):
        return self.loc

    @property
    def skewness(self):
        return torch.tensor(0.).float()

    @property
    def kurtosis(self):
        return torch.tensor(6. / 5.).float()

    def get_parameters(self):
        if self.n_dims == 1:
            return {'loc': self.loc.item(), 'scale': self.scale.item()}
        return {
            'loc': self.loc.detach().numpy(),
            'scale': self.scale.detach().numpy()
        }
Ejemplo n.º 13
0
class Gumbel(Distribution):
    def __init__(self, loc=0., scale=1., learnable=True):
        super().__init__()
        if not isinstance(loc, torch.Tensor):
            loc = torch.tensor(loc).view(-1)
        self.n_dims = len(loc)
        if not isinstance(scale, torch.Tensor):
            scale = torch.tensor(scale).view(-1)
        self.loc = loc.float()
        self._scale = softplus_inverse(scale.float())
        if learnable:
            self.loc = Parameter(self.loc)
            self._scale = Parameter(self._scale)

    def log_prob(self, value):
        z = (value - self.loc) / self.scale
        return (-self.scale.log() - (z + (-z).exp())).sum(-1)

    def sample(self, batch_size):
        U = torch.rand((batch_size, self.n_dims))
        return self.icdf(U)

    def cdf(self, value):
        return (-(-(value - self.loc) / self.scale).exp()).exp()

    def icdf(self, value):
        return self.loc - self.scale * (-(value + eps).log()).log()

    def entropy(self):
        return self.scale.log() + euler_mascheroni + 1.

    @property
    def expectation(self):
        return self.loc + self.scale * euler_mascheroni

    @property
    def mode(self):
        return self.loc

    @property
    def variance(self):
        return ((math.pi**2) / 6.) * self.scale.pow(2)

    @property
    def median(self):
        return self.loc - self.scale * math.log(math.log(2))

    @property
    def skewness(self):
        return torch.tensor(1.14).float()  # expand this out

    @property
    def kurtosis(self):
        return torch.tensor(12. / 5.).float()

    @property
    def scale(self):
        return softplus(self._scale)

    def get_parameters(self):
        if self.n_dims == 1:
            return {'loc': self.loc.item(), 'scale': self.scale.item()}
        return {
            'loc': self.loc.detach().numpy(),
            'scale': self.scale.detach().numpy()
        }
Ejemplo n.º 14
0
class AsymmetricLaplace(Distribution):
    def __init__(self, loc=0., scale=1., asymmetry=1., learnable=True):
        super().__init__()
        if not isinstance(loc, torch.Tensor):
            loc = torch.tensor(loc).view(-1)
        self.n_dims = len(loc)
        if not isinstance(scale, torch.Tensor):
            scale = torch.tensor(scale).view(-1)
        if not isinstance(asymmetry, torch.Tensor):
            asymmetry = torch.tensor(asymmetry).view(-1)
        self.loc = loc.float()
        self._scale = utils.softplus_inverse(scale.float())
        self._asymmetry = utils.softplus_inverse(asymmetry.float())
        if learnable:
            self.loc = Parameter(self.loc)
            self._scale = Parameter(self._scale)
            self._asymmetry = Parameter(self._asymmetry)

    def log_prob(self, value):
        s = (value - self.loc).sign()
        exponent = -(value -
                     self.loc).abs() * self.scale * self.asymmetry.pow(s)
        coeff = self.scale.log() - (self.asymmetry +
                                    (1 / self.asymmetry)).log()
        return (coeff + exponent).sum(-1)

    def sample(self, batch_size):
        U = Uniform(low=-self.asymmetry,
                    high=(1. / self.asymmetry),
                    learnable=False).sample(batch_size)
        s = U.sign()
        log_term = (1. - U * s * self.asymmetry.pow(s)).log()
        return self.loc - (1. /
                           (self.scale * s * self.asymmetry.pow(s))) * log_term

    def cdf(self, value):
        s = (value - self.loc).sign()
        exponent = -(value -
                     self.loc).abs() * self.scale * self.asymmetry.pow(s)
        exponent = exponent.exp()
        return (value > self.loc).float() - s * self.asymmetry.pow(1 - s) / (
            1 + self.asymmetry.pow(2)) * exponent

    # def icdf(self, value):
    #     return

    def entropy(self):
        return (utils.e * (1 + self.asymmetry.pow(2)) /
                (self.asymmetry * self.scale)).log().sum()

    @property
    def expectation(self):
        return self.loc + ((1 - self.asymmetry.pow(2)) /
                           (self.scale * self.asymmetry))

    @property
    def variance(self):
        return (1 + self.asymmetry.pow(4)) / (self.scale.pow(2) *
                                              self.asymmetry.pow(2))

    @property
    def mode(self):
        return self.loc

    @property
    def median(self):
        return self.loc + (self.asymmetry / self.scale) * (
            (1 + self.asymmetry.pow(2)) / (2 * self.asymmetry.pow(2))).log()

    @property
    def skewness(self):
        return (2 *
                (1 - self.asymmetry.pow(6))) / (1 + self.asymmetry.pow(4)).pow(
                    3. / 2.)

    @property
    def kurtosis(self):
        return (6 *
                (1 + self.asymmetry.pow(8))) / (1 +
                                                self.asymmetry.pow(4)).pow(2)

    @property
    def scale(self):
        return softplus(self._scale)

    @property
    def asymmetry(self):
        return softplus(self._asymmetry)

    def get_parameters(self):
        if self.n_dims == 1:
            return {
                'loc': self.loc.item(),
                'scale': self.scale.item(),
                'asymmetry': self.asymmetry.item()
            }
        return {
            'loc': self.loc.detach().numpy(),
            'scale': self.scale.detach().numpy(),
            'asymmetry': self.asymmetry.detach().numpy()
        }
Ejemplo n.º 15
0
class Laplace(Distribution):
    def __init__(self, loc=0., scale=1., learnable=True):
        super().__init__()
        if not isinstance(loc, torch.Tensor):
            loc = torch.tensor(loc).view(-1)
        self.n_dims = len(loc)
        if not isinstance(scale, torch.Tensor):
            scale = torch.tensor(scale).view(-1)
        self.loc = loc.float()
        self._scale = utils.softplus_inverse(scale.float())
        if learnable:
            self.loc = Parameter(self.loc)
            self._scale = Parameter(self._scale)

    def log_prob(self, value):
        return (-(2. * self.scale).log() -
                ((value - self.loc).abs() / self.scale)).sum(-1)

    def sample(self, batch_size):
        return dists.Laplace(self.loc, self.scale).rsample((batch_size, ))

    def cdf(self, value):
        return 0.5 - 0.5 * (value - self.loc).sign() * (
            -(value - self.loc).abs() / self.scale).expm1()

    def icdf(self, value):
        term = value - 0.5
        return self.loc - self.scale * term.sign() * (-2 * term.abs()).log1p()

    def entropy(self):
        return 1 + (2 * self.scale).log()

    def kl(self, other):
        if isinstance(other, Laplace):
            scale_ratio = self.scale / other.scale
            loc_abs_diff = (self.loc - other.loc).abs()
            t1 = -scale_ratio.log()
            t2 = loc_abs_diff / other.scale
            t3 = scale_ratio * (-loc_abs_diff / self.scale).exp()
            return (t1 + t2 + t3 - 1.).sum()
        return None

    @property
    def expectation(self):
        return self.loc

    @property
    def variance(self):
        return 2 * self.scale.pow(2)

    @property
    def median(self):
        return self.loc

    @property
    def stddev(self):
        return (2**0.5) * self.scale

    @property
    def mode(self):
        return self.loc

    @property
    def skewness(self):
        return torch.tensor(0.).float()

    @property
    def kurtosis(self):
        return torch.tensor(3.).float()

    @property
    def scale(self):
        return softplus(self._scale)

    def get_parameters(self):
        if self.n_dims == 1:
            return {'loc': self.loc.item(), 'scale': self.scale.item()}
        return {
            'loc': self.loc.detach().numpy(),
            'scale': self.scale.detach().numpy()
        }
Ejemplo n.º 16
0
class StudentT(Distribution):
    def __init__(self, df=1., loc=0., scale=1., learnable=True):
        super().__init__()
        if not isinstance(loc, torch.Tensor):
            loc = torch.tensor(loc).view(-1)
        self.n_dims = len(loc)
        if not isinstance(scale, torch.Tensor):
            scale = torch.tensor(scale).view(-1)
        if not isinstance(df, torch.Tensor):
            df = torch.tensor(df).view(-1)
        self.loc = loc.float()
        self._scale = utils.softplus_inverse(scale.float())
        self._df = utils.softplus_inverse(df.float())
        if learnable:
            self.loc = Parameter(self.loc)
            self._scale = Parameter(self._scale)
            self._df = Parameter(self._df)

    def log_prob(self, value):
        model = dists.StudentT(self.df, self.loc, self.scale)
        return model.log_prob(value).sum(-1)

    def sample(self, batch_size):
        model = dists.StudentT(self.df, self.loc, self.scale)
        return model.rsample((batch_size, ))

    def entropy(self):
        return dists.StudentT(self.df, self.loc, self.scale).entropy()

    @property
    def expectation(self):
        return dists.StudentT(self.df, self.loc, self.scale).mean

    @property
    def mode(self):
        return self.expectation

    @property
    def variance(self):
        return dists.StudentT(self.df, self.loc, self.scale).variance

    @property
    def scale(self):
        return softplus(self._scale)

    @property
    def df(self):
        return softplus(self._df)

    def get_parameters(self):
        if self.n_dims == 1:
            return {
                'loc': self.loc.item(),
                'scale': self.scale.item(),
                'df': self.df.item()
            }
        return {
            'loc': self.loc.detach().numpy(),
            'scale': self.scale.detach().numpy(),
            'df': self.df.detach().numpy()
        }
Ejemplo n.º 17
0
class bfcp(Module):
    def __init__(self, size, rank, alpha=1, beta=0.2, c=1, d=1e6, init='unif'):
        super(bfcp, self).__init__()
        self.size = size
        self.rank = rank
        self.dim = len(size)
        self.lamb = Parameter(torch.Tensor(self.rank))
        self.tau = Parameter(torch.tensor(1.0))
        self.alpha = Parameter(torch.tensor(alpha), requires_grad=False)
        self.beta = Parameter(torch.tensor(beta), requires_grad=False)
        self.c = Parameter(torch.tensor(c), requires_grad=False)
        self.d = Parameter(torch.tensor(d), requires_grad=False)
        #        self.lamb = torch.Tensor(self.rank)

        self.factors = ParameterList(
            [Parameter(torch.Tensor(s, rank)) for s in size])
        self.reset_parameters(init)

    def reset_parameters(self, initms='unif'):
        #        init.uniform_(self.lamb)
        init.constant_(self.lamb, 1)
        for f in self.factors:
            if initms == 'unif':
                init.uniform_(f)
            elif initms == 'normal':
                init.normal_(f)
            else:
                raise NotImplementedError

#    @weak_script_method

    def forward(self, input):
        #input: indexes, batch*dim, torch.long
        vsz = input.shape[0]
        vals = torch.zeros(vsz, device=input.device)
        for j in range(self.rank):
            #            tmpvals = self.lamb[j] * torch.ones(vsz)
            tmpvals = torch.ones(vsz, device=input.device)
            for k in range(self.dim):
                akvals = self.factors[k][input[:, k], j]
                tmpvals = tmpvals * akvals

            vals = vals + tmpvals
        return vals

    def prior_theta(self):
        self.lamb.data.clamp_min_(1e-6)

        ret = 0
        for f in self.factors:
            ret += torch.sum(torch.sum(f**2, dim=0) / self.lamb / 2)
        ret += torch.sum(torch.log(self.lamb)) * sum(self.size) / 2

        #        for f in self.factors:
        #            ret += torch.sum(torch.sum(f**2, dim=0) * self.lamb)
        #        ret -= torch.sum(torch.log(self.lamb))*sum(self.size)/ 2
        """inverse gamma distribution"""
        ret += torch.sum(self.beta / self.lamb)
        ret += (self.alpha + 1) * torch.sum(torch.log(self.lamb))

        return ret

    def prior_tau_exp(self):
        return -self.c * self.tau + torch.exp(self.tau) / self.d

    def prior_theta_exp(self):

        ret = 0
        for f in self.factors:
            ret += torch.sum(torch.sum(f**2, dim=0) * torch.exp(self.lamb))
        ret -= sum(self.size) * torch.sum(self.lamb) / 2

        ret -= self.alpha * torch.sum(self.lamb)
        ret += torch.sum(torch.exp(self.lamb)) / self.beta

        return ret


#    def prior_tau(self):
#        self.tau.data.clamp_min_(1e-6)
#
#        return (self.c+1) * torch.log(self.tau) + self.d / self.tau

    def extra_repr(self):
        return 'size={}, rank={}, alpha={}, beta={}, c={}, d={}'.format(
            self.size, self.rank, self.alpha.item(), self.beta.item(),
            self.c.item(), self.d.item())

    def full(self):
        return tensorly.kruskal_to_tensor(
            (torch.ones(self.rank, device=self.lamb.device), self.factors))
Ejemplo n.º 18
0
class bftucker(Module):
    def __init__(self,
                 size,
                 rank,
                 alpha=1,
                 beta=0.2,
                 c=1,
                 d=1e6,
                 e=1,
                 init='unif'):
        super(bftucker, self).__init__()
        self.size = size
        self.rank = rank
        self.dim = len(size)

        self.tau = Parameter(torch.tensor(1.0))
        self.alpha = Parameter(torch.tensor(alpha), requires_grad=False)
        self.beta = Parameter(torch.tensor(beta), requires_grad=False)
        self.c = Parameter(torch.tensor(c), requires_grad=False)
        self.d = Parameter(torch.tensor(d), requires_grad=False)
        self.e = Parameter(torch.tensor(e), requires_grad=False)
        #        self.lamb = torch.Tensor(self.rank)

        self.lamb = ParameterList([Parameter(torch.Tensor(r)) for r in rank])
        self.factors = ParameterList(
            [Parameter(torch.Tensor(s, r)) for (s, r) in zip(size, rank)])
        self.core = Parameter(torch.zeros(rank))
        self.reset_parameters(init)

    def reset_parameters(self, initms='unif'):
        #        init.uniform_(self.lamb)
        for l in self.lamb:
            init.constant_(l, 1)
        for f in self.factors:
            if initms == 'unif':
                init.uniform_(f)
            elif initms == 'normal':
                init.normal_(f)
            else:
                raise NotImplementedError
        if initms == 'unif':
            init.uniform_(self.core)
        elif initms == 'normal':
            init.normal_(self.core)

#    @weak_script_method

    def forward(self, input):
        #input: indexes, batch*dim, torch.long
        vals = torch.zeros(input.shape[0]).to(input.device)
        for b in range(input.shape[0]):
            inputd = input[b]
            factors = [
                self.factors[i][inputd[i], :].reshape((1, -1))
                for i in range(input.shape[1])
            ]
            vals[b] = tensorly.tucker_to_vec((self.core, factors))
        return vals

    def prior_theta(self):
        for l in self.lamb:
            l.data.clamp_min_(1e-6)

        ret = 0
        for f, l in zip(self.factors, self.lamb):
            ret += torch.sum(torch.sum(f**2, dim=0) / l / 2)
        ret += torch.sum(torch.log(l)) * sum(self.size) / 2

        core2 = self.core**2
        for d, l in enumerate(list(self.lamb)):
            s = [1] * len(self.rank)
            s[d] = -1
            l = l.reshape(s)
            core2 = core2 / l
            ret -= core2.numel() / l.numel() * torch.sum(l) / 2


#            core2 = self.core ** 2
        ret += torch.sum(core2) * self.e / 2
        #        for f in self.factors:
        #            ret += torch.sum(torch.sum(f**2, dim=0) * self.lamb)
        #        ret -= torch.sum(torch.log(self.lamb))*sum(self.size)/ 2

        for l in self.lamb:
            ret += torch.sum(self.beta / l)
            ret += (self.alpha + 1) * torch.sum(torch.log(l))

        return ret

    def prior_tau_exp(self):
        return -self.c * self.tau + torch.exp(self.tau) / self.d

    def extra_repr(self):
        return 'size={}, rank={}, alpha={}, beta={}, c={}, d={}'.format(
            self.size, self.rank, self.alpha.item(), self.beta.item(),
            self.c.item(), self.d.item())

    def full(self):
        return tensorly.kruskal_to_tensor(
            (torch.ones(self.rank, device=self.lamb.device), self.factors))
Ejemplo n.º 19
0
class TTlinear(nn.Module):
    def __init__(self, in_size, out_size, rank, alpha=1, beta=0.1, **kwargs):
        # increase beta to decrease rank
        super(TTlinear, self).__init__()
        assert (len(in_size) == len(out_size))
        assert (len(rank) == len(in_size) - 1)
        self.in_size = list(in_size)
        self.out_size = list(out_size)
        self.rank = list(rank)
        self.factors = ParameterList()
        r1 = [1] + list(rank)
        r2 = list(rank) + [1]
        for ri, ro, si, so in zip(r1, r2, in_size, out_size):
            p = Parameter(torch.Tensor(ri, si, so, ro))
            self.factors.append(p)
        self.bias = Parameter(torch.Tensor(np.prod(out_size)))
        self.lamb = ParameterList([Parameter(torch.ones(r)) for r in rank])
        self.alpha = Parameter(torch.tensor(alpha), requires_grad=False)
        self.beta = Parameter(torch.tensor(beta), requires_grad=False)

        self._initialize_weights()

    def forward(self, x):
        def mode2_dot(tensor, matrix, mode):
            ms = matrix.shape
            matrix = matrix.reshape(ms[0] * ms[1], ms[2] * ms[3])

            sp = list(tensor.shape)
            sp[mode:mode + 2] = [sp[mode] * sp[mode + 1], 1]

            sn = list(tensor.shape)
            sn[mode:mode + 2] = ms[2:4]

            tensor = tensor.reshape(sp)
            tensor = tl.tenalg.mode_dot(tensor, matrix.t(), mode)
            return tensor.reshape(sn)

        x = x.reshape((x.shape[0], 1, *self.in_size))
        for (i, f) in enumerate(self.factors):
            x = mode2_dot(x, f, i + 1)
        x = x.reshape((x.shape[0], -1))
        x = x + self.bias
        return x

    def _initialize_weights(self):
        for f in self.factors:
            nn.init.kaiming_uniform_(f)
        nn.init.constant_(self.bias, 0)

    def regularizer(self, exp=True):
        ret = 0
        if exp:
            for i in range(len(self.rank)):
                # ret += torch.sum(torch.sum(self.factors[i]**2, dim=[0, 1, 2])
                # * torch.exp(self.lamb[i]) / 2)
                ret -= np.prod(self.factors[i].shape[:-1]) \
                    * torch.sum(self.lamb[i]) / 2
                # ret += torch.sum(torch.sum(self.factors[i+1]**2, dim=[1, 2, 3])
                # * torch.exp(self.lamb[i] / 2))
                ret -= np.prod(self.factors[i+1].shape[1:]) \
                     * torch.sum(self.lamb[i]) / 2
                ret -= torch.sum(self.alpha * self.lamb[i])
                ret += torch.sum(torch.exp(self.lamb[i])) / self.beta

            for i in range(len(self.rank) + 1):
                m = torch.sum(self.factors[i]**2, dim=[1, 2])
                if i > 0:
                    m = m * torch.exp(self.lamb[i - 1]).reshape([-1, 1])
                if i < len(self.rank):
                    m = m * torch.exp(self.lamb[i]).reshape([1, -1])
                ret += torch.sum(m, dim=[0, 1]) / 2

        else:
            for i in range(len(self.rank)):
                self.lamb[i].data.clamp_min_(1e-6)
                ret += torch.sum(
                    torch.sum(self.factors[i]**2, dim=[0, 1, 2]) /
                    self.lamb[i] / 2)
                ret += np.prod(self.factors[i].shape[:-1]) \
                    * torch.sum(torch.log(self.lamb[i])) / 2
                ret += torch.sum(
                    torch.sum(self.factors[i + 1]**2, dim=[1, 2, 3]) /
                    self.lamb[i] / 2)
                ret += np.prod(self.factors[i+1].shape[1:]) \
                    * torch.sum(torch.log(self.lamb[i])) / 2

                ret += torch.sum(self.beta / self.lamb[i])
                ret += (self.alpha + 1) * torch.sum(torch.log(self.lamb[i]))

        return ret

    def get_lamb_ths(self, exp=True):
        if (exp):
            lamb_ths = [
                np.log((np.prod(self.factors[i].shape[:-1]) / 2 +
                        np.prod(self.factors[i + 1].shape[1:]) / 2 +
                        self.alpha.item()) * self.beta.item())
                for i in range(len(self.lamb))
            ]
        else:
            lamb_ths = [
                self.beta.item() /
                (np.prod(self.factors[i].shape[:-1]) / 2 +
                 np.prod(self.factors[i + 1].shape[1:]) / 2 +
                 self.alpha.item() + 1) for i in range(len(self.lamb))
            ]
        return lamb_ths
Ejemplo n.º 20
0
class TTConv2d(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 rank,
                 stride=1,
                 padding=0,
                 dilation=1,
                 alpha=1,
                 beta=0.1,
                 **kwargs):
        # increase beta to decrease rank
        super(TTConv2d, self).__init__()
        assert (len(in_channels) == len(in_channels))
        assert (len(rank) == len(in_channels) - 1)
        self.in_channels = list(in_channels)
        self.out_channels = list(out_channels)
        self.rank = list(rank)
        self.factors = ParameterList()

        r1 = [1] + self.rank[:-1]
        r2 = self.rank
        for ri, ro, si, so in zip(r1, r2, in_channels[:-1], out_channels[:-1]):
            p = Parameter(torch.Tensor(ri, si, so, ro))
            self.factors.append(p)
        self.bias = Parameter(torch.Tensor(np.prod(out_channels)))

        self.conv = nn.Conv2d(in_channels=self.rank[-1] * in_channels[-1],
                              out_channels=out_channels[-1],
                              kernel_size=kernel_size,
                              stride=stride,
                              padding=padding,
                              dilation=dilation,
                              bias=False)

        self.lamb = ParameterList([Parameter(torch.ones(r)) for r in rank])
        self.alpha = Parameter(torch.tensor(alpha), requires_grad=False)
        self.beta = Parameter(torch.tensor(beta), requires_grad=False)

        self._initialize_weights()

    def forward(self, x):
        def mode2_dot(tensor, matrix, mode):
            ms = matrix.shape
            matrix = matrix.reshape(ms[0] * ms[1], ms[2] * ms[3])

            sp = list(tensor.shape)
            sp[mode:mode + 2] = [sp[mode] * sp[mode + 1], 1]

            sn = list(tensor.shape)
            sn[mode:mode + 2] = ms[2:4]

            tensor = tensor.reshape(sp)
            tensor = tl.tenalg.mode_dot(tensor, matrix.t(), mode)
            return tensor.reshape(sn)

        (b, c, h, w) = x.shape
        x = x.reshape((x.shape[0], 1, *self.in_channels, h, w))
        for (i, f) in enumerate(self.factors):
            x = mode2_dot(x, f, i + 1)
        x = x.reshape((b * np.prod(self.out_channels[:-1]),
                       self.rank[-1] * self.in_channels[-1], h, w))
        x = self.conv(x)
        x = x.reshape((b, np.prod(self.out_channels), h, w))
        x = x + self.bias.reshape((1, -1, 1, 1))
        return x

    def _initialize_weights(self):
        for f in self.factors:
            nn.init.kaiming_uniform_(f)
        nn.init.constant_(self.bias, 0)

    def regularizer(self, exp=True):
        ret = 0
        if exp:
            for i in range(len(self.rank)):
                # ret += torch.sum(torch.sum(self.factors[i]**2, dim=[0, 1, 2])
                #     * torch.exp(self.lamb[i]) / 2)
                m = torch.sum(self.factors[i]**2, dim=[1, 2])
                if i > 0:
                    m = m * torch.exp(self.lamb[i-1]).reshape([-1, 1]) \
                        / np.exp(self.get_lamb_ths(exp)[i-1])
                m = m * torch.exp(self.lamb[i]).reshape([1, -1])
                ret += torch.sum(m, dim=[0, 1]) / 2
                ret -= np.prod(self.factors[i].shape[:-1]) \
                    * torch.sum(self.lamb[i]) / 2
                if i != len(self.rank) - 1:
                    # ret += torch.sum(torch.sum(self.factors[i+1]**2, dim=[1, 2, 3])
                    #     * torch.exp(self.lamb[i] / 2))
                    ret -= np.prod(self.factors[i+1].shape[1:]) \
                        * torch.sum(self.lamb[i]) / 2
                else:
                    w = self.conv.weight.transpose(0, 1)
                    w = w.reshape(self.rank[i], -1)
                    ret += torch.sum(
                        torch.sum(w**2, dim=1) * torch.exp(self.lamb[i]) / 2)
                    ret -= w.shape[1] * torch.sum(self.lamb[i]) / 2

                ret -= torch.sum(self.alpha * self.lamb[i])
                ret += torch.sum(torch.exp(self.lamb[i])) / self.beta

        else:
            for i in range(len(self.rank) - 1):
                self.lamb[i].data.clamp_min_(1e-6)
                ret += torch.sum(
                    torch.sum(self.factors[i]**2, dim=[0, 1, 2]) /
                    self.lamb[i] / 2)
                ret += np.prod(self.factors[i].shape[:-1]) \
                    * torch.sum(torch.log(self.lamb[i])) / 2
                if i != len(self.rank) - 1:
                    ret += torch.sum(
                        torch.sum(self.factors[i + 1]**2, dim=[1, 2, 3]) /
                        self.lamb[i] / 2)
                    ret += np.prod(self.factors[i+1].shape[1:]) \
                        * torch.sum(torch.log(self.lamb[i])) / 2
                else:
                    w = self.conv.weight.transpose(0, 1)
                    w = w.reshape(self.rank[i], -1)
                    ret += torch.sum(torch.sum(w**2, dim=1) / self.lamb[i] / 2)
                    ret += w.shape[1] * torch.sum(torch.log(self.lamb[i])) / 2

                ret += torch.sum(self.beta / self.lamb[i])
                ret += (self.alpha + 1) * torch.sum(torch.log(self.lamb[i]))

        return ret

    def get_lamb_ths(self, exp=True):
        if (exp):
            lamb_ths = [
                np.log((np.prod(self.factors[i].shape[:-1]) / 2 +
                        np.prod(self.factors[i + 1].shape[1:]) / 2 +
                        self.alpha.item()) * self.beta.item())
                for i in range(len(self.lamb) - 1)
            ]
            lamb_ths.append(
                np.log(
                    (np.prod(self.factors[-1].shape[:-1]) / 2 +
                     (self.out_channels[-1] * self.in_channels[-1] *
                      self.conv.weight.shape[2] * self.conv.weight.shape[3]) /
                     2 + self.alpha.item()) * self.beta.item()))
        else:
            lamb_ths = [
                self.beta.item() /
                (np.prod(self.factors[i].shape[:-1]) / 2 +
                 np.prod(self.factors[i + 1].shape[1:]) / 2 +
                 self.alpha.item() + 1) for i in range(len(self.lamb) - 1)
            ]
            lamb_ths.append(
                self.beta.item() /
                (np.prod(self.factors[-1].shape[:-1]) / 2 +
                 (self.out_channels[-1] * self.in_channels[-1] *
                  self.conv.weight.shape[2] * self.conv.weight.shape[3]) / 2 +
                 self.alpha.item() + 1))
        return lamb_ths
Ejemplo n.º 21
0
class InfiniteMixtureModel(Distribution):
    def __init__(self,
                 df,
                 loc,
                 scale,
                 loc_learnable=True,
                 scale_learnable=True,
                 df_learnable=True):
        super().__init__()
        self.loc = torch.tensor(loc).view(-1)
        self.n_dims = len(self.loc)
        if loc_learnable:
            self.loc = Parameter(self.loc)
        self._scale = utils.softplus_inverse(torch.tensor(scale).view(-1))
        if scale_learnable:
            self._scale = Parameter(self._scale)
        self._df = utils.softplus_inverse(torch.tensor(df).view(-1))
        if df_learnable:
            self._df = Parameter(self._df)

    def sample(self, batch_size, return_latents=False):
        weight_model = Gamma(self.df / 2, self.df / 2, learnable=False)
        latent_samples = weight_model.sample(batch_size)
        normal_model = Normal(self.loc.expand(batch_size),
                              (self.scale / latent_samples).squeeze(1),
                              learnable=False)
        if return_latents:
            return normal_model.sample(1).squeeze(0).unsqueeze(
                1), latent_samples
        else:
            return normal_model.sample(1).squeeze(0).unsqueeze(1)

    def log_prob(self, samples, latents=None):
        if latents is None:
            raise NotImplementedError(
                "InfiniteMixtureModel log_prob not implemented")
        weight_model = Gamma(self.df / 2, self.df / 2, learnable=False)
        normal_model = Normal(self.loc.expand(latents.size(0)),
                              (self.scale / latents).squeeze(1),
                              learnable=False)
        return normal_model.log_prob(samples) + weight_model.log_prob(latents)

    @property
    def scale(self):
        return softplus(self._scale)

    @property
    def df(self):
        return softplus(self._df)

    @property
    def has_latents(self):
        return True

    def get_parameters(self):
        if self.n_dims == 1:
            return {
                "loc": self.loc.item(),
                "scale": self.scale.item(),
                "df": self.df.item(),
            }
        return {
            "loc": self.loc.detach().numpy(),
            "scale": self.scale.detach().numpy(),
            "df": self.df.detach().numpy(),
        }