Пример #1
0
 def test1VisibleNotInitialized(self):
     os.environ["CUDA_VISIBLE_DEVICES"] = "0"
     with patch("torch.cuda.is_initialized") as init_mock:
         init_mock.return_value = False
         mock_runner = MagicMock()
         mock_runner._set_cuda_device = MagicMock()
         LocalDistributedRunner._try_reserve_and_set_cuda(mock_runner)
         mock_runner._set_cuda_device.assert_called_with("0")
         self.assertEquals(len(os.environ["CUDA_VISIBLE_DEVICES"]), 1)
Пример #2
0
 def test2VisibleNotInitialized(self):
     os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
     with patch("torch.cuda.is_initialized") as init_mock:
         init_mock.return_value = False
         mock_runner = MagicMock()
         mock_runner._set_cuda_device = MagicMock()
         LocalDistributedRunner._try_reserve_and_set_cuda(mock_runner)
         args, _ = mock_runner._set_cuda_device.call_args
         self.assertTrue(("1" in args) or "0" in args)
         self.assertEquals(len(os.environ["CUDA_VISIBLE_DEVICES"]), 1)
Пример #3
0
    def _testWithInitialized(self, init_mock):
        mock_runner = MagicMock()
        mock_runner._set_cuda_device = MagicMock()
        preset_devices = os.environ.get("CUDA_VISIBLE_DEVICES")

        LocalDistributedRunner._try_reserve_and_set_cuda(mock_runner)

        self.assertTrue(mock_runner._set_cuda_device.called)
        local_device = mock_runner._set_cuda_device.call_args[0][0]
        env_set_device = os.environ["CUDA_VISIBLE_DEVICES"]
        self.assertEquals(len(env_set_device), 1)

        if preset_devices:
            self.assertIn(env_set_device, preset_devices.split(","))
            self.assertEquals(local_device, "0")
        else:
            self.assertEquals(local_device, env_set_device)
Пример #4
0
 def _testNotInitialized(self, init_mock):
     mock_runner = MagicMock()
     mock_runner._set_cuda_device = MagicMock()
     LocalDistributedRunner._try_reserve_and_set_cuda(mock_runner)
     mock_runner._set_cuda_device.assert_called_with("0")
     self.assertEquals(len(os.environ["CUDA_VISIBLE_DEVICES"]), 1)