示例#1
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        div_sqrt = match['op']
        div_sqrt_name = div_sqrt.soft_get('name', div_sqrt.id)
        shape_node = Shape(graph,
                           dict(name=div_sqrt_name + '/Shape')).create_node()
        data_out_port = div_sqrt.in_port(0).get_source()
        shape_node.in_port(0).connect(data_out_port)

        shape_values_node = node_to_get_shape_value_of_indices(
            shape_node=shape_node, indices=[-1])

        pow_node = AttributedPower(
            graph, dict(name=div_sqrt_name + '/Sqrt',
                        power=mo_array(0.5))).create_node()

        # Due to specification, Power must have inputs with the same data type.
        convert_pow_input = Cast(
            graph,
            dict(dst_type=np.float32,
                 name=shape_values_node.name +
                 '/ConvertToFP32')).create_node()
        div_node = Div(graph, dict(name="Div")).create_node()

        shape_values_node.out_port(0).connect(convert_pow_input.in_port(0))
        convert_pow_input.out_port(0).connect(pow_node.in_port(0))
        div_sqrt.in_port(0).get_connection().set_destination(
            div_node.in_port(0))
        div_node.in_port(1).connect(pow_node.out_port(0))
        div_sqrt.out_port(0).get_connection().set_source(div_node.out_port(0))

        rename_nodes([(div_sqrt, div_sqrt_name + '/ShouldBeDeleted'),
                      (div_node, div_sqrt_name)])
示例#2
0
 def extract(cls, node):
     data_type = tf_dtype_extractor(node.pb.attr["T"].type)
     AttributedPower.update_node_stat(node, {
         'power': data_type(2),
         'data_type': data_type
     })
     return cls.enabled
示例#3
0
 def extract(cls, node: Node):
     scale = onnx_attr(node,
                       'scale',
                       'f',
                       default=mo_array(1.0),
                       dst_type=lambda x: mo_array(x))
     AttributedPower.update_node_stat(node, {'scale': scale})
     return cls.enabled
 def extract(cls, node: Node):
     pb = node.pb
     assert pb, 'Protobuf layer can not be empty'
     param = pb.power_param
     attrs = {
         'output_spatial_shape': None,
         'power': param.power,
         'scale': param.scale,
         'shift': param.shift,
     }
     AttributedPower.update_node_stat(node, attrs)
     return cls.enabled
示例#5
0
 def extract(cls, node):
     AttributedPower.update_node_stat(node, {'power': 0.5})
     return cls.enabled
示例#6
0
 def extract(cls, node: Node):
     AttributedPower.update_node_stat(node, {'scale': -1})
     return cls.enabled
示例#7
0
 def extract(cls, node: Node):
     attrs = {
         'power': node.module.exponent,
     }
     AttributedPower.update_node_stat(node, attrs)
     return cls.enabled
示例#8
0
 def extract(cls, node: Node):
     attrs = {
         'power': 0.5,
     }
     AttributedPower.update_node_stat(node, attrs)
     return cls.enabled