Exemple #1
0
 def __init__(self, in_features, out_features, bias=True, prior = None):
     super(LinearVB, self).__init__()
     self.type_layer = "linear"
     # device= conf_a.device, dtype= conf_a.dtype,
     self.in_features = in_features
     self.out_features = out_features
     self.bias = bias
     self.posterior_mean = False # Flag to know if we sample from the posterior mean or we actually sample
     
     ## If no prior is specified we just create it ourselves
     if (type(prior) == type (None)):
         prior = Vil.Prior(0.5, np.log(0.1),np.log(0.5))
     prior =  prior.get_standarized_Prior(in_features)
     self.prior = prior 
     
     """
     Mean and rhos of the parameters
     """
     self.mu_weight = Parameter(torch.Tensor(out_features, in_features))# , requires_grad=True
     self.rho_weight = Parameter(torch.Tensor(out_features, in_features))
     
     if bias:
         self.rho_bias = Parameter(torch.Tensor(out_features))
         self.mu_bias = Parameter(torch.Tensor(out_features))
     else:
         self.register_parameter('rho_bias', None)
         self.register_parameter('mu_bias', None)
         
     """
     The sampled weights
     """
     self.weight = torch.Tensor(out_features, in_features)
     if bias:
         self.bias = torch.Tensor(out_features,1)
     else:
         self.register_parameter('bias', None)
     
     if(0):
         print ("linear bias device: ",self.bias.device)
         print ("linear weights device: ",self.weight.device)
         print ("linear bias mu device: ",self.mu_bias.device)
         print ("linear bias rho device: ",self.rho_bias.device)
         
         print ("linear weights mu  device: ",self.mu_weight.device)
         print ("linear weights rho device: ",self.rho_weight.device)
         
     ## Initialize the Variational variables
     self.reset_parameters()
     self.sample_posterior()
Exemple #2
0
class BoundaryDetector(nn.Module):
    '''
    Boundary Detector,边界检测模块
    '''

    def __init__(self, i_features, h_features, s_features, inplace=False):
        super(BoundaryDetector, self).__init__()
        self.inplace = inplace
        self.Wsi = Parameter(torch.Tensor(s_features, i_features))
        self.Wsh = Parameter(torch.Tensor(s_features, h_features))
        self.bias = Parameter(torch.Tensor(s_features))
        self.vs = Parameter(torch.Tensor(1, s_features))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.Wsi.size(1))
        self.Wsi.data.uniform_(-stdv, stdv)
        self.Wsh.data.uniform_(-stdv, stdv)
        self.bias.data.uniform_(-stdv, stdv)
        self.vs.data.uniform_(-stdv, stdv)

    def forward(self, x, h):
        z = F.linear(x, self.Wsi) + F.linear(h, self.Wsh) + self.bias
        z = F.sigmoid(F.linear(z, self.vs))
        return BinaryGate.apply(z, self.training, self.inplace)

    def __repr__(self):
        return self.__class__.__name__
Exemple #3
0
 def __init__(self, i_features, h_features, s_features, inplace=False):
     super(BoundaryDetector, self).__init__()
     self.inplace = inplace
     self.Wsi = Parameter(torch.Tensor(s_features, i_features))
     self.Wsh = Parameter(torch.Tensor(s_features, h_features))
     self.bias = Parameter(torch.Tensor(s_features))
     self.vs = Parameter(torch.Tensor(1, s_features))
     self.reset_parameters()
Exemple #4
0
class Bilinear(Module):
    r"""Applies a bilinear transformation to the incoming data:
    :math:`y = x_1 A x_2 + b`

    Args:
        in1_features: size of each first input sample
        in2_features: size of each second input sample
        out_features: size of each output sample
        bias: If set to False, the layer will not learn an additive bias.
            Default: ``True``

    Shape:
        - Input: :math:`(N, *, \text{in1_features})`, :math:`(N, *, \text{in2_features})`
          where :math:`*` means any number of additional dimensions. All but the last
          dimension of the inputs should be the same.
        - Output: :math:`(N, *, \text{out_features})` where all but the last dimension
          are the same shape as the input.

    Attributes:
        weight: the learnable weights of the module of shape
            `(out_features x in1_features x in2_features)`
        bias:   the learnable bias of the module of shape `(out_features)`

    Examples::

        >>> m = nn.Bilinear(20, 30, 40)
        >>> input1 = torch.randn(128, 20)
        >>> input2 = torch.randn(128, 30)
        >>> output = m(input1, input2)
        >>> print(output.size())
    """

    def __init__(self, in1_features, in2_features, out_features, bias=True):
        super(Bilinear, self).__init__()
        self.in1_features = in1_features
        self.in2_features = in2_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in1_features, in2_features))

        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input1, input2):
        return F.bilinear(input1, input2, self.weight, self.bias)

    def extra_repr(self):
        return 'in1_features={}, in2_features={}, out_features={}, bias={}'.format(
            self.in1_features, self.in2_features, self.out_features, self.bias is not None
        )
Exemple #5
0
class Bilinear(Module):
    r"""Applies a bilinear transformation to the incoming data:
    :math:`y = x_1 A x_2 + b`

    Args:
        in1_features: size of each first input sample
        in2_features: size of each second input sample
        out_features: size of each output sample
        bias: If set to False, the layer will not learn an additive bias.
            Default: ``True``

    Shape:
        - Input: :math:`(N, in1\_features)`, :math:`(N, in2\_features)`
        - Output: :math:`(N, out\_features)`

    Attributes:
        weight: the learnable weights of the module of shape
            `(out_features x in1_features x in2_features)`
        bias:   the learnable bias of the module of shape `(out_features)`

    Examples::

        >>> m = nn.Bilinear(20, 30, 40)
        >>> input1 = torch.randn(128, 20)
        >>> input2 = torch.randn(128, 30)
        >>> output = m(input1, input2)
        >>> print(output.size())
    """

    def __init__(self, in1_features, in2_features, out_features, bias=True):
        super(Bilinear, self).__init__()
        self.in1_features = in1_features
        self.in2_features = in2_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in1_features, in2_features))

        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input1, input2):
        return F.bilinear(input1, input2, self.weight, self.bias)

    def __repr__(self):
        return self.__class__.__name__ + '(' \
            + 'in1_features=' + str(self.in1_features) \
            + ', in2_features=' + str(self.in2_features) \
            + ', out_features=' + str(self.out_features) \
            + ', bias=' + str(self.bias is not None) + ')'
 def __init__(self, n_features, context):
     # should we handle batch_first=True?
     super(Lookahead, self).__init__()
     self.n_features = n_features
     self.weight = Parameter(torch.Tensor(n_features, context + 1))
     assert context > 0
     self.context = context
     self.register_parameter('bias', None)
     self.init_parameters()
class BilinearMatrixAttention(MatrixAttention):
    """
    Computes attention between two matrices using a bilinear attention function.  This function has
    a matrix of weights ``W`` and a bias ``b``, and the similarity between the two matrices ``X``
    and ``Y`` is computed as ``X W Y^T + b``.

    Parameters
    ----------
    matrix_1_dim : ``int``
        The dimension of the matrix ``X``, described above.  This is ``X.size()[-1]`` - the length
        of the vector that will go into the similarity computation.  We need this so we can build
        the weight matrix correctly.
    matrix_2_dim : ``int``
        The dimension of the matrix ``Y``, described above.  This is ``Y.size()[-1]`` - the length
        of the vector that will go into the similarity computation.  We need this so we can build
        the weight matrix correctly.
    activation : ``Activation``, optional (default=linear (i.e. no activation))
        An activation function applied after the ``X W Y^T + b`` calculation.  Default is no
        activation.
    use_input_biases : ``bool``, optional (default = False)
        If True, we add biases to the inputs such that the final computation
        is equivelent to the original bilinear matrix multiplication plus a
        projection of both inputs.
    """
    def __init__(self,
                 matrix_1_dim: int,
                 matrix_2_dim: int,
                 activation: Activation = None,
                 use_input_biases: bool = False) -> None:
        super().__init__()
        if use_input_biases:
            matrix_1_dim += 1
            matrix_2_dim += 1
        self._weight_matrix = Parameter(torch.Tensor(matrix_1_dim, matrix_2_dim))

        self._bias = Parameter(torch.Tensor(1))
        self._activation = activation or Activation.by_name('linear')()
        self._use_input_biases = use_input_biases
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self._weight_matrix)
        self._bias.data.fill_(0)

    @overrides
    def forward(self, matrix_1: torch.Tensor, matrix_2: torch.Tensor) -> torch.Tensor:

        if self._use_input_biases:
            bias1 = matrix_1.new_ones(matrix_1.size()[:-1] + (1,))
            bias2 = matrix_2.new_ones(matrix_2.size()[:-1] + (1,))

            matrix_1 = torch.cat([matrix_1, bias1], -1)
            matrix_2 = torch.cat([matrix_2, bias2], -1)
        intermediate = torch.matmul(matrix_1.unsqueeze(1), self._weight_matrix.unsqueeze(0))
        final = torch.matmul(intermediate, matrix_2.unsqueeze(1).transpose(2, 3))
        return self._activation(final.squeeze(1) + self._bias)
Exemple #8
0
class LinearSimilarity(SimilarityFunction):
    """
    This similarity function performs a dot product between a vector of weights and some
    combination of the two input vectors, followed by an (optional) activation function.  The
    combination used is configurable.

    If the two vectors are ``x`` and ``y``, we allow the following kinds of combinations: ``x``,
    ``y``, ``x*y``, ``x+y``, ``x-y``, ``x/y``, where each of those binary operations is performed
    elementwise.  You can list as many combinations as you want, comma separated.  For example, you
    might give ``x,y,x*y`` as the ``combination`` parameter to this class.  The computed similarity
    function would then be ``w^T [x; y; x*y] + b``, where ``w`` is a vector of weights, ``b`` is a
    bias parameter, and ``[;]`` is vector concatenation.

    Note that if you want a bilinear similarity function with a diagonal weight matrix W, where the
    similarity function is computed as `x * w * y + b` (with `w` the diagonal of `W`), you can
    accomplish that with this class by using "x*y" for `combination`.

    Parameters
    ----------
    tensor_1_dim : ``int``
        The dimension of the first tensor, ``x``, described above.  This is ``x.size()[-1]`` - the
        length of the vector that will go into the similarity computation.  We need this so we can
        build weight vectors correctly.
    tensor_2_dim : ``int``
        The dimension of the second tensor, ``y``, described above.  This is ``y.size()[-1]`` - the
        length of the vector that will go into the similarity computation.  We need this so we can
        build weight vectors correctly.
    combination : ``str``, optional (default="x,y")
        Described above.
    activation : ``Activation``, optional (default=linear (i.e. no activation))
        An activation function applied after the ``w^T * [x;y] + b`` calculation.  Default is no
        activation.
    """
    def __init__(self,
                 tensor_1_dim: int,
                 tensor_2_dim: int,
                 combination: str = 'x,y',
                 activation: Activation = None) -> None:
        super(LinearSimilarity, self).__init__()
        self._combination = combination
        combined_dim = util.get_combined_dim(combination, [tensor_1_dim, tensor_2_dim])
        self._weight_vector = Parameter(torch.Tensor(combined_dim))
        self._bias = Parameter(torch.Tensor(1))
        self._activation = activation or Activation.by_name('linear')()
        self.reset_parameters()

    def reset_parameters(self):
        std = math.sqrt(6 / (self._weight_vector.size(0) + 1))
        self._weight_vector.data.uniform_(-std, std)
        self._bias.data.fill_(0)

    @overrides
    def forward(self, tensor_1: torch.Tensor, tensor_2: torch.Tensor) -> torch.Tensor:
        combined_tensors = util.combine_tensors(self._combination, [tensor_1, tensor_2])
        dot_product = torch.matmul(combined_tensors, self._weight_vector)
        return self._activation(dot_product + self._bias)
Exemple #9
0
 def __init__(self, in_features, out_features, bias=True):
     super(Linear, self).__init__()
     self.in_features = in_features
     self.out_features = out_features
     self.weight = Parameter(torch.Tensor(out_features, in_features))
     if bias:
         self.bias = Parameter(torch.Tensor(out_features))
     else:
         self.register_parameter('bias', None)
     self.reset_parameters()
Exemple #10
0
class Linear(Module):
    r"""Applies a linear transformation to the incoming data: :math:`y = Ax + b`

    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        bias: If set to False, the layer will not learn an additive bias.
            Default: ``True``

    Shape:
        - Input: :math:`(N, *, in\_features)` where :math:`*` means any number of
          additional dimensions
        - Output: :math:`(N, *, out\_features)` where all but the last dimension
          are the same shape as the input.

    Attributes:
        weight: the learnable weights of the module of shape
            `(out_features x in_features)`
        bias:   the learnable bias of the module of shape `(out_features)`

    Examples::

        >>> m = nn.Linear(20, 30)
        >>> input = torch.randn(128, 20)
        >>> output = m(input)
        >>> print(output.size())
    """

    def __init__(self, in_features, out_features, bias=True):
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input):
        return F.linear(input, self.weight, self.bias)

    def __repr__(self):
        return self.__class__.__name__ + '(' \
            + 'in_features=' + str(self.in_features) \
            + ', out_features=' + str(self.out_features) \
            + ', bias=' + str(self.bias is not None) + ')'
Exemple #11
0
 def __init__(self,
              tensor_1_dim: int,
              tensor_2_dim: int,
              combination: str = 'x,y',
              activation: Activation = Activation.by_name('linear')()) -> None:
     super(LinearSimilarity, self).__init__()
     self._combination = combination
     combined_dim = util.get_combined_dim(combination, [tensor_1_dim, tensor_2_dim])
     self._weight_vector = Parameter(torch.Tensor(combined_dim))
     self._bias = Parameter(torch.Tensor(1))
     self._activation = activation
     self.reset_parameters()
Exemple #12
0
    def __init__(self,
                 tensor_1_dim: int,
                 tensor_2_dim: int,
                 combination: str = 'x,y',
                 activation: Activation = None,
                 prior = None) -> None:
        super(LinearSimilarityVB, self).__init__()
        self._combination = combination
        combined_dim = util.get_combined_dim(combination, [tensor_1_dim, tensor_2_dim])
        
        self.posterior_mean = False # Flag to know if we sample from the posterior mean or we actually sample
        
        ## If no prior is specified we just create it ourselves
        if (type(prior) == type (None)):
            prior = Vil.Prior(0.5, np.log(0.1),np.log(0.5))
        
        size_combination = int(torch.Tensor(combined_dim).size()[0])
#        print ("Combination size: ", size_combination)
        prior =  prior.get_standarized_Prior(size_combination)
        self.prior = prior 
        
        """
        Mean and rhos of the parameters
        """
        self.mu_weight = Parameter(torch.Tensor(combined_dim))# , requires_grad=True
        self.rho_weight = Parameter(torch.Tensor(combined_dim))

        self.rho_bias = Parameter(torch.Tensor(1))
        self.mu_bias = Parameter(torch.Tensor(1))
            
        """
        The sampled weights
        """
        self.weight = torch.Tensor(combined_dim)
        self.bias = torch.Tensor(1)
        
        self._activation = activation or Activation.by_name('linear')()
        
        ## Initialize the Variational variables
        self.reset_parameters()
    def __init__(self, in_features, out_features, cuda=False, init_weight=None, init_bias=None, clip_var=None):

        super(LinearGroupNJ, self).__init__()
        self.cuda = cuda
        self.in_features = in_features
        self.out_features = out_features
        self.clip_var = clip_var
        self.deterministic = False  # flag is used for compressed inference
        # trainable params according to Eq.(6)
        # dropout params
        self.z_mu = Parameter(torch.Tensor(in_features))
        self.z_logvar = Parameter(torch.Tensor(in_features))  # = z_mu^2 * alpha
        # weight params
        self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
        self.weight_logvar = Parameter(torch.Tensor(out_features, in_features))

        self.bias_mu = Parameter(torch.Tensor(out_features))
        self.bias_logvar = Parameter(torch.Tensor(out_features))

        # init params either random or with pretrained net
        self.reset_parameters(init_weight, init_bias)

        # activations for kl
        self.sigmoid = nn.Sigmoid()
        self.softplus = nn.Softplus()

        # numerical stability param
        self.epsilon = 1e-8
    def __init__(self,
                 matrix_1_dim: int,
                 matrix_2_dim: int,
                 activation: Activation = None,
                 use_input_biases: bool = False) -> None:
        super().__init__()
        if use_input_biases:
            matrix_1_dim += 1
            matrix_2_dim += 1
        self._weight_matrix = Parameter(torch.Tensor(matrix_1_dim, matrix_2_dim))

        self._bias = Parameter(torch.Tensor(1))
        self._activation = activation or Activation.by_name('linear')()
        self._use_input_biases = use_input_biases
        self.reset_parameters()
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding,
                 groups, bias, init_weight, init_bias, cuda=False, clip_var=None):
        super(_ConvNdGroupNJ, self).__init__()
        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.transposed = transposed
        self.output_padding = output_padding
        self.groups = groups

        self.cuda = cuda
        self.clip_var = clip_var
        self.deterministic = False  # flag is used for compressed inference

        if transposed:
            self.weight_mu = Parameter(torch.Tensor(
                in_channels, out_channels // groups, *kernel_size))
            self.weight_logvar = Parameter(torch.Tensor(
                in_channels, out_channels // groups, *kernel_size))
        else:
            self.weight_mu = Parameter(torch.Tensor(
                out_channels, in_channels // groups, *kernel_size))
            self.weight_logvar = Parameter(torch.Tensor(
                out_channels, in_channels // groups, *kernel_size))

        self.bias_mu = Parameter(torch.Tensor(out_channels))
        self.bias_logvar = Parameter(torch.Tensor(out_channels))

        self.z_mu = Parameter(torch.Tensor(self.out_channels))
        self.z_logvar = Parameter(torch.Tensor(self.out_channels))

        self.reset_parameters(init_weight, init_bias)

        # activations for kl
        self.sigmoid = nn.Sigmoid()
        self.softplus = nn.Softplus()
        # numerical stability param
        self.epsilon = 1e-8
Exemple #16
0
class Linear(nn.Module):
    """Custom Linear layer which allows for sharing weights (e.g. with an
    nn.Embedding layer).
    """
    def __init__(self, in_features, out_features, bias=True,
                 shared_weight=None):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.shared = shared_weight is not None

        # init weight
        if not self.shared:
            self.weight = Parameter(torch.Tensor(out_features, in_features))
        else:
            if (shared_weight.size(0) != out_features or
                    shared_weight.size(1) != in_features):
                raise RuntimeError('wrong dimensions for shared weights')
            self.weight = shared_weight

        # init bias
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        if not self.shared:
            # weight is shared so don't overwrite it
            self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input):
        weight = self.weight
        if self.shared:
            # detach weight to prevent gradients from changing weight
            # (but need to detach every time so weights are up to date)
            weight = weight.detach()
        return F.linear(input, weight, self.bias)

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
            + str(self.in_features) + ' -> ' \
            + str(self.out_features) + ')'
Exemple #17
0
    def __init__(self, in_features, out_features, bias=True,
                 shared_weight=None):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.shared = shared_weight is not None

        # init weight
        if not self.shared:
            self.weight = Parameter(torch.Tensor(out_features, in_features))
        else:
            if (shared_weight.size(0) != out_features or
                    shared_weight.size(1) != in_features):
                raise RuntimeError('wrong dimensions for shared weights')
            self.weight = shared_weight

        # init bias
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
class Lookahead(nn.Module):
    # Wang et al 2016 - Lookahead Convolution Layer for Unidirectional Recurrent Neural Networks
    # input shape - sequence, batch, feature - TxNxH
    # output shape - same as input
    def __init__(self, n_features, context):
        # should we handle batch_first=True?
        super(Lookahead, self).__init__()
        self.n_features = n_features
        self.weight = Parameter(torch.Tensor(n_features, context + 1))
        assert context > 0
        self.context = context
        self.register_parameter('bias', None)
        self.init_parameters()

    def init_parameters(self):  # what's a better way initialiase this layer?
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)

    def forward(self, input):
        seq_len = input.size(0)
        # pad the 0th dimension (T/sequence) with zeroes whose number = context
        # Once pytorch's padding functions have settled, should move to those.
        padding = torch.zeros(self.context, *(input.size()[1:])).type_as(input.data)
        x = torch.cat((input, Variable(padding)), 0)

        # add lookahead windows (with context+1 width) as a fourth dimension
        # for each seq-batch-feature combination
        x = [x[i:i + self.context + 1] for i in range(seq_len)]  # TxLxNxH - sequence, context, batch, feature
        x = torch.stack(x)
        x = x.permute(0, 2, 3, 1)  # TxNxHxL - sequence, batch, feature, context

        x = torch.mul(x, self.weight).sum(dim=3)
        return x

    def __repr__(self):
        return self.__class__.__name__ + '(' \
               + 'n_features=' + str(self.n_features) \
               + ', context=' + str(self.context) + ')'
class LinearGroupNJ(BayesianLayers):
    """Fully Connected Group Normal-Jeffrey's layer (aka Group Variational Dropout).

    References:
    [1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015).
    [2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017).
    [3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017).
    """

    def __init__(self, in_features, out_features, cuda=False, init_weight=None, init_bias=None, clip_var=None):

        super(LinearGroupNJ, self).__init__()
        self.cuda = cuda
        self.in_features = in_features
        self.out_features = out_features
        self.clip_var = clip_var
        self.deterministic = False  # flag is used for compressed inference
        # trainable params according to Eq.(6)
        # dropout params
        self.z_mu = Parameter(torch.Tensor(in_features))
        self.z_logvar = Parameter(torch.Tensor(in_features))  # = z_mu^2 * alpha
        # weight params
        self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
        self.weight_logvar = Parameter(torch.Tensor(out_features, in_features))

        self.bias_mu = Parameter(torch.Tensor(out_features))
        self.bias_logvar = Parameter(torch.Tensor(out_features))

        # init params either random or with pretrained net
        self.reset_parameters(init_weight, init_bias)

        # activations for kl
        self.sigmoid = nn.Sigmoid()
        self.softplus = nn.Softplus()

        # numerical stability param
        self.epsilon = 1e-8

    def reset_parameters(self, init_weight, init_bias):
        # init means
        stdv = 1. / math.sqrt(self.weight_mu.size(1))

        self.z_mu.data.normal_(1, 1e-2)

        if init_weight is not None:
            self.weight_mu.data = torch.Tensor(init_weight)
        else:
            self.weight_mu.data.normal_(0, stdv)

        if init_bias is not None:
            self.bias_mu.data = torch.Tensor(init_bias)
        else:
            self.bias_mu.data.fill_(0)

        # init logvars
        self.z_logvar.data.normal_(-9, 1e-2)
        self.weight_logvar.data.normal_(-9, 1e-2)
        self.bias_logvar.data.normal_(-9, 1e-2)

    def clip_variances(self):
        if self.clip_var:
            self.weight_logvar.data.clamp_(max=math.log(self.clip_var))
            self.bias_logvar.data.clamp_(max=math.log(self.clip_var))

    def get_log_dropout_rates(self):
        log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon)
        return log_alpha

    def compute_posterior_params(self):
        weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp()
        self.post_weight_var = self.z_mu.pow(2) * weight_var + z_var * self.weight_mu.pow(2) + z_var * weight_var
        self.post_weight_mu = self.weight_mu * self.z_mu
        # print("self.z_mu.pow(2): ", self.z_mu.pow(2).size())
        # print("weight_var: ", weight_var.size())
        # print("z_var: ", z_var.size())
        # print("self.weight_mu.pow(2): ", self.weight_mu.pow(2).size())
        # print("weight_var: ", weight_var.size())
        # print("post_weight_mu: ", self.post_weight_mu.size())
        # print("post_weight_var: ", self.post_weight_var.size())
        return self.post_weight_mu, self.post_weight_var

    def forward(self, x):
        if self.deterministic:
            assert self.training == False, "Flag deterministic is True. This should not be used in training."
            return F.linear(x, self.post_weight_mu, self.bias_mu)

        batch_size = x.size()[0]
        # compute z
        # note that we reparametrise according to [2] Eq. (11) (not [1])
        z = reparametrize(self.z_mu.repeat(batch_size, 1), self.z_logvar.repeat(batch_size, 1), sampling=self.training,
                          cuda=self.cuda)

        # apply local reparametrisation trick see [1] Eq. (6)
        # to the parametrisation given in [3] Eq. (6)
        xz = x * z
        mu_activations = F.linear(xz, self.weight_mu, self.bias_mu)
        var_activations = F.linear(xz.pow(2), self.weight_logvar.exp(), self.bias_logvar.exp())

        return reparametrize(mu_activations, var_activations.log(), sampling=self.training, cuda=self.cuda)

    def kl_divergence(self):
        # KL(q(z)||p(z))
        # we use the kl divergence approximation given by [2] Eq.(14)
        k1, k2, k3 = 0.63576, 1.87320, 1.48695
        log_alpha = self.get_log_dropout_rates()
        KLD = -torch.sum(k1 * self.sigmoid(k2 + k3 * log_alpha) - 0.5 * self.softplus(-log_alpha) - k1)

        # KL(q(w|z)||p(w|z))
        # we use the kl divergence given by [3] Eq.(8)
        KLD_element = -0.5 * self.weight_logvar + 0.5 * (self.weight_logvar.exp() + self.weight_mu.pow(2)) - 0.5
        KLD += torch.sum(KLD_element)

        # KL bias
        KLD_element = -0.5 * self.bias_logvar + 0.5 * (self.bias_logvar.exp() + self.bias_mu.pow(2)) - 0.5
        KLD += torch.sum(KLD_element)

        return KLD

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'
Exemple #20
0
 def __init__(self, channel):
     super(ParamFree, self).__init__()
     self.sig = nn.Sigmoid()
     self.learn = Parameter(torch.zeros(3))
     self.batch_norm = nn.BatchNorm2d(channel)
 def __init__(self, dim, device):
     super(InnerModel, self).__init__()
     self.bias = Parameter(torch.FloatTensor([1.0] * dim).to(device))
Exemple #22
0
class CustomEmbedding(torch.nn.Module):
    """
        Memory efficient way to compute weighted EmbeddingBag
    """
    def __init__(self,
                 num_embeddings,
                 embedding_dim,
                 max_norm=None,
                 norm_type=2,
                 scale_grad_by_freq=False,
                 sparse=False,
                 device="cuda:0"):
        """
            Args:
                num_embeddings: int: vocalubary size
                embedding_dim: int: dimension for embeddings
                padding_idx: int: index for <PAD>; embedding is not updated
                max_norm: 
                norm_type: int: default: 2
                scale_grad_by_freq: boolean: True/False
                sparse: boolean: sparse or dense gradients
        """
        super(CustomEmbedding, self).__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.padding_idx = num_embeddings
        if self.padding_idx is not None:
            self.num_embeddings = num_embeddings + 1
        self.max_norm = max_norm
        self.norm_type = norm_type
        self.scale_grad_by_freq = scale_grad_by_freq
        self.weight = Parameter(
            torch.Tensor(self.num_embeddings, embedding_dim))
        self.sparse = sparse
        self.device = device
        self.reset_parameters()

    def reset_parameters(self):
        """
            Reset weights
        """
        self.weight.data.normal_(0, 1)
        if self.padding_idx is not None:
            self.weight.data[self.padding_idx].fill_(0)

    def to(self):
        super().to(self.device)

    def forward(self, batch_data):
        """
            Forward pass for embedding layer
            Arguments
                ----------
                batch_data: dict
                    {'X': torch.LongTensor (BxN),
                     'X_w': torch.FloatTensor (BxN)}
                    'X': Feature indices
                    'X_w': Feature weights
            Returns:
                ----------
                torch.Tensor
                    embedding for each sample B x embedding_dims
        """
        features = batch_data['X'].to(self.device)
        weights = batch_data['X_w'].to(self.device)
        out = F.embedding(features, self.weight, self.padding_idx,
                          self.max_norm, self.norm_type,
                          self.scale_grad_by_freq, self.sparse)
        out = weights.unsqueeze(1).bmm(out).squeeze()
        return out

    def get_weights(self):
        return self.weight.detach().cpu().numpy()

    def __repr__(self):
        s = '{name}({num_embeddings}, {embedding_dim}, {device}'
        if self.padding_idx is not None:
            s += ', padding_idx={padding_idx}'
        if self.max_norm is not None:
            s += ', max_norm={max_norm}'
        if self.norm_type != 2:
            s += ', norm_type={norm_type}'
        if self.scale_grad_by_freq is not False:
            s += ', scale_grad_by_freq={scale_grad_by_freq}'
        if self.sparse is not False:
            s += ', sparse=True'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)

    def _init_(self, state_dict):
        keys = list(state_dict.keys())
        key = [x for x in keys if x.split(".")[-1] in ['weight']][0]
        weight = state_dict[key]
        self.weight.data.copy_(weight)
Exemple #23
0
    def __init__(
            self,
            size,
            num_actions=11,
            k=None,
            leg_x=2,  # distance between wheel and robot base (x coordinate)
            leg_y=2,  # distance between wheel and robot base (y coordinate)
            num_orientations=16,  # number of discrete orientations
            device=None,
            name=None,
            level_2_features=5,  # number of features for Level-2 representation
            level_3_features=10,  # number of features for Level-3 representation
            level_1_conv_features=[10, 30, 60],
            level_1_conv_kernels=[(5, 5), (3, 3), (3, 3)],
            level_1_conv_paddings=[2, 1, 1],
            level_2_conv_features=[90, 120],
            level_2_conv_kernels=[(5, 5), (3, 3)],
            level_2_conv_paddings=[2, 1],
            level_3_conv_features=[150],
            level_3_conv_kernels=[(3, 3)],
            level_3_conv_paddings=[1]):
        super(Abstraction_VIN_3D, self).__init__()
        self.size = size  # grid world size
        self.size_eff = size // 4  # size of each abstraction map

        self.level_2_features = level_2_features
        self.level_3_features = level_3_features
        self.features = 1 + level_2_features + level_3_features  # overall number of features of reward map (sum over all 3 levels)

        if name is None:
            self.name = 'Abstraction_VIN_3D_' + str(size)
        else:
            self.name = name
        print("Network name: ", self.name)
        self.device = device or torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        self.k = k or int(
            1.5 *
            k_values[self.size_eff])  # number of iterations within VI module
        self.leg_x = leg_x
        self.leg_y = leg_y
        self.num_orientations = num_orientations
        self.rotation_step_size = 2 * np.pi / num_orientations
        self.num_actions = num_actions

        # precompute orientation dependent local footprints
        self.local_footprints_1, self.local_footprints_2, self.local_footprints_3 = calculate_local_footprints_mulitlayer(
            leg_x, leg_y, num_orientations)
        self.local_footprints_1, self.local_footprints_2, self.local_footprints_3 = self.local_footprints_1.to(
            self.device), self.local_footprints_2.to(
                self.device), self.local_footprints_3.to(self.device)

        # learn abstract representations
        self.learn_level_2 = nn.Conv2d(in_channels=1,
                                       out_channels=level_2_features,
                                       kernel_size=(2, 2),
                                       stride=2,
                                       padding=0,
                                       bias=False)

        self.learn_level_3 = nn.Conv2d(in_channels=level_2_features,
                                       out_channels=level_3_features,
                                       kernel_size=(2, 2),
                                       stride=2,
                                       padding=0,
                                       bias=False)

        # process Level-1
        self.abstraction_1_pad = nn.ConstantPad2d(int(0.25 * self.size_eff), 0)
        self.level_1_conv = nn.ModuleList()

        self.level_1_conv.append(
            nn.Conv2d(in_channels=2,
                      out_channels=level_1_conv_features[0],
                      kernel_size=level_1_conv_kernels[0],
                      stride=1,
                      padding=level_1_conv_paddings[0],
                      bias=True))

        for i in range(1, len(level_1_conv_features)):
            self.level_1_conv.append(
                nn.Conv2d(in_channels=level_1_conv_features[i - 1],
                          out_channels=level_1_conv_features[i],
                          kernel_size=level_1_conv_kernels[i],
                          stride=1,
                          padding=level_1_conv_paddings[i],
                          bias=True))

        # process Level-2
        self.abstraction_2_pad = nn.ConstantPad2d(self.size_eff // 4, 0)
        self.level_2_conv = nn.ModuleList()

        self.level_2_conv.append(
            nn.Conv2d(in_channels=self.num_orientations + level_2_features + 1,
                      out_channels=level_2_conv_features[0],
                      kernel_size=level_2_conv_kernels[0],
                      stride=1,
                      padding=level_2_conv_paddings[0],
                      bias=True))

        for i in range(1, len(level_2_conv_features)):
            self.level_2_conv.append(
                nn.Conv2d(in_channels=level_2_conv_features[i - 1],
                          out_channels=level_2_conv_features[i],
                          kernel_size=level_2_conv_kernels[i],
                          stride=1,
                          padding=level_2_conv_paddings[i],
                          bias=True))

        # process Level-3
        self.level_3_conv = nn.ModuleList()
        self.level_3_conv.append(
            nn.Conv2d(in_channels=self.num_orientations * level_2_features +
                      level_3_features + 1,
                      out_channels=level_3_conv_features[0],
                      kernel_size=level_3_conv_kernels[0],
                      stride=1,
                      padding=level_3_conv_paddings[0],
                      bias=True))

        for i in range(1, len(level_3_conv_features)):
            self.level_3_conv.append(
                nn.Conv2d(in_channels=level_3_conv_features[i - 1],
                          out_channels=level_3_conv_features[i],
                          kernel_size=level_3_conv_kernels[i],
                          stride=1,
                          padding=level_3_conv_paddings[i],
                          bias=True))

        # generate reward map
        self.r1 = nn.Conv2d(in_channels=level_1_conv_features[-1],
                            out_channels=1 * self.num_orientations,
                            kernel_size=(1, 1),
                            stride=1,
                            padding=0,
                            bias=False)
        self.r2 = nn.Conv2d(in_channels=level_2_conv_features[-1],
                            out_channels=level_2_features *
                            self.num_orientations,
                            kernel_size=(1, 1),
                            stride=1,
                            padding=0,
                            bias=False)
        self.r3 = nn.Conv2d(in_channels=level_3_conv_features[-1],
                            out_channels=level_3_features *
                            self.num_orientations,
                            kernel_size=(1, 1),
                            stride=1,
                            padding=0,
                            bias=False)

        # value iteration
        self.q1 = nn.Conv3d(in_channels=1,
                            out_channels=num_actions,
                            kernel_size=(3, 3, 3),
                            stride=1,
                            padding=0,
                            bias=False)
        self.q2 = nn.Conv3d(in_channels=level_2_features,
                            out_channels=num_actions,
                            kernel_size=(3, 3, 3),
                            stride=1,
                            padding=0,
                            bias=False)
        self.q3 = nn.Conv3d(in_channels=level_3_features,
                            out_channels=num_actions,
                            kernel_size=(3, 3, 3),
                            stride=1,
                            padding=0,
                            bias=False)
        self.w = Parameter(torch.zeros(num_actions, 1, 3, 3, 3),
                           requires_grad=True)

        # reactive policy (map state values to action probabilities)
        self.fc = nn.Linear(in_features=11,
                            out_features=num_actions,
                            bias=False)
Exemple #24
0
 def __init__(self, m, n, k):
     super(IdentityLayer3D, self).__init__()
     self.weight = Parameter(torch.Tensor(m, n, k))
     torch.nn.init.xavier_normal_(self.weight)
    def reset_parameters1(self):
        fm_size = self.filtermap.size()
        fm_width = fm_size[2]
        fm_height = fm_size[1]
        fm_depth = fm_size[0]
        # not for 1x1 conv, do the padding on the spatial 
        if self.filtermap.size()[1] > 1 and self.filtermap.size()[2] > 1:
           self.fm_pad_width = fm_width + 1
           self.fm_pad_height = fm_height + 1
        #for 1x1 conv no padding on the spatial 
        else:
           self.fm_pad_width = fm_width
           self.fm_pad_height = fm_height

        self.fm_pad_depth = fm_depth*2
        #set the ids for extracting filters from filtermap
        out_channels = self.out_channels
        in_channels = self.in_channels // self.groups
        k_h = self.kernel_size[0]
        k_w = self.kernel_size[1]

        sample_y = self.sample_y
        sample_x = self.sample_x
        sample_c = self.sample_c

        stride_y = self.stride_y
        stride_x = self.stride_x
        stride_c = self.stride_c
        
 
        fm_depth = self.fm_pad_depth
        fm_height = self.fm_pad_height
        fm_width = self.fm_pad_width
        
        
        ids = (torch.Tensor(range(0,k_h*k_w)))
        tmp_count = 0
        for y in range(0,k_h):
            for x in range(0,k_w):
                ids[tmp_count] = y*fm_width+x
                tmp_count = tmp_count+1
 
        ids0 = ids
               
        #pdb.set_trace() 
        for c in range(1,in_channels):
            ids_c = ids0 + c*fm_height*fm_width
            ids = torch.cat((ids,ids_c),0)
        
        #ids0 = ids
        #for x in range(1, out_channels):
        #    ids = torch.cat((ids,ids0),0)
        #pdb.set_trace()
        ids0 = ids
        for y in range(0,sample_y):
            for x in range(0,sample_x):
                if y == 0 and x == 0:
                   continue
                ss = y*stride_y*fm_width + x*stride_x
                ids_ss = ids0+ss
                ids = torch.cat((ids,ids_ss),0)
        
        #pdb.set_trace() 
        ids0 = ids
        for c in range(1,sample_c):
            ids_c = ids0+c*stride_c*fm_height*fm_width
            ids = torch.cat((ids,ids_c),0)
        
        #pdb.set_trace()
        #ids = ids.long()
        #ids = ids.detach()

        #pdb.set_trace()
        ids = ids.long()
        self.ids = Parameter(ids)
        self.ids.requires_grad = False
Exemple #26
0
class Code2Vec(nn.Module):
    """the code2vec model"""
    def __init__(self, option):
        super(Code2Vec, self).__init__()
        self.option = option
        self.terminal_embedding = nn.Embedding(option.terminal_count,
                                               option.terminal_embed_size)
        self.path_embedding = nn.Embedding(option.path_count,
                                           option.path_embed_size)
        self.input_linear = nn.Linear(option.terminal_embed_size * 2 +
                                      option.path_embed_size,
                                      option.encode_size,
                                      bias=False)
        self.input_layer_norm = nn.LayerNorm(option.encode_size)

        if 0.0 < option.dropout_prob < 1.0:
            self.input_dropout = nn.Dropout(p=option.dropout_prob)
        else:
            self.input_dropout = None

        self.attention_parameter = Parameter(torch.nn.init.xavier_normal_(
            torch.zeros(option.encode_size,
                        1,
                        dtype=torch.float32,
                        requires_grad=True)).view(-1),
                                             requires_grad=True)

        if option.angular_margin_loss:
            self.output_linear = Parameter(
                torch.FloatTensor(option.label_count, option.encode_size))
            nn.init.xavier_uniform_(self.output_linear)
            self.cos_m = math.cos(option.angular_margin)
            self.sin_m = math.sin(option.angular_margin)
            self.th = math.cos(math.pi - option.angular_margin)
            self.mm = math.sin(math.pi -
                               option.angular_margin) * option.angular_margin
        else:
            self.output_linear = nn.Linear(option.encode_size,
                                           option.label_count,
                                           bias=True)
            self.output_linear.bias.data.fill_(0.0)

    def forward(self, starts, paths, ends, label):
        option = self.option

        # embedding
        embed_starts = self.terminal_embedding(starts)
        embed_paths = self.path_embedding(paths)
        embed_ends = self.terminal_embedding(ends)
        combined_context_vectors = torch.cat(
            (embed_starts, embed_paths, embed_ends), dim=2)

        # FNN, Layer Normalization, tanh
        combined_context_vectors = self.input_linear(combined_context_vectors)
        ccv_size = combined_context_vectors.size()
        combined_context_vectors = self.input_layer_norm(
            combined_context_vectors.view(-1,
                                          option.encode_size)).view(ccv_size)
        combined_context_vectors = torch.tanh(combined_context_vectors)

        # dropout
        if self.input_dropout is not None:
            combined_context_vectors = self.input_dropout(
                combined_context_vectors)

        # attention
        attn_mask = (starts > 0).float()
        attention = self.get_attention(combined_context_vectors, attn_mask)

        # code vector
        expanded_attn = attention.unsqueeze(-1).expand_as(
            combined_context_vectors)
        code_vector = torch.sum(torch.mul(combined_context_vectors,
                                          expanded_attn),
                                dim=1)

        if option.angular_margin_loss:
            # angular margin loss
            cosine = F.linear(F.normalize(code_vector),
                              F.normalize(self.output_linear))
            sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
            phi = cosine * self.cos_m - sine * self.sin_m
            phi = torch.where(cosine > 0, phi, cosine)
            one_hot = torch.zeros(cosine.size(), device=option.device)
            one_hot.scatter_(1, label.view(-1, 1).long(), 1)
            outputs = (one_hot * phi) + ((1.0 - one_hot) * cosine)
            outputs *= option.inverse_temp
        else:
            # FNN
            outputs = self.output_linear(code_vector)

        # if opt.training and opt.dropout_prob < 1.0:
        #     outputs = F.dropout(outputs, p=opt.dropout_prob, training=opt.training)

        return outputs, code_vector, attention

    def get_attention(self, vectors, mask):
        """calculate the attention of the (masked) context vetors. mask=1: meaningful value, mask=0: padded."""
        expanded_attn_param = self.attention_parameter.unsqueeze(0).expand_as(
            vectors)
        attn_ca = torch.mul(torch.sum(vectors * expanded_attn_param, dim=2),
                            mask) + (1 - mask) * NINF
        # attn_ca = torch.sum(vectors * expanded_attn_param, dim=2)
        # attn_ca[mask == 0] = NINF
        attention = F.softmax(attn_ca, dim=1)

        # expanded_attn_param = self.attention_parameter.unsqueeze(0).expand_as(vectors)
        # attn_ca = torch.mul(torch.sum(vectors * expanded_attn_param, dim=2), mask)
        # attn_max, _ = torch.max(attn_ca, dim=1, keepdim=True)
        # attn_exp = torch.mul(torch.exp(attn_ca - attn_max), mask)
        # attn_sum = torch.sum(attn_exp, dim=1, keepdim=True)
        # attention = torch.div(attn_exp, attn_sum.expand_as(attn_exp) + eps)

        return attention
class GNJConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True):

        # Init torch module
        super(GNJConv2d, self).__init__()

        # Init conv params
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = stride
        self.padding = padding
        self.dilation = dilation

        # Init filter latents
        self.weight_mu = Parameter(Tensor(out_channels, in_channels, *self.kernel_size))
        self.weight_logvar = Parameter(Tensor(out_channels, in_channels, *self.kernel_size))


        self.bias = bias
        self.bias_mu = Parameter(Tensor(out_channels)) if self.bias else None
        self.bias_logvar = Parameter(Tensor(out_channels)) if self.bias else None

        # Init prior latents
        self.z_mu = Parameter(Tensor(out_channels))
        self.z_logvar = Parameter(Tensor(out_channels))

        # Set initial parameters
        self._init_params()

        # for brevity to conv2d calls
        self.convargs = [self.stride, self.padding, self.dilation]

        # util activations
        self.sigmoid = Sigmoid()
        self.softplus = Softplus()


    # forward network pass
    def forward(self, x):

        # vanilla forward pass if testing
        if not self.training:
            post_weight_mu = self.weight_mu * self.z_mu[:, None, None, None]
            post_bias_mu = self.bias_mu * self.z_mu if (self.bias_mu is not None) else None
            return conv2d(x, post_weight_mu, post_bias_mu, *self.convargs)

        #batch_size = x.size()[0]

        # unpack mean/std
        mu = self.z_mu
        std = torch.exp(0.5 * self.z_logvar)

        # rsample: sample scale prior with reparam trick
        z = Normal(mu, std).rsample()[None, :, None, None]

        # weights and biases for variance estimation
        weight_v = self.weight_logvar.exp()
        bias_v = self.bias_logvar.exp() if self.bias else None

        # parameterise output distribution
        mu_out = conv2d(x, self.weight_mu, self.bias_mu, *self.convargs) * z
        var_out = conv2d(x**2, weight_v, bias_v, *self.convargs) * (z ** 2)

        # Init out, note multiplicative noise==variational dropout
        dist_out = Normal(mu_out, var_out.sqrt()).rsample()
        #dist_out = self.reparam(mu_out*z, (var_out * z.pow(2)).log())

        return dist_out

    def _init_params(self, weight=None, bias=None):

        n = self.in_channels * self.kernel_size[0] * self.kernel_size[1]
        thresh = 1/math.sqrt(n)

        # weights
        self.weight_logvar.data.normal_(-9, 1e-2)

        if weight is not None:
            self.weight_mu.data = weight
        else:
            self.weight_mu.data.uniform_(-thresh, thresh)


        if self.bias:
            # biases
            self.bias_logvar.data.normal_(-9, 1e-2)

            if bias is not None:
                self.bias_mu.data = bias
            else:
                self.bias_mu.data.fill_(0)

        # priors
        self.z_mu.data.normal_(1, 1e-2)
        self.z_logvar.data.normal_(-9, 1e-2)


    # shape,scale family reparameterization trick (rsample does this?)
    def reparam(self, mu, logvar):
        std = torch.exp(0.5 * logvar)

        # check for cuda
        #tenv = torch.cuda if cuda else torch

        # draw from normal
        eps = torch.FloatTensor(std.size()).normal_()

        return mu + eps * std

    # KL div for GNJ w. Normal approx posterior
    def kl_divergence(self):

        # for brevity in kl_scale
        sg = self.sigmoid
        sp = self.softplus

        # Approximation parameters. Molchanov et al.
        k1, k2, k3 = 0.63576, 1.87320, 1.48695
        log_alpha = self._log_alpha()
        kl_scale = torch.sum(0.5 * sp(-log_alpha) + k1 - k1 * sg(k2  + k3 * log_alpha))
        kl_weight = self._conditional_kl_div(self.weight_mu, self.weight_logvar)
        kl_bias = self._conditional_kl_div(self.bias_mu, self.bias_logvar) if self.bias else 0

        return kl_scale + kl_weight + kl_bias

    @staticmethod
    def _conditional_kl_div(mu, logvar):
        # (8) Weight/bias divergence KL(q(w|z)||p(w|z))
        kl_div = -0.5 * logvar + 0.5 * (logvar.exp() + mu ** 2 - 1)
        return torch.sum(kl_div)

    # effective dropout rate
    def _log_alpha(self):
        epsilon = 1e-8
        log_a = self.z_logvar  - torch.log(self.z_mu ** 2 + epsilon)
        return log_a
Exemple #28
0
class VBDSharedWeight(VBDClassification):
    def __init__(self,
                 dim_input,
                 dim_hidden,
                 dim_output,
                 prior_p,
                 thresh=0,
                 ard_init=1,
                 anneal=1.05,
                 anneal_max=100,
                 rw_max=20,
                 neuron=sb_neuron):
        super(VBDSharedWeight,
              self).__init__(dim_input, dim_output, prior_p, thresh, ard_init,
                             anneal, anneal_max, rw_max)

        self.dim_hidden = dim_hidden
        self.W = Parameter(torch.Tensor(dim_input, dim_hidden))
        self.b = Parameter(torch.Tensor(dim_hidden))

        self.V = Parameter(torch.Tensor(dim_hidden, dim_output))
        self.c = Parameter(torch.Tensor(dim_output))

        stdv = 1. / math.sqrt(self.dim_input)

        self.W.data.normal_(0, stdv)
        self.b.data.fill_(0)
        self.V.data.normal_(0, stdv)
        self.c.data.fill_(0)

        self.nonlinearity = F.relu
        self.neuron = neuron

    def forward(self,
                input,
                epoch=1,
                stochastic=False,
                testing=False,
                thresh=None,
                train_clip=False):
        logit_p = self.clip(self.logit_p)
        if thresh is None:
            thresh = self.thresh

        z2_mu = []
        for d in range(self.dim_output):
            mask = self.neuron(logit_p[:, d:(d + 1)].t(),
                               testing=testing,
                               stochastic=stochastic,
                               anneal_slope=self.anneal_policy(epoch))
            if train_clip:
                mask.data[logit_p.data[:, d:(d + 1)] < thresh] = 0

            masked_input = input * mask.expand_as(input)

            z_d = F.linear(masked_input, self.W.t(), self.b)
            h_d = self.nonlinearity(z_d)

            z2_d = torch.mm(h_d, self.V[:, d:(d + 1)])
            z2_mu.append(z2_d)

        mu = torch.cat(z2_mu, 1)
        mu = mu + self.c.expand_as(mu)

        return mu
 def __init__(self, channels):
     super(Scale, self).__init__()
     self.weight = Parameter(torch.Tensor(channels))
     self.bias = Parameter(torch.Tensor(channels))
     self.channels = channels
 def __init__(self, num_features):
     super().__init__()
     self.weight = Parameter(torch.Tensor(num_features))
     self.bias = Parameter(torch.Tensor(num_features))
     self.reset_parameters()
 def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
     super(_BatchInstanceNorm, self).__init__(num_features, eps, momentum, affine)
     self.gate = Parameter(torch.Tensor(num_features))
     self.gate.data.fill_(1)
     setattr(self.gate, 'bin_gate', True)
Exemple #32
0
 def __init__(self, p=3, eps=1e-6):
     super(GeM, self).__init__()
     self.p = Parameter(torch.ones(1) * p)
     self.eps = eps
class _ConvNdGroupNJ(BayesianLayers):
    """Convolutional Group Normal-Jeffrey's layers (aka Group Variational Dropout).

    References:
    [1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015).
    [2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017).
    [3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017).
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding,
                 groups, bias, init_weight, init_bias, cuda=False, clip_var=None):
        super(_ConvNdGroupNJ, self).__init__()
        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.transposed = transposed
        self.output_padding = output_padding
        self.groups = groups

        self.cuda = cuda
        self.clip_var = clip_var
        self.deterministic = False  # flag is used for compressed inference

        if transposed:
            self.weight_mu = Parameter(torch.Tensor(
                in_channels, out_channels // groups, *kernel_size))
            self.weight_logvar = Parameter(torch.Tensor(
                in_channels, out_channels // groups, *kernel_size))
        else:
            self.weight_mu = Parameter(torch.Tensor(
                out_channels, in_channels // groups, *kernel_size))
            self.weight_logvar = Parameter(torch.Tensor(
                out_channels, in_channels // groups, *kernel_size))

        self.bias_mu = Parameter(torch.Tensor(out_channels))
        self.bias_logvar = Parameter(torch.Tensor(out_channels))

        self.z_mu = Parameter(torch.Tensor(self.out_channels))
        self.z_logvar = Parameter(torch.Tensor(self.out_channels))

        self.reset_parameters(init_weight, init_bias)

        # activations for kl
        self.sigmoid = nn.Sigmoid()
        self.softplus = nn.Softplus()
        # numerical stability param
        self.epsilon = 1e-8

    def reset_parameters(self, init_weight, init_bias):
        # init means
        n = self.in_channels
        for k in self.kernel_size:
            n *= k
        stdv = 1. / math.sqrt(n)

        # init means
        if init_weight is not None:
            self.weight_mu.data = init_weight
        else:
            self.weight_mu.data.uniform_(-stdv, stdv)

        if init_bias is not None:
            self.bias_mu.data = init_bias
        else:
            self.bias_mu.data.fill_(0)

        # inti z
        self.z_mu.data.normal_(1, 1e-2)

        # init logvars
        self.z_logvar.data.normal_(-9, 1e-2)
        self.weight_logvar.data.normal_(-9, 1e-2)
        self.bias_logvar.data.normal_(-9, 1e-2)

    def clip_variances(self):
        if self.clip_var:
            self.weight_logvar.data.clamp_(max=math.log(self.clip_var))
            self.bias_logvar.data.clamp_(max=math.log(self.clip_var))

    def get_log_dropout_rates(self):
        log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon)
        return log_alpha

    def compute_posterior_params(self):
        weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp()
        print("self.z_mu.pow(2): ", self.z_mu.pow(2).size())
        print("weight_var: ", weight_var.size())
        print("z_var: ", z_var.size())
        print("self.weight_mu.pow(2): ", self.weight_mu.pow(2).size())
        print("weight_var: ", weight_var.size())
        part1 = self.z_mu.pow(2) * weight_var
        part2 = z_var * self.weight_mu.pow(2)
        part3 = z_var * weight_var
        self.post_weight_var = part1 + part2 + part3
        self.post_weight_mu = self.weight_mu * self.z_mu
        print("post_weight_mu: ", self.post_weight_mu.size())
        print("post_weight_var: ", self.post_weight_var.size())
        return self.post_weight_mu, self.post_weight_var

    def kl_divergence(self):
        # KL(q(z)||p(z))
        # we use the kl divergence approximation given by [2] Eq.(14)
        k1, k2, k3 = 0.63576, 1.87320, 1.48695
        log_alpha = self.get_log_dropout_rates()
        KLD = -torch.sum(k1 * self.sigmoid(k2 + k3 * log_alpha) - 0.5 * self.softplus(-log_alpha) - k1)

        # KL(q(w|z)||p(w|z))
        # we use the kl divergence given by [3] Eq.(8)
        KLD_element = -self.weight_logvar + 0.5 * (self.weight_logvar.exp().pow(2) + self.weight_mu.pow(2)) - 0.5
        KLD += torch.sum(KLD_element)

        # KL bias
        KLD_element = -self.bias_logvar + 0.5 * (self.bias_logvar.exp().pow(2) + self.bias_mu.pow(2)) - 0.5
        KLD += torch.sum(KLD_element)

        return KLD

    def __repr__(self):
        s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}')
        if self.padding != (0,) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1,) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0,) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)
    def __init__(
        self,
        hidden_size: int,
        embeddings: TokenEmbeddings,
        tag_dictionary: Dictionary,
        tag_type: str,
        use_crf: bool = True,
        use_rnn: bool = True,
        rnn_layers: int = 1,
        dropout: float = 0.0,
        word_dropout: float = 0.05,
        locked_dropout: float = 0.5,
        reproject_to: int = None,
        train_initial_hidden_state: bool = False,
        rnn_type: str = "LSTM",
        pickle_module: str = "pickle",
        beta: float = 1.0,
        loss_weights: Dict[str, float] = None,
    ):
        """
        Initializes a SequenceTagger
        :param hidden_size: number of hidden states in RNN
        :param embeddings: word embeddings used in tagger
        :param tag_dictionary: dictionary of tags you want to predict
        :param tag_type: string identifier for tag type
        :param use_crf: if True use CRF decoder, else project directly to tag space
        :param use_rnn: if True use RNN layer, otherwise use word embeddings directly
        :param rnn_layers: number of RNN layers
        :param dropout: dropout probability
        :param word_dropout: word dropout probability
        :param reproject_to: set this to control the dimensionality of the reprojection layer
        :param locked_dropout: locked dropout probability
        :param train_initial_hidden_state: if True, trains initial hidden state of RNN
        :param beta: Parameter for F-beta score for evaluation and training annealing
        :param loss_weights: Dictionary of weights for classes (tags) for the loss function
        (if any tag's weight is unspecified it will default to 1.0)

        """

        super(SequenceTagger, self).__init__()
        self.use_rnn = use_rnn
        self.hidden_size = hidden_size
        self.use_crf: bool = use_crf
        self.rnn_layers: int = rnn_layers

        self.trained_epochs: int = 0

        self.embeddings = embeddings

        # set the dictionaries
        self.tag_dictionary: Dictionary = tag_dictionary
        # if we use a CRF, we must add special START and STOP tags to the dictionary
        if use_crf:
            self.tag_dictionary.add_item(START_TAG)
            self.tag_dictionary.add_item(STOP_TAG)

        self.tag_type: str = tag_type
        self.tagset_size: int = len(tag_dictionary)

        self.beta = beta

        self.weight_dict = loss_weights
        # Initialize the weight tensor
        if loss_weights is not None:
            n_classes = len(self.tag_dictionary)
            weight_list = [1. for i in range(n_classes)]
            for i, tag in enumerate(self.tag_dictionary.get_items()):
                if tag in loss_weights.keys():
                    weight_list[i] = loss_weights[tag]
            self.loss_weights = torch.FloatTensor(weight_list).to(flair.device)
        else:
            self.loss_weights = None

        # initialize the network architecture
        self.nlayers: int = rnn_layers
        self.hidden_word = None

        # dropouts
        self.use_dropout: float = dropout
        self.use_word_dropout: float = word_dropout
        self.use_locked_dropout: float = locked_dropout

        self.pickle_module = pickle_module

        if dropout > 0.0:
            self.dropout = torch.nn.Dropout(dropout)

        if word_dropout > 0.0:
            self.word_dropout = flair.nn.WordDropout(word_dropout)

        if locked_dropout > 0.0:
            self.locked_dropout = flair.nn.LockedDropout(locked_dropout)

        embedding_dim: int = self.embeddings.embedding_length

        # if no dimensionality for reprojection layer is set, reproject to equal dimension
        self.reproject_to = reproject_to
        if self.reproject_to is None: self.reproject_to = embedding_dim
        rnn_input_dim: int = self.reproject_to

        self.relearn_embeddings: bool = True
        if self.relearn_embeddings:
            self.embedding2nn = torch.nn.Linear(embedding_dim, rnn_input_dim)

        self.train_initial_hidden_state = train_initial_hidden_state
        self.bidirectional = True
        self.rnn_type = rnn_type

        # bidirectional LSTM on top of embedding layer
        if self.use_rnn:
            num_directions = 2 if self.bidirectional else 1

            if self.rnn_type in ["LSTM", "GRU"]:

                self.rnn = getattr(torch.nn, self.rnn_type)(
                    rnn_input_dim,
                    hidden_size,
                    num_layers=self.nlayers,
                    dropout=0.0 if self.nlayers == 1 else 0.5,
                    bidirectional=True,
                    batch_first=True,
                )
                # Create initial hidden state and initialize it
                if self.train_initial_hidden_state:
                    self.hs_initializer = torch.nn.init.xavier_normal_

                    self.lstm_init_h = Parameter(
                        torch.randn(self.nlayers * num_directions,
                                    self.hidden_size),
                        requires_grad=True,
                    )

                    self.lstm_init_c = Parameter(
                        torch.randn(self.nlayers * num_directions,
                                    self.hidden_size),
                        requires_grad=True,
                    )

                    # TODO: Decide how to initialize the hidden state variables
                    # self.hs_initializer(self.lstm_init_h)
                    # self.hs_initializer(self.lstm_init_c)

            # final linear map to tag space
            self.linear = torch.nn.Linear(hidden_size * num_directions,
                                          len(tag_dictionary))
        else:
            self.linear = torch.nn.Linear(self.embeddings.embedding_length,
                                          len(tag_dictionary))

        if self.use_crf:
            self.transitions = torch.nn.Parameter(
                torch.randn(self.tagset_size, self.tagset_size))

            self.transitions.detach()[
                self.tag_dictionary.get_idx_for_item(START_TAG), :] = -10000

            self.transitions.detach()[:,
                                      self.tag_dictionary.
                                      get_idx_for_item(STOP_TAG)] = -10000

        self.to(flair.device)
Exemple #35
0
class LinearVB(nn.Module):
    """
    Bayesian Linear Layer with parameters:
        - mu: The mean value of the 
        - rho: The sigma of th OR sigma
    
    """
    def __init__(self, in_features, out_features, bias=True, prior = None):
        super(LinearVB, self).__init__()
        self.type_layer = "linear"
        # device= conf_a.device, dtype= conf_a.dtype,
        self.in_features = in_features
        self.out_features = out_features
        self.bias = bias
        self.posterior_mean = False # Flag to know if we sample from the posterior mean or we actually sample
        
        ## If no prior is specified we just create it ourselves
        if (type(prior) == type (None)):
            prior = Vil.Prior(0.5, np.log(0.1),np.log(0.5))
        prior =  prior.get_standarized_Prior(in_features)
        self.prior = prior 
        
        """
        Mean and rhos of the parameters
        """
        self.mu_weight = Parameter(torch.Tensor(out_features, in_features))# , requires_grad=True
        self.rho_weight = Parameter(torch.Tensor(out_features, in_features))
        
        if bias:
            self.rho_bias = Parameter(torch.Tensor(out_features))
            self.mu_bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('rho_bias', None)
            self.register_parameter('mu_bias', None)
            
        """
        The sampled weights
        """
        self.weight = torch.Tensor(out_features, in_features)
        if bias:
            self.bias = torch.Tensor(out_features,1)
        else:
            self.register_parameter('bias', None)
        
        if(0):
            print ("linear bias device: ",self.bias.device)
            print ("linear weights device: ",self.weight.device)
            print ("linear bias mu device: ",self.mu_bias.device)
            print ("linear bias rho device: ",self.rho_bias.device)
            
            print ("linear weights mu  device: ",self.mu_weight.device)
            print ("linear weights rho device: ",self.rho_weight.device)
            
        ## Initialize the Variational variables
        self.reset_parameters()
        self.sample_posterior()
    
    def reset_parameters(self):
        """
        In this function we initialize the parameters using the prior.
        The variance of the weights depends on the prior !! 
        TODO: Should it depend on dimensionality ! 
        Also the initializaion of weights should follow the normal scheme or from prior ? Can they be different
        """
        print ("mu_bias prior LinearVB: ", self.prior.mu_bias)
        self.rho_weight.data = Vil.init_rho(self.mu_weight.size(), self.prior)
        if self.bias is not None:
            self.rho_bias.data = Vil.init_rho(self.mu_bias.size(), self.prior)
        
        ## Now initialize the mean
        self.mu_weight.data = Vil.init_mu(self.mu_weight.size(), self.prior,Ninput = self.mu_weight.size(1))
        if self.bias is not None:
            self.mu_bias.data = Vil.init_mu(self.mu_bias.size(), self.prior, Ninput = self.mu_weight.size(1))

    def sample_posterior(self):
        """
        This function samples the Bayesian weights from the parameters and puts them into the variables.
        It needs to do so using the reparametrization trick so that we can derive respect to sigma and mu
        """
        
#        print ("SAMPLING FROM LINEAR VB")
        if (self.posterior_mean == False):
            
            self.weight = Vil.sample_posterior(self.mu_weight, Vil.softplus(self.rho_weight))
            if self.bias is not None:
                self.bias = Vil.sample_posterior(self.mu_bias, Vil.softplus(self.rho_bias))
        else:
            self.weight.data = self.mu_weight.data
            if self.bias is not None:
                self.bias.data = self.mu_bias.data
        
    def get_KL_divergence(self):
        """
        This function computes the KL loss for all the Variational Weights in the network !!
        It does not sample the weights again, it uses the ones that are already sampled.
        
        """
        KL_loss_W = Vil.get_KL_divergence_Samples(self.mu_weight, Vil.softplus(self.rho_weight),
                                                  self.weight, self.prior, mu_prior_fluid = self.prior.mu_weight)
        KL_loss_b = 0
        if self.bias is not None:
            KL_loss_b = Vil.get_KL_divergence_Samples(self.mu_bias, Vil.softplus(self.rho_bias), 
                                                      self.bias,  self.prior, mu_prior_fluid = self.prior.mu_bias)
            
        KL_loss = KL_loss_W + KL_loss_b
        
        return KL_loss
    
    def forward(self, X):
        """
        Funciton call to generate the output, every time we call it, the dynamic graph is created.
        There can be difference between forward in training and test:
            - In dropout we do not zero neurons in test
            - In Variational Inference we dont randombly sample from the posterior
        
        We create the forward pass by performing operations between the input X (Nsam_batch, Ndim)
        and the parameters of the model that we should have initialized in the __init__
        """
        
#        o2 = torch.mm(X, self.weight) + self.bias
        o2 = F.linear(X, self.weight, self.bias)
        return o2
    
    """
    Flag to set that we actually get the posterior mean and not a sample from the random variables
    """
    def set_posterior_mean(self, posterior_mean):
        self.posterior_mean = posterior_mean
Exemple #36
0
class Network(nn.Module):
    """
    Todo:
    - Beam search
    - check if this is right? attend during P->FC rather than during softmax->P?
    - allow length 0 inputs/targets
    - give n_examples as input to FC
    - Initialise new weights randomly, rather than as zeroes
    """
    def __init__(self,
                 input_vocabulary,
                 target_vocabulary,
                 hidden_size=512,
                 embedding_size=128,
                 cell_type="LSTM"):
        """
        :param list input_vocabulary: list of possible inputs
        :param list target_vocabulary: list of possible targets
        """
        super(Network, self).__init__()
        self.h_input_encoder_size = hidden_size
        self.h_output_encoder_size = hidden_size
        self.h_decoder_size = hidden_size
        self.embedding_size = embedding_size
        self.input_vocabulary = input_vocabulary
        self.target_vocabulary = target_vocabulary
        # Number of tokens in input vocabulary
        self.v_input = len(input_vocabulary)
        # Number of tokens in target vocabulary
        self.v_target = len(target_vocabulary)

        self.cell_type = cell_type
        if cell_type == 'GRU':
            self.input_encoder_cell = nn.GRUCell(
                input_size=self.v_input + 1,
                hidden_size=self.h_input_encoder_size,
                bias=True)
            self.input_encoder_init = Parameter(
                torch.rand(1, self.h_input_encoder_size))
            self.output_encoder_cell = nn.GRUCell(
                input_size=self.v_input + 1 + self.h_input_encoder_size,
                hidden_size=self.h_output_encoder_size,
                bias=True)
            self.decoder_cell = nn.GRUCell(input_size=self.v_target + 1,
                                           hidden_size=self.h_decoder_size,
                                           bias=True)
        if cell_type == 'LSTM':
            self.input_encoder_cell = nn.LSTMCell(
                input_size=self.v_input + 1,
                hidden_size=self.h_input_encoder_size,
                bias=True)
            self.input_encoder_init = nn.ParameterList([
                Parameter(torch.rand(1, self.h_input_encoder_size)),
                Parameter(torch.rand(1, self.h_input_encoder_size))
            ])
            self.output_encoder_cell = nn.LSTMCell(
                input_size=self.v_input + 1 + self.h_input_encoder_size,
                hidden_size=self.h_output_encoder_size,
                bias=True)
            self.output_encoder_init_c = Parameter(
                torch.rand(1, self.h_output_encoder_size))
            self.decoder_cell = nn.LSTMCell(input_size=self.v_target + 1,
                                            hidden_size=self.h_decoder_size,
                                            bias=True)
            self.decoder_init_c = Parameter(torch.rand(1, self.h_decoder_size))

        self.W = nn.Linear(self.h_output_encoder_size + self.h_decoder_size,
                           self.embedding_size)
        self.V = nn.Linear(self.embedding_size, self.v_target + 1)
        self.input_A = nn.Bilinear(self.h_input_encoder_size,
                                   self.h_output_encoder_size,
                                   1,
                                   bias=False)
        self.output_A = nn.Bilinear(self.h_output_encoder_size,
                                    self.h_decoder_size,
                                    1,
                                    bias=False)
        self.input_EOS = torch.zeros(1, self.v_input + 1)
        self.input_EOS[:, -1] = 1
        self.input_EOS = Parameter(self.input_EOS)
        self.output_EOS = torch.zeros(1, self.v_input + 1)
        self.output_EOS[:, -1] = 1
        self.output_EOS = Parameter(self.output_EOS)
        self.target_EOS = torch.zeros(1, self.v_target + 1)
        self.target_EOS[:, -1] = 1
        self.target_EOS = Parameter(self.target_EOS)

    def __getstate__(self):
        if hasattr(self, 'opt'):
            return dict([(k, v)
                         for k, v in self.__dict__.items() if k is not 'opt'] +
                        [('optstate', self.opt.state_dict())])
            # return {**{k:v for k,v in self.__dict__.items() if k is not 'opt'},
            #         'optstate': self.opt.state_dict()}
        else:
            return self.__dict__

    def __setstate__(self, state):
        self.__dict__.update(state)
        # Legacy:
        if isinstance(self.input_encoder_init, tuple):
            self.input_encoder_init = nn.ParameterList(
                list(self.input_encoder_init))

    def clear_optimiser(self):
        if hasattr(self, 'opt'):
            del self.opt
        if hasattr(self, 'optstate'):
            del self.optstate

    def get_optimiser(self):
        self.opt = torch.optim.Adam(self.parameters(), lr=0.001)
        if hasattr(self, 'optstate'):
            self.opt.load_state_dict(self.optstate)

    def optimiser_step(self, inputs, outputs, target):
        if not hasattr(self, 'opt'):
            self.get_optimiser()
        score = self.score(inputs, outputs, target, autograd=True).mean()
        (-score).backward()
        self.opt.step()
        self.opt.zero_grad()
        return score.data[0]

    def set_target_vocabulary(self, target_vocabulary):
        if target_vocabulary == self.target_vocabulary:
            return

        V_weight = []
        V_bias = []
        decoder_ih = []

        for i in range(len(target_vocabulary)):
            if target_vocabulary[i] in self.target_vocabulary:
                j = self.target_vocabulary.index(target_vocabulary[i])
                V_weight.append(self.V.weight.data[j:j + 1])
                V_bias.append(self.V.bias.data[j:j + 1])
                decoder_ih.append(self.decoder_cell.weight_ih.data[:, j:j + 1])
            else:
                V_weight.append(torch.zeros(1, self.V.weight.size(1)))
                V_bias.append(torch.ones(1) * -10)
                decoder_ih.append(
                    torch.zeros(self.decoder_cell.weight_ih.data.size(0), 1))

        V_weight.append(self.V.weight.data[-1:])
        V_bias.append(self.V.bias.data[-1:])
        decoder_ih.append(self.decoder_cell.weight_ih.data[:, -1:])

        self.target_vocabulary = target_vocabulary
        self.v_target = len(target_vocabulary)
        self.target_EOS.data = torch.zeros(1, self.v_target + 1)
        self.target_EOS.data[:, -1] = 1

        self.V.weight.data = torch.cat(V_weight, dim=0)
        self.V.bias.data = torch.cat(V_bias, dim=0)
        self.V.out_features = self.V.bias.data.size(0)

        self.decoder_cell.weight_ih.data = torch.cat(decoder_ih, dim=1)
        self.decoder_cell.input_size = self.decoder_cell.weight_ih.data.size(1)

        self.clear_optimiser()

    def input_encoder_get_init(self, batch_size):
        if self.cell_type == "GRU":
            return self.input_encoder_init.repeat(batch_size, 1)
        if self.cell_type == "LSTM":
            return tuple(
                x.repeat(batch_size, 1) for x in self.input_encoder_init)

    def output_encoder_get_init(self, input_encoder_h):
        if self.cell_type == "GRU":
            return input_encoder_h
        if self.cell_type == "LSTM":
            return (input_encoder_h,
                    self.output_encoder_init_c.repeat(input_encoder_h.size(0),
                                                      1))

    def decoder_get_init(self, output_encoder_h):
        if self.cell_type == "GRU":
            return output_encoder_h
        if self.cell_type == "LSTM":
            return (output_encoder_h,
                    self.decoder_init_c.repeat(output_encoder_h.size(0), 1))

    def cell_get_h(self, cell_state):
        if self.cell_type == "GRU":
            return cell_state
        if self.cell_type == "LSTM":
            return cell_state[0]

    def score(self, inputs, outputs, target, autograd=False):
        inputs = self.inputsToTensors(inputs)
        outputs = self.inputsToTensors(outputs)
        target = self.targetToTensor(target)
        target, score = self.run(inputs, outputs, target=target, mode="score")
        # target = self.tensorToOutput(target)
        if autograd:
            return score
        else:
            return score.data

    def sample(self, inputs, outputs):
        inputs = self.inputsToTensors(inputs)
        outputs = self.inputsToTensors(outputs)
        target, score = self.run(inputs, outputs, mode="sample")
        target = self.tensorToOutput(target)
        return target

    def sampleAndScore(self, inputs, outputs, nRepeats=None):
        inputs = self.inputsToTensors(inputs)
        outputs = self.inputsToTensors(outputs)
        if nRepeats is None:
            target, score = self.run(inputs, outputs, mode="sample")
            target = self.tensorToOutput(target)
            return target, score.data
        else:
            target = []
            score = []
            for i in range(nRepeats):
                # print("repeat %d" % i)
                t, s = self.run(inputs, outputs, mode="sample")
                t = self.tensorToOutput(t)
                target.extend(t)
                score.extend(list(s.data))
            return target, score

    def run(self, inputs, outputs, target=None, mode="sample"):
        """
        :param mode: "score" returns log p(target|input), "sample" returns target ~ p(-|input)
        :param List[LongTensor] inputs: n_examples * (max_length_input * batch_size)
        :param List[LongTensor] target: max_length_target * batch_size
        """
        assert ((mode == "score" and target is not None) or mode == "sample")

        n_examples = len(inputs)
        max_length_input = [inputs[j].size(0) for j in range(n_examples)]
        max_length_output = [outputs[j].size(0) for j in range(n_examples)]
        max_length_target = target.size(0) if target is not None else 10
        batch_size = inputs[0].size(1)

        score = Variable(torch.zeros(batch_size))
        inputs_scatter = [
            Variable(
                torch.zeros(max_length_input[j], batch_size,
                            self.v_input + 1).scatter_(2, inputs[j][:, :,
                                                                    None], 1))
            for j in range(n_examples)
        ]  # n_examples * (max_length_input * batch_size * v_input+1)
        outputs_scatter = [
            Variable(
                torch.zeros(max_length_output[j], batch_size,
                            self.v_input + 1).scatter_(2, outputs[j][:, :,
                                                                     None], 1))
            for j in range(n_examples)
        ]  # n_examples * (max_length_output * batch_size * v_input+1)
        if target is not None:
            target_scatter = Variable(
                torch.zeros(
                    max_length_target, batch_size, self.v_target + 1).scatter_(
                        2, target[:, :, None],
                        1))  # max_length_target * batch_size * v_target+1

        # -------------- Input Encoder -------------

        # n_examples * (max_length_input * batch_size * h_encoder_size)
        input_H = []
        input_embeddings = []  # h for example at INPUT_EOS
        # 0 until (and including) INPUT_EOS, then -inf
        input_attention_mask = []
        for j in range(n_examples):
            active = torch.Tensor(max_length_input[j], batch_size).byte()
            active[0, :] = 1
            state = self.input_encoder_get_init(batch_size)
            hs = []
            for i in range(max_length_input[j]):
                state = self.input_encoder_cell(inputs_scatter[j][i, :, :],
                                                state)
                if i + 1 < max_length_input[j]:
                    active[i + 1, :] = active[i, :] * \
                        (inputs[j][i, :] != self.v_input)
                h = self.cell_get_h(state)
                hs.append(h[None, :, :])
            input_H.append(torch.cat(hs, 0))
            embedding_idx = active.sum(0).long() - 1
            embedding = input_H[j].gather(
                0,
                Variable(embedding_idx[None, :, None].repeat(
                    1, 1, self.h_input_encoder_size)))[0]
            input_embeddings.append(embedding)
            input_attention_mask.append(Variable(active.float().log()))

        # -------------- Output Encoder -------------

        def input_attend(j, h_out):
            """
            'general' attention from https://arxiv.org/pdf/1508.04025.pdf
            :param j: Index of example
            :param h_out: batch_size * h_output_encoder_size
            """
            scores = self.input_A(
                input_H[j].view(max_length_input[j] * batch_size,
                                self.h_input_encoder_size),
                h_out.view(batch_size, self.h_output_encoder_size).repeat(
                    max_length_input[j],
                    1)).view(max_length_input[j],
                             batch_size) + input_attention_mask[j]
            c = (F.softmax(scores[:, :, None], dim=0) * input_H[j]).sum(0)
            return c

        # n_examples * (max_length_input * batch_size * h_encoder_size)
        output_H = []
        output_embeddings = []  # h for example at INPUT_EOS
        # 0 until (and including) INPUT_EOS, then -inf
        output_attention_mask = []
        for j in range(n_examples):
            active = torch.Tensor(max_length_output[j], batch_size).byte()
            active[0, :] = 1
            state = self.output_encoder_get_init(input_embeddings[j])
            hs = []
            h = self.cell_get_h(state)
            for i in range(max_length_output[j]):
                state = self.output_encoder_cell(
                    torch.cat(
                        [outputs_scatter[j][i, :, :],
                         input_attend(j, h)], 1), state)
                if i + 1 < max_length_output[j]:
                    active[i + 1, :] = active[i, :] * \
                        (outputs[j][i, :] != self.v_input)
                h = self.cell_get_h(state)
                hs.append(h[None, :, :])
            output_H.append(torch.cat(hs, 0))
            embedding_idx = active.sum(0).long() - 1
            embedding = output_H[j].gather(
                0,
                Variable(embedding_idx[None, :, None].repeat(
                    1, 1, self.h_output_encoder_size)))[0]
            output_embeddings.append(embedding)
            output_attention_mask.append(Variable(active.float().log()))

        # ------------------ Decoder -----------------

        def output_attend(j, h_dec):
            """
            'general' attention from https://arxiv.org/pdf/1508.04025.pdf
            :param j: Index of example
            :param h_dec: batch_size * h_decoder_size
            """
            scores = self.output_A(
                output_H[j].view(max_length_output[j] * batch_size,
                                 self.h_output_encoder_size),
                h_dec.view(batch_size, self.h_decoder_size).repeat(
                    max_length_output[j],
                    1)).view(max_length_output[j],
                             batch_size) + output_attention_mask[j]
            c = (F.softmax(scores[:, :, None], dim=0) * output_H[j]).sum(0)
            return c

        # Multi-example pooling: Figure 3, https://arxiv.org/pdf/1703.07469.pdf
        target = target if mode == "score" else torch.zeros(
            max_length_target, batch_size).long()
        decoder_states = [
            self.decoder_get_init(output_embeddings[j])
            for j in range(n_examples)
        ]  # P
        active = torch.ones(batch_size).byte()
        for i in range(max_length_target):
            FC = []
            for j in range(n_examples):
                h = self.cell_get_h(decoder_states[j])
                p_aug = torch.cat([h, output_attend(j, h)], 1)
                FC.append(F.tanh(self.W(p_aug)[None, :, :]))
            # batch_size * embedding_size
            m = torch.max(torch.cat(FC, 0), 0)[0]
            logsoftmax = F.log_softmax(self.V(m), dim=1)
            if mode == "sample":
                target[i, :] = torch.multinomial(logsoftmax.data.exp(), 1)[:,
                                                                           0]
            score = score + \
                choose(logsoftmax, target[i, :]) * Variable(active.float())
            active *= (target[i, :] != self.v_target)
            for j in range(n_examples):
                if mode == "score":
                    target_char_scatter = target_scatter[i, :, :]
                elif mode == "sample":
                    target_char_scatter = Variable(
                        torch.zeros(batch_size, self.v_target + 1).scatter_(
                            1, target[i, :, None], 1))
                decoder_states[j] = self.decoder_cell(target_char_scatter,
                                                      decoder_states[j])
        return target, score

    def inputsToTensors(self, inputss):
        """
        :param inputss: size = nBatch * nExamples
        """
        tensors = []
        for j in range(len(inputss[0])):
            inputs = [x[j] for x in inputss]
            maxlen = max(len(s) for s in inputs)
            t = torch.ones(1 if maxlen == 0 else maxlen + 1,
                           len(inputs)).long() * self.v_input
            for i in range(len(inputs)):
                s = inputs[i]
                if len(s) > 0:
                    t[:len(s), i] = torch.LongTensor(
                        [self.input_vocabulary.index(x) for x in s])
            tensors.append(t)
        return tensors

    def targetToTensor(self, targets):
        """
        :param targets:
        """
        maxlen = max(len(s) for s in targets)
        t = torch.ones(1 if maxlen == 0 else maxlen + 1,
                       len(targets)).long() * self.v_target
        for i in range(len(targets)):
            s = targets[i]
            if len(s) > 0:
                t[:len(s), i] = torch.LongTensor(
                    [self.target_vocabulary.index(x) for x in s])
        return t

    def tensorToOutput(self, tensor):
        """
        :param tensor: max_length * batch_size
        """
        out = []
        for i in range(tensor.size(1)):
            l = tensor[:, i].tolist()
            if l[0] == self.v_target:
                out.append([])
            elif self.v_target in l:
                final = tensor[:, i].tolist().index(self.v_target)
                out.append(
                    [self.target_vocabulary[x] for x in tensor[:final, i]])
            else:
                out.append([self.target_vocabulary[x] for x in tensor[:, i]])
        return out
class SequenceTagger(flair.nn.Model):
    def __init__(
        self,
        hidden_size: int,
        embeddings: TokenEmbeddings,
        tag_dictionary: Dictionary,
        tag_type: str,
        use_crf: bool = True,
        use_rnn: bool = True,
        rnn_layers: int = 1,
        dropout: float = 0.0,
        word_dropout: float = 0.05,
        locked_dropout: float = 0.5,
        reproject_to: int = None,
        train_initial_hidden_state: bool = False,
        rnn_type: str = "LSTM",
        pickle_module: str = "pickle",
        beta: float = 1.0,
        loss_weights: Dict[str, float] = None,
    ):
        """
        Initializes a SequenceTagger
        :param hidden_size: number of hidden states in RNN
        :param embeddings: word embeddings used in tagger
        :param tag_dictionary: dictionary of tags you want to predict
        :param tag_type: string identifier for tag type
        :param use_crf: if True use CRF decoder, else project directly to tag space
        :param use_rnn: if True use RNN layer, otherwise use word embeddings directly
        :param rnn_layers: number of RNN layers
        :param dropout: dropout probability
        :param word_dropout: word dropout probability
        :param reproject_to: set this to control the dimensionality of the reprojection layer
        :param locked_dropout: locked dropout probability
        :param train_initial_hidden_state: if True, trains initial hidden state of RNN
        :param beta: Parameter for F-beta score for evaluation and training annealing
        :param loss_weights: Dictionary of weights for classes (tags) for the loss function
        (if any tag's weight is unspecified it will default to 1.0)

        """

        super(SequenceTagger, self).__init__()
        self.use_rnn = use_rnn
        self.hidden_size = hidden_size
        self.use_crf: bool = use_crf
        self.rnn_layers: int = rnn_layers

        self.trained_epochs: int = 0

        self.embeddings = embeddings

        # set the dictionaries
        self.tag_dictionary: Dictionary = tag_dictionary
        # if we use a CRF, we must add special START and STOP tags to the dictionary
        if use_crf:
            self.tag_dictionary.add_item(START_TAG)
            self.tag_dictionary.add_item(STOP_TAG)

        self.tag_type: str = tag_type
        self.tagset_size: int = len(tag_dictionary)

        self.beta = beta

        self.weight_dict = loss_weights
        # Initialize the weight tensor
        if loss_weights is not None:
            n_classes = len(self.tag_dictionary)
            weight_list = [1. for i in range(n_classes)]
            for i, tag in enumerate(self.tag_dictionary.get_items()):
                if tag in loss_weights.keys():
                    weight_list[i] = loss_weights[tag]
            self.loss_weights = torch.FloatTensor(weight_list).to(flair.device)
        else:
            self.loss_weights = None

        # initialize the network architecture
        self.nlayers: int = rnn_layers
        self.hidden_word = None

        # dropouts
        self.use_dropout: float = dropout
        self.use_word_dropout: float = word_dropout
        self.use_locked_dropout: float = locked_dropout

        self.pickle_module = pickle_module

        if dropout > 0.0:
            self.dropout = torch.nn.Dropout(dropout)

        if word_dropout > 0.0:
            self.word_dropout = flair.nn.WordDropout(word_dropout)

        if locked_dropout > 0.0:
            self.locked_dropout = flair.nn.LockedDropout(locked_dropout)

        embedding_dim: int = self.embeddings.embedding_length

        # if no dimensionality for reprojection layer is set, reproject to equal dimension
        self.reproject_to = reproject_to
        if self.reproject_to is None: self.reproject_to = embedding_dim
        rnn_input_dim: int = self.reproject_to

        self.relearn_embeddings: bool = True
        if self.relearn_embeddings:
            self.embedding2nn = torch.nn.Linear(embedding_dim, rnn_input_dim)

        self.train_initial_hidden_state = train_initial_hidden_state
        self.bidirectional = True
        self.rnn_type = rnn_type

        # bidirectional LSTM on top of embedding layer
        if self.use_rnn:
            num_directions = 2 if self.bidirectional else 1

            if self.rnn_type in ["LSTM", "GRU"]:

                self.rnn = getattr(torch.nn, self.rnn_type)(
                    rnn_input_dim,
                    hidden_size,
                    num_layers=self.nlayers,
                    dropout=0.0 if self.nlayers == 1 else 0.5,
                    bidirectional=True,
                    batch_first=True,
                )
                # Create initial hidden state and initialize it
                if self.train_initial_hidden_state:
                    self.hs_initializer = torch.nn.init.xavier_normal_

                    self.lstm_init_h = Parameter(
                        torch.randn(self.nlayers * num_directions,
                                    self.hidden_size),
                        requires_grad=True,
                    )

                    self.lstm_init_c = Parameter(
                        torch.randn(self.nlayers * num_directions,
                                    self.hidden_size),
                        requires_grad=True,
                    )

                    # TODO: Decide how to initialize the hidden state variables
                    # self.hs_initializer(self.lstm_init_h)
                    # self.hs_initializer(self.lstm_init_c)

            # final linear map to tag space
            self.linear = torch.nn.Linear(hidden_size * num_directions,
                                          len(tag_dictionary))
        else:
            self.linear = torch.nn.Linear(self.embeddings.embedding_length,
                                          len(tag_dictionary))

        if self.use_crf:
            self.transitions = torch.nn.Parameter(
                torch.randn(self.tagset_size, self.tagset_size))

            self.transitions.detach()[
                self.tag_dictionary.get_idx_for_item(START_TAG), :] = -10000

            self.transitions.detach()[:,
                                      self.tag_dictionary.
                                      get_idx_for_item(STOP_TAG)] = -10000

        self.to(flair.device)

    def _get_state_dict(self):
        model_state = {
            "state_dict": self.state_dict(),
            "embeddings": self.embeddings,
            "hidden_size": self.hidden_size,
            "train_initial_hidden_state": self.train_initial_hidden_state,
            "tag_dictionary": self.tag_dictionary,
            "tag_type": self.tag_type,
            "use_crf": self.use_crf,
            "use_rnn": self.use_rnn,
            "rnn_layers": self.rnn_layers,
            "use_word_dropout": self.use_word_dropout,
            "use_locked_dropout": self.use_locked_dropout,
            "rnn_type": self.rnn_type,
            "beta": self.beta,
            "weight_dict": self.weight_dict,
            "reproject_to": self.reproject_to,
        }
        return model_state

    @staticmethod
    def _init_model_with_state_dict(state):

        rnn_type = "LSTM" if "rnn_type" not in state.keys(
        ) else state["rnn_type"]
        use_dropout = 0.0 if "use_dropout" not in state.keys(
        ) else state["use_dropout"]
        use_word_dropout = (0.0 if "use_word_dropout" not in state.keys() else
                            state["use_word_dropout"])
        use_locked_dropout = (0.0 if "use_locked_dropout" not in state.keys()
                              else state["use_locked_dropout"])
        train_initial_hidden_state = (False if "train_initial_hidden_state"
                                      not in state.keys() else
                                      state["train_initial_hidden_state"])
        beta = 1.0 if "beta" not in state.keys() else state["beta"]
        weights = None if "weight_dict" not in state.keys(
        ) else state["weight_dict"]
        reproject_to = None if "reproject_to" not in state.keys(
        ) else state["reproject_to"]

        model = SequenceTagger(
            hidden_size=state["hidden_size"],
            embeddings=state["embeddings"],
            tag_dictionary=state["tag_dictionary"],
            tag_type=state["tag_type"],
            use_crf=state["use_crf"],
            use_rnn=state["use_rnn"],
            rnn_layers=state["rnn_layers"],
            dropout=use_dropout,
            word_dropout=use_word_dropout,
            locked_dropout=use_locked_dropout,
            train_initial_hidden_state=train_initial_hidden_state,
            rnn_type=rnn_type,
            beta=beta,
            loss_weights=weights,
            reproject_to=reproject_to,
        )
        model.load_state_dict(state["state_dict"])
        return model

    def predict(
        self,
        sentences: Union[List[Sentence], Sentence, List[str], str],
        mini_batch_size=32,
        embedding_storage_mode="none",
        all_tag_prob: bool = False,
        verbose: bool = False,
        use_tokenizer: Union[bool, Callable[[str],
                                            List[Token]]] = space_tokenizer,
    ) -> List[Sentence]:
        """
        Predict sequence tags for Named Entity Recognition task
        :param sentences: a Sentence or a string or a List of Sentence or a List of string.
        :param mini_batch_size: size of the minibatch, usually bigger is more rapid but consume more memory,
        up to a point when it has no more effect.
        :param embedding_storage_mode: 'none' for the minimum memory footprint, 'cpu' to store embeddings in Ram,
        'gpu' to store embeddings in GPU memory.
        :param all_tag_prob: True to compute the score for each tag on each token,
        otherwise only the score of the best tag is returned
        :param verbose: set to True to display a progress bar
        :param use_tokenizer: a custom tokenizer when string are provided (default is space based tokenizer).
        :return: List of Sentence enriched by the predicted tags
        """
        with torch.no_grad():
            if not sentences:
                return sentences

            if isinstance(sentences, Sentence) or isinstance(sentences, str):
                sentences = [sentences]

            if (flair.device.type
                    == "cuda") and embedding_storage_mode == "cpu":
                log.warning(
                    "You are inferring on GPU with parameter 'embedding_storage_mode' set to 'cpu'."
                    "This option will slow down your inference, usually 'none' (default value) "
                    "is a better choice.")

            # reverse sort all sequences by their length
            rev_order_len_index = sorted(range(len(sentences)),
                                         key=lambda k: len(sentences[k]),
                                         reverse=True)
            original_order_index = sorted(range(len(rev_order_len_index)),
                                          key=lambda k: rev_order_len_index[k])

            reordered_sentences: List[Union[Sentence, str]] = [
                sentences[index] for index in rev_order_len_index
            ]

            if isinstance(sentences[0], Sentence):
                # remove previous embeddings
                store_embeddings(reordered_sentences, "none")
                dataset = SentenceDataset(reordered_sentences)
            else:
                dataset = StringDataset(reordered_sentences,
                                        use_tokenizer=use_tokenizer)
            dataloader = DataLoader(dataset=dataset,
                                    batch_size=mini_batch_size,
                                    collate_fn=lambda x: x)

            if self.use_crf:
                transitions = self.transitions.detach().cpu().numpy()
            else:
                transitions = None

            # progress bar for verbosity
            if verbose:
                dataloader = tqdm(dataloader)

            results: List[Sentence] = []
            for i, batch in enumerate(dataloader):

                if verbose:
                    dataloader.set_description(f"Inferencing on batch {i}")
                results += batch
                batch = self._filter_empty_sentences(batch)
                # stop if all sentences are empty
                if not batch:
                    continue

                feature: torch.Tensor = self.forward(batch)
                tags, all_tags = self._obtain_labels(
                    feature=feature,
                    batch_sentences=batch,
                    transitions=transitions,
                    get_all_tags=all_tag_prob,
                )

                for (sentence, sent_tags) in zip(batch, tags):
                    for (token, tag) in zip(sentence.tokens, sent_tags):
                        token.add_tag_label(self.tag_type, tag)

                # all_tags will be empty if all_tag_prob is set to False, so the for loop will be avoided
                for (sentence, sent_all_tags) in zip(batch, all_tags):
                    for (token, token_all_tags) in zip(sentence.tokens,
                                                       sent_all_tags):
                        token.add_tags_proba_dist(self.tag_type,
                                                  token_all_tags)

                # clearing token embeddings to save memory
                store_embeddings(batch, storage_mode=embedding_storage_mode)

            results: List[Union[Sentence, str]] = [
                results[index] for index in original_order_index
            ]
            assert len(sentences) == len(results)
            return results

    def evaluate(
        self,
        data_loader: DataLoader,
        out_path: Path = None,
        embedding_storage_mode: str = "none",
    ) -> (Result, float):

        if type(out_path) == str:
            out_path = Path(out_path)

        with torch.no_grad():
            eval_loss = 0

            batch_no: int = 0

            metric = Metric("Evaluation", beta=self.beta)

            lines: List[str] = []

            if self.use_crf:
                transitions = self.transitions.detach().cpu().numpy()
            else:
                transitions = None

            for batch in data_loader:
                batch_no += 1

                with torch.no_grad():
                    features = self.forward(batch)
                    loss = self._calculate_loss(features, batch)
                    tags, _ = self._obtain_labels(
                        feature=features,
                        batch_sentences=batch,
                        transitions=transitions,
                        get_all_tags=False,
                    )

                eval_loss += loss

                for (sentence, sent_tags) in zip(batch, tags):
                    for (token, tag) in zip(sentence.tokens, sent_tags):
                        token: Token = token
                        token.add_tag("predicted", tag.value, tag.score)

                        # append both to file for evaluation
                        eval_line = "{} {} {} {}\n".format(
                            token.text,
                            token.get_tag(self.tag_type).value,
                            tag.value,
                            tag.score,
                        )
                        lines.append(eval_line)
                    lines.append("\n")

                for sentence in batch:
                    # make list of gold tags
                    gold_tags = [(tag.tag, tag.text)
                                 for tag in sentence.get_spans(self.tag_type)]
                    # make list of predicted tags
                    predicted_tags = [
                        (tag.tag, tag.text)
                        for tag in sentence.get_spans("predicted")
                    ]

                    # check for true positives, false positives and false negatives
                    for tag, prediction in predicted_tags:
                        if (tag, prediction) in gold_tags:
                            metric.add_tp(tag)
                        else:
                            metric.add_fp(tag)

                    for tag, gold in gold_tags:
                        if (tag, gold) not in predicted_tags:
                            metric.add_fn(tag)
                        else:
                            metric.add_tn(tag)

                store_embeddings(batch, embedding_storage_mode)

            eval_loss /= batch_no

            if out_path is not None:
                with open(out_path, "w", encoding="utf-8") as outfile:
                    outfile.write("".join(lines))

            detailed_result = (
                f"\nMICRO_AVG: acc {metric.micro_avg_accuracy():.4f} - f1-score {metric.micro_avg_f_score():.4f}"
                f"\nMACRO_AVG: acc {metric.macro_avg_accuracy():.4f} - f1-score {metric.macro_avg_f_score():.4f}"
            )
            for class_name in metric.get_classes():
                detailed_result += (
                    f"\n{class_name:<10} tp: {metric.get_tp(class_name)} - fp: {metric.get_fp(class_name)} - "
                    f"fn: {metric.get_fn(class_name)} - tn: {metric.get_tn(class_name)} - precision: "
                    f"{metric.precision(class_name):.4f} - recall: {metric.recall(class_name):.4f} - "
                    f"accuracy: {metric.accuracy(class_name):.4f} - f1-score: "
                    f"{metric.f_score(class_name):.4f}")

            result = Result(
                main_score=metric.micro_avg_f_score(),
                log_line=
                f"{metric.precision():.4f}\t{metric.recall():.4f}\t{metric.micro_avg_f_score():.4f}",
                log_header="PRECISION\tRECALL\tF1",
                detailed_results=detailed_result,
            )

            return result, eval_loss

    def forward_loss(self,
                     data_points: Union[List[Sentence], Sentence],
                     sort=True) -> torch.tensor:
        features = self.forward(data_points)
        return self._calculate_loss(features, data_points)

    def forward(self, sentences: List[Sentence]):

        self.embeddings.embed(sentences)

        lengths: List[int] = [len(sentence.tokens) for sentence in sentences]
        longest_token_sequence_in_batch: int = max(lengths)

        pre_allocated_zero_tensor = torch.zeros(
            self.embeddings.embedding_length * longest_token_sequence_in_batch,
            dtype=torch.float,
            device=flair.device,
        )

        all_embs = list()
        for sentence in sentences:
            all_embs += [
                emb for token in sentence
                for emb in token.get_each_embedding()
            ]
            nb_padding_tokens = longest_token_sequence_in_batch - len(sentence)

            if nb_padding_tokens > 0:
                t = pre_allocated_zero_tensor[:self.embeddings.
                                              embedding_length *
                                              nb_padding_tokens]
                all_embs.append(t)

        sentence_tensor = torch.cat(all_embs).view([
            len(sentences),
            longest_token_sequence_in_batch,
            self.embeddings.embedding_length,
        ])

        # --------------------------------------------------------------------
        # FF PART
        # --------------------------------------------------------------------
        if self.use_dropout > 0.0:
            sentence_tensor = self.dropout(sentence_tensor)
        if self.use_word_dropout > 0.0:
            sentence_tensor = self.word_dropout(sentence_tensor)
        if self.use_locked_dropout > 0.0:
            sentence_tensor = self.locked_dropout(sentence_tensor)

        if self.relearn_embeddings:
            sentence_tensor = self.embedding2nn(sentence_tensor)

        if self.use_rnn:
            packed = torch.nn.utils.rnn.pack_padded_sequence(
                sentence_tensor,
                lengths,
                enforce_sorted=False,
                batch_first=True)

            # if initial hidden state is trainable, use this state
            if self.train_initial_hidden_state:
                initial_hidden_state = [
                    self.lstm_init_h.unsqueeze(1).repeat(1, len(sentences), 1),
                    self.lstm_init_c.unsqueeze(1).repeat(1, len(sentences), 1),
                ]
                rnn_output, hidden = self.rnn(packed, initial_hidden_state)
            else:
                rnn_output, hidden = self.rnn(packed)

            sentence_tensor, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(
                rnn_output, batch_first=True)

            if self.use_dropout > 0.0:
                sentence_tensor = self.dropout(sentence_tensor)
            # word dropout only before LSTM - TODO: more experimentation needed
            # if self.use_word_dropout > 0.0:
            #     sentence_tensor = self.word_dropout(sentence_tensor)
            if self.use_locked_dropout > 0.0:
                sentence_tensor = self.locked_dropout(sentence_tensor)

        features = self.linear(sentence_tensor)

        return features

    def _score_sentence(self, feats, tags, lens_):

        start = torch.tensor([self.tag_dictionary.get_idx_for_item(START_TAG)],
                             device=flair.device)
        start = start[None, :].repeat(tags.shape[0], 1)

        stop = torch.tensor([self.tag_dictionary.get_idx_for_item(STOP_TAG)],
                            device=flair.device)
        stop = stop[None, :].repeat(tags.shape[0], 1)

        pad_start_tags = torch.cat([start, tags], 1)
        pad_stop_tags = torch.cat([tags, stop], 1)

        for i in range(len(lens_)):
            pad_stop_tags[i, lens_[i]:] = self.tag_dictionary.get_idx_for_item(
                STOP_TAG)

        score = torch.FloatTensor(feats.shape[0]).to(flair.device)

        for i in range(feats.shape[0]):
            r = torch.LongTensor(range(lens_[i])).to(flair.device)

            score[i] = torch.sum(self.transitions[
                pad_stop_tags[i, :lens_[i] + 1],
                pad_start_tags[i, :lens_[i] + 1]]) + torch.sum(
                    feats[i, r, tags[i, :lens_[i]]])

        return score

    def _calculate_loss(self, features: torch.tensor,
                        sentences: List[Sentence]) -> float:

        lengths: List[int] = [len(sentence.tokens) for sentence in sentences]

        tag_list: List = []
        for s_id, sentence in enumerate(sentences):
            # get the tags in this sentence
            tag_idx: List[int] = [
                self.tag_dictionary.get_idx_for_item(
                    token.get_tag(self.tag_type).value) for token in sentence
            ]
            # add tags as tensor
            tag = torch.tensor(tag_idx, device=flair.device)
            tag_list.append(tag)

        if self.use_crf:
            # pad tags if using batch-CRF decoder
            tags, _ = pad_tensors(tag_list)

            forward_score = self._forward_alg(features, lengths)
            gold_score = self._score_sentence(features, tags, lengths)

            score = forward_score - gold_score

            return score.mean()

        else:
            score = 0
            for sentence_feats, sentence_tags, sentence_length in zip(
                    features, tag_list, lengths):
                sentence_feats = sentence_feats[:sentence_length]
                score += torch.nn.functional.cross_entropy(
                    sentence_feats, sentence_tags, weight=self.loss_weights)
            score /= len(features)
            return score

    def _obtain_labels(
        self,
        feature: torch.Tensor,
        batch_sentences: List[Sentence],
        transitions: Optional[np.ndarray],
        get_all_tags: bool,
    ) -> (List[List[Label]], List[List[List[Label]]]):
        """
        Returns a tuple of two lists:
         - The first list corresponds to the most likely `Label` per token in each sentence.
         - The second list contains a probability distribution over all `Labels` for each token
           in a sentence for all sentences.
        """

        lengths: List[int] = [
            len(sentence.tokens) for sentence in batch_sentences
        ]

        tags = []
        all_tags = []
        feature = feature.cpu()
        if self.use_crf:
            feature = feature.numpy()
        else:
            for index, length in enumerate(lengths):
                feature[index, length:] = 0
            softmax_batch = F.softmax(feature, dim=2).cpu()
            scores_batch, prediction_batch = torch.max(softmax_batch, dim=2)
            feature = zip(softmax_batch, scores_batch, prediction_batch)

        for feats, length in zip(feature, lengths):
            if self.use_crf:
                confidences, tag_seq, scores = self._viterbi_decode(
                    feats=feats[:length],
                    transitions=transitions,
                    all_scores=get_all_tags,
                )
            else:
                softmax, score, prediction = feats
                confidences = score[:length].tolist()
                tag_seq = prediction[:length].tolist()
                scores = softmax[:length].tolist()

            tags.append([
                Label(self.tag_dictionary.get_item_for_index(tag), conf)
                for conf, tag in zip(confidences, tag_seq)
            ])

            if get_all_tags:
                all_tags.append([[
                    Label(self.tag_dictionary.get_item_for_index(score_id),
                          score) for score_id, score in enumerate(score_dist)
                ] for score_dist in scores])

        return tags, all_tags

    @staticmethod
    def _softmax(x, axis):
        # reduce raw values to avoid NaN during exp
        x_norm = x - x.max(axis=axis, keepdims=True)
        y = np.exp(x_norm)
        return y / y.sum(axis=axis, keepdims=True)

    def _viterbi_decode(self, feats: np.ndarray, transitions: np.ndarray,
                        all_scores: bool):
        id_start = self.tag_dictionary.get_idx_for_item(START_TAG)
        id_stop = self.tag_dictionary.get_idx_for_item(STOP_TAG)

        backpointers = np.empty(shape=(feats.shape[0], self.tagset_size),
                                dtype=np.int_)
        backscores = np.empty(shape=(feats.shape[0], self.tagset_size),
                              dtype=np.float32)

        init_vvars = np.expand_dims(np.repeat(-10000.0, self.tagset_size),
                                    axis=0).astype(np.float32)
        init_vvars[0][id_start] = 0

        forward_var = init_vvars
        for index, feat in enumerate(feats):
            # broadcasting will do the job of reshaping and is more efficient than calling repeat
            next_tag_var = forward_var + transitions
            bptrs_t = next_tag_var.argmax(axis=1)
            viterbivars_t = next_tag_var[np.arange(bptrs_t.shape[0]), bptrs_t]
            forward_var = viterbivars_t + feat
            backscores[index] = forward_var
            forward_var = forward_var[np.newaxis, :]
            backpointers[index] = bptrs_t

        terminal_var = forward_var.squeeze() + transitions[id_stop]
        terminal_var[id_stop] = -10000.0
        terminal_var[id_start] = -10000.0
        best_tag_id = terminal_var.argmax()

        best_path = [best_tag_id]
        for bptrs_t in reversed(backpointers):
            best_tag_id = bptrs_t[best_tag_id]
            best_path.append(best_tag_id)

        start = best_path.pop()
        assert start == id_start
        best_path.reverse()

        best_scores_softmax = self._softmax(backscores, axis=1)
        best_scores_np = np.max(best_scores_softmax, axis=1)

        # default value
        all_scores_np = np.zeros(0, dtype=np.float64)
        if all_scores:
            all_scores_np = best_scores_softmax
            for index, (tag_id,
                        tag_scores) in enumerate(zip(best_path,
                                                     all_scores_np)):
                if type(tag_id) != int and tag_id.item() != tag_scores.argmax(
                ):
                    swap_index_score = tag_scores.argmax()
                    (
                        all_scores_np[index][tag_id.item()],
                        all_scores_np[index][swap_index_score],
                    ) = (
                        all_scores_np[index][swap_index_score],
                        all_scores_np[index][tag_id.item()],
                    )
                elif type(tag_id) == int and tag_id != tag_scores.argmax():
                    swap_index_score = tag_scores.argmax()
                    (
                        all_scores_np[index][tag_id],
                        all_scores_np[index][swap_index_score],
                    ) = (
                        all_scores_np[index][swap_index_score],
                        all_scores_np[index][tag_id],
                    )

        return best_scores_np.tolist(), best_path, all_scores_np.tolist()

    def _forward_alg(self, feats, lens_):

        init_alphas = torch.FloatTensor(self.tagset_size).fill_(-10000.0)
        init_alphas[self.tag_dictionary.get_idx_for_item(START_TAG)] = 0.0

        forward_var = torch.zeros(
            feats.shape[0],
            feats.shape[1] + 1,
            feats.shape[2],
            dtype=torch.float,
            device=flair.device,
        )

        forward_var[:, 0, :] = init_alphas[None, :].repeat(feats.shape[0], 1)

        transitions = self.transitions.view(1, self.transitions.shape[0],
                                            self.transitions.shape[1]).repeat(
                                                feats.shape[0], 1, 1)

        for i in range(feats.shape[1]):
            emit_score = feats[:, i, :]

            tag_var = (
                emit_score[:, :, None].repeat(1, 1, transitions.shape[2]) +
                transitions + forward_var[:, i, :][:, :, None].repeat(
                    1, 1, transitions.shape[2]).transpose(2, 1))

            max_tag_var, _ = torch.max(tag_var, dim=2)

            tag_var = tag_var - max_tag_var[:, :, None].repeat(
                1, 1, transitions.shape[2])

            agg_ = torch.log(torch.sum(torch.exp(tag_var), dim=2))

            cloned = forward_var.clone()
            cloned[:, i + 1, :] = max_tag_var + agg_

            forward_var = cloned

        forward_var = forward_var[range(forward_var.shape[0]), lens_, :]

        terminal_var = forward_var + self.transitions[
            self.tag_dictionary.get_idx_for_item(STOP_TAG)][None, :].repeat(
                forward_var.shape[0], 1)

        alpha = log_sum_exp_batch(terminal_var)

        return alpha

    @staticmethod
    def _filter_empty_sentences(sentences: List[Sentence]) -> List[Sentence]:
        filtered_sentences = [
            sentence for sentence in sentences if sentence.tokens
        ]
        if len(sentences) != len(filtered_sentences):
            log.warning(
                f"Ignore {len(sentences) - len(filtered_sentences)} sentence(s) with no tokens."
            )
        return filtered_sentences

    @staticmethod
    def _filter_empty_string(texts: List[str]) -> List[str]:
        filtered_texts = [text for text in texts if text]
        if len(texts) != len(filtered_texts):
            log.warning(
                f"Ignore {len(texts) - len(filtered_texts)} string(s) with no tokens."
            )
        return filtered_texts

    @staticmethod
    def _fetch_model(model_name) -> str:

        model_map = {}

        aws_resource_path_v04 = "https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/models-v0.4"
        hu_path: str = "https://nlp.informatik.hu-berlin.de/resources/models"

        model_map["ner"] = "/".join([
            aws_resource_path_v04, "NER-conll03-english",
            "en-ner-conll03-v0.4.pt"
        ])

        model_map["ner-fast"] = "/".join([
            aws_resource_path_v04,
            "NER-conll03--h256-l1-b32-p3-0.5-%2Bglove%2Bnews-forward-fast%2Bnews-backward-fast-normal-locked0.5-word0.05--release_4",
            "en-ner-fast-conll03-v0.4.pt",
        ])

        model_map["ner-ontonotes"] = "/".join([
            aws_resource_path_v04,
            "release-ner-ontonotes-0",
            "en-ner-ontonotes-v0.4.pt",
        ])

        model_map["ner-ontonotes-fast"] = "/".join([
            aws_resource_path_v04,
            "release-ner-ontonotes-fast-0",
            "en-ner-ontonotes-fast-v0.4.pt",
        ])

        for key in ["ner-multi", "multi-ner"]:
            model_map[key] = "/".join([
                aws_resource_path_v04,
                "release-quadner-512-l2-multi-embed",
                "quadner-large.pt",
            ])

        for key in ["ner-multi-fast", "multi-ner-fast"]:
            model_map[key] = "/".join(
                [aws_resource_path_v04, "NER-multi-fast", "ner-multi-fast.pt"])

        for key in ["ner-multi-fast-learn", "multi-ner-fast-learn"]:
            model_map[key] = "/".join([
                aws_resource_path_v04,
                "NER-multi-fast-evolve",
                "ner-multi-fast-learn.pt",
            ])

        model_map["upos"] = "/".join([
            aws_resource_path_v04,
            "POS-ontonotes--h256-l1-b32-p3-0.5-%2Bglove%2Bnews-forward%2Bnews-backward-normal-locked0.5-word0.05--v0.4_0",
            "en-pos-ontonotes-v0.4.pt",
        ])

        model_map["pos"] = "/".join([
            hu_path,
            "release-pos-0",
            "en-pos-ontonotes-v0.5.pt",
        ])

        model_map["upos-fast"] = "/".join([
            aws_resource_path_v04,
            "release-pos-fast-0",
            "en-pos-ontonotes-fast-v0.4.pt",
        ])

        model_map["pos-fast"] = "/".join([
            hu_path,
            "release-pos-fast-0",
            "en-pos-ontonotes-fast-v0.5.pt",
        ])

        for key in ["pos-multi", "multi-pos"]:
            model_map[key] = "/".join([
                aws_resource_path_v04,
                "release-dodekapos-512-l2-multi",
                "pos-multi-v0.1.pt",
            ])

        for key in ["pos-multi-fast", "multi-pos-fast"]:
            model_map[key] = "/".join([
                aws_resource_path_v04, "UPOS-multi-fast", "pos-multi-fast.pt"
            ])

        model_map["frame"] = "/".join([
            aws_resource_path_v04, "release-frame-1",
            "en-frame-ontonotes-v0.4.pt"
        ])

        model_map["frame-fast"] = "/".join([
            aws_resource_path_v04,
            "release-frame-fast-0",
            "en-frame-ontonotes-fast-v0.4.pt",
        ])

        model_map["chunk"] = "/".join([
            aws_resource_path_v04,
            "NP-conll2000--h256-l1-b32-p3-0.5-%2Bnews-forward%2Bnews-backward-normal-locked0.5-word0.05--v0.4_0",
            "en-chunk-conll2000-v0.4.pt",
        ])

        model_map["chunk-fast"] = "/".join([
            aws_resource_path_v04,
            "release-chunk-fast-0",
            "en-chunk-conll2000-fast-v0.4.pt",
        ])

        model_map["da-pos"] = "/".join(
            [aws_resource_path_v04, "POS-danish", "da-pos-v0.1.pt"])

        model_map["da-ner"] = "/".join(
            [aws_resource_path_v04, "NER-danish", "da-ner-v0.1.pt"])

        model_map["de-pos"] = "/".join(
            [hu_path, "release-de-pos-0", "de-pos-ud-hdt-v0.5.pt"])

        model_map["de-pos-tweets"] = "/".join([
            aws_resource_path_v04,
            "POS-fine-grained-german-tweets",
            "de-pos-twitter-v0.1.pt",
        ])

        model_map["de-ner"] = "/".join([
            aws_resource_path_v04, "release-de-ner-0", "de-ner-conll03-v0.4.pt"
        ])

        model_map["de-ner-germeval"] = "/".join([
            aws_resource_path_v04, "NER-germeval", "de-ner-germeval-0.4.1.pt"
        ])

        model_map["fr-ner"] = "/".join([
            aws_resource_path_v04, "release-fr-ner-0", "fr-ner-wikiner-0.4.pt"
        ])
        model_map["nl-ner"] = "/".join([
            aws_resource_path_v04, "NER-conll2002-dutch",
            "nl-ner-conll02-v0.1.pt"
        ])
        model_map[
            "ml-pos"] = "https://raw.githubusercontent.com/qburst/models-repository/master/FlairMalayalamModels/malayalam-xpos-model.pt"
        model_map[
            "ml-upos"] = "https://raw.githubusercontent.com/qburst/models-repository/master/FlairMalayalamModels/malayalam-upos-model.pt"

        cache_dir = Path("models")
        if model_name in model_map:
            model_name = cached_path(model_map[model_name],
                                     cache_dir=cache_dir)

        # the historical German taggers by the @redewiegergabe project
        if model_name == "de-historic-indirect":
            model_file = Path(
                flair.cache_root) / cache_dir / 'indirect' / 'final-model.pt'
            if not model_file.exists():
                cached_path('http://www.redewiedergabe.de/models/indirect.zip',
                            cache_dir=cache_dir)
                unzip_file(
                    Path(flair.cache_root) / cache_dir / 'indirect.zip',
                    Path(flair.cache_root) / cache_dir)
            model_name = str(
                Path(flair.cache_root) / cache_dir / 'indirect' /
                'final-model.pt')

        if model_name == "de-historic-direct":
            model_file = Path(
                flair.cache_root) / cache_dir / 'direct' / 'final-model.pt'
            if not model_file.exists():
                cached_path('http://www.redewiedergabe.de/models/direct.zip',
                            cache_dir=cache_dir)
                unzip_file(
                    Path(flair.cache_root) / cache_dir / 'direct.zip',
                    Path(flair.cache_root) / cache_dir)
            model_name = str(
                Path(flair.cache_root) / cache_dir / 'direct' /
                'final-model.pt')

        if model_name == "de-historic-reported":
            model_file = Path(
                flair.cache_root) / cache_dir / 'reported' / 'final-model.pt'
            if not model_file.exists():
                cached_path('http://www.redewiedergabe.de/models/reported.zip',
                            cache_dir=cache_dir)
                unzip_file(
                    Path(flair.cache_root) / cache_dir / 'reported.zip',
                    Path(flair.cache_root) / cache_dir)
            model_name = str(
                Path(flair.cache_root) / cache_dir / 'reported' /
                'final-model.pt')

        if model_name == "de-historic-free-indirect":
            model_file = Path(flair.cache_root
                              ) / cache_dir / 'freeIndirect' / 'final-model.pt'
            if not model_file.exists():
                cached_path(
                    'http://www.redewiedergabe.de/models/freeIndirect.zip',
                    cache_dir=cache_dir)
                unzip_file(
                    Path(flair.cache_root) / cache_dir / 'freeIndirect.zip',
                    Path(flair.cache_root) / cache_dir)
            model_name = str(
                Path(flair.cache_root) / cache_dir / 'freeIndirect' /
                'final-model.pt')

        return model_name

    def get_transition_matrix(self):
        data = []
        for to_idx, row in enumerate(self.transitions):
            for from_idx, column in enumerate(row):
                row = [
                    self.tag_dictionary.get_item_for_index(from_idx),
                    self.tag_dictionary.get_item_for_index(to_idx),
                    column.item(),
                ]
                data.append(row)
            data.append(["----"])
        print(tabulate(data, headers=["FROM", "TO", "SCORE"]))

    def __str__(self):
        return super(flair.nn.Model, self).__str__().rstrip(')') + \
               f'  (beta): {self.beta}\n' + \
               f'  (weights): {self.weight_dict}\n' + \
               f'  (weight_tensor) {self.loss_weights}\n)'
class Conv2d_filtermap_1x1_compression(_ConvNd_filtermap_1x1_compression):
    def __init__(self, in_channels, out_channels, channel_compression_1x1, kernel_size, binary_filtermap = False, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        
        super(Conv2d_filtermap_1x1_compression, self).__init__(
            in_channels, out_channels, channel_compression_1x1, kernel_size, stride, padding, dilation,
            False, _pair(0), groups, bias)
        
        self.binary_filtermap = binary_filtermap
        self.reset_parameters1()
        #pdb.set_trace()
        #self.conv_weight = Parameter(torch.Tensor(out_channels, self.in_channels // self.groups, \
        #              self.kernel_size[0], self.kernel_size[1]))
        #self.conv_weight.requires_grad = False
        #self.conv_weight.cuda() 
    def reset_parameters1(self):
        fm_size = self.filtermap.size()
        fm_width = fm_size[2]
        fm_height = fm_size[1]
        fm_depth = fm_size[0]
        # not for 1x1 conv, do the padding on the spatial 
        if self.filtermap.size()[1] > 1 and self.filtermap.size()[2] > 1:
           self.fm_pad_width = fm_width + 1
           self.fm_pad_height = fm_height + 1
        #for 1x1 conv no padding on the spatial 
        else:
           self.fm_pad_width = fm_width
           self.fm_pad_height = fm_height

        self.fm_pad_depth = fm_depth*2
        #set the ids for extracting filters from filtermap
        out_channels = self.out_channels
        in_channels = self.in_channels // self.groups
        k_h = self.kernel_size[0]
        k_w = self.kernel_size[1]

        sample_y = self.sample_y
        sample_x = self.sample_x
        sample_c = self.sample_c

        stride_y = self.stride_y
        stride_x = self.stride_x
        stride_c = self.stride_c
        
 
        fm_depth = self.fm_pad_depth
        fm_height = self.fm_pad_height
        fm_width = self.fm_pad_width
        
        
        ids = (torch.Tensor(range(0,k_h*k_w)))
        tmp_count = 0
        for y in range(0,k_h):
            for x in range(0,k_w):
                ids[tmp_count] = y*fm_width+x
                tmp_count = tmp_count+1
 
        ids0 = ids
               
        #pdb.set_trace() 
        for c in range(1,in_channels):
            ids_c = ids0 + c*fm_height*fm_width
            ids = torch.cat((ids,ids_c),0)
        
        #ids0 = ids
        #for x in range(1, out_channels):
        #    ids = torch.cat((ids,ids0),0)
        #pdb.set_trace()
        ids0 = ids
        for y in range(0,sample_y):
            for x in range(0,sample_x):
                if y == 0 and x == 0:
                   continue
                ss = y*stride_y*fm_width + x*stride_x
                ids_ss = ids0+ss
                ids = torch.cat((ids,ids_ss),0)
        
        #pdb.set_trace() 
        ids0 = ids
        for c in range(1,sample_c):
            ids_c = ids0+c*stride_c*fm_height*fm_width
            ids = torch.cat((ids,ids_c),0)
        
        #pdb.set_trace()
        #ids = ids.long()
        #ids = ids.detach()

        #pdb.set_trace()
        ids = ids.long()
        self.ids = Parameter(ids)
        self.ids.requires_grad = False
        #self.register_parameter()
        #if torch.max(ids) >= fm_depth*fm_height*fm_width or torch.min(ids) < 0:
        #print(torch.max(ids))
        #ids = Variable(ids)
    def extract_filters(self):
        #pdb.set_trace()
         
        out_channels = self.out_channels
        in_channels = self.in_channels // self.groups 
        k_h = self.kernel_size[0]
        k_w = self.kernel_size[1]

        #for compressing the channel by 2 times
        #if in_channels != 3:
        #   filtermap_pad_tmp = torch.cat((self.filtermap,self.filtermap),0)
        #else:
        #   filtermap_pad_tmp = self.filtermap
        #filtermap_pad = torch.cat((filtermap_pad_tmp,filtermap_pad_tmp),0)
        
        #for not compressing the channel
        filtermap_pad = torch.cat((self.filtermap,self.filtermap),0)
        # not for 1x1 conv, do the padding on the spatial 
        if self.filtermap.size()[1] > 1 and self.filtermap.size()[2] > 1: 
           filtermap_pad_s1 = filtermap_pad[:,1,:]
           filtermap_pad_s1 = filtermap_pad_s1[:,None,:]
           filtermap_pad = torch.cat((filtermap_pad,filtermap_pad_s1),1)
           filtermap_pad_s2 = filtermap_pad[:,:,1]
           filtermap_pad_s2 = filtermap_pad_s2[:,:,None]
           filtermap_pad = torch.cat((filtermap_pad,filtermap_pad_s2),2)

        #pdb.set_trace()
        ids = self.ids.detach()
        conv_weight = filtermap_pad.view(-1,1).index_select(0,ids)
        conv_weight = conv_weight.view(out_channels,in_channels,k_h,k_w)
        if self.binary_filtermap:
           binary_conv_weight = conv_weight.clone()
           for nf in range(0,out_channels):
               float_filter = conv_weight[nf,:,:,:];
               L1_norm = torch.norm(float_filter.view(-1,1),1);
               sign_filter = torch.sign(float_filter);
               binary_filter = sign_filter*L1_norm;
               binary_conv_weight[nf,:,:,:] = binary_filter
           return binary_conv_weight
        else:
           return conv_weight
        #pdb.set_trace()
        #for c in range(0,sample_c):
        #   for y in range(0,sample_y):
        #      for x in range(0,sample_x):
        #          filter_count = c*sample_y*sample_x + y*sample_x + x
        #          conv_weight_clone[filter_count,:,:,:] = filtermap_pad[c*stride_c:c*stride_c+in_channels, \
        #                                      y*stride_y:y*stride_y+k_h, x*stride_x:x*stride_x+k_w]
        #return conv_weight
    def forward(self, input):
        #return F.conv2d(input, self.weight, self.bias, self.stride,
        #                self.padding, self.dilation, self.groups)
        
        #conv_weight = torch.mm(self.compressed_weight, torch.tanh(self.transform_mat))
        #conv_weight = torch.mm(self.compressed_weight, (self.transform_mat))
        
        #compressed_weight = torch.mm(self.input_weight, torch.tanh(self.transform_mat))
        #conv_weight = torch.mm(compressed_weight, torch.tanh(self.transform_back_mat))
       
        #conv_weight = conv_weight.view(self.in_channels // self.groups, self.kernel_size[0], \
        #        self.kernel_size[1], self.out_channels);
        #conv_weight = conv_weight.permute(3, 0, 1, 2)
        #conv_weight = conv_weight.contiguous()
        
        #fit_loss = torch.norm(conv_weight-self.ref_conv_weight,2)

        #pdb.set_trace()
        #conv_weight = Variable(torch.Tensor(self.out_channels, self.in_channels // self.groups, \
        #              self.kernel_size[0], self.kernel_size[1]))
        #conv_weight.cuda()
        conv_weight = self.extract_filters()
        #conv_weight[0,:,:,:] = self.filtermap[0:self.in_channels, \
        #                                      0:0+self.kernel_size[0], 0:0+self.kernel_size[1]]
        out = F.conv2d(input, conv_weight, self.bias, self.stride,
                         self.padding, self.dilation, self.groups)
   
        return out        
        #return F.conv2d(input, conv_weight, self.bias, self.stride,
        #                 self.padding, self.dilation, self.groups)
    def fit_filtermap1(self, conv):
        conv_weight = conv.weight.data
        conv_weight_size = conv_weight.size()
        out_channels = conv_weight_size[0]
        in_channels = conv_weight_size[1]
        k_h = conv_weight_size[2]
        k_w = conv_weight_size[3]

        sample_y = self.sample_y
        sample_x = self.sample_x
        sample_c = self.sample_c

        stride_y = self.stride_y
        stride_x = self.stride_x
        stride_c = self.stride_c

        fm_ext = torch.Tensor(sample_c*in_channels,sample_y*k_h,sample_x*k_w)
        for c in range(0,sample_c):
            for y in range(0,sample_y):
                for x in range(0,sample_x):
                    filter_count = c*sample_y*sample_x + y*sample_x + x
                    fm_ext[c*in_channels:(c+1)*in_channels,y*k_h:(y+1)*k_h,x*k_w:(x+1)*k_w] = conv_weight[filter_count,:,:,:]


        #pdb.set_trace()
        for oc in range(0,sample_c):
            for c in range(1,sample_c):
                idx = sample_c-c+oc
                if idx > sample_c-1:
                   idx = 0
                fm_ext[oc*stride_c:(oc+1)*stride_c,:,:] += \
                   fm_ext[idx*stride_c:(idx+1)*stride_c,:,:]

        fm_ext = fm_ext[0:in_channels,:,:]/sample_c

        fm_ext_h = fm_ext.size()[1]
        fm_ext_w = fm_ext.size()[2]

        for y in range(0,sample_y):
            fm_ext[:,y*k_h,:] += fm_ext[:,(y*k_h-1+fm_ext_h)%fm_ext_h,:]
            fm_ext[:,y*k_h,:] = fm_ext[:,y*k_h,:]/2

        for x in range(0,sample_x):
            fm_ext[:,:,x*k_w] += fm_ext[:,:,(x*k_w-1+fm_ext_w)%fm_ext_w] 
            fm_ext[:,:,x*k_w] = fm_ext[:,:,x*k_w]/2

        y_ids = torch.Tensor([0,1])
        y_ids0 = y_ids
        for y in range(1,sample_y):
            y_ids = torch.cat((y_ids,y_ids0+y*k_h),0)
    
        x_ids = torch.Tensor([0,1])
        x_ids0 = x_ids
        for x in range(1,sample_x):
            x_ids = torch.cat((x_ids,x_ids0+x*k_w),0)

        fm = torch.index_select(fm_ext,1,y_ids.long())
        fm = torch.index_select(fm,2,x_ids.long())
    
    
        self.filtermap.data = fm
Exemple #39
0
class MaskedLinear(nn.Module):
    """
    Adopted from https://github.com/rtqichen/ffjord
    Creates masked linear layer for MLP MADE.
    For input (x) to hidden (h) or hidden to hidden layers choose diagonal_zeros = False.
    For hidden to output (y) layers:
    If output depends on input through y_i = f(x_{<i}) set diagonal_zeros = True.
    Else if output depends on input through y_i = f(x_{<=i}) set diagonal_zeros = False.
    """
    def __init__(self,
                 in_features,
                 out_features,
                 diagonal_zeros=False,
                 bias=True):
        super(MaskedLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.diagonal_zeros = diagonal_zeros
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter("bias", None)
        mask = torch.from_numpy(self.build_mask())
        if torch.cuda.is_available():
            mask = mask.cuda()
        self.mask = torch.autograd.Variable(mask, requires_grad=False)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_normal(self.weight)
        if self.bias is not None:
            self.bias.data.zero_()

    def build_mask(self):
        n_in, n_out = self.in_features, self.out_features
        assert n_in % n_out == 0 or n_out % n_in == 0

        mask = np.ones((n_in, n_out), dtype=np.float32)
        if n_out >= n_in:
            k = n_out // n_in
            for i in range(n_in):
                mask[i + 1:, i * k:(i + 1) * k] = 0
                if self.diagonal_zeros:
                    mask[i:i + 1, i * k:(i + 1) * k] = 0
        else:
            k = n_in // n_out
            for i in range(n_out):
                mask[(i + 1) * k:, i:i + 1] = 0
                if self.diagonal_zeros:
                    mask[i * k:(i + 1) * k:, i:i + 1] = 0
        return mask

    def forward(self, x):
        output = x.mm(self.mask * self.weight)

        if self.bias is not None:
            return output.add(self.bias.expand_as(output))
        else:
            return output

    def __repr__(self):
        if self.bias is not None:
            bias = True
        else:
            bias = False
        return (self.__class__.__name__ + " (" + str(self.in_features) +
                " -> " + str(self.out_features) + ", diagonal_zeros=" +
                str(self.diagonal_zeros) + ", bias=" + str(bias) + ")")
Exemple #40
0
    def __init__(self, input_size, hidden_size, bias=True, prior = None):
        super(LSTMCellVB, self).__init__()
        self.type_layer = "LSTM"
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.posterior_mean = False # Flag to know if we sample from the posterior mean or we actually sample
        
        ## If no prior is specified we just create it ourselves
        if (type(prior) == type (None)):
            prior = Vil.Prior(0.5, np.log(0.1),np.log(0.5))
        self.prior = prior
        
        """
        Variational Inference Parameters
        """
        self.mu_weight_ih = Parameter(torch.Tensor(4 * hidden_size, input_size).to(device = Vil.device, dtype = Vil.dtype))
        self.mu_weight_hh = Parameter(torch.Tensor(4 * hidden_size, hidden_size).to(device = Vil.device, dtype = Vil.dtype))
        self.rho_weight_ih = Parameter(torch.Tensor(4 * hidden_size, input_size).to(device = Vil.device, dtype = Vil.dtype))
        self.rho_weight_hh = Parameter(torch.Tensor(4 * hidden_size, hidden_size).to(device = Vil.device, dtype = Vil.dtype))
        if bias:
            self.mu_bias_ih = Parameter(torch.Tensor(4 * hidden_size).to(device = Vil.device, dtype = Vil.dtype))
            self.mu_bias_hh = Parameter(torch.Tensor(4 * hidden_size).to(device = Vil.device, dtype = Vil.dtype))
            self.rho_bias_ih = Parameter(torch.Tensor(4 * hidden_size).to(device = Vil.device, dtype = Vil.dtype))
            self.rho_bias_hh = Parameter(torch.Tensor(4 * hidden_size).to(device = Vil.device, dtype = Vil.dtype))
        else:
            self.register_parameter('mu_bias_ih', None)
            self.register_parameter('mu_bias_hh', None)
            self.register_parameter('rho_bias_ih', None)
            self.register_parameter('rho_bias_hh', None)
        """
        Sampled weights
        """

        self.weight_ih = torch.Tensor(4 * hidden_size, input_size).to(device = Vil.device, dtype = Vil.dtype)
        self.weight_hh = torch.Tensor(4 * hidden_size, hidden_size).to(device = Vil.device, dtype = Vil.dtype)
        if bias:
            self.bias_ih = torch.Tensor(4 * hidden_size).to(device = Vil.device, dtype = Vil.dtype)
            self.bias_hh = torch.Tensor(4 * hidden_size).to(device = Vil.device, dtype = Vil.dtype)
        else:
            self.register_parameter('bias_ih', None)
            self.register_parameter('bias_hh', None)

        if(1):
            print ("linear bias_ih device: ",self.bias_ih.device)
            print ("linear weights_ih device: ",self.weight_ih.device)
            print ("linear bias mu_ih device: ",self.mu_bias_ih.device)
            print ("linear bias rho_ih device: ",self.rho_bias_ih.device)
            
            print ("linear weights mu_ih  device: ",self.mu_weight_ih.device)
            print ("linear weights rho_ih device: ",self.rho_weight_ih.device)
            
            print ("linear bias_hh device: ",self.bias_hh.device)
            print ("linear weights_hh device: ",self.weight_hh.device)
            print ("linear bias mu_hh device: ",self.mu_bias_hh.device)
            print ("linear bias rho_hh device: ",self.rho_bias_hh.device)
            
            print ("linear weights mu_hh  device: ",self.mu_weight_hh.device)
            print ("linear weights rho_hh device: ",self.rho_weight_hh.device)
            
        self.reset_parameters()
        self.sample_posterior()
Exemple #41
0
    def __init__(self,
                 input_vocabulary,
                 target_vocabulary,
                 hidden_size=512,
                 embedding_size=128,
                 cell_type="LSTM"):
        """
        :param list input_vocabulary: list of possible inputs
        :param list target_vocabulary: list of possible targets
        """
        super(Network, self).__init__()
        self.h_input_encoder_size = hidden_size
        self.h_output_encoder_size = hidden_size
        self.h_decoder_size = hidden_size
        self.embedding_size = embedding_size
        self.input_vocabulary = input_vocabulary
        self.target_vocabulary = target_vocabulary
        # Number of tokens in input vocabulary
        self.v_input = len(input_vocabulary)
        # Number of tokens in target vocabulary
        self.v_target = len(target_vocabulary)

        self.cell_type = cell_type
        if cell_type == 'GRU':
            self.input_encoder_cell = nn.GRUCell(
                input_size=self.v_input + 1,
                hidden_size=self.h_input_encoder_size,
                bias=True)
            self.input_encoder_init = Parameter(
                torch.rand(1, self.h_input_encoder_size))
            self.output_encoder_cell = nn.GRUCell(
                input_size=self.v_input + 1 + self.h_input_encoder_size,
                hidden_size=self.h_output_encoder_size,
                bias=True)
            self.decoder_cell = nn.GRUCell(input_size=self.v_target + 1,
                                           hidden_size=self.h_decoder_size,
                                           bias=True)
        if cell_type == 'LSTM':
            self.input_encoder_cell = nn.LSTMCell(
                input_size=self.v_input + 1,
                hidden_size=self.h_input_encoder_size,
                bias=True)
            self.input_encoder_init = nn.ParameterList([
                Parameter(torch.rand(1, self.h_input_encoder_size)),
                Parameter(torch.rand(1, self.h_input_encoder_size))
            ])
            self.output_encoder_cell = nn.LSTMCell(
                input_size=self.v_input + 1 + self.h_input_encoder_size,
                hidden_size=self.h_output_encoder_size,
                bias=True)
            self.output_encoder_init_c = Parameter(
                torch.rand(1, self.h_output_encoder_size))
            self.decoder_cell = nn.LSTMCell(input_size=self.v_target + 1,
                                            hidden_size=self.h_decoder_size,
                                            bias=True)
            self.decoder_init_c = Parameter(torch.rand(1, self.h_decoder_size))

        self.W = nn.Linear(self.h_output_encoder_size + self.h_decoder_size,
                           self.embedding_size)
        self.V = nn.Linear(self.embedding_size, self.v_target + 1)
        self.input_A = nn.Bilinear(self.h_input_encoder_size,
                                   self.h_output_encoder_size,
                                   1,
                                   bias=False)
        self.output_A = nn.Bilinear(self.h_output_encoder_size,
                                    self.h_decoder_size,
                                    1,
                                    bias=False)
        self.input_EOS = torch.zeros(1, self.v_input + 1)
        self.input_EOS[:, -1] = 1
        self.input_EOS = Parameter(self.input_EOS)
        self.output_EOS = torch.zeros(1, self.v_input + 1)
        self.output_EOS[:, -1] = 1
        self.output_EOS = Parameter(self.output_EOS)
        self.target_EOS = torch.zeros(1, self.v_target + 1)
        self.target_EOS[:, -1] = 1
        self.target_EOS = Parameter(self.target_EOS)
Exemple #42
0
 def register(name, tensor):
     self.register_parameter(name, Parameter(tensor))
Exemple #43
0
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups: int = 1,
                 bias: bool = True,
                 padding_mode: str = 'zeros',
                 symmetry: dict = {},
                 share_bias: bool = False):
        '''
        Args:
        symmetry (dict) - number of filters that are symmetric about the horizontal, 
                          vertical, or both axes
                          e.g. {'h':4, 'v': 2, 'hv':8} has 4 filters (2 filter pairs) that are 
                          horizontally symmetric, 2 filters (1 filter pair) which are vertically 
                          symmetric, and 8 filters (2 filter quadruples) that are symmetric 
                          about both axes
        share_bias (bool) - if True, symmetric filter pairs also share their biases
        '''
        super(SymmetricConv2d,
              self).__init__(in_channels, out_channels, kernel_size, stride,
                             padding, dilation, groups, bias, padding_mode)
        if self.groups > 1:
            raise ValueError(self.__str__() + ' does not support groups>1')
        if not bias:
            self.share_bias = False
        else:
            self.share_bias = share_bias
        if symmetry is None:
            # no symmetry, return a standard Conv2d
            self.symmetry = None
        else:
            # Set defaults for symmetric filters pairs
            symmetry = dict(symmetry)  # make a copy
            symmetry.setdefault('h', 0)
            symmetry.setdefault('v', 0)
            symmetry.setdefault('hv', 0)
            self.symmetry = symmetry

            # sanity check: number of filters divisible by 2 resp. 4?
            for key, val in symmetry.items():
                if (key in ['h', 'v']) and (val % 2 != 0):
                    raise ValueError(
                        'Number of symmetric h and v filters must be divisible by 2'
                    )
                elif (key == 'hv') and (val % 4 != 0):
                    raise ValueError(
                        'Number of symmetric hv filters must be divisible by 4'
                    )
            # sanity check: number of symmetric filters must be <= number of filters
            assert sum(
                list(symmetry.values())
            ) <= self.out_channels, "Number of symmetric channels exceeds number of out channels"
            self.unique_out_channels = self.out_channels - symmetry[
                'h'] // 2 - symmetry['v'] // 2 - 3 * symmetry['hv'] // 4

            # Create only the unique weights
            if self.transposed:
                self.weight = Parameter(
                    torch.Tensor(in_channels, self.unique_out_channels,
                                 *self.kernel_size))
            else:
                self.weight = Parameter(
                    torch.Tensor(self.unique_out_channels, in_channels,
                                 *self.kernel_size))

        self.reset_parameters()
 def build(self, input_shape):
     if self._built == False:
         if self.num_parameters is None:
             self.num_parameters = self.input_filters
         self.weight = Parameter(ones((self.num_parameters)) * self.init)
         self._built = True
class PGDAttack(BaseAttack):
    """
    Spectral attack for graph data
    """
    def __init__(self,
                 model=None,
                 nnodes=None,
                 loss_type='CE',
                 feature_shape=None,
                 attack_structure=True,
                 attack_features=False,
                 loss_weight=1.0,
                 regularization_weight=0.0,
                 device='cpu'):

        super(PGDAttack, self).__init__(model, nnodes, attack_structure,
                                        attack_features, device)

        assert attack_structure or attack_features, 'attack_feature or attack_structure cannot be both False'

        self.loss_type = loss_type
        self.modified_adj = None
        self.modified_features = None
        self.loss_weight = loss_weight
        self.regularization_weight = regularization_weight

        if attack_features:
            assert True, 'Current Spectral Attack does not support attack feature'

        if attack_structure:
            assert nnodes is not None, 'Please give nnodes='
            self.adj_changes = Parameter(
                torch.FloatTensor(int(nnodes * (nnodes - 1) / 2)))
            torch.nn.init.uniform_(self.adj_changes, 0.0, 0.001)
            # self.adj_changes.data.fill_(0)

        self.complementary = None

    def set_model(self, model):
        self.surrogate = model

    def attack(self,
               ori_features,
               ori_adj,
               labels,
               idx_target,
               n_perturbations,
               att_lr,
               epochs=200,
               distance_type='l2',
               sample_type='sample',
               opt_type='max',
               verbose=True,
               **kwargs):
        """
        Generate perturbations on the input graph
        """

        victim_model = self.surrogate

        self.sparse_features = sp.issparse(ori_features)
        # ori_adj, ori_features, labels = utils.to_tensor(ori_adj, ori_features, labels, device=self.device)
        ori_adj_norm = utils.normalize_adj_tensor(ori_adj, device=self.device)
        ori_e, ori_v = torch.symeig(ori_adj_norm, eigenvectors=True)

        l, r, m = 0, 0, 0
        victim_model.eval()
        # for t in tqdm(range(epochs), desc='Perturb Adj'):
        for t in tqdm(range(epochs)):
            modified_adj = self.get_modified_adj(ori_adj)
            adj_norm = utils.normalize_adj_tensor(modified_adj,
                                                  device=self.device)
            output = victim_model(
                ori_features,
                adj_norm)  # forward of gcn need to normalize adj first
            task_loss = self._loss(output[idx_target], labels[idx_target])

            # spectral distance term for spectral distance
            eigen_mse = torch.tensor(0)
            eigen_self = torch.tensor(0)
            eigen_gf = torch.tensor(0)
            eigen_norm = self.norm = torch.norm(ori_e)
            if self.regularization_weight != 0:
                # add noise to make the graph asymmetric
                modified_adj_noise = modified_adj
                # modified_adj_noise = self.add_random_noise(modified_adj)
                adj_norm_noise = utils.normalize_adj_tensor(modified_adj_noise,
                                                            device=self.device)
                e, v = torch.symeig(adj_norm_noise, eigenvectors=True)
                eigen_mse = torch.norm(ori_e - e)
                eigen_self = torch.norm(e)

                # low-rank loss in GF-attack
                idx = torch.argsort(e)[:128]
                mask = torch.zeros_like(e).bool()
                mask[idx] = True
                eigen_gf = torch.pow(torch.norm(e * mask, p=2), 2) * torch.pow(
                    torch.norm(torch.matmul(v.detach() * mask, ori_features),
                               p=2), 2)

            reg_loss = 0
            if distance_type == 'l2':
                reg_loss = eigen_mse / eigen_norm
            elif distance_type == 'normDiv':
                reg_loss = eigen_self / eigen_norm
            elif distance_type == 'gf':
                reg_loss = eigen_gf
            else:
                exit(f'unknown distance metric: {distance_type}')

            if verbose and t % 20 == 0:
                loss_target, acc_target = calc_acc(output, labels, idx_target)
                print(
                    '-- Epoch {}, '.format(t),
                    'ptb budget/true = {:.1f}/{:.1f}'.format(
                        n_perturbations,
                        torch.clamp(self.adj_changes, 0, 1).sum()),
                    'l/r/m = {:.4f}/{:.4f}/{:.4f}'.format(l, r, m),
                    'class loss = {:.4f} | '.format(task_loss.item()),
                    'reg loss = {:.4f} | '.format(reg_loss.item()),
                    'mse_norm = {:4f} | '.format(eigen_norm),
                    'eigen_mse = {:.4f} | '.format(eigen_mse),
                    'eigen_self = {:.4f} | '.format(eigen_self),
                    'acc/mis = {:.4f}/{:.4f}'.format(acc_target,
                                                     1 - acc_target))

            self.loss = self.loss_weight * task_loss + self.regularization_weight * reg_loss

            adj_grad = torch.autograd.grad(self.loss, self.adj_changes)[0]

            if self.loss_type == 'CE':
                lr = att_lr / np.sqrt(t + 1)
                self.adj_changes.data.add_(lr * adj_grad)

            if self.loss_type == 'CW':
                lr = att_lr / np.sqrt(t + 1)
                self.adj_changes.data.add_(lr * adj_grad)

            # return self.adj_changes.cpu().detach().numpy()

            if verbose and t % 20 == 0:
                print('budget/true={:.1f}/{:.1f}'.format(
                    n_perturbations,
                    torch.clamp(self.adj_changes, 0, 1).sum()))

            if sample_type == 'sample':
                l, r, m = self.projection(n_perturbations)
            elif sample_type == 'greedy':
                self.greedy(n_perturbations)
            elif sample_type == 'greedy2':
                self.greedy2(n_perturbations)
            elif sample_type == 'greedy3':
                self.greedy3(n_perturbations)
            else:
                exit(f"unkown sample type {sample_type}")

            if verbose and t % 20 == 0:
                print('budget/true={:.1f}/{:.1f}'.format(
                    n_perturbations,
                    torch.clamp(self.adj_changes, 0, 1).sum()))

        if sample_type == 'sample':
            self.random_sample(ori_adj, ori_features, labels, idx_target,
                               n_perturbations)
        elif sample_type == 'greedy':
            self.greedy(n_perturbations)
        elif sample_type == 'greedy2':
            self.greedy2(n_perturbations)
        elif sample_type == 'greedy3':
            self.greedy3(n_perturbations)
        else:
            exit(f"unkown sample type {sample_type}")

        print("final ptb budget/true= {:.1f}/{:.1f}".format(
            n_perturbations, self.adj_changes.sum()))
        self.modified_adj = self.get_modified_adj(ori_adj).detach()
        self.check_adj_tensor(self.modified_adj)

        # for sanity check
        ori_adj_norm = utils.normalize_adj_tensor(ori_adj, device=self.device)
        ori_e, ori_v = torch.symeig(ori_adj_norm, eigenvectors=True)
        adj_norm = utils.normalize_adj_tensor(self.modified_adj,
                                              device=self.device)
        e, v = torch.symeig(adj_norm, eigenvectors=True)

        self.adj = ori_adj.detach()
        self.labels = labels.detach()
        self.ori_e = ori_e
        self.ori_v = ori_v
        self.e = e
        self.v = v

    def greedy(self, n_perturbations):
        s = self.adj_changes.cpu().detach().numpy()
        # l = min(s)
        # r = max(s)
        # noise = np.random.normal((l+r)/2, 0.1*(r-l), s.shape)
        # s += noise

        s_vec = np.squeeze(np.reshape(s, (1, -1)))
        # max_index = (-np.absolute(s_vec)).argsort()[:n_perturbations]
        max_index = (-s_vec).argsort()[:n_perturbations]

        mask = np.zeros_like(s_vec)
        mask[max_index] = 1.0

        best_s = np.reshape(mask, s.shape)

        self.adj_changes.data.copy_(
            torch.clamp(torch.tensor(best_s), min=0, max=1))

    def greedy3(self, n_perturbations):
        s = self.adj_changes.cpu().detach().numpy()
        s_vec = np.squeeze(np.reshape(s, (1, -1)))
        # max_index = (-np.absolute(s_vec)).argsort()[:n_perturbations]
        max_index = (s_vec).argsort()[:n_perturbations]

        mask = np.zeros_like(s_vec)
        mask[max_index] = 1.0

        best_s = np.reshape(mask, s.shape)

        self.adj_changes.data.copy_(
            torch.clamp(torch.tensor(best_s), min=0, max=1))

    def greedy2(self, n_perturbations):
        s = self.adj_changes.cpu().detach().numpy()
        l = min(s)
        r = max(s)
        noise = np.random.normal((l + r) / 2, 0.4 * (r - l), s.shape)
        s += noise

        s_vec = np.squeeze(np.reshape(s, (1, -1)))
        max_index = (-np.absolute(s_vec)).argsort()[:n_perturbations]

        mask = np.zeros_like(s_vec)
        mask[max_index] = 1.0

        best_s = np.reshape(mask, s.shape)

        self.adj_changes.data.copy_(
            torch.clamp(torch.tensor(best_s), min=0, max=1))

    def random_sample(self, ori_adj, ori_features, labels, idx_target,
                      n_perturbations):
        K = 10
        best_loss = -1000
        victim_model = self.surrogate
        with torch.no_grad():
            s = self.adj_changes.cpu().detach().numpy()
            for i in range(K):
                sampled = np.random.binomial(1, s)
                # randm = np.random.uniform(size=s.shape[0])
                # sampled = np.where(s > randm, 1, 0)

                # if sampled.sum() > n_perturbations:
                #     continue
                while sampled.sum() > n_perturbations:
                    sampled = np.random.binomial(1, s)
                # if sampled.sum() > n_perturbations:
                #     indices = np.transpose(np.nonzero(sampled))
                #     candidate_idx = [m for m in range(indices.shape[0])]
                #     chosen_idx = np.random.choice(candidate_idx, n_perturbations, replace=False)
                #     chosen_indices = indices[chosen_idx, :]
                #     sampled = np.zeros_like(sampled)
                #     for idx in chosen_indices:
                #         sampled[idx] = 1

                self.adj_changes.data.copy_(torch.tensor(sampled))
                modified_adj = self.get_modified_adj(ori_adj)
                adj_norm = utils.normalize_adj_tensor(modified_adj,
                                                      device=self.device)
                output = victim_model(ori_features, adj_norm)
                loss = self._loss(output[idx_target], labels[idx_target])
                # loss = F.nll_loss(output[idx_target], labels[idx_target])
                # print(loss)
                if best_loss < loss:
                    best_loss = loss
                    best_s = sampled
            self.adj_changes.data.copy_(torch.tensor(best_s))

    def get_modified_adj(self, ori_adj):

        if self.complementary is None:
            self.complementary = (torch.ones_like(ori_adj) - torch.eye(
                self.nnodes).to(self.device) - ori_adj) - ori_adj

        m = torch.zeros((self.nnodes, self.nnodes)).to(self.device)
        tril_indices = torch.tril_indices(row=self.nnodes,
                                          col=self.nnodes,
                                          offset=-1)
        m[tril_indices[0], tril_indices[1]] = self.adj_changes
        m = m + m.t()
        modified_adj = self.complementary * m + ori_adj

        return modified_adj

    def add_random_noise(self, ori_adj):
        noise = 1e-4 * torch.rand(self.nnodes, self.nnodes).to(self.device)
        return (noise + torch.transpose(noise, 0, 1)) / 2.0 + ori_adj

    def projection2(self, n_perturbations):
        s = self.adj_changes.cpu().detach().numpy()
        n = np.squeeze(np.reshape(s, (1, -1))).shape[0]
        self.adj_changes.data.copy_(
            torch.clamp(self.adj_changes.data, min=0, max=n_perturbations / n))
        return 0, 0, 0

    def projection(self, n_perturbations):
        l, r, m = 0, 0, 0
        if torch.clamp(self.adj_changes, 0, 1).sum() > n_perturbations:
            left = (self.adj_changes).min()
            right = self.adj_changes.max()
            miu = self.bisection(left, right, n_perturbations, epsilon=1e-5)
            l = left.cpu().detach()
            r = right.cpu().detach()
            m = miu.cpu().detach()
            self.adj_changes.data.copy_(
                torch.clamp(self.adj_changes.data - miu, min=0, max=1))
        else:
            self.adj_changes.data.copy_(
                torch.clamp(self.adj_changes.data, min=0, max=1))

        return l, r, m

    def _loss(self, output, labels):
        if self.loss_type == "CE":
            loss = F.nll_loss(output, labels)
        if self.loss_type == "CW":
            onehot = utils.tensor2onehot(labels)
            best_second_class = (output - 1000 * onehot).argmax(1).detach()
            margin = output[np.arange(len(output)), labels] - \
                   output[np.arange(len(output)), best_second_class]
            k = 0
            loss = -torch.clamp(margin, min=k).mean()
            # loss = torch.clamp(margin.sum()+50, min=k)
        return loss

    def bisection(self, a, b, n_perturbations, epsilon):
        def func(x):
            return torch.clamp(self.adj_changes - x, 0,
                               1).sum() - n_perturbations

        miu = a
        while ((b - a) >= epsilon):
            miu = (a + b) / 2
            # Check if middle point is root
            if (func(miu) == 0.0):
                b = miu
                break
            # Decide the side to repeat the steps
            if (func(miu) * func(a) < 0):
                b = miu
            else:
                a = miu
        # print("The value of root is : ","%.4f" % miu)
        return miu
 def __init__(self, w0=30.0, name=None):
     super(SIREN, self).__init__()
     self._built = True
     self.w0 = Parameter(data=to_tensor(w0, requires_grad=True))
Exemple #47
0
class CBatchNorm2d(nn.Module):
    def __init__(
        self,
        num_features,
        eps=1e-5,
        momentum=0.1,
        affine=True,
        track_running_stats=True,
        buffer_num=0,
        rho=1.0,
        burnin=0,
        two_stage=True,
        FROZEN=False,
        out_p=False,
    ):
        super(CBatchNorm2d, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats

        self.buffer_num = buffer_num
        self.max_buffer_num = buffer_num
        self.rho = rho
        self.burnin = burnin
        self.two_stage = two_stage
        self.FROZEN = FROZEN
        self.out_p = out_p

        self.iter_count = 0
        self.pre_mu = []
        self.pre_meanx2 = []  # mean(x^2)
        self.pre_dmudw = []
        self.pre_dmeanx2dw = []
        self.pre_weight = []
        self.ones = torch.ones(self.num_features).cuda()

        if self.affine:
            self.weight = Parameter(torch.Tensor(num_features))
            self.bias = Parameter(torch.Tensor(num_features))
        else:
            self.register_parameter("weight", None)
            self.register_parameter("bias", None)
        if self.track_running_stats:
            self.register_buffer("running_mean", torch.zeros(num_features))
            self.register_buffer("running_var", torch.ones(num_features))
        else:
            self.register_parameter("running_mean", None)
            self.register_parameter("running_var", None)
        self.reset_parameters()

    def reset_parameters(self):
        if self.track_running_stats:
            self.running_mean.zero_()
            self.running_var.fill_(1)
        if self.affine:
            self.weight.data.uniform_()
            self.bias.data.zero_()

    def _check_input_dim(self, input):
        if input.dim() != 4:
            raise ValueError("expected 4D input (got {}D input)".format(
                input.dim()))

    def _update_buffer_num(self):
        if self.two_stage:
            if self.iter_count > self.burnin:
                self.buffer_num = self.max_buffer_num
            else:
                self.buffer_num = 0
        else:
            self.buffer_num = int(self.max_buffer_num *
                                  min(self.iter_count / self.burnin, 1.0))

    def forward(self, input, weight):
        # deal with wight and grad of self.pre_dxdw!
        self._check_input_dim(input)
        y = input.transpose(0, 1)
        return_shape = y.shape
        y = y.contiguous().view(input.size(1), -1)

        # burnin
        if self.training and self.burnin > 0:
            self.iter_count += 1
            self._update_buffer_num()

        if (self.buffer_num > 0 and self.training
                and input.requires_grad):  # some layers are frozen!
            # cal current batch mu and sigma
            cur_mu = y.mean(dim=1)
            cur_meanx2 = torch.pow(y, 2).mean(dim=1)
            cur_sigma2 = y.var(dim=1)
            # cal dmu/dw dsigma2/dw
            dmudw = torch.autograd.grad(cur_mu,
                                        weight,
                                        self.ones,
                                        retain_graph=True)[0]
            dmeanx2dw = torch.autograd.grad(cur_meanx2,
                                            weight,
                                            self.ones,
                                            retain_graph=True)[0]
            # update cur_mu and cur_sigma2 with pres
            mu_all = torch.stack([
                cur_mu,
            ] + [
                tmp_mu + (self.rho * tmp_d *
                          (weight.data - tmp_w)).sum(1).sum(1).sum(1)
                for tmp_mu, tmp_d, tmp_w in zip(self.pre_mu, self.pre_dmudw,
                                                self.pre_weight)
            ])
            meanx2_all = torch.stack([
                cur_meanx2,
            ] + [
                tmp_meanx2 + (self.rho * tmp_d *
                              (weight.data - tmp_w)).sum(1).sum(1).sum(1)
                for tmp_meanx2, tmp_d, tmp_w in zip(
                    self.pre_meanx2, self.pre_dmeanx2dw, self.pre_weight)
            ])
            sigma2_all = meanx2_all - torch.pow(mu_all, 2)

            # with considering count
            re_mu_all = mu_all.clone()
            re_meanx2_all = meanx2_all.clone()
            re_mu_all[sigma2_all < 0] = 0
            re_meanx2_all[sigma2_all < 0] = 0
            count = (sigma2_all >= 0).sum(dim=0).float()
            mu = re_mu_all.sum(dim=0) / count
            sigma2 = re_meanx2_all.sum(dim=0) / count - torch.pow(mu, 2)

            self.pre_mu = [
                cur_mu.detach(),
            ] + self.pre_mu[:(self.buffer_num - 1)]
            self.pre_meanx2 = [
                cur_meanx2.detach(),
            ] + self.pre_meanx2[:(self.buffer_num - 1)]
            self.pre_dmudw = [
                dmudw.detach(),
            ] + self.pre_dmudw[:(self.buffer_num - 1)]
            self.pre_dmeanx2dw = [
                dmeanx2dw.detach(),
            ] + self.pre_dmeanx2dw[:(self.buffer_num - 1)]

            tmp_weight = torch.zeros_like(weight.data)
            tmp_weight.copy_(weight.data)
            self.pre_weight = [
                tmp_weight.detach(),
            ] + self.pre_weight[:(self.buffer_num - 1)]

        else:
            x = y
            mu = x.mean(dim=1)
            cur_mu = mu
            sigma2 = x.var(dim=1)
            cur_sigma2 = sigma2

        if not self.training or self.FROZEN:
            y = y - self.running_mean.view(-1, 1)
            # TODO: outside **0.5?
            if self.out_p:
                y = y / (self.running_var.view(-1, 1) + self.eps)**0.5
            else:
                y = y / (self.running_var.view(-1, 1)**0.5 + self.eps)

        else:
            if self.track_running_stats is True:
                with torch.no_grad():
                    self.running_mean = (
                        1 - self.momentum
                    ) * self.running_mean + self.momentum * cur_mu
                    self.running_var = (
                        1 - self.momentum
                    ) * self.running_var + self.momentum * cur_sigma2
            y = y - mu.view(-1, 1)
            # TODO: outside **0.5?
            if self.out_p:
                y = y / (sigma2.view(-1, 1) + self.eps)**0.5
            else:
                y = y / (sigma2.view(-1, 1)**0.5 + self.eps)

        y = self.weight.view(-1, 1) * y + self.bias.view(-1, 1)
        return y.view(return_shape).transpose(0, 1)

    def extra_repr(self):
        return (
            "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, "
            "buffer={max_buffer_num}, burnin={burnin}, "
            "track_running_stats={track_running_stats}".format(
                **self.__dict__))
Exemple #48
0
 def __init__(self, ave_source_num, feature_size=0, bias=False):
     super(Ave_multi_view, self).__init__()
     self.ave_source_num = ave_source_num
     self.feature_size = feature_size
     self.weight = Parameter(FloatTensor(ave_source_num))
     self.reset_parameters()
Exemple #49
0
class Linear(torch.nn.Module):
    r"""Applies a linear tranformation to the incoming data

    .. math::
        \mathbf{x}^{\prime} = \mathbf{x} \mathbf{W}^{\top} + \mathbf{b}

    similar to :class:`torch.nn.Linear`.
    It supports lazy initialization and customizable weight and bias
    initialization.

    Args:
        in_channels (int): Size of each input sample. Will be initialized
            lazily in case it is given as :obj:`-1`.
        out_channels (int): Size of each output sample.
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
        weight_initializer (str, optional): The initializer for the weight
            matrix (:obj:`"glorot"`, :obj:`"uniform"`, :obj:`"kaiming_uniform"`
            or :obj:`None`).
            If set to :obj:`None`, will match default weight initialization of
            :class:`torch.nn.Linear`. (default: :obj:`None`)
        bias_initializer (str, optional): The initializer for the bias vector
            (:obj:`"zeros"` or :obj:`None`).
            If set to :obj:`None`, will match default bias initialization of
            :class:`torch.nn.Linear`. (default: :obj:`None`)

    Shapes:
        - **input:** features :math:`(*, F_{in})`
        - **output:** features :math:`(*, F_{out})`
    """
    def __init__(self, in_channels: int, out_channels: int, bias: bool = True,
                 weight_initializer: Optional[str] = None,
                 bias_initializer: Optional[str] = None):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.weight_initializer = weight_initializer
        self.bias_initializer = bias_initializer

        if in_channels > 0:
            self.weight = Parameter(torch.Tensor(out_channels, in_channels))
        else:
            self.weight = nn.parameter.UninitializedParameter()
            self._hook = self.register_forward_pre_hook(
                self.initialize_parameters)

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self._load_hook = self._register_load_state_dict_pre_hook(
            self._lazy_load_hook)

        self.reset_parameters()

    def __deepcopy__(self, memo):
        out = Linear(self.in_channels, self.out_channels, self.bias
                     is not None, self.weight_initializer,
                     self.bias_initializer)
        if self.in_channels > 0:
            out.weight = copy.deepcopy(self.weight, memo)
        if self.bias is not None:
            out.bias = copy.deepcopy(self.bias, memo)
        return out

    def reset_parameters(self):
        if isinstance(self.weight, nn.parameter.UninitializedParameter):
            pass
        elif self.weight_initializer == 'glorot':
            inits.glorot(self.weight)
        elif self.weight_initializer == 'uniform':
            bound = 1.0 / math.sqrt(self.weight.size(-1))
            torch.nn.init.uniform_(self.weight.data, -bound, bound)
        elif self.weight_initializer == 'kaiming_uniform':
            inits.kaiming_uniform(self.weight, fan=self.in_channels,
                                  a=math.sqrt(5))
        elif self.weight_initializer is None:
            inits.kaiming_uniform(self.weight, fan=self.in_channels,
                                  a=math.sqrt(5))
        else:
            raise RuntimeError(f"Linear layer weight initializer "
                               f"'{self.weight_initializer}' is not supported")

        if isinstance(self.weight, nn.parameter.UninitializedParameter):
            pass
        elif self.bias is None:
            pass
        elif self.bias_initializer == 'zeros':
            inits.zeros(self.bias)
        elif self.bias_initializer is None:
            inits.uniform(self.in_channels, self.bias)
        else:
            raise RuntimeError(f"Linear layer bias initializer "
                               f"'{self.bias_initializer}' is not supported")

    def forward(self, x: Tensor) -> Tensor:
        r"""
        Args:
            x (Tensor): The features.
        """
        return F.linear(x, self.weight, self.bias)

    @torch.no_grad()
    def initialize_parameters(self, module, input):
        if isinstance(self.weight, nn.parameter.UninitializedParameter):
            self.in_channels = input[0].size(-1)
            self.weight.materialize((self.out_channels, self.in_channels))
            self.reset_parameters()
        self._hook.remove()
        delattr(self, '_hook')

    def _save_to_state_dict(self, destination, prefix, keep_vars):
        if isinstance(self.weight, nn.parameter.UninitializedParameter):
            destination[prefix + 'weight'] = self.weight
        else:
            destination[prefix + 'weight'] = self.weight.detach()
        if self.bias is not None:
            destination[prefix + 'bias'] = self.bias.detach()

    def _lazy_load_hook(self, state_dict, prefix, local_metadata, strict,
                        missing_keys, unexpected_keys, error_msgs):

        weight = state_dict[prefix + 'weight']
        if isinstance(weight, nn.parameter.UninitializedParameter):
            self.in_channels = -1
            self.weight = nn.parameter.UninitializedParameter()
            if not hasattr(self, '_hook'):
                self._hook = self.register_forward_pre_hook(
                    self.initialize_parameters)

        elif isinstance(self.weight, nn.parameter.UninitializedParameter):
            self.in_channels = weight.size(-1)
            self.weight.materialize((self.out_channels, self.in_channels))
            if hasattr(self, '_hook'):
                self._hook.remove()
                delattr(self, '_hook')

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, bias={self.bias is not None})')
Exemple #50
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 bound_min=8,
                 bound_max=16,
                 kernel_size=3,
                 stride=1,
                 padding=1,
                 dilation=1,
                 groups=1,
                 deformable_groups=1,
                 bias=True):
        super(DeformConvWithOffsetScaleGaussBoundMinMaxShared, self).__init__()

        assert in_channels % groups == 0, 'in_channels must be divisible by groups'
        assert out_channels % groups == 0, 'out_channels must be divisible by groups'
        assert out_channels % deformable_groups == 0, 'out_channels must be divisible by deformable groups'

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride)
        self.padding = _pair(padding)
        self.dilation = _pair(dilation)
        self.groups = groups
        self.deformable_groups = deformable_groups

        self.weight = Parameter(
            torch.Tensor(self.out_channels, self.in_channels // self.groups,
                         *self.kernel_size).cuda())
        if bias:
            self.bias = Parameter(torch.Tensor(self.out_channels).cuda())
        else:
            self.register_parameter('bias', None)

        stdv = 1. / math.sqrt(self.in_channels * kernel_size * kernel_size)
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

        self.anchor_default = torch.FloatTensor(
            [-1, -1, -1, 0, -1, 1, 0, -1, 0, 0, 0, 1, 1, -1, 1, 0, 1,
             1]).unsqueeze(0).unsqueeze(2).unsqueeze(2)

        self.anchor_gauss = torch.FloatTensor([
            -0.7071, -0.7071, -1, 0, -0.7071, 0.7071, 0, -1, 0, 0, 0, 1,
            0.7071, -0.7071, 1, 0, 0.7071, 0.7071
        ]).unsqueeze(0).unsqueeze(2).unsqueeze(2)

        self.conv_scale = nn.Conv2d(in_channels,
                                    2,
                                    kernel_size=3,
                                    stride=1,
                                    padding=1,
                                    bias=True)
        self.conv_scale.weight.data.zero_()
        # self.conv_scale.bias.data.zero_()
        nn.init.constant_(self.conv_scale.bias.data, 1)
        self.bmin = torch.nn.Hardtanh(min_val=0,
                                      max_val=bound_min,
                                      inplace=True)
        self.bmax = torch.nn.Hardtanh(min_val=bound_min,
                                      max_val=bound_max,
                                      inplace=True)
Exemple #51
0
class ABAE(torch.nn.Module):
    """
        The model described in the paper ``An Unsupervised Neural Attention Model for Aspect Extraction''
        by He, Ruidan and  Lee, Wee Sun  and  Ng, Hwee Tou  and  Dahlmeier, Daniel, ACL2017
        https://aclweb.org/anthology/papers/P/P17/P17-1036/

    """
    def __init__(self,
                 wv_dim: int = 200,
                 asp_count: int = 30,
                 ortho_reg: float = 0.1,
                 maxlen: int = 201,
                 init_aspects_matrix=None):
        """
        Initializing the model

        :param wv_dim: word vector size
        :param asp_count: number of aspects
        :param ortho_reg: coefficient for tuning the ortho-regularizer's influence
        :param maxlen: sentence max length taken into account
        :param init_aspects_matrix: None or init. matrix for aspects
        """
        super(ABAE, self).__init__()
        self.wv_dim = wv_dim
        self.asp_count = asp_count
        self.ortho = ortho_reg
        self.maxlen = maxlen

        self.attention = SelfAttention(wv_dim, maxlen)
        self.linear_transform = torch.nn.Linear(self.wv_dim, self.asp_count)
        self.softmax_aspects = torch.nn.Softmax()
        self.aspects_embeddings = Parameter(
            torch.empty(size=(wv_dim, asp_count)))

        if init_aspects_matrix is None:
            torch.nn.init.xavier_uniform(self.aspects_embeddings)
        else:
            self.aspects_embeddings.data = torch.from_numpy(
                init_aspects_matrix.T)

    def get_aspects_importances(self, text_embeddings):
        """
            Takes embeddings of a sentence as input, returns attention weights
        """

        # compute attention scores, looking at text embeddings average
        attention_weights = self.attention(text_embeddings)

        # multiplying text embeddings by attention scores -- and summing
        # (matmul: we sum every word embedding's coordinate with attention weights)
        weighted_text_emb = torch.matmul(
            attention_weights.unsqueeze(1),  # (batch, 1, sentence)
            text_embeddings  # (batch, sentence, wv_dim)
        ).squeeze()

        # encoding with a simple feed-forward layer (wv_dim) -> (aspects_count)
        raw_importances = self.linear_transform(weighted_text_emb)

        # computing 'aspects distribution in a sentence'
        aspects_importances = self.softmax_aspects(raw_importances)

        return attention_weights, aspects_importances, weighted_text_emb

    def forward(self, text_embeddings, negative_samples_texts):

        # negative samples are averaged
        averaged_negative_samples = torch.mean(negative_samples_texts, dim=2)

        # encoding: words embeddings -> sentence embedding, aspects importances
        _, aspects_importances, weighted_text_emb = self.get_aspects_importances(
            text_embeddings)

        # decoding: aspects embeddings matrix, aspects_importances -> recovered sentence embedding
        recovered_emb = torch.matmul(
            self.aspects_embeddings,
            aspects_importances.unsqueeze(2)).squeeze()

        # loss
        reconstruction_triplet_loss = ABAE._reconstruction_loss(
            weighted_text_emb, recovered_emb, averaged_negative_samples)
        max_margin = torch.max(reconstruction_triplet_loss,
                               torch.zeros_like(reconstruction_triplet_loss))

        return self.ortho * self._ortho_regularizer() + max_margin

    @staticmethod
    def _reconstruction_loss(text_emb, recovered_emb, averaged_negative_emb):

        positive_dot_products = torch.matmul(
            text_emb.unsqueeze(1), recovered_emb.unsqueeze(2)).squeeze()
        negative_dot_products = torch.matmul(
            averaged_negative_emb, recovered_emb.unsqueeze(2)).squeeze()
        reconstruction_triplet_loss = torch.sum(
            1 - positive_dot_products.unsqueeze(1) + negative_dot_products,
            dim=1)

        return reconstruction_triplet_loss

    def _ortho_regularizer(self):
        return torch.norm(
            torch.matmul(self.aspects_embeddings.t(), self.aspects_embeddings) \
            - torch.eye(self.asp_count))

    def get_aspect_words(self, w2v_model, topn=15):
        words = []

        # getting aspects embeddings
        aspects = self.aspects_embeddings.detach().numpy()

        # getting scalar products of word embeddings and aspect embeddings;
        # to obtain the ``probabilities'', one should also apply softmax
        words_scores = w2v_model.wv.syn0.dot(aspects)

        for row in range(aspects.shape[1]):
            argmax_scalar_products = np.argsort(-words_scores[:, row])[:topn]
            # print([w2v_model.wv.index2word[i] for i in argmax_scalar_products])
            # print([w for w, dist in w2v_model.similar_by_vector(aspects.T[row])[:topn]])
            words.append(
                [w2v_model.wv.index2word[i] for i in argmax_scalar_products])

        return words
Exemple #52
0
    def __init__(self, feats, k):
        super().__init__()
        self.scorer = Parameter(torch.Tensor(feats, 1))
        self.reset_param(self.scorer)

        self.k = k
class GHConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True):

        # Init torch module
        super(GHConv2d, self).__init__()

        # Init conv params
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = stride
        self.padding = padding
        self.dilation = dilation

        # init constants according to section 5
        self.t0 = 1e-5

        # Init globals
        self.sa_mu = Parameter(Tensor(1))
        self.sa_logvar = Parameter(Tensor(1))
        self.sb_mu = Parameter(Tensor(1))
        self.sb_logvar = Parameter(Tensor(1))

        # Filter locals
        self.alpha_mu = Parameter(Tensor(out_channels))
        self.alpha_logvar = Parameter(Tensor(out_channels))
        self.beta_mu = Parameter(Tensor(out_channels))
        self.beta_logvar = Parameter(Tensor(out_channels))

        # Weight local
        self.weight_mu = Parameter(Tensor(out_channels, in_channels, *self.kernel_size))
        self.weight_logvar = Parameter(Tensor(out_channels, in_channels, *self.kernel_size))

        # Bias local if required
        self.bias = bias
        self.bias_mu = Parameter(Tensor(out_channels)) if self.bias else None
        self.bias_logvar = Parameter(Tensor(out_channels)) if self.bias else None

        # Set initial parameters
        self._init_params()

        # for brevity to conv2d calls
        self.convargs = [self.stride, self.padding, self.dilation]
    def _s_mu(self):
        return 0.5 * (self.sa_mu + self.sb_mu)

    def _s_var(self):
        return 0.25 * (self.sa_logvar.exp() + self.sb_logvar.exp())

    def _z_var(self):
        return 0.25 * (self.alpha_logvar.exp() + self.beta_logvar.exp())

    def _z_mu(self):
        return 0.5 * (self.alpha_mu + self.beta_mu)

    def forward(self, x):

        # vanilla forward pass if testing
        if not self.training:
            expect_z = torch.exp(0.5 * (self._z_var() + self._s_var()) + self._z_mu() + self._s_mu())
            post_weight_mu = self.weight_mu * expect_z[:, None, None, None]
            post_bias_mu = self.bias_mu * expect_z if (self.bias_mu is not None) else None
            return conv2d(x, post_weight_mu, post_bias_mu, *self.convargs)

        # compute global shrinkage
        s_mu = 0.5 * (self.sa_mu + self.sb_mu)
        s_sig = torch.sqrt(self._s_var())
        s = LogNormal(s_mu, s_sig).rsample()

        # compute filter scales
        z_mu = self._z_mu()
        z_var = self._z_var()
        z = s * LogNormal(z_mu, z_var.sqrt()).rsample()[None, :, None, None]


        # lognormal out params, local reparameterization trick
        bvar = self.bias_logvar.exp() if self.bias else None
        mu_out = conv2d(x, self.weight_mu, self.bias_mu, *self.convargs) * z
        scale_out = conv2d(x**2, self.weight_logvar.exp(), bvar, *self.convargs) * (z ** 2)

        # compute output weight distribution, again reparameterised
        dist_out = Normal(mu_out, scale_out.sqrt()).rsample()

        # return fully reparameterised forward pass
        return dist_out


    def _init_params(self, weight=None, bias=None):

        # initialisation params - note mean of lognormal is exp(mu + 0.5 *var)
        init_mu_logvar, init_mu, init_var = -9, 0., 1e-2

        # compute xavier initialisation on weights
        n = self.in_channels * self.kernel_size[0] * self.kernel_size[1]
        thresh = 1/math.sqrt(n)

        if weight is not None:
            self.weight_mu.data = weight
        else:
            self.weight_mu.data.uniform_(-thresh, thresh)

        # init variance according to appendix A
        self.weight_logvar.data.normal_(init_mu_logvar, init_var)

        if self.bias:
            if bias is not None:
                self.bias_mu.data = bias
            else:
                self.bias_mu.data.fill_(0)

            # biases
            self.bias_logvar.data.normal_(init_mu_logvar, init_var)

        # Decomposed prior means => E[z_init] init ~ 1
        self.alpha_mu.data.normal_(init_mu, init_var)
        self.beta_mu.data.normal_(init_mu, init_var)
        self.sa_mu.data.normal_(init_mu, init_var)
        self.sb_mu.data.normal_(init_mu, init_var)

        # Decomposed prior variances
        self.alpha_logvar.data.normal_(init_mu_logvar, init_var)
        self.beta_logvar.data.normal_(init_mu_logvar, init_var)
        self.sa_logvar.data.normal_(init_mu_logvar, init_var)
        self.sb_logvar.data.normal_(init_mu_logvar, init_var)


    # KL div for GNH with lognormal scale, normal weight variational posterior
    def kl_divergence(self):
        # negative kls, eqns (34-37)
        neg_kl_s = self._global_negative_kl()
        neg_kl_ab = self._filter_local_negative_kl()

        # weight/bias local
        kl_w = self._conditional_kl_div(self.weight_mu, self.weight_logvar)

        if self.bias:
            kl_b = self._conditional_kl_div(self.bias_mu, self.bias_logvar)
        else:
            kl_b = 0

        return kl_w + kl_b - (neg_kl_s + neg_kl_ab)


    def _global_negative_kl(self):

        # hyperparams
        t0 = self.t0

        # const added in every kl div
        c = 1 + math.log(2)

        # shape/scale of global scale parameters
        sa_mu, sb_mu = self.sa_mu, self.sb_mu
        sa_var, sb_var = self.sa_logvar.exp(), self.sb_logvar.exp()

        # Eqns (34)(35)
        kl_sa = math.log(t0) - torch.exp(sa_mu + 0.5 * sa_var)/t0 + 0.5 * (sa_mu + self.sa_logvar + c)
        kl_sb = 0.5 * (self.sb_logvar - sb_mu + c ) - torch.exp(0.5 * sb_var - sb_mu)

        return kl_sa + kl_sb


    def _filter_local_negative_kl(self):

        # const added in every kl div
        c = 1 + math.log(2)

        # hyperparams
        t0 = self.t0

        # filter level shape/scale parameters
        alpha_mu, beta_mu = self.alpha_mu, self.beta_mu
        alpha_logvar, beta_logvar = self.alpha_logvar, self.beta_logvar

        # Eqns (36)(37)
        kl_alpha = torch.sum(0.5 * (alpha_mu + alpha_logvar + c) - torch.exp(alpha_mu + 0.5 * alpha_logvar.exp()))
        kl_beta = torch.sum(0.5 * (beta_logvar - beta_mu + c) - torch.exp(0.5 * beta_logvar.exp() - beta_mu))

        return kl_alpha + kl_beta


    @staticmethod
    def _conditional_kl_div(mu, logvar):
        # eqn (8)
        kl_div = -0.5 * logvar + 0.5 * (logvar.exp() + mu ** 2 - 1)
        return torch.sum(kl_div)
Exemple #54
0
class LinearSimilarityVB(SimilarityFunction):
    """
    This similarity function performs a dot product between a vector of weights and some
    combination of the two input vectors, followed by an (optional) activation function.  The
    combination used is configurable.
    If the two vectors are ``x`` and ``y``, we allow the following kinds of combinations: ``x``,
    ``y``, ``x*y``, ``x+y``, ``x-y``, ``x/y``, where each of those binary operations is performed
    elementwise.  You can list as many combinations as you want, comma separated.  For example, you
    might give ``x,y,x*y`` as the ``combination`` parameter to this class.  The computed similarity
    function would then be ``w^T [x; y; x*y] + b``, where ``w`` is a vector of weights, ``b`` is a
    bias parameter, and ``[;]`` is vector concatenation.
    Note that if you want a bilinear similarity function with a diagonal weight matrix W, where the
    similarity function is computed as `x * w * y + b` (with `w` the diagonal of `W`), you can
    accomplish that with this class by using "x*y" for `combination`.
    Parameters
    ----------
    tensor_1_dim : ``int``
        The dimension of the first tensor, ``x``, described above.  This is ``x.size()[-1]`` - the
        length of the vector that will go into the similarity computation.  We need this so we can
        build weight vectors correctly.
    tensor_2_dim : ``int``
        The dimension of the second tensor, ``y``, described above.  This is ``y.size()[-1]`` - the
        length of the vector that will go into the similarity computation.  We need this so we can
        build weight vectors correctly.
    combination : ``str``, optional (default="x,y")
        Described above.
    activation : ``Activation``, optional (default=linear (i.e. no activation))
        An activation function applied after the ``w^T * [x;y] + b`` calculation.  Default is no
        activation.
    """
    def __init__(self,
                 tensor_1_dim: int,
                 tensor_2_dim: int,
                 combination: str = 'x,y',
                 activation: Activation = None,
                 prior = None) -> None:
        super(LinearSimilarityVB, self).__init__()
        self._combination = combination
        combined_dim = util.get_combined_dim(combination, [tensor_1_dim, tensor_2_dim])
        
        self.posterior_mean = False # Flag to know if we sample from the posterior mean or we actually sample
        
        ## If no prior is specified we just create it ourselves
        if (type(prior) == type (None)):
            prior = Vil.Prior(0.5, np.log(0.1),np.log(0.5))
        
        size_combination = int(torch.Tensor(combined_dim).size()[0])
#        print ("Combination size: ", size_combination)
        prior =  prior.get_standarized_Prior(size_combination)
        self.prior = prior 
        
        """
        Mean and rhos of the parameters
        """
        self.mu_weight = Parameter(torch.Tensor(combined_dim))# , requires_grad=True
        self.rho_weight = Parameter(torch.Tensor(combined_dim))

        self.rho_bias = Parameter(torch.Tensor(1))
        self.mu_bias = Parameter(torch.Tensor(1))
            
        """
        The sampled weights
        """
        self.weight = torch.Tensor(combined_dim)
        self.bias = torch.Tensor(1)
        
        self._activation = activation or Activation.by_name('linear')()
        
        ## Initialize the Variational variables
        self.reset_parameters()
#        self.sample_posterior()

    def reset_parameters(self):
#        std = math.sqrt(6 / (self._weight_vector.size(0) + 1))
#        self._weight_vector.data.uniform_(-std, std)
#        self._bias.data.fill_(0)
        
        self.rho_weight.data = Vil.init_rho(self.mu_weight.size(), self.prior)
        self.rho_bias.data = Vil.init_rho(self.mu_bias.size(), self.prior)
        
        ## Now initialize the mean
        self.mu_weight.data = Vil.init_mu(self.mu_weight.size(), 
                                          self.prior,Ninput = self.mu_weight.size(0), type = "LinearSimilarity")
        
        self.mu_bias.data = Vil.init_mu(self.mu_bias.size(), self.prior, Ninput = self.mu_weight.size(0), type = "LinearSimilarity")
        
    def sample_posterior(self):
        """
        This function samples the Bayesian weights from the parameters and puts them into the variables.
        It needs to do so using the reparametrization trick so that we can derive respect to sigma and mu
        """
        
#        print ("SAMPLING FROM LINEAR SIMILARITY VB")
        if (self.posterior_mean == False):
            self.weight = Vil.sample_posterior(self.mu_weight, Vil.softplus(self.rho_weight))
            self.bias = Vil.sample_posterior(self.mu_bias, Vil.softplus(self.rho_bias))
#            print (self.bias)
        else:
            self.weight.data = self.mu_weight.data
            self.bias.data = self.mu_bias.data
                
    @overrides
    def forward(self, tensor_1: torch.Tensor, tensor_2: torch.Tensor) -> torch.Tensor:
        combined_tensors = util.combine_tensors(self._combination, [tensor_1, tensor_2])
        dot_product = torch.matmul(combined_tensors, self.weight)
        return self._activation(dot_product + self.bias)

    def get_KL_divergence(self):
        """
        This function computes the KL loss for all the Variational Weights in the network !!
        It does not sample the weights again, it uses the ones that are already sampled.
        
        """
        KL_loss_W = Vil.get_KL_divergence_Samples(self.mu_weight, Vil.softplus(self.rho_weight), self.weight, self.prior)
        KL_loss_b = 0
        if self.bias is not None:
            KL_loss_b = Vil.get_KL_divergence_Samples(self.mu_bias, Vil.softplus(self.rho_bias), self.bias,  self.prior)
            
        KL_loss = KL_loss_W + KL_loss_b
        
        return KL_loss
    
    """
    Flag to set that we actually get the posterior mean and not a sample from the random variables
    """
    def set_posterior_mean(self, posterior_mean):
        self.posterior_mean = posterior_mean
        
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True):

        # Init torch module
        super(GHConv2d, self).__init__()

        # Init conv params
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = stride
        self.padding = padding
        self.dilation = dilation

        # init constants according to section 5
        self.t0 = 1e-5

        # Init globals
        self.sa_mu = Parameter(Tensor(1))
        self.sa_logvar = Parameter(Tensor(1))
        self.sb_mu = Parameter(Tensor(1))
        self.sb_logvar = Parameter(Tensor(1))

        # Filter locals
        self.alpha_mu = Parameter(Tensor(out_channels))
        self.alpha_logvar = Parameter(Tensor(out_channels))
        self.beta_mu = Parameter(Tensor(out_channels))
        self.beta_logvar = Parameter(Tensor(out_channels))

        # Weight local
        self.weight_mu = Parameter(Tensor(out_channels, in_channels, *self.kernel_size))
        self.weight_logvar = Parameter(Tensor(out_channels, in_channels, *self.kernel_size))

        # Bias local if required
        self.bias = bias
        self.bias_mu = Parameter(Tensor(out_channels)) if self.bias else None
        self.bias_logvar = Parameter(Tensor(out_channels)) if self.bias else None

        # Set initial parameters
        self._init_params()

        # for brevity to conv2d calls
        self.convargs = [self.stride, self.padding, self.dilation]
Exemple #56
0
class LSTMCellVB(RNNCellBase):

    def __init__(self, input_size, hidden_size, bias=True, prior = None):
        super(LSTMCellVB, self).__init__()
        self.type_layer = "LSTM"
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.posterior_mean = False # Flag to know if we sample from the posterior mean or we actually sample
        
        ## If no prior is specified we just create it ourselves
        if (type(prior) == type (None)):
            prior = Vil.Prior(0.5, np.log(0.1),np.log(0.5))
        self.prior = prior
        
        """
        Variational Inference Parameters
        """
        self.mu_weight_ih = Parameter(torch.Tensor(4 * hidden_size, input_size).to(device = Vil.device, dtype = Vil.dtype))
        self.mu_weight_hh = Parameter(torch.Tensor(4 * hidden_size, hidden_size).to(device = Vil.device, dtype = Vil.dtype))
        self.rho_weight_ih = Parameter(torch.Tensor(4 * hidden_size, input_size).to(device = Vil.device, dtype = Vil.dtype))
        self.rho_weight_hh = Parameter(torch.Tensor(4 * hidden_size, hidden_size).to(device = Vil.device, dtype = Vil.dtype))
        if bias:
            self.mu_bias_ih = Parameter(torch.Tensor(4 * hidden_size).to(device = Vil.device, dtype = Vil.dtype))
            self.mu_bias_hh = Parameter(torch.Tensor(4 * hidden_size).to(device = Vil.device, dtype = Vil.dtype))
            self.rho_bias_ih = Parameter(torch.Tensor(4 * hidden_size).to(device = Vil.device, dtype = Vil.dtype))
            self.rho_bias_hh = Parameter(torch.Tensor(4 * hidden_size).to(device = Vil.device, dtype = Vil.dtype))
        else:
            self.register_parameter('mu_bias_ih', None)
            self.register_parameter('mu_bias_hh', None)
            self.register_parameter('rho_bias_ih', None)
            self.register_parameter('rho_bias_hh', None)
        """
        Sampled weights
        """

        self.weight_ih = torch.Tensor(4 * hidden_size, input_size).to(device = Vil.device, dtype = Vil.dtype)
        self.weight_hh = torch.Tensor(4 * hidden_size, hidden_size).to(device = Vil.device, dtype = Vil.dtype)
        if bias:
            self.bias_ih = torch.Tensor(4 * hidden_size).to(device = Vil.device, dtype = Vil.dtype)
            self.bias_hh = torch.Tensor(4 * hidden_size).to(device = Vil.device, dtype = Vil.dtype)
        else:
            self.register_parameter('bias_ih', None)
            self.register_parameter('bias_hh', None)

        if(1):
            print ("linear bias_ih device: ",self.bias_ih.device)
            print ("linear weights_ih device: ",self.weight_ih.device)
            print ("linear bias mu_ih device: ",self.mu_bias_ih.device)
            print ("linear bias rho_ih device: ",self.rho_bias_ih.device)
            
            print ("linear weights mu_ih  device: ",self.mu_weight_ih.device)
            print ("linear weights rho_ih device: ",self.rho_weight_ih.device)
            
            print ("linear bias_hh device: ",self.bias_hh.device)
            print ("linear weights_hh device: ",self.weight_hh.device)
            print ("linear bias mu_hh device: ",self.mu_bias_hh.device)
            print ("linear bias rho_hh device: ",self.rho_bias_hh.device)
            
            print ("linear weights mu_hh  device: ",self.mu_weight_hh.device)
            print ("linear weights rho_hh device: ",self.rho_weight_hh.device)
            
        self.reset_parameters()
        self.sample_posterior()
        
    def reset_parameters(self):
        """
        In this function we initialize the parameters using the prior.
        The variance of the weights depends on the prior !! 
        TODO: Should it depend on dimensionality ! 
        Also the initializaion of weights should follow the normal scheme or from prior ? Can they be different
        """
        
        self.rho_weight_ih.data = Vil.init_rho(self.rho_weight_ih.size(), self.prior)
        self.rho_weight_hh.data = Vil.init_rho(self.rho_weight_hh.size(), self.prior)
        if self.bias is not None:
            self.rho_bias_ih.data = Vil.init_rho(self.rho_bias_ih.size(), self.prior)
            self.rho_bias_hh.data = Vil.init_rho(self.rho_bias_hh.size(), self.prior)
            
        ## Now initialize the mean
        self.mu_weight_ih.data = Vil.init_mu(self.mu_weight_ih.size(), self.prior,Ninput = self.mu_weight_ih.size(1))
        self.mu_weight_hh.data = Vil.init_mu(self.mu_weight_hh.size(), self.prior,Ninput = self.mu_weight_hh.size(1))
        if self.bias is not None:
            self.mu_bias_ih.data = Vil.init_mu(self.mu_bias_ih.size(), self.prior, Ninput = self.mu_weight_ih.size(1))
            self.mu_bias_hh.data = Vil.init_mu(self.mu_bias_hh.size(), self.prior, Ninput = self.mu_weight_hh.size(1))

    def sample_posterior(self):
        """
        This function samples the Bayesian weights from the parameters and puts them into the variables.
        It needs to do so using the reparametrization trick so that we can derive respect to sigma and mu
        """
        if (self.posterior_mean == False):
            self.weight_ih = Vil.sample_posterior(self.mu_weight_ih, Vil.softplus(self.rho_weight_ih))
            self.weight_hh = Vil.sample_posterior(self.mu_weight_hh, Vil.softplus(self.rho_weight_hh))
            if self.bias is not None:
                self.bias_ih = Vil.sample_posterior(self.mu_bias_ih, Vil.softplus(self.rho_bias_ih))
                self.bias_hh = Vil.sample_posterior(self.mu_bias_ih, Vil.softplus(self.rho_bias_hh))     
        else:
            self.weight_ih.data = self.mu_weight_ih.data
            self.weight_hh.data = self.mu_weight_hh.data
            if self.bias is not None:
                self.bias_hh.data = self.mu_bias_hh.data
                self.bias_ih.data = self.mu_bias_ih.data
                
    def get_KL_divergence(self):
        """
        This function computes the KL loss for all the Variational Weights in the network !!
        It does not sample the weights again, it uses the ones that are already sampled.
        
        """
        KL_loss_ih = Vil.get_KL_divergence_Samples(self.mu_weight_ih, Vil.softplus(self.rho_weight_ih), self.weight_ih, self.prior)
        KL_loss_hh = Vil.get_KL_divergence_Samples(self.mu_weight_hh, Vil.softplus(self.rho_weight_hh), self.weight_hh, self.prior)
        
        KL_loss_bih = 0
        KL_loss_bhh = 0
        if self.bias is not None:
            KL_loss_bih = Vil.get_KL_divergence_Samples(self.mu_bias_ih, Vil.softplus(self.rho_bias_ih), self.bias_ih,  self.prior)
            KL_loss_bhh = Vil.get_KL_divergence_Samples(self.mu_bias_hh, Vil.softplus(self.rho_bias_hh), self.bias_hh,  self.prior)        
        KL_loss = KL_loss_ih + KL_loss_hh + KL_loss_bih +KL_loss_bhh
        return KL_loss
    
    def forward(self, input, hx=None):
        self.check_forward_input(input)
        if hx is None:
            hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
            hx = (hx, hx)
        self.check_forward_hidden(input, hx[0], '[0]')
        self.check_forward_hidden(input, hx[1], '[1]')
        return self._backend.LSTMCell(
            input, hx,
            self.weight_ih, self.weight_hh,
            self.bias_ih, self.bias_hh,
        )
    """
    Flag to set that we actually get the posterior mean and not a sample from the random variables
    """
    def set_posterior_mean(self, posterior_mean):
        self.posterior_mean = posterior_mean