Пример #1
0
class StochasticGaussianMLPPolicy(StochasticPolicy, LasagnePowered,
                                  Serializable):
    def __init__(
        self,
        env_spec,
        input_latent_vars=None,
        hidden_sizes=(32, 32),
        hidden_latent_vars=None,
        learn_std=True,
        init_std=1.0,
        hidden_nonlinearity=NL.tanh,
        output_nonlinearity=None,
    ):
        Serializable.quick_init(self, locals())
        assert isinstance(env_spec.action_space, Box)

        obs_dim = env_spec.observation_space.flat_dim
        action_dim = env_spec.action_space.flat_dim

        # create network
        mean_network = StochasticMLP(
            input_shape=(obs_dim, ),
            input_latent_vars=input_latent_vars,
            output_dim=action_dim,
            hidden_sizes=hidden_sizes,
            hidden_latent_vars=hidden_latent_vars,
            hidden_nonlinearity=hidden_nonlinearity,
            output_nonlinearity=output_nonlinearity,
        )

        l_mean = mean_network.output_layer
        obs_var = mean_network.input_layer.input_var

        l_log_std = ParamLayer(
            mean_network.input_layer,
            num_units=action_dim,
            param=lasagne.init.Constant(np.log(init_std)),
            name="output_log_std",
            trainable=learn_std,
        )

        self._mean_network = mean_network
        self._n_latent_layers = len(mean_network.latent_layers)
        self._l_mean = l_mean
        self._l_log_std = l_log_std

        LasagnePowered.__init__(self, [l_mean, l_log_std])
        super(StochasticGaussianMLPPolicy, self).__init__(env_spec)

        outputs = self.dist_info_sym(mean_network.input_var)
        latent_keys = sorted(
            set(outputs.keys()).difference({"mean", "log_std"}))

        extras = get_full_output([self._l_mean, self._l_log_std] +
                                 self._mean_network.latent_layers, )[1]
        latent_distributions = [
            extras[layer]["distribution"]
            for layer in self._mean_network.latent_layers
        ]

        self._latent_keys = latent_keys
        self._latent_distributions = latent_distributions
        self._dist = DiagonalGaussian(action_dim)

        self._f_dist_info = ext.compile_function(
            inputs=[obs_var],
            outputs=outputs,
        )
        self._f_dist_info_givens = None

    @property
    def latent_layers(self):
        return self._mean_network.latent_layers

    @property
    def latent_dims(self):
        return self._mean_network.latent_dims

    def dist_info(self, obs, state_infos=None):
        if state_infos is None or len(state_infos) == 0:
            return self._f_dist_info(obs)
        if self._f_dist_info_givens is None:
            # compile function
            obs_var = self._mean_network.input_var
            latent_keys = [
                "latent_%d" % idx for idx in range(self._n_latent_layers)
            ]
            latent_vars = [
                TT.matrix("latent_%d" % idx)
                for idx in range(self._n_latent_layers)
            ]
            latent_dict = dict(list(zip(latent_keys, latent_vars)))
            self._f_dist_info_givens = ext.compile_function(
                inputs=[obs_var] + latent_vars,
                outputs=self.dist_info_sym(obs_var, latent_dict),
            )
        latent_vals = []
        for idx in range(self._n_latent_layers):
            latent_vals.append(state_infos["latent_%d" % idx])
        return self._f_dist_info_givens(*[obs] + latent_vals)

    def reset(self):  #here I would sample a latents var.
        # sample latents
        # store it in self.something that then goes to all the others
        pass

    def dist_info_sym(self, obs_var, state_info_vars=None):
        if state_info_vars is not None:
            latent_givens = {
                latent_layer: state_info_vars["latent_%d" % idx]
                for idx, latent_layer in enumerate(
                    self._mean_network.latent_layers)
            }
            latent_dist_infos = dict()
            for idx, latent_layer in enumerate(
                    self._mean_network.latent_layers):
                cur_dist_info = dict()
                prefix = "latent_%d_" % idx
                for k, v in state_info_vars.items():
                    if k.startswith(prefix):
                        cur_dist_info[k[len(prefix):]] = v
                latent_dist_infos[latent_layer] = cur_dist_info
        else:
            latent_givens = dict()
            latent_dist_infos = dict()
        all_outputs, extras = get_full_output(
            [self._l_mean, self._l_log_std] + self._mean_network.latent_layers,
            inputs={self._mean_network._l_in: obs_var},
            latent_givens=latent_givens,
            latent_dist_infos=latent_dist_infos,
        )

        mean_var = all_outputs[0]
        log_std_var = all_outputs[1]
        latent_vars = all_outputs[2:]
        latent_dist_infos = []
        for latent_layer in self._mean_network.latent_layers:
            latent_dist_infos.append(extras[latent_layer]["dist_info"])

        output_dict = dict(mean=mean_var, log_std=log_std_var)
        for idx, latent_var, latent_dist_info in zip(itertools.count(),
                                                     latent_vars,
                                                     latent_dist_infos):
            output_dict["latent_%d" % idx] = latent_var
            for k, v in latent_dist_info.items():
                output_dict["latent_%d_%s" % (idx, k)] = v

        return output_dict

    def kl_sym(self, old_dist_info_vars, new_dist_info_vars):
        """
        Compute the symbolic KL divergence of distributions of both the actions and the latents variables
        """
        kl = self._dist.kl_sym(old_dist_info_vars, new_dist_info_vars)
        for idx, latent_dist in enumerate(self._latent_distributions):
            # collect dist info for each latents variable
            prefix = "latent_%d_" % idx
            old_latent_dist_info = {
                k[len(prefix):]: v
                for k, v in old_dist_info_vars.items() if k.startswith(prefix)
            }
            new_latent_dist_info = {
                k[len(prefix):]: v
                for k, v in new_dist_info_vars.items() if k.startswith(prefix)
            }
            kl += latent_dist.kl_sym(old_latent_dist_info,
                                     new_latent_dist_info)
        return kl

    def likelihood_ratio_sym(self, action_var, old_dist_info_vars,
                             new_dist_info_vars):
        """
        Compute the symbolic likelihood ratio of both the actions and the latents variables.
        """
        lr = self._dist.likelihood_ratio_sym(action_var, old_dist_info_vars,
                                             new_dist_info_vars)
        for idx, latent_dist in enumerate(self._latent_distributions):
            latent_var = old_dist_info_vars["latent_%d" % idx]
            prefix = "latent_%d_" % idx
            old_latent_dist_info = {
                k[len(prefix):]: v
                for k, v in old_dist_info_vars.items() if k.startswith(prefix)
            }
            new_latent_dist_info = {
                k[len(prefix):]: v
                for k, v in new_dist_info_vars.items() if k.startswith(prefix)
            }
            lr *= latent_dist.likelihood_ratio_sym(latent_var,
                                                   old_latent_dist_info,
                                                   new_latent_dist_info)
        return lr

    def log_likelihood(self, actions, dist_info, action_only=False):
        """
        Computes the log likelihood of both the actions and the latents variables, unless action_only is set to True,
        in which case it will only compute the log likelihood of the actions.
        :return:
        """
        logli = self._dist.log_likelihood(actions, dist_info)
        if not action_only:
            for idx, latent_dist in enumerate(self._latent_distributions):
                latent_var = dist_info["latent_%d" % idx]
                prefix = "latent_%d_" % idx
                latent_dist_info = {
                    k[len(prefix):]: v
                    for k, v in dist_info.items() if k.startswith(prefix)
                }
                logli += latent_dist.log_likelihood(latent_var,
                                                    latent_dist_info)
        return logli

    def log_likelihood_sym(self, action_var, dist_info_vars):
        logli = self._dist.log_likelihood_sym(action_var, dist_info_vars)
        for idx, latent_dist in enumerate(self._latent_distributions):
            latent_var = dist_info_vars["latent_%d" % idx]
            prefix = "latent_%d_" % idx
            latent_dist_info = {
                k[len(prefix):]: v
                for k, v in dist_info_vars.items() if k.startswith(prefix)
            }
            logli += latent_dist.log_likelihood_sym(latent_var,
                                                    latent_dist_info)
        return logli

    def entropy(self, dist_info):
        ent = self._dist.entropy(dist_info)
        for idx, latent_dist in enumerate(self._latent_distributions):
            prefix = "latent_%d_" % idx
            latent_dist_info = {
                k[len(prefix):]: v
                for k, v in dist_info.items() if k.startswith(prefix)
            }
            ent += latent_dist.entropy(latent_dist_info)
        return ent

    @property
    def dist_info_keys(self):
        return ["mean", "log_std"] + self._latent_keys

    @overrides
    def get_action(self, observation):
        actions, outputs = self.get_actions([observation])
        return actions[0], {k: v[0] for k, v in outputs.items()}

    def get_actions(self, observations):
        outputs = self._f_dist_info(observations)
        mean = outputs["mean"]
        log_std = outputs["log_std"]
        rnd = np.random.normal(size=mean.shape)
        actions = rnd * np.exp(log_std) + mean
        return actions, outputs

    def log_diagnostics(self, paths):
        log_stds = np.vstack(
            [path["agent_infos"]["log_std"] for path in paths])
        logger.record_tabular('AveragePolicyStd', np.mean(np.exp(log_stds)))

    @property
    def distribution(self):
        """
        We set the distribution to the policy itself since we need some behavior different from a usual diagonal
        Gaussian distribution.
        """
        return self

    @property
    def state_info_keys(self):
        return self._latent_keys
Пример #2
0
class Hippo(BatchPolopt):
    def __init__(
            self,
            optimizer=None,
            optimizer_args=None,
            step_size=0.0003,
            latents=None,  # some sort of iterable of the actual latent vectors
            average_period=10,  # average over all the periods
            truncate_local_is_ratio=None,
            epsilon=0.1,
            train_pi_iters=80,
            use_skill_dependent_baseline=False,
            mlp_skill_dependent_baseline=False,
            **kwargs):
        if optimizer is None:
            if optimizer_args is None:
                # optimizer_args = dict()
                optimizer_args = dict(batch_size=None)
            self.optimizer = FirstOrderOptimizer(learning_rate=step_size,
                                                 max_epochs=train_pi_iters,
                                                 **optimizer_args)
        self.step_size = step_size
        self.truncate_local_is_ratio = truncate_local_is_ratio
        self.epsilon = epsilon

        super(Hippo,
              self).__init__(**kwargs)  # not sure if this line is correct
        self.num_latents = kwargs['policy'].latent_dim
        self.latents = latents
        self.average_period = average_period

        # import pdb; pdb.set_trace()
        # self.sampler = BatchSampler(self)
        self.sampler = HierBatchSampler(self, period=None)

        # i hope this is right
        self.diagonal = DiagonalGaussian(
            self.policy.low_policy.action_space.flat_dim)
        self.debug_fns = []

        assert isinstance(self.policy, HierarchicalPolicyRandomTime)
        # self.old_policy = copy.deepcopy(self.policy)

        # skill dependent baseline
        self.use_skill_dependent_baseline = use_skill_dependent_baseline
        self.mlp_skill_dependent_baseline = mlp_skill_dependent_baseline
        if use_skill_dependent_baseline:
            curr_env = kwargs['env']
            skill_dependent_action_space = curr_env.action_space
            new_obs_space_no_bi = curr_env.observation_space.shape[
                0] + 1  # 1 for the t_remaining
            skill_dependent_obs_space_dim = (new_obs_space_no_bi *
                                             (self.num_latents + 1) +
                                             self.num_latents, )
            skill_dependent_obs_space = Box(
                -1.0, 1.0, shape=skill_dependent_obs_space_dim)
            skill_dependent_env_spec = EnvSpec(skill_dependent_obs_space,
                                               skill_dependent_action_space)
            if self.mlp_skill_dependent_baseline:
                self.skill_dependent_baseline = GaussianMLPBaseline(
                    env_spec=skill_dependent_env_spec)
            else:
                self.skill_dependent_baseline = LinearFeatureBaseline(
                    env_spec=skill_dependent_env_spec)

    def init_opt(self):
        obs_var = ext.new_tensor(
            'obs', ndim=2, dtype=theano.config.floatX)  # todo: check the dtype

        manager_obs_var = ext.new_tensor('manager_obs',
                                         ndim=2,
                                         dtype=theano.config.floatX)

        action_var = self.env.action_space.new_tensor_variable(
            'action',
            extra_dims=1,
        )

        # this will have to be the advantage every time the manager makes a decision
        manager_advantage_var = ext.new_tensor('manager_advantage',
                                               ndim=1,
                                               dtype=theano.config.floatX)

        skill_advantage_var = ext.new_tensor('skill_advantage',
                                             ndim=1,
                                             dtype=theano.config.floatX)

        latent_var_sparse = ext.new_tensor('sparse_latent',
                                           ndim=2,
                                           dtype=theano.config.floatX)

        latent_var = ext.new_tensor('latents',
                                    ndim=2,
                                    dtype=theano.config.floatX)

        mean_var = ext.new_tensor('mean', ndim=2, dtype=theano.config.floatX)

        log_std_var = ext.new_tensor('log_std',
                                     ndim=2,
                                     dtype=theano.config.floatX)

        manager_prob_var = ext.new_tensor('log_std',
                                          ndim=2,
                                          dtype=theano.config.floatX)

        assert isinstance(self.policy, HierarchicalPolicy)

        #############################################################
        ### calculating the manager portion of the surrogate loss ###
        #############################################################

        # i, j should contain the probability of latent j at time step self.period*i
        # should be a len(obs)//self.period by len(self.latent) tensor
        latent_probs = self.policy.manager.dist_info_sym(
            manager_obs_var)['prob']
        # old_latent_probs = self.old_policy.manager.dist_info_sym(manager_obs_var)['prob']

        actual_latent_probs = TT.sum(latent_probs * latent_var_sparse, axis=1)
        old_actual_latent_probs = TT.sum(manager_prob_var * latent_var_sparse,
                                         axis=1)
        lr = TT.exp(
            TT.log(actual_latent_probs) - TT.log(old_actual_latent_probs))
        manager_surr_loss_vector = TT.minimum(
            lr * manager_advantage_var,
            TT.clip(lr, 1 - self.epsilon, 1 + self.epsilon) *
            manager_advantage_var)
        manager_surr_loss = -TT.mean(manager_surr_loss_vector)

        ############################################################
        ### calculating the skills portion of the surrogate loss ###
        ############################################################

        dist_info_var = self.policy.low_policy.dist_info_sym(
            obs_var, state_info_var=latent_var)
        old_dist_info_var = dict(mean=mean_var, log_std=log_std_var)
        skill_lr = self.diagonal.likelihood_ratio_sym(action_var,
                                                      old_dist_info_var,
                                                      dist_info_var)

        skill_surr_loss_vector = TT.minimum(
            skill_lr * skill_advantage_var,
            TT.clip(skill_lr, 1 - self.epsilon, 1 + self.epsilon) *
            skill_advantage_var)
        skill_surr_loss = -TT.mean(skill_surr_loss_vector)

        surr_loss = manager_surr_loss / self.average_period + skill_surr_loss

        input_list = [
            obs_var, manager_obs_var, action_var, manager_advantage_var,
            skill_advantage_var, latent_var, latent_var_sparse, mean_var,
            log_std_var, manager_prob_var
        ]

        self.optimizer.update_opt(loss=surr_loss,
                                  target=self.policy,
                                  inputs=input_list)
        return dict()

    # do the optimization
    def optimize_policy(self, itr, samples_data):
        # print(len(samples_data['observations']), self.period)
        # assert len(samples_data['observations']) % self.period == 0

        # note that I have to do extra preprocessing to the advantages, and also create obs_var_sparse
        if self.use_skill_dependent_baseline:
            input_values = tuple(
                ext.extract(samples_data, "observations", "actions",
                            "advantages", "agent_infos", "skill_advantages"))
        else:
            input_values = tuple(
                ext.extract(samples_data, "observations", "actions",
                            "advantages", "agent_infos"))

        time_remaining = input_values[3]['time_remaining']
        resampled_period = input_values[3]['resampled_period']
        obs_var = np.insert(input_values[0],
                            self.policy.obs_robot_dim,
                            time_remaining,
                            axis=1)
        manager_obs_var = obs_var[resampled_period]
        action_var = input_values[1]
        manager_adv_var = input_values[2][resampled_period]

        latent_var = input_values[3]['latents']
        latent_var_sparse = latent_var[resampled_period]
        mean = input_values[3]['mean']
        log_std = input_values[3]['log_std']
        prob = input_values[3]['prob'][resampled_period]
        if self.use_skill_dependent_baseline:
            skill_adv_var = input_values[4]
            all_input_values = (obs_var, manager_obs_var, action_var,
                                manager_adv_var, skill_adv_var, latent_var,
                                latent_var_sparse, mean, log_std, prob)
        else:
            skill_adv_var = input_values[2]
            all_input_values = (obs_var, manager_obs_var, action_var,
                                manager_adv_var, skill_adv_var, latent_var,
                                latent_var_sparse, mean, log_std, prob)

        # todo: assign current parameters to old policy; does this work?
        # old_param_values = self.policy.get_param_values()
        # self.old_policy.set_param_values(old_param_values)
        loss_before = self.optimizer.loss(all_input_values)
        self.optimizer.optimize(all_input_values)
        loss_after = self.optimizer.loss(all_input_values)
        logger.record_tabular('LossBefore', loss_before)
        logger.record_tabular('LossAfter', loss_after)
        logger.record_tabular('dLoss', loss_before - loss_after)
        return dict()

    def get_itr_snapshot(self, itr, samples_data):
        return dict(itr=itr,
                    policy=self.policy,
                    baseline=self.baseline,
                    env=self.env)

    def log_diagnostics(self, paths):
        # paths obtained by self.sampler.obtain_samples
        BatchPolopt.log_diagnostics(self, paths)
Пример #3
0
class ConcurrentContinuousPPO(BatchPolopt):
    """
    Designed to enable concurrent training of a SNN that parameterizes skills
    and also train the manager at the same time

    Note that, if I'm not trying to do the sample approximation of the weird log of sum term,
    I don't need to know which skill was picked, just need to know the action
    """

    # double check this constructor later
    def __init__(
            self,
            optimizer=None,
            optimizer_args=None,
            step_size=0.003,
            num_latents=6,
            latents=None,  # some sort of iterable of the actual latent vectors
            period=10,  # how often I choose a latent
            truncate_local_is_ratio=None,
            epsilon=0.1,
            train_pi_iters=10,
            use_skill_dependent_baseline=False,
            mlp_skill_dependent_baseline=False,
            freeze_manager=False,
            freeze_skills=False,
            **kwargs):
        if optimizer is None:
            if optimizer_args is None:
                # optimizer_args = dict()
                optimizer_args = dict(batch_size=None)
            self.optimizer = FirstOrderOptimizer(learning_rate=step_size,
                                                 max_epochs=train_pi_iters,
                                                 **optimizer_args)
        self.step_size = step_size
        self.truncate_local_is_ratio = truncate_local_is_ratio
        self.epsilon = epsilon

        super(ConcurrentContinuousPPO,
              self).__init__(**kwargs)  # not sure if this line is correct
        self.num_latents = kwargs['policy'].latent_dim
        self.latents = latents
        self.period = period
        self.freeze_manager = freeze_manager
        self.freeze_skills = freeze_skills
        assert (not freeze_manager) or (not freeze_skills)

        # todo: fix this sampler stuff
        # import pdb; pdb.set_trace()
        self.sampler = HierBatchSampler(self, self.period)
        # self.sampler = BatchSampler(self)
        # i hope this is right
        self.diagonal = DiagonalGaussian(
            self.policy.low_policy.action_space.flat_dim)
        self.debug_fns = []

        assert isinstance(self.policy, HierarchicalPolicy)
        self.period = self.policy.period
        assert self.policy.period == self.period
        self.continuous_latent = self.policy.continuous_latent
        assert self.continuous_latent
        # self.old_policy = copy.deepcopy(self.policy)

        # skill dependent baseline
        self.use_skill_dependent_baseline = use_skill_dependent_baseline
        self.mlp_skill_dependent_baseline = mlp_skill_dependent_baseline
        if use_skill_dependent_baseline:
            curr_env = kwargs['env']
            skill_dependent_action_space = curr_env.action_space
            new_obs_space_no_bi = curr_env.observation_space.shape[
                0] + 1  # 1 for the t_remaining
            skill_dependent_obs_space_dim = (new_obs_space_no_bi *
                                             (self.num_latents + 1) +
                                             self.num_latents, )
            skill_dependent_obs_space = Box(
                -1.0, 1.0, shape=skill_dependent_obs_space_dim)
            skill_dependent_env_spec = EnvSpec(skill_dependent_obs_space,
                                               skill_dependent_action_space)
            if self.mlp_skill_dependent_baseline:
                self.skill_dependent_baseline = GaussianMLPBaseline(
                    env_spec=skill_dependent_env_spec)
            else:
                self.skill_dependent_baseline = LinearFeatureBaseline(
                    env_spec=skill_dependent_env_spec)

    # initialize the computation graph
    # optimize is run on >= 1 trajectory at a time
    # assumptions: 1 trajectory, which is a multiple of p; that the obs_var_probs is valid
    def init_opt(self):
        assert isinstance(self.policy, HierarchicalPolicy)
        assert not self.freeze_manager and not self.freeze_skills
        manager_surr_loss = 0
        # skill_surr_loss = 0

        obs_var_sparse = ext.new_tensor('sparse_obs',
                                        ndim=2,
                                        dtype=theano.config.floatX)
        obs_var_raw = ext.new_tensor(
            'obs', ndim=3, dtype=theano.config.floatX)  # todo: check the dtype
        action_var = self.env.action_space.new_tensor_variable(
            'action',
            extra_dims=1,
        )
        advantage_var = ext.new_tensor('advantage',
                                       ndim=1,
                                       dtype=theano.config.floatX)
        # latent_var = ext.new_tensor('latents', ndim=2, dtype=theano.config.floatX)
        mean_var = ext.new_tensor('mean', ndim=2, dtype=theano.config.floatX)
        log_std_var = ext.new_tensor('log_std',
                                     ndim=2,
                                     dtype=theano.config.floatX)

        # undoing the reshape, so that batch sampling is ok
        obs_var = TT.reshape(obs_var_raw, [
            obs_var_raw.shape[0] * obs_var_raw.shape[1], obs_var_raw.shape[2]
        ])

        ############################################################
        ### calculating the skills portion of the surrogate loss ###
        ############################################################
        latent_var_sparse = self.policy.manager.dist_info_sym(
            obs_var_sparse)['mean']
        latent_var = TT.extra_ops.repeat(latent_var_sparse,
                                         self.period,
                                         axis=0)  #.dimshuffle(0, 'x')
        dist_info_var = self.policy.low_policy.dist_info_sym(
            obs_var, state_info_var=latent_var)
        old_dist_info_var = dict(mean=mean_var, log_std=log_std_var)
        skill_lr = self.diagonal.likelihood_ratio_sym(action_var,
                                                      old_dist_info_var,
                                                      dist_info_var)
        skill_surr_loss_vector = TT.minimum(
            skill_lr * advantage_var,
            TT.clip(skill_lr, 1 - self.epsilon, 1 + self.epsilon) *
            advantage_var)
        skill_surr_loss = -TT.mean(skill_surr_loss_vector)

        surr_loss = skill_surr_loss  # so that the relative magnitudes are correct

        if self.freeze_skills and not self.freeze_manager:
            raise NotImplementedError
        elif self.freeze_manager and not self.freeze_skills:
            raise NotImplementedError
        else:
            assert (not self.freeze_manager) or (not self.freeze_skills)
            input_list = [
                obs_var_raw, obs_var_sparse, action_var, advantage_var,
                mean_var, log_std_var
            ]

        self.optimizer.update_opt(loss=surr_loss,
                                  target=self.policy,
                                  inputs=input_list)
        return dict()

    # do the optimization
    def optimize_policy(self, itr, samples_data):
        print(len(samples_data['observations']), self.period)
        assert len(samples_data['observations']) % self.period == 0

        # note that I have to do extra preprocessing to the advantages, and also create obs_var_sparse

        if self.use_skill_dependent_baseline:
            input_values = tuple(
                ext.extract(samples_data, "observations", "actions",
                            "advantages", "agent_infos", "skill_advantages"))
        else:
            input_values = tuple(
                ext.extract(samples_data, "observations", "actions",
                            "advantages", "agent_infos"))

        obs_raw = input_values[0].reshape(
            input_values[0].shape[0] // self.period, self.period,
            input_values[0].shape[1])

        obs_sparse = input_values[0].take(
            [i for i in range(0, input_values[0].shape[0], self.period)],
            axis=0)
        if not self.continuous_latent:
            advantage_sparse = input_values[2].reshape(
                [input_values[2].shape[0] // self.period, self.period])[:, 0]
            latents = input_values[3]['latents']
            latents_sparse = latents.take(
                [i for i in range(0, latents.shape[0], self.period)], axis=0)
            prob = np.array(list(input_values[3]['prob'].take(
                [i for i in range(0, latents.shape[0], self.period)], axis=0)),
                            dtype=np.float32)
        mean = input_values[3]['mean']
        log_std = input_values[3]['log_std']

        if self.use_skill_dependent_baseline:
            advantage_var = input_values[4]
        else:
            advantage_var = input_values[2]
        # import ipdb; ipdb.set_trace()
        if self.freeze_skills and not self.freeze_manager:
            raise NotImplementedError
        elif self.freeze_manager and not self.freeze_skills:
            raise NotImplementedError
        else:
            assert (not self.freeze_manager) or (not self.freeze_skills)
            all_input_values = (obs_raw, obs_sparse, input_values[1],
                                advantage_var, mean, log_std)

        # todo: assign current parameters to old policy; does this work?
        # old_param_values = self.policy.get_param_values(trainable=True)
        # self.old_policy.set_param_values(old_param_values, trainable=True)
        # old_param_values = self.policy.get_param_values()
        # self.old_policy.set_param_values(old_param_values)
        loss_before = self.optimizer.loss(all_input_values)
        self.optimizer.optimize(all_input_values)
        loss_after = self.optimizer.loss(all_input_values)
        logger.record_tabular('LossBefore', loss_before)
        logger.record_tabular('LossAfter', loss_after)
        logger.record_tabular('dLoss', loss_before - loss_after)
        return dict()

    def get_itr_snapshot(self, itr, samples_data):
        return dict(itr=itr,
                    policy=self.policy,
                    baseline=self.baseline,
                    env=self.env)

    def log_diagnostics(self, paths):
        # paths obtained by self.sampler.obtain_samples
        BatchPolopt.log_diagnostics(self, paths)