Exemple #1
0
class AIRL(SingleTimestepIRL):
    """ 
    Fits advantage function based reward functions
    """
    def __init__(self,
                 env,
                 expert_trajs=None,
                 discrim_arch=relu_net,
                 discrim_arch_args={},
                 normalize_reward=False,
                 score_dtau=False,
                 init_itrs=None,
                 discount=1.0,
                 l2_reg=0,
                 state_only=False,
                 shaping_with_actions=False,
                 max_itrs=100,
                 fusion=False,
                 fusion_subsample=0.5,
                 action_penalty=0.0,
                 name='trajprior'):
        super(AIRL, self).__init__()
        env_spec = env.spec
        if fusion:
            self.fusion = RamFusionDistr(100, subsample_ratio=fusion_subsample)
        else:
            self.fusion = None
        self.dO = env_spec.observation_space.flat_dim
        self.dU = env_spec.action_space.flat_dim
        if isinstance(env.action_space, Box):
            self.continuous = True
        else:
            self.continuous = False
        self.normalize_reward = normalize_reward
        self.score_dtau = score_dtau
        self.init_itrs = init_itrs
        self.gamma = discount
        #assert fitted_value_fn_arch is not None
        self.set_demos(expert_trajs)
        self.state_only = state_only
        self.max_itrs = max_itrs

        # build energy model
        with tf.variable_scope(name) as _vs:
            # Should be batch_size x T x dO/dU
            self.obs_t = tf.placeholder(tf.float32, [None, self.dO],
                                        name='obs')
            self.nobs_t = tf.placeholder(tf.float32, [None, self.dO],
                                         name='nobs')
            self.act_t = tf.placeholder(tf.float32, [None, self.dU],
                                        name='act')
            self.nact_t = tf.placeholder(tf.float32, [None, self.dU],
                                         name='nact')
            self.labels = tf.placeholder(tf.float32, [None, 1], name='labels')
            self.lprobs = tf.placeholder(tf.float32, [None, 1],
                                         name='log_probs')
            self.lr = tf.placeholder(tf.float32, (), name='lr')

            #obs_act = tf.concat([self.obs_t, self.act_t], axis=1)
            with tf.variable_scope('discrim') as dvs:
                if self.state_only:
                    with tf.variable_scope('energy') as vs:
                        # reward function (or q-function)
                        self.energy = discrim_arch(self.obs_t,
                                                   dout=1,
                                                   **discrim_arch_args)
                        energy_vars = tf.get_collection(
                            tf.GraphKeys.TRAINABLE_VARIABLES, scope=vs.name)
                else:
                    if self.continuous:
                        obs_act = tf.concat([self.obs_t, self.act_t], axis=1)
                        with tf.variable_scope('energy') as vs:
                            # reward function (or q-function)
                            self.energy = discrim_arch(obs_act,
                                                       dout=1,
                                                       **discrim_arch_args)
                            energy_vars = tf.get_collection(
                                tf.GraphKeys.TRAINABLE_VARIABLES,
                                scope=vs.name)
                    else:
                        raise ValueError()

                if shaping_with_actions:
                    nobs_act = tf.concat([self.nobs_t, self.nact_t], axis=1)
                    obs_act = tf.concat([self.obs_t, self.act_t], axis=1)
                else:
                    nobs_act = self.nobs_t
                    obs_act = self.obs_t

                # with tf.variable_scope('vfn'):
                #     fitted_value_fn_n = fitted_value_fn_arch(nobs_act, dout=1)
                # with tf.variable_scope('vfn', reuse=True):
                #     self.value_fn = fitted_value_fn = fitted_value_fn_arch(obs_act, dout=1)

                self.value_fn = tf.zeros(shape=[])

                # Define log p_tau(a|s) = r + gamma * V(s') - V(s)

                if action_penalty > 0:
                    self.r = r = -self.energy + action_penalty * tf.reduce_sum(
                        tf.square(self.act_t), axis=1, keepdims=True)
                else:
                    self.r = r = -self.energy

                self.qfn = r  #+self.gamma*fitted_value_fn_n
                log_p_tau = r  #  + self.gamma*fitted_value_fn_n - fitted_value_fn
                discrim_vars = tf.get_collection('reg_vars', scope=dvs.name)

            log_q_tau = self.lprobs

            if l2_reg > 0:
                reg_loss = l2_reg * tf.reduce_sum(
                    [tf.reduce_sum(tf.square(var)) for var in discrim_vars])
            else:
                reg_loss = 0

            log_pq = tf.reduce_logsumexp([log_p_tau, log_q_tau], axis=0)
            self.d_tau = tf.exp(log_p_tau - log_pq)
            cent_loss = -tf.reduce_mean(self.labels * (log_p_tau - log_pq) +
                                        (1 - self.labels) *
                                        (log_q_tau - log_pq))

            self.loss = cent_loss
            tot_loss = self.loss + reg_loss
            self.step = tf.train.AdamOptimizer(
                learning_rate=self.lr).minimize(tot_loss)
            self._make_param_ops(_vs)

    def fit(self,
            paths,
            policy=None,
            batch_size=32,
            logger=None,
            lr=1e-3,
            last_timestep_only=False,
            **kwargs):

        if self.fusion is not None:
            old_paths = self.fusion.sample_paths(n=len(paths))
            self.fusion.add_paths(paths)
            paths = paths + old_paths

        self._compute_path_probs(paths, insert=True)

        #self.eval_expert_probs(paths, policy, insert=True)
        for traj in self.expert_trajs:
            if 'agent_infos' in traj:
                #print('deleting agent_infos')
                del traj['agent_infos']
                del traj['a_logprobs']
        self.eval_expert_probs(self.expert_trajs, policy, insert=True)

        self._insert_next_state(paths)
        self._insert_next_state(self.expert_trajs)

        obs, obs_next, acts, acts_next, path_probs = self.extract_paths(
            paths,
            keys=('observations', 'observations_next', 'actions',
                  'actions_next', 'a_logprobs'),
            last_timestep_only=last_timestep_only)
        expert_obs, expert_obs_next, expert_acts, expert_acts_next, expert_probs = self.extract_paths(
            self.expert_trajs,
            keys=('observations', 'observations_next', 'actions',
                  'actions_next', 'a_logprobs'),
            last_timestep_only=last_timestep_only)

        # Train discriminator
        for it in TrainingIterator(self.max_itrs, heartbeat=5):
            nobs_batch, obs_batch, nact_batch, act_batch, lprobs_batch = \
                self.sample_batch(obs_next, obs, acts_next, acts, path_probs, batch_size=batch_size)

            nexpert_obs_batch, expert_obs_batch, nexpert_act_batch, expert_act_batch, expert_lprobs_batch = \
                self.sample_batch(expert_obs_next, expert_obs, expert_acts_next, expert_acts, expert_probs, batch_size=batch_size)

            labels = np.zeros((batch_size * 2, 1))
            labels[batch_size:] = 1.0
            obs_batch = np.concatenate([obs_batch, expert_obs_batch], axis=0)
            nobs_batch = np.concatenate([nobs_batch, nexpert_obs_batch],
                                        axis=0)
            act_batch = np.concatenate([act_batch, expert_act_batch], axis=0)
            nact_batch = np.concatenate([nact_batch, nexpert_act_batch],
                                        axis=0)
            lprobs_batch = np.expand_dims(np.concatenate(
                [lprobs_batch, expert_lprobs_batch], axis=0),
                                          axis=1).astype(np.float32)

            feed_dict = {
                self.act_t: act_batch,
                self.obs_t: obs_batch,
                self.nobs_t: nobs_batch,
                self.nact_t: nact_batch,
                self.labels: labels,
                self.lprobs: lprobs_batch,
                self.lr: lr
            }
            loss, _ = tf.get_default_session().run([self.loss, self.step],
                                                   feed_dict=feed_dict)

            it.record('loss', loss)
            if it.heartbeat:
                print(it.itr_message())
                mean_loss = it.pop_mean('loss')
                print('\tLoss:%f' % mean_loss)

        if logger:
            logger.record_tabular('GCLDiscrimLoss', mean_loss)
            #obs_next = np.r_[obs_next, np.expand_dims(obs_next[-1], axis=0)]
            #logZ,
            energy, logZ, dtau = tf.get_default_session().run(
                [self.energy, self.value_fn, self.d_tau],
                feed_dict={
                    self.act_t: acts,
                    self.obs_t: obs,
                    self.nobs_t: obs_next,
                    self.nact_t: acts_next,
                    self.lprobs: np.expand_dims(path_probs, axis=1)
                })
            logger.record_tabular('GCLLogZ', np.mean(logZ))
            logger.record_tabular('GCLAverageEnergy', np.mean(energy))
            logger.record_tabular('GCLAverageLogPtau', np.mean(-energy - logZ))
            logger.record_tabular('GCLAverageLogQtau', np.mean(path_probs))
            logger.record_tabular('GCLMedianLogQtau', np.median(path_probs))
            logger.record_tabular('GCLAverageDtau', np.mean(dtau))

            #expert_obs_next = np.r_[expert_obs_next, np.expand_dims(expert_obs_next[-1], axis=0)]
            energy, logZ, dtau = tf.get_default_session().run(
                [self.energy, self.value_fn, self.d_tau],
                feed_dict={
                    self.act_t: expert_acts,
                    self.obs_t: expert_obs,
                    self.nobs_t: expert_obs_next,
                    self.nact_t: expert_acts_next,
                    self.lprobs: np.expand_dims(expert_probs, axis=1)
                })
            logger.record_tabular('GCLAverageExpertEnergy', np.mean(energy))
            logger.record_tabular('GCLAverageExpertLogPtau',
                                  np.mean(-energy - logZ))
            logger.record_tabular('GCLAverageExpertLogQtau',
                                  np.mean(expert_probs))
            logger.record_tabular('GCLMedianExpertLogQtau',
                                  np.median(expert_probs))
            logger.record_tabular('GCLAverageExpertDtau', np.mean(dtau))
        return mean_loss

    def eval(self, paths, gamma=1.0, **kwargs):
        """
        Return bonus
        """
        if self.score_dtau:
            self._compute_path_probs(paths, insert=True)
            obs, obs_next, acts, path_probs = self.extract_paths(
                paths,
                keys=('observations', 'observations_next', 'actions',
                      'a_logprobs'))
            path_probs = np.expand_dims(path_probs, axis=1)
            scores = tf.get_default_session().run(self.d_tau,
                                                  feed_dict={
                                                      self.act_t: acts,
                                                      self.obs_t: obs,
                                                      self.nobs_t: obs_next,
                                                      self.lprobs: path_probs
                                                  })
            score = np.log(scores) - np.log(1 - scores)
            score = score[:, 0]
        else:
            obs, acts = self.extract_paths(paths)
            energy = tf.get_default_session().run(self.energy,
                                                  feed_dict={
                                                      self.act_t: acts,
                                                      self.obs_t: obs
                                                  })
            score = (-energy)[:, 0]
        return self.unpack(score, paths)

    def eval_discrim(self, paths):
        self._compute_path_probs(paths, insert=True)
        obs, obs_next, acts, path_probs = self.extract_paths(
            paths,
            keys=('observations', 'observations_next', 'actions',
                  'a_logprobs'))
        path_probs = np.expand_dims(path_probs, axis=1)
        scores = tf.get_default_session().run(self.d_tau,
                                              feed_dict={
                                                  self.act_t: acts,
                                                  self.obs_t: obs,
                                                  self.nobs_t: obs_next,
                                                  self.lprobs: path_probs
                                              })
        score = (scores)
        score = score[:, 0]
        return self.unpack(score, paths)

    def eval_single(self, obs):
        energy = tf.get_default_session().run(self.energy,
                                              feed_dict={self.obs_t: obs})
        score = (-energy)[:, 0]
        return score

    def debug_eval(self, paths, **kwargs):
        obs, acts = self.extract_paths(paths)
        energy, v, qfn = tf.get_default_session().run(
            [self.energy, self.value_fn, self.qfn],
            feed_dict={
                self.act_t: acts,
                self.obs_t: obs
            })
        return {
            'reward': -energy,
            'value': v,
            'qfn': qfn,
        }
        return {}
Exemple #2
0
class AIRL(SingleTimestepIRL):
    """ 


    Args:
        fusion (bool): Use trajectories from old iterations to train.
        state_only (bool): Fix the learned reward to only depend on state.
        score_discrim (bool): Use log D - log 1-D as reward (if true you should not need to use an entropy bonus)
        max_itrs (int): Number of training iterations to run per fit step.
    """
    def __init__(self, env,
                 expert_trajs=None,
                 reward_arch=relu_net,
                 reward_arch_args=None,
                 value_fn_arch=relu_net,
                 score_discrim=False,
                 discount=1.0,
                 state_only=False,
                 max_itrs=100,
                 fusion=False,
                 name='airl'):
        super(AIRL, self).__init__()
        env_spec = env.spec
        if reward_arch_args is None:
            reward_arch_args = {}

        if fusion:
            self.fusion = RamFusionDistr(100, subsample_ratio=0.5)
        else:
            self.fusion = None
        self.dO = env_spec.observation_space.flat_dim
        self.dU = env_spec.action_space.flat_dim
        assert isinstance(env.action_space, Box)
        self.score_discrim = score_discrim
        self.gamma = discount
        assert value_fn_arch is not None
        self.set_demos(expert_trajs)
        self.state_only=state_only
        self.max_itrs=max_itrs

        # build energy model
        with tf.variable_scope(name) as _vs:
            # Should be batch_size x T x dO/dU
            self.obs_t = tf.placeholder(tf.float32, [None, self.dO], name='obs')
            self.nobs_t = tf.placeholder(tf.float32, [None, self.dO], name='nobs')
            self.act_t = tf.placeholder(tf.float32, [None, self.dU], name='act')
            self.nact_t = tf.placeholder(tf.float32, [None, self.dU], name='nact')
            self.labels = tf.placeholder(tf.float32, [None, 1], name='labels')
            self.lprobs = tf.placeholder(tf.float32, [None, 1], name='log_probs')
            self.lr = tf.placeholder(tf.float32, (), name='lr')

            with tf.variable_scope('discrim') as dvs:
                rew_input = self.obs_t
                if not self.state_only:
                    rew_input = tf.concat([self.obs_t, self.act_t], axis=1)
                with tf.variable_scope('reward'):
                    self.reward = reward_arch(rew_input, dout=1, **reward_arch_args)
                    #energy_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=vs.name)

                # value function shaping
                with tf.variable_scope('vfn'):
                    fitted_value_fn_n = value_fn_arch(self.nobs_t, dout=1)
                with tf.variable_scope('vfn', reuse=True):
                    self.value_fn = fitted_value_fn = value_fn_arch(self.obs_t, dout=1)

                # Define log p_tau(a|s) = r + gamma * V(s') - V(s)
                self.qfn = self.reward + self.gamma*fitted_value_fn_n
                log_p_tau = self.reward  + self.gamma*fitted_value_fn_n - fitted_value_fn

            log_q_tau = self.lprobs

            log_pq = tf.reduce_logsumexp([log_p_tau, log_q_tau], axis=0)
            self.discrim_output = tf.exp(log_p_tau-log_pq)
            cent_loss = -tf.reduce_mean(self.labels*(log_p_tau-log_pq) + (1-self.labels)*(log_q_tau-log_pq))

            self.loss = cent_loss
            tot_loss = self.loss
            self.step = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(tot_loss)
            self._make_param_ops(_vs)

    def fit(self, paths, policy=None, batch_size=32, logger=None, lr=1e-3,**kwargs):

        if self.fusion is not None:
            old_paths = self.fusion.sample_paths(n=len(paths))
            self.fusion.add_paths(paths)
            paths = paths+old_paths

        # eval samples under current policy
        self._compute_path_probs(paths, insert=True)

        # eval expert log probs under current policy
        self.eval_expert_probs(self.expert_trajs, policy, insert=True)

        self._insert_next_state(paths)
        self._insert_next_state(self.expert_trajs)
        obs, obs_next, acts, acts_next, path_probs = \
            self.extract_paths(paths,
                               keys=('observations', 'observations_next', 'actions', 'actions_next', 'a_logprobs'))
        expert_obs, expert_obs_next, expert_acts, expert_acts_next, expert_probs = \
            self.extract_paths(self.expert_trajs,
                               keys=('observations', 'observations_next', 'actions', 'actions_next', 'a_logprobs'))


        # Train discriminator
        for it in TrainingIterator(self.max_itrs, heartbeat=5):
            nobs_batch, obs_batch, nact_batch, act_batch, lprobs_batch = \
                self.sample_batch(obs_next, obs, acts_next, acts, path_probs, batch_size=batch_size)

            nexpert_obs_batch, expert_obs_batch, nexpert_act_batch, expert_act_batch, expert_lprobs_batch = \
                self.sample_batch(expert_obs_next, expert_obs, expert_acts_next, expert_acts, expert_probs, batch_size=batch_size)

            # Build feed dict
            labels = np.zeros((batch_size*2, 1))
            labels[batch_size:] = 1.0
            obs_batch = np.concatenate([obs_batch, expert_obs_batch], axis=0)
            nobs_batch = np.concatenate([nobs_batch, nexpert_obs_batch], axis=0)
            act_batch = np.concatenate([act_batch, expert_act_batch], axis=0)
            nact_batch = np.concatenate([nact_batch, nexpert_act_batch], axis=0)
            lprobs_batch = np.expand_dims(np.concatenate([lprobs_batch, expert_lprobs_batch], axis=0), axis=1).astype(np.float32)
            feed_dict = {
                self.act_t: act_batch,
                self.obs_t: obs_batch,
                self.nobs_t: nobs_batch,
                self.nact_t: nact_batch,
                self.labels: labels,
                self.lprobs: lprobs_batch,
                self.lr: lr
                }

            loss, _ = tf.get_default_session().run([self.loss, self.step], feed_dict=feed_dict)
            it.record('loss', loss)
            if it.heartbeat:
                print(it.itr_message())
                mean_loss = it.pop_mean('loss')
                print('\tLoss:%f' % mean_loss)

        if logger:
            logger.record_tabular('GCLDiscrimLoss', mean_loss)
            #obs_next = np.r_[obs_next, np.expand_dims(obs_next[-1], axis=0)]
            energy, logZ, dtau = tf.get_default_session().run([self.reward, self.value_fn, self.discrim_output],
                                                               feed_dict={self.act_t: acts, self.obs_t: obs, self.nobs_t: obs_next,
                                                                   self.nact_t: acts_next,
                                                               self.lprobs: np.expand_dims(path_probs, axis=1)})
            energy = -energy
            logger.record_tabular('GCLLogZ', np.mean(logZ))
            logger.record_tabular('GCLAverageEnergy', np.mean(energy))
            logger.record_tabular('GCLAverageLogPtau', np.mean(-energy-logZ))
            logger.record_tabular('GCLAverageLogQtau', np.mean(path_probs))
            logger.record_tabular('GCLMedianLogQtau', np.median(path_probs))
            logger.record_tabular('GCLAverageDtau', np.mean(dtau))


            #expert_obs_next = np.r_[expert_obs_next, np.expand_dims(expert_obs_next[-1], axis=0)]
            energy, logZ, dtau = tf.get_default_session().run([self.reward, self.value_fn, self.discrim_output],
                    feed_dict={self.act_t: expert_acts, self.obs_t: expert_obs, self.nobs_t: expert_obs_next,
                                    self.nact_t: expert_acts_next,
                                    self.lprobs: np.expand_dims(expert_probs, axis=1)})
            energy = -energy
            logger.record_tabular('GCLAverageExpertEnergy', np.mean(energy))
            logger.record_tabular('GCLAverageExpertLogPtau', np.mean(-energy-logZ))
            logger.record_tabular('GCLAverageExpertLogQtau', np.mean(expert_probs))
            logger.record_tabular('GCLMedianExpertLogQtau', np.median(expert_probs))
            logger.record_tabular('GCLAverageExpertDtau', np.mean(dtau))
        return mean_loss

    def eval(self, paths, **kwargs):
        """
        Return bonus
        """
        if self.score_discrim:
            self._compute_path_probs(paths, insert=True)
            obs, obs_next, acts, path_probs = self.extract_paths(paths, keys=('observations', 'observations_next', 'actions', 'a_logprobs'))
            path_probs = np.expand_dims(path_probs, axis=1)
            scores = tf.get_default_session().run(self.discrim_output,
                                              feed_dict={self.act_t: acts, self.obs_t: obs,
                                                         self.nobs_t: obs_next,
                                                         self.lprobs: path_probs})
            score = np.log(np.maximum(scores,1e-8)) - np.log(np.maximum(1 - scores,1e-8))
            score = score[:,0]
        else:
            obs, acts = self.extract_paths(paths)
            reward = tf.get_default_session().run(self.reward,
                                              feed_dict={self.act_t: acts, self.obs_t: obs})
            score = reward[:,0]
        return self.unpack(score, paths)

    def eval_single(self, obs):
        reward = tf.get_default_session().run(self.reward,
                                              feed_dict={self.obs_t: obs})
        score = reward[:, 0]
        return score

    def debug_eval(self, paths, **kwargs):
        obs, acts = self.extract_paths(paths)
        reward, v, qfn = tf.get_default_session().run([self.reward, self.value_fn,
                                                            self.qfn],
                                                      feed_dict={self.act_t: acts, self.obs_t: obs})
        return {
            'reward': reward,
            'value': v,
            'qfn': qfn,
        }
Exemple #3
0
class Empowerment(SingleTimestepIRL):
    """ 


    Args:
        fusion (bool): Use trajectories from old iterations to train.
        state_only (bool): Fix the learned reward to only depend on state.
        max_itrs (int): Number of training iterations to run per fit step.
    """
    def __init__(
            self,
            env,
            emp_fn_arch=relu_net,  #_dropout,
            scope='efn',
            max_itrs=100,
            fusion=False,
            name='empowerment'):
        super(Empowerment, self).__init__()
        env_spec = env.spec

        if fusion:
            self.fusion = RamFusionDistr(100, subsample_ratio=0.5)
        else:
            self.fusion = None
        self.dO = env_spec.observation_space.flat_dim
        self.dU = env_spec.action_space.flat_dim
        assert isinstance(env.action_space, Box)
        assert emp_fn_arch is not None
        self.max_itrs = max_itrs

        # build energy model
        with tf.variable_scope(name) as _vs:
            # Should be batch_size x T x dO/dU
            self.obs_t = tf.placeholder(tf.float32, [None, self.dO],
                                        name='obs')
            self.act_qvar = tf.placeholder(tf.float32, [None, 1],
                                           name='act_qvar')
            self.act_policy = tf.placeholder(tf.float32, [None, 1],
                                             name='act_policy')
            self.lr = tf.placeholder(tf.float32, (), name='lr')

            # empowerment function
            with tf.variable_scope(scope):
                self.empwerment = emp_fn_arch(self.obs_t, dout=1)

            cent_loss = tf.losses.mean_squared_error(
                predictions=(self.empwerment + self.act_policy),
                labels=self.act_qvar)

            self.loss_emp = cent_loss
            tot_loss_emp = self.loss_emp
            self.step_emp = tf.train.AdamOptimizer(
                learning_rate=self.lr).minimize(tot_loss_emp)

            self._make_param_ops(_vs)

    def fit(self,
            paths,
            irl_model=None,
            tempw=None,
            policy=None,
            qvar_model=None,
            batch_size=32,
            logger=None,
            lr=1e-3,
            **kwargs):

        if self.fusion is not None:
            old_paths = self.fusion.sample_paths(n=len(paths))
            self.fusion.add_paths(paths)
            paths = paths + old_paths

        #self._insert_next_state(paths)
        obs, ac, obs_next = self.extract_paths(paths,
                                               keys=('observations', 'actions',
                                                     'observations_next'))

        for it in TrainingIterator(self.max_itrs, heartbeat=1):

            nobs_batch, obs_batch, ac_batch = self.sample_batch(
                obs_next, obs, ac)
            dist_info_vars = policy.dist_info_sym(obs_batch, None)

            dist_s = policy.distribution.log_likelihood(
                policy.distribution.sample(dist_info_vars), dist_info_vars)

            q_input = tf.concat([obs_batch, nobs_batch], axis=1)
            q_dist_info_vars = qvar_model.dist_info_sym(q_input, None)

            q_dist_s = policy.distribution.log_likelihood(
                policy.distribution.sample(q_dist_info_vars), q_dist_info_vars)

            # Build feed dict
            feed_dict = {
                self.obs_t: obs_batch,
                self.act_qvar: q_dist_s.eval(),
                self.act_policy: dist_s.eval(),
                self.lr: lr
            }

            loss_emp, _ = tf.get_default_session().run(
                [self.loss_emp, self.step_emp], feed_dict=feed_dict)
            it.record('loss_emp', loss_emp)
            if it.heartbeat:
                print(it.itr_message())
                mean_loss_emp = it.pop_mean('loss_emp')
                print('\tLoss_emp:%f' % mean_loss_emp)

        return mean_loss_emp

    def eval(self, obs, **kwargs):
        """
        Return bonus
        """

        empw = tf.get_default_session().run(self.empwerment,
                                            feed_dict={self.obs_t: obs})
        return empw
Exemple #4
0
class Qvar(SingleTimestepIRL):
    """ 


    Args:
        fusion (bool): Use trajectories from old iterations to train.
        state_only (bool): Fix the learned reward to only depend on state.
        score_discrim (bool): Use log D - log 1-D as reward (if true you should not need to use an entropy bonus)
        max_itrs (int): Number of training iterations to run per fit step.
    """
    def __init__(self, env,
                 expert_trajs=None,
                 qvar=None,
                 score_discrim=False,
                 discount=1.0,
                 state_only=False,
                 max_itrs=100,
                 fusion=False,
                 name='qvar'):
        super(Qvar, self).__init__()
        env_spec = env.spec
        if qvar is not None:
            self.qvar = qvar

        if fusion:
            self.fusion = RamFusionDistr(100, subsample_ratio=0.5)
        else:
            self.fusion = None
        self.dO = env_spec.observation_space.flat_dim
        self.dU = env_spec.action_space.flat_dim
        assert isinstance(env.action_space, Box)
        self.set_demos(expert_trajs)
        self.max_itrs=max_itrs

        # build energy model
        with tf.variable_scope(name) as _vs:
            # Should be batch_size x T x dO/dU
            self.act_t = tf.placeholder(tf.float32, [None, self.dU], name='act')
            self.obs_t = tf.placeholder(tf.float32, [None, self.dO], name='obs')
            self.nobs_t = tf.placeholder(tf.float32, [None, self.dO], name='nobs')
            self.lr = tf.placeholder(tf.float32, (), name='lr') 

            with tf.variable_scope('q_var') as dvs:
                q_input = tf.concat([self.obs_t,self.nobs_t],axis=1)
                self.act_predicted=self.qvar.dist_info_sym(q_input,None)

            self.loss_q = tf.losses.mean_squared_error(predictions=self.act_predicted["mean"],labels=self.act_t)
            tot_loss_q = self.loss_q


            self.step_q = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(tot_loss_q)
            self._make_param_ops(_vs)

    def fit(self, paths, batch_size=32, logger=None, lr=1e-3,**kwargs):

        if self.fusion is not None:
            old_paths = self.fusion.sample_paths(n=len(paths))
            self.fusion.add_paths(paths)
            paths = paths+old_paths


        obs, obs_next, acts = \
            self.extract_paths(paths,
                               keys=('observations', 'observations_next', 'actions'))
        '''expert_obs, expert_obs_next, expert_acts = \
            self.extract_paths(self.expert_trajs,
                               keys=('observations', 'observations_next', 'actions'))'''


        # Train discriminator
        for it in TrainingIterator(self.max_itrs, heartbeat=5):
            nobs_batch, obs_batch, act_batch = \
                self.sample_batch(obs_next, obs, acts, batch_size=batch_size)


            feed_dict = {
                self.act_t: act_batch,
                self.obs_t: obs_batch,
                self.nobs_t: nobs_batch,
                self.lr: lr
                }

            loss_q, _ = tf.get_default_session().run([self.loss_q, self.step_q], feed_dict=feed_dict)
            it.record('loss_q', loss_q)
            if it.heartbeat:
                mean_loss_q = it.pop_mean('loss_q')
                print('\tLoss_q:%f' % mean_loss_q)


        return mean_loss_q

    def dist_info_sym(self, q_input,state_info_vars):
        return self.qvar.dist_info_sym(q_input,None)

    '''def set_params(self, params):
Exemple #5
0
class AIRL_Nstep(SingleTimestepIRL):
    """ 
    Args:
        fusion (bool): Use trajectories from old iterations to train.
        state_only (bool): Fix the learned reward to only depend on state.
        score_discrim (bool): Use log D - log 1-D as reward (if true you should not need to use an entropy bonus)
        max_itrs (int): Number of training iterations to run per fit step.
    """
    def __init__(self, env,
                 expert_trajs=None,
                 reward_arch=relu_net,
                 reward_arch_args=None,
                 value_fn_arch=relu_net,
                 score_discrim=False,
                 discount=1.0,
                 number_obs = 3,
                 state_only=False,
                 max_itrs=100,
                 fusion=False,
                 name='airl'):
        super(AIRL_Nstep, self).__init__()
        env_spec = env.spec
        if reward_arch_args is None:
            reward_arch_args = {}

        if fusion:
            self.fusion = RamFusionDistr(100, subsample_ratio=0.5)
        else:
            self.fusion = None
        self.dO = env_spec.observation_space.flat_dim
        self.dU = env_spec.action_space.flat_dim
        assert isinstance(env.action_space, Box)
        self.score_discrim = score_discrim
        self.gamma = discount
        assert value_fn_arch is not None
        self.set_demos(expert_trajs)
        self.state_only=state_only
        self.max_itrs=max_itrs
        self.number_obs = number_obs

        # build energy model
        with tf.variable_scope(name) as _vs:
            # Should be batch_size x T x dO/dU
            self.obs_t = tf.placeholder(tf.float32, [None, self.number_obs, self.dO], name='obs')
            self.nobs_t = tf.placeholder(tf.float32, [None, self.dO], name='nobs')
            self.act_t = tf.placeholder(tf.float32, [None, self.number_obs,self.dU], name='act')
            self.nact_t = tf.placeholder(tf.float32, [None, self.dU], name='nact')
            self.labels = tf.placeholder(tf.float32, [None, 1], name='labels')
            self.lprobs = tf.placeholder(tf.float32, [None, 1], name='log_probs')
            self.lr = tf.placeholder(tf.float32, (), name='lr')

            with tf.variable_scope('discrim') as dvs:
                rew_input = self.obs_t
                if not self.state_only:
                    rew_input = tf.concat([self.obs_t, self.act_t], axis=2)
                with tf.variable_scope('reward'):
                    self.reward = reward_arch(tf.reshape(rew_input, [-1, rew_input.shape[2] ]), dout=1, **reward_arch_args)
                    self.reward = tf.reshape(self.reward, [-1, self.number_obs, 1])
                    #energy_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=vs.name)

                # value function shaping
                with tf.variable_scope('vfn'):
                    fitted_value_fn_n = value_fn_arch(self.nobs_t, dout=1)
                with tf.variable_scope('vfn', reuse=True):
                    self.value_fn = fitted_value_fn = value_fn_arch(self.obs_t[:,0,:], dout=1)

                # Define log p_tau(a|s) = r + gamma * V(s') - V(s) 
                gamma_coefs = np.ones([self.number_obs], dtype = np.float32)
                gamma_coefs[1:] *=  self.gamma
                gamma_coefs = np.cumprod(gamma_coefs)
                gamma_coefs = np.expand_dims(gamma_coefs, axis=1)
                self.qfn = tf.reduce_sum(self.reward*gamma_coefs, axis=1) + (self.gamma**self.number_obs)*fitted_value_fn_n
                log_p_tau = self.qfn - fitted_value_fn

            log_q_tau = self.lprobs

            log_pq = tf.reduce_logsumexp([log_p_tau, log_q_tau], axis=0)
            self.discrim_output = tf.exp(log_p_tau-log_pq)
            cent_loss = -tf.reduce_mean(self.labels*(log_p_tau-log_pq) + (1-self.labels)*(log_q_tau-log_pq))

            self.loss = cent_loss
            tot_loss = self.loss
            self.step = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(tot_loss)
            self._make_param_ops(_vs)

    def _reorganize_states(self, paths, pad_val=0.0):
        for path in paths:
            if 'observations_next' in path:
                continue
            nobs = path['observations'][self.number_obs:]
            nact = path['actions'][self.number_obs:]
            nobs = np.r_[nobs, pad_val*np.ones(shape=[self.number_obs, self.dO ], dtype=np.float32 )]
            nact = np.r_[nact, pad_val*np.ones(shape=[self.number_obs, self.dU ], dtype=np.float32 )]
            path['observations_next'] = nobs
            path['actions_next'] = nact

            multiple_obs = np.ones( (path['observations'].shape[0], self.number_obs, self.dO), dtype=np.float32 )
            multiple_act = np.ones( (path['actions'].shape[0], self.number_obs, self.dU), dtype=np.float32 )

            for idx in range(path['observations'].shape[0]):
                if idx+self.number_obs < path['observations'].shape[0]:
                    final_idx = idx+self.number_obs
                    obs = path['observations'][idx:final_idx]
                    act = path['actions'][idx:final_idx]
                else:
                    final_idx = path['observations'].shape[0]
                    delta_idx = self.number_obs - (final_idx - idx)
                    obs = np.r_[ path['observations'][idx:final_idx] , np.ones(shape=[delta_idx, self.dO ], dtype=np.float32) ]
                    act = np.r_[ path['actions'][idx:final_idx] , np.ones(shape=[delta_idx, self.dU ], dtype=np.float32) ]
                multiple_obs[idx,:,:] = obs
                multiple_act[idx,:,:] = act

            path['multi_observations'] = multiple_obs
            path['multi_actions'] = multiple_act

        return paths

    def fit(self, paths, policy=None, batch_size=32, logger=None, lr=1e-3,**kwargs):

        if self.fusion is not None:
            old_paths = self.fusion.sample_paths(n=len(paths))
            self.fusion.add_paths(paths)
            paths = paths+old_paths

        # eval samples under current policy
        self._compute_path_probs(paths, insert=True)

        # eval expert log probs under current policy
        self.eval_expert_probs(self.expert_trajs, policy, insert=True)

        self._reorganize_states(paths)
        self._reorganize_states(self.expert_trajs)
        obs, obs_next, acts, acts_next, path_probs = \
            self.extract_paths(paths,
                               keys=('multi_observations', 'observations_next', 'multi_actions', 'actions_next', 'a_logprobs'))
        expert_obs, expert_obs_next, expert_acts, expert_acts_next, expert_probs = \
            self.extract_paths(self.expert_trajs,
                               keys=('multi_observations', 'observations_next', 'multi_actions', 'actions_next', 'a_logprobs'))


        # Train discriminator
        for it in TrainingIterator(self.max_itrs, heartbeat=5):
            nobs_batch, obs_batch, nact_batch, act_batch, lprobs_batch = \
                self.sample_batch(obs_next, obs, acts_next, acts, path_probs, batch_size=batch_size)

            nexpert_obs_batch, expert_obs_batch, nexpert_act_batch, expert_act_batch, expert_lprobs_batch = \
                self.sample_batch(expert_obs_next, expert_obs, expert_acts_next, expert_acts, expert_probs, batch_size=batch_size)

            # Build feed dict
            labels = np.zeros((batch_size*2, 1))
            labels[batch_size:] = 1.0
            obs_batch = np.concatenate([obs_batch, expert_obs_batch], axis=0)
            nobs_batch = np.concatenate([nobs_batch, nexpert_obs_batch], axis=0)
            act_batch = np.concatenate([act_batch, expert_act_batch], axis=0)
            nact_batch = np.concatenate([nact_batch, nexpert_act_batch], axis=0)
            lprobs_batch = np.expand_dims(np.concatenate([lprobs_batch, expert_lprobs_batch], axis=0), axis=1).astype(np.float32)
            feed_dict = {
                self.act_t: act_batch,
                self.obs_t: obs_batch,
                self.nobs_t: nobs_batch,
                self.nact_t: nact_batch,
                self.labels: labels,
                self.lprobs: lprobs_batch,
                self.lr: lr
                }

            loss, _ = tf.get_default_session().run([self.loss, self.step], feed_dict=feed_dict)
            it.record('loss', loss)
            if it.heartbeat:
                print(it.itr_message())
                mean_loss = it.pop_mean('loss')
                print('\tLoss:%f' % mean_loss)

        if logger:
            logger.record_tabular('GCLDiscrimLoss', mean_loss)
            #obs_next = np.r_[obs_next, np.expand_dims(obs_next[-1], axis=0)]
            energy, logZ, dtau = tf.get_default_session().run([self.reward, self.value_fn, self.discrim_output],
                                                               feed_dict={self.act_t: acts, self.obs_t: obs, self.nobs_t: obs_next,
                                                                   self.nact_t: acts_next,
                                                               self.lprobs: np.expand_dims(path_probs, axis=1)})
            energy = -energy
            logger.record_tabular('GCLLogZ', np.mean(logZ))
            logger.record_tabular('GCLAverageEnergy', np.mean(energy))
            logger.record_tabular('GCLAverageExpertLogPtau', np.mean(-np.mean(energy, axis=1) - logZ))
            logger.record_tabular('GCLAverageLogQtau', np.mean(path_probs))
            logger.record_tabular('GCLMedianLogQtau', np.median(path_probs))
            logger.record_tabular('GCLAverageDtau', np.mean(dtau))


            #expert_obs_next = np.r_[expert_obs_next, np.expand_dims(expert_obs_next[-1], axis=0)]
            energy, logZ, dtau = tf.get_default_session().run([self.reward, self.value_fn, self.discrim_output],
                    feed_dict={self.act_t: expert_acts, self.obs_t: expert_obs, self.nobs_t: expert_obs_next,
                                    self.nact_t: expert_acts_next,
                                    self.lprobs: np.expand_dims(expert_probs, axis=1)})
            energy = -energy
            logger.record_tabular('GCLAverageExpertEnergy', np.mean(energy))
            logger.record_tabular('GCLAverageExpertLogPtau', np.mean(-np.mean(energy, axis=1) - logZ))
            logger.record_tabular('GCLAverageExpertLogQtau', np.mean(expert_probs))
            logger.record_tabular('GCLMedianExpertLogQtau', np.median(expert_probs))
            logger.record_tabular('GCLAverageExpertDtau', np.mean(dtau))
        return mean_loss

    def _compute_path_probs_modified(self, paths, pol_dist_type=None, insert=True,
                            insert_key='a_logprobs'):
        """
        Returns a N x T matrix of action probabilities
        """
        if insert_key in paths[0]:
            return np.array([path[insert_key] for path in paths])

        if pol_dist_type is None:
            # try to  infer distribution type
            path0 = paths[0]
            if 'log_std' in path0['agent_infos']:
                pol_dist_type = DIST_GAUSSIAN
            elif 'prob' in path0['agent_infos']:
                pol_dist_type = DIST_CATEGORICAL
            else:
                raise NotImplementedError()

        # compute path probs
        Npath = len(paths)
        actions = [path['actions'][0] for path in paths]
        if pol_dist_type == DIST_GAUSSIAN:
            params = [(path['agent_infos']['mean'], path['agent_infos']['log_std']) for path in paths]
            path_probs = [gauss_log_pdf(params[i], actions[i]) for i in range(Npath)]
        elif pol_dist_type == DIST_CATEGORICAL:
            params = [(path['agent_infos']['prob'],) for path in paths]
            path_probs = [categorical_log_pdf(params[i], actions[i]) for i in range(Npath)]
        else:
            raise NotImplementedError("Unknown distribution type")

        if insert:
            for i, path in enumerate(paths):
                path[insert_key] = path_probs[i]

        return np.array(path_probs)

    def eval(self, paths, **kwargs):
        """
        Return bonus
        """
        if self.score_discrim:
            self._compute_path_probs_modified(paths, insert=True)
            obs, obs_next, acts, path_probs = self.extract_paths(paths, keys=('multi_observations', 'observations_next', 'multi_actions', 'a_logprobs'))
            path_probs = np.expand_dims(path_probs, axis=1)
            scores = tf.get_default_session().run(self.discrim_output,
                                              feed_dict={self.act_t: acts, self.obs_t: obs,
                                                         self.nobs_t: obs_next,
                                                         self.lprobs: path_probs})
            score = np.log(scores) - np.log(1-scores)
            score = score[:,0]
        else:
            obs, acts = self.extract_paths(paths, keys=('multi_observations', 'multi_actions'))
            reward = tf.get_default_session().run(self.reward,
                                              feed_dict={self.act_t: acts, self.obs_t: obs})
            score = reward[:,0,0]
        return self.unpack(score, paths)

    def eval_single(self, obs):
        reward = tf.get_default_session().run(self.reward,
                                              feed_dict={self.obs_t: obs})
        ## DOUBLE CHECK
        score = reward[:, 0, 0]
        return score

    def debug_eval(self, paths, **kwargs):
        obs, acts = self.extract_paths(paths, keys=('multi_observations', 'multi_actions'))
        reward, v, qfn = tf.get_default_session().run([self.reward, self.value_fn,
                                                            self.qfn],
                                                      feed_dict={self.act_t: acts, self.obs_t: obs})
        return {
            'reward': reward,
            'value': v,
            'qfn': qfn,
        }
Exemple #6
0
class EAIRL(SingleTimestepIRL):
    """ 


    Args:
        fusion (bool): Use trajectories from old iterations to train.
        state_only (bool): Fix the learned reward to only depend on state.
        score_discrim (bool): Use log D - log 1-D as reward (if true you should not need to use an entropy bonus)
        max_itrs (int): Number of training iterations to run per fit step.
    """
    def __init__(self, env,
                 expert_trajs=None,
                 reward_arch=relu_net,
                 reward_arch_args=None,
                 score_discrim=False,
                 discount=1.0,
                 state_only=False,
                 max_itrs=100,
                 fusion=False,
                 name='eairl'):
        super(EAIRL, self).__init__()
        env_spec = env.spec
        if reward_arch_args is None:
            reward_arch_args = {}


        if fusion:
            self.fusion = RamFusionDistr(100, subsample_ratio=0.5)
        else:
            self.fusion = None
        self.dO = env_spec.observation_space.flat_dim
        self.dU = env_spec.action_space.flat_dim
        assert isinstance(env.action_space, Box)
        self.score_discrim = score_discrim
        self.gamma = discount
        self.set_demos(expert_trajs)
        self.state_only=state_only
        self.max_itrs=max_itrs

        # build energy model
        with tf.variable_scope(name) as _vs:
            # Should be batch_size x T x dO/dU
            self.obs_t = tf.placeholder(tf.float32, [None, self.dO], name='obs')
            self.nobs_t = tf.placeholder(tf.float32, [None, self.dO], name='nobs')
            self.act_t = tf.placeholder(tf.float32, [None, self.dU], name='act')
            self.nact_t = tf.placeholder(tf.float32, [None, self.dU], name='nact')
            self.labels = tf.placeholder(tf.float32, [None, 1], name='labels')
            self.lprobs = tf.placeholder(tf.float32, [None, 1], name='log_probs')
            self.lr = tf.placeholder(tf.float32, (), name='lr')
            self.vs = tf.placeholder(tf.float32, [None, 1], name='vs')
            self.vsp = tf.placeholder(tf.float32, [None, 1], name='vsp')

            with tf.variable_scope('discrim') as dvs:
                rew_input = self.obs_t
                if not self.state_only:
                    rew_input = tf.concat([self.obs_t, self.act_t], axis=1)
                with tf.variable_scope('reward'):
                    self.reward = reward_arch(rew_input, dout=1, **reward_arch_args)
                    


                # Define log p_tau(a|s) = r + gamma * V(s') - V(s)
                self.qfn = self.reward + self.gamma*self.vsp
                log_p_tau = self.reward  + self.gamma*self.vsp-self.vs


            log_q_tau = self.lprobs

            log_pq = tf.reduce_logsumexp([log_p_tau, log_q_tau], axis=0)
            self.discrim_output = tf.exp(log_p_tau-log_pq)
            cent_loss = -tf.reduce_mean(self.labels*(log_p_tau-log_pq) + (1-self.labels)*(log_q_tau-log_pq))

            self.loss_irl = cent_loss
            tot_loss_irl = self.loss_irl
            self.step_irl = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(tot_loss_irl)

            self._make_param_ops(_vs)

    def fit(self, paths, policy=None,empw_model=None,t_empw_model=None, batch_size=32, logger=None, lr=1e-3,**kwargs):

        if self.fusion is not None:
            old_paths = self.fusion.sample_paths(n=len(paths))
            self.fusion.add_paths(paths)
            paths = paths+old_paths

        # eval samples under current policy
        self._compute_path_probs(paths, insert=True)

        # eval expert log probs under current policy
        self.eval_expert_probs(self.expert_trajs, policy, insert=True)

        self._insert_next_state(paths)
        self._insert_next_state(self.expert_trajs)
        obs, obs_next, acts, acts_next, path_probs = \
            self.extract_paths(paths,
                               keys=('observations', 'observations_next', 'actions', 'actions_next', 'a_logprobs'))
        expert_obs, expert_obs_next, expert_acts, expert_acts_next, expert_probs = \
            self.extract_paths(self.expert_trajs,
                               keys=('observations', 'observations_next', 'actions', 'actions_next', 'a_logprobs'))


        # Train discriminator
        for it in TrainingIterator(self.max_itrs, heartbeat=5):
            nobs_batch, obs_batch, nact_batch, act_batch, lprobs_batch = \
                self.sample_batch(obs_next, obs, acts_next, acts, path_probs, batch_size=batch_size)

            nexpert_obs_batch, expert_obs_batch, nexpert_act_batch, expert_act_batch, expert_lprobs_batch = \
                self.sample_batch(expert_obs_next, expert_obs, expert_acts_next, expert_acts, expert_probs, batch_size=batch_size)

            # Build feed dict
            labels = np.zeros((batch_size*2, 1))
            labels[batch_size:] = 1.0
            obs_batch = np.concatenate([obs_batch, expert_obs_batch], axis=0)
            nobs_batch = np.concatenate([nobs_batch, nexpert_obs_batch], axis=0)
            act_batch = np.concatenate([act_batch, expert_act_batch], axis=0)
            nact_batch = np.concatenate([nact_batch, nexpert_act_batch], axis=0)
            lprobs_batch = np.expand_dims(np.concatenate([lprobs_batch, expert_lprobs_batch], axis=0), axis=1).astype(np.float32)
            vs=empw_model.eval(obs_batch)
            vsp=t_empw_model.eval(nobs_batch)
            feed_dict = {
                self.act_t: act_batch,
                self.obs_t: obs_batch,
                self.nobs_t: nobs_batch,
                self.nact_t: nact_batch,
                self.labels: labels,
                self.lprobs: lprobs_batch,
                self.lr: lr,
                self.vs: vs,
                self.vsp: vsp
                }


            loss_irl, _ = tf.get_default_session().run([self.loss_irl, self.step_irl], feed_dict=feed_dict)
            it.record('loss_irl', loss_irl)

            if it.heartbeat:
                print(it.itr_message())
                mean_loss_irl = it.pop_mean('loss_irl')
                print('\tLoss_irl:%f' % mean_loss_irl)

         

        return mean_loss_irl

    def next_state(self, paths,**kwargs):
        self._insert_next_state(paths)


    def eval(self, paths,empw_model=None,t_empw_model=None, **kwargs):
        """
        Return bonus
        """
        if self.score_discrim:
            self._compute_path_probs(paths, insert=True)
            obs, obs_next, acts, path_probs = self.extract_paths(paths, keys=('observations', 'observations_next', 'actions', 'a_logprobs'))
            path_probs = np.expand_dims(path_probs, axis=1)
            vs=empw_model.eval(obs).reshape(-1,1)
            vsp=empw_model.eval(obs_next).reshape(-1,1)
            path_probs = np.expand_dims(path_probs, axis=1)
            scores = tf.get_default_session().run(self.discrim_output,
                                              feed_dict={self.act_t: acts, self.obs_t: obs,
                                                         self.nobs_t: obs_next,
                                                         self.lprobs: path_probs.reshape(-1,1), self.vs:vs, self.vsp:vsp})
            score = np.log(scores) - np.log(1-scores)
            score = score[:,0]
        else:
            obs, acts = self.extract_paths(paths)
            reward = tf.get_default_session().run(self.reward,
                                              feed_dict={self.act_t: acts, self.obs_t: obs})
            score = reward[:,0]
        return self.unpack(score, paths)

    def eval_single(self, obs, acts):
        reward = tf.get_default_session().run(self.reward,
                                              feed_dict={self.act_t: acts, self.obs_t: obs})
        score = reward[:, 0]
        return score

    def debug_eval(self, paths, **kwargs):
        obs, acts = self.extract_paths(paths)
        reward, v, qfn = tf.get_default_session().run([self.reward, self.value_fn,
                                                            self.qfn],
                                                      feed_dict={self.act_t: acts, self.obs_t: obs})
        return {
            'reward': reward,
            'value': v,
            'qfn': qfn,
        }
class AIRL_Bootstrap(SingleTimestepIRL):
    """ 
    Args:
        fusion (bool): Use trajectories from old iterations to train.
        state_only (bool): Fix the learned reward to only depend on state.
        score_discrim (bool): Use log D - log 1-D as reward (if true you should not need to use an entropy bonus)
        max_itrs (int): Number of training iterations to run per fit step.
    """
    def __init__(self,
                 env,
                 expert_trajs=None,
                 reward_arch=relu_net,
                 reward_arch_args=None,
                 value_fn_arch=relu_net,
                 score_discrim=False,
                 sess=None,
                 discount=1.0,
                 max_nstep=10,
                 n_value_funct=1,
                 n_rew_funct=1,
                 state_only=False,
                 max_itrs=100,
                 fusion=False,
                 debug=False,
                 score_method=None,
                 name='airl'):
        super(AIRL_Bootstrap, self).__init__()
        env_spec = env.spec
        if reward_arch_args is None:
            reward_arch_args = {}

        if fusion:
            self.fusion = RamFusionDistr(100, subsample_ratio=0.5)
        else:
            self.fusion = None
        self.dO = env_spec.observation_space.flat_dim
        self.dU = env_spec.action_space.flat_dim
        assert isinstance(env.action_space, Box)
        self.score_discrim = score_discrim
        self.gamma = discount
        assert value_fn_arch is not None
        self.set_demos(expert_trajs)
        self.state_only = state_only
        self.max_itrs = max_itrs
        self.max_nstep = max_nstep
        self.n_value_funct = n_value_funct
        self.n_rew_funct = n_rew_funct

        self.reward_arch = reward_arch
        self.reward_arch_args = reward_arch_args
        self.value_fn_arch = value_fn_arch

        self.score_method = score_method

        self.debug = debug

        # build energy model
        with tf.variable_scope(name) as _vs:
            # Should be batch_size x T x dO/dU
            self.obs_t = tf.placeholder(tf.float32, [None, None, self.dO],
                                        name='obs')
            self.nobs_t = tf.placeholder(tf.float32, [None, self.dO],
                                         name='nobs')
            self.act_t = tf.placeholder(tf.float32, [None, None, self.dU],
                                        name='act')
            self.nact_t = tf.placeholder(tf.float32, [None, self.dU],
                                         name='nact')
            self.labels = tf.placeholder(tf.float32, [None, 1], name='labels')
            self.lprobs = tf.placeholder(tf.float32, [None, 1],
                                         name='log_probs')
            self.lr = tf.placeholder(tf.float32, (), name='lr')

            number_obs = tf.shape(self.obs_t)[1]

            with tf.variable_scope('discrim') as dvs:
                rew_input = self.obs_t
                if not self.state_only:
                    rew_input = tf.concat([self.obs_t, self.act_t], axis=2)

                self.reward = [None for i in range(self.n_rew_funct)]
                self.value_fn = [None for i in range(self.n_value_funct)]
                fitted_value_fn_n = [None for i in range(self.n_value_funct)]
                self.qfn = [[None for i in range(self.n_value_funct)]
                            for j in range(self.n_rew_funct)]
                self.discrim_output = [[
                    None for i in range(self.n_value_funct)
                ] for j in range(self.n_rew_funct)]
                self.loss = [[None for i in range(self.n_value_funct)]
                             for j in range(self.n_rew_funct)]
                self.step = [[None for i in range(self.n_value_funct)]
                             for j in range(self.n_rew_funct)]

                log_q_tau = self.lprobs

                for i in range(self.n_rew_funct):
                    with tf.variable_scope('reward_%d' % (i),
                                           reuse=tf.AUTO_REUSE):
                        self.reward[i] = self.reward_arch(
                            tf.reshape(rew_input, [-1, rew_input.shape[2]]),
                            dout=1,
                            **self.reward_arch_args)
                        self.reward[i] = tf.reshape(self.reward[i],
                                                    [-1, number_obs, 1])
                        #energy_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=vs.name)
                for j in range(self.n_value_funct):
                    # value function shaping
                    with tf.variable_scope('vfn_%d' % (j),
                                           reuse=tf.AUTO_REUSE):
                        fitted_value_fn_n[j] = self.value_fn_arch(self.nobs_t,
                                                                  dout=1)
                    with tf.variable_scope('vfn_%d' % (j),
                                           reuse=tf.AUTO_REUSE):
                        self.value_fn[j] = self.value_fn_arch(self.obs_t[:,
                                                                         0, :],
                                                              dout=1)

                self.avg_reward = tf.reduce_mean(tf.stack(self.reward), axis=0)

                gamma_coefs = tf.concat([
                    tf.ones([1], dtype=tf.float32),
                    self.gamma * tf.ones([number_obs - 1], dtype=tf.float32)
                ],
                                        axis=0)
                gamma_coefs = tf.cumprod(gamma_coefs)
                gamma_coefs = tf.expand_dims(gamma_coefs, axis=1)

                for i in range(self.n_rew_funct):
                    for j in range(self.n_value_funct):
                        # Define log p_tau(a|s) = r + gamma * V(s') - V(s)
                        self.qfn[i][j] = tf.reduce_sum(
                            self.reward[i] * gamma_coefs,
                            axis=1) + tf.math.pow(
                                tf.constant(self.gamma),
                                tf.to_float(number_obs)) * fitted_value_fn_n[j]
                        log_p_tau = self.qfn[i][j] - self.value_fn[j]

                        log_pq = tf.reduce_logsumexp([log_p_tau, log_q_tau],
                                                     axis=0)
                        self.discrim_output[i][j] = tf.exp(log_p_tau - log_pq)
                        cent_loss = -tf.reduce_mean(self.labels *
                                                    (log_p_tau - log_pq) +
                                                    (1 - self.labels) *
                                                    (log_q_tau - log_pq))

                        self.loss[i][j] = cent_loss

                        self.step[i][j] = tf.train.AdamOptimizer(
                            learning_rate=self.lr).minimize(self.loss[i][j])

                self._combine_predictions()
            self._make_param_ops(_vs)

    def _combine_predictions(self):
        t0_reward = [None for i in range(self.n_rew_funct)]
        t0_value = [None for i in range(self.n_value_funct)]

        reward = [None for i in range(self.n_rew_funct)]
        fitted_value_fn_n = [None for i in range(self.n_value_funct)]

        log_p_tau = [[
            None for i in range(self.n_rew_funct * self.n_value_funct)
        ] for j in range(self.max_nstep)]

        rew_input = self.obs_t
        if not self.state_only:
            rew_input = tf.concat([self.obs_t, self.act_t], axis=2)
        for i in range(self.n_rew_funct):
            with tf.variable_scope('reward_%d' % (i), reuse=tf.AUTO_REUSE):
                reward[i] = self.reward_arch(tf.reshape(
                    rew_input, [-1, rew_input.shape[2]]),
                                             dout=1,
                                             **self.reward_arch_args)
                reward[i] = tf.reshape(reward[i], [-1, self.max_nstep, 1])
                t0_reward[i] = reward[i][:, 0]

        #Master-student score method
        with tf.variable_scope('student_reward'):
            self.student_reward = self.reward_arch(rew_input[:, 0],
                                                   dout=1,
                                                   **self.reward_arch_args)

        v_input = tf.concat(
            [self.obs_t[:, 1:, :],
             tf.expand_dims(self.nobs_t, axis=1)],
            axis=1)
        for j in range(self.n_value_funct):
            # value function shaping
            with tf.variable_scope('vfn_%d' % (j), reuse=tf.AUTO_REUSE):
                fitted_value_fn_n[j] = self.value_fn_arch(tf.reshape(
                    v_input, [-1, v_input.shape[2]]),
                                                          dout=1)
                fitted_value_fn_n[j] = tf.reshape(fitted_value_fn_n[j],
                                                  [-1, self.max_nstep, 1])
            with tf.variable_scope('vfn_%d' % (j), reuse=tf.AUTO_REUSE):
                t0_value[j] = self.value_fn_arch(self.obs_t[:, 0, :], dout=1)

        #Master-student score method
        with tf.variable_scope('student_value', reuse=tf.AUTO_REUSE):
            self.student_value_n = self.value_fn_arch(v_input[:, 0], dout=1)
        with tf.variable_scope('student_value', reuse=tf.AUTO_REUSE):
            self.student_value = self.value_fn_arch(self.obs_t[:, 0, :],
                                                    dout=1)

        gamma_coefs = np.ones([self.max_nstep], dtype=np.float32)
        gamma_coefs[1:] *= self.gamma
        gamma_coefs = np.cumprod(gamma_coefs)
        gamma_coefs = np.expand_dims(gamma_coefs, axis=1)

        log_q_tau = self.lprobs

        for i in range(self.n_rew_funct):
            for j in range(self.n_value_funct):
                for single_nsteps in range(1, self.max_nstep + 1):
                    # Define log p_tau(a|s) = r + gamma * V(s') - V(s)
                    qfn = tf.reduce_sum(
                        reward[i][:, :single_nsteps] *
                        gamma_coefs[:single_nsteps],
                        axis=1) + (self.gamma**single_nsteps) * (
                            fitted_value_fn_n[j][:, (single_nsteps - 1)])
                    log_p_tau[single_nsteps - 1][i * self.n_value_funct +
                                                 j] = qfn - t0_value[j]

        #Master-student score method
        student_qfn = self.student_reward + self.gamma * self.student_value_n
        student_log_p_tau = student_qfn - self.student_value

        mean_list = [None for i in range(self.max_nstep)]
        variance_list = [None for i in range(self.max_nstep)]

        for i in range(self.max_nstep):
            mean_list[i] = tf.reduce_mean(tf.stack(log_p_tau[i]), axis=0)
            variance_list[i] = tf.math.reduce_variance(tf.stack(log_p_tau[i]),
                                                       axis=0)

        self.weights = tf.concat(variance_list, axis=1)
        self.weights = tf.nn.softmax(1. / (self.weights + 1e-8), axis=1)

        # self.weights = tf.constant( ([0.0]*(self.max_nstep-1)) + [1.0], dtype=tf.float32 )

        log_p_tau = tf.reduce_sum(self.weights * tf.concat(mean_list, axis=1),
                                  axis=1,
                                  keepdims=True)

        log_pq = tf.reduce_logsumexp([log_p_tau, log_q_tau], axis=0)
        self.ensemble_discrim_output = tf.exp(log_p_tau - log_pq)

        #Master-student score method
        student_log_pq = tf.reduce_logsumexp([student_log_p_tau, log_q_tau],
                                             axis=0)
        self.student_discrim_output = tf.exp(student_log_p_tau -
                                             student_log_pq)
        self.student_loss = -tf.reduce_mean(tf.stop_gradient(self.ensemble_discrim_output)*(student_log_p_tau-student_log_pq) + \
                             (1-tf.stop_gradient(self.ensemble_discrim_output))*(log_q_tau-student_log_pq))
        self.student_step = tf.train.AdamOptimizer(
            learning_rate=self.lr).minimize(self.student_loss)
        self.student_absolute_loss = -tf.reduce_mean(
            self.labels * (student_log_p_tau - student_log_pq) +
            (1 - self.labels) * (log_q_tau - student_log_pq))

        self.ensemble_loss = -tf.reduce_mean(self.labels *
                                             (log_p_tau - log_pq) +
                                             (1 - self.labels) *
                                             (log_q_tau - log_pq))

        self.t0_reward = tf.reduce_mean(tf.concat(t0_reward, axis=1),
                                        axis=1,
                                        keepdims=True)
        self.t0_value = tf.reduce_mean(tf.concat(t0_value, axis=1),
                                       axis=1,
                                       keepdims=True)

    def _reorganize_states(self, paths, number_obs=1, pad_val=0.0):
        for path in paths:
            if 'observations_next' in path:
                continue
            nobs = path['observations'][number_obs:]
            nact = path['actions'][number_obs:]
            nobs = np.r_[
                nobs, pad_val *
                np.ones(shape=[number_obs, self.dO], dtype=np.float32)]
            nact = np.r_[
                nact, pad_val *
                np.ones(shape=[number_obs, self.dU], dtype=np.float32)]
            path['observations_next'] = nobs
            path['actions_next'] = nact

            multiple_obs = np.ones(
                (path['observations'].shape[0], number_obs, self.dO),
                dtype=np.float32)
            multiple_act = np.ones(
                (path['actions'].shape[0], number_obs, self.dU),
                dtype=np.float32)

            for idx in range(path['observations'].shape[0]):
                if idx + number_obs < path['observations'].shape[0]:
                    final_idx = idx + number_obs
                    obs = path['observations'][idx:final_idx]
                    act = path['actions'][idx:final_idx]
                else:
                    final_idx = path['observations'].shape[0]
                    delta_idx = number_obs - (final_idx - idx)
                    obs = np.r_[
                        path['observations'][idx:final_idx],
                        np.ones(shape=[delta_idx, self.dO], dtype=np.float32)]
                    act = np.r_[
                        path['actions'][idx:final_idx],
                        np.ones(shape=[delta_idx, self.dU], dtype=np.float32)]
                multiple_obs[idx, :, :] = obs
                multiple_act[idx, :, :] = act

            path['multi_observations'] = multiple_obs
            path['multi_actions'] = multiple_act

        return paths

    def fit(self,
            paths,
            policy=None,
            batch_size=32,
            logger=None,
            lr=1e-3,
            **kwargs):

        if self.fusion is not None:
            old_paths = self.fusion.sample_paths(n=len(paths))
            self.fusion.add_paths(paths)
            paths = paths + old_paths

        # eval samples under current policy
        self._compute_path_probs(paths, insert=True)

        # eval expert log probs under current policy
        self.eval_expert_probs(self.expert_trajs, policy, insert=True)

        self._reorganize_states(paths, number_obs=self.max_nstep)
        self._reorganize_states(self.expert_trajs, number_obs=self.max_nstep)
        obs, obs_next, acts, acts_next, path_probs = \
            self.extract_paths(paths,
                               keys=('multi_observations', 'observations_next', 'multi_actions', 'actions_next', 'a_logprobs'))
        expert_obs, expert_obs_next, expert_acts, expert_acts_next, expert_probs = \
            self.extract_paths(self.expert_trajs,
                               keys=('multi_observations', 'observations_next', 'multi_actions', 'actions_next', 'a_logprobs'))

        all_obs = np.concatenate([obs, expert_obs], axis=0)
        all_nobs = np.concatenate([obs_next, expert_obs_next], axis=0)
        all_acts = np.concatenate([acts, expert_acts], axis=0)
        all_nacts = np.concatenate([acts_next, expert_acts_next], axis=0)
        all_probs = np.concatenate([path_probs, expert_probs], axis=0)
        all_labels = np.zeros((all_obs.shape[0], 1))
        all_labels[obs.shape[0]:] = 1.0

        # Train discriminator
        for it in TrainingIterator(self.max_itrs, heartbeat=5):

            if self.n_rew_funct < self.n_value_funct:
                delta = self.n_value_funct - self.n_rew_funct
                temp = np.arange(self.n_rew_funct)
                np.random.shuffle(temp)
                rew_idxs = np.r_[
                    temp,
                    np.random.randint(self.n_rew_funct, size=delta)]
                val_idxs = np.arange(self.n_value_funct)
            else:
                delta = self.n_rew_funct - self.n_value_funct
                temp = np.arange(self.n_value_funct)
                np.random.shuffle(temp)
                val_idxs = np.r_[
                    temp,
                    np.random.randint(self.n_value_funct, size=delta)]
                rew_idxs = np.arange(self.n_rew_funct)

            for idx in range(val_idxs.shape[0]):
                i = rew_idxs[idx]
                j = val_idxs[idx]
                for single_nstep in range(1, self.max_nstep + 1):
                    nobs_batch, obs_batch, nact_batch, act_batch, lprobs_batch = \
                        self.sample_batch(obs_next, obs, acts_next, acts, path_probs, batch_size=batch_size)

                    nexpert_obs_batch, expert_obs_batch, nexpert_act_batch, expert_act_batch, expert_lprobs_batch = \
                        self.sample_batch(expert_obs_next, expert_obs, expert_acts_next, expert_acts, expert_probs, batch_size=batch_size)

                    # Build feed dict
                    labels = np.zeros((batch_size * 2, 1))
                    labels[batch_size:] = 1.0
                    obs_batch = np.concatenate([obs_batch, expert_obs_batch],
                                               axis=0)
                    nobs_batch = np.concatenate(
                        [nobs_batch, nexpert_obs_batch], axis=0)
                    act_batch = np.concatenate([act_batch, expert_act_batch],
                                               axis=0)
                    nact_batch = np.concatenate(
                        [nact_batch, nexpert_act_batch], axis=0)
                    lprobs_batch = np.expand_dims(np.concatenate(
                        [lprobs_batch, expert_lprobs_batch], axis=0),
                                                  axis=1).astype(np.float32)
                    feed_dict = {
                        self.act_t:
                        act_batch[:, :single_nstep],
                        self.obs_t:
                        obs_batch[:, :single_nstep],
                        self.nobs_t:
                        nobs_batch if self.max_nstep == single_nstep else
                        obs_batch[:, single_nstep],
                        self.nact_t:
                        nact_batch if self.max_nstep == single_nstep else
                        act_batch[:, single_nstep],
                        self.labels:
                        labels,
                        self.lprobs:
                        lprobs_batch,
                        self.lr:
                        lr
                    }

                    loss, _ = tf.get_default_session().run(
                        [self.loss[i][j], self.step[i][j]],
                        feed_dict=feed_dict)
                    it.record('loss', loss)

            if self.score_discrim is False and self.score_method == 'teacher_student':
                nobs_batch, obs_batch, nact_batch, act_batch, lprobs_batch = \
                    self.sample_batch(obs_next, obs, acts_next, acts, path_probs, batch_size=batch_size)

                nexpert_obs_batch, expert_obs_batch, nexpert_act_batch, expert_act_batch, expert_lprobs_batch = \
                    self.sample_batch(expert_obs_next, expert_obs, expert_acts_next, expert_acts, expert_probs, batch_size=batch_size)

                # Build feed dict
                labels = np.zeros((batch_size * 2, 1))
                labels[batch_size:] = 1.0
                obs_batch = np.concatenate([obs_batch, expert_obs_batch],
                                           axis=0)
                nobs_batch = np.concatenate([nobs_batch, nexpert_obs_batch],
                                            axis=0)
                act_batch = np.concatenate([act_batch, expert_act_batch],
                                           axis=0)
                nact_batch = np.concatenate([nact_batch, nexpert_act_batch],
                                            axis=0)
                lprobs_batch = np.expand_dims(np.concatenate(
                    [lprobs_batch, expert_lprobs_batch], axis=0),
                                              axis=1).astype(np.float32)
                feed_dict = {
                    self.act_t: act_batch,
                    self.obs_t: obs_batch,
                    self.nobs_t: nobs_batch,
                    self.nact_t: nact_batch,
                    self.labels: labels,
                    self.lprobs: lprobs_batch,
                    self.lr: lr
                }

                rel_loss, abs_loss, ens_loss, _ = tf.get_default_session().run(
                    [
                        self.student_loss, self.student_absolute_loss,
                        self.ensemble_loss, self.student_step
                    ],
                    feed_dict=feed_dict)

            if it.heartbeat:
                print(it.itr_message())
                mean_loss = it.pop_mean('loss')
                print('\tLoss:%f' % mean_loss)

        if logger:
            logger.record_tabular('GCLMeanDiscrimLoss', mean_loss)
            sess = tf.get_default_session()
            if self.score_discrim is False and self.score_method == 'teacher_student':
                logger.record_tabular('GCLStudentRelativeLoss', rel_loss)
                logger.record_tabular('GCLStudentAbsoluteLoss', abs_loss)
                logger.record_tabular('GCLEnsembleLoss', ens_loss)
            else:
                e_loss, weights = sess.run(
                    [self.ensemble_loss, self.weights],
                    feed_dict={
                        self.act_t: all_acts,
                        self.obs_t: all_obs,
                        self.nobs_t: all_nobs,
                        self.nact_t: all_nacts,
                        self.labels: all_labels,
                        self.lprobs: np.expand_dims(all_probs, axis=1)
                    })
                logger.record_tabular('GCLEnsembleDiscrimLoss', e_loss)
                # logger.record_tabular('TimeWeights', weights)

            if self.score_discrim is False and self.score_method == 'teacher_student':
                energy, logZ, dtau, s_rew, s_val, s_dtau = sess.run([self.t0_reward, self.t0_value, self.ensemble_discrim_output, \
                                                                    self.student_reward, self.student_value, self.student_discrim_output],
                                                                feed_dict={self.act_t: acts, self.obs_t: obs, self.nobs_t: obs_next,
                                                                    self.nact_t: acts_next,
                                                                self.lprobs: np.expand_dims(path_probs, axis=1)})
            else:
                energy, logZ, dtau = sess.run(
                    [
                        self.t0_reward, self.t0_value,
                        self.ensemble_discrim_output
                    ],
                    feed_dict={
                        self.act_t: acts,
                        self.obs_t: obs,
                        self.nobs_t: obs_next,
                        self.nact_t: acts_next,
                        self.lprobs: np.expand_dims(path_probs, axis=1)
                    })

            energy = -energy
            logger.record_tabular('GCLLogZ', np.mean(logZ))
            logger.record_tabular('GCLAverageEnergy', np.mean(energy))
            logger.record_tabular('GCLAverageLogPtau', np.mean(-energy - logZ))
            logger.record_tabular('GCLAverageLogQtau', np.mean(path_probs))
            logger.record_tabular('GCLMedianLogQtau', np.median(path_probs))
            logger.record_tabular('GCLAverageDtau', np.mean(dtau))

            if self.score_discrim is False and self.score_method == 'teacher_student':
                logger.record_tabular('GCLAverageStudentEnergy',
                                      np.mean(-s_rew))
                logger.record_tabular('GCLAverageStudentLogPtau',
                                      np.mean(s_rew - s_val))
                logger.record_tabular('GCLAverageStudentDtau', np.mean(s_dtau))

            if self.score_discrim is False and self.score_method == 'teacher_student':
                energy, logZ, dtau, s_rew, s_val, s_dtau = sess.run([self.t0_reward, self.t0_value, self.ensemble_discrim_output, \
                                                                    self.student_reward, self.student_value, self.student_discrim_output],
                        feed_dict={self.act_t: expert_acts, self.obs_t: expert_obs, self.nobs_t: expert_obs_next,
                                        self.nact_t: expert_acts_next,
                                        self.lprobs: np.expand_dims(expert_probs, axis=1)})
            else:
                energy, logZ, dtau = sess.run(
                    [
                        self.t0_reward, self.t0_value,
                        self.ensemble_discrim_output
                    ],
                    feed_dict={
                        self.act_t: expert_acts,
                        self.obs_t: expert_obs,
                        self.nobs_t: expert_obs_next,
                        self.nact_t: expert_acts_next,
                        self.lprobs: np.expand_dims(expert_probs, axis=1)
                    })

            energy = -energy
            logger.record_tabular('GCLAverageExpertEnergy', np.mean(energy))
            logger.record_tabular('GCLAverageExpertLogPtau',
                                  np.mean(-energy - logZ))
            logger.record_tabular('GCLAverageExpertLogQtau',
                                  np.mean(expert_probs))
            logger.record_tabular('GCLMedianExpertLogQtau',
                                  np.median(expert_probs))
            logger.record_tabular('GCLAverageExpertDtau', np.mean(dtau))

            if self.score_discrim is False and self.score_method == 'teacher_student':
                logger.record_tabular('GCLAverageStudentExpertEnergy',
                                      np.mean(-s_rew))
                logger.record_tabular('GCLAverageStudentExpertLogPtau',
                                      np.mean(s_rew - s_val))
                logger.record_tabular('GCLAverageStudentExpertDtau',
                                      np.mean(s_dtau))

        return mean_loss

    def _compute_path_probs_modified(self,
                                     paths,
                                     pol_dist_type=None,
                                     insert=True,
                                     insert_key='a_logprobs'):
        """
        Returns a N x T matrix of action probabilities
        """
        if insert_key in paths[0]:
            return np.array([path[insert_key] for path in paths])

        if pol_dist_type is None:
            # try to  infer distribution type
            path0 = paths[0]
            if 'log_std' in path0['agent_infos']:
                pol_dist_type = DIST_GAUSSIAN
            elif 'prob' in path0['agent_infos']:
                pol_dist_type = DIST_CATEGORICAL
            else:
                raise NotImplementedError()

        # compute path probs
        Npath = len(paths)
        actions = [path['actions'][0] for path in paths]
        if pol_dist_type == DIST_GAUSSIAN:
            params = [(path['agent_infos']['mean'],
                       path['agent_infos']['log_std']) for path in paths]
            path_probs = [
                gauss_log_pdf(params[i], actions[i]) for i in range(Npath)
            ]
        elif pol_dist_type == DIST_CATEGORICAL:
            params = [(path['agent_infos']['prob'], ) for path in paths]
            path_probs = [
                categorical_log_pdf(params[i], actions[i])
                for i in range(Npath)
            ]
        else:
            raise NotImplementedError("Unknown distribution type")

        if insert:
            for i, path in enumerate(paths):
                path[insert_key] = path_probs[i]

        return np.array(path_probs)

    def eval(self, paths, **kwargs):
        """
        Return bonus
        """
        if self.score_discrim:
            self._compute_path_probs_modified(paths, insert=True)
            obs, obs_next, acts, path_probs = self.extract_paths(
                paths,
                keys=('multi_observations', 'observations_next',
                      'multi_actions', 'a_logprobs'))
            path_probs = np.expand_dims(path_probs, axis=1)
            scores = tf.get_default_session().run(self.ensemble_discrim_output,
                                                  feed_dict={
                                                      self.act_t: acts,
                                                      self.obs_t: obs,
                                                      self.nobs_t: obs_next,
                                                      self.lprobs: path_probs
                                                  })
            score = np.log(np.maximum(scores, 1e-8)) - np.log(
                np.maximum(1 - scores, 1e-8))
            score = score[:, 0]
        else:
            if self.score_method == 'sample_rewards':
                obs, acts = self.extract_paths(paths,
                                               keys=('multi_observations',
                                                     'multi_actions'))
                sampled_idx = np.random.randint(self.n_rew_funct)
                reward = tf.get_default_session().run(self.reward[sampled_idx],
                                                      feed_dict={
                                                          self.act_t:
                                                          acts[:, :1],
                                                          self.obs_t:
                                                          obs[:, :1]
                                                      })
                score = reward[:, 0, 0]
            elif self.score_method == 'average_rewards':
                obs, acts = self.extract_paths(paths,
                                               keys=('multi_observations',
                                                     'multi_actions'))
                reward = tf.get_default_session().run(self.avg_reward,
                                                      feed_dict={
                                                          self.act_t:
                                                          acts[:, :1],
                                                          self.obs_t:
                                                          obs[:, :1]
                                                      })
                score = reward[:, 0, 0]
            elif self.score_method == 'teacher_student':
                obs, acts = self.extract_paths(paths,
                                               keys=('multi_observations',
                                                     'multi_actions'))
                reward = tf.get_default_session().run(self.student_reward,
                                                      feed_dict={
                                                          self.act_t:
                                                          acts[:, :1],
                                                          self.obs_t:
                                                          obs[:, :1]
                                                      })
                score = reward[:, 0]
        return self.unpack(score, paths)