def replace_sub_graph(self, graph: Graph, match: dict): op = match['op'] out_port = op.in_port(0).get_source() if op.soft_get('scale', 1) != 1: const = Const(graph, {'value': np.array(op.scale)}).create_node() mul = Mul(graph, {'name': op.name + '/mul_'}).create_node() const.out_port(0).connect(mul.in_port(1)) out_port.connect(mul.in_port(0)) out_port = mul.out_port(0) if op.soft_get('shift', 0) != 0: const = Const(graph, {'value': np.array(op.shift)}).create_node() add = Add(graph, {'name': op.name + '/add_'}).create_node() const.out_port(0).connect(add.in_port(1)) out_port.connect(add.in_port(0)) out_port = add.out_port(0) if op.soft_get('power', 1) != 1: const = Const(graph, {'value': np.array(op.power)}).create_node() pow = Pow(graph, {'name': op.name + '/pow_'}).create_node() const.out_port(0).connect(pow.in_port(1)) out_port.connect(pow.in_port(0)) out_port = pow.out_port(0) op.out_port(0).get_connection().set_source(out_port)
def replace_op(self, graph: Graph, node: Node): const = Const( graph, dict(value=np.array(-1.), name=node.name + '/reciprocal_pow_const_')).create_node() reciprocal = Pow(graph, { 'name': node.name + '/reciprocal_pow_' }).create_node() node.in_port(0).get_connection().set_destination(reciprocal.in_port(0)) const.out_port(0).connect(reciprocal.in_port(1)) return [reciprocal.id]
def replace_op(self, graph: Graph, node: Node): power_of_exponent = Const(graph, { 'value': np.float64(-1) }).create_node() reciprocal = Pow(graph, { 'name': node.name + '/reciprocal_' }).create_node() mul = Mul(graph, {'name': node.name + '/mul_'}).create_node() # Connect nodes node.in_port(1).get_connection().set_destination(reciprocal.in_port(0)) power_of_exponent.out_port(0).connect(reciprocal.in_port(1)) node.in_port(0).get_connection().set_destination(mul.in_port(1)) reciprocal.out_port(0).connect(mul.in_port(0)) # The "explicit" version of the return value is: [(out_node.id, 0)]) return [mul.id]
def replace_pattern(graph: Graph, match: [str, Node]): node = match['div'] power_of_exponent = Const(graph, { 'value': np.float64(-1) }).create_node() reciprocal = Pow(graph, { 'name': node.name + '/reciprocal_' }).create_node() mul = Mul(graph, {'name': node.name + '/mul_'}).create_node() # Connect nodes node.in_port(1).get_connection().set_destination(reciprocal.in_port(0)) power_of_exponent.out_port(0).connect(reciprocal.in_port(1)) node.in_port(0).get_connection().set_destination(mul.in_port(1)) reciprocal.out_port(0).connect(mul.in_port(0)) node.out_port(0).get_connection().set_source(mul.out_port(0))
def replace_op(self, graph: Graph, node: Node): # Create nodes const_neg = Const( graph, dict(value=np.array(-1), name=node.name + '/negate_const_')).create_node() negate = Mul(graph, {'name': node.name + '/negate_'}).create_node() add = Add(graph, {'name': node.name + '/add_'}).create_node() const = Const(graph, {'value': np.array(2)}).create_node() squared = Pow(graph, {'name': node.name + '/squared_'}).create_node() # Connect nodes node.in_port(0).get_connection().set_destination(add.in_port(0)) node.in_port(1).get_connection().set_destination(negate.in_port(0)) const_neg.out_port(0).connect(negate.in_port(1)) negate.out_port(0).connect(add.in_port(1)) add.out_port(0).connect(squared.in_port(0)) const.out_port(0).connect(squared.in_port(1)) # The "explicit" version of the return value is: [(out_node.id, 0)]) return [squared.id]
def placeholder_scales(self, placeholder: Node): """ Helper function to get scales for prior boxes out of input image size: [1 / im_width, 1 / im_height, 1 / im_width, 1 / im_height] """ graph = placeholder.graph name = placeholder.soft_get('name', placeholder.id) shape_value = placeholder.soft_get('shape', None) assert shape_value is not None, \ "[ {} replacer ] Placeholder `{}` should have shape attribute".format(self.replacement_id, name) assert isinstance(shape_value, np.ndarray), \ "[ {} replacer ] Placeholder `{}` shape attribute should be np.ndarray".format(self.replacement_id, name) assert shape_value.size == 4, \ "[ {} replacer ] Placeholder `{}` should be 4D. Shape: {}".format(self.replacement_id, name, shape_value) shape = Shape(graph, {'name': 'input_image_shape'}).create_node() shape.in_port(0).connect(placeholder.out_port(0)) begin = Const(graph, {'value': int64_array([1])}).create_node() end = Const(graph, {'value': int64_array([3])}).create_node() stride = Const(graph, {'value': int64_array([1])}).create_node() spatial = StridedSlice(graph, {'name': name + '/get_h_w', 'begin_mask': np.array([1]), 'end_mask': np.array([1]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node() spatial.in_port(0).connect(shape.out_port(0)) spatial.in_port(1).connect(begin.out_port(0)) spatial.in_port(2).connect(end.out_port(0)) spatial.in_port(3).connect(stride.out_port(0)) power = Const(graph, {'value': float32_array([-1.])}).create_node() spatial_scale = Pow(graph, {}).create_node() spatial_scale.in_port(0).connect(spatial.out_port(0)) spatial_scale.in_port(1).connect(power.out_port(0)) # Power `type_infer` requires inputs to have equal data type convert_to_fp32 = Cast(graph, {'dst_type': np.float32}).create_node() spatial_scale.in_port(0).get_connection().insert_node(convert_to_fp32) order = Const(graph, {'value': int64_array([1, 0])}).create_node() axis_const = Const(graph, {'value': int64_array(0)}).create_node() reverse = Gather(graph, {}).create_node() reverse.in_port(0).connect(spatial_scale.out_port(0)) reverse.in_port(1).connect(order.out_port(0)) axis_const.out_port(0).connect(reverse.in_port(2)) priors_scale_node = Concat(graph, {'axis': 0, 'in_ports_count': 2}).create_node() priors_scale_node.add_input_port(0, skip_if_exist=True) priors_scale_node.add_input_port(1, skip_if_exist=True) priors_scale_node.in_port(0).connect(reverse.out_port(0)) priors_scale_node.in_port(1).connect(reverse.out_port(0)) return priors_scale_node
def replace_pattern(self, graph: Graph, match: dict): merge = match['merge'] power = Pow(graph, { 'name': merge.name + '/reciprocal_', 'type': 'PNORM' }).create_node() const1 = Const(graph, { 'value': -1.0, 'name': merge.name + '/negate_const' }).create_node() merge.in_port(0).get_connection().set_destination(power.in_port(0)) const1.out_port(0).connect(power.in_port(1)) concat_node = Concat( graph, { 'axis': 0, 'name': merge.name + '/Concat_', 'override_output_shape': True }).create_node() const3 = Const(graph, { 'name': merge.name + '/const_reduce', 'value': 0 }).create_node() for ii, idx in enumerate( range(merge.significant, merge.to_significant + 1, 1)): const_node = Const( graph, { 'value': float_array(math.pow(10.0, idx)), 'name': merge.name + '/Const_' + ii.__str__() }).create_node() mul_node = Mul(graph, { 'name': merge.name + '/Mul_' + ii.__str__() }).create_node() const_node.out_port(0).connect(mul_node.in_port(0)) power.out_port(0).connect( mul_node.in_port(1)) # connect to the graph node mul_node2 = Mul(graph, { 'name': merge.name + '/Mul_Div_' + ii.__str__() }).create_node() const_node2 = Const( graph, { 'value': float_array(math.pow(10.0, -1 * idx)), 'name': merge.name + '/Const_Pow_' + ii.__str__() }).create_node() cast_node = Cast( graph, { 'name': merge.name + '/Cast_' + idx.__str__(), 'dst_type': np.float32 }).create_node() mul_node.out_port(0).connect(cast_node.in_port(0)) const_node2.out_port(0).connect(mul_node2.in_port(1)) cast_node.out_port(0).connect(mul_node2.in_port(0)) concat_node.add_input_port(ii, skip_if_exist=True) concat_node.in_port(ii).get_connection().set_source( mul_node2.out_port(0)) reducesum_node = ReduceMean( graph, { 'name': merge.id + '/_pnorm_reduced_sum', 'keep_dims': False, 'in_ports_count': 2, 'need_shape_inference': None, 'infer': reduce_infer }).create_node() const3.out_port(0).connect(reducesum_node.in_port(1)) reducesum_node.in_port(0).get_connection().set_source( concat_node.out_port(0)) reshape = Reshape(graph, { 'name': merge.name + '/Reshape_Node' }).create_node() reshape_dim = Const(graph, { 'value': np.array([1, 5]), 'name': merge.id + '/Reshape_Dim' }).create_node() reducesum_node.out_port(0).connect(reshape.in_port(0)) reshape.in_port(1).connect(reshape_dim.out_port(0)) merge.out_port(0).get_connection().set_source(reshape.out_port(0))