def test_harmonize_dtypes(self):
     """harmonize_dtypes should change tensor dtype to node dtype"""
     graph_def = rewrite.GraphDef()
     const_nodes = [
         # tensor "a" is float32
         testutils.node_proto_from_json('{"name":"a","op":"Const","attr":'
                                        '{"dtype":{"type":"DT_FLOAT"}}}'),
         # tensor "b" is int64
         testutils.node_proto_from_json('{"name":"b","op":"Const","attr":'
                                        '{"dtype":{"type":"DT_INT64"}}}')
     ]
     graph_def.node.extend(const_nodes)
     weight_dict = {
         # weight "a" matches tensor "a"
         'a':
         convert_to_tensor(
             np.arange(9., dtype=np.float32).reshape((1, 3, 3))),
         # weight "b" is int32 (must be widened to match node)
         'b':
         convert_to_tensor(np.arange(4, dtype=np.int32)),
         # no matching node for weight "c"
         'c':
         convert_to_tensor(np.array(23, dtype=np.int64))
     }
     result = rewrite.harmonize_dtypes(graph_def, weight_dict)
     # existing should be unchanged if matching
     self.assertEqual(result['a'].numpy().dtype, np.float32)
     # existing should be altered to match node
     self.assertEqual(result['b'].numpy().dtype, np.int64)
     # non-existing should be unchanged
     self.assertEqual(result['c'].numpy().dtype, np.int64)
 def test_is_fused_op_without_activation(self):
     """is_fused_op should return True if op is fused with BiasAdd only
         and no activation function is given
     """
     fused_matmul = testutils.node_proto_from_json(
         '{"name":"model/output/BiasAdd","op":"_FusedMatMul",'
         '"input":["model/dense/BiasAdd",'
         '"model/output/MatMul/ReadVariableOp",'
         '"model/output/BiasAdd/ReadVariableOp"],"device":"/device:CPU:0",'
         '"attr":{"transpose_b":{"b":false},"T":{"type":"DT_FLOAT"},'
         '"num_args":{"i": "1"},"epsilon":{"f": 0},'
         '"fused_ops":{"list":{"s":["Qmlhc0FkZA=="]}},'
         '"transpose_a":{"b":false}}}')
     self.assertTrue(
         rewrite.is_fused_op(fused_matmul, 'MatMul', activation=''))
     fused_conv2d = testutils.node_proto_from_json(
         '{"name":"/model/batch_normalization_v1_8/FusedBatchNormV3",'
         '"op":"_FusedConv2D",'
         '"input":["model/depthwise","model/weights","model/bn_offset"],'
         '"device":"/device:CPU:0",'
         '"attr":{"fused_ops":{"list":{"s":["Qmlhc0FkZA=="]}},'
         '"dilations":{"list":{"i":["1","1","1","1"]}},'
         '"T":{"type": "DT_FLOAT"},'
         '"strides":{"list":{"i": ["1","1","1","1"]}},'
         '"data_format":{"s":"TkhXQw=="},'
         '"explicit_paddings":{"list":{}},'
         '"num_args":{"i":"1"},'
         '"epsilon":{"f":0},'
         '"padding":{"s":"VkFMSUQ="}}}')
     self.assertTrue(
         rewrite.is_fused_op(fused_conv2d, 'Conv2D', activation=None))
    def test_copy_op_attrs(self):
        """copy_op_attrs should only copy attrs supported by the target node"""
        # copy_op_attrs is used to transfer attrs from a fused op node
        # (e.g. _FusedConv2D) to a standalone op (e.g. Conv2D)
        # any additional attrs of the fused op need to be ignored
        fused_op_str = '{"name":"model/conv2d/BiasAdd",'\
            + '"op":"_FusedConv2D","input":["input",'\
            + '"model/conv2d/Conv2D/ReadVariableOp",'\
            + '"model/conv2d/BiasAdd/ReadVariableOp",'\
            + '"model/p_re_lu/Neg"],"device":"/device:CPU:0",' \
            + '"attr":{"dilations":{"list":{"i":["1","1","1","1"]}},'\
            + '"T":{"type":"DT_FLOAT"},"data_format":{"s":"TkhXQw=="},'\
            + '"strides":{"list":{"i":["1","1","1","1"]}},'\
            + '"use_cudnn_on_gpu":{"b":true},'\
            + '"explicit_paddings":{"list":{}},'\
            + '"num_args":{"i":"2"},"epsilon":{"f":0},'\
            + '"padding":{"s":"VkFMSUQ="},'\
            + '"fused_ops":{"list":{"s":["Qmlhc0FkZA==","UHJlbHU="]}}}}'
        fused_op = testutils.node_proto_from_json(fused_op_str)
        node = rewrite.make_op_node('Conv2D', fused_op.input[0:2])
        rewrite.copy_op_attrs(source=fused_op, target=node)

        op_def = rewrite.get_op_def(node.op)
        allowed = set(attr.name for attr in op_def.attr)
        forbidden = any(attr for attr in node.attr if attr not in allowed)

        self.assertFalse(forbidden)
        # randomply check for some of the expected attributes
        self.assertTrue('padding' in node.attr)
        self.assertTrue('strides' in node.attr)
 def test_is_fused_op(self):
     """is_fused_op should be true if op is fused with BiasAdd+Activation"""
     missing_activation = testutils.node_proto_from_json(
         '{"name":"model/output/BiasAdd","op":"_FusedMatMul",'
         '"input":["model/dense/BiasAdd",'
         '"model/output/MatMul/ReadVariableOp",'
         '"model/output/BiasAdd/ReadVariableOp"],"device":"/device:CPU:0",'
         '"attr":{"transpose_b":{"b":false},"T":{"type":"DT_FLOAT"},'
         '"num_args":{"i": "1"},"epsilon":{"f": 0},'
         '"fused_ops":{"list":{"s":["Qmlhc0FkZA=="]}},'
         '"transpose_a":{"b":false}}}')
     self.assertFalse(
         rewrite.is_fused_op(missing_activation, 'MatMul', b'Relu'))
     fused_matmul = testutils.node_proto_from_json(
         '{"name":"model/dense/BiasAdd","op":"_FusedMatMul",'
         '"input":["model/flatten/Reshape",'
         '"model/dense/MatMul/ReadVariableOp",'
         '"model/dense/BiasAdd/ReadVariableOp","model/p_re_lu_2/Neg"],'
         '"device":"/device:CPU:0","attr":{"transpose_b":{"b":false},'
         '"T":{"type":"DT_FLOAT"},"num_args":{"i":"2"},"epsilon":{"f":0},'
         '"fused_ops":{"list":{"s":["Qmlhc0FkZA==","UHJlbHU="]}},'
         '"transpose_a":{"b":false}}}')
     self.assertTrue(rewrite.is_fused_op(fused_matmul, 'MatMul', b'Prelu'))
 def test_validate_supported_ops_given_invalid_graph(self):
     """validate_supported_ops should raise ValueError for unsupported op"""
     # case 1: unsupported op node
     graph_def = rewrite.GraphDef()
     unsupported_op = testutils.node_proto_from_json(
         '{"name":"model/p_re_lu_1/Relu","op":"Prelu","input":'
         '["model/add/add","model/p_re_lu_1/Neg"]}')
     graph_def.node.extend([unsupported_op])
     self.assertRaises(ValueError,
                       lambda: rewrite.validate_supported_ops(graph_def))
     # case 2: unsupported fused op
     unsupported_fused_op = testutils.node_proto_from_json(
         '{"name":"model/dense/BiasAdd","op":"_FusedMatMul",'
         '"input":["model/flatten/Reshape",'
         '"model/dense/MatMul/ReadVariableOp",'
         '"model/dense/BiasAdd/ReadVariableOp","model/p_re_lu_2/Neg"],'
         '"device":"/device:CPU:0","attr":{"transpose_b":{"b":false},'
         '"T":{"type":"DT_FLOAT"},"num_args":{"i":"2"},"epsilon":{"f":0},'
         '"fused_ops":{"list":{"s":["Qmlhc0FkZA==","UHJlbHU="]}},'
         '"transpose_a":{"b":false}}}')
     graph_def = rewrite.GraphDef()
     graph_def.node.extend([unsupported_fused_op])
     self.assertRaises(ValueError,
                       lambda: rewrite.validate_supported_ops(graph_def))