def testResnet(self, resnet_version): resnets = { 'resnet_v1': (resnet_v1, 'v1'), 'resnet_v2': (resnet_v2, 'v2') } resnet_module = resnets[resnet_version][0] decorator = ops.ConfigurableOps() hijacked_from_layers_lib = ops.hijack_module_functions( decorator, resnet_module.layers_lib) hijacked_from_utils = ops.hijack_module_functions( decorator, resnet_utils.layers) hijacked_from_module = ops.hijack_module_functions( decorator, resnet_module.layers) print('hijacked_from_layers_lib', hijacked_from_layers_lib) print('hijacked_from_utils', hijacked_from_utils) print('hijacked_from_module', hijacked_from_module) inputs = tf.ones([3, 16, 16, 5]) _ = resnet_module.bottleneck(inputs, depth=64, depth_bottleneck=16, stride=1) self.assertLen(decorator.constructed_ops, 4) base_name = 'bottleneck_' + resnets[resnet_version][1] expected_decorated_ops = sorted([ base_name + '/conv1/Conv2D', base_name + '/conv2/Conv2D', base_name + '/conv3/Conv2D', base_name + '/shortcut/Conv2D', ]) self.assertAllEqual(expected_decorated_ops, sorted(decorator.constructed_ops.keys()))
def testConcatHijack(self): decorator = ops.ConfigurableOps() module = Fake() inputs = tf.ones([2, 3, 3, 5]) empty = ops.VANISHED with self.assertRaises(ValueError): # empty will generate an error before the hijack. _ = module.concat([inputs, empty], 3).shape.as_list() # hijacking: ops.hijack_module_functions(decorator, module) # Verifying success of hijack. self.assertAllEqual( module.concat([inputs, empty], 3).shape.as_list(), [2, 3, 3, 5]) self.assertTrue(ops.is_vanished(module.concat([empty, empty], 3))) self.assertAllEqual( module.concat([inputs, empty, inputs], 3).shape.as_list(), [2, 3, 3, 10])
def testRecover(self): # If this test does not work well, then it might have some bizarre effect on # other tests as it changes the functions in layers decorator = ops.ConfigurableOps() true_separable_conv2d = layers.separable_conv2d original_dict = ops.hijack_module_functions(decorator, layers) self.assertEqual(true_separable_conv2d, original_dict['separable_conv2d']) # Makes sure hijacking worked. self.assertNotEqual(true_separable_conv2d, layers.separable_conv2d) # Recovers original ops ops.recover_module_functions(original_dict, layers) self.assertEqual(true_separable_conv2d, layers.separable_conv2d)
def testHijack(self, fake_module, has_conv2d, has_separable_conv2d, has_fully_connected): # This test verifies that hijacking works with arg scope. # TODO(e1): Test that all is correct when hijacking a real module. def name_and_output_fn(name): # By design there is no add arg_scope here. def fn(*args, **kwargs): return (name, args[1], kwargs['scope']) return fn function_dict = { 'fully_connected': name_and_output_fn('testing_fully_connected'), 'conv2d': name_and_output_fn('testing_conv2d'), 'separable_conv2d': name_and_output_fn('testing_separable_conv2d') } decorator = ops.ConfigurableOps(function_dict=function_dict) originals = ops.hijack_module_functions(decorator, fake_module) self.assertEqual('conv2d' in originals, has_conv2d) self.assertEqual('separable_conv2d' in originals, has_separable_conv2d) self.assertEqual('fully_connected' in originals, has_fully_connected) if has_conv2d: with arg_scope([fake_module.conv2d], num_outputs=2): out = fake_module.conv2d(inputs=tf.zeros([10, 3, 3, 4]), scope='test_conv2d') self.assertAllEqual(['testing_conv2d', 2, 'test_conv2d'], out) if has_fully_connected: with arg_scope([fake_module.fully_connected], num_outputs=3): out = fake_module.fully_connected(inputs=tf.zeros([10, 4]), scope='test_fc') self.assertAllEqual(['testing_fully_connected', 3, 'test_fc'], out) if has_separable_conv2d: with arg_scope([fake_module.separable_conv2d], num_outputs=4): out = fake_module.separable_conv2d(inputs=tf.zeros( [10, 3, 3, 4]), scope='test_sep') self.assertAllEqual(['testing_separable_conv2d', 4, 'test_sep'], out)