示例#1
0
    def test_infer8(self):
        graph = build_graph(nodes_attributes, edges, inputs8)
        scatternd_node = Node(graph, 'scatternd_node')
        ScatterNDUpdate.infer(scatternd_node)

        # get the result
        res_output_value = graph.node['output']['value']

        self.assertTrue(
            np.array_equal(output8, res_output_value),
            'values do not match expected: {} and given: {}'.format(
                output8, res_output_value))
示例#2
0
    def test_partial_infer4(self):
        graph = build_graph(nodes_attributes, edges, inputs4)
        scatternd_node = Node(graph, 'scatternd_node')
        ScatterNDUpdate.infer(scatternd_node)

        # prepare reference results
        ref_output_shape = np.array([10, 40, 50], dtype=np.int32)

        # get the result
        res_output_shape = graph.node['output']['shape']

        self.assertTrue(
            np.array_equal(ref_output_shape, res_output_shape),
            'values do not match expected: {} and given: {}'.format(
                ref_output_shape, res_output_shape))
    def replace_op(self, graph: Graph, node: Node):
        node_name = node.soft_get('name', node.id)

        # broadcast default value to required shape
        broadcast_node = Broadcast(graph, {
            'name': node_name + '/Broadcast_'
        }).create_node()
        node.in_port(1).get_connection().set_destination(
            broadcast_node.in_port(1))
        if not node.in_port(3).disconnected():
            node.in_port(3).get_connection().set_destination(
                broadcast_node.in_port(0))
        else:
            broadcast_node.in_port(0).connect(
                Const(
                    graph, {
                        'name': broadcast_node.name + '/FillValue_',
                        'value': np.float32(0)
                    }).create_node().out_port(0))

        # update broadcasted tensor with required values at required locations
        scatternd_node = ScatterNDUpdate(
            graph, {
                'name': node_name + '/ScatterNDUpdate_'
            }).create_node()
        scatternd_node.in_port(0).connect(broadcast_node.out_port(0))
        node.in_port(0).get_connection().set_destination(
            scatternd_node.in_port(1))
        node.in_port(2).get_connection().set_destination(
            scatternd_node.in_port(2))

        rename_nodes([(node, node_name + "/AbandonedName"),
                      (scatternd_node, node_name)])

        return [scatternd_node.id]
示例#4
0
    def find_and_replace_pattern(self, graph: Graph):
        for tf_scatter_nd in graph.get_op_nodes(op='TFScatterND'):
            if not tf_scatter_nd.is_in_port_connected(0) or not tf_scatter_nd.is_in_port_connected(1) \
                    or not tf_scatter_nd.is_in_port_connected(2):
                continue
            name = tf_scatter_nd.soft_get('name', tf_scatter_nd.soft_get('id'))
            indices_port = tf_scatter_nd.in_port(0).get_source()
            updates_port = tf_scatter_nd.in_port(1).get_source()
            shape_port = tf_scatter_nd.in_port(2).get_source()
            # need get type of  const type
            zero_const = Const(graph, {
                'value': int64_array(0.0),
                'name': name + '/zero_const'
            }).create_node()

            # Convert zero value to type of updates node
            convert_to_type = ConvertLike(graph, {
                'name': name + '/convert_like'
            }).create_node()
            convert_to_type.in_port(0).connect(zero_const.out_port(0))
            convert_to_type.in_port(1).connect(updates_port)

            broad_cast_node = Broadcast(graph, {
                'name': name + '/broadcast'
            }).create_node()
            broad_cast_node.in_port(0).connect(convert_to_type.out_port(0))
            broad_cast_node.in_port(1).connect(shape_port)

            scatter_nd_node = ScatterNDUpdate(graph, {
                'name': name + '/replaced'
            }).create_node()
            scatter_nd_node.in_port(0).connect(broad_cast_node.out_port(0))
            scatter_nd_node.in_port(1).connect(indices_port)
            scatter_nd_node.in_port(2).connect(updates_port)

            rename_nodes([(tf_scatter_nd, name + '/TBD'),
                          (scatter_nd_node, name)])

            tf_scatter_nd.out_port(0).get_connection().set_source(
                scatter_nd_node.out_port(0))
            tf_scatter_nd.in_port(0).disconnect()
            tf_scatter_nd.in_port(1).disconnect()
            tf_scatter_nd.in_port(2).disconnect()
示例#5
0
 def extract(cls, node):
     ScatterNDUpdate.update_node_stat(node, {})
     return cls.enabled