class RiemannLayer(nn.Module):
	
	def __init__(self, in_features, out_features, manifold, over_param):
		super(RiemannLayer, self).__init__()
		self.in_features = in_features
		self.out_features = out_features
		self.weight = Parameter(torch.Tensor(out_features, in_features))
		self.over_param = over_param
		if self.over_param:
			self._bias = ManifoldParameter(torch.Tensor(out_features, in_features), manifold=manifold)
		else:
			self._bias = Parameter(torch.Tensor(out_features, 1))
		self.manifold = manifold
		self.reset_parameters()
	
	@property
	def weight(self):
		return self.manifold.transp0(self._bias, self._weight)
	
	@property
	def bias(self):
		return self.manifold.expmap0(self.weight.mul(self._bias))
    
	def reset_parameters(self):
		nn.init.kaiming_normal_(self.weight, a =math.sqrt(5))
		if self.bias is not None:
			fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
			bound = 4 / math.sqrt(fan_in)
			nn.init.uniform_(self.bias, -bound, bound)
			if self.over_param:
				with torch.no_grad(): self._bias.set_(self.manifold.expmap0(self._bias))
예제 #2
0
class ConvexNet(Module):

    __constants__ = ['width']

    def __init__(self, width):
        super(ConvexNet, self).__init__()

        self.width = width
        self.weight = Parameter(torch.Tensor(2, width))
        self.bias = Parameter(torch.Tensor(width))

    def set_value(self, lines: List[Line]):
        assert len(lines) == self.width

        weight = [[], []]
        bias = []

        for line in lines:
            weight[0].append(line.a)
            weight[1].append(line.b)
            bias.append(line.c)

        self.weight = Parameter(torch.Tensor(weight))
        self.bias = Parameter(torch.Tensor(bias))

    def get_value(self) -> List[Line]:
        ret = []
        weight_list = self.weight.tolist()
        bias_list = self.bias.tolist()

        for i in range(self.width):
            ret.append(Line(weight_list[0][i], weight_list[1][i],
                            bias_list[i]))

        return ret

    def forward(self, input):
        length2 = torch.sum(self.weight.mul(self.weight), dim=0)
        length = torch.sqrt(length2)

        approx_dis = torch.addmm(self.bias, input, self.weight)
        real_dis = torch.div(approx_dis, length)

        max_dis, _ = torch.max(real_dis, dim=1)
        ret = torch.sigmoid(max_dis * 1e2)

        # print(real_dis)
        # print(max_dis)
        # print(ret)

        return ret
예제 #3
0
class NoisyConv2d(Module):
    """Applies a noisy conv2d transformation to the incoming data:
    More details can be found in the paper `Noisy Networks for Exploration` _ .
    Args:
        in_channels: size of each input sample
        out_channels: size of each output sample
        bias: If set to False, the layer will not learn an additive bias. Default: True
        factorised: whether or not to use factorised noise. Default: True
        std_init: initialization constant for standard deviation component of weights. If None,
            defaults to 0.017 for independent and 0.4 for factorised. Default: None
    Shape:
        - Input: :math:`(N, in\_features)`
        - Output: :math:`(N, out\_features)`
    Attributes:
        weight: the learnable weights of the module of shape (out_features x in_features)
        bias:   the learnable bias of the module of shape (out_features)
    Examples::
        >>> m = NoisyConv2d(4, 2, (3,1))
        >>> input = torch.autograd.Variable(torch.randn(1, 4, 51, 3))
        >>> output = m(input)
        >>> print(output.size())
    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 bias=True,
                 stride=1,
                 padding=1,
                 dilation=1,
                 groups=1,
                 factorised=True,
                 std_init=None,
                 gpu_id=0):
        super(NoisyConv2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride)
        self.padding = _pair(padding)
        self.dilation = _pair(dilation)
        self.groups = groups
        self.factorised = factorised
        self.weight_mu = Parameter(
            torch.Tensor(out_channels, in_channels // groups, *kernel_size))
        self.weight_sigma = Parameter(
            torch.Tensor(out_channels, in_channels // groups, *kernel_size))
        self.gpu_id = gpu_id
        if bias:
            self.bias_mu = Parameter(torch.Tensor(out_channels))
            self.bias_sigma = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        if not std_init:
            if self.factorised:
                self.std_init = 0.4
            else:
                self.std_init = 0.017
        else:
            self.std_init = std_init
        self.reset_parameters(bias)

    def reset_parameters(self, bias):
        if self.factorised:
            mu_range = 1. / math.sqrt(self.weight_mu.size(1))
            self.weight_mu.data.uniform_(-mu_range, mu_range)
            self.weight_sigma.data.fill_(self.std_init /
                                         math.sqrt(self.weight_sigma.size(1)))
            if bias:
                self.bias_mu.data.uniform_(-mu_range, mu_range)
                self.bias_sigma.data.fill_(self.std_init /
                                           math.sqrt(self.bias_sigma.size(0)))
        else:
            mu_range = math.sqrt(3. / self.weight_mu.size(1))
            self.weight_mu.data.uniform_(-mu_range, mu_range)
            self.weight_sigma.data.fill_(self.std_init)
            if bias:
                self.bias_mu.data.uniform_(-mu_range, mu_range)
                self.bias_sigma.data.fill_(self.std_init)

    def scale_noise(self, size):
        with torch.cuda.device(self.gpu_id):
            x = torch.Tensor(size).normal_().cuda()
            x = x.sign().mul(x.abs().sqrt())
        return x

    def forward(self, input):
        if self.factorised:
            epsilon = None
            for dim in self.weight_sigma.size():
                if epsilon is None:
                    epsilon = self.scale_noise(dim)
                else:
                    epsilon = epsilon.unsqueeze(-1) * self.scale_noise(dim)
            weight_epsilon = Variable(epsilon)
            bias_epsilon = Variable(self.scale_noise(self.out_channels))
        else:
            with torch.cuda.device(self.gpu_id):
                weight_epsilon = Variable(
                    torch.Tensor(self.out_channels, self.in_channels,
                                 *self.kernel_size).normal_()).cuda()
                bias_epsilon = Variable(
                    torch.Tensor(self.out_channels).normal_()).cuda()
        return F.conv2d(input,
                        self.weight_mu + self.weight_sigma.mul(weight_epsilon),
                        self.bias_mu + self.bias_sigma.mul(bias_epsilon),
                        stride=self.stride,
                        padding=self.padding,
                        dilation=self.dilation,
                        groups=self.groups)

    def __repr__(self):
        s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}')
        if self.padding != (0, ) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1, ) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias_mu is None:
            s += ', bias=False'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)
예제 #4
0
class NoisyLinear(Module):
    r"""Applies a noisy linear transformation to the incoming data.
    During training:
        .. math:: `y = (mu_w + sigma_w \cdot epsilon_w)x
            + mu_b + sigma_b \cdot epsilon_b`
    During evaluation:
        .. math:: `y = mu_w * x + mu_b`
    More details can be found in the paper `Noisy Networks for Exploration` _ .
    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        bias: If set to False, the layer will not learn an additive bias.
            Default: True
        factorized: whether or not to use factorized noise.
            Default: True
        std_init: constant for weight_sigma and bias_sigma initialization.
            If None, defaults to 0.017 for independent and 0.4 for factorized.
            Default: None
    Shape:
        - Input: :math:`(N, in\_features)`
        - Output: :math:`(N, out\_features)`
    Attributes:
        weight: the learnable weights of the module of shape
            (out_features x in_features)
        bias:   the learnable bias of the module of shape (out_features)
    Methods:
        resample: resamples the noise tensors
    Examples::
        >>> m = nn.NoisyLinear(20, 30)
        >>> input = autograd.Variable(torch.randn(128, 20))
        >>> output = m(input)
        >>> m.resample()
        >>> output_new = m(input)
        >>> print(output)
        >>> print(output_new)
    """
    def __init__(self,
                 in_features,
                 out_features,
                 bias=True,
                 factorized=True,
                 std_init=None):
        super(NoisyLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.factorized = factorized
        self.include_bias = bias
        self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
        self.weight_sigma = Parameter(torch.Tensor(out_features, in_features))
        self.register_buffer('weight_epsilon',
                             torch.Tensor(out_features, in_features))
        if self.include_bias:
            self.bias_mu = Parameter(torch.Tensor(out_features))
            self.bias_sigma = Parameter(torch.Tensor(out_features))
            self.register_buffer('bias_epsilon', torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        if not std_init:
            if self.factorized:
                self.std_init = 0.4
            else:
                self.std_init = 0.017
        else:
            self.std_init = std_init
        self.reset_parameters()
        self.resample()

    def reset_parameters(self):
        if self.factorized:
            mu_range = 1. / math.sqrt(self.weight_mu.size(1))
            self.weight_mu.data.uniform_(-mu_range, mu_range)
            self.weight_sigma.data.fill_(self.std_init /
                                         math.sqrt(self.weight_sigma.size(1)))
            if self.include_bias:
                self.bias_mu.data.uniform_(-mu_range, mu_range)
                self.bias_sigma.data.fill_(self.std_init /
                                           math.sqrt(self.bias_sigma.size(0)))
        else:
            mu_range = math.sqrt(3. / self.weight_mu.size(1))
            self.weight_mu.data.uniform_(-mu_range, mu_range)
            self.weight_sigma.data.fill_(self.std_init)
        if self.include_bias:
            self.bias_mu.data.uniform_(-mu_range, mu_range)
            self.bias_sigma.data.fill_(self.std_init)

    def _scale_noise(self, size):
        x = torch.randn(size)
        x = x.sign().mul(x.abs().sqrt())
        return x

    def resample(self):
        if self.factorized:
            epsilon_in = self._scale_noise(self.in_features)
            epsilon_out = self._scale_noise(self.out_features)
            self.weight_epsilon.copy_(epsilon_out.ger(epsilon_in))
            if self.include_bias:
                self.bias_epsilon.copy_(self._scale_noise(self.out_features))
        else:
            self.weight_epsilon.normal_()
            if self.include_bias:
                self.bias_epsilon.normal_()

    def forward(self, input):
        if self.training:
            return F.linear(
                input, self.weight_mu +
                self.weight_sigma.mul(Variable(self.weight_epsilon)),
                self.bias_mu +
                self.bias_sigma.mul(Variable(self.bias_epsilon)))
        else:
            return F.linear(input, self.weight_mu, self.bias_mu)

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
            + str(self.in_features) + ' -> ' \
            + str(self.out_features) + ', Factorized: ' \
            + str(self.factorized) + ')'
class HSDense(Module):
    def __init__(self,
                 in_features,
                 out_features,
                 bias=True,
                 prior_std=1.,
                 prior_std_z=1.,
                 dof=1.,
                 **kwargs):
        super(HSDense, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.prior_std = prior_std
        self.mean_w = Parameter(torch.Tensor(in_features, out_features))
        self.logvar_w = Parameter(torch.Tensor(in_features, out_features))
        self.qz_mean = Parameter(torch.Tensor(in_features))
        self.qz_logvar = Parameter(torch.Tensor(in_features))
        self.dof = dof
        self.prior_std_z = prior_std_z
        self.use_bias = False
        if bias:
            self.mean_bias = Parameter(torch.Tensor(out_features))
            self.logvar_bias = Parameter(torch.Tensor(out_features))
            self.use_bias = True
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.reset_parameters()
        print(self)

    def reset_parameters(self):
        init.kaiming_normal_(self.mean_w, mode='fan_out')
        self.logvar_w.data.normal_(-9., 1e-4)

        self.qz_mean.data.normal_(math.log(math.exp(1) - 1), 1e-3)
        self.qz_logvar.data.normal_(math.log(0.1), 1e-4)
        if self.use_bias:
            self.mean_bias.data.normal_(0, 1e-2)
            self.logvar_bias.data.normal_(-9., 1e-4)

    def constrain_parameters(self, thres_std=1.):
        self.logvar_w.data.clamp_(max=2. * math.log(thres_std))
        if self.use_bias:
            self.logvar_bias.data.clamp_(max=2. * math.log(thres_std))

    def eq_logpw(self):
        logpw = -.5 * math.log(
            2 * math.pi * self.prior_std**2) - .5 * self.logvar_w.exp().div(
                self.prior_std**2)
        logpw -= .5 * self.mean_w.pow(2).div(self.prior_std**2)
        logpb = 0.
        if self.use_bias:
            logpb = - .5 * math.log(2 * math.pi * self.prior_std ** 2) - .5 * self.logvar_bias.exp().div \
                (self.prior_std ** 2)
            logpb -= .5 * self.mean_bias.pow(2).div(self.prior_std**2)
        return torch.sum(logpw) + torch.sum(logpb)

    def eq_logqw(self):
        logqw = -torch.sum(.5 * (math.log(2 * math.pi) + self.logvar_w + 1))
        logqb = 0.
        if self.use_bias:
            logqb = -torch.sum(.5 *
                               (math.log(2 * math.pi) + self.logvar_bias + 1))
        return logqw + logqb

    def kldiv_aux(self):
        z = self.sample_z(1)
        z = z.view(self.in_features)

        logqm = -torch.sum(.5 * (math.log(2 * math.pi) + self.qz_logvar + 1))
        logqm = logqm.add(-torch.sum(F.sigmoid(z.exp().add(-1).log()).log()))

        logpm = torch.sum(
            2 * math.lgamma(.5 * (self.dof + 1)) - math.lgamma(.5 * self.dof) -
            math.log(self.prior_std_z) - .5 * math.log(self.dof * math.pi) -
            .5 * (self.dof + 1) * torch.log(1. + z.pow(2) /
                                            (self.dof * self.prior_std_z**2)))

        return logpm - logqm

    def kldiv(self):
        return self.kldiv_aux() + self.eq_logpw() - self.eq_logqw()

    def get_eps(self, size):
        eps = self.floatTensor(size).normal_()
        eps = Variable(eps)
        return eps

    def sample_z(self, batch_size):
        z = self.qz_mean.view(1, self.in_features)
        if self.training:
            eps = self.get_eps(self.floatTensor(batch_size, self.in_features))
            z = z + eps.mul(
                self.qz_logvar.view(1, self.in_features).mul(0.5).exp_())
        return F.softplus(z)

    def sample_W(self):
        W = self.mean_w
        if self.training:
            eps = self.get_eps(self.mean_w.size())
            W = W.add(eps.mul(self.logvar_w.mul(0.5).exp_()))
        return W

    def sample_b(self):
        b = self.mean_bias
        if self.training:
            eps = self.get_eps(self.mean_bias.size())
            b = b.add(eps.mul(self.logvar_bias.mul(0.5).exp_()))
        return b

    def get_mean_x(self, input):
        mean_xin = input.mm(self.mean_w)
        if self.use_bias:
            mean_xin = mean_xin.add(self.mean_bias.view(1, self.out_features))

        return mean_xin

    def get_var_x(self, input):
        var_xin = input.pow(2).mm(self.logvar_w.exp())
        if self.use_bias:
            var_xin = var_xin.add(self.logvar_bias.exp().view(
                1, self.out_features))

        return var_xin

    def forward(self, input):
        # sampling
        batch_size = input.size(0)
        z = self.sample_z(batch_size)
        xin = input.mul(z)
        mean_xin = self.get_mean_x(xin)
        output = mean_xin
        if self.training:
            var_xin = self.get_var_x(xin)
            eps = self.get_eps(self.floatTensor(batch_size, self.out_features))
            output = output.add(var_xin.sqrt().mul(eps))
        return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ', dof: ' \
               + str(self.dof) + ', prior_std_z: ' \
               + str(self.prior_std_z) + ')'
class KernelDenseBayesian(Module):
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 dim: int = 2,
                 use_bias: bool = True,
                 prior_std: float = 1.,
                 bias_std: float = 1e-3,
                 **kwargs):
        super(KernelDenseBayesian, self).__init__()
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.in_features = in_features
        self.out_features = out_features
        self.dim = dim
        self.prior_std = prior_std
        self.bias_std = bias_std

        self.columns_mean = Parameter(
            self.floatTensor(self.in_features, self.dim))
        self.columns_logvar = Parameter(
            self.floatTensor(self.in_features, self.dim))

        self.rows_mean = Parameter(
            self.floatTensor(self.out_features, self.dim))
        self.rows_logvar = Parameter(
            self.floatTensor(self.out_features, self.dim))

        self.alpha_mean = Parameter(self.floatTensor(self.in_features))
        self.alpha_logvar = Parameter(self.floatTensor(self.in_features))

        self.use_bias = use_bias
        if self.use_bias:
            self.bias_mean = Parameter(self.floatTensor(self.out_features))
            self.bias_logvar = Parameter(self.floatTensor(self.out_features))

        self.reset_parameters()
        print(self)

    def reset_parameters(self):
        self.columns_mean.data.normal_(std=self.prior_std)
        self.columns_logvar.data.normal_(std=self.prior_std)

        self.rows_mean.data.normal_(std=self.prior_std)
        self.rows_logvar.data.normal_(std=self.prior_std)

        self.alpha_mean.data.normal_(std=self.prior_std)
        self.alpha_logvar.data.normal_(std=self.prior_std)

        if self.use_bias:
            self.bias_mean.data.normal_(std=self.bias_std)
            self.bias_logvar.data.normal_(std=self.bias_std)

    def _calc_rbf_weights(self, rows: torch.Tensor, columns: torch.Tensor):
        x2 = rows.pow(2).sum(dim=1).view(1, self.out_features)
        y2 = columns.pow(2).sum(dim=1).view(self.in_features, 1)
        xy = columns.mm(rows.t()).mul(-2.)

        return x2.add(y2).add(xy).mul(-1).exp()

    def _sample_eps(self, shape: tuple):
        return Variable(self.floatTensor(shape).normal_())

    def _eq_logpw(self, prior_std: float, mean: torch.Tensor,
                  logvar: torch.Tensor) -> torch.Tensor:
        logpw = logvar.exp().add(mean**2).div(prior_std**2).add(
            math.log(2. * math.pi * (prior_std**2))).mul(-0.5)
        return torch.sum(logpw)

    def _eq_logqw(self, logvar: torch.Tensor):
        logqw = logvar.add(math.log(2. * math.pi)).add(1.).mul(-0.5)
        return torch.sum(logqw)

    def eq_logpw(self) -> torch.Tensor:
        rows = self._eq_logpw(prior_std=self.prior_std,
                              mean=self.rows_mean,
                              logvar=self.rows_logvar)
        columns = self._eq_logpw(prior_std=self.prior_std,
                                 mean=self.columns_mean,
                                 logvar=self.columns_logvar)
        alpha = self._eq_logpw(prior_std=self.prior_std,
                               mean=self.alpha_mean,
                               logvar=self.alpha_logvar)
        logpw = rows.add(columns).add(alpha)

        if self.use_bias:
            bias = self._eq_logpw(prior_std=self.prior_std,
                                  mean=self.bias_mean,
                                  logvar=self.bias_logvar)
            logpw.add(bias)
        return logpw

    def eq_logqw(self):
        rows = self._eq_logqw(logvar=self.rows_logvar)
        columns = self._eq_logqw(logvar=self.columns_logvar)
        alpha = self._eq_logqw(logvar=self.alpha_logvar)
        logqw = rows.add(columns).add(alpha)

        if self.use_bias:
            bias = self._eq_logqw(logvar=self.bias_logvar)
            logqw.add(bias)
        return logqw

    def kldiv(self):
        return self.eq_logpw() - self.eq_logqw()

    def kldiv_aux(self) -> float:
        return 0.

    def forward(self, input: torch.Tensor):

        rows = self.rows_mean
        columns = self.columns_mean
        alpha = self.alpha_mean

        if self.training:
            rows.add(
                self.rows_logvar.mul(0.5).exp().mul(
                    self._sample_eps(rows.shape)))
            columns.add(
                self.columns_logvar.mul(0.5).exp().mul(
                    self._sample_eps(columns.shape)))
            alpha.add(
                self.alpha_logvar.mul(0.5).exp().mul(
                    self._sample_eps(alpha.shape)))

        w = self._calc_rbf_weights(rows=rows, columns=columns)

        y = input.mul(alpha).mm(w)

        if self.use_bias:
            y.add(self.bias_mean.view(1, self.out_features))
            if self.training:
                y.add(
                    self.bias_logvar.mul(0.5).exp().mul(
                        self._sample_eps(self.bias_logvar.shape)).view(
                            1, self.out_features))
        return y

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
            + str(self.in_features) + ' -> ' \
            + str(self.out_features) + ', dim: ' \
            + str(self.dim) + ')'
class OrthogonalBayesianDense(Module):
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 order: int = 8,
                 use_bias: bool = True,
                 add_diagonal: bool = True,
                 prior_std: float = 1.,
                 bias_std: float = 1e-2,
                 **kwargs):
        super(OrthogonalBayesianDense, self).__init__()
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.device = torch.device(
            'cpu') if not torch.cuda.is_available() else torch.device('cuda')
        self.in_features = in_features
        self.out_features = out_features
        self.add_diagonal = add_diagonal
        self.prior_std = prior_std
        self.bias_std = bias_std

        max_feature = max(self.in_features, self.out_features)

        self.order = order or max_feature
        assert 1 <= self.order <= max_feature

        self.v_mean = Parameter(self.floatTensor(self.order, max_feature))
        self.v_logvar = Parameter(self.floatTensor(self.order, max_feature))

        if self.add_diagonal:
            self.d_mean = Parameter(
                self.floatTensor(min(self.in_features, self.out_features)))
            self.d_logvar = Parameter(
                self.floatTensor(min(self.in_features, self.out_features)))

        self.use_bias = use_bias
        if self.use_bias:
            self.bias_mean = Parameter(self.floatTensor(out_features))
            self.bias_logvar = Parameter(self.floatTensor(out_features))

        self.reset_parameters()
        print(self)

    def reset_parameters(self):
        self.v_mean.data.normal_(std=self.prior_std)
        self.v_logvar.data.normal_(mean=-5., std=self.prior_std)

        if self.add_diagonal:
            self.d_mean.data.normal_(std=self.prior_std)
            self.d_logvar.data.normal_(mean=-5., std=self.prior_std)

        if self.use_bias:
            self.bias_mean.data.normal_(std=self.bias_std)
            self.bias_logvar.data.normal_(mean=-5., std=self.bias_std)

    def _sample_eps(self, shape: tuple):
        return Variable(self.floatTensor(shape).normal_())

    def _eq_logpw(self, prior_mean: float, prior_std: float,
                  mean: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        logpw = logvar.exp().add((prior_mean - mean)**2).div(prior_std**2).add(
            math.log(2. * math.pi * (prior_std**2))).mul(-0.5)
        return torch.sum(logpw)

    def _eq_logqw(self, logvar: torch.Tensor):
        logqw = logvar.add(math.log(2. * math.pi)).add(1.).mul(-0.5)
        return torch.sum(logqw)

    def eq_logpw(self) -> torch.Tensor:
        v = self._eq_logpw(prior_mean=0.,
                           prior_std=self.prior_std,
                           mean=self.v_mean,
                           logvar=self.v_logvar)

        if self.add_diagonal:
            d = self._eq_logpw(prior_mean=0.,
                               prior_std=self.prior_std,
                               mean=self.d_mean,
                               logvar=self.d_logvar)
            logpw = v.add(d)
        else:
            logpw = v

        if self.use_bias:
            bias = self._eq_logpw(prior_mean=0.,
                                  prior_std=self.prior_std,
                                  mean=self.bias_mean,
                                  logvar=self.bias_logvar)
            logpw.add(bias)
        return logpw

    def eq_logqw(self):
        v = self._eq_logqw(logvar=self.v_logvar)

        if self.add_diagonal:
            d = self._eq_logqw(logvar=self.d_logvar)
            logqw = v.add(d)
        else:
            logqw = v

        if self.use_bias:
            bias = self._eq_logqw(logvar=self.bias_logvar)
            logqw.add(bias)
        return logqw

    def kldiv(self):
        return self.eq_logpw() - self.eq_logqw()

    def kldiv_aux(self) -> float:
        return 0.

    def _chain_multiply(self, t: torch.Tensor) -> torch.Tensor:
        p = t[0]
        for i in range(1, t.shape[0]):
            p = p.mm(t[i])
        return p

    def _calc_householder_tensor(self, t: torch.Tensor) -> torch.Tensor:
        norm = t.norm(p=2, dim=1)
        t = t.div(norm.unsqueeze(1))
        h = torch.einsum('ab,ac->abc', (t, t))
        return torch.eye(n=t.shape[1],
                         device=self.device).expand_as(h) - h.mul(2.)

    def _calc_weights(self, v: torch.Tensor, d: torch.Tensor) -> torch.Tensor:

        u = self._chain_multiply(self._calc_householder_tensor(v))

        if self.out_features <= self.in_features:
            D = torch.eye(n=self.in_features,
                          m=self.out_features,
                          device=self.device).mm(torch.diag(d))
            W = u.mm(D)
        else:
            D = torch.diag(d).mm(
                torch.eye(n=self.in_features,
                          m=self.out_features,
                          device=self.device))
            W = D.mm(u)
        return W

    def forward(self, input: torch.Tensor):

        v = self.v_mean.add(
            self.v_logvar.mul(0.5).exp().mul(
                self._sample_eps(self.v_logvar.shape)))

        if self.add_diagonal:
            d = self.d_mean.add(
                self.d_logvar.mul(0.5).exp().mul(
                    self._sample_eps(self.d_logvar.shape)))
        else:
            d = torch.ones(min(self.in_features, self.out_features),
                           device=self.device)

        w = self._calc_weights(v=v, d=d)
        y = input.mm(w)

        if self.use_bias:
            bias = self.bias_mean.add(
                self.bias_logvar.mul(0.5).exp().mul(
                    self._sample_eps(self.bias_logvar.shape)))
            return y.add(bias.view(1, self.out_features))
        return y

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
            + str(self.in_features) + ' -> ' \
            + str(self.out_features) + ', order: ' \
            + str(self.order) + ', add_diagonal: ' \
            + str(self.add_diagonal) + ')'
class FFGaussDense(Module):
    def __init__(self,
                 in_features,
                 out_features,
                 bias=True,
                 prior_std=1.,
                 **kwargs):
        super(FFGaussDense, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.prior_std = prior_std
        self.mean_w = Parameter(torch.Tensor(in_features, out_features))
        self.logvar_w = Parameter(torch.Tensor(in_features, out_features))
        self.use_bias = False
        if bias:
            self.mean_bias = Parameter(torch.Tensor(out_features))
            self.logvar_bias = Parameter(torch.Tensor(out_features))
            self.use_bias = True
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.reset_parameters()
        print(self)

    def reset_parameters(self):
        init.kaiming_normal_(self.mean_w, mode='fan_out')
        self.logvar_w.data.normal_(-9., 1e-4)

        if self.use_bias:
            self.mean_bias.data.zero_()
            self.logvar_bias.data.normal_(-9., 1e-4)

    def constrain_parameters(self, thres_std=1.):
        self.logvar_w.data.clamp_(max=2. * math.log(thres_std))
        if self.use_bias:
            self.logvar_bias.data.clamp_(max=2. * math.log(thres_std))

    def eq_logpw(self):
        logpw = -.5 * math.log(
            2 * math.pi * self.prior_std**2) - .5 * self.logvar_w.exp().div(
                self.prior_std**2)
        logpw -= .5 * self.mean_w.pow(2).div(self.prior_std**2)
        logpb = 0.
        if self.use_bias:
            logpb = -.5 * math.log(2 * math.pi * self.prior_std**
                                   2) - .5 * self.logvar_bias.exp().div(
                                       self.prior_std**2)
            logpb -= .5 * self.mean_bias.pow(2).div(self.prior_std**2)
        return torch.sum(logpw) + torch.sum(logpb)

    def eq_logqw(self):
        logqw = -torch.sum(.5 * (math.log(2 * math.pi) + self.logvar_w + 1))
        logqb = 0.
        if self.use_bias:
            logqb = -torch.sum(.5 *
                               (math.log(2 * math.pi) + self.logvar_bias + 1))
        return logqw + logqb

    def kldiv_aux(self):
        return 0.

    def kldiv(self):
        return self.kldiv_aux() + self.eq_logpw() - self.eq_logqw()

    def get_eps(self, size):
        eps = self.floatTensor(size).normal_()
        eps = Variable(eps)
        return eps

    def sample_pW(self):
        return self.floatTensor(self.in_features, self.out_features).normal_()

    def sample_pb(self):
        return self.floatTensor(self.out_features).normal_()

    def sample_W(self):
        w = self.mean_w
        if self.training:
            eps = self.get_eps(self.mean_w.size())
            w = w.add(eps.mul(self.logvar_w.mul(0.5).exp_()))
        return w

    def sample_b(self):
        b = self.mean_bias
        if self.training:
            eps = self.get_eps(self.mean_bias.size())
            b = b.add(eps.mul(self.logvar_bias.mul(0.5).exp_()))
        return b

    def get_mean_x(self, input):
        mean_xin = input.mm(self.mean_w)
        if self.use_bias:
            mean_xin = mean_xin.add(self.mean_bias.view(1, self.out_features))
        return mean_xin

    def get_var_x(self, input):
        var_xin = input.pow(2).mm(self.logvar_w.exp())
        if self.use_bias:
            var_xin = var_xin.add(self.logvar_bias.exp().view(
                1, self.out_features))

        return var_xin

    def forward(self, input):
        batch_size = input.size(0)
        mean_xin = self.get_mean_x(input)
        if self.training:
            var_xin = self.get_var_x(input)
            eps = self.get_eps(self.floatTensor(batch_size, self.out_features))
            output = mean_xin.add(var_xin.sqrt().mul(eps))
        else:
            output = mean_xin
        return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
            + str(self.in_features) + ' -> ' \
            + str(self.out_features) + ', prior_std: ' \
            + str(self.prior_std) + ')'
class KernelBayesianConv2(ConvNd):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: int,
                 dim: int = 2,
                 stride: int = 1,
                 padding: int = 0,
                 dilation: int = 1,
                 groups: int = 1,
                 use_bias: bool = True,
                 prior_std: float = 1.,
                 bias_std: float = 1e-3,
                 **kwargs):

        stride = pair(stride)
        padding = pair(padding)
        dilation = pair(dilation)

        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.groups = groups
        self.dim = dim
        self.kernel_size = pair(kernel_size)
        self.use_bias = use_bias
        self.prior_std = prior_std
        self.bias_std = bias_std

        super(KernelBayesianConv2, self).__init__(in_channels, out_channels,
                                                  self.kernel_size, stride,
                                                  padding, dilation, False,
                                                  pair(0), groups, use_bias)

        self.columns_mean = Parameter(
            self.floatTensor(self.in_channels * int(np.prod(self.kernel_size)),
                             self.dim))
        self.columns_logvar = Parameter(
            self.floatTensor(self.in_channels * int(np.prod(self.kernel_size)),
                             self.dim))

        self.rows_mean = Parameter(
            self.floatTensor(
                self.out_channels * int(np.prod(self.kernel_size)) // groups,
                self.dim))
        self.rows_logvar = Parameter(
            self.floatTensor(
                self.out_channels * int(np.prod(self.kernel_size)) // groups,
                self.dim))

        self.alpha_mean = Parameter(
            self.floatTensor(self.out_channels // groups, self.in_channels))
        self.alpha_logvar = Parameter(
            self.floatTensor(self.out_channels // groups, self.in_channels))

        self.use_bias = use_bias
        if self.use_bias:
            self.bias_mean = Parameter(
                self.floatTensor(self.out_channels // self.groups))
            self.bias_logvar = Parameter(
                self.floatTensor(self.out_channels // self.groups))

        self.reset_parameters()
        print(self)

    def reset_parameters(self):
        init.kaiming_normal_(self.weight, mode='fan_in')

        if hasattr(self, 'columns_mean'):
            self.columns_mean.data.normal_(std=self.prior_std)
            self.columns_logvar.data.normal_(std=self.prior_std)

        if hasattr(self, 'rows_mean'):
            self.rows_mean.data.normal_(std=self.prior_std)
            self.rows_logvar.data.normal_(std=self.prior_std)

        if hasattr(self, 'alpha_mean'):
            self.alpha_mean.data.normal_(std=self.prior_std)
            self.alpha_logvar.data.normal_(std=self.prior_std)

        if hasattr(self, 'bias_mean'):
            self.bias_mean.data.normal_(std=self.bias_std)
            self.bias_logvar.data.normal_(std=self.bias_std)

    def _calc_rbf_weights(self, rows: torch.Tensor, columns: torch.Tensor,
                          alpha: torch.Tensor) -> Parameter:
        w = self.floatTensor(self.out_channels // self.groups,
                             self.in_channels, np.prod(self.kernel_size))

        for i in range(int(np.prod(self.kernel_size))):
            row_start = i * (self.out_channels // self.groups)
            row_stop = (i + 1) * (self.out_channels // self.groups)
            col_start = i * self.in_channels
            col_stop = (i + 1) * self.in_channels

            x2 = rows[row_start:row_stop, :].pow(2).sum(dim=1).view(
                self.out_channels // self.groups, 1)
            y2 = columns[col_start:col_stop, :].pow(2).sum(dim=1).view(
                1, self.in_channels)
            xy = rows[row_start:row_stop, :].mm(
                columns[col_start:col_stop, :].t()).mul(-2.)

            w[:, :, i] = x2.add(y2).add(xy).mul(-1).exp().mul(alpha)

        return w.view(self.out_channels // self.groups, self.in_channels,
                      *self.kernel_size)

    def _sample_eps(self, shape: tuple):
        return Variable(self.floatTensor(shape).normal_())

    def _eq_logpw(self, prior_std: float, mean: torch.Tensor,
                  logvar: torch.Tensor) -> torch.Tensor:
        logpw = logvar.exp().add(mean**2).div(prior_std**2).add(
            math.log(2. * math.pi * (prior_std**2))).mul(-0.5)
        return torch.sum(logpw)

    def _eq_logqw(self, logvar: torch.Tensor):
        logqw = logvar.add(math.log(2. * math.pi)).add(1.).mul(-0.5)
        return torch.sum(logqw)

    def eq_logpw(self) -> torch.Tensor:
        rows = self._eq_logpw(prior_std=self.prior_std,
                              mean=self.rows_mean,
                              logvar=self.rows_logvar)
        columns = self._eq_logpw(prior_std=self.prior_std,
                                 mean=self.columns_mean,
                                 logvar=self.columns_logvar)
        alpha = self._eq_logpw(prior_std=self.prior_std,
                               mean=self.alpha_mean,
                               logvar=self.alpha_logvar)
        logpw = rows.add(columns).add(alpha)

        if self.use_bias:
            bias = self._eq_logpw(prior_std=self.prior_std,
                                  mean=self.bias_mean,
                                  logvar=self.bias_logvar)
            logpw.add(bias)
        return logpw

    def eq_logqw(self):
        rows = self._eq_logqw(logvar=self.rows_logvar)
        columns = self._eq_logqw(logvar=self.columns_logvar)
        alpha = self._eq_logqw(logvar=self.alpha_logvar)
        logqw = rows.add(columns).add(alpha)

        if self.use_bias:
            bias = self._eq_logqw(logvar=self.bias_logvar)
            logqw.add(bias)
        return logqw

    def kldiv(self):
        return self.eq_logpw() - self.eq_logqw()

    def kldiv_aux(self) -> float:
        return 0.

    def forward(self, input: torch.Tensor):

        rows = self.rows_mean
        columns = self.columns_mean
        alpha = self.alpha_mean

        if self.training:
            rows = self.rows_mean.add(
                self.rows_logvar.mul(0.5).exp().mul(
                    self._sample_eps(rows.shape)))
            columns = self.columns_mean.add(
                self.columns_logvar.mul(0.5).exp().mul(
                    self._sample_eps(columns.shape)))
            alpha = self.alpha_mean.add(
                self.alpha_logvar.mul(0.5).exp().mul(
                    self._sample_eps(alpha.shape)))

        weight = self._calc_rbf_weights(rows=rows,
                                        columns=columns,
                                        alpha=alpha)

        bias = self.bias_mean
        if self.training:
            bias = self.bias_mean.add(
                self.bias_logvar.mul(0.5).exp().mul(
                    self._sample_eps(self.bias_logvar.shape)))

        if not self.use_bias:
            bias = None

        return F.conv2d(input, weight, bias, self.stride, self.padding,
                        self.dilation, self.groups)

    def __repr__(self):
        s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size} '
             ', stride={stride}, dim={dim}')
        if self.padding != (0, ) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1, ) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0, ) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)
class FFGaussConv2d(Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 prior_std=1,
                 **kwargs):
        super(FFGaussConv2d, self).__init__()
        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = pair(kernel_size)
        self.stride = pair(stride)
        self.padding = pair(padding)
        self.dilation = pair(dilation)
        self.output_padding = pair(0)
        self.groups = groups
        self.prior_std = prior_std
        self.use_bias = False
        self.mean_w = Parameter(
            torch.Tensor(out_channels, in_channels // groups,
                         *self.kernel_size))
        self.logvar_w = Parameter(
            torch.Tensor(out_channels, in_channels // groups,
                         *self.kernel_size))

        if bias:
            self.mean_bias = Parameter(torch.Tensor(out_channels))
            self.logvar_bias = Parameter(torch.Tensor(out_channels))
            self.use_bias = True
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.reset_parameters()
        print(self)

    def reset_parameters(self):
        init.kaiming_normal_(self.mean_w, mode='fan_in')
        self.logvar_w.data.normal_(-9., 1e-4)
        if self.use_bias:
            self.mean_bias.data.zero_()
            self.logvar_bias.data.normal_(-9., 1e-4)

    def constrain_parameters(self, thres_std=1.):
        self.logvar_w.data.clamp_(max=2. * math.log(thres_std))
        if self.use_bias:
            self.logvar_bias.data.clamp_(max=2. * math.log(thres_std))

    def eq_logpw(self):
        logpw = -.5 * math.log(
            2 * math.pi * self.prior_std**2) - .5 * self.logvar_w.exp().div(
                self.prior_std**2)
        logpw -= .5 * self.mean_w.pow(2).div(self.prior_std**2)
        logpb = 0.
        if self.use_bias:
            logpb = -.5 * math.log(2 * math.pi * self.prior_std**
                                   2) - .5 * self.logvar_bias.exp().div(
                                       self.prior_std**2)
            logpb -= .5 * self.mean_bias.pow(2).div(self.prior_std**2)
        return torch.sum(logpw) + torch.sum(logpb)

    def eq_logqw(self):
        logqw = -torch.sum(.5 * (math.log(2 * math.pi) + self.logvar_w + 1))
        logqb = 0.
        if self.use_bias:
            logqb = -torch.sum(.5 *
                               (math.log(2 * math.pi) + self.logvar_bias + 1))
        return logqw + logqb

    def kldiv_aux(self):
        return 0.

    def kldiv(self):
        return self.kldiv_aux() + self.eq_logpw() - self.eq_logqw()

    def get_eps(self, size):
        eps = self.floatTensor(size).normal_()
        eps = Variable(eps)
        return eps

    def sample_W(self):
        W = self.mean_w
        if self.training:
            eps = self.get_eps(self.mean_w.size())
            W = W.add(eps.mul(self.logvar_w.mul(0.5).exp_()))
        return W

    def sample_b(self):
        if not self.use_bias:
            return None
        b = self.mean_bias
        if self.training:
            eps = self.get_eps(self.mean_bias.size())
            b = b.add(eps.mul(self.logvar_bias.mul(0.5).exp_()))
        return b

    def forward(self, input_):
        W = self.sample_W()
        b = self.sample_b()

        return F.conv2d(input_, W, b, self.stride, self.padding, self.dilation,
                        self.groups)

    def __repr__(self):
        s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}, prior_std={prior_std}')
        if self.padding != (0, ) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1, ) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0, ) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if not self.use_bias:
            s += ', bias=False'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)
class HSConv2d(Module):
    '''Input channel noise'''
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 prior_std=1.,
                 prior_std_z=1.,
                 dof=1.,
                 **kwargs):
        super(HSConv2d, self).__init__()
        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = pair(kernel_size)
        self.stride = pair(stride)
        self.padding = pair(padding)
        self.dilation = pair(dilation)
        self.output_padding = pair(0)
        self.groups = groups
        self.prior_std = prior_std
        self.prior_std_z = prior_std_z
        self.use_bias = False
        self.dof = dof
        self.mean_w = Parameter(
            torch.Tensor(out_channels, in_channels // groups,
                         *self.kernel_size))
        self.logvar_w = Parameter(
            torch.Tensor(out_channels, in_channels // groups,
                         *self.kernel_size))
        self.qz_mean = Parameter(torch.Tensor(in_channels // groups))
        self.qz_logvar = Parameter(torch.Tensor(in_channels // groups))
        self.dim_z = in_channels // groups

        if bias:
            self.mean_bias = Parameter(torch.Tensor(out_channels))
            self.logvar_bias = Parameter(torch.Tensor(out_channels))
            self.use_bias = True
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.reset_parameters()
        print(self)

    def reset_parameters(self):
        init.kaiming_normal_(self.mean_w, mode='fan_in')
        self.logvar_w.data.normal_(-9., 1e-4)
        self.qz_mean.data.normal_(math.log(math.exp(1) - 1), 1e-3)
        self.qz_logvar.data.normal_(math.log(0.1), 1e-4)

        if self.use_bias:
            self.mean_bias.data.normal_(0, 1e-2)
            self.logvar_bias.data.normal_(-9., 1e-4)

    def constrain_parameters(self, thres_std=1.):
        self.logvar_w.data.clamp_(max=2. * math.log(thres_std))
        if self.use_bias:
            self.logvar_bias.data.clamp_(max=2. * math.log(thres_std))

    def eq_logpw(self):
        logpw = -.5 * math.log(
            2 * math.pi * self.prior_std**2) - .5 * self.logvar_w.exp().div(
                self.prior_std**2)
        logpw -= .5 * self.mean_w.pow(2).div(self.prior_std**2)
        logpb = 0.
        if self.use_bias:
            logpb = -.5 * math.log(2 * math.pi * self.prior_std**
                                   2) - .5 * self.logvar_bias.exp().div(
                                       self.prior_std**2)
            logpb -= .5 * self.mean_bias.pow(2).div(self.prior_std**2)
            logpb = torch.sum(logpb)
        return torch.sum(logpw) + logpb

    def eq_logqw(self):
        logqw = -torch.sum(.5 * (math.log(2 * math.pi) + self.logvar_w + 1))
        logqb = 0.
        if self.use_bias:
            logqb = -torch.sum(.5 *
                               (math.log(2 * math.pi) + self.logvar_bias + 1))
        return logqw + logqb

    def kldiv_aux(self):
        z = self.sample_z(1)
        z = z.view(self.dim_z)

        logqm = -torch.sum(.5 * (math.log(2 * math.pi) + self.qz_logvar + 1))
        logqm = logqm.add(-torch.sum(F.sigmoid(z.exp().add(-1).log()).log()))

        logpm = torch.sum(
            2 * math.lgamma(.5 * (self.dof + 1)) - math.lgamma(.5 * self.dof) -
            math.log(self.prior_std_z) - .5 * math.log(self.dof * math.pi) -
            .5 * (self.dof + 1) * torch.log(1. + z.pow(2) /
                                            (self.dof * self.prior_std_z**2)))

        return logpm - logqm

    def kldiv(self):
        return self.kldiv_aux() + self.eq_logpw() - self.eq_logqw()

    def get_eps(self, size):
        eps = self.floatTensor(size).normal_()
        eps = Variable(eps)
        return eps

    def sample_z(self, batch_size):
        z = self.qz_mean.view(1, self.dim_z).expand(batch_size, self.dim_z)
        if self.training:
            eps = self.get_eps(self.floatTensor(batch_size, self.dim_z))
            z = z.add(
                eps.mul(
                    self.qz_logvar.view(1, self.dim_z).expand(
                        batch_size, self.dim_z).mul(0.5).exp_()))
        z = z.contiguous().view(batch_size, self.dim_z, 1, 1)
        return F.softplus(z)

    def sample_W(self):
        W = self.mean_w
        if self.training:
            eps = self.get_eps(self.mean_w.size())
            W = W.add(eps.mul(self.logvar_w.mul(0.5).exp_()))
        return W

    def sample_b(self):
        if not self.use_bias:
            return None
        b = self.mean_bias
        if self.training:
            eps = self.get_eps(self.mean_bias.size())
            b = b.add(eps.mul(self.logvar_bias.mul(0.5).exp_()))
        return b

    def forward(self, input_):
        z = self.sample_z(input_.size(0))
        W = self.sample_W()
        b = self.sample_b()
        return F.conv2d(input_.mul(z.expand_as(input_)), W, b, self.stride,
                        self.padding, self.dilation, self.groups)

    def __repr__(self):
        s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}, prior_std_z={prior_std_z}, dof={dof}')
        if self.padding != (0, ) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1, ) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0, ) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if not self.use_bias:
            s += ', bias=False'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)
class OrthogonalBayesianConv2d(ConvNd):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 use_bias=True,
                 simple=True,
                 add_diagonal=True,
                 weight_decay=1.,
                 prior_std=1.,
                 bias_std=1e-2,
                 **kwargs):
        kernel_size = pair(kernel_size)
        stride = pair(stride)
        padding = pair(padding)
        dilation = pair(dilation)
        self.weight_decay = weight_decay
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.device = torch.device(
            'cpu') if not torch.cuda.is_available() else torch.device('cuda')
        self.simple = simple
        self.add_diagonal = add_diagonal
        super(OrthogonalBayesianConv2d,
              self).__init__(in_channels, out_channels,
                             kernel_size, stride, padding, dilation, False,
                             pair(0), groups, use_bias)

        self.in_channels = in_channels
        self.out_channels = out_channels // groups
        self.kernel_size = kernel_size[0]

        self.prior_std = prior_std
        self.bias_std = bias_std

        if simple:
            self.r_mean = Parameter(
                self.floatTensor(self.kernel_size * self.kernel_size,
                                 self.out_channels))
            self.r_logvar = Parameter(
                self.floatTensor(self.kernel_size * self.kernel_size,
                                 self.out_channels))
        else:
            self.r_mean = Parameter(self.floatTensor(2, self.out_channels))
            self.r_logvar = Parameter(self.floatTensor(2, self.out_channels))
            self.t_mean = Parameter(
                self.floatTensor(2 * (self.kernel_size - 1),
                                 self.out_channels))
            self.t_logvar = Parameter(
                self.floatTensor(2 * (self.kernel_size - 1),
                                 self.out_channels))

        if self.add_diagonal:
            self.d_mean = Parameter(
                self.floatTensor(self.kernel_size, self.kernel_size,
                                 min(self.in_channels, self.out_channels)))
            self.d_logvar = Parameter(
                self.floatTensor(self.kernel_size, self.kernel_size,
                                 min(self.in_channels, self.out_channels)))

        self.use_bias = use_bias
        if self.use_bias:
            self.bias_mean = Parameter(
                self.floatTensor(self.out_channels // self.groups))
            self.bias_logvar = Parameter(
                self.floatTensor(self.out_channels // self.groups))

        self.reset_parameters()
        print(self)

    def reset_parameters(self):
        if hasattr(self, 'r_mean'):
            torch.nn.init.orthogonal_(self.r_mean)
            self.r_logvar.data.normal_(mean=-5., std=self.prior_std)

        if hasattr(self, 't_mean'):
            torch.nn.init.orthogonal_(self.t_mean)
            self.t_logvar.data.normal_(mean=-5., std=self.prior_std)

        if hasattr(self, 'd_mean'):
            self.d_mean.data.normal_(std=self.prior_std)
            self.d_logvar.data.normal_(mean=-5, std=self.prior_std)

        if hasattr(self, 'bias_mean'):
            self.bias_mean.data.normal_(std=self.bias_std)
            self.bias_logvar.data.normal_(mean=-5., std=self.bias_std)

    def constrain_parameters(self, thres_std=1.):
        pass

    def _sample_eps(self, shape: tuple):
        return Variable(self.floatTensor(shape).normal_())

    def _eq_logpw(self, prior_mean: float, prior_std: float,
                  mean: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        logpw = logvar.exp().add((prior_mean - mean)**2).div(prior_std**2).add(
            math.log(2. * math.pi * (prior_std**2))).mul(-0.5)
        return torch.sum(logpw)

    def _eq_logqw(self, logvar: torch.Tensor):
        logqw = logvar.add(math.log(2. * math.pi)).add(1.).mul(-0.5)
        return torch.sum(logqw)

    def eq_logpw(self) -> torch.Tensor:
        r = self._eq_logpw(prior_mean=0.,
                           prior_std=self.prior_std,
                           mean=self.r_mean,
                           logvar=self.r_logvar)

        if self.add_diagonal:
            d = self._eq_logpw(prior_mean=0.,
                               prior_std=self.prior_std,
                               mean=self.d_mean,
                               logvar=self.d_logvar)
            logpw = r.add(d)
        else:
            logpw = r

        if not self.simple:
            t = self._eq_logpw(prior_mean=0.,
                               prior_std=self.prior_std,
                               mean=self.t_mean,
                               logvar=self.t_logvar)
            logpw = logpw.add(t)

        if self.use_bias:
            bias = self._eq_logpw(prior_mean=0.,
                                  prior_std=self.prior_std,
                                  mean=self.bias_mean,
                                  logvar=self.bias_logvar)
            logpw.add(bias)

        return logpw

    def eq_logqw(self):
        r = self._eq_logqw(logvar=self.r_logvar)

        if self.add_diagonal:
            d = self._eq_logqw(logvar=self.d_logvar)
            logqw = r.add(d)
        else:
            logqw = r

        if not self.simple:
            t = self._eq_logqw(logvar=self.t_logvar)
            logqw.add(t)

        if self.use_bias:
            bias = self._eq_logqw(logvar=self.bias_logvar)
            logqw.add(bias)
        return logqw

    def kldiv(self):
        return self.eq_logpw() - self.eq_logqw()

    def kldiv_aux(self) -> float:
        return 0.

    def _chain_multiply(self, t: torch.Tensor) -> torch.Tensor:
        p = t[0]
        for i in range(1, t.shape[0]):
            p = p.mm(t[i])
        return p

    def _calc_householder_tensor(self, t: torch.Tensor) -> torch.Tensor:
        norm = t.norm(p=2, dim=1)
        t = t.div(norm.unsqueeze(1))
        h = torch.einsum('ab,ac->abc', (t, t))
        return torch.eye(n=t.shape[1],
                         device=self.device).expand_as(h) - h.mul(2.)

    def _calc_block_orthogonal_tensor(self, size: int, t: torch.Tensor):
        h = torch.zeros(size, 2, 2, t.shape[1], t.shape[2], device=self.device)

        for i in range(size):
            p = t[2 * i]
            q = t[2 * i + 1]
            pq = p.mm(q)
            h[i, 0, 0] = pq
            h[i, 0, 1] = p.sub(pq)
            h[i, 1, 0] = q.sub(pq)
            h[i, 1, 1] = torch.eye(p.shape[0],
                                   device=self.device).add(pq).sub(p).sub(q)
        return h

    def _matrix_conv(self, s: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        assert s[0, 0].shape[0] == t[0, 0].shape[0]

        n = s[0, 0].shape[0]
        k = int(np.sqrt(s.shape[0] * s.shape[1]))
        l = int(np.sqrt(t.shape[0] * t.shape[1]))
        size = k + l - 1

        result = torch.zeros(size, size, n, n, device=self.device)

        for i in range(size):
            for j in range(size):
                for index1 in range(min(k, i + 1)):
                    for index2 in range(min(k, j + 1)):
                        if (i - index1) < l and (j - index2) < l:
                            result[i, j] += torch.mm(s[index1, index2],
                                                     t[i - index1, j - index2])
        return result

    def _orthogonal_kernel(self,
                           r: torch.Tensor,
                           t: torch.Tensor,
                           d: torch.Tensor = None,
                           transpose: bool = False) -> torch.Tensor:
        assert self.in_channels <= self.out_channels

        if self.in_channels <= self.out_channels:
            d = torch.einsum('abc,cd->abcd', (d,
                                              torch.eye(n=self.in_channels,
                                                        m=self.out_channels,
                                                        device=self.device)))
        else:
            d = torch.einsum('ab, cda->cdab',
                             (torch.eye(n=self.in_channels,
                                        m=self.out_channels,
                                        device=self.device), d))

        if self.simple:
            q = self._calc_householder_tensor(r).view(self.kernel_size,
                                                      self.kernel_size,
                                                      self.out_channels,
                                                      self.out_channels)
        else:
            r = self._chain_multiply(self._calc_householder_tensor(r))

            if self.kernel_size == 1:
                return torch.unsqueeze(torch.unsqueeze(r, 0), 0)

            t = self._calc_block_orthogonal_tensor(
                size=self.kernel_size - 1, t=self._calc_householder_tensor(t))

            s = t[0]
            for i in range(1, self.kernel_size - 1):
                s = self._matrix_conv(s=s, t=t[i])

            q = torch.einsum('ab,debc->deac', (r, s))

        q = torch.einsum('abcd,abde->abce', (d, q))

        if transpose:
            q = q.permute(2, 3, 1, 0)
        else:
            q = q.permute(3, 2, 1, 0)

        return q

    def forward(self, input_):
        r = self.r_mean.add(
            self.r_logvar.mul(0.5).exp().mul(
                self._sample_eps(self.r_logvar.shape)))

        if self.add_diagonal:
            d = self.d_mean.add(
                self.d_logvar.mul(0.5).exp().mul(
                    self._sample_eps(self.d_logvar.shape)))
        else:
            d = torch.ones(self.kernel_size,
                           self.kernel_size,
                           min(self.in_channels, self.out_channels),
                           device=self.device)

        if self.simple:
            weight = self._orthogonal_kernel(r=r, t=None, d=d)
        else:
            t = self.t_mean.add(
                self.t_logvar.mul(0.5).exp().mul(
                    self._sample_eps(self.t_logvar.shape)))
            weight = self._orthogonal_kernel(r=r, t=t, d=d)

        bias = self.bias_mean.add(
            self.bias_logvar.mul(0.5).exp().mul(
                self._sample_eps(self.bias_logvar.shape)))
        if not self.use_bias:
            bias = None

        return F.conv2d(input_, weight, bias, self.stride, self.padding,
                        self.dilation, self.groups)

    def __repr__(self):
        s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size} '
             ', stride={stride}, weight_decay={weight_decay}')
        if self.padding != (0, ) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1, ) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0, ) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        s += ', simple={simple}'
        s += ', add_diagonal={add_diagonal}'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)