def __init__(self, dimension, scramble=False, seed=None): if dimension > self.MAXDIM or dimension < 1: raise ValueError("Supported range of dimensionality " "for SobolEngine is [1, {}]".format(self.MAXDIM)) self.seed = seed self.scramble = scramble self.dimension = dimension self.sobolstate = torch.zeros(dimension, self.MAXBIT, dtype=torch.long) torch._sobol_engine_initialize_state_(self.sobolstate, self.dimension) if self.scramble: g = torch.Generator() if self.seed is not None: g.manual_seed(self.seed) else: g.seed() self.shift = torch.mv( torch.randint(2, (self.dimension, self.MAXBIT), generator=g), torch.pow(2, torch.arange(0, self.MAXBIT))) ltm = torch.randint(2, (self.dimension, self.MAXBIT, self.MAXBIT), generator=g).tril() torch._sobol_engine_scramble_(self.sobolstate, ltm, self.dimension) else: self.shift = torch.zeros(self.dimension, dtype=torch.long) self.quasi = self.shift.clone(memory_format=torch.contiguous_format) self.num_generated = 0
def _scramble(self): g: Optional[torch.Generator] = None if self.seed is not None: g = torch.Generator() g.manual_seed(self.seed) cpu = torch.device("cpu") # Generate shift vector shift_ints = torch.randint(2, (self.dimension, self.MAXBIT), device=cpu, generator=g) self.shift = torch.mv(shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu))) # Generate lower triangular matrices (stacked across dimensions) ltm_dims = (self.dimension, self.MAXBIT, self.MAXBIT) ltm = torch.randint(2, ltm_dims, device=cpu, generator=g).tril() torch._sobol_engine_scramble_(self.sobolstate, ltm, self.dimension)
def __init__(self, dimension, scramble=False, seed=None): if dimension > self.MAXDIM or dimension < 1: raise ValueError("Supported range of dimensionality " f"for SobolEngine is [1, {self.MAXDIM}]") self.seed = seed self.scramble = scramble self.dimension = dimension cpu = torch.device("cpu") self.sobolstate = torch.zeros(dimension, self.MAXBIT, device=cpu, dtype=torch.long) torch._sobol_engine_initialize_state_(self.sobolstate, self.dimension) if self.scramble: g: Optional[torch.Generator] = None if self.seed is not None: g = torch.Generator() g.manual_seed(self.seed) shift_ints = torch.randint(2, (self.dimension, self.MAXBIT), device=cpu, generator=g) self.shift = torch.mv( shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu))) ltm_dims = (self.dimension, self.MAXBIT, self.MAXBIT) ltm = torch.randint(2, ltm_dims, device=cpu, generator=g).tril() torch._sobol_engine_scramble_(self.sobolstate, ltm, self.dimension) else: self.shift = torch.zeros(self.dimension, device=cpu, dtype=torch.long) self.quasi = self.shift.clone(memory_format=torch.contiguous_format) self.num_generated = 0