예제 #1
0
파일: basic_test.py 프로젝트: lehr-fa/heat
    def setUpClass(cls):
        """
        Read the environment variable 'HEAT_TEST_USE_DEVICE' and return the requested devices.
        Supported values
            - cpu: Use CPU only (default)
            - gpu: Use GPU only

        Raises
        ------
        RuntimeError if value of 'HEAT_TEST_USE_DEVICE' is not recognized

        """

        envar = os.getenv("HEAT_TEST_USE_DEVICE", "cpu")

        if envar == "cpu":
            ht.use_device("cpu")
            ht_device = ht.cpu
            other_device = ht.cpu
            if torch.cuda.is_available():
                torch.cuda.set_device(torch.device(ht.gpu.torch_device))
                other_device = ht.gpu
        elif envar == "gpu" and torch.cuda.is_available():
            ht.use_device("gpu")
            torch.cuda.set_device(torch.device(ht.gpu.torch_device))
            ht_device = ht.gpu
            other_device = ht.cpu
        else:
            raise RuntimeError(
                "Value '{}' of environment variable 'HEAT_TEST_USE_DEVICE' is unsupported"
                .format(envar))

        cls.device, cls.other_device, cls.envar = ht_device, other_device, envar
예제 #2
0
    def test_set_default_device_cpu(self):
        ht.use_device("cpu")
        self.assertIs(ht.get_device(), ht.cpu)
        ht.use_device(ht.cpu)
        self.assertIs(ht.get_device(), ht.cpu)
        ht.use_device(None)
        self.assertIs(ht.get_device(), ht.cpu)

        with self.assertRaises(ValueError):
            ht.use_device("fpu")
        with self.assertRaises(ValueError):
            ht.use_device(1)
예제 #3
0
    def test_set_default_device_gpu(self):
        if ht.torch.cuda.is_available():
            ht.use_device("gpu")
            self.assertIs(ht.get_device(), ht.gpu)
            ht.use_device(ht.gpu)
            self.assertIs(ht.get_device(), ht.gpu)
            ht.use_device(None)
            self.assertIs(ht.get_device(), ht.gpu)

        with self.assertRaises(ValueError):
            ht.use_device("fpu")
        with self.assertRaises(ValueError):
            ht.use_device(1)
예제 #4
0
    def test_sanitize_device(self):
        if os.environ.get("DEVICE") == "gpu":
            ht.use_device(os.environ.get("DEVICE"))
            self.assertIs(ht.sanitize_device("gpu"), ht.gpu)
            self.assertIs(ht.sanitize_device("gPu"), ht.gpu)
            self.assertIs(ht.sanitize_device("  GPU  "), ht.gpu)
            self.assertIs(ht.sanitize_device(ht.gpu), ht.gpu)
            self.assertIs(ht.sanitize_device(None), ht.gpu)
        else:
            self.assertIs(ht.sanitize_device("cpu"), ht.cpu)
            self.assertIs(ht.sanitize_device("cPu"), ht.cpu)
            self.assertIs(ht.sanitize_device("  CPU  "), ht.cpu)
            self.assertIs(ht.sanitize_device(ht.cpu), ht.cpu)
            self.assertIs(ht.sanitize_device(None), ht.cpu)

        with self.assertRaises(ValueError):
            self.assertIs(ht.sanitize_device("fpu"), ht.cpu)
        with self.assertRaises(ValueError):
            self.assertIs(ht.sanitize_device(1), ht.cpu)
예제 #5
0
파일: test_logical.py 프로젝트: melven/heat
import torch
import unittest
import os
import heat as ht

if os.environ.get("DEVICE") == "gpu" and torch.cuda.is_available():
    ht.use_device("gpu")
    torch.cuda.set_device(torch.device(ht.get_device().torch_device))
else:
    ht.use_device("cpu")
device = ht.get_device().torch_device
ht_device = None
if os.environ.get("DEVICE") == "lgpu" and torch.cuda.is_available():
    device = ht.gpu.torch_device
    ht_device = ht.gpu
    torch.cuda.set_device(device)


class TestLogical(unittest.TestCase):
    def test_all(self):
        array_len = 9

        # check all over all float elements of 1d tensor locally
        ones_noaxis = ht.ones(array_len, device=ht_device)
        x = (ones_noaxis == 1).all()

        self.assertIsInstance(x, ht.DNDarray)
        self.assertEqual(x.shape, (1, ))
        self.assertEqual(x.lshape, (1, ))
        self.assertEqual(x.dtype, ht.bool)
        self.assertEqual(x._DNDarray__array.dtype, torch.bool)
예제 #6
0
    def test_set_default_device(self):
        if os.environ.get("DEVICE") == "gpu":
            ht.use_device("gpu")
            self.assertIs(ht.get_device(), ht.gpu)
            ht.use_device(ht.gpu)
            self.assertIs(ht.get_device(), ht.gpu)
            ht.use_device(None)
            self.assertIs(ht.get_device(), ht.gpu)
        else:
            ht.use_device("cpu")
            self.assertIs(ht.get_device(), ht.cpu)
            ht.use_device(ht.cpu)
            self.assertIs(ht.get_device(), ht.cpu)
            ht.use_device(None)
            self.assertIs(ht.get_device(), ht.cpu)

        with self.assertRaises(ValueError):
            ht.use_device("fpu")
        with self.assertRaises(ValueError):
            ht.use_device(1)
예제 #7
0
 def test_get_default_device(self):
     if os.environ.get("DEVICE") == "gpu":
         ht.use_device(os.environ.get("DEVICE"))
         self.assertIs(ht.get_device(), ht.gpu)
     else:
         self.assertIs(ht.get_device(), ht.cpu)
예제 #8
0
파일: imagenet.py 프로젝트: lehr-fa/heat
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

sys.path.append("../../")
import heat as ht

ht.use_device(
    "cpu")  # this will be CPU and it will detect if there are GPUs available

model_names = sorted(name for name in models.__dict__
                     if name.islower() and not name.startswith("__")
                     and callable(models.__dict__[name]))

parser = argparse.ArgumentParser(description="PyTorch ImageNet Training")
parser.add_argument("data", metavar="DIR", help="path to dataset")
parser.add_argument(
    "-a",
    "--arch",
    metavar="ARCH",
    default="resnet18",
    choices=model_names,
    help="model architecture: " + " | ".join(model_names) +
    " (default: resnet18)",
예제 #9
0
 def setUp(self):
     # move to CPU only for the testing printing, otherwise the compare string will become messy
     ht.use_device("cpu")
예제 #10
0
 def tearDown(self):
     # reset the print options back to default after each test run
     ht.set_printoptions(profile="default")
     # reset the default device
     ht.use_device(self.device)