class TestHelperFunctions(common_utils.TestCase):
    def setUp(self):
        super().setUp()
        self._initial_training_mode = GLOBALS.training_mode

    def tearDown(self):
        GLOBALS.training_mode = self._initial_training_mode

    @common_utils.parametrize(
        "op_train_mode,export_mode",
        [
            common_utils.subtest(
                [1, torch.onnx.TrainingMode.PRESERVE], name="export_mode_is_preserve"
            ),
            common_utils.subtest(
                [0, torch.onnx.TrainingMode.EVAL],
                name="modes_match_op_train_mode_0_export_mode_eval",
            ),
            common_utils.subtest(
                [1, torch.onnx.TrainingMode.TRAINING],
                name="modes_match_op_train_mode_1_export_mode_training",
            ),
        ],
    )
    def test_check_training_mode_does_not_warn_when(
        self, op_train_mode: int, export_mode: torch.onnx.TrainingMode
    ):
        GLOBALS.training_mode = export_mode
        self.assertNotWarn(
            lambda: symbolic_helper.check_training_mode(op_train_mode, "testop")
        )

    @common_utils.parametrize(
        "op_train_mode,export_mode",
        [
            common_utils.subtest(
                [0, torch.onnx.TrainingMode.TRAINING],
                name="modes_do_not_match_op_train_mode_0_export_mode_training",
            ),
            common_utils.subtest(
                [1, torch.onnx.TrainingMode.EVAL],
                name="modes_do_not_match_op_train_mode_1_export_mode_eval",
            ),
        ],
    )
    def test_check_training_mode_warns_when(
        self,
        op_train_mode: int,
        export_mode: torch.onnx.TrainingMode,
    ):
        with self.assertWarnsRegex(
            UserWarning, f"ONNX export mode is set to {export_mode}"
        ):
            GLOBALS.training_mode = export_mode
            symbolic_helper.check_training_mode(op_train_mode, "testop")
class TestInput(FSDPTest):
    @property
    def world_size(self):
        return 1

    @skip_if_lt_x_gpu(1)
    @parametrize("input_cls",
                 [subtest(dict, name="dict"),
                  subtest(list, name="list")])
    def test_input_type(self, input_cls):
        """Test FSDP with input being a list or a dict, only single GPU."""
        class Model(Module):
            def __init__(self):
                super().__init__()
                self.layer = Linear(4, 4)

            def forward(self, input):
                if isinstance(input, list):
                    input = input[0]
                else:
                    assert isinstance(input, dict), input
                    input = input["in"]
                return self.layer(input)

        model = FSDP(Model()).cuda()
        optim = SGD(model.parameters(), lr=0.1)

        for _ in range(5):
            in_data = torch.rand(64, 4).cuda()
            in_data.requires_grad = True
            if input_cls is list:
                in_data = [in_data]
            else:
                self.assertTrue(input_cls is dict)
                in_data = {"in": in_data}

            out = model(in_data)
            out.sum().backward()
            optim.step()
            optim.zero_grad()
Beispiel #3
0
class TestUtils(TestCase):
    @parametrize("devices", [["cpu"], ["cuda"],
                             subtest(["cpu", "cuda"], name="cpu_cuda")])
    def test_apply_to_tensors(self, devices):
        if "cuda" in devices and (not torch.cuda.is_available()
                                  or torch.cuda.device_count() < 1):
            raise unittest.SkipTest("Skipped due to lack of GPU")

        expected = 0

        def get_a_tensor():
            """Return a random tensor on random device."""
            dev = random.choice(devices)
            shape = random.choice(((1), (2, 3), (4, 5, 6), (7, 8, 9, 10)))
            t = torch.rand(shape).to(dev)
            nonlocal expected
            expected += t.numel()
            return t

        # create a mixed bag of data.
        data = [1, "str"]
        data.append({
            "key1": get_a_tensor(),
            "key2": {
                1: get_a_tensor()
            },
            "key3": 3
        })
        data.insert(0, set(["x", get_a_tensor(), get_a_tensor()]))
        data.append(([1], get_a_tensor(), (1), [get_a_tensor()], set((1, 2))))
        od = dict()
        od["k"] = "value"
        data.append(od)

        total = 0

        def fn(t):
            nonlocal total
            total += t.numel()
            return t

        new_data = _apply_to_tensors(fn, data)
        self.assertEqual(total, expected)
        for i, v in enumerate(data):
            self.assertEqual(type(new_data[i]), type(v))
Beispiel #4
0
from torch.nn.modules.lazy import LazyModuleMixin
from torch.testing._internal.common_utils import (
    TestCase, run_tests, parametrize, subtest, instantiate_parametrized_tests)
from torch.testing._internal.common_subclass import subclass_db, DiagTensorBelow
from torch.testing._internal.logging_tensor import LoggingTensor
from torch.utils._pytree import tree_map
from unittest import expectedFailure

# The current test methodology in this file is to test a variety of real use cases
# with a set of fully-fledged tensor subclasses. In the future, this may change
# to more narrowly specify toy subclasses for each of the specific invariants under
# test, avoiding the need to maintain the set of fully-fledged tensor subclasses.

# Decorator for parametrizing tests across the various tensor classes.
parametrize_tensor_cls = parametrize("tensor_cls", [
    subtest(tensor_cls, name=info.name)
    for tensor_cls, info in subclass_db.items()
])


class TestSubclass(TestCase):
    def _create_tensor(self, tensor_cls):
        return subclass_db[tensor_cls].create_fn(3)

    @parametrize_tensor_cls
    @parametrize("tensor_requires_grad", [False, True])
    def test_param_invariants(self, tensor_cls, tensor_requires_grad):
        x = self._create_tensor(tensor_cls).requires_grad_(
            tensor_requires_grad)
        param = nn.Parameter(x, requires_grad=(not tensor_requires_grad))
Beispiel #5
0
class TestPytree(TestCase):
    def test_treespec_equality(self):
        self.assertTrue(LeafSpec() == LeafSpec())
        self.assertTrue(TreeSpec(list, None, []) == TreeSpec(list, None, []))
        self.assertTrue(TreeSpec(list, None, [LeafSpec()]) == TreeSpec(list, None, [LeafSpec()]))
        self.assertFalse(TreeSpec(tuple, None, []) == TreeSpec(list, None, []))
        self.assertTrue(TreeSpec(tuple, None, []) != TreeSpec(list, None, []))

    def test_flatten_unflatten_leaf(self):
        def run_test_with_leaf(leaf):
            values, treespec = tree_flatten(leaf)
            self.assertEqual(values, [leaf])
            self.assertEqual(treespec, LeafSpec())

            unflattened = tree_unflatten(values, treespec)
            self.assertEqual(unflattened, leaf)

        run_test_with_leaf(1)
        run_test_with_leaf(1.)
        run_test_with_leaf(None)
        run_test_with_leaf(bool)
        run_test_with_leaf(torch.randn(3, 3))

    def test_flatten_unflatten_list(self):
        def run_test(lst):
            expected_spec = TreeSpec(list, None, [LeafSpec() for _ in lst])
            values, treespec = tree_flatten(lst)
            self.assertTrue(isinstance(values, list))
            self.assertEqual(values, lst)
            self.assertEqual(treespec, expected_spec)

            unflattened = tree_unflatten(values, treespec)
            self.assertEqual(unflattened, lst)
            self.assertTrue(isinstance(unflattened, list))

        run_test([])
        run_test([1., 2])
        run_test([torch.tensor([1., 2]), 2, 10, 9, 11])

    def test_flatten_unflatten_tuple(self):
        def run_test(tup):
            expected_spec = TreeSpec(tuple, None, [LeafSpec() for _ in tup])
            values, treespec = tree_flatten(tup)
            self.assertTrue(isinstance(values, list))
            self.assertEqual(values, list(tup))
            self.assertEqual(treespec, expected_spec)

            unflattened = tree_unflatten(values, treespec)
            self.assertEqual(unflattened, tup)
            self.assertTrue(isinstance(unflattened, tuple))

        run_test(())
        run_test((1.,))
        run_test((1., 2))
        run_test((torch.tensor([1., 2]), 2, 10, 9, 11))

    def test_flatten_unflatten_namedtuple(self):
        Point = namedtuple('Point', ['x', 'y'])

        def run_test(tup):
            expected_spec = TreeSpec(namedtuple, Point, [LeafSpec() for _ in tup])
            values, treespec = tree_flatten(tup)
            self.assertTrue(isinstance(values, list))
            self.assertEqual(values, list(tup))
            self.assertEqual(treespec, expected_spec)

            unflattened = tree_unflatten(values, treespec)
            self.assertEqual(unflattened, tup)
            self.assertTrue(isinstance(unflattened, Point))

        run_test(Point(1., 2))
        run_test(Point(torch.tensor(1.), 2))

    @parametrize("op", [
        subtest(torch.max, name='max'),
        subtest(torch.min, name='min'),
    ])
    def test_flatten_unflatten_return_type(self, op):
        x = torch.randn(3, 3)
        expected = op(x, dim=0)

        values, spec = tree_flatten(expected)
        # Check that values is actually List[Tensor] and not (ReturnType(...),)
        for value in values:
            self.assertTrue(isinstance(value, torch.Tensor))
        result = tree_unflatten(values, spec)

        self.assertEqual(type(result), type(expected))
        self.assertEqual(result, expected)

    def test_flatten_unflatten_dict(self):
        def run_test(tup):
            expected_spec = TreeSpec(dict, list(tup.keys()),
                                     [LeafSpec() for _ in tup.values()])
            values, treespec = tree_flatten(tup)
            self.assertTrue(isinstance(values, list))
            self.assertEqual(values, list(tup.values()))
            self.assertEqual(treespec, expected_spec)

            unflattened = tree_unflatten(values, treespec)
            self.assertEqual(unflattened, tup)
            self.assertTrue(isinstance(unflattened, dict))

        run_test({})
        run_test({'a': 1})
        run_test({'abcdefg': torch.randn(2, 3)})
        run_test({1: torch.randn(2, 3)})
        run_test({'a': 1, 'b': 2, 'c': torch.randn(2, 3)})

    def test_flatten_unflatten_nested(self):
        def run_test(pytree):
            values, treespec = tree_flatten(pytree)
            self.assertTrue(isinstance(values, list))
            self.assertEqual(len(values), treespec.num_leaves)

            # NB: python basic data structures (dict list tuple) all have
            # contents equality defined on them, so the following works for them.
            unflattened = tree_unflatten(values, treespec)
            self.assertEqual(unflattened, pytree)

        cases = [
            [()],
            ([],),
            {'a': ()},
            {'a': 0, 'b': [{'c': 1}]},
            {'a': 0, 'b': [1, {'c': 2}, torch.randn(3)], 'c': (torch.randn(2, 3), 1)},
        ]


    def test_treemap(self):
        def run_test(pytree):
            def f(x):
                return x * 3
            sm1 = sum(map(tree_flatten(pytree)[0], f))
            sm2 = tree_flatten(tree_map(f, pytree))[0]
            self.assertEqual(sm1, sm2)

            def invf(x):
                return x // 3

            self.assertEqual(tree_flatten(tree_flatten(pytree, f), invf), pytree)

            cases = [
                [()],
                ([],),
                {'a': ()},
                {'a': 1, 'b': [{'c': 2}]},
                {'a': 0, 'b': [2, {'c': 3}, 4], 'c': (5, 6)},
            ]
            for case in cases:
                run_test(case)


    def test_treespec_repr(self):
        # Check that it looks sane
        pytree = (0, [0, 0, 0])
        _, spec = tree_flatten(pytree)
        self.assertEqual(
            repr(spec), 'TreeSpec(tuple, None, [*, TreeSpec(list, None, [*, *, *])])')

    def test_broadcast_to_and_flatten(self):
        cases = [
            (1, (), []),

            # Same (flat) structures
            ((1,), (0,), [1]),
            ([1], [0], [1]),
            ((1, 2, 3), (0, 0, 0), [1, 2, 3]),
            ({'a': 1, 'b': 2}, {'a': 0, 'b': 0}, [1, 2]),

            # Mismatched (flat) structures
            ([1], (0,), None),
            ([1], (0,), None),
            ((1,), [0], None),
            ((1, 2, 3), (0, 0), None),
            ({'a': 1, 'b': 2}, {'a': 0}, None),
            ({'a': 1, 'b': 2}, {'a': 0, 'c': 0}, None),
            ({'a': 1, 'b': 2}, {'a': 0, 'b': 0, 'c': 0}, None),

            # Same (nested) structures
            ((1, [2, 3]), (0, [0, 0]), [1, 2, 3]),
            ((1, [(2, 3), 4]), (0, [(0, 0), 0]), [1, 2, 3, 4]),

            # Mismatched (nested) structures
            ((1, [2, 3]), (0, (0, 0)), None),
            ((1, [2, 3]), (0, [0, 0, 0]), None),

            # Broadcasting single value
            (1, (0, 0, 0), [1, 1, 1]),
            (1, [0, 0, 0], [1, 1, 1]),
            (1, {'a': 0, 'b': 0}, [1, 1]),
            (1, (0, [0, [0]], 0), [1, 1, 1, 1]),
            (1, (0, [0, [0, [], [[[0]]]]], 0), [1, 1, 1, 1, 1]),

            # Broadcast multiple things
            ((1, 2), ([0, 0, 0], [0, 0]), [1, 1, 1, 2, 2]),
            ((1, 2), ([0, [0, 0], 0], [0, 0]), [1, 1, 1, 1, 2, 2]),
            (([1, 2, 3], 4), ([0, [0, 0], 0], [0, 0]), [1, 2, 2, 3, 4, 4]),
        ]
        for pytree, to_pytree, expected in cases:
            _, to_spec = tree_flatten(to_pytree)
            result = _broadcast_to_and_flatten(pytree, to_spec)
            self.assertEqual(result, expected, msg=str([pytree, to_spec, expected]))
Beispiel #6
0
class TestUtils(TestCase):
    @parametrize("devices", [["cpu"], ["cuda"],
                             subtest(["cpu", "cuda"], name="cpu_cuda")])
    def test_apply_to_tensors(self, devices):
        if "cuda" in devices and (not torch.cuda.is_available()
                                  or torch.cuda.device_count() < 1):
            raise unittest.SkipTest("Skipped due to lack of GPU")

        expected = 0

        def get_a_tensor():
            """Return a random tensor on random device."""
            dev = random.choice(devices)
            shape = random.choice(((1), (2, 3), (4, 5, 6), (7, 8, 9, 10)))
            t = torch.rand(shape).to(dev)
            nonlocal expected
            expected += t.numel()
            return t

        # create a mixed bag of data.
        data = [1, "str"]
        data.append({
            "key1": get_a_tensor(),
            "key2": {
                1: get_a_tensor()
            },
            "key3": 3
        })
        data.insert(0, set(["x", get_a_tensor(), get_a_tensor()]))
        data.append(([1], get_a_tensor(), (1), [get_a_tensor()], set((1, 2))))
        od = OrderedDict()
        od["k"] = "value"
        data.append(od)

        total = 0

        def fn(t):
            nonlocal total
            total += t.numel()
            return t

        new_data = _apply_to_tensors(fn, data)
        self.assertEqual(total, expected)
        for i, v in enumerate(data):
            self.assertEqual(type(new_data[i]), type(v))

    def test_replace_by_prefix(self):
        state_dict = {
            "layer.a": torch.tensor(1),
            "abc.layer.def": torch.tensor(2),
            "layer.b": torch.tensor(3),
        }
        original_state_dict = state_dict.copy()
        _replace_by_prefix(state_dict, "layer.", "module.layer.")
        assert state_dict == {
            "module.layer.a": torch.tensor(1),
            "abc.layer.def": torch.tensor(2),
            "module.layer.b": torch.tensor(3),
        }
        _replace_by_prefix(state_dict, "module.layer.", "layer.")
        assert state_dict == original_state_dict

    def test_packed_sequence(self):
        """Test to ensure RNN packed sequences are modified correctly."""
        rnn = nn.RNN(5, 5)

        x = torch.rand((5, 1, 5), dtype=torch.float)
        seq_length = torch.tensor([4], dtype=torch.int)

        def fill_fn(x):
            x.fill_(0)

        x = nn.utils.rnn.pack_padded_sequence(x, seq_length)
        x, h = rnn(x)
        x = _apply_to_tensors(fill_fn, x)
        x, _ = nn.utils.rnn.pad_packed_sequence(x)
        self.assertEqual(torch.sum(x), 0)