def testTraining(self, trainable_initial_state, max_unique_stats): """Test that everything trains OK, with or without trainable init. state.""" hidden_size = 3 batch_size = 3 time_steps = 3 cell = snt.BatchNormLSTM(hidden_size=hidden_size, max_unique_stats=max_unique_stats) inputs = tf.constant(np.random.rand(batch_size, time_steps, 3), dtype=tf.float32) initial_state = cell.initial_state(batch_size, tf.float32, trainable_initial_state) output, _ = tf.nn.dynamic_rnn( cell.with_batch_norm_control(is_training=True), inputs, initial_state=initial_state, dtype=tf.float32) loss = tf.reduce_mean( tf.square(output - np.random.rand(batch_size, time_steps, hidden_size))) train_op = tf.train.GradientDescentOptimizer(1).minimize(loss) init = tf.global_variables_initializer() with self.test_session(): init.run() train_op.run()
def testBatchNormInitializersCheck(self): hidden_size = 4 # Test that passing in a batchnorm initializer when we don't request # that form of batchnorm raises an error. for key, options in [(snt.BatchNormLSTM.GAMMA_H, { "use_batch_norm_h": False, "use_batch_norm_x": True }), (snt.BatchNormLSTM.GAMMA_X, { "use_batch_norm_x": False, "use_batch_norm_h": True }), (snt.BatchNormLSTM.GAMMA_C, { "use_batch_norm_c": False, "use_batch_norm_h": True }), (snt.BatchNormLSTM.BETA_C, { "use_batch_norm_c": False, "use_batch_norm_h": True })]: with self.assertRaisesRegexp(KeyError, "Invalid initializer"): snt.BatchNormLSTM( hidden_size, initializers={key: tf.constant_initializer(0)}, **options)
def _construct_lstm(use_batch_norm_h=False, use_batch_norm_x=False, use_batch_norm_c=False, **kwargs): # Preparing for deprecation, this uses plain LSTM if no batch norm required. if any([use_batch_norm_h, use_batch_norm_x, use_batch_norm_c]): return snt.BatchNormLSTM(use_batch_norm_h=use_batch_norm_h, use_batch_norm_x=use_batch_norm_x, use_batch_norm_c=use_batch_norm_c, **kwargs) else: return snt.LSTM(**kwargs)
def _construct_lstm(use_batch_norm_h=False, use_batch_norm_x=False, use_batch_norm_c=False, max_unique_stats=1, **kwargs): if any([use_batch_norm_h, use_batch_norm_x, use_batch_norm_c]): cell = snt.BatchNormLSTM( use_batch_norm_h=use_batch_norm_h, use_batch_norm_x=use_batch_norm_x, use_batch_norm_c=use_batch_norm_c, max_unique_stats=max_unique_stats, **kwargs) return cell, cell.with_batch_norm_control(is_training=True) else: cell = snt.LSTM(**kwargs) return cell, cell
def testBatchNormBuildFlag(self, use_batch_norm_h, use_batch_norm_x, use_batch_norm_c): """Check if an error is raised if we don't specify the is_training flag.""" batch_size = 2 hidden_size = 4 inputs = tf.placeholder(tf.float32, shape=[batch_size, hidden_size]) prev_cell = tf.placeholder(tf.float32, shape=[batch_size, hidden_size]) prev_hidden = tf.placeholder(tf.float32, shape=[batch_size, hidden_size]) err = "is_training flag must be explicitly specified" with self.assertRaisesRegexp(ValueError, err): lstm = snt.BatchNormLSTM( hidden_size, use_batch_norm_h=use_batch_norm_h, use_batch_norm_x=use_batch_norm_x, use_batch_norm_c=use_batch_norm_c) lstm(inputs, (prev_cell, prev_hidden))
def testSameInStaticAndDynamicWithBatchNorm(self, use_peepholes, batch_size, max_unique_stats, seq_len): # Tests that when the cell is used in either a normal tensorflow rnn, or in # tensorflow's dynamic_rnn, that the output is the same. This is to test # test that the cores aren't doing anything funny they shouldn't be (like # relying on the number of times they've been invoked). hidden_size = 3 input_size = 3 inputs = tf.placeholder(tf.float32, shape=[batch_size, seq_len, input_size], name="inputs") static_inputs = tf.unstack(inputs, axis=1) test_local_stats = False cell = snt.BatchNormLSTM( hidden_size=hidden_size, max_unique_stats=max_unique_stats, use_peepholes=use_peepholes, use_batch_norm_h=True, use_batch_norm_x=True, use_batch_norm_c=True) # Connect static in training and test modes train_static_output_unpacked, _ = tf.contrib.rnn.static_rnn( cell.with_batch_norm_control(is_training=True, test_local_stats=test_local_stats), static_inputs, initial_state=cell.initial_state(batch_size, tf.float32)) test_static_output_unpacked, _ = tf.contrib.rnn.static_rnn( cell.with_batch_norm_control(is_training=False, test_local_stats=test_local_stats), static_inputs, initial_state=cell.initial_state(batch_size, tf.float32)) # Connect dynamic in training and test modes train_dynamic_output, _ = tf.nn.dynamic_rnn( cell.with_batch_norm_control(is_training=True, test_local_stats=test_local_stats), inputs, initial_state=cell.initial_state(batch_size, tf.float32), dtype=tf.float32) test_dynamic_output, _ = tf.nn.dynamic_rnn( cell.with_batch_norm_control(is_training=False, test_local_stats=test_local_stats), inputs, initial_state=cell.initial_state(batch_size, tf.float32), dtype=tf.float32) train_static_output = tf.stack(train_static_output_unpacked, axis=1) test_static_output = tf.stack(test_static_output_unpacked, axis=1) with self.test_session() as session: tf.global_variables_initializer().run() def check_static_and_dynamic(training): # Check that static and dynamic give the same output input_data = np.random.rand(batch_size, seq_len, input_size) if training: ops = [train_static_output, train_dynamic_output] else: ops = [test_static_output, test_dynamic_output] static_out, dynamic_out = session.run(ops, feed_dict={inputs: input_data}) self.assertAllClose(static_out, dynamic_out) # Do a pass to train the exponential moving statistics. for _ in range(5): check_static_and_dynamic(True) # And check that same when using test statistics. check_static_and_dynamic(False)