def replace_pattern(graph: Graph, match: dict): node = match['op'] shape = node.in_port(0).data.get_shape().copy() assert shape[1] % node.group == 0 power_node = Power(graph, attrs={'name': node.id + '_power', 'power': node.p}).create_node() reshape_node = create_op_node_with_second_input(graph, Reshape, int64_array([shape[0], shape[1] / node.group, node.group]), {'name': node.id + '_reshape'}) reshape_node.in_port(0).connect(power_node.out_port(0)) reducesum_node = create_op_node_with_second_input(graph, ReduceSum, int64_array([2]), {'name': node.id + '_sum', 'keep_dims': False}) reducesum_node.in_port(0).connect(reshape_node.out_port(0)) invpower_node = Power(graph, attrs={'name': node.id + '_invpower', 'power': 1.0 / node.p}).create_node() invpower_node.in_port(0).connect(reducesum_node.out_port(0)) node.in_port(0).get_connection().set_destination(power_node.in_port(0)) node.out_port(0).get_connection().set_source(invpower_node.out_port(0))
def replace_op(self, graph: Graph, node: Node): reciprocal = Power(graph, {'scale': 1, 'power': np.float64(-1), 'shift': 0, 'name': node.name + '/reciprocal_'}).create_node() mul = Eltwise(graph, {'operation': 'mul', 'name': node.name + '/mul_'}).create_node() # Connect nodes node.in_port(1).get_connection().set_destination(reciprocal.in_port(0)) node.in_port(0).get_connection().set_destination(mul.in_port(1)) reciprocal.out_port(0).connect(mul.in_port(0)) # The "explicit" version of the return value is: [(out_node.id, 0)]) return [mul.id]
def replace_op(self, graph: Graph, node: Node): power = Power(graph, dict(scale=0, name=node.name + '/Power/')).create_node() # Reconnecting inputs to this new node node.in_port(0).get_connection().set_destination(power.in_port(0)) node.out_port(0).get_connection().set_source(power.out_port(0)) return [power.id]
def replace_pattern(graph: Graph, match: [str, Node]): consumers = [ n for n in match if n not in ['mul_add', 'pow_d'] and not check_node_usages_out_of_match(match, n) ] if consumers: log.warning( 'Power(mul_add,pow) pattern was detected. Non pattern consumers of nodes: "{}" were found.' ' Won\'t replace'.format(', '.join( [match[n].id for n in consumers]))) return mul_add = match['mul_add'] pow = match['pow'] new_power = Power( graph, { 'name': mul_add.name + '/fused_power', 'shift': mul_add.shift, 'scale': mul_add.scale, 'power': pow.power }).create_node() source = mul_add.in_port(0).get_connection().get_source() mul_add.in_port(0).disconnect() new_power.in_port(0).connect(source) pow.out_port(0).get_connection().set_source(new_power.out_port(0)) log.debug( 'Power nodes {} and {} were fused to single Power node {}'.format( mul_add.name, pow.name, new_power.name))
def replace_pattern(self, graph: Graph, match: dict): const = 0.99 merge = match['merge'] digits = significant_digits() pnorm = Power( graph, { 'name': merge.name + '/reciprocal_', 'type': 'PNORM', 'significant': digits[0], 'to_significant': digits[1], 'scale': 1, 'shift': 0, 'power': get_power_attr() }).create_node() merge.in_port(0).get_connection().set_destination(pnorm.in_port(0)) in_shape = pnorm.in_port(0).data.get_shape() in_shape = list(in_shape) in_shape.insert(0, 1) reshape1 = Reshape(graph, { 'name': merge.name + '/Reshape_Node1' }).create_node() reshape_dim1 = Const(graph, { 'value': np.array(in_shape), 'name': merge.id + '/Reshape_Dim1' }).create_node() pnorm.out_port(0).connect(reshape1.in_port(0)) reshape1.in_port(1).connect(reshape_dim1.out_port(0)) concat_node = Concat( graph, { 'axis': 0, 'name': merge.name + '/Concat_', 'override_output_shape': True }).create_node() const3 = Const(graph, { 'name': merge.name + '/const_reduce', 'value': 0 }).create_node() for ii, idx in enumerate( range(pnorm.significant, pnorm.to_significant + 1, 1)): const_node = Const( graph, { 'value': float_array(math.pow(const, idx)), 'name': merge.name + '/Const_' + ii.__str__() }).create_node() mul_node = Mul(graph, { 'name': merge.name + '/Mul_' + ii.__str__() }).create_node() const_node.out_port(0).connect(mul_node.in_port(0)) reshape1.out_port(0).connect( mul_node.in_port(1)) # connect to the graph node mul_node2 = Mul(graph, { 'name': merge.name + '/Mul_Div_' + ii.__str__() }).create_node() const_node2 = Const( graph, { 'value': float_array(math.pow(const, -1 * idx)), 'name': merge.name + '/Const_Pow_' + ii.__str__() }).create_node() cast_node = ExpOp(graph, { 'name': merge.name + '/Exp_' + idx.__str__() }).create_node() mul_node.out_port(0).connect(cast_node.in_port(0)) const_node2.out_port(0).connect(mul_node2.in_port(1)) cast_node.out_port(0).connect(mul_node2.in_port(0)) concat_node.add_input_port(ii, skip_if_exist=True) concat_node.in_port(ii).get_connection().set_source( mul_node2.out_port(0)) in_shape = pnorm.in_port(0).data.get_shape() in_shape = list(in_shape) reducesum_node = ReduceMean( graph, { 'name': merge.id + '/_pnorm_reduced_sum', 'keep_dims': True, 'in_ports_count': 2, 'shape': in_shape, 'axis': 0, 'need_shape_inference': None, 'infer': reduce_infer }).create_node() const3.out_port(0).connect(reducesum_node.in_port(1)) reducesum_node.in_port(0).get_connection().set_source( concat_node.out_port(0)) reshape = Reshape(graph, { 'name': merge.name + '/Reshape_Node' }).create_node() reshape_dim = Const(graph, { 'value': np.array(in_shape), 'name': merge.id + '/Reshape_Dim' }).create_node() reducesum_node.out_port(0).connect(reshape.in_port(0)) reshape.in_port(1).connect(reshape_dim.out_port(0)) merge.out_port(0).get_connection().set_source(reshape.out_port(0))