示例#1
0
文件: random.py 项目: Guillemdb/judo
 def choice(a, size=None, replace=True, p=None):
     a = to_backend(a)
     size = size if size is not None else 1
     if replace:
         size = size if isinstance(size, tuple) else (size,)
         indices = to_backend(torch.randint(len(a), size))
         samples = a[indices]
     else:
         indices = to_backend(torch.randperm(len(a)))[:size]
         samples = a[indices]
     return to_backend(samples)
示例#2
0
文件: random.py 项目: Guillemdb/judo
 def randint(low, high, size=None, dtype=None):
     size = size if size is not None else (1,)
     size = size if isinstance(size, tuple) else (size,)
     data = torch.randint(low, high, size)
     if dtype is not None:
         data = data.to(dtype)
     return to_backend(data)
示例#3
0
文件: random.py 项目: Guillemdb/judo
 def uniform(
     low=0.0, high=1.0, size=None, dtype=None,
 ):
     uniform = torch.distributions.uniform.Uniform(low, high)
     if size is not None:
         size = size if isinstance(size, tuple) else (size,)
         sample = uniform.sample(size)
     else:
         sample = uniform.sample()
     if dtype is not None:
         sample = sample.to(dtype)
     return to_backend(sample)
示例#4
0
文件: random.py 项目: Guillemdb/judo
 def normal(loc=0, scale=1.0, size=None):
     size = size if size is not None else (1,)
     size = size if isinstance(size, tuple) else (size,)
     return to_backend(torch.normal(mean=loc, std=scale, size=size))
示例#5
0
文件: random.py 项目: Guillemdb/judo
 def random_sample(*args, **kwargs):
     sample = torch.rand(*args, **kwargs)
     return to_backend(sample)
示例#6
0
文件: random.py 项目: Guillemdb/judo
 def permutation(x):
     idx = torch.randperm(x.shape[0])
     sample = x[idx].to(Backend.get_device()).detach()
     return to_backend(sample)