Ejemplo n.º 1
0
 def add_unsqueeze(onnx_graph, onnx_tensor, axes, unsqueezed_tensor=None):
     # type: (ONNXGraph, ONNXTensor, typing.List[int], typing.Optional[ONNXTensor])->ONNXTensor
     if unsqueezed_tensor is None:
         unsqueezed_tensor = ONNXTensor(graph=onnx_graph,
                                        shape=transforms.unsqueezed_shape(onnx_tensor.shape, axes),
                                        dtype=onnx_tensor.dtype)
     return ONNXOperation(graph=onnx_graph,
                          name="Unsqueeze",
                          inputs=onnx_tensor,
                          attribs=dict(axes=axes),
                          outputs=unsqueezed_tensor).output
Ejemplo n.º 2
0
def generic_convert_argminmax(op, target_name):
    # type: (TFOperation, str)->None
    kTfLiteInt64 = 4

    axis = op.attribs["axis"]

    op.name = target_name
    op.inputs = (op.input, TFTensor(graph=op.graph, shape=[], dtype='INT32', data=[axis]))
    op.attribs = dict(output_type=kTfLiteInt64)

    output_tensor = op.output
    op.outputs = (TFTensor(graph=op.graph,
                           shape=unsqueezed_shape(shape=output_tensor.shape, axes=[axis]),
                           dtype=output_tensor.dtype),)
    TFOperation(graph=op.graph,
                name="SQUEEZE",
                inputs=op.output,
                outputs=output_tensor,
                attribs=dict(squeeze_dims=[axis]))
Ejemplo n.º 3
0
def merge_transforms_into_varlikes(g, transforms_by_name, merge_into_constants,
                                   merge_into_variables, driver):
    # type: (BaseGraph, typing.Dict[str, typing.List[Transform]], bool, bool, DataFormatOptimizationDriver)->None
    transform_ops = [
        driver.squeeze_op_name, driver.unsqueeze_op_name,
        driver.reshape_op_name, driver.transpose_op_name, driver.copy_op_name
    ]

    def get_param(op):
        if op.name == driver.squeeze_op_name:
            return driver.get_axes_from_squeeze(tensor.consumers[0])
        elif op.name == driver.unsqueeze_op_name:
            return driver.get_axes_from_unsqueeze(tensor.consumers[0])
        elif op.name == driver.transpose_op_name:
            return driver.get_axes_from_transpose(tensor.consumers[0])
        elif op.name == driver.reshape_op_name:
            return driver.get_shape_from_reshape(tensor.consumers[0])
        elif op.name == driver.copy_op_name:
            return None
        else:
            assert False

    for tensor in list(g.tensors):
        while (((merge_into_variables and tensor.is_variable) or
                (merge_into_constants and tensor.is_constant))
               and len(tensor.consumers) >= 1
               and tensor.consumers[0].name in transform_ops and tensor is
               tensor.consumers[0].inputs[0]  # need to check for onnx graph
               and all(t not in g.outputs
                       for t in [tensor, tensor.consumers[0].output])):

            op_name = tensor.consumers[0].name
            op_param = get_param(tensor.consumers[0])

            if not all(op.name == op_name and get_param(op) == op_param
                       for op in tensor.consumers[1:]):
                break

            if op_name == driver.squeeze_op_name:
                axes = op_param

                tensor.shape = squeezed_shape(tensor.shape, axes)
                if tensor.is_variable and tensor.data.size > 0:
                    tensor.data = np.squeeze(tensor.data, tuple(axes))
                elif tensor.is_constant:
                    pass  # good as it is

                add_transform(transforms_by_name, tensor, Squeeze(axes))
            elif op_name == driver.unsqueeze_op_name:
                axes = op_param

                tensor.shape = unsqueezed_shape(tensor.shape, axes)
                if tensor.is_variable and tensor.data.size > 0:
                    tensor.data = np.reshape(tensor.data, tensor.shape)
                elif tensor.is_constant:
                    pass  # good as it is

                add_transform(transforms_by_name, tensor, Unsqueeze(axes))
            elif op_name == driver.reshape_op_name:
                tensor.shape = _reshaped_shape(tensor.shape, op_param)

                if tensor.is_variable and tensor.data.size > 0:
                    tensor.data = np.reshape(tensor.data, tensor.shape)
                elif tensor.is_constant:
                    pass  # good as it is

                add_transform(transforms_by_name, tensor,
                              Reshape(tensor.shape))
            elif op_name == driver.transpose_op_name:
                apply_transpose_to_varlike(tensor, op_param,
                                           transforms_by_name)
            elif op_name == driver.copy_op_name:
                pass
            else:
                assert False

            for op in list(tensor.consumers):
                remove_passthrough_ex(g, op)