def replace_sub_graph(self, graph: Graph, match: dict): fbn = match['fbn'] input = fbn.in_node(0) log.debug('Found potential MVN pattern after {} with name {}'.format(input.op, input.name)) if input.id != match['mean'].in_node(0).id or input.id != match['sqdiff'].in_node(0).id: return log.debug('Confirmed MVN pattern after {} with name {}'.format(input.op, input.name)) mvn = MVN(graph, dict( name=fbn.name + '/MVN_', eps=fbn.eps, eps_mode='outside_sqrt', normalize_variance=1 )) mvn.attrs['old_infer'] = mvn.attrs['infer'] mvn.attrs['infer'] = __class__.infer mul = Mul(graph, dict(operation='mul', name=fbn.name + '/Mul_')) add = Add(graph, dict(operation='sum', name=fbn.name + '/Add_')) input_gamma = fbn.in_node(1) input_beta = fbn.in_node(2) mean_reduction = match['mean'].in_node(1) variance_reduction = match['variance'].in_node(1) new_subgraph = add.create_node([ mul.create_node([ mvn.create_node([input, mean_reduction, variance_reduction]), input_gamma ]), input_beta ]) fbn.replace_node(new_subgraph)
def replace_op(self, graph: Graph, node: Node): prefix = node.name + '/InstanceNormalization' mvn = MVN(graph, dict(name=prefix + '/MVN', eps=node.epsilon)) mul = Mul(graph, dict(name=prefix + '/Mul', axis=1)) add = Add(graph, dict(name=prefix + '/Add', axis=1)) new_subgraph = add.create_node([ mul.create_node( [mvn.create_node([node.in_node(0)]), node.in_node(1)]), node.in_node(2) ]) return [new_subgraph.id]
def replace_sub_graph(graph: Graph, match: dict): mvn = MVN(graph, dict( name=match['truediv'].name + '/MVN_', eps_mode='outside_sqrt', normalize_variance=1 )) mvn.attrs['old_infer'] = mvn.attrs['infer'] mvn.attrs['infer'] = __class__.infer mean_reduction = match['mean'].in_node(1) variance_reduction = match['variance'].in_node(1) pow2 = match['pow'].in_node(1) eps = match['add'].in_node(0 if match['add'].in_node(0).id != match['variance'].id else 1) new_subgraph = mvn.create_node([match['mean'].in_node(0), mean_reduction, variance_reduction, pow2, eps]) match['truediv'].replace_node(new_subgraph)
def replace_sub_graph(graph: Graph, match: dict): mvn = MVN( graph, dict(name=match['truediv'].name + '/MVN_', required_reduction_indices=[1, 2] if graph.graph['layout'] == 'NHWC' else [2, 3])) mvn.attrs['old_infer'] = mvn.attrs['infer'] mvn.attrs['infer'] = __class__.infer mean_reduction = match['mean'].in_node(1) variance_reduction = match['variance'].in_node(1) pow2 = match['pow'].in_node(1) eps = match['add'].in_node( 0 if match['add'].in_node(0).id != match['variance'].id else 1) new_subgraph = mvn.create_node([ match['mean'].in_node(0), mean_reduction, variance_reduction, pow2, eps ]) match['truediv'].replace_node(new_subgraph)