Exemplo n.º 1
0
    def fit(self,
            paths,
            irl_model=None,
            tempw=None,
            policy=None,
            qvar_model=None,
            batch_size=32,
            logger=None,
            lr=1e-3,
            **kwargs):

        if self.fusion is not None:
            old_paths = self.fusion.sample_paths(n=len(paths))
            self.fusion.add_paths(paths)
            paths = paths + old_paths

        #self._insert_next_state(paths)
        obs, ac, obs_next = self.extract_paths(paths,
                                               keys=('observations', 'actions',
                                                     'observations_next'))

        for it in TrainingIterator(self.max_itrs, heartbeat=1):

            nobs_batch, obs_batch, ac_batch = self.sample_batch(
                obs_next, obs, ac)
            dist_info_vars = policy.dist_info_sym(obs_batch, None)

            dist_s = policy.distribution.log_likelihood(
                policy.distribution.sample(dist_info_vars), dist_info_vars)

            q_input = tf.concat([obs_batch, nobs_batch], axis=1)
            q_dist_info_vars = qvar_model.dist_info_sym(q_input, None)

            q_dist_s = policy.distribution.log_likelihood(
                policy.distribution.sample(q_dist_info_vars), q_dist_info_vars)

            # Build feed dict
            feed_dict = {
                self.obs_t: obs_batch,
                self.act_qvar: q_dist_s.eval(),
                self.act_policy: dist_s.eval(),
                self.lr: lr
            }

            loss_emp, _ = tf.get_default_session().run(
                [self.loss_emp, self.step_emp], feed_dict=feed_dict)
            it.record('loss_emp', loss_emp)
            if it.heartbeat:
                print(it.itr_message())
                mean_loss_emp = it.pop_mean('loss_emp')
                print('\tLoss_emp:%f' % mean_loss_emp)

        return mean_loss_emp
Exemplo n.º 2
0
def tabular_maxent_irl(env,
                       demo_visitations,
                       num_itrs=50,
                       ent_wt=1.0,
                       lr=1e-3,
                       state_only=False,
                       discount=0.99,
                       T=5):
    dim_obs = env.observation_space.flat_dim
    dim_act = env.action_space.flat_dim

    # Initialize policy and reward function
    reward_fn = np.zeros((dim_obs, dim_act))
    q_rew = np.zeros((dim_obs, dim_act))

    update = adam_optimizer(lr)

    for it in TrainingIterator(num_itrs, heartbeat=1.0):
        q_itrs = 20 if it.itr > 5 else 100
        ### compute policy in closed form
        q_rew = q_iteration(env,
                            reward_matrix=reward_fn,
                            ent_wt=ent_wt,
                            warmstart_q=q_rew,
                            K=q_itrs,
                            gamma=discount)

        ### update reward
        # need to count how often the policy will visit a particular (s, a) pair
        pol_visitations = compute_visitation(env,
                                             q_rew,
                                             ent_wt=ent_wt,
                                             T=T,
                                             discount=discount)

        grad = -(demo_visitations - pol_visitations)
        it.record('VisitationInfNorm', np.max(np.abs(grad)))
        if state_only:
            grad = np.sum(grad, axis=1, keepdims=True)
        reward_fn = update(reward_fn, grad)

        if it.heartbeat:
            print(it.itr_message())
            print('\t', it.pop_mean('VisitationInfNorm'))
    return reward_fn, q_rew
Exemplo n.º 3
0
    def fit(self, paths, batch_size=32, logger=None, lr=1e-3,**kwargs):

        if self.fusion is not None:
            old_paths = self.fusion.sample_paths(n=len(paths))
            self.fusion.add_paths(paths)
            paths = paths+old_paths


        obs, obs_next, acts = \
            self.extract_paths(paths,
                               keys=('observations', 'observations_next', 'actions'))
        '''expert_obs, expert_obs_next, expert_acts = \
            self.extract_paths(self.expert_trajs,
                               keys=('observations', 'observations_next', 'actions'))'''


        # Train discriminator
        for it in TrainingIterator(self.max_itrs, heartbeat=5):
            nobs_batch, obs_batch, act_batch = \
                self.sample_batch(obs_next, obs, acts, batch_size=batch_size)


            feed_dict = {
                self.act_t: act_batch,
                self.obs_t: obs_batch,
                self.nobs_t: nobs_batch,
                self.lr: lr
                }

            loss_q, _ = tf.get_default_session().run([self.loss_q, self.step_q], feed_dict=feed_dict)
            it.record('loss_q', loss_q)
            if it.heartbeat:
                mean_loss_q = it.pop_mean('loss_q')
                print('\tLoss_q:%f' % mean_loss_q)


        return mean_loss_q
Exemplo n.º 4
0
    def fit(self, paths, policy=None, batch_size=32, logger=None, lr=1e-3,**kwargs):

        if self.fusion is not None:
            old_paths = self.fusion.sample_paths(n=len(paths))
            self.fusion.add_paths(paths)
            paths = paths+old_paths

        # eval samples under current policy
        self._compute_path_probs(paths, insert=True)

        # eval expert log probs under current policy
        self.eval_expert_probs(self.expert_trajs, policy, insert=True)

        self._insert_next_state(paths)
        self._insert_next_state(self.expert_trajs)
        obs, obs_next, acts, acts_next, path_probs = \
            self.extract_paths(paths,
                               keys=('observations', 'observations_next', 'actions', 'actions_next', 'a_logprobs'))
        expert_obs, expert_obs_next, expert_acts, expert_acts_next, expert_probs = \
            self.extract_paths(self.expert_trajs,
                               keys=('observations', 'observations_next', 'actions', 'actions_next', 'a_logprobs'))


        # Train discriminator
        for it in TrainingIterator(self.max_itrs, heartbeat=5):
            nobs_batch, obs_batch, nact_batch, act_batch, lprobs_batch = \
                self.sample_batch(obs_next, obs, acts_next, acts, path_probs, batch_size=batch_size)

            nexpert_obs_batch, expert_obs_batch, nexpert_act_batch, expert_act_batch, expert_lprobs_batch = \
                self.sample_batch(expert_obs_next, expert_obs, expert_acts_next, expert_acts, expert_probs, batch_size=batch_size)

            # Build feed dict
            labels = np.zeros((batch_size*2, 1))
            labels[batch_size:] = 1.0
            obs_batch = np.concatenate([obs_batch, expert_obs_batch], axis=0)
            nobs_batch = np.concatenate([nobs_batch, nexpert_obs_batch], axis=0)
            act_batch = np.concatenate([act_batch, expert_act_batch], axis=0)
            nact_batch = np.concatenate([nact_batch, nexpert_act_batch], axis=0)
            lprobs_batch = np.expand_dims(np.concatenate([lprobs_batch, expert_lprobs_batch], axis=0), axis=1).astype(np.float32)
            feed_dict = {
                self.act_t: act_batch,
                self.obs_t: obs_batch,
                self.nobs_t: nobs_batch,
                self.nact_t: nact_batch,
                self.labels: labels,
                self.lprobs: lprobs_batch,
                self.lr: lr
                }

            loss, _ = tf.get_default_session().run([self.loss, self.step], feed_dict=feed_dict)
            it.record('loss', loss)
            if it.heartbeat:
                print(it.itr_message())
                mean_loss = it.pop_mean('loss')
                print('\tLoss:%f' % mean_loss)

        if logger:
            logger.record_tabular('GCLDiscrimLoss', mean_loss)
            #obs_next = np.r_[obs_next, np.expand_dims(obs_next[-1], axis=0)]
            energy, logZ, dtau = tf.get_default_session().run([self.reward, self.value_fn, self.discrim_output],
                                                               feed_dict={self.act_t: acts, self.obs_t: obs, self.nobs_t: obs_next,
                                                                   self.nact_t: acts_next,
                                                               self.lprobs: np.expand_dims(path_probs, axis=1)})
            energy = -energy
            logger.record_tabular('GCLLogZ', np.mean(logZ))
            logger.record_tabular('GCLAverageEnergy', np.mean(energy))
            logger.record_tabular('GCLAverageLogPtau', np.mean(-energy-logZ))
            logger.record_tabular('GCLAverageLogQtau', np.mean(path_probs))
            logger.record_tabular('GCLMedianLogQtau', np.median(path_probs))
            logger.record_tabular('GCLAverageDtau', np.mean(dtau))


            #expert_obs_next = np.r_[expert_obs_next, np.expand_dims(expert_obs_next[-1], axis=0)]
            energy, logZ, dtau = tf.get_default_session().run([self.reward, self.value_fn, self.discrim_output],
                    feed_dict={self.act_t: expert_acts, self.obs_t: expert_obs, self.nobs_t: expert_obs_next,
                                    self.nact_t: expert_acts_next,
                                    self.lprobs: np.expand_dims(expert_probs, axis=1)})
            energy = -energy
            logger.record_tabular('GCLAverageExpertEnergy', np.mean(energy))
            logger.record_tabular('GCLAverageExpertLogPtau', np.mean(-energy-logZ))
            logger.record_tabular('GCLAverageExpertLogQtau', np.mean(expert_probs))
            logger.record_tabular('GCLMedianExpertLogQtau', np.median(expert_probs))
            logger.record_tabular('GCLAverageExpertDtau', np.mean(dtau))
        return mean_loss
Exemplo n.º 5
0
    def fit(self,
            paths,
            expert_traj_batch=None,
            policy=None,
            batch_size=32,
            logger=None,
            lr=1e-3,
            **kwargs):
        meta_batch_size = self.meta_batch_size
        if self.fusion is not None:
            old_paths = self.fusion.sample_paths(expert_traj_batch,
                                                 n=len(paths[0]))
            self.fusion.add_paths(paths, expert_traj_batch, subsample=True)
            if old_paths is not None:
                for key in paths.keys():
                    paths[key] += old_paths[key]

        # Do we need to recalculate path probabilities every iteration since context encoderis being updated?
        # eval samples under current policy
        # TODO: fix this with dict
        self._compute_path_probs_dict(paths, insert=True)

        self._insert_next_state(paths)
        self._insert_next_state(self.expert_trajs)

        obs, obs_next, acts, acts_next, path_probs = \
            self.extract_paths(paths,
                               keys=('observations', 'observations_next', 'actions', 'actions_next', 'a_logprobs'), T=self.T)
        # TODO: we may need to assume that expert_trajs is also a dict
        expert_obs, expert_obs_next, expert_acts, expert_acts_next, expert_contexts = \
            self.extract_paths(self.expert_trajs,
                               keys=('observations', 'observations_next', 'actions', 'actions_next', 'contexts'), T=self.T)

        # eval expert log probs under current policy
        expert_trajs = np.concatenate([expert_obs, expert_acts], axis=-1)
        m_hat_expert = self.context_encoder.get_actions(
            expert_trajs.reshape(-1, self.T * (self.dO + self.dU)))[0]
        self.eval_expert_probs(self.expert_trajs,
                               policy,
                               insert=True,
                               context=m_hat_expert)

        expert_probs = self.extract_paths(self.expert_trajs,
                                          keys=('a_logprobs', ),
                                          T=self.T)[0]

        # Train discriminator
        expert_traj_batch_tile = np.tile(
            expert_traj_batch.reshape(meta_batch_size, 1, self.T, -1),
            [1, batch_size, 1, 1])

        for it in TrainingIterator(self.max_itrs, heartbeat=5):
            # TODO: implement sample_batch in imitation_learning.py
            nobs_batch, obs_batch, nact_batch, act_batch, lprobs_batch = \
                self.sample_batch(obs_next, obs, acts_next, acts, path_probs, batch_size=batch_size)
            if obs_batch.shape[-1] == self.dO + self.latent_dim:
                nobs_batch = nobs_batch[..., :-self.latent_dim]
                obs_batch = obs_batch[..., :-self.latent_dim]

            # First half of the batch is used for inferring m_hat
            nexpert_obs_batch, expert_obs_batch, nexpert_act_batch, expert_act_batch, expert_lprobs_batch = \
                self.sample_batch(expert_obs_next, expert_obs, expert_acts_next, expert_acts, expert_probs, batch_size=self.meta_batch_size*batch_size)
            if expert_obs_batch.shape[-1] == self.dO + self.latent_dim:
                nexpert_obs_batch = nexpert_obs_batch[..., :-self.latent_dim]
                expert_obs_batch = expert_obs_batch[..., :-self.latent_dim]

            # Build feed dict
            labels = np.zeros((meta_batch_size, batch_size * 2, 1, 1))
            labels[:, batch_size:, ...] = 1.0
            imitation_expert_obses_input = expert_traj_batch.reshape(
                meta_batch_size, 1, self.T, -1)[:, :, :, :self.dO]
            imitation_expert_acts_input = expert_traj_batch.reshape(
                meta_batch_size, 1, self.T, -1)[:, :, :, self.dO:]
            expert_traj_batch_input = np.concatenate([
                expert_traj_batch_tile,
                np.concatenate(
                    [expert_obs_batch, expert_act_batch], axis=-1).reshape(
                        meta_batch_size, batch_size, self.T, -1)
            ],
                                                     axis=1)
            sample_traj_batch = np.concatenate([obs_batch, act_batch], axis=-1)
            obs_batch = np.concatenate([
                obs_batch,
                expert_obs_batch.reshape(meta_batch_size, batch_size, self.T,
                                         -1)
            ],
                                       axis=1)
            nobs_batch = np.concatenate([
                nobs_batch,
                nexpert_obs_batch.reshape(meta_batch_size, batch_size, self.T,
                                          -1)
            ],
                                        axis=1)
            act_batch = np.concatenate([
                act_batch,
                expert_act_batch.reshape(meta_batch_size, batch_size, self.T,
                                         -1)
            ],
                                       axis=1)
            nact_batch = np.concatenate([
                nact_batch,
                nexpert_act_batch.reshape(meta_batch_size, batch_size, self.T,
                                          -1)
            ],
                                        axis=1)
            lprobs_batch = np.concatenate([
                lprobs_batch,
                expert_lprobs_batch.reshape(meta_batch_size, batch_size,
                                            self.T, -1)
            ],
                                          axis=1).astype(np.float32)
            feed_dict = {
                self.expert_traj_var: expert_traj_batch_input,
                self.sample_traj_var: sample_traj_batch,
                self.act_t: act_batch,
                self.obs_t: obs_batch,
                self.nobs_t: nobs_batch,
                self.nact_t: nact_batch,
                self.labels: labels,
                self.lprobs: lprobs_batch,
                self.imitation_expert_obses: imitation_expert_obses_input,
                self.imitation_expert_acts: imitation_expert_acts_input,
                self.lr: lr
            }

            loss, _ = tf.get_default_session().run([self.loss, self.step],
                                                   feed_dict=feed_dict)
            it.record('loss', loss)
            if it.heartbeat:
                print(it.itr_message())
                mean_loss = it.pop_mean('loss')
                print('\tLoss:%f' % mean_loss)

        if logger:
            logger.record_tabular('GCLDiscrimLoss', mean_loss)
            #obs_next = np.r_[obs_next, np.expand_dims(obs_next[-1], axis=0)]
            # TODO: fix this
            expert_traj_logging = np.tile(
                expert_traj_batch.reshape(meta_batch_size, 1, self.T, -1),
                [1, acts.shape[1], 1, 1])
            imitation_expert_obses_input = expert_traj_batch.reshape(
                meta_batch_size, 1, self.T, -1)[:, :, :, :self.dO]
            imitation_expert_acts_input = expert_traj_batch.reshape(
                meta_batch_size, 1, self.T, -1)[:, :, :, self.dO:]

            energy, logZ, dtau, info_loss, imit_loss = tf.get_default_session(
            ).run(
                [
                    self.reward, self.value_fn, self.discrim_output,
                    self.info_loss, self.policy_likelihood_loss
                ],
                feed_dict={
                    self.expert_traj_var:
                    expert_traj_logging,
                    self.sample_traj_var:
                    np.concatenate([obs[..., :-self.latent_dim], acts],
                                   axis=-1),
                    self.act_t:
                    acts,
                    self.obs_t:
                    obs[..., :-self.latent_dim],
                    self.nobs_t:
                    obs_next[..., :-self.latent_dim],
                    self.nact_t:
                    acts_next,
                    self.imitation_expert_obses:
                    imitation_expert_obses_input,
                    self.imitation_expert_acts:
                    imitation_expert_acts_input,
                    self.labels:
                    np.zeros([meta_batch_size, acts.shape[1], 1, 1]),
                    self.lprobs:
                    path_probs
                })
            energy = -energy
            logger.record_tabular('GCLLogZ', np.mean(logZ))
            logger.record_tabular('GCLAverageEnergy', np.mean(energy))
            logger.record_tabular('GCLAverageLogPtau', np.mean(-energy - logZ))
            logger.record_tabular('GCLAverageLogQtau', np.mean(path_probs))
            logger.record_tabular('GCLMedianLogQtau', np.median(path_probs))
            logger.record_tabular('GCLAverageDtau', np.mean(dtau))
            logger.record_tabular('GCLAverageMutualInfo', np.mean(info_loss))
            logger.record_tabular('GCLAverageImitationLoss',
                                  np.mean(imit_loss))

            #expert_obs_next = np.r_[expert_obs_next, np.expand_dims(expert_obs_next[-1], axis=0)]
            # Not sure if using expert trajectories for expert_traj_var and sample_traj_var makes sense
            # energy, logZ, dtau, info_loss = tf.get_default_session().run([self.reward, self.value_fn, self.discrim_output, self.info_loss],
            #         feed_dict={self.expert_traj_var: np.concatenate([expert_obs, expert_acts], axis=-1),
            #                     self.sample_traj_var: np.concatenate([expert_obs, expert_acts], axis=-1),
            #                         self.act_t: expert_acts, self.obs_t: expert_obs, self.nobs_t: expert_obs_next,
            #                         self.nact_t: expert_acts_next,
            #                         self.labels: np.zeros([meta_batch_size, acts.shape[1], 1, 1]),
            #                         self.lprobs: expert_probs})
            # energy = -energy
            # logger.record_tabular('GCLAverageExpertEnergy', np.mean(energy))
            # logger.record_tabular('GCLAverageExpertLogPtau', np.mean(-energy-logZ))
            # logger.record_tabular('GCLAverageExpertLogQtau', np.mean(expert_probs))
            # logger.record_tabular('GCLMedianExpertLogQtau', np.median(expert_probs))
            # logger.record_tabular('GCLAverageExpertDtau', np.mean(dtau))
            # logger.record_tabular('GCLAverageExpertMutualInfo', np.mean(info_loss))
        return mean_loss
Exemplo n.º 6
0
def tabular_gcl_irl(env,
                    demo_visitations,
                    irl_model,
                    num_itrs=50,
                    ent_wt=1.0,
                    lr=1e-3,
                    state_only=False,
                    discount=0.99,
                    batch_size=20024):
    dim_obs = env.observation_space.flat_dim
    dim_act = env.action_space.flat_dim

    states_all = []
    actions_all = []
    for s in range(dim_obs):
        for a in range(dim_act):
            states_all.append(flat_to_one_hot(s, dim_obs))
            actions_all.append(flat_to_one_hot(a, dim_act))
    states_all = np.array(states_all)
    actions_all = np.array(actions_all)
    path_all = {'observations': states_all, 'actions': actions_all}

    # Initialize policy and reward function
    reward_fn = np.zeros((dim_obs, dim_act))
    q_rew = np.zeros((dim_obs, dim_act))

    update = adam_optimizer(lr)

    for it in TrainingIterator(num_itrs, heartbeat=1.0):
        q_itrs = 20 if it.itr > 5 else 100
        ### compute policy in closed form
        q_rew = q_iteration(env,
                            reward_matrix=reward_fn,
                            ent_wt=ent_wt,
                            warmstart_q=q_rew,
                            K=q_itrs,
                            gamma=discount)
        pol_rew = get_policy(q_rew, ent_wt=ent_wt)

        ### update reward
        # need to count how often the policy will visit a particular (s, a) pair
        pol_visitations = compute_visitation(env,
                                             q_rew,
                                             ent_wt=ent_wt,
                                             T=5,
                                             discount=discount)

        # now we need to sample states and actions, and give them to the discriminator
        demo_path = sample_states(env, q_rew, demo_visitations, batch_size,
                                  ent_wt)
        irl_model.set_demos([demo_path])
        path = sample_states(env, q_rew, pol_visitations, batch_size, ent_wt)
        irl_model.fit([path],
                      policy=pol_rew,
                      max_itrs=200,
                      lr=1e-3,
                      batch_size=1024)

        rew_stack = irl_model.eval([path_all])[0]
        reward_fn = np.zeros_like(q_rew)
        i = 0
        for s in range(dim_obs):
            for a in range(dim_act):
                reward_fn[s, a] = rew_stack[i]
                i += 1

        diff_visit = np.abs(demo_visitations - pol_visitations)
        it.record('VisitationDiffInfNorm', np.max(diff_visit))
        it.record('VisitationDiffAvg', np.mean(diff_visit))

        if it.heartbeat:
            print(it.itr_message())
            print('\tVisitationDiffInfNorm:',
                  it.pop_mean('VisitationDiffInfNorm'))
            print('\tVisitationDiffAvg:', it.pop_mean('VisitationDiffAvg'))

            print('visitations', pol_visitations)
            print('diff_visit', diff_visit)
            adjusted_rew = reward_fn - np.mean(reward_fn) + np.mean(
                env.rew_matrix)
            print('adjusted_rew', adjusted_rew)
    return reward_fn, q_rew
Exemplo n.º 7
0
    def fit(self, paths, policy=None,empw_model=None,t_empw_model=None, batch_size=32, logger=None, lr=1e-3,**kwargs):

        if self.fusion is not None:
            old_paths = self.fusion.sample_paths(n=len(paths))
            self.fusion.add_paths(paths)
            paths = paths+old_paths

        # eval samples under current policy
        self._compute_path_probs(paths, insert=True)

        # eval expert log probs under current policy
        self.eval_expert_probs(self.expert_trajs, policy, insert=True)

        self._insert_next_state(paths)
        self._insert_next_state(self.expert_trajs)
        obs, obs_next, acts, acts_next, path_probs = \
            self.extract_paths(paths,
                               keys=('observations', 'observations_next', 'actions', 'actions_next', 'a_logprobs'))
        expert_obs, expert_obs_next, expert_acts, expert_acts_next, expert_probs = \
            self.extract_paths(self.expert_trajs,
                               keys=('observations', 'observations_next', 'actions', 'actions_next', 'a_logprobs'))


        # Train discriminator
        for it in TrainingIterator(self.max_itrs, heartbeat=5):
            nobs_batch, obs_batch, nact_batch, act_batch, lprobs_batch = \
                self.sample_batch(obs_next, obs, acts_next, acts, path_probs, batch_size=batch_size)

            nexpert_obs_batch, expert_obs_batch, nexpert_act_batch, expert_act_batch, expert_lprobs_batch = \
                self.sample_batch(expert_obs_next, expert_obs, expert_acts_next, expert_acts, expert_probs, batch_size=batch_size)

            # Build feed dict
            labels = np.zeros((batch_size*2, 1))
            labels[batch_size:] = 1.0
            obs_batch = np.concatenate([obs_batch, expert_obs_batch], axis=0)
            nobs_batch = np.concatenate([nobs_batch, nexpert_obs_batch], axis=0)
            act_batch = np.concatenate([act_batch, expert_act_batch], axis=0)
            nact_batch = np.concatenate([nact_batch, nexpert_act_batch], axis=0)
            lprobs_batch = np.expand_dims(np.concatenate([lprobs_batch, expert_lprobs_batch], axis=0), axis=1).astype(np.float32)
            vs=empw_model.eval(obs_batch)
            vsp=t_empw_model.eval(nobs_batch)
            feed_dict = {
                self.act_t: act_batch,
                self.obs_t: obs_batch,
                self.nobs_t: nobs_batch,
                self.nact_t: nact_batch,
                self.labels: labels,
                self.lprobs: lprobs_batch,
                self.lr: lr,
                self.vs: vs,
                self.vsp: vsp
                }


            loss_irl, _ = tf.get_default_session().run([self.loss_irl, self.step_irl], feed_dict=feed_dict)
            it.record('loss_irl', loss_irl)

            if it.heartbeat:
                print(it.itr_message())
                mean_loss_irl = it.pop_mean('loss_irl')
                print('\tLoss_irl:%f' % mean_loss_irl)

         

        return mean_loss_irl
Exemplo n.º 8
0
    def fit(self,
            paths,
            policy=None,
            batch_size=32,
            logger=None,
            lr=1e-3,
            **kwargs):

        if self.fusion is not None:
            old_paths = self.fusion.sample_paths(n=len(paths))
            self.fusion.add_paths(paths)
            paths = paths + old_paths

        # eval samples under current policy
        self._compute_path_probs(paths, insert=True)

        # eval expert log probs under current policy
        self.eval_expert_probs(self.expert_trajs, policy, insert=True)

        self._reorganize_states(paths, number_obs=self.max_nstep)
        self._reorganize_states(self.expert_trajs, number_obs=self.max_nstep)
        obs, obs_next, acts, acts_next, path_probs = \
            self.extract_paths(paths,
                               keys=('multi_observations', 'observations_next', 'multi_actions', 'actions_next', 'a_logprobs'))
        expert_obs, expert_obs_next, expert_acts, expert_acts_next, expert_probs = \
            self.extract_paths(self.expert_trajs,
                               keys=('multi_observations', 'observations_next', 'multi_actions', 'actions_next', 'a_logprobs'))

        all_obs = np.concatenate([obs, expert_obs], axis=0)
        all_nobs = np.concatenate([obs_next, expert_obs_next], axis=0)
        all_acts = np.concatenate([acts, expert_acts], axis=0)
        all_nacts = np.concatenate([acts_next, expert_acts_next], axis=0)
        all_probs = np.concatenate([path_probs, expert_probs], axis=0)
        all_labels = np.zeros((all_obs.shape[0], 1))
        all_labels[obs.shape[0]:] = 1.0

        # Train discriminator
        for it in TrainingIterator(self.max_itrs, heartbeat=5):

            if self.n_rew_funct < self.n_value_funct:
                delta = self.n_value_funct - self.n_rew_funct
                temp = np.arange(self.n_rew_funct)
                np.random.shuffle(temp)
                rew_idxs = np.r_[
                    temp,
                    np.random.randint(self.n_rew_funct, size=delta)]
                val_idxs = np.arange(self.n_value_funct)
            else:
                delta = self.n_rew_funct - self.n_value_funct
                temp = np.arange(self.n_value_funct)
                np.random.shuffle(temp)
                val_idxs = np.r_[
                    temp,
                    np.random.randint(self.n_value_funct, size=delta)]
                rew_idxs = np.arange(self.n_rew_funct)

            for idx in range(val_idxs.shape[0]):
                i = rew_idxs[idx]
                j = val_idxs[idx]
                for single_nstep in range(1, self.max_nstep + 1):
                    nobs_batch, obs_batch, nact_batch, act_batch, lprobs_batch = \
                        self.sample_batch(obs_next, obs, acts_next, acts, path_probs, batch_size=batch_size)

                    nexpert_obs_batch, expert_obs_batch, nexpert_act_batch, expert_act_batch, expert_lprobs_batch = \
                        self.sample_batch(expert_obs_next, expert_obs, expert_acts_next, expert_acts, expert_probs, batch_size=batch_size)

                    # Build feed dict
                    labels = np.zeros((batch_size * 2, 1))
                    labels[batch_size:] = 1.0
                    obs_batch = np.concatenate([obs_batch, expert_obs_batch],
                                               axis=0)
                    nobs_batch = np.concatenate(
                        [nobs_batch, nexpert_obs_batch], axis=0)
                    act_batch = np.concatenate([act_batch, expert_act_batch],
                                               axis=0)
                    nact_batch = np.concatenate(
                        [nact_batch, nexpert_act_batch], axis=0)
                    lprobs_batch = np.expand_dims(np.concatenate(
                        [lprobs_batch, expert_lprobs_batch], axis=0),
                                                  axis=1).astype(np.float32)
                    feed_dict = {
                        self.act_t:
                        act_batch[:, :single_nstep],
                        self.obs_t:
                        obs_batch[:, :single_nstep],
                        self.nobs_t:
                        nobs_batch if self.max_nstep == single_nstep else
                        obs_batch[:, single_nstep],
                        self.nact_t:
                        nact_batch if self.max_nstep == single_nstep else
                        act_batch[:, single_nstep],
                        self.labels:
                        labels,
                        self.lprobs:
                        lprobs_batch,
                        self.lr:
                        lr
                    }

                    loss, _ = tf.get_default_session().run(
                        [self.loss[i][j], self.step[i][j]],
                        feed_dict=feed_dict)
                    it.record('loss', loss)

            if self.score_discrim is False and self.score_method == 'teacher_student':
                nobs_batch, obs_batch, nact_batch, act_batch, lprobs_batch = \
                    self.sample_batch(obs_next, obs, acts_next, acts, path_probs, batch_size=batch_size)

                nexpert_obs_batch, expert_obs_batch, nexpert_act_batch, expert_act_batch, expert_lprobs_batch = \
                    self.sample_batch(expert_obs_next, expert_obs, expert_acts_next, expert_acts, expert_probs, batch_size=batch_size)

                # Build feed dict
                labels = np.zeros((batch_size * 2, 1))
                labels[batch_size:] = 1.0
                obs_batch = np.concatenate([obs_batch, expert_obs_batch],
                                           axis=0)
                nobs_batch = np.concatenate([nobs_batch, nexpert_obs_batch],
                                            axis=0)
                act_batch = np.concatenate([act_batch, expert_act_batch],
                                           axis=0)
                nact_batch = np.concatenate([nact_batch, nexpert_act_batch],
                                            axis=0)
                lprobs_batch = np.expand_dims(np.concatenate(
                    [lprobs_batch, expert_lprobs_batch], axis=0),
                                              axis=1).astype(np.float32)
                feed_dict = {
                    self.act_t: act_batch,
                    self.obs_t: obs_batch,
                    self.nobs_t: nobs_batch,
                    self.nact_t: nact_batch,
                    self.labels: labels,
                    self.lprobs: lprobs_batch,
                    self.lr: lr
                }

                rel_loss, abs_loss, ens_loss, _ = tf.get_default_session().run(
                    [
                        self.student_loss, self.student_absolute_loss,
                        self.ensemble_loss, self.student_step
                    ],
                    feed_dict=feed_dict)

            if it.heartbeat:
                print(it.itr_message())
                mean_loss = it.pop_mean('loss')
                print('\tLoss:%f' % mean_loss)

        if logger:
            logger.record_tabular('GCLMeanDiscrimLoss', mean_loss)
            sess = tf.get_default_session()
            if self.score_discrim is False and self.score_method == 'teacher_student':
                logger.record_tabular('GCLStudentRelativeLoss', rel_loss)
                logger.record_tabular('GCLStudentAbsoluteLoss', abs_loss)
                logger.record_tabular('GCLEnsembleLoss', ens_loss)
            else:
                e_loss, weights = sess.run(
                    [self.ensemble_loss, self.weights],
                    feed_dict={
                        self.act_t: all_acts,
                        self.obs_t: all_obs,
                        self.nobs_t: all_nobs,
                        self.nact_t: all_nacts,
                        self.labels: all_labels,
                        self.lprobs: np.expand_dims(all_probs, axis=1)
                    })
                logger.record_tabular('GCLEnsembleDiscrimLoss', e_loss)
                # logger.record_tabular('TimeWeights', weights)

            if self.score_discrim is False and self.score_method == 'teacher_student':
                energy, logZ, dtau, s_rew, s_val, s_dtau = sess.run([self.t0_reward, self.t0_value, self.ensemble_discrim_output, \
                                                                    self.student_reward, self.student_value, self.student_discrim_output],
                                                                feed_dict={self.act_t: acts, self.obs_t: obs, self.nobs_t: obs_next,
                                                                    self.nact_t: acts_next,
                                                                self.lprobs: np.expand_dims(path_probs, axis=1)})
            else:
                energy, logZ, dtau = sess.run(
                    [
                        self.t0_reward, self.t0_value,
                        self.ensemble_discrim_output
                    ],
                    feed_dict={
                        self.act_t: acts,
                        self.obs_t: obs,
                        self.nobs_t: obs_next,
                        self.nact_t: acts_next,
                        self.lprobs: np.expand_dims(path_probs, axis=1)
                    })

            energy = -energy
            logger.record_tabular('GCLLogZ', np.mean(logZ))
            logger.record_tabular('GCLAverageEnergy', np.mean(energy))
            logger.record_tabular('GCLAverageLogPtau', np.mean(-energy - logZ))
            logger.record_tabular('GCLAverageLogQtau', np.mean(path_probs))
            logger.record_tabular('GCLMedianLogQtau', np.median(path_probs))
            logger.record_tabular('GCLAverageDtau', np.mean(dtau))

            if self.score_discrim is False and self.score_method == 'teacher_student':
                logger.record_tabular('GCLAverageStudentEnergy',
                                      np.mean(-s_rew))
                logger.record_tabular('GCLAverageStudentLogPtau',
                                      np.mean(s_rew - s_val))
                logger.record_tabular('GCLAverageStudentDtau', np.mean(s_dtau))

            if self.score_discrim is False and self.score_method == 'teacher_student':
                energy, logZ, dtau, s_rew, s_val, s_dtau = sess.run([self.t0_reward, self.t0_value, self.ensemble_discrim_output, \
                                                                    self.student_reward, self.student_value, self.student_discrim_output],
                        feed_dict={self.act_t: expert_acts, self.obs_t: expert_obs, self.nobs_t: expert_obs_next,
                                        self.nact_t: expert_acts_next,
                                        self.lprobs: np.expand_dims(expert_probs, axis=1)})
            else:
                energy, logZ, dtau = sess.run(
                    [
                        self.t0_reward, self.t0_value,
                        self.ensemble_discrim_output
                    ],
                    feed_dict={
                        self.act_t: expert_acts,
                        self.obs_t: expert_obs,
                        self.nobs_t: expert_obs_next,
                        self.nact_t: expert_acts_next,
                        self.lprobs: np.expand_dims(expert_probs, axis=1)
                    })

            energy = -energy
            logger.record_tabular('GCLAverageExpertEnergy', np.mean(energy))
            logger.record_tabular('GCLAverageExpertLogPtau',
                                  np.mean(-energy - logZ))
            logger.record_tabular('GCLAverageExpertLogQtau',
                                  np.mean(expert_probs))
            logger.record_tabular('GCLMedianExpertLogQtau',
                                  np.median(expert_probs))
            logger.record_tabular('GCLAverageExpertDtau', np.mean(dtau))

            if self.score_discrim is False and self.score_method == 'teacher_student':
                logger.record_tabular('GCLAverageStudentExpertEnergy',
                                      np.mean(-s_rew))
                logger.record_tabular('GCLAverageStudentExpertLogPtau',
                                      np.mean(s_rew - s_val))
                logger.record_tabular('GCLAverageStudentExpertDtau',
                                      np.mean(s_dtau))

        return mean_loss
Exemplo n.º 9
0
    def fit(self,
            paths,
            policy=None,
            batch_size=32,
            logger=None,
            lr=1e-3,
            last_timestep_only=False,
            max_itrs=100,
            **kwargs):

        if self.frozen:
            return 0

        if self.fusion is not None:
            old_paths = self.fusion.sample_paths(n=len(paths))
            self.fusion.add_paths(paths)
            paths = paths + old_paths
            # log fusion stats
            fstats = self.fusion.compute_age_stats()
            logger.record_tabular('FusionAgeMean', fstats['mean'])
            logger.record_tabular('FusionAgeMed', fstats['med'])
            logger.record_tabular('FusionAgeStd', fstats['std'])
            logger.record_tabular('FusionAgeMax', fstats['max'])
            logger.record_tabular('FusionAgeMin', fstats['min'])
            logger.record_tabular('FusionAgePFresh', fstats['pfresh'])

        self._compute_path_probs(paths, insert=True)

        # self.eval_expert_probs(paths, policy, insert=True)
        for traj in self.expert_trajs:
            if 'agent_infos' in traj:
                # print('deleting agent_infos')
                del traj['agent_infos']
            if 'a_logprobs' in traj:
                del traj['a_logprobs']
        self.eval_expert_probs(self.expert_trajs, policy, insert=True)

        self._insert_next_state(paths)
        self._insert_next_state(self.expert_trajs)

        obs, obs_next, acts, acts_next, path_probs = self.extract_paths2(
            paths,
            keys=('observations', 'observations_next', 'actions',
                  'actions_next', 'a_logprobs'),
            last_timestep_only=last_timestep_only)
        (expert_obs, expert_obs_next, expert_acts, expert_acts_next,
         expert_probs) = self.extract_paths2(
             self.expert_trajs,
             keys=('observations', 'observations_next', 'actions',
                   'actions_next', 'a_logprobs'),
             last_timestep_only=last_timestep_only)

        # Train discriminator
        for it in TrainingIterator(max_itrs, heartbeat=5):
            nobs_batch, obs_batch, nact_batch, act_batch, lprobs_batch = \
                self.sample_batch(obs_next, obs, acts_next, acts, path_probs,
                                  batch_size=batch_size)

            (nexpert_obs_batch, expert_obs_batch, nexpert_act_batch,
             expert_act_batch,
             expert_lprobs_batch) = self.sample_batch(expert_obs_next,
                                                      expert_obs,
                                                      expert_acts_next,
                                                      expert_acts,
                                                      expert_probs,
                                                      batch_size=batch_size)

            labels = np.zeros((batch_size * 2, 1))
            labels[batch_size:] = 1.0
            obs_batch = np.concatenate([obs_batch, expert_obs_batch], axis=0)
            nobs_batch = np.concatenate([nobs_batch, nexpert_obs_batch],
                                        axis=0)
            act_batch = np.concatenate([act_batch, expert_act_batch], axis=0)
            nact_batch = np.concatenate([nact_batch, nexpert_act_batch],
                                        axis=0)
            lprobs_batch = np.expand_dims(np.concatenate(
                [lprobs_batch, expert_lprobs_batch], axis=0),
                                          axis=1).astype(np.float32)

            learn_step_feed_dict = {
                self.act_t: act_batch,
                self.obs_t: obs_batch,
                self.nobs_t: nobs_batch,
                self.nact_t: nact_batch,
                self.labels: labels,
                self.lprobs: lprobs_batch,
                self.lr: lr,
                # we only enable noise during training
                self.is_train_t: True,
            }
            sess = tf.get_default_session()
            loss, tot_kl, _ = sess.run(
                [self.loss, self.tot_kl_loss, self.step],
                feed_dict=learn_step_feed_dict)
            if self.vairl and self.vairl_adaptive_beta:
                beta, _ = sess.run(
                    [self.vairl_beta, self.vairl_beta_update_op],
                    feed_dict={self.vairl_mean_kl: tot_kl})

            it.record('loss', loss)
            it.record('tot_kl', tot_kl)
            if self.vairl and self.vairl_adaptive_beta:
                it.record('beta', beta)
            if it.heartbeat:
                print(it.itr_message())
                mean_loss = it.pop_mean('loss')
                print('\tLoss:%f' % mean_loss)
                mean_tot_kl = it.pop_mean('tot_kl')
                print('\tKL:%f' % mean_tot_kl)
                if self.vairl and self.vairl_adaptive_beta:
                    mean_beta = it.pop_mean('beta')
                    print('\tBeta:%f' % mean_beta)

        if logger:
            logger.record_tabular('GCLDiscrimLoss', mean_loss)
            # the 'DiscrimVAIRLKL' one is just retained so I don't break my
            # parsing scripts :)
            logger.record_tabular('GCLDiscrimVAIRLKL', mean_tot_kl)
            logger.record_tabular('GCLVAIRLKL', mean_tot_kl)
            if self.vairl and self.vairl_adaptive_beta:
                logger.record_tabular('GCLVAIRLBeta', mean_beta)
            # obs_next = np.r_[obs_next, np.expand_dims(obs_next[-1], axis=0)]
            # logZ,
            for is_train in [True, False]:
                # make sure to keep stats about test-mode configuration as well
                # as train-mode configuration, in case we have something like
                # dropout or VDB noise that affects discriminator results
                prefix = '' if is_train else 'NotIsTrain'
                fake_in_dict = {
                    'energy': self.energy,
                    'logZ': self.value_fn,
                    'dtau_fake': self.d_tau
                }
                real_in_dict = {
                    'energy': self.energy,
                    'logZ': self.value_fn,
                    'dtau_real': self.d_tau
                }
                if self.gp_value is not None:
                    fake_in_dict['gp_value'] = real_in_dict[
                        'gp_value'] = self.gp_value
                fake_out_dict = tf.get_default_session().run(
                    fake_in_dict,
                    feed_dict={
                        self.act_t: acts,
                        self.obs_t: obs,
                        self.nobs_t: obs_next,
                        self.nact_t: acts_next,
                        self.lprobs: np.expand_dims(path_probs, axis=1),
                        self.is_train_t: is_train,
                        self.labels: np.zeros((len(acts), 1)),
                    })
                energy = fake_out_dict['energy']
                logZ = fake_out_dict['logZ']
                dtau_fake = fake_out_dict['dtau_fake']
                logger.record_tabular(prefix + 'GCLLogZ', np.mean(logZ))
                logger.record_tabular(prefix + 'GCLAverageEnergy',
                                      np.mean(energy))
                logger.record_tabular(prefix + 'GCLAverageLogPtau',
                                      np.mean(-energy - logZ))
                logger.record_tabular(prefix + 'GCLAverageLogQtau',
                                      np.mean(path_probs))
                logger.record_tabular(prefix + 'GCLMedianLogQtau',
                                      np.median(path_probs))
                logger.record_tabular(prefix + 'GCLAverageDtau',
                                      np.mean(dtau_fake))

                # expert_obs_next = np.r_[expert_obs_next,
                # np.expand_dims(expert_obs_next[-1], axis=0)]
                real_out_dict = tf.get_default_session().run(
                    real_in_dict,
                    feed_dict={
                        self.act_t: expert_acts,
                        self.obs_t: expert_obs,
                        self.nobs_t: expert_obs_next,
                        self.nact_t: expert_acts_next,
                        self.lprobs: np.expand_dims(expert_probs, axis=1),
                        self.is_train_t: is_train,
                        self.labels: np.ones((len(expert_acts), 1)),
                    })
                energy = real_out_dict['energy']
                logZ = real_out_dict['logZ']
                dtau_real = real_out_dict['dtau_real']
                logger.record_tabular(prefix + 'GCLAverageExpertEnergy',
                                      np.mean(energy))
                logger.record_tabular(prefix + 'GCLAverageExpertLogPtau',
                                      np.mean(-energy - logZ))
                logger.record_tabular(prefix + 'GCLAverageExpertLogQtau',
                                      np.mean(expert_probs))
                logger.record_tabular(prefix + 'GCLMedianExpertLogQtau',
                                      np.median(expert_probs))
                logger.record_tabular(prefix + 'GCLAverageExpertDtau',
                                      np.mean(dtau_real))

                # 1 real, 0 fake
                disc_true_nfake = len(dtau_fake)
                disc_true_nreal = len(dtau_real)
                disc_true_pos = np.sum(dtau_real >= 0.5)
                disc_false_neg = disc_true_nreal - disc_true_pos
                assert disc_false_neg == np.sum(dtau_real < 0.5)
                disc_true_neg = np.sum(dtau_fake < 0.5)
                disc_false_pos = disc_true_nfake - disc_true_neg
                assert disc_false_pos == np.sum(dtau_fake >= 0.5)
                disc_total = disc_true_nfake + disc_true_nreal
                assert 0 <= disc_true_pos and 0 <= disc_false_neg \
                    and 0 <= disc_true_neg and 0 <= disc_false_pos
                assert disc_true_pos + disc_false_neg + disc_true_neg \
                    + disc_false_pos == disc_total
                # acc = (tp+tn)/(tp+fp+tn+fn)
                disc_acc = (disc_true_pos + disc_true_neg) / disc_total
                # precision = |relevant&retrieved|/|retrieved| = tp/(tp+fp)
                disc_prec = disc_true_pos / (disc_true_pos + disc_false_pos)
                # recall = |relevant&retrieved|/|relevant| = tp/(tp+fn)
                disc_recall = disc_true_pos / (disc_true_pos + disc_false_neg)
                # tpr = tp/(tp+fn) = recall
                disc_tpr = disc_true_pos / (disc_true_pos + disc_false_neg)
                assert disc_tpr == disc_recall
                # tnr = tn/(tn+fp) = recall
                disc_tnr = disc_true_neg / (disc_true_neg + disc_false_pos)
                assert 0 <= disc_prec <= 1 and 0 <= disc_prec <= 1 and \
                    0 <= disc_acc <= 1 and 0 <= disc_tpr <= 1 and \
                    0 <= disc_tnr <= 1
                disc_f1 \
                    = 2 * disc_prec * disc_recall / (disc_prec + disc_recall)
                assert 0 <= disc_f1 <= 1
                logger.record_tabular(prefix + 'GCLDiscAcc', disc_acc)
                logger.record_tabular(prefix + 'GCLDiscF1', disc_f1)
                # TPR is accuracy when predicting reals
                logger.record_tabular(prefix + 'GCLDiscTPR', disc_tpr)
                # TNR is accuracy when predicting fakes
                logger.record_tabular(prefix + 'GCLDiscTNR', disc_tnr)
                logger.record_tabular(prefix + 'GCLDiscNFake', disc_true_nfake)
                logger.record_tabular(prefix + 'GCLDiscNReal', disc_true_nreal)

                if self.gp_value is not None:
                    gp_value = 0.5 * (real_out_dict['gp_value'] +
                                      fake_out_dict['gp_value'])
                    logger.record_tabular('GCLDiscGradPenaltyUnscaled',
                                          gp_value)

        return mean_loss
Exemplo n.º 10
0
    def fit(self,
            paths,
            policy=None,
            batch_size=256,
            logger=None,
            lr=1e-3,
            itr=0,
            **kwargs):
        if isinstance(self.expert_trajs[0], dict):
            print("Warning: Processing state out of dictionary")
            self._insert_next_state(self.expert_trajs)
            expert_obs_base, expert_obs_next_base, expert_acts, expert_acts_next = \
                self.extract_paths(self.expert_trajs, keys=(
                    'observations', 'observations_next',
                    'actions', 'actions_next'
                ))
        else:
            expert_obs_base, expert_obs_next_base, expert_acts, expert_acts_next, _ = \
                self.expert_trajs

        #expert_probs = paths.sampler.get_a_logprobs(
        obs, obs_next, acts, acts_next, path_probs = paths.extract_paths(
            ('observations', 'observations_next', 'actions', 'actions_next',
             'a_logprobs'),
            obs_modifier=self.modify_obs)

        expert_obs = expert_obs_base
        expert_obs_next = expert_obs_next_base

        raw_discrim_scores = []
        # Train discriminator
        for it in TrainingIterator(self.max_itrs, heartbeat=5):
            nobs_batch, obs_batch, nact_batch, act_batch, lprobs_batch = \
                self.sample_batch(obs_next, obs, acts_next, acts, path_probs, batch_size=batch_size)

            nexpert_obs_batch, expert_obs_batch, nexpert_act_batch, expert_act_batch = \
                self.sample_batch(
                    expert_obs_next,
                    expert_obs,
                    expert_acts_next,
                    expert_acts,
                    # expert_probs,
                    batch_size=batch_size
                )
            expert_lprobs_batch = paths.sampler.get_a_logprobs(
                expert_obs_batch, expert_act_batch)

            expert_obs_batch = self.modify_obs(expert_obs_batch)
            nexpert_obs_batch = self.modify_obs(nexpert_obs_batch)
            if self.encoder:
                expert_obs_batch = self.encode_fn(
                    expert_obs_batch, expert_act_batch.argmax(axis=1))
                nexpert_obs_batch = self.encode_fn(
                    nexpert_obs_batch, nexpert_act_batch.argmax(axis=1))

            # Build feed dict
            labels = np.zeros((batch_size * 2, 1))
            labels[batch_size:] = 1.0
            obs_batch = np.concatenate([obs_batch, expert_obs_batch], axis=0)
            nobs_batch = np.concatenate([nobs_batch, nexpert_obs_batch],
                                        axis=0)
            act_batch = np.concatenate([act_batch, expert_act_batch], axis=0)
            nact_batch = np.concatenate([nact_batch, nexpert_act_batch],
                                        axis=0)
            lprobs_batch = np.expand_dims(np.concatenate(
                [lprobs_batch, expert_lprobs_batch], axis=0),
                                          axis=1).astype(np.float32)

            feed_dict = {
                self.act_t: act_batch,
                self.obs_t: obs_batch,
                self.nobs_t: nobs_batch,
                self.nact_t: nact_batch,
                self.labels: labels,
                self.lprobs: lprobs_batch,
                self.lr: lr
            }

            loss, _, acc, scores = tf.get_default_session().run(
                [
                    self.loss, self.step, self.update_accuracy,
                    self.discrim_output
                ],
                feed_dict=feed_dict)
            # we only want the average score for the non-expert demos
            non_expert_slice = slice(0, batch_size)
            score, raw_score = self._process_discrim_output(
                scores[non_expert_slice])
            assert len(score) == batch_size
            assert np.sum(labels[non_expert_slice]) == 0
            raw_discrim_scores.append(raw_score)

            it.record('loss', loss)
            it.record('accuracy', acc)
            it.record('avg_score', np.mean(score))
            if it.heartbeat:
                print(it.itr_message())
                mean_loss = it.pop_mean('loss')
                print('\tLoss:%f' % mean_loss)
                mean_acc = it.pop_mean('accuracy')
                print('\tAccuracy:%f' % mean_acc)
                mean_score = it.pop_mean('avg_score')

        if logger:
            logger.record_tabular('GCLDiscrimLoss', mean_loss)
            logger.record_tabular('GCLDiscrimAccuracy', mean_acc)
            logger.record_tabular('GCLMeanScore', mean_score)

        # set the center for our normal distribution
        scores = np.hstack(raw_discrim_scores)
        self.score_std = np.std(scores)
        self.score_mean = np.mean(scores)

        return mean_loss