def replace_op(self, graph: Graph, node: Node): node_name = node.soft_get('name', node.id) # broadcast default value to required shape broadcast_node = Broadcast(graph, { 'name': node_name + '/Broadcast_' }).create_node() node.in_port(1).get_connection().set_destination( broadcast_node.in_port(1)) if not node.in_port(3).disconnected(): node.in_port(3).get_connection().set_destination( broadcast_node.in_port(0)) else: broadcast_node.in_port(0).connect( Const( graph, { 'name': broadcast_node.name + '/FillValue_', 'value': np.float32(0) }).create_node().out_port(0)) # update broadcasted tensor with required values at required locations scatternd_node = ScatterNDUpdate( graph, { 'name': node_name + '/ScatterNDUpdate_' }).create_node() scatternd_node.in_port(0).connect(broadcast_node.out_port(0)) node.in_port(0).get_connection().set_destination( scatternd_node.in_port(1)) node.in_port(2).get_connection().set_destination( scatternd_node.in_port(2)) rename_nodes([(node, node_name + "/AbandonedName"), (scatternd_node, node_name)]) return [scatternd_node.id]
def append_variances(priors_scale_node: Node, variance: list): graph = priors_scale_node.graph name = priors_scale_node.name sp_shape = Shape(graph, {'name': name + '/shape'}).create_node() priors_scale_node.out_port(0).connect(sp_shape.in_port(0)) begin = Const(graph, {'value': int64_array([-2])}).create_node() end = Const(graph, {'value': int64_array([-1])}).create_node() stride = Const(graph, {'value': int64_array([1])}).create_node() shape_part_for_tiling = StridedSlice(graph, {'name': name + '/get_-2_dim', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0])}).create_node() sp_shape.out_port(0).connect(shape_part_for_tiling.in_port(0)) begin.out_port(0).connect(shape_part_for_tiling.in_port(1)) end.out_port(0).connect(shape_part_for_tiling.in_port(2)) stride.out_port(0).connect(shape_part_for_tiling.in_port(3)) shape_concat = create_op_node_with_second_input(graph, Concat, int64_array([4]), {'name': name + '/shape_for_tiling', 'in_ports_count': 2, 'axis': int64_array(0)}, shape_part_for_tiling) variance = Const(graph, {'name': name + '/variance', 'value': float32_array(variance)}).create_node() tile = Broadcast(graph, {'name': name + '/variance_tile'}).create_node() variance.out_port(0).connect(tile.in_port(0)) shape_concat.out_port(0).connect(tile.in_port(1)) reshape_dim = Const(graph, {'value': int64_array([-1, 4])}).create_node() sp_reshape = Reshape(graph, {'name': name + '/reshape'}).create_node() sp_reshape.in_port(0).connect(priors_scale_node.out_port(0)) sp_reshape.in_port(1).connect(reshape_dim.out_port(0)) concat = Concat(graph, {'name': name + '/priors_concat', 'axis': int64_array(0), 'in_ports_count': 2}).create_node() sp_reshape.out_port(0).connect(concat.in_port(0)) tile.out_port(0).connect(concat.in_port(1)) output_dims = Const(graph, {'value': int64_array([1, 2, -1])}).create_node() output_node = Reshape(graph, {'name': name + '/3D_priors_wth_variances'}).create_node() concat.out_port(0).connect(output_node.in_port(0)) output_dims.out_port(0).connect(output_node.in_port(1)) return output_node
def find_and_replace_pattern(self, graph: Graph): for const_of_shape_node in graph.get_op_nodes(op='ConstantOfShape'): broadcast_node = Broadcast(graph, {'name': const_of_shape_node.name + '/Broadcast'}).create_node() const_of_shape_node.in_port(0).get_connection().set_destination(broadcast_node.in_port(1)) broadcast_node.in_port(0).connect(Const(graph, {'name': broadcast_node.name + '/FillValue', 'value': const_of_shape_node.fill_value} ).create_node().out_port(0)) const_of_shape_node.out_port(0).get_connection().set_source(broadcast_node.out_port(0))
def find_and_replace_pattern(self, graph: Graph): for fill_node in graph.get_op_nodes(op='Fill'): name = fill_node.soft_get('name', fill_node.id) broadcast_node = Broadcast(graph, {'name': name + '/Broadcast'}).create_node() fill_node.in_port(0).get_connection().set_destination(broadcast_node.in_port(1)) fill_node.in_port(1).get_connection().set_destination(broadcast_node.in_port(0)) fill_node.out_port(0).get_connection().set_source(broadcast_node.out_port(0)) for fill_node in graph.get_op_nodes(op='ConstantFill'): name = fill_node.soft_get('name', fill_node.id) assert fill_node.has_valid('fill_value') assert fill_node.has_and_set('input_as_shape') const = Const(graph, {'value': mo_array(fill_node.fill_value), 'name': name + '/value'}).create_node() broadcast_node = Broadcast(graph, {'name': name + '/Broadcast'}).create_node() fill_node.in_port(0).get_connection().set_destination(broadcast_node.in_port(1)) const.out_port(0).connect(broadcast_node.in_port(0)) fill_node.out_port(0).get_connection().set_source(broadcast_node.out_port(0))
def find_and_replace_pattern(self, graph: Graph): reverse_nodes = graph.get_op_nodes(op='Reverse') for reverse in reverse_nodes: reverse_name = reverse.soft_get('name', reverse.id) assert reverse.in_port(1).disconnected() assert reverse.has_valid('axis') in_shape_rank = len(reverse.in_port(0).data.get_shape()) # 1. Add new dimension as batch for rank = 1 to have batch != seq_axis if in_shape_rank == 1: unsq_node = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]), {'name': reverse_name+"/Unsqueeze"}) reverse.in_port(0).get_source().connect(unsq_node.in_port(0)) new_in = unsq_node.out_port(0) batch_axis = 0 seq_axis = 1 else: new_in = reverse.in_port(0).get_source() seq_axis = reverse['axis'] batch_axis = 0 if seq_axis != 0 else 1 # 2. For ReverseSequence 1-port input is seq_lengths => create this input node as # shape[seq_axis] broadcasted to shape[batch_axis] # in ---> ShapeOf ----> Gather(seq_axis) ----> Broadcast-----> # | | # | -------> Gather(batch_axis)----------| shape_node = Shape(graph, {'name': reverse_name + "/Shape"}).create_node() new_in.connect(shape_node.in_port(0)) seq_axis_node = node_to_get_shape_value_of_indices(shape_node, [seq_axis]) batch_node = node_to_get_shape_value_of_indices(shape_node, [batch_axis]) broadcast_node = Broadcast(graph, {'name': reverse_name + "/Broadcast"}).create_node() broadcast_node.in_port(0).connect(seq_axis_node.out_port(0)) broadcast_node.in_port(1).connect(batch_node.out_port(0)) # 3. Create new ReverseSequence node and reconnect all inputs/outputs to it rename_node(reverse, reverse_name + '/to_delete') reverse_sequence = ReverseSequence(graph, {'name': reverse_name, 'seq_axis': seq_axis, 'batch_axis': batch_axis}).create_node() reverse_sequence.in_port(0).connect(new_in) reverse_sequence.in_port(1).connect(broadcast_node.out_port(0)) # 4. remove added dimension for rank = 1 if in_shape_rank == 1: rename_node(reverse_sequence, reverse_name + '/ReverseSequence') squeeze_node = create_op_node_with_second_input(graph, Squeeze, int64_array([0]), {'name': reverse_name}) squeeze_node.in_port(0).connect(reverse_sequence.out_port(0)) reverse.out_port(0).get_connection().set_source(squeeze_node.out_port(0)) else: reverse.out_port(0).get_connection().set_source(reverse_sequence.out_port(0)) # 5. Delete old Reverse node graph.remove_nodes_from([reverse.id for reverse in reverse_nodes])
def find_and_replace_pattern(self, graph: Graph): for tf_scatter_nd in graph.get_op_nodes(op='TFScatterND'): if not tf_scatter_nd.is_in_port_connected(0) or not tf_scatter_nd.is_in_port_connected(1) \ or not tf_scatter_nd.is_in_port_connected(2): continue name = tf_scatter_nd.soft_get('name', tf_scatter_nd.soft_get('id')) indices_port = tf_scatter_nd.in_port(0).get_source() updates_port = tf_scatter_nd.in_port(1).get_source() shape_port = tf_scatter_nd.in_port(2).get_source() # need get type of const type zero_const = Const(graph, { 'value': int64_array(0.0), 'name': name + '/zero_const' }).create_node() # Convert zero value to type of updates node convert_to_type = ConvertLike(graph, { 'name': name + '/convert_like' }).create_node() convert_to_type.in_port(0).connect(zero_const.out_port(0)) convert_to_type.in_port(1).connect(updates_port) broad_cast_node = Broadcast(graph, { 'name': name + '/broadcast' }).create_node() broad_cast_node.in_port(0).connect(convert_to_type.out_port(0)) broad_cast_node.in_port(1).connect(shape_port) scatter_nd_node = ScatterNDUpdate(graph, { 'name': name + '/replaced' }).create_node() scatter_nd_node.in_port(0).connect(broad_cast_node.out_port(0)) scatter_nd_node.in_port(1).connect(indices_port) scatter_nd_node.in_port(2).connect(updates_port) rename_nodes([(tf_scatter_nd, name + '/TBD'), (scatter_nd_node, name)]) tf_scatter_nd.out_port(0).get_connection().set_source( scatter_nd_node.out_port(0)) tf_scatter_nd.in_port(0).disconnect() tf_scatter_nd.in_port(1).disconnect() tf_scatter_nd.in_port(2).disconnect()
def generate_sub_graph(self, graph: Graph, match: SubgraphMatch): reshape_classes_node = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]), dict(name='do_reshape_classes'), match.single_input_node(1)[0]) initial_priors_node = match.single_input_node(2)[0] priors_name = initial_priors_node.soft_get('name', initial_priors_node.id) # model calculates identical prior boxes for each batch, so we take first slice of them begin = Const(graph, {'value': mo_array([0, 0, 0], dtype=np.int32)}).create_node() end = Const(graph, {'value': mo_array([1, 0, 0], dtype=np.int32)}).create_node() stride = Const(graph, {'value': mo_array([1, 1, 1], dtype=np.int32)}).create_node() priors_node = StridedSlice(graph, {'name': priors_name + '/0_batch_slice', 'begin_mask': int64_array([1, 1, 1]), 'end_mask': int64_array([1, 0, 0]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0])}).create_node() initial_priors_node.out_port(0).connect(priors_node.in_port(0)) begin.out_port(0).connect(priors_node.in_port(1)) end.out_port(0).connect(priors_node.in_port(2)) stride.out_port(0).connect(priors_node.in_port(3)) placeholders = graph.get_op_nodes(type='Parameter') assert len(placeholders) == 1, "{} replacer requires model to have one Placeholder, but current model has " \ "{} placeholders".format(self.replacement_id, len(placeholders)) placeholder = placeholders[0] # scale prior boxes to the [0, 1] interval node_with_scales_for_prior_boxes = self.placeholder_scales(placeholder) priors_scale_node = Mul(graph, {'name': 'scale_priors'}).create_node() broadcast = Broadcast(graph, {'name': 'scales_broadcast'}).create_node() shape_of_priors = Shape(graph, {'name': 'priors_shape'}).create_node() priors_node.out_port(0).connect(shape_of_priors.in_port(0)) broadcast.in_port(1).connect(shape_of_priors.out_port(0)) broadcast.in_port(0).connect(node_with_scales_for_prior_boxes.out_port(0)) priors_scale_node.in_port(0).connect(priors_node.out_port(0)) priors_scale_node.in_port(1).connect(broadcast.out_port(0)) try: variance = match.custom_replacement_desc.custom_attributes['variance'] except: raise Error('There is no variance attribute in {} replacement config file `custom_attributes`' ''.format(self.replacement_id)) priors = self.append_variances(priors_scale_node, variance) # calculate prior boxes widths and heights split_node = create_op_with_const_inputs( graph, VariadicSplit, {1: int64_array(2), 2: int64_array([1, 1, 1, 1])}, {'out_ports_count': 4}, priors_scale_node) priors_width_node = Sub(graph, dict(name=split_node.name + '/sub_2-0_') ).create_node([(split_node, 2), (split_node, 0)]) priors_height_node = Sub(graph, dict(name=split_node.name + '/sub_3-1_') ).create_node([(split_node, 3), (split_node, 1)]) # concat weights and heights into a single tensor and multiple with the box coordinates regression values # WA with 3 Concats instead of 1 for keeping model reshapable # concat_width_height_node = Concat(graph, {'name': 'concat_priors_width_height', 'axis': -1, # 'in_ports_count': 4}).create_node( # [priors_width_node, priors_height_node, priors_width_node, priors_height_node]) concat_1 = Concat(graph, {'name': 'concat_width_height', 'axis': -1, 'in_ports_count': 2}).create_node([priors_width_node, priors_height_node]) concat_2 = Concat(graph, {'name': 'concat_width_height_width', 'axis': -1, 'in_ports_count': 2}).create_node([concat_1, priors_width_node]) concat_width_height_node = Concat(graph, {'name': 'concat_priors_width_height', 'axis': -1, 'in_ports_count': 2} ).create_node([concat_2, priors_height_node]) applied_width_height_regressions_node = Mul(graph, {'name': 'final_regressions'}).create_node( [concat_width_height_node, match.single_input_node(0)[0]]) # reshape to 2D tensor as Inference Engine Detection Output layer expects reshape_regression_node = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]), dict(name='reshape_regression'), applied_width_height_regressions_node) detection_output_op = DetectionOutput(graph, match.custom_replacement_desc.custom_attributes) # get nms from the original network iou_threshold = None nms_nodes = graph.get_op_nodes(op='NonMaxSuppression') if len(nms_nodes) > 0: # it is highly unlikely that for different classes NMS has different # moreover DetectionOutput accepts only scalar values for iou_threshold (nms_threshold) iou_threshold = nms_nodes[0].in_node(3).value if iou_threshold is None: raise Error('During {} `iou_threshold` was not retrieved from RetinaNet graph'.format(self.replacement_id)) detection_output_node = detection_output_op.create_node( [reshape_regression_node, reshape_classes_node, priors], dict(name=detection_output_op.attrs['type'], nms_threshold=iou_threshold, clip_after_nms=1, normalized=1, variance_encoded_in_target=0, background_label_id=1000)) # As outputs are replaced with a postprocessing node, outgoing tensor names are no longer # correspond to original tensors and should be removed from output->Result edges out_nodes = [] for out in range(match.outputs_count()): out_nodes.append(match.output_node(out)[0]) clear_tensor_names_info(out_nodes) return {'detection_output_node': detection_output_node}