def _get_algorithm_kwargs(self, variant): algorithm_kwargs = super()._get_algorithm_kwargs(variant) algorithm_type = variant['algorithm_params']['type'] # TODO: Replace this with a common API for single vs multigoal # === SINGLE GOAL POOL === if algorithm_type in ('SACClassifier', 'RAQ', 'VICE', 'VICEGAN', 'VICERAQ', 'VICEDynamicsAware', 'DynamicsAwareEmbeddingVICE'): reward_classifier = self.reward_classifier = ( get_reward_classifier_from_variant( self._variant, algorithm_kwargs['training_environment'])) algorithm_kwargs['classifier'] = reward_classifier goal_examples_train, goal_examples_validation = ( get_goal_example_from_variant(variant)) algorithm_kwargs['goal_examples'] = goal_examples_train algorithm_kwargs['goal_examples_validation'] = ( goal_examples_validation) if algorithm_type == 'VICEDynamicsAware': algorithm_kwargs['dynamics_model'] = ( get_dynamics_model_from_variant( self._variant, algorithm_kwargs['training_environment'])) elif algorithm_type == 'DynamicsAwareEmbeddingVICE': # TODO(justinvyu): Get this working for any environment self.distance_fn = algorithm_kwargs['distance_fn'] = ( reward_classifier. observations_preprocessors['state_observation']) # TODO(justinvyu): include goal state as one of the VICE goal exmaples? algorithm_kwargs['goal_state'] = None # === LOAD GOAL POOLS FOR MULTI GOAL === elif algorithm_type in ('VICEGANMultiGoal', 'MultiVICEGAN'): goal_pools_train, goal_pools_validation = ( get_example_pools_from_variant(variant)) num_goals = len(goal_pools_train) reward_classifiers = self.reward_classifiers = tuple( get_reward_classifier_from_variant( variant, algorithm_kwargs['training_environment']) for _ in range(num_goals)) algorithm_kwargs['classifiers'] = reward_classifiers algorithm_kwargs['goal_example_pools'] = goal_pools_train algorithm_kwargs[ 'goal_example_validation_pools'] = goal_pools_validation elif algorithm_type == 'SQIL': goal_transitions = get_goal_transitions_from_variant(variant) algorithm_kwargs['goal_transitions'] = goal_transitions return algorithm_kwargs
def _build(self): variant = copy.deepcopy(self._variant) training_environment = self.training_environment = ( get_goal_example_environment_from_variant(variant)) evaluation_environment = self.evaluation_environment = ( get_goal_example_environment_from_variant(variant)) replay_pool = self.replay_pool = (get_replay_pool_from_variant( variant, training_environment)) sampler = self.sampler = get_sampler_from_variant(variant) # 创建网络 Dense :inputs:[state,action] outputs:size=1 Qs = self.Qs = get_Q_function_from_variant(variant, training_environment) policy = self.policy = get_policy_from_variant(variant, training_environment, Qs) initial_exploration_policy = self.initial_exploration_policy = ( get_policy('UniformPolicy', training_environment)) algorithm_kwargs = { 'variant': self._variant, 'training_environment': self.training_environment, 'evaluation_environment': self.evaluation_environment, 'policy': policy, 'initial_exploration_policy': initial_exploration_policy, 'Qs': Qs, 'pool': replay_pool, 'sampler': sampler, 'session': self._session, } if self._variant['algorithm_params']['type'] in [ 'SACClassifier', 'RAQ', 'VICE', 'VICEGAN', 'VICERAQ' ]: reward_classifier = self.reward_classifier \ = get_reward_classifier_from_variant(self._variant, training_environment) algorithm_kwargs['classifier'] = reward_classifier goal_examples_train, goal_examples_validation = \ get_goal_example_from_variant(variant) algorithm_kwargs['goal_examples'] = goal_examples_train algorithm_kwargs['goal_examples_validation'] = \ goal_examples_validation self.algorithm = get_algorithm_from_variant(**algorithm_kwargs) initialize_tf_variables(self._session, only_uninitialized=True) self._built = True
def _restore_algorithm_kwargs(self, picklable, checkpoint_dir, variant): algorithm_kwargs = super()._restore_algorithm_kwargs(picklable, checkpoint_dir, variant) if 'reward_classifier' in picklable.keys(): reward_classifier = self.reward_classifier = picklable[ 'reward_classifier'] algorithm_kwargs['classifier'] = reward_classifier goal_examples_train, goal_examples_validation = ( get_goal_example_from_variant(variant)) algorithm_kwargs['goal_examples'] = goal_examples_train algorithm_kwargs['goal_examples_validation'] = ( goal_examples_validation) if 'distance_estimator' in picklable.keys(): distance_fn = self.distance_fn = picklable['distance_estimator'] algorithm_kwargs['distance_fn'] = distance_fn algorithm_kwargs['goal_state'] = None return algorithm_kwargs
def _build(self): variant = copy.deepcopy(self._variant) print(variant.keys()) env = self.env = get_environment_from_params( variant['environment_params']['training']) replay_pool = self.replay_pool = (get_replay_pool_from_variant( variant, env)) sampler = self.sampler = get_sampler_from_variant(variant) Qs = self.Qs = get_Q_function_from_variant(variant, env) policy = self.policy = get_policy_from_variant(variant, env, Qs) initial_exploration_policy = self.initial_exploration_policy = ( get_policy('UniformPolicy', env)) algorithm_kwargs = { 'variant': self._variant, 'env': self.env, 'policy': policy, 'initial_exploration_policy': initial_exploration_policy, 'Qs': Qs, 'pool': replay_pool, 'sampler': sampler, 'session': self._session, } if self._variant['algorithm_params']['type'] in CLASSIFIER_RL_ALGS: reward_classifier = self.reward_classifier \ = get_reward_classifier_from_variant(self._variant, env) algorithm_kwargs['classifier'] = reward_classifier goal_examples_train, goal_examples_validation = \ get_goal_example_from_variant(variant) algorithm_kwargs['goal_examples'] = goal_examples_train algorithm_kwargs['goal_examples_validation'] = \ goal_examples_validation self.algorithm = get_algorithm_from_variant(**algorithm_kwargs) initialize_tf_variables(self._session, only_uninitialized=True) self._built = True
def _restore(self, checkpoint_dir): assert isinstance(checkpoint_dir, str), checkpoint_dir checkpoint_dir = checkpoint_dir.rstrip('/') with self._session.as_default(): pickle_path = self._pickle_path(checkpoint_dir) with open(pickle_path, 'rb') as f: picklable = pickle.load(f) training_environment = self.training_environment = picklable[ 'training_environment'] evaluation_environment = self.evaluation_environment = picklable[ 'evaluation_environment'] replay_pool = self.replay_pool = (get_replay_pool_from_variant( self._variant, training_environment)) if self._variant['run_params'].get('checkpoint_replay_pool', False): self._restore_replay_pool(checkpoint_dir) sampler = self.sampler = picklable['sampler'] Qs = self.Qs = picklable['Qs'] # policy = self.policy = picklable['policy'] policy = self.policy = (get_policy_from_variant( self._variant, training_environment, Qs)) self.policy.set_weights(picklable['policy_weights']) initial_exploration_policy = self.initial_exploration_policy = ( get_policy('UniformPolicy', training_environment)) algorithm_kwargs = { 'variant': self._variant, 'training_environment': self.training_environment, 'evaluation_environment': self.evaluation_environment, 'policy': policy, 'initial_exploration_policy': initial_exploration_policy, 'Qs': Qs, 'pool': replay_pool, 'sampler': sampler, 'session': self._session, } if self._variant['algorithm_params']['type'] in [ 'SACClassifier', 'RAQ', 'VICE', 'VICEGAN', 'VICERAQ' ]: reward_classifier = self.reward_classifier = picklable[ 'reward_classifier'] algorithm_kwargs['classifier'] = reward_classifier goal_examples_train, goal_examples_validation = \ get_goal_example_from_variant(variant) algorithm_kwargs['goal_examples'] = goal_examples_train algorithm_kwargs['goal_examples_validation'] = \ goal_examples_validation self.algorithm = get_algorithm_from_variant(**algorithm_kwargs) self.algorithm.__setstate__(picklable['algorithm'].__getstate__()) tf_checkpoint = self._get_tf_checkpoint() status = tf_checkpoint.restore( tf.train.latest_checkpoint( os.path.split(self._tf_checkpoint_prefix(checkpoint_dir))[0])) status.assert_consumed().run_restore_ops(self._session) initialize_tf_variables(self._session, only_uninitialized=True) # TODO(hartikainen): target Qs should either be checkpointed or pickled. for Q, Q_target in zip(self.algorithm._Qs, self.algorithm._Q_targets): Q_target.set_weights(Q.get_weights()) self._built = True