Exemple #1
0
    def test_multiple_device_transfer(self, device, dtype, module_info):
        module_cls = module_info.module_cls
        module_inputs_device = module_info.module_inputs_func(
            module_info, device=device, dtype=dtype, requires_grad=False)
        module_inputs_cpu = module_info.module_inputs_func(module_info,
                                                           device="cpu",
                                                           dtype=dtype,
                                                           requires_grad=False)
        for module_input_device, module_input_cpu in zip(
                module_inputs_device, module_inputs_cpu):
            if module_input_device.forward_input is None:
                continue

            with freeze_rng_state():
                # === Instantiate the module. ===
                args, kwargs = module_input_device.constructor_input.args, module_input_device.constructor_input.kwargs
                m = module_cls(*args, **kwargs)
                m.to(device).to(dtype)

                # === Do forward pass on GPU ===
                input_device_args = module_input_device.forward_input.args
                input_device_kwargs = module_input_device.forward_input.kwargs
                m(*input_device_args, **input_device_kwargs)
                self._assert_module_parameters_and_buffer_are(m, device, dtype)

                # === Move to CPU ===
                input_cpu_args = module_input_cpu.forward_input.args
                input_cpu_kwargs = module_input_cpu.forward_input.kwargs
                m.cpu()
                m(*input_cpu_args, **input_cpu_kwargs)
                self._assert_module_parameters_and_buffer_are(m, "cpu", dtype)

                # === Move back to GPU and forward pass ===
                m.cuda()
                m(*input_device_args, **input_device_kwargs)
                self._assert_module_parameters_and_buffer_are(m, device, dtype)

                if torch.cuda.device_count() >= 2:
                    # === test cross-GPU transfer works
                    def _to_device1(objs):
                        if isinstance(objs, (tuple, list)):
                            return type(objs)(_to_device1(item)
                                              for item in objs)
                        elif isinstance(objs, dict):
                            return {
                                name: _to_device1(item)
                                for name, item in objs.items()
                            }
                        elif isinstance(objs, torch.Tensor):
                            return objs.cuda(1)
                        else:
                            return objs

                    input_device_1_args = _to_device1(input_device_args)
                    input_device_1_kwargs = _to_device1(input_device_kwargs)

                    m.cuda(1)
                    with torch.cuda.device(1):
                        m(*input_device_1_args, **input_device_1_kwargs)
                    self._assert_module_parameters_and_buffer_are(
                        m, torch.device("cuda:1"), dtype)
Exemple #2
0
    def test_non_contiguous_tensors(self, device, dtype, module_info):
        # Check modules work with non-contiguous tensors

        module_cls = module_info.module_cls
        module_inputs = module_info.module_inputs_func(module_info,
                                                       device=device,
                                                       dtype=dtype,
                                                       requires_grad=True)

        def _make_non_contiguous(obj):
            def inner_make_non_contiguous(obj):
                # Scalar tensors can not be made non-contiguous
                if not isinstance(obj, torch.Tensor) or obj.dim() == 0:
                    return obj

                out = torch.repeat_interleave(obj, 2, dim=-1)
                out = out[..., ::2].detach()
                out.requires_grad = obj.requires_grad
                return out

            return self._traverse_obj(obj, inner_make_non_contiguous)

        def _can_be_noncontiguous(obj):
            if isinstance(obj, (tuple, list)):
                return any(_can_be_noncontiguous(o) for o in obj)
            elif isinstance(obj, dict):
                return any(_can_be_noncontiguous(o) for o in obj.values())
            # scalar tensors can not be non-contiguous
            if not isinstance(obj, torch.Tensor) or obj.dim() == 0:
                return False
            return True

        for module_input in module_inputs:
            if module_input.forward_input is None:
                continue

            input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
            if not (_can_be_noncontiguous(input_args)
                    or _can_be_noncontiguous(input_kwargs)):
                continue

            # === Instantiate the module. ===
            args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
            m = module_cls(*args, **kwargs)
            m.to(device).to(dtype)

            self._retain_grad((input_args, input_kwargs))

            # === Forward with default input
            with freeze_rng_state():
                default_output = m(*input_args, **input_kwargs)
                grad_output = default_output.clone().detach_().normal_()
                default_output.backward(grad_output, retain_graph=True)

            default_input_args_grad, default_input_kwargs_grad = deepcopy(
                self._get_grads((input_args, input_kwargs)))
            default_param_grad = deepcopy([p.grad for p in m.parameters()])

            # === Construct non-contiguous tensors ===
            nc_input_args, nc_input_kwargs = _make_non_contiguous(
                (input_args, input_kwargs))
            nc_grad_output = _make_non_contiguous(grad_output)

            # === Compare results with non-contiguous and contiguous tensors ===
            inputs = [(input_args, input_kwargs),
                      (nc_input_args, nc_input_kwargs)]
            grads = [grad_output, nc_grad_output]

            for (in_args, in_kwargs), g_out in product(inputs, grads):
                g_out_copy = deepcopy(g_out)
                self._zero_grad((in_args, in_kwargs))
                self._zero_grad(m.parameters())

                with freeze_rng_state():
                    out = m(*in_args, **in_kwargs)
                    out.backward(g_out_copy, retain_graph=True)

                input_args_grad, input_kwargs_grad = self._get_grads(
                    (in_args, in_kwargs))
                self.assertEqual(out, default_output)
                self.assertEqual(input_args_grad,
                                 default_input_args_grad,
                                 atol=1e-4,
                                 rtol=0)
                self.assertEqual(input_kwargs_grad,
                                 default_input_kwargs_grad,
                                 atol=1e-4,
                                 rtol=0)

                param_grad = [p.grad for p in m.parameters()]
                self.assertEqual(param_grad, default_param_grad)
Exemple #3
0
 def runAndSaveRNG(self, func, inputs, kwargs=None):
     kwargs = kwargs if kwargs else {}
     with freeze_rng_state():
         results = func(*inputs, **kwargs)
     return results
Exemple #4
0
 def gen_data():
     with freeze_rng_state():
         return torch.randn(10), torch.randint(10,
                                               (20, )), torch.randn(20)
Exemple #5
0
    def test_memory_format(self, device, dtype, module_info, training):
        module_cls = module_info.module_cls
        module_inputs = module_info.module_inputs_func(module_info,
                                                       device=device,
                                                       dtype=dtype,
                                                       requires_grad=False,
                                                       training=training)
        module_memformat_affects_out = module_info.module_memformat_affects_out

        def _get_mem_formats(channels_last=False, channels_last_3d=False):
            if channels_last:
                return ([torch.contiguous_format, torch.channels_last], [
                    torch.preserve_format, torch.contiguous_format,
                    torch.channels_last
                ])
            elif channels_last_3d:
                return ([torch.contiguous_format, torch.channels_last_3d], [
                    torch.preserve_format, torch.contiguous_format,
                    torch.channels_last_3d
                ])
            else:
                return ([torch.contiguous_format],
                        [torch.preserve_format, torch.contiguous_format])

        # Check that at least one Tensor input has dim == n
        def _check_dims(obj, n):
            if isinstance(obj, torch.Tensor):
                return obj.dim() == n
            elif isinstance(obj, (tuple, list)):
                return any(_check_dims(o, n) for o in obj)
            else:
                return False

        # Called after _check_dims, when we know that >= 1 tensor can be converted to mem_format
        def _to_mem_format(mem_format, obj):
            def inner_to_mem_format(obj):
                d = obj.dim()
                if ((mem_format == torch.channels_last and d != 4)
                        or (mem_format == torch.channels_last_3d and d != 5)):
                    return obj
                return obj.to(memory_format=mem_format)

            return self._traverse_obj(obj, inner_to_mem_format)

        def _check_out_mem_format(output, input_mem_format, module_mem_format):
            def inner_check_out_mem_format(output):
                d = output.dim()
                if (d == 4 and ((input_mem_format == torch.channels_last) or
                                (module_mem_format == torch.channels_last
                                 and module_memformat_affects_out))):
                    self.assertTrue(
                        output.is_contiguous(
                            memory_format=torch.channels_last))
                elif (d == 5
                      and ((input_mem_format == torch.channels_last_3d) or
                           (module_mem_format == torch.channels_last_3d
                            and module_memformat_affects_out))):
                    self.assertTrue(
                        output.is_contiguous(
                            memory_format=torch.channels_last_3d))
                else:
                    self.assertTrue(output.is_contiguous())

            return self._traverse_obj(output, inner_check_out_mem_format)

        for module_input in module_inputs:
            if module_input.forward_input is None:
                continue

            supports_channels_last = _check_dims(
                module_input.forward_input.args, 4)
            supports_channels_last_3d = _check_dims(
                module_input.forward_input.args, 5)
            input_mem_formats, module_mem_formats = _get_mem_formats(
                supports_channels_last, supports_channels_last_3d)

            with freeze_rng_state():
                # === Instantiate the module. ===
                args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs

                m = module_cls(*args, **kwargs)
                m.to(device).to(dtype)
                m.train(training)

                # === Get output in (contiguous, contiguous) configuration. ===
                args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
                desired_outputs = m(*args, **kwargs)

                for input_mem_format in input_mem_formats:
                    # === Change memformat of input. ===
                    module_input.forward_input.args = _to_mem_format(
                        input_mem_format, module_input.forward_input.args)
                    module_input.forward_input.kwargs = _to_mem_format(
                        input_mem_format, module_input.forward_input.kwargs)

                    for module_mem_format in module_mem_formats:
                        # === Change memformat of module ===
                        m.to(memory_format=module_mem_format)

                        # === Do forward pass. ===
                        args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
                        outputs = m(*args, **kwargs)

                        # === Compare outputs to (contiguous, contiguous) output. ===
                        if input_mem_format != torch.contiguous_format or module_mem_formats != torch.contiguous_format:
                            self.assertEqual(outputs, desired_outputs)

                        # === Check mem format of output. ===
                        _check_out_mem_format(outputs, input_mem_format,
                                              module_mem_format)
 def test_dropout(self, device, dtype):
     # edge case: empty nested tensor
     nt0 = torch.nested_tensor([])
     y = torch.nn.functional.dropout(nt0, 0.5)
     self.nt_equal(nt0, y)
     # normal nested tensor
     ntensors = 4
     nt = self.random_nt(device, dtype, ntensors, (4, 4))
     # edge case: invalid dropout
     self.assertRaises(ValueError, lambda: torch.nn.Dropout(-0.1))
     self.assertRaises(ValueError, lambda: torch.nn.Dropout(1.1))
     self.assertRaises(ValueError,
                       lambda: torch.nn.functional.dropout(nt, -0.1))
     self.assertRaises(ValueError,
                       lambda: torch.nn.functional.dropout(nt, 1.1))
     # edge case: no dropout
     dropouter = torch.nn.Dropout(0.0)
     y0 = dropouter(nt)
     y1 = torch.nn.functional.dropout(nt, 0.0)
     self.nt_equal(nt, y0)
     self.nt_equal(nt, y1)
     # edge case: all dropout
     dropouter = torch.nn.Dropout(1.0)
     y0 = dropouter(nt)
     y1 = torch.nn.functional.dropout(nt, 1.0)
     nt0 = nt.clone()
     for i in range(ntensors):
         nt0[i].fill_(0.0)
     self.nt_equal(nt0, y0)
     self.nt_equal(nt0, y1)
     # normal case: normal dropout
     p = 0.2
     y = torch.nn.functional.dropout(nt, p)
     expect = nt.clone()
     for i in range(ntensors):
         actual_tensor = y[i].view(-1)
         expect_tensor = expect[i].view(-1)
         for j in range(actual_tensor.shape[0]):
             if actual_tensor[j].item() == 0.0:
                 expect_tensor[j] = 0.0
             else:
                 expect_tensor[j] /= 1.0 - p
     self.nt_equal(y, expect)
     with freeze_rng_state():
         dropouter = torch.nn.Dropout(p)
         y0 = dropouter(nt)
     with freeze_rng_state():
         y1 = torch.nn.functional.dropout(nt, p)
     self.nt_equal(y0, y1)
     # inplace
     # in principle, since we have established the correctness of functional, we could simply compare inplace vs functional
     # in practice, cuda functional has its own implementation to skip `bernoulli_`
     # so cuda functional will differ from cuda inplace causing test failure
     # in `test_dropout_cuda_float64 (__main__.TestNestedTensorDeviceTypeCUDA)`
     # on `linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 2, 4, linux.4xlarge.nvidia.gpu)`
     expect = nt.clone()
     torch.nn.functional.dropout(nt, p, inplace=True)
     for i in range(ntensors):
         actual_tensor = nt[i].view(-1)
         expect_tensor = expect[i].view(-1)
         for j in range(actual_tensor.shape[0]):
             if actual_tensor[j].item() == 0.0:
                 expect_tensor[j] = 0.0
             else:
                 expect_tensor[j] /= 1.0 - p
     self.nt_equal(nt, expect)
Exemple #7
0
    def test_scaled_dot_product_attention(self, device, input_dim,
                                          attn_mask_dim, is_causal, dropout_p):
        # TODO: Support cross-device / dtype testing properly when instantiate_device_type_tests() is used.
        dtypes = [torch.double, torch.float]
        for dtype in dtypes:

            def rand_tensor(*shape):
                return torch.randn(shape, device=device, dtype=dtype)

            # This test compares python and C++ implementations of SDP.
            N, N_prime, L, S, E = 5, 2, 4, 3, 6
            if input_dim == 3:
                query = rand_tensor(N, L, E)
                key = rand_tensor(N, S, E)
                value = rand_tensor(N, S, E)
            elif input_dim == 4:
                query = rand_tensor(N, N_prime, L, E)
                key = rand_tensor(N, N_prime, S, E)
                value = rand_tensor(N, N_prime, S, E)
            else:
                self.fail(
                    f'Invalid input_dim {input_dim} encountered in SDP test')

            attn_mask = None
            if attn_mask_dim is not None:
                assert attn_mask_dim in [2, input_dim]
                mask_size = (L, S) if attn_mask_dim == 2 else ((
                    N, L, S) if input_dim == 3 else (N, N_prime, L, S))
                attn_mask = (torch.ones(
                    mask_size, device=device,
                    dtype=torch.bool).tril() if is_causal else torch.randint(
                        0, 2, size=mask_size, device=device, dtype=torch.bool))

            with freeze_rng_state():
                # Python impl only supports float mask and 3D inputs.
                attn_mask_float = attn_mask
                if attn_mask_float is not None:
                    attn_mask_float = torch.zeros_like(attn_mask,
                                                       dtype=query.dtype)
                    attn_mask_float.masked_fill_(attn_mask.logical_not(),
                                                 float("-inf"))
                q, k, v = query.view(-1, L,
                                     E), key.view(-1, S,
                                                  E), value.view(-1, S, E)
                a = attn_mask_float
                if a is not None and attn_mask_dim > 3:
                    a = a.view(-1, L, S)
                expected = F._scaled_dot_product_attention(q,
                                                           k,
                                                           v,
                                                           attn_mask=a,
                                                           dropout_p=dropout_p)
                if input_dim > 3:
                    expected = (expected[0].view(-1, N_prime, L, E),
                                expected[1].view(-1, N_prime, L, S))

            need_attn_weights: bool = True
            with freeze_rng_state():
                if is_causal:
                    # NB: Don't pass attn_mask here
                    actual = torch.ops.aten._scaled_dot_product_attention(
                        query, key, value, None, dropout_p, need_attn_weights,
                        is_causal)

                    # Error case: both explicit attn_mask and is_causal are set
                    with self.assertRaisesRegex(
                            RuntimeError,
                            "Explicit attn_mask should not be set when is_causal=True"
                    ):
                        torch.ops.aten._scaled_dot_product_attention(
                            query, key, value, attn_mask, dropout_p,
                            need_attn_weights, is_causal)
                else:
                    actual = torch.ops.aten._scaled_dot_product_attention(
                        query, key, value, attn_mask, dropout_p,
                        need_attn_weights, is_causal)

            # freeze_rng_state() doesn't seem to work outside of CPU, so dropout makes the results incomparable.
            # TODO: Do this skipping in a nicer way once the granular test skipping logic lands.
            if dropout_p == 0.0 or device == 'cpu':
                self.assertEqual(actual, expected)