Ejemplo n.º 1
0
 def setUp(self):
     self.env = get_environment('gym', 'Swimmer', 'v3', {})
     self.policy = get_policy_from_params({'type': 'UniformPolicy'},
                                          env=self.env)
     self.pool = SimpleReplayPool(max_size=100, environment=self.env)
     self.remote_sampler = RemoteSampler(max_path_length=10,
                                         min_pool_size=10,
                                         batch_size=10)
Ejemplo n.º 2
0
    def _restore(self, checkpoint_dir):
        assert isinstance(checkpoint_dir, str), checkpoint_dir

        checkpoint_dir = checkpoint_dir.rstrip('/')

        with self._session.as_default():
            pickle_path = self._pickle_path(checkpoint_dir)
            with open(pickle_path, 'rb') as f:
                picklable = pickle.load(f)

        training_environment = self.training_environment = picklable[
            'training_environment']
        evaluation_environment = self.evaluation_environment = picklable[
            'evaluation_environment']

        replay_pool = self.replay_pool = (get_replay_pool_from_variant(
            self._variant, training_environment))

        if self._variant['run_params'].get('checkpoint_replay_pool', False):
            self._restore_replay_pool(checkpoint_dir)

        sampler = self.sampler = picklable['sampler']
        Qs = self.Qs = get_Q_function_from_variant(self._variant,
                                                   training_environment)
        self._restore_value_functions(checkpoint_dir)
        policy = self.policy = (get_policy_from_variant(
            self._variant, training_environment))
        self.policy.set_weights(picklable['policy_weights'])
        initial_exploration_policy = self.initial_exploration_policy = (
            get_policy_from_params(self._variant['exploration_policy_params'],
                                   training_environment))

        self.algorithm = get_algorithm_from_variant(
            variant=self._variant,
            training_environment=training_environment,
            evaluation_environment=evaluation_environment,
            policy=policy,
            initial_exploration_policy=initial_exploration_policy,
            Qs=Qs,
            pool=replay_pool,
            sampler=sampler,
            session=self._session)
        self.algorithm.__setstate__(picklable['algorithm'].__getstate__())

        tf_checkpoint = self._get_tf_checkpoint()
        status = tf_checkpoint.restore(
            tf.train.latest_checkpoint(
                os.path.split(self._tf_checkpoint_prefix(checkpoint_dir))[0]))

        status.assert_consumed().run_restore_ops(self._session)
        initialize_tf_variables(self._session, only_uninitialized=True)

        # TODO(hartikainen): target Qs should either be checkpointed or pickled.
        for Q, Q_target in zip(self.algorithm._Qs, self.algorithm._Q_targets):
            Q_target.set_weights(Q.get_weights())

        self._built = True
Ejemplo n.º 3
0
    def _build(self):
        variant = copy.deepcopy(self._variant)

        environment_params = variant['environment_params']
        training_environment = self.training_environment = (
            get_environment_from_params(environment_params['training']))
        evaluation_environment = self.evaluation_environment = (
            get_environment_from_params(environment_params['evaluation'])
            if 'evaluation' in environment_params
            else training_environment)

        training_environment.seed(variant['run_params']['seed'])
        evaluation_environment.seed(variant['run_params']['seed'])

        replay_pool = self.replay_pool = (
            get_replay_pool_from_variant(variant, training_environment))
        sampler = self.sampler = get_sampler_from_variant(variant)
        Qs = self.Qs = get_Q_function_from_variant(
            variant, training_environment)
        policy = self.policy = get_policy_from_variant(
            variant, training_environment)

        initial_exploration_policy = self.initial_exploration_policy = (
            get_policy_from_params(
                variant['exploration_policy_params'], training_environment))

        self.algorithm = get_algorithm_from_variant(
            variant=self._variant,
            training_environment=training_environment,
            evaluation_environment=evaluation_environment,
            policy=policy,
            initial_exploration_policy=initial_exploration_policy,
            Qs=Qs,
            pool=replay_pool,
            sampler=sampler,
            session=self._session)

        initialize_tf_variables(self._session, only_uninitialized=True)

        self._built = True
Ejemplo n.º 4
0
    def _build(self):
        variant = copy.deepcopy(self._variant)

        #training_environment = self.training_environment = (
        #    get_goal_example_environment_from_variant(
        #        variant['task'], gym_adapter=False))

        training_environment = self.training_environment = (GymAdapter(
            domain=variant['domain'],
            task=variant['task'],
            **variant['env_params']))

        #evaluation_environment = self.evaluation_environment = (
        #    get_goal_example_environment_from_variant(
        #        variant['task_evaluation'], gym_adapter=False))
        evaluation_environment = self.evaluation_environment = (GymAdapter(
            domain=variant['domain'],
            task=variant['task_evaluation'],
            **variant['env_params']))

        # training_environment = self.training_environment = (
        #     flatten_multiworld_env(self.training_environment))
        # evaluation_environment = self.evaluation_environment = (
        #     flatten_multiworld_env(self.evaluation_environment))
        #training_environment = self.training_environment = (
        #        GymAdapter(env=training_environment))
        #evaluation_environment = self.evaluation_environment = (
        #        GymAdapter(env=evaluation_environment))

        # make sure this is her replay pool
        replay_pool = self.replay_pool = (get_replay_pool_from_variant(
            variant, training_environment))
        sampler = self.sampler = get_sampler_from_variant(variant)
        Qs = self.Qs = get_Q_function_from_variant(variant,
                                                   training_environment)
        policy = self.policy = get_policy_from_variant(variant,
                                                       training_environment)
        initial_exploration_policy = self.initial_exploration_policy = (
            get_policy_from_params(variant['exploration_policy_params'],
                                   training_environment))

        algorithm_kwargs = {
            'variant': self._variant,
            'training_environment': self.training_environment,
            'evaluation_environment': self.evaluation_environment,
            'policy': policy,
            'initial_exploration_policy': initial_exploration_policy,
            'Qs': Qs,
            'pool': replay_pool,
            'sampler': sampler,
            'session': self._session,
        }

        if self._variant['algorithm_params']['type'] in [
                'VICEGoalConditioned', 'VICEGANGoalConditioned'
        ]:
            reward_classifier = self.reward_classifier = (
                get_reward_classifier_from_variant(self._variant,
                                                   training_environment))
            algorithm_kwargs['classifier'] = reward_classifier

            # goal_examples_train, goal_examples_validation = \
            #     get_goal_example_from_variant(variant)
            algorithm_kwargs['goal_examples'] = np.empty((1, 1))
            algorithm_kwargs['goal_examples_validation'] = np.empty((1, 1))

        # RND
        if variant['algorithm_params']['rnd_params']:
            from softlearning.rnd.utils import get_rnd_networks_from_variant
            rnd_networks = get_rnd_networks_from_variant(
                variant, training_environment)
        else:
            rnd_networks = ()
        algorithm_kwargs['rnd_networks'] = rnd_networks

        self.algorithm = get_algorithm_from_variant(**algorithm_kwargs)

        initialize_tf_variables(self._session, only_uninitialized=True)

        self._built = True
Ejemplo n.º 5
0
    def _restore(self, checkpoint_dir):
        assert isinstance(checkpoint_dir, str), checkpoint_dir

        checkpoint_dir = checkpoint_dir.rstrip('/')

        with self._session.as_default():
            pickle_path = self._pickle_path(checkpoint_dir)
            with open(pickle_path, 'rb') as f:
                picklable = pickle.load(f)

        training_environment = self.training_environment = picklable[
            'training_environment']
        evaluation_environment = self.evaluation_environment = picklable[
            'evaluation_environment']

        replay_pool = self.replay_pool = (get_replay_pool_from_variant(
            self._variant, training_environment))

        if self._variant['run_params'].get('checkpoint_replay_pool', False):
            self._restore_replay_pool(checkpoint_dir)

        sampler = self.sampler = picklable['sampler']
        Qs = self.Qs = picklable['Qs']
        # policy = self.policy = picklable['policy']
        policy = self.policy = (get_policy_from_variant(
            self._variant, training_environment, Qs))
        self.policy.set_weights(picklable['policy_weights'])

        initial_exploration_policy = self.initial_exploration_policy = (
            get_policy_from_params(variant['exploration_policy_params'],
                                   training_environment))

        algorithm_kwargs = {
            'variant': self._variant,
            'training_environment': self.training_environment,
            'evaluation_environment': self.evaluation_environment,
            'policy': policy,
            'initial_exploration_policy': initial_exploration_policy,
            'Qs': Qs,
            'pool': replay_pool,
            'sampler': sampler,
            'session': self._session,
        }

        if self._variant['algorithm_params']['type'] in ('SACClassifier',
                                                         'RAQ', 'VICE',
                                                         'VICERAQ'):
            reward_classifier = self.reward_classifier = picklable[
                'reward_classifier']
            algorithm_kwargs['classifier'] = reward_classifier

            # goal_examples_train, goal_examples_validation = \
            #     get_goal_example_from_variant(variant)
            # algorithm_kwargs['goal_examples'] = goal_examples_train
            # algorithm_kwargs['goal_examples_validation'] = \
            #     goal_examples_validation

        self.algorithm = get_algorithm_from_variant(**algorithm_kwargs)
        self.algorithm.__setstate__(picklable['algorithm'].__getstate__())

        tf_checkpoint = self._get_tf_checkpoint()
        status = tf_checkpoint.restore(
            tf.train.latest_checkpoint(
                os.path.split(self._tf_checkpoint_prefix(checkpoint_dir))[0]))

        status.assert_consumed().run_restore_ops(self._session)
        initialize_tf_variables(self._session, only_uninitialized=True)

        # TODO(hartikainen): target Qs should either be checkpointed or pickled.
        for Q, Q_target in zip(self.algorithm._Qs, self.algorithm._Q_targets):
            Q_target.set_weights(Q.get_weights())

        self._built = True