def __init__(self, load_path): trajs = torch.load(load_path) rutils.pstart_sep() self._setup(trajs) trajs = self._generate_trajectories(trajs) assert len(trajs) != 0, 'No trajectories found to load!' self.n_trajs = len(trajs) print('Collected %i trajectories' % len(trajs)) # Compute statistics across the trajectories. all_obs = torch.cat([t[0] for t in trajs]) all_actions = torch.cat([t[1] for t in trajs]) self.state_mean = torch.mean(all_obs, dim=0) self.state_std = torch.std(all_obs, dim=0) self.action_mean = torch.mean(all_actions, dim=0) self.action_std = torch.std(all_actions, dim=0) self.data = self._gen_data(trajs) self.traj_lens = [len(traj[0]) for traj in trajs] self.trajs = trajs self.holdout_idxs = [] rutils.pend_sep()
def create_runner(self, add_args={}, ray_create=False): policy = self.get_policy() algo = self.get_algo() args, log = self._sys_setup(add_args, ray_create, algo, policy) if args is None: return None env_interface = self._get_env_interface(args) checkpointer = Checkpointer(args) alg_env_settings = algo.get_env_settings(args) # Setup environment envs = make_vec_envs(args.env_name, args.seed, args.num_processes, args.gamma, args.device, args.eval_only, env_interface, args, alg_env_settings, set_eval=args.eval_only) rutils.pstart_sep() print('Action space:', envs.action_space) if isinstance(envs.action_space, Box): print('Action range:', (envs.action_space.low, envs.action_space.high)) print('Observation space', envs.observation_space) rutils.pend_sep() # Setup policy policy_args = (envs.observation_space, envs.action_space, args) policy.init(*policy_args) policy = policy.to(args.device) policy.watch(log) policy.set_env_ref(envs) # Setup algo algo.set_get_policy(self.get_policy, policy_args) algo.init(policy, args) algo.set_env_ref(envs) # Setup storage buffer storage = algo.get_storage_buffer(policy, envs, args) for ik, get_shape in alg_env_settings.include_info_keys: storage.add_info_key(ik, get_shape(envs)) storage.to(args.device) storage.init_storage(envs.reset()) storage.set_traj_done_callback(algo.on_traj_finished) runner = Runner(envs, storage, policy, log, env_interface, checkpointer, args, algo) return runner
def pre_main(self, log, env_interface): """ Gathers the random experience and trains the inverse model on it. """ n_steps = self.args.bco_expl_steps // self.args.num_processes base_data_dir = 'data/traj/bco' if not osp.exists(base_data_dir): os.makedirs(base_data_dir) loaded_traj = None if self.args.bco_expl_load is not None: load_path = osp.join(base_data_dir, self.args.bco_expl_load) if osp.exists(load_path) and not self.args.bco_expl_refresh: loaded_traj = torch.load(load_path) states = loaded_traj['states'] actions = loaded_traj['actions'] dones = loaded_traj['dones'] print(f"Loaded expl trajectories from {load_path}") if loaded_traj is None: envs = make_vec_envs_easy(self.args.env_name, self.args.num_processes, env_interface, self.get_env_settings(self.args), self.args) policy = RandomPolicy() policy.init(envs.observation_space, envs.action_space, self.args) rutils.pstart_sep() print('Collecting exploration experience') states = [] actions = [] state = rutils.get_def_obs(envs.reset()) states.extend(state) dones = [True] for _ in tqdm(range(n_steps)): ac_info = policy.get_action(state, None, None, None, None) state, reward, done, info = envs.step(ac_info.take_action) state = rutils.get_def_obs(state) actions.extend(ac_info.action) dones.extend(done) states.extend(state) rutils.pend_sep() envs.close() if self.args.bco_expl_load is not None and loaded_traj is None: # Save the data. torch.save({ 'states': states, 'actions': actions, 'dones': dones, }, load_path) print(f"Saved data to {load_path}") if self.args.bco_inv_load is not None: self.inv_func.load_state_dict(torch.load(self.args.bco_inv_load)) self._update_all(states, actions, dones)
def first_train(self, log, eval_policy): rutils.pstart_sep() print('Pre-training policy with BC') self.bc.full_train() bc_eval_args = copy.copy(self.args) bc_eval_args.eval_num_processes = 32 bc_eval_args.num_eval = 5 bc_eval_args.num_render = 0 tmp_env = eval_policy(self.policy, 0, bc_eval_args) if tmp_env is not None: tmp_env.close() rutils.pend_sep() super().first_train(log, eval_policy)
def run_policy(run_settings, runner=None): if runner is None: runner = run_settings.create_runner() end_update = runner.updater.get_num_updates() args = runner.args if args.ray: from ray import tune import ray # Release resources as they will be recreated by Ray runner.close() use_config = eval(args.ray_config) use_config['cwd'] = os.getcwd() use_config = run_settings.get_add_ray_config(use_config) rutils.pstart_sep() print('Running ray for %i updates per run' % end_update) rutils.pend_sep() ray.init(local_mode=args.ray_debug) tune.run(type(run_settings), resources_per_trial={ 'cpu': args.ray_cpus, "gpu": args.ray_gpus }, stop={'training_iteration': end_update}, num_samples=args.ray_nsamples, global_checkpoint_period=np.inf, config=use_config, **run_settings.get_add_ray_kwargs()) else: args = runner.args if runner.should_load_from_checkpoint(): runner.load_from_checkpoint() if args.eval_only: return runner.full_eval(run_settings.create_traj_saver) start_update = 0 if args.resume: start_update = runner.resume() runner.setup() print('RL Training (%d/%d)' % (start_update, end_update)) # Initialize outside the loop just in case there are no updates. j = 0 for j in range(start_update, end_update): updater_log_vals = runner.training_iter(j) if args.log_interval > 0 and (j + 1) % args.log_interval == 0: log_dict = runner.log_vals(updater_log_vals, j) if args.save_interval > 0 and (j + 1) % args.save_interval == 0: runner.save(j) if args.eval_interval > 0 and (j + 1) % args.eval_interval == 0: runner.eval(j) if args.eval_interval > 0: runner.eval(j + 1) if args.save_interval > 0: runner.save(j + 1) runner.close() # WB prefix of the run so we can later fetch the data. return RunResult(args.prefix)
def eval_from_file(plot_cfg_path, load_dir, get_run_settings, args): with open(plot_cfg_path) as f: eval_settings = yaml.load(f) config_mgr.init(eval_settings['config_yaml']) eval_key = eval_settings['eval_key'] scale_factor = eval_settings['scale_factor'] rename_sections = eval_settings['rename_sections'] wb_proj_name = config_mgr.get_prop('proj_name') wb_entity = config_mgr.get_prop('wb_entity') api = wandb.Api() all_run_names = [] for eval_section in eval_settings['eval_sections']: report_name = eval_section['report_name'] eval_sections = eval_section['eval_sections'] cacher = CacheHelper(report_name, eval_sections) if cacher.exists() and not eval_section['force_reload']: run_names = cacher.load() else: run_ids = get_run_ids_from_report(wb_entity, wb_proj_name, report_name, eval_sections, api) run_names = convert_to_prefix(run_ids, {'report_name': report_name}) cacher.save(run_names) all_run_names.extend(run_names) full_load_name = osp.join(load_dir, 'data/trained_models') full_log_name = osp.join(load_dir, 'data/log') method_names = defaultdict(list) for name, method_name, env_name, info in all_run_names: model_dir = osp.join(full_load_name, env_name, name) cmd_path = osp.join(full_log_name, env_name, name) if not osp.exists(model_dir): raise ValueError(f"Model {model_dir} does not exist", info) if not osp.exists(cmd_path): raise ValueError(f"Model {cmd_path} does not exist") model_nums = [ int(x.split('_')[1].split('.')[0]) for x in os.listdir(model_dir) if 'model_' in x ] if len(model_nums) == 0: raise ValueError(f"Model {model_dir} is empty", info) max_idx = max(model_nums) use_model = osp.join(model_dir, f"model_{max_idx}.pt") with open(osp.join(cmd_path, 'cmd.txt'), 'r') as f: cmd = f.read() method_names[method_name].append((use_model, cmd, env_name, info)) env_results = defaultdict(lambda: defaultdict(list)) NUM_PROCS = 20 total_count = sum([len(x) for x in method_names.values()]) done_count = 0 for method_name, runs in method_names.items(): for use_model, cmd, env_name, info in runs: print(f"({done_count}/{total_count})") done_count += 1 cache_result = CacheHelper( f"result_{method_name}_{use_model.replace('/', '_')}_{args.num_eval}", cmd) if cache_result.exists() and not args.override: eval_result = cache_result.load() else: if args.table_only and not args.override: break cmd = cmd.split(' ')[2:] cmd.append('--no-wb') cmd.append('--eval-only') cmd.extend(['--cuda', 'False']) cmd.extend(['--num-render', '0']) cmd.extend(['--eval-num-processes', str(NUM_PROCS)]) cmd.extend(["--num-eval", f"{args.num_eval // NUM_PROCS}"]) cmd.extend(["--load-file", use_model]) run_settings = get_run_settings(cmd) run_settings.setup() eval_result = run_settings.eval_result cache_result.save(eval_result) store_num = eval_result[args.get_key] env_results[info['report_name']][method_name].append(store_num) rutils.pstart_sep() print(f"Result for {use_model}: {store_num}") rutils.pend_sep() print(generate_eval_table(env_results, scale_factor, rename_sections))
def _update_all(self, states, actions, dones): """ - states (list[N+1]) - masks (list[N+1]) - actions (list[N]) Performs a complete update of the model by following these steps: 1. Train inverse function with ground truth data provided. 2. Infer actions in expert dataset 3. Train BC """ dataset = [{ 's0': states[i], 's1': states[i+1], 'action': actions[i] } for i in range(len(actions)) if not dones[i+1]] rutils.pstart_sep() print(f"BCO Update {self.update_i}/{self.args.bco_alpha}") print('---') print('Training inverse function') dataset_idxs = list(range(len(dataset))) np.random.shuffle(dataset_idxs) eval_len = int(len(dataset_idxs) * self.args.bco_inv_eval_holdout) if eval_len != 0.0: train_trans_sampler = BatchSampler(SubsetRandomSampler( dataset_idxs[:-eval_len]), self.args.bco_inv_batch_size, drop_last=False) val_trans_sampler = BatchSampler(SubsetRandomSampler( dataset_idxs[-eval_len:]), self.args.bco_inv_batch_size, drop_last=False) else: train_trans_sampler = BatchSampler(SubsetRandomSampler( dataset_idxs), self.args.bco_inv_batch_size, drop_last=False) if self.args.bco_inv_load is None or self.update_i > 0: infer_ac_losses = self._train_inv_func(train_trans_sampler, dataset) rutils.plot_line(infer_ac_losses, f"ac_inv_loss_{self.update_i}.png", self.args, not self.args.no_wb, self.get_completed_update_steps(self.update_i)) if self.update_i == 0: # Only save the inverse model on the first epoch for debugging # purposes rutils.save_model(self.inv_func, f"inv_func_{self.update_i}.pt", self.args) if eval_len != 0.0: if not isinstance(self.policy.action_space, spaces.Discrete): raise ValueError(('Evaluating the holdout accuracy is only', ' supported for discrete action spaces right now')) accuracy = self._infer_inv_accuracy(val_trans_sampler, dataset) print('Inferred actions with %.2f accuracy' % accuracy) if isinstance(self.expert_dataset, torch.utils.data.Subset): s0 = self.expert_dataset.dataset.trajs['obs'].to(self.args.device).float() s1 = self.expert_dataset.dataset.trajs['next_obs'].to(self.args.device).float() dataset_device = self.expert_dataset.dataset.trajs['obs'].device else: s0 = self.expert_dataset.trajs['obs'].to(self.args.device).float() s1 = self.expert_dataset.trajs['next_obs'].to(self.args.device).float() dataset_device = self.expert_dataset.trajs['obs'].device # Perform inference on the expert states with torch.no_grad(): pred_actions = self.inv_func(s0, s1).to(dataset_device) pred_actions = rutils.get_ac_compact(self.policy.action_space, pred_actions) if not self.args.bco_oracle_actions: if isinstance(self.expert_dataset, torch.utils.data.Subset): self.expert_dataset.dataset.trajs['actions'] = pred_actions else: self.expert_dataset.trajs['actions'] = pred_actions # Recreate the dataset for BC training so we can be sure it has the # most recent data. self._create_train_loader(self.args) print('Training Policy') self.full_train(self.update_i) self.update_i += 1 rutils.pend_sep()