def _compute_discrim_reward(self, storage, step, add_info): state = rutils.get_def_obs(storage.get_obs(step)) next_state = rutils.get_def_obs(storage.get_obs(step + 1)) masks = storage.masks[step + 1] finished_episodes = [i for i in range(len(masks)) if masks[i] == 0.0] add_inputs = {k: v[(step + 1) - 1] for k, v in add_info.items()} obsfilt = self.get_env_ob_filt() for i in finished_episodes: next_state[i] = add_inputs['final_obs'][i] if obsfilt is not None: next_state[i] = torch.FloatTensor( obsfilt(next_state[i].cpu().numpy(), update=False)).to(self.args.device) d_val = self.discrim_net(state, next_state) s = torch.sigmoid(d_val) eps = 1e-20 if self.args.reward_type == 'airl': reward = (s + eps).log() - (1 - s + eps).log() elif self.args.reward_type == 'gail': reward = (s + eps).log() elif self.args.reward_type == 'raw': reward = d_val elif self.args.reward_type == 'gaifo': reward = -1.0 * (s + eps).log() else: raise ValueError( f"Unrecognized reward type {self.args.reward_type}") return reward
def pre_main(self, log, env_interface): """ Gathers the random experience and trains the inverse model on it. """ n_steps = self.args.bco_expl_steps // self.args.num_processes base_data_dir = 'data/traj/bco' if not osp.exists(base_data_dir): os.makedirs(base_data_dir) loaded_traj = None if self.args.bco_expl_load is not None: load_path = osp.join(base_data_dir, self.args.bco_expl_load) if osp.exists(load_path) and not self.args.bco_expl_refresh: loaded_traj = torch.load(load_path) states = loaded_traj['states'] actions = loaded_traj['actions'] dones = loaded_traj['dones'] print(f"Loaded expl trajectories from {load_path}") if loaded_traj is None: envs = make_vec_envs_easy(self.args.env_name, self.args.num_processes, env_interface, self.get_env_settings(self.args), self.args) policy = RandomPolicy() policy.init(envs.observation_space, envs.action_space, self.args) rutils.pstart_sep() print('Collecting exploration experience') states = [] actions = [] state = rutils.get_def_obs(envs.reset()) states.extend(state) dones = [True] for _ in tqdm(range(n_steps)): ac_info = policy.get_action(state, None, None, None, None) state, reward, done, info = envs.step(ac_info.take_action) state = rutils.get_def_obs(state) actions.extend(ac_info.action) dones.extend(done) states.extend(state) rutils.pend_sep() envs.close() if self.args.bco_expl_load is not None and loaded_traj is None: # Save the data. torch.save({ 'states': states, 'actions': actions, 'dones': dones, }, load_path) print(f"Saved data to {load_path}") if self.args.bco_inv_load is not None: self.inv_func.load_state_dict(torch.load(self.args.bco_inv_load)) self._update_all(states, actions, dones)
def traj_to_tensor(trajs, device): """ - trajs: [B, N, 5]. Batch size of B, trajectory length of N. 5 for state, action, mask, info, reward. """ # Get the data into a format we can work with. max_traj_len = max([len(traj) for traj in trajs]) n_trajs = len(trajs) ob_dim = rutils.get_def_obs(trajs[0][0][0]).shape ac_dim = trajs[0][0][1].shape other_obs_info = { k: x.shape for k, x in rutils.get_other_obs(trajs[0][0][0]).items() } obs = torch.zeros(n_trajs, max_traj_len, *ob_dim).to(device) obs_add = { k: torch.zeros(n_trajs, max_traj_len, *shp).to(device) for k, shp in other_obs_info.items() } actions = torch.zeros(n_trajs, max_traj_len, *ac_dim).to(device) masks = torch.zeros(n_trajs, max_traj_len, 1).to(device) rewards = torch.zeros(n_trajs, max_traj_len, 1).to(device) traj_mask = torch.zeros(n_trajs, max_traj_len).to(device) add_infos = { k: torch.zeros(len(trajs), max_traj_len, *v.shape) for k, v in trajs[0][-1][3].items() } for i in range(len(trajs)): traj_len = len(trajs[i]) o, a, m, infos, r = list(zip(*trajs[i])) for j, inf in enumerate(infos): for k, v in inf.items(): add_infos[k][i, j] = v traj_mask[i, :traj_len] = 1.0 obs[i, :traj_len] = torch.stack([rutils.get_def_obs(o_i) for o_i in o]).to(device) for k in obs_add: obs_add[k][i, :traj_len] = torch.stack([o_i[k] for o_i in o]).to(device) actions[i, :traj_len] = torch.stack(a).to(device) masks[i, :traj_len] = torch.tensor(m).unsqueeze(-1).to(device) rewards[i, :traj_len] = torch.stack(r).to(device) for k in add_infos: add_infos[k] = add_infos[k].to(device) return obs, obs_add, actions, masks, add_infos, rewards
def training_iter(self, update_iter): self.log.start_interval_log() self.updater.pre_update(update_iter) for step in range(self.args.num_steps): # Sample actions obs = self.storage.get_obs(step) step_info = get_step_info(update_iter, step, self.episode_count, self.args) with torch.no_grad(): ac_info = self.policy.get_action( utils.get_def_obs(obs, self.args.policy_ob_key), utils.get_other_obs(obs), self.storage.get_hidden_state(step), self.storage.get_masks(step), step_info) if self.args.clip_actions: ac_info.clip_action(*self.ac_tensor) next_obs, reward, done, infos = self.envs.step(ac_info.take_action) reward += ac_info.add_reward step_log_vals = utils.agg_ep_log_stats(infos, ac_info.extra) self.episode_count += sum([int(d) for d in done]) self.log.collect_step_info(step_log_vals) self.storage.insert(obs, next_obs, reward, done, infos, ac_info) updater_log_vals = self.updater.update(self.storage) self.storage.after_update() return updater_log_vals
def step_wait(self): obs, rews, news, infos = self.venv.step_wait() stacked_obs, infos = self.stacked_obs.update_obs( rutils.get_def_obs(obs), news, infos) obs = rutils.set_def_obs(obs, stacked_obs) return obs, rews, news, infos
def get_action(self, state, add_state, hxs, masks, step_info): n_procs = rutils.get_def_obs(state).shape[0] action = torch.tensor([ self.action_space.sample() for _ in range(n_procs) ]).to(self.args.device) if isinstance(self.action_space, spaces.Discrete): action = action.unsqueeze(-1) return create_simple_action_data(action, hxs)
def observation(self, obs): frame = rutils.get_def_obs(obs) if self.grayscale: frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA) if self.grayscale: frame = np.expand_dims(frame, -1) return rutils.set_def_obs(obs, frame)
def init_storage(self, obs): super().init_storage(obs) batch_size = rutils.get_def_obs(obs).shape[0] hxs = {} for k, dim in self.hidden_state_dims.items(): hxs[k] = torch.zeros(batch_size, dim) self.last_seen = { 'obs': obs, 'masks': torch.zeros(batch_size, 1), 'hxs': hxs, }
def _get_next_value(self, rollouts): """ Gets the value of the final observations. Needed if you need to estimate the returns of any partial trajectories. """ with torch.no_grad(): next_value = self.policy.get_value( rutils.get_def_obs(rollouts.get_obs(-1), self.args.policy_ob_key), rutils.get_other_obs(rollouts.get_obs(-1), self.args.policy_ob_key), rollouts.get_hidden_state(-1), rollouts.masks[-1]).detach() return next_value
def mod_render_frames(self, frame, env_cur_obs, env_cur_action, env_cur_reward, env_next_obs, **kwargs): use_cur_obs = rutils.get_def_obs(env_cur_obs) use_cur_obs = torch.FloatTensor(use_cur_obs).unsqueeze(0).to( self.args.device) if env_cur_action is not None: use_action = torch.FloatTensor(env_cur_action).unsqueeze(0).to( self.args.device) disc_val = self._compute_disc_val(use_cur_obs, use_action).item() else: disc_val = 0.0 frame = append_text_to_image(frame, [ "Discrim: %.3f" % disc_val, "Reward: %.3f" % (env_cur_reward if env_cur_reward is not None else 0.0) ]) return frame
def insert(self, obs, next_obs, reward, done, infos, ac_info): super().insert(obs, next_obs, reward, done, infos, ac_info) masks, bad_masks = self.compute_masks(done, infos) batch_size = rutils.get_def_obs(obs).shape[0] for i in range(batch_size): if bad_masks[i] == 0 and self.args.use_proper_time_limits: # we are not actually done. masks[i] = 1.0 self._push_transition({ 'action': ac_info.action[i], 'state': rutils.obs_select(obs, i), 'reward': reward[i], 'next_state': rutils.obs_select(next_obs, i), 'mask': masks[i], 'hxs': rutils.deep_dict_select(ac_info.hxs, i), }) self.last_seen = { 'obs': next_obs, 'masks': masks, 'hxs': ac_info.hxs, }
def _trans_agent_state(self, state, other_state=None): if not self.args.gail_state_norm: if other_state is None: return state['raw_obs'] return other_state['raw_obs'] return rutils.get_def_obs(state)
def get_def_obs_seq(self): if isinstance(self.obs, dict): return rutils.get_def_obs(self.obs) else: return self.obs
def evaluate(args, alg_env_settings, policy, true_vec_norm, env_interface, num_steps, mode, eval_envs, log, create_traj_saver_fn): if args.eval_num_processes is None: num_processes = args.num_processes else: num_processes = args.eval_num_processes if eval_envs is None: eval_envs = make_vec_envs(args.env_name, args.seed + num_steps, num_processes, args.gamma, args.device, True, env_interface, args, alg_env_settings, set_eval=True) assert get_vec_normalize(eval_envs) is None, 'Norm is manually applied' if true_vec_norm is not None: obfilt = true_vec_norm._obfilt else: def obfilt(x, update): return x eval_episode_rewards = [] eval_def_stats = defaultdict(list) ep_stats = defaultdict(list) obs = eval_envs.reset() hidden_states = {} for k, dim in policy.get_storage_hidden_states().items(): hidden_states[k] = torch.zeros(num_processes, dim).to(args.device) eval_masks = torch.zeros(num_processes, 1, device=args.device) frames = [] infos = None policy.eval() if args.eval_save and create_traj_saver_fn is not None: traj_saver = create_traj_saver_fn( osp.join(args.traj_dir, args.env_name, args.prefix)) else: assert not args.eval_save, ( 'Cannot save evaluation without ', 'specifying the eval saver creator function') total_num_eval = num_processes * args.num_eval # Measure the number of episodes completed pbar = tqdm(total=total_num_eval) evaluated_episode_count = 0 n_succs = 0 n_fails = 0 succ_frames = [] fail_frames = [] if args.render_succ_fails and args.eval_num_processes > 1: raise ValueError(""" Can only render successes and failures when the number of processes is 1. """) if args.num_render is None or args.num_render > 0: frames.extend( get_render_frames(eval_envs, env_interface, None, None, None, None, None, args, evaluated_episode_count)) while evaluated_episode_count < total_num_eval: step_info = get_empty_step_info() with torch.no_grad(): act_obs = obfilt(utils.ob_to_np(obs), update=False) act_obs = utils.ob_to_tensor(act_obs, args.device) ac_info = policy.get_action(utils.get_def_obs(act_obs), utils.get_other_obs(obs), hidden_states, eval_masks, step_info) hidden_states = ac_info.hxs # Observe reward and next obs next_obs, _, done, infos = eval_envs.step(ac_info.take_action) if args.eval_save: finished_count = traj_saver.collect(obs, next_obs, done, ac_info.take_action, infos) else: finished_count = sum([int(d) for d in done]) pbar.update(finished_count) evaluated_episode_count += finished_count cur_frame = None eval_masks = torch.tensor([[0.0] if done_ else [1.0] for done_ in done], dtype=torch.float32, device=args.device) should_render = (args.num_render) is None or (evaluated_episode_count < args.num_render) if args.render_succ_fails: should_render = n_succs < args.num_render or n_fails < args.num_render if should_render: frames.extend( get_render_frames(eval_envs, env_interface, obs, next_obs, ac_info.take_action, eval_masks, infos, args, evaluated_episode_count)) obs = next_obs step_log_vals = utils.agg_ep_log_stats(infos, ac_info.extra) for k, v in step_log_vals.items(): ep_stats[k].extend(v) if 'ep_success' in step_log_vals and args.render_succ_fails: is_succ = step_log_vals['ep_success'][0] if is_succ == 1.0: if n_succs < args.num_render: succ_frames.extend(frames) n_succs += 1 else: if n_fails < args.num_render: fail_frames.extend(frames) n_fails += 1 frames = [] pbar.close() info = {} if args.eval_save: traj_saver.save() ret_info = {} print(" Evaluation using %i episodes:" % len(ep_stats['r'])) for k, v in ep_stats.items(): print(' - %s: %.5f' % (k, np.mean(v))) ret_info[k] = np.mean(v) if args.render_succ_fails: # Render the success and failures to two separate files. save_frames(succ_frames, "succ_" + mode, num_steps, args) save_frames(fail_frames, "fail_" + mode, num_steps, args) else: save_file = save_frames(frames, mode, num_steps, args) if save_file is not None: log.log_video(save_file, num_steps, args.vid_fps) # Switch policy back to train mode policy.train() return ret_info, eval_envs
def mod_env_ob_filt(state, update=True): state = obfilt(state, update) state = rutils.get_def_obs(state) return state
def reset(self): obs = self.venv.reset() stacked_obs = self.stacked_obs.reset(rutils.get_def_obs(obs)) obs = rutils.set_def_obs(obs, stacked_obs) return obs
def init_storage(self, obs): self.traj_storage = [[] for _ in range(rutils.get_def_obs(obs).shape[0])]
def _norm_state(self, x): obs_x = torch.clamp((rutils.get_def_obs(x) - self.norm_mean) / torch.pow(self.norm_var + 1e-8, 0.5), -10.0, 10.0) if isinstance(x, dict): x['observation'] = obs_x return x