Ejemplo n.º 1
0
def apply_transpose_to_varlike(tensor, axes, transforms_by_name):
    # type: (BaseTensor, typing.List[int], typing.Dict[str, typing.List[Transform]])->None

    if tensor.rank <= 1:
        return

    old_shape = tensor.shape
    tensor.shape = utils.apply_permutation(old_shape, axes)
    if tensor.is_variable and tensor.data.size > 0:
        tensor.data = np.transpose(tensor.data, axes)
    elif tensor.is_constant:
        if len(tensor.data) > 1:
            tensor.data = (np.array(tensor.data).reshape(old_shape).transpose(
                axes).flatten().tolist())

    add_transform(transforms_by_name, tensor, Transpose(axes))
Ejemplo n.º 2
0
class IOTransform(object):
    Transpose = Transpose
    NDHWC_TO_NCDHW = Transpose([0, 4, 1, 2, 3])
    NCDHW_TO_NDHWC = Transpose([0, 2, 3, 4, 1])
    NHWC_TO_NCHW = Transpose([0, 3, 1, 2])
    NCHW_TO_NHWC = Transpose([0, 2, 3, 1])
    NWC_TO_NCW = Transpose([0, 3, 1])
    NCW_TO_NWC = Transpose([0, 2, 1])
    IDENTITY = _Identity()
    TF_FILTER_GRAD_TO_NNEF = _TFFilterGradToNNEF()
    NNEF_FILTER_GRAD_TO_TF = _NNEFFilterGradToTF()
    SMART_NCHW_TO_NHWC = _SmartNCHWToNHWC()
    SMART_NHWC_TO_NCHW = _SmartNHWCToNCHW()
    SMART_TF_NHWC_TO_NCHW = _SmartTFNHWCToNCHW()
    SMART_TF_NCHW_TO_NCHW = _SmartTFNCHWToNCHW()
    SMART_NCHW_TO_TF_NHWC = _SmartNCHWToTFNHWC()
    SMART_NCHW_TO_TF_NCHW = _SmartNCHWToTFNCHW()
Ejemplo n.º 3
0
def transform_remove_inverse_transposes(
        g,  # type: BaseGraph
        transforms_by_name,  # type:typing.Dict[str, typing.List[Transform]]
        merge_into_constants,  # type: bool
        merge_into_variables,  # type: bool
        driver,  # type: DataFormatOptimizationDriver
        transposable_ops=None,  # type: typing.Optional[typing.List[TransposableOperation]]
):
    # type: (...)-> None

    if transposable_ops is None:
        transposable_ops = []

    transposable_op_by_name = {
    }  # type: typing.Dict[str, TransposableOperation]
    transposable_op_by_name.update({top.name: top for top in transposable_ops})

    for op in g.operations:
        if op.name == driver.transpose_op_name and op.output.rank > len(
                driver.get_axes_from_transpose(op)):
            driver.set_axes_on_transpose(
                op,
                driver.get_axes_from_transpose(op) + list(range(
                    op.output.rank))[len(driver.get_axes_from_transpose(op)):])

    matches = _find_inverse_transposes(
        g,
        transposable_op_names=set(six.iterkeys(transposable_op_by_name)),
        merge_into_constants=merge_into_constants,
        merge_into_variables=merge_into_variables,
        driver=driver)

    for axes, subgraph in matches:
        upper_perm = axes if subgraph.started_down else utils.inverse_permutation(
            axes)
        lower_perm = utils.inverse_permutation(upper_perm)

        upper_boundary = [
            be for be in subgraph.boundary_elements if not be.from_up
        ]
        lower_boundary = [
            be for be in subgraph.boundary_elements if be.from_up
        ]

        for _, tensor in upper_boundary:
            if tensor.producer is not None and tensor.producer.name == driver.transpose_op_name:
                if tensor in g.outputs:
                    graph_output = driver.create_tensor(
                        graph=g,
                        name=tensor.name,
                        shape=utils.apply_permutation(
                            tensor.producer.input.shape, upper_perm),
                        dtype=tensor.producer.input.dtype)
                    driver.create_transpose_op(graph=g,
                                               input=tensor.producer.input,
                                               axes=list(upper_perm),
                                               output=graph_output)
                    graph_utils.replace_tensor_in_outputs(
                        g, tensor, graph_output)
                elif (len(tensor.producer.input.consumers) == 1
                      and tensor.producer.input not in g.inputs
                      and tensor.producer.input not in g.outputs):
                    tensor.producer.input.name = tensor.name
                    add_transform(transforms_by_name, tensor.producer.input,
                                  Transpose(lower_perm))
                remove_passthrough_ex(g, tensor.producer)
            else:
                assert (merge_into_variables and tensor.is_variable) \
                       or (merge_into_constants and tensor.is_constant)

                apply_transpose_to_varlike(tensor, lower_perm,
                                           transforms_by_name)

        skipped_ops = set(
            tensor.producer for tensor in
            subgraph.skipped_tensors)  # type: typing.Set[BaseOperation]
        for op in skipped_ops:
            assert op.name in transposable_op_by_name
            transposable_op_by_name[op.name].dg_transpose(
                _transposer, g, op, lower_perm)
            for output in op.outputs:
                if output in g.outputs:
                    graph_output = driver.create_tensor(graph=g,
                                                        name=output.name,
                                                        shape=output.shape,
                                                        dtype=output.dtype)
                    driver.create_transpose_op(graph=g,
                                               input=output,
                                               axes=list(upper_perm),
                                               output=graph_output)

                    graph_utils.replace_tensor_in_outputs(
                        g, output, graph_output)
                    output.name = None
                    output.shape = utils.apply_permutation(
                        output.shape, lower_perm)
                else:
                    output.shape = utils.apply_permutation(
                        output.shape, lower_perm)
                    add_transform(transforms_by_name, output,
                                  Transpose(lower_perm))

        for _, tensor in lower_boundary:
            if tensor.producer is not None and tensor.producer.name == driver.transpose_op_name:
                if tensor in g.outputs:
                    graph_output = driver.create_tensor(
                        graph=g,
                        name=tensor.name,
                        shape=tensor.producer.input.shape,
                        dtype=tensor.producer.input.dtype)

                    driver.create_copy_op(graph=g,
                                          input=tensor.producer.input,
                                          output=graph_output)

                    graph_utils.replace_tensor_in_outputs(
                        g, tensor, graph_output)
                remove_passthrough_ex(g, tensor.producer)
            elif tensor.producer is not None and tensor.producer.name == driver.squeeze_op_name:
                driver.set_axes_on_squeeze(
                    tensor.producer,
                    sorted(
                        Transposer.apply_permutation_to_axes(
                            driver.get_axes_from_squeeze(tensor.producer),
                            lower_perm)))
            else:
                assert False

    graph_utils.remove_unreachable(g)
Ejemplo n.º 4
0
def transform_io(g, io_transform, transforms_by_name, driver):
    # type:(BaseGraph, TrafoOrTrafoDictType, typing.Dict[str, typing.List[Transform]], DataFormatOptimizationDriver)->None
    io_tensors_by_name = {t.name: t for t in list(g.inputs) + list(g.outputs)}

    transform_by_io_tensor = {}
    if isinstance(io_transform, dict):
        for k, v in six.iteritems(io_transform):
            assert isinstance(k, (str, driver.tensor_type)), \
                "io_transform: Key type must be {} or str".format(driver.tensor_type.__name__)
            assert isinstance(
                v, Transform), "io_transform: Value type must be Transform"

            if isinstance(k, BaseTensor):
                assert k in six.itervalues(io_tensors_by_name)
            else:
                assert k in io_tensors_by_name
                k = io_tensors_by_name[k]
            transform_by_io_tensor[k] = v
        for io_tensor in six.itervalues(io_tensors_by_name):
            assert io_tensor in transform_by_io_tensor, \
                "io_transform: Please specify transform for all io tensors. " \
                "You can use graph_optimizer.IDENTITY if no change is required."
    else:
        assert isinstance(io_transform, Transform), \
            "io_transform must be Transform or Dict[str, Transform] or Dict[NNEFTensor, Transform]"
        for t in six.itervalues(io_tensors_by_name):
            transform_by_io_tensor[t] = io_transform

    for tensor, transform in six.iteritems(transform_by_io_tensor):
        assert bool(tensor in g.inputs) != bool(tensor in g.outputs), \
            "Tensor must be input or output (and not both)"

        assert isinstance(transform, (Transpose, _CustomTransform)), \
            "Unsupported io_transform"

        if isinstance(transform, _Identity):
            continue

        if isinstance(transform, _SmartTFNCHWToNCHW):
            try:
                _transform_tf_filter_grad_to_nnef(g, tensor,
                                                  transforms_by_name, driver)
            except _TransformException:
                pass
            continue

        if isinstance(transform, _SmartNHWCToNCHW):
            if tensor.rank <= 2:
                continue
            transform = Transpose([0, tensor.rank - 1] +
                                  list(range(tensor.rank))[1:-1])

        if isinstance(transform, _SmartTFNHWCToNCHW):
            try:
                _transform_tf_filter_grad_to_nnef(g, tensor,
                                                  transforms_by_name, driver)
                continue
            except _TransformException:
                if tensor.rank <= 2:
                    continue
                transform = Transpose([0, tensor.rank - 1] +
                                      list(range(tensor.rank))[1:-1])

        if isinstance(transform, _SmartNCHWToTFNCHW):
            try:
                _transform_nnef_filter_grad_to_tf(g, tensor,
                                                  transforms_by_name, driver)
            except _TransformException:
                pass
            continue

        if isinstance(transform, _SmartNCHWToNHWC):
            if tensor.rank <= 2:
                continue
            transform = Transpose([0] + list(range(tensor.rank))[2:] + [1])

        if isinstance(transform, _SmartNCHWToTFNHWC):
            try:
                _transform_nnef_filter_grad_to_tf(g, tensor,
                                                  transforms_by_name, driver)
                continue
            except _TransformException:
                if tensor.rank <= 2:
                    continue
                transform = Transpose([0] + list(range(tensor.rank))[2:] + [1])

        if isinstance(transform, _TFFilterGradToNNEF):
            _transform_tf_filter_grad_to_nnef(g, tensor, transforms_by_name,
                                              driver)
            continue

        if isinstance(transform, _NNEFFilterGradToTF):
            _transform_nnef_filter_grad_to_tf(g, tensor, transforms_by_name,
                                              driver)
            continue

        assert isinstance(
            transform,
            Transpose), "Unsupported io_transform: {}".format(transform)
        assert len(transform.axes) == tensor.rank, "Transpose: invalid rank"

        if transform.is_identity():
            continue

        if tensor in g.inputs:
            assert tensor.name

            new_input_tensor = driver.create_tensor(
                graph=g,
                name=tensor.name,
                shape=utils.apply_permutation(tensor.shape, transform.axes),
                dtype=tensor.dtype)
            add_transform(transforms_by_name, new_input_tensor, transform)

            transpose = driver.create_transpose_op(
                graph=g,
                input=new_input_tensor,
                axes=utils.inverse_permutation(transform.axes),
                output=driver.create_tensor(graph=g,
                                            name=None,
                                            shape=tensor.shape,
                                            dtype=tensor.dtype))

            graph_utils.replace_tensor_in_inputs(g, tensor, new_input_tensor)
            graph_utils.replace_tensor_in_consumers(g,
                                                    tensor,
                                                    transpose.output,
                                                    remove=True)
        else:  # output
            transpose = driver.create_transpose_op(
                graph=g,
                input=tensor,
                axes=transform.axes,
                output=driver.create_tensor(graph=g,
                                            name=tensor.name,
                                            shape=utils.apply_permutation(
                                                tensor.shape, transform.axes),
                                            dtype=tensor.dtype))
            add_transform(transforms_by_name, transpose.output, transform)
            tensor.name = None

            graph_utils.replace_tensor_in_outputs(g, tensor, transpose.output)
Ejemplo n.º 5
0
def _transform_nnef_filter_grad_to_tf(g, tensor, transforms_by_name, driver):
    # type: (BaseGraph, BaseTensor, typing.Dict[str, typing.List[Transform]], DataFormatOptimizationDriver)->None

    assert driver.conv_grad_filter_op_names

    cgf1_output = matcher.Tensor()
    cgf1 = matcher.Operation(name=driver.conv_grad_filter_op_names,
                             outputs=cgf1_output)
    transpose1 = matcher.Operation(name=driver.transpose_op_name,
                                   inputs=cgf1_output)

    cgf2_output = matcher.Tensor()
    cgf2 = matcher.Operation(name=driver.conv_grad_filter_op_names,
                             outputs=cgf2_output)
    reshape2_output = matcher.Tensor()
    reshape2 = matcher.Operation(name=driver.reshape_op_name,
                                 inputs=cgf2_output,
                                 outputs=reshape2_output)
    transpose2 = matcher.Operation(name=driver.transpose_op_name,
                                   inputs=reshape2_output)

    if tensor.producer is None:
        raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF")

    m = matcher.match(g, tensor.producer,
                      matcher.OrPattern(transpose1, transpose2))

    if transpose1 in m:
        cgf = m[cgf1]  # type: BaseOperation
        transpose = m[transpose1]  # type: BaseOperation
        if not (len(transpose.output.consumers) <= 1
                and cgf.output not in g.outputs):
            raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF")
        cgf.output.name = transpose.output.name
        add_transform(
            transforms_by_name, cgf.output,
            Transpose(
                utils.inverse_permutation(
                    driver.get_axes_from_transpose(transpose))))
        graph_utils.replace_tensor_in_outputs(g, transpose.output, cgf.output)
        graph_utils.remove_subgraph(g, [transpose])
    elif transpose2 in m:
        cgf = m[cgf2]  # type: BaseOperation
        reshape = m[reshape2]  # type: BaseOperation
        transpose = m[transpose2]  # type: BaseOperation

        if not (len(reshape.output.consumers) <= 1
                and len(transpose.output.consumers) <= 1
                and cgf.output not in g.outputs):
            raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF")

        cgf.output.name = transpose.output.name
        add_transform(
            transforms_by_name, cgf.output,
            Transpose(
                utils.inverse_permutation(
                    driver.get_axes_from_transpose(transpose))))
        add_transform(transforms_by_name, cgf.output,
                      Reshape(cgf.output.shape))

        graph_utils.replace_tensor_in_outputs(g, transpose.output, cgf.output)
        graph_utils.remove_subgraph(g, [transpose, reshape])
    else:
        raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF")