Exemplo n.º 1
0
def compare_with_tensorflow(device_type, in_shape, axis, k, data_type, sorted):
    assert device_type in ["gpu", "cpu"]
    assert data_type in ["float32", "double", "int8", "int32", "int64"]
    flow.clear_default_session()
    func_config = flow.FunctionConfig()
    func_config.default_logical_view(flow.scope.mirrored_view())
    func_config.default_data_type(flow.float)

    @flow.global_function(function_config=func_config)
    def TopKJob(input: oft.ListNumpy.Placeholder(
        tuple([dim + 10 for dim in in_shape]),
        dtype=type_name_to_flow_type[data_type],
    )):
        with flow.scope.placement(device_type, "0:0"):
            return flow.math.top_k(input, axis, k, sorted)

    input = (np.random.random(in_shape) * 100).astype(
        type_name_to_np_type[data_type])
    # OneFlow
    of_out = TopKJob([input]).get().numpy_list()[0]
    # TensorFlow
    if k <= in_shape[axis]:
        perm = get_perm_when_transpose_axis_to_last_dim(len(in_shape), axis)
        x = tf.transpose(input, perm)
        _, indices = tf.math.top_k(x, k, sorted)
        tf_out = tf.transpose(indices, get_inversed_perm(perm))

    else:
        tf_out = tf.argsort(input, axis, direction="DESCENDING", stable=True)

    assert np.array_equal(of_out, tf_out.numpy())
Exemplo n.º 2
0
    def forward(self, input):
        if self.dim == None:
            self.dim = -1

        num_axes = len(input.shape)
        axis = self.dim if self.dim >= 0 else self.dim + num_axes
        assert 0 <= axis < num_axes, "axis out of range"
        if axis == num_axes - 1:
            if self.largest:
                indices = self._op_topk_last_dim(input)[0]
            else:
                neg_input = flow.experimental.mul(input, -1)
                indices = self._op_topk_last_dim(neg_input)[0]
            return (flow.experimental.gather(input, indices,
                                             dim=axis), indices)
        else:
            perm = get_perm_when_transpose_axis_to_last_dim(num_axes, axis)
            x = flow.F.transpose(input, perm=perm)
            if self.largest:
                indices = self._op_topk_last_dim(x)[0]
            else:
                neg_input = flow.experimental.mul(x, -1)
                indices = self._op_topk_last_dim(neg_input)[0]
            indices = flow.F.transpose(indices, perm=get_inversed_perm(perm))
            return (flow.experimental.gather(input, indices,
                                             dim=axis), indices)
Exemplo n.º 3
0
def sort(
    input: remote_blob_util.BlobDef,
    axis: int = -1,
    direction: str = "ASCENDING",
    name: Optional[str] = None,
) -> remote_blob_util.BlobDef:
    """This operator sorts the input Blob at specified axis.

    Args:
        input (remote_blob_util.BlobDef): A Blob
        axis (int, optional): dimension to be sorted. Defaults to the last dim (-1)
        direction (str, optional): The direction in which to sort the Blob values. If the direction is "ASCENDING", The order of input will be sorted as ascending, else, the order of input will be sorted as descending. Defaults to "ASCENDING".
        name (Optional[str], optional): The name for the operation. Defaults to None.

    Returns:
        remote_blob_util.BlobDef: The sorted Blob

    For example:

    .. code-block:: python

        import oneflow as flow
        import numpy as np
        import oneflow.typing as tp


        @flow.global_function()
        def sort_Job(x: tp.Numpy.Placeholder((5, ))
        ) -> tp.Numpy:
            return flow.sort(input=x)

        x = np.array([10, 2, 9, 3, 7]).astype("float32")
        out = sort_Job(x)

        # out [ 2.  3.  7.  9. 10.]

    """
    assert direction in ["ASCENDING", "DESCENDING"]
    name = name if name is not None else id_util.UniqueStr("Sort_")
    num_axes = len(input.shape)
    axis = axis if axis >= 0 else axis + num_axes
    assert 0 <= axis < num_axes, "axis out of range"
    if axis == num_axes - 1:
        return _sort_at_last_dim(input, direction, name)
    else:
        perm = get_perm_when_transpose_axis_to_last_dim(num_axes, axis)
        x = flow.transpose(input, perm, False, True, name + "_transpose")
        x = _sort_at_last_dim(x, direction, name)
        return flow.transpose(
            x, get_inversed_perm(perm), False, True, name + "_inverse_transpose"
        )
Exemplo n.º 4
0
    def forward(self, input):
        if self.dim == None:
            input = self._flatten(input)[0]
            self.dim = 0

        num_axes = len(input.shape)
        axis = self.dim if self.dim >= 0 else self.dim + num_axes
        assert 0 <= axis < num_axes, "axis out of range"
        if axis == num_axes - 1:
            x = self._op_softmax_last_dim(input)[0]
            if self.keepdim == True:
                x = self._expand_op(x)
            return x
        else:
            perm = get_perm_when_transpose_axis_to_last_dim(num_axes, axis)
            x = flow.tmp.transpose(input, perm=perm)
            x = self._op_softmax_last_dim(x)[0]
            x = self._expand_op(x)[0]
            x = flow.tmp.transpose(x, perm=get_inversed_perm(perm))
            if self.keepdim == False:
                x = flow.tmp.squeeze(x, axis=[axis])
            return x
Exemplo n.º 5
0
    def forward(self, input):
        if self.dim == None:
            input = flow.F.flatten(input)
            self.dim = 0

        num_axes = len(input.shape)
        axis = self.dim if self.dim >= 0 else self.dim + num_axes
        assert 0 <= axis < num_axes, "axis out of range"
        if axis == num_axes - 1:
            x = flow.F.argmax(input)
            if self.keepdim == True:
                x = flow.experimental.unsqueeze(x, -1)
            return x
        else:
            perm = get_perm_when_transpose_axis_to_last_dim(num_axes, axis)
            x = flow.F.transpose(input, perm=perm)
            x = flow.F.argmax(x)
            x = flow.experimental.unsqueeze(x, -1)
            x = flow.F.transpose(x, perm=get_inversed_perm(perm))
            if self.keepdim == False:
                x = x.squeeze(dim=[axis])
            return x