コード例 #1
0
class wcLinear(nn.Linear):
    """
    custom Linear layers for quantization
    """
    def __init__(self, in_features, out_features, bias=True, rate=0.):
        super(wcLinear, self).__init__(in_features=in_features,
                                       out_features=out_features,
                                       bias=bias)

        self.binary_weight = Parameter(torch.ones(self.weight.data.size(1)))
        self.float_weight = Parameter(torch.ones(self.weight.data.size(1)))
        self.register_buffer('rate', torch.ones(1).fill_(rate))

    def compute_grad(self):
        self.float_weight.grad = Variable(self.binary_weight.grad.data)
        # set binary_weight_grad to zero is very very important
        self.binary_weight.grad = None

    def forward(self, input):
        if self.train:
            self.float_weight.clamp(min=0)
            self.binary_weight.data.copy_(
                self.float_weight.data.ge(self.rate[0]).float())

        # get new weight
        new_weight = self.binary_weight.unsqueeze(0).expand_as(
            self.weight) * self.weight

        return F.linear(input, new_weight, self.bias)
コード例 #2
0
class wcConv2d(nn.Conv2d):
    """
    custom convolutional layers for quantization
    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 bias=True,
                 rate=0.):
        super(wcConv2d, self).__init__(in_channels=in_channels,
                                       out_channels=out_channels,
                                       kernel_size=kernel_size,
                                       stride=stride,
                                       padding=padding,
                                       bias=bias)
        self.binary_weight = Parameter(torch.ones(self.weight.data.size(1)))
        self.float_weight = Parameter(torch.ones(self.weight.data.size(1)))
        self.register_buffer('rate', torch.ones(1).fill_(rate))

    def compute_grad(self):
        self.float_weight.grad = Variable(self.binary_weight.grad.data)
        # set binary_weight_grad to zero is very very important
        self.binary_weight.grad = None

    def forward(self, input):
        if self.train:
            self.float_weight.clamp(min=0)
            self.binary_weight.data.copy_(
                self.float_weight.data.ge(self.rate[0]).float())

        new_weight = self.binary_weight.unsqueeze(0).unsqueeze(2).unsqueeze(
            3).expand_as(self.weight) * self.weight
        return F.conv2d(input, new_weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)
class DOnePoleCell(Module):
    def __init__(self, a1=0.5, b0=1.0, b1=0.0):
        super(DOnePoleCell, self).__init__()
        self.b0 = Parameter(FloatTensor([b0]))
        self.b1 = Parameter(FloatTensor([b1]))
        self.a1 = Parameter(FloatTensor([a1]))

    def init_states(self, size):
        state = torch.zeros(size).to(self.a1.device)
        return state

    def forward(self, input, state):
        self.a1.data = self.a1.clamp(-1, 1)
        output = self.b0 * input + state
        state = self.b1 * input + self.a1 * output
        return output, state
コード例 #4
0
class CartesianAdj(Module):
    """Concatenates Cartesian spatial relations based on the position
    :math:`P \in \mathbb{R}^{N x D}` of graph nodes to the graph's edge
    attributes."""
    def __init__(self, r=None, trainable=False):
        super(CartesianAdj, self).__init__()
        if r is not None:
            r = torch.FloatTensor([r]).cuda()
        if trainable and r is not None:
            self.r = Parameter(r)
        else:
            self.r = r

    def __call__(self, data):
        row, col = data.index
        # Compute Cartesian pseudo-coordinates.
        weight = data.pos[col] - data.pos[row]

        max = weight.abs().max() if self.r is None else self.r.clamp(
            min=0.0001)

        if self.r is not None:
            weight = weight * (1 / max)
            factor = weight.abs().max(1)[0].clamp(min=1)
            weight = weight / factor.unsqueeze(1)
            weight = weight / 2
        else:
            weight = weight * (1 / (2 * max))

        weight = weight + 0.5

        if data.weight is None:
            data.weight = weight
        else:
            data.weight = torch.cat([weight, data.weight.unsqueeze(1)], dim=1)

        return data
コード例 #5
0
class BaseRNNCell(nn.Module):
    def __init__(self,
                 input_size,
                 hidden_size,
                 bias=False,
                 nonlinearity="tanh",
                 hidden_min_abs=0,
                 hidden_max_abs=None,
                 hidden_init=None,
                 recurrent_init=None,
                 gradient_clip=5):
        super(BaseRNNCell, self).__init__()
        self.hidden_max_abs = hidden_max_abs
        self.hidden_min_abs = hidden_min_abs
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.nonlinearity = nonlinearity
        self.hidden_init = hidden_init
        self.recurrent_init = recurrent_init
        if self.nonlinearity == "tanh":
            self.activation = F.tanh
        elif self.nonlinearity == "relu":
            self.activation = F.relu
        elif self.nonlinearity == "sigmoid":
            self.activation = F.sigmoid
        elif self.nonlinearity == "log":
            self.activation = torch.log
        elif self.nonlinearity == "sin":
            self.activation = torch.sin
        else:
            raise RuntimeError("Unknown nonlinearity: {}".format(
                self.nonlinearity))

        self.weight_ih = Parameter(torch.eye(hidden_size, input_size))
        self.weight_hh = Parameter(torch.Tensor(hidden_size, 20).uniform_())
        self.weight_hh1 = Parameter(torch.eye(input_size, hidden_size))
        if bias:
            self.bias_ih = Parameter(torch.randn(hidden_size))
        else:
            self.register_parameter('bias_ih', None)
        # self.reset_parameters()

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)

    # def reset_parameters(self):
    #     for name, weight in self.named_parameters():
    #         if "bias" in name:
    #             weight.data.zero_()
    #         elif "weight_hh" in name:
    #             if self.recurrent_init is None:
    #                 nn.init.constant_(weight, 1)
    #             else:
    #                 self.recurrent_init(weight)
    #         elif "weight_ih" in name:
    #             if self.hidden_init is None:
    #                 nn.init.normal_(weight, 0, 0.01)
    #             else:
    #                 self.hidden_init(weight)
    #         else:
    #             weight.data.normal_(0, 0.01)
    #             # weight.data.uniform_(-stdv, stdv)
    #     self.check_bounds()

    def check_bounds(self):
        if self.hidden_min_abs:
            abs_kernel = torch.abs(
                self.weight_hh.data).clamp_(min=self.hidden_min_abs)
            self.weight_hh.data = self.weight_hh.mul(
                torch.sign(self.weight_hh.data), abs_kernel)
        if self.hidden_max_abs:
            self.weight_hh.data = self.weight_hh.clamp(
                max=self.hidden_max_abs, min=-self.hidden_max_abs)

    def forward(self, input, hx):
        # x = F.linear(input, self.weight_ih, self.bias_ih) + torch.matmul(hx, self.weight_hh.matmul(self.weight_hh1))
        # return self.talor(x)
        return self.activation(
            F.linear(input, self.weight_ih, self.bias_ih) +
            torch.matmul(hx, self.weight_ih.matmul(self.weight_hh1)))

    def talor(self, x):
        return (x -
                1) - (x - 1) * (x - 1) / 2 + (x - 1) * (x - 1) * (x - 1) / 3
コード例 #6
0
class IndRNNCell(nn.Module):
    r"""An IndRNN cell with tanh or ReLU non-linearity.

    .. math::

        h' = \tanh(w_{ih} * x + b_{ih}  +  w_{hh} (*) h)
    With (*) being element-wise vector multiplication.
    If nonlinearity='relu', then ReLU is used in place of tanh.

    Args:
        input_size: The number of expected features in the input x
        hidden_size: The number of features in the hidden state h
        bias: If ``False``, then the layer does not use bias weights b_ih and b_hh.
            Default: ``True``
        nonlinearity: The non-linearity to use ['tanh'|'relu']. Default: 'relu'
        hidden_min_abs: Minimal absolute inital value for hidden weights. Default: 0
        hidden_max_abs: Maximal absolute inital value for hidden weights. Default: None

    Inputs: input, hidden
        - **input** (batch, input_size): tensor containing input features
        - **hidden** (batch, hidden_size): tensor containing the initial hidden
          state for each element in the batch.

    Outputs: h'
        - **h'** (batch, hidden_size): tensor containing the next hidden state
          for each element in the batch

    Attributes:
        weight_ih: the learnable input-hidden weights, of shape
            `(input_size x hidden_size)`
        weight_hh: the learnable hidden-hidden weights, of shape
            `(hidden_size)`
        bias_ih: the learnable input-hidden bias, of shape `(hidden_size)`

    Examples::

        >>> rnn = nn.IndRNNCell(10, 20)
        >>> input = Variable(torch.randn(6, 3, 10))
        >>> hx = Variable(torch.randn(3, 20))
        >>> output = []
        >>> for i in range(6):
        ...     hx = rnn(input[i], hx)
        ...     output.append(hx)
    """
    def __init__(self,
                 input_size,
                 hidden_size,
                 bias=True,
                 nonlinearity="relu",
                 hidden_min_abs=0,
                 hidden_max_abs=None,
                 hidden_init=None,
                 recurrent_init=None,
                 gradient_clip=None):
        super(IndRNNCell, self).__init__()
        self.hidden_max_abs = hidden_max_abs
        self.hidden_min_abs = hidden_min_abs
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.nonlinearity = nonlinearity
        self.hidden_init = hidden_init
        self.recurrent_init = recurrent_init
        if self.nonlinearity == "tanh":
            self.activation = F.tanh
        elif self.nonlinearity == "relu":
            self.activation = F.relu
        else:
            raise RuntimeError("Unknown nonlinearity: {}".format(
                self.nonlinearity))
        self.weight_ih = Parameter(torch.Tensor(hidden_size, input_size))
        self.weight_hh = Parameter(torch.Tensor(hidden_size))
        if bias:
            self.bias_ih = Parameter(torch.Tensor(hidden_size))
        else:
            self.register_parameter('bias_ih', None)

        if gradient_clip:
            if isinstance(gradient_clip, tuple):
                min_g, max_g = gradient_clip
            else:
                max_g = gradient_clip
                min_g = -max_g
            self.weight_ih.register_hook(
                lambda x: x.clamp(min=min_g, max=max_g))
            self.weight_hh.register_hook(
                lambda x: x.clamp(min=min_g, max=max_g))
            if bias:
                self.bias_ih.register_hook(
                    lambda x: x.clamp(min=min_g, max=max_g))

        self.reset_parameters()

    def reset_parameters(self):
        for name, weight in self.named_parameters():
            if "bias" in name:
                weight.data.zero_()
            elif "weight_hh" in name:
                if self.recurrent_init is None:
                    nn.init.constant_(weight, 1)
                else:
                    self.recurrent_init(weight)
            elif "weight_ih" in name:
                if self.hidden_init is None:
                    nn.init.normal_(weight, 0, 0.01)
                else:
                    self.hidden_init(weight)
            else:
                weight.data.normal_(0, 0.01)
                # weight.data.uniform_(-stdv, stdv)
        self.check_bounds()

    def check_bounds(self):
        if self.hidden_min_abs:
            abs_kernel = torch.abs(
                self.weight_hh.data).clamp_(min=self.hidden_min_abs)
            self.weight_hh.data = self.weight_hh.mul(
                torch.sign(self.weight_hh.data), abs_kernel)
        if self.hidden_max_abs:
            self.weight_hh.data = self.weight_hh.clamp(
                max=self.hidden_max_abs, min=-self.hidden_max_abs)

    def forward(self, input, hx):
        return self.activation(
            F.linear(input, self.weight_ih, self.bias_ih) +
            F.mul(self.weight_hh, hx))
コード例 #7
0
class IndRNNCell(nn.Module):
    def __init__(self,
                 input_size,
                 hidden_size,
                 bias=True,
                 activation="relu",
                 recurrent_min_abs=None,
                 recurrent_max_abs=None,
                 hidden_initializer=None,
                 recurrent_initializer=None,
                 gradient_clip_min=None,
                 gradient_clip_max=None):
        super(IndRNNCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.weight_ih = Parameter(
            torch.Tensor(self.hidden_size, self.input_size))
        self.weight_hh = Parameter(torch.Tensor(hidden_size))
        if bias:
            self.bias_ih = nn.Parameter(torch.Tensor(hidden_size))
        else:
            self.register_parameter('bias_ih', None)

        if activation == "relu":
            self.activation = F.relu
        elif activation == "tanh":
            self.activation = F.tanh
        else:
            warnings.warn(
                "IndRNN supports only ReLu and tanh activations. Fallingback to ReLU "
            )
            self.activation = F.relu
        self.recurrent_min_abs = recurrent_min_abs
        self.recurrent_max_abs = recurrent_max_abs
        self.hidden_initializer = hidden_initializer
        self.recurrent_initializer = recurrent_initializer

        # Gradient Clippnig to prevent Gradient Explosion and over fitting
        if not gradient_clip_max is None:
            self.gradient_clip_min = -gradient_clip_max
            self.gradient_clip_max = gradient_clip_max
            if not gradient_clip_min is None:
                self.gradient_clip_min = gradient_clip_min
            # register_hook will record the change to the parameter made
            # into the grad and this will be used during gradient descent
            self.weight_ih.register_hook(lambda x: x.clamp_(
                min=gradient_clip_min, max=gradient_clip_max))
            self.weight_hh.register_hook(lambda x: x.clamp_(
                min=gradient_clip_min, max=gradient_clip_max))
            if self.bias:
                self.bias_ih.register_hook(lambda x: x.clamp_(
                    min=gradient_clip_min, max=gradient_clip_max))

        # Initialize all parametere of the model
        for name, weight in self.named_parameters():
            if "bias" in name:
                # self.add_variable("bias", shape=[self._num_units], initializer=init_ops.zeros_initializer(dtype=self.dtype))
                weight.data.zero_()
            elif "weight_ih" in name:
                # self._input_initializer = init_ops.random_normal_initializer(mean=0.0, stddev=0.001)
                if self.hidden_initializer is None:
                    nn.init.normal_(weight, 0, 0.01)
                else:
                    self.hidden_initializer(weight)
            elif "weight_hh" in name:
                # self._recurrent_initializer = init_ops.constant_initializer(1.)
                if self.recurrent_initializer is None:
                    nn.init.constant_(weight, 1)
                else:
                    self.recurrent_initializer(weight)
            else:
                weight.data.normal_(0, 0.01)
        self.clip_recurrent_weights()

    def clip_recurrent_weights(self):
        # Clip the absolute values of the recurrent weights to the specified minimum
        r"""
        Code from https://github.com/batzner/indrnn/blob/master/ind_rnn_cell.py
        # Clip the absolute values of the recurrent weights to the specified minimum
            if self._recurrent_min_abs:
              abs_kernel = math_ops.abs(self._recurrent_kernel)
              min_abs_kernel = math_ops.maximum(abs_kernel, self._recurrent_min_abs)
              self._recurrent_kernel = math_ops.multiply(
                  math_ops.sign(self._recurrent_kernel),
                  min_abs_kernel
              )

            # Clip the absolute values of the recurrent weights to the specified maximum
            if self._recurrent_max_abs:
              self._recurrent_kernel = clip_ops.clip_by_value(self._recurrent_kernel,
                                                              -self._recurrent_max_abs,
                                                              self._recurrent_max_abs)
        """
        if self.recurrent_min_abs:
            abs_kernel = torch.abs(
                self.weight_hh.data).clamp_(min=self.recurrent_min_abs)
            self.weight_hh.data = abs_kernel.mm(torch.sign(
                self.weight_hh.data))
        if self.recurrent_max_abs:
            self.weight_hh.data = self.weight_hh.clamp(
                max=self.recurrent_max_abs, min=-self.recurrent_max_abs)

        # if self.recurrent_min_abs:
        #     # abs_kernel = torch.abs(self.weight_hh.data).clamp_(min=self.recurrent_min_abs)
        #     # self.weight_hh.data = self.weight_hh.mul(torch.sign(self.weight_hh.data), abs_kernel)
        #     abs_kernel = torch.abs(self.weight_hh.data).clamp_(min=self.recurrent_min_abs)
        #     self.weight_hh.data = self.weight_hh.mul(torch.sign(self.weight_hh.data), abs_kernel)
        #
        # # Clip the absolute values of the recurrent weights to the specified maximum
        # if self.recurrent_max_abs:
        #     self.weight_hh.data = self.weight_hh.clamp(min=-self._recurrent_max_abs,
        #                                                max=self._recurrent_max_abs)

        # Pendnng: Implement code for dropouts
        # --------

    def forward(self, input, hx=None):
        # out = tanh(w_{ih} * x + b_{ih}  +  w_{hh} (*) h)
        # (*) Hammard Product
        return self.activation(
            F.linear(input, self.weight_ih, self.bias_ih) +
            F.mul(self.weight_hh, hx))
コード例 #8
0
class IndRNNCell(nn.Module):
    """
    IndRNN Cell computes:

        $$h_t = \sigma(w_{ih} x_t + b_{ih}  +  w_{hh} (*) h_{(t-1)})$$

    \sigma is sigmoid or relu

    hyper-params:

        1. hidden_size
        2. input_size
        3. bias: true or false
        4. act: the nonlinearity function ("tanh", "relu", "sigmoid")
        5. hidden_min_abs & hidden_max_abs
        6. reccurent_only: only computes the reccurent part for faster computation.
        7. init: how to initialize the params. Default norm for N(0,1/\sqrt(size)); constant; uniform; orth
        8. gradient_clip: `(min,max)` or `bound`

    inputs:

        1. Input: (batch, input_size)
        2. Hidden: (batch, hidden_size)

        batch first by default

    output:

        1. output: (batch, hidden_size)
        1. hidden state: (batch, hidden_size)

    params:

        1. weight_ih: (hidden_size,input_size)
        2. weight_hh: (1,hidden_size)
        3. bias_ih: (1,hidden_size) or None

    usage:

        >>> cell = IndRNNCell(100,128)
        >>> Input = torch.randn(32,100)
        >>> Hidden = torch.randn(32,128)
        >>> _, h = cell(Input, Hidden)

    """
    def __init__(self,
                 input_size,
                 hidden_size,
                 bias=True,
                 act="relu",
                 hidden_min_abs=0,
                 hidden_max_abs=2,
                 reccurent_only=False,
                 gradient_clip=None,
                 init_ih="norm",
                 input_weight_initializer=None,
                 recurrent_weight_initializer=None,
                 name="Default",
                 debug=False):

        super(IndRNNCell, self).__init__()
        self.hidden_max_abs = hidden_max_abs
        self.hidden_min_abs = hidden_min_abs
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.act = act
        self.reccurent_only = reccurent_only
        self.init_ih = init_ih
        self.input_weight_initializer = input_weight_initializer
        self.recurrent_weight_initializer = recurrent_weight_initializer
        self.name = name
        self.debug = debug
        if self.act is None:
            self.activation = F.tanh
        elif self.act == "relu":
            self.activation = F.relu
        elif self.act == "sigmoid":
            self.activation = F.sigmoid
        elif self.act == "tanh":
            self.activation = None
        else:
            raise RuntimeError(f"Unknown activation type: {self.nonlinearity}")
        if not self.reccurent_only:
            self.weight_ih = Parameter(torch.Tensor(hidden_size, input_size))
        else:
            self.register_parameter('weight_ih', None)
        self.weight_hh = Parameter(torch.Tensor(1, hidden_size))
        if bias:
            self.bias_ih = Parameter(torch.Tensor(1, hidden_size))
        else:
            self.register_parameter('bias_ih', None)

        if gradient_clip:
            if isinstance(gradient_clip, tuple):
                assert len(gradient_clip) == 2
                min_g, max_g = gradient_clip
            else:
                max_g = gradient_clip
                min_g = -max_g
            if not self.reccurent_only:
                self.weight_ih.register_hook(
                    lambda x: x.clamp(min=min_g, max=max_g))
            self.weight_hh.register_hook(
                lambda x: x.clamp(min=min_g, max=max_g))
            if bias:
                self.bias_ih.register_hook(
                    lambda x: x.clamp(min=min_g, max=max_g))
        # debug
        # if self.debug:
        #    pdb.set_trace()
        self.reset_parameters()

    def reset_parameters(self):
        for name, weight in self.named_parameters():
            if "bias" in name:
                weight.data.zero_()
            elif "weight" in name:
                if self.input_weight_initializer and "weight_ih" in name:
                    self.input_weight_initializer(weight)
                elif self.recurrent_weight_initializer and "weight_hh" in name:
                    self.recurrent_weight_initializer(weight)
                elif "constant" in self.init_ih:
                    nn.init.constant_(weight, 1.0)
            else:
                weight.data.normal_(0, 0.01)
        self.clip_weight()

    def clip_weight(self):
        if self.hidden_min_abs:
            abs_kernel = torch.abs(
                self.weight_hh.data).clamp(min=self.hidden_min_abs)
            self.weight_hh.data = torch.sign(self.weight_hh.data) * abs_kernel
        if self.hidden_max_abs:
            self.weight_hh.data = self.weight_hh.clamp(
                min=-self.hidden_max_abs, max=self.hidden_max_abs)
        self.weight_hh.data.detach_()

    def forward(self, Input, Hidden):

        if not self.reccurent_only:
            h = F.linear(Input, self.weight_ih) + self.weight_hh * Hidden
            if self.bias:
                h += self.bias_ih
        else:
            h = Input + self.weight_hh * Hidden
        if self.activation:
            h = self.activation(h)
        return h, h
コード例 #9
0
ファイル: sbp.py プロジェクト: kckishan/bbdrop
class SBP(Gate):
    def __init__(self,
                 num_gates,
                 min_log=-20.0,
                 max_log=0.0,
                 thres=1.0,
                 kl_scale=1.0):
        super(SBP, self).__init__(num_gates)
        self.min_log = min_log
        self.max_log = max_log
        self.thres = thres
        self.kl_scale = kl_scale
        self.mu = Parameter(torch.zeros(num_gates))
        self.log_sigma = Parameter(-5 * torch.ones(num_gates))

    def _mean_truncated_log_normal(self):
        a, b = self.min_log, self.max_log
        mu = self.mu.clamp(-20.0, 5.0)
        log_sigma = self.log_sigma.clamp(-20.0, 5.0)
        sigma = log_sigma.exp()

        alpha = (a - mu) / sigma
        beta = (b - mu) / sigma
        z = phi(beta) - phi(alpha)
        mean = erfcx(
            (sigma - beta) / math.sqrt(2.0)) * torch.exp(b - beta * beta / 2)
        mean = mean - erfcx((sigma - alpha) /
                            math.sqrt(2.0)) * torch.exp(a - alpha * alpha / 2)
        mean = mean / (2 * z)
        return mean

    def _snr_truncated_log_normal(self):
        a, b = self.min_log, self.max_log
        mu = self.mu.clamp(-20.0, 5.0)
        log_sigma = self.log_sigma.clamp(-20.0, 5.0)
        sigma = log_sigma.exp()

        alpha = (a - mu) / sigma
        beta = (b - mu) / sigma
        z = phi(beta) - phi(alpha)
        ratio = erfcx((sigma - beta) /
                      math.sqrt(2.0)) * torch.exp((b - mu) - beta**2 / 2.0)
        ratio = ratio - erfcx((sigma - alpha) / math.sqrt(2.0)) * torch.exp(
            (a - mu) - alpha**2 / 2.0)
        denominator = 2 * z * erfcx(
            (2.0 * sigma - beta) / math.sqrt(2.0)) * torch.exp(2.0 * (b - mu) -
                                                               beta**2 / 2.0)
        denominator = denominator - 2*z*erfcx((2.0*sigma-alpha)/math.sqrt(2.0))\
                                       *torch.exp(2.0*(a-mu)-alpha**2/2.0)
        denominator = denominator - ratio**2
        ratio = ratio / torch.sqrt(denominator)

        return ratio

    def _sample_truncated_normal(self):
        a, b = self.min_log, self.max_log
        mu = self.mu.clamp(-20.0, 5.0)
        log_sigma = self.log_sigma.clamp(-20.0, 5.0)
        sigma = torch.exp(log_sigma)

        alpha = (a - mu) / sigma
        beta = (b - mu) / sigma
        u = torch.rand(self.num_gates)
        if torch.cuda.is_available():
            u = u.cuda()
        gamma = phi(alpha) + u * (phi(beta) - phi(alpha))
        return (phi_inv(gamma.clamp(1e-5, 1 - 1e-5)) * sigma + mu).clamp(
            a, b).exp()

    def get_mask(self):
        snr = self._snr_truncated_log_normal()
        return (snr > self.thres).float()

    def get_weight(self, x):
        if self.training:
            z = self._sample_truncated_normal()
        else:
            Etheta = self._mean_truncated_log_normal()
            mask = self.get_mask()
            z = Etheta * mask
        return z

    def get_reg(self, base):
        a, b = self.min_log, self.max_log
        mu = self.mu.clamp(-20.0, 5.0)
        log_sigma = self.log_sigma.clamp(-20.0, 5.0)
        sigma = log_sigma.exp()

        alpha = (a - mu) / sigma
        beta = (b - mu) / sigma
        z = phi(beta) - phi(alpha)

        def pdf(x):
            return torch.exp(-x * x / 2.0) / math.sqrt(2.0 * math.pi)

        kld = -log_sigma - torch.log(z) - (alpha * pdf(alpha) -
                                           beta * pdf(beta)) / (2.0 * z)
        kld += math.log(self.max_log -
                        self.min_log) - math.log(2.0 * math.pi * math.e) / 2.0
        kld = self.kl_scale * kld.sum()
        return kld