def testRegularizers(self): inputs = tf.placeholder(tf.float32, shape=[self.batch_size, self.in_size]) prev_state = tf.placeholder(tf.float32, shape=[self.batch_size, self.hidden_size]) with self.assertRaisesRegexp(KeyError, "Invalid regularizer keys.*"): snt.VanillaRNN(name="rnn", hidden_size=self.hidden_size, regularizers={"invalid": None}) err = "Regularizer for 'w' is not a callable function" with self.assertRaisesRegexp(TypeError, err): snt.VanillaRNN(name="rnn", hidden_size=self.hidden_size, regularizers={"in_to_hidden": {"w": tf.zeros([10, 10])}}) # Nested regularizers. valid_regularizers = { "in_to_hidden": { "w": tf.nn.l2_loss, }, "hidden_to_hidden": { "b": tf.nn.l2_loss, } } vanilla_rnn = snt.VanillaRNN(name="rnn", hidden_size=self.hidden_size, regularizers=valid_regularizers) vanilla_rnn(inputs, prev_state) regularizers = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) self.assertEqual(len(regularizers), 2)
def testVariables(self): batch_size = 3 in_size = 2 hidden1_size = 4 hidden2_size = 5 mod_name = "deep_rnn" inputs = tf.placeholder(tf.float32, shape=[batch_size, in_size]) prev_state1 = tf.placeholder(tf.float32, shape=[batch_size, hidden1_size]) prev_state2 = tf.placeholder(tf.float32, shape=[batch_size, hidden1_size]) prev_state = (prev_state1, prev_state2) cores = [ snt.VanillaRNN(name="rnn1", hidden_size=hidden1_size), snt.VanillaRNN(name="rnn2", hidden_size=hidden2_size) ] deep_rnn = snt.DeepRNN(cores, name=mod_name) self.assertEqual(deep_rnn.scope_name, mod_name) with self.assertRaisesRegexp(snt.Error, "not instantiated yet"): deep_rnn.get_variables() deep_rnn(inputs, prev_state) # No variables now exposed by the DeepRNN. self.assertEqual(deep_rnn.get_variables(), ()) # Have to retrieve the modules from the cores individually. deep_rnn_variables = tuple( itertools.chain.from_iterable([c.get_variables() for c in cores])) self.assertEqual(len(deep_rnn_variables), 4 * len(cores), "Cores should have %d variables" % (4 * len(cores))) for v in deep_rnn_variables: self.assertRegexpMatches( v.name, "rnn(1|2)/(in_to_hidden|hidden_to_hidden)/(w|b):0")
def testShape(self): batch_size = 3 batch_size_shape = tf.TensorShape(batch_size) in_size = 2 hidden1_size = 4 hidden2_size = 5 inputs = tf.placeholder(tf.float32, shape=[batch_size, in_size]) prev_state0 = tf.placeholder(tf.float32, shape=[batch_size, in_size]) prev_state1 = tf.placeholder(tf.float32, shape=[batch_size, hidden1_size]) prev_state2 = tf.placeholder(tf.float32, shape=[batch_size, hidden2_size]) prev_state = (prev_state0, prev_state1, prev_state2) # Test recurrent and non-recurrent cores cores = [ snt.VanillaRNN(name="rnn0", hidden_size=in_size), snt.VanillaRNN(name="rnn1", hidden_size=hidden1_size), snt.VanillaRNN(name="rnn2", hidden_size=hidden2_size) ] deep_rnn = snt.DeepRNN(cores, name="deep_rnn", skip_connections=True) output, next_state = deep_rnn(inputs, prev_state) output_shape = output.get_shape() output_size = in_size + hidden1_size + hidden2_size self.assertTrue( output_shape.is_compatible_with([batch_size, output_size])) self.assertTrue( output_shape.is_compatible_with( batch_size_shape.concatenate(deep_rnn.output_size))) next_state_shape = (next_state[0].get_shape(), next_state[1].get_shape(), next_state[2].get_shape()) self.assertTrue(next_state_shape[0].is_compatible_with( [batch_size, in_size])) self.assertTrue(next_state_shape[1].is_compatible_with( [batch_size, hidden1_size])) self.assertTrue(next_state_shape[2].is_compatible_with( [batch_size, hidden2_size])) for state_shape, expected_shape in zip(next_state_shape, deep_rnn.state_size): self.assertTrue( state_shape.is_compatible_with( batch_size_shape.concatenate(expected_shape))) # Initial state should be a valid state initial_state = deep_rnn.initial_state(batch_size, tf.float32) self.assertTrue(len(initial_state), len(next_state)) self.assertShapeEqual(np.ndarray((batch_size, in_size)), initial_state[0]) self.assertShapeEqual(np.ndarray((batch_size, hidden1_size)), initial_state[1]) self.assertShapeEqual(np.ndarray((batch_size, hidden2_size)), initial_state[2])
def testComputation(self, skip_connections, create_initial_state): batch_size = 3 in_size = 2 hidden1_size = 4 hidden2_size = 5 mod_name = "deep_rnn" cores = [ snt.VanillaRNN(name="rnn1", hidden_size=hidden1_size), snt.VanillaRNN(name="rnn2", hidden_size=hidden2_size) ] deep_rnn = snt.DeepRNN(cores, name=mod_name, skip_connections=skip_connections) inputs = tf.constant(np.random.randn(batch_size, in_size), dtype=tf.float32) if create_initial_state: prev_state = deep_rnn.initial_state(batch_size, tf.float32) else: prev_state1 = tf.constant(np.random.randn(batch_size, hidden1_size), dtype=tf.float32) prev_state2 = tf.constant(np.random.randn(batch_size, hidden2_size), dtype=tf.float32) prev_state = (prev_state1, prev_state2) output, next_state = deep_rnn(inputs, prev_state) # With random data, check the DeepRNN calculation matches the manual # stacking version. self.evaluate(tf.global_variables_initializer()) outputs_value = self.evaluate([output, next_state[0], next_state[1]]) output_value, next_state1_value, next_state2_value = outputs_value # Build manual computation graph output1, next_state1 = cores[0](inputs, prev_state[0]) if skip_connections: input2 = tf.concat([inputs, output1], 1) else: input2 = output1 output2, next_state2 = cores[1](input2, prev_state[1]) if skip_connections: manual_output = tf.concat([output1, output2], 1) else: manual_output = output2 manual_outputs_value = self.evaluate( [manual_output, next_state1, next_state2]) manual_output_value = manual_outputs_value[0] manual_next_state1_value = manual_outputs_value[1] manual_next_state2_value = manual_outputs_value[2] self.assertAllClose(output_value, manual_output_value) self.assertAllClose(next_state1_value, manual_next_state1_value) self.assertAllClose(next_state2_value, manual_next_state2_value)
def testInitialState(self, trainable, use_custom_initial_value): batch_size = 3 hidden1_size = 4 hidden2_size = 5 output1_size = 6 output2_size = 7 initializer = None if use_custom_initial_value: initializer = [ tf.constant_initializer(8), tf.constant_initializer(9) ] # Test that the initial state of a non-recurrent DeepRNN is an empty list. non_recurrent_cores = [ snt.Linear(output_size=output1_size), snt.Linear(output_size=output2_size) ] dummy_deep_rnn = snt.DeepRNN(non_recurrent_cores, skip_connections=False) dummy_initial_state = dummy_deep_rnn.initial_state(batch_size, trainable=trainable) self.assertFalse(dummy_initial_state) # Test that the initial state of a recurrent DeepRNN is the same as calling # all cores' initial_state method. cores = [ snt.VanillaRNN(hidden_size=hidden1_size), snt.VanillaRNN(hidden_size=hidden2_size) ] deep_rnn = snt.DeepRNN(cores) initial_state = deep_rnn.initial_state( batch_size, trainable=trainable, trainable_initializers=initializer) expected_initial_state = [] for i, core in enumerate(cores): with tf.variable_scope("core-%d" % i): expected_initializer = None if initializer: expected_initializer = initializer[i] expected_initial_state.append( core.initial_state( batch_size, trainable=trainable, trainable_initializers=expected_initializer)) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) initial_state_value = sess.run(initial_state) expected_initial_state_value = sess.run(expected_initial_state) for expected_value, actual_value in zip(expected_initial_state_value, initial_state_value): self.assertAllEqual(actual_value, expected_value)
def testComputation(self, skip_connections, create_initial_state): batch_size = 3 in_size = 2 hidden1_size = 4 hidden2_size = 5 mod_name = "deep_rnn" cores = [snt.VanillaRNN(name="rnn1", hidden_size=hidden1_size), snt.VanillaRNN(name="rnn2", hidden_size=hidden2_size)] deep_rnn = snt.DeepRNN(cores, name=mod_name, skip_connections=skip_connections) inputs = tf.placeholder(tf.float32, shape=[batch_size, in_size]) if create_initial_state: prev_state = deep_rnn.initial_state(batch_size, tf.float32) else: prev_state1 = tf.placeholder( tf.float32, shape=[batch_size, hidden1_size]) prev_state2 = tf.placeholder( tf.float32, shape=[batch_size, hidden2_size]) prev_state = (prev_state1, prev_state2) output, next_state = deep_rnn(inputs, prev_state) with self.test_session() as sess: # With random data, check the DeepRNN calculation matches the manual # stacking version. input_data = np.random.randn(batch_size, in_size) feed_dict = {inputs: input_data} if not create_initial_state: feed_dict[prev_state1] = np.random.randn(batch_size, hidden1_size) feed_dict[prev_state2] = np.random.randn(batch_size, hidden2_size) tf.global_variables_initializer().run() outputs_value = sess.run([output, next_state[0], next_state[1]], feed_dict=feed_dict) output_value, next_state1_value, next_state2_value = outputs_value # Build manual computation graph output1, next_state1 = cores[0](inputs, prev_state[0]) if skip_connections: input2 = tf.concat([inputs, output1], 1) else: input2 = output1 output2, next_state2 = cores[1](input2, prev_state[1]) if skip_connections: manual_output = tf.concat([output1, output2], 1) else: manual_output = output2 manual_outputs_value = sess.run([manual_output, next_state1, next_state2], feed_dict=feed_dict) manual_output_value = manual_outputs_value[0] manual_next_state1_value = manual_outputs_value[1] manual_next_state2_value = manual_outputs_value[2] self.assertAllClose(output_value, manual_output_value) self.assertAllClose(next_state1_value, manual_next_state1_value) self.assertAllClose(next_state2_value, manual_next_state2_value)
def testPartitioners(self): if tf.executing_eagerly(): self.skipTest( "Partitioned variables are not supported in eager mode.") inputs = tf.ones(dtype=tf.float32, shape=[self.batch_size, self.in_size]) prev_state = tf.ones(dtype=tf.float32, shape=[self.batch_size, self.hidden_size]) with self.assertRaisesRegexp(KeyError, "Invalid partitioner keys.*"): snt.VanillaRNN(name="rnn", hidden_size=self.hidden_size, partitioners={"invalid": None}) err = "Partitioner for 'w' is not a callable function" with self.assertRaisesRegexp(TypeError, err): snt.VanillaRNN( name="rnn", hidden_size=self.hidden_size, partitioners={"in_to_hidden": { "w": tf.zeros([10, 10]) }}) # Nested partitioners. valid_partitioners = { "in_to_hidden": { "w": tf.fixed_size_partitioner(num_shards=2), "b": tf.fixed_size_partitioner(num_shards=2), }, "hidden_to_hidden": { "w": tf.fixed_size_partitioner(num_shards=2), "b": tf.fixed_size_partitioner(num_shards=2), } } vanilla_rnn = snt.VanillaRNN(name="rnn", hidden_size=self.hidden_size, partitioners=valid_partitioners) vanilla_rnn(inputs, prev_state) self.assertEqual(type(vanilla_rnn.in_to_hidden_linear.w), variables.PartitionedVariable) self.assertEqual(type(vanilla_rnn.in_to_hidden_linear.b), variables.PartitionedVariable) self.assertEqual(type(vanilla_rnn.hidden_to_hidden_linear.w), variables.PartitionedVariable) self.assertEqual(type(vanilla_rnn.hidden_to_hidden_linear.b), variables.PartitionedVariable)
def testIncompatibleOptions(self): in_size = 2 hidden1_size = 4 hidden2_size = 5 cores = [ snt.Linear(name="linear", output_size=in_size), snt.VanillaRNN(name="rnn1", hidden_size=hidden1_size), snt.VanillaRNN(name="rnn2", hidden_size=hidden2_size) ] with self.assertRaisesRegexp( ValueError, "skip_connections are enabled but not all cores are " "recurrent, which is not supported"): snt.DeepRNN(cores, name="deep_rnn", skip_connections=True)
def testVariables(self): mod_name = "rnn" inputs = tf.placeholder(tf.float32, shape=[self.batch_size, self.in_size]) prev_state = tf.placeholder(tf.float32, shape=[self.batch_size, self.hidden_size]) vanilla_rnn = snt.VanillaRNN(name=mod_name, hidden_size=self.hidden_size) self.assertEqual(vanilla_rnn.scope_name, mod_name) with self.assertRaisesRegexp(snt.Error, "not instantiated yet"): vanilla_rnn.get_variables() vanilla_rnn(inputs, prev_state) rnn_variables = vanilla_rnn.get_variables() self.assertEqual(len(rnn_variables), 4, "RNN should have 4 variables") in_to_hidden_w = next(v for v in rnn_variables if v.name == "%s/in_to_hidden/w:0" % mod_name) in_to_hidden_b = next(v for v in rnn_variables if v.name == "%s/in_to_hidden/b:0" % mod_name) hidden_to_hidden_w = next(v for v in rnn_variables if v.name == "%s/hidden_to_hidden/w:0" % mod_name) hidden_to_hidden_b = next(v for v in rnn_variables if v.name == "%s/hidden_to_hidden/b:0" % mod_name) self.assertShapeEqual(np.ndarray((self.in_size, self.hidden_size)), in_to_hidden_w.initial_value) self.assertShapeEqual(np.ndarray(self.hidden_size), in_to_hidden_b.initial_value) self.assertShapeEqual(np.ndarray((self.hidden_size, self.hidden_size)), hidden_to_hidden_w.initial_value) self.assertShapeEqual(np.ndarray(self.hidden_size), hidden_to_hidden_b.initial_value)
def get_rnn_core(cfg): """Get the Sonnet rnn cell from the given config. Args: cfg: config generated from `sample_rnn_core`. Returns: A Sonnet module with the given config. """ name, args = cfg if name == "lstm": init = {} init = {"w_gates": get_initializer(args["w_gates"])} return snt.LSTM(args["core_dim"], initializers=init) elif name == "gru": init = {} for init_key in ["wh", "wz", "wr", "uh", "uz", "ur"]: init[init_key] = get_initializer(args[init_key]) return snt.GRU(args["core_dim"], initializers=init) elif name == "vrnn": init = { "in_to_hidden": { "w": get_initializer(args["in_to_hidden"]) }, "hidden_to_hidden": { "w": get_initializer(args["hidden_to_hidden"]) }, } act_fn = get_activation(args["act_fn"]) return snt.VanillaRNN(args["core_dim"], initializers=init, activation=act_fn) else: raise ValueError("No core for name [%s] found." % name)
def testComputation(self): inputs = tf.placeholder(tf.float32, shape=[self.batch_size, self.in_size]) prev_state = tf.placeholder(tf.float32, shape=[self.batch_size, self.in_size]) vanilla_rnn = snt.VanillaRNN(name="rnn", hidden_size=self.in_size) residual = snt.SkipConnectionCore(vanilla_rnn, name="skip") output, new_state = residual(inputs, prev_state) in_to_hid = vanilla_rnn.in_to_hidden_variables hid_to_hid = vanilla_rnn.hidden_to_hidden_variables with self.test_session() as sess: # With random data, check the TF calculation matches the Numpy version. input_data = np.random.randn(self.batch_size, self.in_size) prev_state_data = np.random.randn(self.batch_size, self.in_size) tf.global_variables_initializer().run() fetches = [output, new_state, in_to_hid[0], in_to_hid[1], hid_to_hid[0], hid_to_hid[1]] output = sess.run(fetches, {inputs: input_data, prev_state: prev_state_data}) output_v, new_state_v, in_to_hid_w, in_to_hid_b = output[:4] hid_to_hid_w, hid_to_hid_b = output[4:] real_in_to_hid = np.dot(input_data, in_to_hid_w) + in_to_hid_b real_hid_to_hid = np.dot(prev_state_data, hid_to_hid_w) + hid_to_hid_b vanilla_output = np.tanh(real_in_to_hid + real_hid_to_hid) skip_output = np.concatenate((input_data, vanilla_output), -1) self.assertAllClose(skip_output, output_v) self.assertAllClose(vanilla_output, new_state_v)
def testComputation(self): input_data = np.random.randn(self.batch_size, self.in_size) prev_state_data = np.random.randn(self.batch_size, self.hidden_size) inputs = tf.convert_to_tensor(input_data) prev_state = tf.convert_to_tensor(prev_state_data) vanilla_rnn = snt.VanillaRNN(name="rnn", hidden_size=self.hidden_size) output, next_state = vanilla_rnn(inputs, prev_state) in_to_hid = vanilla_rnn.in_to_hidden_variables hid_to_hid = vanilla_rnn.hidden_to_hidden_variables # With random data, check the TF calculation matches the Numpy version. self.evaluate(tf.global_variables_initializer()) fetches = [ output, next_state, in_to_hid[0], in_to_hid[1], hid_to_hid[0], hid_to_hid[1] ] output = self.evaluate(fetches) output_v, next_state_v, in_to_hid_w, in_to_hid_b = output[:4] hid_to_hid_w, hid_to_hid_b = output[4:] real_in_to_hid = np.dot(input_data, in_to_hid_w) + in_to_hid_b real_hid_to_hid = np.dot(prev_state_data, hid_to_hid_w) + hid_to_hid_b real_output = np.tanh(real_in_to_hid + real_hid_to_hid) self.assertAllClose(real_output, output_v) self.assertAllClose(real_output, next_state_v)
def _(): base_model_fn = rnn_classification( lambda: snt.VanillaRNN(64, activation=tf.nn.relu), embed_dim=64, aggregate_method="avg") dataset = imdb_subword(128, 32) return base.DatasetModelTask(base_model_fn, dataset)
def testIncompatibleOptions(self): in_size = 2 hidden1_size = 4 hidden2_size = 5 cores = [snt.Linear(name="linear", output_size=in_size), snt.VanillaRNN(name="rnn1", hidden_size=hidden1_size), snt.VanillaRNN(name="rnn2", hidden_size=hidden2_size)] with self.assertRaisesRegexp( ValueError, "skip_connections are enabled but not all cores are " "`snt.RNNCore`s, which is not supported"): snt.DeepRNN(cores, name="deep_rnn", skip_connections=True) cells = [tf.contrib.rnn.BasicLSTMCell(5), tf.contrib.rnn.BasicLSTMCell(5)] with self.assertRaisesRegexp( ValueError, "skip_connections are enabled but not all cores are " "`snt.RNNCore`s, which is not supported"): snt.DeepRNN(cells, skip_connections=True)
def testInitializers(self): inputs = tf.placeholder(tf.float32, shape=[self.batch_size, self.in_size]) prev_state = tf.placeholder(tf.float32, shape=[self.batch_size, self.hidden_size]) with self.assertRaisesRegexp(KeyError, "Invalid initializer keys.*"): snt.VanillaRNN(name="rnn", hidden_size=self.hidden_size, initializers={"invalid": None}) err = "Initializer for 'w' is not a callable function" with self.assertRaisesRegexp(TypeError, err): snt.VanillaRNN( name="rnn", hidden_size=self.hidden_size, initializers={"in_to_hidden": { "w": tf.zeros([10, 10]) }}) # Nested initializer. valid_initializers = { "in_to_hidden": { "w": tf.ones_initializer(), }, "hidden_to_hidden": { "b": tf.ones_initializer(), } } vanilla_rnn = snt.VanillaRNN(name="rnn", hidden_size=self.hidden_size, initializers=valid_initializers) vanilla_rnn(inputs, prev_state) init = tf.global_variables_initializer() with self.test_session() as sess: sess.run(init) w_v, b_v = sess.run([ vanilla_rnn.in_to_hidden_linear.w, vanilla_rnn.hidden_to_hidden_linear.b, ]) self.assertAllClose(w_v, np.ones([self.in_size, self.hidden_size])) self.assertAllClose(b_v, np.ones([self.hidden_size]))
def testShape(self): inputs = tf.placeholder(tf.float32, shape=[self.batch_size, self.in_size]) prev_state = tf.placeholder(tf.float32, shape=[self.batch_size, self.hidden_size]) vanilla_rnn = snt.VanillaRNN(name="rnn", hidden_size=self.hidden_size) output, next_state = vanilla_rnn(inputs, prev_state) shape = np.ndarray((self.batch_size, self.hidden_size)) self.assertShapeEqual(shape, output) self.assertShapeEqual(shape, next_state)
def __init__(self, n_dim, n_hidden): super(RecurrentNormalImpl, self).__init__() self._n_dim = n_dim self._n_hidden = n_hidden with self._enter_variable_scope(): self._rnn = snt.VanillaRNN(self._n_dim) self._readout = snt.Linear(n_dim * 2) self._init_state = self._rnn.initial_state(1, trainable=True) self._init_sample = tf.zeros((1, self._n_hidden))
def testShape(self): inputs = tf.placeholder(tf.float32, shape=[self.batch_size, self.in_size]) prev_state = tf.placeholder( tf.float32, shape=[self.batch_size, self.in_size]) vanilla_rnn = snt.VanillaRNN(self.in_size) residual_wrapper = snt.ResidualCore(vanilla_rnn, name="residual") output, next_state = residual_wrapper(inputs, prev_state) shape = np.ndarray((self.batch_size, self.in_size)) self.assertEqual(self.in_size, residual_wrapper.output_size) self.assertShapeEqual(shape, output) self.assertShapeEqual(shape, next_state)
def testShape(self): inputs = tf.placeholder(tf.float32, shape=[self.batch_size, self.in_size]) prev_state = tf.placeholder( tf.float32, shape=[self.batch_size, self.hidden_size]) vanilla_rnn = snt.VanillaRNN(self.hidden_size) skip_wrapper = snt.SkipConnectionCore(vanilla_rnn, name="skip") output, next_state = skip_wrapper(inputs, prev_state) output_shape = np.ndarray((self.batch_size, self.in_size + self.hidden_size)) state_shape = np.ndarray((self.batch_size, self.hidden_size)) self.assertShapeEqual(output_shape, output) self.assertShapeEqual(state_shape, next_state)
def __init__(self, n_dim, n_hidden, conditional=False, output_initializers=None): super(RecurrentNormalImpl, self).__init__() self._n_dim = n_dim self._n_hidden = n_hidden with self._enter_variable_scope(): self._rnn = snt.VanillaRNN(self._n_dim) self._readout = snt.Linear(self._n_dim * 2, initializers=output_initializers) self._init_state = self._rnn.initial_state(1, trainable=True) self._init_sample = tf.get_variable('init_sample', shape=(1, self._n_dim), trainable=True) if conditional: self._cond_state = snt.Sequential([snt.Linear(self._n_hidden), tf.nn.elu])
def make_rnn_model(): self.hidden_size = 20 valid_regularizers = { "in_to_hidden": { "w": tf.nn.l2_loss, }, "hidden_to_hidden": { "b": tf.nn.l2_loss, } } return snt.Sequential([ snt.VanillaRNN(name="rnn", hidden_size=self.hidden_size, egularizers=valid_regularizers), snt.LayerNorm() ])
def testOutputSize(self): inputs = tf.placeholder(tf.float32, shape=[self.batch_size, self.in_size]) prev_state = tf.placeholder( tf.float32, shape=[self.batch_size, self.hidden_size]) vanilla_rnn = snt.VanillaRNN(self.hidden_size) skip_wrapper = snt.SkipConnectionCore(vanilla_rnn, name="skip") with self.assertRaises(ValueError): _ = skip_wrapper.output_size skip_wrapper(inputs, prev_state) self.assertAllEqual([self.in_size + self.hidden_size], skip_wrapper.output_size.as_list()) skip_wrapper = snt.SkipConnectionCore( vanilla_rnn, input_shape=(self.in_size,), name="skip") self.assertAllEqual([self.in_size + self.hidden_size], skip_wrapper.output_size.as_list())
def testInvalidDicts(self): batch_size = 3 # Mistake seen in the wild - https://github.com/deepmind/sonnet/issues/74 # Should actually be {'hidden_to_hidden': {'w': some_initializers(), ...}} initializers = {"hidden_to_hidden": tf.truncated_normal_initializer(0, 1)} vanilla_rnn = snt.VanillaRNN(hidden_size=23, initializers=initializers) with self.assertRaisesRegexp(TypeError, "Expected a dict"): vanilla_rnn(tf.zeros([batch_size, 4], dtype=tf.float32), vanilla_rnn.zero_state(batch_size, dtype=tf.float32)) # Error: should be a dict mapping strings to partitioners/regularizers. partitioners = tf.fixed_size_partitioner(num_shards=16) with self.assertRaisesRegexp(TypeError, "Expected a dict"): snt.LSTM(hidden_size=42, partitioners=partitioners) regularizers = tf.contrib.layers.l1_regularizer(scale=0.5) with self.assertRaisesRegexp(TypeError, "Expected a dict"): snt.GRU(hidden_size=108, regularizers=regularizers)
def load(img, num, mean_img=None): F = tf.flags.FLAGS target = F.target.lower() assert target in Model.TARGETS, 'Target is {} and not in {}'.format( F.target, Model.TARGETS) gradients_through_z = True if target != Model.TARGETS[0]: gradients_through_z = False glimpse_size = [20, 20] n_hidden = 32 * 8 n_layers = 2 n_hiddens = [n_hidden] * n_layers n_what = 50 steps_pred_hidden = [128, 64] shape = img.shape.as_list() batch_size, img_size = shape[0], shape[1:] air_cell = AIRCell( img_size, glimpse_size, n_what, rnn=snt.VanillaRNN(256), input_encoder=partial(Encoder, n_hiddens), glimpse_encoder=partial(Encoder, n_hiddens), transform_estimator=partial(StochasticTransformParam, n_hiddens, scale_bias=F.transform_var_bias), steps_predictor=partial(StepsPredictor, steps_pred_hidden, F.step_bias), gradients_through_z=gradients_through_z) glimpse_decoder = partial(Decoder, n_hiddens, output_scale=F.output_multiplier) if F.step_success_prob != -1.: assert 0. <= F.step_success_prob <= 1. step_success_prob = F.step_success_prob else: step_success_prob = geom_success_prob(F.init_step_success_prob, F.final_step_success_prob) air = AttendInferRepeat( F.n_steps_per_image, F.output_std, step_success_prob, air_cell, glimpse_decoder, mean_img=mean_img, recurrent_prior=F.rec_prior, ) model = Model(img, air, F.k_particles, target=target, target_arg=F.target_arg, presence=num) return model
def rnn(): return snt.VanillaRNN(num_unit, activation=tf.nn.relu, initializers=init)