def evaluate_model(self): logging.info('SimPLe epoch [% 6d]: evaluating model.', self._simple_epoch) start_time = time.time() self._sim_env.initialize( batch_size=self._simulated_batch_size, history_stream=itertools.repeat(None), ) (_, eval_trajectories) = self._load_trajectories( # If we have any trajectories collected in this run, evaluate on them. # Otherwise, use the initial dataset. initial=(not self._has_own_data)) chosen_trajectories = [ random.choice(eval_trajectories) for _ in range(self._sim_env.batch_size) ] summaries = simple.evaluate_model(self._sim_env, chosen_trajectories, plt) if summaries is not None: for (name, value) in summaries.items(): self._summary_writer.scalar('simple/{}'.format(name), value, step=self._simple_epoch) self._summary_writer.plot('simple/model_eval_plot', plt, step=self._simple_epoch) self.flush_summaries() logging.vlog(1, 'Evaluating model took %0.2f sec.', time.time() - start_time)
def test_fails_to_evaluate_model_with_matrix_observation_space(self): with backend.use_backend('numpy'): env = self._make_env( # pylint: disable=no-value-for-parameter observation_space=gym.spaces.Box(shape=(2, 2), low=0, high=1), action_space=gym.spaces.Discrete(n=1), max_trajectory_length=2, batch_size=1, ) trajectories = [ self._make_trajectory(np.array([[0, 1], [2, 3]]), np.array([0]))] metrics = simple.evaluate_model(env, trajectories, plt) self.assertIsNone(metrics)
def test_evaluates_model_with_vector_observation_space(self): with backend.use_backend('numpy'): env = self._make_env( # pylint: disable=no-value-for-parameter observation_space=gym.spaces.Box(shape=(2,), low=0, high=1), action_space=gym.spaces.Discrete(n=1), max_trajectory_length=2, batch_size=3, ) trajectories = [ self._make_trajectory(observations, actions) # pylint: disable=g-complex-comprehension for (observations, actions) in [ (np.array([[0, 1]]), np.array([0])), (np.array([[1, 2], [3, 4]]), np.array([0, 0])), (np.array([[1, 2], [3, 4], [5, 6]]), np.array([0, 0, 0])), ] ] metrics = simple.evaluate_model(env, trajectories, plt) self.assertIsNotNone(metrics) self.assertEqual(len(metrics), 2)