class VariationalDropout(nn.Module):
    def __init__(self, input_size, out_size, log_sigma2=-10, threshold=3):
        """
        :param input_size: An int of input size
        :param log_sigma2: Initial value of log sigma ^ 2.
               It is crusial for training since it determines initial value of alpha
        :param threshold: Value for thresholding of validation. If log_alpha > threshold, then weight is zeroed
        :param out_size: An int of output size
        """
        super(VariationalDropout, self).__init__()

        self.input_size = input_size
        self.out_size = out_size

        self.theta = Parameter(t.FloatTensor(input_size, out_size))
        self.bias = Parameter(t.Tensor(out_size))

        self.log_sigma2 = Parameter(t.FloatTensor(input_size, out_size).fill_(log_sigma2))

        self.reset_parameters()

        self.threshold = threshold

    def forward(self, input): # Local Reparameterization Trick
        log_alpha = self.clip(self.log_sigma2 - t.log(self.theta ** 2))
        kld = self.kld(log_alpha)

        if not self.training:
            mask = log_alpha > self.threshold
            return t.addmm(self.bias, input, self.theta.masked_fill(mask, 0))

        mu = t.mm(input, self.theta)
        std = t.sqrt(t.mm(input ** 2, self.log_sigma2.exp()) + 1e-6)

        eps = Variable(t.randn(*mu.size()))
        if input.is_cuda:
            eps = eps.cuda()

        return std * eps + mu + self.bias, kld


    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.out_size)

        self.theta.data.uniform_(-stdv, stdv)
        self.bias.data.uniform_(-stdv, stdv)

    def clip(self, input, to=8):
        input = input.masked_fill(input < -to, -to)
        input = input.masked_fill(input > to, to)

        return input

    def kld(self, log_alpha): # in paper "Variational Dropout Sparsifies Deep Neural Networks"
        k = [0.63576, 1.87320, 1.48695]

        first_term = k[0] * t.sigmoid(k[1] + k[2] * log_alpha)
        second_term = 0.5 * t.log(1 + t.exp(-log_alpha))

        return - (first_term - second_term - k[0]).sum() / (self.input_size * self.out_size)
Example #2
0
class VDropLinear(Module):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        self.w_mu = Parameter(torch.Tensor(out_features, in_features))
        init.kaiming_normal_(self.w_mu, mode="fan_out")

        self.w_logsigma2 = Parameter(torch.Tensor(out_features, in_features))
        self.w_logsigma2.data.fill_(-10)

        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
            self.bias.data.fill_(0)
        else:
            self.bias = None

        self.threshold = 3
        self.epsilon = 1e-8
        self.tensor = (torch.FloatTensor if not torch.cuda.is_available()
                       else torch.cuda.FloatTensor)

    def compute_mask(self):
        w_logalpha = self.w_logsigma2 - (self.w_mu ** 2 + self.epsilon).log()
        return (w_logalpha < self.threshold).float()

    def forward(self, x):
        if self.training:
            y_mu = F.linear(x, self.w_mu, self.bias)

            # Avoid sqrt(0), otherwise a divide-by-zero occurs during backprop.
            y_sigma = F.linear(
                x ** 2, self.w_logsigma2.exp()
            ).clamp(self.epsilon).sqrt()

            rv = self.tensor(y_mu.size()).normal_()
            return y_mu + (rv * y_sigma)
        else:
            return F.linear(x, self.w_mu * self.compute_mask(), self.bias)

    def regularization(self):
        k1, k2, k3 = 0.63576, 1.8732, 1.48695
        w_logalpha = self.w_logsigma2 - (self.w_mu ** 2 + self.epsilon).log()

        return -(k1 * torch.sigmoid(k2 + k3 * w_logalpha)
                 - 0.5 * F.softplus(-w_logalpha) - k1).sum()

    def get_inference_nonzeros(self):
        return self.compute_mask().int().sum(dim=1)

    def count_inference_flops(self):
        # For each unit, multiply with its n inputs then do n - 1 additions.
        # To capture the -1, subtract it, but only in cases where there is at
        # least one weight.
        nz_by_unit = self.get_inference_nonzeros()
        multiplies = torch.sum(nz_by_unit)
        adds = multiplies - torch.sum(nz_by_unit > 0)
        return multiplies.item(), adds.item()

    def weight_size(self):
        return self.w_mu.size()
Example #3
0
class OptNetEq(nn.Module):
    def __init__(self, n, Qpenalty, qp_solver, trueInit=False):
        super().__init__()

        self.qp_solver = qp_solver

        nx = (n**2)**3
        self.Q = Variable(Qpenalty * torch.eye(nx).double().cuda())
        self.Q_idx = spa.csc_matrix(self.Q.detach().cpu().numpy()).nonzero()

        self.G = Variable(-torch.eye(nx).double().cuda())
        self.h = Variable(torch.zeros(nx).double().cuda())
        t = get_sudoku_matrix(n)

        if trueInit:
            self.A = Parameter(torch.DoubleTensor(t).cuda())
        else:
            self.A = Parameter(torch.rand(t.shape).double().cuda())
        self.log_z0 = Parameter(torch.zeros(nx).double().cuda())
        # self.b = Variable(torch.ones(self.A.size(0)).double().cuda())

        if self.qp_solver == 'osqpth':
            t = torch.cat((self.A, self.G), dim=0)
            self.AG_idx = spa.csc_matrix(t.detach().cpu().numpy()).nonzero()

    # @profile
    def forward(self, puzzles):
        nBatch = puzzles.size(0)

        p = -puzzles.view(nBatch, -1)
        b = self.A.mv(self.log_z0.exp())

        if self.qp_solver == 'qpth':
            y = QPFunction(verbose=-1)(self.Q, p.double(), self.G, self.h,
                                       self.A, b).float().view_as(puzzles)
        elif self.qp_solver == 'osqpth':
            _l = torch.cat((b,
                            torch.full(self.h.shape,
                                       float('-inf'),
                                       device=self.h.device,
                                       dtype=self.h.dtype)),
                           dim=0)
            _u = torch.cat((b, self.h), dim=0)
            Q_data = self.Q[self.Q_idx[0], self.Q_idx[1]]

            AG = torch.cat((self.A, self.G), dim=0)
            AG_data = AG[self.AG_idx[0], self.AG_idx[1]]
            y = OSQP(self.Q_idx,
                     self.Q.shape,
                     self.AG_idx,
                     AG.shape,
                     diff_mode=DiffModes.FULL)(Q_data, p.double(), AG_data, _l,
                                               _u).float().view_as(puzzles)
        else:
            assert False

        return y
Example #4
0
class VDropLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        init.kaiming_normal_(self.weight, mode="fan_out")

        self.w_logvar = Parameter(torch.Tensor(out_features, in_features))
        self.w_logvar.data.fill_(-10)

        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
            self.bias.data.fill_(0)
        else:
            self.bias = None

        self.threshold = 3
        self.epsilon = 1e-8
        self.tensor_constructor = (torch.FloatTensor
                                   if not torch.cuda.is_available()
                                   else torch.cuda.FloatTensor)

    def constrain_parameters(self):
        self.w_logvar.data.clamp_(min=-10., max=10.)

    def compute_mask(self):
        w_logalpha = self.w_logvar - (self.weight ** 2 + self.epsilon).log()
        return (w_logalpha < self.threshold).float()

    def forward(self, x):
        if self.training:
            return vdrop_linear_forward(x,
                                        lambda: self.weight,
                                        lambda: self.w_logvar.exp(),
                                        self.bias, self.tensor_constructor,
                                        self.epsilon)
        else:
            return F.linear(x, self.weight * self.compute_mask(), self.bias)

    def regularization(self):
        w_logalpha = self.w_logvar - (self.weight ** 2 + self.epsilon).log()
        return vdrop_regularization(w_logalpha).sum()

    def get_inference_nonzeros(self):
        return self.compute_mask().int().sum(dim=1)

    def count_inference_flops(self):
        # For each unit, multiply with its n inputs then do n - 1 additions.
        # To capture the -1, subtract it, but only in cases where there is at
        # least one weight.
        nz_by_unit = self.get_inference_nonzeros()
        multiplies = torch.sum(nz_by_unit)
        adds = multiplies - torch.sum(nz_by_unit > 0)
        return multiplies.item(), adds.item()

    def weight_size(self):
        return self.weight.size()
class VariationalDropout(nn.Module):
    def __init__(self, input_size, out_size, log_sigma2=-10, threshold=3):
        """
        :param input_size: An int of input size
        :param log_sigma2: Initial value of log_sigma^2 (crucial for training as it determines initial value of alpha)
        :param threshold: Value for thresholding of validation. If log_alpha > threshold, then weight is zeroed
        :param out_size: An int of output size
        """
        super(VariationalDropout, self).__init__()

        self.input_size = input_size
        self.out_size = out_size

        self.theta = Parameter(torch.FloatTensor(input_size, out_size))
        self.bias = Parameter(torch.Tensor(out_size))

        self.log_sigma2 = Parameter(
            torch.FloatTensor(input_size, out_size).fill_(log_sigma2))

        self.reset_parameters()

        self.k = [0.63576, 1.87320, 1.48695]

        self.threshold = threshold

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.out_size)
        self.theta.data.uniform_(-stdv, stdv)
        self.bias.data.uniform_(-stdv, stdv)

    @staticmethod
    def clip(input, to=8):
        input = input.masked_fill(input < -to, -to)
        input = input.masked_fill(input > to, to)
        return input

    def kld(self, log_alpha):
        first_term = self.k[0] * F.sigmoid(self.k[1] + self.k[2] * log_alpha)
        second_term = 0.5 * torch.log(1 + torch.exp(-log_alpha))
        return -(first_term - second_term - self.k[0]).sum() / (
            self.input_size * self.out_size)

    def forward(self, input):
        """
        :param input: An float tensor with shape of [batch_size, input_size]
        :return: An float tensor with shape of [batch_size, out_size] and negative layer-kld estimation
        """

        log_alpha = self.clip(self.log_sigma2 - torch.log(self.theta**2))
        kld = self.kld(log_alpha)

        if not self.training:
            mask = log_alpha > self.threshold
            return torch.addmm(self.bias, input,
                               self.theta.masked_fill(mask, 0))

        mu = torch.mm(input, self.theta)
        std = torch.sqrt(torch.mm(input**2, self.log_sigma2.exp()) + 1e-6)

        eps = Variable(torch.randn(*mu.size()))
        if input.is_cuda:
            eps = eps.cuda()

        return std * eps + mu + self.bias, kld

    def max_alpha(self):
        log_alpha = self.log_sigma2 - self.theta**2
        return torch.max(log_alpha.exp())
Example #6
0
class VariationalDropout(nn.Module):
    def __init__(self, input_size, out_size, log_sigma2=-10, threshold=3):
        """
        This module create a fully connected layer with variational dropout enabled
        
        :param input_size: An int of input size
        :param log_sigma2: Initial value of log sigma ^ 2.
               It is crucial for training since it determines initial value of alpha
        :param threshold: Value for thresholding of validation. If log_alpha > threshold, then weight is zeroed
        :param out_size: An int of output size
        """
        super(VariationalDropout, self).__init__()

        self.input_size = input_size
        self.out_size = out_size

        self.theta = Parameter(t.FloatTensor(
            input_size, out_size))  # fully connected weight
        self.bias = Parameter(t.Tensor(out_size))  # bias

        self.log_sigma2 = Parameter(
            t.FloatTensor(input_size, out_size).fill_(
                log_sigma2))  # the Gaussian noise sample iid w.r.t each weight

        self.reset_parameters()

        self.k = [0.63576, 1.87320, 1.48695]

        self.threshold = threshold  # as it said, this is used for zero the weight if the Gaussian noise ball has too large radius.

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.out_size)

        self.theta.data.uniform_(-stdv, stdv)
        self.bias.data.uniform_(-stdv, stdv)

    @staticmethod
    def clip(input, to=8):
        input = input.masked_fill(input < -to, -to)
        input = input.masked_fill(input > to, to)

        return input

    def kld(self, log_alpha):

        first_term = self.k[0] * F.sigmoid(self.k[1] + self.k[2] * log_alpha)
        second_term = 0.5 * t.log(1 + t.exp(-log_alpha))

        return -(first_term - second_term - self.k[0]).sum() / (
            self.input_size * self.out_size)

    def forward(self, input):
        """
        :param input: An float tensor with shape of [batch_size, input_size]
        :return: An float tensor with shape of [batch_size, out_size] and negative layer-kld estimation
        """

        log_alpha = self.clip(self.log_sigma2 - t.log(self.theta**2))
        kld = self.kld(log_alpha)

        if not self.training:
            mask = log_alpha > self.threshold
            return t.addmm(self.bias, input, self.theta.masked_fill(mask, 0))

        mu = t.mm(input, self.theta)
        std = t.sqrt(t.mm(input**2, self.log_sigma2.exp()) + 1e-6)

        eps = Variable(
            t.randn(*mu.size()))  # sample from standard normal distribution
        if input.is_cuda:
            eps = eps.cuda()

        return std * eps + mu + self.bias, kld  # a reparameterization trick to form the Gaussian dropout

    def max_alpha(self):
        log_alpha = self.log_sigma2 - self.theta**2
        return t.max(log_alpha.exp())
class VariationalDropout(nn.Module):
    def __init__(self, input_size, out_size, log_sigma2=-10, threshold=3):
        """
        :param input_size: An int of input size
        :param log_sigma2: Initial value of log sigma ^ 2.
               It is crusial for training since it determines initial value of alpha
        :param threshold: Value for thresholding of validation. If log_alpha > threshold, then weight is zeroed
        :param out_size: An int of output size
        """
        super(VariationalDropout, self).__init__()

        self.input_size = input_size
        self.out_size = out_size

        self.theta = Parameter(t.FloatTensor(input_size, out_size))
        self.bias = Parameter(t.Tensor(out_size))

        self.log_sigma2 = Parameter(
            t.FloatTensor(input_size, out_size).fill_(log_sigma2))

        self.reset_parameters()

        self.k = [0.63576, 1.87320, 1.48695]

        self.threshold = threshold

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.out_size)

        self.theta.data.uniform_(-stdv, stdv)
        self.bias.data.uniform_(-stdv, stdv)

    @staticmethod
    def clip(input, to=8.):
        input = input.masked_fill(input < -to, -to)
        input = input.masked_fill(input > to, to)

        return input

    def kld(self, log_alpha):

        first_term = self.k[0] * F.sigmoid(self.k[1] + self.k[2] * log_alpha)
        second_term = 0.5 * t.log(1 + t.exp(-log_alpha))
        return (first_term - second_term -
                self.k[0]).sum() / (self.input_size * self.out_size)

    def forward(self, input, train):
        """
        :param input: An float tensor with shape of [batch_size, input_size]
        :return: An float tensor with shape of [batch_size, out_size] and negative layer-kld estimation
        """
        log_alpha = self.clip(self.log_sigma2 - t.log(self.theta**2))
        fh = open("log_alpha_values_during_training.txt", 'a')
        fh.write(
            str(self.input_size) + "||||" + str(log_alpha.data.numpy()[0][0]) +
            "\n")
        fh.close()
        #print(log_alpha.data.numpy()[0][0])
        kld = self.kld(log_alpha)

        if not train:
            mask = log_alpha > self.threshold
            if (t.nonzero(mask).dim() != 0):
                zeroed_weights = t.nonzero(mask).size(0)

            else:
                zeroed_weights = 0

            total_weights = mask.size(0) * mask.size(1)
            print('number of zeroed weights is {}'.format(zeroed_weights))
            print('total numer of weights is {}'.format(total_weights))
            print('ratio for non zeroed weights is {}'.format(
                (total_weights - zeroed_weights) / total_weights))
            return t.addmm(self.bias, input, self.theta.masked_fill(mask, 0))

        mu = t.mm(input, self.theta)
        std = t.sqrt(t.mm(input**2, self.log_sigma2.exp()) + 1e-6)

        eps = Variable(t.randn(*mu.size()))
        if input.is_cuda:
            eps = eps.cuda()

        return std * eps + mu + self.bias, kld

    def max_alpha(self):
        log_alpha = self.log_sigma2 - self.theta**2
        return t.max(log_alpha)
class HSConv2d(Module):
    '''Input channel noise'''
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 prior_std=1.,
                 prior_std_z=1.,
                 dof=1.,
                 **kwargs):
        super(HSConv2d, self).__init__()
        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = pair(kernel_size)
        self.stride = pair(stride)
        self.padding = pair(padding)
        self.dilation = pair(dilation)
        self.output_padding = pair(0)
        self.groups = groups
        self.prior_std = prior_std
        self.prior_std_z = prior_std_z
        self.use_bias = False
        self.dof = dof
        self.mean_w = Parameter(
            torch.Tensor(out_channels, in_channels // groups,
                         *self.kernel_size))
        self.logvar_w = Parameter(
            torch.Tensor(out_channels, in_channels // groups,
                         *self.kernel_size))
        self.qz_mean = Parameter(torch.Tensor(in_channels // groups))
        self.qz_logvar = Parameter(torch.Tensor(in_channels // groups))
        self.dim_z = in_channels // groups

        if bias:
            self.mean_bias = Parameter(torch.Tensor(out_channels))
            self.logvar_bias = Parameter(torch.Tensor(out_channels))
            self.use_bias = True
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.reset_parameters()
        print(self)

    def reset_parameters(self):
        init.kaiming_normal_(self.mean_w, mode='fan_in')
        self.logvar_w.data.normal_(-9., 1e-4)
        self.qz_mean.data.normal_(math.log(math.exp(1) - 1), 1e-3)
        self.qz_logvar.data.normal_(math.log(0.1), 1e-4)

        if self.use_bias:
            self.mean_bias.data.normal_(0, 1e-2)
            self.logvar_bias.data.normal_(-9., 1e-4)

    def constrain_parameters(self, thres_std=1.):
        self.logvar_w.data.clamp_(max=2. * math.log(thres_std))
        if self.use_bias:
            self.logvar_bias.data.clamp_(max=2. * math.log(thres_std))

    def eq_logpw(self):
        logpw = -.5 * math.log(
            2 * math.pi * self.prior_std**2) - .5 * self.logvar_w.exp().div(
                self.prior_std**2)
        logpw -= .5 * self.mean_w.pow(2).div(self.prior_std**2)
        logpb = 0.
        if self.use_bias:
            logpb = -.5 * math.log(2 * math.pi * self.prior_std**
                                   2) - .5 * self.logvar_bias.exp().div(
                                       self.prior_std**2)
            logpb -= .5 * self.mean_bias.pow(2).div(self.prior_std**2)
            logpb = torch.sum(logpb)
        return torch.sum(logpw) + logpb

    def eq_logqw(self):
        logqw = -torch.sum(.5 * (math.log(2 * math.pi) + self.logvar_w + 1))
        logqb = 0.
        if self.use_bias:
            logqb = -torch.sum(.5 *
                               (math.log(2 * math.pi) + self.logvar_bias + 1))
        return logqw + logqb

    def kldiv_aux(self):
        z = self.sample_z(1)
        z = z.view(self.dim_z)

        logqm = -torch.sum(.5 * (math.log(2 * math.pi) + self.qz_logvar + 1))
        logqm = logqm.add(-torch.sum(F.sigmoid(z.exp().add(-1).log()).log()))

        logpm = torch.sum(
            2 * math.lgamma(.5 * (self.dof + 1)) - math.lgamma(.5 * self.dof) -
            math.log(self.prior_std_z) - .5 * math.log(self.dof * math.pi) -
            .5 * (self.dof + 1) * torch.log(1. + z.pow(2) /
                                            (self.dof * self.prior_std_z**2)))

        return logpm - logqm

    def kldiv(self):
        return self.kldiv_aux() + self.eq_logpw() - self.eq_logqw()

    def get_eps(self, size):
        eps = self.floatTensor(size).normal_()
        eps = Variable(eps)
        return eps

    def sample_z(self, batch_size):
        z = self.qz_mean.view(1, self.dim_z).expand(batch_size, self.dim_z)
        if self.training:
            eps = self.get_eps(self.floatTensor(batch_size, self.dim_z))
            z = z.add(
                eps.mul(
                    self.qz_logvar.view(1, self.dim_z).expand(
                        batch_size, self.dim_z).mul(0.5).exp_()))
        z = z.contiguous().view(batch_size, self.dim_z, 1, 1)
        return F.softplus(z)

    def sample_W(self):
        W = self.mean_w
        if self.training:
            eps = self.get_eps(self.mean_w.size())
            W = W.add(eps.mul(self.logvar_w.mul(0.5).exp_()))
        return W

    def sample_b(self):
        if not self.use_bias:
            return None
        b = self.mean_bias
        if self.training:
            eps = self.get_eps(self.mean_bias.size())
            b = b.add(eps.mul(self.logvar_bias.mul(0.5).exp_()))
        return b

    def forward(self, input_):
        z = self.sample_z(input_.size(0))
        W = self.sample_W()
        b = self.sample_b()
        return F.conv2d(input_.mul(z.expand_as(input_)), W, b, self.stride,
                        self.padding, self.dilation, self.groups)

    def __repr__(self):
        s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}, prior_std_z={prior_std_z}, dof={dof}')
        if self.padding != (0, ) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1, ) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0, ) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if not self.use_bias:
            s += ', bias=False'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)
class GHConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True):

        # Init torch module
        super(GHConv2d, self).__init__()

        # Init conv params
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = stride
        self.padding = padding
        self.dilation = dilation

        # init constants according to section 5
        self.t0 = 1e-5

        # Init globals
        self.sa_mu = Parameter(Tensor(1))
        self.sa_logvar = Parameter(Tensor(1))
        self.sb_mu = Parameter(Tensor(1))
        self.sb_logvar = Parameter(Tensor(1))

        # Filter locals
        self.alpha_mu = Parameter(Tensor(out_channels))
        self.alpha_logvar = Parameter(Tensor(out_channels))
        self.beta_mu = Parameter(Tensor(out_channels))
        self.beta_logvar = Parameter(Tensor(out_channels))

        # Weight local
        self.weight_mu = Parameter(Tensor(out_channels, in_channels, *self.kernel_size))
        self.weight_logvar = Parameter(Tensor(out_channels, in_channels, *self.kernel_size))

        # Bias local if required
        self.bias = bias
        self.bias_mu = Parameter(Tensor(out_channels)) if self.bias else None
        self.bias_logvar = Parameter(Tensor(out_channels)) if self.bias else None

        # Set initial parameters
        self._init_params()

        # for brevity to conv2d calls
        self.convargs = [self.stride, self.padding, self.dilation]
    def _s_mu(self):
        return 0.5 * (self.sa_mu + self.sb_mu)

    def _s_var(self):
        return 0.25 * (self.sa_logvar.exp() + self.sb_logvar.exp())

    def _z_var(self):
        return 0.25 * (self.alpha_logvar.exp() + self.beta_logvar.exp())

    def _z_mu(self):
        return 0.5 * (self.alpha_mu + self.beta_mu)

    def forward(self, x):

        # vanilla forward pass if testing
        if not self.training:
            expect_z = torch.exp(0.5 * (self._z_var() + self._s_var()) + self._z_mu() + self._s_mu())
            post_weight_mu = self.weight_mu * expect_z[:, None, None, None]
            post_bias_mu = self.bias_mu * expect_z if (self.bias_mu is not None) else None
            return conv2d(x, post_weight_mu, post_bias_mu, *self.convargs)

        # compute global shrinkage
        s_mu = 0.5 * (self.sa_mu + self.sb_mu)
        s_sig = torch.sqrt(self._s_var())
        s = LogNormal(s_mu, s_sig).rsample()

        # compute filter scales
        z_mu = self._z_mu()
        z_var = self._z_var()
        z = s * LogNormal(z_mu, z_var.sqrt()).rsample()[None, :, None, None]


        # lognormal out params, local reparameterization trick
        bvar = self.bias_logvar.exp() if self.bias else None
        mu_out = conv2d(x, self.weight_mu, self.bias_mu, *self.convargs) * z
        scale_out = conv2d(x**2, self.weight_logvar.exp(), bvar, *self.convargs) * (z ** 2)

        # compute output weight distribution, again reparameterised
        dist_out = Normal(mu_out, scale_out.sqrt()).rsample()

        # return fully reparameterised forward pass
        return dist_out


    def _init_params(self, weight=None, bias=None):

        # initialisation params - note mean of lognormal is exp(mu + 0.5 *var)
        init_mu_logvar, init_mu, init_var = -9, 0., 1e-2

        # compute xavier initialisation on weights
        n = self.in_channels * self.kernel_size[0] * self.kernel_size[1]
        thresh = 1/math.sqrt(n)

        if weight is not None:
            self.weight_mu.data = weight
        else:
            self.weight_mu.data.uniform_(-thresh, thresh)

        # init variance according to appendix A
        self.weight_logvar.data.normal_(init_mu_logvar, init_var)

        if self.bias:
            if bias is not None:
                self.bias_mu.data = bias
            else:
                self.bias_mu.data.fill_(0)

            # biases
            self.bias_logvar.data.normal_(init_mu_logvar, init_var)

        # Decomposed prior means => E[z_init] init ~ 1
        self.alpha_mu.data.normal_(init_mu, init_var)
        self.beta_mu.data.normal_(init_mu, init_var)
        self.sa_mu.data.normal_(init_mu, init_var)
        self.sb_mu.data.normal_(init_mu, init_var)

        # Decomposed prior variances
        self.alpha_logvar.data.normal_(init_mu_logvar, init_var)
        self.beta_logvar.data.normal_(init_mu_logvar, init_var)
        self.sa_logvar.data.normal_(init_mu_logvar, init_var)
        self.sb_logvar.data.normal_(init_mu_logvar, init_var)


    # KL div for GNH with lognormal scale, normal weight variational posterior
    def kl_divergence(self):
        # negative kls, eqns (34-37)
        neg_kl_s = self._global_negative_kl()
        neg_kl_ab = self._filter_local_negative_kl()

        # weight/bias local
        kl_w = self._conditional_kl_div(self.weight_mu, self.weight_logvar)

        if self.bias:
            kl_b = self._conditional_kl_div(self.bias_mu, self.bias_logvar)
        else:
            kl_b = 0

        return kl_w + kl_b - (neg_kl_s + neg_kl_ab)


    def _global_negative_kl(self):

        # hyperparams
        t0 = self.t0

        # const added in every kl div
        c = 1 + math.log(2)

        # shape/scale of global scale parameters
        sa_mu, sb_mu = self.sa_mu, self.sb_mu
        sa_var, sb_var = self.sa_logvar.exp(), self.sb_logvar.exp()

        # Eqns (34)(35)
        kl_sa = math.log(t0) - torch.exp(sa_mu + 0.5 * sa_var)/t0 + 0.5 * (sa_mu + self.sa_logvar + c)
        kl_sb = 0.5 * (self.sb_logvar - sb_mu + c ) - torch.exp(0.5 * sb_var - sb_mu)

        return kl_sa + kl_sb


    def _filter_local_negative_kl(self):

        # const added in every kl div
        c = 1 + math.log(2)

        # hyperparams
        t0 = self.t0

        # filter level shape/scale parameters
        alpha_mu, beta_mu = self.alpha_mu, self.beta_mu
        alpha_logvar, beta_logvar = self.alpha_logvar, self.beta_logvar

        # Eqns (36)(37)
        kl_alpha = torch.sum(0.5 * (alpha_mu + alpha_logvar + c) - torch.exp(alpha_mu + 0.5 * alpha_logvar.exp()))
        kl_beta = torch.sum(0.5 * (beta_logvar - beta_mu + c) - torch.exp(0.5 * beta_logvar.exp() - beta_mu))

        return kl_alpha + kl_beta


    @staticmethod
    def _conditional_kl_div(mu, logvar):
        # eqn (8)
        kl_div = -0.5 * logvar + 0.5 * (logvar.exp() + mu ** 2 - 1)
        return torch.sum(kl_div)
Example #10
0
class VDropCentralData(nn.Module):
    """
    Stores data for a set of variational dropout (VDrop) modules in large
    central tensors. The VDrop modules access the data using views. This makes
    it possible to operate on all of the data at once, (rather than e.g. 53
    times with resnet50).

    Usage:
    1. Instantiate
    2. Pass into multiple constructed VDropLinear and VDropConv2d modules
    3. Call finalize

    Before calling forward on the model, call "compute_forward_data".
    After calling forward on the model, call "clear_forward_data".

    The parameters are stored in terms of z_mu and z_var rather than w_mu and
    w_var to support group variational dropout (e.g. to allow for pruning entire
    channels.)
    """
    def __init__(self, z_logvar_init=-10):
        super().__init__()
        self.z_chunk_sizes = []
        self.z_logvar_init = z_logvar_init
        self.z_logvar_min = min(z_logvar_init, -10)
        self.z_logvar_max = 10.
        self.epsilon = 1e-8
        self.data_views = {}
        self.modules = []

        # Populated during register(), deleted during finalize()
        self.all_z_mu = []
        self.all_z_logvar = []
        self.all_num_weights = []

        # Populated during finalize()
        self.z_mu = None
        self.z_logvar = None
        self.z_num_weights = None

        self.threshold = 3

    def extra_repr(self):
        s = f"z_logvar_init={self.z_logvar_init}"
        return s

    def __getitem__(self, key):
        return self.data_views[key]

    def register(self, module, z_mu, z_logvar, num_weights_per_z=1):
        self.all_z_mu.append(z_mu.flatten())
        self.all_z_logvar.append(z_logvar.flatten())
        self.all_num_weights.append(num_weights_per_z)

        self.modules.append(module)
        data_index = len(self.z_chunk_sizes)
        self.z_chunk_sizes.append(z_mu.numel())

        return data_index

    def finalize(self):
        self.z_mu = Parameter(torch.cat(self.all_z_mu))
        self.z_logvar = Parameter(torch.cat(self.all_z_logvar))
        self.z_num_weights = torch.tensor(self.all_num_weights,
                                          dtype=torch.float).repeat_interleave(
                                              torch.tensor(self.z_chunk_sizes))
        del self.all_z_mu
        del self.all_z_logvar
        del self.all_num_weights

    def to(self, *args, **kwargs):
        ret = super().to(*args, **kwargs)
        self.z_num_weights = self.z_num_weights.to(*args, **kwargs)
        return ret

    def compute_forward_data(self):
        if self.training:
            self.data_views["z_mu"] = self.z_mu.split(self.z_chunk_sizes)
            self.data_views["z_var"] = self.z_logvar.exp().split(
                self.z_chunk_sizes)
        else:
            self.data_views["z_mu"] = (
                self.z_mu *
                (self.compute_z_logalpha() < self.threshold).float()).split(
                    self.z_chunk_sizes)

    def clear_forward_data(self):
        self.data_views.clear()

    def compute_z_logalpha(self):
        return self.z_logvar - (self.z_mu.square() + self.epsilon).log()

    def regularization(self):
        return (vdrop_regularization(self.compute_z_logalpha()) *
                self.z_num_weights).sum()

    def constrain_parameters(self):
        self.z_logvar.data.clamp_(min=self.z_logvar_min, max=self.z_logvar_max)
Example #11
0
class _ConvNdGroupNJ(Module):
    """Convolutional Group Normal-Jeffrey's layers (aka Group Variational Dropout).

    References:
    [1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015).
    [2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017).
    [3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017).
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding,
                 groups, bias, init_weight, init_bias, cuda=False, clip_var=None):
        super(_ConvNdGroupNJ, self).__init__()
        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.transposed = transposed
        self.output_padding = output_padding
        self.groups = groups

        self.cuda = cuda
        self.clip_var = clip_var
        self.deterministic = False  # flag is used for compressed inference

        if transposed:
            self.weight_mu = Parameter(torch.Tensor(
                in_channels, out_channels // groups, *kernel_size))
            self.weight_logvar = Parameter(torch.Tensor(
                in_channels, out_channels // groups, *kernel_size))
        else:
            self.weight_mu = Parameter(torch.Tensor(
                out_channels, in_channels // groups, *kernel_size))
            self.weight_logvar = Parameter(torch.Tensor(
                out_channels, in_channels // groups, *kernel_size))

        self.bias_mu = Parameter(torch.Tensor(out_channels))
        self.bias_logvar = Parameter(torch.Tensor(out_channels))

        self.z_mu = Parameter(torch.Tensor(self.out_channels))
        self.z_logvar = Parameter(torch.Tensor(self.out_channels))

        self.reset_parameters(init_weight, init_bias)

        # activations for kl
        self.sigmoid = nn.Sigmoid()
        self.softplus = nn.Softplus()
        # numerical stability param
        self.epsilon = 1e-8

    def reset_parameters(self, init_weight, init_bias):
        # init means
        n = self.in_channels
        for k in self.kernel_size:
            n *= k
        stdv = 1. / math.sqrt(n)

        # init means
        if init_weight is not None:
            self.weight_mu.data = init_weight
        else:
            self.weight_mu.data.uniform_(-stdv, stdv)

        if init_bias is not None:
            self.bias_mu.data = init_bias
        else:
            self.bias_mu.data.fill_(0)

        # inti z
        self.z_mu.data.normal_(1, 1e-2)

        # init logvars
        self.z_logvar.data.normal_(-9, 1e-2)
        self.weight_logvar.data.normal_(-9, 1e-2)
        self.bias_logvar.data.normal_(-9, 1e-2)

    def clip_variances(self):
        if self.clip_var:
            self.weight_logvar.data.clamp_(max=math.log(self.clip_var))
            self.bias_logvar.data.clamp_(max=math.log(self.clip_var))

    def get_log_dropout_rates(self):
        log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon)
        return log_alpha

    def compute_posterior_params(self):
        weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp()
        self.post_weight_var = self.z_mu.pow(2) * weight_var + z_var * self.weight_mu.pow(2) + z_var * weight_var
        self.post_weight_mu = self.weight_mu * self.z_mu
        return self.post_weight_mu, self.post_weight_var

    def kl_divergence(self):
        # KL(q(z)||p(z))
        # we use the kl divergence approximation given by [2] Eq.(14)
        k1, k2, k3 = 0.63576, 1.87320, 1.48695
        log_alpha = self.get_log_dropout_rates()
        KLD = -torch.sum(k1 * self.sigmoid(k2 + k3 * log_alpha) - 0.5 * self.softplus(-log_alpha) - k1)

        # KL(q(w|z)||p(w|z))
        # we use the kl divergence given by [3] Eq.(8)
        KLD_element = - 0.5 * self.weight_logvar + 0.5 * (self.weight_logvar.exp() + self.weight_mu.pow(2)) - 0.5
        KLD += torch.sum(KLD_element)

        # KL bias
        KLD_element = - 0.5 * self.bias_logvar + 0.5 * (self.bias_logvar.exp() + self.bias_mu.pow(2)) - 0.5
        KLD += torch.sum(KLD_element)

        return KLD

    def __repr__(self):
        s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}')
        if self.padding != (0,) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1,) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0,) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)
Example #12
0
class TrimConv2d(Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 log_alpha=-10.0,
                 lamda=0.1,
                 h=0):
        super(TrimConv2d, self).__init__()
        if in_channels % groups != 0:
            raise ValueError("in_channels must be divisible by groups")
        if out_channels % groups != 0:
            raise ValueError("out_channels must be divisible by groups")
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.weight = Parameter(
            torch.Tensor(out_channels, in_channels // groups, *kernel_size))
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        # For Bayesian inference
        self.log_alpha = Parameter(
            torch.randn(*self.weight.size()).fill_(log_alpha))

        # KL divergence
        self.c1 = 1.16145124
        self.c2 = -1.50204118
        self.c3 = 0.58629921

        # Trimming Parameters
        self.lamda = lamda  # regularization parameters
        self.n_pnt = out_channels - h  # the number of penalties

        if torch.cuda.is_available():
            self.regw = torch.ones(
                out_channels).cuda()  # output feature map sparsity
        else:
            self.regw = torch.ones(out_channels)

        self.mask = None
        self.bias_mask = None
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_normal_(self.weight, mode='fan_in')
        if self.bias is not None:
            self.bias.data.fill_(0.0)
        self.regw.data.fill_(self.n_pnt / self.out_channels)
        self.regw.requires_grad_()

    def forward(self, inputs):
        self.log_alpha.data.clamp_(max=0.0)
        if self.training:
            # Local Reparametrization Trick on Training Phase
            mu = F.conv2d(inputs, self.weight)
            std = torch.sqrt(
                F.conv2d(inputs**2,
                         self.log_alpha.exp() * self.weight**2) + 1e-8)

            # This means that sampling only one time for each datapoint
            eps = torch.randn(*mu.size())
            if inputs.is_cuda:
                eps = eps.cuda()

            output_size = eps.size()
            conv_bias = self.bias.unsqueeze(0).repeat(
                output_size[0],
                1).unsqueeze(-1).repeat(1, 1,
                                        output_size[2]).unsqueeze(-1).repeat(
                                            1, 1, 1, output_size[3])
            return torch.add((mu + std * eps), conv_bias)
        else:
            # Test Phase
            if self.mask is not None:
                return F.conv2d(inputs, self.mask * self.weight, self.bias)
            else:
                return F.conv2d(inputs, self.weight, self.bias)

    def kld(self):
        """
        Variational Dropout and the Local Reparametrization Trick
        --> Variational (A2) method
        """
        self.log_alpha.data.clamp_(max=0.0)
        alpha = self.log_alpha.exp()
        nkld = 0.5 * self.log_alpha + self.c1 * alpha + self.c2 * alpha**2 + self.c3 * alpha**3
        kld = -nkld
        return kld.mean() / 3

    def compute_expected_flops(self):
        """
        To be implemented
        """
        return

    def reg_theta_loss(self):
        """
        For subgradient method
        """
        eps = torch.randn(*self.weight.size())
        if torch.cuda.is_available():
            eps = eps.cuda()
        mu = self.weight
        std = torch.sqrt(self.log_alpha.exp() * self.weight**2 + 1e-8)
        samples = mu + std * eps
        return self.lamda * torch.sum(self.regw.data * conv_norm(samples))

    def reg_w_loss(self):
        eps = torch.randn(*self.weight.size())
        if torch.cuda.is_available():
            eps = eps.cuda()
        mu = self.weight.data
        std = torch.sqrt(self.log_alpha.data.exp() * self.weight.data**2 +
                         1e-8)
        samples = mu + std * eps
        return torch.sum(self.regw * conv_norm(samples))

    def set_weight_mask(self):
        _, in_filters, height, width = self.weight.size()
        self.mask = self.regw.data < 1.0
        if torch.cuda.is_available():
            self.mask = self.mask.type(torch.cuda.FloatTensor)
        else:
            self.mask = self.mask.type(torch.FloatTensor)
        self.mask = self.mask.unsqueeze(-1).repeat(
            1, in_filters).unsqueeze(-1).repeat(1, 1,
                                                height).unsqueeze(-1).repeat(
                                                    1, 1, 1, width)

    def set_bias_mask(self):
        self.bias_mask = self.regw.data < 1.0
        if torch.cuda.is_available():
            self.bias_mask = self.bias_mask.type(torch.cuda.FloatTensor)
        else:
            self.bias_mask = self.bias_mask.type(torch.FloatTensor)
        return

    def apply_mask(self):
        self.weight.data = self.weight.data * self.mask
        if self.bias_mask is not None:
            self.bias.data = self.bias.data * self.bias_mask
        return

    def extra_repr(self):
        return "in_channels={}, out_channels={}, bias={}".format(
            self.in_channels, self.out_channels, self.bias is not None)
Example #13
0
class TrimDense(Module):
    """
    Dense layer for the Trimmed \ell_1 Regularization
    We treat the network parameters as a Bayesian
    """
    def __init__(self,
                 in_features,
                 out_features,
                 log_alpha='hidden',
                 lamda=0.1,
                 h=0,
                 bias=True):
        """
        Args:
            in_features:    the number of input-neurons
            out_features:   the number of output-neurons
            h:              the number of largest entries which do not be penalized
            bias:           use bias or not
        """
        super(TrimDense, self).__init__()
        assert in_features >= h

        # Fully-Connected Layers
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)

        # For Bayesian Inference
        self.log_alpha = Parameter(
            torch.randn(in_features, out_features).fill_(-10.0))
        self.c1 = 1.16145124
        self.c2 = -1.50204118
        self.c3 = 0.58629921

        # Trimming Parameters
        self.lamda = lamda  # regularization parameters
        self.n_pnt = in_features - h  # the number of penalties

        if torch.cuda.is_available():
            self.regw = torch.ones(in_features).cuda()  # input-neuron sparsity
        else:
            self.regw = torch.ones(in_features)

        self.mask = None
        self.bias_mask = None
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_normal_(self.weight.data, mode='fan_out')
        if self.bias is not None:
            self.bias.data.fill_(0)
        self.regw.data.fill_(self.n_pnt / self.in_features)
        self.regw.requires_grad_()

    def forward(self, inputs):
        self.log_alpha.data.clamp_(max=0.0)
        if self.training:
            # Local Reparametrization Trick on Training Phase
            mu = torch.mm(inputs, self.weight)
            std = torch.sqrt(
                torch.mm(inputs**2,
                         self.log_alpha.exp() * self.weight**2) + 1e-8)

            # This means that sampling only one time for each datapoint
            eps = torch.randn(*mu.size())
            if inputs.is_cuda:
                eps = eps.cuda()
            return std * eps + mu + self.bias
        else:
            # Test Phase
            if self.mask is not None:
                return torch.addmm(self.bias, inputs, self.mask * self.weight)
            else:
                return torch.addmm(self.bias, inputs, self.weight)

    def kld(self):
        """
        Variational Dropout and the Local Reparametrization Trick
        --> Variational (A2) method
        """
        self.log_alpha.data.clamp_(max=0.0)
        alpha = self.log_alpha.exp()
        nkld = 0.5 * self.log_alpha + self.c1 * alpha + self.c2 * alpha**2 + self.c3 * alpha**3
        kld = -nkld
        return kld.mean() / 3

    def compute_expected_flops(self):
        """
        To be implemented
        """
        return

    def reg_theta_loss(self, batch_size=100):
        """
        For subgradient method
        """
        eps = torch.randn(*self.weight.size())
        if torch.cuda.is_available():
            eps = eps.cuda()
        mu = self.weight
        std = torch.sqrt(self.log_alpha.exp() * self.weight**2 + 1e-8)
        samples = mu + std * eps

        return self.lamda * torch.sum(self.regw.data * samples.norm(dim=1))

    def reg_w_loss(self, batch_size=100):
        eps = torch.randn(*self.weight.size())
        if torch.cuda.is_available():
            eps = eps.cuda()
        mu = self.weight.data
        std = torch.sqrt(self.log_alpha.data.exp() * self.weight.data**2 +
                         1e-8)
        samples = mu + std * eps
        return torch.sum(self.regw * samples.norm(dim=1))

    def set_weight_mask(self):
        self.mask = (self.regw.data < 1.0).unsqueeze(-1).repeat(
            1, self.out_features)
        if torch.cuda.is_available():
            self.mask = self.mask.type(torch.cuda.FloatTensor)
        else:
            self.mask = self.mask.type(torch.FloatTensor)

    def set_bias_mask(self, next_layer_regw):
        assert len(next_layer_regw) == self.out_features
        self.bias_mask = next_layer_regw.data < 1.0
        if torch.cuda.is_available():
            self.bias_mask = self.bias_mask.type(torch.cuda.FloatTensor)
        else:
            self.bias_mask = self.bias_mask.type(torch.FloatTensor)
        return

    def apply_mask(self):
        self.weight.data = self.weight.data * self.mask
        if self.bias_mask is not None:
            self.bias.data = self.bias.data * self.bias_mask
        return

    def extra_repr(self):
        return "in_features={}, out_features={}, bias={}".format(
            self.in_features, self.out_features, self.bias is not None)
class LinearGroupNJ(BayesianLayers):
    """Fully Connected Group Normal-Jeffrey's layer (aka Group Variational Dropout).

    References:
    [1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015).
    [2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017).
    [3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017).
    """

    def __init__(self, in_features, out_features, cuda=False, init_weight=None, init_bias=None, clip_var=None):

        super(LinearGroupNJ, self).__init__()
        self.cuda = cuda
        self.in_features = in_features
        self.out_features = out_features
        self.clip_var = clip_var
        self.deterministic = False  # flag is used for compressed inference
        # trainable params according to Eq.(6)
        # dropout params
        self.z_mu = Parameter(torch.Tensor(in_features))
        self.z_logvar = Parameter(torch.Tensor(in_features))  # = z_mu^2 * alpha
        # weight params
        self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
        self.weight_logvar = Parameter(torch.Tensor(out_features, in_features))

        self.bias_mu = Parameter(torch.Tensor(out_features))
        self.bias_logvar = Parameter(torch.Tensor(out_features))

        # init params either random or with pretrained net
        self.reset_parameters(init_weight, init_bias)

        # activations for kl
        self.sigmoid = nn.Sigmoid()
        self.softplus = nn.Softplus()

        # numerical stability param
        self.epsilon = 1e-8

    def reset_parameters(self, init_weight, init_bias):
        # init means
        stdv = 1. / math.sqrt(self.weight_mu.size(1))

        self.z_mu.data.normal_(1, 1e-2)

        if init_weight is not None:
            self.weight_mu.data = torch.Tensor(init_weight)
        else:
            self.weight_mu.data.normal_(0, stdv)

        if init_bias is not None:
            self.bias_mu.data = torch.Tensor(init_bias)
        else:
            self.bias_mu.data.fill_(0)

        # init logvars
        self.z_logvar.data.normal_(-9, 1e-2)
        self.weight_logvar.data.normal_(-9, 1e-2)
        self.bias_logvar.data.normal_(-9, 1e-2)

    def clip_variances(self):
        if self.clip_var:
            self.weight_logvar.data.clamp_(max=math.log(self.clip_var))
            self.bias_logvar.data.clamp_(max=math.log(self.clip_var))

    def get_log_dropout_rates(self):
        log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon)
        return log_alpha

    def compute_posterior_params(self):
        weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp()
        self.post_weight_var = self.z_mu.pow(2) * weight_var + z_var * self.weight_mu.pow(2) + z_var * weight_var
        self.post_weight_mu = self.weight_mu * self.z_mu
        # print("self.z_mu.pow(2): ", self.z_mu.pow(2).size())
        # print("weight_var: ", weight_var.size())
        # print("z_var: ", z_var.size())
        # print("self.weight_mu.pow(2): ", self.weight_mu.pow(2).size())
        # print("weight_var: ", weight_var.size())
        # print("post_weight_mu: ", self.post_weight_mu.size())
        # print("post_weight_var: ", self.post_weight_var.size())
        return self.post_weight_mu, self.post_weight_var

    def forward(self, x):
        if self.deterministic:
            assert self.training == False, "Flag deterministic is True. This should not be used in training."
            return F.linear(x, self.post_weight_mu, self.bias_mu)

        batch_size = x.size()[0]
        # compute z
        # note that we reparametrise according to [2] Eq. (11) (not [1])
        z = reparametrize(self.z_mu.repeat(batch_size, 1), self.z_logvar.repeat(batch_size, 1), sampling=self.training,
                          cuda=self.cuda)

        # apply local reparametrisation trick see [1] Eq. (6)
        # to the parametrisation given in [3] Eq. (6)
        xz = x * z
        mu_activations = F.linear(xz, self.weight_mu, self.bias_mu)
        var_activations = F.linear(xz.pow(2), self.weight_logvar.exp(), self.bias_logvar.exp())

        return reparametrize(mu_activations, var_activations.log(), sampling=self.training, cuda=self.cuda)

    def kl_divergence(self):
        # KL(q(z)||p(z))
        # we use the kl divergence approximation given by [2] Eq.(14)
        k1, k2, k3 = 0.63576, 1.87320, 1.48695
        log_alpha = self.get_log_dropout_rates()
        KLD = -torch.sum(k1 * self.sigmoid(k2 + k3 * log_alpha) - 0.5 * self.softplus(-log_alpha) - k1)

        # KL(q(w|z)||p(w|z))
        # we use the kl divergence given by [3] Eq.(8)
        KLD_element = -0.5 * self.weight_logvar + 0.5 * (self.weight_logvar.exp() + self.weight_mu.pow(2)) - 0.5
        KLD += torch.sum(KLD_element)

        # KL bias
        KLD_element = -0.5 * self.bias_logvar + 0.5 * (self.bias_logvar.exp() + self.bias_mu.pow(2)) - 0.5
        KLD += torch.sum(KLD_element)

        return KLD

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'
Example #15
0
class VaDE(torch.nn.Module):
    """Variational Deep Embedding(VaDE).

    Args:
        n_classes (int): Number of clusters.
        data_dim (int): Dimension of observed data.
        latent_dim (int): Dimension of latent space.
    """
    def __init__(self, n_classes, data_dim, latent_dim):
        super(VaDE, self).__init__()

        self._pi = Parameter(torch.zeros(n_classes))
        self.mu = Parameter(torch.randn(n_classes, latent_dim))
        self.logvar = Parameter(torch.randn(n_classes, latent_dim))

        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(data_dim, 512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, 512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, 2048),
            torch.nn.ReLU(),
        )
        self.encoder_mu = torch.nn.Linear(2048, latent_dim)
        self.encoder_logvar = torch.nn.Linear(2048, latent_dim)

        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(latent_dim, 2048),
            torch.nn.ReLU(),
            torch.nn.Linear(2048, 512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, 512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, data_dim),
            torch.nn.Sigmoid(),
        )

    @property
    def weights(self):
        return torch.softmax(self._pi, dim=0)

    def encode(self, x):
        h = self.encoder(x)
        mu = self.encoder_mu(h)
        logvar = self.encoder_logvar(h)
        return mu, logvar

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = _reparameterize(mu, logvar)
        recon_x = self.decode(z)
        return recon_x, mu, logvar

    def classify(self, x, n_samples=8):
        with torch.no_grad():
            mu, logvar = self.encode(x)
            z = torch.stack(
                [_reparameterize(mu, logvar) for _ in range(n_samples)], dim=1)
            z = z.unsqueeze(2)
            h = z - self.mu
            h = torch.exp(-0.5 * torch.sum(h * h / self.logvar.exp(), dim=3))
            # Same as `torch.sqrt(torch.prod(self.logvar.exp(), dim=1))`
            h = h / torch.sum(0.5 * self.logvar, dim=1).exp()
            p_z_given_c = h / (2 * math.pi)
            p_z_c = p_z_given_c * self.weights
            y = p_z_c / torch.sum(p_z_c, dim=2, keepdim=True)
            y = torch.sum(y, dim=1)
            pred = torch.argmax(y, dim=1)
        return pred
class _ConvNdGroupNJ(BayesianLayers):
    """Convolutional Group Normal-Jeffrey's layers (aka Group Variational Dropout).

    References:
    [1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015).
    [2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017).
    [3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017).
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding,
                 groups, bias, init_weight, init_bias, cuda=False, clip_var=None):
        super(_ConvNdGroupNJ, self).__init__()
        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.transposed = transposed
        self.output_padding = output_padding
        self.groups = groups

        self.cuda = cuda
        self.clip_var = clip_var
        self.deterministic = False  # flag is used for compressed inference

        if transposed:
            self.weight_mu = Parameter(torch.Tensor(
                in_channels, out_channels // groups, *kernel_size))
            self.weight_logvar = Parameter(torch.Tensor(
                in_channels, out_channels // groups, *kernel_size))
        else:
            self.weight_mu = Parameter(torch.Tensor(
                out_channels, in_channels // groups, *kernel_size))
            self.weight_logvar = Parameter(torch.Tensor(
                out_channels, in_channels // groups, *kernel_size))

        self.bias_mu = Parameter(torch.Tensor(out_channels))
        self.bias_logvar = Parameter(torch.Tensor(out_channels))

        self.z_mu = Parameter(torch.Tensor(self.out_channels))
        self.z_logvar = Parameter(torch.Tensor(self.out_channels))

        self.reset_parameters(init_weight, init_bias)

        # activations for kl
        self.sigmoid = nn.Sigmoid()
        self.softplus = nn.Softplus()
        # numerical stability param
        self.epsilon = 1e-8

    def reset_parameters(self, init_weight, init_bias):
        # init means
        n = self.in_channels
        for k in self.kernel_size:
            n *= k
        stdv = 1. / math.sqrt(n)

        # init means
        if init_weight is not None:
            self.weight_mu.data = init_weight
        else:
            self.weight_mu.data.uniform_(-stdv, stdv)

        if init_bias is not None:
            self.bias_mu.data = init_bias
        else:
            self.bias_mu.data.fill_(0)

        # inti z
        self.z_mu.data.normal_(1, 1e-2)

        # init logvars
        self.z_logvar.data.normal_(-9, 1e-2)
        self.weight_logvar.data.normal_(-9, 1e-2)
        self.bias_logvar.data.normal_(-9, 1e-2)

    def clip_variances(self):
        if self.clip_var:
            self.weight_logvar.data.clamp_(max=math.log(self.clip_var))
            self.bias_logvar.data.clamp_(max=math.log(self.clip_var))

    def get_log_dropout_rates(self):
        log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon)
        return log_alpha

    def compute_posterior_params(self):
        weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp()
        print("self.z_mu.pow(2): ", self.z_mu.pow(2).size())
        print("weight_var: ", weight_var.size())
        print("z_var: ", z_var.size())
        print("self.weight_mu.pow(2): ", self.weight_mu.pow(2).size())
        print("weight_var: ", weight_var.size())
        part1 = self.z_mu.pow(2) * weight_var
        part2 = z_var * self.weight_mu.pow(2)
        part3 = z_var * weight_var
        self.post_weight_var = part1 + part2 + part3
        self.post_weight_mu = self.weight_mu * self.z_mu
        print("post_weight_mu: ", self.post_weight_mu.size())
        print("post_weight_var: ", self.post_weight_var.size())
        return self.post_weight_mu, self.post_weight_var

    def kl_divergence(self):
        # KL(q(z)||p(z))
        # we use the kl divergence approximation given by [2] Eq.(14)
        k1, k2, k3 = 0.63576, 1.87320, 1.48695
        log_alpha = self.get_log_dropout_rates()
        KLD = -torch.sum(k1 * self.sigmoid(k2 + k3 * log_alpha) - 0.5 * self.softplus(-log_alpha) - k1)

        # KL(q(w|z)||p(w|z))
        # we use the kl divergence given by [3] Eq.(8)
        KLD_element = -self.weight_logvar + 0.5 * (self.weight_logvar.exp().pow(2) + self.weight_mu.pow(2)) - 0.5
        KLD += torch.sum(KLD_element)

        # KL bias
        KLD_element = -self.bias_logvar + 0.5 * (self.bias_logvar.exp().pow(2) + self.bias_mu.pow(2)) - 0.5
        KLD += torch.sum(KLD_element)

        return KLD

    def __repr__(self):
        s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}')
        if self.padding != (0,) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1,) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0,) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)
class HSDense(Module):
    def __init__(self,
                 in_features,
                 out_features,
                 bias=True,
                 prior_std=1.,
                 prior_std_z=1.,
                 dof=1.,
                 **kwargs):
        super(HSDense, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.prior_std = prior_std
        self.mean_w = Parameter(torch.Tensor(in_features, out_features))
        self.logvar_w = Parameter(torch.Tensor(in_features, out_features))
        self.qz_mean = Parameter(torch.Tensor(in_features))
        self.qz_logvar = Parameter(torch.Tensor(in_features))
        self.dof = dof
        self.prior_std_z = prior_std_z
        self.use_bias = False
        if bias:
            self.mean_bias = Parameter(torch.Tensor(out_features))
            self.logvar_bias = Parameter(torch.Tensor(out_features))
            self.use_bias = True
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.reset_parameters()
        print(self)

    def reset_parameters(self):
        init.kaiming_normal_(self.mean_w, mode='fan_out')
        self.logvar_w.data.normal_(-9., 1e-4)

        self.qz_mean.data.normal_(math.log(math.exp(1) - 1), 1e-3)
        self.qz_logvar.data.normal_(math.log(0.1), 1e-4)
        if self.use_bias:
            self.mean_bias.data.normal_(0, 1e-2)
            self.logvar_bias.data.normal_(-9., 1e-4)

    def constrain_parameters(self, thres_std=1.):
        self.logvar_w.data.clamp_(max=2. * math.log(thres_std))
        if self.use_bias:
            self.logvar_bias.data.clamp_(max=2. * math.log(thres_std))

    def eq_logpw(self):
        logpw = -.5 * math.log(
            2 * math.pi * self.prior_std**2) - .5 * self.logvar_w.exp().div(
                self.prior_std**2)
        logpw -= .5 * self.mean_w.pow(2).div(self.prior_std**2)
        logpb = 0.
        if self.use_bias:
            logpb = - .5 * math.log(2 * math.pi * self.prior_std ** 2) - .5 * self.logvar_bias.exp().div \
                (self.prior_std ** 2)
            logpb -= .5 * self.mean_bias.pow(2).div(self.prior_std**2)
        return torch.sum(logpw) + torch.sum(logpb)

    def eq_logqw(self):
        logqw = -torch.sum(.5 * (math.log(2 * math.pi) + self.logvar_w + 1))
        logqb = 0.
        if self.use_bias:
            logqb = -torch.sum(.5 *
                               (math.log(2 * math.pi) + self.logvar_bias + 1))
        return logqw + logqb

    def kldiv_aux(self):
        z = self.sample_z(1)
        z = z.view(self.in_features)

        logqm = -torch.sum(.5 * (math.log(2 * math.pi) + self.qz_logvar + 1))
        logqm = logqm.add(-torch.sum(F.sigmoid(z.exp().add(-1).log()).log()))

        logpm = torch.sum(
            2 * math.lgamma(.5 * (self.dof + 1)) - math.lgamma(.5 * self.dof) -
            math.log(self.prior_std_z) - .5 * math.log(self.dof * math.pi) -
            .5 * (self.dof + 1) * torch.log(1. + z.pow(2) /
                                            (self.dof * self.prior_std_z**2)))

        return logpm - logqm

    def kldiv(self):
        return self.kldiv_aux() + self.eq_logpw() - self.eq_logqw()

    def get_eps(self, size):
        eps = self.floatTensor(size).normal_()
        eps = Variable(eps)
        return eps

    def sample_z(self, batch_size):
        z = self.qz_mean.view(1, self.in_features)
        if self.training:
            eps = self.get_eps(self.floatTensor(batch_size, self.in_features))
            z = z + eps.mul(
                self.qz_logvar.view(1, self.in_features).mul(0.5).exp_())
        return F.softplus(z)

    def sample_W(self):
        W = self.mean_w
        if self.training:
            eps = self.get_eps(self.mean_w.size())
            W = W.add(eps.mul(self.logvar_w.mul(0.5).exp_()))
        return W

    def sample_b(self):
        b = self.mean_bias
        if self.training:
            eps = self.get_eps(self.mean_bias.size())
            b = b.add(eps.mul(self.logvar_bias.mul(0.5).exp_()))
        return b

    def get_mean_x(self, input):
        mean_xin = input.mm(self.mean_w)
        if self.use_bias:
            mean_xin = mean_xin.add(self.mean_bias.view(1, self.out_features))

        return mean_xin

    def get_var_x(self, input):
        var_xin = input.pow(2).mm(self.logvar_w.exp())
        if self.use_bias:
            var_xin = var_xin.add(self.logvar_bias.exp().view(
                1, self.out_features))

        return var_xin

    def forward(self, input):
        # sampling
        batch_size = input.size(0)
        z = self.sample_z(batch_size)
        xin = input.mul(z)
        mean_xin = self.get_mean_x(xin)
        output = mean_xin
        if self.training:
            var_xin = self.get_var_x(xin)
            eps = self.get_eps(self.floatTensor(batch_size, self.out_features))
            output = output.add(var_xin.sqrt().mul(eps))
        return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ', dof: ' \
               + str(self.dof) + ', prior_std_z: ' \
               + str(self.prior_std_z) + ')'
class FFGaussDense(Module):
    def __init__(self,
                 in_features,
                 out_features,
                 bias=True,
                 prior_std=1.,
                 **kwargs):
        super(FFGaussDense, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.prior_std = prior_std
        self.mean_w = Parameter(torch.Tensor(in_features, out_features))
        self.logvar_w = Parameter(torch.Tensor(in_features, out_features))
        self.use_bias = False
        if bias:
            self.mean_bias = Parameter(torch.Tensor(out_features))
            self.logvar_bias = Parameter(torch.Tensor(out_features))
            self.use_bias = True
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.reset_parameters()
        print(self)

    def reset_parameters(self):
        init.kaiming_normal_(self.mean_w, mode='fan_out')
        self.logvar_w.data.normal_(-9., 1e-4)

        if self.use_bias:
            self.mean_bias.data.zero_()
            self.logvar_bias.data.normal_(-9., 1e-4)

    def constrain_parameters(self, thres_std=1.):
        self.logvar_w.data.clamp_(max=2. * math.log(thres_std))
        if self.use_bias:
            self.logvar_bias.data.clamp_(max=2. * math.log(thres_std))

    def eq_logpw(self):
        logpw = -.5 * math.log(
            2 * math.pi * self.prior_std**2) - .5 * self.logvar_w.exp().div(
                self.prior_std**2)
        logpw -= .5 * self.mean_w.pow(2).div(self.prior_std**2)
        logpb = 0.
        if self.use_bias:
            logpb = -.5 * math.log(2 * math.pi * self.prior_std**
                                   2) - .5 * self.logvar_bias.exp().div(
                                       self.prior_std**2)
            logpb -= .5 * self.mean_bias.pow(2).div(self.prior_std**2)
        return torch.sum(logpw) + torch.sum(logpb)

    def eq_logqw(self):
        logqw = -torch.sum(.5 * (math.log(2 * math.pi) + self.logvar_w + 1))
        logqb = 0.
        if self.use_bias:
            logqb = -torch.sum(.5 *
                               (math.log(2 * math.pi) + self.logvar_bias + 1))
        return logqw + logqb

    def kldiv_aux(self):
        return 0.

    def kldiv(self):
        return self.kldiv_aux() + self.eq_logpw() - self.eq_logqw()

    def get_eps(self, size):
        eps = self.floatTensor(size).normal_()
        eps = Variable(eps)
        return eps

    def sample_pW(self):
        return self.floatTensor(self.in_features, self.out_features).normal_()

    def sample_pb(self):
        return self.floatTensor(self.out_features).normal_()

    def sample_W(self):
        w = self.mean_w
        if self.training:
            eps = self.get_eps(self.mean_w.size())
            w = w.add(eps.mul(self.logvar_w.mul(0.5).exp_()))
        return w

    def sample_b(self):
        b = self.mean_bias
        if self.training:
            eps = self.get_eps(self.mean_bias.size())
            b = b.add(eps.mul(self.logvar_bias.mul(0.5).exp_()))
        return b

    def get_mean_x(self, input):
        mean_xin = input.mm(self.mean_w)
        if self.use_bias:
            mean_xin = mean_xin.add(self.mean_bias.view(1, self.out_features))
        return mean_xin

    def get_var_x(self, input):
        var_xin = input.pow(2).mm(self.logvar_w.exp())
        if self.use_bias:
            var_xin = var_xin.add(self.logvar_bias.exp().view(
                1, self.out_features))

        return var_xin

    def forward(self, input):
        batch_size = input.size(0)
        mean_xin = self.get_mean_x(input)
        if self.training:
            var_xin = self.get_var_x(input)
            eps = self.get_eps(self.floatTensor(batch_size, self.out_features))
            output = mean_xin.add(var_xin.sqrt().mul(eps))
        else:
            output = mean_xin
        return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
            + str(self.in_features) + ' -> ' \
            + str(self.out_features) + ', prior_std: ' \
            + str(self.prior_std) + ')'
Example #19
0
class VDropLinear2(nn.Module):
    """
    A self-contained VDropLinear (doesn't use the VDropCentralData)
    """
    def __init__(self,
                 in_features,
                 out_features,
                 bias=True,
                 w_logvar_init=-10):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.w_logvar_min = min(w_logvar_init, -10)
        self.w_logvar_max = 10.
        self.pruned_logvar_sentinel = self.w_logvar_max - 0.00058
        self.epsilon = 1e-8

        self.w_mu = Parameter(torch.Tensor(self.out_features,
                                           self.in_features))
        self.w_logvar = Parameter(
            torch.Tensor(self.out_features, self.in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.bias = None

        self.w_logvar.data.fill_(w_logvar_init)
        # Standard nn.Linear initialization.
        init.kaiming_uniform_(self.w_mu, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.w_mu)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

        self.tensor_constructor = (torch.FloatTensor
                                   if not torch.cuda.is_available() else
                                   torch.cuda.FloatTensor)

    def extra_repr(self):
        s = f"{self.in_features}, {self.out_features}, "
        if self.bias is None:
            s += ", bias=False"
        return s

    def get_w_mu(self):
        return self.w_mu

    def get_w_var(self):
        return self.w_logvar.exp()

    def forward(self, x):
        if self.training:
            return vdrop_linear_forward(x, self.get_w_mu, self.get_w_var,
                                        self.bias, self.tensor_constructor)
        else:
            return F.linear(x, self.get_w_mu(), self.bias)

    def compute_w_logalpha(self):
        return self.w_logvar - (self.w_mu.square() + self.epsilon).log()

    def regularization(self):
        return vdrop_regularization(self.compute_w_logalpha()).sum()

    def constrain_parameters(self):
        self.w_logvar.data.clamp_(min=self.w_logvar_min, max=self.w_logvar_max)
Example #20
0
class VariationalDropoutCNN(nn.Module):
    def __init__(self,
                 in_channel,
                 out_channel,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 log_sigma2=-8,
                 threshold=3):
        """
        :param input_channel: An int of input channel
        :param log_sigma2: Initial value of log sigma ^ 2.
               It is crusial for training since it determines initial value of alpha
        :param threshold: Value for thresholding of validation. If log_alpha > threshold, then weight is zeroed
        :param out_channel: An int of output channel
        """
        super(VariationalDropoutCNN, self).__init__()
        #         self.m = img_row
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups

        self.theta = Parameter(
            t.Tensor(out_channel, in_channel // groups, kernel_size,
                     kernel_size))
        self.prior_theta = 0.
        self.prior_log_sigma2 = -2.
        #         self.bias = Parameter(t.Tensor(out_channel, in_channel // groups, kernel_size, kernel_size))
        #         self.bias = Parameter(t.Tensor(out_channel, self.m-kernel_size+1, self.m-kernel_size+1))
        self.sz = out_channel * (in_channel // groups) * kernel_size**2
        self.log_sigma2 = Parameter(
            t.FloatTensor(out_channel, in_channel // groups, kernel_size,
                          kernel_size).fill_(log_sigma2))
        self.s = Parameter(t.Tensor([scale]))
        self.code = t.Tensor([0.2, 0, -0.2])

        self.reset_parameters()

        self.k = [0.63576, 1.87320, 1.48695]

        self.threshold = threshold

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.out_channel)

        self.theta.data.uniform_(-stdv, stdv)
#         self.bias.data.uniform_(-stdv, stdv)

    @staticmethod
    def clip_logsig(input):
        input = input.masked_fill(input < -10, -10)
        input = input.masked_fill(input > 1, 1)

        return input

    def clip(self):
        self.log_sigma2.masked_fill(self.log_sigma2 < -10, -10)
        self.log_sigma2.masked_fill(self.log_sigma2 > 1, 1)
        self.theta.data = t.where(
            self.theta < (-0.2 - 0.3679 * t.sqrt(self.log_sigma2.exp())),
            (-0.2 - 0.3679 * t.sqrt(self.log_sigma2.exp())), self.theta)
        self.theta.data = t.where(
            self.theta > (0.2 + 0.3679 * t.sqrt(self.log_sigma2.exp())),
            (0.2 + 0.3679 * t.sqrt(self.log_sigma2.exp())), self.theta)
#         self.theta.masked_fill(self.theta < (-0.2-0.3679*t.sqrt(self.log_sigma2.exp())), (-0.2-0.3679*t.sqrt(self.log_sigma2.exp())))
#         self.theta.masked_fill(self.theta > (0.2+0.3679*t.sqrt(self.log_sigma2.exp())), (0.2+0.3679*t.sqrt(self.log_sigma2.exp())))

    def kld(self, idx):

        window1 = gaussian_window(self.theta * self.s, 0.2)
        window2 = gaussian_window(self.theta * self.s, -0.2)

        log_alpha1 = self.log_sigma2 + 2 * t.log(self.s) - t.log(
            (self.theta * self.s - 0.2)**2)
        log_alpha2 = self.log_sigma2 + 2 * t.log(self.s) - t.log(
            (self.theta * self.s)**2)
        log_alpha3 = self.log_sigma2 + 2 * t.log(self.s) - t.log(
            (self.theta * self.s + 0.2)**2)

        F_KLLU1 = kllu(log_alpha1)
        F_KLLU2 = kllu(log_alpha2)
        F_KLLU3 = kllu(log_alpha3)
        F_KL = F_KLLU1 * window1 + F_KLLU3 * window2 + F_KLLU2 * (1 - window1 -
                                                                  window2)
        return F_KL.sum() / (self.sz)

    def forward(self, input, train, noquan):
        """
        :param input: An float tensor with shape of [batch_size, input_size]
        :return: An float tensor with shape of [batch_size, out_size] and negative layer-kld estimation
        """
        self.clip()
        c1 = (self.theta * self.s - 0.2)**2
        c2 = (self.theta * self.s)**2
        c3 = (self.theta * self.s + 0.2)**2
        mean = t.min(t.min(c1, c2), c3)
        c = t.stack((c1, c2, c3), 0)
        idx = t.argmin(c, 0)

        if not train and not noquan:
            """
            mask = log_alpha > self.threshold
            return F.conv2d( input, weight = self.theta.masked_fill(mask, 0), stride=self.stride, 
                          padding=self.padding,dilation=self.dilation, groups=self.groups)
            """

            theta_q = self.theta.data.clone()

            theta_q[:] = self.code[idx].cuda() / self.s
            mu = F.conv2d(input,
                          weight=theta_q,
                          stride=self.stride,
                          padding=self.padding,
                          dilation=self.dilation,
                          groups=self.groups)

            kld = t.sum((theta_q - self.theta)**2)
            return mu, kld  #+self.bias , kld
        if noquan:
            kld = 0
            theta_q = self.theta.data.clone()
            mu = F.conv2d(input,
                          weight=theta_q,
                          stride=self.stride,
                          padding=self.padding,
                          dilation=self.dilation,
                          groups=self.groups)

            return mu, kld  #+self.bias , kld
        kld = _kl_loss(self.theta, self.log_sigma2, self.prior_theta,
                       self.prior_log_sigma2) / self.sz
        mu = F.conv2d(input,
                      weight=self.theta * self.s,
                      stride=self.stride,
                      padding=self.padding,
                      dilation=self.dilation,
                      groups=self.groups)
        std = t.sqrt(
            F.conv2d(input**2,
                     weight=self.log_sigma2.exp() * self.s**2,
                     stride=self.stride,
                     padding=self.padding,
                     dilation=self.dilation,
                     groups=self.groups) + 1e-6)

        eps = Variable(t.randn(*mu.size()))
        if input.is_cuda:
            eps = eps.cuda()
        return std * eps + mu, kld  # + self.bias , kld

    def max_alpha(self):
        log_alpha = self.log_sigma2 - (self.theta - 0.2)**2
        return t.max(log_alpha.exp())
Example #21
0
class MaskedVDropConv2d(nn.Module):
    """
    A self-contained masked Conv2d (doesn't use the VDropCentralData)
    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 mask=None,
                 w_logvar_init=-10):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = pair(kernel_size)
        self.stride = pair(stride)
        self.padding = pair(padding)
        self.dilation = pair(dilation)
        self.groups = groups

        self.w_logvar_min = min(w_logvar_init, -10)
        self.w_logvar_max = 10.
        self.pruned_logvar_sentinel = self.w_logvar_max - 0.00058
        self.epsilon = 1e-8

        self.w_mu = Parameter(
            torch.Tensor(out_channels, in_channels // groups,
                         *self.kernel_size))
        self.w_logvar = Parameter(
            torch.Tensor(out_channels, in_channels // groups,
                         *self.kernel_size))
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.bias = None

        self.w_logvar.data.fill_(w_logvar_init)

        self.register_buffer(
            "w_mask",
            torch.HalfTensor(out_channels, in_channels // groups,
                             *self.kernel_size))

        # Standard nn.Conv2d initialization.
        init.kaiming_uniform_(self.w_mu, a=math.sqrt(5))

        if mask is not None:
            self.w_mask[:] = mask
            self.w_mu.data *= self.w_mask
            self.w_logvar.data[self.w_mask ==
                               0.0] = self.pruned_logvar_sentinel
        else:
            self.w_mask.fill_(1.0)

        # Standard nn.Conv2d initialization.
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.w_mu)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

        self.tensor_constructor = (torch.FloatTensor
                                   if not torch.cuda.is_available() else
                                   torch.cuda.FloatTensor)

    def extra_repr(self):
        s = (f"{self.in_channels}, {self.out_channels}, "
             f"kernel_size={self.kernel_size}, stride={self.stride}")
        if self.padding != (0, ) * len(self.padding):
            s += f", padding={self.padding}"
        if self.dilation != (1, ) * len(self.dilation):
            s += f", dilation={self.dilation}"
        if self.groups != 1:
            s += f", groups={self.groups}"
        if self.bias is None:
            s += ", bias=False"
        return s

    def get_w_mu(self):
        return self.w_mu * self.w_mask

    def get_w_var(self):
        return self.w_logvar.exp() * self.w_mask

    def forward(self, x):
        if self.training:
            return vdrop_conv_forward(x, self.get_w_mu, self.get_w_var,
                                      self.bias, self.stride, self.padding,
                                      self.dilation, self.groups,
                                      self.tensor_constructor)
        else:
            return F.conv2d(x, self.get_w_mu(), self.bias, self.stride,
                            self.padding, self.dilation, self.groups)

    def compute_w_logalpha(self):
        return self.w_logvar - (self.w_mu.square() + self.epsilon).log()

    def regularization(self):
        return (vdrop_regularization(self.compute_w_logalpha()) *
                self.w_mask).sum()

    def constrain_parameters(self):
        self.w_logvar.data.clamp_(min=self.w_logvar_min, max=self.w_logvar_max)
Example #22
0
class VDropConv2d(Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = pair(kernel_size)
        self.stride = pair(stride)
        self.padding = pair(padding)
        self.dilation = pair(dilation)
        self.groups = groups

        self.w_mu = Parameter(torch.Tensor(out_channels, in_channels // groups,
                                           *self.kernel_size))
        init.kaiming_normal_(self.w_mu, mode="fan_out")

        self.w_logsigma2 = Parameter(torch.Tensor(out_channels, in_channels // groups,
                                                  *self.kernel_size))
        self.w_logsigma2.data.fill_(-10)

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
            self.bias.data.fill_(0)
        else:
            self.bias = None

        self.input_shape = None

        self.threshold = 3
        self.epsilon = 1e-8
        self.tensor = (torch.FloatTensor if not torch.cuda.is_available()
                       else torch.cuda.FloatTensor)

    def compute_mask(self):
        w_logalpha = self.w_logsigma2 - (self.w_mu ** 2 + self.epsilon).log()
        return (w_logalpha < self.threshold).float()

    def forward(self, x):
        if self.input_shape is None:
            self.input_shape = x.size()

        if self.training:
            y_mu = F.conv2d(x, self.w_mu, self.bias, self.stride, self.padding,
                            self.dilation, self.groups)

            # Avoid sqrt(0), otherwise a divide-by-zero occurs during backprop.
            y_sigma = F.conv2d(
                x ** 2, self.w_logsigma2.exp(), None, self.stride, self.padding,
                self.dilation, self.groups
            ).clamp(self.epsilon).sqrt()

            rv = self.tensor(y_mu.size()).normal_()
            return y_mu + (rv * y_sigma)
        else:
            return F.conv2d(x, self.w_mu * self.compute_mask(), self.bias,
                            self.stride, self.padding, self.dilation,
                            self.groups)

    def regularization(self):
        k1, k2, k3 = 0.63576, 1.8732, 1.48695
        w_logalpha = self.w_logsigma2 - (self.w_mu ** 2 + self.epsilon).log()

        return -(k1 * torch.sigmoid(k2 + k3 * w_logalpha)
                 - 0.5 * F.softplus(-w_logalpha) - k1).sum()

    def get_inference_nonzeros(self):
        mask = self.compute_mask().int()
        return mask.sum(dim=tuple(range(1, len(mask.shape))))

    def count_inference_flops(self):
        # For each unit, multiply with its n inputs then do n - 1 additions.
        # To capture the -1, subtract it, but only in cases where there is at
        # least one weight.
        nz_by_unit = self.get_inference_nonzeros()
        multiplies_per_instance = torch.sum(nz_by_unit)
        adds_per_instance = multiplies_per_instance - torch.sum(nz_by_unit > 0)

        # for rows
        instances = (
            (self.input_shape[-2] - self.kernel_size[0]
             + 2 * self.padding[0]) / self.stride[0]) + 1
        # multiplying with cols
        instances *= (
            (self.input_shape[-1] - self.kernel_size[1] + 2 * self.padding[1])
            / self.stride[1]) + 1

        multiplies = multiplies_per_instance * instances
        adds = adds_per_instance * instances

        return multiplies.item(), adds.item()

    def weight_size(self):
        return self.w_mu.size()
Example #23
0
class discrete_vision_actor_critic_Net(nn.Module):
    def __init__(self,
                 s_dim,
                 n_actions,
                 latent_dim,
                 n_heads=8,
                 init_log_alpha=0.0,
                 parallel=True,
                 lr=1e-4,
                 lr_alpha=1e-4,
                 lr_actor=1e-4):
        super().__init__()

        self.s_dim = s_dim
        self.n_actions = n_actions
        self._parallel = parallel

        self.q = vision_multihead_dueling_q_Net(s_dim, latent_dim, n_actions,
                                                n_heads, lr)
        self.q_target = vision_multihead_dueling_q_Net(s_dim, latent_dim,
                                                       n_actions, n_heads, lr)
        self.update(rate=1.0)

        self.actor = vision_softmax_policy_Net(s_dim,
                                               latent_dim,
                                               n_actions,
                                               noisy=False,
                                               lr=lr_alpha)

        self.log_alpha = Parameter(torch.Tensor(1))
        nn.init.constant_(self.log_alpha, init_log_alpha)
        self.alpha_optimizer = Adam([self.log_alpha], lr=lr_alpha)

    def forward(self):
        pass

    def evaluate_critic(self, inner_state, outer_state, next_inner_state,
                        next_outer_state):
        q = self.q(inner_state, outer_state)
        next_q = self.q_target(next_inner_state, next_outer_state)
        next_pi, next_log_pi = self.actor(next_inner_state, next_outer_state)
        log_alpha = self.log_alpha.view(-1, 1)
        return q, next_q, next_pi, next_log_pi, log_alpha

    def evaluate_actor(self, inner_state, outer_state):
        q = self.q(inner_state, outer_state)
        pi, log_pi = self.actor(inner_state, outer_state)
        return q, pi, log_pi

    def sample_action(self, inner_state, outer_state, explore=True):
        PA_s = self.actor(inner_state.view(1, -1),
                          outer_state.unsqueeze(0))[0].squeeze(0).view(-1)
        assert torch.all(PA_s == PA_s), 'Boom. Capoot.'
        if explore:
            A = Categorical(probs=PA_s).sample().item()
        else:
            tie_breaking_dist = torch.isclose(PA_s, PA_s.max()).float()
            tie_breaking_dist /= tie_breaking_dist.sum()
            A = Categorical(probs=tie_breaking_dist).sample().item()
        return A, PA_s.detach().cpu().numpy()

    def update(self, rate=5e-3):
        updateNet(self.q_target, self.q, rate)

    def get_alpha(self):
        return self.log_alpha.exp().item()
Example #24
0
class GNJConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True):

        # Init torch module
        super(GNJConv2d, self).__init__()

        # Init conv params
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = stride
        self.padding = padding
        self.dilation = dilation

        # Init filter latents
        self.weight_mu = Parameter(Tensor(out_channels, in_channels, *self.kernel_size))
        self.weight_logvar = Parameter(Tensor(out_channels, in_channels, *self.kernel_size))


        self.bias = bias
        self.bias_mu = Parameter(Tensor(out_channels)) if self.bias else None
        self.bias_logvar = Parameter(Tensor(out_channels)) if self.bias else None

        # Init prior latents
        self.z_mu = Parameter(Tensor(out_channels))
        self.z_logvar = Parameter(Tensor(out_channels))

        # Set initial parameters
        self._init_params()

        # for brevity to conv2d calls
        self.convargs = [self.stride, self.padding, self.dilation]

        # util activations
        self.sigmoid = Sigmoid()
        self.softplus = Softplus()


    # forward network pass
    def forward(self, x):

        # vanilla forward pass if testing
        if not self.training:
            post_weight_mu = self.weight_mu * self.z_mu[:, None, None, None]
            post_bias_mu = self.bias_mu * self.z_mu if (self.bias_mu is not None) else None
            return conv2d(x, post_weight_mu, post_bias_mu, *self.convargs)

        #batch_size = x.size()[0]

        # unpack mean/std
        mu = self.z_mu
        std = torch.exp(0.5 * self.z_logvar)

        # rsample: sample scale prior with reparam trick
        z = Normal(mu, std).rsample()[None, :, None, None]

        # weights and biases for variance estimation
        weight_v = self.weight_logvar.exp()
        bias_v = self.bias_logvar.exp() if self.bias else None

        # parameterise output distribution
        mu_out = conv2d(x, self.weight_mu, self.bias_mu, *self.convargs) * z
        var_out = conv2d(x**2, weight_v, bias_v, *self.convargs) * (z ** 2)

        # Init out, note multiplicative noise==variational dropout
        dist_out = Normal(mu_out, var_out.sqrt()).rsample()
        #dist_out = self.reparam(mu_out*z, (var_out * z.pow(2)).log())

        return dist_out

    def _init_params(self, weight=None, bias=None):

        n = self.in_channels * self.kernel_size[0] * self.kernel_size[1]
        thresh = 1/math.sqrt(n)

        # weights
        self.weight_logvar.data.normal_(-9, 1e-2)

        if weight is not None:
            self.weight_mu.data = weight
        else:
            self.weight_mu.data.uniform_(-thresh, thresh)


        if self.bias:
            # biases
            self.bias_logvar.data.normal_(-9, 1e-2)

            if bias is not None:
                self.bias_mu.data = bias
            else:
                self.bias_mu.data.fill_(0)

        # priors
        self.z_mu.data.normal_(1, 1e-2)
        self.z_logvar.data.normal_(-9, 1e-2)


    # shape,scale family reparameterization trick (rsample does this?)
    def reparam(self, mu, logvar):
        std = torch.exp(0.5 * logvar)

        # check for cuda
        #tenv = torch.cuda if cuda else torch

        # draw from normal
        eps = torch.FloatTensor(std.size()).normal_()

        return mu + eps * std

    # KL div for GNJ w. Normal approx posterior
    def kl_divergence(self):

        # for brevity in kl_scale
        sg = self.sigmoid
        sp = self.softplus

        # Approximation parameters. Molchanov et al.
        k1, k2, k3 = 0.63576, 1.87320, 1.48695
        log_alpha = self._log_alpha()
        kl_scale = torch.sum(0.5 * sp(-log_alpha) + k1 - k1 * sg(k2  + k3 * log_alpha))
        kl_weight = self._conditional_kl_div(self.weight_mu, self.weight_logvar)
        kl_bias = self._conditional_kl_div(self.bias_mu, self.bias_logvar) if self.bias else 0

        return kl_scale + kl_weight + kl_bias

    @staticmethod
    def _conditional_kl_div(mu, logvar):
        # (8) Weight/bias divergence KL(q(w|z)||p(w|z))
        kl_div = -0.5 * logvar + 0.5 * (logvar.exp() + mu ** 2 - 1)
        return torch.sum(kl_div)

    # effective dropout rate
    def _log_alpha(self):
        epsilon = 1e-8
        log_a = self.z_logvar  - torch.log(self.z_mu ** 2 + epsilon)
        return log_a
Example #25
0
class LinearGroupNJ(Module):
    """Fully Connected Group Normal-Jeffrey's layer (aka Group Variational Dropout).

    References:
    [1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015).
    [2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017).
    [3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017).
    """

    def __init__(self, in_features, out_features, cuda=False, init_weight=None, init_bias=None, clip_var=None):

        super(LinearGroupNJ, self).__init__()
        self.cuda = cuda
        self.in_features = in_features
        self.out_features = out_features
        self.clip_var = clip_var
        self.deterministic = False  # flag is used for compressed inference
        # trainable params according to Eq.(6)
        # dropout params
        self.z_mu = Parameter(torch.Tensor(in_features))
        self.z_logvar = Parameter(torch.Tensor(in_features))  # = z_mu^2 * alpha
        # weight params
        self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
        self.weight_logvar = Parameter(torch.Tensor(out_features, in_features))

        self.bias_mu = Parameter(torch.Tensor(out_features))
        self.bias_logvar = Parameter(torch.Tensor(out_features))

        # init params either random or with pretrained net
        self.reset_parameters(init_weight, init_bias)

        # activations for kl
        self.sigmoid = nn.Sigmoid()
        self.softplus = nn.Softplus()

        # numerical stability param
        self.epsilon = 1e-8

    def reset_parameters(self, init_weight, init_bias):
        # init means
        stdv = 1. / math.sqrt(self.weight_mu.size(1))

        self.z_mu.data.normal_(1, 1e-2)

        if init_weight is not None:
            self.weight_mu.data = torch.Tensor(init_weight)
        else:
            self.weight_mu.data.normal_(0, stdv)

        if init_bias is not None:
            self.bias_mu.data = torch.Tensor(init_bias)
        else:
            self.bias_mu.data.fill_(0)

        # init logvars
        self.z_logvar.data.normal_(-9, 1e-2)
        self.weight_logvar.data.normal_(-9, 1e-2)
        self.bias_logvar.data.normal_(-9, 1e-2)

    def clip_variances(self):
        if self.clip_var:
            self.weight_logvar.data.clamp_(max=math.log(self.clip_var))
            self.bias_logvar.data.clamp_(max=math.log(self.clip_var))

    def get_log_dropout_rates(self):
        log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon)
        return log_alpha

    def compute_posterior_params(self):
        weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp()
        self.post_weight_var = self.z_mu.pow(2) * weight_var + z_var * self.weight_mu.pow(2) + z_var * weight_var
        self.post_weight_mu = self.weight_mu * self.z_mu
        return self.post_weight_mu, self.post_weight_var

    def forward(self, x):
        if self.deterministic:
            assert self.training == False, "Flag deterministic is True. This should not be used in training."
            return F.linear(x, self.post_weight_mu, self.bias_mu)

        batch_size = x.size()[0]
        # compute z  
        # note that we reparametrise according to [2] Eq. (11) (not [1])
        z = reparametrize(self.z_mu.repeat(batch_size, 1), self.z_logvar.repeat(batch_size, 1), sampling=self.training)

        # apply local reparametrisation trick see [1] Eq. (6)
        # to the parametrisation given in [3] Eq. (6)
        xz = x * z
        mu_activations = F.linear(xz, self.weight_mu, self.bias_mu)
        var_activations = F.linear(xz.pow(2), self.weight_logvar.exp(), self.bias_logvar.exp())

        return reparametrize(mu_activations, var_activations.log(), sampling=self.training)

    def kl_divergence(self):
        # KL(q(z)||p(z))
        # we use the kl divergence approximation given by [2] Eq.(14)
        k1, k2, k3 = 0.63576, 1.87320, 1.48695
        log_alpha = self.get_log_dropout_rates()
        KLD = -torch.sum(k1 * self.sigmoid(k2 + k3 * log_alpha) - 0.5 * self.softplus(-log_alpha) - k1)

        # KL(q(w|z)||p(w|z))
        # we use the kl divergence given by [3] Eq.(8)
        KLD_element = -0.5 * self.weight_logvar + 0.5 * (self.weight_logvar.exp() + self.weight_mu.pow(2)) - 0.5
        KLD += torch.sum(KLD_element)

        # KL bias
        KLD_element = -0.5 * self.bias_logvar + 0.5 * (self.bias_logvar.exp() + self.bias_mu.pow(2)) - 0.5
        KLD += torch.sum(KLD_element)

        return KLD

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'
class FFGaussConv2d(Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 prior_std=1,
                 **kwargs):
        super(FFGaussConv2d, self).__init__()
        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = pair(kernel_size)
        self.stride = pair(stride)
        self.padding = pair(padding)
        self.dilation = pair(dilation)
        self.output_padding = pair(0)
        self.groups = groups
        self.prior_std = prior_std
        self.use_bias = False
        self.mean_w = Parameter(
            torch.Tensor(out_channels, in_channels // groups,
                         *self.kernel_size))
        self.logvar_w = Parameter(
            torch.Tensor(out_channels, in_channels // groups,
                         *self.kernel_size))

        if bias:
            self.mean_bias = Parameter(torch.Tensor(out_channels))
            self.logvar_bias = Parameter(torch.Tensor(out_channels))
            self.use_bias = True
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.reset_parameters()
        print(self)

    def reset_parameters(self):
        init.kaiming_normal_(self.mean_w, mode='fan_in')
        self.logvar_w.data.normal_(-9., 1e-4)
        if self.use_bias:
            self.mean_bias.data.zero_()
            self.logvar_bias.data.normal_(-9., 1e-4)

    def constrain_parameters(self, thres_std=1.):
        self.logvar_w.data.clamp_(max=2. * math.log(thres_std))
        if self.use_bias:
            self.logvar_bias.data.clamp_(max=2. * math.log(thres_std))

    def eq_logpw(self):
        logpw = -.5 * math.log(
            2 * math.pi * self.prior_std**2) - .5 * self.logvar_w.exp().div(
                self.prior_std**2)
        logpw -= .5 * self.mean_w.pow(2).div(self.prior_std**2)
        logpb = 0.
        if self.use_bias:
            logpb = -.5 * math.log(2 * math.pi * self.prior_std**
                                   2) - .5 * self.logvar_bias.exp().div(
                                       self.prior_std**2)
            logpb -= .5 * self.mean_bias.pow(2).div(self.prior_std**2)
        return torch.sum(logpw) + torch.sum(logpb)

    def eq_logqw(self):
        logqw = -torch.sum(.5 * (math.log(2 * math.pi) + self.logvar_w + 1))
        logqb = 0.
        if self.use_bias:
            logqb = -torch.sum(.5 *
                               (math.log(2 * math.pi) + self.logvar_bias + 1))
        return logqw + logqb

    def kldiv_aux(self):
        return 0.

    def kldiv(self):
        return self.kldiv_aux() + self.eq_logpw() - self.eq_logqw()

    def get_eps(self, size):
        eps = self.floatTensor(size).normal_()
        eps = Variable(eps)
        return eps

    def sample_W(self):
        W = self.mean_w
        if self.training:
            eps = self.get_eps(self.mean_w.size())
            W = W.add(eps.mul(self.logvar_w.mul(0.5).exp_()))
        return W

    def sample_b(self):
        if not self.use_bias:
            return None
        b = self.mean_bias
        if self.training:
            eps = self.get_eps(self.mean_bias.size())
            b = b.add(eps.mul(self.logvar_bias.mul(0.5).exp_()))
        return b

    def forward(self, input_):
        W = self.sample_W()
        b = self.sample_b()

        return F.conv2d(input_, W, b, self.stride, self.padding, self.dilation,
                        self.groups)

    def __repr__(self):
        s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}, prior_std={prior_std}')
        if self.padding != (0, ) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1, ) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0, ) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if not self.use_bias:
            s += ', bias=False'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)
Example #27
0
class VariationalDropout(nn.Module):
    def __init__(self, input_size, out_size, log_sigma2=-8, threshold=3):
        """
        :param input_size: An int of input size
        :param log_sigma2: Initial value of log sigma ^ 2.
               It is crusial for training since it determines initial value of alpha
        :param threshold: Value for thresholding of validation. If log_alpha > threshold, then weight is zeroed
        :param out_size: An int of output size
        """
        super(VariationalDropout, self).__init__()

        self.input_size = input_size
        self.out_size = out_size

        self.theta = Parameter(t.FloatTensor(input_size, out_size))
        self.bias = Parameter(t.Tensor(out_size))
        self.prior_theta = 0.
        self.prior_log_sigma2 = -2.
        self.log_sigma2 = Parameter(
            t.FloatTensor(input_size, out_size).fill_(log_sigma2))
        self.sz = input_size * out_size
        self.s = Parameter(t.Tensor([scale]))
        self.code = t.Tensor([0.2, 0, -0.2])

        self.reset_parameters()

        self.k = [0.63576, 1.87320, 1.48695]

        self.threshold = threshold

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.out_size)

        self.theta.data.uniform_(-stdv, stdv)
        self.bias.data.uniform_(-stdv, stdv)

    @staticmethod
    def clip(self):
        self.log_sigma2.masked_fill(self.log_sigma2 < -10, -10)
        self.log_sigma2.masked_fill(self.log_sigma2 > 1, 1)
        self.theta.data = t.where(
            self.theta < (-0.2 - 0.3679 * t.sqrt(self.log_sigma2.exp())),
            (-0.2 - 0.3679 * t.sqrt(self.log_sigma2.exp())), self.theta)
        self.theta.data = t.where(
            self.theta > (0.2 + 0.3679 * t.sqrt(self.log_sigma2.exp())),
            (0.2 + 0.3679 * t.sqrt(self.log_sigma2.exp())), self.theta)
#         self.theta.masked_fill(self.theta < (-0.2-0.3679*t.sqrt(self.log_sigma2.exp())), (-0.2-0.3679*t.sqrt(self.log_sigma2.exp())))
#         self.theta.masked_fill(self.theta > (0.2+0.3679*t.sqrt(self.log_sigma2.exp())), (0.2+0.3679*t.sqrt(self.log_sigma2.exp())))

    def clip__(input, to=8):
        input = input.masked_fill(input < -to, -to)
        input = input.masked_fill(input > to, to)

        return input


#     def kllu(self,log_alpha):
#         first_term = self.k[0] * F.sigmoid(self.k[1] + self.k[2] * log_alpha)
#         second_term = 0.5 * t.log(1 + t.exp(-log_alpha))
#         return -(first_term - second_term - self.k[0])

    def kld(self, mean, idx):

        window1 = gaussian_window(mean * self.s, 0.2)
        window2 = gaussian_window(mean * self.s, -0.2)

        log_alpha1 = self.log_sigma2 + 2 * t.log(self.s) - t.log(
            (mean * self.s - 0.2)**2)
        log_alpha2 = self.log_sigma2 + 2 * t.log(self.s) - t.log(
            (mean * self.s)**2)
        log_alpha3 = self.log_sigma2 + 2 * t.log(self.s) - t.log(
            (mean * self.s + 0.2)**2)

        F_KLLU1 = kllu(log_alpha1)
        F_KLLU2 = kllu(log_alpha2)
        F_KLLU3 = kllu(log_alpha3)
        #         print(F_KLLU1)
        #         print(F_KLLU2)
        #         print(F_KLLU3)
        #         print(hi)
        F_KL = F_KLLU1 * window1 + F_KLLU3 * window2 + F_KLLU2 * (1 - window1 -
                                                                  window2)
        return F_KL.sum() / (self.sz)

    def forward(self, input, train, noquan):
        """
        :param input: An float tensor with shape of [batch_size, input_size]
        :return: An float tensor with shape of [batch_size, out_size] and negative layer-kld estimation
        """
        self.clip(self)
        c1 = (self.theta * self.s - 0.2)**2
        c2 = (self.theta * self.s)**2
        c3 = (self.theta * self.s + 0.2)**2
        mean = t.min(t.min(c1, c2), c3)
        c = t.stack((c1, c2, c3), 0)
        idx = t.argmin(c, 0)
        #         print(idx)

        if not train and not noquan:
            """
            mask = log_alpha > self.threshold
            return t.addmm(self.bias, input, self.theta.masked_fill(mask, 0))
            """
            theta_q = self.theta.data.clone()
            theta_q[:] = self.code[idx].cuda() / self.s
            #             mask = log_alpha > self.threshold

            mu = t.mm(input, theta_q)
            kld = t.sum((theta_q - self.theta)**2)

            return mu + self.bias, kld
        if noquan:
            kld = 0
            """
            mask = log_alpha > self.threshold
            return t.addmm(self.bias, input, self.theta.masked_fill(mask, 0))
            """
            theta_q = self.theta.data.clone()
            mu = t.mm(input, theta_q)

            return mu + self.bias, kld
        kld = _kl_loss(self.theta, self.log_sigma2, self.prior_theta,
                       self.prior_log_sigma2) / self.sz
        mu = t.mm(input, self.theta * self.s)
        std = t.sqrt(t.mm(input**2, self.s**2 * self.log_sigma2.exp()) + 1e-6)

        eps = Variable(t.randn(*mu.size()))
        if input.is_cuda:
            eps = eps.cuda()

        return std * eps + mu + self.bias, kld

    def max_alpha(self):
        log_alpha = self.log_sigma2 - self.theta**2
        return t.max(log_alpha.exp())
Example #28
0
class VDropConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = pair(kernel_size)
        self.stride = pair(stride)
        self.padding = pair(padding)
        self.dilation = pair(dilation)
        self.groups = groups

        self.weight = Parameter(torch.Tensor(out_channels,
                                             in_channels // groups,
                                             *self.kernel_size))
        init.kaiming_normal_(self.weight, mode="fan_out")

        self.w_logvar = Parameter(torch.Tensor(out_channels,
                                               in_channels // groups,
                                               *self.kernel_size))
        self.w_logvar.data.fill_(-10)

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
            self.bias.data.fill_(0)
        else:
            self.bias = None

        self.input_shape = None

        self.threshold = 3
        self.epsilon = 1e-8
        self.tensor_constructor = (torch.FloatTensor
                                   if not torch.cuda.is_available()
                                   else torch.cuda.FloatTensor)

    def extra_repr(self):
        s = (f"{self.in_channels}, {self.out_channels}, "
             f"kernel_size={self.kernel_size}, stride={self.stride}")
        if self.padding != (0,) * len(self.padding):
            s += f", padding={self.padding}"
        if self.dilation != (1,) * len(self.dilation):
            s += f", dilation={self.dilation}"
        if self.groups != 1:
            s += f", groups={self.groups}"
        if self.bias is None:
            s += ", bias=False"
        return s

    def constrain_parameters(self):
        self.w_logvar.data.clamp_(min=-10., max=10.)

    def compute_mask(self):
        w_logalpha = self.w_logvar - (self.weight ** 2 + self.epsilon).log()
        return (w_logalpha < self.threshold).float()

    def forward(self, x):
        if self.input_shape is None:
            self.input_shape = x.size()

        if self.training:
            return vdrop_conv_forward(x,
                                      lambda: self.weight,
                                      lambda: self.w_logvar.exp(),
                                      self.bias, self.stride, self.padding,
                                      self.dilation, self.groups,
                                      self.tensor_constructor, self.epsilon)
        else:
            return F.conv2d(x, self.weight * self.compute_mask(), self.bias,
                            self.stride, self.padding, self.dilation,
                            self.groups)

    def regularization(self):
        w_logalpha = self.w_logvar - (self.weight ** 2 + self.epsilon).log()
        return vdrop_regularization(w_logalpha).sum()

    def get_inference_nonzeros(self):
        mask = self.compute_mask().int()
        return mask.sum(dim=tuple(range(1, len(mask.shape))))

    def count_inference_flops(self):
        # For each unit, multiply with its n inputs then do n - 1 additions.
        # To capture the -1, subtract it, but only in cases where there is at
        # least one weight.
        nz_by_unit = self.get_inference_nonzeros()
        multiplies_per_instance = torch.sum(nz_by_unit)
        adds_per_instance = multiplies_per_instance - torch.sum(nz_by_unit > 0)

        # for rows
        instances = (
            (self.input_shape[-2] - self.kernel_size[0]
             + 2 * self.padding[0]) / self.stride[0]) + 1
        # multiplying with cols
        instances *= (
            (self.input_shape[-1] - self.kernel_size[1] + 2 * self.padding[1])
            / self.stride[1]) + 1

        multiplies = multiplies_per_instance * instances
        adds = adds_per_instance * instances

        return multiplies.item(), adds.item()

    def weight_size(self):
        return self.weight.size()