def test_non_recurrent_mappings(self): insize = 2 hidden1_size = 4 hidden2_size = 5 seq_length = 7 batch_size = 3 # As mentioned above, non-recurrent cores are not supported with # skip connections. But test that some number of non-recurrent cores # is okay (particularly as the last core) without skip connections. cores1 = [snt.LSTM(hidden1_size), tf.tanh, snt.Linear(hidden2_size)] core1 = snt.DeepRNN(cores1, skip_connections=False) core1_h0 = core1.initial_state(batch_size=batch_size) cores2 = [snt.LSTM(hidden1_size), snt.Linear(hidden2_size), tf.tanh] core2 = snt.DeepRNN(cores2, skip_connections=False) core2_h0 = core2.initial_state(batch_size=batch_size) xseq = tf.random_normal(shape=[seq_length, batch_size, insize]) y1, _ = tf.nn.dynamic_rnn(core1, xseq, initial_state=core1_h0, time_major=True) y2, _ = tf.nn.dynamic_rnn(core2, xseq, initial_state=core2_h0, time_major=True) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) sess.run([y1, y2])
def __init__(self, access_config, controller_config, output_size, name='components'): """Initializes the DNC core. Args: access_config: dictionary of access module configurations. controller_config: dictionary of controller (LSTM) module configurations. output_size: output dimension size of core. clip_value: clips controller and core output values to between `[-clip_value, clip_value]` if specified. name: module name (default 'dnc'). Raises: TypeError: if direct_input_size is not None for any access module other than KeyValueMemory. """ super(Components, self).__init__(name=name) with self._enter_variable_scope(): self.controller = snt.DeepRNN( [snt.LSTM(**controller_config), snt.LSTM(**controller_config)]) #self.controller = snt.LSTM(**controller_config) self.access = access.MemoryAccess(**access_config) self.output_linear = snt.Linear(output_size=output_size, use_bias=False) if FLAGS.is_input_embedder: self.input_embedder = snt.Sequential( [snt.Linear(output_size=64, use_bias=True), tf.nn.tanh])
def testSkipConnectionOptions(self): batch_size = 3 x_seq_shape = [10, batch_size, 2] num_hidden = 5 num_layers = 4 final_hidden_size = 9 x_seq = tf.placeholder(shape=x_seq_shape, dtype=tf.float32) cores = [snt.LSTM(num_hidden) for _ in xrange(num_layers - 1)] final_core = snt.LSTM(final_hidden_size) cores += [final_core] deep_rnn_core = snt.DeepRNN(cores, skip_connections=True, concat_final_output_if_skip=False) initial_state = deep_rnn_core.initial_state(batch_size=batch_size) output_seq, _ = tf.nn.dynamic_rnn(deep_rnn_core, x_seq, time_major=True, initial_state=initial_state, dtype=tf.float32) initial_output = output_seq[0] feed_dict = {x_seq: np.random.normal(size=x_seq_shape)} with self.test_session() as sess: sess.run(tf.global_variables_initializer()) initial_output_res = sess.run(initial_output, feed_dict=feed_dict) expected_shape = (batch_size, final_hidden_size) self.assertSequenceEqual(initial_output_res.shape, expected_shape)
def testInitialStateNames(self): hidden_size_a = 3 hidden_size_b = 4 batch_size = 5 deep_rnn = snt.DeepRNN([ snt.LSTM(hidden_size_a, name="a"), snt.LSTM(hidden_size_b, name="b") ]) deep_rnn_state = deep_rnn.initial_state(batch_size, trainable=True) self.assertEqual( deep_rnn_state[0][0].name, "deep_rnn_initial_state/a_initial_state/state_0_tiled:0") self.assertEqual( deep_rnn_state[0][1].name, "deep_rnn_initial_state/a_initial_state/state_1_tiled:0") self.assertEqual( deep_rnn_state[1][0].name, "deep_rnn_initial_state/b_initial_state/state_0_tiled:0") self.assertEqual( deep_rnn_state[1][1].name, "deep_rnn_initial_state/b_initial_state/state_1_tiled:0") other_start_state = deep_rnn.initial_state(batch_size, trainable=True, name="blah") self.assertEqual(other_start_state[0][0].name, "blah/a_initial_state/state_0_tiled:0") self.assertEqual(other_start_state[0][1].name, "blah/a_initial_state/state_1_tiled:0") self.assertEqual(other_start_state[1][0].name, "blah/b_initial_state/state_0_tiled:0") self.assertEqual(other_start_state[1][1].name, "blah/b_initial_state/state_1_tiled:0")
def module_build(): core = snt.DeepRNN([snt.LSTM(4), snt.LSTM(5)]) initial_state1 = core.initial_state( batch_size, dtype=tf.float32, trainable=True) initial_state2 = core.initial_state( batch_size + 1, dtype=tf.float32, trainable=True) return initial_state1, initial_state2
def testSideBySide(self): hidden_size = 3 batch_size = 4 lstm1 = snt.LSTM(hidden_size=hidden_size) lstm2 = snt.LSTM(hidden_size=hidden_size) lstm1.initial_state(batch_size, trainable=True) # Previously either of the two lines below would cause a crash due to # Variable name collision. lstm1.initial_state(batch_size, trainable=True) lstm2.initial_state(batch_size, trainable=True)
def _build(self, inputs): embed = tf.contrib.layers.embed_sequence(inputs, 128, 32) fw_cell = snt.LSTM(self.size, name="lstm_fw") bw_cell = snt.LSTM(self.size, name="lstm_bw") (a, b), _ = tf.nn.bidirectional_dynamic_rnn(fw_cell, bw_cell, embed, dtype=tf.float32) concat = tf.concat([a[:, -1, :], b[:, -1, :]], 1) final = tf.layers.dense(concat, self.size) return final
def testConflictingNormalization(self): with self.assertRaisesRegexp( ValueError, "Only one of use_batch_norm_h and layer_norm is allowed."): snt.LSTM(hidden_size=3, use_layer_norm=True, use_batch_norm_h=True) with self.assertRaisesRegexp( ValueError, "Only one of use_batch_norm_x and layer_norm is allowed."): snt.LSTM(hidden_size=3, use_layer_norm=True, use_batch_norm_x=True) with self.assertRaisesRegexp( ValueError, "Only one of use_batch_norm_c and layer_norm is allowed."): snt.LSTM(hidden_size=3, use_layer_norm=True, use_batch_norm_c=True)
def build_common_network(inputs): """common network :param inputs: [Time, Batch, state_size] :return: [Time, Batch, hidden_size] """ # build rnn batch_size = inputs.get_shape().as_list()[1] l1 = snt.LSTM(64, name='rnn_first') l2 = snt.LSTM(32, name='rnn_second') rnn = snt.DeepRNN([l1, l2]) initial_state = rnn.initial_state(batch_size) # looping output_sequence, final_state = tf.nn.dynamic_rnn( rnn, inputs, initial_state=initial_state, time_major=True) return output_sequence
def __init__(self, config, name='updater'): super(Updater, self).__init__(name=name) assert len(config['upd_channel']) == len(config['upd_kernel']) assert len(config['upd_channel']) == len(config['upd_stride']) with self._enter_variable_scope(check_same_graph=False): self._grid = get_grid(config['image_shape'][:-1], name='grid') self._layers = [] for idx, (channel, kernel, stride) in enumerate( zip(config['upd_channel'], config['upd_kernel'], config['upd_stride'])): self._layers += [ snt.Conv2D(channel, kernel, stride=stride, name='conv_{}'.format(idx)), partial(tf.nn.elu, name='conv_{}_elu'.format(idx)), ] self._layers.append( partial(tf.math.reduce_mean, axis=[1, 2], name='global_avg_pool')) for idx, hidden in enumerate(config['upd_hidden']): self._layers += [ snt.Linear(hidden, name='linear_{}'.format(idx)), partial(tf.nn.elu, name='linear_{}_elu'.format(idx)), ] if config['state_size']: self._lstm = snt.LSTM(config['state_size'], name='lstm') else: self._lstm = None self._linear_loc = snt.Linear(config['latent_size'], name='linear_loc') self._linear_scale = snt.Linear(config['latent_size'], name='linear_scale')
def testBatchNormVariables(self, use_peepholes, use_batch_norm_h, use_batch_norm_x, use_batch_norm_c): cell = snt.LSTM(hidden_size=3, use_peepholes=use_peepholes, use_batch_norm_h=use_batch_norm_h, use_batch_norm_x=use_batch_norm_x, use_batch_norm_c=use_batch_norm_c) # Need to connect the cell before it has variables batch_size = 3 inputs = tf.placeholder(tf.float32, shape=[batch_size, 3, 3]) tf.nn.dynamic_rnn(cell, inputs, initial_state=cell.initial_state( batch_size, tf.float32)) self.assertEqual(use_peepholes, cell.use_peepholes) self.assertEqual(use_batch_norm_h, cell.use_batch_norm_h) self.assertEqual(use_batch_norm_x, cell.use_batch_norm_x) self.assertEqual(use_batch_norm_c, cell.use_batch_norm_c) if use_batch_norm_h or use_batch_norm_x: expected = 3 # gate bias and two weights else: expected = 2 # gate bias and weight if use_peepholes: expected += 3 if use_batch_norm_h: expected += 1 # gamma_h if use_batch_norm_x: expected += 1 # gamma_x if use_batch_norm_c: expected += 2 # gamma_c, beta_c self.assertEqual(len(cell.get_variables()), expected)
def testTraining(self, trainable_initial_state, max_unique_stats): """Test that everything trains OK, with or without trainable init. state.""" hidden_size = 3 batch_size = 3 time_steps = 3 cell = snt.LSTM(hidden_size=hidden_size, use_batch_norm_h=True, max_unique_stats=max_unique_stats) inputs = tf.constant(np.random.rand(batch_size, time_steps, 3), dtype=tf.float32) initial_state = cell.initial_state(batch_size, tf.float32, trainable_initial_state) output, _ = tf.nn.dynamic_rnn(cell, inputs, initial_state=initial_state, dtype=tf.float32) loss = tf.reduce_mean( tf.square(output - np.random.rand(batch_size, time_steps, hidden_size))) train_op = tf.train.GradientDescentOptimizer(1).minimize(loss) init = tf.global_variables_initializer() with self.test_session(): init.run() train_op.run()
def testPartitioners(self, use_peepholes, use_batch_norm_h, use_batch_norm_x, use_batch_norm_c): batch_size = 2 hidden_size = 4 keys = snt.LSTM.get_possible_initializer_keys(use_peepholes, use_batch_norm_h, use_batch_norm_x, use_batch_norm_c) partitioners = { key: tf.variable_axis_size_partitioner(10) for key in keys } # Test we can successfully create the LSTM with partitioners. lstm = snt.LSTM(hidden_size, use_peepholes=use_peepholes, use_batch_norm_h=use_batch_norm_h, use_batch_norm_x=use_batch_norm_x, use_batch_norm_c=use_batch_norm_c, partitioners=partitioners) # Test we can build the LSTM inputs = tf.placeholder(tf.float32, shape=[batch_size, hidden_size]) prev_cell = tf.placeholder(tf.float32, shape=[batch_size, hidden_size]) prev_hidden = tf.placeholder(tf.float32, shape=[batch_size, hidden_size]) lstm(inputs, (prev_hidden, prev_cell)) # Test that the variables are partitioned. var_names = _get_lstm_variable_names(lstm) for var_name in var_names: self.assertEqual(type(getattr(lstm, "_" + var_name)), variables.PartitionedVariable)
def testRegularizers(self, use_peepholes, use_batch_norm_h, use_batch_norm_x, use_batch_norm_c): batch_size = 2 hidden_size = 4 keys = snt.LSTM.get_possible_initializer_keys(use_peepholes, use_batch_norm_h, use_batch_norm_x, use_batch_norm_c) regularizers = {key: tf.nn.l2_loss for key in keys} # Test we can successfully create the LSTM with regularizers. lstm = snt.LSTM(hidden_size, use_peepholes=use_peepholes, use_batch_norm_h=use_batch_norm_h, use_batch_norm_x=use_batch_norm_x, use_batch_norm_c=use_batch_norm_c, regularizers=regularizers) # Test we can build the LSTM inputs = tf.placeholder(tf.float32, shape=[batch_size, hidden_size]) prev_cell = tf.placeholder(tf.float32, shape=[batch_size, hidden_size]) prev_hidden = tf.placeholder(tf.float32, shape=[batch_size, hidden_size]) lstm(inputs, (prev_hidden, prev_cell)) # Test that we have regularization losses. num_reg_losses = len( tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) if use_batch_norm_h or use_batch_norm_x: self.assertEqual(num_reg_losses, len(keys) + 1) else: self.assertEqual(num_reg_losses, len(keys))
def __init__(self, obs, nums, glimpse_size=(20, 20), inpt_encoder_hidden=[256]*2, glimpse_encoder_hidden=[256]*2, glimpse_decoder_hidden=[252]*2, transform_estimator_hidden=[256]*2, steps_pred_hidden=[50]*1, baseline_hidden=[256, 128]*1, transform_var_bias=-2., step_bias=0., *args, **kwargs): self.baseline = BaselineMLP(baseline_hidden) def _make_transform_estimator(x): est = StochasticTransformParam(transform_estimator_hidden, x, scale_bias=transform_var_bias) return est super(AIRonMNIST, self).__init__( *args, obs=obs, nums=nums, glimpse_size=glimpse_size, n_appearance=50, transition=snt.LSTM(256), input_encoder=(lambda: Encoder(inpt_encoder_hidden)), glimpse_encoder=(lambda: Encoder(glimpse_encoder_hidden)), glimpse_decoder=(lambda x: Decoder(glimpse_decoder_hidden, x)), transform_estimator=_make_transform_estimator, steps_predictor=(lambda: StepsPredictor(steps_pred_hidden, step_bias)), output_std=.3, **kwargs )
def __init__(self, hidden_size, num_res_blocks): super().__init__(name='LSTM') self._hidden_size = hidden_size self._lstm = snt.LSTM(hidden_size) # use a resnet before the LSTM self._resnet = resnet(num_res_blocks, hidden_size)
def testCellClipping(self): core = snt.LSTM(hidden_size=5, cell_clip_value=1.0) obs = tf.constant(np.random.rand(3, 10), dtype=tf.float32) hidden = tf.placeholder(tf.float32, shape=[3, 5]) cell = tf.placeholder(tf.float32, shape=[3, 5]) output = core(obs, [hidden, cell]) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) unclipped = np.random.rand(3, 5) - 0.5 unclipped *= 2.0 / unclipped.max() clipped = unclipped.clip(-1., 1.) output1, (hidden1, cell1) = sess.run(output, feed_dict={ hidden: unclipped, cell: unclipped }) output2, (hidden2, cell2) = sess.run(output, feed_dict={ hidden: unclipped, cell: clipped }) self.assertAllClose(output1, output2) self.assertAllClose(hidden1, hidden2) self.assertAllClose(cell1, cell2)
def build_module(inputs, input_len): cores = [ snt.LSTM(hidden_size, custom_getter=lstm_bbb_custom_getter, forget_bias=0.0, name="lstm_layer_{}".format(i)) for i in six.moves.range(n_layers) ] rnn_core = snt.DeepRNN(cores, skip_connections=True, name="deep_lstm_core") # Do BBB on weights but not biases of output layer. output_linear = Linear(n_classes, custom_getter={"w": non_lstm_bbb_custom_getter}) # initial_rnn_state = nest.map_structure(lambda t: tf.get_local_variable( # "{}/rnn_state/{}".format("train", t.op.name), initializer=t), rnn_core.initial_state(batch_size)) # assign_zero_rnn_state = nest.map_structure(lambda x: x.assign(tf.zeros_like(x)), initial_rnn_state) # assign_zero_rnn_state = tf.group(*nest.flatten(assign_zero_rnn_state)) # Unroll the RNN core over the sequence. rnn_output_seq, rnn_final_state = tf.nn.dynamic_rnn( cell=rnn_core, inputs=inputs, sequence_length=input_len, dtype=tf.float32) # Persist the RNN state for the next unroll. # update_rnn_state = nest.map_structure(tf.assign, initial_rnn_state, rnn_final_state) # with tf.control_dependencies(nest.flatten(update_rnn_state)): # rnn_output_seq = tf.identity(rnn_output_seq, name="rnn_output_seq") rnn_output_seq = tf.reshape(tf.concat(rnn_output_seq, 2), [-1, 2 * 512]) output_logits = output_linear(rnn_output_seq) return output_logits #, assign_zero_rnn_state
def test_rnn_snapshot(self): """Test that snapshotter correctly calls saves/restores snapshots on RNNs.""" # Create a test network. net = snt.LSTM(10) spec = specs.Array([10], dtype=np.float32) tf2_utils.create_variables(net, [spec]) # Test that if you add some postprocessing without rerunning # create_variables, it still works. wrapped_net = snt.DeepRNN([net, lambda x: x]) for net1 in [net, wrapped_net]: # Save the test network. directory = self.get_tempdir() objects_to_save = {'net': net1} snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory) snapshotter.save() # Reload the test network. net2 = tf.saved_model.load(os.path.join(snapshotter.directory, 'net')) inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec)) with tf.GradientTape() as tape: outputs1, next_state1 = net1(inputs, net1.initial_state(1)) loss1 = tf.math.reduce_sum(outputs1) grads1 = tape.gradient(loss1, net1.trainable_variables) with tf.GradientTape() as tape: outputs2, next_state2 = net2(inputs, net2.initial_state(1)) loss2 = tf.math.reduce_sum(outputs2) grads2 = tape.gradient(loss2, net2.trainable_variables) assert np.allclose(outputs1, outputs2) assert np.allclose(tree.flatten(next_state1), tree.flatten(next_state2)) assert all(tree.map_structure(np.allclose, list(grads1), list(grads2)))
def testVariables(self): batch_size = 5 hidden_size = 20 mod_name = "rnn" inputs = tf.placeholder(tf.float32, shape=[batch_size, hidden_size]) prev_cell = tf.placeholder(tf.float32, shape=[batch_size, hidden_size]) prev_hidden = tf.placeholder(tf.float32, shape=[batch_size, hidden_size]) lstm = snt.LSTM(hidden_size, name=mod_name) self.assertEqual(lstm.scope_name, mod_name) with self.assertRaisesRegexp(snt.Error, "not instantiated yet"): lstm.get_variables() lstm(inputs, (prev_hidden, prev_cell)) lstm_variables = lstm.get_variables() self.assertEqual(len(lstm_variables), 2, "LSTM should have 2 variables") param_map = { param.name.split("/")[-1].split(":")[0]: param for param in lstm_variables } self.assertShapeEqual(np.ndarray(4 * hidden_size), param_map[snt.LSTM.B_GATES].initial_value) self.assertShapeEqual(np.ndarray((2 * hidden_size, 4 * hidden_size)), param_map[snt.LSTM.W_GATES].initial_value)
def __init__(self, controller_config, memory_config, output_size, classic_dnc_output=False, clip_value=20, name="dnc"): super().__init__(name=name) self._output_size = output_size self._R = memory_config["read_heads_num"] self._W = memory_config["word_size"] self._interface_vector_size = self._R * self._W + 3 * self._W + 5 * self._R + 3 self._clip_value = clip_value self._classic_dnc_output = classic_dnc_output with self._enter_variable_scope(): self._controller = snt.LSTM(**controller_config, cell_clip_value=clip_value) self._memory = Memory(**memory_config) self._controller_to_interface_weights = snt.Linear( self._interface_vector_size, name='controller_to_interface') if not self._classic_dnc_output: self._controller_to_output_weights = snt.Linear( self._output_size, name="controller_to_output") self._memory_to_output_weights = snt.Linear( self._output_size, name="memory_to_output")
def testSameInStaticAndDynamic(self): batch_size = 3 seq_len = 2 hidden_size = 3 input_size = 3 inputs = tf.placeholder(tf.float32, shape=[batch_size, seq_len, input_size], name="inputs") static_inputs = tf.unstack(inputs, axis=1) cell = snt.LSTM(hidden_size=hidden_size) static_output_unpacked, _ = tf.contrib.rnn.static_rnn( cell, static_inputs, initial_state=cell.initial_state(batch_size, tf.float32)) dynamic_output, _ = tf.nn.dynamic_rnn( cell, inputs, initial_state=cell.initial_state(batch_size, tf.float32), dtype=tf.float32) static_output = tf.stack(static_output_unpacked, axis=1) with self.test_session() as session: tf.global_variables_initializer().run() # Check that static and dynamic give the same output input_data = np.random.rand(batch_size, seq_len, input_size) static_out, dynamic_out = session.run([static_output, dynamic_output], feed_dict={inputs: input_data}) self.assertAllClose(static_out, dynamic_out)
def testPeephole(self): batch_size = 5 hidden_size = 20 # Initialize the rnn and verify the number of parameter sets. inputs = tf.placeholder(tf.float32, shape=[batch_size, hidden_size]) prev_cell = tf.placeholder(tf.float32, shape=[batch_size, hidden_size]) prev_hidden = tf.placeholder(tf.float32, shape=[batch_size, hidden_size]) lstm = snt.LSTM(hidden_size, use_peepholes=True) _, next_state = lstm(inputs, (prev_hidden, prev_cell)) next_hidden, next_cell = next_state lstm_variables = lstm.get_variables() self.assertEqual(len(lstm_variables), 5, "LSTM should have 5 variables") # Unpack parameters into dict and check their sizes. param_map = {param.name.split("/")[-1].split(":")[0]: param for param in lstm_variables} self.assertShapeEqual(np.ndarray(4 * hidden_size), param_map[snt.LSTM.B_GATES].initial_value) self.assertShapeEqual(np.ndarray((2 * hidden_size, 4 * hidden_size)), param_map[snt.LSTM.W_GATES].initial_value) self.assertShapeEqual(np.ndarray(hidden_size), param_map[snt.LSTM.W_F_DIAG].initial_value) self.assertShapeEqual(np.ndarray(hidden_size), param_map[snt.LSTM.W_I_DIAG].initial_value) self.assertShapeEqual(np.ndarray(hidden_size), param_map[snt.LSTM.W_O_DIAG].initial_value) # With random data, check the TF calculation matches the Numpy version. input_data = np.random.randn(batch_size, hidden_size) prev_hidden_data = np.random.randn(batch_size, hidden_size) prev_cell_data = np.random.randn(batch_size, hidden_size) with self.test_session() as session: tf.global_variables_initializer().run() fetches = [(next_hidden, next_cell), param_map[snt.LSTM.W_GATES], param_map[snt.LSTM.B_GATES], param_map[snt.LSTM.W_F_DIAG], param_map[snt.LSTM.W_I_DIAG], param_map[snt.LSTM.W_O_DIAG]] output = session.run(fetches, {inputs: input_data, prev_cell: prev_cell_data, prev_hidden: prev_hidden_data}) next_state_ex, w_ex, b_ex, wfd_ex, wid_ex, wod_ex = output in_and_hid = np.concatenate((input_data, prev_hidden_data), axis=1) real_gate = np.dot(in_and_hid, w_ex) + b_ex # i = input_gate, j = next_input, f = forget_gate, o = output_gate i, j, f, o = np.hsplit(real_gate, 4) real_cell = (prev_cell_data / (1 + np.exp(-(f + lstm._forget_bias + wfd_ex * prev_cell_data))) + 1 / (1 + np.exp(-(i + wid_ex * prev_cell_data))) * np.tanh(j)) real_hidden = (np.tanh(real_cell + wod_ex * real_cell) * 1 / (1 + np.exp(-o))) self.assertAllClose(real_hidden, next_state_ex[0]) self.assertAllClose(real_cell, next_state_ex[1])
def compute_dynamic_rnn(inputs, config, sequence_length=None): if config["lstm_hidden_size"] > 0: if sequence_length != None: sequence_length = tf.squeeze(sequence_length, axis=1) # todo : bi lstm ? lstm_cell = snt.LSTM(hidden_size=config["lstm_hidden_size"], use_layer_norm=config["use_layer_norm"]) # todo : optionnal init state with vision value _, last_ht_rnn = tf.nn.dynamic_rnn( lstm_cell, inputs=inputs, dtype=tf.float32, time_major=False, sequence_length= sequence_length #if available, reduce time and complexity ) last_ht_rnn = last_ht_rnn.hidden else: # NO USE OF LSTM last_ht_rnn = inputs return last_ht_rnn
def get_rnn_core(cfg): """Get the Sonnet rnn cell from the given config. Args: cfg: config generated from `sample_rnn_core`. Returns: A Sonnet module with the given config. """ name, args = cfg if name == "lstm": init = {} init = {"w_gates": get_initializer(args["w_gates"])} return snt.LSTM(args["core_dim"], initializers=init) elif name == "gru": init = {} for init_key in ["wh", "wz", "wr", "uh", "uz", "ur"]: init[init_key] = get_initializer(args[init_key]) return snt.GRU(args["core_dim"], initializers=init) elif name == "vrnn": init = { "in_to_hidden": { "w": get_initializer(args["in_to_hidden"]) }, "hidden_to_hidden": { "w": get_initializer(args["hidden_to_hidden"]) }, } act_fn = get_activation(args["act_fn"]) return snt.VanillaRNN(args["core_dim"], initializers=init, activation=act_fn) else: raise ValueError("No core for name [%s] found." % name)
def _make_network(action_spec: specs.DiscreteArray) -> snt.RNNCore: return snt.DeepRNN([ snt.Flatten(), snt.LSTM(20), snt.nets.MLP([50, 50]), networks.PolicyValueHead(action_spec.num_values), ])
def __init__(self, hidden_sizes, name="lstm"): super().__init__(name=name) self._hidden_sizes = hidden_sizes with self._enter_variable_scope(): self._lstm_layers = [ snt.LSTM(hidden_size=h) for h in self._hidden_sizes ]
def __init__(self, action_spec: specs.DiscreteArray): super().__init__(name='r2d2_test_network') self._net = snt.DeepRNN([ snt.Flatten(), snt.LSTM(20), snt.nets.MLP([50, 50, action_spec.num_values]) ])
def __init__(self, access_config, controller_config, output_size, clip_value=None, name='dnc'): """Initializes the DNC core. Args: access_config: dictionary of access module configurations. controller_config: dictionary of controller (LSTM) module configurations. output_size: output dimension size of core. clip_value: clips controller and core output values to between `[-clip_value, clip_value]` if specified. name: module name (default 'dnc'). Raises: TypeError: if direct_input_size is not None for any access module other than KeyValueMemory. """ super(DNC, self).__init__(name=name) with self._enter_variable_scope(): self._controller = snt.LSTM(**controller_config) self._access = access.MemoryAccess(**access_config) self._access_output_size = np.prod(self._access.output_size.as_list()) self._output_size = output_size self._clip_value = clip_value or 0 self._output_size = tf.TensorShape([output_size]) self._state_size = DNCState( access_output=self._access_output_size, access_state=self._access.state_size, controller_state=self._controller.state_size)
def __init__(self, output_size, layers, preprocess_name="identity", preprocess_options=None, scale=1.0, initializer=None, name="deep_lstm", tanh_output=False, num_linear_heads=1): """Creates an instance of `StandardDeepLSTM`. Args: output_size: Output sizes of the final linear layer. layers: Output sizes of LSTM layers. preprocess_name: Gradient preprocessing class name (in `l2l.preprocess` or tf modules). Default is `tf.identity`. preprocess_options: Gradient preprocessing options. scale: Gradient scaling (default is 1.0). initializer: Variable initializer for linear layer. See `snt.Linear` and `snt.LSTM` docs for more info. This parameter can be a string (e.g. "zeros" will be converted to tf.zeros_initializer). name: Module name. """ super(StandardDeepLSTM, self).__init__(name=name) self._output_size = output_size self._scale = scale self._num_linear_heads = num_linear_heads assert self._num_linear_heads >= 1 if preprocess_name == 'fc': with tf.variable_scope(self._template.variable_scope): init = _get_layer_initializers(initializer, "input_projection", ("w", "b")) print("initialization of input projection !!!!!!!!!!") print(init) self._linear_input = snt.Linear(preprocess_options["dim"], name="input_projection", initializers=init) elif hasattr(preprocess, preprocess_name): preprocess_class = getattr(preprocess, preprocess_name) self._preprocess = preprocess_class(initializer, **preprocess_options) else: self._preprocess = getattr(tf, preprocess_name) self._preprocess_name = preprocess_name with tf.variable_scope(self._template.variable_scope): self._cores = [] for i, size in enumerate(layers, start=1): name = "lstm_{}".format(i) init = _get_layer_initializers(initializer, name, ("w_gates", "b_gates")) self._cores.append(snt.LSTM(size, name=name, initializers=init)) self._rnn = snt.DeepRNN(self._cores, skip_connections=False, name="deep_rnn") if self._num_linear_heads == 1: init = _get_layer_initializers(initializer, "linear", ("w", "b")) self._linear = snt.Linear(output_size, name="linear", initializers=init) else: self._linears = [] init = _get_layer_initializers(initializer, "linear", ("w", "b")) self._linears.append(snt.Linear(output_size, name="linear", initializers=init)) for i in range(1, num_linear_heads): init = _get_layer_initializers(initializer, "linear_{}".format(i), ("w", "b")) self._linears.append(snt.Linear(output_size, name="linear_{}".format(i), initializers=init)) self.init_flag = False self.tanh_output = tanh_output