Exemplo n.º 1
0
def test_random_shape_per_tensor_seed(batch_size, min_shape, max_shape):
    threshold = batch_size * len(max_shape) * 0.9
    manual_seed(0)
    shape_per_tensor1 = random_shape_per_tensor(batch_size, min_shape,
                                                max_shape)
    shape_per_tensor2 = random_shape_per_tensor(batch_size, min_shape,
                                                max_shape)
    assert torch.sum(shape_per_tensor1 != shape_per_tensor2) > threshold
    manual_seed(0)
    shape_per_tensor3 = random_shape_per_tensor(batch_size, min_shape,
                                                max_shape)
    assert torch.equal(shape_per_tensor1, shape_per_tensor3)
    manual_seed(1)
    shape_per_tensor4 = random_shape_per_tensor(batch_size, min_shape,
                                                max_shape)
    assert torch.sum(shape_per_tensor1 != shape_per_tensor4) > threshold
Exemplo n.º 2
0
def test_random_shape_per_tensor(batch_size, min_shape, max_shape):
    old_seed = torch.initial_seed()
    torch.manual_seed(0)
    shape_per_tensor = random_shape_per_tensor(batch_size, min_shape,
                                               max_shape)
    if min_shape is None:
        min_shape = tuple([1] * len(max_shape))
    min_shape = torch.tensor(min_shape).unsqueeze(0)
    max_shape = torch.tensor(max_shape).unsqueeze(0)
    assert shape_per_tensor.shape[0] == batch_size
    assert (min_shape <= shape_per_tensor).all() and (shape_per_tensor <=
                                                      max_shape).all()
Exemplo n.º 3
0
 def shape_per_tensor(self, batch_size, min_shape, max_shape):
     return random_shape_per_tensor(batch_size, min_shape=min_shape,
                                    max_shape=max_shape)