def test_multiple_device_transfer(self, device, dtype, module_info): module_cls = module_info.module_cls module_inputs_device = module_info.module_inputs_func( module_info, device=device, dtype=dtype, requires_grad=False) module_inputs_cpu = module_info.module_inputs_func(module_info, device="cpu", dtype=dtype, requires_grad=False) for module_input_device, module_input_cpu in zip( module_inputs_device, module_inputs_cpu): if module_input_device.forward_input is None: continue with freeze_rng_state(): # === Instantiate the module. === args, kwargs = module_input_device.constructor_input.args, module_input_device.constructor_input.kwargs m = module_cls(*args, **kwargs) m.to(device).to(dtype) # === Do forward pass on GPU === input_device_args = module_input_device.forward_input.args input_device_kwargs = module_input_device.forward_input.kwargs m(*input_device_args, **input_device_kwargs) self._assert_module_parameters_and_buffer_are(m, device, dtype) # === Move to CPU === input_cpu_args = module_input_cpu.forward_input.args input_cpu_kwargs = module_input_cpu.forward_input.kwargs m.cpu() m(*input_cpu_args, **input_cpu_kwargs) self._assert_module_parameters_and_buffer_are(m, "cpu", dtype) # === Move back to GPU and forward pass === m.cuda() m(*input_device_args, **input_device_kwargs) self._assert_module_parameters_and_buffer_are(m, device, dtype) if torch.cuda.device_count() >= 2: # === test cross-GPU transfer works def _to_device1(objs): if isinstance(objs, (tuple, list)): return type(objs)(_to_device1(item) for item in objs) elif isinstance(objs, dict): return { name: _to_device1(item) for name, item in objs.items() } elif isinstance(objs, torch.Tensor): return objs.cuda(1) else: return objs input_device_1_args = _to_device1(input_device_args) input_device_1_kwargs = _to_device1(input_device_kwargs) m.cuda(1) with torch.cuda.device(1): m(*input_device_1_args, **input_device_1_kwargs) self._assert_module_parameters_and_buffer_are( m, torch.device("cuda:1"), dtype)
def test_non_contiguous_tensors(self, device, dtype, module_info): # Check modules work with non-contiguous tensors module_cls = module_info.module_cls module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, requires_grad=True) def _make_non_contiguous(obj): def inner_make_non_contiguous(obj): # Scalar tensors can not be made non-contiguous if not isinstance(obj, torch.Tensor) or obj.dim() == 0: return obj out = torch.repeat_interleave(obj, 2, dim=-1) out = out[..., ::2].detach() out.requires_grad = obj.requires_grad return out return self._traverse_obj(obj, inner_make_non_contiguous) def _can_be_noncontiguous(obj): if isinstance(obj, (tuple, list)): return any(_can_be_noncontiguous(o) for o in obj) elif isinstance(obj, dict): return any(_can_be_noncontiguous(o) for o in obj.values()) # scalar tensors can not be non-contiguous if not isinstance(obj, torch.Tensor) or obj.dim() == 0: return False return True for module_input in module_inputs: if module_input.forward_input is None: continue input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs if not (_can_be_noncontiguous(input_args) or _can_be_noncontiguous(input_kwargs)): continue # === Instantiate the module. === args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs m = module_cls(*args, **kwargs) m.to(device).to(dtype) self._retain_grad((input_args, input_kwargs)) # === Forward with default input with freeze_rng_state(): default_output = m(*input_args, **input_kwargs) grad_output = default_output.clone().detach_().normal_() default_output.backward(grad_output, retain_graph=True) default_input_args_grad, default_input_kwargs_grad = deepcopy( self._get_grads((input_args, input_kwargs))) default_param_grad = deepcopy([p.grad for p in m.parameters()]) # === Construct non-contiguous tensors === nc_input_args, nc_input_kwargs = _make_non_contiguous( (input_args, input_kwargs)) nc_grad_output = _make_non_contiguous(grad_output) # === Compare results with non-contiguous and contiguous tensors === inputs = [(input_args, input_kwargs), (nc_input_args, nc_input_kwargs)] grads = [grad_output, nc_grad_output] for (in_args, in_kwargs), g_out in product(inputs, grads): g_out_copy = deepcopy(g_out) self._zero_grad((in_args, in_kwargs)) self._zero_grad(m.parameters()) with freeze_rng_state(): out = m(*in_args, **in_kwargs) out.backward(g_out_copy, retain_graph=True) input_args_grad, input_kwargs_grad = self._get_grads( (in_args, in_kwargs)) self.assertEqual(out, default_output) self.assertEqual(input_args_grad, default_input_args_grad, atol=1e-4, rtol=0) self.assertEqual(input_kwargs_grad, default_input_kwargs_grad, atol=1e-4, rtol=0) param_grad = [p.grad for p in m.parameters()] self.assertEqual(param_grad, default_param_grad)
def runAndSaveRNG(self, func, inputs, kwargs=None): kwargs = kwargs if kwargs else {} with freeze_rng_state(): results = func(*inputs, **kwargs) return results
def gen_data(): with freeze_rng_state(): return torch.randn(10), torch.randint(10, (20, )), torch.randn(20)
def test_memory_format(self, device, dtype, module_info, training): module_cls = module_info.module_cls module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, requires_grad=False, training=training) module_memformat_affects_out = module_info.module_memformat_affects_out def _get_mem_formats(channels_last=False, channels_last_3d=False): if channels_last: return ([torch.contiguous_format, torch.channels_last], [ torch.preserve_format, torch.contiguous_format, torch.channels_last ]) elif channels_last_3d: return ([torch.contiguous_format, torch.channels_last_3d], [ torch.preserve_format, torch.contiguous_format, torch.channels_last_3d ]) else: return ([torch.contiguous_format], [torch.preserve_format, torch.contiguous_format]) # Check that at least one Tensor input has dim == n def _check_dims(obj, n): if isinstance(obj, torch.Tensor): return obj.dim() == n elif isinstance(obj, (tuple, list)): return any(_check_dims(o, n) for o in obj) else: return False # Called after _check_dims, when we know that >= 1 tensor can be converted to mem_format def _to_mem_format(mem_format, obj): def inner_to_mem_format(obj): d = obj.dim() if ((mem_format == torch.channels_last and d != 4) or (mem_format == torch.channels_last_3d and d != 5)): return obj return obj.to(memory_format=mem_format) return self._traverse_obj(obj, inner_to_mem_format) def _check_out_mem_format(output, input_mem_format, module_mem_format): def inner_check_out_mem_format(output): d = output.dim() if (d == 4 and ((input_mem_format == torch.channels_last) or (module_mem_format == torch.channels_last and module_memformat_affects_out))): self.assertTrue( output.is_contiguous( memory_format=torch.channels_last)) elif (d == 5 and ((input_mem_format == torch.channels_last_3d) or (module_mem_format == torch.channels_last_3d and module_memformat_affects_out))): self.assertTrue( output.is_contiguous( memory_format=torch.channels_last_3d)) else: self.assertTrue(output.is_contiguous()) return self._traverse_obj(output, inner_check_out_mem_format) for module_input in module_inputs: if module_input.forward_input is None: continue supports_channels_last = _check_dims( module_input.forward_input.args, 4) supports_channels_last_3d = _check_dims( module_input.forward_input.args, 5) input_mem_formats, module_mem_formats = _get_mem_formats( supports_channels_last, supports_channels_last_3d) with freeze_rng_state(): # === Instantiate the module. === args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs m = module_cls(*args, **kwargs) m.to(device).to(dtype) m.train(training) # === Get output in (contiguous, contiguous) configuration. === args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs desired_outputs = m(*args, **kwargs) for input_mem_format in input_mem_formats: # === Change memformat of input. === module_input.forward_input.args = _to_mem_format( input_mem_format, module_input.forward_input.args) module_input.forward_input.kwargs = _to_mem_format( input_mem_format, module_input.forward_input.kwargs) for module_mem_format in module_mem_formats: # === Change memformat of module === m.to(memory_format=module_mem_format) # === Do forward pass. === args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs outputs = m(*args, **kwargs) # === Compare outputs to (contiguous, contiguous) output. === if input_mem_format != torch.contiguous_format or module_mem_formats != torch.contiguous_format: self.assertEqual(outputs, desired_outputs) # === Check mem format of output. === _check_out_mem_format(outputs, input_mem_format, module_mem_format)
def test_dropout(self, device, dtype): # edge case: empty nested tensor nt0 = torch.nested_tensor([]) y = torch.nn.functional.dropout(nt0, 0.5) self.nt_equal(nt0, y) # normal nested tensor ntensors = 4 nt = self.random_nt(device, dtype, ntensors, (4, 4)) # edge case: invalid dropout self.assertRaises(ValueError, lambda: torch.nn.Dropout(-0.1)) self.assertRaises(ValueError, lambda: torch.nn.Dropout(1.1)) self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, -0.1)) self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, 1.1)) # edge case: no dropout dropouter = torch.nn.Dropout(0.0) y0 = dropouter(nt) y1 = torch.nn.functional.dropout(nt, 0.0) self.nt_equal(nt, y0) self.nt_equal(nt, y1) # edge case: all dropout dropouter = torch.nn.Dropout(1.0) y0 = dropouter(nt) y1 = torch.nn.functional.dropout(nt, 1.0) nt0 = nt.clone() for i in range(ntensors): nt0[i].fill_(0.0) self.nt_equal(nt0, y0) self.nt_equal(nt0, y1) # normal case: normal dropout p = 0.2 y = torch.nn.functional.dropout(nt, p) expect = nt.clone() for i in range(ntensors): actual_tensor = y[i].view(-1) expect_tensor = expect[i].view(-1) for j in range(actual_tensor.shape[0]): if actual_tensor[j].item() == 0.0: expect_tensor[j] = 0.0 else: expect_tensor[j] /= 1.0 - p self.nt_equal(y, expect) with freeze_rng_state(): dropouter = torch.nn.Dropout(p) y0 = dropouter(nt) with freeze_rng_state(): y1 = torch.nn.functional.dropout(nt, p) self.nt_equal(y0, y1) # inplace # in principle, since we have established the correctness of functional, we could simply compare inplace vs functional # in practice, cuda functional has its own implementation to skip `bernoulli_` # so cuda functional will differ from cuda inplace causing test failure # in `test_dropout_cuda_float64 (__main__.TestNestedTensorDeviceTypeCUDA)` # on `linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 2, 4, linux.4xlarge.nvidia.gpu)` expect = nt.clone() torch.nn.functional.dropout(nt, p, inplace=True) for i in range(ntensors): actual_tensor = nt[i].view(-1) expect_tensor = expect[i].view(-1) for j in range(actual_tensor.shape[0]): if actual_tensor[j].item() == 0.0: expect_tensor[j] = 0.0 else: expect_tensor[j] /= 1.0 - p self.nt_equal(nt, expect)
def test_scaled_dot_product_attention(self, device, input_dim, attn_mask_dim, is_causal, dropout_p): # TODO: Support cross-device / dtype testing properly when instantiate_device_type_tests() is used. dtypes = [torch.double, torch.float] for dtype in dtypes: def rand_tensor(*shape): return torch.randn(shape, device=device, dtype=dtype) # This test compares python and C++ implementations of SDP. N, N_prime, L, S, E = 5, 2, 4, 3, 6 if input_dim == 3: query = rand_tensor(N, L, E) key = rand_tensor(N, S, E) value = rand_tensor(N, S, E) elif input_dim == 4: query = rand_tensor(N, N_prime, L, E) key = rand_tensor(N, N_prime, S, E) value = rand_tensor(N, N_prime, S, E) else: self.fail( f'Invalid input_dim {input_dim} encountered in SDP test') attn_mask = None if attn_mask_dim is not None: assert attn_mask_dim in [2, input_dim] mask_size = (L, S) if attn_mask_dim == 2 else (( N, L, S) if input_dim == 3 else (N, N_prime, L, S)) attn_mask = (torch.ones( mask_size, device=device, dtype=torch.bool).tril() if is_causal else torch.randint( 0, 2, size=mask_size, device=device, dtype=torch.bool)) with freeze_rng_state(): # Python impl only supports float mask and 3D inputs. attn_mask_float = attn_mask if attn_mask_float is not None: attn_mask_float = torch.zeros_like(attn_mask, dtype=query.dtype) attn_mask_float.masked_fill_(attn_mask.logical_not(), float("-inf")) q, k, v = query.view(-1, L, E), key.view(-1, S, E), value.view(-1, S, E) a = attn_mask_float if a is not None and attn_mask_dim > 3: a = a.view(-1, L, S) expected = F._scaled_dot_product_attention(q, k, v, attn_mask=a, dropout_p=dropout_p) if input_dim > 3: expected = (expected[0].view(-1, N_prime, L, E), expected[1].view(-1, N_prime, L, S)) need_attn_weights: bool = True with freeze_rng_state(): if is_causal: # NB: Don't pass attn_mask here actual = torch.ops.aten._scaled_dot_product_attention( query, key, value, None, dropout_p, need_attn_weights, is_causal) # Error case: both explicit attn_mask and is_causal are set with self.assertRaisesRegex( RuntimeError, "Explicit attn_mask should not be set when is_causal=True" ): torch.ops.aten._scaled_dot_product_attention( query, key, value, attn_mask, dropout_p, need_attn_weights, is_causal) else: actual = torch.ops.aten._scaled_dot_product_attention( query, key, value, attn_mask, dropout_p, need_attn_weights, is_causal) # freeze_rng_state() doesn't seem to work outside of CPU, so dropout makes the results incomparable. # TODO: Do this skipping in a nicer way once the granular test skipping logic lands. if dropout_p == 0.0 or device == 'cpu': self.assertEqual(actual, expected)