def test_network_step_no_mix_forward(self): spec = blur.NetworkSpec(use_forward_activations_for_synapse_update=False) genome = genome_util.convert_genome_to_tf_variables( genome_util.create_backprop_genome(num_species=1)) initializer = lambda params: 2 * tf.ones(params.shape, dtype=tf.float32) data = random_dataset() state = blur_meta.init_first_state( genome, synapse_initializer=initializer, data=data, hidden_layers=[256, 128]) input_fn = data.make_one_shot_iterator().get_next data_support_fn, _ = blur_meta.episode_data_fn_split(input_fn) blur.network_step( state, genome, data_support_fn, data.make_one_shot_iterator().get_next, network_spec=spec, env=blur_env.tf_env) g = tf.get_default_graph() synapse_pre = g.get_operation_by_name( 'step/backward/synapse_update/hebbian_pre').inputs[0] synapse_post = g.get_operation_by_name( 'step/backward/synapse_update/hebbian_post').inputs[0] self.assertIn('backward', synapse_pre.name) self.assertIn('backward', synapse_post.name) self.assertNotIn('forward', synapse_pre.name) self.assertNotIn('forward', synapse_post.name)
def test_verify_gradient_match_jp(self): tf.reset_default_graph() tf.disable_v2_behavior() num_in = 10 num_out = 15 inp, out, ow = random_dense(num_in, num_out) env = blur_env.jp_env pre, post, synapse, _, network_spec = get_blur_state(env, inp, out, ow) genome = genome_util.create_backprop_genome(num_species=1) update1, update2 = blur.get_synaptic_update( pre, post, synapse=synapse, input_transform_gn=genome.neuron.transform, update_type=synapse_util.UpdateType.BOTH, env=env) hebbian_update = blur.get_hebbian_update( pre, post, genome.synapse.transform, global_spec=network_spec, env=env) grad_weights, grad_image, y = tf_gradients(inp, out, ow) np.set_printoptions(precision=4, linewidth=200) with tf.Session(): verify_equal(update1, update2, hebbian_update, grad_weights.eval(), grad_image.eval(), y.eval(), num_in, num_out)
def get_blur_state(env, inp, out, ow): pre = np.concatenate([inp, np.zeros_like(inp)], axis=-1).astype(blur_env.NP_FLOATING_TYPE) post = np.concatenate([np.zeros_like(out), out], axis=-1).astype(blur_env.NP_FLOATING_TYPE) ww = ow.astype(blur_env.NP_FLOATING_TYPE) ww = ww[Ellipsis, None] synapse = synapse_util.combine_in_out_synapses( ww, synapse_util.transpose_synapse(ww, env), env=env) synapse = synapse_util.sync_states_synapse(synapse, env, num_states=2) genome = genome_util.convert_genome_to_tf_variables( genome_util.create_backprop_genome(num_species=1)) network_spec = blur.NetworkSpec() network_spec.symmetric_synapses = True network_spec.batch_average = False network_spec.backward_update = 'multiplicative_second_state' return pre, post, synapse, genome, network_spec