Esempio n. 1
0
            def is_boundary(tensor2, from_up, started_down):
                if tensor2 in visited_tensors:
                    return False
                if started_down:
                    perm = utils.inverse_permutation(axes)
                else:
                    perm = axes
                if from_up:
                    is_boundary_transpose = (
                        tensor2.producer is not None
                        and tensor2.producer.name == driver.transpose_op_name
                        and driver.get_axes_from_transpose(
                            tensor2.producer) == perm)
                    is_boundary_squeeze = (tensor2.producer is not None
                                           and tensor2.producer.name
                                           == driver.squeeze_op_name
                                           and _is_squeeze_invariant_to_perm(
                                               driver.get_axes_from_squeeze(
                                                   tensor2.producer), perm))

                    return is_boundary_transpose or is_boundary_squeeze
                else:
                    is_boundary_transpose = (
                        tensor2.producer is not None
                        and tensor2.producer.name == driver.transpose_op_name
                        and driver.get_axes_from_transpose(tensor2.producer)
                        == utils.inverse_permutation(perm))
                    is_boundary_varlike = (
                        ((merge_into_constants and tensor2.is_constant) or
                         (merge_into_variables and tensor2.is_variable))
                        and len(tensor2.shape) in [0, 1, len(axes)])

                    return is_boundary_transpose or is_boundary_varlike
def evaluate_shape_of_transpose_grad(op):
    # type: (TFOperation)->None
    grad_shape = op.inputs[2].shape  # grad
    perm = op.attribs["orig_perm"]
    if perm is None:
        op.output.shape = _unify_shape(reversed(grad_shape))
    else:
        op.output.shape = _unify_shape(_np_permute(grad_shape, utils.inverse_permutation(perm)))
Esempio n. 3
0
def transform_transpose_grad(g, op):
    # type: (TFGraph, TFOperation)->None

    grad = op.inputs[2]

    TFOperation(
        graph=g,
        name="tf.transpose",
        inputs=grad,
        attribs=dict(perm=utils.inverse_permutation(op.attribs["orig_perm"])),
        outputs=op.outputs)

    g.remove_operation(op, unlink=True)
Esempio n. 4
0
def _get_transform_for_box_or_max_pool(input_shape, active):
    # type: (List[int], List[bool])->Any
    assert len(input_shape) >= 3
    assert len(input_shape) == len(active)
    if sum(active) > 3:
        raise utils.NNEFToolsException(
            "Sliding window operations are not supported if they have more than 3 'active' dimensions. "
            "We have: {}".format(sum(active)))
    if 3 <= len(input_shape
                ) <= 5 and not active[0] and not active[1]:  # Direct support
        return None, None, None, None
    else:
        inactive_dims = [i for i, a in enumerate(active) if not a]
        active_dims = [i for i, a in enumerate(active) if a]
        inactive_shape = [
            s for i, s in enumerate(input_shape) if i not in active_dims
        ]
        active_shape = [
            s for i, s in enumerate(input_shape) if i in active_dims
        ]
        perm = inactive_dims + active_dims
        perm_inv = utils.inverse_permutation(perm)
    return perm, perm_inv, inactive_shape, active_shape
Esempio n. 5
0
def transform_remove_inverse_transposes(
        g,  # type: BaseGraph
        transforms_by_name,  # type:typing.Dict[str, typing.List[Transform]]
        merge_into_constants,  # type: bool
        merge_into_variables,  # type: bool
        driver,  # type: DataFormatOptimizationDriver
        transposable_ops=None,  # type: typing.Optional[typing.List[TransposableOperation]]
):
    # type: (...)-> None

    if transposable_ops is None:
        transposable_ops = []

    transposable_op_by_name = {
    }  # type: typing.Dict[str, TransposableOperation]
    transposable_op_by_name.update({top.name: top for top in transposable_ops})

    for op in g.operations:
        if op.name == driver.transpose_op_name and op.output.rank > len(
                driver.get_axes_from_transpose(op)):
            driver.set_axes_on_transpose(
                op,
                driver.get_axes_from_transpose(op) + list(range(
                    op.output.rank))[len(driver.get_axes_from_transpose(op)):])

    matches = _find_inverse_transposes(
        g,
        transposable_op_names=set(six.iterkeys(transposable_op_by_name)),
        merge_into_constants=merge_into_constants,
        merge_into_variables=merge_into_variables,
        driver=driver)

    for axes, subgraph in matches:
        upper_perm = axes if subgraph.started_down else utils.inverse_permutation(
            axes)
        lower_perm = utils.inverse_permutation(upper_perm)

        upper_boundary = [
            be for be in subgraph.boundary_elements if not be.from_up
        ]
        lower_boundary = [
            be for be in subgraph.boundary_elements if be.from_up
        ]

        for _, tensor in upper_boundary:
            if tensor.producer is not None and tensor.producer.name == driver.transpose_op_name:
                if tensor in g.outputs:
                    graph_output = driver.create_tensor(
                        graph=g,
                        name=tensor.name,
                        shape=utils.apply_permutation(
                            tensor.producer.input.shape, upper_perm),
                        dtype=tensor.producer.input.dtype)
                    driver.create_transpose_op(graph=g,
                                               input=tensor.producer.input,
                                               axes=list(upper_perm),
                                               output=graph_output)
                    graph_utils.replace_tensor_in_outputs(
                        g, tensor, graph_output)
                elif (len(tensor.producer.input.consumers) == 1
                      and tensor.producer.input not in g.inputs
                      and tensor.producer.input not in g.outputs):
                    tensor.producer.input.name = tensor.name
                    add_transform(transforms_by_name, tensor.producer.input,
                                  Transpose(lower_perm))
                remove_passthrough_ex(g, tensor.producer)
            else:
                assert (merge_into_variables and tensor.is_variable) \
                       or (merge_into_constants and tensor.is_constant)

                apply_transpose_to_varlike(tensor, lower_perm,
                                           transforms_by_name)

        skipped_ops = set(
            tensor.producer for tensor in
            subgraph.skipped_tensors)  # type: typing.Set[BaseOperation]
        for op in skipped_ops:
            assert op.name in transposable_op_by_name
            transposable_op_by_name[op.name].dg_transpose(
                _transposer, g, op, lower_perm)
            for output in op.outputs:
                if output in g.outputs:
                    graph_output = driver.create_tensor(graph=g,
                                                        name=output.name,
                                                        shape=output.shape,
                                                        dtype=output.dtype)
                    driver.create_transpose_op(graph=g,
                                               input=output,
                                               axes=list(upper_perm),
                                               output=graph_output)

                    graph_utils.replace_tensor_in_outputs(
                        g, output, graph_output)
                    output.name = None
                    output.shape = utils.apply_permutation(
                        output.shape, lower_perm)
                else:
                    output.shape = utils.apply_permutation(
                        output.shape, lower_perm)
                    add_transform(transforms_by_name, output,
                                  Transpose(lower_perm))

        for _, tensor in lower_boundary:
            if tensor.producer is not None and tensor.producer.name == driver.transpose_op_name:
                if tensor in g.outputs:
                    graph_output = driver.create_tensor(
                        graph=g,
                        name=tensor.name,
                        shape=tensor.producer.input.shape,
                        dtype=tensor.producer.input.dtype)

                    driver.create_copy_op(graph=g,
                                          input=tensor.producer.input,
                                          output=graph_output)

                    graph_utils.replace_tensor_in_outputs(
                        g, tensor, graph_output)
                remove_passthrough_ex(g, tensor.producer)
            elif tensor.producer is not None and tensor.producer.name == driver.squeeze_op_name:
                driver.set_axes_on_squeeze(
                    tensor.producer,
                    sorted(
                        Transposer.apply_permutation_to_axes(
                            driver.get_axes_from_squeeze(tensor.producer),
                            lower_perm)))
            else:
                assert False

    graph_utils.remove_unreachable(g)
Esempio n. 6
0
def transform_io(g, io_transform, transforms_by_name, driver):
    # type:(BaseGraph, TrafoOrTrafoDictType, typing.Dict[str, typing.List[Transform]], DataFormatOptimizationDriver)->None
    io_tensors_by_name = {t.name: t for t in list(g.inputs) + list(g.outputs)}

    transform_by_io_tensor = {}
    if isinstance(io_transform, dict):
        for k, v in six.iteritems(io_transform):
            assert isinstance(k, (str, driver.tensor_type)), \
                "io_transform: Key type must be {} or str".format(driver.tensor_type.__name__)
            assert isinstance(
                v, Transform), "io_transform: Value type must be Transform"

            if isinstance(k, BaseTensor):
                assert k in six.itervalues(io_tensors_by_name)
            else:
                assert k in io_tensors_by_name
                k = io_tensors_by_name[k]
            transform_by_io_tensor[k] = v
        for io_tensor in six.itervalues(io_tensors_by_name):
            assert io_tensor in transform_by_io_tensor, \
                "io_transform: Please specify transform for all io tensors. " \
                "You can use graph_optimizer.IDENTITY if no change is required."
    else:
        assert isinstance(io_transform, Transform), \
            "io_transform must be Transform or Dict[str, Transform] or Dict[NNEFTensor, Transform]"
        for t in six.itervalues(io_tensors_by_name):
            transform_by_io_tensor[t] = io_transform

    for tensor, transform in six.iteritems(transform_by_io_tensor):
        assert bool(tensor in g.inputs) != bool(tensor in g.outputs), \
            "Tensor must be input or output (and not both)"

        assert isinstance(transform, (Transpose, _CustomTransform)), \
            "Unsupported io_transform"

        if isinstance(transform, _Identity):
            continue

        if isinstance(transform, _SmartTFNCHWToNCHW):
            try:
                _transform_tf_filter_grad_to_nnef(g, tensor,
                                                  transforms_by_name, driver)
            except _TransformException:
                pass
            continue

        if isinstance(transform, _SmartNHWCToNCHW):
            if tensor.rank <= 2:
                continue
            transform = Transpose([0, tensor.rank - 1] +
                                  list(range(tensor.rank))[1:-1])

        if isinstance(transform, _SmartTFNHWCToNCHW):
            try:
                _transform_tf_filter_grad_to_nnef(g, tensor,
                                                  transforms_by_name, driver)
                continue
            except _TransformException:
                if tensor.rank <= 2:
                    continue
                transform = Transpose([0, tensor.rank - 1] +
                                      list(range(tensor.rank))[1:-1])

        if isinstance(transform, _SmartNCHWToTFNCHW):
            try:
                _transform_nnef_filter_grad_to_tf(g, tensor,
                                                  transforms_by_name, driver)
            except _TransformException:
                pass
            continue

        if isinstance(transform, _SmartNCHWToNHWC):
            if tensor.rank <= 2:
                continue
            transform = Transpose([0] + list(range(tensor.rank))[2:] + [1])

        if isinstance(transform, _SmartNCHWToTFNHWC):
            try:
                _transform_nnef_filter_grad_to_tf(g, tensor,
                                                  transforms_by_name, driver)
                continue
            except _TransformException:
                if tensor.rank <= 2:
                    continue
                transform = Transpose([0] + list(range(tensor.rank))[2:] + [1])

        if isinstance(transform, _TFFilterGradToNNEF):
            _transform_tf_filter_grad_to_nnef(g, tensor, transforms_by_name,
                                              driver)
            continue

        if isinstance(transform, _NNEFFilterGradToTF):
            _transform_nnef_filter_grad_to_tf(g, tensor, transforms_by_name,
                                              driver)
            continue

        assert isinstance(
            transform,
            Transpose), "Unsupported io_transform: {}".format(transform)
        assert len(transform.axes) == tensor.rank, "Transpose: invalid rank"

        if transform.is_identity():
            continue

        if tensor in g.inputs:
            assert tensor.name

            new_input_tensor = driver.create_tensor(
                graph=g,
                name=tensor.name,
                shape=utils.apply_permutation(tensor.shape, transform.axes),
                dtype=tensor.dtype)
            add_transform(transforms_by_name, new_input_tensor, transform)

            transpose = driver.create_transpose_op(
                graph=g,
                input=new_input_tensor,
                axes=utils.inverse_permutation(transform.axes),
                output=driver.create_tensor(graph=g,
                                            name=None,
                                            shape=tensor.shape,
                                            dtype=tensor.dtype))

            graph_utils.replace_tensor_in_inputs(g, tensor, new_input_tensor)
            graph_utils.replace_tensor_in_consumers(g,
                                                    tensor,
                                                    transpose.output,
                                                    remove=True)
        else:  # output
            transpose = driver.create_transpose_op(
                graph=g,
                input=tensor,
                axes=transform.axes,
                output=driver.create_tensor(graph=g,
                                            name=tensor.name,
                                            shape=utils.apply_permutation(
                                                tensor.shape, transform.axes),
                                            dtype=tensor.dtype))
            add_transform(transforms_by_name, transpose.output, transform)
            tensor.name = None

            graph_utils.replace_tensor_in_outputs(g, tensor, transpose.output)
Esempio n. 7
0
def _transform_nnef_filter_grad_to_tf(g, tensor, transforms_by_name, driver):
    # type: (BaseGraph, BaseTensor, typing.Dict[str, typing.List[Transform]], DataFormatOptimizationDriver)->None

    assert driver.conv_grad_filter_op_names

    cgf1_output = matcher.Tensor()
    cgf1 = matcher.Operation(name=driver.conv_grad_filter_op_names,
                             outputs=cgf1_output)
    transpose1 = matcher.Operation(name=driver.transpose_op_name,
                                   inputs=cgf1_output)

    cgf2_output = matcher.Tensor()
    cgf2 = matcher.Operation(name=driver.conv_grad_filter_op_names,
                             outputs=cgf2_output)
    reshape2_output = matcher.Tensor()
    reshape2 = matcher.Operation(name=driver.reshape_op_name,
                                 inputs=cgf2_output,
                                 outputs=reshape2_output)
    transpose2 = matcher.Operation(name=driver.transpose_op_name,
                                   inputs=reshape2_output)

    if tensor.producer is None:
        raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF")

    m = matcher.match(g, tensor.producer,
                      matcher.OrPattern(transpose1, transpose2))

    if transpose1 in m:
        cgf = m[cgf1]  # type: BaseOperation
        transpose = m[transpose1]  # type: BaseOperation
        if not (len(transpose.output.consumers) <= 1
                and cgf.output not in g.outputs):
            raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF")
        cgf.output.name = transpose.output.name
        add_transform(
            transforms_by_name, cgf.output,
            Transpose(
                utils.inverse_permutation(
                    driver.get_axes_from_transpose(transpose))))
        graph_utils.replace_tensor_in_outputs(g, transpose.output, cgf.output)
        graph_utils.remove_subgraph(g, [transpose])
    elif transpose2 in m:
        cgf = m[cgf2]  # type: BaseOperation
        reshape = m[reshape2]  # type: BaseOperation
        transpose = m[transpose2]  # type: BaseOperation

        if not (len(reshape.output.consumers) <= 1
                and len(transpose.output.consumers) <= 1
                and cgf.output not in g.outputs):
            raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF")

        cgf.output.name = transpose.output.name
        add_transform(
            transforms_by_name, cgf.output,
            Transpose(
                utils.inverse_permutation(
                    driver.get_axes_from_transpose(transpose))))
        add_transform(transforms_by_name, cgf.output,
                      Reshape(cgf.output.shape))

        graph_utils.replace_tensor_in_outputs(g, transpose.output, cgf.output)
        graph_utils.remove_subgraph(g, [transpose, reshape])
    else:
        raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF")
Esempio n. 8
0
 def inverse_permutation(perm):
     return utils.inverse_permutation(perm)
Esempio n. 9
0
def expand_softmax(tf_graph, tf_op):
    assert tf_op.input.rank != 0

    axis = tf_op.attribs.get('axis')
    if axis is None:
        axis = -1
    if axis < 0:
        axis += tf_op.input.rank

    tf_op.attribs['axis'] = -1

    if tf_op.input.rank == 2 and axis == 1:
        return

    if axis != tf_op.input.rank - 1:
        perm = utils.without(range(tf_op.input.rank), axis) + [axis]
        perm_inv = utils.inverse_permutation(perm)
        transpose = TFOperation(graph=tf_graph,
                                name="tf.transpose",
                                inputs=tf_op.input,
                                attribs=dict(perm=perm),
                                outputs=TFTensor(graph=tf_graph,
                                                 name=None,
                                                 shape=infer.transpose(
                                                     input=tf_op.input.shape,
                                                     axes=perm),
                                                 dtype=tf_op.input.dtype))
        tf_op.inputs = transpose.output
        old_output = tf_op.output
        tf_op.outputs = TFTensor(graph=tf_graph,
                                 name=None,
                                 shape=tf_op.input.shape,
                                 dtype=tf_op.input.dtype)
        TFOperation(graph=tf_graph,
                    name="tf.transpose",
                    inputs=tf_op.output,
                    attribs=dict(perm=perm_inv),
                    outputs=old_output)

    if tf_op.input.rank != 2:
        shape = [-1, tf_op.input.shape[-1]]
        reshape = TFOperation(graph=tf_graph,
                              name="tf.reshape",
                              inputs=tf_op.input,
                              attribs=dict(shape=shape),
                              outputs=TFTensor(graph=tf_graph,
                                               name=None,
                                               shape=infer.reshape(
                                                   input=tf_op.input.shape,
                                                   shape=shape),
                                               dtype=tf_op.input.dtype))
        tf_op.inputs = reshape.output
        old_output = tf_op.output
        tf_op.outputs = TFTensor(graph=tf_graph,
                                 name=None,
                                 shape=list(tf_op.input.shape),
                                 dtype=tf_op.input.dtype)
        TFOperation(graph=tf_graph,
                    name="tf.reshape",
                    inputs=tf_op.output,
                    attribs=dict(shape=old_output.shape),
                    outputs=old_output)