def test_copy_op_attrs(self):
        """copy_op_attrs should only copy attrs supported by the target node"""
        # copy_op_attrs is used to transfer attrs from a fused op node
        # (e.g. _FusedConv2D) to a standalone op (e.g. Conv2D)
        # any additional attrs of the fused op need to be ignored
        fused_op_str = '{"name":"model/conv2d/BiasAdd",'\
            + '"op":"_FusedConv2D","input":["input",'\
            + '"model/conv2d/Conv2D/ReadVariableOp",'\
            + '"model/conv2d/BiasAdd/ReadVariableOp",'\
            + '"model/p_re_lu/Neg"],"device":"/device:CPU:0",' \
            + '"attr":{"dilations":{"list":{"i":["1","1","1","1"]}},'\
            + '"T":{"type":"DT_FLOAT"},"data_format":{"s":"TkhXQw=="},'\
            + '"strides":{"list":{"i":["1","1","1","1"]}},'\
            + '"use_cudnn_on_gpu":{"b":true},'\
            + '"explicit_paddings":{"list":{}},'\
            + '"num_args":{"i":"2"},"epsilon":{"f":0},'\
            + '"padding":{"s":"VkFMSUQ="},'\
            + '"fused_ops":{"list":{"s":["Qmlhc0FkZA==","UHJlbHU="]}}}}'
        fused_op = testutils.node_proto_from_json(fused_op_str)
        node = rewrite.make_op_node('Conv2D', fused_op.input[0:2])
        rewrite.copy_op_attrs(source=fused_op, target=node)

        op_def = rewrite.get_op_def(node.op)
        allowed = set(attr.name for attr in op_def.attr)
        forbidden = any(attr for attr in node.attr if attr not in allowed)

        self.assertFalse(forbidden)
        # randomply check for some of the expected attributes
        self.assertTrue('padding' in node.attr)
        self.assertTrue('strides' in node.attr)
def _split_fused_depthwise(node: util.NodeDef, input_node_map: util.NameToNode,
                           weight_mods: util.WeightModifiers) -> util.NodeList:
    """Decompose fused op into DepthwiseConv2dNative + BiasAdd [+ Activation]
    """
    fused_ops = list(s.decode('utf-8') for s in node.attr['fused_ops'].list.s)
    inputs = node.input
    names_used = set()

    def node_name(node_index):
        """Return unique node names for sub-operations by appending fused-op"""
        i = min(node_index, len(inputs)-1)  # PReLU has 4 inputs, others only 3
        name = generate_name_from(inputs[i], input_node_map)
        if name in names_used:
            name = generate_name_from(name, input_node_map,
                                      suffix=fused_ops[node_index-2])
        names_used.add(name)
        return name

    op = 'DepthwiseConv2dNative'
    depthwise = util.make_op_node(op, inputs[0:2], node_name(1))
    depthwise = util.copy_op_attrs(source=node, target=depthwise)
    op = fused_ops[0]
    bias_add = util.make_op_node(op, [depthwise, inputs[2]], node_name(2))
    bias_add = util.copy_op_attrs(source=node, target=bias_add)
    node_list = [depthwise, bias_add]
    if len(fused_ops) > 1:
        # we have an activation function
        op = fused_ops[1]
        input_nodes = [bias_add] + inputs[3:]
        if util.get_op_def(op) is None:
            # unsupported activation function - just copy type attribute
            dtype = depthwise.attr['T'].type
            activation = util.make_op_node(op, input_nodes, node_name(3),
                                           dtype)
        else:
            # supported activation function - copy applicable attributes
            activation = util.make_op_node(op, input_nodes, node_name(3))
            activation = util.copy_op_attrs(source=node, target=activation)
        node_list.append(activation)
    return node_list
 def test_get_op_def_given_bogus_op(self):
     """Should return None for unknown operations"""
     op_def = rewrite.get_op_def('CureCancer')
     self.assertIsNone(op_def)
 def test_get_op_def_given_known_op(self):
     """Should return valid op def for known operations"""
     op_def = rewrite.get_op_def('MatMul')
     self.assertIsNotNone(op_def)