예제 #1
0
    def testResnet(self, resnet_version):
        resnets = {
            'resnet_v1': (resnet_v1, 'v1'),
            'resnet_v2': (resnet_v2, 'v2')
        }
        resnet_module = resnets[resnet_version][0]

        decorator = ops.ConfigurableOps()
        hijacked_from_layers_lib = ops.hijack_module_functions(
            decorator, resnet_module.layers_lib)
        hijacked_from_utils = ops.hijack_module_functions(
            decorator, resnet_utils.layers)
        hijacked_from_module = ops.hijack_module_functions(
            decorator, resnet_module.layers)
        print('hijacked_from_layers_lib', hijacked_from_layers_lib)
        print('hijacked_from_utils', hijacked_from_utils)
        print('hijacked_from_module', hijacked_from_module)
        inputs = tf.ones([3, 16, 16, 5])
        _ = resnet_module.bottleneck(inputs,
                                     depth=64,
                                     depth_bottleneck=16,
                                     stride=1)

        self.assertLen(decorator.constructed_ops, 4)

        base_name = 'bottleneck_' + resnets[resnet_version][1]
        expected_decorated_ops = sorted([
            base_name + '/conv1/Conv2D',
            base_name + '/conv2/Conv2D',
            base_name + '/conv3/Conv2D',
            base_name + '/shortcut/Conv2D',
        ])

        self.assertAllEqual(expected_decorated_ops,
                            sorted(decorator.constructed_ops.keys()))
예제 #2
0
    def testConcatHijack(self):
        decorator = ops.ConfigurableOps()
        module = Fake()
        inputs = tf.ones([2, 3, 3, 5])
        empty = ops.VANISHED
        with self.assertRaises(ValueError):
            # empty will generate an error before the hijack.
            _ = module.concat([inputs, empty], 3).shape.as_list()

        # hijacking:
        ops.hijack_module_functions(decorator, module)
        # Verifying success of hijack.
        self.assertAllEqual(
            module.concat([inputs, empty], 3).shape.as_list(), [2, 3, 3, 5])
        self.assertTrue(ops.is_vanished(module.concat([empty, empty], 3)))
        self.assertAllEqual(
            module.concat([inputs, empty, inputs], 3).shape.as_list(),
            [2, 3, 3, 10])
예제 #3
0
    def testRecover(self):
        # If this test does not work well, then it might have some bizarre effect on
        # other tests as it changes the functions in layers
        decorator = ops.ConfigurableOps()
        true_separable_conv2d = layers.separable_conv2d
        original_dict = ops.hijack_module_functions(decorator, layers)

        self.assertEqual(true_separable_conv2d,
                         original_dict['separable_conv2d'])
        # Makes sure hijacking worked.
        self.assertNotEqual(true_separable_conv2d, layers.separable_conv2d)
        # Recovers original ops
        ops.recover_module_functions(original_dict, layers)
        self.assertEqual(true_separable_conv2d, layers.separable_conv2d)
예제 #4
0
    def testHijack(self, fake_module, has_conv2d, has_separable_conv2d,
                   has_fully_connected):
        # This test verifies that hijacking works with arg scope.
        # TODO(e1): Test that all is correct when hijacking a real module.
        def name_and_output_fn(name):
            # By design there is no add arg_scope here.
            def fn(*args, **kwargs):
                return (name, args[1], kwargs['scope'])

            return fn

        function_dict = {
            'fully_connected': name_and_output_fn('testing_fully_connected'),
            'conv2d': name_and_output_fn('testing_conv2d'),
            'separable_conv2d': name_and_output_fn('testing_separable_conv2d')
        }

        decorator = ops.ConfigurableOps(function_dict=function_dict)
        originals = ops.hijack_module_functions(decorator, fake_module)

        self.assertEqual('conv2d' in originals, has_conv2d)
        self.assertEqual('separable_conv2d' in originals, has_separable_conv2d)
        self.assertEqual('fully_connected' in originals, has_fully_connected)

        if has_conv2d:
            with arg_scope([fake_module.conv2d], num_outputs=2):
                out = fake_module.conv2d(inputs=tf.zeros([10, 3, 3, 4]),
                                         scope='test_conv2d')
            self.assertAllEqual(['testing_conv2d', 2, 'test_conv2d'], out)

        if has_fully_connected:
            with arg_scope([fake_module.fully_connected], num_outputs=3):
                out = fake_module.fully_connected(inputs=tf.zeros([10, 4]),
                                                  scope='test_fc')
            self.assertAllEqual(['testing_fully_connected', 3, 'test_fc'], out)

        if has_separable_conv2d:
            with arg_scope([fake_module.separable_conv2d], num_outputs=4):
                out = fake_module.separable_conv2d(inputs=tf.zeros(
                    [10, 3, 3, 4]),
                                                   scope='test_sep')
            self.assertAllEqual(['testing_separable_conv2d', 4, 'test_sep'],
                                out)