def test_tensors(): m = Message(tensors, vectors) t = m.tensors() assert t == TensorMessage(tensors) t = m.tensors(keys=['a']) assert t == TensorMessage({'a': tensors['a']}) t = m.tensors(keys=['a','c']) assert t == TensorMessage({'a': tensors['a'], 'c': torch.Tensor(vectors['c'])})
def test_cpu_gpu(): m = Message(tensors, vectors) m.cpu() assert set(m.tensors().keys()) == set(['a','b']) for key, tensor in m.tensors().items(): assert tensor.device.type == 'cpu' if torch.cuda.is_available(): m.cuda() for key, tensor in m.tensors().items(): assert tensor.device.type == 'cuda' m.cpu() for key, tensor in m.tensors().items(): assert tensor.device.type == 'cpu'