def extract(cls, node: Node): Shape.update_node_stat( node, { 'data_type': tf_dtype_extractor(node.pb.attr['out_type'].type, np.int32) }) return cls.enabled
def make_interpolate_reshapeable(interpolate, concat): assert interpolate.soft_get('type') == 'Interpolate' assert concat.soft_get('type') == 'Concat' output_shape = interpolate.out_port(0).data.get_shape() interp_axes = [get_canonical_axis_index(output_shape, axis) for axis in Interpolate.get_axes(interpolate)] concat_axis = get_canonical_axis_index(output_shape, concat.axis) if concat_axis in interp_axes: return concat_srcs = [port.get_source() for port in concat.in_ports().values() if not port.disconnected()] non_interp_concat_srcs = [src for src in concat_srcs if src.node.soft_get('type') != 'Interpolate'] if len(non_interp_concat_srcs) == 0: return graph = interpolate.graph src = non_interp_concat_srcs[0] shape = Shape(graph, {'name': src.node.soft_get('name', src.node.id) + '/Shape'}).create_node() shape.in_port(0).connect(src) gather = create_op_with_const_inputs(graph, Gather, {1: np.array(interp_axes, dtype=np.int32), 2: int64_array(0)}, {'name': shape.name + '/Gathered'}, shape) interpolate.in_port(1).get_connection().set_source(gather.out_port(0))
def replace_pattern(self, graph: Graph, match: dict): node = match['op'] node.is_training = False shape = node.in_port(1).data.get_shape() assert shape is not None, 'The shape of scale input of the BatchNorm node {} is not defined'.format(node.name) bn_mean = Const(graph, {'name': node.name + '/mean', 'value': np.zeros(shape, dtype=np.float32), 'override_output_shape': True}).create_node() bn_std = Const(graph, {'name': node.name + '/std', 'value': np.ones(shape, dtype=np.float32), 'override_output_shape': True}).create_node() node.in_port(3).get_connection().set_source(bn_mean.out_port(0)) node.in_port(4).get_connection().set_source(bn_std.out_port(0)) # save the original shape original_shape = Shape(graph, {'name': node.in_port(0).get_source().node.soft_get('name')}).create_node() original_shape.in_port(0).connect(node.in_port(0).get_source()) mvn = MVN(graph, {'name': node.name + '/mvn_', 'eps': node.soft_get('eps', 1e-6), 'override_output_shape': True}).create_node() node.in_port(0).get_connection().insert_node(mvn) reshape_4d = create_op_node_with_second_input(graph, Reshape, int64_array([1, -1, 0, 0]), {'override_output_shape': True, 'name': node.soft_get('name') + '/fused_batch_and_channels'}) mvn.in_port(0).get_connection().insert_node(reshape_4d) # restore original shape reshape_back = Reshape(graph, {'name': mvn.soft_get('name') + '/restore_shape', 'override_output_shape': True}).create_node() reshape_back.in_port(1).connect(original_shape.out_port(0)) mvn.out_port(0).get_connection().insert_node(reshape_back)
def replace_sub_graph(self, graph: Graph, match: dict): node = match['mxreshape'] input_index = 0 reshape_index = 0 shape_node = Shape(graph, dict(name=node.id + '/ShapeMXReshape')).create_node() shape_node.in_port(0).connect(node.in_port(0).get_source()) output_dims_nodes = [] for d in node.dim: if reshape_index < len(node.dim): input_index, reshape_index, output_dims_nodes = self.resolve( input_index, reshape_index, node.dim, shape_node, output_dims_nodes) concat_node = Concat( shape_node.graph, dict(name=shape_node.id + '/ConcatMXReshape_', axis=0, in_ports_count=len(output_dims_nodes))).create_node() for in_port_index, dim_node in enumerate(output_dims_nodes): concat_node.in_port(in_port_index).connect(dim_node.out_port(0)) reshape_node = Reshape(graph, dict(name=node.id + '/Reshape_')).create_node() reshape_node.in_port(1).connect(concat_node.out_port(0)) node.in_port(0).get_connection().set_destination( reshape_node.in_port(0)) node.out_port(0).get_connection().set_source(reshape_node.out_port(0))
def replace_pattern(graph: Graph, match: dict): node = match['conv'] node_name = node.soft_get('name', node.id) # create Reshape before convolution # shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride] i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node() shape = Cast( graph, { 'name': node_name + '/to_float', 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type) }).create_node() i_shape.in_port(0).connect(node.in_port(0).get_source()) shape.in_port(0).connect(i_shape.out_port(0)) N, H = node_to_get_shape_value_of_indices( shape, [0]), node_to_get_shape_value_of_indices(shape, [1]) div = create_op_with_const_inputs( graph, Div, {1: float_array([node.patch_stride])}, {'name': node_name + '/div_stride_h'}) div.in_port(0).connect(H.out_port(0)) concat = create_op_with_const_inputs( graph, Concat, { 2: float_array([1]), 3: float_array([node.patch_stride]) }, { 'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0 }) concat.in_port(0).connect(N.out_port(0)) concat.in_port(1).connect(div.out_port(0)) reshape_pattern = Cast(graph, { 'name': node_name + '/to_int', 'dst_type': np.int64 }).create_node() concat.out_port(0).connect(reshape_pattern.in_port(0)) reshape_in = Reshape(graph, { 'name': node_name + '/reshape_in' }).create_node() reshape_in.in_port(1).connect(reshape_pattern.out_port(0)) # create Reshape after Convolution reshape_out = create_op_node_with_second_input( graph, Reshape, int64_array([0, -1]), {'name': node_name + '/reshape_out'}) # connect input_reshape_node source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(reshape_in.out_port(0)) reshape_in.in_port(0).connect(source) # connect output_reshape_node node.out_port(0).get_connection().set_source(reshape_out.out_port(0)) node.out_port(0).connect(reshape_out.in_port(0))
def replace_sub_graph(self, graph: Graph, match: dict): source_connection = match['split'].in_port(0).get_connection() source_node = source_connection.get_source().node cast_node = match['cast'] range_node = Range(graph, { 'name': source_node.id + '/Range' }).create_node() start_node = Const(graph, { 'name': range_node.id + '/Start', 'value': int64_array(0) }).create_node() step_node = Const(graph, { 'name': range_node.id + '/Step', 'value': int64_array(1) }).create_node() input_shape_node = Shape(graph, { 'name': start_node.id + '/Shape' }).create_node() input_shape_node.in_port(0).connect(source_node.out_port(0)) limit_node_1D = node_to_get_batch_value(input_shape_node) limit_node = create_op_node_with_second_input( graph, Squeeze, int64_array([0]), {'name': source_node.id + '/batch_0D_value'}, limit_node_1D) range_node.in_port(0).connect(start_node.out_port(0)) range_node.in_port(1).connect(limit_node.out_port(0)) range_node.in_port(2).connect(step_node.out_port(0)) cast_node.out_port(0).get_connection().set_source( range_node.out_port(0)) graph.remove_nodes_from([node.id for node in match.values()])
def replace_sub_graph(self, graph: Graph, match: dict): node = match['flatten'] name = node.soft_get('name', node.id) assert node.has_valid('axis'), 'Flatten {} should have `axis` attribute extracted, but it\'s not'.format(name) axis = node.axis if axis == 0: dim = Const(graph, {'value': int64_array([1, -1])}).create_node() elif axis == 1: dim = Const(graph, {'value': int64_array([0, -1])}).create_node() else: shape = Shape(graph, {'name': name + '/input_shape'}).create_node() idxs = list(range(axis)) if axis > 0 else list(range(axis, 0)) axis_shape_portion = node_to_get_shape_value_of_indices(shape, idxs) first_dims = create_op_node_with_second_input(graph, ReduceProd, int64_array([0]), {'keep_dims': True}) second_dims = Const(graph, {'value': int64_array([-1])}).create_node() node.in_port(0).get_source().connect(shape.in_port(0)) axis_shape_portion.out_port(0).connect(first_dims.in_port(0)) order_of_dims = [first_dims, second_dims] if axis > 0 else [second_dims, first_dims] dim = new_shape_node_from_shape_nodes(order_of_dims) reshape_node = Reshape(graph, {'name': node.id + '/Reshape'}).create_node() reshape_node.in_port(1).connect(dim.out_port(0)) node.out_port(0).get_connection().set_source(reshape_node.out_port(0)) node.in_port(0).get_connection().set_destination(reshape_node.in_port(0))
def find_and_replace_pattern(self, graph: Graph): for roll_node in graph.get_op_nodes(op='Roll'): if not roll_node.in_port(2).disconnected(): return node_name = roll_node.soft_get('name', roll_node.id) # reshape to 1d tensor reshape_to_1d = create_op_node_with_second_input( graph, Reshape, int64_array([-1]), {'name': node_name + '/reshape'}) roll_node.in_port(0).get_connection().insert_node(reshape_to_1d) # add zero const as axes input to roll const_zero = Const(graph, { 'value': int64_array([0]), 'name': node_name + '/axes' }).create_node() const_zero.out_port(0).connect(roll_node.in_port(2)) # reshape to original shape shape_of = Shape(graph, { 'name': node_name + '/shape_of' }).create_node() roll_node.in_port(0).get_connection().add_destination( shape_of.in_port(0)) reshape_to_orig_shape = Reshape(graph, {}).create_node() rename_nodes([(roll_node, node_name + '/roll'), (reshape_to_orig_shape, node_name)]) shape_of.out_port(0).connect(reshape_to_orig_shape.in_port(1)) roll_node.out_port(0).get_connection().insert_node( reshape_to_orig_shape)
def replace_pattern(self, graph: Graph, match: dict): matmul = match['matmul'] reshape = match['reshape'] other_input_port_idx = 0 if match['matmul'].in_port(0).get_source().node.id == match['other_input'].id else 1 shape_source = match['matmul'].in_port(other_input_port_idx).get_source() initial_reshape_pattern = reshape.in_port(1).data.get_value() if len(initial_reshape_pattern) != 2: return reshape_is_A_input = matmul.in_port(0).get_source().node.id == reshape.id if reshape_is_A_input: idx = -1 if matmul.transpose_b else -2 else: idx = -2 if matmul.transpose_a else -1 idx = get_canonical_axis_index(initial_reshape_pattern, idx) shape_name = shape_source.node.soft_get('name', shape_source.node.id) shape = Shape(graph, {'name': shape_name + '/Shape'}).create_node() shape.in_port(0).connect(shape_source) C = node_to_get_shape_value_of_indices(shape, [idx]) N = Const(graph, {'name': shape_name + '/MinusOne', 'value': int64_array([-1])}).create_node() if len(initial_reshape_pattern) == 2: if reshape_is_A_input: reshape_pattern = [C, N] if matmul.transpose_a else [N, C] else: reshape_pattern = [N, C] if matmul.transpose_b else [C, N] new_reshape_pattern = new_shape_node_from_shape_nodes(reshape_pattern) reshape.in_port(1).get_connection().set_source(new_reshape_pattern.out_port(0)) else: return
def make_interpolate_reshape_able(self, interpolate: Node, concat: Node): assert interpolate.soft_get('type') == 'Interpolate' assert concat.soft_get('type') == 'Concat' interp_axes = Interpolate.get_axes(interpolate) concat_axis = self.get_concat_axis(concat) if concat_axis is None or interp_axes is None \ or np.any(interp_axes < 0) or concat_axis < 0 \ or concat_axis in interp_axes: # checks that interpolate axes and concat axis are valid and do not intersect return non_interp_concat_srcs = self.get_non_interpolate_concat_sources( concat) if not len(non_interp_concat_srcs): # there is no Concat input to take input from return graph = interpolate.graph src = non_interp_concat_srcs[0] shape = Shape(graph, { 'name': src.node.soft_get('name', src.node.id) + '/Shape' }).create_node() shape.in_port(0).connect(src) gather = create_op_with_const_inputs( graph, Gather, { 1: np.array(interp_axes, dtype=np.int32), 2: int64_array(0) }, {'name': shape.name + '/Gathered'}, input_node=shape) interpolate.in_port(1).get_connection().set_source(gather.out_port(0))
def squeeze_initial_states(graph: Graph, match: dict): """ Squeeze input initial states of recurrent node to 2-D shape. """ hidden_init_port = 5 cell_init_port = 6 rnn_layer = match['rnn_layer'] # Add input ports to rnn_layer rnn_layer.add_sequence_of_ports(type='in', rng=range(7)) rnn_layer_name = rnn_layer.soft_get('name', rnn_layer.id) assert hidden_init_port in rnn_layer.in_nodes() hidden_size = rnn_layer.hidden_size shape = Shape(graph, dict(name=rnn_layer_name + '/ShapeOf')).create_node() rnn_layer.in_port(0).get_source().connect(shape.in_port(0)) batch = node_to_get_shape_value_of_indices(shape, int64_array([rnn_layer.batch_dim])) new_dim = create_op_node_with_second_input(graph, Concat, second_input_value=int64_array([hidden_size]), op_attrs=dict(name=rnn_layer_name + '/HiddenStateResizeDim', in_ports_count=2, axis=0), input_node=batch) reshape_h = Reshape(graph, dict(name=rnn_layer_name + '/HiddenStateResize', override_output_shape=True)).create_node() new_dim.out_port(0).connect(reshape_h.in_port(1)) rnn_layer.in_port(hidden_init_port).get_connection().insert_node(reshape_h) if rnn_layer.op == 'LSTM': assert cell_init_port in rnn_layer.in_nodes() reshape_c = Reshape(graph, dict(name=rnn_layer_name + '/CellStateResize', override_output_shape=True)).create_node() new_dim.out_port(0).connect(reshape_c.in_port(1)) rnn_layer.in_port(cell_init_port).get_connection().insert_node(reshape_c)
def replace_pattern(self, graph: Graph, match: dict): node = match['pb'] name = node.soft_get('name', node.id) graph.graph['cmd_params'].static_shape = False assert len(node.in_ports()) == 2 begin = Const(graph, {'value': np.array([2], dtype=np.int32), 'name': name + '/ss_begin'}).create_node() end = Const(graph, {'value': np.array([4], dtype=np.int32), 'name': name + '/ss_end'}).create_node() stride = Const(graph, {'value': np.array([1], dtype=np.int32), 'name': name + '/ss_stride'}).create_node() shape_0 = Shape(graph, {'name': name + '/0_port'}).create_node() ss_0 = StridedSlice(graph, {'name': name + '/ss_0_port', 'begin_mask': np.array([1], dtype=np.int32), 'end_mask': np.array([0], dtype=np.int32), 'new_axis_mask': np.array([0], dtype=np.int32), 'shrink_axis_mask': np.array([0], dtype=np.int32), 'ellipsis_mask': np.array([0], dtype=np.int32)}).create_node() shape_0.out_port(0).connect(ss_0.in_port(0)) begin.out_port(0).connect(ss_0.in_port(1)) end.out_port(0).connect(ss_0.in_port(2)) stride.out_port(0).connect(ss_0.in_port(3)) source = node.in_port(0).get_connection().get_source() node.in_port(0).disconnect() source.connect(shape_0.in_port(0)) ss_0.out_port(0).connect(node.in_port(0)) shape_1 = Shape(graph, {'name': name + '/1_port'}).create_node() ss_1 = StridedSlice(graph, {'name': name + '/ss_1_port', 'begin_mask': np.array([1], dtype=np.int32), 'end_mask': np.array([0], dtype=np.int32), 'new_axis_mask': np.array([0], dtype=np.int32), 'shrink_axis_mask': np.array([0], dtype=np.int32), 'ellipsis_mask': np.array([0], dtype=np.int32)}).create_node() shape_1.out_port(0).connect(ss_1.in_port(0)) begin.out_port(0).connect(ss_1.in_port(1)) end.out_port(0).connect(ss_1.in_port(2)) stride.out_port(0).connect(ss_1.in_port(3)) source = node.in_port(1).get_connection().get_source() node.in_port(1).disconnect() source.connect(shape_1.in_port(0)) ss_1.out_port(0).connect(node.in_port(1)) ss_0['force_precision_in_ports'] = {1: 'int64', 2: 'int64', 3: 'int64'} ss_1['force_precision_in_ports'] = {1: 'int64', 2: 'int64', 3: 'int64'} node['need_shape_inference'] = True node['override_output_shape'] = True node['V10_infer'] = True unsqueeze = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]), {'name': name + '/unsqueeze'}) naked_priorbox_name = name + '/naked_not_unsqueezed' rename_nodes([(node, naked_priorbox_name), (unsqueeze, name)]) node.out_port(0).get_connection().set_source(unsqueeze.out_port(0)) node.out_port(0).connect(unsqueeze.in_port(0))
def replace_pattern(self, graph: Graph, match: dict): node = match['pb'] assert len(node.in_ports()) == 2 begin = Const(graph, {'value': np.array([2])}).create_node() end = Const(graph, {'value': np.array([4])}).create_node() stride = Const(graph, {'value': np.array([1])}).create_node() shape_0 = Shape(graph, { 'name': node.name + '/0_port', 'stop_value_propagation': True }).create_node() ss_0 = StridedSlice( graph, { 'name': node.name + '/ss_0_port', 'begin_mask': np.array([1]), 'end_mask': np.array([0]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0]) }).create_node() shape_0.out_port(0).connect(ss_0.in_port(0)) begin.out_port(0).connect(ss_0.in_port(1)) end.out_port(0).connect(ss_0.in_port(2)) stride.out_port(0).connect(ss_0.in_port(3)) source = node.in_port(0).get_connection().get_source() node.in_port(0).disconnect() source.connect(shape_0.in_port(0)) ss_0.out_port(0).connect(node.in_port(0)) shape_1 = Shape(graph, { 'name': node.name + '/1_port', 'stop_value_propagation': True }).create_node() ss_1 = StridedSlice( graph, { 'name': node.name + '/ss_1_port', 'begin_mask': np.array([1]), 'end_mask': np.array([0]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0]) }).create_node() shape_1.out_port(0).connect(ss_1.in_port(0)) begin.out_port(0).connect(ss_1.in_port(1)) end.out_port(0).connect(ss_1.in_port(2)) stride.out_port(0).connect(ss_1.in_port(3)) source = node.in_port(1).get_connection().get_source() node.in_port(1).disconnect() source.connect(shape_1.in_port(0)) ss_1.out_port(0).connect(node.in_port(1)) ss_0['force_precision_in_ports'] = {1: 'int64', 2: 'int64', 3: 'int64'} ss_1['force_precision_in_ports'] = {1: 'int64', 2: 'int64', 3: 'int64'}
def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(op='SpaceToBatch') + graph.get_op_nodes(op='BatchToSpace'): node.add_input_port(3, skip_if_exist=True) # convert TF representation of the pads/crops as [N, 2] to IE representation: [N] and [N] transposed_pads = create_op_with_const_inputs(graph, Transpose, {1: int64_array([1, 0])}) node.in_port(2).get_connection().set_destination(transposed_pads.in_port(0)) split_pads = create_op_with_const_inputs(graph, Split, {1: int64_array(0)}, {'num_splits': 2}) transposed_pads.out_port(0).connect(split_pads.in_port(0)) for port_ind in range(2): node.in_port(port_ind + 2).connect(split_pads.out_port(port_ind)) node.in_port(port_ind + 2).get_connection().insert_node( create_op_with_const_inputs(graph, Squeeze, {1: int64_array([0])})) # add zeros/ones to related inputs to align it with data input in0_rank = Rank(graph, {'name': node.name + '/rank_0'}).create_node() in1_rank = Shape(graph, {'name': node.name + '/rank_1'}).create_node() diff_size = Sub(graph, {'name': node.name + '/sub_0'}).create_node() diff = Sub(graph, {'name': node.name + '/sub_1'}).create_node() const_begin = Const(graph, {'value': int64_array([1])}).create_node() const_pad_val = Const(graph, {'value': int64_array([1])}).create_node() block_shape = Pad(graph, {'name': node.name + '/aligned_block_shape', 'mode': 'constant'}).create_node() # in case of SpaceToBatch begin = pads_begin, end = pads_end # in case of BatchToSpace begin = crops_begin, end = crops_end new_begin_name = '/aligned_pads_begin' new_end_name = '/aligned_pads_end' if node.type == 'BatchToSpace': new_begin_name = '/aligned_crops_begin' new_end_name = '/aligned_crops_end' begin = Pad(graph, {'name': node.name + new_begin_name, 'mode': 'constant'}).create_node() end = Pad(graph, {'name': node.name + new_end_name, 'mode': 'constant'}).create_node() in0_rank_1d = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]), {'name': node.name + '/1d_rank_of_0'}, in0_rank) in1_rank_1d = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]), {'name': node.name + '/1d_rank_of_1'}, in1_rank) node.in_port(0).get_source().connect(in0_rank.in_port(0)) node.in_port(1).get_source().connect(in1_rank.in_port(0)) in0_rank_1d.out_port(0).connect(diff_size.in_port(0)) in1_rank_1d.out_port(0).connect(diff_size.in_port(1)) diff_size.out_port(0).connect(diff.in_port(0)) const_begin.out_port(0).connect(diff.in_port(1)) const_pad_val.out_port(0).connect(block_shape.in_port(3)) inputs_array = [block_shape, begin, end] for idx, input_to_node in enumerate(inputs_array): node.in_port(idx + 1).get_connection().set_destination(input_to_node.in_port(0)) const_begin.out_port(0).connect(input_to_node.in_port(1)) diff.out_port(0).connect(input_to_node.in_port(2)) input_to_node.out_port(0).connect(node.in_port(idx + 1))
def placeholder_scales(self, placeholder: Node): """ Helper function to get scales for prior boxes out of input image size: [1 / im_width, 1 / im_height, 1 / im_width, 1 / im_height] """ graph = placeholder.graph name = placeholder.soft_get('name', placeholder.id) shape_value = placeholder.soft_get('shape', None) assert shape_value is not None, \ "[ {} replacer ] Placeholder `{}` should have shape attribute".format(self.replacement_id, name) assert isinstance(shape_value, np.ndarray), \ "[ {} replacer ] Placeholder `{}` shape attribute should be np.ndarray".format(self.replacement_id, name) assert shape_value.size == 4, \ "[ {} replacer ] Placeholder `{}` should be 4D. Shape: {}".format(self.replacement_id, name, shape_value) shape = Shape(graph, {'name': 'input_image_shape'}).create_node() shape.in_port(0).connect(placeholder.out_port(0)) begin = Const(graph, {'value': int64_array([1])}).create_node() end = Const(graph, {'value': int64_array([3])}).create_node() stride = Const(graph, {'value': int64_array([1])}).create_node() spatial = StridedSlice(graph, {'name': name + '/get_h_w', 'begin_mask': np.array([1]), 'end_mask': np.array([1]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node() spatial.in_port(0).connect(shape.out_port(0)) spatial.in_port(1).connect(begin.out_port(0)) spatial.in_port(2).connect(end.out_port(0)) spatial.in_port(3).connect(stride.out_port(0)) power = Const(graph, {'value': float32_array([-1.])}).create_node() spatial_scale = Pow(graph, {}).create_node() spatial_scale.in_port(0).connect(spatial.out_port(0)) spatial_scale.in_port(1).connect(power.out_port(0)) # Power `type_infer` requires inputs to have equal data type convert_to_fp32 = Cast(graph, {'dst_type': np.float32}).create_node() spatial_scale.in_port(0).get_connection().insert_node(convert_to_fp32) order = Const(graph, {'value': int64_array([1, 0])}).create_node() axis_const = Const(graph, {'value': int64_array(0)}).create_node() reverse = Gather(graph, {}).create_node() reverse.in_port(0).connect(spatial_scale.out_port(0)) reverse.in_port(1).connect(order.out_port(0)) axis_const.out_port(0).connect(reverse.in_port(2)) priors_scale_node = Concat(graph, {'axis': 0, 'in_ports_count': 2}).create_node() priors_scale_node.add_input_port(0, skip_if_exist=True) priors_scale_node.add_input_port(1, skip_if_exist=True) priors_scale_node.in_port(0).connect(reverse.out_port(0)) priors_scale_node.in_port(1).connect(reverse.out_port(0)) return priors_scale_node
def replace_interpolate_pattern(graph: Graph, match: dict): split = match['split'] scale = np.array([get_split_scale(split)], dtype=np.float32) axis = int(split.in_port(1).get_connection().get_source().node.value) split_node_name = split.name axis_node = Const(graph, {'name': split_node_name + '/axis', 'value': int64_array([axis])}).create_node() shape_node = Shape(graph, dict(name=split_node_name + '/Shape')).create_node() scales_node = Const(graph, dict(name=split_node_name + '/scales', value=scale)).create_node() mul_node = Mul(graph, dict(name=split_node_name + '/Mul')).create_node() scales_node.out_port(0).connect(mul_node.in_port(1)) strided_slice_node = create_op_with_const_inputs(graph, StridedSlice, {1: int64_array([axis]), 2: int64_array([axis + 1])}, { 'name': split_node_name + '/StridedSlice', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) shape_node.out_port(0).connect(strided_slice_node.in_port(0)) cast_shape_to_float = Cast(graph, {'dst_type': np.float32}).create_node() strided_slice_node.out_port(0).connect(cast_shape_to_float.in_port(0)) cast_shape_to_float.out_port(0).connect(mul_node.in_port(0)) interp_node = Interpolate(graph, dict(name=split_node_name + '/Interpolate', mode='nearest', antialias=0, pads_begin=int64_array([0]), pads_end=int64_array([0]), coordinate_transformation_mode='half_pixel', nearest_mode='round_prefer_floor', cube_coeff=-0.75, version='opset4', shape_calculation_mode='scales', in_ports_count=4, maybe_part_of_sequence=True)).create_node() floor_node = Floor(graph, {'name': split_node_name + '/Floor'}).create_node() cast_mul_result_to_int = Cast(graph, {'dst_type': np.int64}).create_node() mul_node.out_port(0).connect(floor_node.in_port(0)) floor_node.out_port(0).connect(cast_mul_result_to_int.in_port(0)) cast_mul_result_to_int.out_port(0).connect(interp_node.in_port(1)) scales_node.out_port(0).connect(interp_node.in_port(2)) axis_node.out_port(0).connect(interp_node.in_port(3)) match['concat'].out_port(0).get_connection().set_source(interp_node.out_port(0)) split_connection = split.in_port(0).get_connection() split_connection.set_destination(interp_node.in_port(0)) split_connection.get_source().connect(shape_node.in_port(0))
def replace_pattern(self, graph: Graph, match: dict): if not self.is_applicable(match): return unsqueeze_node = match['unsqueeze'] unsqueeze_name = unsqueeze_node.soft_get('name', unsqueeze_node.id) second_input_of_unsqueeze = unsqueeze_node.in_port( 1).get_connection().get_source().node d_idx = int(second_input_of_unsqueeze.value) axis = d_idx - 1 shape_node = Shape(graph, dict(name=unsqueeze_name + '/Shape')).create_node() axis_len_node = node_to_get_shape_value_of_indices(shape_node, [axis]) second_input_of_tile = match['tile'].in_port( 1).get_connection().get_source().node scale = int64_array([second_input_of_tile.value[d_idx]]) float_scale = float32_array([second_input_of_tile.value[d_idx]]) mul_node = create_op_with_const_inputs( graph, Mul, {1: scale}, {'name': unsqueeze_name + '/Mul'}) axis_len_node.out_port(0).connect(mul_node.in_port(0)) interp_node = create_op_with_const_inputs( graph, Interpolate, { 2: float_scale, 3: int64_array([axis]) }, { 'mode': 'nearest', 'antialias': 0, 'pads_begin': int64_array([0]), 'pads_end': int64_array([0]), 'coordinate_transformation_mode': 'half_pixel', 'nearest_mode': 'round_prefer_floor', 'cube_coeff': -0.75, 'version': 'opset4', 'shape_calculation_mode': 'scales', 'in_ports_count': 4, 'maybe_part_of_sequence': True }) mul_node.out_port(0).connect(interp_node.in_port(1)) reshape_node = match['reshape'] reshape_node.out_port(0).get_connection().set_source( interp_node.out_port(0)) reshape_name = reshape_node.soft_get('name', reshape_node.id) rename_nodes([(reshape_node, reshape_name + '/delete'), (interp_node, reshape_name)]) unsqueeze_connection = unsqueeze_node.in_port(0).get_connection() unsqueeze_connection.set_destination(interp_node.in_port(0)) unsqueeze_connection.get_source().connect(shape_node.in_port(0))
def replace_sub_graph(self, graph: Graph, match: dict): node = match['op'] name = node.soft_get('name', node.id) assert node.has_valid('output_type'), \ 'Size node should have `output_type` attribute, but it`s not for node {}'.format(name) shape = Shape(graph, {'name': name + '/Shape/', 'output_type': node.output_type}).create_node() node.in_port(0).get_connection().set_destination(shape.in_port(0)) reduce_prod = create_op_node_with_second_input( graph, ReduceProd, int64_array([0]), {'name': shape.name + 'ReduceProd/', 'keep_dims': False}, shape) node.out_port(0).get_connection().set_source(reduce_prod.out_port(0)) rename_nodes([(node, name + '/ToBeDeleted'), (reduce_prod, name)])
def replace_pattern(graph: Graph, match: dict): node = match['pool'] if node.pool_step is None: node.stride = int64_array([1, 1, node.window[-1], node.window[-1]]) # create Reshape before convolution # shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride] shape = Shape(graph, {}).create_node() shape.in_port(0).connect(node.in_port(0).get_source()) split = create_op_with_const_inputs(graph, VariadicSplit, { 1: int64_array(0), 2: int64_array([1, -1]) }, {'out_ports_count': 2}, shape) node_pool_stride = Const(graph, { 'value': int64_array([node.pool_stride]) }).create_node() pow_node = create_op_node_with_second_input(graph, Pow, int64_array([-1])) pow_node.in_port(0).connect(node_pool_stride.out_port(0)) mul = Mul(graph, {}).create_node() mul.in_port(0).connect(split.out_port(1)) mul.in_port(1).connect(pow_node.out_port(0)) const_1 = Const(graph, {'value': int64_array([1])}).create_node() concat = Concat(graph, {'in_ports_count': 4, 'axis': 0}).create_node() concat.in_port(0).connect(split.out_port(0)) concat.in_port(3).connect(mul.out_port(0)) concat.in_port(2).connect(const_1.out_port(0)) concat.in_port(1).connect(node_pool_stride.out_port(0)) reshape_in = Reshape(graph, { 'name': '/Reshape/' + node.name }).create_node() reshape_in.in_port(1).connect(concat.out_port(0)) # create Reshape after Convolution reshape_out = create_op_node_with_second_input( graph, Reshape, int64_array([0, -1]), {'name': node.name + '/Reshape/'}) # connect input_reshape_node source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(reshape_in.out_port(0)) reshape_in.in_port(0).connect(source) # connect output_reshape_node node.out_port(0).get_connection().set_source(reshape_out.out_port(0)) node.out_port(0).connect(reshape_out.in_port(0))
def create_zero_value_with_batch_from_input(input_out_port: Port, second_dim, precision=np.float): # create init_graph connected to ReadValue graph = input_out_port.node.graph input_name = input_out_port.node.name shape_of_input = Shape(graph, { 'name': 'shape/' + input_name }).create_node() shape_of_input.in_port(0).connect(input_out_port) dim_for_get_batch = Const( graph, { 'name': 'dim/crop_batch/' + shape_of_input.name, 'value': int64_array([1]), 'shape': int64_array([1]) }).create_node() get_batch = Crop( graph, { 'name': 'crop_batch/' + shape_of_input.name, 'axis': int64_array([0]), 'offset': int64_array([0]) }).create_node() get_batch.in_port(0).connect(shape_of_input.out_port(0)) get_batch.in_port(1).connect(dim_for_get_batch.out_port(0)) mem_shape_2nd_dim = Const( graph, { 'name': 'gifo_r_weights_shape/' + input_name, 'value': int64_array([second_dim]), 'shape': int64_array([1]) }).create_node() mem_shape = Concat( graph, { 'name': 'gather_memory_shape/' + input_name, 'axis': 0, 'in_ports_count': 2 }).create_node() mem_shape.in_port(0).connect(get_batch.out_port(0)) mem_shape.in_port(1).connect(mem_shape_2nd_dim.out_port(0)) fill_value = Const( graph, { 'name': 'fill_value/' + input_name, 'value': np.array([0.0], precision), 'shape': int64_array([1]) }).create_node() init_value_prev_lstm_output = Broadcast(graph, { 'name': 'init_value/' + input_name, }).create_node() init_value_prev_lstm_output.in_port(0).connect(fill_value.out_port(0)) init_value_prev_lstm_output.in_port(1).connect(mem_shape.out_port(0)) return init_value_prev_lstm_output
def resolve_minus2(self, shape_node, input_index, reshape_index, dims): rank_node = Shape( shape_node.graph, dict(name=shape_node.id + '/RankShapeMXReshapeMinus2')).create_node() rank_node.in_port(0).connect(shape_node.out_port(0)) shape_values_node = get_shape_values_by_range_idxs(shape=shape_node, rank=rank_node, begin=input_index, end=-1, include_begin=True, include_end=True) input_index = None reshape_index = reshape_index + 1 return input_index, reshape_index, dims, shape_values_node
def replace_op(self, graph: Graph, node: Node): shape = Shape(graph, {'name': node.name + '/Shape/'}).create_node() reduce_prod = ReduceProd(graph, { 'name': shape.name + 'ReduceProd/', 'keep_dims': False }).create_node() reduce_axis = Const(graph, {'value': int64_array([0])}).create_node() # Connect nodes node.in_port(0).get_connection().set_destination(shape.in_port(0)) reduce_prod.in_port(0).get_connection().set_source(shape.out_port(0)) reduce_prod.in_port(1).get_connection().set_source( reduce_axis.out_port(0)) # The "explicit" version of the return value is: [(out_node.id, 0)]) return [reduce_prod.id]
def replace_sub_graph(self, graph: Graph, match: dict): if not check_applicability(match): return reshape = match['reshape'] div_name = match['division'].name input_shape = Shape(graph, dict(name=div_name + '/shape/MVN_T_')).create_node() shape_of_reshape = reshape.in_port(1).get_connection().get_source().node.value c1, c2 = shape_of_reshape[1], shape_of_reshape[2] c = c1 * c2 new_reshape = create_op_node_with_second_input(graph, Reshape, int64_array([0, 0, 0, c1, c2]), dict(name=div_name + '/first_reshape/MVN_T_')) permute_order = int64_array([0, 1, 2, 4, 3]) first_permute = create_op_node_with_second_input(graph, Transpose, permute_order, dict(name=div_name + '/first_permute/MVN_T_'), new_reshape) add = match['add'] variance = match['variance'] eps_port_num = 0 if add.in_port(0).get_connection().get_source().node.id != variance.id else 1 eps = add.in_port(eps_port_num).get_connection().get_source().node mvn_node = create_op_with_const_inputs(graph, MVN, {1: int64_array([1, 2, 3])}, dict(name=div_name + '/MVN/MVN_T_', eps=eps.value, normalize_variance=1, eps_mode='inside_sqrt')) first_permute.out_port(0).connect(mvn_node.in_port(0)) second_permute = create_op_node_with_second_input(graph, Transpose, permute_order, dict(name=div_name + '/second_permute/MVN_T_'), mvn_node) new_reshape2 = Reshape(graph, dict(name=div_name + '/second_reshape/MVN_T_')).create_node() second_permute.out_port(0).connect(new_reshape2.in_port(0)) gamma_val = np.reshape(match['gamma_identity'].in_port(0).get_connection().get_source().node.value, int64_array([1, 1, 1, c])) new_mul = create_op_node_with_second_input(graph, Mul, gamma_val, dict(name=match['mul'].name + '/MVN_T_'), new_reshape2) beta_val = np.reshape(match['beta_identity'].in_port(0).get_connection().get_source().node.value, int64_array([1, 1, 1, c])) new_add2 = create_op_node_with_second_input(graph, Add, beta_val, dict(name=match['add2'].name + '/MVN_T_'), new_mul) transpose_connection = match['transpose'].in_port(0).get_connection() before_transpose = transpose_connection.get_source().node transpose_connection.set_destination(new_reshape.in_port(0)) input_shape.out_port(0).connect(new_reshape2.in_port(1)) before_transpose.out_port(0).connect(input_shape.in_port(0)) match['transpose2'].out_port(0).get_connection().set_source(new_add2.out_port(0))
def replace_interpolate_pattern(graph: Graph, match: dict): split = match['split'] scale = int64_array([get_split_scale(split)]) axis = int(split.in_port(1).get_connection().get_source().node.value) split_node_name = split.name shape_node = Shape(graph, dict(name=split_node_name + '/Shape_')).create_node() scales_node = Const(graph, dict(name=split_node_name + '/scales_', value=scale)).create_node() mul_node = Mul(graph, dict(name=split_node_name + '/Mul_')).create_node() scales_node.out_port(0).connect(mul_node.in_port(1)) slice_begin = Const( graph, dict(name=split_node_name + '/slice_begin_', value=int64_array([axis]))).create_node() slice_end = Const( graph, dict(name=split_node_name + '/slice_end_', value=int64_array([axis + 1]))).create_node() strided_slice_node = StridedSlice( graph, { 'name': split_node_name + '/StridedSlice_', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]), }).create_node([shape_node, slice_begin, slice_end]) strided_slice_node.out_port(0).connect(mul_node.in_port(0)) interp_node = Interpolate( graph, dict(name=split_node_name + '/Interpolate_', axes=int64_array([axis]), mode='nearest')).create_node() mul_node.out_port(0).connect(interp_node.in_port(1)) match['concat'].out_port(0).get_connection().set_source( interp_node.out_port(0)) split_connection = split.in_port(0).get_connection() split_connection.set_destination(interp_node.in_port(0)) split_connection.get_source().connect(shape_node.in_port(0))
def append_variances(priors_scale_node: Node, variance: list): graph = priors_scale_node.graph name = priors_scale_node.name sp_shape = Shape(graph, {'name': name + '/shape'}).create_node() priors_scale_node.out_port(0).connect(sp_shape.in_port(0)) begin = Const(graph, {'value': np.array([-2])}).create_node() end = Const(graph, {'value': np.array([-1])}).create_node() stride = Const(graph, {'value': np.array([1])}).create_node() shape_part_for_tiling = StridedSlice(graph, {'name': name + '/get_-2_dim', 'begin_mask': np.array([1]), 'end_mask': np.array([1]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node() sp_shape.out_port(0).connect(shape_part_for_tiling.in_port(0)) begin.out_port(0).connect(shape_part_for_tiling.in_port(1)) end.out_port(0).connect(shape_part_for_tiling.in_port(2)) stride.out_port(0).connect(shape_part_for_tiling.in_port(3)) concat_value = Const(graph, {'value': np.array([4])}).create_node() shape_concat = Concat(graph, {'name': name + '/shape_for_tiling', 'in_ports_count': 2, 'axis': np.array(0)}).create_node() shape_part_for_tiling.out_port(0).connect(shape_concat.in_port(0)) concat_value.out_port(0).connect(shape_concat.in_port(1)) variance = Const(graph, {'name': name + '/variance', 'value': np.array(variance)}).create_node() tile = Broadcast(graph, {'name': name + '/variance_tile'}).create_node() variance.out_port(0).connect(tile.in_port(0)) shape_concat.out_port(0).connect(tile.in_port(1)) reshape_dim = Const(graph, {'value': int64_array([-1, 4])}).create_node() sp_reshape = Reshape(graph, {'name': name + '/reshape'}).create_node() sp_reshape.in_port(0).connect(priors_scale_node.out_port(0)) sp_reshape.in_port(1).connect(reshape_dim.out_port(0)) concat = Concat(graph, {'name': name + '/priors_concat', 'axis': np.array(0), 'in_ports_count': 2}).create_node() sp_reshape.out_port(0).connect(concat.in_port(0)) tile.out_port(0).connect(concat.in_port(1)) output_dims = Const(graph, {'value': int64_array([1, 2, -1])}).create_node() output_node = Reshape(graph, {'name': name + '/3D_priors_wth_variances'}).create_node() concat.out_port(0).connect(output_node.in_port(0)) output_dims.out_port(0).connect(output_node.in_port(1)) return output_node
def make_interpolate_reshapeable(interpolate): assert interpolate.soft_get('type') == 'Interpolate' axes = Interpolate.get_axes(interpolate) input_shape = interpolate.in_port(0).data.get_shape() output_shape = interpolate.out_port(0).data.get_shape() if not np.all(np.remainder(output_shape, input_shape) == 0) and \ not np.all(np.remainder(input_shape, output_shape) == 0): return graph = interpolate.graph name = interpolate.soft_get('name', interpolate.id) shape = Shape(graph, {'name': name + '/ShapeOf'}).create_node() shape.in_port(0).connect(interpolate.in_port(0).get_source()) gather = create_op_with_const_inputs(graph, Gather, {1: np.array(axes, dtype=np.int32), 2: int64_array(0)}, {'name': shape.name + '/Gathered'}, shape) multipliers = output_shape[axes] / input_shape[axes] mul = create_op_node_with_second_input(graph, Mul, multipliers, {'name': gather.name + '/Multiplied'}, gather) interpolate.in_port(1).get_connection().set_source(mul.out_port(0))
def replace_sub_graph(self, graph: Graph, match: dict): mxreshape = match['op'] if not mxreshape.reverse: return shape_node = Shape(graph, dict(name=mxreshape.id + '/Shape')).create_node() forward_reverse_unsqueeze_node = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]), dict(name=str(mxreshape.id) + '/ForwardUnsqueeze')) forward_reverse_node = Reverse(graph, dict(name=mxreshape.id + '/ForwardReverse', axis=1)).create_node() forward_reverse_squeeze_node = create_op_node_with_second_input(graph, Squeeze, int64_array([0]), dict(name=str(mxreshape.id) + '/ForwardSqueeze')) reshape_node = Reshape(graph, dict(name=mxreshape.id + '/Reshape')).create_node() shape_node.in_port(0).connect(mxreshape.in_port(0).get_source()) mxreshape.in_port(0).get_connection().set_destination(reshape_node.in_port(0)) forward_reverse_unsqueeze_node.in_port(0).connect(shape_node.out_port(0)) forward_reverse_node.in_port(0).connect(forward_reverse_unsqueeze_node.out_port(0)) forward_reverse_squeeze_node.in_port(0).connect(forward_reverse_node.out_port(0)) reshape_node.in_port(1).connect(forward_reverse_squeeze_node.out_port(0)) reshape_shape_node = create_op_node_with_second_input(graph, Reshape, int64_array(np.flip(mxreshape.dim, 0)), dict(name=str(mxreshape.id) + '/ReshapeShape')) if np.sum(np.in1d([-2, -3, -4], mxreshape.dim), axis=0): reshape_shape_node = MXReshape(graph, dict(name=mxreshape.id + '/Reshape', dim=int64_array(np.flip(mxreshape.dim, 0)))).create_node() reshape_shape_node.in_port(0).connect(reshape_node.out_port(0)) backward_shape_node = Shape(graph, dict(name=mxreshape.id + '/BackwardShape')).create_node() backward_reverse_unsqueeze_node = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]), dict(name=str(mxreshape.id) + '/BackwardUnsqueeze')) backward_reverse_node = Reverse(graph, dict(name=mxreshape.id + '/BackwardReverse', axis=1)).create_node() backward_reverse_squeeze_node = create_op_node_with_second_input(graph, Squeeze, int64_array([0]), dict(name=str(mxreshape.id) + '/BackwardSqueeze')) backward_reshape_node = Reshape(graph, dict(name=mxreshape.id + '/BackwardReshape')).create_node() backward_shape_node.in_port(0).connect(reshape_shape_node.out_port(0)) backward_reverse_unsqueeze_node.in_port(0).connect(backward_shape_node.out_port(0)) backward_reverse_node.in_port(0).connect(backward_reverse_unsqueeze_node.out_port(0)) backward_reverse_squeeze_node.in_port(0).connect(backward_reverse_node.out_port(0)) backward_reshape_node.in_port(0).connect(reshape_shape_node.out_port(0)) backward_reshape_node.in_port(1).connect(backward_reverse_squeeze_node.out_port(0)) mxreshape.out_port(0).get_connection().set_source(backward_reshape_node.out_port(0))
def replace_sub_graph(self, graph: Graph, match: dict): tf_slice_node = match['op'] slice_name = tf_slice_node.soft_get('name', tf_slice_node.id) slice_node = Slice(graph).create_node() rename_nodes([(tf_slice_node, slice_name + '/to_be_removed'), (slice_node, slice_name)]) ends_node = Add(graph, {'name': slice_name + '/ends'}).create_node() # reconnect input, begin, and size from TFSlice to the subgraph with Slice tf_slice_node.in_port(0).get_connection().set_destination( slice_node.in_port(0)) tf_slice_node.in_port(1).get_connection().set_destination( slice_node.in_port(1)) tf_slice_node.in_port(2).get_connection().set_destination( ends_node.in_port(0)) slice_node.in_port(1).get_connection().add_destination( ends_node.in_port(1)) max_ends = Shape(graph, { 'name': slice_name + '/ShapeOf' }).create_node() slice_node.in_port(0).get_connection().add_destination( max_ends.in_port(0)) # check if size[i] == -1, will be applied elementwisely: len(size) = len(begin) = input_rank where_max_ends_is_needed = create_op_with_const_inputs( graph, Equal, {0: int64_array(-1)}, {'name': slice_name + '/where_max_ends_is_needed'}) ends_node.in_port(0).get_connection().add_destination( where_max_ends_is_needed.in_port(1)) # select requires equal dtypes, need to convert ends to I64 ends_casted_to_i64 = Cast(graph, { 'name': slice_name + '/CastToI64', 'dst_type': np.int64 }).create_node([ends_node]) # if size[i] == 1 then take max_ends values correct_ends = Select(graph, { 'name': slice_name + '/chosen_ends' }).create_node( [where_max_ends_is_needed, max_ends, ends_casted_to_i64]) correct_ends.out_port(0).connect(slice_node.in_port(2)) tf_slice_node.out_port(0).get_connection().set_source( slice_node.out_port(0))
def replace_sub_graph(self, graph: Graph, match: dict): node = match['op'] name = node.soft_get('name', node.id) shape_of = Shape(graph, {'name': name + '/shape_of'}).create_node() rank_1d = Shape(graph, {'name': name + '/rank_of'}).create_node() rank_0d = create_op_node_with_second_input( graph, Squeeze, int64_array([0]), {'name': name + '/0d_rank_of'}, rank_1d) shape_of.out_port(0).connect(rank_1d.in_port(0)) node.out_port(0).get_connection().set_source(rank_0d.out_port(0)) node.in_port(0).get_connection().set_destination(shape_of.in_port(0))
def replace_op(self, graph: Graph, node: Node): node_name = node.soft_get('name', node.id) assert node.has_valid( 'axis' ), 'The node "{}" does not have mandatory attribute "axis"'.format( node_name) flatten_node = FlattenONNX(graph, { 'name': node_name + '/FlattenONNX_', 'axis': node.axis }).create_node() shape_node = Shape(graph, { 'name': node_name + '/ShapeOf_' }).create_node() logsoftmax_node = LogSoftmax(graph, { 'name': node_name + '/LogSoftmax_', 'axis': 1 }).create_node() reshape_node = Reshape(graph, {}).create_node() rename_nodes([(node, node_name + '/delete'), (reshape_node, node_name)]) shape_node.out_port(0).connect(reshape_node.in_port(1)) logsoftmax_node.out_port(0).connect(reshape_node.in_port(0)) flatten_node.out_port(0).connect(logsoftmax_node.in_port(0)) source = node.in_port(0).get_source() flatten_node.in_port(0).connect(source) shape_node.in_port(0).connect(source) return [reshape_node.id]