예제 #1
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        fbn = match['fbn']
        input = fbn.in_node(0)
        log.debug('Found potential MVN pattern after {} with name {}'.format(input.op, input.name))
        if input.id != match['mean'].in_node(0).id or input.id != match['sqdiff'].in_node(0).id:
            return

        log.debug('Confirmed MVN pattern after {} with name {}'.format(input.op, input.name))

        mvn = MVN(graph, dict(
            name=fbn.name + '/MVN_',
            eps=fbn.eps,
            eps_mode='outside_sqrt',
            normalize_variance=1
        ))
        mvn.attrs['old_infer'] = mvn.attrs['infer']
        mvn.attrs['infer'] = __class__.infer

        mul = Mul(graph, dict(operation='mul', name=fbn.name + '/Mul_'))
        add = Add(graph, dict(operation='sum', name=fbn.name + '/Add_'))

        input_gamma = fbn.in_node(1)
        input_beta = fbn.in_node(2)

        mean_reduction = match['mean'].in_node(1)
        variance_reduction = match['variance'].in_node(1)

        new_subgraph = add.create_node([
            mul.create_node([
                mvn.create_node([input, mean_reduction, variance_reduction]),
                input_gamma
            ]),
            input_beta
        ])
        fbn.replace_node(new_subgraph)
예제 #2
0
    def replace_op(self, graph: Graph, node: Node):
        prefix = node.name + '/InstanceNormalization'
        mvn = MVN(graph, dict(name=prefix + '/MVN', eps=node.epsilon))
        mul = Mul(graph, dict(name=prefix + '/Mul', axis=1))
        add = Add(graph, dict(name=prefix + '/Add', axis=1))

        new_subgraph = add.create_node([
            mul.create_node(
                [mvn.create_node([node.in_node(0)]),
                 node.in_node(1)]),
            node.in_node(2)
        ])

        return [new_subgraph.id]
    def replace_sub_graph(graph: Graph, match: dict):
        mvn = MVN(graph, dict(
            name=match['truediv'].name + '/MVN_',
            eps_mode='outside_sqrt',
            normalize_variance=1
        ))
        mvn.attrs['old_infer'] = mvn.attrs['infer']
        mvn.attrs['infer'] = __class__.infer

        mean_reduction = match['mean'].in_node(1)
        variance_reduction = match['variance'].in_node(1)
        pow2 = match['pow'].in_node(1)
        eps = match['add'].in_node(0 if match['add'].in_node(0).id != match['variance'].id else 1)

        new_subgraph = mvn.create_node([match['mean'].in_node(0), mean_reduction, variance_reduction, pow2, eps])

        match['truediv'].replace_node(new_subgraph)
예제 #4
0
    def replace_sub_graph(graph: Graph, match: dict):
        mvn = MVN(
            graph,
            dict(name=match['truediv'].name + '/MVN_',
                 required_reduction_indices=[1, 2]
                 if graph.graph['layout'] == 'NHWC' else [2, 3]))
        mvn.attrs['old_infer'] = mvn.attrs['infer']
        mvn.attrs['infer'] = __class__.infer

        mean_reduction = match['mean'].in_node(1)
        variance_reduction = match['variance'].in_node(1)
        pow2 = match['pow'].in_node(1)
        eps = match['add'].in_node(
            0 if match['add'].in_node(0).id != match['variance'].id else 1)

        new_subgraph = mvn.create_node([
            match['mean'].in_node(0), mean_reduction, variance_reduction, pow2,
            eps
        ])

        match['truediv'].replace_node(new_subgraph)