def bn_gamma_beta(self, x):
     if self.use_cuda:
         ones = Parameter(torch.ones(x.size()[0], 1).cuda())
     else:
         ones = Parameter(torch.ones(x.size()[0], 1))
     t = x + ones.mm(self.bn_beta)
     if self.train_bn_scaling:
         t = torch.mul(t, ones.mm(self.bn_gamma))
     return t
Example #2
0
class GraphLearning(Module):
    def __init__(self, in_features):
        super(GraphLearning, self).__init__()
        self.in_features = in_features
        self.weight = Parameter(torch.FloatTensor(1, in_features))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)

    def forward(self, inputs):
        s = torch.zeros((len(inputs), len(inputs)))
        for i in range(len(inputs)):
            for j in range(len(inputs)):
                # print(torch.exp(F.relu(self.weight.mm(torch.abs(inputs[i] - inputs[j]).unsqueeze(0).t())))[0][0])
                # print(s[i, j])
                # exit()
                s[i, j] = torch.exp(
                    F.relu(
                        self.weight.mm(
                            torch.abs(inputs[i] -
                                      inputs[j]).unsqueeze(0).t())))[0][0]
        A = F.softmax(s, dim=1)
        D = torch.diag(torch.sum(A, dim=1))
        return A, D

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.in_features) + ')'
Example #3
0
 def encode(self, x, add_noise=False):
     if add_noise:
         tilde_x = self.corrupt(x)
     else:
         tilde_x = x.clone()
     ones = Parameter(torch.ones(self.batch_size, 1))
     t = tilde_x.mm(self.W)
     t = t + ones.mm(self.b)
     t = self.sigmoid1.forward(t)
     return t
Example #4
0
    def g(self, tilde_z_l, u_l):
        ones = Parameter(torch.ones(tilde_z_l.size()[0], 1).to(self.use_cuda))
        b_a1 = ones.mm(self.a1)
        b_a2 = ones.mm(self.a2)
        b_a3 = ones.mm(self.a3)
        b_a4 = ones.mm(self.a4)
        b_a5 = ones.mm(self.a5)

        b_a6 = ones.mm(self.a6)
        b_a7 = ones.mm(self.a7)
        b_a8 = ones.mm(self.a8)
        b_a9 = ones.mm(self.a9)
        b_a10 = ones.mm(self.a10)

        mu_l = torch.mul(b_a1, torch.sigmoid(torch.mul(b_a2, u_l) + b_a3)) + \
               torch.mul(b_a4, u_l) + \
               b_a5

        v_l = torch.mul(b_a6, torch.sigmoid(torch.mul(b_a7, u_l) + b_a8)) + \
              torch.mul(b_a9, u_l) + \
              b_a10

        hat_z_l = torch.mul(tilde_z_l - mu_l, v_l) + mu_l

        return hat_z_l
Example #5
0
class Dainet(torch.nn.Module):
    def __init__(self, m, n):
        super(Dainet, self).__init__()
        self.vec_1t_n = torch.ones(1, n)
        self.vec_1t_m = torch.ones(1, m)
        self.m = m
        self.r = Parameter(torch.ones(m, 1))

        print(self.r)
        print(self.r.size())
        print(self.r.requires_grad)

    def forward(self, x):
        r1t = self.r.mm(self.vec_1t_n)  # r1t
        r1tA = r1t.mul(x)
        output = self.vec_1t_m.mm(r1tA)

        new_m = self.r.t().mm(torch.ones(m, 1))
        return output / new_m
Example #6
0
class SVDConv2d(Module):
    '''
    W = UdV
    '''
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 norm=False):
        self.eps = 1e-8
        self.norm = norm

        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        super(SVDConv2d, self).__init__()

        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.total_in_dim = in_channels * kernel_size[0] * kernel_size[1]
        self.weiSize = (self.out_channels, in_channels, kernel_size[0],
                        kernel_size[1])

        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.output_padding = _pair(0)
        self.groups = groups

        self.scale = Parameter(torch.Tensor(1))
        self.scale.data.fill_(1)

        if self.out_channels <= self.total_in_dim:
            self.Uweight = Parameter(
                torch.Tensor(self.out_channels, self.out_channels))
            self.Dweight = Parameter(torch.Tensor(self.out_channels))
            self.Vweight = Parameter(
                torch.Tensor(self.out_channels, self.total_in_dim))
            self.Uweight.data.normal_(0, math.sqrt(2. / self.out_channels))
            self.Vweight.data.normal_(0, math.sqrt(2. / self.total_in_dim))
            self.Dweight.data.fill_(1)
        else:
            self.Uweight = Parameter(
                torch.Tensor(self.out_channels, self.total_in_dim))
            self.Dweight = Parameter(torch.Tensor(self.total_in_dim))
            self.Vweight = Parameter(
                torch.Tensor(self.total_in_dim, self.total_in_dim))
            self.Uweight.data.normal_(0, math.sqrt(2. / self.out_channels))
            self.Vweight.data.normal_(0, math.sqrt(2. / self.total_in_dim))
            self.Dweight.data.fill_(1)
        self.projectiter = 0
        self.project(style='qr', interval=1)

        if bias:
            self.bias = Parameter(torch.Tensor(self.out_channels))
            self.bias.data.fill_(0)
        else:
            self.register_parameter('bias', None)

        if norm:
            self.register_buffer(
                'input_norm_wei',
                torch.ones(1, in_channels // groups, *kernel_size))

    def update_sigma(self):
        self.Dweight.data = self.Dweight.data / self.Dweight.data.abs().max()

    def spectral_reg(self):
        return -(torch.log(self.Dweight)).mean()

    @property
    def W_(self):
        self.update_sigma()
        return self.Uweight.mm(self.Dweight.diag()).mm(self.Vweight).view(
            self.weiSize) * self.scale

    def forward(self, input):
        _output = F.conv2d(input, self.W_, self.bias, self.stride,
                           self.padding, self.dilation, self.groups)
        return _output

    def orth_reg(self):
        penalty = 0

        if self.out_channels <= self.total_in_dim:
            W = self.Uweight
        else:
            W = self.Uweight.t()
        Wt = torch.t(W)
        WWt = W.mm(Wt)
        I = Variable(torch.eye(WWt.size()[0]).cuda())
        penalty = penalty + ((WWt.sub(I))**2).sum()

        W = self.Vweight
        Wt = torch.t(W)
        WWt = W.mm(Wt)
        I = Variable(torch.eye(WWt.size()[0]).cuda())
        penalty = penalty + ((WWt.sub(I))**2).sum()
        return penalty

    def project(self, style='none', interval=1):
        '''
        Project weight to l2 ball
        '''
        self.projectiter = self.projectiter + 1
        if style == 'qr' and self.projectiter % interval == 0:
            # Compute the qr factorization for U
            if self.out_channels <= self.total_in_dim:
                q, r = torch.qr(self.Uweight.data.t())
            else:
                q, r = torch.qr(self.Uweight.data)
            # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
            d = torch.diag(r, 0)
            ph = d.sign()
            q *= ph
            if self.out_channels <= self.total_in_dim:
                self.Uweight.data = q.t()
            else:
                self.Uweight.data = q

            # Compute the qr factorization for V
            q, r = torch.qr(self.Vweight.data.t())
            # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
            d = torch.diag(r, 0)
            ph = d.sign()
            q *= ph
            self.Vweight.data = q.t()
        elif style == 'svd' and self.projectiter % interval == 0:
            # Compute the svd factorization (may be not stable) for U
            u, s, v = torch.svd(self.Uweight.data)
            self.Uweight.data = u.mm(v.t())

            # Compute the svd factorization (may be not stable) for V
            u, s, v = torch.svd(self.Vweight.data)
            self.Vweight.data = u.mm(v.t())

    def showOrthInfo(self):
        s = self.Dweight.data
        _D = self.Dweight.data.diag()
        W = self.Uweight.data.mm(_D).mm(self.Vweight.data)
        _, ss, _ = torch.svd(W.t())
        print('Singular Value Summary: ')
        print('max :', s.max().item(), 'max* :', ss.max().item())
        print('mean:', s.mean().item(), 'mean*:', ss.mean().item())
        print('min :', s.min().item(), 'min* :', ss.min().item())
        print('var :', s.var().item(), 'var* :', ss.var().item())
        print('s RMSE: ', ((s - ss)**2).mean().item()**0.5)
        if self.out_channels <= self.total_in_dim:
            pu = (self.Uweight.data.mm(self.Uweight.data.t()) -
                  torch.eye(self.Uweight.size()[0]).cuda()).norm().item()**2
        else:
            pu = (self.Uweight.data.t().mm(self.Uweight.data) -
                  torch.eye(self.Uweight.size()[1]).cuda()).norm().item()**2
        pv = (self.Vweight.data.mm(self.Vweight.data.t()) -
              torch.eye(self.Vweight.size()[0]).cuda()).norm().item()**2
        print('penalty :', pu, ' (U) + ', pv, ' (V)')
        return ss
Example #7
0
 def decode(self, x):
     ones = Parameter(torch.ones(self.batch_size, 1))
     t = x.mm(self.W.transpose(1, 0)) + ones.mm(self.b_prime)
     t = self.sigmoid2.forward(t)
     return t
Example #8
0
class LSTM_attention(nn.Module):
    def __init__(self, word_dim1, word_dim2, hidden_dim=128):
        super(LSTM_attention, self).__init__()
        # Assign instance variables
        self.word_dim1 = word_dim1
        self.word_dim2 = word_dim2
        self.hidden_dim = hidden_dim
        self.beta1 = 0.9
        self.beta2 = 0.999
        self.eps = 1e-8
        # Randomly initialize the network parameters
        self.U = Parameter(torch.randn(4 * hidden_dim, word_dim1))
        self.W = Parameter(torch.randn(4 * hidden_dim, hidden_dim))

        self.U_ = Parameter(torch.randn(4 * hidden_dim, word_dim2))
        self.W_ = Parameter(torch.randn(4 * hidden_dim, hidden_dim))

        self.V = Parameter(torch.randn(word_dim2, 2 * hidden_dim))

        self.vU = torch.zeros_like(self.U)
        self.vV = torch.zeros_like(self.V)
        self.vW = torch.zeros_like(self.W)
        self.mU = torch.zeros_like(self.U)
        self.mV = torch.zeros_like(self.V)
        self.mW = torch.zeros_like(self.W)

        self.vU_ = torch.zeros_like(self.U_)
        self.vW_ = torch.zeros_like(self.W_)
        self.mU_ = torch.zeros_like(self.U_)
        self.mW_ = torch.zeros_like(self.W_)

        self.Loss = 0

    def forward_propagation(self, x, y):
        # The total number of time steps
        T1 = len(x)
        # During forward propagation we save all hidden states in s because need them later.
        # We add one additional element for the initial hidden, which we set to 0
        h = torch.zeros((T1 + 1, self.hidden_dim))
        h[-1] = torch.zeros(self.hidden_dim)
        c = torch.zeros((T1 + 1, self.hidden_dim))
        o = torch.zeros((T1, self.hidden_dim))
        i = torch.zeros((T1, self.hidden_dim))
        f = torch.zeros((T1, self.hidden_dim))
        g = torch.zeros((T1, self.hidden_dim))
        # The outputs at each time step. Again, we save them for later.
        # For each time step...
        H = self.hidden_dim
        for t in torch.arange(T1):
            # Note that we are indxing U by x[t]. This is the same as multiplying U with a one-hot vector.
            temp = self.U[:, x[t]] + self.W.mm(h[t - 1].clone().view(
                -1, 1)).squeeze()
            i[t] = sigmoid(temp[:H])
            f[t] = sigmoid(temp[H:2 * H])
            o[t] = sigmoid(temp[2 * H:3 * H])
            g[t] = torch.tanh(temp[3 * H:])
            c[t] = f[t].clone() * (c[t - 1].clone()) + i[t].clone() * (
                g[t].clone())
            h[t] = o[t].clone() * torch.tanh(c[t].clone())

        T2 = len(y)
        # During forward propagation we save all hidden states in s because need them later.
        # We add one additional element for the initial hidden, which we set to 0
        s = torch.zeros((T2 + 1, self.hidden_dim))
        s[-1] = h[-2]
        c_ = torch.zeros((T2 + 1, self.hidden_dim))
        c_[-1] = c[-2]
        o_ = torch.zeros((T2, self.hidden_dim))
        i_ = torch.zeros((T2, self.hidden_dim))
        f_ = torch.zeros((T2, self.hidden_dim))
        g_ = torch.zeros((T2, self.hidden_dim))
        # The outputs at each time step. Again, we save them for later.
        output = torch.zeros((T2, self.word_dim2))
        result = torch.zeros((T2, self.hidden_dim * 2))
        alpha = torch.zeros((T2, T1))
        # For each time step...
        for t in torch.arange(T2):
            # Note that we are indxing U by x[t]. This is the same as multiplying U with a one-hot vector.
            temp_ = self.U_[:, y[t]] + self.W_.mm(s[t - 1].clone().view(
                -1, 1)).squeeze()
            i_[t] = sigmoid(temp_[:H])
            f_[t] = sigmoid(temp_[H:2 * H])
            o_[t] = sigmoid(temp_[2 * H:3 * H])
            g_[t] = torch.tanh(temp_[3 * H:])
            c_[t] = f_[t].clone() * (c_[t - 1].clone()) + i_[t].clone() * (
                g_[t].clone())
            s[t] = o_[t].clone() * torch.tanh(c_[t].clone())
            e = torch.mm(h[:-1].clone(), s[t].clone().view(-1, 1)).squeeze()
            alpha[t] = softmax(e)
            a = torch.mm(alpha[t].clone().view(1, -1),
                         h[:-1].clone()).squeeze()
            result[t] = torch.cat((a, s[t]), 0)
            output[t] = softmax(self.V.mm(result[t].clone().view(
                -1, 1))).squeeze()

        return output

    def bptt(self, x, y):
        T2 = len(y)
        y_ = y[1:]
        y_.append(1)
        # Perform forward propagation
        output = self.forward_propagation(x, y)
        # We accumulate the gradients in these variables
        loss = -torch.log(
            torch.gather(output, 1,
                         torch.tensor(y_).view(output.shape[0], 1)))
        loss = torch.sum(loss)

        return loss

    def predict(self, x):
        T1 = len(x)
        # During forward propagation we save all hidden states in s because need them later.
        # We add one additional element for the initial hidden, which we set to 0
        h = torch.zeros((T1 + 1, self.hidden_dim))
        h[-1] = torch.zeros(self.hidden_dim)
        c = torch.zeros((T1 + 1, self.hidden_dim))
        o = torch.zeros((T1, self.hidden_dim))
        i = torch.zeros((T1, self.hidden_dim))
        f = torch.zeros((T1, self.hidden_dim))
        g = torch.zeros((T1, self.hidden_dim))
        # The outputs at each time step. Again, we save them for later.
        # For each time step...
        H = self.hidden_dim
        for t in torch.arange(T1):
            # Note that we are indxing U by x[t]. This is the same as multiplying U with a one-hot vector.
            temp = self.U[:, x[t]] + self.W.mm(h[t - 1].clone().view(
                -1, 1)).squeeze()
            i[t] = sigmoid(temp[:H])
            f[t] = sigmoid(temp[H:2 * H])
            o[t] = sigmoid(temp[2 * H:3 * H])
            g[t] = torch.tanh(temp[3 * H:])
            c[t] = f[t].clone() * (c[t - 1].clone()) + i[t].clone() * (
                g[t].clone())
            h[t] = o[t].clone() * torch.tanh(c[t].clone())

        s_pre = h[-2]
        c__pre = c[-2]
        z = 0
        pred = []
        att = []

        # The outputs at each time step. Again, we save them for later.
        output = torch.zeros((self.word_dim2))
        # For each time step...
        step = 0
        while z != 1 and step <= 10:
            temp_ = self.U_[:, z] + self.W_.mm(s_pre.clone().view(
                -1, 1)).squeeze()
            i_ = sigmoid(temp_[:H])
            f_ = sigmoid(temp_[H:2 * H])
            o_ = sigmoid(temp_[2 * H:3 * H])
            g_ = torch.tanh(temp_[3 * H:])
            c_ = f_.clone() * (c__pre.clone()) + i_.clone() * (g_.clone())
            s = o_.clone() * torch.tanh(c_.clone())
            e = torch.mm(h[:-1].clone(), s.clone().view(-1, 1)).squeeze()
            alpha = softmax(e)
            a = torch.mm(alpha.clone().view(1, -1), h[:-1].clone()).squeeze()
            result = torch.cat((a, s), 0)
            output = softmax(self.V.mm(result.clone().view(-1, 1))).squeeze()
            z = torch.argmax(output)
            pred.append(z)
            att.append(alpha)
            s_pre = s
            c__pre = c_
            step += 1

        return pred, att

    # Performs one step of SGD.
    def numpy_sdg_step(self, x, y, learning_rate):
        # Calculate the gradients

        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        loss = self.bptt(x, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return loss

    def train_with_sgd(self,
                       X_train,
                       y_train,
                       learning_rate=0.0025,
                       nepoch=100):
        # We keep track of the losses so we can plot them later
        losses = []
        num_examples_seen = 1
        Loss_len = 0
        for epoch in range(nepoch):
            # For each training example...
            for i in range(len(y_train)):
                # One SGD step
                self.Loss += self.numpy_sdg_step(X_train[i], y_train[i],
                                                 learning_rate)
                Loss_len += 1
                if num_examples_seen == len(y_train) * nepoch:
                    last_loss = self.Loss.item() / Loss_len
                elif num_examples_seen % (1024 * 3) == 0:
                    time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
                    print(time,
                          ' ',
                          int(100 * num_examples_seen /
                              (len(y_train) * nepoch)),
                          end='')
                    print('%   완료!!!', end='')
                    print('   loss :', self.Loss.item() / Loss_len)
                    self.Loss = 0
                    Loss_len = 0
                num_examples_seen += 1

        return last_loss
class SpGraphAttentionLayer(Module):
    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super(SpGraphAttentionLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat

        self.W = Parameter(torch.FloatTensor(in_features, out_features))
        self.a = Parameter(torch.FloatTensor(1, 2 * out_features))

        self.dropout = nn.Dropout(dropout)
        self.leakyrelu = nn.LeakyReLU(self.alpha)
        self.special_spmm = SpecialSpmm()
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

    def forward(self, input, adj):
        # Apply features transformation
        h = torch.mm(input, self.W)
        # h: N x out
        assert not torch.isnan(h).any()
        N = input.size()[0]

        # edge: 2*D x E
        edge = adj._indices()

        # Self-attention on the nodes - Shared attention mechanism
        edge_h = torch.cat((h[edge[0, :], :], h[edge[1, :], :]), dim=1).t()
        # edge_e: E
        edge_e = self.leakyrelu(self.a.mm(edge_h).squeeze())

        # Do sparse softmax:
        # - Uses trick like logsumexp to confirm the stability (avoid underflow and overflow)
        # - softmax([1,3]) = [0.1192, 0.8808]
        # - softmax([-2,0]) = [0.1192, 0.8808]

        # Find the max of each row, edge_r_rm should shaped as (E,)
        # edge_e = torch.sparse_coo_tensor(edge, edge_e, torch.Size([N, N]))
        # edge_e_rm = torch.sparse.max(edge_e, dim=1)
        # edge_e_rm = torch.max(edge_e) # For simple, just use the max of all
        # Subtract the max value of each row
        # edge_e = edge_e - edge_e_rm
        edge_e = -edge_e
        # Do exp
        edge_e = torch.exp(edge_e)
        assert not torch.isnan(edge_e).any()
        # Do sum
        # e_rowsum: N x 1
        e_rowsum = self.special_spmm(edge, edge_e, torch.Size([N, N]),
                                     input.new_full([N, 1], fill_value=1))

        # edge_e: E
        edge_e = self.dropout(edge_e)

        # h_prime: N x out
        h_prime = self.special_spmm(edge, edge_e, torch.Size([N, N]), h)
        assert not torch.isnan(h_prime).any()

        # h_prime: N x out
        h_prime = h_prime.div(e_rowsum + 1e-9)
        assert not torch.isnan(h_prime).any()

        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'
Example #10
0
class Nnsae(nn.Module):

    # Constructor for a NNSAE class
    # input:
    #  - inpDim gives the Data sample dimension
    #  - hidDim specifies size of the hidden layer
    # output
    #  - net is the created Non-Negative Sparse Autoencoder
    __constants__ = ['inpDim', 'hidDim']

    def __init__(self, inpDim, hidDim, batch_size=1):
        torch.autograd.set_detect_anomaly(True)
        super(Nnsae, self).__init__()
        self.inpDim = inpDim  # number of input neurons (and output neurons)
        self.hidDim = hidDim  # number of hidden neurons
        self.nonlin = torch.sigmoid

        self.inp = torch.zeros(self.inpDim, 1)  # vector holding current input
        self.out = torch.zeros(self.hidDim, 1)  # output neurons
        # neural activity before non-linearity
        self.h = torch.zeros(self.hidDim,
                             batch_size)  # hidden neuron activation
        self.g = torch.zeros(self.hidDim, batch_size)  # pre hidden neuron
        self.a = Parameter(torch.ones(self.hidDim, 1))
        self.b = Parameter(torch.ones(self.hidDim, 1) * (-3.0))
        self.weights = Parameter(torch.zeros(inpDim, hidDim))
        self.scale = 0.025
        self.weights.data = self.scale * (2 * torch.rand(
            inpDim, hidDim) - 0.5 * torch.ones(inpDim, hidDim)) + self.scale

        # learning rate for synaptic plasticity of read-out layer (RO)
        self.lrateRO = 0.01
        self.regRO = 0.0002  # numerical regularization constant

        self.lrateIP = 0.001  # learning rate for intrinsic plasticity (IP)
        self.meanIP = 0.2  # desired mean activity, a parameter of IP
        self._cuda = False

    def __setstate__(self, state):
        super(Nnsae, self).__setstate__(state)

    @property
    def cuda(self):
        return self._cuda

    def to(self, device):
        super().to(device)
        self._cuda = device.type != 'cpu'
        self.a.to(device)
        self.b.to(device)
        self.h = self.h.to(device)
        self.g = self.g.to(device)
        self.out.to(device)
        self.inp.to(device)
        self.weights.to(device)

    def ip(self):
        h = self.h
        tmp = self.lrateIP * (1.0 - (2.0 + 1.0 / self.meanIP) * h +
                              (h**2) / self.meanIP)
        self.b += tmp.sum(1, keepdim=True)
        a_tmp = self.lrateIP / self.a + self.g * tmp
        self.a += a_tmp.sum(1, keepdim=True)

    def bpdc(self, error):
        # calculate adaptive learning rate
        lrate = (self.lrateRO / (self.regRO +
                                 (self.h**2).sum(0, keepdim=True))).diag()
        self.weights.data += error.mm(lrate * (self.h).t())

    def fit(self, inp):
        # forward path
        out = self.forward(inp)
        # bpdc.step()
        error = inp - out
        self.bpdc(error)
        # non negative constraint
        self.weights.data[self.weights < 0] = 0
        # intrinsic plasticity
        self.ip()
        return out, error

    def forward(self, x):
        # Here the forward pass is simply a linear function
        g = self.weights.t().mm(x)
        h = self.nonlin(self.a * g + self.b)
        out = self.weights.mm(h)

        self.g[:, :] = g.detach()
        self.h[:, :] = h.detach()
        return out

    def save_state_dict(self, fileName):
        torch.save(self.state_dict(), fileName)

    def extra_repr(self):
        s = ('({inpDim} x {hidDim})')
        s += ', Intrinsic plasticity: mean={meanIP}, leaning rate={lrateIP}'
        s += '; Synaptic plasticity: learning rate={lrateRO}, epsilon={regRO}'
        return s.format(**self.__dict__)