def replace_pattern(graph: Graph, match: dict): log.debug('================== SimpleBackEdgeInputFind ===============') assert len(match['BackEdge'].in_nodes()) == 3 condition = match['BackEdge'].in_node(2) init_input = match['BackEdge'].in_node(0) cycle_input = match['BackEdge'].in_node(1) # We need to create new TensorItertorInput node only if this node doesn't exist already. if (len(init_input.in_nodes()) == 0 or \ (len(init_input.in_nodes()) == 1 and init_input.has_valid('value') and init_input.in_node(0).soft_get('op') != 'TensorIteratorInput')): input_node = TensorIteratorInput( graph, dict(external_port_id=None, internal_layer_id=None, name=match['BackEdge'].name + '/TensorIteratorInput_')) # In case if data node has Constant producer if len(init_input.in_nodes()) == 1: graph.remove_edge(init_input.in_node(0).id, init_input.id) input_data_node = input_node.create_node_with_data( inputs=[init_input]) input_data_node.shape = int64_array(init_input.shape) graph.remove_edges_from([(init_input.id, match['BackEdge'].id)]) graph.add_edges_from([(input_data_node.id, match['BackEdge'].id, { 'in': 0, 'out': 0 })])
def replace_pattern(graph: Graph, match: dict): log.debug('================== SimpletInputFind ===============') input_node = TensorIteratorInput(graph, dict(external_port_id=None, internal_layer_id=None, name=match['Enter'].name + '/TensorIteratorInput_' )) input_node.create_node_with_data(inputs=[match['Enter'].in_node()], data_nodes=[match['Enter'].out_node()]) # Delete useless nodes graph.remove_nodes_from([match['Enter'].id])
def replace_pattern(graph: Graph, match: dict): log.debug('================== SmartInputFind ===============') assert match['Enter_data'].value is not None assert match['stack_data']['value'][0] == 0 and match['stack_1_data']['value'][0] == 1 and \ match['stack_2_data']['value'][0] == 1 assert match['start_data']['value'] == 0 and match['delta_data'][ 'value'] == 1 ta_size_data = match['TensorArray'].in_node() ta_size = ta_size_data.in_node() value = match['TensorArrayScatter'].in_node(2) start, end = None, None if 0 in ta_size.in_nodes(): shape = match['StridedSlice'].in_node(0).in_node(0) # Case when value for Strided slice is Const, not Shape if shape['kind'] == 'op' and shape['op'] == 'Const': start = 0 end = shape.value[0] log.warning( "Your network cannot be reshaped since shapes of placeholders are constants. " "Please, provide non-constant shapes. ") # Create input node with params # axis == 0 because in TensorArray we ALWAYS iterate over 0 axis, other params will be fill later (with # condition) input_node = TensorIteratorInput( graph, dict(axis=0, start=start, stride=None, part_size=None, external_port_id=str(match['Enter_data'].value), internal_layer_id=match['TensorArrayRead_data'].id, name=match['TensorArrayRead'].name + '/TensorIteratorInput_')) input_node.create_node_with_data( inputs=[ta_size_data, value, match['Condition_data']], data_nodes=[match['TensorArrayRead_data']]) # Delete useless nodes safe_nodes = ['TensorArrayRead_data', 'Condition', 'Condition_data'] nodes_for_remove = [] for node in match.keys(): if node not in safe_nodes: nodes_for_remove.append(match[node].id) graph.remove_nodes_from(nodes_for_remove)
def replace_pattern(graph: Graph, match: dict): # retrieve attribute values for TensorIteratorInput node init_time = match['InitIndex'].value.item(0) time_step = match['IndexDelta'].value.item(0) axis = match['Axis'].value.item(0) # retrieve input and output nodes for TensorIteratorInput node initial_input_node = match['EnterInput'] current_index_node = match['IdentityIndex'] size_node = match['EnterMaxIndex'] resulted_slice_node = match['SqueezeSlice'] resulted_slice_node_name = resulted_slice_node.soft_get('name', resulted_slice_node.id) # create TensorIteratorInput node that reflects slicing of input for each time step along axis ti_input_node = TensorIteratorInput(graph, dict(axis=axis, start=init_time, stride=time_step, name=resulted_slice_node_name + '/TensorIteratorInput') ).create_node() size_node.in_port(0).get_connection().add_destination(ti_input_node.in_port(0)) initial_input_node.in_port(0).get_connection().set_destination(ti_input_node.in_port(1)) current_index_node.out_port(0).connect(ti_input_node.in_port(2)) resulted_slice_node.out_port(0).get_connection().set_source(ti_input_node.out_port(0)) # delete no longer needed nodes responsible for slicing of input in the original graph node_names_for_remove = ['EnterInput', 'MergeInput', 'SwitchInput', 'IdentityInput', 'NextIterationInput', 'SqueezeSlice', 'UnsqueezeIndex', 'Gather'] graph.remove_nodes_from([match[node_name].id for node_name in node_names_for_remove])