z = outputs
z_regular = z[:, :, :FLAGS.n_regular]
z_adaptive = z[:, :, FLAGS.n_regular:]

with tf.name_scope('ClassificationLoss'):
    psp_decay = np.exp(
        -dt / FLAGS.tau_v
    )  # output layer psp decay, chose value between 15 and 30ms as for tau_v
    psp = exp_convolve(z, decay=psp_decay)
    n_neurons = z.get_shape()[2]

    # Define the readout weights
    if 0 < FLAGS.rewiring_connectivity:
        w_out, w_out_sign, w_out_var, _ = weight_sampler(
            FLAGS.n_regular + FLAGS.n_adaptive,
            n_output_symbols,
            FLAGS.rewiring_connectivity,
            neuron_sign=rec_neuron_sign)
    else:
        w_out = tf.get_variable(name='out_weight',
                                shape=[n_neurons, n_output_symbols])
    b_out = tf.get_variable(name='out_bias',
                            shape=[n_output_symbols],
                            initializer=tf.zeros_initializer())

    # Define the loss function
    out = einsum_bij_jk_to_bik(psp, w_out) + b_out

    if FLAGS.crs_thr:
        outt = tf_downsample(out, new_size=(28 + 2) * FLAGS.ext_time,
                             axis=1)  # 32 x 30 x 10
Exemplo n.º 2
0
    def __init__(self,
                 n_in,
                 n_rec,
                 tau=20.,
                 thr=0.03,
                 dt=1.,
                 n_refractory=0,
                 dtype=tf.float32,
                 n_delay=1,
                 rewiring_connectivity=-1,
                 in_neuron_sign=None,
                 rec_neuron_sign=None,
                 dampening_factor=0.3,
                 injected_noise_current=0.,
                 V0=1.):
        """
        Tensorflow cell object that simulates a LIF neuron with an approximation of the spike derivatives.

        :param n_in: number of input neurons
        :param n_rec: number of recurrent neurons
        :param tau: membrane time constant
        :param thr: threshold voltage
        :param dt: time step of the simulation
        :param n_refractory: number of refractory time steps
        :param dtype: data type of the cell tensors
        :param n_delay: number of synaptic delay, the delay range goes from 1 to n_delay time steps
        :param reset: method of resetting membrane potential after spike thr-> by fixed threshold amount, zero-> to zero
        """

        if np.isscalar(tau): tau = tf.ones(n_rec, dtype=dtype) * np.mean(tau)
        if np.isscalar(thr): thr = tf.ones(n_rec, dtype=dtype) * np.mean(thr)
        tau = tf.cast(tau, dtype=dtype)
        dt = tf.cast(dt, dtype=dtype)

        self.dampening_factor = dampening_factor

        # Parameters
        self.n_delay = n_delay
        self.n_refractory = n_refractory

        self.dt = dt
        self.n_in = n_in
        self.n_rec = n_rec
        self.data_type = dtype

        self._num_units = self.n_rec

        self.tau = tf.Variable(tau, dtype=dtype, name="Tau", trainable=False)
        self._decay = tf.exp(-dt / tau)
        self.thr = tf.Variable(thr,
                               dtype=dtype,
                               name="Threshold",
                               trainable=False)

        self.V0 = V0
        self.injected_noise_current = injected_noise_current

        self.rewiring_connectivity = rewiring_connectivity
        self.in_neuron_sign = in_neuron_sign
        self.rec_neuron_sign = rec_neuron_sign

        with tf.variable_scope('InputWeights'):

            # Input weights
            if 0 < rewiring_connectivity < 1:
                self.w_in_val, self.w_in_sign, self.w_in_var, _ = weight_sampler(
                    n_in,
                    n_rec,
                    rewiring_connectivity,
                    neuron_sign=in_neuron_sign)
            else:
                self.w_in_var = tf.Variable(rd.randn(n_in, n_rec) /
                                            np.sqrt(n_in),
                                            dtype=dtype,
                                            name="InputWeight")
                self.w_in_val = self.w_in_var

            self.w_in_val = self.V0 * self.w_in_val
            self.w_in_delay = tf.Variable(rd.randint(
                self.n_delay, size=n_in * n_rec).reshape(n_in, n_rec),
                                          dtype=tf.int64,
                                          name="InDelays",
                                          trainable=False)
            self.W_in = weight_matrix_with_delay_dimension(
                self.w_in_val, self.w_in_delay, self.n_delay)

        with tf.variable_scope('RecWeights'):
            if 0 < rewiring_connectivity < 1:
                self.w_rec_val, self.w_rec_sign, self.w_rec_var, _ = weight_sampler(
                    n_rec,
                    n_rec,
                    rewiring_connectivity,
                    neuron_sign=rec_neuron_sign)
            else:
                if rec_neuron_sign is not None or in_neuron_sign is not None:
                    raise NotImplementedError(
                        'Neuron sign requested but this is only implemented with rewiring'
                    )
                self.w_rec_var = Variable(rd.randn(n_rec, n_rec) /
                                          np.sqrt(n_rec),
                                          dtype=dtype,
                                          name='RecurrentWeight')
                self.w_rec_val = self.w_rec_var

            recurrent_disconnect_mask = np.diag(np.ones(n_rec, dtype=bool))

            self.w_rec_val = self.w_rec_val * self.V0
            self.w_rec_val = tf.where(recurrent_disconnect_mask,
                                      tf.zeros_like(self.w_rec_val),
                                      self.w_rec_val)  # Disconnect autotapse
            self.w_rec_delay = tf.Variable(rd.randint(
                self.n_delay, size=n_rec * n_rec).reshape(n_rec, n_rec),
                                           dtype=tf.int64,
                                           name="RecDelays",
                                           trainable=False)
            self.W_rec = weight_matrix_with_delay_dimension(
                self.w_rec_val, self.w_rec_delay, self.n_delay)