Example #1
0
    def testTraining(self, trainable_initial_state, max_unique_stats):
        """Test that everything trains OK, with or without trainable init. state."""
        hidden_size = 3
        batch_size = 3
        time_steps = 3
        cell = snt.BatchNormLSTM(hidden_size=hidden_size,
                                 max_unique_stats=max_unique_stats)
        inputs = tf.constant(np.random.rand(batch_size, time_steps, 3),
                             dtype=tf.float32)
        initial_state = cell.initial_state(batch_size, tf.float32,
                                           trainable_initial_state)
        output, _ = tf.nn.dynamic_rnn(
            cell.with_batch_norm_control(is_training=True),
            inputs,
            initial_state=initial_state,
            dtype=tf.float32)

        loss = tf.reduce_mean(
            tf.square(output -
                      np.random.rand(batch_size, time_steps, hidden_size)))
        train_op = tf.train.GradientDescentOptimizer(1).minimize(loss)
        init = tf.global_variables_initializer()
        with self.test_session():
            init.run()
            train_op.run()
    def testBatchNormInitializersCheck(self):
        hidden_size = 4

        # Test that passing in a batchnorm initializer when we don't request
        # that form of batchnorm raises an error.
        for key, options in [(snt.BatchNormLSTM.GAMMA_H, {
                "use_batch_norm_h": False,
                "use_batch_norm_x": True
        }),
                             (snt.BatchNormLSTM.GAMMA_X, {
                                 "use_batch_norm_x": False,
                                 "use_batch_norm_h": True
                             }),
                             (snt.BatchNormLSTM.GAMMA_C, {
                                 "use_batch_norm_c": False,
                                 "use_batch_norm_h": True
                             }),
                             (snt.BatchNormLSTM.BETA_C, {
                                 "use_batch_norm_c": False,
                                 "use_batch_norm_h": True
                             })]:
            with self.assertRaisesRegexp(KeyError, "Invalid initializer"):
                snt.BatchNormLSTM(
                    hidden_size,
                    initializers={key: tf.constant_initializer(0)},
                    **options)
Example #3
0
def _construct_lstm(use_batch_norm_h=False,
                    use_batch_norm_x=False,
                    use_batch_norm_c=False,
                    **kwargs):
    # Preparing for deprecation, this uses plain LSTM if no batch norm required.
    if any([use_batch_norm_h, use_batch_norm_x, use_batch_norm_c]):
        return snt.BatchNormLSTM(use_batch_norm_h=use_batch_norm_h,
                                 use_batch_norm_x=use_batch_norm_x,
                                 use_batch_norm_c=use_batch_norm_c,
                                 **kwargs)
    else:
        return snt.LSTM(**kwargs)
Example #4
0
def _construct_lstm(use_batch_norm_h=False, use_batch_norm_x=False,
                    use_batch_norm_c=False, max_unique_stats=1, **kwargs):
  if any([use_batch_norm_h, use_batch_norm_x, use_batch_norm_c]):
    cell = snt.BatchNormLSTM(
        use_batch_norm_h=use_batch_norm_h,
        use_batch_norm_x=use_batch_norm_x,
        use_batch_norm_c=use_batch_norm_c,
        max_unique_stats=max_unique_stats,
        **kwargs)
    return cell, cell.with_batch_norm_control(is_training=True)
  else:
    cell = snt.LSTM(**kwargs)
    return cell, cell
Example #5
0
  def testBatchNormBuildFlag(self, use_batch_norm_h, use_batch_norm_x,
                             use_batch_norm_c):
    """Check if an error is raised if we don't specify the is_training flag."""
    batch_size = 2
    hidden_size = 4

    inputs = tf.placeholder(tf.float32, shape=[batch_size, hidden_size])
    prev_cell = tf.placeholder(tf.float32, shape=[batch_size, hidden_size])
    prev_hidden = tf.placeholder(tf.float32, shape=[batch_size, hidden_size])

    err = "is_training flag must be explicitly specified"
    with self.assertRaisesRegexp(ValueError, err):
      lstm = snt.BatchNormLSTM(
          hidden_size,
          use_batch_norm_h=use_batch_norm_h,
          use_batch_norm_x=use_batch_norm_x,
          use_batch_norm_c=use_batch_norm_c)
      lstm(inputs, (prev_cell, prev_hidden))
Example #6
0
  def testSameInStaticAndDynamicWithBatchNorm(self, use_peepholes, batch_size,
                                              max_unique_stats, seq_len):
    # Tests that when the cell is used in either a normal tensorflow rnn, or in
    # tensorflow's dynamic_rnn, that the output is the same. This is to test
    # test that the cores aren't doing anything funny they shouldn't be (like
    # relying on the number of times they've been invoked).

    hidden_size = 3
    input_size = 3

    inputs = tf.placeholder(tf.float32,
                            shape=[batch_size, seq_len, input_size],
                            name="inputs")
    static_inputs = tf.unstack(inputs, axis=1)

    test_local_stats = False

    cell = snt.BatchNormLSTM(
        hidden_size=hidden_size,
        max_unique_stats=max_unique_stats,
        use_peepholes=use_peepholes,
        use_batch_norm_h=True,
        use_batch_norm_x=True,
        use_batch_norm_c=True)

    # Connect static in training and test modes
    train_static_output_unpacked, _ = tf.contrib.rnn.static_rnn(
        cell.with_batch_norm_control(is_training=True,
                                     test_local_stats=test_local_stats),
        static_inputs,
        initial_state=cell.initial_state(batch_size, tf.float32))

    test_static_output_unpacked, _ = tf.contrib.rnn.static_rnn(
        cell.with_batch_norm_control(is_training=False,
                                     test_local_stats=test_local_stats),
        static_inputs,
        initial_state=cell.initial_state(batch_size, tf.float32))

    # Connect dynamic in training and test modes
    train_dynamic_output, _ = tf.nn.dynamic_rnn(
        cell.with_batch_norm_control(is_training=True,
                                     test_local_stats=test_local_stats),
        inputs,
        initial_state=cell.initial_state(batch_size, tf.float32),
        dtype=tf.float32)

    test_dynamic_output, _ = tf.nn.dynamic_rnn(
        cell.with_batch_norm_control(is_training=False,
                                     test_local_stats=test_local_stats),
        inputs,
        initial_state=cell.initial_state(batch_size, tf.float32),
        dtype=tf.float32)

    train_static_output = tf.stack(train_static_output_unpacked, axis=1)
    test_static_output = tf.stack(test_static_output_unpacked, axis=1)

    with self.test_session() as session:
      tf.global_variables_initializer().run()

      def check_static_and_dynamic(training):
        # Check that static and dynamic give the same output
        input_data = np.random.rand(batch_size, seq_len, input_size)

        if training:
          ops = [train_static_output, train_dynamic_output]
        else:
          ops = [test_static_output, test_dynamic_output]

        static_out, dynamic_out = session.run(ops,
                                              feed_dict={inputs: input_data})
        self.assertAllClose(static_out, dynamic_out)

      # Do a pass to train the exponential moving statistics.
      for _ in range(5):
        check_static_and_dynamic(True)

      # And check that same when using test statistics.
      check_static_and_dynamic(False)