Exemplo n.º 1
0
def searchsorted(a: NdarrayOrTensor, v: NdarrayOrTensor, right=False, sorter=None):
    side = "right" if right else "left"
    if isinstance(a, np.ndarray):
        return np.searchsorted(a, v, side, sorter)  # type: ignore
    if hasattr(torch, "searchsorted"):
        return torch.searchsorted(a, v, right=right)  # type: ignore
    # if using old PyTorch, will convert to numpy array then compute
    ret = np.searchsorted(a.cpu().numpy(), v.cpu().numpy(), side, sorter)  # type: ignore
    ret, *_ = convert_to_dst_type(ret, a)
    return ret
Exemplo n.º 2
0
def assert_allclose(
    actual: NdarrayOrTensor,
    desired: NdarrayOrTensor,
    type_test: bool = True,
    device_test: bool = False,
    *args,
    **kwargs,
):
    """
    Assert that types and all values of two data objects are close.

    Args:
        actual: Pytorch Tensor or numpy array for comparison.
        desired: Pytorch Tensor or numpy array to compare against.
        type_test: whether to test that `actual` and `desired` are both numpy arrays or torch tensors.
        device_test: whether to test the device property.
        args: extra arguments to pass on to `np.testing.assert_allclose`.
        kwargs: extra arguments to pass on to `np.testing.assert_allclose`.


    """
    if type_test:
        # check both actual and desired are of the same type
        np.testing.assert_equal(isinstance(actual, np.ndarray),
                                isinstance(desired, np.ndarray), "numpy type")
        np.testing.assert_equal(isinstance(actual, torch.Tensor),
                                isinstance(desired, torch.Tensor),
                                "torch type")

    if isinstance(desired, torch.Tensor) or isinstance(actual, torch.Tensor):
        if device_test:
            np.testing.assert_equal(str(actual.device), str(desired.device),
                                    "torch device check")  # type: ignore
        actual = actual.cpu().numpy() if isinstance(actual,
                                                    torch.Tensor) else actual
        desired = desired.cpu().numpy() if isinstance(
            desired, torch.Tensor) else desired
    np.testing.assert_allclose(actual, desired, *args, **kwargs)