コード例 #1
0
def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
    scores_inds = flow_exp.argsort(scores, dim=0, descending=True)
    boxes = flow._C.gather(boxes, scores_inds, axis=0)
    _nms_op = (flow_exp.builtin_op("nms").Input("in").Output("out").Attr(
        "iou_threshold", iou_threshold).Attr("keep_n", -1).Build())
    keep = _nms_op(boxes)[0]
    index = flow_exp.squeeze(flow_exp.argwhere(keep), dim=[1])
    return flow._C.gather(scores_inds, index, axis=0)
コード例 #2
0
def _test_argwhere_with_random_data(test_case, ndim, placement, sbp):
    dims = [random(1, 3) * 8 for _ in range(ndim)]
    x = random_tensor(ndim, *dims).to_global(placement=placement, sbp=sbp)
    # PyTorch has no argwhere before v1.11, so we use nonzero instead of argwhere for PyTorch
    # y = torch.argwhere(x)
    y = x.clone()
    y.oneflow = flow.argwhere(x.oneflow)
    y.pytorch = torch_ori.nonzero(x.pytorch)
    return y
コード例 #3
0
def _test_argwhere(test_case, shape, device):
    np_input = np.random.randn(*shape)
    input = flow.tensor(np_input,
                        dtype=flow.float32,
                        device=flow.device(device))
    of_out = flow.argwhere(input)
    np_out = np.argwhere(np_input)
    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))
    test_case.assertTrue(np.array_equal(of_out.numpy().shape, np_out.shape))
コード例 #4
0
ファイル: masked_select.py プロジェクト: zzk0/oneflow
def masked_select_op(input, mask):
    """

    Returns a new 1-D tensor which indexes the input tensor according to the boolean mask mask which is a BoolTensor(In oneFlow BoolTensor is replaced by Int8Tensor).

    The shapes of the mask tensor and the input tensor don’t need to match, but they must be broadcastable.

    Args:
        input (Tensor): the input tensor.
        mask (Tensor): the tensor containing the binary mask to index with

    For example:

    .. code-block:: python

        >>> import oneflow as flow
        >>> import numpy as np
        
        >>> input = flow.tensor(np.array([[-0.4620, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]), dtype=flow.float32)
        >>> mask = input.gt(0.05)
        >>> out = flow.masked_select(input, mask)
        >>> out
        tensor([0.3139, 0.3898], dtype=oneflow.float32)
    """

    assert len(input.shape) == len(
        mask.shape
    ), f"The dim of masked_select module's inputs can not match, please check!"
    broadcast_like_shape = []
    broadcast_x_axes = []
    broadcast_mask_axes = []
    for i in range(len(input.shape)):
        max_dim = max(input.shape[i], mask.shape[i])
        broadcast_like_shape.append(max_dim)
        if max_dim != input.shape[i]:
            broadcast_x_axes.append(i)
        if max_dim != mask.shape[i]:
            broadcast_mask_axes.append(i)
    broadcast_like_tensor = flow.zeros(tuple(broadcast_like_shape),
                                       dtype=flow.float32,
                                       device=input.device)
    broadcast_like_tensor.requires_grad = input.requires_grad or mask.requires_grad
    if len(broadcast_x_axes) != 0:
        input = flow.broadcast_like(input,
                                    broadcast_like_tensor,
                                    broadcast_axes=tuple(broadcast_x_axes))
    if len(broadcast_mask_axes) != 0:
        mask = flow.broadcast_like(mask,
                                   broadcast_like_tensor,
                                   broadcast_axes=tuple(broadcast_mask_axes))
    mask = mask.to(dtype=input.dtype)
    res = flow._C.mul(input, mask)
    indices = flow.argwhere(res)
    gather_res = flow._C.gather_nd(res, indices)
    return gather_res.flatten()
コード例 #5
0
def _argwhere(self):
    return flow.argwhere(self)
コード例 #6
0
ファイル: nms.py プロジェクト: zzk0/oneflow
def nms_op(boxes, scores, iou_threshold: float):
    score_inds = flow.argsort(scores, dim=0, descending=True)
    boxes = flow._C.gather(boxes, score_inds, axis=0)
    keep = flow._C.nms(boxes, iou_threshold)
    index = flow.squeeze(flow.argwhere(keep), dim=[1])
    return flow._C.gather(score_inds, index, axis=0)
コード例 #7
0
ファイル: test_argwhere.py プロジェクト: zyg11/oneflow
 def argwhere_fn(x: flow.typing.Numpy.Placeholder(
     x.shape, dtype=data_type)) -> flow.typing.ListNumpy:
     return flow.argwhere(x, dtype=out_data_type)
コード例 #8
0
 def do_argwhere(x_blob):
     with flow.scope.placement(device_type, "0:0"):
         return flow.argwhere(x_blob, dtype=out_data_type)