def test_simple_model_forward_pass(self): input_features = tf.constant([[1.0, 2.0, 3.0]]) output_sizes = {'a': 4} outputs, activations = models.simple_model( input_features, output_sizes, sequential_inputs=False, is_training=True, num_hidden_nodes=2, weight_initializer=tf.initializers.ones(), bias_initializer=tf.initializers.zeros(), weight_max_norm=0.0, use_batch_norm=False, dropout_rate=0.0, num_fcs_per_block=2, num_fc_blocks=3) with self.session() as sess: sess.run(tf.initializers.global_variables()) outputs_result, activations_result = sess.run([outputs, activations]) self.assertCountEqual(outputs_result.keys(), ['a']) self.assertAllClose(outputs_result['a'], [[1500.0, 1500.0, 1500.0, 1500.0]]) self.assertCountEqual(activations_result.keys(), ['base_activations']) self.assertAllClose(activations_result['base_activations'], [[750.0, 750.0]])
def test_simple_model_shapes(self): # Shape = [4, 2, 3]. input_features = tf.constant([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], [[13.0, 14.0, 15.0], [16.0, 17.0, 18.0]], [[19.0, 20.0, 21.0], [22.0, 23.0, 24.0]]]) output_sizes = {'a': 8, 'b': [4, 3]} outputs, activations = models.simple_model( input_features, output_sizes, sequential_inputs=False, is_training=True, num_bottleneck_nodes=16) expected_global_variable_shapes = { 'SimpleModel/InputFC/Linear/weight:0': ([3, 1024]), 'SimpleModel/InputFC/Linear/bias:0': ([1024]), 'SimpleModel/InputFC/BatchNorm/gamma:0': ([1024]), 'SimpleModel/InputFC/BatchNorm/beta:0': ([1024]), 'SimpleModel/InputFC/BatchNorm/moving_mean:0': ([1024]), 'SimpleModel/InputFC/BatchNorm/moving_variance:0': ([1024]), 'SimpleModel/FullyConnectedBlock_0/FC_0/Linear/weight:0': ([1024, 1024]), 'SimpleModel/FullyConnectedBlock_0/FC_0/Linear/bias:0': ([1024]), 'SimpleModel/FullyConnectedBlock_0/FC_0/BatchNorm/gamma:0': ([1024]), 'SimpleModel/FullyConnectedBlock_0/FC_0/BatchNorm/beta:0': ([1024]), 'SimpleModel/FullyConnectedBlock_0/FC_0/BatchNorm/moving_mean:0': ([1024]), 'SimpleModel/FullyConnectedBlock_0/FC_0/BatchNorm/moving_variance:0': ([1024]), 'SimpleModel/FullyConnectedBlock_0/FC_1/Linear/weight:0': ([1024, 1024]), 'SimpleModel/FullyConnectedBlock_0/FC_1/Linear/bias:0': ([1024]), 'SimpleModel/FullyConnectedBlock_0/FC_1/BatchNorm/gamma:0': ([1024]), 'SimpleModel/FullyConnectedBlock_0/FC_1/BatchNorm/beta:0': ([1024]), 'SimpleModel/FullyConnectedBlock_0/FC_1/BatchNorm/moving_mean:0': ([1024]), 'SimpleModel/FullyConnectedBlock_0/FC_1/BatchNorm/moving_variance:0': ([1024]), 'SimpleModel/FullyConnectedBlock_1/FC_0/Linear/weight:0': ([1024, 1024]), 'SimpleModel/FullyConnectedBlock_1/FC_0/Linear/bias:0': ([1024]), 'SimpleModel/FullyConnectedBlock_1/FC_0/BatchNorm/gamma:0': ([1024]), 'SimpleModel/FullyConnectedBlock_1/FC_0/BatchNorm/beta:0': ([1024]), 'SimpleModel/FullyConnectedBlock_1/FC_0/BatchNorm/moving_mean:0': ([1024]), 'SimpleModel/FullyConnectedBlock_1/FC_0/BatchNorm/moving_variance:0': ([1024]), 'SimpleModel/FullyConnectedBlock_1/FC_1/Linear/weight:0': ([1024, 1024]), 'SimpleModel/FullyConnectedBlock_1/FC_1/Linear/bias:0': ([1024]), 'SimpleModel/FullyConnectedBlock_1/FC_1/BatchNorm/gamma:0': ([1024]), 'SimpleModel/FullyConnectedBlock_1/FC_1/BatchNorm/beta:0': ([1024]), 'SimpleModel/FullyConnectedBlock_1/FC_1/BatchNorm/moving_mean:0': ([1024]), 'SimpleModel/FullyConnectedBlock_1/FC_1/BatchNorm/moving_variance:0': ([1024]), 'SimpleModel/BottleneckLogits/weight:0': ([1024, 16]), 'SimpleModel/BottleneckLogits/bias:0': ([16]), 'SimpleModel/OutputLogits/a/weight:0': ([16, 8]), 'SimpleModel/OutputLogits/a/bias:0': ([8]), 'SimpleModel/OutputLogits/b/weight:0': ([16, 12]), 'SimpleModel/OutputLogits/b/bias:0': ([12]), } self.assertDictEqual( {var.name: var.shape.as_list() for var in tf.global_variables()}, expected_global_variable_shapes) self.assertCountEqual(outputs.keys(), ['a', 'b']) self.assertAllEqual(outputs['a'].shape.as_list(), [4, 2, 8]) self.assertAllEqual(outputs['b'].shape.as_list(), [4, 2, 4, 3]) self.assertCountEqual(activations.keys(), ['base_activations', 'bottleneck_activations']) self.assertAllEqual(activations['base_activations'].shape.as_list(), [4, 2, 1024]) self.assertAllEqual(activations['bottleneck_activations'].shape.as_list(), [4, 2, 16])
def test_add_moving_average(self): inputs = tf.zeros([4, 2, 3]) output_sizes = {'a': 8, 'b': 4} models.simple_model(inputs, output_sizes, sequential_inputs=False, is_training=True, name='M') pipeline_utils.add_moving_average(decay=0.9999) expected_global_variable_shapes = { 'M/InputFC/Linear/weight:0': ([3, 1024]), 'M/InputFC/Linear/bias:0': ([1024]), 'M/InputFC/BatchNorm/gamma:0': ([1024]), 'M/InputFC/BatchNorm/beta:0': ([1024]), 'M/InputFC/BatchNorm/moving_mean:0': ([1024]), 'M/InputFC/BatchNorm/moving_variance:0': ([1024]), 'M/FullyConnectedBlock_0/FC_0/Linear/weight:0': ([1024, 1024]), 'M/FullyConnectedBlock_0/FC_0/Linear/bias:0': ([1024]), 'M/FullyConnectedBlock_0/FC_0/BatchNorm/gamma:0': ([1024]), 'M/FullyConnectedBlock_0/FC_0/BatchNorm/beta:0': ([1024]), 'M/FullyConnectedBlock_0/FC_0/BatchNorm/moving_mean:0': ([1024]), 'M/FullyConnectedBlock_0/FC_0/BatchNorm/moving_variance:0': ([1024]), 'M/FullyConnectedBlock_0/FC_1/Linear/weight:0': ([1024, 1024]), 'M/FullyConnectedBlock_0/FC_1/Linear/bias:0': ([1024]), 'M/FullyConnectedBlock_0/FC_1/BatchNorm/gamma:0': ([1024]), 'M/FullyConnectedBlock_0/FC_1/BatchNorm/beta:0': ([1024]), 'M/FullyConnectedBlock_0/FC_1/BatchNorm/moving_mean:0': ([1024]), 'M/FullyConnectedBlock_0/FC_1/BatchNorm/moving_variance:0': ([1024]), 'M/FullyConnectedBlock_1/FC_0/Linear/weight:0': ([1024, 1024]), 'M/FullyConnectedBlock_1/FC_0/Linear/bias:0': ([1024]), 'M/FullyConnectedBlock_1/FC_0/BatchNorm/gamma:0': ([1024]), 'M/FullyConnectedBlock_1/FC_0/BatchNorm/beta:0': ([1024]), 'M/FullyConnectedBlock_1/FC_0/BatchNorm/moving_mean:0': ([1024]), 'M/FullyConnectedBlock_1/FC_0/BatchNorm/moving_variance:0': ([1024]), 'M/FullyConnectedBlock_1/FC_1/Linear/weight:0': ([1024, 1024]), 'M/FullyConnectedBlock_1/FC_1/Linear/bias:0': ([1024]), 'M/FullyConnectedBlock_1/FC_1/BatchNorm/gamma:0': ([1024]), 'M/FullyConnectedBlock_1/FC_1/BatchNorm/beta:0': ([1024]), 'M/FullyConnectedBlock_1/FC_1/BatchNorm/moving_mean:0': ([1024]), 'M/FullyConnectedBlock_1/FC_1/BatchNorm/moving_variance:0': ([1024]), 'M/OutputLogits/a/weight:0': ([1024, 8]), 'M/OutputLogits/a/bias:0': ([8]), 'M/OutputLogits/b/weight:0': ([1024, 4]), 'M/OutputLogits/b/bias:0': ([4]), 'M/InputFC/Linear/weight/ExponentialMovingAverage:0': ([3, 1024]), 'M/InputFC/Linear/bias/ExponentialMovingAverage:0': ([1024]), 'M/InputFC/BatchNorm/gamma/ExponentialMovingAverage:0': ([1024]), 'M/InputFC/BatchNorm/beta/ExponentialMovingAverage:0': ([1024]), 'M/FullyConnectedBlock_0/FC_0/Linear/weight/ExponentialMovingAverage:0': ([1024, 1024]), 'M/FullyConnectedBlock_0/FC_0/Linear/bias/ExponentialMovingAverage:0': ([1024]), 'M/FullyConnectedBlock_0/FC_0/BatchNorm/gamma/ExponentialMovingAverage:0': ([1024]), 'M/FullyConnectedBlock_0/FC_0/BatchNorm/beta/ExponentialMovingAverage:0': ([1024]), 'M/FullyConnectedBlock_0/FC_1/Linear/weight/ExponentialMovingAverage:0': ([1024, 1024]), 'M/FullyConnectedBlock_0/FC_1/Linear/bias/ExponentialMovingAverage:0': ([1024]), 'M/FullyConnectedBlock_0/FC_1/BatchNorm/gamma/ExponentialMovingAverage:0': ([1024]), 'M/FullyConnectedBlock_0/FC_1/BatchNorm/beta/ExponentialMovingAverage:0': ([1024]), 'M/FullyConnectedBlock_1/FC_0/Linear/weight/ExponentialMovingAverage:0': ([1024, 1024]), 'M/FullyConnectedBlock_1/FC_0/Linear/bias/ExponentialMovingAverage:0': ([1024]), 'M/FullyConnectedBlock_1/FC_0/BatchNorm/gamma/ExponentialMovingAverage:0': ([1024]), 'M/FullyConnectedBlock_1/FC_0/BatchNorm/beta/ExponentialMovingAverage:0': ([1024]), 'M/FullyConnectedBlock_1/FC_1/Linear/weight/ExponentialMovingAverage:0': ([1024, 1024]), 'M/FullyConnectedBlock_1/FC_1/Linear/bias/ExponentialMovingAverage:0': ([1024]), 'M/FullyConnectedBlock_1/FC_1/BatchNorm/gamma/ExponentialMovingAverage:0': ([1024]), 'M/FullyConnectedBlock_1/FC_1/BatchNorm/beta/ExponentialMovingAverage:0': ([1024]), 'M/OutputLogits/a/weight/ExponentialMovingAverage:0': ([1024, 8]), 'M/OutputLogits/a/bias/ExponentialMovingAverage:0': ([8]), 'M/OutputLogits/b/weight/ExponentialMovingAverage:0': ([1024, 4]), 'M/OutputLogits/b/bias/ExponentialMovingAverage:0': ([4]), 'global_step:0': ([]), } self.assertDictEqual( {var.name: var.shape.as_list() for var in tf.global_variables()}, expected_global_variable_shapes)
def test_get_moving_average_variables_to_restore(self): inputs = tf.zeros([4, 2, 3]) output_sizes = {'a': 8, 'b': 4} models.simple_model(inputs, output_sizes, sequential_inputs=False, is_training=False, name='M') variables_to_restore = ( pipeline_utils.get_moving_average_variables_to_restore()) expected_variable_to_restore_names = { 'M/InputFC/Linear/weight/ExponentialMovingAverage': 'M/InputFC/Linear/weight:0', 'M/InputFC/Linear/bias/ExponentialMovingAverage': 'M/InputFC/Linear/bias:0', 'M/InputFC/BatchNorm/gamma/ExponentialMovingAverage': 'M/InputFC/BatchNorm/gamma:0', 'M/InputFC/BatchNorm/beta/ExponentialMovingAverage': 'M/InputFC/BatchNorm/beta:0', 'M/InputFC/BatchNorm/moving_mean': 'M/InputFC/BatchNorm/moving_mean:0', 'M/InputFC/BatchNorm/moving_variance': 'M/InputFC/BatchNorm/moving_variance:0', 'M/FullyConnectedBlock_0/FC_0/Linear/weight/ExponentialMovingAverage': 'M/FullyConnectedBlock_0/FC_0/Linear/weight:0', 'M/FullyConnectedBlock_0/FC_0/Linear/bias/ExponentialMovingAverage': 'M/FullyConnectedBlock_0/FC_0/Linear/bias:0', 'M/FullyConnectedBlock_0/FC_0/BatchNorm/gamma/ExponentialMovingAverage': 'M/FullyConnectedBlock_0/FC_0/BatchNorm/gamma:0', 'M/FullyConnectedBlock_0/FC_0/BatchNorm/beta/ExponentialMovingAverage': 'M/FullyConnectedBlock_0/FC_0/BatchNorm/beta:0', 'M/FullyConnectedBlock_0/FC_0/BatchNorm/moving_mean': 'M/FullyConnectedBlock_0/FC_0/BatchNorm/moving_mean:0', 'M/FullyConnectedBlock_0/FC_0/BatchNorm/moving_variance': 'M/FullyConnectedBlock_0/FC_0/BatchNorm/moving_variance:0', 'M/FullyConnectedBlock_0/FC_1/Linear/weight/ExponentialMovingAverage': 'M/FullyConnectedBlock_0/FC_1/Linear/weight:0', 'M/FullyConnectedBlock_0/FC_1/Linear/bias/ExponentialMovingAverage': 'M/FullyConnectedBlock_0/FC_1/Linear/bias:0', 'M/FullyConnectedBlock_0/FC_1/BatchNorm/gamma/ExponentialMovingAverage': 'M/FullyConnectedBlock_0/FC_1/BatchNorm/gamma:0', 'M/FullyConnectedBlock_0/FC_1/BatchNorm/beta/ExponentialMovingAverage': 'M/FullyConnectedBlock_0/FC_1/BatchNorm/beta:0', 'M/FullyConnectedBlock_0/FC_1/BatchNorm/moving_mean': 'M/FullyConnectedBlock_0/FC_1/BatchNorm/moving_mean:0', 'M/FullyConnectedBlock_0/FC_1/BatchNorm/moving_variance': 'M/FullyConnectedBlock_0/FC_1/BatchNorm/moving_variance:0', 'M/FullyConnectedBlock_1/FC_0/Linear/weight/ExponentialMovingAverage': 'M/FullyConnectedBlock_1/FC_0/Linear/weight:0', 'M/FullyConnectedBlock_1/FC_0/Linear/bias/ExponentialMovingAverage': 'M/FullyConnectedBlock_1/FC_0/Linear/bias:0', 'M/FullyConnectedBlock_1/FC_0/BatchNorm/gamma/ExponentialMovingAverage': 'M/FullyConnectedBlock_1/FC_0/BatchNorm/gamma:0', 'M/FullyConnectedBlock_1/FC_0/BatchNorm/beta/ExponentialMovingAverage': 'M/FullyConnectedBlock_1/FC_0/BatchNorm/beta:0', 'M/FullyConnectedBlock_1/FC_0/BatchNorm/moving_mean': 'M/FullyConnectedBlock_1/FC_0/BatchNorm/moving_mean:0', 'M/FullyConnectedBlock_1/FC_0/BatchNorm/moving_variance': 'M/FullyConnectedBlock_1/FC_0/BatchNorm/moving_variance:0', 'M/FullyConnectedBlock_1/FC_1/Linear/weight/ExponentialMovingAverage': 'M/FullyConnectedBlock_1/FC_1/Linear/weight:0', 'M/FullyConnectedBlock_1/FC_1/Linear/bias/ExponentialMovingAverage': 'M/FullyConnectedBlock_1/FC_1/Linear/bias:0', 'M/FullyConnectedBlock_1/FC_1/BatchNorm/gamma/ExponentialMovingAverage': 'M/FullyConnectedBlock_1/FC_1/BatchNorm/gamma:0', 'M/FullyConnectedBlock_1/FC_1/BatchNorm/beta/ExponentialMovingAverage': 'M/FullyConnectedBlock_1/FC_1/BatchNorm/beta:0', 'M/FullyConnectedBlock_1/FC_1/BatchNorm/moving_mean': 'M/FullyConnectedBlock_1/FC_1/BatchNorm/moving_mean:0', 'M/FullyConnectedBlock_1/FC_1/BatchNorm/moving_variance': 'M/FullyConnectedBlock_1/FC_1/BatchNorm/moving_variance:0', 'M/OutputLogits/a/weight/ExponentialMovingAverage': 'M/OutputLogits/a/weight:0', 'M/OutputLogits/a/bias/ExponentialMovingAverage': 'M/OutputLogits/a/bias:0', 'M/OutputLogits/b/weight/ExponentialMovingAverage': 'M/OutputLogits/b/weight:0', 'M/OutputLogits/b/bias/ExponentialMovingAverage': 'M/OutputLogits/b/bias:0', } self.assertDictEqual( {key: var.name for key, var in variables_to_restore.items()}, expected_variable_to_restore_names)