示例#1
0
 def _hijack_and_recover(self, parameterization, **kwargs):
     module = tm
     _, original_layers = ops.hijack_keras_module(parameterization, module,
                                                  **kwargs)
     out = tm.build_model_with_all_configurable_types(self.inputs)
     ops.recover_module_functions(original_layers, module)
     return out
示例#2
0
 def testHijackingLocalAliases(self):
     parameterization = {'conv/Conv2D': 1}
     module = tm
     _, original_layers = ops.hijack_keras_module(parameterization, module)
     out = tm.build_simple_keras_model_from_local_aliases(self.inputs)
     ops.recover_module_functions(original_layers, module)
     self.assertEqual(out.shape.as_list()[-1], 1)
示例#3
0
 def testHijackingImportedLayerLib(self):
     parameterization = {'conv/Conv2D': 1}
     module = tm.layers
     _, original_layers = ops.hijack_keras_module(parameterization, module)
     out = tm.build_simple_keras_model(self.inputs)
     ops.recover_module_functions(original_layers, module)
     self.assertEqual(out.shape.as_list()[-1], 1)
示例#4
0
 def testConstructedOps(self):
     parameterization = {
         'conv/Conv2D': 1,
         'sep_conv/separable_conv2d': 2,
         'dense/Tensordot/MatMul': 3,
     }
     module = tm
     constructed_ops, original_layers = ops.hijack_keras_module(
         parameterization, module)
     out = tm.build_model_with_all_configurable_types(self.inputs)
     ops.recover_module_functions(original_layers, module)
     self.assertEqual(out.shape.as_list()[-1], 3)
     self.assertDictEqual(constructed_ops, parameterization)
示例#5
0
    def testRecover(self):
        # If this test does not work well, then it might have some bizarre effect on
        # other tests as it changes the functions in layers
        decorator = ops.ConfigurableOps()
        true_separable_conv2d = layers.separable_conv2d
        original_dict = ops.hijack_module_functions(decorator, layers)

        self.assertEqual(true_separable_conv2d,
                         original_dict['separable_conv2d'])
        # Makes sure hijacking worked.
        self.assertNotEqual(true_separable_conv2d, layers.separable_conv2d)
        # Recovers original ops
        ops.recover_module_functions(original_dict, layers)
        self.assertEqual(true_separable_conv2d, layers.separable_conv2d)
示例#6
0
    def testPassThroughHijacking(self):
        parameterization = {
            'conv1/Conv2D':
            0,  # followed by BatchNorm, Activation, Add and Concat
            'conv2/Conv2D': 1,
        }
        module = tm.layers
        _, original_layers = ops.hijack_keras_module(
            parameterization, module, keep_first_channel_alive=False)

        # output = Concat([branch1, branch2, branch1 + branch2]).
        # if branch1 vanishes, output should have only 2 channels.
        output, branch1, branch2 = tm.build_two_branch_model(self.inputs)

        ops.recover_module_functions(original_layers, module)

        self.assertEqual(branch1, ops.VANISHED)
        self.assertEqual(branch2.shape.as_list(), [1, 8, 8, 1])
        self.assertEqual(output.shape.as_list(), [1, 8, 8, 2])