def setUp(self): self.env = get_environment('gym', 'Swimmer', 'v3', {}) self.policy = get_policy_from_params({'type': 'UniformPolicy'}, env=self.env) self.pool = SimpleReplayPool(max_size=100, environment=self.env) self.remote_sampler = RemoteSampler(max_path_length=10, min_pool_size=10, batch_size=10)
def _restore(self, checkpoint_dir): assert isinstance(checkpoint_dir, str), checkpoint_dir checkpoint_dir = checkpoint_dir.rstrip('/') with self._session.as_default(): pickle_path = self._pickle_path(checkpoint_dir) with open(pickle_path, 'rb') as f: picklable = pickle.load(f) training_environment = self.training_environment = picklable[ 'training_environment'] evaluation_environment = self.evaluation_environment = picklable[ 'evaluation_environment'] replay_pool = self.replay_pool = (get_replay_pool_from_variant( self._variant, training_environment)) if self._variant['run_params'].get('checkpoint_replay_pool', False): self._restore_replay_pool(checkpoint_dir) sampler = self.sampler = picklable['sampler'] Qs = self.Qs = get_Q_function_from_variant(self._variant, training_environment) self._restore_value_functions(checkpoint_dir) policy = self.policy = (get_policy_from_variant( self._variant, training_environment)) self.policy.set_weights(picklable['policy_weights']) initial_exploration_policy = self.initial_exploration_policy = ( get_policy_from_params(self._variant['exploration_policy_params'], training_environment)) self.algorithm = get_algorithm_from_variant( variant=self._variant, training_environment=training_environment, evaluation_environment=evaluation_environment, policy=policy, initial_exploration_policy=initial_exploration_policy, Qs=Qs, pool=replay_pool, sampler=sampler, session=self._session) self.algorithm.__setstate__(picklable['algorithm'].__getstate__()) tf_checkpoint = self._get_tf_checkpoint() status = tf_checkpoint.restore( tf.train.latest_checkpoint( os.path.split(self._tf_checkpoint_prefix(checkpoint_dir))[0])) status.assert_consumed().run_restore_ops(self._session) initialize_tf_variables(self._session, only_uninitialized=True) # TODO(hartikainen): target Qs should either be checkpointed or pickled. for Q, Q_target in zip(self.algorithm._Qs, self.algorithm._Q_targets): Q_target.set_weights(Q.get_weights()) self._built = True
def _build(self): variant = copy.deepcopy(self._variant) environment_params = variant['environment_params'] training_environment = self.training_environment = ( get_environment_from_params(environment_params['training'])) evaluation_environment = self.evaluation_environment = ( get_environment_from_params(environment_params['evaluation']) if 'evaluation' in environment_params else training_environment) training_environment.seed(variant['run_params']['seed']) evaluation_environment.seed(variant['run_params']['seed']) replay_pool = self.replay_pool = ( get_replay_pool_from_variant(variant, training_environment)) sampler = self.sampler = get_sampler_from_variant(variant) Qs = self.Qs = get_Q_function_from_variant( variant, training_environment) policy = self.policy = get_policy_from_variant( variant, training_environment) initial_exploration_policy = self.initial_exploration_policy = ( get_policy_from_params( variant['exploration_policy_params'], training_environment)) self.algorithm = get_algorithm_from_variant( variant=self._variant, training_environment=training_environment, evaluation_environment=evaluation_environment, policy=policy, initial_exploration_policy=initial_exploration_policy, Qs=Qs, pool=replay_pool, sampler=sampler, session=self._session) initialize_tf_variables(self._session, only_uninitialized=True) self._built = True
def _build(self): variant = copy.deepcopy(self._variant) #training_environment = self.training_environment = ( # get_goal_example_environment_from_variant( # variant['task'], gym_adapter=False)) training_environment = self.training_environment = (GymAdapter( domain=variant['domain'], task=variant['task'], **variant['env_params'])) #evaluation_environment = self.evaluation_environment = ( # get_goal_example_environment_from_variant( # variant['task_evaluation'], gym_adapter=False)) evaluation_environment = self.evaluation_environment = (GymAdapter( domain=variant['domain'], task=variant['task_evaluation'], **variant['env_params'])) # training_environment = self.training_environment = ( # flatten_multiworld_env(self.training_environment)) # evaluation_environment = self.evaluation_environment = ( # flatten_multiworld_env(self.evaluation_environment)) #training_environment = self.training_environment = ( # GymAdapter(env=training_environment)) #evaluation_environment = self.evaluation_environment = ( # GymAdapter(env=evaluation_environment)) # make sure this is her replay pool replay_pool = self.replay_pool = (get_replay_pool_from_variant( variant, training_environment)) sampler = self.sampler = get_sampler_from_variant(variant) Qs = self.Qs = get_Q_function_from_variant(variant, training_environment) policy = self.policy = get_policy_from_variant(variant, training_environment) initial_exploration_policy = self.initial_exploration_policy = ( get_policy_from_params(variant['exploration_policy_params'], training_environment)) algorithm_kwargs = { 'variant': self._variant, 'training_environment': self.training_environment, 'evaluation_environment': self.evaluation_environment, 'policy': policy, 'initial_exploration_policy': initial_exploration_policy, 'Qs': Qs, 'pool': replay_pool, 'sampler': sampler, 'session': self._session, } if self._variant['algorithm_params']['type'] in [ 'VICEGoalConditioned', 'VICEGANGoalConditioned' ]: reward_classifier = self.reward_classifier = ( get_reward_classifier_from_variant(self._variant, training_environment)) algorithm_kwargs['classifier'] = reward_classifier # goal_examples_train, goal_examples_validation = \ # get_goal_example_from_variant(variant) algorithm_kwargs['goal_examples'] = np.empty((1, 1)) algorithm_kwargs['goal_examples_validation'] = np.empty((1, 1)) # RND if variant['algorithm_params']['rnd_params']: from softlearning.rnd.utils import get_rnd_networks_from_variant rnd_networks = get_rnd_networks_from_variant( variant, training_environment) else: rnd_networks = () algorithm_kwargs['rnd_networks'] = rnd_networks self.algorithm = get_algorithm_from_variant(**algorithm_kwargs) initialize_tf_variables(self._session, only_uninitialized=True) self._built = True
def _restore(self, checkpoint_dir): assert isinstance(checkpoint_dir, str), checkpoint_dir checkpoint_dir = checkpoint_dir.rstrip('/') with self._session.as_default(): pickle_path = self._pickle_path(checkpoint_dir) with open(pickle_path, 'rb') as f: picklable = pickle.load(f) training_environment = self.training_environment = picklable[ 'training_environment'] evaluation_environment = self.evaluation_environment = picklable[ 'evaluation_environment'] replay_pool = self.replay_pool = (get_replay_pool_from_variant( self._variant, training_environment)) if self._variant['run_params'].get('checkpoint_replay_pool', False): self._restore_replay_pool(checkpoint_dir) sampler = self.sampler = picklable['sampler'] Qs = self.Qs = picklable['Qs'] # policy = self.policy = picklable['policy'] policy = self.policy = (get_policy_from_variant( self._variant, training_environment, Qs)) self.policy.set_weights(picklable['policy_weights']) initial_exploration_policy = self.initial_exploration_policy = ( get_policy_from_params(variant['exploration_policy_params'], training_environment)) algorithm_kwargs = { 'variant': self._variant, 'training_environment': self.training_environment, 'evaluation_environment': self.evaluation_environment, 'policy': policy, 'initial_exploration_policy': initial_exploration_policy, 'Qs': Qs, 'pool': replay_pool, 'sampler': sampler, 'session': self._session, } if self._variant['algorithm_params']['type'] in ('SACClassifier', 'RAQ', 'VICE', 'VICERAQ'): reward_classifier = self.reward_classifier = picklable[ 'reward_classifier'] algorithm_kwargs['classifier'] = reward_classifier # goal_examples_train, goal_examples_validation = \ # get_goal_example_from_variant(variant) # algorithm_kwargs['goal_examples'] = goal_examples_train # algorithm_kwargs['goal_examples_validation'] = \ # goal_examples_validation self.algorithm = get_algorithm_from_variant(**algorithm_kwargs) self.algorithm.__setstate__(picklable['algorithm'].__getstate__()) tf_checkpoint = self._get_tf_checkpoint() status = tf_checkpoint.restore( tf.train.latest_checkpoint( os.path.split(self._tf_checkpoint_prefix(checkpoint_dir))[0])) status.assert_consumed().run_restore_ops(self._session) initialize_tf_variables(self._session, only_uninitialized=True) # TODO(hartikainen): target Qs should either be checkpointed or pickled. for Q, Q_target in zip(self.algorithm._Qs, self.algorithm._Q_targets): Q_target.set_weights(Q.get_weights()) self._built = True