class Linear(nn.Module): def __init__(self, in_features, out_features, bias): super(Linear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight_rsampler = NormalReparametrizedSample( batch_shape=torch.Size([out_features, in_features])) self.bias_rsampler = NormalReparametrizedSample( batch_shape=torch.Size([out_features])) if bias else None # TODO : ELBO optimization initialization should be given def forward(self, input): weight = self.weight_rsampler(1)[0] bias = self.bias_rsampler( 1)[0] if self.bias_rsampler is not None else None return F.linear(input, weight, bias) def kl_divergence(self): weight_batch_one = torch.ones_like(self.weight_rsampler.mu) kld = KL_Normal(self.weight_rsampler.mu, softplus(self.weight_rsampler.softplus_inv_std)**2, weight_batch_one * 0, weight_batch_one * 1).sum() if self.bias_rsampler is not None: bias_batch_one = torch.ones_like(self.bias_rsampler.mu) kld += KL_Normal(self.bias_rsampler.mu, softplus(self.bias_rsampler.softplus_inv_std)**2, bias_batch_one * 0, bias_batch_one * 1).sum() return kld def sample_kld(self): return self.weight_rsampler.sample_kld().sum( ) + self.bias_rsampler.sample_kld().sum()
def __init__(self, in_features, out_features, bias): super(Linear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight_rsampler = NormalReparametrizedSample( batch_shape=torch.Size([out_features, in_features])) self.bias_rsampler = NormalReparametrizedSample( batch_shape=torch.Size([out_features])) if bias else None
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, prior=None, with_global=True): super(DoubleRadialConv2dFlatten, self).__init__() self.with_global = with_global self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = _pair(kernel_size) self.stride = _pair(stride) self.padding = _pair(padding) self.dilation = _pair(dilation) self.groups = groups self.weight_size = torch.Size([out_channels, in_channels // groups]) + self.kernel_size self.weight_size_col = torch.Size([in_channels // groups, out_channels]) + self.kernel_size kernel_size_cumprod = reduce(mul, self.kernel_size, 1) self.row_batch_shape = torch.Size([self.weight_size[0]]) self.row_event_shape = torch.Size([self.weight_size[1] * kernel_size_cumprod]) self.col_batch_shape = torch.Size([self.weight_size[1]]) self.col_event_shape = torch.Size([self.weight_size[0] * kernel_size_cumprod]) self.register_buffer('row_batch_ones', torch.ones(self.row_batch_shape)) self.register_buffer('col_batch_ones', torch.ones(self.col_batch_shape)) if bias: self.register_buffer('bias_shape_ones', torch.ones([out_channels])) assert prior['direction'][0] == 'vMF' self.row_direction_prior_param = prior['direction'][1] self.row_radius_prior_type = prior['radius'][0] self.row_radius_prior_param = prior['radius'][1] self.col_direction_prior_param = prior['direction'][1] self.col_radius_prior_type = prior['radius'][0] self.col_radius_prior_param = prior['radius'][1] if in_channels > 1 and out_channels > 1: self.row_direction_rsampler = VonMisesFisherReparametrizedSample(batch_shape=self.row_batch_shape, event_shape=self.row_event_shape) if self.with_global: self.row_global_scale_rsampler = LognormalReparametrizedSample(batch_shape=torch.Size([1])) self.row_global_scale_rsampler1 = LognormalReparametrizedSample(batch_shape=torch.Size([1])) self.row_radius_rsampler = LognormalReparametrizedSample(batch_shape=self.row_batch_shape) self.row_radius_rsampler1 = LognormalReparametrizedSample(batch_shape=self.row_batch_shape) self.col_direction_rsampler = VonMisesFisherReparametrizedSample(batch_shape=self.col_batch_shape, event_shape=self.col_event_shape) if self.with_global: self.col_global_scale_rsampler = LognormalReparametrizedSample(batch_shape=torch.Size([1])) self.col_global_scale_rsampler1 = LognormalReparametrizedSample(batch_shape=torch.Size([1])) self.col_radius_rsampler = LognormalReparametrizedSample(batch_shape=self.col_batch_shape) self.col_radius_rsampler1 = LognormalReparametrizedSample(batch_shape=self.col_batch_shape) elif in_channels == 1 and out_channels > 1: self.col_rsampler = NormalReparametrizedSample(batch_shape=self.col_event_shape) elif out_channels == 1 and in_channels > 1: self.row_rsampler = NormalReparametrizedSample(batch_shape=self.row_event_shape) self.bias_rsampler = NormalReparametrizedSample(batch_shape=torch.Size([out_channels])) if bias else None
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, prior=None, with_global=True): self.in_channels = in_channels self.out_channels = out_channels kernel_size = _pair(kernel_size) stride = _pair(stride) padding = _pair(padding) dilation = _pair(dilation) super(RadialConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, prior, with_global) assert prior['direction'][0] == 'vMF' self.direction_prior_param = prior['direction'][1] self.radius_prior_type = prior['radius'][0] self.radius_prior_param = prior['radius'][1] self.direction_rsampler = VonMisesFisherReparametrizedSample( batch_shape=self.batch_shape, event_shape=self.event_shape) if self.with_global: self.global_scale_rsampler = LognormalReparametrizedSample( batch_shape=torch.Size([1])) self.global_scale_rsampler1 = LognormalReparametrizedSample( batch_shape=torch.Size([1])) self.radius_rsampler = LognormalReparametrizedSample( batch_shape=self.batch_shape) self.radius_rsampler1 = LognormalReparametrizedSample( batch_shape=self.batch_shape) self.bias_rsampler = NormalReparametrizedSample( batch_shape=torch.Size([out_channels])) if bias else None
def __init__(self, in_features, out_features, bias, prior=None, with_global=True): super(RadialLinear, self).__init__() self.with_global = with_global self.in_features = in_features self.out_features = out_features batch_shape = torch.Size([in_features]) event_shape = torch.Size([out_features]) self.batch_shape = batch_shape self.register_buffer('batch_ones', torch.ones(batch_shape)) if bias: self.register_buffer('bias_ones', torch.ones([out_features])) assert prior['direction'][0] == 'vMF' self.direction_prior_param = prior['direction'][1] self.radius_prior_type = prior['radius'][0] self.radius_prior_param = prior['radius'][1] self.direction_rsampler = VonMisesFisherReparametrizedSample( batch_shape=batch_shape, event_shape=event_shape) if self.with_global: self.global_scale_rsampler = LognormalReparametrizedSample( batch_shape=torch.Size([1])) self.global_scale_rsampler1 = LognormalReparametrizedSample( batch_shape=torch.Size([1])) self.radius_rsampler = LognormalReparametrizedSample( batch_shape=batch_shape) self.radius_rsampler1 = LognormalReparametrizedSample( batch_shape=batch_shape) self.bias_rsampler = NormalReparametrizedSample( batch_shape=torch.Size([out_features])) if bias else None
def __init__(self, in_features, out_features, bias, prior=None, with_global=True): super(DoubleRadialLinear, self).__init__() self.with_global = with_global self.in_features = in_features self.out_features = out_features self.row_batch_shape = torch.Size([out_features]) self.row_event_shape = torch.Size([in_features]) self.col_batch_shape = torch.Size([in_features]) self.col_event_shape = torch.Size([out_features]) self.register_buffer('row_batch_ones', torch.ones(self.row_batch_shape)) self.register_buffer('col_batch_ones', torch.ones(self.col_batch_shape)) if bias: self.register_buffer('bias_shape_ones', torch.ones([out_features])) assert prior['direction'][0] == 'vMF' self.row_direction_prior_param = prior['direction'][1] self.row_radius_prior_type = prior['radius'][0] self.row_radius_prior_param = prior['radius'][1].copy() self.col_direction_prior_param = prior['direction'][1] self.col_radius_prior_type = prior['radius'][0] self.col_radius_prior_param = prior['radius'][1].copy() # self.col_radius_prior_param['halfcauchy_tau'] /= 10.0 if in_features > 1 and out_features > 1: self.row_direction_rsampler = VonMisesFisherReparametrizedSample( batch_shape=self.row_batch_shape, event_shape=self.row_event_shape) if self.with_global: self.row_global_scale_rsampler = LognormalReparametrizedSample( batch_shape=torch.Size([1])) self.row_global_scale_rsampler1 = LognormalReparametrizedSample( batch_shape=torch.Size([1])) self.row_radius_rsampler = LognormalReparametrizedSample( batch_shape=self.row_batch_shape) self.row_radius_rsampler1 = LognormalReparametrizedSample( batch_shape=self.row_batch_shape) self.col_direction_rsampler = VonMisesFisherReparametrizedSample( batch_shape=self.col_batch_shape, event_shape=self.col_event_shape) if self.with_global: self.col_global_scale_rsampler = LognormalReparametrizedSample( batch_shape=torch.Size([1])) self.col_global_scale_rsampler1 = LognormalReparametrizedSample( batch_shape=torch.Size([1])) self.col_radius_rsampler = LognormalReparametrizedSample( batch_shape=self.col_batch_shape) self.col_radius_rsampler1 = LognormalReparametrizedSample( batch_shape=self.col_batch_shape) elif in_features == 1 and out_features > 1: self.row_rsampler = NormalReparametrizedSample( batch_shape=self.row_batch_shape) elif out_features == 1 and in_features > 1: self.col_rsampler = NormalReparametrizedSample( batch_shape=self.col_batch_shape) self.bias_rsampler = NormalReparametrizedSample( batch_shape=torch.Size([out_features])) if bias else None