def test_tensors_to_device_cuda(tensors): out = tensors_to_device(tensors, "cuda") if isinstance(out, Tensor): assert out.is_cuda elif isinstance(out, Iterable): for tens in out: if isinstance(tens, Tensor): assert tens.is_cuda elif isinstance(out, Dict): for key, tens in out.items(): if isinstance(tens, Tensor): assert tens.is_cuda
def test_tensors_to_precision_half_cuda(tensors): tensors = tensors_to_device(tensors, "cuda") out = tensors_to_precision(tensors, False) if isinstance(out, Tensor): assert out.dtype == torch.float16 elif isinstance(out, Iterable): for tens in out: if isinstance(tens, Tensor): assert tens.dtype == torch.float16 elif isinstance(out, Dict): for key, tens in out.items(): if isinstance(tens, Tensor): assert tens.dtype == torch.float16
def test_tensors_module_forward_cuda(module, tensors, check_feat_lab_inp): module = module.to("cuda") tensors = tensors_to_device(tensors, "cuda") out = tensors_module_forward(tensors, module, check_feat_lab_inp) assert out