コード例 #1
0
ファイル: judo_tensor.py プロジェクト: Guillemdb/judo
def new_backend_tensor(
    x,
    dtype=None,
    device=None,
    requires_grad: bool = None,
    copy: bool = False,
    pin_memory: bool = False,
):
    kwargs = update_with_backend_values(requires_grad=requires_grad,
                                        device=device,
                                        copy=copy,
                                        dtype=dtype)
    if Backend.is_numpy():
        return new_numpy_array(x,
                               dtype=kwargs.get("dtype"),
                               copy=kwargs.get("copy"))
    elif Backend.is_torch():
        return new_torch_tensor(
            x,
            dtype=kwargs.get("dtype"),
            copy=kwargs.get("copy"),
            device=kwargs.get("device"),
            requires_grad=kwargs.get("requires_grad"),
            pin_memory=pin_memory,
        )
コード例 #2
0
def __new_getattr(name):
    if name in DATA_TYPE_NAMES:
        return getattr(_data_types, name)()
    elif name in AVAILABLE_FUNCTIONS:
        return getattr(API, name)
    try:
        return __old_getattr(name)
    except AttributeError as e:
        if Backend.is_numpy():
            val = getattr(numpy, name)
            return __backend_wrap(val) if callable(val) else val
        elif Backend.is_torch():
            val = getattr(torch, name)
            return __backend_wrap(val) if callable(val) else val
        raise e
コード例 #3
0
ファイル: tree.py プロジェクト: FragileTech/judo
def to_node_id(x):
    if Backend.is_numpy():
        return str(x) if Backend.use_true_hash() else int(x)
    elif Backend.is_torch():
        return int(x)