Ejemplo n.º 1
0
def transform_fuse_activations(tf_graph):
    # type: (TFGraph)->None

    fuse_to = [
        "tf.add",
        "tf.subtract",
        "tf.multiply",
        "tf.divide",
        "tf.nn.conv2d",
        "tf.nn.depthwise_conv2d",
        "tf.nn.max_pool",
        "tf.nn.avg_pool",
        # "tf.nn.conv2d_transpose", (not working yet)
        "tf.matmul",
        "tf.nn.l2_normalize",
        # "tf.concat" (not working yet)
    ]

    conv_output = matcher.Tensor()
    convlike = matcher.Operation(name=fuse_to, outputs=conv_output)
    activation = matcher.Operation(name="tf.nn.relu", inputs={0: conv_output})

    matcher.replace(
        tf_graph, activation, lambda m: TFOperation(
            graph=tf_graph,
            name=m[convlike].name,
            attribs=utils.dict_union(m[convlike].attribs,
                                     dict(fused_activation_function='RELU')),
            inputs=m[convlike].inputs,
            outputs=m[activation].outputs),
        lambda m: not m[convlike].attribs.get('fused_activation_function'))

    conv_output = matcher.Tensor()
    convlike = matcher.Operation(name=fuse_to, outputs=conv_output)
    activation = matcher.Operation(name="tf.clip_by_value",
                                   inputs={0: conv_output})

    matcher.replace(
        graph=tf_graph,
        pattern=activation,
        replacement=lambda m: TFOperation(
            graph=tf_graph,
            name=m[convlike].name,
            attribs=utils.dict_union(m[convlike].attribs,
                                     dict(fused_activation_function='RELU6')),
            inputs=m[convlike].inputs,
            outputs=m[activation].outputs),
        condition=lambda m:
        (m[activation].inputs[1].data == [0] and m[activation].inputs[2].data
         == [6] and not m[convlike].attribs.get('fused_activation_function')))
Ejemplo n.º 2
0
def _merge_pads(g):
    # type: (NNEFGraph)->None

    t = matcher.Tensor()
    pad = matcher.Operation(name=['box', 'pad'], outputs=t)
    sliding = matcher.Operation(name=['argmax_pool', 'max_pool', 'max_pool_with_index', 'avg_pool', 'conv'],
                                inputs={0: t})

    def condition(m):
        # type: (matcher.Match)->bool
        if not (m[pad].name == 'pad' or
                (m[pad].name == 'box'
                 and all(s == 1 for s in m[pad].attribs.get('size', []))
                 and all(s == 1 for s in m[pad].attribs.get('stride', []))
                 and all(s == 1 for s in m[pad].attribs.get('dilation', []))
                 and not m[pad].attribs.get('normalize', False))):
            return False

        value = m[pad].attribs.get('_value', 0.0)

        if value not in [0.0, float('-inf')]:
            return False

        if value == float('-inf'):
            if not m[sliding].name in ['argmax_pool', 'max_pool', 'max_pool_with_index']:
                return False

        if m[pad].attribs.get('border', 'constant') != 'constant':
            return False

        if (m[sliding].attribs.get('border', 'constant') != 'constant'
                and any(p != 0 or q != 0 for p, q in m[sliding].attribs.get('padding', []))):
            return False

        if m[sliding].name in ['conv'] and any(p != 0 or q != 0 for p, q in m[pad].attribs.get('padding', [])[:2]):
            return False

        return True

    def action(m):
        # type: (matcher.Match)->None
        value = m[pad].attribs.get('_value', 0.0)
        pad_padding = m[pad].attribs.get('padding', [(0, 0) * m[t].rank])
        sliding_padding = m[sliding].attribs.get('padding', [(0, 0) * m[t].rank])

        if m[sliding].name in ['conv']:
            pad_padding = pad_padding[2:]

        assert len(pad_padding) == len(sliding_padding)

        m[sliding].attribs['padding'] = [(p + pp, q + qq) for (p, q), (pp, qq) in zip(pad_padding, sliding_padding)]
        m[sliding].attribs['border'] = 'ignore' if value == float('-inf') else 'constant'

        graph_utils.remove_passthrough(g, m[pad])

    matcher.for_each(graph=g, pattern=sliding, action=action, condition=condition)

    for op in g.operations:
        if op.name in ['box', 'pad'] and '_value' in op.attribs:
            raise utils.NNEFToolsException('Could not export {} with value={}'.format(op.name, op.attribs['_value']))
Ejemplo n.º 3
0
def transform_fuse_bias_add_to_conv(tf_graph):
    # type: (TFGraph)->None

    conv_output = matcher.Tensor()
    conv = matcher.Operation(name=["tf.nn.conv2d", "tf.nn.depthwise_conv2d"], outputs=conv_output)
    add = matcher.Operation(name="tf.nn.bias_add", inputs={0: conv_output})

    matcher.replace(tf_graph,
                    add,
                    lambda m: TFOperation(graph=tf_graph,
                                          name=m[conv].name,
                                          attribs=m[conv].attribs,
                                          inputs=tuple(m[conv].inputs) + (m[add].inputs[1],),
                                          outputs=m[add].outputs),
                    lambda m: len(m[conv].inputs) == 2)
Ejemplo n.º 4
0
def transform_bias_add_conv(g):
    # type: (TFGraph)->None

    conv_output = matcher.Tensor()
    conv = matcher.Operation(name=[
        "_conv", "_planewise_conv", "_separable_conv", "_deconv",
        "_planewise_deconv"
    ],
                             inputs={2: None},
                             outputs={0: conv_output})
    add = matcher.Operation(name="tf.nn.bias_add", inputs={0: conv_output})

    matcher.replace(
        g, add, lambda m: TFOperation(graph=g,
                                      name=m[conv].name,
                                      inputs=(m[conv].inputs[0], m[conv].
                                              inputs[1], m[add].inputs[1]),
                                      attribs=m[conv].attribs,
                                      outputs=m[add].outputs))
Ejemplo n.º 5
0
def _transform_nnef_filter_grad_to_tf(g, tensor, transforms_by_name, driver):
    # type: (BaseGraph, BaseTensor, typing.Dict[str, typing.List[Transform]], DataFormatOptimizationDriver)->None

    assert driver.conv_grad_filter_op_names

    cgf1_output = matcher.Tensor()
    cgf1 = matcher.Operation(name=driver.conv_grad_filter_op_names,
                             outputs=cgf1_output)
    transpose1 = matcher.Operation(name=driver.transpose_op_name,
                                   inputs=cgf1_output)

    cgf2_output = matcher.Tensor()
    cgf2 = matcher.Operation(name=driver.conv_grad_filter_op_names,
                             outputs=cgf2_output)
    reshape2_output = matcher.Tensor()
    reshape2 = matcher.Operation(name=driver.reshape_op_name,
                                 inputs=cgf2_output,
                                 outputs=reshape2_output)
    transpose2 = matcher.Operation(name=driver.transpose_op_name,
                                   inputs=reshape2_output)

    if tensor.producer is None:
        raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF")

    m = matcher.match(g, tensor.producer,
                      matcher.OrPattern(transpose1, transpose2))

    if transpose1 in m:
        cgf = m[cgf1]  # type: BaseOperation
        transpose = m[transpose1]  # type: BaseOperation
        if not (len(transpose.output.consumers) <= 1
                and cgf.output not in g.outputs):
            raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF")
        cgf.output.name = transpose.output.name
        add_transform(
            transforms_by_name, cgf.output,
            Transpose(
                utils.inverse_permutation(
                    driver.get_axes_from_transpose(transpose))))
        graph_utils.replace_tensor_in_outputs(g, transpose.output, cgf.output)
        graph_utils.remove_subgraph(g, [transpose])
    elif transpose2 in m:
        cgf = m[cgf2]  # type: BaseOperation
        reshape = m[reshape2]  # type: BaseOperation
        transpose = m[transpose2]  # type: BaseOperation

        if not (len(reshape.output.consumers) <= 1
                and len(transpose.output.consumers) <= 1
                and cgf.output not in g.outputs):
            raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF")

        cgf.output.name = transpose.output.name
        add_transform(
            transforms_by_name, cgf.output,
            Transpose(
                utils.inverse_permutation(
                    driver.get_axes_from_transpose(transpose))))
        add_transform(transforms_by_name, cgf.output,
                      Reshape(cgf.output.shape))

        graph_utils.replace_tensor_in_outputs(g, transpose.output, cgf.output)
        graph_utils.remove_subgraph(g, [transpose, reshape])
    else:
        raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF")