Exemplo n.º 1
0
    def _get_algorithm_kwargs(self, variant):
        algorithm_kwargs = super()._get_algorithm_kwargs(variant)
        algorithm_type = variant['algorithm_params']['type']

        # TODO: Replace this with a common API for single vs multigoal
        # === SINGLE GOAL POOL ===
        if algorithm_type in ('SACClassifier', 'RAQ', 'VICE', 'VICEGAN',
                              'VICERAQ', 'VICEDynamicsAware',
                              'DynamicsAwareEmbeddingVICE'):

            reward_classifier = self.reward_classifier = (
                get_reward_classifier_from_variant(
                    self._variant, algorithm_kwargs['training_environment']))
            algorithm_kwargs['classifier'] = reward_classifier

            goal_examples_train, goal_examples_validation = (
                get_goal_example_from_variant(variant))
            algorithm_kwargs['goal_examples'] = goal_examples_train
            algorithm_kwargs['goal_examples_validation'] = (
                goal_examples_validation)

            if algorithm_type == 'VICEDynamicsAware':
                algorithm_kwargs['dynamics_model'] = (
                    get_dynamics_model_from_variant(
                        self._variant,
                        algorithm_kwargs['training_environment']))

            elif algorithm_type == 'DynamicsAwareEmbeddingVICE':
                # TODO(justinvyu): Get this working for any environment
                self.distance_fn = algorithm_kwargs['distance_fn'] = (
                    reward_classifier.
                    observations_preprocessors['state_observation'])
                # TODO(justinvyu): include goal state as one of the VICE goal exmaples?
                algorithm_kwargs['goal_state'] = None

        # === LOAD GOAL POOLS FOR MULTI GOAL ===
        elif algorithm_type in ('VICEGANMultiGoal', 'MultiVICEGAN'):
            goal_pools_train, goal_pools_validation = (
                get_example_pools_from_variant(variant))
            num_goals = len(goal_pools_train)

            reward_classifiers = self.reward_classifiers = tuple(
                get_reward_classifier_from_variant(
                    variant, algorithm_kwargs['training_environment'])
                for _ in range(num_goals))

            algorithm_kwargs['classifiers'] = reward_classifiers
            algorithm_kwargs['goal_example_pools'] = goal_pools_train
            algorithm_kwargs[
                'goal_example_validation_pools'] = goal_pools_validation

        elif algorithm_type == 'SQIL':
            goal_transitions = get_goal_transitions_from_variant(variant)
            algorithm_kwargs['goal_transitions'] = goal_transitions

        return algorithm_kwargs
Exemplo n.º 2
0
    def _build(self):
        variant = copy.deepcopy(self._variant)

        training_environment = self.training_environment = (
            get_goal_example_environment_from_variant(variant))
        evaluation_environment = self.evaluation_environment = (
            get_goal_example_environment_from_variant(variant))
        replay_pool = self.replay_pool = (get_replay_pool_from_variant(
            variant, training_environment))
        sampler = self.sampler = get_sampler_from_variant(variant)
        # 创建网络 Dense :inputs:[state,action] outputs:size=1
        Qs = self.Qs = get_Q_function_from_variant(variant,
                                                   training_environment)
        policy = self.policy = get_policy_from_variant(variant,
                                                       training_environment,
                                                       Qs)
        initial_exploration_policy = self.initial_exploration_policy = (
            get_policy('UniformPolicy', training_environment))

        algorithm_kwargs = {
            'variant': self._variant,
            'training_environment': self.training_environment,
            'evaluation_environment': self.evaluation_environment,
            'policy': policy,
            'initial_exploration_policy': initial_exploration_policy,
            'Qs': Qs,
            'pool': replay_pool,
            'sampler': sampler,
            'session': self._session,
        }

        if self._variant['algorithm_params']['type'] in [
                'SACClassifier', 'RAQ', 'VICE', 'VICEGAN', 'VICERAQ'
        ]:
            reward_classifier = self.reward_classifier \
                = get_reward_classifier_from_variant(self._variant, training_environment)
            algorithm_kwargs['classifier'] = reward_classifier

            goal_examples_train, goal_examples_validation = \
                get_goal_example_from_variant(variant)
            algorithm_kwargs['goal_examples'] = goal_examples_train
            algorithm_kwargs['goal_examples_validation'] = \
                goal_examples_validation

        self.algorithm = get_algorithm_from_variant(**algorithm_kwargs)

        initialize_tf_variables(self._session, only_uninitialized=True)

        self._built = True
Exemplo n.º 3
0
    def _restore_algorithm_kwargs(self, picklable, checkpoint_dir, variant):
        algorithm_kwargs = super()._restore_algorithm_kwargs(picklable, checkpoint_dir, variant)

        if 'reward_classifier' in picklable.keys():
            reward_classifier = self.reward_classifier = picklable[
                'reward_classifier']

            algorithm_kwargs['classifier'] = reward_classifier

            goal_examples_train, goal_examples_validation = (
                get_goal_example_from_variant(variant))
            algorithm_kwargs['goal_examples'] = goal_examples_train
            algorithm_kwargs['goal_examples_validation'] = (
                goal_examples_validation)

        if 'distance_estimator' in picklable.keys():
            distance_fn = self.distance_fn = picklable['distance_estimator']
            algorithm_kwargs['distance_fn'] = distance_fn
            algorithm_kwargs['goal_state'] = None

        return algorithm_kwargs
Exemplo n.º 4
0
    def _build(self):
        variant = copy.deepcopy(self._variant)
        print(variant.keys())
        env = self.env = get_environment_from_params(
            variant['environment_params']['training'])
        replay_pool = self.replay_pool = (get_replay_pool_from_variant(
            variant, env))
        sampler = self.sampler = get_sampler_from_variant(variant)
        Qs = self.Qs = get_Q_function_from_variant(variant, env)
        policy = self.policy = get_policy_from_variant(variant, env, Qs)
        initial_exploration_policy = self.initial_exploration_policy = (
            get_policy('UniformPolicy', env))

        algorithm_kwargs = {
            'variant': self._variant,
            'env': self.env,
            'policy': policy,
            'initial_exploration_policy': initial_exploration_policy,
            'Qs': Qs,
            'pool': replay_pool,
            'sampler': sampler,
            'session': self._session,
        }

        if self._variant['algorithm_params']['type'] in CLASSIFIER_RL_ALGS:
            reward_classifier = self.reward_classifier \
                = get_reward_classifier_from_variant(self._variant, env)
            algorithm_kwargs['classifier'] = reward_classifier

            goal_examples_train, goal_examples_validation = \
                get_goal_example_from_variant(variant)
            algorithm_kwargs['goal_examples'] = goal_examples_train
            algorithm_kwargs['goal_examples_validation'] = \
                goal_examples_validation

        self.algorithm = get_algorithm_from_variant(**algorithm_kwargs)

        initialize_tf_variables(self._session, only_uninitialized=True)

        self._built = True
Exemplo n.º 5
0
    def _restore(self, checkpoint_dir):
        assert isinstance(checkpoint_dir, str), checkpoint_dir

        checkpoint_dir = checkpoint_dir.rstrip('/')

        with self._session.as_default():
            pickle_path = self._pickle_path(checkpoint_dir)
            with open(pickle_path, 'rb') as f:
                picklable = pickle.load(f)

        training_environment = self.training_environment = picklable[
            'training_environment']
        evaluation_environment = self.evaluation_environment = picklable[
            'evaluation_environment']

        replay_pool = self.replay_pool = (get_replay_pool_from_variant(
            self._variant, training_environment))

        if self._variant['run_params'].get('checkpoint_replay_pool', False):
            self._restore_replay_pool(checkpoint_dir)

        sampler = self.sampler = picklable['sampler']
        Qs = self.Qs = picklable['Qs']
        # policy = self.policy = picklable['policy']
        policy = self.policy = (get_policy_from_variant(
            self._variant, training_environment, Qs))
        self.policy.set_weights(picklable['policy_weights'])
        initial_exploration_policy = self.initial_exploration_policy = (
            get_policy('UniformPolicy', training_environment))

        algorithm_kwargs = {
            'variant': self._variant,
            'training_environment': self.training_environment,
            'evaluation_environment': self.evaluation_environment,
            'policy': policy,
            'initial_exploration_policy': initial_exploration_policy,
            'Qs': Qs,
            'pool': replay_pool,
            'sampler': sampler,
            'session': self._session,
        }

        if self._variant['algorithm_params']['type'] in [
                'SACClassifier', 'RAQ', 'VICE', 'VICEGAN', 'VICERAQ'
        ]:
            reward_classifier = self.reward_classifier = picklable[
                'reward_classifier']
            algorithm_kwargs['classifier'] = reward_classifier

            goal_examples_train, goal_examples_validation = \
                get_goal_example_from_variant(variant)
            algorithm_kwargs['goal_examples'] = goal_examples_train
            algorithm_kwargs['goal_examples_validation'] = \
                goal_examples_validation

        self.algorithm = get_algorithm_from_variant(**algorithm_kwargs)
        self.algorithm.__setstate__(picklable['algorithm'].__getstate__())

        tf_checkpoint = self._get_tf_checkpoint()
        status = tf_checkpoint.restore(
            tf.train.latest_checkpoint(
                os.path.split(self._tf_checkpoint_prefix(checkpoint_dir))[0]))

        status.assert_consumed().run_restore_ops(self._session)
        initialize_tf_variables(self._session, only_uninitialized=True)

        # TODO(hartikainen): target Qs should either be checkpointed or pickled.
        for Q, Q_target in zip(self.algorithm._Qs, self.algorithm._Q_targets):
            Q_target.set_weights(Q.get_weights())

        self._built = True