def initialize_weights(model, name=None, seed=0, half=False, **kwargs): """Initialize the model weights using a particular method Parameters ---------- model: Module name: str Name of the initializer seed: int seed to use for the PRNGs Returns ------- The initialized model """ # TODO: remove dependency to global PRNG # At the moment we simply fork the PRNG to prevent affecting later calls with fork_rng(enabled=True): init_seed(seed) method_builder = registered_initialization.get(name) if not method_builder: raise RegisteredInitNotFound(name) method = method_builder(**kwargs) return method(model)
def __iter__(self): with fork_rng() as rng: torch.set_rng_state(self.rng_state) iterator = iter(self.sampler) self.rng_state = torch.get_rng_state() print(self.rng_state, rng) return iterator
def __init__( self, dataset_size: int, item_shape: Tuple[int], distribution_function: torch.distributions.distribution.Distribution, seed: Optional[int] = None, device: torch.device = ( torch.device('cuda' if torch.cuda.is_available() else 'cpu')), ) -> None: """Dataset of synthetic samples from a specified distribution with given shape. Args: dataset_size (int): Number of items to generate (N). item_shape (Tuple[int]): The shape of each item tensor returned on indexing the dataset. For example for 2D items with timeseries of 3 timesteps and 5 features: (3, 5) distribution_function (torch.distributions.distribution.Distribution): An instance of the distribution class from which individual data items can be sampled by calling the .sample() method. This can either be an object that is directly initialised using a method from torch.distributions, such as, torch.distributions.normal.Normal(loc=0.0,scale=1.0), or from a factory using a keyword, for example, DISTRIBUTION_FUNCTION_FACTORY['normal](loc=0.0, scale=1.0) is a valid argument since the factory (found in utils.factories.py) initialises the distribution class object based on a string keyword and passes the relevant arguments to that object. seed (Optional[int]): If passed, all items are generated once with this seed (using a local RNG only). Defaults to None, where individual items are generated when the DistributionalDataset is indexed (using the global RNG). device (torch.device): Device where the tensors are stored. Defaults to gpu, if available. """ super(DistributionalDataset, self).__init__() self.dataset_size = dataset_size self.item_shape = item_shape self.seed = seed self.device = device self.data_sampler = distribution_function if self.seed: # Eager dataset creation with fork_rng(): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) self.datasource = self.data_sampler.sample( (dataset_size, *item_shape)) self.datasource = self.datasource.to(device) else: # get sampled item on indexing self.datasource = StochasticItems(self.data_sampler, self.item_shape, self.device)
def decorator(self, *args, **kwargs) -> T: if "random_state" in kwargs.keys(): self._random_instance = check_random_state(kwargs["random_state"]) elif not hasattr(self, "_random_instance"): self._random_instance = check_random_state(randint(0, high=MAX_NUMPY_SEED_VALUE)) with fork_rng(): manual_seed(self._random_instance.randint(0, high=MAX_TORCH_SEED_VALUE)) decorated(self, *args, **kwargs)
def __init__(self, sampler: Sampler, seed: int = None, state=None): super(ResumableSampler, self).__init__(data_source=None) self.sampler = sampler self.seed = seed if state is not None: self.rng_state = state elif seed is not None: with fork_rng(): torch.manual_seed(seed) self.rng_state = torch.get_rng_state() else: self.rng_state = torch.get_rng_state()
def get_matching_set(self, index: int, set_reference: Tensor) -> Tensor: """Gets the corresponding set to match to the reference set. Args: index (int): The index to be sampled. reference_set (Tensor): Tensor that represents samples of the reference set. Returns: Tensor: Tensor of the permuted reference set with additive noise. """ enable_fork_rng = self.seed is not None with fork_rng(enabled=enable_fork_rng): if enable_fork_rng: seed = self.seed + index torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) additive_noise = self.noise.sample(set_reference.size()) return set_reference[self.permute_indices, :] + additive_noise
def __getitem__(self, index: int) -> Tuple: """Generates one sample from the dataset. Args: index (int): The index to be sampled. Returns: Tuple : Tuple containing sampled set1, sampled set2, hungarian matching indices of set1 vs set2 and set2 vs set1. """ set_reference = self.dataset[index] enable_fork_rng = self.seed is not None with fork_rng(enabled=enable_fork_rng): if enable_fork_rng: seed = self.seed + index torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) self.permute_indices = self.permutation indexes_reference, indexes_matching, length = get_subsampling_indexes( self.min_set_length, self.max_set_length, self.permute_indices, shuffle=self.shuffle, ) set_matching = self.get_matching_set(index, set_reference) cropped_set_reference = set_reference[indexes_reference, :] cropped_set_matching = set_matching[indexes_matching, :] targets_12, targets_21 = hungarian_assignment( cropped_set_reference, cropped_set_matching, cost_metric_function=self.get_cost_matrix, ) if not self.pad: return ( cropped_set_reference, cropped_set_matching, targets_12, targets_21, length, ) else: return ( *pad_item( item=( cropped_set_reference, cropped_set_matching, targets_12, targets_21, ), padding_modes=['constant', 'constant', 'range', 'range'], padding_values=[ self.set_padding_value, self.set_padding_value, range(self.max_set_length), range(self.max_set_length), ], max_length=self.max_set_length, device=self.device, ), length, )
def __call__(self, model): with fork_rng(enabled=True): init_seed(self.seed) return self.initializer(model)