def test_network_step_no_mix_forward(self):
    spec = blur.NetworkSpec(use_forward_activations_for_synapse_update=False)
    genome = genome_util.convert_genome_to_tf_variables(
        genome_util.create_backprop_genome(num_species=1))
    initializer = lambda params: 2 * tf.ones(params.shape, dtype=tf.float32)
    data = random_dataset()
    state = blur_meta.init_first_state(
        genome,
        synapse_initializer=initializer,
        data=data,
        hidden_layers=[256, 128])
    input_fn = data.make_one_shot_iterator().get_next
    data_support_fn, _ = blur_meta.episode_data_fn_split(input_fn)
    blur.network_step(
        state,
        genome,
        data_support_fn,
        data.make_one_shot_iterator().get_next,
        network_spec=spec,
        env=blur_env.tf_env)
    g = tf.get_default_graph()

    synapse_pre = g.get_operation_by_name(
        'step/backward/synapse_update/hebbian_pre').inputs[0]
    synapse_post = g.get_operation_by_name(
        'step/backward/synapse_update/hebbian_post').inputs[0]
    self.assertIn('backward', synapse_pre.name)
    self.assertIn('backward', synapse_post.name)
    self.assertNotIn('forward', synapse_pre.name)
    self.assertNotIn('forward', synapse_post.name)
def get_blur_state(env, inp, out, ow):
  pre = np.concatenate([inp, np.zeros_like(inp)],
                       axis=-1).astype(blur_env.NP_FLOATING_TYPE)
  post = np.concatenate([np.zeros_like(out), out],
                        axis=-1).astype(blur_env.NP_FLOATING_TYPE)
  ww = ow.astype(blur_env.NP_FLOATING_TYPE)
  ww = ww[Ellipsis, None]
  synapse = synapse_util.combine_in_out_synapses(
      ww, synapse_util.transpose_synapse(ww, env), env=env)
  synapse = synapse_util.sync_states_synapse(synapse, env, num_states=2)

  genome = genome_util.convert_genome_to_tf_variables(
      genome_util.create_backprop_genome(num_species=1))

  network_spec = blur.NetworkSpec()
  network_spec.symmetric_synapses = True
  network_spec.batch_average = False
  network_spec.backward_update = 'multiplicative_second_state'

  return pre, post, synapse, genome, network_spec
Exemple #3
0
def init_first_state(genome,
                     data,
                     hidden_layers,
                     synapse_initializer,
                     create_synapses_fn=synapse_util.create_synapses,
                     network_spec=blur.NetworkSpec(),
                     env=blur_env.tf_env):
    """Initialize the very first state of the graph."""
    if create_synapses_fn is None:
        create_synapses_fn = blur.create_synapses

    if callable(genome):
        num_neuron_states = genome(0).synapse.transform.pre.shape[-1]
    else:
        num_neuron_states = genome.synapse.transform.pre.shape[-1]
    output_shape = data.element_spec[1].shape
    num_outputs = output_shape[-1]
    num_inputs = data.element_spec[0].shape[-1]
    batch_dims = output_shape[:-1]

    layer_sizes = (num_inputs, *hidden_layers, num_outputs)
    layers = [
        tf.zeros((*batch_dims, h, num_neuron_states)) for h in layer_sizes
    ]
    synapses = create_synapses_fn(layers, synapse_initializer)
    if network_spec.symmetric_in_out_synapses:
        for i in range(len(synapses)):
            synapses[i] = blur.sync_in_and_out_synapse(synapses[i],
                                                       layers[i].shape[-2],
                                                       env)
    if network_spec.symmetric_states_synapses:
        for i in range(len(synapses)):
            synapses[i] = blur.sync_states_synapse(synapses[i], env)
    num_updatable_units = len(hidden_layers) + 1
    ground_truth = tf.zeros((*batch_dims, num_outputs))

    return blur.NetworkState(layers=layers,
                             synapses=synapses,
                             ground_truth=ground_truth,
                             updatable=[False] + [True] * num_updatable_units)