Пример #1
0
    def __init__(self, load_path):
        trajs = torch.load(load_path)

        rutils.pstart_sep()
        self._setup(trajs)

        trajs = self._generate_trajectories(trajs)

        assert len(trajs) != 0, 'No trajectories found to load!'

        self.n_trajs = len(trajs)
        print('Collected %i trajectories' % len(trajs))
        # Compute statistics across the trajectories.
        all_obs = torch.cat([t[0] for t in trajs])
        all_actions = torch.cat([t[1] for t in trajs])
        self.state_mean = torch.mean(all_obs, dim=0)
        self.state_std = torch.std(all_obs, dim=0)
        self.action_mean = torch.mean(all_actions, dim=0)
        self.action_std = torch.std(all_actions, dim=0)

        self.data = self._gen_data(trajs)
        self.traj_lens = [len(traj[0]) for traj in trajs]
        self.trajs = trajs
        self.holdout_idxs = []

        rutils.pend_sep()
Пример #2
0
    def create_runner(self, add_args={}, ray_create=False):
        policy = self.get_policy()
        algo = self.get_algo()

        args, log = self._sys_setup(add_args, ray_create, algo, policy)
        if args is None:
            return None

        env_interface = self._get_env_interface(args)

        checkpointer = Checkpointer(args)

        alg_env_settings = algo.get_env_settings(args)

        # Setup environment
        envs = make_vec_envs(args.env_name,
                             args.seed,
                             args.num_processes,
                             args.gamma,
                             args.device,
                             args.eval_only,
                             env_interface,
                             args,
                             alg_env_settings,
                             set_eval=args.eval_only)

        rutils.pstart_sep()
        print('Action space:', envs.action_space)
        if isinstance(envs.action_space, Box):
            print('Action range:',
                  (envs.action_space.low, envs.action_space.high))
        print('Observation space', envs.observation_space)
        rutils.pend_sep()

        # Setup policy
        policy_args = (envs.observation_space, envs.action_space, args)
        policy.init(*policy_args)
        policy = policy.to(args.device)
        policy.watch(log)
        policy.set_env_ref(envs)

        # Setup algo
        algo.set_get_policy(self.get_policy, policy_args)
        algo.init(policy, args)
        algo.set_env_ref(envs)

        # Setup storage buffer
        storage = algo.get_storage_buffer(policy, envs, args)
        for ik, get_shape in alg_env_settings.include_info_keys:
            storage.add_info_key(ik, get_shape(envs))
        storage.to(args.device)
        storage.init_storage(envs.reset())
        storage.set_traj_done_callback(algo.on_traj_finished)

        runner = Runner(envs, storage, policy, log, env_interface,
                        checkpointer, args, algo)
        return runner
Пример #3
0
    def pre_main(self, log, env_interface):
        """
        Gathers the random experience and trains the inverse model on it.
        """
        n_steps = self.args.bco_expl_steps // self.args.num_processes
        base_data_dir = 'data/traj/bco'
        if not osp.exists(base_data_dir):
            os.makedirs(base_data_dir)

        loaded_traj = None
        if self.args.bco_expl_load is not None:
            load_path = osp.join(base_data_dir, self.args.bco_expl_load)
            if osp.exists(load_path) and not self.args.bco_expl_refresh:
                loaded_traj = torch.load(load_path)
                states = loaded_traj['states']
                actions = loaded_traj['actions']
                dones = loaded_traj['dones']
                print(f"Loaded expl trajectories from {load_path}")

        if loaded_traj is None:
            envs = make_vec_envs_easy(self.args.env_name, self.args.num_processes,
                                      env_interface, self.get_env_settings(self.args), self.args)
            policy = RandomPolicy()
            policy.init(envs.observation_space, envs.action_space, self.args)
            rutils.pstart_sep()
            print('Collecting exploration experience')
            states = []
            actions = []
            state = rutils.get_def_obs(envs.reset())
            states.extend(state)
            dones = [True]
            for _ in tqdm(range(n_steps)):
                ac_info = policy.get_action(state, None, None, None, None)
                state, reward, done, info = envs.step(ac_info.take_action)
                state = rutils.get_def_obs(state)
                actions.extend(ac_info.action)
                dones.extend(done)
                states.extend(state)
            rutils.pend_sep()
            envs.close()

        if self.args.bco_expl_load is not None and loaded_traj is None:
            # Save the data.
            torch.save({
                'states': states,
                'actions': actions,
                'dones': dones,
            }, load_path)
            print(f"Saved data to {load_path}")

        if self.args.bco_inv_load is not None:
            self.inv_func.load_state_dict(torch.load(self.args.bco_inv_load))

        self._update_all(states, actions, dones)
Пример #4
0
    def first_train(self, log, eval_policy):
        rutils.pstart_sep()
        print('Pre-training policy with BC')
        self.bc.full_train()

        bc_eval_args = copy.copy(self.args)
        bc_eval_args.eval_num_processes = 32
        bc_eval_args.num_eval = 5
        bc_eval_args.num_render = 0
        tmp_env = eval_policy(self.policy, 0, bc_eval_args)
        if tmp_env is not None:
            tmp_env.close()

        rutils.pend_sep()

        super().first_train(log, eval_policy)
Пример #5
0
def run_policy(run_settings, runner=None):
    if runner is None:
        runner = run_settings.create_runner()
    end_update = runner.updater.get_num_updates()
    args = runner.args

    if args.ray:
        from ray import tune
        import ray

        # Release resources as they will be recreated by Ray
        runner.close()

        use_config = eval(args.ray_config)
        use_config['cwd'] = os.getcwd()
        use_config = run_settings.get_add_ray_config(use_config)

        rutils.pstart_sep()
        print('Running ray for %i updates per run' % end_update)
        rutils.pend_sep()

        ray.init(local_mode=args.ray_debug)
        tune.run(type(run_settings),
                 resources_per_trial={
                     'cpu': args.ray_cpus,
                     "gpu": args.ray_gpus
                 },
                 stop={'training_iteration': end_update},
                 num_samples=args.ray_nsamples,
                 global_checkpoint_period=np.inf,
                 config=use_config,
                 **run_settings.get_add_ray_kwargs())
    else:
        args = runner.args

        if runner.should_load_from_checkpoint():
            runner.load_from_checkpoint()

        if args.eval_only:
            return runner.full_eval(run_settings.create_traj_saver)

        start_update = 0
        if args.resume:
            start_update = runner.resume()

        runner.setup()
        print('RL Training (%d/%d)' % (start_update, end_update))

        # Initialize outside the loop just in case there are no updates.
        j = 0
        for j in range(start_update, end_update):
            updater_log_vals = runner.training_iter(j)
            if args.log_interval > 0 and (j + 1) % args.log_interval == 0:
                log_dict = runner.log_vals(updater_log_vals, j)
            if args.save_interval > 0 and (j + 1) % args.save_interval == 0:
                runner.save(j)
            if args.eval_interval > 0 and (j + 1) % args.eval_interval == 0:
                runner.eval(j)

        if args.eval_interval > 0:
            runner.eval(j + 1)
        if args.save_interval > 0:
            runner.save(j + 1)

        runner.close()
        # WB prefix of the run so we can later fetch the data.
        return RunResult(args.prefix)
Пример #6
0
def eval_from_file(plot_cfg_path, load_dir, get_run_settings, args):
    with open(plot_cfg_path) as f:
        eval_settings = yaml.load(f)
        config_mgr.init(eval_settings['config_yaml'])
        eval_key = eval_settings['eval_key']
        scale_factor = eval_settings['scale_factor']
        rename_sections = eval_settings['rename_sections']
        wb_proj_name = config_mgr.get_prop('proj_name')
        wb_entity = config_mgr.get_prop('wb_entity')
        api = wandb.Api()
        all_run_names = []
        for eval_section in eval_settings['eval_sections']:
            report_name = eval_section['report_name']
            eval_sections = eval_section['eval_sections']
            cacher = CacheHelper(report_name, eval_sections)
            if cacher.exists() and not eval_section['force_reload']:
                run_names = cacher.load()
            else:
                run_ids = get_run_ids_from_report(wb_entity, wb_proj_name,
                                                  report_name, eval_sections,
                                                  api)
                run_names = convert_to_prefix(run_ids,
                                              {'report_name': report_name})
                cacher.save(run_names)
            all_run_names.extend(run_names)

    full_load_name = osp.join(load_dir, 'data/trained_models')
    full_log_name = osp.join(load_dir, 'data/log')
    method_names = defaultdict(list)
    for name, method_name, env_name, info in all_run_names:
        model_dir = osp.join(full_load_name, env_name, name)
        cmd_path = osp.join(full_log_name, env_name, name)

        if not osp.exists(model_dir):
            raise ValueError(f"Model {model_dir} does not exist", info)

        if not osp.exists(cmd_path):
            raise ValueError(f"Model {cmd_path} does not exist")

        model_nums = [
            int(x.split('_')[1].split('.')[0]) for x in os.listdir(model_dir)
            if 'model_' in x
        ]
        if len(model_nums) == 0:
            raise ValueError(f"Model {model_dir} is empty", info)

        max_idx = max(model_nums)
        use_model = osp.join(model_dir, f"model_{max_idx}.pt")

        with open(osp.join(cmd_path, 'cmd.txt'), 'r') as f:
            cmd = f.read()

        method_names[method_name].append((use_model, cmd, env_name, info))

    env_results = defaultdict(lambda: defaultdict(list))
    NUM_PROCS = 20

    total_count = sum([len(x) for x in method_names.values()])

    done_count = 0
    for method_name, runs in method_names.items():
        for use_model, cmd, env_name, info in runs:
            print(f"({done_count}/{total_count})")
            done_count += 1
            cache_result = CacheHelper(
                f"result_{method_name}_{use_model.replace('/', '_')}_{args.num_eval}",
                cmd)
            if cache_result.exists() and not args.override:
                eval_result = cache_result.load()
            else:
                if args.table_only and not args.override:
                    break
                cmd = cmd.split(' ')[2:]
                cmd.append('--no-wb')
                cmd.append('--eval-only')
                cmd.extend(['--cuda', 'False'])
                cmd.extend(['--num-render', '0'])
                cmd.extend(['--eval-num-processes', str(NUM_PROCS)])
                cmd.extend(["--num-eval", f"{args.num_eval // NUM_PROCS}"])
                cmd.extend(["--load-file", use_model])
                run_settings = get_run_settings(cmd)
                run_settings.setup()
                eval_result = run_settings.eval_result
                cache_result.save(eval_result)
            store_num = eval_result[args.get_key]
            env_results[info['report_name']][method_name].append(store_num)
            rutils.pstart_sep()
            print(f"Result for {use_model}: {store_num}")
            rutils.pend_sep()
    print(generate_eval_table(env_results, scale_factor, rename_sections))
Пример #7
0
    def _update_all(self, states, actions, dones):
        """
        - states (list[N+1])
        - masks (list[N+1])
        - actions (list[N])
        Performs a complete update of the model by following these steps:
            1. Train inverse function with ground truth data provided.
            2. Infer actions in expert dataset
            3. Train BC
        """
        dataset = [{
            's0': states[i],
            's1': states[i+1],
            'action': actions[i]
        } for i in range(len(actions))
            if not dones[i+1]]

        rutils.pstart_sep()
        print(f"BCO Update {self.update_i}/{self.args.bco_alpha}")
        print('---')

        print('Training inverse function')
        dataset_idxs = list(range(len(dataset)))
        np.random.shuffle(dataset_idxs)

        eval_len = int(len(dataset_idxs) * self.args.bco_inv_eval_holdout)
        if eval_len != 0.0:
            train_trans_sampler = BatchSampler(SubsetRandomSampler(
                dataset_idxs[:-eval_len]), self.args.bco_inv_batch_size, drop_last=False)
            val_trans_sampler = BatchSampler(SubsetRandomSampler(
                dataset_idxs[-eval_len:]), self.args.bco_inv_batch_size, drop_last=False)
        else:
            train_trans_sampler = BatchSampler(SubsetRandomSampler(
                dataset_idxs), self.args.bco_inv_batch_size, drop_last=False)

        if self.args.bco_inv_load is None or self.update_i > 0:
            infer_ac_losses = self._train_inv_func(train_trans_sampler, dataset)
            rutils.plot_line(infer_ac_losses, f"ac_inv_loss_{self.update_i}.png",
                             self.args, not self.args.no_wb,
                             self.get_completed_update_steps(self.update_i))
            if self.update_i == 0:
                # Only save the inverse model on the first epoch for debugging
                # purposes
                rutils.save_model(self.inv_func, f"inv_func_{self.update_i}.pt",
                        self.args)

        if eval_len != 0.0:
            if not isinstance(self.policy.action_space, spaces.Discrete):
                raise ValueError(('Evaluating the holdout accuracy is only',
                                  ' supported for discrete action spaces right now'))
            accuracy = self._infer_inv_accuracy(val_trans_sampler, dataset)
            print('Inferred actions with %.2f accuracy' % accuracy)

        if isinstance(self.expert_dataset, torch.utils.data.Subset):
            s0 = self.expert_dataset.dataset.trajs['obs'].to(self.args.device).float()
            s1 = self.expert_dataset.dataset.trajs['next_obs'].to(self.args.device).float()
            dataset_device = self.expert_dataset.dataset.trajs['obs'].device
        else:
            s0 = self.expert_dataset.trajs['obs'].to(self.args.device).float()
            s1 = self.expert_dataset.trajs['next_obs'].to(self.args.device).float()
            dataset_device = self.expert_dataset.trajs['obs'].device

        # Perform inference on the expert states
        with torch.no_grad():
            pred_actions = self.inv_func(s0, s1).to(dataset_device)
            pred_actions = rutils.get_ac_compact(self.policy.action_space,
                                                 pred_actions)
            if not self.args.bco_oracle_actions:
                if isinstance(self.expert_dataset, torch.utils.data.Subset):
                    self.expert_dataset.dataset.trajs['actions'] = pred_actions
                else:
                    self.expert_dataset.trajs['actions'] = pred_actions
        # Recreate the dataset for BC training so we can be sure it has the
        # most recent data.
        self._create_train_loader(self.args)

        print('Training Policy')
        self.full_train(self.update_i)
        self.update_i += 1
        rutils.pend_sep()