def test_harmonize_dtypes(self): """harmonize_dtypes should change tensor dtype to node dtype""" graph_def = rewrite.GraphDef() const_nodes = [ # tensor "a" is float32 testutils.node_proto_from_json('{"name":"a","op":"Const","attr":' '{"dtype":{"type":"DT_FLOAT"}}}'), # tensor "b" is int64 testutils.node_proto_from_json('{"name":"b","op":"Const","attr":' '{"dtype":{"type":"DT_INT64"}}}') ] graph_def.node.extend(const_nodes) weight_dict = { # weight "a" matches tensor "a" 'a': convert_to_tensor( np.arange(9., dtype=np.float32).reshape((1, 3, 3))), # weight "b" is int32 (must be widened to match node) 'b': convert_to_tensor(np.arange(4, dtype=np.int32)), # no matching node for weight "c" 'c': convert_to_tensor(np.array(23, dtype=np.int64)) } result = rewrite.harmonize_dtypes(graph_def, weight_dict) # existing should be unchanged if matching self.assertEqual(result['a'].numpy().dtype, np.float32) # existing should be altered to match node self.assertEqual(result['b'].numpy().dtype, np.int64) # non-existing should be unchanged self.assertEqual(result['c'].numpy().dtype, np.int64)
def test_is_fused_op_without_activation(self): """is_fused_op should return True if op is fused with BiasAdd only and no activation function is given """ fused_matmul = testutils.node_proto_from_json( '{"name":"model/output/BiasAdd","op":"_FusedMatMul",' '"input":["model/dense/BiasAdd",' '"model/output/MatMul/ReadVariableOp",' '"model/output/BiasAdd/ReadVariableOp"],"device":"/device:CPU:0",' '"attr":{"transpose_b":{"b":false},"T":{"type":"DT_FLOAT"},' '"num_args":{"i": "1"},"epsilon":{"f": 0},' '"fused_ops":{"list":{"s":["Qmlhc0FkZA=="]}},' '"transpose_a":{"b":false}}}') self.assertTrue( rewrite.is_fused_op(fused_matmul, 'MatMul', activation='')) fused_conv2d = testutils.node_proto_from_json( '{"name":"/model/batch_normalization_v1_8/FusedBatchNormV3",' '"op":"_FusedConv2D",' '"input":["model/depthwise","model/weights","model/bn_offset"],' '"device":"/device:CPU:0",' '"attr":{"fused_ops":{"list":{"s":["Qmlhc0FkZA=="]}},' '"dilations":{"list":{"i":["1","1","1","1"]}},' '"T":{"type": "DT_FLOAT"},' '"strides":{"list":{"i": ["1","1","1","1"]}},' '"data_format":{"s":"TkhXQw=="},' '"explicit_paddings":{"list":{}},' '"num_args":{"i":"1"},' '"epsilon":{"f":0},' '"padding":{"s":"VkFMSUQ="}}}') self.assertTrue( rewrite.is_fused_op(fused_conv2d, 'Conv2D', activation=None))
def test_copy_op_attrs(self): """copy_op_attrs should only copy attrs supported by the target node""" # copy_op_attrs is used to transfer attrs from a fused op node # (e.g. _FusedConv2D) to a standalone op (e.g. Conv2D) # any additional attrs of the fused op need to be ignored fused_op_str = '{"name":"model/conv2d/BiasAdd",'\ + '"op":"_FusedConv2D","input":["input",'\ + '"model/conv2d/Conv2D/ReadVariableOp",'\ + '"model/conv2d/BiasAdd/ReadVariableOp",'\ + '"model/p_re_lu/Neg"],"device":"/device:CPU:0",' \ + '"attr":{"dilations":{"list":{"i":["1","1","1","1"]}},'\ + '"T":{"type":"DT_FLOAT"},"data_format":{"s":"TkhXQw=="},'\ + '"strides":{"list":{"i":["1","1","1","1"]}},'\ + '"use_cudnn_on_gpu":{"b":true},'\ + '"explicit_paddings":{"list":{}},'\ + '"num_args":{"i":"2"},"epsilon":{"f":0},'\ + '"padding":{"s":"VkFMSUQ="},'\ + '"fused_ops":{"list":{"s":["Qmlhc0FkZA==","UHJlbHU="]}}}}' fused_op = testutils.node_proto_from_json(fused_op_str) node = rewrite.make_op_node('Conv2D', fused_op.input[0:2]) rewrite.copy_op_attrs(source=fused_op, target=node) op_def = rewrite.get_op_def(node.op) allowed = set(attr.name for attr in op_def.attr) forbidden = any(attr for attr in node.attr if attr not in allowed) self.assertFalse(forbidden) # randomply check for some of the expected attributes self.assertTrue('padding' in node.attr) self.assertTrue('strides' in node.attr)
def test_is_fused_op(self): """is_fused_op should be true if op is fused with BiasAdd+Activation""" missing_activation = testutils.node_proto_from_json( '{"name":"model/output/BiasAdd","op":"_FusedMatMul",' '"input":["model/dense/BiasAdd",' '"model/output/MatMul/ReadVariableOp",' '"model/output/BiasAdd/ReadVariableOp"],"device":"/device:CPU:0",' '"attr":{"transpose_b":{"b":false},"T":{"type":"DT_FLOAT"},' '"num_args":{"i": "1"},"epsilon":{"f": 0},' '"fused_ops":{"list":{"s":["Qmlhc0FkZA=="]}},' '"transpose_a":{"b":false}}}') self.assertFalse( rewrite.is_fused_op(missing_activation, 'MatMul', b'Relu')) fused_matmul = testutils.node_proto_from_json( '{"name":"model/dense/BiasAdd","op":"_FusedMatMul",' '"input":["model/flatten/Reshape",' '"model/dense/MatMul/ReadVariableOp",' '"model/dense/BiasAdd/ReadVariableOp","model/p_re_lu_2/Neg"],' '"device":"/device:CPU:0","attr":{"transpose_b":{"b":false},' '"T":{"type":"DT_FLOAT"},"num_args":{"i":"2"},"epsilon":{"f":0},' '"fused_ops":{"list":{"s":["Qmlhc0FkZA==","UHJlbHU="]}},' '"transpose_a":{"b":false}}}') self.assertTrue(rewrite.is_fused_op(fused_matmul, 'MatMul', b'Prelu'))
def test_validate_supported_ops_given_invalid_graph(self): """validate_supported_ops should raise ValueError for unsupported op""" # case 1: unsupported op node graph_def = rewrite.GraphDef() unsupported_op = testutils.node_proto_from_json( '{"name":"model/p_re_lu_1/Relu","op":"Prelu","input":' '["model/add/add","model/p_re_lu_1/Neg"]}') graph_def.node.extend([unsupported_op]) self.assertRaises(ValueError, lambda: rewrite.validate_supported_ops(graph_def)) # case 2: unsupported fused op unsupported_fused_op = testutils.node_proto_from_json( '{"name":"model/dense/BiasAdd","op":"_FusedMatMul",' '"input":["model/flatten/Reshape",' '"model/dense/MatMul/ReadVariableOp",' '"model/dense/BiasAdd/ReadVariableOp","model/p_re_lu_2/Neg"],' '"device":"/device:CPU:0","attr":{"transpose_b":{"b":false},' '"T":{"type":"DT_FLOAT"},"num_args":{"i":"2"},"epsilon":{"f":0},' '"fused_ops":{"list":{"s":["Qmlhc0FkZA==","UHJlbHU="]}},' '"transpose_a":{"b":false}}}') graph_def = rewrite.GraphDef() graph_def.node.extend([unsupported_fused_op]) self.assertRaises(ValueError, lambda: rewrite.validate_supported_ops(graph_def))