Beispiel #1
0
class BoundaryDetector(nn.Module):
    '''
    Boundary Detector,边界检测模块
    '''

    def __init__(self, i_features, h_features, s_features, inplace=False):
        super(BoundaryDetector, self).__init__()
        self.inplace = inplace
        self.Wsi = Parameter(torch.Tensor(s_features, i_features))
        self.Wsh = Parameter(torch.Tensor(s_features, h_features))
        self.bias = Parameter(torch.Tensor(s_features))
        self.vs = Parameter(torch.Tensor(1, s_features))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.Wsi.size(1))
        self.Wsi.data.uniform_(-stdv, stdv)
        self.Wsh.data.uniform_(-stdv, stdv)
        self.bias.data.uniform_(-stdv, stdv)
        self.vs.data.uniform_(-stdv, stdv)

    def forward(self, x, h):
        z = F.linear(x, self.Wsi) + F.linear(h, self.Wsh) + self.bias
        z = F.sigmoid(F.linear(z, self.vs))
        return BinaryGate.apply(z, self.training, self.inplace)

    def __repr__(self):
        return self.__class__.__name__
Beispiel #2
0
class Bilinear(Module):
    r"""Applies a bilinear transformation to the incoming data:
    :math:`y = x_1 A x_2 + b`

    Args:
        in1_features: size of each first input sample
        in2_features: size of each second input sample
        out_features: size of each output sample
        bias: If set to False, the layer will not learn an additive bias.
            Default: ``True``

    Shape:
        - Input: :math:`(N, *, \text{in1_features})`, :math:`(N, *, \text{in2_features})`
          where :math:`*` means any number of additional dimensions. All but the last
          dimension of the inputs should be the same.
        - Output: :math:`(N, *, \text{out_features})` where all but the last dimension
          are the same shape as the input.

    Attributes:
        weight: the learnable weights of the module of shape
            `(out_features x in1_features x in2_features)`
        bias:   the learnable bias of the module of shape `(out_features)`

    Examples::

        >>> m = nn.Bilinear(20, 30, 40)
        >>> input1 = torch.randn(128, 20)
        >>> input2 = torch.randn(128, 30)
        >>> output = m(input1, input2)
        >>> print(output.size())
    """

    def __init__(self, in1_features, in2_features, out_features, bias=True):
        super(Bilinear, self).__init__()
        self.in1_features = in1_features
        self.in2_features = in2_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in1_features, in2_features))

        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input1, input2):
        return F.bilinear(input1, input2, self.weight, self.bias)

    def extra_repr(self):
        return 'in1_features={}, in2_features={}, out_features={}, bias={}'.format(
            self.in1_features, self.in2_features, self.out_features, self.bias is not None
        )
Beispiel #3
0
class Bilinear(Module):
    r"""Applies a bilinear transformation to the incoming data:
    :math:`y = x_1 A x_2 + b`

    Args:
        in1_features: size of each first input sample
        in2_features: size of each second input sample
        out_features: size of each output sample
        bias: If set to False, the layer will not learn an additive bias.
            Default: ``True``

    Shape:
        - Input: :math:`(N, in1\_features)`, :math:`(N, in2\_features)`
        - Output: :math:`(N, out\_features)`

    Attributes:
        weight: the learnable weights of the module of shape
            `(out_features x in1_features x in2_features)`
        bias:   the learnable bias of the module of shape `(out_features)`

    Examples::

        >>> m = nn.Bilinear(20, 30, 40)
        >>> input1 = torch.randn(128, 20)
        >>> input2 = torch.randn(128, 30)
        >>> output = m(input1, input2)
        >>> print(output.size())
    """

    def __init__(self, in1_features, in2_features, out_features, bias=True):
        super(Bilinear, self).__init__()
        self.in1_features = in1_features
        self.in2_features = in2_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in1_features, in2_features))

        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input1, input2):
        return F.bilinear(input1, input2, self.weight, self.bias)

    def __repr__(self):
        return self.__class__.__name__ + '(' \
            + 'in1_features=' + str(self.in1_features) \
            + ', in2_features=' + str(self.in2_features) \
            + ', out_features=' + str(self.out_features) \
            + ', bias=' + str(self.bias is not None) + ')'
Beispiel #4
0
class LinearSimilarity(SimilarityFunction):
    """
    This similarity function performs a dot product between a vector of weights and some
    combination of the two input vectors, followed by an (optional) activation function.  The
    combination used is configurable.

    If the two vectors are ``x`` and ``y``, we allow the following kinds of combinations: ``x``,
    ``y``, ``x*y``, ``x+y``, ``x-y``, ``x/y``, where each of those binary operations is performed
    elementwise.  You can list as many combinations as you want, comma separated.  For example, you
    might give ``x,y,x*y`` as the ``combination`` parameter to this class.  The computed similarity
    function would then be ``w^T [x; y; x*y] + b``, where ``w`` is a vector of weights, ``b`` is a
    bias parameter, and ``[;]`` is vector concatenation.

    Note that if you want a bilinear similarity function with a diagonal weight matrix W, where the
    similarity function is computed as `x * w * y + b` (with `w` the diagonal of `W`), you can
    accomplish that with this class by using "x*y" for `combination`.

    Parameters
    ----------
    tensor_1_dim : ``int``
        The dimension of the first tensor, ``x``, described above.  This is ``x.size()[-1]`` - the
        length of the vector that will go into the similarity computation.  We need this so we can
        build weight vectors correctly.
    tensor_2_dim : ``int``
        The dimension of the second tensor, ``y``, described above.  This is ``y.size()[-1]`` - the
        length of the vector that will go into the similarity computation.  We need this so we can
        build weight vectors correctly.
    combination : ``str``, optional (default="x,y")
        Described above.
    activation : ``Activation``, optional (default=linear (i.e. no activation))
        An activation function applied after the ``w^T * [x;y] + b`` calculation.  Default is no
        activation.
    """
    def __init__(self,
                 tensor_1_dim: int,
                 tensor_2_dim: int,
                 combination: str = 'x,y',
                 activation: Activation = None) -> None:
        super(LinearSimilarity, self).__init__()
        self._combination = combination
        combined_dim = util.get_combined_dim(combination, [tensor_1_dim, tensor_2_dim])
        self._weight_vector = Parameter(torch.Tensor(combined_dim))
        self._bias = Parameter(torch.Tensor(1))
        self._activation = activation or Activation.by_name('linear')()
        self.reset_parameters()

    def reset_parameters(self):
        std = math.sqrt(6 / (self._weight_vector.size(0) + 1))
        self._weight_vector.data.uniform_(-std, std)
        self._bias.data.fill_(0)

    @overrides
    def forward(self, tensor_1: torch.Tensor, tensor_2: torch.Tensor) -> torch.Tensor:
        combined_tensors = util.combine_tensors(self._combination, [tensor_1, tensor_2])
        dot_product = torch.matmul(combined_tensors, self._weight_vector)
        return self._activation(dot_product + self._bias)
Beispiel #5
0
class Linear(Module):
    r"""Applies a linear transformation to the incoming data: :math:`y = Ax + b`

    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        bias: If set to False, the layer will not learn an additive bias.
            Default: ``True``

    Shape:
        - Input: :math:`(N, *, in\_features)` where :math:`*` means any number of
          additional dimensions
        - Output: :math:`(N, *, out\_features)` where all but the last dimension
          are the same shape as the input.

    Attributes:
        weight: the learnable weights of the module of shape
            `(out_features x in_features)`
        bias:   the learnable bias of the module of shape `(out_features)`

    Examples::

        >>> m = nn.Linear(20, 30)
        >>> input = torch.randn(128, 20)
        >>> output = m(input)
        >>> print(output.size())
    """

    def __init__(self, in_features, out_features, bias=True):
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input):
        return F.linear(input, self.weight, self.bias)

    def __repr__(self):
        return self.__class__.__name__ + '(' \
            + 'in_features=' + str(self.in_features) \
            + ', out_features=' + str(self.out_features) \
            + ', bias=' + str(self.bias is not None) + ')'
Beispiel #6
0
class Linear(nn.Module):
    """Custom Linear layer which allows for sharing weights (e.g. with an
    nn.Embedding layer).
    """
    def __init__(self, in_features, out_features, bias=True,
                 shared_weight=None):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.shared = shared_weight is not None

        # init weight
        if not self.shared:
            self.weight = Parameter(torch.Tensor(out_features, in_features))
        else:
            if (shared_weight.size(0) != out_features or
                    shared_weight.size(1) != in_features):
                raise RuntimeError('wrong dimensions for shared weights')
            self.weight = shared_weight

        # init bias
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        if not self.shared:
            # weight is shared so don't overwrite it
            self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input):
        weight = self.weight
        if self.shared:
            # detach weight to prevent gradients from changing weight
            # (but need to detach every time so weights are up to date)
            weight = weight.detach()
        return F.linear(input, weight, self.bias)

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
            + str(self.in_features) + ' -> ' \
            + str(self.out_features) + ')'
Beispiel #7
0
class TemporalDecay(nn.Module):
    def __init__(self, input_size, output_size):
        super(TemporalDecay, self).__init__()
        self.build(input_size, output_size)

    def build(self, input_size, output_size):
        self.W = Parameter(torch.Tensor(output_size, input_size))
        self.b = Parameter(torch.Tensor(output_size))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.W.size(0))
        self.W.data.uniform_(-stdv, stdv)
        if self.b is not None:
            self.b.data.uniform_(-stdv, stdv)

    def forward(self, delta):
        gamma = F.relu(F.linear(delta, self.W, self.b))
        gamma = torch.exp(-gamma)
        return gamma
Beispiel #8
0
class WNlinear(Module):

    def __init__(self, in_features, out_features, 
                 bias=True, mask=N_, norm=True):
        super(WNlinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.register_buffer('mask',mask)
        self.norm = norm
        self.direction = Parameter(torch.Tensor(out_features, in_features))
        self.scale = Parameter(torch.Tensor(out_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', N_)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.direction.size(1))
        self.direction.data.uniform_(-stdv, stdv)
        self.scale.data.uniform_(1, 1)
        if self.bias is not N_:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input):
        if self.norm:
            dir_ = self.direction
            direction = dir_.div(dir_.pow(2).sum(1).sqrt()[:,N_])
            weight = self.scale[:,N_].mul(direction)
        else:
            weight = self.scale[:,N_].mul(self.direction)
        if self.mask is not N_:
            #weight = weight * getattr(self.mask, 
            #                          ('cpu', 'cuda')[weight.is_cuda])()
            weight = weight * Variable(self.mask)
        return F.linear(input, weight, self.bias)

    def __repr__(self):
        return self.__class__.__name__ + '(' \
            + 'in_features=' + str(self.in_features) \
            + ', out_features=' + str(self.out_features) + ')'
Beispiel #9
0
class FilterLinear_V2(nn.Module):
    def __init__(self,
                 in_features,
                 out_features,
                 filter_square_matrix,
                 bias=True,
                 device=DEVICE):
        '''
        filter_square_matrix : filter square matrix, whose each elements is 0 or 1.
        '''
        super(FilterLinear_V2, self).__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.filter_square_matrix = None
        self.filter_square_matrix = Variable(filter_square_matrix,
                                             requires_grad=False).to(device)
        self.weight = Parameter(
            torch.Tensor(out_features, in_features).to(device))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features).to(device))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input):
        return F.linear(input, self.filter_square_matrix.mul(self.weight),
                        self.bias)

    def __repr__(self):
        return self.__class__.__name__ + '(' \
               + 'in_features=' + str(self.in_features) \
               + ', out_features=' + str(self.out_features) \
               + ', bias=' + str(self.bias is not None) + ')'
Beispiel #10
0
class Lookahead(nn.Module):
    # Wang et al 2016 - Lookahead Convolution Layer for Unidirectional Recurrent Neural Networks
    # input shape - sequence, batch, feature - TxNxH
    # output shape - same as input
    def __init__(self, n_features, context):
        # should we handle batch_first=True?
        super(Lookahead, self).__init__()
        self.n_features = n_features
        self.weight = Parameter(torch.Tensor(n_features, context + 1))
        assert context > 0
        self.context = context
        self.register_parameter('bias', None)
        self.init_parameters()

    def init_parameters(self):  # what's a better way initialiase this layer?
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)

    def forward(self, input):
        seq_len = input.size(0)
        # pad the 0th dimension (T/sequence) with zeroes whose number = context
        # Once pytorch's padding functions have settled, should move to those.
        padding = torch.zeros(self.context,
                              *(input.size()[1:])).type_as(input.data)
        x = torch.cat((input, Variable(padding)), 0)

        # add lookahead windows (with context+1 width) as a fourth dimension
        # for each seq-batch-feature combination
        x = [x[i:i + self.context + 1] for i in range(seq_len)
             ]  # TxLxNxH - sequence, context, batch, feature
        x = torch.stack(x)
        x = x.permute(0, 2, 3,
                      1)  # TxNxHxL - sequence, batch, feature, context

        x = torch.mul(x, self.weight).sum(dim=3)
        return x

    def __repr__(self):
        return self.__class__.__name__ + '(' \
               + 'n_features=' + str(self.n_features) \
               + ', context=' + str(self.context) + ')'
Beispiel #11
0
class GraphConvolution(Module):
    """Simple GCN layer, similar to https://github.com/tkipf/pygcn
    """
    def __init__(self, in_features, out_features, with_bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if with_bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        # self.weight.data.fill_(1)
        # if self.bias is not None:
        #     self.bias.data.fill_(1)

        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        """ Graph Convolutional Layer forward function
        """
        if input.data.is_sparse:
            support = torch.spmm(input, self.weight)
        else:
            support = torch.mm(input, self.weight)
        output = torch.spmm(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'
Beispiel #12
0
class MemoryUnit(nn.Module):
    def __init__(self, mem_dim, fea_dim, shrink_thres=0.0025):
        super(MemoryUnit, self).__init__()
        self.mem_dim = mem_dim
        self.fea_dim = fea_dim
        self.weight = Parameter(torch.Tensor(self.mem_dim,
                                             self.fea_dim))  # M x C
        #         print("memory shape", self.weight.shape)
        self.bias = None
        self.shrink_thres = shrink_thres
        # self.hard_sparse_shrink_opt = nn.Hardshrink(lambd=shrink_thres)

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input):
        att_weight = F.linear(input,
                              self.weight)  # Fea x Mem^T, (TxC) x (CxM) = TxM
        att_weight = F.softmax(att_weight, dim=1)  # TxM
        # ReLU based shrinkage, hard shrinkage for positive value
        if (self.shrink_thres > 0):
            att_weight = hard_shrink_relu(att_weight, lambd=self.shrink_thres)
            # att_weight = F.softshrink(att_weight, lambd=self.shrink_thres)
            # normalize???
            att_weight = F.normalize(att_weight, p=1, dim=1)
            # att_weight = F.softmax(att_weight, dim=1)
            # att_weight = self.hard_sparse_shrink_opt(att_weight)
        mem_trans = self.weight.permute(1, 0)  # Mem^T, MxC
        output = F.linear(
            att_weight,
            mem_trans)  # AttWeight x Mem^T^T = AW x Mem, (TxM) x (MxC) = TxC
        return {'output': output, 'att': att_weight}  # output, att_weight

    def extra_repr(self):
        return 'mem_dim={}, fea_dim={}'.format(self.mem_dim, self.fea_dim
                                               is not None)
Beispiel #13
0
class ArcFullyConnected(Module):

    def __init__(self, in_features, out_features, s, m, is_pw=True, is_hard=False):
        super(ArcFullyConnected, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.is_pw = is_pw
        self.is_hard = is_hard
        assert s > 0
        assert 0 <= m < 0.5* math.pi
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)

    def __repr__(self):
        return ('in_features={}, out_features={}, s={}, m={}'
                .format(self.in_features, self.out_features, self.s, self.m))

    def forward(self, embed, label):
        n_weight = F.normalize(self.weight, p=2, dim=1)
        n_embed = F.normalize(embed, p=2, dim=1)*self.s
        out = F.linear(n_embed, n_weight)
        score = out.gather(1, label.view(-1, 1))
        cos_y = score / self.s
        sin_y = torch.sqrt(1 - cos_y**2)
        arc_score = self.s * (cos_y*math.cos(self.m) - sin_y*math.sin(self.m))
        if self.is_pw:
            if not self.is_hard:
                arc_score = where(score > 0, arc_score, score)
            else:
                mm = math.sin(math.pi - self.m)*self.m # actually it is sin(m)*m
                th = math.cos(math.pi - self.m) # actually it is -cos(m)
                arc_score = where((score-th) > 0, arc_score, score-self.s*mm)
        one_hot = Variable(torch.cuda.FloatTensor(out.shape).fill_(0))
        out += (arc_score - score) * one_hot.scatter_(1, label.view(-1, 1), 1)
        return out
Beispiel #14
0
class GraphConvolution(nn.Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """

    def __init__(self, in_features, out_features, bias=True, skip=True):
        super(GraphConvolution, self).__init__()
        self.skip = skip
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        # TODO make fc more efficient via "pack_padded_sequence"
        # import ipdb; ipdb.set_trace()
        support = torch.bmm(input, self.weight.unsqueeze(
            0).expand(input.shape[0], -1, -1))
        output = torch.bmm(adj, support)
        #output = SparseMM(adj)(support)
        if self.bias is not None:
            output += self.bias.unsqueeze(0).expand(input.shape[0], -1, -1)
        if self.skip:
            output += support

        return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
            + str(self.in_features) + ' -> ' \
            + str(self.out_features) + ')'
class GraphConvolution( nn.Module ):
    """
    グラフ畳み込み / 
    """
    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features

        # nn.Parameter() でネットワークのパラメータを一括に設定
        # この nn.Parameter() で作成したデータは、普通の Tensor 型とは異なり, <class 'torch.nn.parameter.Parameter'> という別の型になる
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()
        return

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        # torch.mm() : 行列の積 / input * self.weight
        support = torch.mm(input, self.weight)

        # torch.spmm() : 疎行列の演算 / adj : 隣接行列で疎行列になっている
        output = torch.spmm(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'
Beispiel #16
0
class truncated_krylov_layer(general_GCN_layer):
    def __init__(self,
                 in_features,
                 n_blocks,
                 out_features,
                 LIST_A_EXP=None,
                 LIST_A_EXP_X_CAT=None):
        super(truncated_krylov_layer, self).__init__()
        self.LIST_A_EXP = LIST_A_EXP
        self.LIST_A_EXP_X_CAT = LIST_A_EXP_X_CAT
        self.in_features, self.out_features, self.n_blocks = in_features, out_features, n_blocks
        self.shared_weight, self.output_bias = Parameter(
            torch.FloatTensor(self.in_features * self.n_blocks,
                              self.out_features).cuda()), Parameter(
                                  torch.FloatTensor(self.out_features).cuda())
        self.reset_parameters()

    def reset_parameters(self):
        stdv_shared_weight, stdv_output_bias = 1. / math.sqrt(
            self.shared_weight.size(1)), 1. / math.sqrt(
                self.output_bias.size(0))
        torch.nn.init.uniform_(self.shared_weight, -stdv_shared_weight,
                               stdv_shared_weight)
        torch.nn.init.uniform_(self.output_bias, -stdv_output_bias,
                               stdv_output_bias)

    def forward(self, input, adj, eye=True):
        if self.n_blocks == 1:
            output = torch.mm(input, self.shared_weight)
        elif self.LIST_A_EXP_X_CAT is not None:
            output = torch.mm(self.LIST_A_EXP_X_CAT, self.shared_weight)
        elif self.LIST_A_EXP is not None:
            feature_output = []
            for i in range(self.n_blocks):
                AX = self.multiplication(self.LIST_A_EXP[i], input)
                feature_output.append(AX)
            output = torch.mm(torch.cat(feature_output, 1), self.shared_weight)
        if eye:
            return output + self.output_bias
        else:
            return self.multiplication(adj, output) + self.output_bias
Beispiel #17
0
class DenseLayer(Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """

    def __init__(self, in_features, out_features, bias=False):
        super(DenseLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        # self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)
        torch.nn.init.xavier_uniform_(self.weight)

    def forward(self, self_vecs, neigh_vecs, neigh_num):
        self_vecs = torch.mm(self_vecs, self.weight)
        self_vecs = self_vecs.view(-1, 1, self.out_features)

        neigh_vecs = torch.mm(neigh_vecs, self.weight)
        neigh_vecs = neigh_vecs.view(-1, neigh_num, self.out_features)
        output = torch.cat([self_vecs, neigh_vecs], dim=1)

        output = torch.mean(output, dim=1)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'
class GraphConvolution(Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """
    def __init__(self,
                 in_features,
                 out_features,
                 bias=False,
                 act=lambda x: x,
                 dropout=0.0):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.dropout = dropout
        self.act = act
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        #self.weight.data.uniform_(-stdv, stdv)
        torch.nn.init.xavier_uniform_(self.weight)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        input = F.dropout(input, self.dropout, training=self.training)
        support = torch.mm(input, self.weight)
        output = torch.spmm(adj, support)
        if self.bias is not None:
            output = output + self.bias
        return self.act(output)

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'
class GCNLPAConv(nn.Module):
    """
    A GCN-LPA layer. Please refer to: https://arxiv.org/abs/2002.06755
    """

    def __init__(self, in_features, out_features, adj, bias=True):
        super(GCNLPAConv, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
        self.adjacency_mask = Parameter(adj.clone())

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, x, adj, y):
        y = y.view(y.size()[0], 1).float()
        # W * x
        support = torch.mm(x, self.weight)
        # Hadamard Product: A' = Hadamard(A, M)
        adj = adj * self.adjacency_mask
        # Row-Normalize: D^-1 * (A')
        adj = F.normalize(adj, p=1, dim=1)

        # output = D^-1 * A' * X * W
        output = torch.mm(adj, support)
        # y' = D^-1 * A' * y
        y_hat = torch.mm(adj, y)
        
        if self.bias is not None:
            return output + self.bias, y_hat
        else:
            return output, y_hat
Beispiel #20
0
class GraphConvolution(Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """

    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        #print("input:", input.shape)
        #print("adj:", adj.shape)
        #print("adj[0,0]:", adj)
        #print("weight:", self.weight.shape)
        support = torch.mm(input, self.weight)
        output = torch.spmm(adj, support)
        #print("support:", support.shape)
        #print("output:", output.shape)
        #exit()
        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'
Beispiel #21
0
class GraphConvolution(nn.Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """
    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj, params=None):
        if params == None:
            support = torch.matmul(input, self.weight)
            output = torch.matmul(adj, support)
            if self.bias is not None:
                return output + self.bias
            else:
                return output
        else:
            support = torch.matmul(input, params['weight'])
            output = torch.matmul(adj, support)
            if self.bias is not None:
                return output + params['bias']
            else:
                return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
            + str(self.in_features) + ' -> ' \
            + str(self.out_features) + ')'
Beispiel #22
0
class Affinity(nn.Module):
    """
    Affinity Layer to compute the affinity matrix via inner product from feature space.
    Me = X * Lambda * Y^T
    Mp = Ux * Uy^T
    Parameter: scale of weight d
    Input: edgewise (pairwise) feature X, Y
           pointwise (unary) feature Ux, Uy
    Output: edgewise affinity matrix Me
            pointwise affinity matrix Mp
    Weight: weight matrix Lambda = [[Lambda1, Lambda2],
                                    [Lambda2, Lambda1]]
            where Lambda1, Lambda2 > 0
    """

    def __init__(self, d):
        super(Affinity, self).__init__()
        self.d = d
        self.lambda1 = Parameter(Tensor(self.d, self.d))
        self.lambda2 = Parameter(Tensor(self.d, self.d))
        self.relu = nn.ReLU()  # problem: if weight<0, then always grad=0. So this parameter is never updated!
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.lambda1.size(1) * 2)
        self.lambda1.data.uniform_(-stdv, stdv)
        self.lambda2.data.uniform_(-stdv, stdv)
        self.lambda1.data += torch.eye(self.d) / 2
        self.lambda2.data += torch.eye(self.d) / 2

    def forward(self, X, Y, Ux, Uy, w1=1, w2=1):
        assert X.shape[1] == Y.shape[1] == 2 * self.d
        lambda1 = self.relu(self.lambda1 + self.lambda1.transpose(0, 1)) * w1
        lambda2 = self.relu(self.lambda2 + self.lambda2.transpose(0, 1)) * w2
        weight = torch.cat((torch.cat((lambda1, lambda2)),
                            torch.cat((lambda2, lambda1))), 1)
        Me = torch.matmul(X.transpose(1, 2), weight)
        Me = torch.matmul(Me, Y)
        Mp = torch.matmul(Ux.transpose(1, 2), Uy)

        return Me, Mp
Beispiel #23
0
class GCN_Spectral(Module):
    """ Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 """
    def __init__(self,
                 in_units: int,
                 out_units: int,
                 bias: bool = True) -> None:
        super(GCN_Spectral, self).__init__()
        self.in_units = in_units
        self.out_units = out_units
        self.weight = Parameter(torch.FloatTensor(in_units, out_units))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_units))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
        """

        weight=(input_dim X hid_dim)
        :param input: (#samples X input_dim)
        :param adj:
        :return:
        """
        support = torch.mm(input, self.weight)
        # logger.debug((adj.dtype,support.dtype))
        output = torch.spmm(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(
            self.in_units) + ' -> ' + str(self.out_units) + ')'
Beispiel #24
0
class GraphConvolution(nn.Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """
    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        self.self_loop_w = torch.nn.Linear(in_features, out_features)
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter("bias", None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.weight.size(1))
        self.weight.data.normal_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.normal_(-stdv, stdv)

    def forward(self, input, edge_index, edge_attr=None):
        if edge_attr is None:
            edge_attr = torch.ones(edge_index.shape[1]).float().to(
                input.device)
        adj = torch.sparse_coo_tensor(
            edge_index,
            edge_attr,
            (input.shape[0], input.shape[0]),
        ).to(input.device)
        support = torch.mm(input, self.weight)
        output = torch.spmm(adj, support)
        if self.bias is not None:
            return output + self.self_loop_w(input) + self.bias
        else:
            return output + self.self_loop_w(input)

    def __repr__(self):
        return (self.__class__.__name__ + " (" + str(self.in_features) +
                " -> " + str(self.out_features) + ")")
Beispiel #25
0
class BinaryModule(nn.Module):
   def __init__(self, h):
      super(BinaryModule, self).__init__()
      self.weight = Parameter(t.Tensor(1, 4))
      #self.reset_parameters()

      self.fc0 = nn.Linear(2*h, h, bias=False)
      self.fc1 = nn.Linear(h, h, bias=False)
      self.fc2 = nn.Linear(h, h, bias=False)
      self.fc3 = nn.Linear(h, h, bias=False)
      self.fc4 = nn.Linear(h, h, bias=False)

   def reset_parameters(self):
      stdv = 1. / np.sqrt(self.weight.size(1))
      self.weight.data.uniform_(-stdv, stdv)

   def forward(self, x1, x2):
      x1 = self.fc1(x1)
      x2 = self.fc2(x2)
      xx  = t.stack([x1+x2, x1-x2, x1*x2, x1/(x2+1e-4)], 1)
      xx = t.sum(xx * self.weight, 1)
      return xx
      '''
      xx = t.cat((x1, x2))
      xx = self.fc0(xx)
      xx = F.relu(xx)
      xx = self.fc1(xx)
      return xx
      xx = xx.view(2, -1)
      x1, x2 = xx[0], xx[1]
      a = self.fc1(x1+x2)
      b = self.fc2(x1-x2)
      c = self.fc3(x1*x2)
      d = self.fc4(x1/(x2+1e-4))
      return (a+b+c+d)
      '''
      #norm = t.sqrt(t.sum(t.abs(self.weight)))
      #weight = self.weight / norm
      #ret = F.sigmoid(self.weight) * xx
      ret = t.sum(ret, 1).view(1, 1)
      return ret
Beispiel #26
0
class DistMultDecoder(nn.Module):
    """DistMult Decoder model layer for link prediction."""
    def __init__(self, input_dim, num_types, bias=True, act=lambda x: x):
        super(DistMultDecoder, self).__init__()
        self.act = act
        self.num_types = num_types
        self.bias = Parameter(torch.rand(1)) if bias else 0
        self.weight = Parameter(torch.FloatTensor(num_types, input_dim))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = math.sqrt(6. / self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)

    def forward(self, input1, input2, type_index):
        relation = torch.diag(self.weight[type_index])
        intermediate_product = torch.mm(input1, relation)
        outputs = torch.mm(intermediate_product, input2.transpose(0, 1))
        outputs = outputs + self.bias

        return self.act(outputs)
Beispiel #27
0
class InputTemporalDecay(nn.Module):
    def __init__(self, input_size):
        super().__init__()

        self.W = Parameter(torch.Tensor(input_size, input_size))
        self.b = Parameter(torch.Tensor(input_size))

        m = torch.eye(input_size, input_size)
        self.register_buffer('m', m)

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.W.size(0))
        self.W.data.uniform_(-stdv, stdv)
        if self.b is not None:
            self.b.data.uniform_(-stdv, stdv)

    def forward(self, d):
        gamma = F.relu(F.linear(d, self.W * Variable(self.m), self.b))
        return torch.exp(-gamma)
Beispiel #28
0
class Source_Dictionary(nn.Module):
    r"""I basically modified the source code for the nn.Linear() class
        Removed bias, and the weights are of dimension M X EMBEDDING_SIZE X K
        INPUT: BATCH_SIZE X M X K X 1
        OUTPUT:BATCH_SIZE X K X 1
    """
    def __init__(self, M, emb_size, K):
        super(Source_Dictionary, self).__init__()
        # The weight of the dictionary is the set of M dictionaries of size EMB X K
        self.weight = Parameter(torch.Tensor(M, emb_size, K))
        self.reset_parameters()

    # Initialize parameters of Source_Dictionary
    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)

    # Operation necessary for proper batch matrix multiplication
    def forward(self, input):
        result = torch.matmul(self.weight, input)
        return result.squeeze(-1)
Beispiel #29
0
class GraphConv(Module):
    def __init__(self, in_features, out_features, bias=True):
        super(GraphConv, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        nn.init.xavier_uniform_(self.weight)
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
            stdv = 1. / math.sqrt(self.weight.size(1))
            self.bias.data.uniform_(-stdv, stdv)
            # nn.init.xavier_uniform_(self.bias)
        else:
            self.register_parameter('bias', None)

    def forward(self, x, adj):
        output = torch.mm(adj, torch.mm(x, self.weight))
        if self.bias is not None:
            return output + self.bias
        else:
            return output
Beispiel #30
0
class GraphConvLayer(Module):
    """ This class is taken from https://github.com/tkipf/pygcn.
    It implements one graph convolution layer.
    """
    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        b, _, _ = np.shape(input)
        hidden = torch.bmm(
            input, self.weight.expand(b, self.in_features, self.out_features))
        output = torch.bmm(adj, hidden)
        return output
class SelfAttention_ori(Module):
    """docstring for SelfAttention"""
    def __init__(self, in_features):
        super(SelfAttention, self).__init__()
        self.a = Parameter(torch.FloatTensor(2 * in_features, 1))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.a.size(1))
        self.a.data.uniform_(-stdv, stdv)

    def forward(self, inputs):
        x = inputs.transpose(0, 1)
        self.n = x.size()[0]
        x = torch.cat([x, torch.stack([x] * self.n, dim=0)], dim=2)
        U = torch.matmul(x, self.a).transpose(0, 1)
        # 非线性激活
        U = F.leaky_relu(U)
        weights = F.softmax(U, dim=1)
        outputs = torch.matmul(weights.transpose(1, 2), inputs).squeeze(1)
        return outputs, weights
class GCLayer(Module):
    def __init__(self, in_features, out_features):
        super(GCLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.weights = Parameter(torch.FloatTensor(in_features, out_features))
        self.bias = Parameter(torch.FloatTensor(out_features))

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weights.size(1))
        self.weights.data.uniform_(-stdv, stdv)
        self.bias.data.fill_(0)

    def forward(self, vertex, adj):
        support = torch.mm(vertex, self.weights)
        out = torch.spmm(adj, support)
        out += self.bias
        return out
Beispiel #33
0
class MyLinear(Module):
    def __init__(self, in_features, out_features, bias=True):
        super(MyLinear, self).__init__()
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input):
        if self.bias is not None:
            return torch.mm(input, self.weight) + self.bias
        else:
            return torch.mm(input, self.weight)
Beispiel #34
0
class GraphConvFilter(nn.Module):
    def __init__(self, Fin, Fout, bias=True):
        super(GraphConvFilter, self).__init__()
        self.Fout = Fout
        self.weight = Parameter(torch.FloatTensor(Fin, Fout))

        if bias:
            self.bias = Parameter(torch.FloatTensor(Fout))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, x, adj):
        # Fout 为卷积个数
        # N 即 batch_size的大小, M 为 x 的特征维度(即节点个数),Fin = 1 代表的一个timestep
        N, M, Fin = x.shape
        N, M, Fin = int(N), int(M), int(Fin)

        # print("N, M, Fin: ", N, M, Fin)

        x = torch.transpose(x, dim0=0, dim1=1)  # (156, 50, 1)
        # print("x size: ", x.shape)
        x = torch.squeeze(x)
        x = torch.mm(adj, x)
        x = torch.transpose(x, dim0=0, dim1=1)  # (50, 156, 1)
        x = torch.unsqueeze(x, dim=2)  # (50, 156, 1)
        x = torch.reshape(x, (N * M, Fin))  # (50*156, 1)
        x = torch.matmul(x, self.weight)  #  (50*156, 8)
        x = torch.reshape(x, (N, M, self.Fout))

        if self.bias is not None:
            return x + self.bias
        else:
            return x
Beispiel #35
0
class GraphConvolution(nn.Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """
    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(
            in_features, out_features))  # 特征权重(可以理解为每一张特征图加权)
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    # adj x (input x weight) + bias  图卷积的本质就是输入特征和权重和邻接矩阵相乘
    def forward(self, input, adj):  # adj代表邻接矩阵(其本质可以理解为不同提议的权重)
        support = torch.mm(
            input, self.weight
        )  # [168,1024] x [1024,512]--->[168,512] / [168,512]x[512,1024]--->[168,1024]
        output = torch.mm(
            adj, support
        )  # [168,168]x[168,512]-->[168,512]  / [168,168]x[168,1024]--->[168,1024]
        #output = SparseMM(adj)(support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output  # [168,512] / [168,1024]

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'
Beispiel #36
0
class Dense(Module):
    """
    Simple Dense layer, Do not consider adj.
    """
    def __init__(self,
                 in_features,
                 out_features,
                 activation=lambda x: x,
                 bias=True,
                 res=False):
        super(Dense, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.sigma = activation
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        self.res = res
        self.bn = nn.BatchNorm1d(out_features)
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input):
        output = torch.mm(input, self.weight)
        if self.bias is not None:
            output = output + self.bias
        output = self.bn(output)
        return self.sigma(output)

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'
Beispiel #37
0
class Lookahead(nn.Module):
    # Wang et al 2016 - Lookahead Convolution Layer for Unidirectional Recurrent Neural Networks
    # input shape - sequence, batch, feature - TxNxH
    # output shape - same as input
    def __init__(self, n_features, context):
        # should we handle batch_first=True?
        super(Lookahead, self).__init__()
        self.n_features = n_features
        self.weight = Parameter(torch.Tensor(n_features, context + 1))
        assert context > 0
        self.context = context
        self.register_parameter('bias', None)
        self.init_parameters()

    def init_parameters(self):  # what's a better way initialiase this layer?
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)

    def forward(self, input):
        seq_len = input.size(0)
        # pad the 0th dimension (T/sequence) with zeroes whose number = context
        # Once pytorch's padding functions have settled, should move to those.
        padding = torch.zeros(self.context, *(input.size()[1:])).type_as(input.data)
        x = torch.cat((input, Variable(padding)), 0)

        # add lookahead windows (with context+1 width) as a fourth dimension
        # for each seq-batch-feature combination
        x = [x[i:i + self.context + 1] for i in range(seq_len)]  # TxLxNxH - sequence, context, batch, feature
        x = torch.stack(x)
        x = x.permute(0, 2, 3, 1)  # TxNxHxL - sequence, batch, feature, context

        x = torch.mul(x, self.weight).sum(dim=3)
        return x

    def __repr__(self):
        return self.__class__.__name__ + '(' \
               + 'n_features=' + str(self.n_features) \
               + ', context=' + str(self.context) + ')'
class LinearGroupNJ(BayesianLayers):
    """Fully Connected Group Normal-Jeffrey's layer (aka Group Variational Dropout).

    References:
    [1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015).
    [2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017).
    [3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017).
    """

    def __init__(self, in_features, out_features, cuda=False, init_weight=None, init_bias=None, clip_var=None):

        super(LinearGroupNJ, self).__init__()
        self.cuda = cuda
        self.in_features = in_features
        self.out_features = out_features
        self.clip_var = clip_var
        self.deterministic = False  # flag is used for compressed inference
        # trainable params according to Eq.(6)
        # dropout params
        self.z_mu = Parameter(torch.Tensor(in_features))
        self.z_logvar = Parameter(torch.Tensor(in_features))  # = z_mu^2 * alpha
        # weight params
        self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
        self.weight_logvar = Parameter(torch.Tensor(out_features, in_features))

        self.bias_mu = Parameter(torch.Tensor(out_features))
        self.bias_logvar = Parameter(torch.Tensor(out_features))

        # init params either random or with pretrained net
        self.reset_parameters(init_weight, init_bias)

        # activations for kl
        self.sigmoid = nn.Sigmoid()
        self.softplus = nn.Softplus()

        # numerical stability param
        self.epsilon = 1e-8

    def reset_parameters(self, init_weight, init_bias):
        # init means
        stdv = 1. / math.sqrt(self.weight_mu.size(1))

        self.z_mu.data.normal_(1, 1e-2)

        if init_weight is not None:
            self.weight_mu.data = torch.Tensor(init_weight)
        else:
            self.weight_mu.data.normal_(0, stdv)

        if init_bias is not None:
            self.bias_mu.data = torch.Tensor(init_bias)
        else:
            self.bias_mu.data.fill_(0)

        # init logvars
        self.z_logvar.data.normal_(-9, 1e-2)
        self.weight_logvar.data.normal_(-9, 1e-2)
        self.bias_logvar.data.normal_(-9, 1e-2)

    def clip_variances(self):
        if self.clip_var:
            self.weight_logvar.data.clamp_(max=math.log(self.clip_var))
            self.bias_logvar.data.clamp_(max=math.log(self.clip_var))

    def get_log_dropout_rates(self):
        log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon)
        return log_alpha

    def compute_posterior_params(self):
        weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp()
        self.post_weight_var = self.z_mu.pow(2) * weight_var + z_var * self.weight_mu.pow(2) + z_var * weight_var
        self.post_weight_mu = self.weight_mu * self.z_mu
        # print("self.z_mu.pow(2): ", self.z_mu.pow(2).size())
        # print("weight_var: ", weight_var.size())
        # print("z_var: ", z_var.size())
        # print("self.weight_mu.pow(2): ", self.weight_mu.pow(2).size())
        # print("weight_var: ", weight_var.size())
        # print("post_weight_mu: ", self.post_weight_mu.size())
        # print("post_weight_var: ", self.post_weight_var.size())
        return self.post_weight_mu, self.post_weight_var

    def forward(self, x):
        if self.deterministic:
            assert self.training == False, "Flag deterministic is True. This should not be used in training."
            return F.linear(x, self.post_weight_mu, self.bias_mu)

        batch_size = x.size()[0]
        # compute z
        # note that we reparametrise according to [2] Eq. (11) (not [1])
        z = reparametrize(self.z_mu.repeat(batch_size, 1), self.z_logvar.repeat(batch_size, 1), sampling=self.training,
                          cuda=self.cuda)

        # apply local reparametrisation trick see [1] Eq. (6)
        # to the parametrisation given in [3] Eq. (6)
        xz = x * z
        mu_activations = F.linear(xz, self.weight_mu, self.bias_mu)
        var_activations = F.linear(xz.pow(2), self.weight_logvar.exp(), self.bias_logvar.exp())

        return reparametrize(mu_activations, var_activations.log(), sampling=self.training, cuda=self.cuda)

    def kl_divergence(self):
        # KL(q(z)||p(z))
        # we use the kl divergence approximation given by [2] Eq.(14)
        k1, k2, k3 = 0.63576, 1.87320, 1.48695
        log_alpha = self.get_log_dropout_rates()
        KLD = -torch.sum(k1 * self.sigmoid(k2 + k3 * log_alpha) - 0.5 * self.softplus(-log_alpha) - k1)

        # KL(q(w|z)||p(w|z))
        # we use the kl divergence given by [3] Eq.(8)
        KLD_element = -0.5 * self.weight_logvar + 0.5 * (self.weight_logvar.exp() + self.weight_mu.pow(2)) - 0.5
        KLD += torch.sum(KLD_element)

        # KL bias
        KLD_element = -0.5 * self.bias_logvar + 0.5 * (self.bias_logvar.exp() + self.bias_mu.pow(2)) - 0.5
        KLD += torch.sum(KLD_element)

        return KLD

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'
Beispiel #39
0
class LSTMCellVB(RNNCellBase):

    def __init__(self, input_size, hidden_size, bias=True, prior = None):
        super(LSTMCellVB, self).__init__()
        self.type_layer = "LSTM"
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.posterior_mean = False # Flag to know if we sample from the posterior mean or we actually sample
        
        ## If no prior is specified we just create it ourselves
        if (type(prior) == type (None)):
            prior = Vil.Prior(0.5, np.log(0.1),np.log(0.5))
        self.prior = prior
        
        """
        Variational Inference Parameters
        """
        self.mu_weight_ih = Parameter(torch.Tensor(4 * hidden_size, input_size).to(device = Vil.device, dtype = Vil.dtype))
        self.mu_weight_hh = Parameter(torch.Tensor(4 * hidden_size, hidden_size).to(device = Vil.device, dtype = Vil.dtype))
        self.rho_weight_ih = Parameter(torch.Tensor(4 * hidden_size, input_size).to(device = Vil.device, dtype = Vil.dtype))
        self.rho_weight_hh = Parameter(torch.Tensor(4 * hidden_size, hidden_size).to(device = Vil.device, dtype = Vil.dtype))
        if bias:
            self.mu_bias_ih = Parameter(torch.Tensor(4 * hidden_size).to(device = Vil.device, dtype = Vil.dtype))
            self.mu_bias_hh = Parameter(torch.Tensor(4 * hidden_size).to(device = Vil.device, dtype = Vil.dtype))
            self.rho_bias_ih = Parameter(torch.Tensor(4 * hidden_size).to(device = Vil.device, dtype = Vil.dtype))
            self.rho_bias_hh = Parameter(torch.Tensor(4 * hidden_size).to(device = Vil.device, dtype = Vil.dtype))
        else:
            self.register_parameter('mu_bias_ih', None)
            self.register_parameter('mu_bias_hh', None)
            self.register_parameter('rho_bias_ih', None)
            self.register_parameter('rho_bias_hh', None)
        """
        Sampled weights
        """

        self.weight_ih = torch.Tensor(4 * hidden_size, input_size).to(device = Vil.device, dtype = Vil.dtype)
        self.weight_hh = torch.Tensor(4 * hidden_size, hidden_size).to(device = Vil.device, dtype = Vil.dtype)
        if bias:
            self.bias_ih = torch.Tensor(4 * hidden_size).to(device = Vil.device, dtype = Vil.dtype)
            self.bias_hh = torch.Tensor(4 * hidden_size).to(device = Vil.device, dtype = Vil.dtype)
        else:
            self.register_parameter('bias_ih', None)
            self.register_parameter('bias_hh', None)

        if(1):
            print ("linear bias_ih device: ",self.bias_ih.device)
            print ("linear weights_ih device: ",self.weight_ih.device)
            print ("linear bias mu_ih device: ",self.mu_bias_ih.device)
            print ("linear bias rho_ih device: ",self.rho_bias_ih.device)
            
            print ("linear weights mu_ih  device: ",self.mu_weight_ih.device)
            print ("linear weights rho_ih device: ",self.rho_weight_ih.device)
            
            print ("linear bias_hh device: ",self.bias_hh.device)
            print ("linear weights_hh device: ",self.weight_hh.device)
            print ("linear bias mu_hh device: ",self.mu_bias_hh.device)
            print ("linear bias rho_hh device: ",self.rho_bias_hh.device)
            
            print ("linear weights mu_hh  device: ",self.mu_weight_hh.device)
            print ("linear weights rho_hh device: ",self.rho_weight_hh.device)
            
        self.reset_parameters()
        self.sample_posterior()
        
    def reset_parameters(self):
        """
        In this function we initialize the parameters using the prior.
        The variance of the weights depends on the prior !! 
        TODO: Should it depend on dimensionality ! 
        Also the initializaion of weights should follow the normal scheme or from prior ? Can they be different
        """
        
        self.rho_weight_ih.data = Vil.init_rho(self.rho_weight_ih.size(), self.prior)
        self.rho_weight_hh.data = Vil.init_rho(self.rho_weight_hh.size(), self.prior)
        if self.bias is not None:
            self.rho_bias_ih.data = Vil.init_rho(self.rho_bias_ih.size(), self.prior)
            self.rho_bias_hh.data = Vil.init_rho(self.rho_bias_hh.size(), self.prior)
            
        ## Now initialize the mean
        self.mu_weight_ih.data = Vil.init_mu(self.mu_weight_ih.size(), self.prior,Ninput = self.mu_weight_ih.size(1))
        self.mu_weight_hh.data = Vil.init_mu(self.mu_weight_hh.size(), self.prior,Ninput = self.mu_weight_hh.size(1))
        if self.bias is not None:
            self.mu_bias_ih.data = Vil.init_mu(self.mu_bias_ih.size(), self.prior, Ninput = self.mu_weight_ih.size(1))
            self.mu_bias_hh.data = Vil.init_mu(self.mu_bias_hh.size(), self.prior, Ninput = self.mu_weight_hh.size(1))

    def sample_posterior(self):
        """
        This function samples the Bayesian weights from the parameters and puts them into the variables.
        It needs to do so using the reparametrization trick so that we can derive respect to sigma and mu
        """
        if (self.posterior_mean == False):
            self.weight_ih = Vil.sample_posterior(self.mu_weight_ih, Vil.softplus(self.rho_weight_ih))
            self.weight_hh = Vil.sample_posterior(self.mu_weight_hh, Vil.softplus(self.rho_weight_hh))
            if self.bias is not None:
                self.bias_ih = Vil.sample_posterior(self.mu_bias_ih, Vil.softplus(self.rho_bias_ih))
                self.bias_hh = Vil.sample_posterior(self.mu_bias_ih, Vil.softplus(self.rho_bias_hh))     
        else:
            self.weight_ih.data = self.mu_weight_ih.data
            self.weight_hh.data = self.mu_weight_hh.data
            if self.bias is not None:
                self.bias_hh.data = self.mu_bias_hh.data
                self.bias_ih.data = self.mu_bias_ih.data
                
    def get_KL_divergence(self):
        """
        This function computes the KL loss for all the Variational Weights in the network !!
        It does not sample the weights again, it uses the ones that are already sampled.
        
        """
        KL_loss_ih = Vil.get_KL_divergence_Samples(self.mu_weight_ih, Vil.softplus(self.rho_weight_ih), self.weight_ih, self.prior)
        KL_loss_hh = Vil.get_KL_divergence_Samples(self.mu_weight_hh, Vil.softplus(self.rho_weight_hh), self.weight_hh, self.prior)
        
        KL_loss_bih = 0
        KL_loss_bhh = 0
        if self.bias is not None:
            KL_loss_bih = Vil.get_KL_divergence_Samples(self.mu_bias_ih, Vil.softplus(self.rho_bias_ih), self.bias_ih,  self.prior)
            KL_loss_bhh = Vil.get_KL_divergence_Samples(self.mu_bias_hh, Vil.softplus(self.rho_bias_hh), self.bias_hh,  self.prior)        
        KL_loss = KL_loss_ih + KL_loss_hh + KL_loss_bih +KL_loss_bhh
        return KL_loss
    
    def forward(self, input, hx=None):
        self.check_forward_input(input)
        if hx is None:
            hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
            hx = (hx, hx)
        self.check_forward_hidden(input, hx[0], '[0]')
        self.check_forward_hidden(input, hx[1], '[1]')
        return self._backend.LSTMCell(
            input, hx,
            self.weight_ih, self.weight_hh,
            self.bias_ih, self.bias_hh,
        )
    """
    Flag to set that we actually get the posterior mean and not a sample from the random variables
    """
    def set_posterior_mean(self, posterior_mean):
        self.posterior_mean = posterior_mean
Beispiel #40
0
class LinearVB(nn.Module):
    """
    Bayesian Linear Layer with parameters:
        - mu: The mean value of the 
        - rho: The sigma of th OR sigma
    
    """
    def __init__(self, in_features, out_features, bias=True, prior = None):
        super(LinearVB, self).__init__()
        self.type_layer = "linear"
        # device= conf_a.device, dtype= conf_a.dtype,
        self.in_features = in_features
        self.out_features = out_features
        self.bias = bias
        self.posterior_mean = False # Flag to know if we sample from the posterior mean or we actually sample
        
        ## If no prior is specified we just create it ourselves
        if (type(prior) == type (None)):
            prior = Vil.Prior(0.5, np.log(0.1),np.log(0.5))
        prior =  prior.get_standarized_Prior(in_features)
        self.prior = prior 
        
        """
        Mean and rhos of the parameters
        """
        self.mu_weight = Parameter(torch.Tensor(out_features, in_features))# , requires_grad=True
        self.rho_weight = Parameter(torch.Tensor(out_features, in_features))
        
        if bias:
            self.rho_bias = Parameter(torch.Tensor(out_features))
            self.mu_bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('rho_bias', None)
            self.register_parameter('mu_bias', None)
            
        """
        The sampled weights
        """
        self.weight = torch.Tensor(out_features, in_features)
        if bias:
            self.bias = torch.Tensor(out_features,1)
        else:
            self.register_parameter('bias', None)
        
        if(0):
            print ("linear bias device: ",self.bias.device)
            print ("linear weights device: ",self.weight.device)
            print ("linear bias mu device: ",self.mu_bias.device)
            print ("linear bias rho device: ",self.rho_bias.device)
            
            print ("linear weights mu  device: ",self.mu_weight.device)
            print ("linear weights rho device: ",self.rho_weight.device)
            
        ## Initialize the Variational variables
        self.reset_parameters()
        self.sample_posterior()
    
    def reset_parameters(self):
        """
        In this function we initialize the parameters using the prior.
        The variance of the weights depends on the prior !! 
        TODO: Should it depend on dimensionality ! 
        Also the initializaion of weights should follow the normal scheme or from prior ? Can they be different
        """
        print ("mu_bias prior LinearVB: ", self.prior.mu_bias)
        self.rho_weight.data = Vil.init_rho(self.mu_weight.size(), self.prior)
        if self.bias is not None:
            self.rho_bias.data = Vil.init_rho(self.mu_bias.size(), self.prior)
        
        ## Now initialize the mean
        self.mu_weight.data = Vil.init_mu(self.mu_weight.size(), self.prior,Ninput = self.mu_weight.size(1))
        if self.bias is not None:
            self.mu_bias.data = Vil.init_mu(self.mu_bias.size(), self.prior, Ninput = self.mu_weight.size(1))

    def sample_posterior(self):
        """
        This function samples the Bayesian weights from the parameters and puts them into the variables.
        It needs to do so using the reparametrization trick so that we can derive respect to sigma and mu
        """
        
#        print ("SAMPLING FROM LINEAR VB")
        if (self.posterior_mean == False):
            
            self.weight = Vil.sample_posterior(self.mu_weight, Vil.softplus(self.rho_weight))
            if self.bias is not None:
                self.bias = Vil.sample_posterior(self.mu_bias, Vil.softplus(self.rho_bias))
        else:
            self.weight.data = self.mu_weight.data
            if self.bias is not None:
                self.bias.data = self.mu_bias.data
        
    def get_KL_divergence(self):
        """
        This function computes the KL loss for all the Variational Weights in the network !!
        It does not sample the weights again, it uses the ones that are already sampled.
        
        """
        KL_loss_W = Vil.get_KL_divergence_Samples(self.mu_weight, Vil.softplus(self.rho_weight),
                                                  self.weight, self.prior, mu_prior_fluid = self.prior.mu_weight)
        KL_loss_b = 0
        if self.bias is not None:
            KL_loss_b = Vil.get_KL_divergence_Samples(self.mu_bias, Vil.softplus(self.rho_bias), 
                                                      self.bias,  self.prior, mu_prior_fluid = self.prior.mu_bias)
            
        KL_loss = KL_loss_W + KL_loss_b
        
        return KL_loss
    
    def forward(self, X):
        """
        Funciton call to generate the output, every time we call it, the dynamic graph is created.
        There can be difference between forward in training and test:
            - In dropout we do not zero neurons in test
            - In Variational Inference we dont randombly sample from the posterior
        
        We create the forward pass by performing operations between the input X (Nsam_batch, Ndim)
        and the parameters of the model that we should have initialized in the __init__
        """
        
#        o2 = torch.mm(X, self.weight) + self.bias
        o2 = F.linear(X, self.weight, self.bias)
        return o2
    
    """
    Flag to set that we actually get the posterior mean and not a sample from the random variables
    """
    def set_posterior_mean(self, posterior_mean):
        self.posterior_mean = posterior_mean
Beispiel #41
0
class LinearSimilarityVB(SimilarityFunction):
    """
    This similarity function performs a dot product between a vector of weights and some
    combination of the two input vectors, followed by an (optional) activation function.  The
    combination used is configurable.
    If the two vectors are ``x`` and ``y``, we allow the following kinds of combinations: ``x``,
    ``y``, ``x*y``, ``x+y``, ``x-y``, ``x/y``, where each of those binary operations is performed
    elementwise.  You can list as many combinations as you want, comma separated.  For example, you
    might give ``x,y,x*y`` as the ``combination`` parameter to this class.  The computed similarity
    function would then be ``w^T [x; y; x*y] + b``, where ``w`` is a vector of weights, ``b`` is a
    bias parameter, and ``[;]`` is vector concatenation.
    Note that if you want a bilinear similarity function with a diagonal weight matrix W, where the
    similarity function is computed as `x * w * y + b` (with `w` the diagonal of `W`), you can
    accomplish that with this class by using "x*y" for `combination`.
    Parameters
    ----------
    tensor_1_dim : ``int``
        The dimension of the first tensor, ``x``, described above.  This is ``x.size()[-1]`` - the
        length of the vector that will go into the similarity computation.  We need this so we can
        build weight vectors correctly.
    tensor_2_dim : ``int``
        The dimension of the second tensor, ``y``, described above.  This is ``y.size()[-1]`` - the
        length of the vector that will go into the similarity computation.  We need this so we can
        build weight vectors correctly.
    combination : ``str``, optional (default="x,y")
        Described above.
    activation : ``Activation``, optional (default=linear (i.e. no activation))
        An activation function applied after the ``w^T * [x;y] + b`` calculation.  Default is no
        activation.
    """
    def __init__(self,
                 tensor_1_dim: int,
                 tensor_2_dim: int,
                 combination: str = 'x,y',
                 activation: Activation = None,
                 prior = None) -> None:
        super(LinearSimilarityVB, self).__init__()
        self._combination = combination
        combined_dim = util.get_combined_dim(combination, [tensor_1_dim, tensor_2_dim])
        
        self.posterior_mean = False # Flag to know if we sample from the posterior mean or we actually sample
        
        ## If no prior is specified we just create it ourselves
        if (type(prior) == type (None)):
            prior = Vil.Prior(0.5, np.log(0.1),np.log(0.5))
        
        size_combination = int(torch.Tensor(combined_dim).size()[0])
#        print ("Combination size: ", size_combination)
        prior =  prior.get_standarized_Prior(size_combination)
        self.prior = prior 
        
        """
        Mean and rhos of the parameters
        """
        self.mu_weight = Parameter(torch.Tensor(combined_dim))# , requires_grad=True
        self.rho_weight = Parameter(torch.Tensor(combined_dim))

        self.rho_bias = Parameter(torch.Tensor(1))
        self.mu_bias = Parameter(torch.Tensor(1))
            
        """
        The sampled weights
        """
        self.weight = torch.Tensor(combined_dim)
        self.bias = torch.Tensor(1)
        
        self._activation = activation or Activation.by_name('linear')()
        
        ## Initialize the Variational variables
        self.reset_parameters()
#        self.sample_posterior()

    def reset_parameters(self):
#        std = math.sqrt(6 / (self._weight_vector.size(0) + 1))
#        self._weight_vector.data.uniform_(-std, std)
#        self._bias.data.fill_(0)
        
        self.rho_weight.data = Vil.init_rho(self.mu_weight.size(), self.prior)
        self.rho_bias.data = Vil.init_rho(self.mu_bias.size(), self.prior)
        
        ## Now initialize the mean
        self.mu_weight.data = Vil.init_mu(self.mu_weight.size(), 
                                          self.prior,Ninput = self.mu_weight.size(0), type = "LinearSimilarity")
        
        self.mu_bias.data = Vil.init_mu(self.mu_bias.size(), self.prior, Ninput = self.mu_weight.size(0), type = "LinearSimilarity")
        
    def sample_posterior(self):
        """
        This function samples the Bayesian weights from the parameters and puts them into the variables.
        It needs to do so using the reparametrization trick so that we can derive respect to sigma and mu
        """
        
#        print ("SAMPLING FROM LINEAR SIMILARITY VB")
        if (self.posterior_mean == False):
            self.weight = Vil.sample_posterior(self.mu_weight, Vil.softplus(self.rho_weight))
            self.bias = Vil.sample_posterior(self.mu_bias, Vil.softplus(self.rho_bias))
#            print (self.bias)
        else:
            self.weight.data = self.mu_weight.data
            self.bias.data = self.mu_bias.data
                
    @overrides
    def forward(self, tensor_1: torch.Tensor, tensor_2: torch.Tensor) -> torch.Tensor:
        combined_tensors = util.combine_tensors(self._combination, [tensor_1, tensor_2])
        dot_product = torch.matmul(combined_tensors, self.weight)
        return self._activation(dot_product + self.bias)

    def get_KL_divergence(self):
        """
        This function computes the KL loss for all the Variational Weights in the network !!
        It does not sample the weights again, it uses the ones that are already sampled.
        
        """
        KL_loss_W = Vil.get_KL_divergence_Samples(self.mu_weight, Vil.softplus(self.rho_weight), self.weight, self.prior)
        KL_loss_b = 0
        if self.bias is not None:
            KL_loss_b = Vil.get_KL_divergence_Samples(self.mu_bias, Vil.softplus(self.rho_bias), self.bias,  self.prior)
            
        KL_loss = KL_loss_W + KL_loss_b
        
        return KL_loss
    
    """
    Flag to set that we actually get the posterior mean and not a sample from the random variables
    """
    def set_posterior_mean(self, posterior_mean):
        self.posterior_mean = posterior_mean