示例#1
0
 def init_sampler_class(
     name: str, params: Dict[str, Any], seed: Optional[int] = None
 ) -> Sampler:
     """
     Initializes the sampler class associated with the name with the params
     :param name: Name of the sampler in the factory to initialize
     :param params: Parameters associated to the sampler attached to the name
     :param seed: Random seed to be used to set deterministic random draws for the sampler
     """
     if name not in SamplerFactory.NAME_TO_CLASS:
         raise SamplerException(
             name + " sampler is not registered in the SamplerFactory."
             " Use the register_sample method to register the string"
             " associated to your sampler in the SamplerFactory."
         )
     sampler_cls = SamplerFactory.NAME_TO_CLASS[name]
     params["seed"] = seed
     try:
         return sampler_cls(**params)
     except TypeError:
         raise SamplerException(
             "The sampler class associated to the " + name + " key in the factory "
             "was not provided the required arguments. Please ensure that the sampler "
             "config file consists of the appropriate keys for this sampler class."
         )
示例#2
0
def create_sampler_manager(sampler_config, run_seed=None):
    resample_interval = None
    if sampler_config is not None:
        if "resampling-interval" in sampler_config:
            # Filter arguments that do not exist in the environment
            resample_interval = sampler_config.pop("resampling-interval")
            if (resample_interval <= 0) or (not isinstance(resample_interval, int)):
                raise SamplerException(
                    "Specified resampling-interval is not valid. Please provide"
                    " a positive integer value for resampling-interval"
                )

        else:
            raise SamplerException(
                "Resampling interval was not specified in the sampler file."
                " Please specify it with the 'resampling-interval' key in the sampler config file."
            )

    sampler_manager = SamplerManager(sampler_config, run_seed)
    return sampler_manager, resample_interval
示例#3
0
    def __init__(
        self, reset_param_dict: Dict[str, Any], seed: Optional[int] = None
    ) -> None:
        """
        :param reset_param_dict: Arguments needed for initializing the samplers
        :param seed: Random seed to be used for drawing samples from the samplers
        """
        self.reset_param_dict = reset_param_dict if reset_param_dict else {}
        assert isinstance(self.reset_param_dict, dict)
        self.samplers: Dict[str, Sampler] = {}
        for param_name, cur_param_dict in self.reset_param_dict.items():
            if "sampler-type" not in cur_param_dict:
                raise SamplerException(
                    "'sampler_type' argument hasn't been supplied for the {0} parameter".format(
                        param_name
                    )
                )
            sampler_name = cur_param_dict.pop("sampler-type")
            param_sampler = SamplerFactory.init_sampler_class(
                sampler_name, cur_param_dict, seed
            )

            self.samplers[param_name] = param_sampler