def test_where_input_validation(condition, x, y):
    args = [i for i in (x, y) if i is not None]

    try:
        np.where(condition, *args)
    except Exception as e:
        with pytest.raises(type(e)):
            where(condition, *args)
        return
def test_where_condition_only_fwd(condition):
    """mygrad.where should merely mirror numpy.where when only `where(condition)`
    is specified."""
    tensor_condition = (mg.Tensor(condition) if isinstance(
        condition, np.ndarray) else condition)
    assert all(
        np.all(x == y)
        for x, y in zip(where(tensor_condition), np.where(condition)))
def mygrad_where(x, y, condition, constant=False):
    return where(condition, x, y, constant=constant)