Пример #1
0
    def replace_op(self, graph: Graph, node: Node):
        node_name = node.soft_get('name', node.id)

        # broadcast default value to required shape
        broadcast_node = Broadcast(graph, {'name': node_name + '/Broadcast_'}).create_node()
        node.in_port(1).get_connection().set_destination(broadcast_node.in_port(1))
        if not node.in_port(3).disconnected():
            # TODO: remove casting once we start to support I64 model input
            # cast default value to I32 due limitation about I64 input support
            # so that input parameter and default value will be of the same I32 type as required ScatterNDUpdate
            cast_default_value = Cast(graph, {'name': node_name + '/CastDefaultValue', 'dst_type': np.int32}).create_node()
            node.in_port(3).get_connection().set_destination(cast_default_value.in_port(0))
            broadcast_node.in_port(0).connect(cast_default_value.out_port(0))
        else:
            broadcast_node.in_port(0).connect(Const(graph, {'name': broadcast_node.name + '/FillValue_',
                                                            'value': np.float32(0)}
                                                    ).create_node().out_port(0))

        # update broadcasted tensor with required values at required locations
        scatternd_node = ScatterNDUpdate(graph, {'name': node_name + '/ScatterNDUpdate_'}).create_node()
        scatternd_node.in_port(0).connect(broadcast_node.out_port(0))
        node.in_port(0).get_connection().set_destination(scatternd_node.in_port(1))
        node.in_port(2).get_connection().set_destination(scatternd_node.in_port(2))

        rename_nodes([(node, node_name + "/AbandonedName"), (scatternd_node, node_name)])

        return [scatternd_node.id]
Пример #2
0
    def test_infer8(self):
        graph = build_graph(nodes_attributes, edges, inputs8)
        scatternd_node = Node(graph, 'scatternd_node')
        ScatterNDUpdate.infer(scatternd_node)

        # get the result
        res_output_value = graph.node['output']['value']

        self.assertTrue(np.array_equal(output8, res_output_value),
                        'values do not match expected: {} and given: {}'.format(output8, res_output_value))
Пример #3
0
    def test_partial_infer3(self):
        graph = build_graph(nodes_attributes, edges, inputs3)
        scatternd_node = Node(graph, 'scatternd_node')
        ScatterNDUpdate.infer(scatternd_node)

        # prepare reference results
        ref_output_shape = np.array([20, 30, 5], dtype=np.int32)

        # get the result
        res_output_shape = graph.node['output']['shape']

        self.assertTrue(np.array_equal(ref_output_shape, res_output_shape),
                        'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
    def replace_op(self, graph: Graph, node: Node):
        node_name = node.soft_get('name', node.id)

        # broadcast default value to required shape
        broadcast_node = Broadcast(graph, {'name': node_name + '/Broadcast_'}).create_node()
        node.in_port(1).get_connection().set_destination(broadcast_node.in_port(1))
        if not node.in_port(3).disconnected():
            node.in_port(3).get_connection().set_destination(broadcast_node.in_port(0))
        else:
            broadcast_node.in_port(0).connect(Const(graph, {'name': broadcast_node.name + '/FillValue_',
                                                            'value': np.float32(0)}
                                                    ).create_node().out_port(0))

        # update broadcasted tensor with required values at required locations
        scatternd_node = ScatterNDUpdate(graph, {'name': node_name + '/ScatterNDUpdate_'}).create_node()
        scatternd_node.in_port(0).connect(broadcast_node.out_port(0))
        node.in_port(0).get_connection().set_destination(scatternd_node.in_port(1))
        node.in_port(2).get_connection().set_destination(scatternd_node.in_port(2))

        rename_nodes([(node, node_name + "/AbandonedName"), (scatternd_node, node_name)])

        return [scatternd_node.id]
Пример #5
0
 def extract(cls, node):
     ScatterNDUpdate.update_node_stat(node, {})
     return cls.enabled