Exemplo n.º 1
0
    def apply_mean_value(graph: Graph, input_node: Node,
                         node_mean_scale_values: dict):
        if 'mean' in node_mean_scale_values and node_mean_scale_values[
                'mean'] is not None:
            if all([x == 0 for x in node_mean_scale_values['mean']]):
                return
            out_node = input_node.out_node()
            if not input_node.has_valid('shape'):
                raise Error("Node {} has not valid shape attribute".format(
                    input_node.id))
            input_shape = input_node.shape
            # Create Add node
            graph.remove_edge(input_node.id, out_node.id)

            value = np.array(node_mean_scale_values['mean']) * (-1)

            add_node = Add(graph, dict(name="Add_"))
            add_data = Op.create_input_data_node(graph, "data_add_",
                                                 np.array(value))
            Op.expand_node_shape(add_data,
                                 (len(input_shape) -
                                  2 if graph.graph['layout'] == 'NCHW' else 0))
            add_input = Op.create_data_node(graph, input_node,
                                            {'shape': out_node.shape})

            add_node.create_node_with_data(inputs=[add_input, add_data],
                                           data_nodes=out_node)
Exemplo n.º 2
0
    def apply_scale(graph: Graph, input_node: Node,
                    node_mean_scale_values: dict):
        if 'scale' in node_mean_scale_values and node_mean_scale_values[
                'scale'] is not None:
            if all([x == 1 for x in node_mean_scale_values['scale']]):
                return
            out_node = input_node.out_node()
            if not input_node.has_valid('shape'):
                raise Error("Node {} has not valid shape attribute".format(
                    input_node.id))
            input_shape = input_node.shape

            # Create Mul node
            value = 1 / np.array(node_mean_scale_values['scale'])
            graph.remove_edge(input_node.id, out_node.id)

            mul_node = Mul(graph, dict(name="Mul_"))
            mul_data = Op.create_input_data_node(graph, "data_mul_",
                                                 np.array(value))
            Op.expand_node_shape(mul_data,
                                 (len(input_shape) -
                                  2 if graph.graph['layout'] == 'NCHW' else 0))
            mul_input = Op.create_data_node(graph, input_node,
                                            {'shape': out_node.shape})

            mul_node.create_node_with_data(inputs=[mul_input, mul_data],
                                           data_nodes=out_node)
Exemplo n.º 3
0
def _scale_input_action_mul(graph: nx.MultiDiGraph, match: dict, scale: float):
    assert (len(match['placeholder'].out_nodes()))

    tinput = match['placeholder']
    if not tinput.has_valid('shape'):
        raise Error("Node {} has not valid shape attribute".format(tinput.id))

    input_shape = tinput.shape
    toutput = match['data']

    # Create Mul node
    value = np.array([1 / scale])

    # Disconnect input with data node
    graph.remove_edge(tinput.id, toutput.id)

    # Create Mul node
    mul_node = Mul(graph, dict(name="Mul1_"))
    mul_data = Op.create_input_data_node(graph, "data_mul_scale_",
                                         np.array(value))
    Op.expand_node_shape(
        mul_data,
        len(input_shape) - 2 if graph.graph['layout'] == 'NCHW' else 0)
    mul_input = Op.create_data_node(graph, tinput, {'shape': toutput.shape})

    mul_node.create_node_with_data(inputs=[mul_input, mul_data],
                                   data_nodes=toutput)
Exemplo n.º 4
0
def convert_batch_norm(graph: nx.MultiDiGraph):
    """
    This function finds FusedBatchNorm layer (or BatchNorm for MXNet) and replaces with Mul->Add->Mul->Add sequence.
    """
    for n in list(graph.nodes()):
        node = Node(graph, n)
        if node.has_valid('op') and (node.op == 'FusedBatchNorm'
                                     or node.op == 'BatchNorm'
                                     or node.op == 'BatchNormalization'):
            toutput = node.out_node()
            tinput = node.in_node(0)

            if any([
                    node.in_node(i).value is None
                    for i in range(1, len(node.in_nodes()))
            ]):
                log.warning(
                    'Cannot translate FusedBatchNorm {} node with non-constant weights'
                    .format(
                        node.name if node.has_valid('name') else '<UNKNOWN>'))
                continue

            const = node.in_node(1)
            beta = node.in_node(2)
            mean = node.in_node(3)
            variance = node.in_node(4)
            eps = node.eps

            if node.has_valid('fix_gamma') and node.fix_gamma:
                const.value.fill(1.)

            can_be_fused = False if not node.soft_get('can_be_fused') else True

            # Remove edges from FusedBN node
            graph.remove_edge(tinput.id, node.id)
            graph.remove_edge(beta.id, node.id)
            graph.remove_edge(const.id, node.id)
            graph.remove_edge(mean.id, node.id)
            graph.remove_edge(variance.id, node.id)
            graph.remove_edge(node.id, toutput.id)

            scale = 1. / np.sqrt(variance.value + eps)
            shift = (mean.value * (-1)) * scale

            # Expand dims for current layout
            broadcast_dims_cnt = len(
                tinput.shape) - 2 if graph.graph['layout'] == 'NCHW' else 0
            # Update values and shapes with new shape
            Op.expand_node_shape(const, broadcast_dims_cnt)
            Op.expand_node_shape(beta, broadcast_dims_cnt)

            for idx in range(broadcast_dims_cnt):
                scale = np.expand_dims(scale, axis=-1)
                shift = np.expand_dims(shift, axis=-1)

            _fused_batch_norm_decomposition(graph, tinput, toutput, const,
                                            beta, scale, shift, can_be_fused)
Exemplo n.º 5
0
def _bn_to_mul_add_action(graph: nx.MultiDiGraph, match: dict):
    # Data nodes
    tinput = match['input']
    toutput = match['output']
    mean = match['mean']
    variance = match['variance']

    # Op node
    bn_node = match['batch_norm']

    # Disconnect data nodes from
    graph.remove_edge(tinput.node, bn_node.node)
    graph.remove_edge(mean.node, bn_node.node)
    graph.remove_edge(variance.node, bn_node.node)

    graph.remove_edge(bn_node.node, toutput.node)

    scale = 1. / np.sqrt(variance.value + bn_node.epsilon)
    shift = (mean.value * (-1)) * scale

    mean.value = np.array(scale)
    variance.value = np.array(shift)

    # Expand dims for current layout
    broadcast_dims_cnt = len(
        tinput.shape) - 2 if graph.graph['layout'] == 'NCHW' else 0
    # Update values and shapes with new shape
    Op.expand_node_shape(mean, broadcast_dims_cnt)
    Op.expand_node_shape(variance, broadcast_dims_cnt)

    can_be_fused = False if not bn_node.soft_get('can_be_fused') else True

    mul_node = Mul(graph, dict(name="Mul_", can_be_fused=can_be_fused))
    add_node = Add(graph, dict(name="Add_", can_be_fused=can_be_fused))

    # Connect input->mul->add
    add_node.create_node_with_data(inputs=[
        mul_node.create_node_with_data(inputs=[tinput, mean]), variance
    ],
                                   data_nodes=toutput)
Exemplo n.º 6
0
def convert_scale_shift_to_mul_add(graph: nx.MultiDiGraph):
    nodes = [
        Node(graph, node) for node in graph.nodes()
        if Node(graph, node).soft_get('op') == 'ScaleShift'
    ]
    for node in nodes:
        if node.soft_get('can_be_fused') is False:
            continue

        has_biases = True
        has_weights = True
        # We don't need zero biases
        if len(node.in_nodes()) < 3 or all(
            [x == 0 for x in node.in_node(2).value]):
            has_biases = False
        input_node = node.in_node(0)
        scale_node = node.in_node(1)
        shift_node = node.in_node(2) if has_biases else None
        output_node = node.out_node()

        if scale_node.has_valid("value") and all(
            [x == 1 for x in scale_node.value]):
            has_weights = False

        mul_node = Mul(graph, dict(name=node.name + "/Mul_"))
        add_node = Add(graph, dict(name=node.name + "/Add_"))

        # Disconnect ScaleShift node
        graph.remove_edge(input_node.id, node.id)
        graph.remove_edge(node.id, output_node.id)

        # Expand dims for current layout
        broadcast_dims_cnt = len(
            input_node.shape) - 2 if graph.graph['layout'] == 'NCHW' else 0
        if scale_node.has_valid("value"):
            Op.expand_node_shape(scale_node, broadcast_dims_cnt)
        else:
            # insert reshape to make shapes similar
            reshape_dims = np.zeros(len(input_node.shape), dtype=np.int64)
            for i in range(0, node.axis):
                reshape_dims[i] = 1
            for i in range(node.axis, node.axis + len(scale_node.shape)):
                reshape_dims[i] = scale_node.shape[i - node.axis]
            for i in range(node.axis + len(scale_node.shape),
                           len(input_node.shape)):
                reshape_dims[i] = 1
            reshape = Reshape(
                graph,
                dict(name=scale_node.name + "/Broadcast_", dim=reshape_dims))
            scale_node = reshape.create_node_with_data(inputs=[scale_node])

        Op.expand_node_shape(shift_node, broadcast_dims_cnt)

        # Connect input->mul->out->add->out
        if has_biases:
            add_node.create_node_with_data(inputs=[
                mul_node.create_node_with_data(
                    inputs=[input_node, scale_node]), shift_node
            ],
                                           data_nodes=output_node)
        elif has_weights:
            mul_node.create_node_with_data(inputs=[input_node, scale_node],
                                           data_nodes=output_node)
        else:
            merge_data_nodes(graph, input_node, output_node)
            graph.remove_node(output_node.id)