Beispiel #1
0
class MBPO(RLAlgorithm):
    """Model-Based Policy Optimization (MBPO)

    References
    ----------
        Michael Janner, Justin Fu, Marvin Zhang, Sergey Levine. 
        When to Trust Your Model: Model-Based Policy Optimization. 
        arXiv preprint arXiv:1906.08253. 2019.
    """
    def __init__(
        self,
        training_environment,
        evaluation_environment,
        policy,
        Qs,
        pool,
        static_fns,
        plotter=None,
        tf_summaries=False,
        lr=3e-4,
        reward_scale=1.0,
        target_entropy='auto',
        discount=0.99,
        tau=5e-3,
        target_update_interval=1,
        action_prior='uniform',
        reparameterize=False,
        store_extra_policy_info=False,
        deterministic=False,
        model_train_freq=250,
        num_networks=7,
        num_elites=5,
        model_retain_epochs=20,
        load_model_dir=None,
        rollout_batch_size=100e3,
        real_ratio=0.1,
        rollout_schedule=[20, 100, 1, 1],
        hidden_dim=200,
        max_model_t=None,
        **kwargs,
    ):
        """
        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy: A policy function approximator.
            initial_exploration_policy: ('Policy'): A policy that we use
                for initial exploration which is not trained by the algorithm.
            Qs: Q-function approximators. The min of these
                approximators will be used. Usage of at least two Q-functions
                improves performance by reducing overestimation bias.
            pool (`PoolBase`): Replay pool to add gathered samples to.
            plotter (`QFPolicyPlotter`): Plotter instance to be used for
                visualizing Q-function during training.
            lr (`float`): Learning rate used for the function approximators.
            discount (`float`): Discount factor for Q-function updates.
            tau (`float`): Soft value function target update weight.
            target_update_interval ('int'): Frequency at which target network
                updates occur in iterations.
            reparameterize ('bool'): If True, we use a gradient estimator for
                the policy derived using the reparameterization trick. We use
                a likelihood ratio based estimator otherwise.
        """

        super(MBPO, self).__init__(**kwargs)

        obs_dim = np.prod(training_environment.observation_space.shape)
        act_dim = np.prod(training_environment.action_space.shape)

        self._model_params = dict(obs_dim=obs_dim,
                                  act_dim=act_dim,
                                  hidden_dim=hidden_dim,
                                  num_networks=num_networks,
                                  num_elites=num_elites)
        if load_model_dir is not None:
            self._model_params['load_model'] = True
            self._model_params['model_dir'] = load_model_dir
        self._model = construct_model(**self._model_params)

        self._static_fns = static_fns
        self.fake_env = FakeEnv(self._model, self._static_fns)

        self._rollout_schedule = rollout_schedule
        self._max_model_t = max_model_t

        self._model_retain_epochs = model_retain_epochs

        self._model_train_freq = model_train_freq
        self._rollout_batch_size = int(rollout_batch_size)
        self._deterministic = deterministic
        self._real_ratio = real_ratio

        self._log_dir = os.getcwd()
        self._writer = Writer(self._log_dir)

        self._training_environment = training_environment
        self._evaluation_environment = evaluation_environment
        self._policy = policy

        self._Qs = Qs
        self._Q_targets = tuple(tf.keras.models.clone_model(Q) for Q in Qs)

        self._pool = pool
        self._plotter = plotter
        self._tf_summaries = tf_summaries

        self._policy_lr = lr
        self._Q_lr = lr

        self._reward_scale = reward_scale
        self._target_entropy = (
            -np.prod(self._training_environment.action_space.shape)
            if target_entropy == 'auto' else target_entropy)
        print('[ MBPO ] Target entropy: {}'.format(self._target_entropy))

        self._discount = discount
        self._tau = tau
        self._target_update_interval = target_update_interval
        self._action_prior = action_prior

        self._reparameterize = reparameterize
        self._store_extra_policy_info = store_extra_policy_info

        observation_shape = self._training_environment.active_observation_shape
        action_shape = self._training_environment.action_space.shape

        ### @ anyboby fixed pool size, reallocate causes memory leak
        obs_space = self._pool._observation_space
        act_space = self._pool._action_space
        rollouts_per_epoch = self._rollout_batch_size * self._epoch_length / self._model_train_freq
        model_steps_per_epoch = int(self._rollout_schedule[-1] *
                                    rollouts_per_epoch)
        mpool_size = self._model_retain_epochs * model_steps_per_epoch

        self._model_pool = SimpleReplayPool(obs_space, act_space, mpool_size)

        assert len(observation_shape) == 1, observation_shape
        self._observation_shape = observation_shape
        assert len(action_shape) == 1, action_shape
        self._action_shape = action_shape

        self._build()

    def _build(self):
        self._training_ops = {}

        self._init_global_step()
        self._init_placeholders()
        self._init_actor_update()
        self._init_critic_update()

    def _train(self):
        """Return a generator that performs RL training.

        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy (`Policy`): Policy used for training
            initial_exploration_policy ('Policy'): Policy used for exploration
                If None, then all exploration is done using policy
            pool (`PoolBase`): Sample pool to add samples to
        """
        training_environment = self._training_environment
        evaluation_environment = self._evaluation_environment
        policy = self._policy
        pool = self._pool
        model_metrics = {}

        if not self._training_started:
            self._init_training()

            self._initial_exploration_hook(training_environment,
                                           self._initial_exploration_policy,
                                           pool)

        self.sampler.initialize(training_environment, policy, pool)

        gt.reset_root()
        gt.rename_root('RLAlgorithm')
        gt.set_def_unique(False)

        self._training_before_hook()

        for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)):

            self._epoch_before_hook()
            gt.stamp('epoch_before_hook')

            self._training_progress = Progress(self._epoch_length *
                                               self._n_train_repeat)
            start_samples = self.sampler._total_samples
            for i in count():
                samples_now = self.sampler._total_samples
                self._timestep = samples_now - start_samples

                if (samples_now >= start_samples + self._epoch_length
                        and self.ready_to_train):
                    break

                self._timestep_before_hook()
                gt.stamp('timestep_before_hook')

                if self._timestep % self._model_train_freq == 0 and self._real_ratio < 1.0:
                    self._training_progress.pause()
                    print('[ MBPO ] log_dir: {} | ratio: {}'.format(
                        self._log_dir, self._real_ratio))
                    print(
                        '[ MBPO ] Training model at epoch {} | freq {} | timestep {} (total: {}) | epoch train steps: {} (total: {})'
                        .format(self._epoch, self._model_train_freq,
                                self._timestep, self._total_timestep,
                                self._train_steps_this_epoch,
                                self._num_train_steps))

                    model_train_metrics = self._train_model(
                        batch_size=256,
                        max_epochs=None,
                        holdout_ratio=0.2,
                        max_t=self._max_model_t)
                    model_metrics.update(model_train_metrics)
                    gt.stamp('epoch_train_model')

                    self._set_rollout_length()
                    if self._rollout_batch_size > 30000:
                        factor = self._rollout_batch_size // 30000 + 1
                        mini_batch = self._rollout_batch_size // factor
                        for i in range(factor):
                            model_rollout_metrics = self._rollout_model(
                                rollout_batch_size=mini_batch,
                                deterministic=self._deterministic)
                    else:
                        model_rollout_metrics = self._rollout_model(
                            rollout_batch_size=self._rollout_batch_size,
                            deterministic=self._deterministic)
                    model_metrics.update(model_rollout_metrics)

                    gt.stamp('epoch_rollout_model')
                    # self._visualize_model(self._evaluation_environment, self._total_timestep)
                    self._training_progress.resume()

                self._do_sampling(timestep=self._total_timestep)
                gt.stamp('sample')

                if self.ready_to_train:
                    self._do_training_repeats(timestep=self._total_timestep)
                gt.stamp('train')

                self._timestep_after_hook()
                gt.stamp('timestep_after_hook')

            training_paths = self.sampler.get_last_n_paths(
                math.ceil(self._epoch_length / self.sampler._max_path_length))
            gt.stamp('training_paths')
            evaluation_paths = self._evaluation_paths(policy,
                                                      evaluation_environment)
            gt.stamp('evaluation_paths')

            training_metrics = self._evaluate_rollouts(training_paths,
                                                       training_environment)
            gt.stamp('training_metrics')
            if evaluation_paths:
                evaluation_metrics = self._evaluate_rollouts(
                    evaluation_paths, evaluation_environment)
                gt.stamp('evaluation_metrics')
            else:
                evaluation_metrics = {}

            self._epoch_after_hook(training_paths)
            gt.stamp('epoch_after_hook')

            sampler_diagnostics = self.sampler.get_diagnostics()

            diagnostics = self.get_diagnostics(
                iteration=self._total_timestep,
                batch=self._evaluation_batch(),
                training_paths=training_paths,
                evaluation_paths=evaluation_paths)

            time_diagnostics = gt.get_times().stamps.itrs

            diagnostics.update(
                OrderedDict((
                    *((f'evaluation/{key}', evaluation_metrics[key])
                      for key in sorted(evaluation_metrics.keys())),
                    *((f'training/{key}', training_metrics[key])
                      for key in sorted(training_metrics.keys())),
                    *((f'times/{key}', time_diagnostics[key][-1])
                      for key in sorted(time_diagnostics.keys())),
                    *((f'sampler/{key}', sampler_diagnostics[key])
                      for key in sorted(sampler_diagnostics.keys())),
                    *((f'model/{key}', model_metrics[key])
                      for key in sorted(model_metrics.keys())),
                    ('epoch', self._epoch),
                    ('timestep', self._timestep),
                    ('timesteps_total', self._total_timestep),
                    ('train-steps', self._num_train_steps),
                )))

            if self._eval_render_mode is not None and hasattr(
                    evaluation_environment, 'render_rollouts'):
                training_environment.render_rollouts(evaluation_paths)

            yield diagnostics

        self.sampler.terminate()

        self._training_after_hook()

        self._training_progress.close()

        yield {'done': True, **diagnostics}

    def train(self, *args, **kwargs):
        return self._train(*args, **kwargs)

    def _log_policy(self):
        save_path = os.path.join(self._log_dir, 'models')
        filesystem.mkdir(save_path)
        weights = self._policy.get_weights()
        data = {'policy_weights': weights}
        full_path = os.path.join(save_path,
                                 'policy_{}.pkl'.format(self._total_timestep))
        print('Saving policy to: {}'.format(full_path))
        pickle.dump(data, open(full_path, 'wb'))

    def _log_model(self):
        save_path = os.path.join(self._log_dir, 'models')
        filesystem.mkdir(save_path)
        print('Saving model to: {}'.format(save_path))
        self._model.save(save_path, self._total_timestep)

    def _set_rollout_length(self):
        min_epoch, max_epoch, min_length, max_length = self._rollout_schedule
        if self._epoch <= min_epoch:
            y = min_length
        else:
            dx = (self._epoch - min_epoch) / (max_epoch - min_epoch)
            dx = min(dx, 1)
            y = dx * (max_length - min_length) + min_length

        self._rollout_length = int(y)
        print(
            '[ Model Length ] Epoch: {} (min: {}, max: {}) | Length: {} (min: {} , max: {})'
            .format(self._epoch, min_epoch, max_epoch, self._rollout_length,
                    min_length, max_length))

    def _reallocate_model_pool(self):
        obs_space = self._pool._observation_space
        act_space = self._pool._action_space

        rollouts_per_epoch = self._rollout_batch_size * self._epoch_length / self._model_train_freq
        model_steps_per_epoch = int(self._rollout_length * rollouts_per_epoch)
        new_pool_size = self._model_retain_epochs * model_steps_per_epoch

        if not hasattr(self, '_model_pool'):
            print(
                '[ MBPO ] Initializing new model pool with size {:.2e}'.format(
                    new_pool_size))
            self._model_pool = SimpleReplayPool(obs_space, act_space,
                                                new_pool_size)

        elif self._model_pool._max_size != new_pool_size:
            print('[ MBPO ] Updating model pool | {:.2e} --> {:.2e}'.format(
                self._model_pool._max_size, new_pool_size))
            samples = self._model_pool.return_all_samples()
            new_pool = SimpleReplayPool(obs_space, act_space, new_pool_size)
            new_pool.add_samples(samples)
            assert self._model_pool.size == new_pool.size
            self._model_pool = new_pool

    def _train_model(self, **kwargs):
        env_samples = self._pool.return_all_samples()
        train_inputs, train_outputs = format_samples_for_training(env_samples)
        model_metrics = self._model.train(train_inputs, train_outputs,
                                          **kwargs)
        return model_metrics

    def _rollout_model(self, rollout_batch_size, **kwargs):
        print(
            '[ Model Rollout ] Starting | Epoch: {} | Rollout length: {} | Batch size: {}'
            .format(self._epoch, self._rollout_length, rollout_batch_size))
        batch = self.sampler.random_batch(rollout_batch_size)
        obs = batch['observations']
        steps_added = []
        for i in range(self._rollout_length):
            act = self._policy.actions_np(obs)

            next_obs, rew, term, info = self.fake_env.step(obs, act, **kwargs)
            steps_added.append(len(obs))

            samples = {
                'observations': obs,
                'actions': act,
                'next_observations': next_obs,
                'rewards': rew,
                'terminals': term
            }
            self._model_pool.add_samples(samples)

            nonterm_mask = ~term.squeeze(-1)
            if nonterm_mask.sum() == 0:
                print('[ Model Rollout ] Breaking early: {} | {} / {}'.format(
                    i, nonterm_mask.sum(), nonterm_mask.shape))
                break

            obs = next_obs[nonterm_mask]

        mean_rollout_length = sum(steps_added) / rollout_batch_size
        rollout_stats = {'mean_rollout_length': mean_rollout_length}
        print(
            '[ Model Rollout ] Added: {:.1e} | Model pool: {:.1e} (max {:.1e}) | Length: {} | Train rep: {}'
            .format(sum(steps_added), self._model_pool.size,
                    self._model_pool._max_size, mean_rollout_length,
                    self._n_train_repeat))
        return rollout_stats

    def _visualize_model(self, env, timestep):
        ## save env state
        state = env.unwrapped.state_vector()
        qpos_dim = len(env.unwrapped.sim.data.qpos)
        qpos = state[:qpos_dim]
        qvel = state[qpos_dim:]

        print('[ Visualization ] Starting | Epoch {} | Log dir: {}\n'.format(
            self._epoch, self._log_dir))
        visualize_policy(env, self.fake_env, self._policy, self._writer,
                         timestep)
        print('[ Visualization ] Done')
        ## set env state
        env.unwrapped.set_state(qpos, qvel)

    def _training_batch(self, batch_size=None):
        batch_size = batch_size or self.sampler._batch_size
        env_batch_size = int(batch_size * self._real_ratio)
        model_batch_size = batch_size - env_batch_size

        ## can sample from the env pool even if env_batch_size == 0
        env_batch = self._pool.random_batch(env_batch_size)

        if model_batch_size > 0:
            model_batch = self._model_pool.random_batch(model_batch_size)

            keys = env_batch.keys()
            batch = {
                k: np.concatenate((env_batch[k], model_batch[k]), axis=0)
                for k in keys
            }
        else:
            ## if real_ratio == 1.0, no model pool was ever allocated,
            ## so skip the model pool sampling
            batch = env_batch
        return batch

    def _init_global_step(self):
        self.global_step = training_util.get_or_create_global_step()
        self._training_ops.update(
            {'increment_global_step': training_util._increment_global_step(1)})

    def _init_placeholders(self):
        """Create input placeholders for the SAC algorithm.

        Creates `tf.placeholder`s for:
            - observation
            - next observation
            - action
            - reward
            - terminals
        """
        self._iteration_ph = tf.placeholder(tf.int64,
                                            shape=None,
                                            name='iteration')

        self._observations_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._observation_shape),
            name='observation',
        )

        self._next_observations_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._observation_shape),
            name='next_observation',
        )

        self._actions_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._action_shape),
            name='actions',
        )

        self._rewards_ph = tf.placeholder(
            tf.float32,
            shape=(None, 1),
            name='rewards',
        )

        self._terminals_ph = tf.placeholder(
            tf.float32,
            shape=(None, 1),
            name='terminals',
        )

        if self._store_extra_policy_info:
            self._log_pis_ph = tf.placeholder(
                tf.float32,
                shape=(None, 1),
                name='log_pis',
            )
            self._raw_actions_ph = tf.placeholder(
                tf.float32,
                shape=(None, *self._action_shape),
                name='raw_actions',
            )

    def _get_Q_target(self):
        next_actions = self._policy.actions([self._next_observations_ph])
        next_log_pis = self._policy.log_pis([self._next_observations_ph],
                                            next_actions)

        next_Qs_values = tuple(
            Q([self._next_observations_ph, next_actions])
            for Q in self._Q_targets)

        min_next_Q = tf.reduce_min(next_Qs_values, axis=0)
        next_value = min_next_Q - self._alpha * next_log_pis

        Q_target = td_target(reward=self._reward_scale * self._rewards_ph,
                             discount=self._discount,
                             next_value=(1 - self._terminals_ph) * next_value)

        return Q_target

    def _init_critic_update(self):
        """Create minimization operation for critic Q-function.

        Creates a `tf.optimizer.minimize` operation for updating
        critic Q-function with gradient descent, and appends it to
        `self._training_ops` attribute.
        """
        Q_target = tf.stop_gradient(self._get_Q_target())

        assert Q_target.shape.as_list() == [None, 1]

        Q_values = self._Q_values = tuple(
            Q([self._observations_ph, self._actions_ph]) for Q in self._Qs)

        Q_losses = self._Q_losses = tuple(
            tf.losses.mean_squared_error(
                labels=Q_target, predictions=Q_value, weights=0.5)
            for Q_value in Q_values)

        self._Q_optimizers = tuple(
            tf.train.AdamOptimizer(learning_rate=self._Q_lr,
                                   name='{}_{}_optimizer'.format(Q._name, i))
            for i, Q in enumerate(self._Qs))
        Q_training_ops = tuple(
            tf.contrib.layers.optimize_loss(Q_loss,
                                            self.global_step,
                                            learning_rate=self._Q_lr,
                                            optimizer=Q_optimizer,
                                            variables=Q.trainable_variables,
                                            increment_global_step=False,
                                            summaries=((
                                                "loss", "gradients",
                                                "gradient_norm",
                                                "global_gradient_norm"
                                            ) if self._tf_summaries else ()))
            for i, (Q, Q_loss, Q_optimizer) in enumerate(
                zip(self._Qs, Q_losses, self._Q_optimizers)))

        self._training_ops.update({'Q': tf.group(Q_training_ops)})

    def _init_actor_update(self):
        """Create minimization operations for policy and entropy.

        Creates a `tf.optimizer.minimize` operations for updating
        policy and entropy with gradient descent, and adds them to
        `self._training_ops` attribute.
        """

        actions = self._policy.actions([self._observations_ph])
        log_pis = self._policy.log_pis([self._observations_ph], actions)

        assert log_pis.shape.as_list() == [None, 1]

        log_alpha = self._log_alpha = tf.get_variable('log_alpha',
                                                      dtype=tf.float32,
                                                      initializer=0.0)
        alpha = tf.exp(log_alpha)

        if isinstance(self._target_entropy, Number):
            alpha_loss = -tf.reduce_mean(
                log_alpha * tf.stop_gradient(log_pis + self._target_entropy))

            self._alpha_optimizer = tf.train.AdamOptimizer(
                self._policy_lr, name='alpha_optimizer')
            self._alpha_train_op = self._alpha_optimizer.minimize(
                loss=alpha_loss, var_list=[log_alpha])

            self._training_ops.update(
                {'temperature_alpha': self._alpha_train_op})

        self._alpha = alpha

        if self._action_prior == 'normal':
            policy_prior = tf.contrib.distributions.MultivariateNormalDiag(
                loc=tf.zeros(self._action_shape),
                scale_diag=tf.ones(self._action_shape))
            policy_prior_log_probs = policy_prior.log_prob(actions)
        elif self._action_prior == 'uniform':
            policy_prior_log_probs = 0.0

        Q_log_targets = tuple(
            Q([self._observations_ph, actions]) for Q in self._Qs)
        min_Q_log_target = tf.reduce_min(Q_log_targets, axis=0)

        if self._reparameterize:
            policy_kl_losses = (alpha * log_pis - min_Q_log_target -
                                policy_prior_log_probs)
        else:
            raise NotImplementedError

        assert policy_kl_losses.shape.as_list() == [None, 1]

        policy_loss = tf.reduce_mean(policy_kl_losses)

        self._policy_optimizer = tf.train.AdamOptimizer(
            learning_rate=self._policy_lr, name="policy_optimizer")
        policy_train_op = tf.contrib.layers.optimize_loss(
            policy_loss,
            self.global_step,
            learning_rate=self._policy_lr,
            optimizer=self._policy_optimizer,
            variables=self._policy.trainable_variables,
            increment_global_step=False,
            summaries=("loss", "gradients", "gradient_norm",
                       "global_gradient_norm") if self._tf_summaries else ())

        self._training_ops.update({'policy_train_op': policy_train_op})

    def _init_training(self):
        self._update_target(tau=1.0)

    def _update_target(self, tau=None):
        tau = tau or self._tau

        for Q, Q_target in zip(self._Qs, self._Q_targets):
            source_params = Q.get_weights()
            target_params = Q_target.get_weights()
            Q_target.set_weights([
                tau * source + (1.0 - tau) * target
                for source, target in zip(source_params, target_params)
            ])

    def _do_training(self, iteration, batch):
        """Runs the operations for updating training and target ops."""

        self._training_progress.update()
        self._training_progress.set_description()

        feed_dict = self._get_feed_dict(iteration, batch)

        self._session.run(self._training_ops, feed_dict)

        if iteration % self._target_update_interval == 0:
            # Run target ops here.
            self._update_target()

    def _get_feed_dict(self, iteration, batch):
        """Construct TensorFlow feed_dict from sample batch."""

        feed_dict = {
            self._observations_ph: batch['observations'],
            self._actions_ph: batch['actions'],
            self._next_observations_ph: batch['next_observations'],
            self._rewards_ph: batch['rewards'],
            self._terminals_ph: batch['terminals'],
        }

        if self._store_extra_policy_info:
            feed_dict[self._log_pis_ph] = batch['log_pis']
            feed_dict[self._raw_actions_ph] = batch['raw_actions']

        if iteration is not None:
            feed_dict[self._iteration_ph] = iteration

        return feed_dict

    def get_diagnostics(self, iteration, batch, training_paths,
                        evaluation_paths):
        """Return diagnostic information as ordered dictionary.

        Records mean and standard deviation of Q-function and state
        value function, and TD-loss (mean squared Bellman error)
        for the sample batch.

        Also calls the `draw` method of the plotter, if plotter defined.
        """

        feed_dict = self._get_feed_dict(iteration, batch)

        (Q_values, Q_losses, alpha, global_step) = self._session.run(
            (self._Q_values, self._Q_losses, self._alpha, self.global_step),
            feed_dict)

        diagnostics = OrderedDict({
            'Q-avg': np.mean(Q_values),
            'Q-std': np.std(Q_values),
            'Q_loss': np.mean(Q_losses),
            'alpha': alpha,
        })

        policy_diagnostics = self._policy.get_diagnostics(
            batch['observations'])
        diagnostics.update({
            f'policy/{key}': value
            for key, value in policy_diagnostics.items()
        })

        if self._plotter:
            self._plotter.draw()

        return diagnostics

    @property
    def tf_saveables(self):
        saveables = {
            '_policy_optimizer': self._policy_optimizer,
            **{
                f'Q_optimizer_{i}': optimizer
                for i, optimizer in enumerate(self._Q_optimizers)
            },
            '_log_alpha': self._log_alpha,
        }

        if hasattr(self, '_alpha_optimizer'):
            saveables['_alpha_optimizer'] = self._alpha_optimizer

        return saveables

    def save_model(self, dir):
        self._model.save(savedir=dir, timestep=self._epoch + 1)
Beispiel #2
0
    def train(self,
              inputs,
              targets,
              batch_size=32,
              max_epochs=None,
              max_epochs_since_update=5,
              hide_progress=False,
              holdout_ratio=0.0,
              max_logging=5000,
              max_grad_updates=None,
              timer=None,
              max_t=None):
        """Trains/Continues network training

        Arguments:
            inputs (np.ndarray): Network inputs in the training dataset in rows.
            targets (np.ndarray): Network target outputs in the training dataset in rows corresponding
                to the rows in inputs.
            batch_size (int): The minibatch size to be used for training.
            epochs (int): Number of epochs (full network passes that will be done.
            hide_progress (bool): If True, hides the progress bar shown at the beginning of training.

        Returns: None
        """
        self._max_epochs_since_update = max_epochs_since_update
        self._start_train()
        break_train = False

        def shuffle_rows(arr):
            idxs = np.argsort(np.random.uniform(size=arr.shape), axis=-1)
            return arr[np.arange(arr.shape[0])[:, None], idxs]

        # Split into training and holdout sets
        num_holdout = min(int(inputs.shape[0] * holdout_ratio), max_logging)
        permutation = np.random.permutation(inputs.shape[0])
        inputs, holdout_inputs = inputs[permutation[num_holdout:]], inputs[
            permutation[:num_holdout]]
        targets, holdout_targets = targets[permutation[num_holdout:]], targets[
            permutation[:num_holdout]]
        holdout_inputs = np.tile(holdout_inputs[None], [self.num_nets, 1, 1])
        holdout_targets = np.tile(holdout_targets[None], [self.num_nets, 1, 1])

        print('[ BNN ] Training {} | Holdout: {}'.format(
            inputs.shape, holdout_inputs.shape))
        with self.sess.as_default():
            self.scaler.fit(inputs)

        idxs = np.random.randint(inputs.shape[0],
                                 size=[self.num_nets, inputs.shape[0]])
        if hide_progress:
            progress = Silent()
        else:
            progress = Progress(max_epochs)

        if max_epochs:
            epoch_iter = range(max_epochs)
        else:
            epoch_iter = itertools.count()

        # else:
        #     epoch_range = trange(epochs, unit="epoch(s)", desc="Network training")

        t0 = time.time()
        grad_updates = 0
        for epoch in epoch_iter:
            for batch_num in range(int(np.ceil(idxs.shape[-1] / batch_size))):
                batch_idxs = idxs[:, batch_num * batch_size:(batch_num + 1) *
                                  batch_size]
                self.sess.run(self.train_op,
                              feed_dict={
                                  self.sy_train_in: inputs[batch_idxs],
                                  self.sy_train_targ: targets[batch_idxs]
                              })
                grad_updates += 1

            idxs = shuffle_rows(idxs)
            if not hide_progress:
                if holdout_ratio < 1e-12:
                    losses = self.sess.run(self.mse_loss,
                                           feed_dict={
                                               self.sy_train_in:
                                               inputs[idxs[:, :max_logging]],
                                               self.sy_train_targ:
                                               targets[idxs[:, :max_logging]]
                                           })
                    named_losses = [['M{}'.format(i), losses[i]]
                                    for i in range(len(losses))]
                    progress.set_description(named_losses)
                else:
                    losses = self.sess.run(self.mse_loss,
                                           feed_dict={
                                               self.sy_train_in:
                                               inputs[idxs[:, :max_logging]],
                                               self.sy_train_targ:
                                               targets[idxs[:, :max_logging]]
                                           })
                    holdout_losses = self.sess.run(self.mse_loss,
                                                   feed_dict={
                                                       self.sy_train_in:
                                                       holdout_inputs,
                                                       self.sy_train_targ:
                                                       holdout_targets
                                                   })
                    named_losses = [['M{}'.format(i), losses[i]]
                                    for i in range(len(losses))]
                    named_holdout_losses = [[
                        'V{}'.format(i), holdout_losses[i]
                    ] for i in range(len(holdout_losses))]
                    named_losses = named_losses + named_holdout_losses + [[
                        'T', time.time() - t0
                    ]]
                    progress.set_description(named_losses)

                    break_train = self._save_best(epoch, holdout_losses)

            progress.update()
            t = time.time() - t0
            if break_train or (max_grad_updates
                               and grad_updates > max_grad_updates):
                break
            if max_t and t > max_t:
                descr = 'Breaking because of timeout: {}! (max: {})'.format(
                    t, max_t)
                progress.append_description(descr)
                # print('Breaking because of timeout: {}! | (max: {})\n'.format(t, max_t))
                # time.sleep(5)
                break

        progress.stamp()
        if timer: timer.stamp('bnn_train')

        self._set_state()
        if timer: timer.stamp('bnn_set_state')

        holdout_losses = self.sess.run(self.mse_loss,
                                       feed_dict={
                                           self.sy_train_in: holdout_inputs,
                                           self.sy_train_targ: holdout_targets
                                       })

        if timer: timer.stamp('bnn_holdout')

        self._end_train(holdout_losses)
        if timer: timer.stamp('bnn_end')

        val_loss = (np.sort(holdout_losses)[:self.num_elites]).mean()
        model_metrics = {'val_loss': val_loss}
        print('[ BNN ] Holdout', np.sort(holdout_losses), model_metrics)
        return OrderedDict(model_metrics)
Beispiel #3
0
    def _train(self):
        """Return a generator that performs RL training.

        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy (`Policy`): Policy used for training
            initial_exploration_policy ('Policy'): Policy used for exploration
                If None, then all exploration is done using policy
            pool (`PoolBase`): Sample pool to add samples to
        """
        training_environment = self._training_environment
        evaluation_environment = self._evaluation_environment
        policy = self._policy
        pool = self._pool
        model_metrics = {}

        if not self._training_started:
            self._init_training()

            self._initial_exploration_hook(training_environment,
                                           self._initial_exploration_policy,
                                           pool)

        self.sampler.initialize(training_environment, policy, pool)

        gt.reset_root()
        gt.rename_root('RLAlgorithm')
        gt.set_def_unique(False)

        self._training_before_hook()

        for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)):

            self._epoch_before_hook()
            gt.stamp('epoch_before_hook')

            self._training_progress = Progress(self._epoch_length *
                                               self._n_train_repeat)
            start_samples = self.sampler._total_samples
            for i in count():
                samples_now = self.sampler._total_samples
                self._timestep = samples_now - start_samples

                if (samples_now >= start_samples + self._epoch_length
                        and self.ready_to_train):
                    break

                self._timestep_before_hook()
                gt.stamp('timestep_before_hook')

                if self._timestep % self._model_train_freq == 0 and self._real_ratio < 1.0:
                    self._training_progress.pause()
                    print('[ MBPO ] log_dir: {} | ratio: {}'.format(
                        self._log_dir, self._real_ratio))
                    print(
                        '[ MBPO ] Training model at epoch {} | freq {} | timestep {} (total: {}) | epoch train steps: {} (total: {})'
                        .format(self._epoch, self._model_train_freq,
                                self._timestep, self._total_timestep,
                                self._train_steps_this_epoch,
                                self._num_train_steps))

                    model_train_metrics = self._train_model(
                        batch_size=256,
                        max_epochs=None,
                        holdout_ratio=0.2,
                        max_t=self._max_model_t)
                    model_metrics.update(model_train_metrics)
                    gt.stamp('epoch_train_model')

                    self._set_rollout_length()
                    if self._rollout_batch_size > 30000:
                        factor = self._rollout_batch_size // 30000 + 1
                        mini_batch = self._rollout_batch_size // factor
                        for i in range(factor):
                            model_rollout_metrics = self._rollout_model(
                                rollout_batch_size=mini_batch,
                                deterministic=self._deterministic)
                    else:
                        model_rollout_metrics = self._rollout_model(
                            rollout_batch_size=self._rollout_batch_size,
                            deterministic=self._deterministic)
                    model_metrics.update(model_rollout_metrics)

                    gt.stamp('epoch_rollout_model')
                    # self._visualize_model(self._evaluation_environment, self._total_timestep)
                    self._training_progress.resume()

                self._do_sampling(timestep=self._total_timestep)
                gt.stamp('sample')

                if self.ready_to_train:
                    self._do_training_repeats(timestep=self._total_timestep)
                gt.stamp('train')

                self._timestep_after_hook()
                gt.stamp('timestep_after_hook')

            training_paths = self.sampler.get_last_n_paths(
                math.ceil(self._epoch_length / self.sampler._max_path_length))
            gt.stamp('training_paths')
            evaluation_paths = self._evaluation_paths(policy,
                                                      evaluation_environment)
            gt.stamp('evaluation_paths')

            training_metrics = self._evaluate_rollouts(training_paths,
                                                       training_environment)
            gt.stamp('training_metrics')
            if evaluation_paths:
                evaluation_metrics = self._evaluate_rollouts(
                    evaluation_paths, evaluation_environment)
                gt.stamp('evaluation_metrics')
            else:
                evaluation_metrics = {}

            self._epoch_after_hook(training_paths)
            gt.stamp('epoch_after_hook')

            sampler_diagnostics = self.sampler.get_diagnostics()

            diagnostics = self.get_diagnostics(
                iteration=self._total_timestep,
                batch=self._evaluation_batch(),
                training_paths=training_paths,
                evaluation_paths=evaluation_paths)

            time_diagnostics = gt.get_times().stamps.itrs

            diagnostics.update(
                OrderedDict((
                    *((f'evaluation/{key}', evaluation_metrics[key])
                      for key in sorted(evaluation_metrics.keys())),
                    *((f'training/{key}', training_metrics[key])
                      for key in sorted(training_metrics.keys())),
                    *((f'times/{key}', time_diagnostics[key][-1])
                      for key in sorted(time_diagnostics.keys())),
                    *((f'sampler/{key}', sampler_diagnostics[key])
                      for key in sorted(sampler_diagnostics.keys())),
                    *((f'model/{key}', model_metrics[key])
                      for key in sorted(model_metrics.keys())),
                    ('epoch', self._epoch),
                    ('timestep', self._timestep),
                    ('timesteps_total', self._total_timestep),
                    ('train-steps', self._num_train_steps),
                )))

            if self._eval_render_mode is not None and hasattr(
                    evaluation_environment, 'render_rollouts'):
                training_environment.render_rollouts(evaluation_paths)

            yield diagnostics

        self.sampler.terminate()

        self._training_after_hook()

        self._training_progress.close()

        yield {'done': True, **diagnostics}
Beispiel #4
0
class CMBPO(RLAlgorithm):
    """Model-Based Policy Optimization (MBPO)

    References
    ----------
        Michael Janner, Justin Fu, Marvin Zhang, Sergey Levine. 
        When to Trust Your Model: Model-Based Policy Optimization. 
        arXiv preprint arXiv:1906.08253. 2019.
    """
    def __init__(
        self,
        training_environment,
        evaluation_environment,
        mjc_model_environment,
        policy,
        Qs,
        pool,
        static_fns,
        plotter=None,
        tf_summaries=False,
        n_env_interacts=1e7,
        lr=3e-4,
        reward_scale=1.0,
        target_entropy='auto',
        discount=0.99,
        tau=5e-3,
        target_update_interval=1,
        action_prior='uniform',
        reparameterize=False,
        store_extra_policy_info=False,
        eval_every_n_steps=5e3,
        deterministic=False,
        model_train_freq=250,
        num_networks=7,
        num_elites=5,
        model_retain_epochs=20,
        rollout_batch_size=100e3,
        real_ratio=0.1,
        dyn_model_train_schedule=[20, 100, 1, 5],
        cost_model_train_schedule=[20, 100, 1, 30],
        cares_about_cost=False,
        policy_alpha=1,
        max_uncertainty_rew=None,
        max_uncertainty_c=None,
        rollout_mode='schedule',
        rollout_schedule=[20, 100, 1, 1],
        maxroll=80,
        max_tddyn_err=1e-5,
        max_tddyn_err_decay=.995,
        min_real_samples_per_epoch=1000,
        batch_size_policy=5000,
        hidden_dims=(200, 200, 200, 200),
        max_model_t=None,
        use_mjc_state_model=False,
        model_std_inc=0.02,
        **kwargs,
    ):
        """
        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy: A policy function approximator.
            initial_exploration_policy: ('Policy'): A policy that we use
                for initial exploration which is not trained by the algorithm.
            Qs: Q-function approximators. The min of these
                approximators will be used. Usage of at least two Q-functions
                improves performance by reducing overestimation bias.
            pool (`PoolBase`): Replay pool to add gathered samples to.
            plotter (`QFPolicyPlotter`): Plotter instance to be used for
                visualizing Q-function during training.
            lr (`float`): Learning rate used for the function approximators.
            discount (`float`): Discount factor for Q-function updates.
            tau (`float`): Soft value function target update weight.
            target_update_interval ('int'): Frequency at which target network
                updates occur in iterations.
            reparameterize ('bool'): If True, we use a gradient estimator for
                the policy derived using the reparameterization trick. We use
                a likelihood ratio based estimator otherwise.
        """

        super(CMBPO, self).__init__(**kwargs)

        self.obs_space = training_environment.observation_space
        self.act_space = training_environment.action_space
        self.obs_dim = np.prod(training_environment.observation_space.shape)

        self.act_dim = np.prod(training_environment.action_space.shape)
        self.n_env_interacts = n_env_interacts
        #### determine unstacked obs dim
        self.num_stacks = training_environment.stacks
        self.stacking_axis = training_environment.stacking_axis
        self.active_obs_dim = int(self.obs_dim / self.num_stacks)

        self.policy_alpha = policy_alpha
        self.cares_about_cost = cares_about_cost

        ## create fake environment for model
        self.fake_env = FakeEnv(training_environment,
                                static_fns,
                                num_networks=7,
                                num_elites=3,
                                hidden_dims=hidden_dims,
                                cares_about_cost=cares_about_cost,
                                session=self._session)

        self.use_mjc_state_model = use_mjc_state_model

        self._rollout_schedule = rollout_schedule
        self._max_model_t = max_model_t

        self._model_retain_epochs = model_retain_epochs
        self.eval_every_n_steps = eval_every_n_steps

        self._dyn_model_train_schedule = dyn_model_train_schedule
        self._cost_model_train_schedule = cost_model_train_schedule

        self._dyn_model_train_freq = 1
        self._cost_model_train_freq = 1
        self._rollout_batch_size = int(rollout_batch_size)
        self._max_uncertainty_rew = max_uncertainty_rew
        self._max_uncertainty_c = max_uncertainty_c
        self._deterministic = deterministic
        self._real_ratio = real_ratio

        self._log_dir = os.getcwd()
        self._writer = Writer(self._log_dir)

        self._training_environment = training_environment
        self._evaluation_environment = evaluation_environment
        self._mjc_model_environment = mjc_model_environment
        self.perturbed_env = PerturbedEnv(self._mjc_model_environment,
                                          std_inc=model_std_inc)
        self._policy = policy
        self._initial_exploration_policy = policy  #overwriting initial _exploration policy, not implemented for cpo yet

        self._pool = pool
        self._plotter = plotter
        self._tf_summaries = tf_summaries

        # set up pool
        pi_info_shapes = {
            k: v.shape.as_list()[1:]
            for k, v in self._policy.pi_info_phs.items()
        }
        self._pool.initialize(pi_info_shapes,
                              gamma=self._policy.gamma,
                              lam=self._policy.lam,
                              cost_gamma=self._policy.cost_gamma,
                              cost_lam=self._policy.cost_lam)

        self._policy_lr = lr

        self._reward_scale = reward_scale
        self._target_entropy = (
            -np.prod(self._training_environment.action_space.shape)
            if target_entropy == 'auto' else target_entropy)
        print('[ MBPO ] Target entropy: {}'.format(self._target_entropy))

        self._discount = discount
        self._tau = tau
        self._target_update_interval = target_update_interval
        self._action_prior = action_prior

        self._reparameterize = reparameterize
        self._store_extra_policy_info = store_extra_policy_info

        observation_shape = self._training_environment.active_observation_shape
        action_shape = self._training_environment.action_space.shape

        assert len(observation_shape) == 1, observation_shape
        self._observation_shape = observation_shape
        assert len(action_shape) == 1, action_shape
        self._action_shape = action_shape

        ### model sampler and buffer
        self.rollout_mode = rollout_mode
        self.max_tddyn_err = max_tddyn_err
        self.max_tddyn_err_decay = max_tddyn_err_decay
        self.min_real_samples_per_epoch = min_real_samples_per_epoch
        self.batch_size_policy = batch_size_policy

        self.model_pool = ModelBuffer(
            batch_size=self._rollout_batch_size,
            max_path_length=maxroll,
            env=self.fake_env,
            ensemble_size=num_networks,
            rollout_mode=self.rollout_mode,
            cares_about_cost=cares_about_cost,
            max_uncertainty_c=self._max_uncertainty_c,
            max_uncertainty_r=self._max_uncertainty_rew,
        )
        self.model_pool.initialize(
            pi_info_shapes,
            gamma=self._policy.gamma,
            lam=self._policy.lam,
            cost_gamma=self._policy.cost_gamma,
            cost_lam=self._policy.cost_lam,
        )
        #@anyboby debug
        self.model_sampler = ModelSampler(
            max_path_length=maxroll,
            batch_size=self._rollout_batch_size,
            store_last_n_paths=10,
            cares_about_cost=cares_about_cost,
            max_uncertainty_c=self._max_uncertainty_c,
            max_uncertainty_r=self._max_uncertainty_rew,
            logger=None,
            rollout_mode=self.rollout_mode,
        )

        # provide policy and sampler with the same logger
        self.logger = EpochLogger()
        self._policy.set_logger(self.logger)
        self.sampler.set_logger(self.logger)
        #self.model_sampler.set_logger(self.logger)

    def _train(self):
        """Return a generator that performs RL training.

        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy (`Policy`): Policy used for training
            initial_exploration_policy ('Policy'): Policy used for exploration
                If None, then all exploration is done using policy
            pool (`PoolBase`): Sample pool to add samples to
        """

        #### pool is e.g. simple_replay_pool
        training_environment = self._training_environment
        evaluation_environment = self._evaluation_environment
        policy = self._policy
        pool = self._pool

        if not self._training_started:
            #### perform some initial steps (gather samples) using initial policy
            ######  fills pool with _n_initial_exploration_steps samples
            self._initial_exploration_hook(training_environment, self._policy,
                                           pool)

        #### set up sampler with train env and actual policy (may be different from initial exploration policy)
        ######## note: sampler is set up with the pool that may be already filled from initial exploration hook
        self.sampler.initialize(training_environment, policy, pool)
        self.model_sampler.initialize(self.fake_env, policy, self.model_pool)
        rollout_dkl_lim = self.model_sampler.compute_dynamics_dkl(
            obs_batch=self._pool.rand_batch_from_archive(
                5000, fields=['observations'])['observations'],
            depth=self._rollout_schedule[2])
        self.model_sampler.set_rollout_dkl(rollout_dkl_lim)
        self.initial_model_dkl = self.model_sampler.dyn_dkl
        #### reset gtimer (for coverage of project development)
        gt.reset_root()
        gt.rename_root('RLAlgorithm')
        gt.set_def_unique(False)
        self.policy_epoch = 0  ### count policy updates
        self.new_real_samples = 0
        self.last_eval_step = 0
        self.diag_counter = 0
        running_diag = {}
        self.approx_model_batch = self.batch_size_policy - self.min_real_samples_per_epoch  ### some size to start off

        #### not implemented, could train policy before hook
        self._training_before_hook()

        #### iterate over epochs, gt.timed_for to create loop with gt timestamps
        for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)):

            #### do something at beginning of epoch (in this case reset self._train_steps_this_epoch=0)
            self._epoch_before_hook()
            gt.stamp('epoch_before_hook')

            #### util class Progress, e.g. for plotting a progress bar
            #######   note: sampler may already contain samples in its pool from initial_exploration_hook or previous epochs
            self._training_progress = Progress(self._epoch_length *
                                               self._n_train_repeat /
                                               self._train_every_n_steps)

            samples_added = 0
            #=====================================================================#
            #            Rollout model                                            #
            #=====================================================================#
            model_samples = None
            keep_rolling = True
            model_metrics = {}
            #### start model rollout
            if self._real_ratio < 1.0:  #if self._timestep % self._model_train_freq == 0 and self._real_ratio < 1.0:
                #=====================================================================#
                #                           Model Rollouts                            #
                #=====================================================================#
                if self.rollout_mode == 'schedule':
                    self._set_rollout_length()

                while keep_rolling:
                    ep_b = self._pool.epoch_batch(
                        batch_size=self._rollout_batch_size,
                        epochs=self._pool.epochs_list,
                        fields=['observations', 'pi_infos'])
                    kls = np.clip(self._policy.compute_DKL(
                        ep_b['observations'], ep_b['mu'], ep_b['log_std']),
                                  a_min=0,
                                  a_max=None)
                    btz_dist = self._pool.boltz_dist(kls,
                                                     alpha=self.policy_alpha)
                    btz_b = self._pool.distributed_batch_from_archive(
                        self._rollout_batch_size,
                        btz_dist,
                        fields=['observations', 'pi_infos'])
                    start_states, mus, logstds = btz_b['observations'], btz_b[
                        'mu'], btz_b['log_std']
                    btz_kl = np.clip(self._policy.compute_DKL(
                        start_states, mus, logstds),
                                     a_min=0,
                                     a_max=None)

                    self.model_sampler.reset(start_states)
                    if self.rollout_mode == 'uncertainty':
                        self.model_sampler.set_max_uncertainty(
                            self.max_tddyn_err)

                    for i in count():
                        # print(f'Model Sampling step Nr. {i+1}')

                        _, _, _, info = self.model_sampler.sample(
                            max_samples=int(self.approx_model_batch -
                                            samples_added))

                        if self.model_sampler._total_samples + samples_added >= .99 * self.approx_model_batch:
                            keep_rolling = False
                            break

                        if info['alive_ratio'] <= 0.1: break

                    ### diagnostics for rollout ###
                    rollout_diagnostics = self.model_sampler.finish_all_paths()
                    if self.rollout_mode == 'iv_gae':
                        keep_rolling = self.model_pool.size + samples_added <= .99 * self.approx_model_batch

                    ######################################################################
                    ### get model_samples, get() invokes the inverse variance rollouts ###
                    model_samples_new, buffer_diagnostics_new = self.model_pool.get(
                    )
                    model_samples = [
                        np.concatenate((o, n), axis=0)
                        for o, n in zip(model_samples, model_samples_new)
                    ] if model_samples else model_samples_new

                    ######################################################################
                    ### diagnostics
                    new_n_samples = len(model_samples_new[0]) + EPS
                    diag_weight_old = samples_added / (new_n_samples +
                                                       samples_added)
                    diag_weight_new = new_n_samples / (new_n_samples +
                                                       samples_added)
                    model_metrics = update_dict(model_metrics,
                                                rollout_diagnostics,
                                                weight_a=diag_weight_old,
                                                weight_b=diag_weight_new)
                    model_metrics = update_dict(model_metrics,
                                                buffer_diagnostics_new,
                                                weight_a=diag_weight_old,
                                                weight_b=diag_weight_new)
                    ### run diagnostics on model data
                    if buffer_diagnostics_new['poolm_batch_size'] > 0:
                        model_data_diag = self._policy.run_diagnostics(
                            model_samples_new)
                        model_data_diag = {
                            k + '_m': v
                            for k, v in model_data_diag.items()
                        }
                        model_metrics = update_dict(model_metrics,
                                                    model_data_diag,
                                                    weight_a=diag_weight_old,
                                                    weight_b=diag_weight_new)

                    samples_added += new_n_samples
                    model_metrics.update({'samples_added': samples_added})
                    ######################################################################

                ## for debugging
                model_metrics.update({
                    'cached_var':
                    np.mean(self.fake_env._model.scaler_out.cached_var)
                })
                model_metrics.update({
                    'cached_mu':
                    np.mean(self.fake_env._model.scaler_out.cached_mu)
                })

                print(f'Rollouts finished')
                gt.stamp('epoch_rollout_model')

            #=====================================================================#
            #  Sample                                                             #
            #=====================================================================#
            n_real_samples = self.model_sampler.dyn_dkl / self.initial_model_dkl * self.min_real_samples_per_epoch
            n_real_samples = max(n_real_samples, 1000)
            # n_real_samples = self.min_real_samples_per_epoch ### for ablation

            model_metrics.update({'n_real_samples': n_real_samples})
            start_samples = self.sampler._total_samples
            ### train for epoch_length ###
            for i in count():

                #### _timestep is within an epoch
                samples_now = self.sampler._total_samples
                self._timestep = samples_now - start_samples

                #### not implemented atm
                self._timestep_before_hook()
                gt.stamp('timestep_before_hook')

                ##### śampling from the real world ! #####
                _, _, _, _ = self._do_sampling(timestep=self.policy_epoch)
                gt.stamp('sample')

                self._timestep_after_hook()
                gt.stamp('timestep_after_hook')

                if self.ready_to_train or self._timestep > n_real_samples:
                    self.sampler.finish_all_paths(append_val=True,
                                                  append_cval=True,
                                                  reset_path=False)
                    self.new_real_samples += self._timestep
                    break

            #=====================================================================#
            #  Train model                                                        #
            #=====================================================================#
            if self.new_real_samples > 2048 and self._real_ratio < 1.0:
                model_diag = self.train_model(min_epochs=1, max_epochs=10)
                self.new_real_samples = 0
                model_metrics.update(model_diag)

            #=====================================================================#
            #  Get Buffer Data                                                    #
            #=====================================================================#
            real_samples, buf_diag = self._pool.get()

            ### run diagnostics on real data
            policy_diag = self._policy.run_diagnostics(real_samples)
            policy_diag = {k + '_r': v for k, v in policy_diag.items()}
            model_metrics.update(policy_diag)
            model_metrics.update(buf_diag)

            #=====================================================================#
            #  Update Policy                                                      #
            #=====================================================================#
            train_samples = [
                np.concatenate((r, m), axis=0)
                for r, m in zip(real_samples, model_samples)
            ] if model_samples else real_samples
            self._policy.update_real_c(real_samples)
            self._policy.update_policy(train_samples)
            self._policy.update_critic(
                train_samples,
                train_vc=(train_samples[-3] >
                          0).any())  ### only train vc if there are any costs

            if self._real_ratio < 1.0:
                self.approx_model_batch = self.batch_size_policy - n_real_samples  #self.model_sampler.dyn_dkl/self.initial_model_dkl * self.min_real_samples_per_epoch

            self.policy_epoch += 1
            self.max_tddyn_err *= self.max_tddyn_err_decay
            #### log policy diagnostics
            self._policy.log()

            gt.stamp('train')
            #=====================================================================#
            #  Log performance and stats                                          #
            #=====================================================================#

            self.sampler.log()
            # write results to file, ray prints for us, so no need to print from logger
            logger_diagnostics = self.logger.dump_tabular(
                output_dir=self._log_dir, print_out=False)
            #=====================================================================#

            if self._total_timestep // self.eval_every_n_steps > self.last_eval_step:
                evaluation_paths = self._evaluation_paths(
                    policy, evaluation_environment)
                gt.stamp('evaluation_paths')

                self.last_eval_step = self._total_timestep // self.eval_every_n_steps
            else:
                evaluation_paths = []

            if evaluation_paths:
                evaluation_metrics = self._evaluate_rollouts(
                    evaluation_paths, evaluation_environment)
                gt.stamp('evaluation_metrics')
                diag_obs_batch = np.concatenate(([
                    evaluation_paths[i]['observations']
                    for i in range(len(evaluation_paths))
                ]),
                                                axis=0)
            else:
                evaluation_metrics = {}
                diag_obs_batch = []

            gt.stamp('epoch_after_hook')

            new_diagnostics = {}

            time_diagnostics = gt.get_times().stamps.itrs

            # add diagnostics from logger
            new_diagnostics.update(logger_diagnostics)

            new_diagnostics.update(
                OrderedDict((
                    *((f'evaluation/{key}', evaluation_metrics[key])
                      for key in sorted(evaluation_metrics.keys())),
                    *((f'times/{key}', time_diagnostics[key][-1])
                      for key in sorted(time_diagnostics.keys())),
                    *((f'model/{key}', model_metrics[key])
                      for key in sorted(model_metrics.keys())),
                )))

            if self._eval_render_mode is not None and hasattr(
                    evaluation_environment, 'render_rollouts'):
                training_environment.render_rollouts(evaluation_paths)

            #### updateing and averaging
            old_ts_diag = running_diag.get('timestep', 0)
            new_ts_diag = self._total_timestep - self.diag_counter - old_ts_diag
            w_olddiag = old_ts_diag / (new_ts_diag + old_ts_diag)
            w_newdiag = new_ts_diag / (new_ts_diag + old_ts_diag)
            running_diag = update_dict(running_diag,
                                       new_diagnostics,
                                       weight_a=w_olddiag,
                                       weight_b=w_newdiag)
            running_diag.update({'timestep': new_ts_diag + old_ts_diag})
            ####

            if new_ts_diag + old_ts_diag > self.eval_every_n_steps:
                running_diag.update({
                    'epoch': self._epoch,
                    'timesteps_total': self._total_timestep,
                    'train-steps': self._num_train_steps,
                })
                self.diag_counter = self._total_timestep
                diag = running_diag.copy()
                running_diag = {}
                yield diag

            if self._total_timestep >= self.n_env_interacts:
                self.sampler.terminate()

                self._training_after_hook()

                self._training_progress.close()
                print("###### DONE ######")
                yield {'done': True, **running_diag}

                break

    def train(self, *args, **kwargs):
        return self._train(*args, **kwargs)

    def _initial_exploration_hook(self, env, initial_exploration_policy, pool):
        if self._n_initial_exploration_steps < 1: return

        if not initial_exploration_policy:
            raise ValueError("Initial exploration policy must be provided when"
                             " n_initial_exploration_steps > 0.")

        self.sampler.initialize(env, initial_exploration_policy, pool)
        while True:
            self.sampler.sample(timestep=0)
            if self.sampler._total_samples >= self._n_initial_exploration_steps:
                self.sampler.finish_all_paths(append_val=True,
                                              append_cval=True,
                                              reset_path=False)
                pool.get()  # moves policy samples to archive
                break

        ### train model
        if self._real_ratio < 1.0:
            self.train_model(min_epochs=150, max_epochs=500)

    def train_model(self, min_epochs=5, max_epochs=100, batch_size=2048):
        self._dyn_model_train_freq = self._set_model_train_freq(
            self._dyn_model_train_freq,
            self._dyn_model_train_schedule)  ## set current train freq.
        self._cost_model_train_freq = self._set_model_train_freq(
            self._cost_model_train_freq, self._cost_model_train_schedule)

        print('[ MBPO ] log_dir: {} | ratio: {}'.format(
            self._log_dir, self._real_ratio))
        print(
            '[ MBPO ] Training model at epoch {} | freq {} | timestep {} (total: {}) (total: {})'
            .format(self._epoch, self._dyn_model_train_freq, self._timestep,
                    self._total_timestep, self._num_train_steps))

        model_samples = self._pool.get_archive([
            'observations',
            'actions',
            'next_observations',
            'rewards',
            'costs',
            'terminals',
            'epochs',
        ])

        if self._epoch % self._dyn_model_train_freq == 0:
            diag_dyn = self.fake_env.train_dyn_model(
                model_samples,
                batch_size=batch_size,  #512
                max_epochs=max_epochs,  # max_epochs 
                min_epoch_before_break=min_epochs,  # min_epochs, 
                holdout_ratio=0.2,
                max_t=self._max_model_t)

        if self._epoch % self._cost_model_train_freq == 0 and self.fake_env.cares_about_cost:
            diag_c = self.fake_env.train_cost_model(
                model_samples,
                batch_size=batch_size,  #batch_size, #512, 
                min_epoch_before_break=min_epochs,  #min_epochs,
                max_epochs=max_epochs,  # max_epochs, 
                holdout_ratio=0.2,
                max_t=self._max_model_t)
            diag_dyn.update(diag_c)

        return diag_dyn

    @property
    def _total_timestep(self):
        total_timestep = self.sampler._total_samples
        return total_timestep

    def _set_rollout_length(self):
        min_epoch, max_epoch, min_length, max_length = self._rollout_schedule
        if self._epoch <= min_epoch:
            y = min_length
        else:
            dx = (self._epoch - min_epoch) / (max_epoch - min_epoch)
            dx = min(dx, 1)
            y = dx * (max_length - min_length) + min_length

        self._rollout_length = int(y)
        self.model_sampler.set_max_path_length(self._rollout_length)
        print(
            '[ Model Length ] Epoch: {} (min: {}, max: {}) | Length: {} (min: {} , max: {})'
            .format(self._epoch, min_epoch, max_epoch, self._rollout_length,
                    min_length, max_length))

    def _do_sampling(self, timestep):
        return self.sampler.sample(timestep=timestep)

    def _set_model_train_freq(self, var, schedule):
        min_epoch, max_epoch, min_freq, max_freq = schedule
        if self._epoch <= min_epoch:
            y = min_freq
        else:
            dx = (self._epoch - min_epoch) / (max_epoch - min_epoch)
            dx = min(dx, 1)
            y = dx * (max_freq - min_freq) + min_freq

        var = int(y)
        print(
            '[ Model Train Frequency ] Epoch: {} (min: {}, max: {}) | Frequency: {} (min: {} , max: {})'
            .format(self._epoch, min_epoch, max_epoch, var, min_freq,
                    max_freq))
        return var

    def _evaluate_rollouts(self, paths, env):
        """Compute evaluation metrics for the given rollouts."""

        total_returns = [path['rewards'].sum() for path in paths]
        episode_lengths = [len(p['rewards']) for p in paths]
        total_cost = [path['cost'].sum() for path in paths]
        diagnostics = OrderedDict((
            ('return-average', np.mean(total_returns)),
            ('return-min', np.min(total_returns)),
            ('return-max', np.max(total_returns)),
            ('return-std', np.std(total_returns)),
            ('episode-length-avg', np.mean(episode_lengths)),
            ('episode-length-min', np.min(episode_lengths)),
            ('episode-length-max', np.max(episode_lengths)),
            ('episode-length-std', np.std(episode_lengths)),
            ('creturn-average', np.mean(total_cost)),
            ('creturn-fullep-average', np.mean(total_cost) /
             np.mean(episode_lengths) * self.sampler.max_path_length),
            ('creturn-min', np.min(total_cost)),
            ('creturn-max', np.max(total_cost)),
            ('creturn-std', np.std(total_cost)),
        ))

        env_infos = env.get_path_infos(paths)
        for key, value in env_infos.items():
            diagnostics[f'env_infos/{key}'] = value

        return diagnostics

    def _visualize_model(self, env, timestep):
        ## save env state
        state = env.unwrapped.state_vector()
        qpos_dim = len(env.unwrapped.sim.data.qpos)
        qpos = state[:qpos_dim]
        qvel = state[qpos_dim:]

        print('[ Visualization ] Starting | Epoch {} | Log dir: {}\n'.format(
            self._epoch, self._log_dir))
        visualize_policy(env, self.fake_env, self._policy, self._writer,
                         timestep)
        print('[ Visualization ] Done')
        ## set env state
        env.unwrapped.set_state(qpos, qvel)

    def get_diagnostics(self,
                        iteration,
                        obs_batch=None,
                        training_paths=None,
                        evaluation_paths=None):
        """Return diagnostic information as ordered dictionary.

        Records state value function, and TD-loss (mean squared Bellman error)
        for the sample batch.

        Also calls the `draw` method of the plotter, if plotter defined.
        """

        # @anyboby
        warnings.warn('diagnostics not implemented yet!')

        # diagnostics = OrderedDict({
        #     })

        # policy_diagnostics = self._policy.get_diagnostics(                  #use eval paths
        #     obs_batch[:,-self.active_obs_dim:])
        # diagnostics.update({
        #     f'policy/{key}': value
        #     for key, value in policy_diagnostics.items()
        # })

        # if self._plotter:
        #     self._plotter.draw()

        diagnostics = {}

        return diagnostics

    def save(self, savedir):
        self.fake_env._model.save(savedir, self._epoch)

    @property
    def tf_saveables(self):
        saveables = {self._policy.tf_saveables}

        # saveables = {
        #     '_policy_optimizer': self._policy_optimizer,
        #     **{
        #         f'Q_optimizer_{i}': optimizer
        #         for i, optimizer in enumerate(self._Q_optimizers)
        #     },
        #     '_log_alpha': self._log_alpha,
        # }

        # if hasattr(self, '_alpha_optimizer'):
        #     saveables['_alpha_optimizer'] = self._alpha_optimizer

        return saveables
Beispiel #5
0
    def _train(self):
        """Return a generator that performs RL training.

        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy (`Policy`): Policy used for training
            initial_exploration_policy ('Policy'): Policy used for exploration
                If None, then all exploration is done using policy
            pool (`PoolBase`): Sample pool to add samples to
        """

        #### pool is e.g. simple_replay_pool
        training_environment = self._training_environment
        evaluation_environment = self._evaluation_environment
        policy = self._policy
        pool = self._pool

        if not self._training_started:
            #### perform some initial steps (gather samples) using initial policy
            ######  fills pool with _n_initial_exploration_steps samples
            self._initial_exploration_hook(training_environment, self._policy,
                                           pool)

        #### set up sampler with train env and actual policy (may be different from initial exploration policy)
        ######## note: sampler is set up with the pool that may be already filled from initial exploration hook
        self.sampler.initialize(training_environment, policy, pool)
        self.model_sampler.initialize(self.fake_env, policy, self.model_pool)
        rollout_dkl_lim = self.model_sampler.compute_dynamics_dkl(
            obs_batch=self._pool.rand_batch_from_archive(
                5000, fields=['observations'])['observations'],
            depth=self._rollout_schedule[2])
        self.model_sampler.set_rollout_dkl(rollout_dkl_lim)
        self.initial_model_dkl = self.model_sampler.dyn_dkl
        #### reset gtimer (for coverage of project development)
        gt.reset_root()
        gt.rename_root('RLAlgorithm')
        gt.set_def_unique(False)
        self.policy_epoch = 0  ### count policy updates
        self.new_real_samples = 0
        self.last_eval_step = 0
        self.diag_counter = 0
        running_diag = {}
        self.approx_model_batch = self.batch_size_policy - self.min_real_samples_per_epoch  ### some size to start off

        #### not implemented, could train policy before hook
        self._training_before_hook()

        #### iterate over epochs, gt.timed_for to create loop with gt timestamps
        for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)):

            #### do something at beginning of epoch (in this case reset self._train_steps_this_epoch=0)
            self._epoch_before_hook()
            gt.stamp('epoch_before_hook')

            #### util class Progress, e.g. for plotting a progress bar
            #######   note: sampler may already contain samples in its pool from initial_exploration_hook or previous epochs
            self._training_progress = Progress(self._epoch_length *
                                               self._n_train_repeat /
                                               self._train_every_n_steps)

            samples_added = 0
            #=====================================================================#
            #            Rollout model                                            #
            #=====================================================================#
            model_samples = None
            keep_rolling = True
            model_metrics = {}
            #### start model rollout
            if self._real_ratio < 1.0:  #if self._timestep % self._model_train_freq == 0 and self._real_ratio < 1.0:
                #=====================================================================#
                #                           Model Rollouts                            #
                #=====================================================================#
                if self.rollout_mode == 'schedule':
                    self._set_rollout_length()

                while keep_rolling:
                    ep_b = self._pool.epoch_batch(
                        batch_size=self._rollout_batch_size,
                        epochs=self._pool.epochs_list,
                        fields=['observations', 'pi_infos'])
                    kls = np.clip(self._policy.compute_DKL(
                        ep_b['observations'], ep_b['mu'], ep_b['log_std']),
                                  a_min=0,
                                  a_max=None)
                    btz_dist = self._pool.boltz_dist(kls,
                                                     alpha=self.policy_alpha)
                    btz_b = self._pool.distributed_batch_from_archive(
                        self._rollout_batch_size,
                        btz_dist,
                        fields=['observations', 'pi_infos'])
                    start_states, mus, logstds = btz_b['observations'], btz_b[
                        'mu'], btz_b['log_std']
                    btz_kl = np.clip(self._policy.compute_DKL(
                        start_states, mus, logstds),
                                     a_min=0,
                                     a_max=None)

                    self.model_sampler.reset(start_states)
                    if self.rollout_mode == 'uncertainty':
                        self.model_sampler.set_max_uncertainty(
                            self.max_tddyn_err)

                    for i in count():
                        # print(f'Model Sampling step Nr. {i+1}')

                        _, _, _, info = self.model_sampler.sample(
                            max_samples=int(self.approx_model_batch -
                                            samples_added))

                        if self.model_sampler._total_samples + samples_added >= .99 * self.approx_model_batch:
                            keep_rolling = False
                            break

                        if info['alive_ratio'] <= 0.1: break

                    ### diagnostics for rollout ###
                    rollout_diagnostics = self.model_sampler.finish_all_paths()
                    if self.rollout_mode == 'iv_gae':
                        keep_rolling = self.model_pool.size + samples_added <= .99 * self.approx_model_batch

                    ######################################################################
                    ### get model_samples, get() invokes the inverse variance rollouts ###
                    model_samples_new, buffer_diagnostics_new = self.model_pool.get(
                    )
                    model_samples = [
                        np.concatenate((o, n), axis=0)
                        for o, n in zip(model_samples, model_samples_new)
                    ] if model_samples else model_samples_new

                    ######################################################################
                    ### diagnostics
                    new_n_samples = len(model_samples_new[0]) + EPS
                    diag_weight_old = samples_added / (new_n_samples +
                                                       samples_added)
                    diag_weight_new = new_n_samples / (new_n_samples +
                                                       samples_added)
                    model_metrics = update_dict(model_metrics,
                                                rollout_diagnostics,
                                                weight_a=diag_weight_old,
                                                weight_b=diag_weight_new)
                    model_metrics = update_dict(model_metrics,
                                                buffer_diagnostics_new,
                                                weight_a=diag_weight_old,
                                                weight_b=diag_weight_new)
                    ### run diagnostics on model data
                    if buffer_diagnostics_new['poolm_batch_size'] > 0:
                        model_data_diag = self._policy.run_diagnostics(
                            model_samples_new)
                        model_data_diag = {
                            k + '_m': v
                            for k, v in model_data_diag.items()
                        }
                        model_metrics = update_dict(model_metrics,
                                                    model_data_diag,
                                                    weight_a=diag_weight_old,
                                                    weight_b=diag_weight_new)

                    samples_added += new_n_samples
                    model_metrics.update({'samples_added': samples_added})
                    ######################################################################

                ## for debugging
                model_metrics.update({
                    'cached_var':
                    np.mean(self.fake_env._model.scaler_out.cached_var)
                })
                model_metrics.update({
                    'cached_mu':
                    np.mean(self.fake_env._model.scaler_out.cached_mu)
                })

                print(f'Rollouts finished')
                gt.stamp('epoch_rollout_model')

            #=====================================================================#
            #  Sample                                                             #
            #=====================================================================#
            n_real_samples = self.model_sampler.dyn_dkl / self.initial_model_dkl * self.min_real_samples_per_epoch
            n_real_samples = max(n_real_samples, 1000)
            # n_real_samples = self.min_real_samples_per_epoch ### for ablation

            model_metrics.update({'n_real_samples': n_real_samples})
            start_samples = self.sampler._total_samples
            ### train for epoch_length ###
            for i in count():

                #### _timestep is within an epoch
                samples_now = self.sampler._total_samples
                self._timestep = samples_now - start_samples

                #### not implemented atm
                self._timestep_before_hook()
                gt.stamp('timestep_before_hook')

                ##### śampling from the real world ! #####
                _, _, _, _ = self._do_sampling(timestep=self.policy_epoch)
                gt.stamp('sample')

                self._timestep_after_hook()
                gt.stamp('timestep_after_hook')

                if self.ready_to_train or self._timestep > n_real_samples:
                    self.sampler.finish_all_paths(append_val=True,
                                                  append_cval=True,
                                                  reset_path=False)
                    self.new_real_samples += self._timestep
                    break

            #=====================================================================#
            #  Train model                                                        #
            #=====================================================================#
            if self.new_real_samples > 2048 and self._real_ratio < 1.0:
                model_diag = self.train_model(min_epochs=1, max_epochs=10)
                self.new_real_samples = 0
                model_metrics.update(model_diag)

            #=====================================================================#
            #  Get Buffer Data                                                    #
            #=====================================================================#
            real_samples, buf_diag = self._pool.get()

            ### run diagnostics on real data
            policy_diag = self._policy.run_diagnostics(real_samples)
            policy_diag = {k + '_r': v for k, v in policy_diag.items()}
            model_metrics.update(policy_diag)
            model_metrics.update(buf_diag)

            #=====================================================================#
            #  Update Policy                                                      #
            #=====================================================================#
            train_samples = [
                np.concatenate((r, m), axis=0)
                for r, m in zip(real_samples, model_samples)
            ] if model_samples else real_samples
            self._policy.update_real_c(real_samples)
            self._policy.update_policy(train_samples)
            self._policy.update_critic(
                train_samples,
                train_vc=(train_samples[-3] >
                          0).any())  ### only train vc if there are any costs

            if self._real_ratio < 1.0:
                self.approx_model_batch = self.batch_size_policy - n_real_samples  #self.model_sampler.dyn_dkl/self.initial_model_dkl * self.min_real_samples_per_epoch

            self.policy_epoch += 1
            self.max_tddyn_err *= self.max_tddyn_err_decay
            #### log policy diagnostics
            self._policy.log()

            gt.stamp('train')
            #=====================================================================#
            #  Log performance and stats                                          #
            #=====================================================================#

            self.sampler.log()
            # write results to file, ray prints for us, so no need to print from logger
            logger_diagnostics = self.logger.dump_tabular(
                output_dir=self._log_dir, print_out=False)
            #=====================================================================#

            if self._total_timestep // self.eval_every_n_steps > self.last_eval_step:
                evaluation_paths = self._evaluation_paths(
                    policy, evaluation_environment)
                gt.stamp('evaluation_paths')

                self.last_eval_step = self._total_timestep // self.eval_every_n_steps
            else:
                evaluation_paths = []

            if evaluation_paths:
                evaluation_metrics = self._evaluate_rollouts(
                    evaluation_paths, evaluation_environment)
                gt.stamp('evaluation_metrics')
                diag_obs_batch = np.concatenate(([
                    evaluation_paths[i]['observations']
                    for i in range(len(evaluation_paths))
                ]),
                                                axis=0)
            else:
                evaluation_metrics = {}
                diag_obs_batch = []

            gt.stamp('epoch_after_hook')

            new_diagnostics = {}

            time_diagnostics = gt.get_times().stamps.itrs

            # add diagnostics from logger
            new_diagnostics.update(logger_diagnostics)

            new_diagnostics.update(
                OrderedDict((
                    *((f'evaluation/{key}', evaluation_metrics[key])
                      for key in sorted(evaluation_metrics.keys())),
                    *((f'times/{key}', time_diagnostics[key][-1])
                      for key in sorted(time_diagnostics.keys())),
                    *((f'model/{key}', model_metrics[key])
                      for key in sorted(model_metrics.keys())),
                )))

            if self._eval_render_mode is not None and hasattr(
                    evaluation_environment, 'render_rollouts'):
                training_environment.render_rollouts(evaluation_paths)

            #### updateing and averaging
            old_ts_diag = running_diag.get('timestep', 0)
            new_ts_diag = self._total_timestep - self.diag_counter - old_ts_diag
            w_olddiag = old_ts_diag / (new_ts_diag + old_ts_diag)
            w_newdiag = new_ts_diag / (new_ts_diag + old_ts_diag)
            running_diag = update_dict(running_diag,
                                       new_diagnostics,
                                       weight_a=w_olddiag,
                                       weight_b=w_newdiag)
            running_diag.update({'timestep': new_ts_diag + old_ts_diag})
            ####

            if new_ts_diag + old_ts_diag > self.eval_every_n_steps:
                running_diag.update({
                    'epoch': self._epoch,
                    'timesteps_total': self._total_timestep,
                    'train-steps': self._num_train_steps,
                })
                self.diag_counter = self._total_timestep
                diag = running_diag.copy()
                running_diag = {}
                yield diag

            if self._total_timestep >= self.n_env_interacts:
                self.sampler.terminate()

                self._training_after_hook()

                self._training_progress.close()
                print("###### DONE ######")
                yield {'done': True, **running_diag}

                break
Beispiel #6
0
class MBPO(RLAlgorithm):
    """Model-Based Policy Optimization (MBPO)

    References
    ----------
        Michael Janner, Justin Fu, Marvin Zhang, Sergey Levine. 
        When to Trust Your Model: Model-Based Policy Optimization. 
        arXiv preprint arXiv:1906.08253. 2019.
    """

    def __init__(
            self,
            training_environment,
            evaluation_environment,
            policy,
            Qs,
            pool,
            static_fns,
            plotter=None,
            tf_summaries=False,

            lr=3e-4,
            reward_scale=1.0,
            target_entropy='auto',
            discount=0.99,
            tau=5e-3,
            target_update_interval=1,
            action_prior='uniform',
            reparameterize=False,
            store_extra_policy_info=False,

            deterministic=False,
            model_train_freq=250,
            model_train_slower=1,
            num_networks=7,
            num_elites=5,
            num_Q_elites=2, # The num of Q ensemble is set in command line
            model_retain_epochs=20,
            rollout_batch_size=100e3,
            real_ratio=0.1,
            critic_same_as_actor=True,
            rollout_schedule=[20,100,1,1],
            hidden_dim=200,
            max_model_t=None,
            dir_name=None,
            evaluate_explore_freq=0,
            num_Q_per_grp=2,
            num_Q_grp=1,
            cross_grp_diff_batch=False,

            model_load_dir=None,
            model_load_index=None,
            model_log_freq=0,
            **kwargs,
    ):
        """
        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy: A policy function approximator.
            initial_exploration_policy: ('Policy'): A policy that we use
                for initial exploration which is not trained by the algorithm.
            Qs: Q-function approximators. The min of these
                approximators will be used. Usage of at least two Q-functions
                improves performance by reducing overestimation bias.
            pool (`PoolBase`): Replay pool to add gathered samples to.
            plotter (`QFPolicyPlotter`): Plotter instance to be used for
                visualizing Q-function during training.
            lr (`float`): Learning rate used for the function approximators.
            discount (`float`): Discount factor for Q-function updates.
            tau (`float`): Soft value function target update weight.
            target_update_interval ('int'): Frequency at which target network
                updates occur in iterations.
            reparameterize ('bool'): If True, we use a gradient estimator for
                the policy derived using the reparameterization trick. We use
                a likelihood ratio based estimator otherwise.
            critic_same_as_actor ('bool'): If True, use the same sampling schema
                (model free or model based) as the actor in critic training. 
                Otherwise, use model free sampling to train critic.
        """

        super(MBPO, self).__init__(**kwargs)

        if training_environment.unwrapped.spec.id.find("Fetch") != -1:
            # Fetch env
            obs_dim = sum([i.shape[0] for i in training_environment.observation_space.spaces.values()]) 
            self.multigoal = 1
        else:
            obs_dim = np.prod(training_environment.observation_space.shape)
        # print("====", obs_dim, "========")

        act_dim = np.prod(training_environment.action_space.shape)

        # TODO: add variable scope to directly extract model parameters
        self._model_load_dir = model_load_dir
        print("============Model dir: ", self._model_load_dir)
        if model_load_index:
            latest_model_index = model_load_index
        else:
            latest_model_index = self._get_latest_index()
        self._model = construct_model(obs_dim=obs_dim, act_dim=act_dim, hidden_dim=hidden_dim, num_networks=num_networks, num_elites=num_elites,
                                      model_dir=self._model_load_dir, model_load_timestep=latest_model_index, load_model=True if model_load_dir else False)
        self._static_fns = static_fns
        self.fake_env = FakeEnv(self._model, self._static_fns)

        model_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self._model.name)
        all_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)

        self._rollout_schedule = rollout_schedule
        self._max_model_t = max_model_t

        # self._model_pool_size = model_pool_size
        # print('[ MBPO ] Model pool size: {:.2E}'.format(self._model_pool_size))
        # self._model_pool = SimpleReplayPool(pool._observation_space, pool._action_space, self._model_pool_size)

        self._model_retain_epochs = model_retain_epochs

        self._model_train_freq = model_train_freq
        self._rollout_batch_size = int(rollout_batch_size)
        self._deterministic = deterministic
        self._real_ratio = real_ratio

        self._log_dir = os.getcwd()
        self._writer = Writer(self._log_dir)

        self._training_environment = training_environment
        self._evaluation_environment = evaluation_environment
        self._policy = policy

        self._Qs = Qs
        self._Q_ensemble = len(Qs)
        self._Q_elites = num_Q_elites
        self._Q_targets = tuple(tf.keras.models.clone_model(Q) for Q in Qs)

        self._pool = pool
        self._plotter = plotter
        self._tf_summaries = tf_summaries

        self._policy_lr = lr
        self._Q_lr = lr

        self._reward_scale = reward_scale
        self._target_entropy = (
            -np.prod(self._training_environment.action_space.shape)
            if target_entropy == 'auto'
            else target_entropy)
        print('[ MBPO ] Target entropy: {}'.format(self._target_entropy))

        self._discount = discount
        self._tau = tau
        self._target_update_interval = target_update_interval
        self._action_prior = action_prior

        self._reparameterize = reparameterize
        self._store_extra_policy_info = store_extra_policy_info

        observation_shape = self._training_environment.active_observation_shape
        action_shape = self._training_environment.action_space.shape

        assert len(observation_shape) == 1, observation_shape
        self._observation_shape = observation_shape
        assert len(action_shape) == 1, action_shape
        self._action_shape = action_shape

        # self._critic_train_repeat = kwargs["critic_train_repeat"]
        # actor UTD should be n times larger or smaller than critic UTD
        assert self._actor_train_repeat % self._critic_train_repeat == 0 or \
               self._critic_train_repeat % self._actor_train_repeat == 0

        self._critic_train_freq = self._n_train_repeat // self._critic_train_repeat
        self._actor_train_freq = self._n_train_repeat // self._actor_train_repeat
        self._critic_same_as_actor = critic_same_as_actor
        self._model_train_slower = model_train_slower
        self._origin_model_train_epochs = 0

        self._dir_name = dir_name
        self._evaluate_explore_freq = evaluate_explore_freq

        # Inter-group Qs are trained with the same data; Cross-group Qs different.
        self._num_Q_per_grp = num_Q_per_grp
        self._num_Q_grp = num_Q_grp
        self._cross_grp_diff_batch = cross_grp_diff_batch

        self._model_log_freq = model_log_freq

        self._build()

    def _build(self):
        self._training_ops = {}
        self._actor_training_ops = {}
        self._critic_training_ops = {} if not self._cross_grp_diff_batch else \
                                    [{} for _ in range(self._num_Q_grp)]
        self._misc_training_ops = {} # basically no feeddict is needed
        # device = "/device:GPU:1"
        # with tf.device(device):
        #     self._init_global_step()
        #     self._init_placeholders()
        #     self._init_actor_update()
        #     self._init_critic_update()
        self._init_global_step()
        self._init_placeholders()
        self._init_actor_update()
        self._init_critic_update()

    def _train(self):
        
        """Return a generator that performs RL training.

        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy (`Policy`): Policy used for training
            initial_exploration_policy ('Policy'): Policy used for exploration
                If None, then all exploration is done using policy
            pool (`PoolBase`): Sample pool to add samples to
        """
        training_environment = self._training_environment
        evaluation_environment = self._evaluation_environment
        policy = self._policy
        pool = self._pool
        model_metrics = {}

        if not self._training_started:
            self._init_training()

            self._initial_exploration_hook(
                training_environment, self._initial_exploration_policy, pool)

        self.sampler.initialize(training_environment, policy, pool)

        gt.reset_root()
        gt.rename_root('RLAlgorithm')
        gt.set_def_unique(False)

        self._training_before_hook()

        for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)):

            self._epoch_before_hook()
            gt.stamp('epoch_before_hook')

            if self._evaluate_explore_freq != 0 and self._epoch % self._evaluate_explore_freq == 0:
                self._evaluate_exploration()

            self._training_progress = Progress(self._epoch_length * self._n_train_repeat)
            start_samples = self.sampler._total_samples
            for i in count():
                samples_now = self.sampler._total_samples
                self._timestep = samples_now - start_samples

                if (samples_now >= start_samples + self._epoch_length
                    and self.ready_to_train):
                    break
                self._timestep_before_hook()
                gt.stamp('timestep_before_hook')

                if self._timestep % self._model_train_freq == 0 and self._real_ratio < 1.0:
                    
                    self._training_progress.pause()
                    print('[ MBPO ] log_dir: {} | ratio: {}'.format(self._log_dir, self._real_ratio))
                    print('[ MBPO ] Training model at epoch {} | freq {} | timestep {} (total: {}) | epoch train steps: {} (total: {}) | times slower: {}'.format(
                        self._epoch, self._model_train_freq, self._timestep, self._total_timestep, self._train_steps_this_epoch, self._num_train_steps, self._model_train_slower)
                    )

                    if self._origin_model_train_epochs % self._model_train_slower == 0:
                        model_train_metrics = self._train_model(batch_size=256, max_epochs=None, holdout_ratio=0.2, max_t=self._max_model_t)
                        model_metrics.update(model_train_metrics)
                        gt.stamp('epoch_train_model')
                    else:
                        print('[ MBPO ] Skipping model training due to slowed training setting')
                    self._origin_model_train_epochs += 1
                    
                    self._set_rollout_length()
                    self._reallocate_model_pool()
                    model_rollout_metrics = self._rollout_model(rollout_batch_size=self._rollout_batch_size, deterministic=self._deterministic)
                    model_metrics.update(model_rollout_metrics)
                    
                    if self._model_log_freq != 0 and self._timestep % self._model_log_freq == 0:
                        self._log_model()

                    gt.stamp('epoch_rollout_model')
                    # self._visualize_model(self._evaluation_environment, self._total_timestep)
                    self._training_progress.resume()

                self._do_sampling(timestep=self._total_timestep)
                gt.stamp('sample')

                if self.ready_to_train:
                    self._do_training_repeats(timestep=self._total_timestep)
                gt.stamp('train')

                self._timestep_after_hook()
                gt.stamp('timestep_after_hook')

            training_paths = self.sampler.get_last_n_paths(
                math.ceil(self._epoch_length / self.sampler._max_path_length))
            gt.stamp('training_paths')
            evaluation_paths = self._evaluation_paths(
                policy, evaluation_environment)
            gt.stamp('evaluation_paths')

            training_metrics = self._evaluate_rollouts(
                training_paths, training_environment)
            gt.stamp('training_metrics')
            if evaluation_paths:
                evaluation_metrics = self._evaluate_rollouts(
                    evaluation_paths, evaluation_environment)
                gt.stamp('evaluation_metrics')
            else:
                evaluation_metrics = {}

            self._epoch_after_hook(training_paths)
            gt.stamp('epoch_after_hook')

            sampler_diagnostics = self.sampler.get_diagnostics()

            diagnostics = self.get_diagnostics(
                iteration=self._total_timestep,
                batch=self._evaluation_batch(),
                training_paths=training_paths,
                evaluation_paths=evaluation_paths)

            time_diagnostics = gt.get_times().stamps.itrs

            diagnostics.update(OrderedDict((
                *(
                    (f'evaluation/{key}', evaluation_metrics[key])
                    for key in sorted(evaluation_metrics.keys())
                ),
                *(
                    (f'training/{key}', training_metrics[key])
                    for key in sorted(training_metrics.keys())
                ),
                *(
                    (f'times/{key}', time_diagnostics[key][-1])
                    for key in sorted(time_diagnostics.keys())
                ),
                *(
                    (f'sampler/{key}', sampler_diagnostics[key])
                    for key in sorted(sampler_diagnostics.keys())
                ),
                *(
                    (f'model/{key}', model_metrics[key])
                    for key in sorted(model_metrics.keys())
                ),
                ('epoch', self._epoch),
                ('timestep', self._timestep),
                ('timesteps_total', self._total_timestep),
                ('train-steps', self._num_train_steps),
            )))

            if self._eval_render_mode is not None and hasattr(
                    evaluation_environment, 'render_rollouts'):
                training_environment.render_rollouts(evaluation_paths)


            yield diagnostics

        self.sampler.terminate()

        self._training_after_hook()

        self._training_progress.close()

        yield {'done': True, **diagnostics}
    
    def _evaluate_exploration(self):
        print("=============evaluate exploration=========")

        # specify data dir
        base_dir = "/home/linus/Research/mbpo/mbpo_experiment/exploration_eval"
        if not self._dir_name:
            return
        data_dir = os.path.join(base_dir, self._dir_name)
        if not os.path.isdir(data_dir):
            os.mkdir(data_dir)

        # specify data name
        exp_name = "%d.pkl"%self._epoch
        path = os.path.join(data_dir, exp_name)

        evaluation_size = 3000
        action_repeat = 20
        batch = self.sampler.random_batch(evaluation_size)
        obs = batch['observations']
        actions_repeat = [self._policy.actions_np(obs) for _ in range(action_repeat)]

        Qs = []
        policy_std = []
        for action in actions_repeat:
            Q = []
            for (s,a) in zip(obs, action):
                s, a = np.array(s).reshape(1, -1), np.array(a).reshape(1, -1)
                Q.append(
                    self._session.run(
                        self._Q_values,
                        feed_dict = {
                            self._observations_ph: s,
                            self._actions_ph: a
                        }
                    )
                )
            Qs.append(Q)
        Qs = np.array(Qs).squeeze()
        Qs_mean_action = np.mean(Qs, axis = 0) # Compute mean across different actions of one given state.
        if self._cross_grp_diff_batch:
            inter_grp_q_stds = [np.std(Qs_mean_action[:, i * self._num_Q_per_grp:(i+1) * self._num_Q_per_grp], axis = 1) for i in range(self._num_Q_grp)]
            mean_inter_grp_q_std = np.mean(np.array(inter_grp_q_stds), axis = 0)
            min_qs_per_grp = [np.mean(Qs_mean_action[:, i * self._num_Q_per_grp:(i+1) * self._num_Q_per_grp], axis = 1) for i in range(self._num_Q_grp)]
            cross_grp_std = np.std(np.array(min_qs_per_grp), axis = 0)
        else:
            q_std = np.std(Qs_mean_action, axis=1) # In fact V std

        policy_std = [np.prod(np.exp(self._policy.policy_log_scale_model.predict(np.array(s).reshape(1,-1)))) for s in obs]

        if self._cross_grp_diff_batch:
            data = { 
                'obs': obs,
                'inter_q_std': mean_inter_grp_q_std,
                'cross_q_std': cross_grp_std,
                'pi_std': policy_std
            }
        else:
            data = {
                'obs': obs,
                'q_std': q_std,
                'pi_std': policy_std
            }
        with open(path, 'wb') as f:
            pickle.dump(data, f)
        print("==========================================")


    def train(self, *args, **kwargs):
        return self._train(*args, **kwargs)

    def _log_policy(self):
        save_path = os.path.join(self._log_dir, 'models')
        filesystem.mkdir(save_path)
        weights = self._policy.get_weights()
        data = {'policy_weights': weights}
        full_path = os.path.join(save_path, 'policy_{}.pkl'.format(self._total_timestep))
        print('Saving policy to: {}'.format(full_path))
        pickle.dump(data, open(full_path, 'wb'))

    # TODO: use this function to save model
    def _log_model(self):
        save_path = os.path.join(self._log_dir, 'models')
        filesystem.mkdir(save_path)
        print('Saving model to: {}'.format(save_path))
        self._model.save(save_path, self._total_timestep)

    def _set_rollout_length(self):
        min_epoch, max_epoch, min_length, max_length = self._rollout_schedule
        if self._epoch <= min_epoch:
            y = min_length
        else:
            dx = (self._epoch - min_epoch) / (max_epoch - min_epoch)
            dx = min(dx, 1)
            y = dx * (max_length - min_length) + min_length

        self._rollout_length = int(y)
        print('[ Model Length ] Epoch: {} (min: {}, max: {}) | Length: {} (min: {} , max: {})'.format(
            self._epoch, min_epoch, max_epoch, self._rollout_length, min_length, max_length
        ))

    def _reallocate_model_pool(self):
        obs_space = self._pool._observation_space
        act_space = self._pool._action_space

        rollouts_per_epoch = self._rollout_batch_size * self._epoch_length / self._model_train_freq
        model_steps_per_epoch = int(self._rollout_length * rollouts_per_epoch)
        new_pool_size = self._model_retain_epochs * model_steps_per_epoch

        if not hasattr(self, '_model_pool'):
            print('[ MBPO ] Initializing new model pool with size {:.2e}'.format(
                new_pool_size
            ))
            self._model_pool = SimpleReplayPool(obs_space, act_space, new_pool_size)
        
        elif self._model_pool._max_size != new_pool_size:
            print('[ MBPO ] Updating model pool | {:.2e} --> {:.2e}'.format(
                self._model_pool._max_size, new_pool_size
            ))
            samples = self._model_pool.return_all_samples()
            new_pool = SimpleReplayPool(obs_space, act_space, new_pool_size)
            new_pool.add_samples(samples)
            assert self._model_pool.size == new_pool.size
            self._model_pool = new_pool

    def _train_model(self, **kwargs):
        env_samples = self._pool.return_all_samples()
        # train_inputs, train_outputs = format_samples_for_training(env_samples, self.multigoal)
        train_inputs, train_outputs = format_samples_for_training(env_samples)
        model_metrics = self._model.train(train_inputs, train_outputs, **kwargs)
        return model_metrics

    def _rollout_model(self, rollout_batch_size, **kwargs):
        print('[ Model Rollout ] Starting | Epoch: {} | Rollout length: {} | Batch size: {}'.format(
            self._epoch, self._rollout_length, rollout_batch_size
        ))

        # Keep total rollout sample complexity unchanged
        batch = self.sampler.random_batch(rollout_batch_size // self._sample_repeat)
        obs = batch['observations']
        steps_added = []
        sampled_actions = []
        for _ in range(self._sample_repeat):
            for i in range(self._rollout_length):
                # TODO: alter policy distribution in different times of sample repeating
                # self._policy: softlearning.policies.gaussian_policy.FeedforwardGaussianPolicy
                # self._policy._deterministic: False
                # print("=====================================")
                # print(self._policy._deterministic)
                # print("=====================================")
                act = self._policy.actions_np(obs)
                sampled_actions.append(act)
                
                next_obs, rew, term, info = self.fake_env.step(obs, act, **kwargs)
                steps_added.append(len(obs))

                samples = {'observations': obs, 'actions': act, 'next_observations': next_obs, 'rewards': rew, 'terminals': term}
                self._model_pool.add_samples(samples)

                nonterm_mask = ~term.squeeze(-1)
                if nonterm_mask.sum() == 0:
                    print('[ Model Rollout ] Breaking early: {} | {} / {}'.format(i, nonterm_mask.sum(), nonterm_mask.shape))
                    break

                obs = next_obs[nonterm_mask]
        # print(sampled_actions)

        mean_rollout_length = sum(steps_added) / rollout_batch_size
        rollout_stats = {'mean_rollout_length': mean_rollout_length}
        print('[ Model Rollout ] Added: {:.1e} | Model pool: {:.1e} (max {:.1e}) | Length: {} | Train rep: {}'.format(
            sum(steps_added), self._model_pool.size, self._model_pool._max_size, mean_rollout_length, self._n_train_repeat
        ))
        return rollout_stats

    def _visualize_model(self, env, timestep):
        ## save env state
        state = env.unwrapped.state_vector()
        qpos_dim = len(env.unwrapped.sim.data.qpos)
        qpos = state[:qpos_dim]
        qvel = state[qpos_dim:]

        print('[ Visualization ] Starting | Epoch {} | Log dir: {}\n'.format(self._epoch, self._log_dir))
        visualize_policy(env, self.fake_env, self._policy, self._writer, timestep)
        print('[ Visualization ] Done')
        ## set env state
        env.unwrapped.set_state(qpos, qvel)

    def _training_batch(self, batch_size=None):
        batch_size = batch_size or self.sampler._batch_size
        env_batch_size = int(batch_size*self._real_ratio)
        model_batch_size = batch_size - env_batch_size

        ## can sample from the env pool even if env_batch_size == 0
        if self._cross_grp_diff_batch:
            env_batch = [self._pool.random_batch(env_batch_size) for _ in range(self._num_Q_grp)]
        else:
            env_batch = self._pool.random_batch(env_batch_size)

        if model_batch_size > 0:
            model_batch = self._model_pool.random_batch(model_batch_size)

            keys = env_batch.keys()
            batch = {k: np.concatenate((env_batch[k], model_batch[k]), axis=0) for k in keys}
        else:
            ## if real_ratio == 1.0, no model pool was ever allocated,
            ## so skip the model pool sampling
            batch = env_batch
        return batch, env_batch

    def _init_global_step(self):
        self.global_step = training_util.get_or_create_global_step()
        self._training_ops.update({
            'increment_global_step': training_util._increment_global_step(1)
        })
        self._misc_training_ops.update({
            'increment_global_step': training_util._increment_global_step(1)
        })

    def _init_placeholders(self):
        """Create input placeholders for the SAC algorithm.

        Creates `tf.placeholder`s for:
            - observation
            - next observation
            - action
            - reward
            - terminals
        """
        self._iteration_ph = tf.placeholder(
            tf.int64, shape=None, name='iteration')

        self._observations_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._observation_shape),
            name='observation',
        )

        self._next_observations_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._observation_shape),
            name='next_observation',
        )

        self._actions_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._action_shape),
            name='actions',
        )

        self._rewards_ph = tf.placeholder(
            tf.float32,
            shape=(None, 1),
            name='rewards',
        )

        self._terminals_ph = tf.placeholder(
            tf.float32,
            shape=(None, 1),
            name='terminals',
        )

        if self._store_extra_policy_info:
            self._log_pis_ph = tf.placeholder(
                tf.float32,
                shape=(None, 1),
                name='log_pis',
            )
            self._raw_actions_ph = tf.placeholder(
                tf.float32,
                shape=(None, *self._action_shape),
                name='raw_actions',
            )

    def _get_Q_target(self):
        next_actions = self._policy.actions([self._next_observations_ph])
        next_log_pis = self._policy.log_pis(
            [self._next_observations_ph], next_actions)

        next_Qs_values = tuple(
            Q([self._next_observations_ph, next_actions])
            for Q in self._Q_targets)
        Qs_subset = np.random.choice(next_Qs_values, self._Q_elites, replace=False).tolist()
        
        # Line 8 of REDQ: min over M random indices
        min_next_Q = tf.reduce_min(Qs_subset, axis=0)
        next_value = min_next_Q - self._alpha * next_log_pis

        Q_target = td_target(
            reward=self._reward_scale * self._rewards_ph,
            discount=self._discount,
            next_value=(1 - self._terminals_ph) * next_value)

        return Q_target

    def _init_critic_update(self):
        """Create minimization operation for critic Q-function.

        Creates a `tf.optimizer.minimize` operation for updating
        critic Q-function with gradient descent, and appends it to
        `self._training_ops` attribute.
        """
        Q_target = tf.stop_gradient(self._get_Q_target())

        assert Q_target.shape.as_list() == [None, 1]

        Q_values = self._Q_values = tuple(
            Q([self._observations_ph, self._actions_ph])
            for Q in self._Qs)

        Q_losses = self._Q_losses = tuple(
            tf.losses.mean_squared_error(
                labels=Q_target, predictions=Q_value, weights=0.5)
            for Q_value in Q_values)

        self._Q_optimizers = tuple(
            tf.train.AdamOptimizer(
                learning_rate=self._Q_lr,
                name='{}_{}_optimizer'.format(Q._name, i)
            ) for i, Q in enumerate(self._Qs))

        # TODO: divide it to N separate ops, where N is # of Q grps
        Q_training_ops = tuple(
            tf.contrib.layers.optimize_loss(
                Q_loss,
                self.global_step,
                learning_rate=self._Q_lr,
                optimizer=Q_optimizer,
                variables=Q.trainable_variables,
                increment_global_step=False,
                summaries=((
                    "loss", "gradients", "gradient_norm", "global_gradient_norm"
                ) if self._tf_summaries else ()))
            for i, (Q, Q_loss, Q_optimizer)
            in enumerate(zip(self._Qs, Q_losses, self._Q_optimizers)))

        self._training_ops.update({'Q': tf.group(Q_training_ops)})
        if self._cross_grp_diff_batch:
            assert len(Q_training_ops) >= self._num_Q_grp * self._num_Q_per_grp
            for i in range(self._num_Q_grp - 1):
                self._critic_training_ops[i].update({
                    'Q': tf.group(Q_training_ops[i * self._num_Q_grp: (i+1) * self._num_Q_grp])
                })

            self._critic_training_ops[self._num_Q_grp - 1].update({
                'Q': tf.group(Q_training_ops[(self._num_Q_grp - 1) * self._num_Q_grp:])
            })
        else:
            self._critic_training_ops.update({'Q': tf.group(Q_training_ops)})

    def _init_actor_update(self):
        """Create minimization operations for policy and entropy.

        Creates a `tf.optimizer.minimize` operations for updating
        policy and entropy with gradient descent, and adds them to
        `self._training_ops` attribute.
        """

        actions = self._policy.actions([self._observations_ph])
        log_pis = self._policy.log_pis([self._observations_ph], actions)

        assert log_pis.shape.as_list() == [None, 1]

        log_alpha = self._log_alpha = tf.get_variable(
            'log_alpha',
            dtype=tf.float32,
            initializer=0.0)
        alpha = tf.exp(log_alpha)

        if isinstance(self._target_entropy, Number):
            alpha_loss = -tf.reduce_mean(
                log_alpha * tf.stop_gradient(log_pis + self._target_entropy))

            self._alpha_optimizer = tf.train.AdamOptimizer(
                self._policy_lr, name='alpha_optimizer')
            self._alpha_train_op = self._alpha_optimizer.minimize(
                loss=alpha_loss, var_list=[log_alpha])

            self._training_ops.update({
                'temperature_alpha': self._alpha_train_op
            })
            self._actor_training_ops.update({
                'temperature_alpha': self._alpha_train_op
            })

        self._alpha = alpha

        if self._action_prior == 'normal':
            policy_prior = tf.contrib.distributions.MultivariateNormalDiag(
                loc=tf.zeros(self._action_shape),
                scale_diag=tf.ones(self._action_shape))
            policy_prior_log_probs = policy_prior.log_prob(actions)
        elif self._action_prior == 'uniform':
            policy_prior_log_probs = 0.0

        Q_log_targets = tuple(
            Q([self._observations_ph, actions])
            for Q in self._Qs)
        assert len(Q_log_targets) == self._Q_ensemble

        min_Q_log_target = tf.reduce_min(Q_log_targets, axis=0)
        mean_Q_log_target = tf.reduce_mean(Q_log_targets, axis=0)
        Q_target = min_Q_log_target if self._Q_ensemble == 2 else mean_Q_log_target

        if self._reparameterize:
            policy_kl_losses = (
                alpha * log_pis
                - Q_target
                - policy_prior_log_probs)
        else:
            raise NotImplementedError

        assert policy_kl_losses.shape.as_list() == [None, 1]

        policy_loss = tf.reduce_mean(policy_kl_losses)

        self._policy_optimizer = tf.train.AdamOptimizer(
            learning_rate=self._policy_lr,
            name="policy_optimizer")
        policy_train_op = tf.contrib.layers.optimize_loss(
            policy_loss,
            self.global_step,
            learning_rate=self._policy_lr,
            optimizer=self._policy_optimizer,
            variables=self._policy.trainable_variables,
            increment_global_step=False,
            summaries=(
                "loss", "gradients", "gradient_norm", "global_gradient_norm"
            ) if self._tf_summaries else ())

        self._training_ops.update({'policy_train_op': policy_train_op})
        self._actor_training_ops.update({'policy_train_op': policy_train_op})

    def _init_training(self):
        self._update_target(tau=1.0)

    def _update_target(self, tau=None):
        tau = tau or self._tau

        for Q, Q_target in zip(self._Qs, self._Q_targets):
            source_params = Q.get_weights()
            target_params = Q_target.get_weights()
            Q_target.set_weights([
                tau * source + (1.0 - tau) * target
                for source, target in zip(source_params, target_params)
            ])

    def _do_training(self, iteration, batch):
        """Runs the operations for updating training and target ops."""
        mix_batch, mf_batch = batch

        self._training_progress.update()
        self._training_progress.set_description()

        if self._cross_grp_diff_batch:
            assert len(mix_batch) == self._num_Q_grp
            if self._real_ratio != 1:
                assert 0, "Currently different batch is not supported in MBPO"
            mix_feed_dict = [self._get_feed_dict(iteration, i) for i in mix_batch]
            single_mix_feed_dict = mix_feed_dict[0]
        else:
            mix_feed_dict = self._get_feed_dict(iteration, mix_batch)
            single_mix_feed_dict = mix_feed_dict


        if self._critic_same_as_actor:
            critic_feed_dict = mix_feed_dict
        else:
            critic_feed_dict = self._get_feed_dict(iteration, mf_batch)

        self._session.run(self._misc_training_ops, single_mix_feed_dict)

        if iteration % self._actor_train_freq == 0:
            self._session.run(self._actor_training_ops, single_mix_feed_dict)
        if iteration % self._critic_train_freq == 0:
            if self._cross_grp_diff_batch:
                assert len(self._critic_training_ops) == len(critic_feed_dict)
                [
                    self._session.run(op, feed_dict)
                    for (op, feed_dict) in zip(self._critic_training_ops, critic_feed_dict)
                ]
            else:
                self._session.run(self._critic_training_ops, critic_feed_dict)

        if iteration % self._target_update_interval == 0:
            # Run target ops here.
            self._update_target()

    def _get_feed_dict(self, iteration, batch):
        """Construct TensorFlow feed_dict from sample batch."""

        feed_dict = {
            self._observations_ph: batch['observations'],
            self._actions_ph: batch['actions'],
            self._next_observations_ph: batch['next_observations'],
            self._rewards_ph: batch['rewards'],
            self._terminals_ph: batch['terminals'],
        }

        if self._store_extra_policy_info:
            feed_dict[self._log_pis_ph] = batch['log_pis']
            feed_dict[self._raw_actions_ph] = batch['raw_actions']

        if iteration is not None:
            feed_dict[self._iteration_ph] = iteration

        return feed_dict

    def get_diagnostics(self,
                        iteration,
                        batch,
                        training_paths,
                        evaluation_paths):
        """Return diagnostic information as ordered dictionary.

        Records mean and standard deviation of Q-function and state
        value function, and TD-loss (mean squared Bellman error)
        for the sample batch.

        Also calls the `draw` method of the plotter, if plotter defined.
        """
        mix_batch, _ = batch
        if self._cross_grp_diff_batch:
            mix_batch = mix_batch[0]
        mix_feed_dict = self._get_feed_dict(iteration, mix_batch)

        # (Q_values, Q_losses, alpha, global_step) = self._session.run(
        #     (self._Q_values,
        #      self._Q_losses,
        #      self._alpha,
        #      self.global_step),
        #     feed_dict)
        Q_values, Q_losses = self._session.run(
            [self._Q_values, self._Q_losses],
            mix_feed_dict
        )

        alpha, global_step = self._session.run(
            [self._alpha, self.global_step],
            mix_feed_dict
        )

        diagnostics = OrderedDict({
            'Q-avg': np.mean(Q_values),
            'Q-std': np.std(Q_values),
            'Q_loss': np.mean(Q_losses),
            'alpha': alpha,
        })

        policy_diagnostics = self._policy.get_diagnostics(
            mix_batch['observations'])
        diagnostics.update({
            f'policy/{key}': value
            for key, value in policy_diagnostics.items()
        })

        if self._plotter:
            self._plotter.draw()

        return diagnostics

    @property
    def tf_saveables(self):
        saveables = {
            '_policy_optimizer': self._policy_optimizer,
            **{
                f'Q_optimizer_{i}': optimizer
                for i, optimizer in enumerate(self._Q_optimizers)
            },
            '_log_alpha': self._log_alpha,
        }

        if hasattr(self, '_alpha_optimizer'):
            saveables['_alpha_optimizer'] = self._alpha_optimizer

        return saveables

    def _get_latest_index(self):
        if self._model_load_dir is None:
            return
        return max([int(i.split("_")[1].split(".")[0]) for i in os.listdir(self._model_load_dir)])
Beispiel #7
0
    def _train(self):
        """Return a generator that performs RL training.

        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy (`Policy`): Policy used for training
            initial_exploration_policy ('Policy'): Policy used for exploration
                If None, then all exploration is done using policy
            pool (`PoolBase`): Sample pool to add samples to
        """

        #### pool is e.g. simple_replay_pool
        training_environment = self._training_environment
        evaluation_environment = self._evaluation_environment
        policy = self._policy
        pool = self._pool
        model_metrics = {}

        #### init Qs for SAC
        if not self._training_started:
            self._init_training()

            #### perform some initial steps (gather samples) using initial policy
            ######  fills pool with _n_initial_exploration_steps samples
            self._initial_exploration_hook(training_environment,
                                           self._initial_exploration_policy,
                                           pool)

        #### set up sampler with train env and actual policy (may be different from initial exploration policy)
        ######## note: sampler is set up with the pool that may be already filled from initial exploration hook
        self.sampler.initialize(training_environment, policy, pool)

        #### reset gtimer (for coverage of project development)
        gt.reset_root()
        gt.rename_root('RLAlgorithm')
        gt.set_def_unique(False)

        #### not implemented, could train policy before hook
        self._training_before_hook()

        #### iterate over epochs, gt.timed_for to create loop with gt timestamps
        for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)):
            #### do something at beginning of epoch (in this case reset self._train_steps_this_epoch=0)
            self._epoch_before_hook()
            gt.stamp('epoch_before_hook')

            #### util class Progress, e.g. for plotting a progress bar
            #######   note: sampler may already contain samples in its pool from initial_exploration_hook or previous epochs
            self._training_progress = Progress(self._epoch_length *
                                               self._n_train_repeat /
                                               self._train_every_n_steps)
            start_samples = self.sampler._total_samples

            ### train for epoch_length ###
            for i in count():

                #### _timestep is within an epoch
                samples_now = self.sampler._total_samples
                self._timestep = samples_now - start_samples

                #### check if you're at the end of an epoch to train
                if (samples_now >= start_samples + self._epoch_length
                        and self.ready_to_train):
                    break

                #### not implemented atm
                self._timestep_before_hook()
                gt.stamp('timestep_before_hook')

                #### start model rollout
                if self._timestep % self._model_train_freq == 0 and self._real_ratio < 1.0:
                    self._training_progress.pause()
                    print('[ MBPO ] log_dir: {} | ratio: {}'.format(
                        self._log_dir, self._real_ratio))
                    print(
                        '[ MBPO ] Training model at epoch {} | freq {} | timestep {} (total: {}) | epoch train steps: {} (total: {})'
                        .format(self._epoch, self._model_train_freq,
                                self._timestep, self._total_timestep,
                                self._train_steps_this_epoch,
                                self._num_train_steps))

                    #### train the model with input:(obs, act), outputs: (rew, delta_obs), inputs are divided into sets with holdout_ratio
                    #@anyboby debug
                    samples = self._pool.return_all_samples()
                    self.fake_env.reset_model()
                    model_train_metrics = self.fake_env.train(
                        samples,
                        batch_size=512,
                        max_epochs=None,
                        holdout_ratio=0.2,
                        max_t=self._max_model_t)
                    model_metrics.update(model_train_metrics)
                    gt.stamp('epoch_train_model')

                    #### rollout model env ####
                    self._set_rollout_length()
                    self._reallocate_model_pool(
                        use_mjc_model_pool=self.use_mjc_state_model)
                    model_rollout_metrics = self._rollout_model(
                        rollout_batch_size=self._rollout_batch_size,
                        deterministic=self._deterministic)
                    model_metrics.update(model_rollout_metrics)
                    ###########################

                    gt.stamp('epoch_rollout_model')
                    # self._visualize_model(self._evaluation_environment, self._total_timestep)
                    self._training_progress.resume()

                ##### śampling from the real world ! #####
                ##### _total_timestep % train_every_n_steps is checked inside _do_sampling
                self._do_sampling(timestep=self._total_timestep)
                gt.stamp('sample')

                ### n_train_repeat from config ###
                if self.ready_to_train:
                    self._do_training_repeats(timestep=self._total_timestep)
                gt.stamp('train')

                self._timestep_after_hook()
                gt.stamp('timestep_after_hook')

            training_paths = self.sampler.get_last_n_paths(
                math.ceil(self._epoch_length / self.sampler._max_path_length))
            gt.stamp('training_paths')
            evaluation_paths = self._evaluation_paths(policy,
                                                      evaluation_environment)
            gt.stamp('evaluation_paths')

            training_metrics = self._evaluate_rollouts(training_paths,
                                                       training_environment)
            gt.stamp('training_metrics')
            if evaluation_paths:
                evaluation_metrics = self._evaluate_rollouts(
                    evaluation_paths, evaluation_environment)
                gt.stamp('evaluation_metrics')
            else:
                evaluation_metrics = {}

            self._epoch_after_hook(training_paths)
            gt.stamp('epoch_after_hook')

            sampler_diagnostics = self.sampler.get_diagnostics()

            diagnostics = self.get_diagnostics(
                iteration=self._total_timestep,
                batch=self._evaluation_batch(),
                training_paths=training_paths,
                evaluation_paths=evaluation_paths)

            time_diagnostics = gt.get_times().stamps.itrs

            diagnostics.update(
                OrderedDict((
                    *((f'evaluation/{key}', evaluation_metrics[key])
                      for key in sorted(evaluation_metrics.keys())),
                    *((f'training/{key}', training_metrics[key])
                      for key in sorted(training_metrics.keys())),
                    *((f'times/{key}', time_diagnostics[key][-1])
                      for key in sorted(time_diagnostics.keys())),
                    *((f'sampler/{key}', sampler_diagnostics[key])
                      for key in sorted(sampler_diagnostics.keys())),
                    *((f'model/{key}', model_metrics[key])
                      for key in sorted(model_metrics.keys())),
                    ('epoch', self._epoch),
                    ('timestep', self._timestep),
                    ('timesteps_total', self._total_timestep),
                    ('train-steps', self._num_train_steps),
                )))

            if self._eval_render_mode is not None and hasattr(
                    evaluation_environment, 'render_rollouts'):
                training_environment.render_rollouts(evaluation_paths)

            yield diagnostics

        self.sampler.terminate()

        self._training_after_hook()

        self._training_progress.close()

        ### this is where we yield the episode diagnostics to tune trial runner ###
        yield {'done': True, **diagnostics}