Ejemplo n.º 1
0
  def test_sync_in_out_synapses(self):
    num_in = 3
    num_out = 2
    num_states = 2
    env = blur_env.tf_env
    in_out_synapse = tf.random.normal(shape=(num_in + 1, num_out, num_states))
    out_in_synapse = tf.random.normal(shape=(num_out, num_in + 1, num_states))

    synapse = synapse_util.combine_in_out_synapses(in_out_synapse,
                                                   out_in_synapse, env)
    synapse_synced = synapse_util.sync_in_and_out_synapse(synapse, num_in, env)
    fwd_sync_submatrix = synapse_util.synapse_submatrix(
        synapse_synced,
        num_in,
        synapse_util.UpdateType.FORWARD,
        include_bias=True)

    bkw_sync_submatrix = synapse_util.synapse_submatrix(
        synapse_synced,
        num_in,
        synapse_util.UpdateType.BACKWARD,
        include_bias=True)

    with tf.Session() as s:
      bwd, fwd, inp = s.run([
          synapse_util.transpose_synapse(bkw_sync_submatrix, env),
          fwd_sync_submatrix, in_out_synapse
      ])
    self.assertAllEqual(fwd, inp)
    self.assertAllEqual(bwd, inp)
Ejemplo n.º 2
0
def get_blur_state(env, inp, out, ow):
  pre = np.concatenate([inp, np.zeros_like(inp)],
                       axis=-1).astype(blur_env.NP_FLOATING_TYPE)
  post = np.concatenate([np.zeros_like(out), out],
                        axis=-1).astype(blur_env.NP_FLOATING_TYPE)
  ww = ow.astype(blur_env.NP_FLOATING_TYPE)
  ww = ww[Ellipsis, None]
  synapse = synapse_util.combine_in_out_synapses(
      ww, synapse_util.transpose_synapse(ww, env), env=env)
  synapse = synapse_util.sync_states_synapse(synapse, env, num_states=2)

  genome = genome_util.convert_genome_to_tf_variables(
      genome_util.create_backprop_genome(num_species=1))

  network_spec = blur.NetworkSpec()
  network_spec.symmetric_synapses = True
  network_spec.batch_average = False
  network_spec.backward_update = 'multiplicative_second_state'

  return pre, post, synapse, genome, network_spec
Ejemplo n.º 3
0
def default_state_update(state, genome, network_spec, env):
    """Transforms system state (currently can normalize synapses)."""
    synapse_norm_spec = network_spec.synapse_normalization_spec
    if synapse_norm_spec is None:
        return state
    rescale_to = genome.synapse.rescale_to if synapse_norm_spec.rescale else None
    normalize = ft.partial(synapse_util.normalize_synapses,
                           rescale_to=rescale_to,
                           env=env)
    if synapse_norm_spec.normalize_synapses:
        for i in range(len(state.synapses)):
            synapse = state.synapses[i]
            in_channels = state.layers[i].shape[-2]
            forward = synapse[Ellipsis, :(in_channels + 1),
                              (in_channels + 1):, :]
            # We include an extra dimension for axis=-2 here because it does not
            # influence normalization, but the tensors would have proper shapes:
            backward = synapse[Ellipsis,
                               (in_channels + 1):, :(in_channels + 1), :]
            state.synapses[i] = synapse_util.combine_in_out_synapses(
                normalize(forward), normalize(backward), env)

    return state