Exemplo n.º 1
0
    def test_cg_prod_to_device(self, device1, device2):

        cg_prod = CGProduct(maxl=1, device=device1)

        cg_prod.to(device=device2)
        assert cg_prod.device == device2
        assert cg_prod.cg_dict.device == device2
        assert all([t.device == device2 for t in cg_prod.cg_dict.values()])
Exemplo n.º 2
0
    def test_cg_prod_to(self, dtype1, dtype2, device1, device2):

        cg_prod = CGProduct(maxl=1, dtype=dtype1, device=device1)

        cg_prod.to(device2, dtype2)
        assert cg_prod.dtype == dtype2
        assert cg_prod.cg_dict.dtype == dtype2
        assert all([t.dtype == dtype2 for t in cg_prod.cg_dict.values()])
        assert cg_prod.device == device2
        assert cg_prod.cg_dict.device == device2
        assert all([t.device == device2 for t in cg_prod.cg_dict.values()])

        # Check that .half() work as expected
        @pytest.mark.parametrize('dtype',
                                 [None, torch.half, torch.float, torch.double])
        def test_cg_prod_half(self, maxl, dtype):

            cg_prod = CGProduct(maxl=maxl, dtype=dtype)
            cg_prod.half()
            assert cg_prod.dtype == torch.half
            assert cg_prod.cg_dict.dtype == torch.half
            assert all(
                [t.device == torch.half for t in cg_prod.cg_dict.values()])

        # Check that .float() work as expected
        @pytest.mark.parametrize('dtype',
                                 [None, torch.half, torch.float, torch.double])
        def test_cg_prod_float(self, maxl, dtype):

            cg_prod = CGProduct(maxl=maxl, dtype=dtype)
            cg_prod.float()
            assert cg_prod.dtype == torch.float
            assert cg_prod.cg_dict.dtype == torch.float
            assert all(
                [t.device == torch.float for t in cg_prod.cg_dict.values()])

        # Check that .double() work as expected
        @pytest.mark.parametrize('dtype',
                                 [None, torch.half, torch.float, torch.double])
        def test_cg_prod_double(self, maxl, dtype):

            cg_prod = CGProduct(maxl=maxl, dtype=dtype)
            cg_prod.double()
            assert cg_prod.dtype == torch.double
            assert cg_prod.cg_dict.dtype == torch.double
            assert all(
                [t.device == torch.double for t in cg_prod.cg_dict.values()])

        # Check that .cpu() work as expected
        @pytest.mark.parametrize('device', devices)
        def test_cg_prod_cpu(self, maxl, device):

            cg_prod = CGProduct(maxl=maxl, device=device)
            cg_prod.cpu()
            assert cg_prod.device == torch.device('cpu')
            assert cg_prod.cg_dict.device == torch.device('cpu')
            assert all([
                t.device == torch.device('cpu')
                for t in cg_prod.cg_dict.values()
            ])

        # Check that .cuda() work as expected
        @pytest.mark.parametrize('device', devices)
        def test_cg_prod_cuda(self, maxl, device):

            if not torch.cuda.is_available():
                return

            cg_prod = CGProduct(maxl=maxl, device=device)
            cg_prod.cuda()
            assert cg_prod.device == torch.device('cuda')
            assert cg_prod.cg_dict.device == torch.device('cuda')
            assert all([
                t.device == torch.device('cuda')
                for t in cg_prod.cg_dict.values()
            ])