Exemple #1
0
def _is_squeeze_invariant_to_perm(squeeze_axes, perm):
    dummy_shape = list(range(len(perm)))
    perm_dummy_shape = utils.apply_permutation(dummy_shape, perm)
    perm_squeeze_axes = Transposer.apply_permutation_to_axes(
        squeeze_axes, perm)
    shape1 = squeezed_shape(dummy_shape,
                            squeeze_axes,
                            can_squeeze_not_one=True)
    shape2 = squeezed_shape(perm_dummy_shape,
                            perm_squeeze_axes,
                            can_squeeze_not_one=True)

    return shape1 == shape2
 def add_squeeze(caffe2_graph, caffe2_tensor, dims):
     # type: (Caffe2Graph, Caffe2Tensor, typing.List[int])->Caffe2Tensor
     return Caffe2Operation(graph=caffe2_graph,
                            name="Squeeze",
                            inputs=caffe2_tensor,
                            attribs=dict(dims=dims),
                            outputs=Caffe2Tensor(graph=caffe2_graph,
                                                 shape=transforms.squeezed_shape(caffe2_tensor.shape, dims),
                                                 dtype=caffe2_tensor.dtype)).output
 def add_squeeze(onnx_graph, onnx_tensor, axes, squeezed_tensor=None):
     # type: (ONNXGraph, ONNXTensor, typing.List[int], typing.Optional[ONNXTensor])->ONNXTensor
     if squeezed_tensor is None:
         squeezed_tensor = ONNXTensor(graph=onnx_graph,
                                      shape=transforms.squeezed_shape(onnx_tensor.shape, axes),
                                      dtype=onnx_tensor.dtype)
     return ONNXOperation(graph=onnx_graph,
                          name="Squeeze",
                          inputs=onnx_tensor,
                          attribs=dict(axes=axes),
                          outputs=squeezed_tensor).output
Exemple #4
0
def generic_convert_argminmax(op, target_name):
    # type: (TFOperation, str)->None
    tflite_to_tf_dtype = {4: 9, 2: 3}
    op.name = target_name
    assert op.inputs[
        1].data is not None, "ARG_MIN/ARG_MAX is only supported with constant axis (inputs[1])"
    axis = op.inputs[1].data.tolist()
    op.attribs = dict(
        axis=axis, output_type=tflite_to_tf_dtype[op.attribs["output_type"]])
    op.inputs = (op.inputs[0], )

    output_tensor = op.output
    op.outputs = (TFTensor(graph=op.graph,
                           shape=squeezed_shape(shape=output_tensor.shape,
                                                axes=[axis]),
                           dtype=output_tensor.dtype), )
    TFOperation(graph=op.graph,
                name="tf.expand_dims",
                inputs=op.output,
                outputs=output_tensor,
                attribs=dict(axis=axis))
Exemple #5
0
def merge_transforms_into_varlikes(g, transforms_by_name, merge_into_constants,
                                   merge_into_variables, driver):
    # type: (BaseGraph, typing.Dict[str, typing.List[Transform]], bool, bool, DataFormatOptimizationDriver)->None
    transform_ops = [
        driver.squeeze_op_name, driver.unsqueeze_op_name,
        driver.reshape_op_name, driver.transpose_op_name, driver.copy_op_name
    ]

    def get_param(op):
        if op.name == driver.squeeze_op_name:
            return driver.get_axes_from_squeeze(tensor.consumers[0])
        elif op.name == driver.unsqueeze_op_name:
            return driver.get_axes_from_unsqueeze(tensor.consumers[0])
        elif op.name == driver.transpose_op_name:
            return driver.get_axes_from_transpose(tensor.consumers[0])
        elif op.name == driver.reshape_op_name:
            return driver.get_shape_from_reshape(tensor.consumers[0])
        elif op.name == driver.copy_op_name:
            return None
        else:
            assert False

    for tensor in list(g.tensors):
        while (((merge_into_variables and tensor.is_variable) or
                (merge_into_constants and tensor.is_constant))
               and len(tensor.consumers) >= 1
               and tensor.consumers[0].name in transform_ops and tensor is
               tensor.consumers[0].inputs[0]  # need to check for onnx graph
               and all(t not in g.outputs
                       for t in [tensor, tensor.consumers[0].output])):

            op_name = tensor.consumers[0].name
            op_param = get_param(tensor.consumers[0])

            if not all(op.name == op_name and get_param(op) == op_param
                       for op in tensor.consumers[1:]):
                break

            if op_name == driver.squeeze_op_name:
                axes = op_param

                tensor.shape = squeezed_shape(tensor.shape, axes)
                if tensor.is_variable and tensor.data.size > 0:
                    tensor.data = np.squeeze(tensor.data, tuple(axes))
                elif tensor.is_constant:
                    pass  # good as it is

                add_transform(transforms_by_name, tensor, Squeeze(axes))
            elif op_name == driver.unsqueeze_op_name:
                axes = op_param

                tensor.shape = unsqueezed_shape(tensor.shape, axes)
                if tensor.is_variable and tensor.data.size > 0:
                    tensor.data = np.reshape(tensor.data, tensor.shape)
                elif tensor.is_constant:
                    pass  # good as it is

                add_transform(transforms_by_name, tensor, Unsqueeze(axes))
            elif op_name == driver.reshape_op_name:
                tensor.shape = _reshaped_shape(tensor.shape, op_param)

                if tensor.is_variable and tensor.data.size > 0:
                    tensor.data = np.reshape(tensor.data, tensor.shape)
                elif tensor.is_constant:
                    pass  # good as it is

                add_transform(transforms_by_name, tensor,
                              Reshape(tensor.shape))
            elif op_name == driver.transpose_op_name:
                apply_transpose_to_varlike(tensor, op_param,
                                           transforms_by_name)
            elif op_name == driver.copy_op_name:
                pass
            else:
                assert False

            for op in list(tensor.consumers):
                remove_passthrough_ex(g, op)