def fit(self, paths, irl_model=None, tempw=None, policy=None, qvar_model=None, batch_size=32, logger=None, lr=1e-3, **kwargs): if self.fusion is not None: old_paths = self.fusion.sample_paths(n=len(paths)) self.fusion.add_paths(paths) paths = paths + old_paths #self._insert_next_state(paths) obs, ac, obs_next = self.extract_paths(paths, keys=('observations', 'actions', 'observations_next')) for it in TrainingIterator(self.max_itrs, heartbeat=1): nobs_batch, obs_batch, ac_batch = self.sample_batch( obs_next, obs, ac) dist_info_vars = policy.dist_info_sym(obs_batch, None) dist_s = policy.distribution.log_likelihood( policy.distribution.sample(dist_info_vars), dist_info_vars) q_input = tf.concat([obs_batch, nobs_batch], axis=1) q_dist_info_vars = qvar_model.dist_info_sym(q_input, None) q_dist_s = policy.distribution.log_likelihood( policy.distribution.sample(q_dist_info_vars), q_dist_info_vars) # Build feed dict feed_dict = { self.obs_t: obs_batch, self.act_qvar: q_dist_s.eval(), self.act_policy: dist_s.eval(), self.lr: lr } loss_emp, _ = tf.get_default_session().run( [self.loss_emp, self.step_emp], feed_dict=feed_dict) it.record('loss_emp', loss_emp) if it.heartbeat: print(it.itr_message()) mean_loss_emp = it.pop_mean('loss_emp') print('\tLoss_emp:%f' % mean_loss_emp) return mean_loss_emp
def tabular_maxent_irl(env, demo_visitations, num_itrs=50, ent_wt=1.0, lr=1e-3, state_only=False, discount=0.99, T=5): dim_obs = env.observation_space.flat_dim dim_act = env.action_space.flat_dim # Initialize policy and reward function reward_fn = np.zeros((dim_obs, dim_act)) q_rew = np.zeros((dim_obs, dim_act)) update = adam_optimizer(lr) for it in TrainingIterator(num_itrs, heartbeat=1.0): q_itrs = 20 if it.itr > 5 else 100 ### compute policy in closed form q_rew = q_iteration(env, reward_matrix=reward_fn, ent_wt=ent_wt, warmstart_q=q_rew, K=q_itrs, gamma=discount) ### update reward # need to count how often the policy will visit a particular (s, a) pair pol_visitations = compute_visitation(env, q_rew, ent_wt=ent_wt, T=T, discount=discount) grad = -(demo_visitations - pol_visitations) it.record('VisitationInfNorm', np.max(np.abs(grad))) if state_only: grad = np.sum(grad, axis=1, keepdims=True) reward_fn = update(reward_fn, grad) if it.heartbeat: print(it.itr_message()) print('\t', it.pop_mean('VisitationInfNorm')) return reward_fn, q_rew
def fit(self, paths, batch_size=32, logger=None, lr=1e-3,**kwargs): if self.fusion is not None: old_paths = self.fusion.sample_paths(n=len(paths)) self.fusion.add_paths(paths) paths = paths+old_paths obs, obs_next, acts = \ self.extract_paths(paths, keys=('observations', 'observations_next', 'actions')) '''expert_obs, expert_obs_next, expert_acts = \ self.extract_paths(self.expert_trajs, keys=('observations', 'observations_next', 'actions'))''' # Train discriminator for it in TrainingIterator(self.max_itrs, heartbeat=5): nobs_batch, obs_batch, act_batch = \ self.sample_batch(obs_next, obs, acts, batch_size=batch_size) feed_dict = { self.act_t: act_batch, self.obs_t: obs_batch, self.nobs_t: nobs_batch, self.lr: lr } loss_q, _ = tf.get_default_session().run([self.loss_q, self.step_q], feed_dict=feed_dict) it.record('loss_q', loss_q) if it.heartbeat: mean_loss_q = it.pop_mean('loss_q') print('\tLoss_q:%f' % mean_loss_q) return mean_loss_q
def fit(self, paths, policy=None, batch_size=32, logger=None, lr=1e-3,**kwargs): if self.fusion is not None: old_paths = self.fusion.sample_paths(n=len(paths)) self.fusion.add_paths(paths) paths = paths+old_paths # eval samples under current policy self._compute_path_probs(paths, insert=True) # eval expert log probs under current policy self.eval_expert_probs(self.expert_trajs, policy, insert=True) self._insert_next_state(paths) self._insert_next_state(self.expert_trajs) obs, obs_next, acts, acts_next, path_probs = \ self.extract_paths(paths, keys=('observations', 'observations_next', 'actions', 'actions_next', 'a_logprobs')) expert_obs, expert_obs_next, expert_acts, expert_acts_next, expert_probs = \ self.extract_paths(self.expert_trajs, keys=('observations', 'observations_next', 'actions', 'actions_next', 'a_logprobs')) # Train discriminator for it in TrainingIterator(self.max_itrs, heartbeat=5): nobs_batch, obs_batch, nact_batch, act_batch, lprobs_batch = \ self.sample_batch(obs_next, obs, acts_next, acts, path_probs, batch_size=batch_size) nexpert_obs_batch, expert_obs_batch, nexpert_act_batch, expert_act_batch, expert_lprobs_batch = \ self.sample_batch(expert_obs_next, expert_obs, expert_acts_next, expert_acts, expert_probs, batch_size=batch_size) # Build feed dict labels = np.zeros((batch_size*2, 1)) labels[batch_size:] = 1.0 obs_batch = np.concatenate([obs_batch, expert_obs_batch], axis=0) nobs_batch = np.concatenate([nobs_batch, nexpert_obs_batch], axis=0) act_batch = np.concatenate([act_batch, expert_act_batch], axis=0) nact_batch = np.concatenate([nact_batch, nexpert_act_batch], axis=0) lprobs_batch = np.expand_dims(np.concatenate([lprobs_batch, expert_lprobs_batch], axis=0), axis=1).astype(np.float32) feed_dict = { self.act_t: act_batch, self.obs_t: obs_batch, self.nobs_t: nobs_batch, self.nact_t: nact_batch, self.labels: labels, self.lprobs: lprobs_batch, self.lr: lr } loss, _ = tf.get_default_session().run([self.loss, self.step], feed_dict=feed_dict) it.record('loss', loss) if it.heartbeat: print(it.itr_message()) mean_loss = it.pop_mean('loss') print('\tLoss:%f' % mean_loss) if logger: logger.record_tabular('GCLDiscrimLoss', mean_loss) #obs_next = np.r_[obs_next, np.expand_dims(obs_next[-1], axis=0)] energy, logZ, dtau = tf.get_default_session().run([self.reward, self.value_fn, self.discrim_output], feed_dict={self.act_t: acts, self.obs_t: obs, self.nobs_t: obs_next, self.nact_t: acts_next, self.lprobs: np.expand_dims(path_probs, axis=1)}) energy = -energy logger.record_tabular('GCLLogZ', np.mean(logZ)) logger.record_tabular('GCLAverageEnergy', np.mean(energy)) logger.record_tabular('GCLAverageLogPtau', np.mean(-energy-logZ)) logger.record_tabular('GCLAverageLogQtau', np.mean(path_probs)) logger.record_tabular('GCLMedianLogQtau', np.median(path_probs)) logger.record_tabular('GCLAverageDtau', np.mean(dtau)) #expert_obs_next = np.r_[expert_obs_next, np.expand_dims(expert_obs_next[-1], axis=0)] energy, logZ, dtau = tf.get_default_session().run([self.reward, self.value_fn, self.discrim_output], feed_dict={self.act_t: expert_acts, self.obs_t: expert_obs, self.nobs_t: expert_obs_next, self.nact_t: expert_acts_next, self.lprobs: np.expand_dims(expert_probs, axis=1)}) energy = -energy logger.record_tabular('GCLAverageExpertEnergy', np.mean(energy)) logger.record_tabular('GCLAverageExpertLogPtau', np.mean(-energy-logZ)) logger.record_tabular('GCLAverageExpertLogQtau', np.mean(expert_probs)) logger.record_tabular('GCLMedianExpertLogQtau', np.median(expert_probs)) logger.record_tabular('GCLAverageExpertDtau', np.mean(dtau)) return mean_loss
def fit(self, paths, expert_traj_batch=None, policy=None, batch_size=32, logger=None, lr=1e-3, **kwargs): meta_batch_size = self.meta_batch_size if self.fusion is not None: old_paths = self.fusion.sample_paths(expert_traj_batch, n=len(paths[0])) self.fusion.add_paths(paths, expert_traj_batch, subsample=True) if old_paths is not None: for key in paths.keys(): paths[key] += old_paths[key] # Do we need to recalculate path probabilities every iteration since context encoderis being updated? # eval samples under current policy # TODO: fix this with dict self._compute_path_probs_dict(paths, insert=True) self._insert_next_state(paths) self._insert_next_state(self.expert_trajs) obs, obs_next, acts, acts_next, path_probs = \ self.extract_paths(paths, keys=('observations', 'observations_next', 'actions', 'actions_next', 'a_logprobs'), T=self.T) # TODO: we may need to assume that expert_trajs is also a dict expert_obs, expert_obs_next, expert_acts, expert_acts_next, expert_contexts = \ self.extract_paths(self.expert_trajs, keys=('observations', 'observations_next', 'actions', 'actions_next', 'contexts'), T=self.T) # eval expert log probs under current policy expert_trajs = np.concatenate([expert_obs, expert_acts], axis=-1) m_hat_expert = self.context_encoder.get_actions( expert_trajs.reshape(-1, self.T * (self.dO + self.dU)))[0] self.eval_expert_probs(self.expert_trajs, policy, insert=True, context=m_hat_expert) expert_probs = self.extract_paths(self.expert_trajs, keys=('a_logprobs', ), T=self.T)[0] # Train discriminator expert_traj_batch_tile = np.tile( expert_traj_batch.reshape(meta_batch_size, 1, self.T, -1), [1, batch_size, 1, 1]) for it in TrainingIterator(self.max_itrs, heartbeat=5): # TODO: implement sample_batch in imitation_learning.py nobs_batch, obs_batch, nact_batch, act_batch, lprobs_batch = \ self.sample_batch(obs_next, obs, acts_next, acts, path_probs, batch_size=batch_size) if obs_batch.shape[-1] == self.dO + self.latent_dim: nobs_batch = nobs_batch[..., :-self.latent_dim] obs_batch = obs_batch[..., :-self.latent_dim] # First half of the batch is used for inferring m_hat nexpert_obs_batch, expert_obs_batch, nexpert_act_batch, expert_act_batch, expert_lprobs_batch = \ self.sample_batch(expert_obs_next, expert_obs, expert_acts_next, expert_acts, expert_probs, batch_size=self.meta_batch_size*batch_size) if expert_obs_batch.shape[-1] == self.dO + self.latent_dim: nexpert_obs_batch = nexpert_obs_batch[..., :-self.latent_dim] expert_obs_batch = expert_obs_batch[..., :-self.latent_dim] # Build feed dict labels = np.zeros((meta_batch_size, batch_size * 2, 1, 1)) labels[:, batch_size:, ...] = 1.0 imitation_expert_obses_input = expert_traj_batch.reshape( meta_batch_size, 1, self.T, -1)[:, :, :, :self.dO] imitation_expert_acts_input = expert_traj_batch.reshape( meta_batch_size, 1, self.T, -1)[:, :, :, self.dO:] expert_traj_batch_input = np.concatenate([ expert_traj_batch_tile, np.concatenate( [expert_obs_batch, expert_act_batch], axis=-1).reshape( meta_batch_size, batch_size, self.T, -1) ], axis=1) sample_traj_batch = np.concatenate([obs_batch, act_batch], axis=-1) obs_batch = np.concatenate([ obs_batch, expert_obs_batch.reshape(meta_batch_size, batch_size, self.T, -1) ], axis=1) nobs_batch = np.concatenate([ nobs_batch, nexpert_obs_batch.reshape(meta_batch_size, batch_size, self.T, -1) ], axis=1) act_batch = np.concatenate([ act_batch, expert_act_batch.reshape(meta_batch_size, batch_size, self.T, -1) ], axis=1) nact_batch = np.concatenate([ nact_batch, nexpert_act_batch.reshape(meta_batch_size, batch_size, self.T, -1) ], axis=1) lprobs_batch = np.concatenate([ lprobs_batch, expert_lprobs_batch.reshape(meta_batch_size, batch_size, self.T, -1) ], axis=1).astype(np.float32) feed_dict = { self.expert_traj_var: expert_traj_batch_input, self.sample_traj_var: sample_traj_batch, self.act_t: act_batch, self.obs_t: obs_batch, self.nobs_t: nobs_batch, self.nact_t: nact_batch, self.labels: labels, self.lprobs: lprobs_batch, self.imitation_expert_obses: imitation_expert_obses_input, self.imitation_expert_acts: imitation_expert_acts_input, self.lr: lr } loss, _ = tf.get_default_session().run([self.loss, self.step], feed_dict=feed_dict) it.record('loss', loss) if it.heartbeat: print(it.itr_message()) mean_loss = it.pop_mean('loss') print('\tLoss:%f' % mean_loss) if logger: logger.record_tabular('GCLDiscrimLoss', mean_loss) #obs_next = np.r_[obs_next, np.expand_dims(obs_next[-1], axis=0)] # TODO: fix this expert_traj_logging = np.tile( expert_traj_batch.reshape(meta_batch_size, 1, self.T, -1), [1, acts.shape[1], 1, 1]) imitation_expert_obses_input = expert_traj_batch.reshape( meta_batch_size, 1, self.T, -1)[:, :, :, :self.dO] imitation_expert_acts_input = expert_traj_batch.reshape( meta_batch_size, 1, self.T, -1)[:, :, :, self.dO:] energy, logZ, dtau, info_loss, imit_loss = tf.get_default_session( ).run( [ self.reward, self.value_fn, self.discrim_output, self.info_loss, self.policy_likelihood_loss ], feed_dict={ self.expert_traj_var: expert_traj_logging, self.sample_traj_var: np.concatenate([obs[..., :-self.latent_dim], acts], axis=-1), self.act_t: acts, self.obs_t: obs[..., :-self.latent_dim], self.nobs_t: obs_next[..., :-self.latent_dim], self.nact_t: acts_next, self.imitation_expert_obses: imitation_expert_obses_input, self.imitation_expert_acts: imitation_expert_acts_input, self.labels: np.zeros([meta_batch_size, acts.shape[1], 1, 1]), self.lprobs: path_probs }) energy = -energy logger.record_tabular('GCLLogZ', np.mean(logZ)) logger.record_tabular('GCLAverageEnergy', np.mean(energy)) logger.record_tabular('GCLAverageLogPtau', np.mean(-energy - logZ)) logger.record_tabular('GCLAverageLogQtau', np.mean(path_probs)) logger.record_tabular('GCLMedianLogQtau', np.median(path_probs)) logger.record_tabular('GCLAverageDtau', np.mean(dtau)) logger.record_tabular('GCLAverageMutualInfo', np.mean(info_loss)) logger.record_tabular('GCLAverageImitationLoss', np.mean(imit_loss)) #expert_obs_next = np.r_[expert_obs_next, np.expand_dims(expert_obs_next[-1], axis=0)] # Not sure if using expert trajectories for expert_traj_var and sample_traj_var makes sense # energy, logZ, dtau, info_loss = tf.get_default_session().run([self.reward, self.value_fn, self.discrim_output, self.info_loss], # feed_dict={self.expert_traj_var: np.concatenate([expert_obs, expert_acts], axis=-1), # self.sample_traj_var: np.concatenate([expert_obs, expert_acts], axis=-1), # self.act_t: expert_acts, self.obs_t: expert_obs, self.nobs_t: expert_obs_next, # self.nact_t: expert_acts_next, # self.labels: np.zeros([meta_batch_size, acts.shape[1], 1, 1]), # self.lprobs: expert_probs}) # energy = -energy # logger.record_tabular('GCLAverageExpertEnergy', np.mean(energy)) # logger.record_tabular('GCLAverageExpertLogPtau', np.mean(-energy-logZ)) # logger.record_tabular('GCLAverageExpertLogQtau', np.mean(expert_probs)) # logger.record_tabular('GCLMedianExpertLogQtau', np.median(expert_probs)) # logger.record_tabular('GCLAverageExpertDtau', np.mean(dtau)) # logger.record_tabular('GCLAverageExpertMutualInfo', np.mean(info_loss)) return mean_loss
def tabular_gcl_irl(env, demo_visitations, irl_model, num_itrs=50, ent_wt=1.0, lr=1e-3, state_only=False, discount=0.99, batch_size=20024): dim_obs = env.observation_space.flat_dim dim_act = env.action_space.flat_dim states_all = [] actions_all = [] for s in range(dim_obs): for a in range(dim_act): states_all.append(flat_to_one_hot(s, dim_obs)) actions_all.append(flat_to_one_hot(a, dim_act)) states_all = np.array(states_all) actions_all = np.array(actions_all) path_all = {'observations': states_all, 'actions': actions_all} # Initialize policy and reward function reward_fn = np.zeros((dim_obs, dim_act)) q_rew = np.zeros((dim_obs, dim_act)) update = adam_optimizer(lr) for it in TrainingIterator(num_itrs, heartbeat=1.0): q_itrs = 20 if it.itr > 5 else 100 ### compute policy in closed form q_rew = q_iteration(env, reward_matrix=reward_fn, ent_wt=ent_wt, warmstart_q=q_rew, K=q_itrs, gamma=discount) pol_rew = get_policy(q_rew, ent_wt=ent_wt) ### update reward # need to count how often the policy will visit a particular (s, a) pair pol_visitations = compute_visitation(env, q_rew, ent_wt=ent_wt, T=5, discount=discount) # now we need to sample states and actions, and give them to the discriminator demo_path = sample_states(env, q_rew, demo_visitations, batch_size, ent_wt) irl_model.set_demos([demo_path]) path = sample_states(env, q_rew, pol_visitations, batch_size, ent_wt) irl_model.fit([path], policy=pol_rew, max_itrs=200, lr=1e-3, batch_size=1024) rew_stack = irl_model.eval([path_all])[0] reward_fn = np.zeros_like(q_rew) i = 0 for s in range(dim_obs): for a in range(dim_act): reward_fn[s, a] = rew_stack[i] i += 1 diff_visit = np.abs(demo_visitations - pol_visitations) it.record('VisitationDiffInfNorm', np.max(diff_visit)) it.record('VisitationDiffAvg', np.mean(diff_visit)) if it.heartbeat: print(it.itr_message()) print('\tVisitationDiffInfNorm:', it.pop_mean('VisitationDiffInfNorm')) print('\tVisitationDiffAvg:', it.pop_mean('VisitationDiffAvg')) print('visitations', pol_visitations) print('diff_visit', diff_visit) adjusted_rew = reward_fn - np.mean(reward_fn) + np.mean( env.rew_matrix) print('adjusted_rew', adjusted_rew) return reward_fn, q_rew
def fit(self, paths, policy=None,empw_model=None,t_empw_model=None, batch_size=32, logger=None, lr=1e-3,**kwargs): if self.fusion is not None: old_paths = self.fusion.sample_paths(n=len(paths)) self.fusion.add_paths(paths) paths = paths+old_paths # eval samples under current policy self._compute_path_probs(paths, insert=True) # eval expert log probs under current policy self.eval_expert_probs(self.expert_trajs, policy, insert=True) self._insert_next_state(paths) self._insert_next_state(self.expert_trajs) obs, obs_next, acts, acts_next, path_probs = \ self.extract_paths(paths, keys=('observations', 'observations_next', 'actions', 'actions_next', 'a_logprobs')) expert_obs, expert_obs_next, expert_acts, expert_acts_next, expert_probs = \ self.extract_paths(self.expert_trajs, keys=('observations', 'observations_next', 'actions', 'actions_next', 'a_logprobs')) # Train discriminator for it in TrainingIterator(self.max_itrs, heartbeat=5): nobs_batch, obs_batch, nact_batch, act_batch, lprobs_batch = \ self.sample_batch(obs_next, obs, acts_next, acts, path_probs, batch_size=batch_size) nexpert_obs_batch, expert_obs_batch, nexpert_act_batch, expert_act_batch, expert_lprobs_batch = \ self.sample_batch(expert_obs_next, expert_obs, expert_acts_next, expert_acts, expert_probs, batch_size=batch_size) # Build feed dict labels = np.zeros((batch_size*2, 1)) labels[batch_size:] = 1.0 obs_batch = np.concatenate([obs_batch, expert_obs_batch], axis=0) nobs_batch = np.concatenate([nobs_batch, nexpert_obs_batch], axis=0) act_batch = np.concatenate([act_batch, expert_act_batch], axis=0) nact_batch = np.concatenate([nact_batch, nexpert_act_batch], axis=0) lprobs_batch = np.expand_dims(np.concatenate([lprobs_batch, expert_lprobs_batch], axis=0), axis=1).astype(np.float32) vs=empw_model.eval(obs_batch) vsp=t_empw_model.eval(nobs_batch) feed_dict = { self.act_t: act_batch, self.obs_t: obs_batch, self.nobs_t: nobs_batch, self.nact_t: nact_batch, self.labels: labels, self.lprobs: lprobs_batch, self.lr: lr, self.vs: vs, self.vsp: vsp } loss_irl, _ = tf.get_default_session().run([self.loss_irl, self.step_irl], feed_dict=feed_dict) it.record('loss_irl', loss_irl) if it.heartbeat: print(it.itr_message()) mean_loss_irl = it.pop_mean('loss_irl') print('\tLoss_irl:%f' % mean_loss_irl) return mean_loss_irl
def fit(self, paths, policy=None, batch_size=32, logger=None, lr=1e-3, **kwargs): if self.fusion is not None: old_paths = self.fusion.sample_paths(n=len(paths)) self.fusion.add_paths(paths) paths = paths + old_paths # eval samples under current policy self._compute_path_probs(paths, insert=True) # eval expert log probs under current policy self.eval_expert_probs(self.expert_trajs, policy, insert=True) self._reorganize_states(paths, number_obs=self.max_nstep) self._reorganize_states(self.expert_trajs, number_obs=self.max_nstep) obs, obs_next, acts, acts_next, path_probs = \ self.extract_paths(paths, keys=('multi_observations', 'observations_next', 'multi_actions', 'actions_next', 'a_logprobs')) expert_obs, expert_obs_next, expert_acts, expert_acts_next, expert_probs = \ self.extract_paths(self.expert_trajs, keys=('multi_observations', 'observations_next', 'multi_actions', 'actions_next', 'a_logprobs')) all_obs = np.concatenate([obs, expert_obs], axis=0) all_nobs = np.concatenate([obs_next, expert_obs_next], axis=0) all_acts = np.concatenate([acts, expert_acts], axis=0) all_nacts = np.concatenate([acts_next, expert_acts_next], axis=0) all_probs = np.concatenate([path_probs, expert_probs], axis=0) all_labels = np.zeros((all_obs.shape[0], 1)) all_labels[obs.shape[0]:] = 1.0 # Train discriminator for it in TrainingIterator(self.max_itrs, heartbeat=5): if self.n_rew_funct < self.n_value_funct: delta = self.n_value_funct - self.n_rew_funct temp = np.arange(self.n_rew_funct) np.random.shuffle(temp) rew_idxs = np.r_[ temp, np.random.randint(self.n_rew_funct, size=delta)] val_idxs = np.arange(self.n_value_funct) else: delta = self.n_rew_funct - self.n_value_funct temp = np.arange(self.n_value_funct) np.random.shuffle(temp) val_idxs = np.r_[ temp, np.random.randint(self.n_value_funct, size=delta)] rew_idxs = np.arange(self.n_rew_funct) for idx in range(val_idxs.shape[0]): i = rew_idxs[idx] j = val_idxs[idx] for single_nstep in range(1, self.max_nstep + 1): nobs_batch, obs_batch, nact_batch, act_batch, lprobs_batch = \ self.sample_batch(obs_next, obs, acts_next, acts, path_probs, batch_size=batch_size) nexpert_obs_batch, expert_obs_batch, nexpert_act_batch, expert_act_batch, expert_lprobs_batch = \ self.sample_batch(expert_obs_next, expert_obs, expert_acts_next, expert_acts, expert_probs, batch_size=batch_size) # Build feed dict labels = np.zeros((batch_size * 2, 1)) labels[batch_size:] = 1.0 obs_batch = np.concatenate([obs_batch, expert_obs_batch], axis=0) nobs_batch = np.concatenate( [nobs_batch, nexpert_obs_batch], axis=0) act_batch = np.concatenate([act_batch, expert_act_batch], axis=0) nact_batch = np.concatenate( [nact_batch, nexpert_act_batch], axis=0) lprobs_batch = np.expand_dims(np.concatenate( [lprobs_batch, expert_lprobs_batch], axis=0), axis=1).astype(np.float32) feed_dict = { self.act_t: act_batch[:, :single_nstep], self.obs_t: obs_batch[:, :single_nstep], self.nobs_t: nobs_batch if self.max_nstep == single_nstep else obs_batch[:, single_nstep], self.nact_t: nact_batch if self.max_nstep == single_nstep else act_batch[:, single_nstep], self.labels: labels, self.lprobs: lprobs_batch, self.lr: lr } loss, _ = tf.get_default_session().run( [self.loss[i][j], self.step[i][j]], feed_dict=feed_dict) it.record('loss', loss) if self.score_discrim is False and self.score_method == 'teacher_student': nobs_batch, obs_batch, nact_batch, act_batch, lprobs_batch = \ self.sample_batch(obs_next, obs, acts_next, acts, path_probs, batch_size=batch_size) nexpert_obs_batch, expert_obs_batch, nexpert_act_batch, expert_act_batch, expert_lprobs_batch = \ self.sample_batch(expert_obs_next, expert_obs, expert_acts_next, expert_acts, expert_probs, batch_size=batch_size) # Build feed dict labels = np.zeros((batch_size * 2, 1)) labels[batch_size:] = 1.0 obs_batch = np.concatenate([obs_batch, expert_obs_batch], axis=0) nobs_batch = np.concatenate([nobs_batch, nexpert_obs_batch], axis=0) act_batch = np.concatenate([act_batch, expert_act_batch], axis=0) nact_batch = np.concatenate([nact_batch, nexpert_act_batch], axis=0) lprobs_batch = np.expand_dims(np.concatenate( [lprobs_batch, expert_lprobs_batch], axis=0), axis=1).astype(np.float32) feed_dict = { self.act_t: act_batch, self.obs_t: obs_batch, self.nobs_t: nobs_batch, self.nact_t: nact_batch, self.labels: labels, self.lprobs: lprobs_batch, self.lr: lr } rel_loss, abs_loss, ens_loss, _ = tf.get_default_session().run( [ self.student_loss, self.student_absolute_loss, self.ensemble_loss, self.student_step ], feed_dict=feed_dict) if it.heartbeat: print(it.itr_message()) mean_loss = it.pop_mean('loss') print('\tLoss:%f' % mean_loss) if logger: logger.record_tabular('GCLMeanDiscrimLoss', mean_loss) sess = tf.get_default_session() if self.score_discrim is False and self.score_method == 'teacher_student': logger.record_tabular('GCLStudentRelativeLoss', rel_loss) logger.record_tabular('GCLStudentAbsoluteLoss', abs_loss) logger.record_tabular('GCLEnsembleLoss', ens_loss) else: e_loss, weights = sess.run( [self.ensemble_loss, self.weights], feed_dict={ self.act_t: all_acts, self.obs_t: all_obs, self.nobs_t: all_nobs, self.nact_t: all_nacts, self.labels: all_labels, self.lprobs: np.expand_dims(all_probs, axis=1) }) logger.record_tabular('GCLEnsembleDiscrimLoss', e_loss) # logger.record_tabular('TimeWeights', weights) if self.score_discrim is False and self.score_method == 'teacher_student': energy, logZ, dtau, s_rew, s_val, s_dtau = sess.run([self.t0_reward, self.t0_value, self.ensemble_discrim_output, \ self.student_reward, self.student_value, self.student_discrim_output], feed_dict={self.act_t: acts, self.obs_t: obs, self.nobs_t: obs_next, self.nact_t: acts_next, self.lprobs: np.expand_dims(path_probs, axis=1)}) else: energy, logZ, dtau = sess.run( [ self.t0_reward, self.t0_value, self.ensemble_discrim_output ], feed_dict={ self.act_t: acts, self.obs_t: obs, self.nobs_t: obs_next, self.nact_t: acts_next, self.lprobs: np.expand_dims(path_probs, axis=1) }) energy = -energy logger.record_tabular('GCLLogZ', np.mean(logZ)) logger.record_tabular('GCLAverageEnergy', np.mean(energy)) logger.record_tabular('GCLAverageLogPtau', np.mean(-energy - logZ)) logger.record_tabular('GCLAverageLogQtau', np.mean(path_probs)) logger.record_tabular('GCLMedianLogQtau', np.median(path_probs)) logger.record_tabular('GCLAverageDtau', np.mean(dtau)) if self.score_discrim is False and self.score_method == 'teacher_student': logger.record_tabular('GCLAverageStudentEnergy', np.mean(-s_rew)) logger.record_tabular('GCLAverageStudentLogPtau', np.mean(s_rew - s_val)) logger.record_tabular('GCLAverageStudentDtau', np.mean(s_dtau)) if self.score_discrim is False and self.score_method == 'teacher_student': energy, logZ, dtau, s_rew, s_val, s_dtau = sess.run([self.t0_reward, self.t0_value, self.ensemble_discrim_output, \ self.student_reward, self.student_value, self.student_discrim_output], feed_dict={self.act_t: expert_acts, self.obs_t: expert_obs, self.nobs_t: expert_obs_next, self.nact_t: expert_acts_next, self.lprobs: np.expand_dims(expert_probs, axis=1)}) else: energy, logZ, dtau = sess.run( [ self.t0_reward, self.t0_value, self.ensemble_discrim_output ], feed_dict={ self.act_t: expert_acts, self.obs_t: expert_obs, self.nobs_t: expert_obs_next, self.nact_t: expert_acts_next, self.lprobs: np.expand_dims(expert_probs, axis=1) }) energy = -energy logger.record_tabular('GCLAverageExpertEnergy', np.mean(energy)) logger.record_tabular('GCLAverageExpertLogPtau', np.mean(-energy - logZ)) logger.record_tabular('GCLAverageExpertLogQtau', np.mean(expert_probs)) logger.record_tabular('GCLMedianExpertLogQtau', np.median(expert_probs)) logger.record_tabular('GCLAverageExpertDtau', np.mean(dtau)) if self.score_discrim is False and self.score_method == 'teacher_student': logger.record_tabular('GCLAverageStudentExpertEnergy', np.mean(-s_rew)) logger.record_tabular('GCLAverageStudentExpertLogPtau', np.mean(s_rew - s_val)) logger.record_tabular('GCLAverageStudentExpertDtau', np.mean(s_dtau)) return mean_loss
def fit(self, paths, policy=None, batch_size=32, logger=None, lr=1e-3, last_timestep_only=False, max_itrs=100, **kwargs): if self.frozen: return 0 if self.fusion is not None: old_paths = self.fusion.sample_paths(n=len(paths)) self.fusion.add_paths(paths) paths = paths + old_paths # log fusion stats fstats = self.fusion.compute_age_stats() logger.record_tabular('FusionAgeMean', fstats['mean']) logger.record_tabular('FusionAgeMed', fstats['med']) logger.record_tabular('FusionAgeStd', fstats['std']) logger.record_tabular('FusionAgeMax', fstats['max']) logger.record_tabular('FusionAgeMin', fstats['min']) logger.record_tabular('FusionAgePFresh', fstats['pfresh']) self._compute_path_probs(paths, insert=True) # self.eval_expert_probs(paths, policy, insert=True) for traj in self.expert_trajs: if 'agent_infos' in traj: # print('deleting agent_infos') del traj['agent_infos'] if 'a_logprobs' in traj: del traj['a_logprobs'] self.eval_expert_probs(self.expert_trajs, policy, insert=True) self._insert_next_state(paths) self._insert_next_state(self.expert_trajs) obs, obs_next, acts, acts_next, path_probs = self.extract_paths2( paths, keys=('observations', 'observations_next', 'actions', 'actions_next', 'a_logprobs'), last_timestep_only=last_timestep_only) (expert_obs, expert_obs_next, expert_acts, expert_acts_next, expert_probs) = self.extract_paths2( self.expert_trajs, keys=('observations', 'observations_next', 'actions', 'actions_next', 'a_logprobs'), last_timestep_only=last_timestep_only) # Train discriminator for it in TrainingIterator(max_itrs, heartbeat=5): nobs_batch, obs_batch, nact_batch, act_batch, lprobs_batch = \ self.sample_batch(obs_next, obs, acts_next, acts, path_probs, batch_size=batch_size) (nexpert_obs_batch, expert_obs_batch, nexpert_act_batch, expert_act_batch, expert_lprobs_batch) = self.sample_batch(expert_obs_next, expert_obs, expert_acts_next, expert_acts, expert_probs, batch_size=batch_size) labels = np.zeros((batch_size * 2, 1)) labels[batch_size:] = 1.0 obs_batch = np.concatenate([obs_batch, expert_obs_batch], axis=0) nobs_batch = np.concatenate([nobs_batch, nexpert_obs_batch], axis=0) act_batch = np.concatenate([act_batch, expert_act_batch], axis=0) nact_batch = np.concatenate([nact_batch, nexpert_act_batch], axis=0) lprobs_batch = np.expand_dims(np.concatenate( [lprobs_batch, expert_lprobs_batch], axis=0), axis=1).astype(np.float32) learn_step_feed_dict = { self.act_t: act_batch, self.obs_t: obs_batch, self.nobs_t: nobs_batch, self.nact_t: nact_batch, self.labels: labels, self.lprobs: lprobs_batch, self.lr: lr, # we only enable noise during training self.is_train_t: True, } sess = tf.get_default_session() loss, tot_kl, _ = sess.run( [self.loss, self.tot_kl_loss, self.step], feed_dict=learn_step_feed_dict) if self.vairl and self.vairl_adaptive_beta: beta, _ = sess.run( [self.vairl_beta, self.vairl_beta_update_op], feed_dict={self.vairl_mean_kl: tot_kl}) it.record('loss', loss) it.record('tot_kl', tot_kl) if self.vairl and self.vairl_adaptive_beta: it.record('beta', beta) if it.heartbeat: print(it.itr_message()) mean_loss = it.pop_mean('loss') print('\tLoss:%f' % mean_loss) mean_tot_kl = it.pop_mean('tot_kl') print('\tKL:%f' % mean_tot_kl) if self.vairl and self.vairl_adaptive_beta: mean_beta = it.pop_mean('beta') print('\tBeta:%f' % mean_beta) if logger: logger.record_tabular('GCLDiscrimLoss', mean_loss) # the 'DiscrimVAIRLKL' one is just retained so I don't break my # parsing scripts :) logger.record_tabular('GCLDiscrimVAIRLKL', mean_tot_kl) logger.record_tabular('GCLVAIRLKL', mean_tot_kl) if self.vairl and self.vairl_adaptive_beta: logger.record_tabular('GCLVAIRLBeta', mean_beta) # obs_next = np.r_[obs_next, np.expand_dims(obs_next[-1], axis=0)] # logZ, for is_train in [True, False]: # make sure to keep stats about test-mode configuration as well # as train-mode configuration, in case we have something like # dropout or VDB noise that affects discriminator results prefix = '' if is_train else 'NotIsTrain' fake_in_dict = { 'energy': self.energy, 'logZ': self.value_fn, 'dtau_fake': self.d_tau } real_in_dict = { 'energy': self.energy, 'logZ': self.value_fn, 'dtau_real': self.d_tau } if self.gp_value is not None: fake_in_dict['gp_value'] = real_in_dict[ 'gp_value'] = self.gp_value fake_out_dict = tf.get_default_session().run( fake_in_dict, feed_dict={ self.act_t: acts, self.obs_t: obs, self.nobs_t: obs_next, self.nact_t: acts_next, self.lprobs: np.expand_dims(path_probs, axis=1), self.is_train_t: is_train, self.labels: np.zeros((len(acts), 1)), }) energy = fake_out_dict['energy'] logZ = fake_out_dict['logZ'] dtau_fake = fake_out_dict['dtau_fake'] logger.record_tabular(prefix + 'GCLLogZ', np.mean(logZ)) logger.record_tabular(prefix + 'GCLAverageEnergy', np.mean(energy)) logger.record_tabular(prefix + 'GCLAverageLogPtau', np.mean(-energy - logZ)) logger.record_tabular(prefix + 'GCLAverageLogQtau', np.mean(path_probs)) logger.record_tabular(prefix + 'GCLMedianLogQtau', np.median(path_probs)) logger.record_tabular(prefix + 'GCLAverageDtau', np.mean(dtau_fake)) # expert_obs_next = np.r_[expert_obs_next, # np.expand_dims(expert_obs_next[-1], axis=0)] real_out_dict = tf.get_default_session().run( real_in_dict, feed_dict={ self.act_t: expert_acts, self.obs_t: expert_obs, self.nobs_t: expert_obs_next, self.nact_t: expert_acts_next, self.lprobs: np.expand_dims(expert_probs, axis=1), self.is_train_t: is_train, self.labels: np.ones((len(expert_acts), 1)), }) energy = real_out_dict['energy'] logZ = real_out_dict['logZ'] dtau_real = real_out_dict['dtau_real'] logger.record_tabular(prefix + 'GCLAverageExpertEnergy', np.mean(energy)) logger.record_tabular(prefix + 'GCLAverageExpertLogPtau', np.mean(-energy - logZ)) logger.record_tabular(prefix + 'GCLAverageExpertLogQtau', np.mean(expert_probs)) logger.record_tabular(prefix + 'GCLMedianExpertLogQtau', np.median(expert_probs)) logger.record_tabular(prefix + 'GCLAverageExpertDtau', np.mean(dtau_real)) # 1 real, 0 fake disc_true_nfake = len(dtau_fake) disc_true_nreal = len(dtau_real) disc_true_pos = np.sum(dtau_real >= 0.5) disc_false_neg = disc_true_nreal - disc_true_pos assert disc_false_neg == np.sum(dtau_real < 0.5) disc_true_neg = np.sum(dtau_fake < 0.5) disc_false_pos = disc_true_nfake - disc_true_neg assert disc_false_pos == np.sum(dtau_fake >= 0.5) disc_total = disc_true_nfake + disc_true_nreal assert 0 <= disc_true_pos and 0 <= disc_false_neg \ and 0 <= disc_true_neg and 0 <= disc_false_pos assert disc_true_pos + disc_false_neg + disc_true_neg \ + disc_false_pos == disc_total # acc = (tp+tn)/(tp+fp+tn+fn) disc_acc = (disc_true_pos + disc_true_neg) / disc_total # precision = |relevant&retrieved|/|retrieved| = tp/(tp+fp) disc_prec = disc_true_pos / (disc_true_pos + disc_false_pos) # recall = |relevant&retrieved|/|relevant| = tp/(tp+fn) disc_recall = disc_true_pos / (disc_true_pos + disc_false_neg) # tpr = tp/(tp+fn) = recall disc_tpr = disc_true_pos / (disc_true_pos + disc_false_neg) assert disc_tpr == disc_recall # tnr = tn/(tn+fp) = recall disc_tnr = disc_true_neg / (disc_true_neg + disc_false_pos) assert 0 <= disc_prec <= 1 and 0 <= disc_prec <= 1 and \ 0 <= disc_acc <= 1 and 0 <= disc_tpr <= 1 and \ 0 <= disc_tnr <= 1 disc_f1 \ = 2 * disc_prec * disc_recall / (disc_prec + disc_recall) assert 0 <= disc_f1 <= 1 logger.record_tabular(prefix + 'GCLDiscAcc', disc_acc) logger.record_tabular(prefix + 'GCLDiscF1', disc_f1) # TPR is accuracy when predicting reals logger.record_tabular(prefix + 'GCLDiscTPR', disc_tpr) # TNR is accuracy when predicting fakes logger.record_tabular(prefix + 'GCLDiscTNR', disc_tnr) logger.record_tabular(prefix + 'GCLDiscNFake', disc_true_nfake) logger.record_tabular(prefix + 'GCLDiscNReal', disc_true_nreal) if self.gp_value is not None: gp_value = 0.5 * (real_out_dict['gp_value'] + fake_out_dict['gp_value']) logger.record_tabular('GCLDiscGradPenaltyUnscaled', gp_value) return mean_loss
def fit(self, paths, policy=None, batch_size=256, logger=None, lr=1e-3, itr=0, **kwargs): if isinstance(self.expert_trajs[0], dict): print("Warning: Processing state out of dictionary") self._insert_next_state(self.expert_trajs) expert_obs_base, expert_obs_next_base, expert_acts, expert_acts_next = \ self.extract_paths(self.expert_trajs, keys=( 'observations', 'observations_next', 'actions', 'actions_next' )) else: expert_obs_base, expert_obs_next_base, expert_acts, expert_acts_next, _ = \ self.expert_trajs #expert_probs = paths.sampler.get_a_logprobs( obs, obs_next, acts, acts_next, path_probs = paths.extract_paths( ('observations', 'observations_next', 'actions', 'actions_next', 'a_logprobs'), obs_modifier=self.modify_obs) expert_obs = expert_obs_base expert_obs_next = expert_obs_next_base raw_discrim_scores = [] # Train discriminator for it in TrainingIterator(self.max_itrs, heartbeat=5): nobs_batch, obs_batch, nact_batch, act_batch, lprobs_batch = \ self.sample_batch(obs_next, obs, acts_next, acts, path_probs, batch_size=batch_size) nexpert_obs_batch, expert_obs_batch, nexpert_act_batch, expert_act_batch = \ self.sample_batch( expert_obs_next, expert_obs, expert_acts_next, expert_acts, # expert_probs, batch_size=batch_size ) expert_lprobs_batch = paths.sampler.get_a_logprobs( expert_obs_batch, expert_act_batch) expert_obs_batch = self.modify_obs(expert_obs_batch) nexpert_obs_batch = self.modify_obs(nexpert_obs_batch) if self.encoder: expert_obs_batch = self.encode_fn( expert_obs_batch, expert_act_batch.argmax(axis=1)) nexpert_obs_batch = self.encode_fn( nexpert_obs_batch, nexpert_act_batch.argmax(axis=1)) # Build feed dict labels = np.zeros((batch_size * 2, 1)) labels[batch_size:] = 1.0 obs_batch = np.concatenate([obs_batch, expert_obs_batch], axis=0) nobs_batch = np.concatenate([nobs_batch, nexpert_obs_batch], axis=0) act_batch = np.concatenate([act_batch, expert_act_batch], axis=0) nact_batch = np.concatenate([nact_batch, nexpert_act_batch], axis=0) lprobs_batch = np.expand_dims(np.concatenate( [lprobs_batch, expert_lprobs_batch], axis=0), axis=1).astype(np.float32) feed_dict = { self.act_t: act_batch, self.obs_t: obs_batch, self.nobs_t: nobs_batch, self.nact_t: nact_batch, self.labels: labels, self.lprobs: lprobs_batch, self.lr: lr } loss, _, acc, scores = tf.get_default_session().run( [ self.loss, self.step, self.update_accuracy, self.discrim_output ], feed_dict=feed_dict) # we only want the average score for the non-expert demos non_expert_slice = slice(0, batch_size) score, raw_score = self._process_discrim_output( scores[non_expert_slice]) assert len(score) == batch_size assert np.sum(labels[non_expert_slice]) == 0 raw_discrim_scores.append(raw_score) it.record('loss', loss) it.record('accuracy', acc) it.record('avg_score', np.mean(score)) if it.heartbeat: print(it.itr_message()) mean_loss = it.pop_mean('loss') print('\tLoss:%f' % mean_loss) mean_acc = it.pop_mean('accuracy') print('\tAccuracy:%f' % mean_acc) mean_score = it.pop_mean('avg_score') if logger: logger.record_tabular('GCLDiscrimLoss', mean_loss) logger.record_tabular('GCLDiscrimAccuracy', mean_acc) logger.record_tabular('GCLMeanScore', mean_score) # set the center for our normal distribution scores = np.hstack(raw_discrim_scores) self.score_std = np.std(scores) self.score_mean = np.mean(scores) return mean_loss