def _hijack_and_recover(self, parameterization, **kwargs): module = tm _, original_layers = ops.hijack_keras_module(parameterization, module, **kwargs) out = tm.build_model_with_all_configurable_types(self.inputs) ops.recover_module_functions(original_layers, module) return out
def testHijackingLocalAliases(self): parameterization = {'conv/Conv2D': 1} module = tm _, original_layers = ops.hijack_keras_module(parameterization, module) out = tm.build_simple_keras_model_from_local_aliases(self.inputs) ops.recover_module_functions(original_layers, module) self.assertEqual(out.shape.as_list()[-1], 1)
def testHijackingImportedLayerLib(self): parameterization = {'conv/Conv2D': 1} module = tm.layers _, original_layers = ops.hijack_keras_module(parameterization, module) out = tm.build_simple_keras_model(self.inputs) ops.recover_module_functions(original_layers, module) self.assertEqual(out.shape.as_list()[-1], 1)
def testConstructedOps(self): parameterization = { 'conv/Conv2D': 1, 'sep_conv/separable_conv2d': 2, 'dense/Tensordot/MatMul': 3, } module = tm constructed_ops, original_layers = ops.hijack_keras_module( parameterization, module) out = tm.build_model_with_all_configurable_types(self.inputs) ops.recover_module_functions(original_layers, module) self.assertEqual(out.shape.as_list()[-1], 3) self.assertDictEqual(constructed_ops, parameterization)
def testPassThroughHijacking(self): parameterization = { 'conv1/Conv2D': 0, # followed by BatchNorm, Activation, Add and Concat 'conv2/Conv2D': 1, } module = tm.layers _, original_layers = ops.hijack_keras_module( parameterization, module, keep_first_channel_alive=False) # output = Concat([branch1, branch2, branch1 + branch2]). # if branch1 vanishes, output should have only 2 channels. output, branch1, branch2 = tm.build_two_branch_model(self.inputs) ops.recover_module_functions(original_layers, module) self.assertEqual(branch1, ops.VANISHED) self.assertEqual(branch2.shape.as_list(), [1, 8, 8, 1]) self.assertEqual(output.shape.as_list(), [1, 8, 8, 2])