예제 #1
0
 def test_apply_on_recursive_dict(self):
     my_dict = {}
     my_dict['a'] = MagicMock(spec=torch.Tensor)
     my_dict['b'] = {}
     for k in ['c', 'd']:
         my_dict['b'][k] = MagicMock(spec=torch.Tensor)
     torch_apply(my_dict, lambda t: t.cpu())
     self._test_method_calls([my_dict['a'], *my_dict['b'].values()])
예제 #2
0
파일: test_utils.py 프로젝트: jtbai/poutyne
 def test_apply_on_recursive_data_structure(self):
     my_obj = {
         'a': [MagicMock(spec=torch.Tensor) for _ in range(3)],
         'b': tuple(MagicMock(spec=torch.Tensor) for _ in range(2)),
         'c': {
             'd': [MagicMock(spec=torch.Tensor) for _ in range(3)]
         },
         'e': MagicMock(spec=torch.Tensor)
     }
     torch_apply(my_obj, lambda t: t.cpu())
     self._test_method_calls(my_obj['a'] + list(my_obj['b']) + my_obj['c']['d'] + [my_obj['e']])
예제 #3
0
파일: test_utils.py 프로젝트: jtbai/poutyne
 def test_apply_on_object_with_no_tensor(self):
     my_obj = {
         'a': 5,
         'b': 3.141592,
         'c': {
             'd': [1, 2, 3]
         }
     }
     ret = torch_apply(my_obj, lambda t: t.cpu())
     self.assertEqual(ret, my_obj)
     self.assertFalse(ret is my_obj)
예제 #4
0
 def test_apply_with_replacement_to_no_tensor(self):
     my_obj = [MagicMock(spec=torch.Tensor)]
     ret = torch_apply(my_obj, lambda t: 123)
     self.assertEqual(ret, [123])
예제 #5
0
 def test_apply_on_dict(self):
     my_dict = {}
     for k in ['a', 'b', 'c']:
         my_dict[k] = MagicMock(spec=torch.Tensor)
     torch_apply(my_dict, lambda t: t.cpu())
     self._test_method_calls(list(my_dict.values()))
예제 #6
0
 def test_apply_on_recursive_tuple(self):
     my_tuple = tuple(MagicMock(spec=torch.Tensor) for _ in range(2))
     my_tuple += (tuple(MagicMock(spec=torch.Tensor) for _ in range(3)), )
     my_tuple += tuple(MagicMock(spec=torch.Tensor) for _ in range(1))
     torch_apply(my_tuple, lambda t: t.cpu())
     self._test_method_calls(my_tuple[:2] + my_tuple[2] + my_tuple[3:])
예제 #7
0
 def test_apply_on_tuple(self):
     my_tuple = tuple(MagicMock(spec=torch.Tensor) for _ in range(10))
     torch_apply(my_tuple, lambda t: t.cpu())
     self._test_method_calls(my_tuple)
예제 #8
0
 def test_apply_on_recursive_list(self):
     my_list = [MagicMock(spec=torch.Tensor) for _ in range(2)]
     my_list.append([MagicMock(spec=torch.Tensor) for _ in range(3)])
     my_list += [MagicMock(spec=torch.Tensor) for _ in range(1)]
     torch_apply(my_list, lambda t: t.cpu())
     self._test_method_calls(my_list[:2] + my_list[2] + my_list[3:])
예제 #9
0
 def test_apply_on_list(self):
     my_list = [MagicMock(spec=torch.Tensor) for _ in range(10)]
     torch_apply(my_list, lambda t: t.cpu())
     self._test_method_calls(my_list)