def init_sampler_class( name: str, params: Dict[str, Any], seed: Optional[int] = None ) -> Sampler: """ Initializes the sampler class associated with the name with the params :param name: Name of the sampler in the factory to initialize :param params: Parameters associated to the sampler attached to the name :param seed: Random seed to be used to set deterministic random draws for the sampler """ if name not in SamplerFactory.NAME_TO_CLASS: raise SamplerException( name + " sampler is not registered in the SamplerFactory." " Use the register_sample method to register the string" " associated to your sampler in the SamplerFactory." ) sampler_cls = SamplerFactory.NAME_TO_CLASS[name] params["seed"] = seed try: return sampler_cls(**params) except TypeError: raise SamplerException( "The sampler class associated to the " + name + " key in the factory " "was not provided the required arguments. Please ensure that the sampler " "config file consists of the appropriate keys for this sampler class." )
def create_sampler_manager(sampler_config, run_seed=None): resample_interval = None if sampler_config is not None: if "resampling-interval" in sampler_config: # Filter arguments that do not exist in the environment resample_interval = sampler_config.pop("resampling-interval") if (resample_interval <= 0) or (not isinstance(resample_interval, int)): raise SamplerException( "Specified resampling-interval is not valid. Please provide" " a positive integer value for resampling-interval" ) else: raise SamplerException( "Resampling interval was not specified in the sampler file." " Please specify it with the 'resampling-interval' key in the sampler config file." ) sampler_manager = SamplerManager(sampler_config, run_seed) return sampler_manager, resample_interval
def __init__( self, reset_param_dict: Dict[str, Any], seed: Optional[int] = None ) -> None: """ :param reset_param_dict: Arguments needed for initializing the samplers :param seed: Random seed to be used for drawing samples from the samplers """ self.reset_param_dict = reset_param_dict if reset_param_dict else {} assert isinstance(self.reset_param_dict, dict) self.samplers: Dict[str, Sampler] = {} for param_name, cur_param_dict in self.reset_param_dict.items(): if "sampler-type" not in cur_param_dict: raise SamplerException( "'sampler_type' argument hasn't been supplied for the {0} parameter".format( param_name ) ) sampler_name = cur_param_dict.pop("sampler-type") param_sampler = SamplerFactory.init_sampler_class( sampler_name, cur_param_dict, seed ) self.samplers[param_name] = param_sampler