Beispiel #1
0
def get_hebbian_update(pre, post, transform, global_spec, env):
    """Performs hebbian update of the synapse weight matrix.

  Δ w = [pre, [1], post] @ transform.pre * transform.post @  [pre, [0],
  post]

  Args:
   pre: [population] x batch_size x in_channels x num_states
   post: [population] x batch_size x out_channels x num_states
   transform: genome_util.HebbianTransform
   global_spec: Specification of the network.
   env: Environment

  Returns:
   Update.  [in_channels+1 + in_channels, out_channels + 1 + out_channels, k]
  """
    inp = env.concat([env.concat_row(pre, 1), post], axis=-2)
    out = env.concat([env.concat_row(pre, 1), post], axis=-2)
    # inp: [b x (in+out) x k],
    # transforms: [k x k]
    hebbian_update = env.einsum(FC_SYNAPSE_UPDATE, inp, transform.pre,
                                transform.post, out)
    if global_spec.symmetric_in_out_synapses:
        hebbian_update = synapse_util.sync_in_and_out_synapse(
            hebbian_update, pre.shape[-2], env)
    if global_spec.symmetric_states_synapses:
        hebbian_update = synapse_util.sync_states_synapse(hebbian_update, env)
    return hebbian_update
def get_blur_state(env, inp, out, ow):
  pre = np.concatenate([inp, np.zeros_like(inp)],
                       axis=-1).astype(blur_env.NP_FLOATING_TYPE)
  post = np.concatenate([np.zeros_like(out), out],
                        axis=-1).astype(blur_env.NP_FLOATING_TYPE)
  ww = ow.astype(blur_env.NP_FLOATING_TYPE)
  ww = ww[Ellipsis, None]
  synapse = synapse_util.combine_in_out_synapses(
      ww, synapse_util.transpose_synapse(ww, env), env=env)
  synapse = synapse_util.sync_states_synapse(synapse, env, num_states=2)

  genome = genome_util.convert_genome_to_tf_variables(
      genome_util.create_backprop_genome(num_species=1))

  network_spec = blur.NetworkSpec()
  network_spec.symmetric_synapses = True
  network_spec.batch_average = False
  network_spec.backward_update = 'multiplicative_second_state'

  return pre, post, synapse, genome, network_spec
Beispiel #3
0
def init_first_state(genome,
                     data,
                     hidden_layers,
                     synapse_initializer,
                     create_synapses_fn=synapse_util.create_synapses,
                     network_spec=blur.NetworkSpec(),
                     env=blur_env.tf_env):
    """Initialize the very first state of the graph."""
    if create_synapses_fn is None:
        create_synapses_fn = synapse_util.create_synapses

    if callable(genome):
        num_neuron_states = genome(0).synapse.transform.pre.shape[-1]
    else:
        num_neuron_states = genome.synapse.transform.pre.shape[-1]
    output_shape = data.element_spec['support'][1].shape
    num_outputs = output_shape[-1]
    num_inputs = data.element_spec['support'][0].shape[-1]
    batch_dims = output_shape[:-1]

    layer_sizes = (num_inputs, *hidden_layers, num_outputs)
    layers = [
        tf.zeros((*batch_dims, h, num_neuron_states)) for h in layer_sizes
    ]
    synapses = create_synapses_fn(layers, synapse_initializer)
    if network_spec.symmetric_in_out_synapses:
        for i in range(len(synapses)):
            synapses[i] = synapse_util.sync_in_and_out_synapse(
                synapses[i], layers[i].shape[-2], env)
    if network_spec.symmetric_states_synapses:
        for i in range(len(synapses)):
            synapses[i] = synapse_util.sync_states_synapse(synapses[i], env)
    num_updatable_units = len(hidden_layers) + 1
    ground_truth = tf.zeros((*batch_dims, num_outputs))

    return blur.NetworkState(layers=layers,
                             synapses=synapses,
                             ground_truth=ground_truth,
                             updatable=[False] + [True] * num_updatable_units)