def test_simple_model_forward_pass(self):
    input_features = tf.constant([[1.0, 2.0, 3.0]])
    output_sizes = {'a': 4}
    outputs, activations = models.simple_model(
        input_features,
        output_sizes,
        sequential_inputs=False,
        is_training=True,
        num_hidden_nodes=2,
        weight_initializer=tf.initializers.ones(),
        bias_initializer=tf.initializers.zeros(),
        weight_max_norm=0.0,
        use_batch_norm=False,
        dropout_rate=0.0,
        num_fcs_per_block=2,
        num_fc_blocks=3)

    with self.session() as sess:
      sess.run(tf.initializers.global_variables())
      outputs_result, activations_result = sess.run([outputs, activations])

    self.assertCountEqual(outputs_result.keys(), ['a'])
    self.assertAllClose(outputs_result['a'], [[1500.0, 1500.0, 1500.0, 1500.0]])
    self.assertCountEqual(activations_result.keys(), ['base_activations'])
    self.assertAllClose(activations_result['base_activations'],
                        [[750.0, 750.0]])
  def test_simple_model_shapes(self):
    # Shape = [4, 2, 3].
    input_features = tf.constant([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
                                  [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]],
                                  [[13.0, 14.0, 15.0], [16.0, 17.0, 18.0]],
                                  [[19.0, 20.0, 21.0], [22.0, 23.0, 24.0]]])
    output_sizes = {'a': 8, 'b': [4, 3]}
    outputs, activations = models.simple_model(
        input_features,
        output_sizes,
        sequential_inputs=False,
        is_training=True,
        num_bottleneck_nodes=16)

    expected_global_variable_shapes = {
        'SimpleModel/InputFC/Linear/weight:0': ([3, 1024]),
        'SimpleModel/InputFC/Linear/bias:0': ([1024]),
        'SimpleModel/InputFC/BatchNorm/gamma:0': ([1024]),
        'SimpleModel/InputFC/BatchNorm/beta:0': ([1024]),
        'SimpleModel/InputFC/BatchNorm/moving_mean:0': ([1024]),
        'SimpleModel/InputFC/BatchNorm/moving_variance:0': ([1024]),
        'SimpleModel/FullyConnectedBlock_0/FC_0/Linear/weight:0': ([1024,
                                                                    1024]),
        'SimpleModel/FullyConnectedBlock_0/FC_0/Linear/bias:0': ([1024]),
        'SimpleModel/FullyConnectedBlock_0/FC_0/BatchNorm/gamma:0': ([1024]),
        'SimpleModel/FullyConnectedBlock_0/FC_0/BatchNorm/beta:0': ([1024]),
        'SimpleModel/FullyConnectedBlock_0/FC_0/BatchNorm/moving_mean:0':
            ([1024]),
        'SimpleModel/FullyConnectedBlock_0/FC_0/BatchNorm/moving_variance:0':
            ([1024]),
        'SimpleModel/FullyConnectedBlock_0/FC_1/Linear/weight:0': ([1024,
                                                                    1024]),
        'SimpleModel/FullyConnectedBlock_0/FC_1/Linear/bias:0': ([1024]),
        'SimpleModel/FullyConnectedBlock_0/FC_1/BatchNorm/gamma:0': ([1024]),
        'SimpleModel/FullyConnectedBlock_0/FC_1/BatchNorm/beta:0': ([1024]),
        'SimpleModel/FullyConnectedBlock_0/FC_1/BatchNorm/moving_mean:0':
            ([1024]),
        'SimpleModel/FullyConnectedBlock_0/FC_1/BatchNorm/moving_variance:0':
            ([1024]),
        'SimpleModel/FullyConnectedBlock_1/FC_0/Linear/weight:0': ([1024,
                                                                    1024]),
        'SimpleModel/FullyConnectedBlock_1/FC_0/Linear/bias:0': ([1024]),
        'SimpleModel/FullyConnectedBlock_1/FC_0/BatchNorm/gamma:0': ([1024]),
        'SimpleModel/FullyConnectedBlock_1/FC_0/BatchNorm/beta:0': ([1024]),
        'SimpleModel/FullyConnectedBlock_1/FC_0/BatchNorm/moving_mean:0':
            ([1024]),
        'SimpleModel/FullyConnectedBlock_1/FC_0/BatchNorm/moving_variance:0':
            ([1024]),
        'SimpleModel/FullyConnectedBlock_1/FC_1/Linear/weight:0': ([1024,
                                                                    1024]),
        'SimpleModel/FullyConnectedBlock_1/FC_1/Linear/bias:0': ([1024]),
        'SimpleModel/FullyConnectedBlock_1/FC_1/BatchNorm/gamma:0': ([1024]),
        'SimpleModel/FullyConnectedBlock_1/FC_1/BatchNorm/beta:0': ([1024]),
        'SimpleModel/FullyConnectedBlock_1/FC_1/BatchNorm/moving_mean:0':
            ([1024]),
        'SimpleModel/FullyConnectedBlock_1/FC_1/BatchNorm/moving_variance:0':
            ([1024]),
        'SimpleModel/BottleneckLogits/weight:0': ([1024, 16]),
        'SimpleModel/BottleneckLogits/bias:0': ([16]),
        'SimpleModel/OutputLogits/a/weight:0': ([16, 8]),
        'SimpleModel/OutputLogits/a/bias:0': ([8]),
        'SimpleModel/OutputLogits/b/weight:0': ([16, 12]),
        'SimpleModel/OutputLogits/b/bias:0': ([12]),
    }

    self.assertDictEqual(
        {var.name: var.shape.as_list() for var in tf.global_variables()},
        expected_global_variable_shapes)
    self.assertCountEqual(outputs.keys(), ['a', 'b'])
    self.assertAllEqual(outputs['a'].shape.as_list(), [4, 2, 8])
    self.assertAllEqual(outputs['b'].shape.as_list(), [4, 2, 4, 3])
    self.assertCountEqual(activations.keys(),
                          ['base_activations', 'bottleneck_activations'])
    self.assertAllEqual(activations['base_activations'].shape.as_list(),
                        [4, 2, 1024])
    self.assertAllEqual(activations['bottleneck_activations'].shape.as_list(),
                        [4, 2, 16])
    def test_add_moving_average(self):
        inputs = tf.zeros([4, 2, 3])
        output_sizes = {'a': 8, 'b': 4}
        models.simple_model(inputs,
                            output_sizes,
                            sequential_inputs=False,
                            is_training=True,
                            name='M')
        pipeline_utils.add_moving_average(decay=0.9999)

        expected_global_variable_shapes = {
            'M/InputFC/Linear/weight:0': ([3, 1024]),
            'M/InputFC/Linear/bias:0': ([1024]),
            'M/InputFC/BatchNorm/gamma:0': ([1024]),
            'M/InputFC/BatchNorm/beta:0': ([1024]),
            'M/InputFC/BatchNorm/moving_mean:0': ([1024]),
            'M/InputFC/BatchNorm/moving_variance:0': ([1024]),
            'M/FullyConnectedBlock_0/FC_0/Linear/weight:0': ([1024, 1024]),
            'M/FullyConnectedBlock_0/FC_0/Linear/bias:0': ([1024]),
            'M/FullyConnectedBlock_0/FC_0/BatchNorm/gamma:0': ([1024]),
            'M/FullyConnectedBlock_0/FC_0/BatchNorm/beta:0': ([1024]),
            'M/FullyConnectedBlock_0/FC_0/BatchNorm/moving_mean:0': ([1024]),
            'M/FullyConnectedBlock_0/FC_0/BatchNorm/moving_variance:0':
            ([1024]),
            'M/FullyConnectedBlock_0/FC_1/Linear/weight:0': ([1024, 1024]),
            'M/FullyConnectedBlock_0/FC_1/Linear/bias:0': ([1024]),
            'M/FullyConnectedBlock_0/FC_1/BatchNorm/gamma:0': ([1024]),
            'M/FullyConnectedBlock_0/FC_1/BatchNorm/beta:0': ([1024]),
            'M/FullyConnectedBlock_0/FC_1/BatchNorm/moving_mean:0': ([1024]),
            'M/FullyConnectedBlock_0/FC_1/BatchNorm/moving_variance:0':
            ([1024]),
            'M/FullyConnectedBlock_1/FC_0/Linear/weight:0': ([1024, 1024]),
            'M/FullyConnectedBlock_1/FC_0/Linear/bias:0': ([1024]),
            'M/FullyConnectedBlock_1/FC_0/BatchNorm/gamma:0': ([1024]),
            'M/FullyConnectedBlock_1/FC_0/BatchNorm/beta:0': ([1024]),
            'M/FullyConnectedBlock_1/FC_0/BatchNorm/moving_mean:0': ([1024]),
            'M/FullyConnectedBlock_1/FC_0/BatchNorm/moving_variance:0':
            ([1024]),
            'M/FullyConnectedBlock_1/FC_1/Linear/weight:0': ([1024, 1024]),
            'M/FullyConnectedBlock_1/FC_1/Linear/bias:0': ([1024]),
            'M/FullyConnectedBlock_1/FC_1/BatchNorm/gamma:0': ([1024]),
            'M/FullyConnectedBlock_1/FC_1/BatchNorm/beta:0': ([1024]),
            'M/FullyConnectedBlock_1/FC_1/BatchNorm/moving_mean:0': ([1024]),
            'M/FullyConnectedBlock_1/FC_1/BatchNorm/moving_variance:0':
            ([1024]),
            'M/OutputLogits/a/weight:0': ([1024, 8]),
            'M/OutputLogits/a/bias:0': ([8]),
            'M/OutputLogits/b/weight:0': ([1024, 4]),
            'M/OutputLogits/b/bias:0': ([4]),
            'M/InputFC/Linear/weight/ExponentialMovingAverage:0': ([3, 1024]),
            'M/InputFC/Linear/bias/ExponentialMovingAverage:0': ([1024]),
            'M/InputFC/BatchNorm/gamma/ExponentialMovingAverage:0': ([1024]),
            'M/InputFC/BatchNorm/beta/ExponentialMovingAverage:0': ([1024]),
            'M/FullyConnectedBlock_0/FC_0/Linear/weight/ExponentialMovingAverage:0':
            ([1024, 1024]),
            'M/FullyConnectedBlock_0/FC_0/Linear/bias/ExponentialMovingAverage:0':
            ([1024]),
            'M/FullyConnectedBlock_0/FC_0/BatchNorm/gamma/ExponentialMovingAverage:0':
            ([1024]),
            'M/FullyConnectedBlock_0/FC_0/BatchNorm/beta/ExponentialMovingAverage:0':
            ([1024]),
            'M/FullyConnectedBlock_0/FC_1/Linear/weight/ExponentialMovingAverage:0':
            ([1024, 1024]),
            'M/FullyConnectedBlock_0/FC_1/Linear/bias/ExponentialMovingAverage:0':
            ([1024]),
            'M/FullyConnectedBlock_0/FC_1/BatchNorm/gamma/ExponentialMovingAverage:0':
            ([1024]),
            'M/FullyConnectedBlock_0/FC_1/BatchNorm/beta/ExponentialMovingAverage:0':
            ([1024]),
            'M/FullyConnectedBlock_1/FC_0/Linear/weight/ExponentialMovingAverage:0':
            ([1024, 1024]),
            'M/FullyConnectedBlock_1/FC_0/Linear/bias/ExponentialMovingAverage:0':
            ([1024]),
            'M/FullyConnectedBlock_1/FC_0/BatchNorm/gamma/ExponentialMovingAverage:0':
            ([1024]),
            'M/FullyConnectedBlock_1/FC_0/BatchNorm/beta/ExponentialMovingAverage:0':
            ([1024]),
            'M/FullyConnectedBlock_1/FC_1/Linear/weight/ExponentialMovingAverage:0':
            ([1024, 1024]),
            'M/FullyConnectedBlock_1/FC_1/Linear/bias/ExponentialMovingAverage:0':
            ([1024]),
            'M/FullyConnectedBlock_1/FC_1/BatchNorm/gamma/ExponentialMovingAverage:0':
            ([1024]),
            'M/FullyConnectedBlock_1/FC_1/BatchNorm/beta/ExponentialMovingAverage:0':
            ([1024]),
            'M/OutputLogits/a/weight/ExponentialMovingAverage:0': ([1024, 8]),
            'M/OutputLogits/a/bias/ExponentialMovingAverage:0': ([8]),
            'M/OutputLogits/b/weight/ExponentialMovingAverage:0': ([1024, 4]),
            'M/OutputLogits/b/bias/ExponentialMovingAverage:0': ([4]),
            'global_step:0': ([]),
        }
        self.assertDictEqual(
            {var.name: var.shape.as_list()
             for var in tf.global_variables()},
            expected_global_variable_shapes)
    def test_get_moving_average_variables_to_restore(self):
        inputs = tf.zeros([4, 2, 3])
        output_sizes = {'a': 8, 'b': 4}
        models.simple_model(inputs,
                            output_sizes,
                            sequential_inputs=False,
                            is_training=False,
                            name='M')
        variables_to_restore = (
            pipeline_utils.get_moving_average_variables_to_restore())

        expected_variable_to_restore_names = {
            'M/InputFC/Linear/weight/ExponentialMovingAverage':
            'M/InputFC/Linear/weight:0',
            'M/InputFC/Linear/bias/ExponentialMovingAverage':
            'M/InputFC/Linear/bias:0',
            'M/InputFC/BatchNorm/gamma/ExponentialMovingAverage':
            'M/InputFC/BatchNorm/gamma:0',
            'M/InputFC/BatchNorm/beta/ExponentialMovingAverage':
            'M/InputFC/BatchNorm/beta:0',
            'M/InputFC/BatchNorm/moving_mean':
            'M/InputFC/BatchNorm/moving_mean:0',
            'M/InputFC/BatchNorm/moving_variance':
            'M/InputFC/BatchNorm/moving_variance:0',
            'M/FullyConnectedBlock_0/FC_0/Linear/weight/ExponentialMovingAverage':
            'M/FullyConnectedBlock_0/FC_0/Linear/weight:0',
            'M/FullyConnectedBlock_0/FC_0/Linear/bias/ExponentialMovingAverage':
            'M/FullyConnectedBlock_0/FC_0/Linear/bias:0',
            'M/FullyConnectedBlock_0/FC_0/BatchNorm/gamma/ExponentialMovingAverage':
            'M/FullyConnectedBlock_0/FC_0/BatchNorm/gamma:0',
            'M/FullyConnectedBlock_0/FC_0/BatchNorm/beta/ExponentialMovingAverage':
            'M/FullyConnectedBlock_0/FC_0/BatchNorm/beta:0',
            'M/FullyConnectedBlock_0/FC_0/BatchNorm/moving_mean':
            'M/FullyConnectedBlock_0/FC_0/BatchNorm/moving_mean:0',
            'M/FullyConnectedBlock_0/FC_0/BatchNorm/moving_variance':
            'M/FullyConnectedBlock_0/FC_0/BatchNorm/moving_variance:0',
            'M/FullyConnectedBlock_0/FC_1/Linear/weight/ExponentialMovingAverage':
            'M/FullyConnectedBlock_0/FC_1/Linear/weight:0',
            'M/FullyConnectedBlock_0/FC_1/Linear/bias/ExponentialMovingAverage':
            'M/FullyConnectedBlock_0/FC_1/Linear/bias:0',
            'M/FullyConnectedBlock_0/FC_1/BatchNorm/gamma/ExponentialMovingAverage':
            'M/FullyConnectedBlock_0/FC_1/BatchNorm/gamma:0',
            'M/FullyConnectedBlock_0/FC_1/BatchNorm/beta/ExponentialMovingAverage':
            'M/FullyConnectedBlock_0/FC_1/BatchNorm/beta:0',
            'M/FullyConnectedBlock_0/FC_1/BatchNorm/moving_mean':
            'M/FullyConnectedBlock_0/FC_1/BatchNorm/moving_mean:0',
            'M/FullyConnectedBlock_0/FC_1/BatchNorm/moving_variance':
            'M/FullyConnectedBlock_0/FC_1/BatchNorm/moving_variance:0',
            'M/FullyConnectedBlock_1/FC_0/Linear/weight/ExponentialMovingAverage':
            'M/FullyConnectedBlock_1/FC_0/Linear/weight:0',
            'M/FullyConnectedBlock_1/FC_0/Linear/bias/ExponentialMovingAverage':
            'M/FullyConnectedBlock_1/FC_0/Linear/bias:0',
            'M/FullyConnectedBlock_1/FC_0/BatchNorm/gamma/ExponentialMovingAverage':
            'M/FullyConnectedBlock_1/FC_0/BatchNorm/gamma:0',
            'M/FullyConnectedBlock_1/FC_0/BatchNorm/beta/ExponentialMovingAverage':
            'M/FullyConnectedBlock_1/FC_0/BatchNorm/beta:0',
            'M/FullyConnectedBlock_1/FC_0/BatchNorm/moving_mean':
            'M/FullyConnectedBlock_1/FC_0/BatchNorm/moving_mean:0',
            'M/FullyConnectedBlock_1/FC_0/BatchNorm/moving_variance':
            'M/FullyConnectedBlock_1/FC_0/BatchNorm/moving_variance:0',
            'M/FullyConnectedBlock_1/FC_1/Linear/weight/ExponentialMovingAverage':
            'M/FullyConnectedBlock_1/FC_1/Linear/weight:0',
            'M/FullyConnectedBlock_1/FC_1/Linear/bias/ExponentialMovingAverage':
            'M/FullyConnectedBlock_1/FC_1/Linear/bias:0',
            'M/FullyConnectedBlock_1/FC_1/BatchNorm/gamma/ExponentialMovingAverage':
            'M/FullyConnectedBlock_1/FC_1/BatchNorm/gamma:0',
            'M/FullyConnectedBlock_1/FC_1/BatchNorm/beta/ExponentialMovingAverage':
            'M/FullyConnectedBlock_1/FC_1/BatchNorm/beta:0',
            'M/FullyConnectedBlock_1/FC_1/BatchNorm/moving_mean':
            'M/FullyConnectedBlock_1/FC_1/BatchNorm/moving_mean:0',
            'M/FullyConnectedBlock_1/FC_1/BatchNorm/moving_variance':
            'M/FullyConnectedBlock_1/FC_1/BatchNorm/moving_variance:0',
            'M/OutputLogits/a/weight/ExponentialMovingAverage':
            'M/OutputLogits/a/weight:0',
            'M/OutputLogits/a/bias/ExponentialMovingAverage':
            'M/OutputLogits/a/bias:0',
            'M/OutputLogits/b/weight/ExponentialMovingAverage':
            'M/OutputLogits/b/weight:0',
            'M/OutputLogits/b/bias/ExponentialMovingAverage':
            'M/OutputLogits/b/bias:0',
        }
        self.assertDictEqual(
            {key: var.name
             for key, var in variables_to_restore.items()},
            expected_variable_to_restore_names)