Пример #1
0
def _is_squeeze_invariant_to_perm(squeeze_axes, perm):
    dummy_shape = list(range(len(perm)))
    perm_dummy_shape = utils.apply_permutation(dummy_shape, perm)
    perm_squeeze_axes = Transposer.apply_permutation_to_axes(
        squeeze_axes, perm)
    shape1 = squeezed_shape(dummy_shape,
                            squeeze_axes,
                            can_squeeze_not_one=True)
    shape2 = squeezed_shape(perm_dummy_shape,
                            perm_squeeze_axes,
                            can_squeeze_not_one=True)

    return shape1 == shape2
Пример #2
0
def apply_transpose_to_varlike(tensor, axes, transforms_by_name):
    # type: (BaseTensor, typing.List[int], typing.Dict[str, typing.List[Transform]])->None

    if tensor.rank <= 1:
        return

    old_shape = tensor.shape
    tensor.shape = utils.apply_permutation(old_shape, axes)
    if tensor.is_variable and tensor.data.size > 0:
        tensor.data = np.transpose(tensor.data, axes)
    elif tensor.is_constant:
        if len(tensor.data) > 1:
            tensor.data = (np.array(tensor.data).reshape(old_shape).transpose(
                axes).flatten().tolist())

    add_transform(transforms_by_name, tensor, Transpose(axes))
Пример #3
0
def transform_remove_inverse_transposes(
        g,  # type: BaseGraph
        transforms_by_name,  # type:typing.Dict[str, typing.List[Transform]]
        merge_into_constants,  # type: bool
        merge_into_variables,  # type: bool
        driver,  # type: DataFormatOptimizationDriver
        transposable_ops=None,  # type: typing.Optional[typing.List[TransposableOperation]]
):
    # type: (...)-> None

    if transposable_ops is None:
        transposable_ops = []

    transposable_op_by_name = {
    }  # type: typing.Dict[str, TransposableOperation]
    transposable_op_by_name.update({top.name: top for top in transposable_ops})

    for op in g.operations:
        if op.name == driver.transpose_op_name and op.output.rank > len(
                driver.get_axes_from_transpose(op)):
            driver.set_axes_on_transpose(
                op,
                driver.get_axes_from_transpose(op) + list(range(
                    op.output.rank))[len(driver.get_axes_from_transpose(op)):])

    matches = _find_inverse_transposes(
        g,
        transposable_op_names=set(six.iterkeys(transposable_op_by_name)),
        merge_into_constants=merge_into_constants,
        merge_into_variables=merge_into_variables,
        driver=driver)

    for axes, subgraph in matches:
        upper_perm = axes if subgraph.started_down else utils.inverse_permutation(
            axes)
        lower_perm = utils.inverse_permutation(upper_perm)

        upper_boundary = [
            be for be in subgraph.boundary_elements if not be.from_up
        ]
        lower_boundary = [
            be for be in subgraph.boundary_elements if be.from_up
        ]

        for _, tensor in upper_boundary:
            if tensor.producer is not None and tensor.producer.name == driver.transpose_op_name:
                if tensor in g.outputs:
                    graph_output = driver.create_tensor(
                        graph=g,
                        name=tensor.name,
                        shape=utils.apply_permutation(
                            tensor.producer.input.shape, upper_perm),
                        dtype=tensor.producer.input.dtype)
                    driver.create_transpose_op(graph=g,
                                               input=tensor.producer.input,
                                               axes=list(upper_perm),
                                               output=graph_output)
                    graph_utils.replace_tensor_in_outputs(
                        g, tensor, graph_output)
                elif (len(tensor.producer.input.consumers) == 1
                      and tensor.producer.input not in g.inputs
                      and tensor.producer.input not in g.outputs):
                    tensor.producer.input.name = tensor.name
                    add_transform(transforms_by_name, tensor.producer.input,
                                  Transpose(lower_perm))
                remove_passthrough_ex(g, tensor.producer)
            else:
                assert (merge_into_variables and tensor.is_variable) \
                       or (merge_into_constants and tensor.is_constant)

                apply_transpose_to_varlike(tensor, lower_perm,
                                           transforms_by_name)

        skipped_ops = set(
            tensor.producer for tensor in
            subgraph.skipped_tensors)  # type: typing.Set[BaseOperation]
        for op in skipped_ops:
            assert op.name in transposable_op_by_name
            transposable_op_by_name[op.name].dg_transpose(
                _transposer, g, op, lower_perm)
            for output in op.outputs:
                if output in g.outputs:
                    graph_output = driver.create_tensor(graph=g,
                                                        name=output.name,
                                                        shape=output.shape,
                                                        dtype=output.dtype)
                    driver.create_transpose_op(graph=g,
                                               input=output,
                                               axes=list(upper_perm),
                                               output=graph_output)

                    graph_utils.replace_tensor_in_outputs(
                        g, output, graph_output)
                    output.name = None
                    output.shape = utils.apply_permutation(
                        output.shape, lower_perm)
                else:
                    output.shape = utils.apply_permutation(
                        output.shape, lower_perm)
                    add_transform(transforms_by_name, output,
                                  Transpose(lower_perm))

        for _, tensor in lower_boundary:
            if tensor.producer is not None and tensor.producer.name == driver.transpose_op_name:
                if tensor in g.outputs:
                    graph_output = driver.create_tensor(
                        graph=g,
                        name=tensor.name,
                        shape=tensor.producer.input.shape,
                        dtype=tensor.producer.input.dtype)

                    driver.create_copy_op(graph=g,
                                          input=tensor.producer.input,
                                          output=graph_output)

                    graph_utils.replace_tensor_in_outputs(
                        g, tensor, graph_output)
                remove_passthrough_ex(g, tensor.producer)
            elif tensor.producer is not None and tensor.producer.name == driver.squeeze_op_name:
                driver.set_axes_on_squeeze(
                    tensor.producer,
                    sorted(
                        Transposer.apply_permutation_to_axes(
                            driver.get_axes_from_squeeze(tensor.producer),
                            lower_perm)))
            else:
                assert False

    graph_utils.remove_unreachable(g)
Пример #4
0
def transform_io(g, io_transform, transforms_by_name, driver):
    # type:(BaseGraph, TrafoOrTrafoDictType, typing.Dict[str, typing.List[Transform]], DataFormatOptimizationDriver)->None
    io_tensors_by_name = {t.name: t for t in list(g.inputs) + list(g.outputs)}

    transform_by_io_tensor = {}
    if isinstance(io_transform, dict):
        for k, v in six.iteritems(io_transform):
            assert isinstance(k, (str, driver.tensor_type)), \
                "io_transform: Key type must be {} or str".format(driver.tensor_type.__name__)
            assert isinstance(
                v, Transform), "io_transform: Value type must be Transform"

            if isinstance(k, BaseTensor):
                assert k in six.itervalues(io_tensors_by_name)
            else:
                assert k in io_tensors_by_name
                k = io_tensors_by_name[k]
            transform_by_io_tensor[k] = v
        for io_tensor in six.itervalues(io_tensors_by_name):
            assert io_tensor in transform_by_io_tensor, \
                "io_transform: Please specify transform for all io tensors. " \
                "You can use graph_optimizer.IDENTITY if no change is required."
    else:
        assert isinstance(io_transform, Transform), \
            "io_transform must be Transform or Dict[str, Transform] or Dict[NNEFTensor, Transform]"
        for t in six.itervalues(io_tensors_by_name):
            transform_by_io_tensor[t] = io_transform

    for tensor, transform in six.iteritems(transform_by_io_tensor):
        assert bool(tensor in g.inputs) != bool(tensor in g.outputs), \
            "Tensor must be input or output (and not both)"

        assert isinstance(transform, (Transpose, _CustomTransform)), \
            "Unsupported io_transform"

        if isinstance(transform, _Identity):
            continue

        if isinstance(transform, _SmartTFNCHWToNCHW):
            try:
                _transform_tf_filter_grad_to_nnef(g, tensor,
                                                  transforms_by_name, driver)
            except _TransformException:
                pass
            continue

        if isinstance(transform, _SmartNHWCToNCHW):
            if tensor.rank <= 2:
                continue
            transform = Transpose([0, tensor.rank - 1] +
                                  list(range(tensor.rank))[1:-1])

        if isinstance(transform, _SmartTFNHWCToNCHW):
            try:
                _transform_tf_filter_grad_to_nnef(g, tensor,
                                                  transforms_by_name, driver)
                continue
            except _TransformException:
                if tensor.rank <= 2:
                    continue
                transform = Transpose([0, tensor.rank - 1] +
                                      list(range(tensor.rank))[1:-1])

        if isinstance(transform, _SmartNCHWToTFNCHW):
            try:
                _transform_nnef_filter_grad_to_tf(g, tensor,
                                                  transforms_by_name, driver)
            except _TransformException:
                pass
            continue

        if isinstance(transform, _SmartNCHWToNHWC):
            if tensor.rank <= 2:
                continue
            transform = Transpose([0] + list(range(tensor.rank))[2:] + [1])

        if isinstance(transform, _SmartNCHWToTFNHWC):
            try:
                _transform_nnef_filter_grad_to_tf(g, tensor,
                                                  transforms_by_name, driver)
                continue
            except _TransformException:
                if tensor.rank <= 2:
                    continue
                transform = Transpose([0] + list(range(tensor.rank))[2:] + [1])

        if isinstance(transform, _TFFilterGradToNNEF):
            _transform_tf_filter_grad_to_nnef(g, tensor, transforms_by_name,
                                              driver)
            continue

        if isinstance(transform, _NNEFFilterGradToTF):
            _transform_nnef_filter_grad_to_tf(g, tensor, transforms_by_name,
                                              driver)
            continue

        assert isinstance(
            transform,
            Transpose), "Unsupported io_transform: {}".format(transform)
        assert len(transform.axes) == tensor.rank, "Transpose: invalid rank"

        if transform.is_identity():
            continue

        if tensor in g.inputs:
            assert tensor.name

            new_input_tensor = driver.create_tensor(
                graph=g,
                name=tensor.name,
                shape=utils.apply_permutation(tensor.shape, transform.axes),
                dtype=tensor.dtype)
            add_transform(transforms_by_name, new_input_tensor, transform)

            transpose = driver.create_transpose_op(
                graph=g,
                input=new_input_tensor,
                axes=utils.inverse_permutation(transform.axes),
                output=driver.create_tensor(graph=g,
                                            name=None,
                                            shape=tensor.shape,
                                            dtype=tensor.dtype))

            graph_utils.replace_tensor_in_inputs(g, tensor, new_input_tensor)
            graph_utils.replace_tensor_in_consumers(g,
                                                    tensor,
                                                    transpose.output,
                                                    remove=True)
        else:  # output
            transpose = driver.create_transpose_op(
                graph=g,
                input=tensor,
                axes=transform.axes,
                output=driver.create_tensor(graph=g,
                                            name=tensor.name,
                                            shape=utils.apply_permutation(
                                                tensor.shape, transform.axes),
                                            dtype=tensor.dtype))
            add_transform(transforms_by_name, transpose.output, transform)
            tensor.name = None

            graph_utils.replace_tensor_in_outputs(g, tensor, transpose.output)
Пример #5
0
 def apply_permutation(list, perm):
     return utils.apply_permutation(list, perm)
Пример #6
0
def _box_or_max_pool(
        input,  # type: torch.Tensor
        size,  # type: List[int]
        border='constant',  # type: str
        padding=None,  # type: Optional[List[Tuple[int, int]]],
        stride=None,  # type: Optional[List[int]],
        dilation=None,  # type: Optional[List[int]]
        normalize=False,  # type: bool
        is_max_pool=False,  # type: bool
):
    assert not (normalize and is_max_pool)

    rank = len(input.shape)
    padding, stride, dilation = _evaluate_max_pool_or_box_params(
        input_shape=list(input.shape),
        size=size,
        padding=padding,
        stride=stride,
        dilation=dilation)
    active = [
        size_ != 1 or padding_ != (0, 0) or stride_ != 1
        or dilation_ != 1 for size_, padding_, stride_, dilation_ in zip(
            size, padding, stride, dilation)
    ]

    if sum(active) == 0:
        return input

    if rank < 3:
        perm, perm_inv, inactive_shape, active_shape = None, None, None, None
    else:
        perm, perm_inv, inactive_shape, active_shape = _get_transform_for_box_or_max_pool(
            list(input.shape), active)

    if rank < 3:
        input = input.unsqueeze(0).unsqueeze(0)
        size = [1, 1] + size
        padding = [(0, 0), (0, 0)] + padding
        stride = [1, 1] + stride
        dilation = [1, 1] + dilation
    elif perm is not None:
        input = input.permute(*perm)
        size = utils.apply_permutation(size, perm)
        padding = utils.apply_permutation(padding, perm)
        stride = utils.apply_permutation(stride, perm)
        dilation = utils.apply_permutation(dilation, perm)

        active_rank = len(active_shape)
        input = input.reshape(*[utils.product(inactive_shape), 1] +
                              active_shape)
        size = [1, 1] + size[-active_rank:]
        padding = [(0, 0), (0, 0)] + padding[-active_rank:]
        stride = [1, 1] + stride[-active_rank:]
        dilation = [1, 1] + dilation[-active_rank:]

    if is_max_pool:
        output = _max_pool_impl(input=input,
                                size=size,
                                border=border,
                                padding=padding,
                                stride=stride,
                                dilation=dilation,
                                with_index=False)
    else:
        output = _box_impl(input=input,
                           size=size,
                           border=border,
                           padding=padding,
                           stride=stride,
                           dilation=dilation,
                           normalize=normalize)

    if rank < 3:
        output = output.squeeze(0).squeeze(0)
    elif perm is not None:
        active_rank = len(active_shape)
        output = output.reshape(inactive_shape +
                                list(output.shape)[-active_rank:])
        output = output.permute(*perm_inv)

    return output