def to_device(b, device=defaults.device): "Recursively put `b` on `device`." def _inner(o): return o.to(device, non_blocking=True) if isinstance(o, Tensor) else o return apply(_inner, b)
def to_device(b, device=defaults.device): """ purpose: 0. what if we want to put everything on the default 1. Recursively put `b` on defaults.device """ def _inner(o): return o.to(device, non_blocking=True) if isinstance(o, Tensor) else o return apply(_inner, b)
def to_float(b): """ purpose: 1. we sometimes want to convert x to float32 dtype through all levels 2. apply(lambda: , x) make sure recursively deep 3. we only consider two kinds of x/cases for converting 3.1 group int: torch.int64, ...32, ...16 => do nothing 3.2 everything else: torch.int8, torch.float... => convert by `x.float()` 3.3 `x.float()` is to convert to float32 """ return apply( lambda x: x.float() if x.dtype not in [torch.int64, torch.int32, torch.int16] else x, b)
def to_detach(b, cpu=True): """ purpose: 1. often we need to detach data from graph, we use `tensor.detach` 2. but we want to detach tensors at all levels in a complex data object 2.1 so we use `apply(_inner, x)` to do it recursively 2.2 `_inner` is to detach everything 3. so we want _inner to `detach` anything: 3.1 non-tensor: just return x 3.2 tensor: run x.detach() a. if cpu == True, return x.cpu() b. if not, return x """ def _inner(x, cpu=True): if not isinstance(x, Tensor): return x x = x.detach() # detach from graph, no gradient required return x.cpu() if cpu else x return apply(_inner, b, cpu=cpu)
# `to_float(b)` = Recursively map lists of float tensors in `b` to float32 def to_float(b): """ purpose: 1. we sometimes want to convert x to float32 dtype through all levels 2. apply(lambda: , x) make sure recursively deep 3. we only consider two kinds of x/cases for converting 3.1 group int: torch.int64, ...32, ...16 => do nothing 3.2 everything else: torch.int8, torch.float... => convert by `x.float()` 3.3 `x.float()` is to convert to float32 """ return apply( lambda x: x.float() if x.dtype not in [torch.int64, torch.int32, torch.int16] else x, b) # torch.int8, 16, 32, 64 b = [tensor(1, 2), tensor(1.3394858, 3.59483)] apply(lambda x: x.dtype, b) b = to_half(b) apply(lambda x: x.dtype, b) b = [tensor(1, 2), tensor(3.3, 4.1)] apply(lambda x: x.dtype, b) b = to_float(b) apply(lambda x: x.dtype, b) show_doc(torch.float)