def _create_sub(graph: Graph, input_1: Node, port_1: int, input_2: Node, port_2: int): negate = Power(graph, dict(scale=-1, name=input_2.name + '/negate_')) add = Eltwise(graph, dict(operation='sum', name=input_1.name + '/add_')) out_node = add.create_node([(input_1, port_1), negate.create_node([(input_2, port_2)])]) return out_node
def replace_op(self, graph: nx.MultiDiGraph, node: Node): negate = Power(graph, dict(scale=-1, name=node.name + '/negate_')) add = Eltwise(graph, dict(operation='sum', name=node.name + '/add_')) out_node = add.create_node([ (node.in_node(0), node.in_edge(0)['out']), negate.create_node([(node.in_node(1), node.in_edge(1)['out'])]) ]) # Replace edge from out port 0 of the matched node with a edge from node out_node.id with port 0. # The "explicit" version of the return value is: [(out_node.id, 0)]) return [out_node.id]
def replace_op(self, graph: Graph, node: Node): out_node = node.in_node(0) operation = node.operation for ind in range(1, len(node.in_nodes())): eltwise_op = Eltwise( graph, dict(operation=operation, name=node.name + '/' + operation + '_' + str(ind))) out_node = eltwise_op.create_node([out_node, node.in_node(ind)]) return [out_node.id]
def replace_sub_graph(self, graph: Graph, match: dict): fbn = match['fbn'] input = fbn.in_node(0) log.debug('Found potential MVN pattern after {} with name {}'.format( input.op, input.name)) if input.id != match['mean'].in_node( 0).id or input.id != match['sqdiff'].in_node(0).id: return log.debug('Confirmed MVN pattern after {} with name {}'.format( input.op, input.name)) MVN = Op.get_op_class_by_name('MVN') mvn = MVN( graph, dict(name=fbn.name + '/MVN_', eps=fbn.eps, required_reduction_indices=[1, 2] if fbn.data_format == b'NHWC' else [2, 3])) mvn.attrs['old_infer'] = mvn.attrs['infer'] mvn.attrs['infer'] = __class__.infer mul = Eltwise(graph, dict(operation='mul', name=fbn.name + '/Mul_')) add = Eltwise(graph, dict(operation='sum', name=fbn.name + '/Add_')) input_gamma = fbn.in_node(1) input_beta = fbn.in_node(2) mean_reduction = match['mean'].in_node(1) variance_reduction = match['variance'].in_node(1) new_subgraph = add.create_node([ mul.create_node([ mvn.create_node([input, mean_reduction, variance_reduction]), input_gamma ]), input_beta ]) fbn.replace_node(new_subgraph)
def replace_op(self, graph: nx.MultiDiGraph, node: Node): reciprocal = Power( graph, dict(scale=1, power=np.float64(-1), shift=0, name=node.name + '/reciprocal_')) mul = Eltwise(graph, dict(operation='mul', name=node.name + '/mul_')) out_node = mul.create_node([ (node.in_node(0), node.in_edge(0)['out']), reciprocal.create_node([(node.in_node(1), node.in_edge(1)['out'])]) ]) # Replace edge from out port 0 of the matched node with a edge from node out_node.id with port 0. # The "explicit" version of the return value is: [(out_node.id, 0)]) return [out_node.id]