def test_non_recurrent_mappings(self): insize = 2 hidden1_size = 4 hidden2_size = 5 seq_length = 7 batch_size = 3 # As mentioned above, non-recurrent cores are not supported with # skip connections. But test that some number of non-recurrent cores # is okay (particularly as the last core) without skip connections. cores1 = [snt.LSTM(hidden1_size), tf.tanh, snt.Linear(hidden2_size)] core1 = snt.DeepRNN(cores1, skip_connections=False) core1_h0 = core1.initial_state(batch_size=batch_size) cores2 = [snt.LSTM(hidden1_size), snt.Linear(hidden2_size), tf.tanh] core2 = snt.DeepRNN(cores2, skip_connections=False) core2_h0 = core2.initial_state(batch_size=batch_size) xseq = tf.random_normal(shape=[seq_length, batch_size, insize]) y1, _ = tf.nn.dynamic_rnn(core1, xseq, initial_state=core1_h0, time_major=True) y2, _ = tf.nn.dynamic_rnn(core2, xseq, initial_state=core2_h0, time_major=True) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) sess.run([y1, y2])
def testInitialState(self, trainable, use_custom_initial_value): batch_size = 3 hidden1_size = 4 hidden2_size = 5 output1_size = 6 output2_size = 7 initializer = None if use_custom_initial_value: initializer = [ tf.constant_initializer(8), tf.constant_initializer(9) ] # Test that the initial state of a non-recurrent DeepRNN is an empty list. non_recurrent_cores = [ snt.Linear(output_size=output1_size), snt.Linear(output_size=output2_size) ] dummy_deep_rnn = snt.DeepRNN(non_recurrent_cores, skip_connections=False) dummy_initial_state = dummy_deep_rnn.initial_state(batch_size, trainable=trainable) self.assertFalse(dummy_initial_state) # Test that the initial state of a recurrent DeepRNN is the same as calling # all cores' initial_state method. cores = [ snt.VanillaRNN(hidden_size=hidden1_size), snt.VanillaRNN(hidden_size=hidden2_size) ] deep_rnn = snt.DeepRNN(cores) initial_state = deep_rnn.initial_state( batch_size, trainable=trainable, trainable_initializers=initializer) expected_initial_state = [] for i, core in enumerate(cores): with tf.variable_scope("core-%d" % i): expected_initializer = None if initializer: expected_initializer = initializer[i] expected_initial_state.append( core.initial_state( batch_size, trainable=trainable, trainable_initializers=expected_initializer)) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) initial_state_value = sess.run(initial_state) expected_initial_state_value = sess.run(expected_initial_state) for expected_value, actual_value in zip(expected_initial_state_value, initial_state_value): self.assertAllEqual(actual_value, expected_value)
def __init__(self, access_config, controller_config, output_size, name='components'): """Initializes the DNC core. Args: access_config: dictionary of access module configurations. controller_config: dictionary of controller (LSTM) module configurations. output_size: output dimension size of core. clip_value: clips controller and core output values to between `[-clip_value, clip_value]` if specified. name: module name (default 'dnc'). Raises: TypeError: if direct_input_size is not None for any access module other than KeyValueMemory. """ super(Components, self).__init__(name=name) with self._enter_variable_scope(): self.controller = snt.DeepRNN( [snt.LSTM(**controller_config), snt.LSTM(**controller_config)]) #self.controller = snt.LSTM(**controller_config) self.access = access.MemoryAccess(**access_config) self.output_linear = snt.Linear(output_size=output_size, use_bias=False) if FLAGS.is_input_embedder: self.input_embedder = snt.Sequential( [snt.Linear(output_size=64, use_bias=True), tf.nn.tanh])
def evaluator( self, variable_source: acme.VariableSource, counter: counting.Counter, ): """The evaluation process.""" environment = self._environment_factory(True) network = self._network_factory(self._environment_spec.actions) tf2_utils.create_variables(network, [self._obs_spec]) policy_network = snt.DeepRNN([ network, lambda qs: tf.cast(tf.argmax(qs, axis=-1), tf.int32), ]) variable_client = tf2_variable_utils.VariableClient( client=variable_source, variables={'policy': policy_network.variables}, update_period=self._variable_update_period) # Make sure not to use a random policy after checkpoint restoration by # assigning variables before running the environment loop. variable_client.update_and_wait() # Create the agent. actor = actors.RecurrentActor( policy_network=policy_network, variable_client=variable_client) # Create the run loop and return it. logger = loggers.make_default_logger( 'evaluator', save_data=True, steps_key='evaluator_steps') counter = counting.Counter(counter, 'evaluator') return acme.EnvironmentLoop(environment, actor, counter, logger)
def testVariables(self): batch_size = 3 in_size = 2 hidden1_size = 4 hidden2_size = 5 mod_name = "deep_rnn" inputs = tf.placeholder(tf.float32, shape=[batch_size, in_size]) prev_state1 = tf.placeholder(tf.float32, shape=[batch_size, hidden1_size]) prev_state2 = tf.placeholder(tf.float32, shape=[batch_size, hidden1_size]) prev_state = (prev_state1, prev_state2) cores = [ snt.VanillaRNN(name="rnn1", hidden_size=hidden1_size), snt.VanillaRNN(name="rnn2", hidden_size=hidden2_size) ] deep_rnn = snt.DeepRNN(cores, name=mod_name) self.assertEqual(deep_rnn.scope_name, mod_name) with self.assertRaisesRegexp(snt.Error, "not instantiated yet"): deep_rnn.get_variables() deep_rnn(inputs, prev_state) # No variables now exposed by the DeepRNN. self.assertEqual(deep_rnn.get_variables(), ()) # Have to retrieve the modules from the cores individually. deep_rnn_variables = tuple( itertools.chain.from_iterable([c.get_variables() for c in cores])) self.assertEqual(len(deep_rnn_variables), 4 * len(cores), "Cores should have %d variables" % (4 * len(cores))) for v in deep_rnn_variables: self.assertRegexpMatches( v.name, "rnn(1|2)/(in_to_hidden|hidden_to_hidden)/(w|b):0")
def testNonRecurrentOnly(self): batch_size = 3 in_size = 2 output1_size = 4 output2_size = 5 cores = [ snt.Linear(name="linear1", output_size=output1_size), snt.Linear(name="linear2", output_size=output2_size) ] # Build DeepRNN of non-recurrent components. deep_rnn = snt.DeepRNN(cores, name="deeprnn", skip_connections=False) input_ = tf.placeholder(tf.float32, shape=[batch_size, in_size]) output, _ = deep_rnn(input_, ()) # Build manual computation graph. output1 = cores[0](input_) input2 = output1 output2 = cores[1](input2) manual_output = output2 with self.test_session() as sess: input_data = np.random.randn(batch_size, in_size) feed_dict = {input_: input_data} tf.global_variables_initializer().run() output_value = sess.run([output], feed_dict=feed_dict) manual_out_value = sess.run([manual_output], feed_dict=feed_dict) self.assertAllClose(output_value, manual_out_value)
def testInitialStateNames(self): hidden_size_a = 3 hidden_size_b = 4 batch_size = 5 deep_rnn = snt.DeepRNN([ snt.LSTM(hidden_size_a, name="a"), snt.LSTM(hidden_size_b, name="b") ]) deep_rnn_state = deep_rnn.initial_state(batch_size, trainable=True) self.assertEqual( deep_rnn_state[0][0].name, "deep_rnn_initial_state/a_initial_state/state_0_tiled:0") self.assertEqual( deep_rnn_state[0][1].name, "deep_rnn_initial_state/a_initial_state/state_1_tiled:0") self.assertEqual( deep_rnn_state[1][0].name, "deep_rnn_initial_state/b_initial_state/state_0_tiled:0") self.assertEqual( deep_rnn_state[1][1].name, "deep_rnn_initial_state/b_initial_state/state_1_tiled:0") other_start_state = deep_rnn.initial_state(batch_size, trainable=True, name="blah") self.assertEqual(other_start_state[0][0].name, "blah/a_initial_state/state_0_tiled:0") self.assertEqual(other_start_state[0][1].name, "blah/a_initial_state/state_1_tiled:0") self.assertEqual(other_start_state[1][0].name, "blah/b_initial_state/state_0_tiled:0") self.assertEqual(other_start_state[1][1].name, "blah/b_initial_state/state_1_tiled:0")
def testSkipConnectionOptions(self): batch_size = 3 x_seq_shape = [10, batch_size, 2] num_hidden = 5 num_layers = 4 final_hidden_size = 9 x_seq = tf.placeholder(shape=x_seq_shape, dtype=tf.float32) cores = [snt.LSTM(num_hidden) for _ in xrange(num_layers - 1)] final_core = snt.LSTM(final_hidden_size) cores += [final_core] deep_rnn_core = snt.DeepRNN(cores, skip_connections=True, concat_final_output_if_skip=False) initial_state = deep_rnn_core.initial_state(batch_size=batch_size) output_seq, _ = tf.nn.dynamic_rnn(deep_rnn_core, x_seq, time_major=True, initial_state=initial_state, dtype=tf.float32) initial_output = output_seq[0] feed_dict = {x_seq: np.random.normal(size=x_seq_shape)} with self.test_session() as sess: sess.run(tf.global_variables_initializer()) initial_output_res = sess.run(initial_output, feed_dict=feed_dict) expected_shape = (batch_size, final_hidden_size) self.assertSequenceEqual(initial_output_res.shape, expected_shape)
def test_rnn_snapshot(self): """Test that snapshotter correctly calls saves/restores snapshots on RNNs.""" # Create a test network. net = snt.LSTM(10) spec = specs.Array([10], dtype=np.float32) tf2_utils.create_variables(net, [spec]) # Test that if you add some postprocessing without rerunning # create_variables, it still works. wrapped_net = snt.DeepRNN([net, lambda x: x]) for net1 in [net, wrapped_net]: # Save the test network. directory = self.get_tempdir() objects_to_save = {'net': net1} snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory) snapshotter.save() # Reload the test network. net2 = tf.saved_model.load(os.path.join(snapshotter.directory, 'net')) inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec)) with tf.GradientTape() as tape: outputs1, next_state1 = net1(inputs, net1.initial_state(1)) loss1 = tf.math.reduce_sum(outputs1) grads1 = tape.gradient(loss1, net1.trainable_variables) with tf.GradientTape() as tape: outputs2, next_state2 = net2(inputs, net2.initial_state(1)) loss2 = tf.math.reduce_sum(outputs2) grads2 = tape.gradient(loss2, net2.trainable_variables) assert np.allclose(outputs1, outputs2) assert np.allclose(tree.flatten(next_state1), tree.flatten(next_state2)) assert all(tree.map_structure(np.allclose, list(grads1), list(grads2)))
def build_module(inputs, input_len): cores = [ snt.LSTM(hidden_size, custom_getter=lstm_bbb_custom_getter, forget_bias=0.0, name="lstm_layer_{}".format(i)) for i in six.moves.range(n_layers) ] rnn_core = snt.DeepRNN(cores, skip_connections=True, name="deep_lstm_core") # Do BBB on weights but not biases of output layer. output_linear = Linear(n_classes, custom_getter={"w": non_lstm_bbb_custom_getter}) # initial_rnn_state = nest.map_structure(lambda t: tf.get_local_variable( # "{}/rnn_state/{}".format("train", t.op.name), initializer=t), rnn_core.initial_state(batch_size)) # assign_zero_rnn_state = nest.map_structure(lambda x: x.assign(tf.zeros_like(x)), initial_rnn_state) # assign_zero_rnn_state = tf.group(*nest.flatten(assign_zero_rnn_state)) # Unroll the RNN core over the sequence. rnn_output_seq, rnn_final_state = tf.nn.dynamic_rnn( cell=rnn_core, inputs=inputs, sequence_length=input_len, dtype=tf.float32) # Persist the RNN state for the next unroll. # update_rnn_state = nest.map_structure(tf.assign, initial_rnn_state, rnn_final_state) # with tf.control_dependencies(nest.flatten(update_rnn_state)): # rnn_output_seq = tf.identity(rnn_output_seq, name="rnn_output_seq") rnn_output_seq = tf.reshape(tf.concat(rnn_output_seq, 2), [-1, 2 * 512]) output_logits = output_linear(rnn_output_seq) return output_logits #, assign_zero_rnn_state
def _build(self, inputs): ## DROPOUT ## if self.dropout > 0.0: dropouts = [ Dropout(keep_prob=self._keep_prob) for i in range(self.num_layers) ] self.subcores = interleave(self.subcores, dropouts) if len(self.subcores) > 1: self.core = snt.DeepRNN(self.subcores, name="multi_lstm_core", skip_connections=self.use_skip_connections) else: self.core = self.subcores[0] if self.train_initial_state: self._initial_rnn_state = self.core.initial_state(self.batch_size, tf.float32, trainable=True) #self._initial_rnn_state = snt.TrainableInitialState(self.core.initial_state(self.batch_size, tf.float32))() else: self._initial_rnn_state = self.core.zero_state( self.batch_size, tf.float32) output, final_rnn_state = tf.nn.dynamic_rnn( self.core, inputs, dtype=tf.float32, sequence_length=self._seq_len, initial_state=self._initial_rnn_state) return output, final_rnn_state
def __init__(self, action_spec: specs.DiscreteArray): super().__init__(name='r2d2_test_network') self._net = snt.DeepRNN([ snt.Flatten(), snt.LSTM(20), snt.nets.MLP([50, 50, action_spec.num_values]) ])
def testNonRecurrentOnly(self): batch_size = 3 in_size = 2 output1_size = 4 output2_size = 5 cores = [ snt.Linear(name="linear1", output_size=output1_size), snt.Linear(name="linear2", output_size=output2_size) ] # Build DeepRNN of non-recurrent components. deep_rnn = snt.DeepRNN(cores, name="deeprnn", skip_connections=False) input_data = np.random.randn(batch_size, in_size) input_ = tf.constant(input_data, dtype=tf.float32) output, _ = deep_rnn(input_, ()) # Build manual computation graph. output1 = cores[0](input_) input2 = output1 output2 = cores[1](input2) manual_output = output2 self.evaluate(tf.global_variables_initializer()) output_value = self.evaluate(output) manual_out_value = self.evaluate(manual_output) self.assertAllClose(output_value, manual_out_value)
def _make_network(action_spec: specs.DiscreteArray) -> snt.RNNCore: return snt.DeepRNN([ snt.Flatten(), snt.LSTM(20), snt.nets.MLP([50, 50]), networks.PolicyValueHead(action_spec.num_values), ])
def __init__(self, output_size, layers, preprocess_name="identity", preprocess_options=None, scale=1.0, initializer=None, name="deep_lstm", tanh_output=False, num_linear_heads=1): """Creates an instance of `StandardDeepLSTM`. Args: output_size: Output sizes of the final linear layer. layers: Output sizes of LSTM layers. preprocess_name: Gradient preprocessing class name (in `l2l.preprocess` or tf modules). Default is `tf.identity`. preprocess_options: Gradient preprocessing options. scale: Gradient scaling (default is 1.0). initializer: Variable initializer for linear layer. See `snt.Linear` and `snt.LSTM` docs for more info. This parameter can be a string (e.g. "zeros" will be converted to tf.zeros_initializer). name: Module name. """ super(StandardDeepLSTM, self).__init__(name=name) self._output_size = output_size self._scale = scale self._num_linear_heads = num_linear_heads assert self._num_linear_heads >= 1 if preprocess_name == 'fc': with tf.variable_scope(self._template.variable_scope): init = _get_layer_initializers(initializer, "input_projection", ("w", "b")) print("initialization of input projection !!!!!!!!!!") print(init) self._linear_input = snt.Linear(preprocess_options["dim"], name="input_projection", initializers=init) elif hasattr(preprocess, preprocess_name): preprocess_class = getattr(preprocess, preprocess_name) self._preprocess = preprocess_class(initializer, **preprocess_options) else: self._preprocess = getattr(tf, preprocess_name) self._preprocess_name = preprocess_name with tf.variable_scope(self._template.variable_scope): self._cores = [] for i, size in enumerate(layers, start=1): name = "lstm_{}".format(i) init = _get_layer_initializers(initializer, name, ("w_gates", "b_gates")) self._cores.append(snt.LSTM(size, name=name, initializers=init)) self._rnn = snt.DeepRNN(self._cores, skip_connections=False, name="deep_rnn") if self._num_linear_heads == 1: init = _get_layer_initializers(initializer, "linear", ("w", "b")) self._linear = snt.Linear(output_size, name="linear", initializers=init) else: self._linears = [] init = _get_layer_initializers(initializer, "linear", ("w", "b")) self._linears.append(snt.Linear(output_size, name="linear", initializers=init)) for i in range(1, num_linear_heads): init = _get_layer_initializers(initializer, "linear_{}".format(i), ("w", "b")) self._linears.append(snt.Linear(output_size, name="linear_{}".format(i), initializers=init)) self.init_flag = False self.tanh_output = tanh_output
def testShape(self): batch_size = 3 batch_size_shape = tf.TensorShape(batch_size) in_size = 2 hidden1_size = 4 hidden2_size = 5 inputs = tf.placeholder(tf.float32, shape=[batch_size, in_size]) prev_state0 = tf.placeholder(tf.float32, shape=[batch_size, in_size]) prev_state1 = tf.placeholder(tf.float32, shape=[batch_size, hidden1_size]) prev_state2 = tf.placeholder(tf.float32, shape=[batch_size, hidden2_size]) prev_state = (prev_state0, prev_state1, prev_state2) # Test recurrent and non-recurrent cores cores = [ snt.VanillaRNN(name="rnn0", hidden_size=in_size), snt.VanillaRNN(name="rnn1", hidden_size=hidden1_size), snt.VanillaRNN(name="rnn2", hidden_size=hidden2_size) ] deep_rnn = snt.DeepRNN(cores, name="deep_rnn", skip_connections=True) output, next_state = deep_rnn(inputs, prev_state) output_shape = output.get_shape() output_size = in_size + hidden1_size + hidden2_size self.assertTrue( output_shape.is_compatible_with([batch_size, output_size])) self.assertTrue( output_shape.is_compatible_with( batch_size_shape.concatenate(deep_rnn.output_size))) next_state_shape = (next_state[0].get_shape(), next_state[1].get_shape(), next_state[2].get_shape()) self.assertTrue(next_state_shape[0].is_compatible_with( [batch_size, in_size])) self.assertTrue(next_state_shape[1].is_compatible_with( [batch_size, hidden1_size])) self.assertTrue(next_state_shape[2].is_compatible_with( [batch_size, hidden2_size])) for state_shape, expected_shape in zip(next_state_shape, deep_rnn.state_size): self.assertTrue( state_shape.is_compatible_with( batch_size_shape.concatenate(expected_shape))) # Initial state should be a valid state initial_state = deep_rnn.initial_state(batch_size, tf.float32) self.assertTrue(len(initial_state), len(next_state)) self.assertShapeEqual(np.ndarray((batch_size, in_size)), initial_state[0]) self.assertShapeEqual(np.ndarray((batch_size, hidden1_size)), initial_state[1]) self.assertShapeEqual(np.ndarray((batch_size, hidden2_size)), initial_state[2])
def testIncompatibleOptions(self): in_size = 2 hidden1_size = 4 hidden2_size = 5 cores = [snt.Linear(name="linear", output_size=in_size), snt.VanillaRNN(name="rnn1", hidden_size=hidden1_size), snt.VanillaRNN(name="rnn2", hidden_size=hidden2_size)] with self.assertRaisesRegexp( ValueError, "skip_connections are enabled but not all cores are " "`snt.RNNCore`s, which is not supported"): snt.DeepRNN(cores, name="deep_rnn", skip_connections=True) cells = [tf.contrib.rnn.BasicLSTMCell(5), tf.contrib.rnn.BasicLSTMCell(5)] with self.assertRaisesRegexp( ValueError, "skip_connections are enabled but not all cores are " "`snt.RNNCore`s, which is not supported"): snt.DeepRNN(cells, skip_connections=True)
def module_build(): core = snt.DeepRNN([snt.LSTM(4), snt.LSTM(5)]) initial_state1 = core.initial_state( batch_size, dtype=tf.float32, trainable=True) initial_state2 = core.initial_state( batch_size + 1, dtype=tf.float32, trainable=True) return initial_state1, initial_state2
def __init__(self, num_embedding, num_hidden, lstm_depth, output_size, use_dynamic_rnn=True, use_skip_connections=True, name="text_model"): """Constructs a `TextModel`. Args: num_embedding: Size of embedding representation, used directly after the one-hot encoded input. num_hidden: Number of hidden units in each LSTM layer. lstm_depth: Number of LSTM layers. output_size: Size of the output layer on top of the DeepRNN. use_dynamic_rnn: Whether to use dynamic RNN unrolling. If `False`, it uses static unrolling. Default is `True`. use_skip_connections: Whether to use skip connections in the `snt.DeepRNN`. Default is `True`. name: Name of the module. """ super(TextModel, self).__init__(name=name) self._num_embedding = num_embedding self._num_hidden = num_hidden self._lstm_depth = lstm_depth self._output_size = output_size self._use_dynamic_rnn = use_dynamic_rnn self._use_skip_connections = use_skip_connections with self._enter_variable_scope(): self._embed_module = snt.Linear(self._num_embedding, name="linear_embed") self._output_module = snt.Linear(self._output_size, name="linear_output") self._subcores = [ snt.LSTM(self._num_hidden, name="lstm_{}".format(i)) for i in range(self._lstm_depth) ] if self._use_skip_connections: skips = [] current_input_shape = self._num_embedding for lstm in self._subcores: input_shape = tf.TensorShape([current_input_shape]) skip = snt.SkipConnectionCore(lstm, input_shape=input_shape, name="skip_{}".format( lstm.module_name)) skips.append(skip) # SkipConnectionCore concatenates the input with the output, so the # dimensionality increases with depth. current_input_shape += self._num_hidden self._subcores = skips self._core = snt.DeepRNN(self._subcores, skip_connections=False, name="deep_lstm")
def make_rnn(hparams, name): """Constructs a DeepRNN using hparams.rnn_hidden_sizes.""" regularizers = {snt.LSTM.W_GATES: regularizer(hparams)} with tf.variable_scope(name): layers = [ snt.LSTM(size, regularizers=regularizers) for size in hparams.rnn_hidden_sizes ] return snt.DeepRNN(layers, skip_connections=False, name=name)
def test_feedforward(self, recurrent: bool): model = snt.Linear(42) if recurrent: model = snt.DeepRNN([model]) input_spec = specs.Array(shape=(10,), dtype=np.float32) tf2_utils.create_variables(model, [input_spec]) variables: Sequence[tf.Variable] = model.variables shapes = [v.shape.as_list() for v in variables] self.assertSequenceEqual(shapes, [[42], [10, 42]])
def testComputation(self, skip_connections, create_initial_state): batch_size = 3 in_size = 2 hidden1_size = 4 hidden2_size = 5 mod_name = "deep_rnn" cores = [ snt.VanillaRNN(name="rnn1", hidden_size=hidden1_size), snt.VanillaRNN(name="rnn2", hidden_size=hidden2_size) ] deep_rnn = snt.DeepRNN(cores, name=mod_name, skip_connections=skip_connections) inputs = tf.constant(np.random.randn(batch_size, in_size), dtype=tf.float32) if create_initial_state: prev_state = deep_rnn.initial_state(batch_size, tf.float32) else: prev_state1 = tf.constant(np.random.randn(batch_size, hidden1_size), dtype=tf.float32) prev_state2 = tf.constant(np.random.randn(batch_size, hidden2_size), dtype=tf.float32) prev_state = (prev_state1, prev_state2) output, next_state = deep_rnn(inputs, prev_state) # With random data, check the DeepRNN calculation matches the manual # stacking version. self.evaluate(tf.global_variables_initializer()) outputs_value = self.evaluate([output, next_state[0], next_state[1]]) output_value, next_state1_value, next_state2_value = outputs_value # Build manual computation graph output1, next_state1 = cores[0](inputs, prev_state[0]) if skip_connections: input2 = tf.concat([inputs, output1], 1) else: input2 = output1 output2, next_state2 = cores[1](input2, prev_state[1]) if skip_connections: manual_output = tf.concat([output1, output2], 1) else: manual_output = output2 manual_outputs_value = self.evaluate( [manual_output, next_state1, next_state2]) manual_output_value = manual_outputs_value[0] manual_next_state1_value = manual_outputs_value[1] manual_next_state2_value = manual_outputs_value[2] self.assertAllClose(output_value, manual_output_value) self.assertAllClose(next_state1_value, manual_next_state1_value) self.assertAllClose(next_state2_value, manual_next_state2_value)
def testComputation(self, skip_connections, create_initial_state): batch_size = 3 in_size = 2 hidden1_size = 4 hidden2_size = 5 mod_name = "deep_rnn" cores = [snt.VanillaRNN(name="rnn1", hidden_size=hidden1_size), snt.VanillaRNN(name="rnn2", hidden_size=hidden2_size)] deep_rnn = snt.DeepRNN(cores, name=mod_name, skip_connections=skip_connections) inputs = tf.placeholder(tf.float32, shape=[batch_size, in_size]) if create_initial_state: prev_state = deep_rnn.initial_state(batch_size, tf.float32) else: prev_state1 = tf.placeholder( tf.float32, shape=[batch_size, hidden1_size]) prev_state2 = tf.placeholder( tf.float32, shape=[batch_size, hidden2_size]) prev_state = (prev_state1, prev_state2) output, next_state = deep_rnn(inputs, prev_state) with self.test_session() as sess: # With random data, check the DeepRNN calculation matches the manual # stacking version. input_data = np.random.randn(batch_size, in_size) feed_dict = {inputs: input_data} if not create_initial_state: feed_dict[prev_state1] = np.random.randn(batch_size, hidden1_size) feed_dict[prev_state2] = np.random.randn(batch_size, hidden2_size) tf.global_variables_initializer().run() outputs_value = sess.run([output, next_state[0], next_state[1]], feed_dict=feed_dict) output_value, next_state1_value, next_state2_value = outputs_value # Build manual computation graph output1, next_state1 = cores[0](inputs, prev_state[0]) if skip_connections: input2 = tf.concat([inputs, output1], 1) else: input2 = output1 output2, next_state2 = cores[1](input2, prev_state[1]) if skip_connections: manual_output = tf.concat([output1, output2], 1) else: manual_output = output2 manual_outputs_value = sess.run([manual_output, next_state1, next_state2], feed_dict=feed_dict) manual_output_value = manual_outputs_value[0] manual_next_state1_value = manual_outputs_value[1] manual_next_state2_value = manual_outputs_value[2] self.assertAllClose(output_value, manual_output_value) self.assertAllClose(next_state1_value, manual_next_state1_value) self.assertAllClose(next_state2_value, manual_next_state2_value)
def __init__(self, access_config, access_config_2, controller_config, output_size, clip_value=None, name='dnc'): """Initializes the DNC core. Args: access_config: dictionary of access module configurations. controller_config: dictionary of controller (LSTM) module configurations. output_size: output dimension size of core. clip_value: clips controller and core output values to between `[-clip_value, clip_value]` if specified. name: module name (default 'dnc'). Raises: TypeError: if direct_input_size is not None for any access module other than KeyValueMemory. """ super(DNC, self).__init__(name=name) self._b_init = tf.zeros(access_config_2['num_reads']) + 0.01 with self._enter_variable_scope(): #self._controller = snt.LSTM(**controller_config) #self._controller = snt.DeepRNN([snt.Linear(200), tf.nn.leaky_relu, snt.Linear(controller_config['hidden_size'])], skip_connections=False) self._controller = snt.DeepRNN([ snt.Linear(200), tf.nn.leaky_relu, snt.LSTM(controller_config['hidden_size']) ], skip_connections=False) self._access = access.MemoryAccess(**access_config) #self._access_2 = access.MemoryAccess(**access_config_2) self._access_2 = access_instruction.MemoryAccessInst( **access_config_2) self._b = tf.get_variable(name="b", initializer=self._b_init, trainable=True) #self._bn = snt.BatchNorm(update_ops_collection=None) #self._test_layer = snt.Linear(64) self._access_output_size = np.prod(self._access.output_size.as_list()) # self._access_output_size_2 = np.prod(self._access_2.output_size.as_list()) self._output_size = output_size self._clip_value = clip_value or 0 self._output_size = tf.TensorShape([output_size]) self._state_size = DNCState( access_output=self._access_output_size, access_state=self._access.state_size, # access_output_2=self._access_output_size, access_state_2=self._access_2.state_size, controller_state=self._controller.state_size)
def __init__(self, cores, name=None): super(StackedRNN, self).__init__(name=name) self._cores = cores can_use_conv = all( type(c) in self._supports_conv_mode for c in cores if isinstance(c, snt.AbstractModule)) with self._enter_variable_scope(): self._conv_core = snt.Sequential(cores) if can_use_conv else None self._rnn_core = snt.DeepRNN( cores, skip_connections=False) if cores else None
def test_output_spec_feedforward(self, recurrent: bool): input_spec = specs.Array(shape=(10,), dtype=np.float32) model = snt.Linear(42) expected_spec = tf.TensorSpec(shape=(42,), dtype=tf.float32) if recurrent: model = snt.DeepRNN([model]) expected_spec = (expected_spec, ()) output_spec = tf2_utils.create_variables(model, [input_spec]) self.assertEqual(output_spec, expected_spec)
def __init__(self, action_spec: specs.DiscreteArray, name: Optional[Text] = None): super().__init__(name=name) # TODO: make a flags for hidden layer dims. self.flat = snt.nets.MLP([64, 64], name="mlp_1") self.rnn = snt.DeepRNN([ snt.nets.MLP([50, 50], activate_final=True, name="mlp_2"), snt.GRU(512, name="gru"), networks.PolicyValueHead(action_spec.num_values) ])
def __init__( self, conf, create_scale=False, create_offset=False, recurrent_dropout=0.35, name="NormLSTM", ): super(NormLSTM, self).__init__(name=name) self._hidden_size = conf[0] train_lstm = [] test_lstm = [] for _ in range(conf[1]): dropout_lstm, lstm = snt.lstm_with_recurrent_dropout( self._hidden_size, dropout=recurrent_dropout) test_lstm.append(lstm) train_lstm.append(dropout_lstm) self._test_lstm = snt.DeepRNN(test_lstm) self._train_lstm = snt.DeepRNN(train_lstm) self._norm = snt.BatchNorm(create_scale, create_offset)
def custom_recurrent_network( environment_spec: mava_specs.MAEnvironmentSpec, q_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = [128, 128], shared_weights: bool = True, ) -> Mapping[str, types.TensorTransformation]: """Creates networks used by the agents.""" specs = environment_spec.get_agent_specs() # Create agent_type specs if shared_weights: type_specs = {key.split("_")[0]: specs[key] for key in specs.keys()} specs = type_specs if isinstance(q_networks_layer_sizes, Sequence): q_networks_layer_sizes = {key: q_networks_layer_sizes for key in specs.keys()} def action_selector_fn( q_values: types.NestedTensor, legal_actions: types.NestedTensor, epsilon: Optional[tf.Variable] = None, ) -> types.NestedTensor: return epsilon_greedy_action_selector( action_values=q_values, legal_actions_mask=legal_actions, epsilon=epsilon ) q_networks = {} action_selectors = {} for key in specs.keys(): # Get total number of action dimensions from action spec. num_dimensions = specs[key].actions.num_values # Create the policy network. q_network = snt.DeepRNN( [ snt.Linear(q_networks_layer_sizes[key][0]), tf.nn.relu, snt.GRU(q_networks_layer_sizes[key][1]), networks.NearZeroInitializedLinear(num_dimensions), ] ) # epsilon greedy action selector action_selector = action_selector_fn q_networks[key] = q_network action_selectors[key] = action_selector return { "q_networks": q_networks, "action_selectors": action_selectors, }
def testFinalCoreHasNoSizeWarning(self): cores = [snt.LSTM(hidden_size=10), snt.Linear(output_size=42), tf.nn.relu] rnn = snt.DeepRNN(cores, skip_connections=False) with mock.patch.object(tf.logging, "warning") as mocked_logging_warning: # This will produce a warning. unused_output_size = rnn.output_size self.assertTrue(mocked_logging_warning.called) first_call_args = mocked_logging_warning.call_args[0] self.assertTrue("final core %s does not have the " ".output_size field" in first_call_args[0]) self.assertEqual(first_call_args[2], 42)