def replace_op(self, graph: Graph, node: Node): node_name = node.soft_get('name', node.id) # broadcast default value to required shape broadcast_node = Broadcast(graph, {'name': node_name + '/Broadcast_'}).create_node() node.in_port(1).get_connection().set_destination(broadcast_node.in_port(1)) if not node.in_port(3).disconnected(): # TODO: remove casting once we start to support I64 model input # cast default value to I32 due limitation about I64 input support # so that input parameter and default value will be of the same I32 type as required ScatterNDUpdate cast_default_value = Cast(graph, {'name': node_name + '/CastDefaultValue', 'dst_type': np.int32}).create_node() node.in_port(3).get_connection().set_destination(cast_default_value.in_port(0)) broadcast_node.in_port(0).connect(cast_default_value.out_port(0)) else: broadcast_node.in_port(0).connect(Const(graph, {'name': broadcast_node.name + '/FillValue_', 'value': np.float32(0)} ).create_node().out_port(0)) # update broadcasted tensor with required values at required locations scatternd_node = ScatterNDUpdate(graph, {'name': node_name + '/ScatterNDUpdate_'}).create_node() scatternd_node.in_port(0).connect(broadcast_node.out_port(0)) node.in_port(0).get_connection().set_destination(scatternd_node.in_port(1)) node.in_port(2).get_connection().set_destination(scatternd_node.in_port(2)) rename_nodes([(node, node_name + "/AbandonedName"), (scatternd_node, node_name)]) return [scatternd_node.id]
def test_infer8(self): graph = build_graph(nodes_attributes, edges, inputs8) scatternd_node = Node(graph, 'scatternd_node') ScatterNDUpdate.infer(scatternd_node) # get the result res_output_value = graph.node['output']['value'] self.assertTrue(np.array_equal(output8, res_output_value), 'values do not match expected: {} and given: {}'.format(output8, res_output_value))
def test_partial_infer3(self): graph = build_graph(nodes_attributes, edges, inputs3) scatternd_node = Node(graph, 'scatternd_node') ScatterNDUpdate.infer(scatternd_node) # prepare reference results ref_output_shape = np.array([20, 30, 5], dtype=np.int32) # get the result res_output_shape = graph.node['output']['shape'] self.assertTrue(np.array_equal(ref_output_shape, res_output_shape), 'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
def replace_op(self, graph: Graph, node: Node): node_name = node.soft_get('name', node.id) # broadcast default value to required shape broadcast_node = Broadcast(graph, {'name': node_name + '/Broadcast_'}).create_node() node.in_port(1).get_connection().set_destination(broadcast_node.in_port(1)) if not node.in_port(3).disconnected(): node.in_port(3).get_connection().set_destination(broadcast_node.in_port(0)) else: broadcast_node.in_port(0).connect(Const(graph, {'name': broadcast_node.name + '/FillValue_', 'value': np.float32(0)} ).create_node().out_port(0)) # update broadcasted tensor with required values at required locations scatternd_node = ScatterNDUpdate(graph, {'name': node_name + '/ScatterNDUpdate_'}).create_node() scatternd_node.in_port(0).connect(broadcast_node.out_port(0)) node.in_port(0).get_connection().set_destination(scatternd_node.in_port(1)) node.in_port(2).get_connection().set_destination(scatternd_node.in_port(2)) rename_nodes([(node, node_name + "/AbandonedName"), (scatternd_node, node_name)]) return [scatternd_node.id]
def extract(cls, node): ScatterNDUpdate.update_node_stat(node, {}) return cls.enabled