def test_network_step_no_mix_forward(self):
    spec = blur.NetworkSpec(use_forward_activations_for_synapse_update=False)
    genome = genome_util.convert_genome_to_tf_variables(
        genome_util.create_backprop_genome(num_species=1))
    initializer = lambda params: 2 * tf.ones(params.shape, dtype=tf.float32)
    data = random_dataset()
    state = blur_meta.init_first_state(
        genome,
        synapse_initializer=initializer,
        data=data,
        hidden_layers=[256, 128])
    input_fn = data.make_one_shot_iterator().get_next
    data_support_fn, _ = blur_meta.episode_data_fn_split(input_fn)
    blur.network_step(
        state,
        genome,
        data_support_fn,
        data.make_one_shot_iterator().get_next,
        network_spec=spec,
        env=blur_env.tf_env)
    g = tf.get_default_graph()

    synapse_pre = g.get_operation_by_name(
        'step/backward/synapse_update/hebbian_pre').inputs[0]
    synapse_post = g.get_operation_by_name(
        'step/backward/synapse_update/hebbian_post').inputs[0]
    self.assertIn('backward', synapse_pre.name)
    self.assertIn('backward', synapse_post.name)
    self.assertNotIn('forward', synapse_pre.name)
    self.assertNotIn('forward', synapse_post.name)
  def test_verify_gradient_match_jp(self):
    tf.reset_default_graph()
    tf.disable_v2_behavior()
    num_in = 10
    num_out = 15
    inp, out, ow = random_dense(num_in, num_out)
    env = blur_env.jp_env
    pre, post, synapse, _, network_spec = get_blur_state(env, inp, out, ow)

    genome = genome_util.create_backprop_genome(num_species=1)

    update1, update2 = blur.get_synaptic_update(
        pre,
        post,
        synapse=synapse,
        input_transform_gn=genome.neuron.transform,
        update_type=synapse_util.UpdateType.BOTH,
        env=env)

    hebbian_update = blur.get_hebbian_update(
        pre,
        post,
        genome.synapse.transform,
        global_spec=network_spec,
        env=env)
    grad_weights, grad_image, y = tf_gradients(inp, out, ow)

    np.set_printoptions(precision=4, linewidth=200)

    with tf.Session():
      verify_equal(update1, update2, hebbian_update, grad_weights.eval(),
                   grad_image.eval(), y.eval(), num_in, num_out)
def get_blur_state(env, inp, out, ow):
  pre = np.concatenate([inp, np.zeros_like(inp)],
                       axis=-1).astype(blur_env.NP_FLOATING_TYPE)
  post = np.concatenate([np.zeros_like(out), out],
                        axis=-1).astype(blur_env.NP_FLOATING_TYPE)
  ww = ow.astype(blur_env.NP_FLOATING_TYPE)
  ww = ww[Ellipsis, None]
  synapse = synapse_util.combine_in_out_synapses(
      ww, synapse_util.transpose_synapse(ww, env), env=env)
  synapse = synapse_util.sync_states_synapse(synapse, env, num_states=2)

  genome = genome_util.convert_genome_to_tf_variables(
      genome_util.create_backprop_genome(num_species=1))

  network_spec = blur.NetworkSpec()
  network_spec.symmetric_synapses = True
  network_spec.batch_average = False
  network_spec.backward_update = 'multiplicative_second_state'

  return pre, post, synapse, genome, network_spec