Ejemplo n.º 1
0
    def _create_experiment(self, environment, environment_params, agent_name,
                           agent_builder_params):
        separator = '.'
        if separator in environment:
            environment_name, environment_id = environment.split(separator)
            environment_params = dict(env_id=environment_id,
                                      **environment_params)
            environment = environment.replace(separator, '_')
        else:
            environment_name = environment

        logger = BenchmarkLogger(log_dir=self.logger.get_path(),
                                 log_id='{}/{}'.format(environment,
                                                       agent_name),
                                 use_timestamp=False)

        try:
            builder = getattr(mushroom_rl_benchmark.builders,
                              '{}Builder'.format(agent_name))
        except AttributeError as e:
            logger.exception(e)

        agent_builder = builder.default(**agent_builder_params)
        env_builder = EnvironmentBuilder(environment_name, environment_params)

        exp = BenchmarkExperiment(agent_builder, env_builder, logger)

        return exp
Ejemplo n.º 2
0
    def from_path(cls, path):
        """
        Method to create a BenchmarkVisualizer from a path.

        """
        path = Path(path)
        return cls(BenchmarkLogger(path.parent, path.name, False))
Ejemplo n.º 3
0
    def __init__(self,
                 log_dir=None,
                 log_id=None,
                 use_timestamp=True,
                 parallel=None,
                 slurm=None):
        """
        Constructor.

        Args:
            log_dir (str): path to the log directory (Default: ./logs or /work/scratch/$USER)
            log_id (str): log id (Default: benchmark[_YYYY-mm-dd-HH-MM-SS])
            use_timestamp (bool): select if a timestamp should be appended to the log id
            parallel (dict, None): parameters that are passed to the run_parallel method of the experiment
            slurm (dict, None): parameters that are passed to the run_slurm method of the experiment
        
        """
        self._experiment_structure = dict()
        self._environment_dict = dict()
        self._agent_list = []
        self._parallel = parallel
        self._slurm = slurm
        self.logger = BenchmarkLogger(log_dir=log_dir,
                                      log_id=log_id,
                                      use_timestamp=use_timestamp)
Ejemplo n.º 4
0
    def __init__(self,
                 log_dir=None,
                 log_id=None,
                 use_timestamp=True,
                 **run_params):
        """
        Constructor.

        Kwargs:
            log_dir (str): path to the log directory (Default: ./logs or /work/scratch/$USER)
            log_id (str): log id (Default: benchmark[_YY-mm-ddTHH:MM:SS.zzz])
            use_timestamp (bool): select if a timestamp should be appended to the log id
            **run_params (dict): parameters that are passed to the run method of the experiment
        """
        self.experiment_structure = dict()
        self.environment_list = []
        self.agent_list = []
        self.run_params = run_params
        self.logger = BenchmarkLogger(log_dir=log_dir,
                                      log_id=log_id,
                                      use_timestamp=use_timestamp)