コード例 #1
0
 def _hijack_and_recover(self, parameterization, **kwargs):
     module = tm
     _, original_layers = ops.hijack_keras_module(parameterization, module,
                                                  **kwargs)
     out = tm.build_model_with_all_configurable_types(self.inputs)
     ops.recover_module_functions(original_layers, module)
     return out
コード例 #2
0
 def testHijackingLocalAliases(self):
     parameterization = {'conv/Conv2D': 1}
     module = tm
     _, original_layers = ops.hijack_keras_module(parameterization, module)
     out = tm.build_simple_keras_model_from_local_aliases(self.inputs)
     ops.recover_module_functions(original_layers, module)
     self.assertEqual(out.shape.as_list()[-1], 1)
コード例 #3
0
 def testHijackingImportedLayerLib(self):
     parameterization = {'conv/Conv2D': 1}
     module = tm.layers
     _, original_layers = ops.hijack_keras_module(parameterization, module)
     out = tm.build_simple_keras_model(self.inputs)
     ops.recover_module_functions(original_layers, module)
     self.assertEqual(out.shape.as_list()[-1], 1)
コード例 #4
0
 def testConstructedOps(self):
     parameterization = {
         'conv/Conv2D': 1,
         'sep_conv/separable_conv2d': 2,
         'dense/Tensordot/MatMul': 3,
     }
     module = tm
     constructed_ops, original_layers = ops.hijack_keras_module(
         parameterization, module)
     out = tm.build_model_with_all_configurable_types(self.inputs)
     ops.recover_module_functions(original_layers, module)
     self.assertEqual(out.shape.as_list()[-1], 3)
     self.assertDictEqual(constructed_ops, parameterization)
コード例 #5
0
    def testPassThroughHijacking(self):
        parameterization = {
            'conv1/Conv2D':
            0,  # followed by BatchNorm, Activation, Add and Concat
            'conv2/Conv2D': 1,
        }
        module = tm.layers
        _, original_layers = ops.hijack_keras_module(
            parameterization, module, keep_first_channel_alive=False)

        # output = Concat([branch1, branch2, branch1 + branch2]).
        # if branch1 vanishes, output should have only 2 channels.
        output, branch1, branch2 = tm.build_two_branch_model(self.inputs)

        ops.recover_module_functions(original_layers, module)

        self.assertEqual(branch1, ops.VANISHED)
        self.assertEqual(branch2.shape.as_list(), [1, 8, 8, 1])
        self.assertEqual(output.shape.as_list(), [1, 8, 8, 2])