Ejemplo n.º 1
0
 def get_ground_truth(self):
     self.ground_truth_path = self.local_models_path / 'safety_ground_truth.npz'
     load = self.ground_truth_path.exists()
     if load:
         try:
             ground_truth = SafetyTruth.load(self.ground_truth_path,
                                             self.env)
         except ValueError:
             load = False
     if not load:
         ground_truth = SafetyTruth(self.env)
         ground_truth.compute()
         ground_truth.save(self.ground_truth_path)
     return ground_truth
Ejemplo n.º 2
0
class SafetyTruthComputation(TruthComputationSimulation):
    def __init__(self, name, env_name, discretization_shape, *args, **kwargs):
        if env_name == 'cartpole':
            env_builder = ContinuousCartPole
        else:
            raise ValueError(f'Environment {env_name} is not supported')
        output_directory = Path(__file__).parent.resolve()
        super(SafetyTruthComputation, self).__init__(output_directory, name,
                                                     safety_name(env_name))

        self.env = env_builder(discretization_shape=discretization_shape,
                               *args,
                               **kwargs)
        self.truth = SafetyTruth(self.env)

        self.Q_map_path = self.output_directory / (str(Q_map_name(env_name)) +
                                                   '.npy')
        self.save_path = self.output_directory / safety_name(env_name)

        logger.info(config_msg(f"env_name='{env_name}'"))
        logger.info(
            config_msg(f"discretization_shape='{discretization_shape}'"))
        logger.info((config_msg(f"args={args}")))
        logger.info((config_msg(f"kwargs={kwargs}")))

    def run(self):
        logger.info('Launched computation of viable set')
        if not self.Q_map_path.exists():
            errormsg = f'The transition map could not be found at ' \
                       f'{str(self.Q_map_path)}. Please compute it first.'
            logger.critical(errormsg)
            raise FileNotFoundError(errormsg)
        tick = time.time()
        self.truth.compute(self.Q_map_path)
        tock = time.time()
        logger.info(f'Done in {tock - tick:.2f} s.')
        self.truth.save(str(self.save_path))
        logger.info(f'Output saved in {str(self.save_path)}')