Beispiel #1
0
    def __init__(self,
                 in_features,
                 out_features,
                 p_logvar_init=0.05,
                 q_logvar_init=0.05):
        # p_logvar_init can be either
        # (list/tuples): prior model is a Gaussian distribution
        # q_logvar_init: float, the approximate posterior is currently always a factorized gaussian

        super(BBBLinearFactorial, self).__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.p_logvar_init = p_logvar_init
        self.q_logvar_init = q_logvar_init

        # Approximate posterior weights...
        self.fc_qw_mean = Parameter(torch.Tensor(out_features, in_features))
        self.fc_qw_logvar = Parameter(torch.Tensor(out_features, in_features))
        self.weight = Normal(mu=self.fc_qw_mean, logvar=self.fc_qw_logvar)

        self.register_buffer('eps_weight',
                             torch.Tensor(out_features, in_features))

        # initialise
        self.log_alpha = Parameter(torch.Tensor(1, 1))

        # prior model
        self.pw = Normal(mu=0.0, logvar=p_logvar_init)

        # initialize all paramaters
        self.reset_parameters()
Beispiel #2
0
    def convprobforward(self, input):
        """
        Convolutional probabilistic forwarding method.
        :param input: data tensor
        :return: output, KL-divergence
        """

        # sig_weight = torch.exp(self.conv_qw_std)
        # weight = self.conv_qw_mean + sig_weight * self.eps_weight.normal_()

        weight = self.weight.sample()

        # local reparameterization trick for convolutional layer

        conv_qw_mean = F.conv2d(input=input,
                                weight=weight,
                                stride=self.stride,
                                padding=self.padding,
                                dilation=self.dilation,
                                groups=self.groups)
        conv_qw_std = torch.sqrt(1e-8 +
                                 F.conv2d(input=input.pow(2),
                                          weight=torch.exp(self.log_alpha) *
                                          weight.pow(2),
                                          stride=self.stride,
                                          padding=self.padding,
                                          dilation=self.dilation,
                                          groups=self.groups))

        if cuda:
            conv_qw_mean.cuda()
            conv_qw_std.cuda()

        # sample from output
        if cuda:
            # output = conv_qw_mean + conv_qw_std * (torch.randn(conv_qw_mean.size())).cuda()
            output = conv_qw_mean + conv_qw_std * torch.cuda.FloatTensor(
                conv_qw_mean.size()).normal_()
        else:
            output = conv_qw_mean + conv_qw_std * (torch.randn(
                conv_qw_mean.size()))

        if cuda:
            output.cuda()

        conv_qw = Normal(mu=conv_qw_mean, logvar=conv_qw_std)

        # self.conv_qw_mean = Parameter(torch.Tensor(conv_qw_mean.cpu()))
        # self.conv_qw_std = Parameter(torch.Tensor(conv_qw_std.cpu()))

        w_sample = conv_qw.sample()

        # KL divergence
        qw_logpdf = conv_qw.logpdf(w_sample)

        kl = torch.sum(qw_logpdf - self.pw.logpdf(w_sample))

        return output, kl
Beispiel #3
0
    def __init__(self,
                 in_features,
                 out_features,
                 p_logvar_init=0.05,
                 p_pi=1.0,
                 q_logvar_init=0.05):
        # p_logvar_init, p_pi can be either
        # (list/tuples): prior model is a mixture of Gaussians components=len(p_pi)=len(p_logvar_init)
        # float: Gussian distribution
        # q_logvar_init: float, the approximate posterior is currently always a factorized gaussian
        super(BBBLinearFactorial, self).__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.p_logvar_init = p_logvar_init
        self.q_logvar_init = q_logvar_init

        # self.weight = Parameter(torch.Tensor(out_features, in_features))

        self.fc_qw_mean = Parameter(torch.Tensor(out_features, in_features))
        self.fc_qw_std = Parameter(torch.Tensor(out_features, in_features))

        self.weight = Normal(mu=self.fc_qw_mean, logvar=self.fc_qw_std)

        # Approximate posterior weights...
        # self.qw_mean = Parameter(torch.Tensor(out_features, in_features))
        # self.qw_logvar = Parameter(torch.Tensor(out_features, in_features))

        # optionally add bias
        # self.qb_mean = Parameter(torch.Tensor(out_features))
        # self.qb_logvar = Parameter(torch.Tensor(out_features))

        # ...and output...

        self.register_buffer('eps_weight',
                             torch.Tensor(out_features, in_features))

        # ...as normal distributions
        # self.qw = Normal(mu=self.qw_mean, logvar=self.qw_logvar)
        # self.qb = Normal(mu=self.qb_mean, logvar=self.qb_logvar)
        # self.fc_qw = Normal(mu=self.fc_qw_mean, logvar=self.fc_qw_std)

        # initialise
        self.log_alpha = Parameter(torch.Tensor(1, 1))

        # prior model
        self.pw = distribution_selector(mu=0.0, logvar=p_logvar_init, pi=p_pi)
        # self.pb = distribution_selector(mu=0.0, logvar=p_logvar_init, pi=p_pi)

        # initialize all paramaters
        self.reset_parameters()
Beispiel #4
0
    def fcprobforward(self, input):
        """
        Probabilistic forwarding method.
        :param input: data tensor
        :return: output, kl-divergence
        """

        # sig_weight = torch.exp(self.fc_qw_std)
        # weight = self.fc_qw_mean + sig_weight * self.eps_weight.normal_()

        weight = self.weight.sample()

        fc_qw_mean = F.linear(input=input, weight=weight)
        fc_qw_std = torch.sqrt(1e-8 +
                               F.linear(input=input.pow(2),
                                        weight=torch.exp(self.log_alpha) *
                                        weight.pow(2)))

        if cuda:
            fc_qw_mean.cuda()
            fc_qw_std.cuda()

        # sample from output
        if cuda:
            # output = fc_qw_mean + fc_qw_si * (torch.randn(fc_qw_mean.size())).cuda()
            output = fc_qw_mean + fc_qw_std * torch.cuda.FloatTensor(
                fc_qw_mean.size()).normal_()
        else:
            output = fc_qw_mean + fc_qw_std * (torch.randn(fc_qw_mean.size()))

        if cuda:
            output.cuda()

        # self.fc_qw_mean = Parameter(torch.Tensor(fc_qw_mean.cpu()))
        # self.fc_qw_std = Parameter(torch.Tensor(fc_qw_std.cpu()))

        fc_qw = Normal(mu=fc_qw_mean, logvar=fc_qw_std)

        w_sample = fc_qw.sample()

        # KL divergence
        qw_logpdf = fc_qw.logpdf(w_sample)

        kl = torch.sum(qw_logpdf - self.pw.logpdf(w_sample))

        return output, kl
Beispiel #5
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride,
                 padding,
                 dilation,
                 output_padding,
                 groups,
                 p_logvar_init=0.05,
                 p_pi=1.0,
                 q_logvar_init=0.05):
        super(_ConvNd, self).__init__()
        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.output_padding = output_padding
        self.groups = groups

        # initialize log variance of p and q
        self.p_logvar_init = p_logvar_init
        self.q_logvar_init = q_logvar_init

        # self.weight = Parameter(torch.Tensor(out_channels, in_channels// groups, kernel_size, kernel_size))

        # approximate posterior weights...
        # self.qw_mean = Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size))
        # self.qw_logvar = Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size))

        # optionally add bias
        # self.qb_mean = Parameter(torch.Tensor(out_channels))
        # self.qb_logvar = Parameter(torch.Tensor(out_channels))

        # ...and output...
        self.conv_qw_mean = Parameter(
            torch.Tensor(out_channels, in_channels // groups, *kernel_size))
        self.conv_qw_std = Parameter(
            torch.Tensor(out_channels, in_channels // groups, *kernel_size))

        self.register_buffer(
            'eps_weight',
            torch.Tensor(out_channels, in_channels // groups, *kernel_size))

        # ...as normal distributions
        # self.qw = Normal(mu=self.qw_mean, logvar=self.qw_logvar)
        # self.qb = Normal(mu=self.qb_mean, logvar=self.qb_logvar)

        # self.conv_qw = Normal(mu=self.conv_qw_mean, logvar=self.conv_qw_std)

        self.weight = Normal(mu=self.conv_qw_mean, logvar=self.conv_qw_std)

        # initialise
        self.log_alpha = Parameter(torch.Tensor(1, 1))

        # prior model
        # (does not have any trainable parameters so we use fixed normal or fixed mixture normal distributions)
        self.pw = distribution_selector(mu=0.0, logvar=p_logvar_init, pi=p_pi)
        # self.pb = distribution_selector(mu=0.0, logvar=p_logvar_init, pi=p_pi)

        # initialize all parameters
        self.reset_parameters()