def _restore(self, checkpoint_dir): assert isinstance(checkpoint_dir, str), checkpoint_dir checkpoint_dir = checkpoint_dir.rstrip('/') with self._session.as_default(): pickle_path = self._pickle_path(checkpoint_dir) with open(pickle_path, 'rb') as f: picklable = pickle.load(f) training_environment = self.training_environment = picklable[ 'training_environment'] evaluation_environment = self.evaluation_environment = picklable[ 'evaluation_environment'] replay_pool = self.replay_pool = ( get_replay_pool_from_variant(self._variant, training_environment)) if self._variant['run_params'].get('checkpoint_replay_pool', False): self._restore_replay_pool(checkpoint_dir) sampler = self.sampler = picklable['sampler'] Qs = self.Qs = get_Q_function_from_variant( self._variant, training_environment) self._restore_value_functions(checkpoint_dir) policy = self.policy = ( get_policy_from_variant(self._variant, training_environment)) self.policy.set_weights(picklable['policy_weights']) initial_exploration_policy = self.initial_exploration_policy = ( get_policy_from_params( self._variant['exploration_policy_params'], training_environment)) self.algorithm = get_algorithm_from_variant( variant=self._variant, training_environment=training_environment, evaluation_environment=evaluation_environment, policy=policy, initial_exploration_policy=initial_exploration_policy, Qs=Qs, pool=replay_pool, sampler=sampler, session=self._session) self.algorithm.__setstate__(picklable['algorithm'].__getstate__()) tf_checkpoint = self._get_tf_checkpoint() status = tf_checkpoint.restore(tf.train.latest_checkpoint( os.path.split(self._tf_checkpoint_prefix(checkpoint_dir))[0])) status.assert_consumed().run_restore_ops(self._session) initialize_tf_variables(self._session, only_uninitialized=True) # TODO(hartikainen): target Qs should either be checkpointed or pickled. for Q, Q_target in zip(self.algorithm._Qs, self.algorithm._Q_targets): Q_target.set_weights(Q.get_weights()) self._built = True
def run_experiment(variant, reporter): training_environment = ( get_environment('gym', 'MultiGoal', 'Default-v0', { 'actuation_cost_coeff': 30, 'distance_cost_coeff': 1, 'goal_reward': 10, 'init_sigma': 0.1, })) evaluation_environment = training_environment.copy() pool = SimpleReplayPool( environment=training_environment, max_size=1e6) sampler = SimpleSampler(max_path_length=30) Qs = get_Q_function_from_variant(variant, training_environment) policy = get_policy_from_variant(variant, training_environment) plotter = QFPolicyPlotter( Q=Qs[0], policy=policy, obs_lst=np.array(((-2.5, 0.0), (0.0, 0.0), (2.5, 2.5), (-2.5, -2.5))), default_action=(np.nan, np.nan), n_samples=100) algorithm = get_algorithm_from_variant( variant=variant, training_environment=training_environment, evaluation_environment=evaluation_environment, policy=policy, Qs=Qs, pool=pool, sampler=sampler, min_pool_size=100, batch_size=46, plotter=plotter, ) initialize_tf_variables(algorithm._session, only_uninitialized=True) for train_result in algorithm.train(): reporter(**train_result)
def _build(self): variant = copy.deepcopy(self._variant) environment_params = variant['environment_params'] training_environment = self.training_environment = ( get_environment_from_params(environment_params['training'])) evaluation_environment = self.evaluation_environment = ( get_environment_from_params(environment_params['evaluation']) if 'evaluation' in environment_params else training_environment) replay_pool = self.replay_pool = ( get_replay_pool_from_variant(variant, training_environment)) sampler = self.sampler = get_sampler_from_variant(variant) Qs = self.Qs = get_Q_function_from_variant( variant, training_environment) policy = self.policy = get_policy_from_variant( variant, training_environment) initial_exploration_policy = self.initial_exploration_policy = ( get_policy_from_params( variant['exploration_policy_params'], training_environment)) self.algorithm = get_algorithm_from_variant( variant=self._variant, training_environment=training_environment, evaluation_environment=evaluation_environment, policy=policy, initial_exploration_policy=initial_exploration_policy, Qs=Qs, pool=replay_pool, sampler=sampler, session=self._session) initialize_tf_variables(self._session, only_uninitialized=True) self._built = True