コード例 #1
0
def hash_type():
    funcs = {
        "numpy":
        lambda x: numpy.dtype("<U64")
        if Backend.use_true_hash() else numpy.uint64,
        "torch":
        lambda x: torch.int64,
    }
    return Backend.execute(None, funcs)
コード例 #2
0
ファイル: judo_tensor.py プロジェクト: Guillemdb/judo
def copy(x, requires_grad: bool = None):
    if x is None:
        return
    if not dtype.is_tensor(x):
        x = JudoTensor(x)

    funcs = {
        "numpy": lambda x: x.copy(),
        "torch": lambda x: copy_torch(x, requires_grad),
    }
    return Backend.execute(x, funcs)
コード例 #3
0
ファイル: random.py プロジェクト: Guillemdb/judo
 def __getattr__(cls, item):
     funcs = {
         "numpy": lambda x: getattr(cls._numpy_random_state, x),
         "torch": lambda x: getattr(cls._torch_random_state, x),
     }
     return Backend.execute(item, funcs)
コード例 #4
0
def float64():
    funcs = {
        "numpy": lambda x: numpy.float64,
        "torch": lambda x: torch.float64
    }
    return Backend.execute(None, funcs)
コード例 #5
0
def int32():
    funcs = {"numpy": lambda x: numpy.int32, "torch": lambda x: torch.int32}
    return Backend.execute(None, funcs)
コード例 #6
0
def uint8():
    funcs = {"numpy": lambda x: numpy.uint8, "torch": lambda x: torch.uint8}
    return Backend.execute(None, funcs)
コード例 #7
0
def bool():
    funcs = {"numpy": lambda x: numpy.bool_, "torch": lambda x: torch.bool}
    return Backend.execute(None, funcs)
コード例 #8
0
ファイル: judo_tensor.py プロジェクト: Guillemdb/judo
def as_tensor(x, *args, **kwargs):
    funcs = {
        "numpy": lambda x: numpy.ascontiguousarray(x, *args, **kwargs),
        "torch": lambda x: torch.as_tensor(x, *args, **kwargs),
    }
    return Backend.execute(x, funcs)
コード例 #9
0
ファイル: judo_tensor.py プロジェクト: Guillemdb/judo
def astype(x, dtype):
    funcs = {
        "numpy": lambda x: x.astype(dtype),
        "torch": lambda x: x.to(dtype),
    }
    return Backend.execute(x, funcs)
コード例 #10
0
ファイル: hashing.py プロジェクト: FragileTech/judo
 def true_hash_tensor(cls, x):
     funcs = {
         "numpy": cls.hash_numpy,
         "torch": cls.hash_torch,
     }
     return Backend.execute(x, funcs)