Exemple #1
0
    def _compute_discrim_reward(self, storage, step, add_info):
        state = rutils.get_def_obs(storage.get_obs(step))

        next_state = rutils.get_def_obs(storage.get_obs(step + 1))
        masks = storage.masks[step + 1]
        finished_episodes = [i for i in range(len(masks)) if masks[i] == 0.0]
        add_inputs = {k: v[(step + 1) - 1] for k, v in add_info.items()}
        obsfilt = self.get_env_ob_filt()
        for i in finished_episodes:
            next_state[i] = add_inputs['final_obs'][i]
            if obsfilt is not None:
                next_state[i] = torch.FloatTensor(
                    obsfilt(next_state[i].cpu().numpy(),
                            update=False)).to(self.args.device)

        d_val = self.discrim_net(state, next_state)
        s = torch.sigmoid(d_val)
        eps = 1e-20
        if self.args.reward_type == 'airl':
            reward = (s + eps).log() - (1 - s + eps).log()
        elif self.args.reward_type == 'gail':
            reward = (s + eps).log()
        elif self.args.reward_type == 'raw':
            reward = d_val
        elif self.args.reward_type == 'gaifo':
            reward = -1.0 * (s + eps).log()
        else:
            raise ValueError(
                f"Unrecognized reward type {self.args.reward_type}")
        return reward
Exemple #2
0
    def pre_main(self, log, env_interface):
        """
        Gathers the random experience and trains the inverse model on it.
        """
        n_steps = self.args.bco_expl_steps // self.args.num_processes
        base_data_dir = 'data/traj/bco'
        if not osp.exists(base_data_dir):
            os.makedirs(base_data_dir)

        loaded_traj = None
        if self.args.bco_expl_load is not None:
            load_path = osp.join(base_data_dir, self.args.bco_expl_load)
            if osp.exists(load_path) and not self.args.bco_expl_refresh:
                loaded_traj = torch.load(load_path)
                states = loaded_traj['states']
                actions = loaded_traj['actions']
                dones = loaded_traj['dones']
                print(f"Loaded expl trajectories from {load_path}")

        if loaded_traj is None:
            envs = make_vec_envs_easy(self.args.env_name, self.args.num_processes,
                                      env_interface, self.get_env_settings(self.args), self.args)
            policy = RandomPolicy()
            policy.init(envs.observation_space, envs.action_space, self.args)
            rutils.pstart_sep()
            print('Collecting exploration experience')
            states = []
            actions = []
            state = rutils.get_def_obs(envs.reset())
            states.extend(state)
            dones = [True]
            for _ in tqdm(range(n_steps)):
                ac_info = policy.get_action(state, None, None, None, None)
                state, reward, done, info = envs.step(ac_info.take_action)
                state = rutils.get_def_obs(state)
                actions.extend(ac_info.action)
                dones.extend(done)
                states.extend(state)
            rutils.pend_sep()
            envs.close()

        if self.args.bco_expl_load is not None and loaded_traj is None:
            # Save the data.
            torch.save({
                'states': states,
                'actions': actions,
                'dones': dones,
            }, load_path)
            print(f"Saved data to {load_path}")

        if self.args.bco_inv_load is not None:
            self.inv_func.load_state_dict(torch.load(self.args.bco_inv_load))

        self._update_all(states, actions, dones)
Exemple #3
0
def traj_to_tensor(trajs, device):
    """
    - trajs: [B, N, 5]. Batch size of B, trajectory length of N. 5 for state,
      action, mask, info, reward.
    """
    # Get the data into a format we can work with.
    max_traj_len = max([len(traj) for traj in trajs])
    n_trajs = len(trajs)
    ob_dim = rutils.get_def_obs(trajs[0][0][0]).shape
    ac_dim = trajs[0][0][1].shape
    other_obs_info = {
        k: x.shape
        for k, x in rutils.get_other_obs(trajs[0][0][0]).items()
    }

    obs = torch.zeros(n_trajs, max_traj_len, *ob_dim).to(device)
    obs_add = {
        k: torch.zeros(n_trajs, max_traj_len, *shp).to(device)
        for k, shp in other_obs_info.items()
    }
    actions = torch.zeros(n_trajs, max_traj_len, *ac_dim).to(device)
    masks = torch.zeros(n_trajs, max_traj_len, 1).to(device)
    rewards = torch.zeros(n_trajs, max_traj_len, 1).to(device)

    traj_mask = torch.zeros(n_trajs, max_traj_len).to(device)

    add_infos = {
        k: torch.zeros(len(trajs), max_traj_len, *v.shape)
        for k, v in trajs[0][-1][3].items()
    }

    for i in range(len(trajs)):
        traj_len = len(trajs[i])
        o, a, m, infos, r = list(zip(*trajs[i]))
        for j, inf in enumerate(infos):
            for k, v in inf.items():
                add_infos[k][i, j] = v

        traj_mask[i, :traj_len] = 1.0
        obs[i, :traj_len] = torch.stack([rutils.get_def_obs(o_i)
                                         for o_i in o]).to(device)
        for k in obs_add:
            obs_add[k][i, :traj_len] = torch.stack([o_i[k]
                                                    for o_i in o]).to(device)
        actions[i, :traj_len] = torch.stack(a).to(device)
        masks[i, :traj_len] = torch.tensor(m).unsqueeze(-1).to(device)
        rewards[i, :traj_len] = torch.stack(r).to(device)

    for k in add_infos:
        add_infos[k] = add_infos[k].to(device)
    return obs, obs_add, actions, masks, add_infos, rewards
Exemple #4
0
    def training_iter(self, update_iter):
        self.log.start_interval_log()
        self.updater.pre_update(update_iter)
        for step in range(self.args.num_steps):
            # Sample actions
            obs = self.storage.get_obs(step)

            step_info = get_step_info(update_iter, step, self.episode_count,
                                      self.args)
            with torch.no_grad():
                ac_info = self.policy.get_action(
                    utils.get_def_obs(obs, self.args.policy_ob_key),
                    utils.get_other_obs(obs),
                    self.storage.get_hidden_state(step),
                    self.storage.get_masks(step), step_info)
                if self.args.clip_actions:
                    ac_info.clip_action(*self.ac_tensor)

            next_obs, reward, done, infos = self.envs.step(ac_info.take_action)

            reward += ac_info.add_reward

            step_log_vals = utils.agg_ep_log_stats(infos, ac_info.extra)

            self.episode_count += sum([int(d) for d in done])
            self.log.collect_step_info(step_log_vals)

            self.storage.insert(obs, next_obs, reward, done, infos, ac_info)

        updater_log_vals = self.updater.update(self.storage)
        self.storage.after_update()

        return updater_log_vals
Exemple #5
0
    def step_wait(self):
        obs, rews, news, infos = self.venv.step_wait()

        stacked_obs, infos = self.stacked_obs.update_obs(
            rutils.get_def_obs(obs), news, infos)

        obs = rutils.set_def_obs(obs, stacked_obs)
        return obs, rews, news, infos
Exemple #6
0
    def get_action(self, state, add_state, hxs, masks, step_info):
        n_procs = rutils.get_def_obs(state).shape[0]
        action = torch.tensor([
            self.action_space.sample() for _ in range(n_procs)
        ]).to(self.args.device)
        if isinstance(self.action_space, spaces.Discrete):
            action = action.unsqueeze(-1)

        return create_simple_action_data(action, hxs)
Exemple #7
0
    def observation(self, obs):
        frame = rutils.get_def_obs(obs)
        if self.grayscale:
            frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)

        frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
        if self.grayscale:
            frame = np.expand_dims(frame, -1)

        return rutils.set_def_obs(obs, frame)
 def init_storage(self, obs):
     super().init_storage(obs)
     batch_size = rutils.get_def_obs(obs).shape[0]
     hxs = {}
     for k, dim in self.hidden_state_dims.items():
         hxs[k] = torch.zeros(batch_size, dim)
     self.last_seen = {
             'obs': obs,
             'masks': torch.zeros(batch_size, 1),
             'hxs': hxs,
             }
Exemple #9
0
 def _get_next_value(self, rollouts):
     """
     Gets the value of the final observations. Needed if you need to
     estimate the returns of any partial trajectories.
     """
     with torch.no_grad():
         next_value = self.policy.get_value(
             rutils.get_def_obs(rollouts.get_obs(-1),
                                self.args.policy_ob_key),
             rutils.get_other_obs(rollouts.get_obs(-1),
                                  self.args.policy_ob_key),
             rollouts.get_hidden_state(-1), rollouts.masks[-1]).detach()
     return next_value
Exemple #10
0
    def mod_render_frames(self, frame, env_cur_obs, env_cur_action,
                          env_cur_reward, env_next_obs, **kwargs):
        use_cur_obs = rutils.get_def_obs(env_cur_obs)
        use_cur_obs = torch.FloatTensor(use_cur_obs).unsqueeze(0).to(
            self.args.device)

        if env_cur_action is not None:
            use_action = torch.FloatTensor(env_cur_action).unsqueeze(0).to(
                self.args.device)
            disc_val = self._compute_disc_val(use_cur_obs, use_action).item()
        else:
            disc_val = 0.0

        frame = append_text_to_image(frame, [
            "Discrim: %.3f" % disc_val,
            "Reward: %.3f" %
            (env_cur_reward if env_cur_reward is not None else 0.0)
        ])
        return frame
    def insert(self, obs, next_obs, reward, done, infos, ac_info):
        super().insert(obs, next_obs, reward, done, infos, ac_info)

        masks, bad_masks = self.compute_masks(done, infos)

        batch_size = rutils.get_def_obs(obs).shape[0]
        for i in range(batch_size):
            if bad_masks[i] == 0 and self.args.use_proper_time_limits:
                # we are not actually done.
                masks[i] = 1.0
            self._push_transition({
                    'action': ac_info.action[i],
                    'state': rutils.obs_select(obs, i),
                    'reward': reward[i],
                    'next_state': rutils.obs_select(next_obs, i),
                    'mask': masks[i],
                    'hxs': rutils.deep_dict_select(ac_info.hxs, i),
                    })

        self.last_seen = {
                'obs': next_obs,
                'masks': masks,
                'hxs': ac_info.hxs,
                }
Exemple #12
0
 def _trans_agent_state(self, state, other_state=None):
     if not self.args.gail_state_norm:
         if other_state is None:
             return state['raw_obs']
         return other_state['raw_obs']
     return rutils.get_def_obs(state)
Exemple #13
0
 def get_def_obs_seq(self):
     if isinstance(self.obs, dict):
         return rutils.get_def_obs(self.obs)
     else:
         return self.obs
Exemple #14
0
def evaluate(args, alg_env_settings, policy, true_vec_norm, env_interface,
             num_steps, mode, eval_envs, log, create_traj_saver_fn):
    if args.eval_num_processes is None:
        num_processes = args.num_processes
    else:
        num_processes = args.eval_num_processes

    if eval_envs is None:
        eval_envs = make_vec_envs(args.env_name,
                                  args.seed + num_steps,
                                  num_processes,
                                  args.gamma,
                                  args.device,
                                  True,
                                  env_interface,
                                  args,
                                  alg_env_settings,
                                  set_eval=True)

    assert get_vec_normalize(eval_envs) is None, 'Norm is manually applied'

    if true_vec_norm is not None:
        obfilt = true_vec_norm._obfilt
    else:

        def obfilt(x, update):
            return x

    eval_episode_rewards = []
    eval_def_stats = defaultdict(list)
    ep_stats = defaultdict(list)

    obs = eval_envs.reset()

    hidden_states = {}
    for k, dim in policy.get_storage_hidden_states().items():
        hidden_states[k] = torch.zeros(num_processes, dim).to(args.device)
    eval_masks = torch.zeros(num_processes, 1, device=args.device)

    frames = []
    infos = None

    policy.eval()
    if args.eval_save and create_traj_saver_fn is not None:
        traj_saver = create_traj_saver_fn(
            osp.join(args.traj_dir, args.env_name, args.prefix))
    else:
        assert not args.eval_save, (
            'Cannot save evaluation without ',
            'specifying the eval saver creator function')

    total_num_eval = num_processes * args.num_eval

    # Measure the number of episodes completed
    pbar = tqdm(total=total_num_eval)
    evaluated_episode_count = 0
    n_succs = 0
    n_fails = 0
    succ_frames = []
    fail_frames = []
    if args.render_succ_fails and args.eval_num_processes > 1:
        raise ValueError("""
                Can only render successes and failures when the number of
                processes is 1.
                """)

    if args.num_render is None or args.num_render > 0:
        frames.extend(
            get_render_frames(eval_envs, env_interface, None, None, None, None,
                              None, args, evaluated_episode_count))

    while evaluated_episode_count < total_num_eval:
        step_info = get_empty_step_info()
        with torch.no_grad():
            act_obs = obfilt(utils.ob_to_np(obs), update=False)
            act_obs = utils.ob_to_tensor(act_obs, args.device)

            ac_info = policy.get_action(utils.get_def_obs(act_obs),
                                        utils.get_other_obs(obs),
                                        hidden_states, eval_masks, step_info)

            hidden_states = ac_info.hxs

        # Observe reward and next obs
        next_obs, _, done, infos = eval_envs.step(ac_info.take_action)
        if args.eval_save:
            finished_count = traj_saver.collect(obs, next_obs, done,
                                                ac_info.take_action, infos)
        else:
            finished_count = sum([int(d) for d in done])

        pbar.update(finished_count)
        evaluated_episode_count += finished_count

        cur_frame = None

        eval_masks = torch.tensor([[0.0] if done_ else [1.0]
                                   for done_ in done],
                                  dtype=torch.float32,
                                  device=args.device)

        should_render = (args.num_render) is None or (evaluated_episode_count <
                                                      args.num_render)
        if args.render_succ_fails:
            should_render = n_succs < args.num_render or n_fails < args.num_render

        if should_render:
            frames.extend(
                get_render_frames(eval_envs, env_interface, obs, next_obs,
                                  ac_info.take_action, eval_masks, infos, args,
                                  evaluated_episode_count))
        obs = next_obs

        step_log_vals = utils.agg_ep_log_stats(infos, ac_info.extra)
        for k, v in step_log_vals.items():
            ep_stats[k].extend(v)

        if 'ep_success' in step_log_vals and args.render_succ_fails:
            is_succ = step_log_vals['ep_success'][0]
            if is_succ == 1.0:
                if n_succs < args.num_render:
                    succ_frames.extend(frames)
                n_succs += 1
            else:
                if n_fails < args.num_render:
                    fail_frames.extend(frames)
                n_fails += 1
            frames = []

    pbar.close()
    info = {}
    if args.eval_save:
        traj_saver.save()

    ret_info = {}

    print(" Evaluation using %i episodes:" % len(ep_stats['r']))
    for k, v in ep_stats.items():
        print(' - %s: %.5f' % (k, np.mean(v)))
        ret_info[k] = np.mean(v)

    if args.render_succ_fails:
        # Render the success and failures to two separate files.
        save_frames(succ_frames, "succ_" + mode, num_steps, args)
        save_frames(fail_frames, "fail_" + mode, num_steps, args)
    else:
        save_file = save_frames(frames, mode, num_steps, args)
        if save_file is not None:
            log.log_video(save_file, num_steps, args.vid_fps)

    # Switch policy back to train mode
    policy.train()

    return ret_info, eval_envs
Exemple #15
0
 def mod_env_ob_filt(state, update=True):
     state = obfilt(state, update)
     state = rutils.get_def_obs(state)
     return state
Exemple #16
0
 def reset(self):
     obs = self.venv.reset()
     stacked_obs = self.stacked_obs.reset(rutils.get_def_obs(obs))
     obs = rutils.set_def_obs(obs, stacked_obs)
     return obs
Exemple #17
0
 def init_storage(self, obs):
     self.traj_storage = [[]
                          for _ in range(rutils.get_def_obs(obs).shape[0])]
Exemple #18
0
 def _norm_state(self, x):
     obs_x = torch.clamp((rutils.get_def_obs(x) - self.norm_mean) /
                         torch.pow(self.norm_var + 1e-8, 0.5), -10.0, 10.0)
     if isinstance(x, dict):
         x['observation'] = obs_x
     return x