def choice(a, size=None, replace=True, p=None): a = to_backend(a) size = size if size is not None else 1 if replace: size = size if isinstance(size, tuple) else (size,) indices = to_backend(torch.randint(len(a), size)) samples = a[indices] else: indices = to_backend(torch.randperm(len(a)))[:size] samples = a[indices] return to_backend(samples)
def randint(low, high, size=None, dtype=None): size = size if size is not None else (1,) size = size if isinstance(size, tuple) else (size,) data = torch.randint(low, high, size) if dtype is not None: data = data.to(dtype) return to_backend(data)
def uniform( low=0.0, high=1.0, size=None, dtype=None, ): uniform = torch.distributions.uniform.Uniform(low, high) if size is not None: size = size if isinstance(size, tuple) else (size,) sample = uniform.sample(size) else: sample = uniform.sample() if dtype is not None: sample = sample.to(dtype) return to_backend(sample)
def normal(loc=0, scale=1.0, size=None): size = size if size is not None else (1,) size = size if isinstance(size, tuple) else (size,) return to_backend(torch.normal(mean=loc, std=scale, size=size))
def random_sample(*args, **kwargs): sample = torch.rand(*args, **kwargs) return to_backend(sample)
def permutation(x): idx = torch.randperm(x.shape[0]) sample = x[idx].to(Backend.get_device()).detach() return to_backend(sample)