def test_random_tensor_seed(low, high, shape): threshold = shape[0] * shape[1] * 0.9 manual_seed(0) tensor1 = random_tensor(low, high, shape) tensor2 = random_tensor(low, high, shape) assert torch.sum(tensor1 != tensor2) > threshold manual_seed(0) tensor3 = random_tensor(low, high, shape) assert torch.equal(tensor1, tensor3) manual_seed(1) tensor4 = random_tensor(low, high, shape) assert torch.sum(tensor1 != tensor4) > threshold
def test_random_shape_per_tensor_seed(batch_size, min_shape, max_shape): threshold = batch_size * len(max_shape) * 0.9 manual_seed(0) shape_per_tensor1 = random_shape_per_tensor(batch_size, min_shape, max_shape) shape_per_tensor2 = random_shape_per_tensor(batch_size, min_shape, max_shape) assert torch.sum(shape_per_tensor1 != shape_per_tensor2) > threshold manual_seed(0) shape_per_tensor3 = random_shape_per_tensor(batch_size, min_shape, max_shape) assert torch.equal(shape_per_tensor1, shape_per_tensor3) manual_seed(1) shape_per_tensor4 = random_shape_per_tensor(batch_size, min_shape, max_shape) assert torch.sum(shape_per_tensor1 != shape_per_tensor4) > threshold
def orig_test_wrapper(*args, **kwargs): torch_state, random_state, np_state = random.get_state() random.manual_seed(torch_seed, numpy_seed, random_seed) output = orig_test(*args, **kwargs) random.set_state(torch_state, random_state, np_state) return output