def test_infer8(self): graph = build_graph(nodes_attributes, edges, inputs8) scatternd_node = Node(graph, 'scatternd_node') ScatterNDUpdate.infer(scatternd_node) # get the result res_output_value = graph.node['output']['value'] self.assertTrue( np.array_equal(output8, res_output_value), 'values do not match expected: {} and given: {}'.format( output8, res_output_value))
def test_partial_infer4(self): graph = build_graph(nodes_attributes, edges, inputs4) scatternd_node = Node(graph, 'scatternd_node') ScatterNDUpdate.infer(scatternd_node) # prepare reference results ref_output_shape = np.array([10, 40, 50], dtype=np.int32) # get the result res_output_shape = graph.node['output']['shape'] self.assertTrue( np.array_equal(ref_output_shape, res_output_shape), 'values do not match expected: {} and given: {}'.format( ref_output_shape, res_output_shape))
def replace_op(self, graph: Graph, node: Node): node_name = node.soft_get('name', node.id) # broadcast default value to required shape broadcast_node = Broadcast(graph, { 'name': node_name + '/Broadcast_' }).create_node() node.in_port(1).get_connection().set_destination( broadcast_node.in_port(1)) if not node.in_port(3).disconnected(): node.in_port(3).get_connection().set_destination( broadcast_node.in_port(0)) else: broadcast_node.in_port(0).connect( Const( graph, { 'name': broadcast_node.name + '/FillValue_', 'value': np.float32(0) }).create_node().out_port(0)) # update broadcasted tensor with required values at required locations scatternd_node = ScatterNDUpdate( graph, { 'name': node_name + '/ScatterNDUpdate_' }).create_node() scatternd_node.in_port(0).connect(broadcast_node.out_port(0)) node.in_port(0).get_connection().set_destination( scatternd_node.in_port(1)) node.in_port(2).get_connection().set_destination( scatternd_node.in_port(2)) rename_nodes([(node, node_name + "/AbandonedName"), (scatternd_node, node_name)]) return [scatternd_node.id]
def find_and_replace_pattern(self, graph: Graph): for tf_scatter_nd in graph.get_op_nodes(op='TFScatterND'): if not tf_scatter_nd.is_in_port_connected(0) or not tf_scatter_nd.is_in_port_connected(1) \ or not tf_scatter_nd.is_in_port_connected(2): continue name = tf_scatter_nd.soft_get('name', tf_scatter_nd.soft_get('id')) indices_port = tf_scatter_nd.in_port(0).get_source() updates_port = tf_scatter_nd.in_port(1).get_source() shape_port = tf_scatter_nd.in_port(2).get_source() # need get type of const type zero_const = Const(graph, { 'value': int64_array(0.0), 'name': name + '/zero_const' }).create_node() # Convert zero value to type of updates node convert_to_type = ConvertLike(graph, { 'name': name + '/convert_like' }).create_node() convert_to_type.in_port(0).connect(zero_const.out_port(0)) convert_to_type.in_port(1).connect(updates_port) broad_cast_node = Broadcast(graph, { 'name': name + '/broadcast' }).create_node() broad_cast_node.in_port(0).connect(convert_to_type.out_port(0)) broad_cast_node.in_port(1).connect(shape_port) scatter_nd_node = ScatterNDUpdate(graph, { 'name': name + '/replaced' }).create_node() scatter_nd_node.in_port(0).connect(broad_cast_node.out_port(0)) scatter_nd_node.in_port(1).connect(indices_port) scatter_nd_node.in_port(2).connect(updates_port) rename_nodes([(tf_scatter_nd, name + '/TBD'), (scatter_nd_node, name)]) tf_scatter_nd.out_port(0).get_connection().set_source( scatter_nd_node.out_port(0)) tf_scatter_nd.in_port(0).disconnect() tf_scatter_nd.in_port(1).disconnect() tf_scatter_nd.in_port(2).disconnect()
def extract(cls, node): ScatterNDUpdate.update_node_stat(node, {}) return cls.enabled