def copy_variable_to_graph(org_instance, to_graph, scope=''):
    """Given a `Variable` instance from one `Graph`, initializes and returns
  a copy of it from another `Graph`, under the specified scope
  (default `""`).

  Args:
    org_instance: A `Variable` from some `Graph`.
    to_graph: The `Graph` to copy the `Variable` to.
    scope: A scope for the new `Variable` (default `""`).

  Returns:
    The copied `Variable` from `to_graph`.

  Raises:
    TypeError: If `org_instance` is not a `Variable`.
  """

    if not isinstance(org_instance, Variable):
        raise TypeError(str(org_instance) + ' is not a Variable')

    #The name of the new variable
    if scope != '':
        new_name = (scope + '/' +
                    org_instance.name[:org_instance.name.index(':')])
    else:
        new_name = org_instance.name[:org_instance.name.index(':')]

    #Get the collections that the new instance needs to be added to.
    #The new collections will also be a part of the given scope,
    #except the special ones required for variable initialization and
    #training.
    collections = []
    for name, collection in org_instance.graph._collections.items():
        if org_instance in collection:
            if (name == ops.GraphKeys.GLOBAL_VARIABLES
                    or name == ops.GraphKeys.TRAINABLE_VARIABLES
                    or scope == ''):
                collections.append(name)
            else:
                collections.append(scope + '/' + name)

    #See if its trainable.
    trainable = (org_instance in org_instance.graph.get_collection(
        ops.GraphKeys.TRAINABLE_VARIABLES))
    #Get the initial value
    with org_instance.graph.as_default():
        temp_session = Session()
        init_value = temp_session.run(org_instance.initialized_value())

    #Initialize the new variable
    with to_graph.as_default():
        new_var = Variable(init_value,
                           trainable,
                           name=new_name,
                           collections=collections,
                           validate_shape=False)

    return new_var
Beispiel #2
0
    def test_alias_tensors(self):
        a = constant(1)
        v = Variable(2)
        s = 'a'
        l = [1, 2, 3]

        new_a, new_v, new_s, new_l = misc.alias_tensors(a, v, s, l)

        self.assertFalse(new_a is a)
        self.assertTrue(new_v is v)
        self.assertTrue(new_s is s)
        self.assertTrue(new_l is l)
        with self.cached_session() as sess:
            self.assertEqual(1, sess.run(new_a))
def add_variable_to_graph(output_graph, var_name, init_value,
                          trainable=True, collections=[], scope=''):
    if scope != '':
        new_name = scope + '/' + var_name
    else:
        new_name = var_name

    with output_graph.as_default():
        new_var = Variable(
            init_value,
            trainable,
            name=new_name,
            collections=collections,
            validate_shape=False)
        new_var.set_shape(init_value.shape)
    return new_var
 def f():
     inputs = Variable(array_ops.zeros([32, 100], dtypes.float32))
     del inputs
Beispiel #5
0
    def __init__(self,
                 n_in,
                 n_rec,
                 tau=20.,
                 thr=0.03,
                 dt=1.,
                 n_refractory=0,
                 dtype=tf.float32,
                 n_delay=1,
                 rewiring_connectivity=-1,
                 in_neuron_sign=None,
                 rec_neuron_sign=None,
                 dampening_factor=0.3,
                 injected_noise_current=0.,
                 V0=1.):
        """
        Tensorflow cell object that simulates a LIF neuron with an approximation of the spike derivatives.

        :param n_in: number of input neurons
        :param n_rec: number of recurrent neurons
        :param tau: membrane time constant
        :param thr: threshold voltage
        :param dt: time step of the simulation
        :param n_refractory: number of refractory time steps
        :param dtype: data type of the cell tensors
        :param n_delay: number of synaptic delay, the delay range goes from 1 to n_delay time steps
        :param reset: method of resetting membrane potential after spike thr-> by fixed threshold amount, zero-> to zero
        """

        if np.isscalar(tau): tau = tf.ones(n_rec, dtype=dtype) * np.mean(tau)
        if np.isscalar(thr): thr = tf.ones(n_rec, dtype=dtype) * np.mean(thr)
        tau = tf.cast(tau, dtype=dtype)
        dt = tf.cast(dt, dtype=dtype)

        self.dampening_factor = dampening_factor

        # Parameters
        self.n_delay = n_delay
        self.n_refractory = n_refractory

        self.dt = dt
        self.n_in = n_in
        self.n_rec = n_rec
        self.data_type = dtype

        self._num_units = self.n_rec

        self.tau = tf.Variable(tau, dtype=dtype, name="Tau", trainable=False)
        self._decay = tf.exp(-dt / tau)
        self.thr = tf.Variable(thr,
                               dtype=dtype,
                               name="Threshold",
                               trainable=False)

        self.V0 = V0
        self.injected_noise_current = injected_noise_current

        self.rewiring_connectivity = rewiring_connectivity
        self.in_neuron_sign = in_neuron_sign
        self.rec_neuron_sign = rec_neuron_sign

        with tf.variable_scope('InputWeights'):

            # Input weights
            if 0 < rewiring_connectivity < 1:
                self.w_in_val, self.w_in_sign, self.w_in_var, _ = weight_sampler(
                    n_in,
                    n_rec,
                    rewiring_connectivity,
                    neuron_sign=in_neuron_sign)
            else:
                self.w_in_var = tf.Variable(rd.randn(n_in, n_rec) /
                                            np.sqrt(n_in),
                                            dtype=dtype,
                                            name="InputWeight")
                self.w_in_val = self.w_in_var

            self.w_in_val = self.V0 * self.w_in_val
            self.w_in_delay = tf.Variable(rd.randint(
                self.n_delay, size=n_in * n_rec).reshape(n_in, n_rec),
                                          dtype=tf.int64,
                                          name="InDelays",
                                          trainable=False)
            self.W_in = weight_matrix_with_delay_dimension(
                self.w_in_val, self.w_in_delay, self.n_delay)

        with tf.variable_scope('RecWeights'):
            if 0 < rewiring_connectivity < 1:
                self.w_rec_val, self.w_rec_sign, self.w_rec_var, _ = weight_sampler(
                    n_rec,
                    n_rec,
                    rewiring_connectivity,
                    neuron_sign=rec_neuron_sign)
            else:
                if rec_neuron_sign is not None or in_neuron_sign is not None:
                    raise NotImplementedError(
                        'Neuron sign requested but this is only implemented with rewiring'
                    )
                self.w_rec_var = Variable(rd.randn(n_rec, n_rec) /
                                          np.sqrt(n_rec),
                                          dtype=dtype,
                                          name='RecurrentWeight')
                self.w_rec_val = self.w_rec_var

            recurrent_disconnect_mask = np.diag(np.ones(n_rec, dtype=bool))

            self.w_rec_val = self.w_rec_val * self.V0
            self.w_rec_val = tf.where(recurrent_disconnect_mask,
                                      tf.zeros_like(self.w_rec_val),
                                      self.w_rec_val)  # Disconnect autotapse
            self.w_rec_delay = tf.Variable(rd.randint(
                self.n_delay, size=n_rec * n_rec).reshape(n_rec, n_rec),
                                           dtype=tf.int64,
                                           name="RecDelays",
                                           trainable=False)
            self.W_rec = weight_matrix_with_delay_dimension(
                self.w_rec_val, self.w_rec_delay, self.n_delay)