def test_network_step_no_mix_forward(self): spec = blur.NetworkSpec(use_forward_activations_for_synapse_update=False) genome = genome_util.convert_genome_to_tf_variables( genome_util.create_backprop_genome(num_species=1)) initializer = lambda params: 2 * tf.ones(params.shape, dtype=tf.float32) data = random_dataset() state = blur_meta.init_first_state( genome, synapse_initializer=initializer, data=data, hidden_layers=[256, 128]) input_fn = data.make_one_shot_iterator().get_next data_support_fn, _ = blur_meta.episode_data_fn_split(input_fn) blur.network_step( state, genome, data_support_fn, data.make_one_shot_iterator().get_next, network_spec=spec, env=blur_env.tf_env) g = tf.get_default_graph() synapse_pre = g.get_operation_by_name( 'step/backward/synapse_update/hebbian_pre').inputs[0] synapse_post = g.get_operation_by_name( 'step/backward/synapse_update/hebbian_post').inputs[0] self.assertIn('backward', synapse_pre.name) self.assertIn('backward', synapse_post.name) self.assertNotIn('forward', synapse_pre.name) self.assertNotIn('forward', synapse_post.name)
def get_blur_state(env, inp, out, ow): pre = np.concatenate([inp, np.zeros_like(inp)], axis=-1).astype(blur_env.NP_FLOATING_TYPE) post = np.concatenate([np.zeros_like(out), out], axis=-1).astype(blur_env.NP_FLOATING_TYPE) ww = ow.astype(blur_env.NP_FLOATING_TYPE) ww = ww[Ellipsis, None] synapse = synapse_util.combine_in_out_synapses( ww, synapse_util.transpose_synapse(ww, env), env=env) synapse = synapse_util.sync_states_synapse(synapse, env, num_states=2) genome = genome_util.convert_genome_to_tf_variables( genome_util.create_backprop_genome(num_species=1)) network_spec = blur.NetworkSpec() network_spec.symmetric_synapses = True network_spec.batch_average = False network_spec.backward_update = 'multiplicative_second_state' return pre, post, synapse, genome, network_spec
def init_first_state(genome, data, hidden_layers, synapse_initializer, create_synapses_fn=synapse_util.create_synapses, network_spec=blur.NetworkSpec(), env=blur_env.tf_env): """Initialize the very first state of the graph.""" if create_synapses_fn is None: create_synapses_fn = blur.create_synapses if callable(genome): num_neuron_states = genome(0).synapse.transform.pre.shape[-1] else: num_neuron_states = genome.synapse.transform.pre.shape[-1] output_shape = data.element_spec[1].shape num_outputs = output_shape[-1] num_inputs = data.element_spec[0].shape[-1] batch_dims = output_shape[:-1] layer_sizes = (num_inputs, *hidden_layers, num_outputs) layers = [ tf.zeros((*batch_dims, h, num_neuron_states)) for h in layer_sizes ] synapses = create_synapses_fn(layers, synapse_initializer) if network_spec.symmetric_in_out_synapses: for i in range(len(synapses)): synapses[i] = blur.sync_in_and_out_synapse(synapses[i], layers[i].shape[-2], env) if network_spec.symmetric_states_synapses: for i in range(len(synapses)): synapses[i] = blur.sync_states_synapse(synapses[i], env) num_updatable_units = len(hidden_layers) + 1 ground_truth = tf.zeros((*batch_dims, num_outputs)) return blur.NetworkState(layers=layers, synapses=synapses, ground_truth=ground_truth, updatable=[False] + [True] * num_updatable_units)