Beispiel #1
0
def reshape_shape(op):
    # type: (Caffe2Operation)->ShapeResult

    if len(op.inputs) == 1:
        shape = op.attribs['shape']
    elif len(op.inputs) == 2:
        if op.inputs[1].data is None:
            raise utils.NNEFToolsException(
                'Reshape is not supported with calculated shape.')
        shape = op.inputs[1].data.tolist()
    else:
        assert False

    graph_utils.replace_tensor_in_consumers(
        op.graph,
        op.outputs[1],
        Caffe2Tensor(graph=op.graph,
                     shape=[op.inputs[0].rank],
                     data=np.array(op.inputs[0].shape, dtype=np.int64),
                     dtype=DTYPE_INT64),
        remove=False)

    op.attribs['shape'] = shape
    op.inputs = (op.inputs[0], )

    return (infer.reshape(op.inputs[0].shape, shape=shape, zero_means_same=True), [op.inputs[0].rank]), \
           (op.inputs[0].dtype, DTYPE_INT64)
Beispiel #2
0
def remove_passthrough_ex(g, op):
    # type: (BaseGraph, BaseOperation)->None
    op_input = op.inputs[0]
    op_output = op.outputs[0]

    g.remove_operation(op, unlink=True)
    graph_utils.replace_tensor_in_consumers(g,
                                            op_output,
                                            op_input,
                                            remove=True)
Beispiel #3
0
def shape_shape(op):
    # type: (Caffe2Operation)->ShapeResult
    graph_utils.replace_tensor_in_consumers(op.graph,
                                            op.output,
                                            Caffe2Tensor(graph=op.graph,
                                                         shape=[op.input.rank],
                                                         data=np.array(
                                                             op.input.shape,
                                                             dtype=np.int64),
                                                         dtype=DTYPE_INT64),
                                            remove=False)
    return [op.input.rank], DTYPE_INT64
Beispiel #4
0
def size_shape(op):
    # type: (Caffe2Operation)->ShapeResult

    graph_utils.replace_tensor_in_consumers(
        op.graph,
        op.outputs[0],
        Caffe2Tensor(graph=op.graph,
                     shape=[],
                     data=np.array(op.inputs[0].count, dtype=np.int64),
                     dtype=DTYPE_INT64),
        remove=False)

    return one_element_0d_shape(op)
Beispiel #5
0
def concat_shape(op):
    # type: (Caffe2Operation)->ShapeResult
    if op.attribs['add_axis']:
        output_shape = infer.stack([input.shape for input in op.inputs],
                                   axis=op.attribs['axis'])
    else:
        output_shape = infer.concat([input.shape for input in op.inputs],
                                    axis=op.attribs['axis'])

    graph_utils.replace_tensor_in_consumers(
        op.graph,
        op.outputs[1],
        Caffe2Tensor(
            graph=op.graph,
            shape=[len(op.inputs)],
            data=np.array(
                [input.shape[op.attribs['axis']] for input in op.inputs],
                dtype=np.int32),
            dtype=DTYPE_INT32),
        remove=False)

    return (output_shape, [len(op.inputs)]), (op.inputs[0].dtype, DTYPE_INT32)
Beispiel #6
0
def transform_io(g, io_transform, transforms_by_name, driver):
    # type:(BaseGraph, TrafoOrTrafoDictType, typing.Dict[str, typing.List[Transform]], DataFormatOptimizationDriver)->None
    io_tensors_by_name = {t.name: t for t in list(g.inputs) + list(g.outputs)}

    transform_by_io_tensor = {}
    if isinstance(io_transform, dict):
        for k, v in six.iteritems(io_transform):
            assert isinstance(k, (str, driver.tensor_type)), \
                "io_transform: Key type must be {} or str".format(driver.tensor_type.__name__)
            assert isinstance(
                v, Transform), "io_transform: Value type must be Transform"

            if isinstance(k, BaseTensor):
                assert k in six.itervalues(io_tensors_by_name)
            else:
                assert k in io_tensors_by_name
                k = io_tensors_by_name[k]
            transform_by_io_tensor[k] = v
        for io_tensor in six.itervalues(io_tensors_by_name):
            assert io_tensor in transform_by_io_tensor, \
                "io_transform: Please specify transform for all io tensors. " \
                "You can use graph_optimizer.IDENTITY if no change is required."
    else:
        assert isinstance(io_transform, Transform), \
            "io_transform must be Transform or Dict[str, Transform] or Dict[NNEFTensor, Transform]"
        for t in six.itervalues(io_tensors_by_name):
            transform_by_io_tensor[t] = io_transform

    for tensor, transform in six.iteritems(transform_by_io_tensor):
        assert bool(tensor in g.inputs) != bool(tensor in g.outputs), \
            "Tensor must be input or output (and not both)"

        assert isinstance(transform, (Transpose, _CustomTransform)), \
            "Unsupported io_transform"

        if isinstance(transform, _Identity):
            continue

        if isinstance(transform, _SmartTFNCHWToNCHW):
            try:
                _transform_tf_filter_grad_to_nnef(g, tensor,
                                                  transforms_by_name, driver)
            except _TransformException:
                pass
            continue

        if isinstance(transform, _SmartNHWCToNCHW):
            if tensor.rank <= 2:
                continue
            transform = Transpose([0, tensor.rank - 1] +
                                  list(range(tensor.rank))[1:-1])

        if isinstance(transform, _SmartTFNHWCToNCHW):
            try:
                _transform_tf_filter_grad_to_nnef(g, tensor,
                                                  transforms_by_name, driver)
                continue
            except _TransformException:
                if tensor.rank <= 2:
                    continue
                transform = Transpose([0, tensor.rank - 1] +
                                      list(range(tensor.rank))[1:-1])

        if isinstance(transform, _SmartNCHWToTFNCHW):
            try:
                _transform_nnef_filter_grad_to_tf(g, tensor,
                                                  transforms_by_name, driver)
            except _TransformException:
                pass
            continue

        if isinstance(transform, _SmartNCHWToNHWC):
            if tensor.rank <= 2:
                continue
            transform = Transpose([0] + list(range(tensor.rank))[2:] + [1])

        if isinstance(transform, _SmartNCHWToTFNHWC):
            try:
                _transform_nnef_filter_grad_to_tf(g, tensor,
                                                  transforms_by_name, driver)
                continue
            except _TransformException:
                if tensor.rank <= 2:
                    continue
                transform = Transpose([0] + list(range(tensor.rank))[2:] + [1])

        if isinstance(transform, _TFFilterGradToNNEF):
            _transform_tf_filter_grad_to_nnef(g, tensor, transforms_by_name,
                                              driver)
            continue

        if isinstance(transform, _NNEFFilterGradToTF):
            _transform_nnef_filter_grad_to_tf(g, tensor, transforms_by_name,
                                              driver)
            continue

        assert isinstance(
            transform,
            Transpose), "Unsupported io_transform: {}".format(transform)
        assert len(transform.axes) == tensor.rank, "Transpose: invalid rank"

        if transform.is_identity():
            continue

        if tensor in g.inputs:
            assert tensor.name

            new_input_tensor = driver.create_tensor(
                graph=g,
                name=tensor.name,
                shape=utils.apply_permutation(tensor.shape, transform.axes),
                dtype=tensor.dtype)
            add_transform(transforms_by_name, new_input_tensor, transform)

            transpose = driver.create_transpose_op(
                graph=g,
                input=new_input_tensor,
                axes=utils.inverse_permutation(transform.axes),
                output=driver.create_tensor(graph=g,
                                            name=None,
                                            shape=tensor.shape,
                                            dtype=tensor.dtype))

            graph_utils.replace_tensor_in_inputs(g, tensor, new_input_tensor)
            graph_utils.replace_tensor_in_consumers(g,
                                                    tensor,
                                                    transpose.output,
                                                    remove=True)
        else:  # output
            transpose = driver.create_transpose_op(
                graph=g,
                input=tensor,
                axes=transform.axes,
                output=driver.create_tensor(graph=g,
                                            name=tensor.name,
                                            shape=utils.apply_permutation(
                                                tensor.shape, transform.axes),
                                            dtype=tensor.dtype))
            add_transform(transforms_by_name, transpose.output, transform)
            tensor.name = None

            graph_utils.replace_tensor_in_outputs(g, tensor, transpose.output)
Beispiel #7
0
def transform_fused_batch_norm(g, op):
    # type: (TFGraph, TFOperation)->None
    VARIANCE_CORRECTION_ENABLED = True

    in_input = op.inputs[0]
    in_scale = op.inputs[1]
    in_offset = op.inputs[2]

    epsilon = op.attribs["epsilon"]

    out_y = op.outputs[0]
    out_batch_mean = op.outputs[1]
    out_batch_var = op.outputs[2]

    data_format = op.attribs["data_format"].upper(
    ) if op.attribs["data_format"] else "NHWC"
    channel_dim = 1 if data_format == "NCHW" else in_input.rank - 1
    rest_count = int(op.inputs[0].count / channel_dim)
    tensors_to_remove = []

    if op.attribs["is_training"]:
        if VARIANCE_CORRECTION_ENABLED:
            biased_batch_var = TFTensor(graph=g,
                                        shape=list(out_batch_var.shape),
                                        dtype=out_batch_var.dtype)
            const = TFTensor(graph=g,
                             shape=[],
                             dtype=in_input.dtype,
                             data=float(rest_count) / max(rest_count - 1, 1))
            TFOperation(graph=g,
                        name="tf.nn.moments",
                        inputs=in_input,
                        attribs=dict(axes=utils.without(
                            range(in_input.rank), channel_dim),
                                     keep_dims=False),
                        outputs=(out_batch_mean, biased_batch_var))
            TFOperation(graph=g,
                        name="tf.multiply",
                        inputs=(biased_batch_var, const),
                        outputs=out_batch_var)
            TFOperation(graph=g,
                        name="tf.nn.batch_normalization",
                        inputs=(in_input, out_batch_mean, out_batch_var,
                                in_offset, in_scale),
                        attribs=dict(variance_epsilon=epsilon,
                                     _data_format=data_format),
                        outputs=out_y)
            if len(op.outputs) > 3:  # This can happen in gradients
                out_saved_mean = op.outputs[3]
                out_saved_var = op.outputs[4]
                graph_utils.replace_tensor_in_consumers(
                    g, out_saved_mean, out_batch_mean)
                graph_utils.replace_tensor_in_consumers(
                    g, out_saved_var, out_batch_var)
                tensors_to_remove += [out_saved_mean, out_saved_var]
        else:  # not VARIANCE_CORRECTION_ENABLED
            TFOperation(graph=g,
                        name="tf.nn.moments",
                        inputs=in_input,
                        attribs=dict(axes=utils.without(
                            range(in_input.rank), channel_dim),
                                     keep_dims=False),
                        outputs=(out_batch_mean, out_batch_var))
            TFOperation(graph=g,
                        name="tf.nn.batch_normalization",
                        inputs=(in_input, out_batch_mean, out_batch_var,
                                in_offset, in_scale),
                        attribs=dict(variance_epsilon=epsilon,
                                     _data_format=data_format),
                        outputs=out_y)
            if len(op.outputs) > 3:  # This can happen in gradients
                out_saved_mean = op.outputs[3]
                out_saved_var = op.outputs[4]
                graph_utils.replace_tensor_in_consumers(
                    g, out_saved_mean, out_batch_mean)
                graph_utils.replace_tensor_in_consumers(
                    g, out_saved_var, out_batch_var)
                tensors_to_remove += [out_saved_mean, out_saved_var]
    else:  # not training
        in_mean = op.inputs[3]
        in_variance = op.inputs[4]
        graph_utils.replace_tensor_in_consumers(g, out_batch_mean, in_mean)
        graph_utils.replace_tensor_in_consumers(g, out_batch_var, in_variance)
        tensors_to_remove += [out_batch_mean, out_batch_var]
        if len(op.outputs) > 3:  # This can happen in gradients
            out_saved_mean = op.outputs[3]
            out_saved_var = op.outputs[4]
            graph_utils.replace_tensor_in_consumers(g, out_saved_mean, in_mean)
            graph_utils.replace_tensor_in_consumers(g, out_saved_var,
                                                    in_variance)
            tensors_to_remove += [out_saved_mean, out_saved_var]
        TFOperation(graph=g,
                    name="tf.nn.batch_normalization",
                    inputs=(in_input, in_mean, in_variance, in_offset,
                            in_scale),
                    attribs=dict(variance_epsilon=epsilon,
                                 _data_format=data_format),
                    outputs=out_y)
    g.remove_operation(op, unlink=True)
    for t in tensors_to_remove:
        g.remove_tensor(t)