Beispiel #1
0
    def _reallocate_model_pool(self):
        obs_space = self._pool._observation_space
        act_space = self._pool._action_space

        rollouts_per_epoch = self._rollout_batch_size * self._epoch_length / self._model_train_freq
        model_steps_per_epoch = int(
            (self._forward_rollout_length + self._backward_rollout_length) *
            rollouts_per_epoch)
        new_pool_size = self._model_retain_epochs * model_steps_per_epoch

        if not hasattr(self, '_model_pool'):
            print(
                '[ Allocate Model Pool ] Initializing new model pool with size {:.2e}'
                .format(new_pool_size))
            self._model_pool = SimpleReplayPool(obs_space, act_space,
                                                new_pool_size)

        elif self._model_pool._max_size != new_pool_size:
            print(
                '[ Reallocate Model Pool ] Updating model pool | {:.2e} --> {:.2e}'
                .format(self._model_pool._max_size, new_pool_size))
            samples = self._model_pool.return_all_samples()
            new_pool = SimpleReplayPool(obs_space, act_space, new_pool_size)
            new_pool.add_samples(samples)
            assert self._model_pool.size == new_pool.size
            self._model_pool = new_pool
Beispiel #2
0
class MBPO(RLAlgorithm):
    """Model-Based Policy Optimization (MBPO)

    References
    ----------
        Michael Janner, Justin Fu, Marvin Zhang, Sergey Levine. 
        When to Trust Your Model: Model-Based Policy Optimization. 
        arXiv preprint arXiv:1906.08253. 2019.
    """
    def __init__(
        self,
        training_environment,
        evaluation_environment,
        policy,
        Qs,
        pool,
        static_fns,
        plotter=None,
        tf_summaries=False,
        lr=3e-4,
        reward_scale=1.0,
        target_entropy='auto',
        discount=0.99,
        tau=5e-3,
        target_update_interval=1,
        action_prior='uniform',
        reparameterize=False,
        store_extra_policy_info=False,
        deterministic=False,
        model_train_freq=250,
        num_networks=7,
        num_elites=5,
        model_retain_epochs=20,
        load_model_dir=None,
        rollout_batch_size=100e3,
        real_ratio=0.1,
        rollout_schedule=[20, 100, 1, 1],
        hidden_dim=200,
        max_model_t=None,
        **kwargs,
    ):
        """
        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy: A policy function approximator.
            initial_exploration_policy: ('Policy'): A policy that we use
                for initial exploration which is not trained by the algorithm.
            Qs: Q-function approximators. The min of these
                approximators will be used. Usage of at least two Q-functions
                improves performance by reducing overestimation bias.
            pool (`PoolBase`): Replay pool to add gathered samples to.
            plotter (`QFPolicyPlotter`): Plotter instance to be used for
                visualizing Q-function during training.
            lr (`float`): Learning rate used for the function approximators.
            discount (`float`): Discount factor for Q-function updates.
            tau (`float`): Soft value function target update weight.
            target_update_interval ('int'): Frequency at which target network
                updates occur in iterations.
            reparameterize ('bool'): If True, we use a gradient estimator for
                the policy derived using the reparameterization trick. We use
                a likelihood ratio based estimator otherwise.
        """

        super(MBPO, self).__init__(**kwargs)

        obs_dim = np.prod(training_environment.observation_space.shape)
        act_dim = np.prod(training_environment.action_space.shape)

        self._model_params = dict(obs_dim=obs_dim,
                                  act_dim=act_dim,
                                  hidden_dim=hidden_dim,
                                  num_networks=num_networks,
                                  num_elites=num_elites)
        if load_model_dir is not None:
            self._model_params['load_model'] = True
            self._model_params['model_dir'] = load_model_dir
        self._model = construct_model(**self._model_params)

        self._static_fns = static_fns
        self.fake_env = FakeEnv(self._model, self._static_fns)

        self._rollout_schedule = rollout_schedule
        self._max_model_t = max_model_t

        self._model_retain_epochs = model_retain_epochs

        self._model_train_freq = model_train_freq
        self._rollout_batch_size = int(rollout_batch_size)
        self._deterministic = deterministic
        self._real_ratio = real_ratio

        self._log_dir = os.getcwd()
        self._writer = Writer(self._log_dir)

        self._training_environment = training_environment
        self._evaluation_environment = evaluation_environment
        self._policy = policy

        self._Qs = Qs
        self._Q_targets = tuple(tf.keras.models.clone_model(Q) for Q in Qs)

        self._pool = pool
        self._plotter = plotter
        self._tf_summaries = tf_summaries

        self._policy_lr = lr
        self._Q_lr = lr

        self._reward_scale = reward_scale
        self._target_entropy = (
            -np.prod(self._training_environment.action_space.shape)
            if target_entropy == 'auto' else target_entropy)
        print('[ MBPO ] Target entropy: {}'.format(self._target_entropy))

        self._discount = discount
        self._tau = tau
        self._target_update_interval = target_update_interval
        self._action_prior = action_prior

        self._reparameterize = reparameterize
        self._store_extra_policy_info = store_extra_policy_info

        observation_shape = self._training_environment.active_observation_shape
        action_shape = self._training_environment.action_space.shape

        ### @ anyboby fixed pool size, reallocate causes memory leak
        obs_space = self._pool._observation_space
        act_space = self._pool._action_space
        rollouts_per_epoch = self._rollout_batch_size * self._epoch_length / self._model_train_freq
        model_steps_per_epoch = int(self._rollout_schedule[-1] *
                                    rollouts_per_epoch)
        mpool_size = self._model_retain_epochs * model_steps_per_epoch

        self._model_pool = SimpleReplayPool(obs_space, act_space, mpool_size)

        assert len(observation_shape) == 1, observation_shape
        self._observation_shape = observation_shape
        assert len(action_shape) == 1, action_shape
        self._action_shape = action_shape

        self._build()

    def _build(self):
        self._training_ops = {}

        self._init_global_step()
        self._init_placeholders()
        self._init_actor_update()
        self._init_critic_update()

    def _train(self):
        """Return a generator that performs RL training.

        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy (`Policy`): Policy used for training
            initial_exploration_policy ('Policy'): Policy used for exploration
                If None, then all exploration is done using policy
            pool (`PoolBase`): Sample pool to add samples to
        """
        training_environment = self._training_environment
        evaluation_environment = self._evaluation_environment
        policy = self._policy
        pool = self._pool
        model_metrics = {}

        if not self._training_started:
            self._init_training()

            self._initial_exploration_hook(training_environment,
                                           self._initial_exploration_policy,
                                           pool)

        self.sampler.initialize(training_environment, policy, pool)

        gt.reset_root()
        gt.rename_root('RLAlgorithm')
        gt.set_def_unique(False)

        self._training_before_hook()

        for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)):

            self._epoch_before_hook()
            gt.stamp('epoch_before_hook')

            self._training_progress = Progress(self._epoch_length *
                                               self._n_train_repeat)
            start_samples = self.sampler._total_samples
            for i in count():
                samples_now = self.sampler._total_samples
                self._timestep = samples_now - start_samples

                if (samples_now >= start_samples + self._epoch_length
                        and self.ready_to_train):
                    break

                self._timestep_before_hook()
                gt.stamp('timestep_before_hook')

                if self._timestep % self._model_train_freq == 0 and self._real_ratio < 1.0:
                    self._training_progress.pause()
                    print('[ MBPO ] log_dir: {} | ratio: {}'.format(
                        self._log_dir, self._real_ratio))
                    print(
                        '[ MBPO ] Training model at epoch {} | freq {} | timestep {} (total: {}) | epoch train steps: {} (total: {})'
                        .format(self._epoch, self._model_train_freq,
                                self._timestep, self._total_timestep,
                                self._train_steps_this_epoch,
                                self._num_train_steps))

                    model_train_metrics = self._train_model(
                        batch_size=256,
                        max_epochs=None,
                        holdout_ratio=0.2,
                        max_t=self._max_model_t)
                    model_metrics.update(model_train_metrics)
                    gt.stamp('epoch_train_model')

                    self._set_rollout_length()
                    if self._rollout_batch_size > 30000:
                        factor = self._rollout_batch_size // 30000 + 1
                        mini_batch = self._rollout_batch_size // factor
                        for i in range(factor):
                            model_rollout_metrics = self._rollout_model(
                                rollout_batch_size=mini_batch,
                                deterministic=self._deterministic)
                    else:
                        model_rollout_metrics = self._rollout_model(
                            rollout_batch_size=self._rollout_batch_size,
                            deterministic=self._deterministic)
                    model_metrics.update(model_rollout_metrics)

                    gt.stamp('epoch_rollout_model')
                    # self._visualize_model(self._evaluation_environment, self._total_timestep)
                    self._training_progress.resume()

                self._do_sampling(timestep=self._total_timestep)
                gt.stamp('sample')

                if self.ready_to_train:
                    self._do_training_repeats(timestep=self._total_timestep)
                gt.stamp('train')

                self._timestep_after_hook()
                gt.stamp('timestep_after_hook')

            training_paths = self.sampler.get_last_n_paths(
                math.ceil(self._epoch_length / self.sampler._max_path_length))
            gt.stamp('training_paths')
            evaluation_paths = self._evaluation_paths(policy,
                                                      evaluation_environment)
            gt.stamp('evaluation_paths')

            training_metrics = self._evaluate_rollouts(training_paths,
                                                       training_environment)
            gt.stamp('training_metrics')
            if evaluation_paths:
                evaluation_metrics = self._evaluate_rollouts(
                    evaluation_paths, evaluation_environment)
                gt.stamp('evaluation_metrics')
            else:
                evaluation_metrics = {}

            self._epoch_after_hook(training_paths)
            gt.stamp('epoch_after_hook')

            sampler_diagnostics = self.sampler.get_diagnostics()

            diagnostics = self.get_diagnostics(
                iteration=self._total_timestep,
                batch=self._evaluation_batch(),
                training_paths=training_paths,
                evaluation_paths=evaluation_paths)

            time_diagnostics = gt.get_times().stamps.itrs

            diagnostics.update(
                OrderedDict((
                    *((f'evaluation/{key}', evaluation_metrics[key])
                      for key in sorted(evaluation_metrics.keys())),
                    *((f'training/{key}', training_metrics[key])
                      for key in sorted(training_metrics.keys())),
                    *((f'times/{key}', time_diagnostics[key][-1])
                      for key in sorted(time_diagnostics.keys())),
                    *((f'sampler/{key}', sampler_diagnostics[key])
                      for key in sorted(sampler_diagnostics.keys())),
                    *((f'model/{key}', model_metrics[key])
                      for key in sorted(model_metrics.keys())),
                    ('epoch', self._epoch),
                    ('timestep', self._timestep),
                    ('timesteps_total', self._total_timestep),
                    ('train-steps', self._num_train_steps),
                )))

            if self._eval_render_mode is not None and hasattr(
                    evaluation_environment, 'render_rollouts'):
                training_environment.render_rollouts(evaluation_paths)

            yield diagnostics

        self.sampler.terminate()

        self._training_after_hook()

        self._training_progress.close()

        yield {'done': True, **diagnostics}

    def train(self, *args, **kwargs):
        return self._train(*args, **kwargs)

    def _log_policy(self):
        save_path = os.path.join(self._log_dir, 'models')
        filesystem.mkdir(save_path)
        weights = self._policy.get_weights()
        data = {'policy_weights': weights}
        full_path = os.path.join(save_path,
                                 'policy_{}.pkl'.format(self._total_timestep))
        print('Saving policy to: {}'.format(full_path))
        pickle.dump(data, open(full_path, 'wb'))

    def _log_model(self):
        save_path = os.path.join(self._log_dir, 'models')
        filesystem.mkdir(save_path)
        print('Saving model to: {}'.format(save_path))
        self._model.save(save_path, self._total_timestep)

    def _set_rollout_length(self):
        min_epoch, max_epoch, min_length, max_length = self._rollout_schedule
        if self._epoch <= min_epoch:
            y = min_length
        else:
            dx = (self._epoch - min_epoch) / (max_epoch - min_epoch)
            dx = min(dx, 1)
            y = dx * (max_length - min_length) + min_length

        self._rollout_length = int(y)
        print(
            '[ Model Length ] Epoch: {} (min: {}, max: {}) | Length: {} (min: {} , max: {})'
            .format(self._epoch, min_epoch, max_epoch, self._rollout_length,
                    min_length, max_length))

    def _reallocate_model_pool(self):
        obs_space = self._pool._observation_space
        act_space = self._pool._action_space

        rollouts_per_epoch = self._rollout_batch_size * self._epoch_length / self._model_train_freq
        model_steps_per_epoch = int(self._rollout_length * rollouts_per_epoch)
        new_pool_size = self._model_retain_epochs * model_steps_per_epoch

        if not hasattr(self, '_model_pool'):
            print(
                '[ MBPO ] Initializing new model pool with size {:.2e}'.format(
                    new_pool_size))
            self._model_pool = SimpleReplayPool(obs_space, act_space,
                                                new_pool_size)

        elif self._model_pool._max_size != new_pool_size:
            print('[ MBPO ] Updating model pool | {:.2e} --> {:.2e}'.format(
                self._model_pool._max_size, new_pool_size))
            samples = self._model_pool.return_all_samples()
            new_pool = SimpleReplayPool(obs_space, act_space, new_pool_size)
            new_pool.add_samples(samples)
            assert self._model_pool.size == new_pool.size
            self._model_pool = new_pool

    def _train_model(self, **kwargs):
        env_samples = self._pool.return_all_samples()
        train_inputs, train_outputs = format_samples_for_training(env_samples)
        model_metrics = self._model.train(train_inputs, train_outputs,
                                          **kwargs)
        return model_metrics

    def _rollout_model(self, rollout_batch_size, **kwargs):
        print(
            '[ Model Rollout ] Starting | Epoch: {} | Rollout length: {} | Batch size: {}'
            .format(self._epoch, self._rollout_length, rollout_batch_size))
        batch = self.sampler.random_batch(rollout_batch_size)
        obs = batch['observations']
        steps_added = []
        for i in range(self._rollout_length):
            act = self._policy.actions_np(obs)

            next_obs, rew, term, info = self.fake_env.step(obs, act, **kwargs)
            steps_added.append(len(obs))

            samples = {
                'observations': obs,
                'actions': act,
                'next_observations': next_obs,
                'rewards': rew,
                'terminals': term
            }
            self._model_pool.add_samples(samples)

            nonterm_mask = ~term.squeeze(-1)
            if nonterm_mask.sum() == 0:
                print('[ Model Rollout ] Breaking early: {} | {} / {}'.format(
                    i, nonterm_mask.sum(), nonterm_mask.shape))
                break

            obs = next_obs[nonterm_mask]

        mean_rollout_length = sum(steps_added) / rollout_batch_size
        rollout_stats = {'mean_rollout_length': mean_rollout_length}
        print(
            '[ Model Rollout ] Added: {:.1e} | Model pool: {:.1e} (max {:.1e}) | Length: {} | Train rep: {}'
            .format(sum(steps_added), self._model_pool.size,
                    self._model_pool._max_size, mean_rollout_length,
                    self._n_train_repeat))
        return rollout_stats

    def _visualize_model(self, env, timestep):
        ## save env state
        state = env.unwrapped.state_vector()
        qpos_dim = len(env.unwrapped.sim.data.qpos)
        qpos = state[:qpos_dim]
        qvel = state[qpos_dim:]

        print('[ Visualization ] Starting | Epoch {} | Log dir: {}\n'.format(
            self._epoch, self._log_dir))
        visualize_policy(env, self.fake_env, self._policy, self._writer,
                         timestep)
        print('[ Visualization ] Done')
        ## set env state
        env.unwrapped.set_state(qpos, qvel)

    def _training_batch(self, batch_size=None):
        batch_size = batch_size or self.sampler._batch_size
        env_batch_size = int(batch_size * self._real_ratio)
        model_batch_size = batch_size - env_batch_size

        ## can sample from the env pool even if env_batch_size == 0
        env_batch = self._pool.random_batch(env_batch_size)

        if model_batch_size > 0:
            model_batch = self._model_pool.random_batch(model_batch_size)

            keys = env_batch.keys()
            batch = {
                k: np.concatenate((env_batch[k], model_batch[k]), axis=0)
                for k in keys
            }
        else:
            ## if real_ratio == 1.0, no model pool was ever allocated,
            ## so skip the model pool sampling
            batch = env_batch
        return batch

    def _init_global_step(self):
        self.global_step = training_util.get_or_create_global_step()
        self._training_ops.update(
            {'increment_global_step': training_util._increment_global_step(1)})

    def _init_placeholders(self):
        """Create input placeholders for the SAC algorithm.

        Creates `tf.placeholder`s for:
            - observation
            - next observation
            - action
            - reward
            - terminals
        """
        self._iteration_ph = tf.placeholder(tf.int64,
                                            shape=None,
                                            name='iteration')

        self._observations_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._observation_shape),
            name='observation',
        )

        self._next_observations_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._observation_shape),
            name='next_observation',
        )

        self._actions_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._action_shape),
            name='actions',
        )

        self._rewards_ph = tf.placeholder(
            tf.float32,
            shape=(None, 1),
            name='rewards',
        )

        self._terminals_ph = tf.placeholder(
            tf.float32,
            shape=(None, 1),
            name='terminals',
        )

        if self._store_extra_policy_info:
            self._log_pis_ph = tf.placeholder(
                tf.float32,
                shape=(None, 1),
                name='log_pis',
            )
            self._raw_actions_ph = tf.placeholder(
                tf.float32,
                shape=(None, *self._action_shape),
                name='raw_actions',
            )

    def _get_Q_target(self):
        next_actions = self._policy.actions([self._next_observations_ph])
        next_log_pis = self._policy.log_pis([self._next_observations_ph],
                                            next_actions)

        next_Qs_values = tuple(
            Q([self._next_observations_ph, next_actions])
            for Q in self._Q_targets)

        min_next_Q = tf.reduce_min(next_Qs_values, axis=0)
        next_value = min_next_Q - self._alpha * next_log_pis

        Q_target = td_target(reward=self._reward_scale * self._rewards_ph,
                             discount=self._discount,
                             next_value=(1 - self._terminals_ph) * next_value)

        return Q_target

    def _init_critic_update(self):
        """Create minimization operation for critic Q-function.

        Creates a `tf.optimizer.minimize` operation for updating
        critic Q-function with gradient descent, and appends it to
        `self._training_ops` attribute.
        """
        Q_target = tf.stop_gradient(self._get_Q_target())

        assert Q_target.shape.as_list() == [None, 1]

        Q_values = self._Q_values = tuple(
            Q([self._observations_ph, self._actions_ph]) for Q in self._Qs)

        Q_losses = self._Q_losses = tuple(
            tf.losses.mean_squared_error(
                labels=Q_target, predictions=Q_value, weights=0.5)
            for Q_value in Q_values)

        self._Q_optimizers = tuple(
            tf.train.AdamOptimizer(learning_rate=self._Q_lr,
                                   name='{}_{}_optimizer'.format(Q._name, i))
            for i, Q in enumerate(self._Qs))
        Q_training_ops = tuple(
            tf.contrib.layers.optimize_loss(Q_loss,
                                            self.global_step,
                                            learning_rate=self._Q_lr,
                                            optimizer=Q_optimizer,
                                            variables=Q.trainable_variables,
                                            increment_global_step=False,
                                            summaries=((
                                                "loss", "gradients",
                                                "gradient_norm",
                                                "global_gradient_norm"
                                            ) if self._tf_summaries else ()))
            for i, (Q, Q_loss, Q_optimizer) in enumerate(
                zip(self._Qs, Q_losses, self._Q_optimizers)))

        self._training_ops.update({'Q': tf.group(Q_training_ops)})

    def _init_actor_update(self):
        """Create minimization operations for policy and entropy.

        Creates a `tf.optimizer.minimize` operations for updating
        policy and entropy with gradient descent, and adds them to
        `self._training_ops` attribute.
        """

        actions = self._policy.actions([self._observations_ph])
        log_pis = self._policy.log_pis([self._observations_ph], actions)

        assert log_pis.shape.as_list() == [None, 1]

        log_alpha = self._log_alpha = tf.get_variable('log_alpha',
                                                      dtype=tf.float32,
                                                      initializer=0.0)
        alpha = tf.exp(log_alpha)

        if isinstance(self._target_entropy, Number):
            alpha_loss = -tf.reduce_mean(
                log_alpha * tf.stop_gradient(log_pis + self._target_entropy))

            self._alpha_optimizer = tf.train.AdamOptimizer(
                self._policy_lr, name='alpha_optimizer')
            self._alpha_train_op = self._alpha_optimizer.minimize(
                loss=alpha_loss, var_list=[log_alpha])

            self._training_ops.update(
                {'temperature_alpha': self._alpha_train_op})

        self._alpha = alpha

        if self._action_prior == 'normal':
            policy_prior = tf.contrib.distributions.MultivariateNormalDiag(
                loc=tf.zeros(self._action_shape),
                scale_diag=tf.ones(self._action_shape))
            policy_prior_log_probs = policy_prior.log_prob(actions)
        elif self._action_prior == 'uniform':
            policy_prior_log_probs = 0.0

        Q_log_targets = tuple(
            Q([self._observations_ph, actions]) for Q in self._Qs)
        min_Q_log_target = tf.reduce_min(Q_log_targets, axis=0)

        if self._reparameterize:
            policy_kl_losses = (alpha * log_pis - min_Q_log_target -
                                policy_prior_log_probs)
        else:
            raise NotImplementedError

        assert policy_kl_losses.shape.as_list() == [None, 1]

        policy_loss = tf.reduce_mean(policy_kl_losses)

        self._policy_optimizer = tf.train.AdamOptimizer(
            learning_rate=self._policy_lr, name="policy_optimizer")
        policy_train_op = tf.contrib.layers.optimize_loss(
            policy_loss,
            self.global_step,
            learning_rate=self._policy_lr,
            optimizer=self._policy_optimizer,
            variables=self._policy.trainable_variables,
            increment_global_step=False,
            summaries=("loss", "gradients", "gradient_norm",
                       "global_gradient_norm") if self._tf_summaries else ())

        self._training_ops.update({'policy_train_op': policy_train_op})

    def _init_training(self):
        self._update_target(tau=1.0)

    def _update_target(self, tau=None):
        tau = tau or self._tau

        for Q, Q_target in zip(self._Qs, self._Q_targets):
            source_params = Q.get_weights()
            target_params = Q_target.get_weights()
            Q_target.set_weights([
                tau * source + (1.0 - tau) * target
                for source, target in zip(source_params, target_params)
            ])

    def _do_training(self, iteration, batch):
        """Runs the operations for updating training and target ops."""

        self._training_progress.update()
        self._training_progress.set_description()

        feed_dict = self._get_feed_dict(iteration, batch)

        self._session.run(self._training_ops, feed_dict)

        if iteration % self._target_update_interval == 0:
            # Run target ops here.
            self._update_target()

    def _get_feed_dict(self, iteration, batch):
        """Construct TensorFlow feed_dict from sample batch."""

        feed_dict = {
            self._observations_ph: batch['observations'],
            self._actions_ph: batch['actions'],
            self._next_observations_ph: batch['next_observations'],
            self._rewards_ph: batch['rewards'],
            self._terminals_ph: batch['terminals'],
        }

        if self._store_extra_policy_info:
            feed_dict[self._log_pis_ph] = batch['log_pis']
            feed_dict[self._raw_actions_ph] = batch['raw_actions']

        if iteration is not None:
            feed_dict[self._iteration_ph] = iteration

        return feed_dict

    def get_diagnostics(self, iteration, batch, training_paths,
                        evaluation_paths):
        """Return diagnostic information as ordered dictionary.

        Records mean and standard deviation of Q-function and state
        value function, and TD-loss (mean squared Bellman error)
        for the sample batch.

        Also calls the `draw` method of the plotter, if plotter defined.
        """

        feed_dict = self._get_feed_dict(iteration, batch)

        (Q_values, Q_losses, alpha, global_step) = self._session.run(
            (self._Q_values, self._Q_losses, self._alpha, self.global_step),
            feed_dict)

        diagnostics = OrderedDict({
            'Q-avg': np.mean(Q_values),
            'Q-std': np.std(Q_values),
            'Q_loss': np.mean(Q_losses),
            'alpha': alpha,
        })

        policy_diagnostics = self._policy.get_diagnostics(
            batch['observations'])
        diagnostics.update({
            f'policy/{key}': value
            for key, value in policy_diagnostics.items()
        })

        if self._plotter:
            self._plotter.draw()

        return diagnostics

    @property
    def tf_saveables(self):
        saveables = {
            '_policy_optimizer': self._policy_optimizer,
            **{
                f'Q_optimizer_{i}': optimizer
                for i, optimizer in enumerate(self._Q_optimizers)
            },
            '_log_alpha': self._log_alpha,
        }

        if hasattr(self, '_alpha_optimizer'):
            saveables['_alpha_optimizer'] = self._alpha_optimizer

        return saveables

    def save_model(self, dir):
        self._model.save(savedir=dir, timestep=self._epoch + 1)
Beispiel #3
0
class BMPO(RLAlgorithm):
    def __init__(
        self,
        training_environment,
        evaluation_environment,
        policy,
        Qs,
        pool,
        static_fns,
        log_file=None,
        plotter=None,
        tf_summaries=False,
        lr=3e-4,
        reward_scale=1.0,
        target_entropy='auto',
        discount=0.99,
        tau=5e-3,
        target_update_interval=1,
        action_prior='uniform',
        reparameterize=False,
        store_extra_policy_info=False,
        deterministic=False,
        model_train_freq=250,
        num_networks=7,
        num_elites=5,
        model_retain_epochs=20,
        rollout_batch_size=100e3,
        real_ratio=0.1,
        forward_rollout_schedule=[20, 100, 1, 1],
        backward_rollout_schedule=[20, 100, 1, 1],
        beta_schedule=[0, 100, 0, 0],
        last_n_epoch=10,
        planning_horizon=0,
        backward_policy_var=0,
        hidden_dim=200,
        max_model_t=None,
        **kwargs,
    ):
        """
        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy: A policy function approximator.
            initial_exploration_policy: ('Policy'): A policy that we use
                for initial exploration which is not trained by the algorithm.
            Qs: Q-function approximators. The min of these
                approximators will be used. Usage of at least two Q-functions
                improves performance by reducing overestimation bias.
            pool (`PoolBase`): Replay pool to add gathered samples to.
            plotter (`QFPolicyPlotter`): Plotter instance to be used for
                visualizing Q-function during training.
            lr (`float`): Learning rate used for the function approximators.
            discount (`float`): Discount factor for Q-function updates.
            tau (`float`): Soft value function target update weight.
            target_update_interval ('int'): Frequency at which target network
                updates occur in iterations.
            reparameterize ('bool'): If True, we use a gradient estimator for
                the policy derived using the reparameterization trick. We use
                a likelihood ratio based estimator otherwise.
        """

        super(BMPO, self).__init__(**kwargs)

        obs_dim = np.prod(training_environment.observation_space.shape)
        act_dim = np.prod(training_environment.action_space.shape)
        self._obs_dim = obs_dim
        self._act_dim = act_dim
        self._forward_model = construct_forward_model(
            obs_dim=obs_dim,
            act_dim=act_dim,
            hidden_dim=hidden_dim,
            num_networks=num_networks,
            num_elites=num_elites)
        self._backward_model = construct_backward_model(
            obs_dim=obs_dim,
            act_dim=act_dim,
            hidden_dim=hidden_dim,
            num_networks=num_networks,
            num_elites=num_elites)
        self._static_fns = static_fns
        self.f_fake_env = Forward_FakeEnv(self._forward_model,
                                          self._static_fns)
        self.b_fake_env = Backward_FakeEnv(self._backward_model,
                                           self._static_fns)

        self._forward_rollout_schedule = forward_rollout_schedule
        self._backward_rollout_schedule = backward_rollout_schedule
        self._beta_schedule = beta_schedule
        self._max_model_t = max_model_t

        self._model_retain_epochs = model_retain_epochs

        self._model_train_freq = model_train_freq
        self._rollout_batch_size = int(rollout_batch_size)
        self._deterministic = deterministic
        self._real_ratio = real_ratio

        self._log_dir = os.getcwd()

        self._training_environment = training_environment
        self._evaluation_environment = evaluation_environment
        self._policy = policy

        self._Qs = Qs
        self._Q_targets = tuple(tf.keras.models.clone_model(Q) for Q in Qs)

        self._pool = pool
        self._last_n_epoch = int(last_n_epoch)
        self._planning_horizon = int(planning_horizon)
        self._backward_policy_var = backward_policy_var

        self._plotter = plotter
        self._tf_summaries = tf_summaries

        self._policy_lr = lr
        self._Q_lr = lr

        self._reward_scale = reward_scale
        self._target_entropy = (
            -np.prod(self._training_environment.action_space.shape)
            if target_entropy == 'auto' else target_entropy)
        print('Target entropy: {}'.format(self._target_entropy))

        self._discount = discount
        self._tau = tau
        self._target_update_interval = target_update_interval
        self._action_prior = action_prior

        self._reparameterize = reparameterize
        self._store_extra_policy_info = store_extra_policy_info

        observation_shape = self._training_environment.active_observation_shape
        action_shape = self._training_environment.action_space.shape

        assert len(observation_shape) == 1, observation_shape
        self._observation_shape = observation_shape
        assert len(action_shape) == 1, action_shape
        self._action_shape = action_shape
        self.log_file = log_file

        self._build()

    def _build(self):
        self._training_ops = {}

        self._init_global_step()
        self._init_placeholders()
        self._init_actor_update()
        self._init_critic_update()
        self._build_backward_policy(self._act_dim)

    def _build_backward_policy(self, act_dim):
        self._max_logvar = tf.Variable(np.ones([1, act_dim]),
                                       dtype=tf.float32,
                                       name="max_log_var")
        self._min_logvar = tf.Variable(-np.ones([1, act_dim]) * 10.,
                                       dtype=tf.float32,
                                       name="min_log_var")
        self._before_action_mean, self._before_action_logvar = self._backward_policy_net(
            'backward_policy', self._next_observations_ph, act_dim)
        action_logvar = self._max_logvar - tf.nn.softplus(
            self._max_logvar - self._before_action_logvar)
        action_logvar = self._min_logvar + tf.nn.softplus(action_logvar -
                                                          self._min_logvar)
        self._before_action_var = tf.exp(action_logvar)
        self._backward_policy_params = tf.get_collection(
            tf.GraphKeys.GLOBAL_VARIABLES, scope='backward_policy')
        loss1 = tf.reduce_mean(
            tf.square(self._before_action_mean - self._actions_ph) /
            self._before_action_var)
        loss2 = tf.reduce_mean(tf.log(self._before_action_var))
        self._backward_policy_loss = loss1 + loss2
        self._backward_policy_optimizer = tf.train.AdamOptimizer(
            self._policy_lr).minimize(loss=self._backward_policy_loss,
                                      var_list=self._backward_policy_params)

    def _backward_policy_net(self, scope, state, action_dim, hidden_dim=256):
        with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
            hidden_layer1 = tf.layers.dense(state, hidden_dim, tf.nn.relu)
            hidden_layer2 = tf.layers.dense(hidden_layer1, hidden_dim,
                                            tf.nn.relu)
            return tf.tanh(tf.layers.dense(hidden_layer2, action_dim)), \
                   tf.layers.dense(hidden_layer2, action_dim)

    def _get_before_action(self, obs):
        before_action_mean, before_action_var = self._session.run(
            [self._before_action_mean, self._before_action_var],
            feed_dict={self._next_observations_ph: obs})
        if (self._backward_policy_var != 0):
            before_action_var = self._backward_policy_var
        X = stats.truncnorm(-2,
                            2,
                            loc=np.zeros_like(before_action_mean),
                            scale=np.ones_like(before_action_mean))
        before_actions = X.rvs(size=np.shape(before_action_mean)) * np.sqrt(
            before_action_var) + before_action_mean
        act = np.clip(before_actions, -1, 1)
        return act

    def _train(self):

        training_environment = self._training_environment
        evaluation_environment = self._evaluation_environment
        policy = self._policy
        pool = self._pool
        f_model_metrics, b_model_metrics = {}, {}

        if not self._training_started:
            self._init_training()

            self._initial_exploration_hook(training_environment,
                                           self._initial_exploration_policy,
                                           pool)

        self.sampler.initialize(training_environment, policy, pool)

        self._training_before_hook()

        for self._epoch in range(self._epoch, self._n_epochs):

            self._epoch_before_hook()

            start_samples = self.sampler._total_samples
            print("\033[0;31m%s%d\033[0m" % ('epoch: ', self._epoch))
            print('[ True Env Buffer Size ]', pool.size)
            for i in count():
                samples_now = self.sampler._total_samples
                self._timestep = samples_now - start_samples

                if samples_now >= start_samples + self._epoch_length and self.ready_to_train:
                    break

                self._timestep_before_hook()

                if self._timestep % self._model_train_freq == 0:
                    f_model_train_metrics, b_model_train_metrics = self._train_model(
                        batch_size=256,
                        max_epochs=None,
                        holdout_ratio=0.2,
                        max_t=self._max_model_t)
                    f_model_metrics.update(f_model_train_metrics)
                    b_model_metrics.update(b_model_train_metrics)
                    self._set_beta()
                    self._set_rollout_length()
                    self._reallocate_model_pool()
                    self._rollout_model(
                        rollout_batch_size=self._rollout_batch_size,
                        deterministic=self._deterministic)

                self._do_sampling(timestep=self._total_timestep)

                if self.ready_to_train:
                    self._do_training_repeats(timestep=self._total_timestep)

                self._timestep_after_hook()

            training_paths = self.sampler.get_last_n_paths(
                math.ceil(self._epoch_length / self.sampler._max_path_length))
            evaluation_paths = self._evaluation_paths(policy,
                                                      evaluation_environment)

            training_metrics = self._evaluate_rollouts(training_paths,
                                                       training_environment)
            if evaluation_paths:
                evaluation_metrics = self._evaluate_rollouts(
                    evaluation_paths, evaluation_environment)
            else:
                evaluation_metrics = {}

            self._epoch_after_hook(training_paths)

            sampler_diagnostics = self.sampler.get_diagnostics()

            diagnostics = self.get_diagnostics(
                iteration=self._total_timestep,
                batch=self._evaluation_batch(),
                training_paths=training_paths,
                evaluation_paths=evaluation_paths)

            diagnostics.update(
                OrderedDict((
                    *((f'evaluation/{key}', evaluation_metrics[key])
                      for key in sorted(evaluation_metrics.keys())),
                    *((f'training/{key}', training_metrics[key])
                      for key in sorted(training_metrics.keys())),
                    *((f'sampler/{key}', sampler_diagnostics[key])
                      for key in sorted(sampler_diagnostics.keys())),
                    *((f'forward-model/{key}', f_model_metrics[key])
                      for key in sorted(f_model_metrics.keys())),
                    *((f'backward-model/{key}', b_model_metrics[key])
                      for key in sorted(b_model_metrics.keys())),
                    ('epoch', self._epoch),
                    ('timestep', self._timestep),
                    ('timesteps_total', self._total_timestep),
                    ('train-steps', self._num_train_steps),
                )))

            if self._eval_render_mode is not None and hasattr(
                    evaluation_environment, 'render_rollouts'):
                training_environment.render_rollouts(evaluation_paths)
            print(diagnostics)
            f_log = open(self.log_file, 'a')
            f_log.write('epoch: %d\n' % self._epoch)
            f_log.write('total time steps: %d\n' % self._total_timestep)
            f_log.write('evaluation return: %f\n' %
                        evaluation_metrics['return-average'])
            f_log.close()

        self.sampler.terminate()

        self._training_after_hook()

    def train(self, *args, **kwargs):
        return self._train(*args, **kwargs)

    def _set_beta(self):
        min_epoch, max_epoch, min_beta, max_beta = self._beta_schedule
        if self._epoch <= min_epoch:
            y = min_beta
        else:
            dx = (self._epoch - min_epoch) / (max_epoch - min_epoch)
            dx = min(dx, 1)
            y = dx * (max_beta - min_beta) + min_beta
        self._beta = y

    def _set_rollout_length(self):
        #set backward rollout length
        min_epoch, max_epoch, min_length, max_length = self._backward_rollout_schedule
        if self._epoch <= min_epoch:
            y = min_length
        else:
            dx = (self._epoch - min_epoch) / (max_epoch - min_epoch)
            dx = min(dx, 1)
            y = dx * (max_length - min_length) + min_length

        self._backward_rollout_length = int(y)
        print(
            '[Backward Model Length ] Epoch: {} (min: {}, max: {}) | Length: {} (min: {} , max: {})'
            .format(self._epoch, min_epoch, max_epoch,
                    self._backward_rollout_length, min_length, max_length))
        # set forward rollout length
        min_epoch, max_epoch, min_length, max_length = self._forward_rollout_schedule
        if self._epoch <= min_epoch:
            y = min_length
        else:
            dx = (self._epoch - min_epoch) / (max_epoch - min_epoch)
            dx = min(dx, 1)
            y = dx * (max_length - min_length) + min_length

        self._forward_rollout_length = int(y)
        print(
            '[Forward Model Length ] Epoch: {} (min: {}, max: {}) | Length: {} (min: {} , max: {})'
            .format(self._epoch, min_epoch, max_epoch,
                    self._forward_rollout_length, min_length, max_length))

    def _get_start_obs(self, rollout_batch_size):
        batch = self.sampler.random_batch(rollout_batch_size)
        temp_obs = batch['observations']
        beta = self._beta
        if (beta == 0):  # randomly sample
            start_obs = temp_obs
        else:  # sample from a Boltzmann distribution
            values = np.squeeze(
                self._session.run(self._value,
                                  feed_dict={self._observations_ph: temp_obs}))
            values += np.squeeze(
                self._session.run(self._target_value,
                                  feed_dict={self._observations_ph: temp_obs}))
            beta = min(beta,
                       10 / (np.max(values) - np.min(values)))  # prevent NAN
            values = np.exp(values * beta)
            prob = values / np.sum(values)
            index = np.array(range(0, rollout_batch_size))
            start_obs = temp_obs[np.random.choice(a=index,
                                                  size=rollout_batch_size,
                                                  replace=True,
                                                  p=prob)]
        return start_obs

    def _reallocate_model_pool(self):
        obs_space = self._pool._observation_space
        act_space = self._pool._action_space

        rollouts_per_epoch = self._rollout_batch_size * self._epoch_length / self._model_train_freq
        model_steps_per_epoch = int(
            (self._forward_rollout_length + self._backward_rollout_length) *
            rollouts_per_epoch)
        new_pool_size = self._model_retain_epochs * model_steps_per_epoch

        if not hasattr(self, '_model_pool'):
            print(
                '[ Allocate Model Pool ] Initializing new model pool with size {:.2e}'
                .format(new_pool_size))
            self._model_pool = SimpleReplayPool(obs_space, act_space,
                                                new_pool_size)

        elif self._model_pool._max_size != new_pool_size:
            print(
                '[ Reallocate Model Pool ] Updating model pool | {:.2e} --> {:.2e}'
                .format(self._model_pool._max_size, new_pool_size))
            samples = self._model_pool.return_all_samples()
            new_pool = SimpleReplayPool(obs_space, act_space, new_pool_size)
            new_pool.add_samples(samples)
            assert self._model_pool.size == new_pool.size
            self._model_pool = new_pool

    def _train_model(self, **kwargs):
        env_samples = self._pool.return_all_samples()
        print('Training forward model:')
        train_inputs, train_outputs = format_samples_for_forward_training(
            env_samples)
        f_model_metrics = self._forward_model.train(train_inputs,
                                                    train_outputs, **kwargs)
        print('Training backward model:')
        train_inputs, train_outputs = format_samples_for_backward_training(
            env_samples)
        b_model_metrics = self._backward_model.train(train_inputs,
                                                     train_outputs, **kwargs)
        return f_model_metrics, b_model_metrics

    def _rollout_model(self, rollout_batch_size, **kwargs):
        print('[ Model Rollout ] Starting | Epoch: {} | Batch size: {}'.format(
            self._epoch, rollout_batch_size))
        start_obs = self._get_start_obs(rollout_batch_size)

        # perform backward rollout
        obs = start_obs
        for i in range(self._backward_rollout_length):
            act = self._get_before_action(obs)

            before_obs, rew, term, info = self.b_fake_env.step(
                obs, act, **kwargs)

            samples = {
                'observations': before_obs,
                'actions': act,
                'next_observations': obs,
                'rewards': rew,
                'terminals': term
            }
            self._model_pool.add_samples(samples)

            nonterm_mask = ~term.squeeze(-1)
            if nonterm_mask.sum() == 0:
                print('[ Model Rollout ] Breaking early: {} | {} / {}'.format(
                    i, nonterm_mask.sum(), nonterm_mask.shape))
                break
            obs = before_obs[nonterm_mask]

        # perform forward rollout
        obs = start_obs
        for i in range(self._forward_rollout_length):
            act = self._policy.actions_np(obs)

            next_obs, rew, term, info = self.f_fake_env.step(
                obs, act, **kwargs)

            samples = {
                'observations': obs,
                'actions': act,
                'next_observations': next_obs,
                'rewards': rew,
                'terminals': term
            }
            self._model_pool.add_samples(samples)

            nonterm_mask = ~term.squeeze(-1)
            if nonterm_mask.sum() == 0:
                print('[ Model Rollout ] Breaking early: {} | {} / {}'.format(
                    i, nonterm_mask.sum(), nonterm_mask.shape))
                break

            obs = next_obs[nonterm_mask]

        print(
            '[ Model Rollout ] Added: {:.1e} | Model pool: {:.1e} (max {:.1e}) | Train rep: {}'
            .format(
                (self._forward_rollout_length + self._backward_rollout_length)
                * rollout_batch_size, self._model_pool.size,
                self._model_pool._max_size, self._n_train_repeat))

    def _training_batch(self, batch_size=None):
        batch_size = batch_size or self.sampler._batch_size
        env_batch_size = int(batch_size * self._real_ratio)
        model_batch_size = batch_size - env_batch_size

        ## can sample from the env pool even if env_batch_size == 0
        env_batch = self._pool.random_batch(env_batch_size)

        if model_batch_size > 0:
            model_batch = self._model_pool.random_batch(model_batch_size)

            keys = env_batch.keys()
            batch = {
                k: np.concatenate((env_batch[k], model_batch[k]), axis=0)
                for k in keys
            }
        else:
            ## if real_ratio == 1.0, no model pool was ever allocated,
            ## so skip the model pool sampling
            batch = env_batch
        return batch

    def _init_global_step(self):
        self.global_step = training_util.get_or_create_global_step()
        self._training_ops.update(
            {'increment_global_step': training_util._increment_global_step(1)})

    def _init_placeholders(self):
        """Create input placeholders for the SAC algorithm.

        Creates `tf.placeholder`s for:
            - observation
            - next observation
            - action
            - reward
            - terminals
        """
        self._iteration_ph = tf.placeholder(tf.int64,
                                            shape=None,
                                            name='iteration')

        self._observations_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._observation_shape),
            name='observation',
        )

        self._next_observations_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._observation_shape),
            name='next_observation',
        )

        self._actions_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._action_shape),
            name='actions',
        )

        self._rewards_ph = tf.placeholder(
            tf.float32,
            shape=(None, 1),
            name='rewards',
        )

        self._terminals_ph = tf.placeholder(
            tf.float32,
            shape=(None, 1),
            name='terminals',
        )

        if self._store_extra_policy_info:
            self._log_pis_ph = tf.placeholder(
                tf.float32,
                shape=(None, 1),
                name='log_pis',
            )
            self._raw_actions_ph = tf.placeholder(
                tf.float32,
                shape=(None, *self._action_shape),
                name='raw_actions',
            )

    def _get_Q_target(self):
        next_actions = self._policy.actions([self._next_observations_ph])
        next_log_pis = self._policy.log_pis([self._next_observations_ph],
                                            next_actions)

        next_Qs_values = tuple(
            Q([self._next_observations_ph, next_actions])
            for Q in self._Q_targets)

        min_next_Q = tf.reduce_min(next_Qs_values, axis=0)
        next_value = min_next_Q - self._alpha * next_log_pis

        Q_target = td_target(reward=self._reward_scale * self._rewards_ph,
                             discount=self._discount,
                             next_value=(1 - self._terminals_ph) * next_value)

        return Q_target

    def _init_critic_update(self):
        """Create minimization operation for critic Q-function.

        Creates a `tf.optimizer.minimize` operation for updating
        critic Q-function with gradient descent, and appends it to
        `self._training_ops` attribute.
        """
        Q_target = tf.stop_gradient(self._get_Q_target())

        assert Q_target.shape.as_list() == [None, 1]

        Q_values = self._Q_values = tuple(
            Q([self._observations_ph, self._actions_ph]) for Q in self._Qs)

        Q_losses = self._Q_losses = tuple(
            tf.losses.mean_squared_error(
                labels=Q_target, predictions=Q_value, weights=0.5)
            for Q_value in Q_values)

        self._Q_optimizers = tuple(
            tf.train.AdamOptimizer(learning_rate=self._Q_lr,
                                   name='{}_{}_optimizer'.format(Q._name, i))
            for i, Q in enumerate(self._Qs))
        Q_training_ops = tuple(
            tf.contrib.layers.optimize_loss(Q_loss,
                                            self.global_step,
                                            learning_rate=self._Q_lr,
                                            optimizer=Q_optimizer,
                                            variables=Q.trainable_variables,
                                            increment_global_step=False,
                                            summaries=((
                                                "loss", "gradients",
                                                "gradient_norm",
                                                "global_gradient_norm"
                                            ) if self._tf_summaries else ()))
            for i, (Q, Q_loss, Q_optimizer) in enumerate(
                zip(self._Qs, Q_losses, self._Q_optimizers)))

        self._training_ops.update({'Q': tf.group(Q_training_ops)})

    def _init_actor_update(self):
        """Create minimization operations for policy and entropy.

        Creates a `tf.optimizer.minimize` operations for updating
        policy and entropy with gradient descent, and adds them to
        `self._training_ops` attribute.
        """

        actions = self._policy.actions([self._observations_ph])
        log_pis = self._policy.log_pis([self._observations_ph], actions)
        self._actions = actions

        assert log_pis.shape.as_list() == [None, 1]

        log_alpha = self._log_alpha = tf.get_variable('log_alpha',
                                                      dtype=tf.float32,
                                                      initializer=0.0)
        alpha = tf.exp(log_alpha)

        if isinstance(self._target_entropy, Number):
            alpha_loss = -tf.reduce_mean(
                log_alpha * tf.stop_gradient(log_pis + self._target_entropy))

            self._alpha_optimizer = tf.train.AdamOptimizer(
                self._policy_lr, name='alpha_optimizer')
            self._alpha_train_op = self._alpha_optimizer.minimize(
                loss=alpha_loss, var_list=[log_alpha])

            self._training_ops.update(
                {'temperature_alpha': self._alpha_train_op})

        self._alpha = alpha

        if self._action_prior == 'normal':
            policy_prior = tf.contrib.distributions.MultivariateNormalDiag(
                loc=tf.zeros(self._action_shape),
                scale_diag=tf.ones(self._action_shape))
            policy_prior_log_probs = policy_prior.log_prob(actions)
        elif self._action_prior == 'uniform':
            policy_prior_log_probs = 0.0

        Q_log_targets = tuple(
            Q([self._observations_ph, actions]) for Q in self._Qs)
        min_Q_log_target = tf.reduce_min(Q_log_targets, axis=0)

        self._value = tf.reduce_mean(Q_log_targets, axis=0)
        self._target_value = tf.reduce_mean(tuple(
            Q([self._observations_ph, actions]) for Q in self._Q_targets),
                                            axis=0)

        if self._reparameterize:
            policy_kl_losses = (alpha * log_pis - min_Q_log_target -
                                policy_prior_log_probs)
        else:
            raise NotImplementedError

        assert policy_kl_losses.shape.as_list() == [None, 1]

        policy_loss = tf.reduce_mean(policy_kl_losses)

        self._policy_optimizer = tf.train.AdamOptimizer(
            learning_rate=self._policy_lr, name="policy_optimizer")
        policy_train_op = tf.contrib.layers.optimize_loss(
            policy_loss,
            self.global_step,
            learning_rate=self._policy_lr,
            optimizer=self._policy_optimizer,
            variables=self._policy.trainable_variables,
            increment_global_step=False,
            summaries=("loss", "gradients", "gradient_norm",
                       "global_gradient_norm") if self._tf_summaries else ())

        self._training_ops.update({'policy_train_op': policy_train_op})

    def _init_training(self):
        self._update_target()

    def _update_target(self):
        tau = self._tau

        for Q, Q_target in zip(self._Qs, self._Q_targets):
            source_params = Q.get_weights()
            target_params = Q_target.get_weights()
            Q_target.set_weights([
                tau * source + (1.0 - tau) * target
                for source, target in zip(source_params, target_params)
            ])

    def _do_training(self, iteration, batch):
        """Runs the operations for updating training and target ops."""

        # self._training_progress.update()
        # self._training_progress.set_description()

        feed_dict = self._get_feed_dict(iteration, batch)
        self._session.run(self._training_ops, feed_dict)

        if iteration % self._target_update_interval == 0:
            # Run target ops here.
            self._update_target()

    def _do_training_repeats(self, timestep, backward_policy_train_repeat=1):
        """Repeat training _n_train_repeat times every _train_every_n_steps"""
        if timestep % self._train_every_n_steps > 0: return
        trained_enough = (self._train_steps_this_epoch >
                          self._max_train_repeat_per_timestep * self._timestep)
        if trained_enough: return

        for i in range(self._n_train_repeat):
            self._do_training(iteration=timestep, batch=self._training_batch())

        for i in range(backward_policy_train_repeat):
            batch = self._pool.last_n_random_batch(last_n=self._epoch_length *
                                                   self._last_n_epoch,
                                                   batch_size=256)
            next_observations = np.array(batch['next_observations'])
            actions = np.array(batch['actions'])
            feed_dict = {
                self._actions_ph: actions,
                self._next_observations_ph: next_observations,
            }
            self._session.run(self._backward_policy_optimizer, feed_dict)

        self._num_train_steps += self._n_train_repeat
        self._train_steps_this_epoch += self._n_train_repeat

    def _get_feed_dict(self, iteration, batch):
        """Construct TensorFlow feed_dict from sample batch."""

        feed_dict = {
            self._observations_ph: batch['observations'],
            self._actions_ph: batch['actions'],
            self._next_observations_ph: batch['next_observations'],
            self._rewards_ph: batch['rewards'],
            self._terminals_ph: batch['terminals'],
        }

        if self._store_extra_policy_info:
            feed_dict[self._log_pis_ph] = batch['log_pis']
            feed_dict[self._raw_actions_ph] = batch['raw_actions']

        if iteration is not None:
            feed_dict[self._iteration_ph] = iteration

        return feed_dict

    def get_diagnostics(self, iteration, batch, training_paths,
                        evaluation_paths):
        """Return diagnostic information as ordered dictionary.

        Records mean and standard deviation of Q-function and state
        value function, and TD-loss (mean squared Bellman error)
        for the sample batch.

        Also calls the `draw` method of the plotter, if plotter defined.
        """

        feed_dict = self._get_feed_dict(iteration, batch)

        (Q_values, Q_losses, alpha, global_step) = self._session.run(
            (self._Q_values, self._Q_losses, self._alpha, self.global_step),
            feed_dict)

        diagnostics = OrderedDict({
            'Q-avg': np.mean(Q_values),
            'Q-std': np.std(Q_values),
            'Q_loss': np.mean(Q_losses),
            'alpha': alpha,
        })

        policy_diagnostics = self._policy.get_diagnostics(
            batch['observations'])
        diagnostics.update({
            f'policy/{key}': value
            for key, value in policy_diagnostics.items()
        })

        if self._plotter:
            self._plotter.draw()

        return diagnostics

    @property
    def tf_saveables(self):
        saveables = {
            '_policy_optimizer': self._policy_optimizer,
            **{
                f'Q_optimizer_{i}': optimizer
                for i, optimizer in enumerate(self._Q_optimizers)
            },
            '_log_alpha': self._log_alpha,
        }

        if hasattr(self, '_alpha_optimizer'):
            saveables['_alpha_optimizer'] = self._alpha_optimizer

        return saveables
Beispiel #4
0
class MOPO(RLAlgorithm):
    """Model-based Offline Policy Optimization (MOPO)

    References
    ----------
        Tianhe Yu, Garrett Thomas, Lantao Yu, Stefano Ermon, James Zou, Sergey Levine, Chelsea Finn, Tengyu Ma. 
        MOPO: Model-based Offline Policy Optimization. 
        arXiv preprint arXiv:2005.13239. 2020.
    """
    def __init__(
            self,
            training_environment,
            evaluation_environment,
            policy,
            Qs,
            pool,
            static_fns,
            plotter=None,
            tf_summaries=False,
            lr=3e-4,
            reward_scale=1.0,
            target_entropy='auto',
            discount=0.99,
            tau=5e-3,
            target_update_interval=1,
            action_prior='uniform',
            reparameterize=False,
            store_extra_policy_info=False,
            adapt=False,
            gru_state_dim=256,
            network_kwargs=None,
            deterministic=False,
            rollout_random=False,
            model_train_freq=250,
            num_networks=7,
            num_elites=5,
            model_retain_epochs=20,
            rollout_batch_size=100e3,
            real_ratio=0.1,
            # rollout_schedule=[20,100,1,1],
            rollout_length=1,
            hidden_dim=200,
            max_model_t=None,
            model_type='mlp',
            separate_mean_var=False,
            identity_terminal=0,
            pool_load_path='',
            pool_load_max_size=0,
            model_name=None,
            model_load_dir=None,
            penalty_coeff=0.,
            penalty_learned_var=False,
            **kwargs):
        """
        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy: A policy function approximator.
            initial_exploration_policy: ('Policy'): A policy that we use
                for initial exploration which is not trained by the algorithm.
            Qs: Q-function approximators. The min of these
                approximators will be used. Usage of at least two Q-functions
                improves performance by reducing overestimation bias.
            pool (`PoolBase`): Replay pool to add gathered samples to.
            plotter (`QFPolicyPlotter`): Plotter instance to be used for
                visualizing Q-function during training.
            lr (`float`): Learning rate used for the function approximators.
            discount (`float`): Discount factor for Q-function updates.
            tau (`float`): Soft value function target update weight.
            target_update_interval ('int'): Frequency at which target network
                updates occur in iterations.
            reparameterize ('bool'): If True, we use a gradient estimator for
                the policy derived using the reparameterization trick. We use
                a likelihood ratio based estimator otherwise.
        """

        super(MOPO, self).__init__(**kwargs)
        print("[ DEBUG ]: model name: {}".format(model_name))
        if '_smv' in model_name:
            self._env_name = model_name[:-8] + '-v0'
        else:
            self._env_name = model_name[:-4] + '-v0'
        if self._env_name in infos.REF_MIN_SCORE:
            self.min_ret = infos.REF_MIN_SCORE[self._env_name]
            self.max_ret = infos.REF_MAX_SCORE[self._env_name]
        else:
            self.min_ret = self.max_ret = 0
        obs_dim = np.prod(training_environment.active_observation_shape)
        act_dim = np.prod(training_environment.action_space.shape)
        self._model_type = model_type
        self._identity_terminal = identity_terminal
        self._model = construct_model(obs_dim=obs_dim,
                                      act_dim=act_dim,
                                      hidden_dim=hidden_dim,
                                      num_networks=num_networks,
                                      num_elites=num_elites,
                                      model_type=model_type,
                                      separate_mean_var=separate_mean_var,
                                      name=model_name,
                                      load_dir=model_load_dir,
                                      deterministic=deterministic)
        print('[ MOPO ]: got self._model')
        self._static_fns = static_fns
        self.fake_env = FakeEnv(self._model,
                                self._static_fns,
                                penalty_coeff=penalty_coeff,
                                penalty_learned_var=penalty_learned_var)

        self._rollout_schedule = [20, 100, rollout_length, rollout_length]
        self._max_model_t = max_model_t

        self._model_retain_epochs = model_retain_epochs

        self._model_train_freq = model_train_freq
        self._rollout_batch_size = int(rollout_batch_size)
        self._deterministic = deterministic
        self._rollout_random = rollout_random
        self._real_ratio = real_ratio
        # TODO: RLA writer (implemented with tf) should be compatible with the Writer object (implemented with tbx)
        self._log_dir = tester.log_dir
        # self._writer = tester.writer
        self._writer = Writer(self._log_dir)

        self._training_environment = training_environment
        self._evaluation_environment = evaluation_environment
        self.gru_state_dim = gru_state_dim
        self.network_kwargs = network_kwargs
        self.adapt = adapt
        self.optim_alpha = False
        # self._policy = policy

        # self._Qs = Qs
        # self._Q_targets = tuple(tf.keras.models.clone_model(Q) for Q in Qs)

        self._pool = pool
        self._plotter = plotter
        self._tf_summaries = tf_summaries

        self._policy_lr = lr
        self._Q_lr = lr

        self._reward_scale = reward_scale
        self._target_entropy = (
            -np.prod(self._training_environment.action_space.shape)
            if target_entropy == 'auto' else target_entropy)
        print('[ MOPO ] Target entropy: {}'.format(self._target_entropy))

        self._discount = discount
        self._tau = tau
        self._target_update_interval = target_update_interval
        self._action_prior = action_prior

        self._reparameterize = reparameterize
        self._store_extra_policy_info = store_extra_policy_info

        observation_shape = self._training_environment.active_observation_shape
        action_shape = self._training_environment.action_space.shape

        assert len(observation_shape) == 1, observation_shape
        self._observation_shape = observation_shape
        assert len(action_shape) == 1, action_shape
        self._action_shape = action_shape

        self._build()

        #### load replay pool data
        self._pool_load_path = pool_load_path
        self._pool_load_max_size = pool_load_max_size

        loader.restore_pool(self._pool,
                            self._pool_load_path,
                            self._pool_load_max_size,
                            save_path=self._log_dir)
        self._init_pool_size = self._pool.size
        print('[ MOPO ] Starting with pool size: {}'.format(
            self._init_pool_size))
        ####

    def _build(self):

        self._training_ops = {}
        # place holder
        self.global_step = training_util.get_or_create_global_step()
        self._training_ops.update(
            {'increment_global_step': training_util._increment_global_step(1)})

        self._iteration_ph = tf.placeholder(tf.int64,
                                            shape=None,
                                            name='iteration')

        self._observations_ph = tf.placeholder(
            tf.float32,
            shape=(None, None, *self._observation_shape),
            name='observation',
        )

        self._next_observations_ph = tf.placeholder(
            tf.float32,
            shape=(None, None, *self._observation_shape),
            name='next_observation',
        )

        self._actions_ph = tf.placeholder(
            tf.float32,
            shape=(None, None, *self._action_shape),
            name='actions',
        )

        self._prev_state_p_ph = tf.placeholder(
            tf.float32,
            shape=(None, self.gru_state_dim),
            name='prev_state_p',
        )
        self._prev_state_v_ph = tf.placeholder(
            tf.float32,
            shape=(None, self.gru_state_dim),
            name='prev_state_v',
        )

        self.seq_len = tf.placeholder(tf.float32, shape=[None], name="seq_len")

        self._rewards_ph = tf.placeholder(
            tf.float32,
            shape=(None, None, 1),
            name='rewards',
        )

        self._terminals_ph = tf.placeholder(
            tf.float32,
            shape=(None, None, 1),
            name='terminals',
        )

        if self._store_extra_policy_info:
            self._log_pis_ph = tf.placeholder(
                tf.float32,
                shape=(None, None, 1),
                name='log_pis',
            )
            self._raw_actions_ph = tf.placeholder(
                tf.float32,
                shape=(None, None, *self._action_shape),
                name='raw_actions',
            )

        # inner functions
        LOG_STD_MAX = 2
        LOG_STD_MIN = -20
        EPS = 1e-8

        def mlp(x,
                hidden_sizes=(32, ),
                activation=tf.tanh,
                output_activation=None,
                kernel_initializer=None):
            print('[ DEBUG ], hidden layer size: ', hidden_sizes)
            for h in hidden_sizes[:-1]:
                x = tf.layers.dense(x,
                                    units=h,
                                    activation=activation,
                                    kernel_initializer=kernel_initializer)
            return tf.layers.dense(x,
                                   units=hidden_sizes[-1],
                                   activation=output_activation,
                                   kernel_initializer=kernel_initializer)

        def gaussian_likelihood(x, mu, log_std):
            pre_sum = -0.5 * (
                ((x - mu) /
                 (tf.exp(log_std) + EPS))**2 + 2 * log_std + np.log(2 * np.pi))
            return tf.reduce_sum(pre_sum, axis=-1)

        def apply_squashing_func(mu, pi, logp_pi):
            # Adjustment to log prob
            # NOTE: This formula is a little bit magic. To get an understanding of where it
            # comes from, check out the original SAC paper (arXiv 1801.01290) and look in
            # appendix C. This is a more numerically-stable equivalent to Eq 21.
            # Try deriving it yourself as a (very difficult) exercise. :)
            logp_pi -= tf.reduce_sum(
                2 * (np.log(2) - pi - tf.nn.softplus(-2 * pi)), axis=-1)
            # Squash those unbounded actions!
            mu = tf.tanh(mu)
            pi = tf.tanh(pi)
            return mu, pi, logp_pi

        def mlp_gaussian_policy(x, a, hidden_sizes, activation,
                                output_activation):
            print('[ DEBUG ]: output activation: ', output_activation,
                  ', activation: ', activation)
            act_dim = a.shape.as_list()[-1]
            net = mlp(x, list(hidden_sizes), activation, activation)
            mu = tf.layers.dense(net, act_dim, activation=output_activation)
            log_std = tf.layers.dense(net, act_dim, activation=None)
            log_std = tf.clip_by_value(log_std, LOG_STD_MIN, LOG_STD_MAX)
            std = tf.exp(log_std)
            pi = mu + tf.random_normal(tf.shape(mu)) * std
            logp_pi = gaussian_likelihood(pi, mu, log_std)
            return mu, pi, logp_pi, std

        def mlp_actor_critic(x,
                             x_v,
                             a,
                             hidden_sizes=(256, 256),
                             activation=tf.nn.relu,
                             output_activation=None,
                             policy=mlp_gaussian_policy):
            # policy
            with tf.variable_scope('pi'):
                mu, pi, logp_pi, std = policy(x, a, hidden_sizes, activation,
                                              output_activation)
                mu, pi, logp_pi = apply_squashing_func(mu, pi, logp_pi)

            # vfs
            vf_mlp = lambda x: tf.squeeze(
                mlp(x,
                    list(hidden_sizes) + [1], activation, None), axis=-1)

            with tf.variable_scope('q1'):
                q1 = vf_mlp(tf.concat([x_v, a], axis=-1))
            with tf.variable_scope('q2'):
                q2 = vf_mlp(tf.concat([x_v, a], axis=-1))
            return mu, pi, logp_pi, q1, q2, std

        policy_state1 = self._observations_ph
        value_state1 = self._observations_ph
        policy_state2 = value_state2 = self._next_observations_ph

        ac_kwargs = {
            "hidden_sizes": self.network_kwargs["hidden_sizes"],
            "activation": self.network_kwargs["activation"],
            "output_activation": self.network_kwargs["output_activation"]
        }

        with tf.variable_scope('main', reuse=False):
            self.mu, self.pi, logp_pi, q1, q2, std = mlp_actor_critic(
                policy_state1, value_state1, self._actions_ph, **ac_kwargs)

        pi_entropy = tf.reduce_sum(tf.log(std + 1e-8) +
                                   0.5 * tf.log(2 * np.pi * np.e),
                                   axis=-1)
        with tf.variable_scope('main', reuse=True):
            # compose q with pi, for pi-learning
            _, _, _, q1_pi, q2_pi, _ = mlp_actor_critic(
                policy_state1, value_state1, self.pi, **ac_kwargs)
            # get actions and log probs of actions for next states, for Q-learning
            _, pi_next, logp_pi_next, _, _, _ = mlp_actor_critic(
                policy_state2, value_state2, self._actions_ph, **ac_kwargs)

        with tf.variable_scope('target'):
            # target q values, using actions from *current* policy
            _, _, _, q1_targ, q2_targ, _ = mlp_actor_critic(
                policy_state2, value_state2, pi_next, **ac_kwargs)

        # actions = self._policy.actions([self._observations_ph])
        # log_pis = self._policy.log_pis([self._observations_ph], actions)
        # assert log_pis.shape.as_list() == [None, 1]

        # alpha optimizer
        log_alpha = self._log_alpha = tf.get_variable('log_alpha',
                                                      dtype=tf.float32,
                                                      initializer=0.0)
        alpha = tf.exp(log_alpha)

        self._alpha = alpha
        assert self._action_prior == 'uniform'
        policy_prior_log_probs = 0.0

        min_q_pi = tf.minimum(q1_pi, q2_pi)
        min_q_targ = tf.minimum(q1_targ, q2_targ)

        if self._reparameterize:
            policy_kl_losses = (tf.stop_gradient(alpha) * logp_pi - min_q_pi -
                                policy_prior_log_probs)
        else:
            raise NotImplementedError

        policy_loss = tf.reduce_mean(policy_kl_losses)

        # Q
        next_log_pis = logp_pi_next
        min_next_Q = min_q_targ
        next_value = min_next_Q - self._alpha * next_log_pis

        q_target = td_target(
            reward=self._reward_scale * self._rewards_ph[..., 0],
            discount=self._discount,
            next_value=(1 - self._terminals_ph[..., 0]) * next_value)

        print('q1_pi: {}, q2_pi: {}, policy_state2: {}, policy_state1: {}, '
              'tmux a: {}, q_targ: {}, mu: {}, reward: {}, '
              'terminal: {}, target_q: {}, next_value: {}, '
              'q1: {}, logp_pi: {}, min_q_pi: {}'.format(
                  q1_pi, q2_pi, policy_state1, policy_state2, pi_next, q1_targ,
                  self.mu, self._rewards_ph[..., 0], self._terminals_ph[...,
                                                                        0],
                  q_target, next_value, q1, logp_pi, min_q_pi))
        # assert q_target.shape.as_list() == [None, 1]
        # (self._Q_values,
        #  self._Q_losses,
        #  self._alpha,
        #  self.global_step),
        self.Q1 = q1
        self.Q2 = q2
        q_target = tf.stop_gradient(q_target)
        q1_loss = tf.losses.mean_squared_error(labels=q_target,
                                               predictions=q1,
                                               weights=0.5)
        q2_loss = tf.losses.mean_squared_error(labels=q_target,
                                               predictions=q2,
                                               weights=0.5)
        self.Q_loss = (q1_loss + q2_loss) / 2

        value_optimizer1 = tf.train.AdamOptimizer(learning_rate=self._Q_lr)
        value_optimizer2 = tf.train.AdamOptimizer(learning_rate=self._Q_lr)
        print('[ DEBUG ]: Q lr is {}'.format(self._Q_lr))

        # train_value_op = value_optimizer.apply_gradients(zip(grads, variables))
        pi_optimizer = tf.train.AdamOptimizer(learning_rate=self._policy_lr)
        print('[ DEBUG ]: policy lr is {}'.format(self._policy_lr))

        pi_var_list = get_vars('main/pi')
        if self.adapt:
            pi_var_list += get_vars("lstm_net_pi")
        train_pi_op = pi_optimizer.minimize(policy_loss, var_list=pi_var_list)
        pgrads, variables = zip(
            *pi_optimizer.compute_gradients(policy_loss, var_list=pi_var_list))

        _, pi_global_norm = tf.clip_by_global_norm(pgrads, 2000)

        with tf.control_dependencies([train_pi_op]):
            value_params1 = get_vars('main/q1')
            value_params2 = get_vars('main/q2')
            if self.adapt:
                value_params1 += get_vars("lstm_net_v")
                value_params2 += get_vars("lstm_net_v")

            grads, variables = zip(*value_optimizer1.compute_gradients(
                self.Q_loss, var_list=value_params1))
            _, q_global_norm = tf.clip_by_global_norm(grads, 2000)
            train_value_op1 = value_optimizer1.minimize(q1_loss,
                                                        var_list=value_params1)
            train_value_op2 = value_optimizer2.minimize(q2_loss,
                                                        var_list=value_params2)
            with tf.control_dependencies([train_value_op1, train_value_op2]):
                if isinstance(self._target_entropy, Number):
                    alpha_loss = -tf.reduce_mean(
                        log_alpha *
                        tf.stop_gradient(logp_pi + self._target_entropy))
                    self._alpha_optimizer = tf.train.AdamOptimizer(
                        self._policy_lr, name='alpha_optimizer')
                    self._alpha_train_op = self._alpha_optimizer.minimize(
                        loss=alpha_loss, var_list=[log_alpha])
                else:
                    self._alpha_train_op = tf.no_op()

        self.target_update = tf.group([
            tf.assign(v_targ, (1 - self._tau) * v_targ + self._tau * v_main)
            for v_main, v_targ in zip(get_vars('main'), get_vars('target'))
        ])

        self.target_init = tf.group([
            tf.assign(v_targ, v_main)
            for v_main, v_targ in zip(get_vars('main'), get_vars('target'))
        ])

        # construct opt
        self._training_ops = [
            tf.group((train_value_op2, train_value_op1, train_pi_op,
                      self._alpha_train_op)), {
                          "sac_pi/pi_global_norm": pi_global_norm,
                          "sac_Q/q_global_norm": q_global_norm,
                          "Q/q1_loss": q1_loss,
                          "sac_Q/q2_loss": q2_loss,
                          "sac_Q/q1": q1,
                          "sac_Q/q2": q2,
                          "sac_pi/alpha": alpha,
                          "sac_pi/pi_entropy": pi_entropy,
                          "sac_pi/logp_pi": logp_pi,
                          "sac_pi/std": logp_pi,
                      }
        ]

        self._session.run(tf.global_variables_initializer())
        self._session.run(self.target_init)

    def get_action_meta(self, state, hidden, deterministic=False):
        with self._session.as_default():
            state_dim = len(np.shape(state))
            if state_dim == 2:
                state = state[None]
            feed_dict = {
                self._observations_ph: state,
                self._prev_state_p_ph: hidden
            }
            mu, pi = self._session.run([self.mu, self.pi], feed_dict=feed_dict)
            if state_dim == 2:
                mu = mu[0]
                pi = pi[0]
            # print(f"[ DEBUG ]: pi_shape: {pi.shape}, mu_shape: {mu.shape}")
            if deterministic:
                return mu, hidden
            else:
                return pi, hidden

    def make_init_hidden(self, batch_size=1):
        return np.zeros((batch_size, self.gru_state_dim))

    def _train(self):
        """Return a generator that performs RL training.

        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy (`Policy`): Policy used for training
            initial_exploration_policy ('Policy'): Policy used for exploration
                If None, then all exploration is done using policy
            pool (`PoolBase`): Sample pool to add samples to
        """
        training_environment = self._training_environment
        evaluation_environment = self._evaluation_environment
        # policy = self._policy
        pool = self._pool
        model_metrics = {}

        # if not self._training_started:
        self._init_training()

        # TODO: change policy to placeholder or a function
        def get_action(state, hidden, deterministic=False):
            return self.get_action_meta(state, hidden, deterministic)

        def make_init_hidden(batch_size=1):
            return self.make_init_hidden(batch_size)

        self.sampler.initialize(training_environment,
                                (get_action, make_init_hidden), pool)

        gt.reset_root()
        gt.rename_root('RLAlgorithm')
        gt.set_def_unique(False)

        # self._training_before_hook()

        #### model training
        print('[ MOPO ] log_dir: {} | ratio: {}'.format(
            self._log_dir, self._real_ratio))
        print(
            '[ MOPO ] Training model at epoch {} | freq {} | timestep {} (total: {})'
            .format(self._epoch, self._model_train_freq, self._timestep,
                    self._total_timestep))
        # train dynamics model offline
        max_epochs = 1 if self._model.model_loaded else None
        model_train_metrics = self._train_model(batch_size=256,
                                                max_epochs=max_epochs,
                                                holdout_ratio=0.2,
                                                max_t=self._max_model_t)

        model_metrics.update(model_train_metrics)
        self._log_model()
        gt.stamp('epoch_train_model')
        ####
        tester.time_step_holder.set_time(0)
        for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)):

            self._epoch_before_hook()
            gt.stamp('epoch_before_hook')

            self._training_progress = Progress(self._epoch_length *
                                               self._n_train_repeat)
            start_samples = self.sampler._total_samples
            training_logs = {}
            for timestep in count():
                self._timestep = timestep
                if (timestep >= self._epoch_length and self.ready_to_train):
                    break

                self._timestep_before_hook()
                gt.stamp('timestep_before_hook')

                ## model rollouts
                if timestep % self._model_train_freq == 0 and self._real_ratio < 1.0:
                    self._training_progress.pause()
                    self._set_rollout_length()
                    self._reallocate_model_pool()
                    model_rollout_metrics = self._rollout_model(
                        rollout_batch_size=self._rollout_batch_size,
                        deterministic=self._deterministic)
                    model_metrics.update(model_rollout_metrics)

                    gt.stamp('epoch_rollout_model')
                    self._training_progress.resume()

                ## train actor and critic
                if self.ready_to_train:
                    # print('[ DEBUG ]: ready to train at timestep: {}'.format(timestep))
                    training_logs = self._do_training_repeats(
                        timestep=timestep)
                gt.stamp('train')

                self._timestep_after_hook()
                gt.stamp('timestep_after_hook')

            training_paths = self.sampler.get_last_n_paths(
                math.ceil(self._epoch_length / self.sampler._max_path_length))
            # evaluate the polices
            evaluation_paths = self._evaluation_paths(
                (lambda _state, _hidden: get_action(_state, _hidden, True),
                 make_init_hidden), evaluation_environment)
            gt.stamp('evaluation_paths')
            if evaluation_paths:
                evaluation_metrics = self._evaluate_rollouts(
                    evaluation_paths, evaluation_environment)
                gt.stamp('evaluation_metrics')
            else:
                evaluation_metrics = {}

            gt.stamp('epoch_after_hook')

            sampler_diagnostics = self.sampler.get_diagnostics()

            diagnostics = self.get_diagnostics(
                iteration=self._total_timestep,
                batch=self._evaluation_batch(),
                training_paths=training_paths,
                evaluation_paths=evaluation_paths)

            time_diagnostics = gt.get_times().stamps.itrs

            diagnostics.update(
                OrderedDict(
                    (*(('evaluation/{}'.format(key), evaluation_metrics[key])
                       for key in sorted(evaluation_metrics.keys())),
                     *(('times/{}'.format(key), time_diagnostics[key][-1])
                       for key in sorted(time_diagnostics.keys())),
                     *(('sampler/{}'.format(key), sampler_diagnostics[key])
                       for key in sorted(sampler_diagnostics.keys())),
                     *(('model/{}'.format(key), model_metrics[key])
                       for key in sorted(model_metrics.keys())),
                     ('epoch', self._epoch), ('timestep', self._timestep),
                     ('timesteps_total',
                      self._total_timestep), ('train-steps',
                                              self._num_train_steps),
                     *(('training/{}'.format(key), training_logs[key])
                       for key in sorted(training_logs.keys())))))
            diagnostics['perf/AverageReturn'] = diagnostics[
                'evaluation/return-average']
            diagnostics['perf/AverageLength'] = diagnostics[
                'evaluation/episode-length-avg']
            if not self.min_ret == self.max_ret:
                diagnostics['perf/NormalizedReturn'] = (diagnostics['perf/AverageReturn'] - self.min_ret) \
                                                       / (self.max_ret - self.min_ret)
            # diagnostics['keys/logp_pi'] =  diagnostics['training/sac_pi/logp_pi']
            if self._eval_render_mode is not None and hasattr(
                    evaluation_environment, 'render_rollouts'):
                training_environment.render_rollouts(evaluation_paths)

            ## ensure we did not collect any more data
            assert self._pool.size == self._init_pool_size
            for k, v in diagnostics.items():
                # print('[ DEBUG ] epoch: {} diagnostics k: {}, v: {}'.format(self._epoch, k, v))
                self._writer.add_scalar(k, v, self._epoch)
            yield diagnostics

        self.sampler.terminate()

        self._training_after_hook()

        self._training_progress.close()

        yield {'done': True, **diagnostics}

    def train(self, *args, **kwargs):
        return self._train(*args, **kwargs)

    def _log_policy(self):
        # TODO: how to saving models
        save_path = os.path.join(self._log_dir, 'models')
        filesystem.mkdir(save_path)
        weights = self._policy.get_weights()
        data = {'policy_weights': weights}
        full_path = os.path.join(save_path,
                                 'policy_{}.pkl'.format(self._total_timestep))
        print('Saving policy to: {}'.format(full_path))
        pickle.dump(data, open(full_path, 'wb'))

    def _log_model(self):
        print('[ MODEL ]: {}'.format(self._model_type))
        if self._model_type == 'identity':
            print('[ MOPO ] Identity model, skipping save')
        elif self._model.model_loaded:
            print('[ MOPO ] Loaded model, skipping save')
        else:
            save_path = os.path.join(self._log_dir, 'models')
            filesystem.mkdir(save_path)
            print('[ MOPO ] Saving model to: {}'.format(save_path))
            self._model.save(save_path, self._total_timestep)

    def _set_rollout_length(self):
        min_epoch, max_epoch, min_length, max_length = self._rollout_schedule
        if self._epoch <= min_epoch:
            y = min_length
        else:
            dx = (self._epoch - min_epoch) / (max_epoch - min_epoch)
            dx = min(dx, 1)
            y = dx * (max_length - min_length) + min_length

        self._rollout_length = int(y)
        print(
            '[ Model Length ] Epoch: {} (min: {}, max: {}) | Length: {} (min: {} , max: {})'
            .format(self._epoch, min_epoch, max_epoch, self._rollout_length,
                    min_length, max_length))

    def _reallocate_model_pool(self):
        obs_space = self._pool._observation_space
        act_space = self._pool._action_space

        rollouts_per_epoch = self._rollout_batch_size * self._epoch_length / self._model_train_freq
        model_steps_per_epoch = int(self._rollout_length * rollouts_per_epoch)
        new_pool_size = self._model_retain_epochs * model_steps_per_epoch

        if not hasattr(self, '_model_pool'):
            print(
                '[ MOPO ] Initializing new model pool with size {:.2e}'.format(
                    new_pool_size))
            self._model_pool = SimpleReplayPool(obs_space, act_space,
                                                new_pool_size)

        elif self._model_pool._max_size != new_pool_size:
            print('[ MOPO ] Updating model pool | {:.2e} --> {:.2e}'.format(
                self._model_pool._max_size, new_pool_size))
            samples = self._model_pool.return_all_samples()
            new_pool = SimpleReplayPool(obs_space, act_space, new_pool_size)
            new_pool.add_samples(samples)
            assert self._model_pool.size == new_pool.size
            self._model_pool = new_pool

    def _train_model(self, **kwargs):
        if self._model_type == 'identity':
            print('[ MOPO ] Identity model, skipping model')
            model_metrics = {}
        else:
            env_samples = self._pool.return_all_samples()
            train_inputs, train_outputs = format_samples_for_training(
                env_samples)
            model_metrics = self._model.train(train_inputs, train_outputs,
                                              **kwargs)
        return model_metrics

    def _rollout_model(self, rollout_batch_size, **kwargs):
        print(
            '[ Model Rollout ] Starting | Epoch: {} | Rollout length: {} | Batch size: {} | Type: {}'
            .format(self._epoch, self._rollout_length, rollout_batch_size,
                    self._model_type))
        batch = self.sampler.random_batch(rollout_batch_size)
        obs = batch['observations']
        steps_added = []
        for i in range(self._rollout_length):
            hidden = self.make_init_hidden(1)
            if not self._rollout_random:
                # act = self._policy.actions_np(obs)
                act, hidden = self.get_action_meta(obs, hidden)
            else:
                # act_ = self._policy.actions_np(obs)
                act_, hidden = self.get_action_meta(obs, hidden)
                act = np.random.uniform(low=-1, high=1, size=act_.shape)

            if self._model_type == 'identity':
                next_obs = obs
                rew = np.zeros((len(obs), 1))
                term = (np.ones(
                    (len(obs), 1)) * self._identity_terminal).astype(np.bool)
                info = {}
            else:
                # print("act: {}, obs: {}".format(act.shape, obs.shape))
                next_obs, rew, term, info = self.fake_env.step(
                    obs, act, **kwargs)
            steps_added.append(len(obs))

            samples = {
                'observations': obs,
                'actions': act,
                'next_observations': next_obs,
                'rewards': rew,
                'terminals': term
            }
            self._model_pool.add_samples(samples)

            nonterm_mask = ~term.squeeze(-1)
            if nonterm_mask.sum() == 0:
                print('[ Model Rollout ] Breaking early: {} | {} / {}'.format(
                    i, nonterm_mask.sum(), nonterm_mask.shape))
                break

            obs = next_obs[nonterm_mask]

        mean_rollout_length = sum(steps_added) / rollout_batch_size
        rollout_stats = {'mean_rollout_length': mean_rollout_length}
        print(
            '[ Model Rollout ] Added: {:.1e} | Model pool: {:.1e} (max {:.1e}) | Length: {} | Train rep: {}'
            .format(sum(steps_added), self._model_pool.size,
                    self._model_pool._max_size, mean_rollout_length,
                    self._n_train_repeat))
        return rollout_stats

    def _visualize_model(self, env, timestep):
        ## save env state
        state = env.unwrapped.state_vector()
        qpos_dim = len(env.unwrapped.sim.data.qpos)
        qpos = state[:qpos_dim]
        qvel = state[qpos_dim:]

        print('[ Visualization ] Starting | Epoch {} | Log dir: {}\n'.format(
            self._epoch, self._log_dir))
        visualize_policy(env, self.fake_env, self._policy, self._writer,
                         timestep)
        print('[ Visualization ] Done')
        ## set env state
        env.unwrapped.set_state(qpos, qvel)

    def _do_training_repeats(self, timestep):
        """Repeat training _n_train_repeat times every _train_every_n_steps"""
        if timestep % self._train_every_n_steps > 0: return
        trained_enough = (self._train_steps_this_epoch >
                          self._max_train_repeat_per_timestep * self._timestep)
        if trained_enough: return
        log_buffer = []
        logs = {}
        # print('[ DEBUG ]: {}'.format(self._training_batch()))
        for i in range(self._n_train_repeat):
            logs = self._do_training(iteration=timestep,
                                     batch=self._training_batch())
            log_buffer.append(logs)
        logs_buffer = {
            k: np.mean([item[k] for item in log_buffer])
            for k in logs
        }

        self._num_train_steps += self._n_train_repeat
        self._train_steps_this_epoch += self._n_train_repeat
        return logs_buffer

    def _training_batch(self, batch_size=None):
        batch_size = batch_size or self.sampler._batch_size
        env_batch_size = int(batch_size * self._real_ratio)
        model_batch_size = batch_size - env_batch_size
        # TODO: how to set teriminal state.
        # TODO: how to set model pool.

        ## can sample from the env pool even if env_batch_size == 0
        env_batch = self._pool.random_batch(env_batch_size)

        if model_batch_size > 0:
            model_batch = self._model_pool.random_batch(model_batch_size)

            # keys = env_batch.keys()
            keys = set(env_batch.keys()) & set(model_batch.keys())
            batch = {
                k: np.concatenate((env_batch[k], model_batch[k]), axis=0)
                for k in keys
            }
        else:
            ## if real_ratio == 1.0, no model pool was ever allocated,
            ## so skip the model pool sampling
            batch = env_batch
        return batch

    # def _init_global_step(self):
    #     self.global_step = training_util.get_or_create_global_step()
    #     self._training_ops.update({
    #         'increment_global_step': training_util._increment_global_step(1)
    #     })
    #

    def _init_training(self):
        self._session.run(self.target_init)
        # self._update_target(tau=1.0)

    def _do_training(self, iteration, batch):
        """Runs the operations for updating training and target ops."""

        self._training_progress.update()
        self._training_progress.set_description()

        feed_dict = self._get_feed_dict(iteration, batch)

        res = self._session.run(self._training_ops, feed_dict)
        if iteration % self._target_update_interval == 0:
            # Run target ops here.
            self._update_target()
        logs = {k: np.mean(res[1][k]) for k in res[1]}
        # for k, v in logs.items():
        #     print("[ DEBUG ] k: {}, v: {}".format(k, v))
        #     self._writer.add_scalar(k, v, iteration)
        return logs

    def _update_target(self):
        self._session.run(self.target_update)

    def _get_feed_dict(self, iteration, batch):
        """Construct TensorFlow feed_dict from sample batch."""
        state_dim = len(batch['observations'].shape)
        resize = lambda x: x[None] if state_dim == 2 else x
        feed_dict = {
            self._observations_ph: resize(batch['observations']),
            self._actions_ph: resize(batch['actions']),
            self._next_observations_ph: resize(batch['next_observations']),
            self._rewards_ph: resize(batch['rewards']),
            self._terminals_ph: resize(batch['terminals']),
        }

        if self._store_extra_policy_info:
            feed_dict[self._log_pis_ph] = resize(batch['log_pis'])
            feed_dict[self._raw_actions_ph] = resize(batch['raw_actions'])

        if iteration is not None:
            feed_dict[self._iteration_ph] = iteration

        return feed_dict

    def get_diagnostics(self, iteration, batch, training_paths,
                        evaluation_paths):
        """Return diagnostic information as ordered dictionary.

        Records mean and standard deviation of Q-function and state
        value function, and TD-loss (mean squared Bellman error)
        for the sample batch.

        Also calls the `draw` method of the plotter, if plotter defined.
        """

        feed_dict = self._get_feed_dict(iteration, batch)

        (Q_value1, Q_value2, Q_losses, alpha, global_step) = self._session.run(
            (self.Q1, self.Q2, self.Q_loss, self._alpha, self.global_step),
            feed_dict)
        Q_values = np.concatenate((Q_value1, Q_value2), axis=0)
        diagnostics = OrderedDict({
            'Q-avg': np.mean(Q_values),
            'Q-std': np.std(Q_values),
            'Q_loss': np.mean(Q_losses),
            'alpha': alpha,
        })

        # TODO (luofm): policy diagnostics
        # policy_diagnostics = self._policy.get_diagnostics(
        #     batch['observations'])
        # diagnostics.update({
        #     'policy/{}'.format(key): value
        #     for key, value in policy_diagnostics.items()
        # })

        if self._plotter:
            self._plotter.draw()

        return diagnostics

    @property
    def tf_saveables(self):
        saveables = {
            '_policy_optimizer': self._policy_optimizer,
            **{
                'Q_optimizer_{}'.format(i): optimizer
                for i, optimizer in enumerate(self._Q_optimizers)
            },
            '_log_alpha': self._log_alpha,
        }

        if hasattr(self, '_alpha_optimizer'):
            saveables['_alpha_optimizer'] = self._alpha_optimizer

        return saveables
Beispiel #5
0
class MBPO(RLAlgorithm):
    """Model-Based Policy Optimization (MBPO)

    References
    ----------
        Michael Janner, Justin Fu, Marvin Zhang, Sergey Levine. 
        When to Trust Your Model: Model-Based Policy Optimization. 
        arXiv preprint arXiv:1906.08253. 2019.
    """

    def __init__(
            self,
            training_environment,
            evaluation_environment,
            policy,
            Qs,
            pool,
            static_fns,
            plotter=None,
            tf_summaries=False,

            lr=3e-4,
            reward_scale=1.0,
            target_entropy='auto',
            discount=0.99,
            tau=5e-3,
            target_update_interval=1,
            action_prior='uniform',
            reparameterize=False,
            store_extra_policy_info=False,

            deterministic=False,
            model_train_freq=250,
            model_train_slower=1,
            num_networks=7,
            num_elites=5,
            num_Q_elites=2, # The num of Q ensemble is set in command line
            model_retain_epochs=20,
            rollout_batch_size=100e3,
            real_ratio=0.1,
            critic_same_as_actor=True,
            rollout_schedule=[20,100,1,1],
            hidden_dim=200,
            max_model_t=None,
            dir_name=None,
            evaluate_explore_freq=0,
            num_Q_per_grp=2,
            num_Q_grp=1,
            cross_grp_diff_batch=False,

            model_load_dir=None,
            model_load_index=None,
            model_log_freq=0,
            **kwargs,
    ):
        """
        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy: A policy function approximator.
            initial_exploration_policy: ('Policy'): A policy that we use
                for initial exploration which is not trained by the algorithm.
            Qs: Q-function approximators. The min of these
                approximators will be used. Usage of at least two Q-functions
                improves performance by reducing overestimation bias.
            pool (`PoolBase`): Replay pool to add gathered samples to.
            plotter (`QFPolicyPlotter`): Plotter instance to be used for
                visualizing Q-function during training.
            lr (`float`): Learning rate used for the function approximators.
            discount (`float`): Discount factor for Q-function updates.
            tau (`float`): Soft value function target update weight.
            target_update_interval ('int'): Frequency at which target network
                updates occur in iterations.
            reparameterize ('bool'): If True, we use a gradient estimator for
                the policy derived using the reparameterization trick. We use
                a likelihood ratio based estimator otherwise.
            critic_same_as_actor ('bool'): If True, use the same sampling schema
                (model free or model based) as the actor in critic training. 
                Otherwise, use model free sampling to train critic.
        """

        super(MBPO, self).__init__(**kwargs)

        if training_environment.unwrapped.spec.id.find("Fetch") != -1:
            # Fetch env
            obs_dim = sum([i.shape[0] for i in training_environment.observation_space.spaces.values()]) 
            self.multigoal = 1
        else:
            obs_dim = np.prod(training_environment.observation_space.shape)
        # print("====", obs_dim, "========")

        act_dim = np.prod(training_environment.action_space.shape)

        # TODO: add variable scope to directly extract model parameters
        self._model_load_dir = model_load_dir
        print("============Model dir: ", self._model_load_dir)
        if model_load_index:
            latest_model_index = model_load_index
        else:
            latest_model_index = self._get_latest_index()
        self._model = construct_model(obs_dim=obs_dim, act_dim=act_dim, hidden_dim=hidden_dim, num_networks=num_networks, num_elites=num_elites,
                                      model_dir=self._model_load_dir, model_load_timestep=latest_model_index, load_model=True if model_load_dir else False)
        self._static_fns = static_fns
        self.fake_env = FakeEnv(self._model, self._static_fns)

        model_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self._model.name)
        all_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)

        self._rollout_schedule = rollout_schedule
        self._max_model_t = max_model_t

        # self._model_pool_size = model_pool_size
        # print('[ MBPO ] Model pool size: {:.2E}'.format(self._model_pool_size))
        # self._model_pool = SimpleReplayPool(pool._observation_space, pool._action_space, self._model_pool_size)

        self._model_retain_epochs = model_retain_epochs

        self._model_train_freq = model_train_freq
        self._rollout_batch_size = int(rollout_batch_size)
        self._deterministic = deterministic
        self._real_ratio = real_ratio

        self._log_dir = os.getcwd()
        self._writer = Writer(self._log_dir)

        self._training_environment = training_environment
        self._evaluation_environment = evaluation_environment
        self._policy = policy

        self._Qs = Qs
        self._Q_ensemble = len(Qs)
        self._Q_elites = num_Q_elites
        self._Q_targets = tuple(tf.keras.models.clone_model(Q) for Q in Qs)

        self._pool = pool
        self._plotter = plotter
        self._tf_summaries = tf_summaries

        self._policy_lr = lr
        self._Q_lr = lr

        self._reward_scale = reward_scale
        self._target_entropy = (
            -np.prod(self._training_environment.action_space.shape)
            if target_entropy == 'auto'
            else target_entropy)
        print('[ MBPO ] Target entropy: {}'.format(self._target_entropy))

        self._discount = discount
        self._tau = tau
        self._target_update_interval = target_update_interval
        self._action_prior = action_prior

        self._reparameterize = reparameterize
        self._store_extra_policy_info = store_extra_policy_info

        observation_shape = self._training_environment.active_observation_shape
        action_shape = self._training_environment.action_space.shape

        assert len(observation_shape) == 1, observation_shape
        self._observation_shape = observation_shape
        assert len(action_shape) == 1, action_shape
        self._action_shape = action_shape

        # self._critic_train_repeat = kwargs["critic_train_repeat"]
        # actor UTD should be n times larger or smaller than critic UTD
        assert self._actor_train_repeat % self._critic_train_repeat == 0 or \
               self._critic_train_repeat % self._actor_train_repeat == 0

        self._critic_train_freq = self._n_train_repeat // self._critic_train_repeat
        self._actor_train_freq = self._n_train_repeat // self._actor_train_repeat
        self._critic_same_as_actor = critic_same_as_actor
        self._model_train_slower = model_train_slower
        self._origin_model_train_epochs = 0

        self._dir_name = dir_name
        self._evaluate_explore_freq = evaluate_explore_freq

        # Inter-group Qs are trained with the same data; Cross-group Qs different.
        self._num_Q_per_grp = num_Q_per_grp
        self._num_Q_grp = num_Q_grp
        self._cross_grp_diff_batch = cross_grp_diff_batch

        self._model_log_freq = model_log_freq

        self._build()

    def _build(self):
        self._training_ops = {}
        self._actor_training_ops = {}
        self._critic_training_ops = {} if not self._cross_grp_diff_batch else \
                                    [{} for _ in range(self._num_Q_grp)]
        self._misc_training_ops = {} # basically no feeddict is needed
        # device = "/device:GPU:1"
        # with tf.device(device):
        #     self._init_global_step()
        #     self._init_placeholders()
        #     self._init_actor_update()
        #     self._init_critic_update()
        self._init_global_step()
        self._init_placeholders()
        self._init_actor_update()
        self._init_critic_update()

    def _train(self):
        
        """Return a generator that performs RL training.

        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy (`Policy`): Policy used for training
            initial_exploration_policy ('Policy'): Policy used for exploration
                If None, then all exploration is done using policy
            pool (`PoolBase`): Sample pool to add samples to
        """
        training_environment = self._training_environment
        evaluation_environment = self._evaluation_environment
        policy = self._policy
        pool = self._pool
        model_metrics = {}

        if not self._training_started:
            self._init_training()

            self._initial_exploration_hook(
                training_environment, self._initial_exploration_policy, pool)

        self.sampler.initialize(training_environment, policy, pool)

        gt.reset_root()
        gt.rename_root('RLAlgorithm')
        gt.set_def_unique(False)

        self._training_before_hook()

        for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)):

            self._epoch_before_hook()
            gt.stamp('epoch_before_hook')

            if self._evaluate_explore_freq != 0 and self._epoch % self._evaluate_explore_freq == 0:
                self._evaluate_exploration()

            self._training_progress = Progress(self._epoch_length * self._n_train_repeat)
            start_samples = self.sampler._total_samples
            for i in count():
                samples_now = self.sampler._total_samples
                self._timestep = samples_now - start_samples

                if (samples_now >= start_samples + self._epoch_length
                    and self.ready_to_train):
                    break
                self._timestep_before_hook()
                gt.stamp('timestep_before_hook')

                if self._timestep % self._model_train_freq == 0 and self._real_ratio < 1.0:
                    
                    self._training_progress.pause()
                    print('[ MBPO ] log_dir: {} | ratio: {}'.format(self._log_dir, self._real_ratio))
                    print('[ MBPO ] Training model at epoch {} | freq {} | timestep {} (total: {}) | epoch train steps: {} (total: {}) | times slower: {}'.format(
                        self._epoch, self._model_train_freq, self._timestep, self._total_timestep, self._train_steps_this_epoch, self._num_train_steps, self._model_train_slower)
                    )

                    if self._origin_model_train_epochs % self._model_train_slower == 0:
                        model_train_metrics = self._train_model(batch_size=256, max_epochs=None, holdout_ratio=0.2, max_t=self._max_model_t)
                        model_metrics.update(model_train_metrics)
                        gt.stamp('epoch_train_model')
                    else:
                        print('[ MBPO ] Skipping model training due to slowed training setting')
                    self._origin_model_train_epochs += 1
                    
                    self._set_rollout_length()
                    self._reallocate_model_pool()
                    model_rollout_metrics = self._rollout_model(rollout_batch_size=self._rollout_batch_size, deterministic=self._deterministic)
                    model_metrics.update(model_rollout_metrics)
                    
                    if self._model_log_freq != 0 and self._timestep % self._model_log_freq == 0:
                        self._log_model()

                    gt.stamp('epoch_rollout_model')
                    # self._visualize_model(self._evaluation_environment, self._total_timestep)
                    self._training_progress.resume()

                self._do_sampling(timestep=self._total_timestep)
                gt.stamp('sample')

                if self.ready_to_train:
                    self._do_training_repeats(timestep=self._total_timestep)
                gt.stamp('train')

                self._timestep_after_hook()
                gt.stamp('timestep_after_hook')

            training_paths = self.sampler.get_last_n_paths(
                math.ceil(self._epoch_length / self.sampler._max_path_length))
            gt.stamp('training_paths')
            evaluation_paths = self._evaluation_paths(
                policy, evaluation_environment)
            gt.stamp('evaluation_paths')

            training_metrics = self._evaluate_rollouts(
                training_paths, training_environment)
            gt.stamp('training_metrics')
            if evaluation_paths:
                evaluation_metrics = self._evaluate_rollouts(
                    evaluation_paths, evaluation_environment)
                gt.stamp('evaluation_metrics')
            else:
                evaluation_metrics = {}

            self._epoch_after_hook(training_paths)
            gt.stamp('epoch_after_hook')

            sampler_diagnostics = self.sampler.get_diagnostics()

            diagnostics = self.get_diagnostics(
                iteration=self._total_timestep,
                batch=self._evaluation_batch(),
                training_paths=training_paths,
                evaluation_paths=evaluation_paths)

            time_diagnostics = gt.get_times().stamps.itrs

            diagnostics.update(OrderedDict((
                *(
                    (f'evaluation/{key}', evaluation_metrics[key])
                    for key in sorted(evaluation_metrics.keys())
                ),
                *(
                    (f'training/{key}', training_metrics[key])
                    for key in sorted(training_metrics.keys())
                ),
                *(
                    (f'times/{key}', time_diagnostics[key][-1])
                    for key in sorted(time_diagnostics.keys())
                ),
                *(
                    (f'sampler/{key}', sampler_diagnostics[key])
                    for key in sorted(sampler_diagnostics.keys())
                ),
                *(
                    (f'model/{key}', model_metrics[key])
                    for key in sorted(model_metrics.keys())
                ),
                ('epoch', self._epoch),
                ('timestep', self._timestep),
                ('timesteps_total', self._total_timestep),
                ('train-steps', self._num_train_steps),
            )))

            if self._eval_render_mode is not None and hasattr(
                    evaluation_environment, 'render_rollouts'):
                training_environment.render_rollouts(evaluation_paths)


            yield diagnostics

        self.sampler.terminate()

        self._training_after_hook()

        self._training_progress.close()

        yield {'done': True, **diagnostics}
    
    def _evaluate_exploration(self):
        print("=============evaluate exploration=========")

        # specify data dir
        base_dir = "/home/linus/Research/mbpo/mbpo_experiment/exploration_eval"
        if not self._dir_name:
            return
        data_dir = os.path.join(base_dir, self._dir_name)
        if not os.path.isdir(data_dir):
            os.mkdir(data_dir)

        # specify data name
        exp_name = "%d.pkl"%self._epoch
        path = os.path.join(data_dir, exp_name)

        evaluation_size = 3000
        action_repeat = 20
        batch = self.sampler.random_batch(evaluation_size)
        obs = batch['observations']
        actions_repeat = [self._policy.actions_np(obs) for _ in range(action_repeat)]

        Qs = []
        policy_std = []
        for action in actions_repeat:
            Q = []
            for (s,a) in zip(obs, action):
                s, a = np.array(s).reshape(1, -1), np.array(a).reshape(1, -1)
                Q.append(
                    self._session.run(
                        self._Q_values,
                        feed_dict = {
                            self._observations_ph: s,
                            self._actions_ph: a
                        }
                    )
                )
            Qs.append(Q)
        Qs = np.array(Qs).squeeze()
        Qs_mean_action = np.mean(Qs, axis = 0) # Compute mean across different actions of one given state.
        if self._cross_grp_diff_batch:
            inter_grp_q_stds = [np.std(Qs_mean_action[:, i * self._num_Q_per_grp:(i+1) * self._num_Q_per_grp], axis = 1) for i in range(self._num_Q_grp)]
            mean_inter_grp_q_std = np.mean(np.array(inter_grp_q_stds), axis = 0)
            min_qs_per_grp = [np.mean(Qs_mean_action[:, i * self._num_Q_per_grp:(i+1) * self._num_Q_per_grp], axis = 1) for i in range(self._num_Q_grp)]
            cross_grp_std = np.std(np.array(min_qs_per_grp), axis = 0)
        else:
            q_std = np.std(Qs_mean_action, axis=1) # In fact V std

        policy_std = [np.prod(np.exp(self._policy.policy_log_scale_model.predict(np.array(s).reshape(1,-1)))) for s in obs]

        if self._cross_grp_diff_batch:
            data = { 
                'obs': obs,
                'inter_q_std': mean_inter_grp_q_std,
                'cross_q_std': cross_grp_std,
                'pi_std': policy_std
            }
        else:
            data = {
                'obs': obs,
                'q_std': q_std,
                'pi_std': policy_std
            }
        with open(path, 'wb') as f:
            pickle.dump(data, f)
        print("==========================================")


    def train(self, *args, **kwargs):
        return self._train(*args, **kwargs)

    def _log_policy(self):
        save_path = os.path.join(self._log_dir, 'models')
        filesystem.mkdir(save_path)
        weights = self._policy.get_weights()
        data = {'policy_weights': weights}
        full_path = os.path.join(save_path, 'policy_{}.pkl'.format(self._total_timestep))
        print('Saving policy to: {}'.format(full_path))
        pickle.dump(data, open(full_path, 'wb'))

    # TODO: use this function to save model
    def _log_model(self):
        save_path = os.path.join(self._log_dir, 'models')
        filesystem.mkdir(save_path)
        print('Saving model to: {}'.format(save_path))
        self._model.save(save_path, self._total_timestep)

    def _set_rollout_length(self):
        min_epoch, max_epoch, min_length, max_length = self._rollout_schedule
        if self._epoch <= min_epoch:
            y = min_length
        else:
            dx = (self._epoch - min_epoch) / (max_epoch - min_epoch)
            dx = min(dx, 1)
            y = dx * (max_length - min_length) + min_length

        self._rollout_length = int(y)
        print('[ Model Length ] Epoch: {} (min: {}, max: {}) | Length: {} (min: {} , max: {})'.format(
            self._epoch, min_epoch, max_epoch, self._rollout_length, min_length, max_length
        ))

    def _reallocate_model_pool(self):
        obs_space = self._pool._observation_space
        act_space = self._pool._action_space

        rollouts_per_epoch = self._rollout_batch_size * self._epoch_length / self._model_train_freq
        model_steps_per_epoch = int(self._rollout_length * rollouts_per_epoch)
        new_pool_size = self._model_retain_epochs * model_steps_per_epoch

        if not hasattr(self, '_model_pool'):
            print('[ MBPO ] Initializing new model pool with size {:.2e}'.format(
                new_pool_size
            ))
            self._model_pool = SimpleReplayPool(obs_space, act_space, new_pool_size)
        
        elif self._model_pool._max_size != new_pool_size:
            print('[ MBPO ] Updating model pool | {:.2e} --> {:.2e}'.format(
                self._model_pool._max_size, new_pool_size
            ))
            samples = self._model_pool.return_all_samples()
            new_pool = SimpleReplayPool(obs_space, act_space, new_pool_size)
            new_pool.add_samples(samples)
            assert self._model_pool.size == new_pool.size
            self._model_pool = new_pool

    def _train_model(self, **kwargs):
        env_samples = self._pool.return_all_samples()
        # train_inputs, train_outputs = format_samples_for_training(env_samples, self.multigoal)
        train_inputs, train_outputs = format_samples_for_training(env_samples)
        model_metrics = self._model.train(train_inputs, train_outputs, **kwargs)
        return model_metrics

    def _rollout_model(self, rollout_batch_size, **kwargs):
        print('[ Model Rollout ] Starting | Epoch: {} | Rollout length: {} | Batch size: {}'.format(
            self._epoch, self._rollout_length, rollout_batch_size
        ))

        # Keep total rollout sample complexity unchanged
        batch = self.sampler.random_batch(rollout_batch_size // self._sample_repeat)
        obs = batch['observations']
        steps_added = []
        sampled_actions = []
        for _ in range(self._sample_repeat):
            for i in range(self._rollout_length):
                # TODO: alter policy distribution in different times of sample repeating
                # self._policy: softlearning.policies.gaussian_policy.FeedforwardGaussianPolicy
                # self._policy._deterministic: False
                # print("=====================================")
                # print(self._policy._deterministic)
                # print("=====================================")
                act = self._policy.actions_np(obs)
                sampled_actions.append(act)
                
                next_obs, rew, term, info = self.fake_env.step(obs, act, **kwargs)
                steps_added.append(len(obs))

                samples = {'observations': obs, 'actions': act, 'next_observations': next_obs, 'rewards': rew, 'terminals': term}
                self._model_pool.add_samples(samples)

                nonterm_mask = ~term.squeeze(-1)
                if nonterm_mask.sum() == 0:
                    print('[ Model Rollout ] Breaking early: {} | {} / {}'.format(i, nonterm_mask.sum(), nonterm_mask.shape))
                    break

                obs = next_obs[nonterm_mask]
        # print(sampled_actions)

        mean_rollout_length = sum(steps_added) / rollout_batch_size
        rollout_stats = {'mean_rollout_length': mean_rollout_length}
        print('[ Model Rollout ] Added: {:.1e} | Model pool: {:.1e} (max {:.1e}) | Length: {} | Train rep: {}'.format(
            sum(steps_added), self._model_pool.size, self._model_pool._max_size, mean_rollout_length, self._n_train_repeat
        ))
        return rollout_stats

    def _visualize_model(self, env, timestep):
        ## save env state
        state = env.unwrapped.state_vector()
        qpos_dim = len(env.unwrapped.sim.data.qpos)
        qpos = state[:qpos_dim]
        qvel = state[qpos_dim:]

        print('[ Visualization ] Starting | Epoch {} | Log dir: {}\n'.format(self._epoch, self._log_dir))
        visualize_policy(env, self.fake_env, self._policy, self._writer, timestep)
        print('[ Visualization ] Done')
        ## set env state
        env.unwrapped.set_state(qpos, qvel)

    def _training_batch(self, batch_size=None):
        batch_size = batch_size or self.sampler._batch_size
        env_batch_size = int(batch_size*self._real_ratio)
        model_batch_size = batch_size - env_batch_size

        ## can sample from the env pool even if env_batch_size == 0
        if self._cross_grp_diff_batch:
            env_batch = [self._pool.random_batch(env_batch_size) for _ in range(self._num_Q_grp)]
        else:
            env_batch = self._pool.random_batch(env_batch_size)

        if model_batch_size > 0:
            model_batch = self._model_pool.random_batch(model_batch_size)

            keys = env_batch.keys()
            batch = {k: np.concatenate((env_batch[k], model_batch[k]), axis=0) for k in keys}
        else:
            ## if real_ratio == 1.0, no model pool was ever allocated,
            ## so skip the model pool sampling
            batch = env_batch
        return batch, env_batch

    def _init_global_step(self):
        self.global_step = training_util.get_or_create_global_step()
        self._training_ops.update({
            'increment_global_step': training_util._increment_global_step(1)
        })
        self._misc_training_ops.update({
            'increment_global_step': training_util._increment_global_step(1)
        })

    def _init_placeholders(self):
        """Create input placeholders for the SAC algorithm.

        Creates `tf.placeholder`s for:
            - observation
            - next observation
            - action
            - reward
            - terminals
        """
        self._iteration_ph = tf.placeholder(
            tf.int64, shape=None, name='iteration')

        self._observations_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._observation_shape),
            name='observation',
        )

        self._next_observations_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._observation_shape),
            name='next_observation',
        )

        self._actions_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._action_shape),
            name='actions',
        )

        self._rewards_ph = tf.placeholder(
            tf.float32,
            shape=(None, 1),
            name='rewards',
        )

        self._terminals_ph = tf.placeholder(
            tf.float32,
            shape=(None, 1),
            name='terminals',
        )

        if self._store_extra_policy_info:
            self._log_pis_ph = tf.placeholder(
                tf.float32,
                shape=(None, 1),
                name='log_pis',
            )
            self._raw_actions_ph = tf.placeholder(
                tf.float32,
                shape=(None, *self._action_shape),
                name='raw_actions',
            )

    def _get_Q_target(self):
        next_actions = self._policy.actions([self._next_observations_ph])
        next_log_pis = self._policy.log_pis(
            [self._next_observations_ph], next_actions)

        next_Qs_values = tuple(
            Q([self._next_observations_ph, next_actions])
            for Q in self._Q_targets)
        Qs_subset = np.random.choice(next_Qs_values, self._Q_elites, replace=False).tolist()
        
        # Line 8 of REDQ: min over M random indices
        min_next_Q = tf.reduce_min(Qs_subset, axis=0)
        next_value = min_next_Q - self._alpha * next_log_pis

        Q_target = td_target(
            reward=self._reward_scale * self._rewards_ph,
            discount=self._discount,
            next_value=(1 - self._terminals_ph) * next_value)

        return Q_target

    def _init_critic_update(self):
        """Create minimization operation for critic Q-function.

        Creates a `tf.optimizer.minimize` operation for updating
        critic Q-function with gradient descent, and appends it to
        `self._training_ops` attribute.
        """
        Q_target = tf.stop_gradient(self._get_Q_target())

        assert Q_target.shape.as_list() == [None, 1]

        Q_values = self._Q_values = tuple(
            Q([self._observations_ph, self._actions_ph])
            for Q in self._Qs)

        Q_losses = self._Q_losses = tuple(
            tf.losses.mean_squared_error(
                labels=Q_target, predictions=Q_value, weights=0.5)
            for Q_value in Q_values)

        self._Q_optimizers = tuple(
            tf.train.AdamOptimizer(
                learning_rate=self._Q_lr,
                name='{}_{}_optimizer'.format(Q._name, i)
            ) for i, Q in enumerate(self._Qs))

        # TODO: divide it to N separate ops, where N is # of Q grps
        Q_training_ops = tuple(
            tf.contrib.layers.optimize_loss(
                Q_loss,
                self.global_step,
                learning_rate=self._Q_lr,
                optimizer=Q_optimizer,
                variables=Q.trainable_variables,
                increment_global_step=False,
                summaries=((
                    "loss", "gradients", "gradient_norm", "global_gradient_norm"
                ) if self._tf_summaries else ()))
            for i, (Q, Q_loss, Q_optimizer)
            in enumerate(zip(self._Qs, Q_losses, self._Q_optimizers)))

        self._training_ops.update({'Q': tf.group(Q_training_ops)})
        if self._cross_grp_diff_batch:
            assert len(Q_training_ops) >= self._num_Q_grp * self._num_Q_per_grp
            for i in range(self._num_Q_grp - 1):
                self._critic_training_ops[i].update({
                    'Q': tf.group(Q_training_ops[i * self._num_Q_grp: (i+1) * self._num_Q_grp])
                })

            self._critic_training_ops[self._num_Q_grp - 1].update({
                'Q': tf.group(Q_training_ops[(self._num_Q_grp - 1) * self._num_Q_grp:])
            })
        else:
            self._critic_training_ops.update({'Q': tf.group(Q_training_ops)})

    def _init_actor_update(self):
        """Create minimization operations for policy and entropy.

        Creates a `tf.optimizer.minimize` operations for updating
        policy and entropy with gradient descent, and adds them to
        `self._training_ops` attribute.
        """

        actions = self._policy.actions([self._observations_ph])
        log_pis = self._policy.log_pis([self._observations_ph], actions)

        assert log_pis.shape.as_list() == [None, 1]

        log_alpha = self._log_alpha = tf.get_variable(
            'log_alpha',
            dtype=tf.float32,
            initializer=0.0)
        alpha = tf.exp(log_alpha)

        if isinstance(self._target_entropy, Number):
            alpha_loss = -tf.reduce_mean(
                log_alpha * tf.stop_gradient(log_pis + self._target_entropy))

            self._alpha_optimizer = tf.train.AdamOptimizer(
                self._policy_lr, name='alpha_optimizer')
            self._alpha_train_op = self._alpha_optimizer.minimize(
                loss=alpha_loss, var_list=[log_alpha])

            self._training_ops.update({
                'temperature_alpha': self._alpha_train_op
            })
            self._actor_training_ops.update({
                'temperature_alpha': self._alpha_train_op
            })

        self._alpha = alpha

        if self._action_prior == 'normal':
            policy_prior = tf.contrib.distributions.MultivariateNormalDiag(
                loc=tf.zeros(self._action_shape),
                scale_diag=tf.ones(self._action_shape))
            policy_prior_log_probs = policy_prior.log_prob(actions)
        elif self._action_prior == 'uniform':
            policy_prior_log_probs = 0.0

        Q_log_targets = tuple(
            Q([self._observations_ph, actions])
            for Q in self._Qs)
        assert len(Q_log_targets) == self._Q_ensemble

        min_Q_log_target = tf.reduce_min(Q_log_targets, axis=0)
        mean_Q_log_target = tf.reduce_mean(Q_log_targets, axis=0)
        Q_target = min_Q_log_target if self._Q_ensemble == 2 else mean_Q_log_target

        if self._reparameterize:
            policy_kl_losses = (
                alpha * log_pis
                - Q_target
                - policy_prior_log_probs)
        else:
            raise NotImplementedError

        assert policy_kl_losses.shape.as_list() == [None, 1]

        policy_loss = tf.reduce_mean(policy_kl_losses)

        self._policy_optimizer = tf.train.AdamOptimizer(
            learning_rate=self._policy_lr,
            name="policy_optimizer")
        policy_train_op = tf.contrib.layers.optimize_loss(
            policy_loss,
            self.global_step,
            learning_rate=self._policy_lr,
            optimizer=self._policy_optimizer,
            variables=self._policy.trainable_variables,
            increment_global_step=False,
            summaries=(
                "loss", "gradients", "gradient_norm", "global_gradient_norm"
            ) if self._tf_summaries else ())

        self._training_ops.update({'policy_train_op': policy_train_op})
        self._actor_training_ops.update({'policy_train_op': policy_train_op})

    def _init_training(self):
        self._update_target(tau=1.0)

    def _update_target(self, tau=None):
        tau = tau or self._tau

        for Q, Q_target in zip(self._Qs, self._Q_targets):
            source_params = Q.get_weights()
            target_params = Q_target.get_weights()
            Q_target.set_weights([
                tau * source + (1.0 - tau) * target
                for source, target in zip(source_params, target_params)
            ])

    def _do_training(self, iteration, batch):
        """Runs the operations for updating training and target ops."""
        mix_batch, mf_batch = batch

        self._training_progress.update()
        self._training_progress.set_description()

        if self._cross_grp_diff_batch:
            assert len(mix_batch) == self._num_Q_grp
            if self._real_ratio != 1:
                assert 0, "Currently different batch is not supported in MBPO"
            mix_feed_dict = [self._get_feed_dict(iteration, i) for i in mix_batch]
            single_mix_feed_dict = mix_feed_dict[0]
        else:
            mix_feed_dict = self._get_feed_dict(iteration, mix_batch)
            single_mix_feed_dict = mix_feed_dict


        if self._critic_same_as_actor:
            critic_feed_dict = mix_feed_dict
        else:
            critic_feed_dict = self._get_feed_dict(iteration, mf_batch)

        self._session.run(self._misc_training_ops, single_mix_feed_dict)

        if iteration % self._actor_train_freq == 0:
            self._session.run(self._actor_training_ops, single_mix_feed_dict)
        if iteration % self._critic_train_freq == 0:
            if self._cross_grp_diff_batch:
                assert len(self._critic_training_ops) == len(critic_feed_dict)
                [
                    self._session.run(op, feed_dict)
                    for (op, feed_dict) in zip(self._critic_training_ops, critic_feed_dict)
                ]
            else:
                self._session.run(self._critic_training_ops, critic_feed_dict)

        if iteration % self._target_update_interval == 0:
            # Run target ops here.
            self._update_target()

    def _get_feed_dict(self, iteration, batch):
        """Construct TensorFlow feed_dict from sample batch."""

        feed_dict = {
            self._observations_ph: batch['observations'],
            self._actions_ph: batch['actions'],
            self._next_observations_ph: batch['next_observations'],
            self._rewards_ph: batch['rewards'],
            self._terminals_ph: batch['terminals'],
        }

        if self._store_extra_policy_info:
            feed_dict[self._log_pis_ph] = batch['log_pis']
            feed_dict[self._raw_actions_ph] = batch['raw_actions']

        if iteration is not None:
            feed_dict[self._iteration_ph] = iteration

        return feed_dict

    def get_diagnostics(self,
                        iteration,
                        batch,
                        training_paths,
                        evaluation_paths):
        """Return diagnostic information as ordered dictionary.

        Records mean and standard deviation of Q-function and state
        value function, and TD-loss (mean squared Bellman error)
        for the sample batch.

        Also calls the `draw` method of the plotter, if plotter defined.
        """
        mix_batch, _ = batch
        if self._cross_grp_diff_batch:
            mix_batch = mix_batch[0]
        mix_feed_dict = self._get_feed_dict(iteration, mix_batch)

        # (Q_values, Q_losses, alpha, global_step) = self._session.run(
        #     (self._Q_values,
        #      self._Q_losses,
        #      self._alpha,
        #      self.global_step),
        #     feed_dict)
        Q_values, Q_losses = self._session.run(
            [self._Q_values, self._Q_losses],
            mix_feed_dict
        )

        alpha, global_step = self._session.run(
            [self._alpha, self.global_step],
            mix_feed_dict
        )

        diagnostics = OrderedDict({
            'Q-avg': np.mean(Q_values),
            'Q-std': np.std(Q_values),
            'Q_loss': np.mean(Q_losses),
            'alpha': alpha,
        })

        policy_diagnostics = self._policy.get_diagnostics(
            mix_batch['observations'])
        diagnostics.update({
            f'policy/{key}': value
            for key, value in policy_diagnostics.items()
        })

        if self._plotter:
            self._plotter.draw()

        return diagnostics

    @property
    def tf_saveables(self):
        saveables = {
            '_policy_optimizer': self._policy_optimizer,
            **{
                f'Q_optimizer_{i}': optimizer
                for i, optimizer in enumerate(self._Q_optimizers)
            },
            '_log_alpha': self._log_alpha,
        }

        if hasattr(self, '_alpha_optimizer'):
            saveables['_alpha_optimizer'] = self._alpha_optimizer

        return saveables

    def _get_latest_index(self):
        if self._model_load_dir is None:
            return
        return max([int(i.split("_")[1].split(".")[0]) for i in os.listdir(self._model_load_dir)])
Beispiel #6
0
class MOPO(RLAlgorithm):
    """Model-based Offline Policy Optimization (MOPO)

    References
    ----------
        Tianhe Yu, Garrett Thomas, Lantao Yu, Stefano Ermon, James Zou, Sergey Levine, Chelsea Finn, Tengyu Ma. 
        MOPO: Model-based Offline Policy Optimization. 
        arXiv preprint arXiv:2005.13239. 2020.
    """
    def __init__(
        self,
        training_environment,
        evaluation_environment,
        policy,
        Qs,
        pool,
        static_fns,
        plotter=None,
        tf_summaries=False,
        lr=3e-4,
        reward_scale=1.0,
        target_entropy='auto',
        discount=0.99,
        tau=5e-3,
        target_update_interval=1,
        action_prior='uniform',
        reparameterize=False,
        store_extra_policy_info=False,
        deterministic=False,
        rollout_random=False,
        model_train_freq=250,
        num_networks=7,
        num_elites=5,
        model_retain_epochs=20,
        rollout_batch_size=100e3,
        real_ratio=0.1,
        # rollout_schedule=[20,100,1,1],
        rollout_length=1,
        hidden_dim=200,
        max_model_t=None,
        model_type='mlp',
        separate_mean_var=False,
        identity_terminal=0,
        pool_load_path='',
        pool_load_max_size=0,
        model_name=None,
        model_load_dir=None,
        penalty_coeff=0.,
        penalty_learned_var=False,
        **kwargs,
    ):
        """
        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy: A policy function approximator.
            initial_exploration_policy: ('Policy'): A policy that we use
                for initial exploration which is not trained by the algorithm.
            Qs: Q-function approximators. The min of these
                approximators will be used. Usage of at least two Q-functions
                improves performance by reducing overestimation bias.
            pool (`PoolBase`): Replay pool to add gathered samples to.
            plotter (`QFPolicyPlotter`): Plotter instance to be used for
                visualizing Q-function during training.
            lr (`float`): Learning rate used for the function approximators.
            discount (`float`): Discount factor for Q-function updates.
            tau (`float`): Soft value function target update weight.
            target_update_interval ('int'): Frequency at which target network
                updates occur in iterations.
            reparameterize ('bool'): If True, we use a gradient estimator for
                the policy derived using the reparameterization trick. We use
                a likelihood ratio based estimator otherwise.
        """

        super(MOPO, self).__init__(**kwargs)

        obs_dim = np.prod(training_environment.active_observation_shape)
        act_dim = np.prod(training_environment.action_space.shape)
        self._model_type = model_type
        self._identity_terminal = identity_terminal
        self._model = construct_model(obs_dim=obs_dim,
                                      act_dim=act_dim,
                                      hidden_dim=hidden_dim,
                                      num_networks=num_networks,
                                      num_elites=num_elites,
                                      model_type=model_type,
                                      separate_mean_var=separate_mean_var,
                                      name=model_name,
                                      load_dir=model_load_dir,
                                      deterministic=deterministic)
        self._modelr = construct_modelr(
            obs_dim=obs_dim,
            act_dim=act_dim,
            hidden_dim=hidden_dim,
            num_networks=num_networks,
            num_elites=num_elites,
            model_type=model_type,
            separate_mean_var=separate_mean_var,
            #name=model_name,
            load_dir=model_load_dir,
            deterministic=True)
        self._modelc = construct_modelc(
            obs_dim=obs_dim,
            act_dim=act_dim,
            hidden_dim=hidden_dim,
            num_networks=num_networks,
            num_elites=num_elites,
            model_type=model_type,
            separate_mean_var=separate_mean_var,
            #name=model_name,
            load_dir=model_load_dir,
            deterministic=True)
        self._static_fns = static_fns
        self.fake_env = FakeEnv(self._model,
                                self._static_fns,
                                penalty_coeff=penalty_coeff,
                                penalty_learned_var=penalty_learned_var)

        self._rollout_schedule = [20, 100, rollout_length, rollout_length]
        self._max_model_t = max_model_t

        self._model_retain_epochs = model_retain_epochs

        self._model_train_freq = model_train_freq
        self._rollout_batch_size = int(rollout_batch_size)
        self._deterministic = deterministic
        self._rollout_random = rollout_random
        self._real_ratio = real_ratio

        self._log_dir = os.getcwd()
        self._writer = Writer(self._log_dir)

        self._training_environment = training_environment
        self._evaluation_environment = evaluation_environment
        self._policy = policy

        self._Qs = Qs
        self._Q_targets = tuple(tf.keras.models.clone_model(Q) for Q in Qs)

        self._pool = pool
        self._plotter = plotter
        self._tf_summaries = tf_summaries

        self._policy_lr = lr
        self._Q_lr = lr

        self._reward_scale = reward_scale
        self._target_entropy = (
            -np.prod(self._training_environment.action_space.shape)
            if target_entropy == 'auto' else target_entropy)
        print('[ MOPO ] Target entropy: {}'.format(self._target_entropy))

        self._discount = discount
        self._tau = tau
        self._target_update_interval = target_update_interval
        self._action_prior = action_prior

        self._reparameterize = reparameterize
        self._store_extra_policy_info = store_extra_policy_info

        observation_shape = self._training_environment.active_observation_shape
        action_shape = self._training_environment.action_space.shape

        assert len(observation_shape) == 1, observation_shape
        self._observation_shape = observation_shape
        assert len(action_shape) == 1, action_shape
        self._action_shape = action_shape

        self._build()

        #### load replay pool data
        self._pool_load_path = pool_load_path
        self._pool_load_max_size = pool_load_max_size

        loader.restore_pool(self._pool,
                            self._pool_load_path,
                            self._pool_load_max_size,
                            save_path=self._log_dir)
        self._init_pool_size = self._pool.size
        print('[ MOPO ] Starting with pool size: {}'.format(
            self._init_pool_size))
        ####

    def _build(self):
        self._training_ops = {}

        self._init_global_step()
        self._init_placeholders()
        self._init_actor_update()
        self._init_critic_update()

    def _train(self):
        """Return a generator that performs RL training.

        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy (`Policy`): Policy used for training
            initial_exploration_policy ('Policy'): Policy used for exploration
                If None, then all exploration is done using policy
            pool (`PoolBase`): Sample pool to add samples to
        """
        training_environment = self._training_environment
        evaluation_environment = self._evaluation_environment
        policy = self._policy
        pool = self._pool
        model_metrics = {}

        if not self._training_started:
            self._init_training()

        self.sampler.initialize(training_environment, policy, pool)

        gt.reset_root()
        gt.rename_root('RLAlgorithm')
        gt.set_def_unique(False)

        self._training_before_hook()

        for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)):

            if self._epoch % 200 == 0:
                #### model training
                print('[ MOPO ] log_dir: {} | ratio: {}'.format(
                    self._log_dir, self._real_ratio))
                print(
                    '[ MOPO ] Training model at epoch {} | freq {} | timestep {} (total: {})'
                    .format(self._epoch, self._model_train_freq,
                            self._timestep, self._total_timestep))
                max_epochs = 1 if self._model.model_loaded else None
                model_train_metrics = self._train_model(
                    batch_size=256,
                    max_epochs=max_epochs,
                    holdout_ratio=0.2,
                    max_t=self._max_model_t)
                model_metrics.update(model_train_metrics)
                self._log_model()
                gt.stamp('epoch_train_model')
                ####

            self._epoch_before_hook()
            gt.stamp('epoch_before_hook')

            self._training_progress = Progress(self._epoch_length *
                                               self._n_train_repeat)
            start_samples = self.sampler._total_samples
            for timestep in count():
                self._timestep = timestep

                if (timestep >= self._epoch_length and self.ready_to_train):
                    break

                self._timestep_before_hook()
                gt.stamp('timestep_before_hook')

                ## model rollouts
                if timestep % self._model_train_freq == 0 and self._real_ratio < 1.0:
                    self._training_progress.pause()
                    self._set_rollout_length()
                    self._reallocate_model_pool()
                    model_rollout_metrics = self._rollout_model(
                        rollout_batch_size=self._rollout_batch_size,
                        deterministic=self._deterministic)
                    model_metrics.update(model_rollout_metrics)

                    gt.stamp('epoch_rollout_model')
                    self._training_progress.resume()

                ## train actor and critic
                if self.ready_to_train:
                    self._do_training_repeats(timestep=timestep)
                gt.stamp('train')

                self._timestep_after_hook()
                gt.stamp('timestep_after_hook')

            training_paths = self.sampler.get_last_n_paths(
                math.ceil(self._epoch_length / self.sampler._max_path_length))

            evaluation_paths = self._evaluation_paths(policy,
                                                      evaluation_environment)
            gt.stamp('evaluation_paths')

            if evaluation_paths:
                evaluation_metrics = self._evaluate_rollouts(
                    evaluation_paths, evaluation_environment)
                gt.stamp('evaluation_metrics')
            else:
                evaluation_metrics = {}

            gt.stamp('epoch_after_hook')

            sampler_diagnostics = self.sampler.get_diagnostics()

            diagnostics = self.get_diagnostics(
                iteration=self._total_timestep,
                batch=self._evaluation_batch(),
                training_paths=training_paths,
                evaluation_paths=evaluation_paths)

            time_diagnostics = gt.get_times().stamps.itrs

            diagnostics.update(
                OrderedDict((
                    *((f'evaluation/{key}', evaluation_metrics[key])
                      for key in sorted(evaluation_metrics.keys())),
                    *((f'times/{key}', time_diagnostics[key][-1])
                      for key in sorted(time_diagnostics.keys())),
                    *((f'sampler/{key}', sampler_diagnostics[key])
                      for key in sorted(sampler_diagnostics.keys())),
                    *((f'model/{key}', model_metrics[key])
                      for key in sorted(model_metrics.keys())),
                    ('epoch', self._epoch),
                    ('timestep', self._timestep),
                    ('timesteps_total', self._total_timestep),
                    ('train-steps', self._num_train_steps),
                )))

            if self._eval_render_mode is not None and hasattr(
                    evaluation_environment, 'render_rollouts'):
                training_environment.render_rollouts(evaluation_paths)

            ## ensure we did not collect any more data
            assert self._pool.size == self._init_pool_size

            yield diagnostics

        epi_ret = self._rollout_model_for_eval(
            self._training_environment.reset)
        np.savetxt("EEepi_ret__fin.csv", epi_ret, delimiter=',')

        self.sampler.terminate()

        self._training_after_hook()

        self._training_progress.close()

        yield {'done': True, **diagnostics}

    def train(self, *args, **kwargs):
        return self._train(*args, **kwargs)

    def _log_policy(self):
        save_path = os.path.join(self._log_dir, 'models')
        filesystem.mkdir(save_path)
        weights = self._policy.get_weights()
        data = {'policy_weights': weights}
        full_path = os.path.join(save_path,
                                 'policy_{}.pkl'.format(self._total_timestep))
        print('Saving policy to: {}'.format(full_path))
        pickle.dump(data, open(full_path, 'wb'))

    def _log_model(self):
        print('MODEL: {}'.format(self._model_type))
        if self._model_type == 'identity':
            print('[ MOPO ] Identity model, skipping save')
        elif self._model.model_loaded:
            print('[ MOPO ] Loaded model, skipping save')
        else:
            save_path = os.path.join(self._log_dir, 'models')
            filesystem.mkdir(save_path)
            print('[ MOPO ] Saving model to: {}'.format(save_path))
            self._model.save(save_path, self._total_timestep)

    def _set_rollout_length(self):
        min_epoch, max_epoch, min_length, max_length = self._rollout_schedule
        if self._epoch <= min_epoch:
            y = min_length
        else:
            dx = (self._epoch - min_epoch) / (max_epoch - min_epoch)
            dx = min(dx, 1)
            y = dx * (max_length - min_length) + min_length

        self._rollout_length = int(y)
        print(
            '[ Model Length ] Epoch: {} (min: {}, max: {}) | Length: {} (min: {} , max: {})'
            .format(self._epoch, min_epoch, max_epoch, self._rollout_length,
                    min_length, max_length))

    def _reallocate_model_pool(self):
        obs_space = self._pool._observation_space
        act_space = self._pool._action_space

        rollouts_per_epoch = self._rollout_batch_size * self._epoch_length / self._model_train_freq
        model_steps_per_epoch = int(self._rollout_length * rollouts_per_epoch)
        new_pool_size = self._model_retain_epochs * model_steps_per_epoch

        if not hasattr(self, '_model_pool'):
            print(
                '[ MOPO ] Initializing new model pool with size {:.2e}'.format(
                    new_pool_size))
            self._model_pool = SimpleReplayPool(obs_space, act_space,
                                                new_pool_size)

        elif self._model_pool._max_size != new_pool_size:
            print('[ MOPO ] Updating model pool | {:.2e} --> {:.2e}'.format(
                self._model_pool._max_size, new_pool_size))
            samples = self._model_pool.return_all_samples()
            new_pool = SimpleReplayPool(obs_space, act_space, new_pool_size)
            new_pool.add_samples(samples)
            assert self._model_pool.size == new_pool.size
            self._model_pool = new_pool

    def _train_model(self, **kwargs):

        from copy import deepcopy

        # hyperparameter
        smth = 0.1
        B_dash = 200.

        env_samples = self._pool.return_all_samples()
        train_inputs_master, train_outputs_master = format_samples_for_training(
            env_samples)

        splitnum = 100

        # for debug
        permutation = np.random.permutation(
            train_inputs_master.shape[0])[:int(train_inputs_master.shape[0] /
                                               20)]
        #train_inputs_master = train_inputs_master[permutation]
        #train_outputs_master = train_outputs_master[permutation]
        #np.savetxt("train_inputs.csv",train_inputs_master,delimiter=',')
        #np.savetxt("train_outputs.csv",train_outputs_master,delimiter=',')
        if debug_data:
            np.savetxt("reward_data.csv",
                       train_outputs_master[:, :1],
                       delimiter=',')

        def compute_dr_weights():
            for j in range(3):
                train_inputs = deepcopy(train_inputs_master)
                if 200000 < train_inputs.shape[0]:
                    np.random.shuffle(train_inputs)
                    train_inputs = train_inputs[:200000]

                fake_inputs = self._rollout_model_for_dr(
                    self._training_environment.reset, train_inputs.shape[0])

                # train ratio model
                _ = self._modelr.train(train_inputs, fake_inputs, **kwargs)

                train_inputs = deepcopy(train_inputs_master)
                #dr_weights, _ = self._modelr.predict(train_inputs)

                train_inputs_list = np.array_split(train_inputs, splitnum)
                dr_weights, _ = self._modelr.predict(train_inputs_list[0])
                for i in range(1, splitnum):
                    temp_dr_weights, _ = self._modelr.predict(
                        train_inputs_list[i])
                    dr_weights = np.concatenate([dr_weights, temp_dr_weights],
                                                0)

                if dr_weights.sum() > 0:
                    break
                else:
                    np.savetxt("dr_weights_raw_for_debug" + str(j) + ".csv",
                               dr_weights,
                               delimiter=',')
            dr_weights *= dr_weights.shape[0] / dr_weights.sum()
            return dr_weights

        if self._model_type == 'identity':
            print('[ MOPO ] Identity model, skipping model')
            model_metrics = {}
        else:
            if self._epoch > 0:
                epi_ret = self._rollout_model_for_eval(
                    self._training_environment.reset)
                np.savetxt("epi_ret__" + str(self._epoch) + ".csv",
                           epi_ret,
                           delimiter=',')

            # compute weight
            print("training weights model for training")
            if self._epoch > 0:
                dr_weights = compute_dr_weights()
                if debug_data:
                    np.savetxt("dr_weights_raw_" + str(self._epoch) + ".csv",
                               dr_weights,
                               delimiter=',')
            else:
                dr_weights = np.ones((train_inputs_master.shape[0], 1))

            # train dynamics model
            print("training dynamics model ")
            actual_dr_weights = dr_weights * smth + (1. - smth)
            if debug_data:
                np.savetxt("dr_weights_train_" + str(self._epoch) + ".csv",
                           actual_dr_weights,
                           delimiter=',')
            if (smth > -0.01) or (self._epoch == 0):
                train_inputs = deepcopy(train_inputs_master)
                train_outputs = deepcopy(train_outputs_master)
                model_metrics = self._model.train(train_inputs, train_outputs,
                                                  actual_dr_weights, **kwargs)
                self._model_metrics_prev = model_metrics
            else:
                model_metrics = self._model_metrics_prev

            # compute weight
            print("training weights model for evaluation")
            dr_weights = compute_dr_weights()
            if debug_data:
                np.savetxt("dr_weights_eval_" + str(self._epoch) + ".csv",
                           dr_weights,
                           delimiter=',')

            # compute loss
            print("compute pointwise loss for evaluation")
            train_inputs = deepcopy(train_inputs_master)
            train_outputs = deepcopy(train_outputs_master)
            train_inputs_list = np.array_split(train_inputs, splitnum)
            train_outputs_list = np.array_split(train_outputs, splitnum)
            loss_list = self._model.get_pointwise_loss(train_inputs_list[0],
                                                       train_outputs_list[0])
            for i in range(1, splitnum):
                temp_loss_list = self._model.get_pointwise_loss(
                    train_inputs_list[i], train_outputs_list[i])
                loss_list = np.concatenate([loss_list, temp_loss_list], 0)
            np.savetxt("loss_list" + str(self._epoch) + ".csv",
                       loss_list,
                       delimiter=',')
            losses = np.mean(loss_list * dr_weights)
            loss_min = np.min(loss_list)

            # compute coeff
            b_coeff = 0.5 * B_dash / np.sqrt(losses - loss_min)
            np.savetxt("Lloss_bcoeff_minloss_Bdash_smth" + str(self._epoch) +
                       ".csv",
                       np.array([losses, b_coeff, loss_min, B_dash, smth]),
                       delimiter=',')

            # loss_model
            print("training loss model")

            #loss_list = ( b_coeff * (1.-self._discount) * (loss_list - loss_min) )

            loss_list = (loss_list - loss_min)
            #loss_list = - ( b_coeff * (1.-self._discount) * (loss_list) - train_outputs_master[:,:1] )

            loss_list2 = loss_list.reshape(loss_list.shape[0], )
            #q25_q75 = np.percentile(loss_list2, q=[25, 75])
            #iqr = q25_q75[1] - q25_q75[0]
            #cutoff_low = q25_q75[0] - iqr*1.5
            #cutoff_high = q25_q75[1] + iqr*1.5
            #idx = np.where((loss_list2 > cutoff_low) & (loss_list2 < cutoff_high))

            #Standard Deviation Method
            loss_mean = np.mean(loss_list2)
            loss_std = np.std(loss_list2)
            cutoff_low = loss_mean - loss_std * 2.
            cutoff_high = loss_mean + loss_std * 2.
            idx = np.where((loss_list2 > cutoff_low)
                           & (loss_list2 < cutoff_high))

            if debug_data:
                np.savetxt("c_" + str(self._epoch) + ".csv",
                           loss_list,
                           delimiter=',')
            #loss_list = loss_list[idx]
            train_inputs = deepcopy(train_inputs_master)  #[idx]
            actual_dr_weights = (dr_weights * 0. + 1.)  #[idx]
            self._modelc.train(train_inputs, loss_list, actual_dr_weights,
                               **kwargs)

            train_inputs = deepcopy(train_inputs_master)
            train_inputs_list = np.array_split(train_inputs, splitnum)
            penalty_np, _ = self._modelc.predict(train_inputs_list[0])
            for i in range(1, splitnum):
                temp_penalty_np, _ = self._modelc.predict(train_inputs_list[i])
                #print("temp_penalty_np",temp_penalty_np.shape)
                penalty_np = np.concatenate([penalty_np, temp_penalty_np], 0)
            if debug_data:
                np.savetxt("Ppredict_BC_" + str(self._epoch) + ".csv",
                           penalty_np * (1. - self._discount),
                           delimiter=',')

            # implement loss model
            self.fake_env.another_reward_model = self._modelc
            self.fake_env.coeff = b_coeff * (1. - self._discount)

        return model_metrics

    def _rollout_model(self, rollout_batch_size, **kwargs):
        print(
            '[ Model Rollout ] Starting | Epoch: {} | Rollout length: {} | Batch size: {} | Type: {}'
            .format(self._epoch, self._rollout_length, rollout_batch_size,
                    self._model_type))
        batch = self.sampler.random_batch(rollout_batch_size)
        obs = batch['observations']
        steps_added = []
        for i in range(self._rollout_length):
            if not self._rollout_random:
                act = self._policy.actions_np(obs)
            else:
                act_ = self._policy.actions_np(obs)
                act = np.random.uniform(low=-1, high=1, size=act_.shape)

            if self._model_type == 'identity':
                next_obs = obs
                rew = np.zeros((len(obs), 1))
                term = (np.ones(
                    (len(obs), 1)) * self._identity_terminal).astype(np.bool)
                info = {}
            else:
                next_obs, rew, term, info = self.fake_env.step(
                    obs, act, **kwargs)
            steps_added.append(len(obs))
            print("rew_min, rew_mean, rew_max, ", np.min(rew), np.mean(rew),
                  np.max(rew))
            print("pen_min, pen_mean, pen_max, ", np.min(info['penalty']),
                  np.mean(info['penalty']), np.max(info['penalty']))

            samples = {
                'observations': obs,
                'actions': act,
                'next_observations': next_obs,
                'rewards': rew,
                'terminals': term
            }
            self._model_pool.add_samples(samples)

            nonterm_mask = ~term.squeeze(-1)
            if nonterm_mask.sum() == 0:
                print('[ Model Rollout ] Breaking early: {} | {} / {}'.format(
                    i, nonterm_mask.sum(), nonterm_mask.shape))
                break
            obs = next_obs[nonterm_mask]

        mean_rollout_length = sum(steps_added) / rollout_batch_size
        rollout_stats = {'mean_rollout_length': mean_rollout_length}
        print(
            '[ Model Rollout ] Added: {:.1e} | Model pool: {:.1e} (max {:.1e}) | Length: {} | Train rep: {}'
            .format(sum(steps_added), self._model_pool.size,
                    self._model_pool._max_size, mean_rollout_length,
                    self._n_train_repeat))
        return rollout_stats

    def _rollout_model_for_dr(self, env_reset, data_num, **kwargs):

        from copy import deepcopy
        print("rollout for ratio estimation")
        ob = np.array([env_reset()['observations'] for i in range(20)])
        ac = self._policy.actions_np(ob)
        ob_store = deepcopy(ob)
        ac_store = deepcopy(ac)
        total_ob_store = None
        total_ac_store = None
        while True:
            overflowFlag = False
            while True:
                ob, rew, term, info = self.fake_env.step(ob, ac, **kwargs)
                nonterm_mask = ~term.squeeze(-1)
                ob = ob[nonterm_mask]
                temp_rand = np.random.rand(ob.shape[0])
                ob = ob[np.where(temp_rand < self._discount)]
                if ob.shape[0] == 0:
                    break
                if np.count_nonzero(np.isnan(ob)) > 0 or (
                        np.nanmax(ob) > 1.e20) or (np.nanmin(ob) < -1.e20):
                    overflowFlag = True
                    break
                ac = self._policy.actions_np(ob)
                if np.count_nonzero(np.isnan(ac)) > 0 or (
                        np.nanmax(ac) > 1.e20) or (np.nanmin(ac) < -1.e20):
                    overflowFlag = True
                    break
                ob_store = np.concatenate([ob_store, ob])
                ac_store = np.concatenate([ac_store, ac])

            if overflowFlag is False:
                if total_ob_store is None:
                    total_ob_store = deepcopy(ob_store)
                    total_ac_store = deepcopy(ac_store)
                else:
                    total_ob_store = np.concatenate([total_ob_store, ob_store])
                    total_ac_store = np.concatenate([total_ac_store, ac_store])

            if total_ob_store is not None:
                print("total_ob_store.shape[0]", total_ob_store.shape[0],
                      "overflowFlag", overflowFlag)
                if total_ob_store.shape[0] > data_num:
                    break

            ob = np.array([env_reset()['observations'] for i in range(20)])
            ac = self._policy.actions_np(ob)
            ob_store = deepcopy(ob)
            ac_store = deepcopy(ac)

        ret_data = np.concatenate([total_ob_store, total_ac_store], axis=1)
        np.random.shuffle(ret_data)
        return ret_data[:int(data_num)]

    def _rollout_model_for_eval(self, env_reset, **kwargs):

        epi_ret_list = []
        for j in range(20):
            ob = env_reset()['observations']
            ob = ob.reshape(1, ob.shape[0])
            temp_epi_ret = 0.
            temp_gamma = 1.
            for i in range(1000):
                ac = self._policy.actions_np(ob)
                ob, rew, term, info = self.fake_env.step(ob, ac, **kwargs)
                #temp_epi_ret += temp_gamma*rew
                temp_epi_ret += temp_gamma * info['unpenalized_rewards']
                temp_gamma *= self._discount
                if (True in term[0]):
                    break
                #if np.random.rand()>self._discount:
                #    break
            print("term", term, ", epi_ret", temp_epi_ret[0][0], ", last_rew",
                  rew, ", unpenalized_rewards", info['unpenalized_rewards'],
                  ", penalty", info['penalty'])
            #print("mean",info['mean'])
            if not np.isnan(temp_epi_ret[0][0]):
                epi_ret_list.append(temp_epi_ret[0][0])
        #return sum(epi_ret_list)/len(epi_ret_list)
        return np.array(epi_ret_list)

    def _visualize_model(self, env, timestep):
        ## save env state
        state = env.unwrapped.state_vector()
        qpos_dim = len(env.unwrapped.sim.data.qpos)
        qpos = state[:qpos_dim]
        qvel = state[qpos_dim:]

        print('[ Visualization ] Starting | Epoch {} | Log dir: {}\n'.format(
            self._epoch, self._log_dir))
        visualize_policy(env, self.fake_env, self._policy, self._writer,
                         timestep)
        print('[ Visualization ] Done')
        ## set env state
        env.unwrapped.set_state(qpos, qvel)

    def _training_batch(self, batch_size=None):
        batch_size = batch_size or self.sampler._batch_size
        env_batch_size = int(batch_size * self._real_ratio)
        model_batch_size = batch_size - env_batch_size

        ## can sample from the env pool even if env_batch_size == 0
        env_batch = self._pool.random_batch(env_batch_size)

        if model_batch_size > 0:
            model_batch = self._model_pool.random_batch(model_batch_size)

            # keys = env_batch.keys()
            keys = set(env_batch.keys()) & set(model_batch.keys())
            batch = {
                k: np.concatenate((env_batch[k], model_batch[k]), axis=0)
                for k in keys
            }
        else:
            ## if real_ratio == 1.0, no model pool was ever allocated,
            ## so skip the model pool sampling
            batch = env_batch
        return batch

    def _init_global_step(self):
        self.global_step = training_util.get_or_create_global_step()
        self._training_ops.update(
            {'increment_global_step': training_util._increment_global_step(1)})

    def _init_placeholders(self):
        """Create input placeholders for the SAC algorithm.

        Creates `tf.placeholder`s for:
            - observation
            - next observation
            - action
            - reward
            - terminals
        """
        self._iteration_ph = tf.placeholder(tf.int64,
                                            shape=None,
                                            name='iteration')

        self._observations_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._observation_shape),
            name='observation',
        )

        self._next_observations_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._observation_shape),
            name='next_observation',
        )

        self._actions_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._action_shape),
            name='actions',
        )

        self._rewards_ph = tf.placeholder(
            tf.float32,
            shape=(None, 1),
            name='rewards',
        )

        self._terminals_ph = tf.placeholder(
            tf.float32,
            shape=(None, 1),
            name='terminals',
        )

        if self._store_extra_policy_info:
            self._log_pis_ph = tf.placeholder(
                tf.float32,
                shape=(None, 1),
                name='log_pis',
            )
            self._raw_actions_ph = tf.placeholder(
                tf.float32,
                shape=(None, *self._action_shape),
                name='raw_actions',
            )

    def _get_Q_target(self):
        next_actions = self._policy.actions([self._next_observations_ph])
        next_log_pis = self._policy.log_pis([self._next_observations_ph],
                                            next_actions)

        next_Qs_values = tuple(
            Q([self._next_observations_ph, next_actions])
            for Q in self._Q_targets)

        min_next_Q = tf.reduce_min(next_Qs_values, axis=0)
        next_value = min_next_Q - self._alpha * next_log_pis

        Q_target = td_target(reward=self._reward_scale * self._rewards_ph,
                             discount=self._discount,
                             next_value=(1 - self._terminals_ph) * next_value)

        return Q_target

    def _init_critic_update(self):
        """Create minimization operation for critic Q-function.

        Creates a `tf.optimizer.minimize` operation for updating
        critic Q-function with gradient descent, and appends it to
        `self._training_ops` attribute.
        """
        Q_target = tf.stop_gradient(self._get_Q_target())

        assert Q_target.shape.as_list() == [None, 1]

        Q_values = self._Q_values = tuple(
            Q([self._observations_ph, self._actions_ph]) for Q in self._Qs)

        Q_losses = self._Q_losses = tuple(
            tf.losses.mean_squared_error(
                labels=Q_target, predictions=Q_value, weights=0.5)
            for Q_value in Q_values)

        self._Q_optimizers = tuple(
            tf.train.AdamOptimizer(learning_rate=self._Q_lr,
                                   name='{}_{}_optimizer'.format(Q._name, i))
            for i, Q in enumerate(self._Qs))
        Q_training_ops = tuple(
            tf.contrib.layers.optimize_loss(Q_loss,
                                            self.global_step,
                                            learning_rate=self._Q_lr,
                                            optimizer=Q_optimizer,
                                            variables=Q.trainable_variables,
                                            increment_global_step=False,
                                            summaries=((
                                                "loss", "gradients",
                                                "gradient_norm",
                                                "global_gradient_norm"
                                            ) if self._tf_summaries else ()))
            for i, (Q, Q_loss, Q_optimizer) in enumerate(
                zip(self._Qs, Q_losses, self._Q_optimizers)))

        self._training_ops.update({'Q': tf.group(Q_training_ops)})

    def _init_actor_update(self):
        """Create minimization operations for policy and entropy.

        Creates a `tf.optimizer.minimize` operations for updating
        policy and entropy with gradient descent, and adds them to
        `self._training_ops` attribute.
        """

        actions = self._policy.actions([self._observations_ph])
        log_pis = self._policy.log_pis([self._observations_ph], actions)

        assert log_pis.shape.as_list() == [None, 1]

        log_alpha = self._log_alpha = tf.get_variable('log_alpha',
                                                      dtype=tf.float32,
                                                      initializer=0.0)
        alpha = tf.exp(log_alpha)

        if isinstance(self._target_entropy, Number):
            alpha_loss = -tf.reduce_mean(
                log_alpha * tf.stop_gradient(log_pis + self._target_entropy))

            self._alpha_optimizer = tf.train.AdamOptimizer(
                self._policy_lr, name='alpha_optimizer')
            self._alpha_train_op = self._alpha_optimizer.minimize(
                loss=alpha_loss, var_list=[log_alpha])

            self._training_ops.update(
                {'temperature_alpha': self._alpha_train_op})

        self._alpha = alpha

        if self._action_prior == 'normal':
            policy_prior = tf.contrib.distributions.MultivariateNormalDiag(
                loc=tf.zeros(self._action_shape),
                scale_diag=tf.ones(self._action_shape))
            policy_prior_log_probs = policy_prior.log_prob(actions)
        elif self._action_prior == 'uniform':
            policy_prior_log_probs = 0.0

        Q_log_targets = tuple(
            Q([self._observations_ph, actions]) for Q in self._Qs)
        min_Q_log_target = tf.reduce_min(Q_log_targets, axis=0)

        if self._reparameterize:
            policy_kl_losses = (alpha * log_pis - min_Q_log_target -
                                policy_prior_log_probs)
        else:
            raise NotImplementedError

        assert policy_kl_losses.shape.as_list() == [None, 1]

        policy_loss = tf.reduce_mean(policy_kl_losses)

        self._policy_optimizer = tf.train.AdamOptimizer(
            learning_rate=self._policy_lr, name="policy_optimizer")
        policy_train_op = tf.contrib.layers.optimize_loss(
            policy_loss,
            self.global_step,
            learning_rate=self._policy_lr,
            optimizer=self._policy_optimizer,
            variables=self._policy.trainable_variables,
            increment_global_step=False,
            summaries=("loss", "gradients", "gradient_norm",
                       "global_gradient_norm") if self._tf_summaries else ())

        self._training_ops.update({'policy_train_op': policy_train_op})

    def _init_training(self):
        self._update_target(tau=1.0)

    def _update_target(self, tau=None):
        tau = tau or self._tau

        for Q, Q_target in zip(self._Qs, self._Q_targets):
            source_params = Q.get_weights()
            target_params = Q_target.get_weights()
            Q_target.set_weights([
                tau * source + (1.0 - tau) * target
                for source, target in zip(source_params, target_params)
            ])

    def _do_training(self, iteration, batch):
        """Runs the operations for updating training and target ops."""

        self._training_progress.update()
        self._training_progress.set_description()

        feed_dict = self._get_feed_dict(iteration, batch)

        self._session.run(self._training_ops, feed_dict)

        if iteration % self._target_update_interval == 0:
            # Run target ops here.
            self._update_target()

    def _get_feed_dict(self, iteration, batch):
        """Construct TensorFlow feed_dict from sample batch."""

        feed_dict = {
            self._observations_ph: batch['observations'],
            self._actions_ph: batch['actions'],
            self._next_observations_ph: batch['next_observations'],
            self._rewards_ph: batch['rewards'],
            self._terminals_ph: batch['terminals'],
        }

        if self._store_extra_policy_info:
            feed_dict[self._log_pis_ph] = batch['log_pis']
            feed_dict[self._raw_actions_ph] = batch['raw_actions']

        if iteration is not None:
            feed_dict[self._iteration_ph] = iteration

        return feed_dict

    def get_diagnostics(self, iteration, batch, training_paths,
                        evaluation_paths):
        """Return diagnostic information as ordered dictionary.

        Records mean and standard deviation of Q-function and state
        value function, and TD-loss (mean squared Bellman error)
        for the sample batch.

        Also calls the `draw` method of the plotter, if plotter defined.
        """

        feed_dict = self._get_feed_dict(iteration, batch)

        (Q_values, Q_losses, alpha, global_step) = self._session.run(
            (self._Q_values, self._Q_losses, self._alpha, self.global_step),
            feed_dict)

        diagnostics = OrderedDict({
            'Q-avg': np.mean(Q_values),
            'Q-std': np.std(Q_values),
            'Q_loss': np.mean(Q_losses),
            'alpha': alpha,
        })

        policy_diagnostics = self._policy.get_diagnostics(
            batch['observations'])
        diagnostics.update({
            f'policy/{key}': value
            for key, value in policy_diagnostics.items()
        })

        if self._plotter:
            self._plotter.draw()

        return diagnostics

    @property
    def tf_saveables(self):
        saveables = {
            '_policy_optimizer': self._policy_optimizer,
            **{
                f'Q_optimizer_{i}': optimizer
                for i, optimizer in enumerate(self._Q_optimizers)
            },
            '_log_alpha': self._log_alpha,
        }

        if hasattr(self, '_alpha_optimizer'):
            saveables['_alpha_optimizer'] = self._alpha_optimizer

        return saveables
Beispiel #7
0
class BMOPO(RLAlgorithm):
    def __init__(
        self,
        training_environment,
        evaluation_environment,
        policy,
        Qs,
        pool,  # used to store env samples
        static_fns,
        log_file=None,
        plotter=None,  # not used, can be used to draw Q function
        tf_summaries=False,  # not used here
        lr=3e-4,
        reward_scale=1.0,
        target_entropy='auto',
        discount=0.99,  # rewards discount
        tau=5e-3,  # ratio when updating the target_Q function.
        target_update_interval=1,  # Frequency at which target network updates occur in iterations.
        action_prior='uniform',
        reparameterize=False,  # If True, we use a gradient estimator for the policy derived using the reparameterization trick. We use a likelihood ratio based estimator otherwise.
        store_extra_policy_info=False,
        deterministic=False,  # whether to use deterministic model (mean) when unrolling model to collect samples
        model_train_freq=250,  # Frequency to (train) unroll the model
        num_networks=7,  # The amount of ensembles
        num_elites=5,  # selected amount of elite ensembles
        model_retain_epochs=20,
        rollout_batch_size=100e3,  # The batch size when unrolling the model, the total amount of collected samples is batch * length
        real_ratio=0.1,  # The ratio of env_data/model_data when feeding training data
        forward_rollout_schedule=None,  # forward rollout schedule: [min_epoch, max_epoch, min_length, max_length]
        backward_rollout_schedule=None,
        last_n_epoch=10,  # TODO: HOW TO USE IT? last n epoch data to collect used for training the backward policy
        backward_policy_var=0,  # A fixed var for backward policy
        hidden_dim=200,  # hidden_dim of the nn model
        max_model_t=None,  # the maximal training time

        # TODO: not existed in bmpo but in mopo
        pool_load_path='',  # the path to load d4rl dataset
        pool_load_max_size=0,  # the max size of load data
        # The penalty term used for offline setting
        separate_mean_var=False,  #TODO: Use this latter
        penalty_coeff=0.,
        penalty_learned_var=False,
        **kwargs,
    ):
        """
        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy: A policy function approximator.
            initial_exploration_policy: ('Policy'): A policy that we use
                for initial exploration which is not trained by the algorithm.
            Qs: Q-function approximators. The min of these
                approximators will be used. Usage of at least two Q-functions
                improves performance by reducing overestimation bias.
            pool (`PoolBase`): Replay pool to add gathered samples to.
            plotter (`QFPolicyPlotter`): Plotter instance to be used for
                visualizing Q-function during training.
            lr (`float`): Learning rate used for the function approximators.
            discount (`float`): Discount factor for Q-function updates.
            tau (`float`): Soft value function target update weight.
            target_update_interval ('int'): Frequency at which target network
                updates occur in iterations.
            reparameterize ('bool'): If True, we use a gradient estimator for
                the policy derived using the reparameterization trick. We use
                a likelihood ratio based estimator otherwise.
        """

        super(BMOPO, self).__init__(**kwargs)

        obs_dim = np.prod(training_environment.observation_space.shape)
        act_dim = np.prod(training_environment.action_space.shape)
        self._obs_dim = obs_dim
        self._act_dim = act_dim
        self._forward_model = construct_forward_model(
            obs_dim=obs_dim,
            act_dim=act_dim,
            hidden_dim=hidden_dim,
            num_networks=num_networks,
            num_elites=num_elites)
        self._backward_model = construct_backward_model(
            obs_dim=obs_dim,
            act_dim=act_dim,
            hidden_dim=hidden_dim,
            num_networks=num_networks,
            num_elites=num_elites)
        self._static_fns = static_fns
        self.f_fake_env = Forward_FakeEnv(
            self._forward_model,
            self._static_fns,
            penalty_coeff=penalty_coeff,
            penalty_learned_var=penalty_learned_var)
        self.b_fake_env = Backward_FakeEnv(
            self._backward_model,
            self._static_fns,
            penalty_coeff=penalty_coeff,
            penalty_learned_var=penalty_learned_var)

        self._forward_rollout_schedule = forward_rollout_schedule
        self._backward_rollout_schedule = backward_rollout_schedule
        self._max_model_t = max_model_t

        self._model_retain_epochs = model_retain_epochs

        self._model_train_freq = model_train_freq
        self._rollout_batch_size = int(rollout_batch_size)
        self._deterministic = deterministic
        self._real_ratio = real_ratio

        self._log_dir = os.getcwd()

        self._training_environment = training_environment
        self._evaluation_environment = evaluation_environment
        self._policy = policy

        self._Qs = Qs
        self._Q_targets = tuple(tf.keras.models.clone_model(Q) for Q in Qs)

        self._pool = pool
        self._last_n_epoch = int(last_n_epoch)
        self._backward_policy_var = backward_policy_var

        self._plotter = plotter
        self._tf_summaries = tf_summaries

        self._policy_lr = lr
        self._Q_lr = lr

        self._reward_scale = reward_scale
        self._target_entropy = (
            -np.prod(self._training_environment.action_space.shape)
            if target_entropy == 'auto' else target_entropy)
        print('Target entropy: {}'.format(self._target_entropy))

        self._discount = discount
        self._tau = tau
        self._target_update_interval = target_update_interval
        self._action_prior = action_prior

        self._reparameterize = reparameterize
        self._store_extra_policy_info = store_extra_policy_info

        observation_shape = self._training_environment.active_observation_shape
        action_shape = self._training_environment.action_space.shape

        assert len(observation_shape) == 1, observation_shape
        self._observation_shape = observation_shape
        assert len(action_shape) == 1, action_shape
        self._action_shape = action_shape
        self.log_file = log_file

        self._build()

        ## Load replay pool data (load d4rl dataset) TODO: add pool_load_path, pool_load_max_size into params
        self._pool_load_path = pool_load_path
        self._pool_load_max_size = pool_load_max_size
        # load samples from d4rl
        restore_pool(self._pool,
                     self._pool_load_path,
                     self._pool_load_max_size,
                     save_path=self._log_dir)
        self._init_pool_size = self._pool.size
        print('[ BMOPO ] Staring with pool size: {}'.format(
            self._init_pool_size))

    def _build(self):
        self._training_ops = {}

        self._init_global_step()
        self._init_placeholders()
        self._init_actor_update()
        self._init_critic_update()
        self._build_backward_policy(self._act_dim)

    def _build_backward_policy(self, act_dim):
        self._max_logvar = tf.Variable(np.ones([1, act_dim]),
                                       dtype=tf.float32,
                                       name="max_log_var")
        self._min_logvar = tf.Variable(-np.ones([1, act_dim]) * 10.,
                                       dtype=tf.float32,
                                       name="min_log_var")
        self._before_action_mean, self._before_action_logvar = self._backward_policy_net(
            'backward_policy', self._next_observations_ph, act_dim)
        action_logvar = self._max_logvar - tf.nn.softplus(
            self._max_logvar - self._before_action_logvar)
        action_logvar = self._min_logvar + tf.nn.softplus(action_logvar -
                                                          self._min_logvar)
        self._before_action_var = tf.exp(action_logvar)
        self._backward_policy_params = tf.get_collection(
            tf.GraphKeys.GLOBAL_VARIABLES, scope='backward_policy')
        loss1 = tf.reduce_mean(
            tf.square(self._before_action_mean - self._actions_ph) /
            self._before_action_var)
        loss2 = tf.reduce_mean(tf.log(self._before_action_var))
        self._backward_policy_loss = loss1 + loss2
        self._backward_policy_optimizer = tf.train.AdamOptimizer(
            self._policy_lr).minimize(loss=self._backward_policy_loss,
                                      var_list=self._backward_policy_params)

    def _backward_policy_net(self, scope, state, action_dim, hidden_dim=256):
        with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
            hidden_layer1 = tf.layers.dense(state, hidden_dim, tf.nn.relu)
            hidden_layer2 = tf.layers.dense(hidden_layer1, hidden_dim,
                                            tf.nn.relu)
            return tf.tanh(tf.layers.dense(hidden_layer2, action_dim)), \
                   tf.layers.dense(hidden_layer2, action_dim)

    def _get_before_action(self, obs):
        # the backward policy
        before_action_mean, before_action_var = self._session.run(
            [self._before_action_mean, self._before_action_var],
            feed_dict={self._next_observations_ph: obs})
        if (self._backward_policy_var != 0):
            before_action_var = self._backward_policy_var
        X = stats.truncnorm(-2,
                            2,
                            loc=np.zeros_like(before_action_mean),
                            scale=np.ones_like(before_action_mean))
        before_actions = X.rvs(size=np.shape(before_action_mean)) * np.sqrt(
            before_action_var
        ) + before_action_mean  # sample from backward policy
        act = np.clip(before_actions, -1, 1)
        return act

    def _train(self):

        training_environment = self._training_environment
        evaluation_environment = self._evaluation_environment
        policy = self._policy
        pool = self._pool  # init the pool for env data (to store d4rl in offline setting)
        f_model_metrics, b_model_metrics = {}, {
        }  # to store the model metrics, for logging

        if not self._training_started:
            self._init_training()

        self.sampler.initialize(training_environment, policy, pool)

        self._training_before_hook()

        max_epochs = 1 if (self._forward_model.model_loaded
                           and self._backward_model.model_loaded) else None
        # max_epochs = 2 # TODO: CHAGE IT
        # train the forward and the backward dynamic model
        f_model_train_metrics, b_model_train_metrics = self._train_model(
            batch_size=256,
            max_epochs=max_epochs,
            holdout_ratio=0.2,
            max_t=self._max_model_t)
        f_model_metrics.update(f_model_train_metrics)
        b_model_metrics.update(b_model_train_metrics)

        self._log_model()  # print and save model

        # collect samples via unrolling models and use them to update the policy
        for self._epoch in range(
                self._epoch, self._n_epochs
        ):  # self._n_epochs = n_epochs, and it's initialised in RLAlgo

            self._epoch_before_hook()

            start_samples = self.sampler._total_samples
            print('------- epoch: {} --------'.format(self._epoch),
                  "start_samples", start_samples)
            print('[ True Env Buffer Size ]', pool.size)

            for timestep in count():
                # samples_now = self.sampler._total_samples
                # self._timestep = samples_now - start_samples
                self._timestep = timestep

                # if samples_now >= start_samples + self._epoch_length and self.ready_to_train:
                if timestep >= self._epoch_length:
                    break

                self._timestep_before_hook()

                # Rollout the dynamic model to collect more data every self._model_train_freq (1000)
                if self._timestep % self._model_train_freq == 0:

                    self._set_rollout_length(
                    )  # set rollout_length for backward and forward model according to self._b/f_rollout_schedule
                    self._reallocate_model_pool(
                    )  # init the data pool used to store samples via unrolling the dynamic model
                    f_model_rollout_metrics, b_model_rollout_metrics = self._rollout_model(
                        rollout_batch_size=self._rollout_batch_size,
                        deterministic=self._deterministic)
                    f_model_metrics.update(f_model_rollout_metrics)
                    b_model_metrics.update(b_model_rollout_metrics)
                # Train the actor and the critic (forward and backward).
                if self.ready_to_train:
                    self._do_training_repeats(
                        timestep=self._total_timestep
                    )  # Here, for bmopo, need to train both forward and backward policy.

                self._timestep_after_hook()

            training_paths = self.sampler.get_last_n_paths(
                math.ceil(self._epoch_length / self.sampler._max_path_length)
            )  # TODO: this seems wrong: param==1, chech the mopo code
            evaluation_paths = self._evaluation_paths(policy,
                                                      evaluation_environment)

            # training_metrics = self._evaluate_rollouts(
            #     training_paths, training_environment)
            if evaluation_paths:
                evaluation_metrics = self._evaluate_rollouts(
                    evaluation_paths, evaluation_environment)
            else:
                evaluation_metrics = {}

            sampler_diagnostics = self.sampler.get_diagnostics()

            diagnostics = self.get_diagnostics(
                iteration=self._total_timestep,
                batch=self._evaluation_batch(),
                training_paths=training_paths,
                evaluation_paths=evaluation_paths)

            diagnostics.update(
                OrderedDict((
                    *((f'evaluation/{key}', evaluation_metrics[key])
                      for key in sorted(evaluation_metrics.keys())),
                    *((f'sampler/{key}', sampler_diagnostics[key])
                      for key in sorted(sampler_diagnostics.keys())),
                    *((f'forward-model/{key}', f_model_metrics[key])
                      for key in sorted(f_model_metrics.keys())),
                    *((f'backward-model/{key}', b_model_metrics[key])
                      for key in sorted(b_model_metrics.keys())),
                    ('epoch', self._epoch),
                    ('timestep', self._timestep),
                    ('timesteps_total', self._total_timestep),
                    ('train-steps', self._num_train_steps),
                )))

            # logging via wandb
            wandb.log(diagnostics)

            if self._eval_render_mode is not None and hasattr(
                    evaluation_environment, 'render_rollouts'):
                training_environment.render_rollouts(evaluation_paths)
            print(diagnostics)
            # f_log = open(self.log_file, 'a')
            # f_log.write('epoch: %d\n' % self._epoch)
            # f_log.write('total time steps: %d\n' % self._total_timestep)
            # f_log.write('evaluation return: %f\n' % evaluation_metrics['return-average'])
            # f_log.close()

        self.sampler.terminate()

        self._training_after_hook()

    def train(self, *args, **kwargs):
        return self._train(*args, **kwargs)

    def _log_policy(self):
        print('--------- log policy is passed ---------')
        pass

    def _log_model(self):
        if self._forward_model.model_loaded and self._backward_model.model_loaded:
            print('[ MOPO ] Loaded model, skipping save')
        else:
            save_path = os.path.join(self._log_dir, 'models')
            mkdir(save_path)
            print('[ MOPO ] Saving model to: {}'.format(save_path))

            # TODO: check the save function
            self._forward_model.save(save_path, self._total_timestep)
            self._backward_model.save(save_path, self._total_timestep)

    def _set_rollout_length(self):
        # set the rollout_length according to self._backward_rollout_schedule and self._forward_rollout_schedule.
        # the format of the rollout_schedule is [min_epoch, max_epoch, min_length, max_length]

        #set backward rollout length
        min_epoch, max_epoch, min_length, max_length = self._backward_rollout_schedule
        if self._epoch <= min_epoch:
            y = min_length
        else:
            dx = (self._epoch - min_epoch) / (max_epoch - min_epoch)
            dx = min(dx, 1)
            y = dx * (max_length - min_length) + min_length

        self._backward_rollout_length = int(y)
        print(
            '[ Set Backward Model Length ] Epoch: {} (min: {}, max: {}) | Length: {} (min: {} , max: {})'
            .format(self._epoch, min_epoch, max_epoch,
                    self._backward_rollout_length, min_length, max_length))
        # set forward rollout length
        min_epoch, max_epoch, min_length, max_length = self._forward_rollout_schedule
        if self._epoch <= min_epoch:
            y = min_length
        else:
            dx = (self._epoch - min_epoch) / (max_epoch - min_epoch)
            dx = min(dx, 1)
            y = dx * (max_length - min_length) + min_length

        self._forward_rollout_length = int(y)
        print(
            '[ Set Forward Model Length ] Epoch: {} (min: {}, max: {}) | Length: {} (min: {} , max: {})'
            .format(self._epoch, min_epoch, max_epoch,
                    self._forward_rollout_length, min_length, max_length))

    def _reallocate_model_pool(self):
        # init the data pool used to store samples from unrolling the dynamic model
        obs_space = self._pool._observation_space
        act_space = self._pool._action_space

        rollouts_per_epoch = self._rollout_batch_size * self._epoch_length / self._model_train_freq
        model_steps_per_epoch = int(
            (self._forward_rollout_length + self._backward_rollout_length) *
            rollouts_per_epoch)
        new_pool_size = self._model_retain_epochs * model_steps_per_epoch

        if not hasattr(self, '_model_pool'):
            print(
                '[ Allocate Model Pool ] Initializing new model pool with size {:.2e}'
                .format(new_pool_size))
            self._model_pool = SimpleReplayPool(obs_space, act_space,
                                                new_pool_size)

        elif self._model_pool._max_size != new_pool_size:
            print(
                '[ Reallocate Model Pool ] Updating model pool | {:.2e} --> {:.2e}'
                .format(self._model_pool._max_size, new_pool_size))
            samples = self._model_pool.return_all_samples()
            new_pool = SimpleReplayPool(obs_space, act_space, new_pool_size)
            new_pool.add_samples(samples)
            assert self._model_pool.size == new_pool.size
            self._model_pool = new_pool

    def _train_model(self, **kwargs):
        # get all the env samples and use them to train the backward and forward dynamic model
        env_samples = self._pool.return_all_samples()
        print('Training forward model:')
        train_inputs, train_outputs = format_samples_for_forward_training(
            env_samples)
        f_model_metrics = self._forward_model.train(train_inputs,
                                                    train_outputs, **kwargs)
        # TODO 1
        # print('Training backward model:')
        # train_inputs, train_outputs = format_samples_for_backward_training(env_samples)
        # b_model_metrics = self._backward_model.train(train_inputs, train_outputs, **kwargs)
        b_model_metrics = {}
        return f_model_metrics, b_model_metrics

    def _rollout_model(self, rollout_batch_size,
                       **kwargs):  # TODO: change the fake_env to add penalty
        # Rollout model using fake_env and add samples into _model_pool

        # print('[ Backward Model Rollout ] Starting | Epoch: {} | Rollout length: {} | Batch size: {}'.format(
        #     self._epoch, self._backward_rollout_schedule[-1], rollout_batch_size,
        # ))
        #
        batch = self.sampler.random_batch(
            rollout_batch_size)  # sample init states batch from env_pool
        start_obs = batch['observations']

        f_steps_added, b_steps_added = [], [
        ]  # to record the num of collected samples

        # TODO: 2
        # # perform backward rollout
        # obs = start_obs
        # for i in range(self._backward_rollout_length):
        #     act = self._get_before_action(obs)
        #
        #     before_obs, rew, term, info = self.b_fake_env.step(obs, act, **kwargs)
        #
        #     samples = {'observations': before_obs, 'actions': act, 'next_observations': obs, 'rewards': rew,
        #                'terminals': term}
        #     self._model_pool.add_samples(samples)  # _model_pool is to add samples get from unrolling the learned dynamic model
        #     b_steps_added.append(len(obs))
        #
        #     nonterm_mask = ~term.squeeze(-1)
        #     if nonterm_mask.sum() == 0:
        #         print('[ Model Rollout ] Breaking early: {} | {} / {}'.format(i, nonterm_mask.sum(),
        #                                                                       nonterm_mask.shape))
        #         break
        #     obs = before_obs[nonterm_mask]

        # perform forward rollout
        print(
            '[ Forward Model Rollout ] Starting | Epoch: {} | Rollout length: {} | Batch size: {} '
            .format(
                self._epoch,
                self._forward_rollout_schedule[-1],
                rollout_batch_size,
            ))
        obs = start_obs
        for i in range(self._forward_rollout_length):
            act = self._policy.actions_np(obs)

            next_obs, rew, term, info = self.f_fake_env.step(
                obs, act, **kwargs)

            samples = {
                'observations': obs,
                'actions': act,
                'next_observations': next_obs,
                'rewards': rew,
                'terminals': term
            }
            self._model_pool.add_samples(samples)
            f_steps_added.append(len(obs))

            nonterm_mask = ~term.squeeze(-1)
            if nonterm_mask.sum() == 0:
                print('[ Model Rollout ] Breaking early: {} | {} / {}'.format(
                    i, nonterm_mask.sum(), nonterm_mask.shape))
                break

            obs = next_obs[nonterm_mask]

        f_mean_rollout_length, b_mean_rollout_length = sum(
            f_steps_added) / rollout_batch_size, sum(
                b_steps_added) / rollout_batch_size

        print(
            '[ Model Rollout ] Added: {:.1e} | Model pool: {:.1e} (max {:.1e}) | Total Length: {} | Train rep: {}'
            .format(
                (self._forward_rollout_length + self._backward_rollout_length)
                * rollout_batch_size, self._model_pool.size,
                self._model_pool._max_size,
                f_mean_rollout_length + b_mean_rollout_length,
                self._n_train_repeat))

        return {
            'f_mean_rollout_length': f_mean_rollout_length
        }, {
            'b_mean_rollout_length': b_mean_rollout_length
        }  # for logging

    def _training_batch(self, batch_size=None):
        # get the training data used for policy and critic update. Called in self._do_training_repeat()
        # batch samples are from both env_pool and model_pool accroding to self._real_ratio

        batch_size = batch_size or self.sampler._batch_size
        env_batch_size = int(batch_size * self._real_ratio)
        model_batch_size = batch_size - env_batch_size

        ## can sample from the env pool even if env_batch_size == 0
        env_batch = self._pool.random_batch(env_batch_size)

        if model_batch_size > 0:
            model_batch = self._model_pool.random_batch(model_batch_size)

            keys = env_batch.keys()
            batch = {
                k: np.concatenate((env_batch[k], model_batch[k]), axis=0)
                for k in keys
            }
        else:
            ## if real_ratio == 1.0, no model pool was ever allocated,
            ## so skip the model pool sampling
            batch = env_batch
        return batch

    def _init_global_step(self):
        self.global_step = training_util.get_or_create_global_step()
        self._training_ops.update(
            {'increment_global_step': training_util._increment_global_step(1)})

    def _init_placeholders(self):
        """Create input placeholders for the SAC algorithm.

        Creates `tf.placeholder`s for:
            - observation
            - next observation
            - action
            - reward
            - terminals
        """
        self._iteration_ph = tf.placeholder(tf.int64,
                                            shape=None,
                                            name='iteration')

        self._observations_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._observation_shape),
            name='observation',
        )

        self._next_observations_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._observation_shape),
            name='next_observation',
        )

        self._actions_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._action_shape),
            name='actions',
        )

        self._rewards_ph = tf.placeholder(
            tf.float32,
            shape=(None, 1),
            name='rewards',
        )

        self._terminals_ph = tf.placeholder(
            tf.float32,
            shape=(None, 1),
            name='terminals',
        )

        if self._store_extra_policy_info:
            self._log_pis_ph = tf.placeholder(
                tf.float32,
                shape=(None, 1),
                name='log_pis',
            )
            self._raw_actions_ph = tf.placeholder(
                tf.float32,
                shape=(None, *self._action_shape),
                name='raw_actions',
            )

    def _get_Q_target(self):
        next_actions = self._policy.actions([self._next_observations_ph])
        next_log_pis = self._policy.log_pis([self._next_observations_ph],
                                            next_actions)

        next_Qs_values = tuple(
            Q([self._next_observations_ph, next_actions])
            for Q in self._Q_targets)

        min_next_Q = tf.reduce_min(next_Qs_values, axis=0)
        next_value = min_next_Q - self._alpha * next_log_pis

        Q_target = td_target(reward=self._reward_scale * self._rewards_ph,
                             discount=self._discount,
                             next_value=(1 - self._terminals_ph) * next_value)

        return Q_target

    def _init_critic_update(self):
        """Create minimization operation for critic Q-function.

        Creates a `tf.optimizer.minimize` operation for updating
        critic Q-function with gradient descent, and appends it to
        `self._training_ops` attribute.
        """
        Q_target = tf.stop_gradient(self._get_Q_target())

        assert Q_target.shape.as_list() == [None, 1]

        Q_values = self._Q_values = tuple(
            Q([self._observations_ph, self._actions_ph]) for Q in self._Qs)

        Q_losses = self._Q_losses = tuple(
            tf.losses.mean_squared_error(
                labels=Q_target, predictions=Q_value, weights=0.5)
            for Q_value in Q_values)

        self._Q_optimizers = tuple(
            tf.train.AdamOptimizer(learning_rate=self._Q_lr,
                                   name='{}_{}_optimizer'.format(Q._name, i))
            for i, Q in enumerate(self._Qs))
        Q_training_ops = tuple(
            tf.contrib.layers.optimize_loss(Q_loss,
                                            self.global_step,
                                            learning_rate=self._Q_lr,
                                            optimizer=Q_optimizer,
                                            variables=Q.trainable_variables,
                                            increment_global_step=False,
                                            summaries=((
                                                "loss", "gradients",
                                                "gradient_norm",
                                                "global_gradient_norm"
                                            ) if self._tf_summaries else ()))
            for i, (Q, Q_loss, Q_optimizer) in enumerate(
                zip(self._Qs, Q_losses, self._Q_optimizers)))

        self._training_ops.update({'Q': tf.group(Q_training_ops)})

    def _init_actor_update(self):
        """Create minimization operations for policy and entropy.

        Creates a `tf.optimizer.minimize` operations for updating
        policy and entropy with gradient descent, and adds them to
        `self._training_ops` attribute.
        """

        actions = self._policy.actions([self._observations_ph])
        log_pis = self._policy.log_pis([self._observations_ph], actions)
        self._actions = actions

        assert log_pis.shape.as_list() == [None, 1]

        log_alpha = self._log_alpha = tf.get_variable('log_alpha',
                                                      dtype=tf.float32,
                                                      initializer=0.0)
        alpha = tf.exp(log_alpha)

        if isinstance(self._target_entropy, Number):
            alpha_loss = -tf.reduce_mean(
                log_alpha * tf.stop_gradient(log_pis + self._target_entropy))

            self._alpha_optimizer = tf.train.AdamOptimizer(
                self._policy_lr, name='alpha_optimizer')
            self._alpha_train_op = self._alpha_optimizer.minimize(
                loss=alpha_loss, var_list=[log_alpha])

            self._training_ops.update(
                {'temperature_alpha': self._alpha_train_op})

        self._alpha = alpha

        if self._action_prior == 'normal':
            policy_prior = tf.contrib.distributions.MultivariateNormalDiag(
                loc=tf.zeros(self._action_shape),
                scale_diag=tf.ones(self._action_shape))
            policy_prior_log_probs = policy_prior.log_prob(actions)
        elif self._action_prior == 'uniform':
            policy_prior_log_probs = 0.0

        Q_log_targets = tuple(
            Q([self._observations_ph, actions]) for Q in self._Qs)
        min_Q_log_target = tf.reduce_min(Q_log_targets, axis=0)

        self._value = tf.reduce_mean(Q_log_targets, axis=0)
        self._target_value = tf.reduce_mean(tuple(
            Q([self._observations_ph, actions]) for Q in self._Q_targets),
                                            axis=0)

        if self._reparameterize:
            policy_kl_losses = (alpha * log_pis - min_Q_log_target -
                                policy_prior_log_probs)
        else:
            raise NotImplementedError

        assert policy_kl_losses.shape.as_list() == [None, 1]

        policy_loss = tf.reduce_mean(policy_kl_losses)

        self._policy_optimizer = tf.train.AdamOptimizer(
            learning_rate=self._policy_lr, name="policy_optimizer")
        policy_train_op = tf.contrib.layers.optimize_loss(
            policy_loss,
            self.global_step,
            learning_rate=self._policy_lr,
            optimizer=self._policy_optimizer,
            variables=self._policy.trainable_variables,
            increment_global_step=False,
            summaries=("loss", "gradients", "gradient_norm",
                       "global_gradient_norm") if self._tf_summaries else ())

        self._training_ops.update({'policy_train_op': policy_train_op})

    def _init_training(self):
        self._update_target()

    def _update_target(self):
        tau = self._tau

        for Q, Q_target in zip(self._Qs, self._Q_targets):
            source_params = Q.get_weights()
            target_params = Q_target.get_weights()
            Q_target.set_weights([
                tau * source + (1.0 - tau) * target
                for source, target in zip(source_params, target_params)
            ])

    def _do_training(self, iteration, batch):
        """Runs the operations for updating training and target ops.
        Called by the self._do_training_repeats
        """

        # self._training_progress.update()
        # self._training_progress.set_description()

        feed_dict = self._get_feed_dict(iteration, batch)
        self._session.run(self._training_ops,
                          feed_dict)  # update policy and critic

        # update target_Q according to self._tau
        if iteration % self._target_update_interval == 0:
            # Run target ops here.
            self._update_target()

    def _do_training_repeats(self, timestep, backward_policy_train_repeat=1):
        """Repeat training _n_train_repeat times every _train_every_n_steps,
        This method overrides the method in softlearning, since it needs to take care of the backward policy
        """
        if timestep % self._train_every_n_steps > 0: return
        trained_enough = (self._train_steps_this_epoch >
                          self._max_train_repeat_per_timestep * self._timestep)
        if trained_enough: return

        # Train forward policy and Q
        for i in range(
                self._n_train_repeat):  # the default self._n_train_repeat = 1
            self._do_training(iteration=timestep, batch=self._training_batch())

        # TODO: 3
        # # train backward policy -- via maximal likelihood. s'->a
        # for i in range(backward_policy_train_repeat):
        #     """ Our goal is to make the backward rollouts resemble the real trajectory sampled by the current forward policy.Thus
        #     when training the backward policy, we only use the recent trajectories sampled by the agent in the real environment."""
        #     batch = self._pool.last_n_random_batch(last_n=self._epoch_length * self._last_n_epoch, batch_size=256)  # TODO: This is incorrect, it uses the recent traj sampled from the real env.
        #     next_observations = np.array(batch['next_observations'])
        #     actions = np.array(batch['actions'])
        #     feed_dict = {
        #         self._actions_ph: actions,
        #         self._next_observations_ph: next_observations,
        #     }
        #     self._session.run(self._backward_policy_optimizer, feed_dict)

        self._num_train_steps += self._n_train_repeat
        self._train_steps_this_epoch += self._n_train_repeat

    def _get_feed_dict(self, iteration, batch):
        """Construct TensorFlow feed_dict from sample batch."""

        feed_dict = {
            self._observations_ph: batch['observations'],
            self._actions_ph: batch['actions'],
            self._next_observations_ph: batch['next_observations'],
            self._rewards_ph: batch['rewards'],
            self._terminals_ph: batch['terminals'],
        }

        if self._store_extra_policy_info:
            feed_dict[self._log_pis_ph] = batch['log_pis']
            feed_dict[self._raw_actions_ph] = batch['raw_actions']

        if iteration is not None:
            feed_dict[self._iteration_ph] = iteration

        return feed_dict

    def get_diagnostics(self, iteration, batch, training_paths,
                        evaluation_paths):
        """Return diagnostic information as ordered dictionary.

        Records mean and standard deviation of Q-function and state
        value function, and TD-loss (mean squared Bellman error)
        for the sample batch.

        Also calls the `draw` method of the plotter, if plotter defined.
        """

        feed_dict = self._get_feed_dict(iteration, batch)

        (Q_values, Q_losses, alpha, global_step) = self._session.run(
            (self._Q_values, self._Q_losses, self._alpha, self.global_step),
            feed_dict)

        diagnostics = OrderedDict({
            'Q-avg': np.mean(Q_values),
            'Q-std': np.std(Q_values),
            'Q_loss': np.mean(Q_losses),
            'alpha': alpha,
        })

        policy_diagnostics = self._policy.get_diagnostics(
            batch['observations'])
        diagnostics.update({
            f'policy/{key}': value
            for key, value in policy_diagnostics.items()
        })

        if self._plotter:
            self._plotter.draw()

        return diagnostics

    @property
    def tf_saveables(self):
        saveables = {
            '_policy_optimizer': self._policy_optimizer,
            **{
                f'Q_optimizer_{i}': optimizer
                for i, optimizer in enumerate(self._Q_optimizers)
            },
            '_log_alpha': self._log_alpha,
        }

        if hasattr(self, '_alpha_optimizer'):
            saveables['_alpha_optimizer'] = self._alpha_optimizer

        return saveables
Beispiel #8
0
class MOPAC(RLAlgorithm):
    """Model-Based Policy Optimization (MOPAC)
    """
    def __init__(
        self,
        training_environment,
        evaluation_environment,
        policy,
        Qs,
        Vs,
        pool,
        static_fns,
        plotter=None,
        tf_summaries=False,
        lr=3e-4,
        reward_scale=1.0,
        target_entropy='auto',
        discount=0.99,
        tau=5e-3,
        target_update_interval=1,
        action_prior='uniform',
        reparameterize=False,
        store_extra_policy_info=False,
        mopac=True,
        valuefunc=False,
        deterministic_obs=False,
        deterministic_rewards=False,
        model_train_freq=250,
        num_networks=7,
        num_elites=5,
        model_retain_epochs=20,
        model_train_end_epoch=-1,
        rollout_batch_size=100e3,
        real_ratio=0.1,
        ratio_schedule=[0, 100, 0.5, 0.5],
        rollout_schedule=[20, 100, 1, 1],
        hidden_dim=200,
        max_model_t=None,
        **kwargs,
    ):
        """
        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy: A policy function approximator.
            initial_exploration_policy: ('Policy'): A policy that we use
                for initial exploration which is not trained by the algorithm.
            Qs: Q-function approximators. The min of these
                approximators will be used. Usage of at least two Q-functions
                improves performance by reducing overestimation bias.
            pool (`PoolBase`): Replay pool to add gathered samples to.
            plotter (`QFPolicyPlotter`): Plotter instance to be used for
                visualizing Q-function during training.
            lr (`float`): Learning rate used for the function approximators.
            discount (`float`): Discount factor for Q-function updates.
            tau (`float`): Soft value function target update weight.
            target_update_interval ('int'): Frequency at which target network
                updates occur in iterations.
            reparameterize ('bool'): If True, we use a gradient estimator for
                the policy derived using the reparameterization trick. We use
                a likelihood ratio based estimator otherwise.
        """

        super(MOPAC, self).__init__(**kwargs)

        obs_dim = np.prod(training_environment.observation_space.shape)
        act_dim = np.prod(training_environment.action_space.shape)
        self._model = construct_model(obs_dim=obs_dim,
                                      act_dim=act_dim,
                                      hidden_dim=hidden_dim,
                                      num_networks=num_networks,
                                      num_elites=num_elites)
        self._static_fns = static_fns
        self.fake_env = FakeEnv(self._model, self._static_fns)

        self._rollout_schedule = rollout_schedule
        self._ratio_schedule = ratio_schedule
        self._max_model_t = max_model_t

        # self._model_pool_size = model_pool_size
        # print('[ MOPAC ] Model pool size: {:.2E}'.format(self._model_pool_size))
        # self._model_pool = SimpleReplayPool(pool._observation_space, pool._action_space, self._model_pool_size)

        self._mopac = mopac
        self._valuefunc = valuefunc

        self._model_retain_epochs = model_retain_epochs
        self._model_train_freq = model_train_freq
        self._model_train_end_epoch = model_train_end_epoch

        self._rollout_batch_size = int(rollout_batch_size)
        self._deterministic_obs = deterministic_obs
        self._deterministic_rewards = deterministic_rewards
        #self._real_ratio = real_ratio

        self._log_dir = os.getcwd()
        self._writer = Writer(self._log_dir)

        self._training_environment = training_environment
        self._evaluation_environment = evaluation_environment
        self._policy = policy

        self._Qs = Qs
        self._Q_targets = tuple(tf.keras.models.clone_model(Q) for Q in Qs)

        self._Vs = Vs
        self._V_target = tf.keras.models.clone_model(Vs)

        self._pool = pool
        self._plotter = plotter
        self._tf_summaries = tf_summaries

        self._policy_lr = lr
        self._Q_lr = lr
        self._V_lr = lr

        self._reward_scale = reward_scale
        self._target_entropy = (
            -np.prod(self._training_environment.action_space.shape)
            if target_entropy == 'auto' else target_entropy)
        print('[ MOPAC ] Target entropy: {}'.format(self._target_entropy))

        self._discount = discount
        self._tau = tau
        self._target_update_interval = target_update_interval
        self._action_prior = action_prior

        self._reparameterize = reparameterize
        self._store_extra_policy_info = store_extra_policy_info

        observation_shape = self._training_environment.active_observation_shape
        action_shape = self._training_environment.action_space.shape

        assert len(observation_shape) == 1, observation_shape
        self._observation_shape = observation_shape
        assert len(action_shape) == 1, action_shape
        self._action_shape = action_shape

        self._build()

    def _build(self):
        self._training_ops = {}

        self._init_global_step()
        self._init_placeholders()
        self._init_actor_update()
        self._init_critic_update()
        self._init_value_update()
        self._init_mppi()

    def _train(self):
        """Return a generator that performs RL training.

        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy (`Policy`): Policy used for training
            initial_exploration_policy ('Policy'): Policy used for exploration
                If None, then all exploration is done using policy
            pool (`PoolBase`): Sample pool to add samples to
        """
        training_environment = self._training_environment
        evaluation_environment = self._evaluation_environment
        policy = self._policy
        pool = self._pool
        model_metrics = {}

        if not self._training_started:
            self._init_training()

            self._initial_exploration_hook(training_environment,
                                           self._initial_exploration_policy,
                                           pool)

        self.sampler.initialize(training_environment, policy, pool)

        gt.reset_root()
        gt.rename_root('RLAlgorithm')
        gt.set_def_unique(False)

        self._training_before_hook()

        for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)):

            self._epoch_before_hook()
            gt.stamp('epoch_before_hook')

            self._training_progress = Progress(self._epoch_length *
                                               self._n_train_repeat)
            start_samples = self.sampler._total_samples
            obs = None

            self._set_rollout_length()
            # reset U and noise
            self._reset_mppi()

            for i in count():
                samples_now = self.sampler._total_samples
                self._timestep = samples_now - start_samples

                if (samples_now >= start_samples + self._epoch_length
                        and self.ready_to_train):
                    break

                self._timestep_before_hook()
                gt.stamp('timestep_before_hook')

                self._set_real_ratio()
                if self._timestep % self._model_train_freq == 0 and self._real_ratio < 1.0:
                    self._training_progress.pause()
                    print('[ MOPAC ] log_dir: {} | ratio: {}'.format(
                        self._log_dir, self._real_ratio))
                    print(
                        '[ MOPAC ] Training model at epoch {} | freq {} | timestep {} (total: {}) | epoch train steps: {} (total: {})'
                        .format(self._epoch, self._model_train_freq,
                                self._timestep, self._total_timestep,
                                self._train_steps_this_epoch,
                                self._num_train_steps))

                    if self._epoch < self._model_train_end_epoch or self._model_train_end_epoch == -1:
                        model_train_metrics = self._train_model(
                            batch_size=256,
                            max_epochs=None,
                            holdout_ratio=0.2,
                            max_t=self._max_model_t)
                        model_metrics.update(model_train_metrics)
                        gt.stamp('epoch_train_model')

                    self._reallocate_model_pool()

                    model_rollout_metrics = self._rollout_model(
                        gamma=self._discount,
                        mopac=self._mopac,
                        valuefunc=self._valuefunc,
                        deterministic_obs=self._deterministic_obs,
                        deterministic_rewards=self._deterministic_rewards)
                    model_metrics.update(model_rollout_metrics)

                    gt.stamp('epoch_rollout_model')
                    # self._visualize_model(self._evaluation_environment, self._total_timestep)
                    self._training_progress.resume()

                self._do_sampling(
                    timestep=self._total_timestep)  # steps the env!!
                gt.stamp('sample')

                if self.ready_to_train:
                    self._do_training_repeats(timestep=self._total_timestep)
                gt.stamp('train')

                self._timestep_after_hook()
                gt.stamp('timestep_after_hook')

            training_paths = self.sampler.get_last_n_paths(
                math.ceil(self._epoch_length / self.sampler._max_path_length))
            gt.stamp('training_paths')
            evaluation_paths = self._evaluation_paths(policy,
                                                      evaluation_environment)
            gt.stamp('evaluation_paths')

            training_metrics = self._evaluate_rollouts(training_paths,
                                                       training_environment)
            gt.stamp('training_metrics')
            if evaluation_paths:
                evaluation_metrics = self._evaluate_rollouts(
                    evaluation_paths, evaluation_environment)
                gt.stamp('evaluation_metrics')
            else:
                evaluation_metrics = {}

            self._epoch_after_hook(training_paths)
            gt.stamp('epoch_after_hook')

            sampler_diagnostics = self.sampler.get_diagnostics()

            diagnostics = self.get_diagnostics(
                iteration=self._total_timestep,
                batch=self._evaluation_batch(),
                training_paths=training_paths,
                evaluation_paths=evaluation_paths)

            time_diagnostics = gt.get_times().stamps.itrs

            diagnostics.update(
                OrderedDict((
                    *((f'evaluation/{key}', evaluation_metrics[key])
                      for key in sorted(evaluation_metrics.keys())),
                    *((f'training/{key}', training_metrics[key])
                      for key in sorted(training_metrics.keys())),
                    *((f'times/{key}', time_diagnostics[key][-1])
                      for key in sorted(time_diagnostics.keys())),
                    *((f'sampler/{key}', sampler_diagnostics[key])
                      for key in sorted(sampler_diagnostics.keys())),
                    *((f'model/{key}', model_metrics[key])
                      for key in sorted(model_metrics.keys())),
                    ('epoch', self._epoch),
                    ('timestep', self._timestep),
                    ('timesteps_total', self._total_timestep),
                    ('train-steps', self._num_train_steps),
                )))

            if self._eval_render_mode is not None and hasattr(
                    evaluation_environment, 'render_rollouts'):
                training_environment.render_rollouts(evaluation_paths)

            yield diagnostics

        self.sampler.terminate()

        self._training_after_hook()

        self._training_progress.close()

        yield {'done': True, **diagnostics}

    def train(self, *args, **kwargs):
        return self._train(*args, **kwargs)

    def _log_policy(self):
        save_path = os.path.join(self._log_dir, 'models')
        filesystem.mkdir(save_path)
        weights = self._policy.get_weights()
        data = {'policy_weights': weights}
        full_path = os.path.join(save_path,
                                 'policy_{}.pkl'.format(self._total_timestep))
        print('Saving policy to: {}'.format(full_path))
        pickle.dump(data, open(full_path, 'wb'))

    def _log_model(self):
        save_path = os.path.join(self._log_dir, 'models')
        filesystem.mkdir(save_path)
        print('Saving model to: {}'.format(save_path))
        self._model.save(save_path, self._total_timestep)

    def _set_rollout_length(self):
        min_epoch, max_epoch, min_length, max_length = self._rollout_schedule
        if self._epoch <= min_epoch:
            y = min_length
        else:
            dx = (self._epoch - min_epoch) / (max_epoch - min_epoch)
            dx = min(dx, 1)
            y = dx * (max_length - min_length) + min_length

        self._rollout_length = int(y)
        print(
            '[ Model Length ] Epoch: {} (min: {}, max: {}) | Length: {} (min: {} , max: {})'
            .format(self._epoch, min_epoch, max_epoch, self._rollout_length,
                    min_length, max_length))

    def _set_real_ratio(self):
        min_epoch, max_epoch, min_length, max_length = self._ratio_schedule
        if self._epoch <= min_epoch:
            y = min_length
        else:
            dx = (self._epoch - min_epoch) / (max_epoch - min_epoch)
            dx = min(dx, 1)
            y = dx * (max_length - min_length) + min_length

        self._real_ratio = y
        print(
            '[ Model Length ] Epoch: {} (min: {}, max: {}) | Ratio: {} (min: {} , max: {})'
            .format(self._epoch, min_epoch, max_epoch, self._real_ratio,
                    min_length, max_length))

    def _reallocate_model_pool(self):
        obs_space = self._pool._observation_space
        act_space = self._pool._action_space

        rollouts_per_epoch = self._rollout_batch_size * self._epoch_length / self._model_train_freq

        #rollouts_per_epoch = self._epoch_length / self._model_train_freq
        model_steps_per_epoch = int(self._rollout_length * rollouts_per_epoch)
        new_pool_size = self._model_retain_epochs * model_steps_per_epoch

        if not hasattr(self, '_model_pool'):
            print('[ MOPAC ] Initializing new model pool with size {:.2e}'.
                  format(new_pool_size))
            self._model_pool = SimpleReplayPool(obs_space, act_space,
                                                new_pool_size)

        elif self._model_pool._max_size != new_pool_size:
            print('[ MOPAC ] Updating model pool | {:.2e} --> {:.2e}'.format(
                self._model_pool._max_size, new_pool_size))

            if new_pool_size < self._model_pool._size:
                self._model_pool._size = new_pool_size

            samples = self._model_pool.return_all_samples()
            new_pool = SimpleReplayPool(obs_space, act_space, new_pool_size)
            new_pool.add_samples(samples)
            assert self._model_pool.size == new_pool.size
            self._model_pool = new_pool

    def _train_model(self, **kwargs):
        env_samples = self._pool.return_all_samples()
        train_inputs, train_outputs = format_samples_for_training(env_samples)
        model_metrics = self._model.train(train_inputs, train_outputs,
                                          **kwargs)
        return model_metrics

    # TODO: refactor, extract functions
    def _rollout_model(self,
                       gamma=0.99,
                       lambda_=1.0,
                       mopac=False,
                       valuefunc=False,
                       deterministic_obs=False,
                       deterministic_rewards=False):
        print(
            '[ Model Rollout ] Starting | Epoch: {} | Rollout length: {} | Batch size: {}'
            .format(self._epoch, self._rollout_length,
                    self._rollout_batch_size))
        batch = self.sampler.random_batch(self._rollout_batch_size)
        obs = batch['observations']
        steps_added = []

        if mopac:
            # repeat initial states for mppi
            obs = np.repeat(obs, self.repeats, axis=0)

            x_acts = np.zeros((self._rollout_batch_size * self.repeats,
                               self._rollout_length, *self._action_shape))
            x_obs = np.zeros((self._rollout_batch_size * self.repeats,
                              self._rollout_length, *self._observation_shape))
            x_total_reward = np.zeros((self._rollout_batch_size * self.repeats,
                                       self._rollout_length, 1))

            # fix model inds across rollouts and initial state batch
            model_inds = self._model.random_inds(
                self._rollout_batch_size).repeat(self.repeats)

        # rollouts
        # in mopac last step is replaced by value func
        horiz = self._rollout_length - 1 if valuefunc and mopac else self._rollout_length
        for t in range(horiz):
            if mopac:
                # first action from control sequence
                #act = self.U[:,t]
                act = self._policy.actions_np(obs)

                # add noise and clip
                act += self.noise[:, t]
                act = np.clip(act, -self.uclip, self.uclip)
            else:
                act = self._policy.actions_np(obs)

                # new random model inds on each step
                model_inds = self._model.random_inds(self._rollout_batch_size)

            next_obs, rew, term, info = self.fake_env.step(
                obs,
                act,
                model_inds,
                deterministic_obs=deterministic_obs,
                deterministic_rewards=deterministic_rewards)
            steps_added.append(len(obs))

            if mopac:
                # store reward (incl gamma decay) and observation
                x_total_reward[:, t] = (gamma**t) * rew

                x_obs[:, t] = obs
                x_acts[:, t] = act
            else:
                samples = {
                    'observations': obs,
                    'actions': act,
                    'next_observations': next_obs,
                    'rewards': rew,
                    'terminals': term
                }
                self._model_pool.add_samples(samples)

            nonterm_mask = ~term.squeeze(-1)
            if nonterm_mask.sum() == 0:
                print('[ Model Rollout ] Breaking early: {} | {} / {}'.format(
                    t, nonterm_mask.sum(), nonterm_mask.shape))
                break

            obs = next_obs if mopac else next_obs[
                nonterm_mask]  # making changes the shape of the array!

        if mopac:
            # VF on final state, appends terminal reward
            if valuefunc:
                # previous next obs becomes last obs for storage
                x_obs[:, -1] = next_obs
                # predict and store reward (incl gamma decay) and observation
                x_total_reward[:,
                               -1] = (gamma**horiz) * self._V_target.predict(
                                   [x_obs[:, -1]])

            x_opt_acts = np.zeros(
                (self._rollout_batch_size, self._rollout_length,
                 self._action_shape[0]))
            x_opt_obs = np.zeros(
                (self._rollout_batch_size, *self._observation_shape))

            # mppi optimization
            for l in range(0, self._rollout_batch_size * self.repeats,
                           self.repeats):
                # selectors
                i = int(l / self.repeats)
                r = range(l, l + self.repeats)

                # cum reward of rollout
                s = np.sum(x_total_reward[r], axis=1)

                # normalize cum reward
                alpha = np.exp(1 / lambda_ * (s - np.max(s)))
                omega = alpha / (np.sum(alpha) + 1e-6)

                # compute control offset (most important part in mppi)
                u_delta = np.sum((omega.squeeze() * self.noise[r].T).T, axis=0)
                # print(u_delta)

                # tweak control (duplicated across range)
                #self.U[r] += 1 * u_delta
                #self.U[r] = np.clip(self.U[r], -self.uclip, self.uclip)

                x_acts[r] += 1 * u_delta
                x_acts[r] = np.clip(x_acts[r], -self.uclip, self.uclip)

                # nan check
                nan_mask = np.isnan(x_acts[r])
                if np.any(nan_mask):
                    raise Exception("action contains nan value")

                # store first initial observation (and action sequence) belonging to action sequence
                x_opt_obs[i] = x_obs[l][0]  # initial observation
                #x_opt_acts[i] = self.U[l][:self._rollout_length]  # truncate
                x_opt_acts[i] = x_acts[l][:self._rollout_length]  # truncate

                # shift all elements to the left along horizon (for next env step)
                #self.U[r] = np.roll(self.U[r], -1, axis=1)

            # rollout trajectories using mppi control action sequences to generate samples
            # fix model inds
            model_inds = self._model.random_inds(self._rollout_batch_size)

            samples = []
            obs = x_opt_obs  # inital obs from first rollout
            for t in range(self._rollout_length):
                act = x_opt_acts[:, t]
                next_obs, rew, term, info = self.fake_env.step(
                    obs,
                    act,
                    model_inds,
                    deterministic_obs=deterministic_obs,
                    deterministic_rewards=deterministic_rewards)

                # store sample
                samples = {
                    'observations': obs,
                    'actions': act,
                    'next_observations': next_obs,
                    'rewards': rew,
                    'terminals': term
                }
                self._model_pool.add_samples(samples)

                # rew *= (gamma**t)

                obs = next_obs
            # cum rewards: potential reward, decreases with every step in trajectory
            # for last step the cum reward becomes the reward of just that step
            # ex.: rew=(1,2,3,4,5) -> cumrewards=(15,14,12,9,5)
            # cumrews = np.array([s['cumrewards'] for s in samples])
            # cumrewards =  np.flip(np.cumsum(np.flip(cumrews), axis=0))
            # # normalize
            # cumrewards /= self._rollout_length

            # add samples to pool, together with cum rewar
            # add samples to pool, together with cum reward
            # for s in zip(samples):
            # add to pool
            # self._model_pool.add_samples(s)

        mean_rollout_length = sum(steps_added) / self._rollout_batch_size
        rollout_stats = {'mean_rollout_length': mean_rollout_length}
        print(
            '[ Model Rollout ] Added: {:.1e} | Model pool: {:.1e} (max {:.1e}) | Length: {} | Train rep: {}'
            .format(sum(steps_added), self._model_pool.size,
                    self._model_pool._max_size, mean_rollout_length,
                    self._n_train_repeat))

        return rollout_stats

    def _visualize_model(self, env, timestep):
        ## save env state
        state = env.unwrapped.state_vector()
        qpos_dim = len(env.unwrapped.sim.data.qpos)
        qpos = state[:qpos_dim]
        qvel = state[qpos_dim:]

        print('[ Visualization ] Starting | Epoch {} | Log dir: {}\n'.format(
            self._epoch, self._log_dir))
        visualize_policy(env, self.fake_env, self._policy, self._writer,
                         timestep)
        print('[ Visualization ] Done')
        ## set env state
        env.unwrapped.set_state(qpos, qvel)

    def _training_batch(self, batch_size=None):
        batch_size = batch_size or self.sampler._batch_size
        env_batch_size = int(batch_size * self._real_ratio)
        model_batch_size = batch_size - env_batch_size

        ## can sample from the env pool even if env_batch_size == 0
        env_batch = self._pool.random_batch(env_batch_size)

        if model_batch_size > 0:
            model_batch = self._model_pool.random_batch(model_batch_size)

            keys = env_batch.keys()
            batch = {
                k: np.concatenate((env_batch[k], model_batch[k]), axis=0)
                for k in keys
            }
        else:
            ## if real_ratio == 1.0, no model pool was ever allocated,
            ## so skip the model pool sampling
            batch = env_batch
        return batch

    def _init_global_step(self):
        self.global_step = training_util.get_or_create_global_step()
        self._training_ops.update(
            {'increment_global_step': training_util._increment_global_step(1)})

    def _init_placeholders(self):
        """Create input placeholders for the SAC algorithm.

        Creates `tf.placeholder`s for:
            - observation
            - next observation
            - action
            - reward
            - terminals
        """
        self._iteration_ph = tf.placeholder(tf.int64,
                                            shape=None,
                                            name='iteration')

        self._observations_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._observation_shape),
            name='observation',
        )

        self._next_observations_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._observation_shape),
            name='next_observation',
        )

        self._actions_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._action_shape),
            name='actions',
        )

        self._rewards_ph = tf.placeholder(
            tf.float32,
            shape=(None, 1),
            name='rewards',
        )

        self._terminals_ph = tf.placeholder(
            tf.float32,
            shape=(None, 1),
            name='terminals',
        )

        # self._cumrewards_ph = tf.placeholder(
        #     tf.float32,
        #     shape=(None, 1),
        #     name='cumrewards',
        # )

        if self._store_extra_policy_info:
            self._log_pis_ph = tf.placeholder(
                tf.float32,
                shape=(None, 1),
                name='log_pis',
            )
            self._raw_actions_ph = tf.placeholder(
                tf.float32,
                shape=(None, *self._action_shape),
                name='raw_actions',
            )

    def _get_Q_target(self):
        next_actions = self._policy.actions([self._next_observations_ph])
        next_log_pis = self._policy.log_pis([self._next_observations_ph],
                                            next_actions)

        next_Qs_values = tuple(
            Q([self._next_observations_ph, next_actions])
            for Q in self._Q_targets)

        min_next_Q = tf.reduce_min(next_Qs_values, axis=0)
        next_value = min_next_Q - self._alpha * next_log_pis

        Q_target = td_target(reward=self._reward_scale * self._rewards_ph,
                             discount=self._discount,
                             next_value=(1 - self._terminals_ph) * next_value)

        return Q_target

    def _get_V_target(self):
        actions = self._policy.actions([self._observations_ph])
        log_pis = self._policy.log_pis([self._observations_ph], actions)

        Qs_values = tuple(
            Q([self._observations_ph, actions]) for Q in self._Q_targets)

        min_Q = tf.reduce_min(Qs_values, axis=0)
        value = min_Q - self._alpha * log_pis

        return value

    def _init_critic_update(self):
        """Create minimization operation for critic Q-function.

        Creates a `tf.optimizer.minimize` operation for updating
        critic Q-function with gradient descent, and appends it to
        `self._training_ops` attribute.
        """
        Q_target = tf.stop_gradient(self._get_Q_target())

        assert Q_target.shape.as_list() == [None, 1]

        Q_values = self._Q_values = tuple(
            Q([self._observations_ph, self._actions_ph]) for Q in self._Qs)

        Q_losses = self._Q_losses = tuple(
            tf.losses.mean_squared_error(
                labels=Q_target, predictions=Q_value, weights=0.5)
            for Q_value in Q_values)

        self._Q_optimizers = tuple(
            tf.train.AdamOptimizer(learning_rate=self._Q_lr,
                                   name='{}_{}_optimizer'.format(Q._name, i))
            for i, Q in enumerate(self._Qs))
        Q_training_ops = tuple(
            tf.contrib.layers.optimize_loss(Q_loss,
                                            self.global_step,
                                            learning_rate=self._Q_lr,
                                            optimizer=Q_optimizer,
                                            variables=Q.trainable_variables,
                                            increment_global_step=False,
                                            summaries=((
                                                "loss", "gradients",
                                                "gradient_norm",
                                                "global_gradient_norm"
                                            ) if self._tf_summaries else ()))
            for i, (Q, Q_loss, Q_optimizer) in enumerate(
                zip(self._Qs, Q_losses, self._Q_optimizers)))

        self._training_ops.update({'Q': tf.group(Q_training_ops)})

    def _init_value_update(self):
        """Create minimization operation for critic V-function.

        Creates a `tf.optimizer.minimize` operation for updating
        critic V-function with gradient descent, and appends it to
        `self._training_ops` attribute.
        """
        V_target = tf.stop_gradient(self._get_V_target())

        assert V_target.shape.as_list() == [None, 1]

        V_value = self._V_value = self._Vs([self._observations_ph])

        V_loss = self._V_loss = tf.losses.mean_squared_error(
            labels=V_target, predictions=V_value, weights=0.5)

        self._V_optimizers = tf.train.AdamOptimizer(learning_rate=self._V_lr,
                                                    name='{}_optimizer'.format(
                                                        self._Vs._name))
        V_training_ops = tf.contrib.layers.optimize_loss(
            V_loss,
            self.global_step,
            learning_rate=self._V_lr,
            optimizer=self._V_optimizers,
            variables=self._Vs.trainable_variables,
            increment_global_step=False,
            summaries=(("loss", "gradients", "gradient_norm",
                        "global_gradient_norm") if self._tf_summaries else ()))

        self._training_ops.update({'V': tf.group(V_training_ops)})

    def _init_actor_update(self):
        """Create minimization operations for policy and entropy.

        Creates a `tf.optimizer.minimize` operations for updating
        policy and entropy with gradient descent, and adds them to
        `self._training_ops` attribute.
        """

        actions = self._policy.actions([self._observations_ph])
        log_pis = self._policy.log_pis([self._observations_ph], actions)

        assert log_pis.shape.as_list() == [None, 1]

        log_alpha = self._log_alpha = tf.get_variable('log_alpha',
                                                      dtype=tf.float32,
                                                      initializer=0.0)
        alpha = tf.exp(log_alpha)

        if isinstance(self._target_entropy, Number):
            alpha_loss = -tf.reduce_mean(
                log_alpha * tf.stop_gradient(log_pis + self._target_entropy))

            self._alpha_optimizer = tf.train.AdamOptimizer(
                self._policy_lr, name='alpha_optimizer')
            self._alpha_train_op = self._alpha_optimizer.minimize(
                loss=alpha_loss, var_list=[log_alpha])

            self._training_ops.update(
                {'temperature_alpha': self._alpha_train_op})
        print('alpha entropy parameter:', alpha)
        self._alpha = alpha

        if self._action_prior == 'normal':
            policy_prior = tf.contrib.distributions.MultivariateNormalDiag(
                loc=tf.zeros(self._action_shape),
                scale_diag=tf.ones(self._action_shape))
            policy_prior_log_probs = policy_prior.log_prob(actions)
        elif self._action_prior == 'uniform':
            policy_prior_log_probs = 0.0

        Q_log_targets = tuple(
            Q([self._observations_ph, actions]) for Q in self._Qs)
        min_Q_log_target = tf.reduce_min(Q_log_targets, axis=0)

        if self._reparameterize:
            policy_kl_losses = (alpha * log_pis - min_Q_log_target -
                                policy_prior_log_probs)
        else:
            raise NotImplementedError

        assert policy_kl_losses.shape.as_list() == [None, 1]

        policy_loss = tf.reduce_mean(policy_kl_losses)

        self._policy_optimizer = tf.train.AdamOptimizer(
            learning_rate=self._policy_lr, name="policy_optimizer")
        policy_train_op = tf.contrib.layers.optimize_loss(
            policy_loss,
            self.global_step,
            learning_rate=self._policy_lr,
            optimizer=self._policy_optimizer,
            variables=self._policy.trainable_variables,
            increment_global_step=False,
            summaries=("loss", "gradients", "gradient_norm",
                       "global_gradient_norm") if self._tf_summaries else ())

        self._training_ops.update({'policy_train_op': policy_train_op})

    def _init_training(self):
        self._update_target(tau=1.0)

    def _init_mppi(self,
                   hl=0.4,
                   horiz=15,
                   noise_mu=0.,
                   noise_sigma=0.5,
                   uclip=1.4,
                   lambda_=1.0,
                   repeats=100):
        action_len = self._action_shape[0]
        obs_len = self._observation_shape[0]
        self.repeats = repeats
        self.U = np.random.uniform(low=-hl,
                                   high=hl,
                                   size=(self._rollout_batch_size * repeats,
                                         horiz, action_len))
        self.noise = np.random.normal(loc=noise_mu,
                                      scale=noise_sigma,
                                      size=(self._rollout_batch_size * repeats,
                                            horiz, action_len))
        self.uclip = uclip
        #self.action_q = Queue()

    def _reset_mppi(self):
        self._init_mppi(horiz=self._rollout_length)

        # sample batch for init U with policy
        #batch = self.sampler.random_batch(self._rollout_batch_size)
        #obs = batch['observations'].repeat(self.repeats, axis=0)

        ## init every timestep with same action
        #for t in range(self._rollout_length):
        #    self.U[:,t] = self._policy.actions_np(obs)

    def _update_target(self, tau=None):
        tau = tau or self._tau

        for Q, Q_target in zip(self._Qs, self._Q_targets):
            source_params = Q.get_weights()
            target_params = Q_target.get_weights()
            Q_target.set_weights([
                tau * source + (1.0 - tau) * target
                for source, target in zip(source_params, target_params)
            ])

        source_params = self._Vs.get_weights()
        target_params = self._V_target.get_weights()
        self._V_target.set_weights([
            tau * source + (1.0 - tau) * target
            for source, target in zip(source_params, target_params)
        ])

    def _do_training(self, iteration, batch):
        """Runs the operations for updating training and target ops."""

        self._training_progress.update()
        self._training_progress.set_description()

        feed_dict = self._get_feed_dict(iteration, batch)

        self._session.run(self._training_ops, feed_dict)

        if iteration % self._target_update_interval == 0:
            # Run target ops here.
            self._update_target()

    def _get_feed_dict(self, iteration, batch):
        """Construct TensorFlow feed_dict from sample batch."""

        feed_dict = {
            self._observations_ph: batch['observations'],
            self._actions_ph: batch['actions'],
            self._next_observations_ph: batch['next_observations'],
            self._rewards_ph: batch['rewards'],
            self._terminals_ph: batch['terminals']
        }

        # feed_dict[self._cumrewards_ph] = batch['cumrewards']

        if self._store_extra_policy_info:
            feed_dict[self._log_pis_ph] = batch['log_pis']
            feed_dict[self._raw_actions_ph] = batch['raw_actions']

        if iteration is not None:
            feed_dict[self._iteration_ph] = iteration

        return feed_dict

    def get_diagnostics(self, iteration, batch, training_paths,
                        evaluation_paths):
        """Return diagnostic information as ordered dictionary.

        Records mean and standard deviation of Q-function and state
        value function, and TD-loss (mean squared Bellman error)
        for the sample batch.

        Also calls the `draw` method of the plotter, if plotter defined.
        """

        feed_dict = self._get_feed_dict(iteration, batch)

        (Q_values, Q_losses, alpha, global_step, actions) = self._session.run(
            (self._Q_values, self._Q_losses, self._alpha, self.global_step,
             self._actions_ph), feed_dict)

        (V_value, V_loss, alpha, global_step) = self._session.run(
            (self._V_value, self._V_loss, self._alpha, self.global_step),
            feed_dict)

        diagnostics = OrderedDict({
            'Q-avg': np.mean(Q_values),
            'Q-std': np.std(Q_values),
            'Q_loss': np.mean(Q_losses),
            'V-avg': np.mean(V_value),
            'V-std': np.std(V_value),
            'V_loss': np.mean(V_loss),
            'alpha': alpha,
        })

        policy_diagnostics = self._policy.get_diagnostics(
            batch['observations'])
        diagnostics.update({
            f'policy/{key}': value
            for key, value in policy_diagnostics.items()
        })

        if self._plotter:
            self._plotter.draw()

        return diagnostics

    @property
    def tf_saveables(self):
        saveables = {
            '_policy_optimizer': self._policy_optimizer,
            **{
                f'Q_optimizer_{i}': optimizer
                for i, optimizer in enumerate(self._Q_optimizers)
            },
            '_log_alpha': self._log_alpha,
        }

        if hasattr(self, '_alpha_optimizer'):
            saveables['_alpha_optimizer'] = self._alpha_optimizer

        return saveables