def test_random_shape_per_tensor_seed(batch_size, min_shape, max_shape): threshold = batch_size * len(max_shape) * 0.9 manual_seed(0) shape_per_tensor1 = random_shape_per_tensor(batch_size, min_shape, max_shape) shape_per_tensor2 = random_shape_per_tensor(batch_size, min_shape, max_shape) assert torch.sum(shape_per_tensor1 != shape_per_tensor2) > threshold manual_seed(0) shape_per_tensor3 = random_shape_per_tensor(batch_size, min_shape, max_shape) assert torch.equal(shape_per_tensor1, shape_per_tensor3) manual_seed(1) shape_per_tensor4 = random_shape_per_tensor(batch_size, min_shape, max_shape) assert torch.sum(shape_per_tensor1 != shape_per_tensor4) > threshold
def test_random_shape_per_tensor(batch_size, min_shape, max_shape): old_seed = torch.initial_seed() torch.manual_seed(0) shape_per_tensor = random_shape_per_tensor(batch_size, min_shape, max_shape) if min_shape is None: min_shape = tuple([1] * len(max_shape)) min_shape = torch.tensor(min_shape).unsqueeze(0) max_shape = torch.tensor(max_shape).unsqueeze(0) assert shape_per_tensor.shape[0] == batch_size assert (min_shape <= shape_per_tensor).all() and (shape_per_tensor <= max_shape).all()
def shape_per_tensor(self, batch_size, min_shape, max_shape): return random_shape_per_tensor(batch_size, min_shape=min_shape, max_shape=max_shape)