class Linear(nn.Module):
    def __init__(self, in_features, out_features, bias):
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight_rsampler = NormalReparametrizedSample(
            batch_shape=torch.Size([out_features, in_features]))
        self.bias_rsampler = NormalReparametrizedSample(
            batch_shape=torch.Size([out_features])) if bias else None

    # TODO : ELBO optimization initialization should be given

    def forward(self, input):
        weight = self.weight_rsampler(1)[0]
        bias = self.bias_rsampler(
            1)[0] if self.bias_rsampler is not None else None
        return F.linear(input, weight, bias)

    def kl_divergence(self):
        weight_batch_one = torch.ones_like(self.weight_rsampler.mu)
        kld = KL_Normal(self.weight_rsampler.mu,
                        softplus(self.weight_rsampler.softplus_inv_std)**2,
                        weight_batch_one * 0, weight_batch_one * 1).sum()
        if self.bias_rsampler is not None:
            bias_batch_one = torch.ones_like(self.bias_rsampler.mu)
            kld += KL_Normal(self.bias_rsampler.mu,
                             softplus(self.bias_rsampler.softplus_inv_std)**2,
                             bias_batch_one * 0, bias_batch_one * 1).sum()
        return kld

    def sample_kld(self):
        return self.weight_rsampler.sample_kld().sum(
        ) + self.bias_rsampler.sample_kld().sum()
 def __init__(self, in_features, out_features, bias):
     super(Linear, self).__init__()
     self.in_features = in_features
     self.out_features = out_features
     self.weight_rsampler = NormalReparametrizedSample(
         batch_shape=torch.Size([out_features, in_features]))
     self.bias_rsampler = NormalReparametrizedSample(
         batch_shape=torch.Size([out_features])) if bias else None
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True, prior=None, with_global=True):
        super(DoubleRadialConv2dFlatten, self).__init__()
        self.with_global = with_global
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride)
        self.padding = _pair(padding)
        self.dilation = _pair(dilation)
        self.groups = groups

        self.weight_size = torch.Size([out_channels, in_channels // groups]) + self.kernel_size
        self.weight_size_col = torch.Size([in_channels // groups, out_channels]) + self.kernel_size
        kernel_size_cumprod = reduce(mul, self.kernel_size, 1)
        self.row_batch_shape = torch.Size([self.weight_size[0]])
        self.row_event_shape = torch.Size([self.weight_size[1] * kernel_size_cumprod])
        self.col_batch_shape = torch.Size([self.weight_size[1]])
        self.col_event_shape = torch.Size([self.weight_size[0] * kernel_size_cumprod])

        self.register_buffer('row_batch_ones', torch.ones(self.row_batch_shape))
        self.register_buffer('col_batch_ones', torch.ones(self.col_batch_shape))
        if bias:
            self.register_buffer('bias_shape_ones', torch.ones([out_channels]))

        assert prior['direction'][0] == 'vMF'
        self.row_direction_prior_param = prior['direction'][1]
        self.row_radius_prior_type = prior['radius'][0]
        self.row_radius_prior_param = prior['radius'][1]
        self.col_direction_prior_param = prior['direction'][1]
        self.col_radius_prior_type = prior['radius'][0]
        self.col_radius_prior_param = prior['radius'][1]

        if in_channels > 1 and out_channels > 1:
            self.row_direction_rsampler = VonMisesFisherReparametrizedSample(batch_shape=self.row_batch_shape, event_shape=self.row_event_shape)
            if self.with_global:
                self.row_global_scale_rsampler = LognormalReparametrizedSample(batch_shape=torch.Size([1]))
                self.row_global_scale_rsampler1 = LognormalReparametrizedSample(batch_shape=torch.Size([1]))
            self.row_radius_rsampler = LognormalReparametrizedSample(batch_shape=self.row_batch_shape)
            self.row_radius_rsampler1 = LognormalReparametrizedSample(batch_shape=self.row_batch_shape)

            self.col_direction_rsampler = VonMisesFisherReparametrizedSample(batch_shape=self.col_batch_shape, event_shape=self.col_event_shape)
            if self.with_global:
                self.col_global_scale_rsampler = LognormalReparametrizedSample(batch_shape=torch.Size([1]))
                self.col_global_scale_rsampler1 = LognormalReparametrizedSample(batch_shape=torch.Size([1]))
            self.col_radius_rsampler = LognormalReparametrizedSample(batch_shape=self.col_batch_shape)
            self.col_radius_rsampler1 = LognormalReparametrizedSample(batch_shape=self.col_batch_shape)
        elif in_channels == 1 and out_channels > 1:
            self.col_rsampler = NormalReparametrizedSample(batch_shape=self.col_event_shape)
        elif out_channels == 1 and in_channels > 1:
            self.row_rsampler = NormalReparametrizedSample(batch_shape=self.row_event_shape)
        self.bias_rsampler = NormalReparametrizedSample(batch_shape=torch.Size([out_channels])) if bias else None
예제 #4
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 prior=None,
                 with_global=True):
        self.in_channels = in_channels
        self.out_channels = out_channels
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        super(RadialConv2d,
              self).__init__(in_channels, out_channels, kernel_size, stride,
                             padding, dilation, groups, bias, prior,
                             with_global)

        assert prior['direction'][0] == 'vMF'
        self.direction_prior_param = prior['direction'][1]
        self.radius_prior_type = prior['radius'][0]
        self.radius_prior_param = prior['radius'][1]

        self.direction_rsampler = VonMisesFisherReparametrizedSample(
            batch_shape=self.batch_shape, event_shape=self.event_shape)
        if self.with_global:
            self.global_scale_rsampler = LognormalReparametrizedSample(
                batch_shape=torch.Size([1]))
            self.global_scale_rsampler1 = LognormalReparametrizedSample(
                batch_shape=torch.Size([1]))
        self.radius_rsampler = LognormalReparametrizedSample(
            batch_shape=self.batch_shape)
        self.radius_rsampler1 = LognormalReparametrizedSample(
            batch_shape=self.batch_shape)
        self.bias_rsampler = NormalReparametrizedSample(
            batch_shape=torch.Size([out_channels])) if bias else None
    def __init__(self,
                 in_features,
                 out_features,
                 bias,
                 prior=None,
                 with_global=True):
        super(RadialLinear, self).__init__()
        self.with_global = with_global
        self.in_features = in_features
        self.out_features = out_features
        batch_shape = torch.Size([in_features])
        event_shape = torch.Size([out_features])
        self.batch_shape = batch_shape

        self.register_buffer('batch_ones', torch.ones(batch_shape))
        if bias:
            self.register_buffer('bias_ones', torch.ones([out_features]))

        assert prior['direction'][0] == 'vMF'
        self.direction_prior_param = prior['direction'][1]
        self.radius_prior_type = prior['radius'][0]
        self.radius_prior_param = prior['radius'][1]

        self.direction_rsampler = VonMisesFisherReparametrizedSample(
            batch_shape=batch_shape, event_shape=event_shape)
        if self.with_global:
            self.global_scale_rsampler = LognormalReparametrizedSample(
                batch_shape=torch.Size([1]))
            self.global_scale_rsampler1 = LognormalReparametrizedSample(
                batch_shape=torch.Size([1]))
        self.radius_rsampler = LognormalReparametrizedSample(
            batch_shape=batch_shape)
        self.radius_rsampler1 = LognormalReparametrizedSample(
            batch_shape=batch_shape)
        self.bias_rsampler = NormalReparametrizedSample(
            batch_shape=torch.Size([out_features])) if bias else None
    def __init__(self,
                 in_features,
                 out_features,
                 bias,
                 prior=None,
                 with_global=True):
        super(DoubleRadialLinear, self).__init__()
        self.with_global = with_global
        self.in_features = in_features
        self.out_features = out_features
        self.row_batch_shape = torch.Size([out_features])
        self.row_event_shape = torch.Size([in_features])
        self.col_batch_shape = torch.Size([in_features])
        self.col_event_shape = torch.Size([out_features])

        self.register_buffer('row_batch_ones',
                             torch.ones(self.row_batch_shape))
        self.register_buffer('col_batch_ones',
                             torch.ones(self.col_batch_shape))
        if bias:
            self.register_buffer('bias_shape_ones', torch.ones([out_features]))

        assert prior['direction'][0] == 'vMF'
        self.row_direction_prior_param = prior['direction'][1]
        self.row_radius_prior_type = prior['radius'][0]
        self.row_radius_prior_param = prior['radius'][1].copy()
        self.col_direction_prior_param = prior['direction'][1]
        self.col_radius_prior_type = prior['radius'][0]
        self.col_radius_prior_param = prior['radius'][1].copy()
        # self.col_radius_prior_param['halfcauchy_tau'] /= 10.0

        if in_features > 1 and out_features > 1:
            self.row_direction_rsampler = VonMisesFisherReparametrizedSample(
                batch_shape=self.row_batch_shape,
                event_shape=self.row_event_shape)
            if self.with_global:
                self.row_global_scale_rsampler = LognormalReparametrizedSample(
                    batch_shape=torch.Size([1]))
                self.row_global_scale_rsampler1 = LognormalReparametrizedSample(
                    batch_shape=torch.Size([1]))
            self.row_radius_rsampler = LognormalReparametrizedSample(
                batch_shape=self.row_batch_shape)
            self.row_radius_rsampler1 = LognormalReparametrizedSample(
                batch_shape=self.row_batch_shape)

            self.col_direction_rsampler = VonMisesFisherReparametrizedSample(
                batch_shape=self.col_batch_shape,
                event_shape=self.col_event_shape)
            if self.with_global:
                self.col_global_scale_rsampler = LognormalReparametrizedSample(
                    batch_shape=torch.Size([1]))
                self.col_global_scale_rsampler1 = LognormalReparametrizedSample(
                    batch_shape=torch.Size([1]))
            self.col_radius_rsampler = LognormalReparametrizedSample(
                batch_shape=self.col_batch_shape)
            self.col_radius_rsampler1 = LognormalReparametrizedSample(
                batch_shape=self.col_batch_shape)
        elif in_features == 1 and out_features > 1:
            self.row_rsampler = NormalReparametrizedSample(
                batch_shape=self.row_batch_shape)
        elif out_features == 1 and in_features > 1:
            self.col_rsampler = NormalReparametrizedSample(
                batch_shape=self.col_batch_shape)
        self.bias_rsampler = NormalReparametrizedSample(
            batch_shape=torch.Size([out_features])) if bias else None