コード例 #1
0
ファイル: blur.py プロジェクト: lucifer2288/google-research
def get_hebbian_update(pre, post, transform, global_spec, env):
    """Performs hebbian update of the synapse weight matrix.

  Δ w = [pre, [1], post] @ transform.pre * transform.post @  [pre, [0],
  post]

  Args:
   pre: [population] x batch_size x in_channels x num_states
   post: [population] x batch_size x out_channels x num_states
   transform: genome_util.HebbianTransform
   global_spec: Specification of the network.
   env: Environment

  Returns:
   Update.  [in_channels+1 + in_channels, out_channels + 1 + out_channels, k]
  """
    inp = env.concat([env.concat_row(pre, 1), post], axis=-2)
    out = env.concat([env.concat_row(pre, 1), post], axis=-2)
    # inp: [b x (in+out) x k],
    # transforms: [k x k]
    hebbian_update = env.einsum(FC_SYNAPSE_UPDATE, inp, transform.pre,
                                transform.post, out)
    if global_spec.symmetric_in_out_synapses:
        hebbian_update = synapse_util.sync_in_and_out_synapse(
            hebbian_update, pre.shape[-2], env)
    if global_spec.symmetric_states_synapses:
        hebbian_update = synapse_util.sync_states_synapse(hebbian_update, env)
    return hebbian_update
コード例 #2
0
  def test_sync_in_out_synapses(self):
    num_in = 3
    num_out = 2
    num_states = 2
    env = blur_env.tf_env
    in_out_synapse = tf.random.normal(shape=(num_in + 1, num_out, num_states))
    out_in_synapse = tf.random.normal(shape=(num_out, num_in + 1, num_states))

    synapse = synapse_util.combine_in_out_synapses(in_out_synapse,
                                                   out_in_synapse, env)
    synapse_synced = synapse_util.sync_in_and_out_synapse(synapse, num_in, env)
    fwd_sync_submatrix = synapse_util.synapse_submatrix(
        synapse_synced,
        num_in,
        synapse_util.UpdateType.FORWARD,
        include_bias=True)

    bkw_sync_submatrix = synapse_util.synapse_submatrix(
        synapse_synced,
        num_in,
        synapse_util.UpdateType.BACKWARD,
        include_bias=True)

    with tf.Session() as s:
      bwd, fwd, inp = s.run([
          synapse_util.transpose_synapse(bkw_sync_submatrix, env),
          fwd_sync_submatrix, in_out_synapse
      ])
    self.assertAllEqual(fwd, inp)
    self.assertAllEqual(bwd, inp)
コード例 #3
0
def init_first_state(genome,
                     data,
                     hidden_layers,
                     synapse_initializer,
                     create_synapses_fn=synapse_util.create_synapses,
                     network_spec=blur.NetworkSpec(),
                     env=blur_env.tf_env):
    """Initialize the very first state of the graph."""
    if create_synapses_fn is None:
        create_synapses_fn = synapse_util.create_synapses

    if callable(genome):
        num_neuron_states = genome(0).synapse.transform.pre.shape[-1]
    else:
        num_neuron_states = genome.synapse.transform.pre.shape[-1]
    output_shape = data.element_spec['support'][1].shape
    num_outputs = output_shape[-1]
    num_inputs = data.element_spec['support'][0].shape[-1]
    batch_dims = output_shape[:-1]

    layer_sizes = (num_inputs, *hidden_layers, num_outputs)
    layers = [
        tf.zeros((*batch_dims, h, num_neuron_states)) for h in layer_sizes
    ]
    synapses = create_synapses_fn(layers, synapse_initializer)
    if network_spec.symmetric_in_out_synapses:
        for i in range(len(synapses)):
            synapses[i] = synapse_util.sync_in_and_out_synapse(
                synapses[i], layers[i].shape[-2], env)
    if network_spec.symmetric_states_synapses:
        for i in range(len(synapses)):
            synapses[i] = synapse_util.sync_states_synapse(synapses[i], env)
    num_updatable_units = len(hidden_layers) + 1
    ground_truth = tf.zeros((*batch_dims, num_outputs))

    return blur.NetworkState(layers=layers,
                             synapses=synapses,
                             ground_truth=ground_truth,
                             updatable=[False] + [True] * num_updatable_units)