def get_hebbian_update(pre, post, transform, global_spec, env): """Performs hebbian update of the synapse weight matrix. Δ w = [pre, [1], post] @ transform.pre * transform.post @ [pre, [0], post] Args: pre: [population] x batch_size x in_channels x num_states post: [population] x batch_size x out_channels x num_states transform: genome_util.HebbianTransform global_spec: Specification of the network. env: Environment Returns: Update. [in_channels+1 + in_channels, out_channels + 1 + out_channels, k] """ inp = env.concat([env.concat_row(pre, 1), post], axis=-2) out = env.concat([env.concat_row(pre, 1), post], axis=-2) # inp: [b x (in+out) x k], # transforms: [k x k] hebbian_update = env.einsum(FC_SYNAPSE_UPDATE, inp, transform.pre, transform.post, out) if global_spec.symmetric_in_out_synapses: hebbian_update = synapse_util.sync_in_and_out_synapse( hebbian_update, pre.shape[-2], env) if global_spec.symmetric_states_synapses: hebbian_update = synapse_util.sync_states_synapse(hebbian_update, env) return hebbian_update
def test_sync_in_out_synapses(self): num_in = 3 num_out = 2 num_states = 2 env = blur_env.tf_env in_out_synapse = tf.random.normal(shape=(num_in + 1, num_out, num_states)) out_in_synapse = tf.random.normal(shape=(num_out, num_in + 1, num_states)) synapse = synapse_util.combine_in_out_synapses(in_out_synapse, out_in_synapse, env) synapse_synced = synapse_util.sync_in_and_out_synapse(synapse, num_in, env) fwd_sync_submatrix = synapse_util.synapse_submatrix( synapse_synced, num_in, synapse_util.UpdateType.FORWARD, include_bias=True) bkw_sync_submatrix = synapse_util.synapse_submatrix( synapse_synced, num_in, synapse_util.UpdateType.BACKWARD, include_bias=True) with tf.Session() as s: bwd, fwd, inp = s.run([ synapse_util.transpose_synapse(bkw_sync_submatrix, env), fwd_sync_submatrix, in_out_synapse ]) self.assertAllEqual(fwd, inp) self.assertAllEqual(bwd, inp)
def init_first_state(genome, data, hidden_layers, synapse_initializer, create_synapses_fn=synapse_util.create_synapses, network_spec=blur.NetworkSpec(), env=blur_env.tf_env): """Initialize the very first state of the graph.""" if create_synapses_fn is None: create_synapses_fn = synapse_util.create_synapses if callable(genome): num_neuron_states = genome(0).synapse.transform.pre.shape[-1] else: num_neuron_states = genome.synapse.transform.pre.shape[-1] output_shape = data.element_spec['support'][1].shape num_outputs = output_shape[-1] num_inputs = data.element_spec['support'][0].shape[-1] batch_dims = output_shape[:-1] layer_sizes = (num_inputs, *hidden_layers, num_outputs) layers = [ tf.zeros((*batch_dims, h, num_neuron_states)) for h in layer_sizes ] synapses = create_synapses_fn(layers, synapse_initializer) if network_spec.symmetric_in_out_synapses: for i in range(len(synapses)): synapses[i] = synapse_util.sync_in_and_out_synapse( synapses[i], layers[i].shape[-2], env) if network_spec.symmetric_states_synapses: for i in range(len(synapses)): synapses[i] = synapse_util.sync_states_synapse(synapses[i], env) num_updatable_units = len(hidden_layers) + 1 ground_truth = tf.zeros((*batch_dims, num_outputs)) return blur.NetworkState(layers=layers, synapses=synapses, ground_truth=ground_truth, updatable=[False] + [True] * num_updatable_units)