Пример #1
0
def rollout_path(env, task_params, obs_task_params, post_cond_policy):
    cur_eval_path_builder = PathBuilder()
    
    # reset the env using the params
    observation = env.reset(task_params=task_params, obs_task_params=obs_task_params)
    terminal = False
    task_identifier = env.task_identifier

    while (not terminal) and len(cur_eval_path_builder) < MAX_PATH_LENGTH:
        agent_obs = observation['obs']
        action, agent_info = post_cond_policy.get_action(agent_obs)
        
        next_ob, raw_reward, terminal, env_info = (env.step(action))
        terminal = False
        
        reward = raw_reward
        terminal = np.array([terminal])
        reward = np.array([reward])
        cur_eval_path_builder.add_all(
            observations=observation,
            actions=action,
            rewards=reward,
            next_observations=next_ob,
            terminals=terminal,
            agent_infos=agent_info,
            env_infos=env_info,
            task_identifiers=task_identifier
        )
        observation = next_ob

    return cur_eval_path_builder.get_all_stacked()
Пример #2
0
    def load_paths(self):
        paths = []
        for i in range(len(self.data)):
            p = self.data[i]
            H = len(p["observations"]) - 1

            path_builder = PathBuilder()

            for t in range(H):
                p["observations"][t]

                ob = path["observations"][t, :]
                action = path["actions"][t, :]
                reward = path["rewards"][t]
                next_ob = path["observations"][t + 1, :]
                terminal = 0
                agent_info = {}  # todo (need to unwrap each key)
                env_info = {}  # todo (need to unwrap each key)

                path_builder.add_all(
                    observations=ob,
                    actions=action,
                    rewards=reward,
                    next_observations=next_ob,
                    terminals=terminal,
                    agent_infos=agent_info,
                    env_infos=env_info,
                )

            path = path_builder.get_all_stacked()
            paths.append(path)
        return paths
    def load_path(self, path, replay_buffer, obs_dict=None):
        rewards = []
        path_builder = PathBuilder()

        print("loading path, length", len(path["observations"]),
              len(path["actions"]))
        H = min(len(path["observations"]), len(path["actions"]))
        print("actions", np.min(path["actions"]), np.max(path["actions"]))

        for i in range(H):
            if obs_dict:
                ob = path["observations"][i][self.obs_key]
                next_ob = path["next_observations"][i][self.obs_key]
            else:
                ob = path["observations"][i]
                next_ob = path["next_observations"][i]

            if i == 0:
                current_obs = np.zeros((self.stack_obs + 1, len(ob)))
                current_obs[-2, :] = ob
                current_obs[-1, :] = next_ob
            else:
                current_obs = np.vstack((current_obs[1:, :], next_ob))
                assert (current_obs[-2, :] == ob
                        ).all(), "mismatch between obs and next_obs"
            obs1 = current_obs[:self.stack_obs, :].flatten()
            obs2 = current_obs[1:, :].flatten()

            action = path["actions"][i]
            reward = path["rewards"][i]
            terminal = path["terminals"][i]
            if not self.load_terminals:
                terminal = np.zeros(terminal.shape)
            agent_info = path["agent_infos"][i]
            env_info = path["env_infos"][i]

            if self.recompute_reward:
                reward = self.env.compute_reward(
                    action,
                    next_ob,
                )

            reward = np.array([reward])
            rewards.append(reward)
            terminal = np.array([terminal]).reshape((1, ))
            path_builder.add_all(
                observations=obs1,
                actions=action,
                rewards=reward,
                next_observations=obs2,
                terminals=terminal,
                agent_infos=agent_info,
                env_infos=env_info,
            )
        self.demo_trajectory_rewards.append(rewards)
        path = path_builder.get_all_stacked()
        replay_buffer.add_path(path)
        print("path sum rewards", sum(rewards), len(rewards))
Пример #4
0
    def load_path(self, path, replay_buffer, obs_dict=None):
        rewards = []
        path_builder = PathBuilder()
        H = min(len(path["observations"]), len(path["actions"]))

        if obs_dict:
            traj_obs = self.preprocess(path["observations"])
            next_traj_obs = self.preprocess(path["next_observations"])
        else:
            traj_obs = self.env.encode(path["observations"])
            next_traj_obs = self.env.encode(path["next_observations"])

        for i in range(H):
            ob = traj_obs[i]
            next_ob = next_traj_obs[i]
            action = path["actions"][i]

            # #temp fix#
            # ob['state_desired_goal'] = np.zeros_like(ob['state_desired_goal'])
            # ob['latent_desired_goal'] = np.zeros_like(ob['latent_desired_goal'])

            # next_ob['state_desired_goal'] = np.zeros_like(next_ob['state_desired_goal'])
            # next_ob['latent_desired_goal'] = np.zeros_like(next_ob['latent_desired_goal'])

            # action[3] /= 5
            # #temp fix#

            reward = path["rewards"][i]
            terminal = path["terminals"][i]
            if not self.load_terminals:
                terminal = np.zeros(terminal.shape)
            agent_info = path["agent_infos"][i]
            env_info = path["env_infos"][i]
            if self.reward_fn:
                reward = self.reward_fn(ob, action, next_ob, next_ob)

            reward = np.array([reward]).flatten()
            rewards.append(reward)
            terminal = np.array([terminal]).reshape((1, ))
            path_builder.add_all(
                observations=ob,
                actions=action,
                rewards=reward,
                next_observations=next_ob,
                terminals=terminal,
                agent_infos=agent_info,
                env_infos=env_info,
            )
        self.demo_trajectory_rewards.append(rewards)
        path = path_builder.get_all_stacked()
        replay_buffer.add_path(path)
        print("rewards", np.min(rewards), np.max(rewards))
        print("loading path, length", len(path["observations"]),
              len(path["actions"]))
        print("actions", np.min(path["actions"]), np.max(path["actions"]))
        print("path sum rewards", sum(rewards), len(rewards))
Пример #5
0
    def load_path(self, path, replay_buffer, obs_dict=None):
        # Filter data #
        if not self.data_filter_fn(path): return

        rewards = []
        path_builder = PathBuilder()

        print("loading path, length", len(path["observations"]),
              len(path["actions"]))
        H = min(len(path["observations"]), len(path["actions"]))
        print("actions", np.min(path["actions"]), np.max(path["actions"]))

        for i in range(H):
            if obs_dict:
                ob = path["observations"][i][self.obs_key]
                next_ob = path["next_observations"][i][self.obs_key]
            else:
                ob = path["observations"][i]
                next_ob = path["next_observations"][i]
            action = path["actions"][i]
            reward = path["rewards"][i]
            terminal = path["terminals"][i]
            if not self.load_terminals:
                terminal = np.zeros(terminal.shape)
            agent_info = path["agent_infos"][i]
            env_info = path["env_infos"][i]

            if self.recompute_reward:
                reward = self.env.compute_reward(
                    action,
                    next_ob,
                )

            reward = np.array([reward]).flatten()
            rewards.append(reward)
            terminal = np.array([terminal]).reshape((1, ))
            path_builder.add_all(
                observations=ob,
                actions=action,
                rewards=reward,
                next_observations=next_ob,
                terminals=terminal,
                agent_infos=agent_info,
                env_infos=env_info,
            )
        self.demo_trajectory_rewards.append(rewards)
        path = path_builder.get_all_stacked()
        replay_buffer.add_path(path)
        print("path sum rewards", sum(rewards), len(rewards))
Пример #6
0
 def collect_new_steps(
         self,
         max_path_length,
         num_steps,
         discard_incomplete_paths,
         random=False,
 ):
     steps_collector = PathBuilder()
     for _ in range(num_steps):
         self.collect_one_step(
             max_path_length,
             discard_incomplete_paths,
             steps_collector,
             random,
         )
     return [steps_collector.get_all_stacked()]
def rollout_path(env, task_params, obs_task_params, post_cond_policy, max_path_length, task_idx):
    cur_eval_path_builder = PathBuilder()
    
    # reset the env using the params
    observation = env.reset(task_params=task_params, obs_task_params=obs_task_params)
    terminal = False
    task_identifier = env.task_identifier

    while (not terminal) and len(cur_eval_path_builder) < max_path_length:
        agent_obs = observation['obs']
        action, agent_info = post_cond_policy.get_action(agent_obs)
        
        next_ob, raw_reward, terminal, env_info = (env.step(action))
        # img = env.render(mode='rgb_array', width=200, height=200)
        if len(cur_eval_path_builder) % 10 == 0:
            # img = env.render(mode='rgb_array')

            env._wrapped_env._get_viewer('rgb_array').render(200, 200, camera_id=0)
            # window size used for old mujoco-py:
            data = env._wrapped_env._get_viewer('rgb_array').read_pixels(200, 200, depth=False)
            # original image is upside-down, so flip it
            img = data[::-1, :, :]
            imsave('plots/walker_irl_frames/walker_task_%02d_step_%03d.png' % (task_idx, len(cur_eval_path_builder)), img)
        terminal = False

        # print(env_info['l2_dist'])
        # print('{}: {}'.format(agent_obs[-3:], env_info['l2_dist']))
        # print(agent_obs)
        # print(env_info['l2_dist'])
        
        reward = raw_reward
        terminal = np.array([terminal])
        reward = np.array([reward])
        cur_eval_path_builder.add_all(
            observations=observation,
            actions=action,
            rewards=reward,
            next_observations=next_ob,
            terminals=terminal,
            agent_infos=agent_info,
            env_infos=env_info,
            task_identifiers=task_identifier
        )
        observation = next_ob

    return cur_eval_path_builder.get_all_stacked()
Пример #8
0
 def test_add_and_get_all(self):
     path = PathBuilder()
     path.add_all(
         action=np.array([1, 2, 3]),
         obs=-np.array([1, 2, 3]),
     )
     path.add_all(
         action=np.array([10, 2, 3]),
         obs=-np.array([10, 2, 3]),
     )
     result = path.get_all_stacked()
     self.assertNpArraysEqual(result['action'],
                              np.array([
                                  [1, 2, 3],
                                  [10, 2, 3],
                              ]))
     self.assertNpArraysEqual(result['obs'], -np.array([
         [1, 2, 3],
         [10, 2, 3],
     ]))
Пример #9
0
    def load_path(self, path, replay_buffer):
        rewards = []
        path_builder = PathBuilder()

        print("loading path, length", len(path["observations"]),
              len(path["actions"]))
        H = min(len(path["observations"]), len(path["actions"]))
        print("actions", np.min(path["actions"]), np.max(path["actions"]))

        for i in range(H):
            ob = path["observations"][i]
            action = path["actions"][i]
            reward = path["rewards"][i]
            next_ob = path["next_observations"][i]
            terminal = path["terminals"][i]
            agent_info = path["agent_infos"][i]
            env_info = path["env_infos"][i]

            if self.recompute_reward:
                reward = self.env.compute_reward(
                    action,
                    next_ob,
                )

            reward = np.array([reward])
            rewards.append(reward)
            terminal = np.array([terminal]).reshape((1, ))
            path_builder.add_all(
                observations=ob,
                actions=action,
                rewards=reward,
                next_observations=next_ob,
                terminals=terminal,
                agent_infos=agent_info,
                env_infos=env_info,
            )
        self.demo_trajectory_rewards.append(rewards)
        path = path_builder.get_all_stacked()
        replay_buffer.add_path(path)
def rollout_path(env, task_params, obs_task_params, post_cond_policy,
                 max_path_length, eval_expert, render):
    cur_eval_path_builder = PathBuilder()

    # reset the env using the params
    observation = env.reset(task_params=task_params,
                            obs_task_params=obs_task_params)
    terminal = False
    task_identifier = env.task_identifier

    while (not terminal) and len(cur_eval_path_builder) < max_path_length:
        agent_obs = observation['obs']
        action, agent_info = post_cond_policy.get_action(agent_obs)

        next_ob, raw_reward, terminal, env_info = (env.step(action))
        terminal = False

        # print(env_info['l2_dist'])
        # print('{}: {}'.format(agent_obs[-3:], env_info['l2_dist']))
        # print(agent_obs)
        # print(env_info['l2_dist'])

        reward = raw_reward
        terminal = np.array([terminal])
        reward = np.array([reward])
        cur_eval_path_builder.add_all(observations=observation,
                                      actions=action,
                                      rewards=reward,
                                      next_observations=next_ob,
                                      terminals=terminal,
                                      agent_infos=agent_info,
                                      env_infos=env_info,
                                      task_identifiers=task_identifier)
        observation = next_ob

        if render: env.render()

    return cur_eval_path_builder.get_all_stacked()
Пример #11
0
 def load_path(self, path, replay_buffer):
     path_builder = PathBuilder()
     for (
             ob,
             action,
             reward,
             next_ob,
             terminal,
             agent_info,
             env_info,
     ) in zip(
             path["observations"],
             path["actions"],
             path["rewards"],
             path["next_observations"],
             path["terminals"],
             path["agent_infos"],
             path["env_infos"],
     ):
         # goal = path["goal"]["state_desired_goal"][0, :]
         # import pdb; pdb.set_trace()
         # print(goal.shape, ob["state_observation"])
         # state_observation = np.concatenate((ob["state_observation"], goal))
         action = action[:2]
         reward = np.array([reward])
         terminal = np.array([terminal])
         path_builder.add_all(
             observations=ob,
             actions=action,
             rewards=reward,
             next_observations=next_ob,
             terminals=terminal,
             agent_infos=agent_info,
             env_infos=env_info,
         )
     path = path_builder.get_all_stacked()
     replay_buffer.add_path(path)
Пример #12
0
def take_step_in_env_per_thread(pid, queue, env, policy, render, reward_scale,
                                steps, max_path_length, n_env_steps_total):
    set_seed(pid)
    n_rollouts_total = 0
    current_path_builder = PathBuilder()
    exploration_paths = []
    replay_samples = {
        'observations': [],
        'actions': [],
        'rewards': [],
        'next_observations': [],
        'terminals': [],
        'agent_infos': [],
        'env_infos': [],
    }

    policy.reset()
    observation = env.reset()
    policy.set_num_steps_total(n_env_steps_total)

    for _ in range(steps):

        action, agent_info = policy.get_action(observation)
        if pid == 0 and render:
            env.render()
        next_ob, raw_reward, terminal, env_info = env.step(action)
        reward = np.array([raw_reward * reward_scale])
        terminal = np.array([terminal])

        replay_samples['observations'].append(observation)
        replay_samples['actions'].append(action)
        replay_samples['rewards'].append(reward)
        replay_samples['next_observations'].append(next_ob)
        replay_samples['terminals'].append(terminal)
        replay_samples['agent_infos'].append(agent_info)
        replay_samples['env_infos'].append(env_info)

        current_path_builder.add_all(
            observations=observation,
            actions=action,
            rewards=reward,
            next_observations=next_ob,
            terminals=terminal,
            agent_infos=agent_info,
            env_infos=env_info,
        )

        if terminal or len(current_path_builder) >= max_path_length:
            # cannot let replay buffer terminate episode
            n_rollouts_total += 1
            if len(current_path_builder) > 0:
                exploration_paths.append(
                    current_path_builder.get_all_stacked())
                current_path_builder = PathBuilder()
            policy.reset()
            observation = env.reset()
        else:
            observation = next_ob

    if queue is None:
        return exploration_paths, replay_samples, n_rollouts_total
    else:
        queue.put([pid, exploration_paths, replay_samples, n_rollouts_total])
Пример #13
0
class RLAlgorithm(metaclass=abc.ABCMeta):
    def __init__(
            self,
            env,
            exploration_policy: ExplorationPolicy,
            training_env=None,
            num_epochs=100,
            num_steps_per_epoch=10000,
            num_steps_per_eval=1000,
            num_updates_per_env_step=1,
            max_num_episodes=None,
            batch_size=1024,
            max_path_length=1000,
            discount=0.99,
            replay_buffer_size=1000000,
            reward_scale=1,
            render=False,
            save_replay_buffer=False,
            save_algorithm=False,
            save_environment=False,
            save_best=False,
            save_best_starting_from_epoch=0,
            eval_sampler=None,
            eval_policy=None,
            replay_buffer=None,
            # for compatibility with deepmind control suite
            # Right now the semantics is that if observations is not a dictionary
            # then it means the policy just uses that. If it's a dictionary, it
            # checks whether policy_uses_pixels to see if it's true or false and
            # based on that it decides whether the policy takes 'pixels' or 'obs'
            # from the dictionary
            policy_uses_pixels=False,
            freq_saving=1,
            # for meta-learning
            policy_uses_task_params=False, # whether the policy uses the task parameters
            concat_task_params_to_policy_obs=False, # how the policy sees the task parameters
            # this is useful when you want to generate trajectories from the expert using the
            # exploration policy
            do_not_train=False,
            # some environment like halfcheetah_v2 have a timelimit that defines the terminal
            # this is used as a minor hack to turn off time limits
            no_terminal=False,
            **kwargs
    ):
        """
        Base class for RL Algorithms
        :param env: Environment used to evaluate.
        :param exploration_policy: Policy used to explore
        :param training_env: Environment used by the algorithm. By default, a
        copy of `env` will be made.
        :param num_epochs:
        :param num_steps_per_epoch:
        :param num_steps_per_eval:
        :param num_updates_per_env_step: Used by online training mode.
        :param num_updates_per_epoch: Used by batch training mode.
        :param batch_size:
        :param max_path_length:
        :param discount:
        :param replay_buffer_size:
        :param reward_scale:
        :param render:
        :param save_replay_buffer:
        :param save_algorithm:
        :param save_environment:
        :param eval_sampler:
        :param eval_policy: Policy to evaluate with.
        :param replay_buffer:
        """
        self.training_env = training_env or pickle.loads(pickle.dumps(env))
        # self.training_env = training_env or deepcopy(env)
        self.exploration_policy = exploration_policy
        self.num_epochs = num_epochs
        self.num_env_steps_per_epoch = num_steps_per_epoch
        self.num_steps_per_eval = num_steps_per_eval
        self.num_updates_per_train_call = num_updates_per_env_step
        self.batch_size = batch_size
        self.max_path_length = max_path_length
        self.discount = discount
        self.replay_buffer_size = replay_buffer_size
        self.reward_scale = reward_scale
        self.render = render
        self.save_replay_buffer = save_replay_buffer
        self.save_algorithm = save_algorithm
        self.save_environment = save_environment
        self.save_best = save_best
        self.save_best_starting_from_epoch = save_best_starting_from_epoch
        self.policy_uses_pixels = policy_uses_pixels
        self.policy_uses_task_params = policy_uses_task_params
        self.concat_task_params_to_policy_obs = concat_task_params_to_policy_obs
        self.freq_saving = freq_saving
        if eval_sampler is None:
            if eval_policy is None:
                eval_policy = exploration_policy
            eval_sampler = InPlacePathSampler(
                env=env,
                policy=eval_policy,
                max_samples=self.num_steps_per_eval + self.max_path_length,
                max_path_length=self.max_path_length, policy_uses_pixels=policy_uses_pixels,
                policy_uses_task_params=policy_uses_task_params,
                concat_task_params_to_policy_obs=concat_task_params_to_policy_obs
            )
        self.eval_policy = eval_policy
        self.eval_sampler = eval_sampler

        self.action_space = env.action_space
        self.obs_space = env.observation_space
        self.env = env
        if replay_buffer is None:
            replay_buffer = EnvReplayBuffer(
                self.replay_buffer_size,
                self.env,
                policy_uses_pixels=self.policy_uses_pixels,
                policy_uses_task_params=self.policy_uses_task_params,
                concat_task_params_to_policy_obs=self.concat_task_params_to_policy_obs
            )
        self.replay_buffer = replay_buffer

        self._n_env_steps_total = 0
        self._n_train_steps_total = 0
        self._n_rollouts_total = 0
        self._do_train_time = 0
        self._epoch_start_time = None
        self._algo_start_time = None
        self._old_table_keys = None
        self._current_path_builder = PathBuilder()
        self._exploration_paths = []
        self.do_not_train = do_not_train
        self.num_episodes = 0
        self.max_num_episodes = max_num_episodes if max_num_episodes is not None else float('inf')
        self.no_terminal = no_terminal

    def train(self, start_epoch=0):
        self.pretrain()
        if start_epoch == 0:
            params = self.get_epoch_snapshot(-1)
            logger.save_itr_params(-1, params)
        self.training_mode(False)
        self._n_env_steps_total = start_epoch * self.num_env_steps_per_epoch
        gt.reset()
        gt.set_def_unique(False)
        self.train_online(start_epoch=start_epoch)

    def pretrain(self):
        """
        Do anything before the main training phase.
        """
        pass

    def train_online(self, start_epoch=0):
        self._current_path_builder = PathBuilder()
        observation = self._start_new_rollout()
        for epoch in gt.timed_for(
                range(start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            self._start_epoch(epoch)
            for _ in range(self.num_env_steps_per_epoch):
                # we are assuming that if it's a dict then it has
                # pixels and obs, and maybe obs_task_params
                if isinstance(self.obs_space, Dict):
                    if self.policy_uses_pixels:
                        agent_obs = observation['pixels']
                    else:
                        agent_obs = observation['obs']
                else:
                    agent_obs = observation
                if self.policy_uses_task_params:
                    task_params = observation['obs_task_params']
                    if self.concat_task_params_to_policy_obs:
                        agent_obs = np.concatenate((agent_obs, task_params), -1)
                    else:
                        agent_obs = {'obs': agent_obs, 'obs_task_params': task_params}
                action, agent_info = self._get_action_and_info(
                    agent_obs,
                )
                if self.render:
                    self.training_env.render()
                next_ob, raw_reward, terminal, env_info = (
                    self.training_env.step(action)
                )
                if self.no_terminal:
                    terminal = False
                self._n_env_steps_total += 1
                reward = raw_reward * self.reward_scale
                terminal = np.array([terminal])
                reward = np.array([reward])
                self._handle_step(
                    observation,
                    action,
                    reward,
                    next_ob,
                    terminal,
                    agent_info=agent_info,
                    env_info=env_info,
                )
                if terminal or len(self._current_path_builder) >= self.max_path_length:
                    self._handle_rollout_ending()
                    observation = self._start_new_rollout()
                else:
                    observation = next_ob

                gt.stamp('sample')
                if not self.do_not_train: self._try_to_train()
                gt.stamp('train')

                if self.num_episodes > self.max_num_episodes:
                    self._try_to_eval(epoch)
                    gt.stamp('eval')
                    self._end_epoch()
                    return

            self._try_to_eval(epoch)
            gt.stamp('eval')
            self._end_epoch()

    def _try_to_train(self):
        if self._can_train():
            self.training_mode(True)
            for i in range(self.num_updates_per_train_call):
                self._do_training()
                self._n_train_steps_total += 1
            self.training_mode(False)

    def _try_to_eval(self, epoch):
        if epoch % self.freq_saving == 0:
            logger.save_extra_data(self.get_extra_data_to_save(epoch))
        if self._can_evaluate():
            self.evaluate(epoch)

            if epoch % self.freq_saving == 0:
                params = self.get_epoch_snapshot(epoch)
                logger.save_itr_params(epoch, params)
            table_keys = logger.get_table_key_set()
            if self._old_table_keys is not None:
                # print('$$$$$$$$$$$$$$$')
                # print(table_keys)
                # print('\n'*4)
                # print(self._old_table_keys)
                # print('$$$$$$$$$$$$$$$')
                # print(set(table_keys) - set(self._old_table_keys))
                # print(set(self._old_table_keys) - set(table_keys))
                assert table_keys == self._old_table_keys, (
                    "Table keys cannot change from iteration to iteration."
                )
            self._old_table_keys = table_keys

            logger.record_tabular(
                "Number of train steps total",
                self._n_train_steps_total,
            )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            times_itrs = gt.get_times().stamps.itrs
            train_time = times_itrs['train'][-1]
            sample_time = times_itrs['sample'][-1]
            eval_time = times_itrs['eval'][-1] if epoch > 0 else 0
            epoch_time = train_time + sample_time + eval_time
            total_time = gt.get_times().total

            logger.record_tabular('Train Time (s)', train_time)
            logger.record_tabular('(Previous) Eval Time (s)', eval_time)
            logger.record_tabular('Sample Time (s)', sample_time)
            logger.record_tabular('Epoch Time (s)', epoch_time)
            logger.record_tabular('Total Train Time (s)', total_time)

            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")

    def _can_evaluate(self):
        """
        One annoying thing about the logger table is that the keys at each
        iteration need to be the exact same. So unless you can compute
        everything, skip evaluation.

        A common example for why you might want to skip evaluation is that at
        the beginning of training, you may not have enough data for a
        validation and training set.

        :return:
        """
        return (
            len(self._exploration_paths) > 0
            and self.replay_buffer.num_steps_can_sample() >= self.batch_size
        )

    def _can_train(self):
        return self.replay_buffer.num_steps_can_sample() >= self.batch_size

    def _get_action_and_info(self, observation):
        """
        Get an action to take in the environment.
        :param observation:
        :return:
        """
        self.exploration_policy.set_num_steps_total(self._n_env_steps_total)
        return self.exploration_policy.get_action(
            observation,
        )

    def _start_epoch(self, epoch):
        self._epoch_start_time = time.time()
        self._exploration_paths = []
        self._do_train_time = 0
        logger.push_prefix('Iteration #%d | ' % epoch)

    def _end_epoch(self):
        logger.log("Epoch Duration: {0}".format(
            time.time() - self._epoch_start_time
        ))
        logger.log("Started Training: {0}".format(self._can_train()))
        logger.pop_prefix()

    def _start_new_rollout(self):
        self.num_episodes += 1
        self.exploration_policy.reset()
        return self.training_env.reset()

    def _handle_path(self, path):
        """
        Naive implementation: just loop through each transition.
        :param path:
        :return:
        """
        for (
            ob,
            action,
            reward,
            next_ob,
            terminal,
            agent_info,
            env_info
        ) in zip(
            path["observations"],
            path["actions"],
            path["rewards"],
            path["next_observations"],
            path["terminals"],
            path["agent_infos"],
            path["env_infos"],
        ):
            self._handle_step(
                ob,
                action,
                reward,
                next_ob,
                terminal,
                agent_info=agent_info,
                env_info=env_info,
            )
        self._handle_rollout_ending()

    def _handle_step(
            self,
            observation,
            action,
            reward,
            next_observation,
            terminal,
            agent_info,
            env_info,
    ):
        """
        Implement anything that needs to happen after every step
        :return:
        """
        self._current_path_builder.add_all(
            observations=observation,
            actions=action,
            rewards=reward,
            next_observations=next_observation,
            terminals=terminal,
            agent_infos=agent_info,
            env_infos=env_info,
        )
        self.replay_buffer.add_sample(
            observation=observation,
            action=action,
            reward=reward,
            terminal=terminal,
            next_observation=next_observation,
            agent_info=agent_info,
            env_info=env_info,
        )

    def _handle_rollout_ending(self):
        """
        Implement anything that needs to happen after every rollout.
        """
        self.replay_buffer.terminate_episode()
        self._n_rollouts_total += 1
        if len(self._current_path_builder) > 0:
            self._exploration_paths.append(
                self._current_path_builder.get_all_stacked()
            )
            self._current_path_builder = PathBuilder()

    def get_epoch_snapshot(self, epoch):
        data_to_save = dict(
            epoch=epoch,
            exploration_policy=self.exploration_policy,
        )
        if self.save_environment:
            data_to_save['env'] = self.training_env
        return data_to_save

    def get_extra_data_to_save(self, epoch):
        """
        Save things that shouldn't be saved every snapshot but rather
        overwritten every time.
        :param epoch:
        :return:
        """
        if self.render:
            self.training_env.render(close=True)
        data_to_save = dict(
            epoch=epoch,
        )
        if self.save_environment:
            data_to_save['env'] = self.training_env
        if self.save_replay_buffer:
            data_to_save['replay_buffer'] = self.replay_buffer
        if self.save_algorithm:
            data_to_save['algorithm'] = self
        return data_to_save

    @abc.abstractmethod
    def training_mode(self, mode):
        """
        Set training mode to `mode`.
        :param mode: If True, training will happen (e.g. set the dropout
        probabilities to not all ones).
        """
        pass

    @abc.abstractmethod
    def cuda(self):
        """
        Turn cuda on.
        :return:
        """
        pass

    @abc.abstractmethod
    def evaluate(self, epoch):
        """
        Evaluate the policy, e.g. save/print progress.
        :param epoch:
        :return:
        """
        pass

    @abc.abstractmethod
    def _do_training(self):
        """
        Perform some update, e.g. perform one gradient step.
        :return:
        """
        pass
Пример #14
0
class RLAlgorithm(metaclass=abc.ABCMeta):
    def __init__(
        self,
        env,
        exploration_policy: ExplorationPolicy,
        training_env=None,
        num_epochs=100,
        num_steps_per_epoch=10000,
        num_steps_per_eval=1000,
        num_updates_per_env_step=1,
        batch_size=1024,
        max_path_length=1000,
        discount=0.99,
        replay_buffer_size=1000000,
        reward_scale=1,
        render=False,
        save_replay_buffer=False,
        save_algorithm=False,
        save_environment=True,
        eval_sampler=None,
        eval_policy=None,
        replay_buffer=None,
    ):
        """
        Base class for RL Algorithms
        :param env: Environment used to evaluate.
        :param exploration_policy: Policy used to explore
        :param training_env: Environment used by the algorithm. By default, a
        copy of `env` will be made.
        :param num_epochs:
        :param num_steps_per_epoch:
        :param num_steps_per_eval:
        :param num_updates_per_env_step: Used by online training mode.
        :param num_updates_per_epoch: Used by batch training mode.
        :param batch_size:
        :param max_path_length:
        :param discount:
        :param replay_buffer_size:
        :param reward_scale:
        :param render:
        :param save_replay_buffer:
        :param save_algorithm:
        :param save_environment:
        :param eval_sampler:
        :param eval_policy: Policy to evaluate with.
        :param replay_buffer:
        """
        self.training_env = training_env or pickle.loads(pickle.dumps(env))
        self.exploration_policy = exploration_policy
        self.num_epochs = num_epochs
        self.num_env_steps_per_epoch = num_steps_per_epoch
        self.num_steps_per_eval = num_steps_per_eval
        self.num_updates_per_train_call = num_updates_per_env_step
        self.batch_size = batch_size
        self.max_path_length = max_path_length
        self.discount = discount
        self.replay_buffer_size = replay_buffer_size
        self.reward_scale = reward_scale
        self.render = render
        self.save_replay_buffer = save_replay_buffer
        self.save_algorithm = save_algorithm
        self.save_environment = save_environment
        self.num_skills = 5  # added the num skills right here!!
        self.pz = np.full(self.num_skills, 1. / self.num_skills)
        #self.curr_z = self.sample_z()
        if eval_sampler is None:
            if eval_policy is None:
                eval_policy = exploration_policy
            eval_sampler = InPlacePathSampler(
                env=env,
                policy=eval_policy,
                max_samples=self.num_steps_per_eval + self.max_path_length,
                max_path_length=self.max_path_length,
            )
        self.eval_policy = eval_policy
        self.eval_sampler = eval_sampler

        self.action_space = env.action_space
        self.obs_space = env.observation_space
        self.env = env
        if replay_buffer is None:
            replay_buffer = EnvReplayBuffer(
                self.replay_buffer_size,
                self.env,
            )
        self.replay_buffer = replay_buffer

        self._n_env_steps_total = 0
        self._n_train_steps_total = 0
        self._n_rollouts_total = 0
        self._do_train_time = 0
        self._epoch_start_time = None
        self._algo_start_time = None
        self._old_table_keys = None
        self._current_path_builder = PathBuilder()
        self._exploration_paths = []

    def train(self, start_epoch=0):
        self.pretrain()
        if start_epoch == 0:
            params = self.get_epoch_snapshot(-1)
            logger.save_itr_params(-1, params)
        self.training_mode(False)
        self._n_env_steps_total = start_epoch * self.num_env_steps_per_epoch
        gt.reset()
        gt.set_def_unique(False)
        self.train_online(start_epoch=start_epoch)

    def pretrain(self):
        """
        Do anything before the main training phase.
        """
        pass

    ''' TODO: Write a function for sample z here'''

    def sample_z(self):
        ''' sample z'''
        dummy = np.zeros((self.num_skills))
        dummy[np.random.choice(self.num_skills, p=self.pz)] = 1
        # pdb.set_trace()

        return dummy

    ''' TODO: concat funciton'''

    def concat_state_z(self, state, z):
        return np.concatenate([state, z], axis=0)

    def train_online(self, start_epoch=0):
        self._current_path_builder = PathBuilder()
        observation = self._start_new_rollout()
        #observation = self.concat_state_z(state, self.curr_z)

        for epoch in gt.timed_for(
                range(start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            self._start_epoch(epoch)
            for _ in range(self.num_env_steps_per_epoch):
                ''' TODO'''
                ''' append the latent variable here'''
                action, agent_info = self._get_action_and_info(observation, )
                if self.render:
                    self.training_env.render()

                next_state, raw_reward, terminal, env_info = (
                    self.training_env.step(action))

                # print (terminal)
                next_ob = self.concat_state_z(next_state, self.curr_z)
                self._n_env_steps_total += 1
                reward = raw_reward * self.reward_scale
                terminal = np.array([terminal])
                reward = np.array([reward])
                self._handle_step(
                    observation,
                    action,
                    reward,
                    next_ob,
                    terminal,
                    agent_info=agent_info,
                    env_info=env_info,
                )
                if terminal or len(
                        self._current_path_builder) >= self.max_path_length:
                    self._handle_rollout_ending()
                    observation = self._start_new_rollout()
                    #print ('starting new rollout')
                else:
                    observation = next_ob

                gt.stamp('sample')
                self._try_to_train()
                gt.stamp('train')

            # need to fix the evaluation here..figure this out!!
            # self._try_to_eval(epoch)
            # gt.stamp('eval')
            # self._end_epoch()

    def _try_to_train(self):
        if self._can_train():
            self.training_mode(True)
            for i in range(self.num_updates_per_train_call):
                self._do_training()
                self._n_train_steps_total += 1
            self.training_mode(False)

    def _try_to_eval(self, epoch):
        logger.save_extra_data(self.get_extra_data_to_save(epoch))
        if self._can_evaluate():
            self.evaluate(epoch)

            params = self.get_epoch_snapshot(epoch)
            logger.save_itr_params(epoch, params)
            table_keys = logger.get_table_key_set()
            if self._old_table_keys is not None:
                assert table_keys == self._old_table_keys, (
                    "Table keys cannot change from iteration to iteration.")
            self._old_table_keys = table_keys

            logger.record_tabular(
                "Number of train steps total",
                self._n_train_steps_total,
            )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            times_itrs = gt.get_times().stamps.itrs
            train_time = times_itrs['train'][-1]
            sample_time = times_itrs['sample'][-1]
            eval_time = times_itrs['eval'][-1] if epoch > 0 else 0
            epoch_time = train_time + sample_time + eval_time
            total_time = gt.get_times().total

            logger.record_tabular('Train Time (s)', train_time)
            logger.record_tabular('(Previous) Eval Time (s)', eval_time)
            logger.record_tabular('Sample Time (s)', sample_time)
            logger.record_tabular('Epoch Time (s)', epoch_time)
            logger.record_tabular('Total Train Time (s)', total_time)

            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")

    def _can_evaluate(self):
        """
        One annoying thing about the logger table is that the keys at each
        iteration need to be the exact same. So unless you can compute
        everything, skip evaluation.

        A common example for why you might want to skip evaluation is that at
        the beginning of training, you may not have enough data for a
        validation and training set.

        :return:
        """
        return (len(self._exploration_paths) > 0 and
                self.replay_buffer.num_steps_can_sample() >= self.batch_size)

    def _can_train(self):
        return self.replay_buffer.num_steps_can_sample() >= self.batch_size

    def _get_action_and_info(self, observation):
        """
        Get an action to take in the environment.
        :param observation:
        :return:
        """
        #print (observation.shape)
        self.exploration_policy.set_num_steps_total(self._n_env_steps_total)
        return self.exploration_policy.get_action(observation, )

    def _start_epoch(self, epoch):
        self._epoch_start_time = time.time()
        self._exploration_paths = []
        self._do_train_time = 0
        logger.push_prefix('Iteration #%d | ' % epoch)

    def _end_epoch(self):
        logger.log("Epoch Duration: {0}".format(time.time() -
                                                self._epoch_start_time))
        logger.log("Started Training: {0}".format(self._can_train()))
        logger.pop_prefix()

    def _start_new_rollout(self):
        self.curr_z = self.sample_z()
        self.exploration_policy.reset()
        return self.concat_state_z(self.training_env.reset(), self.curr_z)

    def _handle_path(self, path):
        """
        Naive implementation: just loop through each transition.
        :param path:
        :return:
        """
        for (ob, action, reward, next_ob, terminal, agent_info,
             env_info) in zip(
                 path["observations"],
                 path["actions"],
                 path["rewards"],
                 path["next_observations"],
                 path["terminals"],
                 path["agent_infos"],
                 path["env_infos"],
             ):
            self._handle_step(
                ob,
                action,
                reward,
                next_ob,
                terminal,
                agent_info=agent_info,
                env_info=env_info,
            )
        self._handle_rollout_ending()

    def _handle_step(
        self,
        observation,
        action,
        reward,
        next_observation,
        terminal,
        agent_info,
        env_info,
    ):
        """
        Implement anything that needs to happen after every step
        :return:
        """
        self._current_path_builder.add_all(
            observations=observation,
            actions=action,
            rewards=reward,
            next_observations=next_observation,
            terminals=terminal,
            agent_infos=agent_info,
            env_infos=env_info,
        )
        self.replay_buffer.add_sample(
            observation=observation,
            action=action,
            reward=reward,
            terminal=terminal,
            next_observation=next_observation,
            agent_info=agent_info,
            env_info=env_info,
        )

    def _handle_rollout_ending(self):
        """
        Implement anything that needs to happen after every rollout.
        """
        self.replay_buffer.terminate_episode()
        self._n_rollouts_total += 1
        if len(self._current_path_builder) > 0:
            self._exploration_paths.append(
                self._current_path_builder.get_all_stacked())
            self._current_path_builder = PathBuilder()

    def get_epoch_snapshot(self, epoch):
        if self.render:
            self.training_env.render(close=True)
        data_to_save = dict(
            epoch=epoch,
            exploration_policy=self.exploration_policy,
        )
        if self.save_environment:
            data_to_save['env'] = self.training_env
        return data_to_save

    def get_extra_data_to_save(self, epoch):
        """
        Save things that shouldn't be saved every snapshot but rather
        overwritten every time.
        :param epoch:
        :return:
        """
        if self.render:
            self.training_env.render(close=True)
        data_to_save = dict(epoch=epoch, )
        if self.save_environment:
            data_to_save['env'] = self.training_env
        if self.save_replay_buffer:
            data_to_save['replay_buffer'] = self.replay_buffer
        if self.save_algorithm:
            data_to_save['algorithm'] = self
        return data_to_save

    @abc.abstractmethod
    def training_mode(self, mode):
        """
        Set training mode to `mode`.
        :param mode: If True, training will happen (e.g. set the dropout
        probabilities to not all ones).
        """
        pass

    @abc.abstractmethod
    def cuda(self):
        """
        Turn cuda on.
        :return:
        """
        pass

    @abc.abstractmethod
    def evaluate(self, epoch):
        """
        Evaluate the policy, e.g. save/print progress.
        :param epoch:
        :return:
        """
        pass

    @abc.abstractmethod
    def _do_training(self):
        """
        Perform some update, e.g. perform one gradient step.
        :return:
        """
        pass
Пример #15
0
class RLAlgorithm(metaclass=abc.ABCMeta):
    def __init__(
        self,
        env,
        exploration_policy: ExplorationPolicy,
        training_env=None,
        num_epochs=100,
        num_steps_per_epoch=10000,
        num_steps_per_eval=1000,
        num_updates_per_env_step=1,
        batch_size=1024,
        max_path_length=1000,
        discount=0.99,
        replay_buffer_size=1000000,
        reward_scale=1,
        render=False,
        save_replay_buffer=False,
        save_algorithm=False,
        save_environment=False,
        eval_sampler=None,
        eval_policy=None,
        replay_buffer=None,
        demo_path=None,
        action_skip=1,
        experiment_name="default",
        mix_demo=False,
    ):
        """
        Base class for RL Algorithms
        :param env: Environment used to evaluate.
        :param exploration_policy: Policy used to explore
        :param training_env: Environment used by the algorithm. By default, a
        copy of `env` will be made.
        :param num_epochs:
        :param num_steps_per_epoch:
        :param num_steps_per_eval:
        :param num_updates_per_env_step: Used by online training mode.
        :param num_updates_per_epoch: Used by batch training mode.
        :param batch_size:
        :param max_path_length:
        :param discount:
        :param replay_buffer_size:
        :param reward_scale:
        :param render:
        :param save_replay_buffer:
        :param save_algorithm:
        :param save_environment:
        :param eval_sampler:
        :param eval_policy: Policy to evaluate with.
        :param replay_buffer:
        """

        ### TODO: look at NormalizedBoxEnv, do we need it? ###

        # self.training_env = training_env or gym.make("HalfCheetah-v2")
        self.training_env = training_env or MujocoManipEnv(
            env.env.__class__.__name__)
        self.exploration_policy = exploration_policy
        self.num_epochs = num_epochs
        self.num_env_steps_per_epoch = num_steps_per_epoch
        self.num_steps_per_eval = num_steps_per_eval
        self.num_updates_per_train_call = num_updates_per_env_step
        self.batch_size = batch_size
        self.max_path_length = max_path_length
        self.discount = discount
        self.replay_buffer_size = replay_buffer_size
        self.reward_scale = reward_scale
        self.render = render
        self.save_replay_buffer = save_replay_buffer
        self.save_algorithm = save_algorithm
        self.save_environment = save_environment
        if eval_sampler is None:
            if eval_policy is None:
                eval_policy = exploration_policy
            eval_sampler = InPlacePathSampler(
                env=env,
                policy=eval_policy,
                max_samples=self.num_steps_per_eval + self.max_path_length,
                max_path_length=self.max_path_length,
            )
        self.eval_policy = eval_policy
        self.eval_sampler = eval_sampler

        self.action_space = env.action_space
        self.obs_space = env.observation_space
        self.env = env
        if replay_buffer is None:
            replay_buffer = EnvReplayBuffer(
                self.replay_buffer_size,
                self.env,
            )
        self.replay_buffer = replay_buffer

        self.demo_sampler = None
        self.mix_demo = mix_demo
        if demo_path is not None:
            self.demo_sampler = DemoSampler(
                demo_path=demo_path,
                observation_dim=self.obs_space.shape[0],
                action_dim=self.action_space.shape[0],
                preload=True)
        self.action_skip = action_skip
        self.action_skip_count = 0

        self._n_env_steps_total = 0
        self._n_train_steps_total = 0
        self._n_rollouts_total = 0
        self._do_train_time = 0
        self._epoch_start_time = None
        self._algo_start_time = None
        self._old_table_keys = None
        self._current_path_builder = PathBuilder()
        self._exploration_paths = []

        t_now = time.time()
        time_str = datetime.datetime.fromtimestamp(t_now).strftime(
            '%Y%m%d%H%M%S')
        os.makedirs(os.path.join(LOCAL_EXP_PATH, experiment_name, time_str))
        self._writer = SummaryWriter(
            os.path.join(LOCAL_EXP_PATH, experiment_name, time_str))

    def train(self, start_epoch=0):
        self.pretrain()
        if start_epoch == 0:
            params = self.get_epoch_snapshot(-1)
            #logger.save_itr_params(-1, params)
        self.training_mode(False)
        self._n_env_steps_total = start_epoch * self.num_env_steps_per_epoch
        gt.reset()
        gt.set_def_unique(False)
        self.train_online(start_epoch=start_epoch)

    def pretrain(self):
        """
        Do anything before the main training phase.
        """
        pass

    def train_online(self, start_epoch=0):
        self._current_path_builder = PathBuilder()
        observation = self._start_new_rollout()
        for epoch in gt.timed_for(
                range(start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            self._start_epoch(epoch)
            for _ in range(self.num_env_steps_per_epoch):
                action, agent_info = self._get_action_and_info(observation, )
                if self.render:
                    self.training_env.render()
                next_ob, raw_reward, terminal, env_info = (
                    self.training_env.step(action))
                self._n_env_steps_total += 1
                reward = raw_reward * self.reward_scale
                terminal = np.array([terminal])
                reward = np.array([reward])
                self._handle_step(
                    observation,
                    action,
                    reward,
                    next_ob,
                    terminal,
                    agent_info=agent_info,
                    env_info=env_info,
                )
                if terminal or len(
                        self._current_path_builder) >= self.max_path_length:
                    self._handle_rollout_ending()
                    observation = self._start_new_rollout()
                else:
                    observation = next_ob

                gt.stamp('sample')
                self._try_to_train()
                gt.stamp('train')

            self._try_to_eval(epoch)
            gt.stamp('eval')
            self._end_epoch()

    def _try_to_train(self):
        if self._can_train():
            self.training_mode(True)
            for i in range(self.num_updates_per_train_call):
                self._do_training()
                self._n_train_steps_total += 1
            self.training_mode(False)

    def _try_to_eval(self, epoch):
        logger.save_extra_data(self.get_extra_data_to_save(epoch))
        if self._can_evaluate():
            self.evaluate(epoch)

            params = self.get_epoch_snapshot(epoch)
            logger.save_itr_params(epoch, params)
            table_keys = logger.get_table_key_set()
            #print("TABLE KEYS")
            #print(table_keys)
            #if self._old_table_keys is not None:
            #    assert table_keys == self._old_table_keys, (
            #        "Table keys cannot change from iteration to iteration."
            #    )
            self._old_table_keys = table_keys

            logger.record_tabular(
                "Number of train steps total",
                self._n_train_steps_total,
            )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            times_itrs = gt.get_times().stamps.itrs
            train_time = times_itrs['train'][-1]
            sample_time = times_itrs['sample'][-1]
            eval_time = times_itrs['eval'][-1] if epoch > 0 else 0
            epoch_time = train_time + sample_time + eval_time
            total_time = gt.get_times().total

            logger.record_tabular('Train Time (s)', train_time)
            logger.record_tabular('(Previous) Eval Time (s)', eval_time)
            logger.record_tabular('Sample Time (s)', sample_time)
            logger.record_tabular('Epoch Time (s)', epoch_time)
            logger.record_tabular('Total Train Time (s)', total_time)

            logger.record_tabular("Epoch", epoch)

            # tensorboard stuff
            _writer = self._writer
            for k, v_str in logger._tabular:

                if k == 'Epoch': continue

                v = float(v_str)
                if k.endswith('Loss'):
                    _writer.add_scalar('Loss/{}'.format(k), v, epoch)
                elif k.endswith('Max'):
                    prefix = k[:-4]
                    _writer.add_scalar('{}/{}'.format(prefix, k), v, epoch)
                elif k.endswith('Min'):
                    prefix = k[:-4]
                    _writer.add_scalar('{}/{}'.format(prefix, k), v, epoch)
                elif k.endswith('Std'):
                    prefix = k[:-4]
                    _writer.add_scalar('{}/{}'.format(prefix, k), v, epoch)
                elif k.endswith('Mean'):
                    prefix = k[:-5]
                    _writer.add_scalar('{}/{}'.format(prefix, k), v, epoch)
                elif 'Time' in k:
                    _writer.add_scalar('Time/{}'.format(k), v, epoch)
                elif k.startswith('Num'):
                    _writer.add_scalar('Number/{}'.format(k), v, epoch)
                elif k.startswith('Exploration'):
                    _writer.add_scalar('Exploration/{}'.format(k), v, epoch)
                elif k.startswith('Test'):
                    _writer.add_scalar('Test/{}'.format(k), v, epoch)
                else:
                    _writer.add_scalar(k, v, epoch)

            _writer.file_writer.flush()

            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")

    def _can_evaluate(self):
        """
        One annoying thing about the logger table is that the keys at each
        iteration need to be the exact same. So unless you can compute
        everything, skip evaluation.

        A common example for why you might want to skip evaluation is that at
        the beginning of training, you may not have enough data for a
        validation and training set.

        :return:
        """
        return (len(self._exploration_paths) > 0 and
                self.replay_buffer.num_steps_can_sample() >= self.batch_size)

    def _can_train(self):
        return self.replay_buffer.num_steps_can_sample() >= self.batch_size

    def _get_action_and_info(self, observation):
        """
        Get an action to take in the environment.
        :param observation:
        :return:
        """
        self.exploration_policy.set_num_steps_total(self._n_env_steps_total)

        # logic for action skipping, only update the policy action every action_skip timesteps
        if self.action_skip_count % self.action_skip == 0:
            self.action_skip_action = self.exploration_policy.get_action(
                observation)
        self.action_skip_count += 1

        return self.action_skip_action

    def _start_epoch(self, epoch):
        self._epoch_start_time = time.time()
        self._exploration_paths = []
        self._do_train_time = 0
        self.action_skip_count = 0
        logger.push_prefix('Iteration #%d | ' % epoch)

    def _end_epoch(self):
        logger.log("Epoch Duration: {0}".format(time.time() -
                                                self._epoch_start_time))
        logger.log("Started Training: {0}".format(self._can_train()))
        logger.pop_prefix()

    def _start_new_rollout(self):
        self.exploration_policy.reset()
        self.action_skip_count = 0
        return self.training_env.reset()

    def _handle_path(self, path):
        """
        Naive implementation: just loop through each transition.
        :param path:
        :return:
        """
        for (ob, action, reward, next_ob, terminal, agent_info,
             env_info) in zip(
                 path["observations"],
                 path["actions"],
                 path["rewards"],
                 path["next_observations"],
                 path["terminals"],
                 path["agent_infos"],
                 path["env_infos"],
             ):
            self._handle_step(
                ob,
                action,
                reward,
                next_ob,
                terminal,
                agent_info=agent_info,
                env_info=env_info,
            )
        self._handle_rollout_ending()

    def _handle_step(
        self,
        observation,
        action,
        reward,
        next_observation,
        terminal,
        agent_info,
        env_info,
    ):
        """
        Implement anything that needs to happen after every step
        :return:
        """
        self._current_path_builder.add_all(
            observations=observation,
            actions=action,
            rewards=reward,
            next_observations=next_observation,
            terminals=terminal,
            agent_infos=agent_info,
            env_infos=env_info,
        )
        self.replay_buffer.add_sample(
            observation=observation,
            action=action,
            reward=reward,
            terminal=terminal,
            next_observation=next_observation,
            agent_info=agent_info,
            env_info=env_info,
        )

    def _handle_rollout_ending(self):
        """
        Implement anything that needs to happen after every rollout.
        """
        self.replay_buffer.terminate_episode()
        self._n_rollouts_total += 1
        if len(self._current_path_builder) > 0:
            self._exploration_paths.append(
                self._current_path_builder.get_all_stacked())
            self._current_path_builder = PathBuilder()

    def get_epoch_snapshot(self, epoch):
        data_to_save = dict(
            epoch=epoch,
            exploration_policy=self.exploration_policy,
        )
        if self.save_environment:
            data_to_save['env'] = self.training_env
        return data_to_save

    def get_extra_data_to_save(self, epoch):
        """
        Save things that shouldn't be saved every snapshot but rather
        overwritten every time.
        :param epoch:
        :return:
        """
        if self.render:
            self.training_env.render(close=True)
        data_to_save = dict(epoch=epoch, )
        if self.save_environment:
            data_to_save['env'] = self.training_env
        if self.save_replay_buffer:
            data_to_save['replay_buffer'] = self.replay_buffer
        if self.save_algorithm:
            data_to_save['algorithm'] = self
        return data_to_save

    @abc.abstractmethod
    def training_mode(self, mode):
        """
        Set training mode to `mode`.
        :param mode: If True, training will happen (e.g. set the dropout
        probabilities to not all ones).
        """
        pass

    @abc.abstractmethod
    def cuda(self):
        """
        Turn cuda on.
        :return:
        """
        pass

    @abc.abstractmethod
    def evaluate(self, epoch):
        """
        Evaluate the policy, e.g. save/print progress.
        :param epoch:
        :return:
        """
        pass

    @abc.abstractmethod
    def _do_training(self):
        """
        Perform some update, e.g. perform one gradient step.
        :return:
        """
        pass
Пример #16
0
class IRLAlgorithm(metaclass=abc.ABCMeta):
    '''
    Generic IRL algorithm class
    Structure:
    while True:
        generate trajectories
        update reward
        fit policy
    '''
    def __init__(
            self,
            env,
            exploration_policy: ExplorationPolicy,
            expert_replay_buffer,
            training_env=None,
            num_epochs=100,
            num_steps_per_epoch=10000,
            num_steps_per_eval=1000,
            num_steps_between_updates=1000,
            min_steps_before_training=1000,
            max_path_length=1000,
            discount=0.99,
            replay_buffer_size=10000,
            render=False,
            save_replay_buffer=False,
            save_algorithm=False,
            save_environment=False,
            save_best=False,
            save_best_starting_from_epoch=0,
            eval_sampler=None,
            eval_policy=None,
            replay_buffer=None,
            policy_uses_pixels=False,
            wrap_absorbing=False,
            freq_saving=1,
            # some environment like halfcheetah_v2 have a timelimit that defines the terminal
            # this is used as a minor hack to turn off time limits
            no_terminal=False,
            policy_uses_task_params=False,
            concat_task_params_to_policy_obs=False
        ):
        """
        Base class for RL Algorithms
        :param env: Environment used to evaluate.
        :param exploration_policy: Policy used to explore
        :param training_env: Environment used by the algorithm. By default, a
        copy of `env` will be made.
        :param num_epochs:
        :param num_steps_per_epoch:
        :param num_steps_per_eval:
        :param num_updates_per_env_step: Used by online training mode.
        :param num_updates_per_epoch: Used by batch training mode.
        :param batch_size:
        :param max_path_length:
        :param discount:
        :param replay_buffer_size:
        :param render:
        :param save_replay_buffer:
        :param save_algorithm:
        :param save_environment:
        :param eval_sampler:
        :param eval_policy: Policy to evaluate with.
        :param replay_buffer:
        """
        self.training_env = training_env or pickle.loads(pickle.dumps(env))
        # self.training_env = training_env or deepcopy(env)
        self.exploration_policy = exploration_policy
        self.expert_replay_buffer = expert_replay_buffer
        self.num_epochs = num_epochs
        self.num_env_steps_per_epoch = num_steps_per_epoch
        self.num_steps_per_eval = num_steps_per_eval
        self.num_steps_between_updates = num_steps_between_updates
        self.min_steps_before_training = min_steps_before_training
        self.max_path_length = max_path_length
        self.discount = discount
        self.replay_buffer_size = replay_buffer_size
        self.render = render
        self.save_replay_buffer = save_replay_buffer
        self.save_algorithm = save_algorithm
        self.save_environment = save_environment
        self.save_best = save_best
        self.save_best_starting_from_epoch = save_best_starting_from_epoch
        self.policy_uses_pixels = policy_uses_pixels
        self.policy_uses_task_params = policy_uses_task_params
        self.concat_task_params_to_policy_obs = concat_task_params_to_policy_obs
        if eval_sampler is None:
            if eval_policy is None:
                eval_policy = exploration_policy
            eval_sampler = InPlacePathSampler(
                env=env,
                policy=eval_policy,
                max_samples=self.num_steps_per_eval + self.max_path_length,
                max_path_length=self.max_path_length, policy_uses_pixels=policy_uses_pixels,
                policy_uses_task_params=policy_uses_task_params,
                concat_task_params_to_policy_obs=concat_task_params_to_policy_obs
            )
        self.eval_policy = eval_policy
        self.eval_sampler = eval_sampler

        self.action_space = env.action_space
        self.obs_space = env.observation_space
        self.env = env
        if replay_buffer is None:
            replay_buffer = EnvReplayBuffer(
                self.replay_buffer_size,
                self.env,
                policy_uses_pixels=self.policy_uses_pixels,
                policy_uses_task_params=self.policy_uses_task_params,
                concat_task_params_to_policy_obs=self.concat_task_params_to_policy_obs
            )
        self.replay_buffer = replay_buffer

        self._n_env_steps_total = 0
        self._n_train_steps_total = 0
        self._n_rollouts_total = 0
        self._do_train_time = 0
        self._epoch_start_time = None
        self._algo_start_time = None
        self._old_table_keys = None
        self._current_path_builder = PathBuilder()
        self._exploration_paths = []
        self.wrap_absorbing = wrap_absorbing
        self.freq_saving = freq_saving
        self.no_terminal = no_terminal


    def train(self, start_epoch=0):
        self.pretrain()
        if start_epoch == 0:
            params = self.get_epoch_snapshot(-1)
            logger.save_itr_params(-1, params)
        self.training_mode(False)
        self._n_env_steps_total = start_epoch * self.num_env_steps_per_epoch
        gt.reset()
        gt.set_def_unique(False)
        self.train_online(start_epoch=start_epoch)


    def pretrain(self):
        """
        Do anything before the main training phase.
        """
        pass


    def train_online(self, start_epoch=0):
        self._current_path_builder = PathBuilder()
        observation = self._start_new_rollout()
        for epoch in gt.timed_for(
                range(start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            self._start_epoch(epoch)
            steps_this_epoch = 0
            while steps_this_epoch < self.num_env_steps_per_epoch:
                # print(steps_this_epoch)
                for _ in range(self.num_steps_between_updates):
                    if isinstance(self.obs_space, Dict):
                        if self.policy_uses_pixels:
                            agent_obs = observation['pixels']
                        else:
                            agent_obs = observation['obs']
                    else:
                        agent_obs = observation
                    if self.policy_uses_task_params:
                        task_params = observation['obs_task_params']
                        if self.concat_task_params_to_policy_obs:
                            agent_obs = np.concatenate((agent_obs, task_params), -1)
                        else:
                            agent_obs = {'obs': agent_obs, 'obs_task_params': task_params}
                    action, agent_info = self._get_action_and_info(
                        agent_obs,
                    )
                    if self.render:
                        self.training_env.render()
                    next_ob, raw_reward, terminal, env_info = (
                        self.training_env.step(action)
                    )
                    if self.no_terminal:
                        terminal = False
                    self._n_env_steps_total += 1
                    reward = raw_reward
                    terminal = np.array([terminal])
                    reward = np.array([reward])
                    self._handle_step(
                        observation,
                        action,
                        reward,
                        next_ob,
                        np.array([False]) if self.wrap_absorbing else terminal,
                        absorbing=np.array([0., 0.]),
                        agent_info=agent_info,
                        env_info=env_info,
                    )
                    if terminal:
                        if self.wrap_absorbing:
                            '''
                            If we wrap absorbing states, two additional
                            transitions must be added: (s_T, s_abs) and
                            (s_abs, s_abs). In Disc Actor Critic paper
                            they make s_abs be a vector of 0s with last
                            dim set to 1. Here we are going to add the following:
                            ([next_ob,0], random_action, [next_ob, 1]) and
                            ([next_ob,1], random_action, [next_ob, 1])
                            This way we can handle varying types of terminal states.
                            '''
                            # next_ob is the absorbing state
                            # for now just taking the previous action
                            self._handle_step(
                                next_ob,
                                action,
                                # env.action_space.sample(),
                                # the reward doesn't matter
                                reward,
                                next_ob,
                                np.array([False]),
                                absorbing=np.array([0.0, 1.0]),
                                agent_info=agent_info,
                                env_info=env_info
                            )
                            self._handle_step(
                                next_ob,
                                action,
                                # env.action_space.sample(),
                                # the reward doesn't matter
                                reward,
                                next_ob,
                                np.array([False]),
                                absorbing=np.array([1.0, 1.0]),
                                agent_info=agent_info,
                                env_info=env_info
                            )
                        self._handle_rollout_ending()
                        observation = self._start_new_rollout()
                    elif len(self._current_path_builder) >= self.max_path_length:
                        self._handle_rollout_ending()
                        observation = self._start_new_rollout()
                    else:
                        observation = next_ob

                    steps_this_epoch += 1

                gt.stamp('sample')
                self._try_to_train(epoch)
                gt.stamp('train')

            self._try_to_eval(epoch)
            gt.stamp('eval')
            self._end_epoch()

    def _try_to_train(self, epoch):
        if self._can_train():
            self.training_mode(True)
            self._do_training(epoch)
            self._n_train_steps_total += 1
            self.training_mode(False)

    def _try_to_eval(self, epoch):
        if epoch % self.freq_saving == 0:
            logger.save_extra_data(self.get_extra_data_to_save(epoch))
        if self._can_evaluate():
            self.evaluate(epoch)

            if epoch % self.freq_saving == 0:
                params = self.get_epoch_snapshot(epoch)
                logger.save_itr_params(epoch, params)
            table_keys = logger.get_table_key_set()
            # if self._old_table_keys is not None:
            #     print('$$$$$$$$$$$$$$$')
            #     print(table_keys)
            #     print('\n'*4)
            #     print(self._old_table_keys)
            #     print('$$$$$$$$$$$$$$$')
            #     print(set(table_keys) - set(self._old_table_keys))
            #     print(set(self._old_table_keys) - set(table_keys))
            #     assert table_keys == self._old_table_keys, (
            #         "Table keys cannot change from iteration to iteration."
            #     )
            # self._old_table_keys = table_keys

            logger.record_tabular(
                "Number of train steps total",
                self._n_train_steps_total,
            )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            times_itrs = gt.get_times().stamps.itrs
            train_time = times_itrs['train'][-1]
            sample_time = times_itrs['sample'][-1]
            eval_time = times_itrs['eval'][-1] if epoch > 0 else 0
            epoch_time = train_time + sample_time + eval_time
            total_time = gt.get_times().total

            logger.record_tabular('Train Time (s)', train_time)
            logger.record_tabular('(Previous) Eval Time (s)', eval_time)
            logger.record_tabular('Sample Time (s)', sample_time)
            logger.record_tabular('Epoch Time (s)', epoch_time)
            logger.record_tabular('Total Train Time (s)', total_time)

            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")

    def _can_evaluate(self):
        """
        One annoying thing about the logger table is that the keys at each
        iteration need to be the exact same. So unless you can compute
        everything, skip evaluation.

        A common example for why you might want to skip evaluation is that at
        the beginning of training, you may not have enough data for a
        validation and training set.

        :return:
        """
        return (
            len(self._exploration_paths) > 0
            and self.replay_buffer.num_steps_can_sample() >= self.min_steps_before_training
        )

    def _can_train(self):
        return self.replay_buffer.num_steps_can_sample() >= self.min_steps_before_training

    def _get_action_and_info(self, observation):
        """
        Get an action to take in the environment.
        :param observation:
        :return:
        """
        self.exploration_policy.set_num_steps_total(self._n_env_steps_total)
        return self.exploration_policy.get_action(
            observation,
        )

    def _start_epoch(self, epoch):
        self._epoch_start_time = time.time()
        self._exploration_paths = []
        self._do_train_time = 0
        logger.push_prefix('Iteration #%d | ' % epoch)

    def _end_epoch(self):
        logger.log("Epoch Duration: {0}".format(
            time.time() - self._epoch_start_time
        ))
        logger.log("Started Training: {0}".format(self._can_train()))
        logger.pop_prefix()

    def _start_new_rollout(self):
        self.exploration_policy.reset()
        return self.training_env.reset()

    def _handle_path(self, path):
        raise NotImplementedError('Does not handle absorbing states')
        """
        Naive implementation: just loop through each transition.
        :param path:
        :return:
        """
        for (
            ob,
            action,
            reward,
            next_ob,
            terminal,
            agent_info,
            env_info
        ) in zip(
            path["observations"],
            path["actions"],
            path["rewards"],
            path["next_observations"],
            path["terminals"],
            path["agent_infos"],
            path["env_infos"],
        ):
            self._handle_step(
                ob,
                action,
                reward,
                next_ob,
                terminal,
                agent_info=agent_info,
                env_info=env_info,
            )
        self._handle_rollout_ending()

    def _handle_step(
            self,
            observation,
            action,
            reward,
            next_observation,
            terminal,
            absorbing,
            agent_info,
            env_info,
    ):
        """
        Implement anything that needs to happen after every step
        :return:
        """
        self._current_path_builder.add_all(
            observations=observation,
            actions=action,
            rewards=reward,
            next_observations=next_observation,
            terminals=terminal,
            absorbing=absorbing,
            agent_infos=agent_info,
            env_infos=env_info,
        )
        self.replay_buffer.add_sample(
            observation=observation,
            action=action,
            reward=reward,
            terminal=terminal,
            next_observation=next_observation,
            absorbing=absorbing,
            agent_info=agent_info,
            env_info=env_info,
        )

    def _handle_rollout_ending(self):
        """
        Implement anything that needs to happen after every rollout.
        """
        self.replay_buffer.terminate_episode()
        self._n_rollouts_total += 1
        if len(self._current_path_builder) > 0:
            self._exploration_paths.append(
                self._current_path_builder.get_all_stacked()
            )
            self._current_path_builder = PathBuilder()

    def get_epoch_snapshot(self, epoch):
        data_to_save = dict(
            epoch=epoch,
            exploration_policy=self.exploration_policy,
        )
        if self.save_environment:
            data_to_save['env'] = self.training_env
        return data_to_save

    def get_extra_data_to_save(self, epoch):
        """
        Save things that shouldn't be saved every snapshot but rather
        overwritten every time.
        :param epoch:
        :return:
        """
        if self.render:
            self.training_env.render(close=True)
        data_to_save = dict(
            epoch=epoch,
        )
        if self.save_environment:
            data_to_save['env'] = self.training_env
        if self.save_replay_buffer:
            data_to_save['replay_buffer'] = self.replay_buffer
        if self.save_algorithm:
            data_to_save['algorithm'] = self
        return data_to_save

    @abc.abstractmethod
    def training_mode(self, mode):
        """
        Set training mode to `mode`.
        :param mode: If True, training will happen (e.g. set the dropout
        probabilities to not all ones).
        """
        pass

    @abc.abstractmethod
    def cuda(self):
        """
        Turn cuda on.
        :return:
        """
        pass

    @abc.abstractmethod
    def evaluate(self, epoch):
        """
        Evaluate the policy, e.g. save/print progress.
        :param epoch:
        :return:
        """
        pass

    @abc.abstractmethod
    def _do_training(self):
        """
        Perform some update, e.g. perform one gradient step.
        :return:
        """
        pass
Пример #17
0
class RLAlgorithm(metaclass=abc.ABCMeta):
    def __init__(
            self,
            env,
            exploration_policy: ExplorationPolicy,
            training_env=None,
            num_epochs=100,
            num_steps_per_epoch=10000,
            num_steps_per_eval=1000,
            num_updates_per_env_step=1,
            num_updates_per_epoch=None,
            batch_size=1024,
            max_path_length=1000,
            discount=0.99,
            replay_buffer_size=1000000,
            reward_scale=1,
            min_num_steps_before_training=None,
            render=False,
            save_replay_buffer=False,
            save_algorithm=False,
            save_environment=True,
            eval_sampler=None,
            eval_policy=None,
            replay_buffer=None,
            collection_mode='online',
            save_extra_data_interval=100000,
            num_gpus=1,
            num_epochs_per_eval=10,
            num_epochs_per_param_save=100,
            **kwargs
    ):
        """
        Base class for RL Algorithms

        :param env: Environment used to evaluate.
        :param exploration_policy: Policy used to explore
        :param training_env: Environment used by the algorithm. By default, a
        copy of `env` will be made for training, so that training and
        evaluation are completely independent.
        :param num_epochs:
        :param num_steps_per_epoch:
        :param num_steps_per_eval:
        :param num_updates_per_env_step: Used by online training mode.
        :param num_updates_per_epoch: Used by batch training mode.
        :param batch_size:
        :param max_path_length:
        :param discount:
        :param replay_buffer_size:
        :param reward_scale:
        :param min_num_steps_before_training:
        :param render:
        :param save_replay_buffer:
        :param save_algorithm:
        :param save_environment:
        :param eval_sampler:
        :param eval_policy: Policy to evaluate with.
        :param replay_buffer:
        :param collection_mode: String determining how training happens
         - 'online': Train after every step taken in the environment.
         - 'batch': Train after every epoch.
        """
        assert collection_mode in ['online', 'batch']
        if collection_mode == 'batch':
            assert num_updates_per_epoch is not None

        self.training_env = training_env or pickle.loads(pickle.dumps(env))
        self.exploration_policy = exploration_policy
        self.num_epochs = num_epochs
        self.num_env_steps_per_epoch = num_steps_per_epoch
        self.num_steps_per_eval = num_steps_per_eval
        if collection_mode == 'online':
            self.num_updates_per_train_call = num_updates_per_env_step
        else:
            self.num_updates_per_train_call = num_updates_per_epoch
        self.batch_size = batch_size
        self.max_path_length = max_path_length
        self.discount = discount
        self.replay_buffer_size = replay_buffer_size
        self.reward_scale = reward_scale
        self.render = render
        self.collection_mode = collection_mode
        self.save_replay_buffer = save_replay_buffer
        self.save_algorithm = save_algorithm
        self.save_environment = save_environment
        if min_num_steps_before_training is None:
            min_num_steps_before_training = self.num_env_steps_per_epoch
        self.min_num_steps_before_training = min_num_steps_before_training
        if eval_sampler is None:
            if eval_policy is None:
                eval_policy = exploration_policy
            eval_sampler = InPlacePathSampler(
                env=env,
                policy=eval_policy,
                max_samples=self.num_steps_per_eval + self.max_path_length,
                max_path_length=self.max_path_length,
            )
        self.eval_policy = eval_policy
        self.eval_sampler = eval_sampler
        self.eval_statistics = OrderedDict()
        self.need_to_update_eval_statistics = True

        self.action_space = env.action_space
        self.obs_space = env.observation_space
        self.env = env
        if replay_buffer is None:
            replay_buffer = EnvReplayBuffer(
                self.replay_buffer_size,
                self.env,
            )
        self.replay_buffer = replay_buffer

        self._n_env_steps_total = 0
        self._n_train_steps_total = 0
        self._n_rollouts_total = 0
        self._do_train_time = 0
        self._epoch_start_time = None
        self._algo_start_time = None
        self._old_table_keys = None
        self._current_path_builder = PathBuilder()
        self._exploration_paths = []
        self.post_epoch_funcs = []
        self.save_extra_data_interval = save_extra_data_interval

        # MPI stuff
        if MPI and ptu.get_mode():
            self.gpu_id = MPI.COMM_WORLD.Get_rank()%num_gpus

        self.num_epochs_per_eval = num_epochs_per_eval
        assert num_epochs_per_param_save % num_epochs_per_eval == 0
        self.num_epochs_per_param_save = num_epochs_per_param_save

        import collections
        # self.reward_buffer = collections.deque([-2*10], 10)

    def train(self, start_epoch=0):
        self.pretrain()
        if start_epoch == 0 and MPI and MPI.COMM_WORLD.Get_rank() == 0:
            params = self.get_epoch_snapshot(-1)
            logger.save_itr_params(-1, params)
        self.training_mode(False)
        self._n_env_steps_total = start_epoch * self.num_env_steps_per_epoch
        gt.reset()
        gt.set_def_unique(False)
        if self.collection_mode == 'online':
            self.train_online(start_epoch=start_epoch)
        elif self.collection_mode == 'batch':
            self.train_batch(start_epoch=start_epoch)
        else:
            raise TypeError("Invalid collection_mode: {}".format(
                self.collection_mode
            ))

    def pretrain(self):
        pass

    def train_online(self, start_epoch=0):
        self._current_path_builder = PathBuilder()
        for epoch in gt.timed_for(
                range(start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            self._start_epoch(epoch)
            set_to_train_mode(self.training_env)
            observation = self._start_new_rollout()
            for _ in range(self.num_env_steps_per_epoch):
                observation = self._take_step_in_env(observation)
                gt.stamp('sample')

                self._try_to_train()
                gt.stamp('train')

            set_to_eval_mode(self.env)
            self._try_to_eval(epoch)
            gt.stamp('eval')
            self._end_epoch(epoch)

    def train_batch(self, start_epoch):
        self._current_path_builder = PathBuilder()
        for epoch in gt.timed_for(
                range(start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            self._start_epoch(epoch)
            set_to_train_mode(self.training_env)
            observation = self._start_new_rollout()
            # This implementation is rather naive. If you want to (e.g.)
            # parallelize data collection, this would be the place to do it.
            for _ in range(self.num_env_steps_per_epoch):
                observation = self._take_step_in_env(observation)
            gt.stamp('sample')

            # self.qf1_optimizer.reinit_flat_operators() #TODO what is this
            self._try_to_train()
            gt.stamp('train')

            set_to_eval_mode(self.env)
            if epoch % self.num_epochs_per_eval == 0:
                self._try_to_eval(epoch)
                gt.stamp('eval')
            self._end_epoch(epoch)

    def _take_step_in_env(self, observation):
        action, agent_info = self._get_action_and_info(
            observation,
        )

        # TODO: remove
        # self.qf1.pooler.current_time_step += 1
        # self.qf2.pooler.current_time_step += 1
        # self.vf.pooler.current_time_step += 1
        # self.qf1.pooler.max_time_horizon = 50 * 2
        # self.qf2.pooler.max_time_horizon = 50 * 2
        # self.vf.pooler.max_time_horizon = 50 * 2

        if self.render:
            self.training_env.render()
        next_ob, raw_reward, terminal, env_info = (
            self.training_env.step(action)
        )
        # self.reward_buffer.append(raw_reward)
        # if sum(self.reward_buffer) >= 0 and self.policy.selection_attention.hard_block == 0:
        #     self.policy.selection_attention.hard_block = 1
            # self.qf1.pooler.selection_attention.hard_block = 1

        self._n_env_steps_total += 1
        reward = raw_reward * self.reward_scale
        terminal = np.array([terminal])
        reward = np.array([reward])
        self._handle_step(
            observation,
            action,
            reward,
            next_ob,
            terminal,
            agent_info=agent_info,
            env_info=env_info,
            mask=get_masks(self.training_env.unwrapped.num_blocks, self.replay_buffer.max_num_blocks, 1)
        )
        # print(F"cpb len {len(self._current_path_builder)}")
        # print(F"terminal {terminal}")
        if terminal or len(self._current_path_builder) >= self.max_path_length:
            self._handle_rollout_ending()
            new_observation = self._start_new_rollout()
        else:
            new_observation = next_ob
        return new_observation

    def _try_to_train(self):
        # assert self.alpha_optimizer.param_groups[0]['params'][0]
        if ptu.get_mode() == "gpu_opt":
            ptu.set_device(device_id=self.gpu_id, device_type="gpu")
            self.to(device=torch.device(F"cuda:{self.gpu_id}"))
            # assert self.alpha_optimizer.m.device.type == "cuda"

        # assert self.alpha_optimizer.m.device.type == "cuda"
        if self._can_train():
            self.training_mode(True)
            # assert self.alpha_optimizer.m.device.type == "cuda"
            for i in range(self.num_updates_per_train_call):
                self._do_training()
                # assert self.alpha_optimizer.m.device.type == "cuda"
                self._n_train_steps_total += 1
            self.training_mode(False)
        if ptu.get_mode() == "gpu_opt":
            ptu.set_device(device_type="cpu")
            self.to(device=torch.device("cpu"))

    def _try_to_eval(self, epoch, eval_paths=None):
        if MPI and MPI.COMM_WORLD.Get_rank() == 0:
            if epoch % self.save_extra_data_interval == 0:
                logger.save_extra_data(self.get_extra_data_to_save(epoch))

            if epoch % self.num_epochs_per_param_save == 0:
                print("Attemping itr param save...")
                params = self.get_epoch_snapshot(epoch)
                logger.save_itr_params(epoch, params)
                print(F"Itr{epoch} param saved!")

        if self._can_evaluate():
            self.evaluate(epoch, eval_paths=eval_paths)

            logger.record_tabular(
                "Number of train steps total",
                self._n_train_steps_total,
            )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            times_itrs = gt.get_times().stamps.itrs
            # train_time = times_itrs['train'][-1]
            training_loops = ['get_batch', 'update_normalizer', 'forward', 'compute_losses', 'qf1_loop', "policy_loss_forward", 'policy_loop', 'vf_loop']
            train_time = sum(times_itrs[loop][-1] for loop in times_itrs.keys())

            sample_time = times_itrs['sample'][-1]

            if epoch > 0:
                eval_time = times_itrs['eval'][-1]
            else:
                times_itrs['eval'] = [0] # Need to do this so we can do line 343, the list comprehension
                eval_time = 0
            epoch_time = train_time + sample_time + eval_time
            total_time = gt.get_times().total

            # logger.record_tabular('Get Batch (s)', times_itrs['get_batch'][-1])
            # logger.record_tabular('Update Normalizer (s)', times_itrs['update_normalizer'][-1])
            # logger.record_tabular('Forward (s)', times_itrs['forward'][-1])
            # logger.record_tabular('Compute Losses (s)', times_itrs['compute_losses'][-1])
            # logger.record_tabular('QF1 Loop (s)', times_itrs['qf1_loop'][-1])
            # logger.record_tabular('QF2 Loop (s)', times_itrs['qf2_loop'][-1])
            # logger.record_tabular("Policy Forward (s)", times_itrs['policy_loss_forward'][-1])
            # logger.record_tabular('Policy Loop (s)', times_itrs['policy_loop'][-1])
            # logger.record_tabular('VF Loop (s)', times_itrs['vf_loop'][-1])

            [logger.record_tabular(key.title(), times_itrs[key][-1]) for key in times_itrs.keys()]

            logger.record_tabular('Train Time (s) ---', train_time)
            logger.record_tabular('(Previous) Eval Time (s) ---', eval_time)
            logger.record_tabular('Sample Time (s) ---', sample_time)
            logger.record_tabular('Epoch Time (s)', epoch_time)
            logger.record_tabular('Total Train Time (s)', total_time)
            logger.record_tabular("Epoch", epoch)

            table_keys = logger.get_table_key_set()
            if self._old_table_keys is not None and table_keys != self._old_table_keys:
                # assert table_keys == self._old_table_keys, (
                #     "Table keys cannot change from iteration to iteration."
                # )
                print("Table keys have changed. Rewriting header and filling with 0s")
                logger.update_header()
                raise NotImplementedError
            self._old_table_keys = table_keys

            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")

    def _can_evaluate(self):
        """
        One annoying thing about the logger table is that the keys at each
        iteration need to be the exact same. So unless you can compute
        everything, skip evaluation.

        A common example for why you might want to skip evaluation is that at
        the beginning of training, you may not have enough data for a
        validation and training set.
        """
        return (
            len(self._exploration_paths) > 0
            and not self.need_to_update_eval_statistics
        )

    def _can_train(self):
        return (
            self.replay_buffer.num_steps_can_sample() >=
            self.min_num_steps_before_training
        )

    def _get_action_and_info(self, observation):
        """
        Get an action to take in the environment.
        :param observation:
        :return:
        """
        self.exploration_policy.set_num_steps_total(self._n_env_steps_total)
        return self.exploration_policy.get_action(
            observation,
        )

    def _start_epoch(self, epoch):
        self._epoch_start_time = time.time()
        self._exploration_paths = []
        self._do_train_time = 0
        logger.push_prefix('Iteration #%d | ' % epoch)

    def _end_epoch(self, epoch):
        logger.log("Epoch Duration: {0}".format(
            time.time() - self._epoch_start_time
        ))
        logger.log("Started Training: {0}".format(self._can_train()))
        logger.pop_prefix()

        for post_epoch_func in self.post_epoch_funcs:
            post_epoch_func(self, epoch)

    def _start_new_rollout(self):
        self.exploration_policy.reset()
        return self.training_env.reset()

    def _handle_path(self, path):
        """
        Naive implementation: just loop through each transition.
        :param path:
        :return:
        """
        for (
            ob,
            action,
            reward,
            next_ob,
            terminal,
            agent_info,
            env_info
        ) in zip(
            path["observations"],
            path["actions"],
            path["rewards"],
            path["next_observations"],
            path["terminals"],
            path["agent_infos"],
            path["env_infos"],
        ):
            self._handle_step(
                ob,
                action,
                reward,
                next_ob,
                terminal,
                agent_info=agent_info,
                env_info=env_info,
            )
        self._handle_rollout_ending()

    def _handle_step(
            self,
            observation,
            action,
            reward,
            next_observation,
            terminal,
            agent_info,
            env_info,
            # full_observations
    ):
        """
        Implement anything that needs to happen after every step
        :return:
        """
        self._current_path_builder.add_all(
            observations=observation,
            actions=action,
            rewards=reward,
            next_observations=next_observation,
            terminals=terminal,
            agent_infos=agent_info,
            env_infos=env_info,
            # full_observations=full_observations,
        )
        self.replay_buffer.add_sample(
            observation=observation,
            action=action,
            reward=reward,
            terminal=terminal,
            next_observation=next_observation,
            agent_info=agent_info,
            env_info=env_info,
            # full_observations=full_observations,
        )

    def _handle_rollout_ending(self):
        """
        Implement anything that needs to happen after every rollout.
        """
        self.replay_buffer.terminate_episode()
        self._n_rollouts_total += 1

        if len(self._current_path_builder) > 0:
            path = self._current_path_builder.get_all_stacked()

            self._exploration_paths.append(path
            )
            self._current_path_builder = PathBuilder()

    def get_epoch_snapshot(self, epoch):
        data_to_save = dict(
            epoch=epoch,
            exploration_policy=self.exploration_policy,
            eval_policy=self.eval_policy,
        )
        if self.save_environment:
            data_to_save['env'] = self.training_env
        return data_to_save

    def get_extra_data_to_save(self, epoch):
        """
        Save things that shouldn't be saved every snapshot but rather
        overwritten every time.
        :param epoch:
        :return:
        """
        if self.render:
            self.training_env.render(close=True)
        data_to_save = dict(
            epoch=epoch,
        )
        if self.save_environment:
            data_to_save['env'] = self.training_env
        if self.save_replay_buffer:
            data_to_save['replay_buffer'] = self.replay_buffer
        if self.save_algorithm:
            data_to_save['algorithm'] = self
        return data_to_save

    @abc.abstractmethod
    def training_mode(self, mode):
        """
        Set training mode to `mode`.
        :param mode: If True, training will happen (e.g. set the dropout
        probabilities to not all ones).
        """
        pass

    def evaluate(self, epoch, eval_paths=None):
        statistics = OrderedDict()
        statistics.update(self.eval_statistics)

        logger.log("Collecting samples for evaluation")
        if eval_paths:
            test_paths = eval_paths
        else:
            test_paths = self.get_eval_paths()
        if hasattr(self.env.unwrapped, "num_blocks"):
            statistics.update(eval_util.get_generic_path_information(
                test_paths, stat_prefix="Test", num_blocks=self.env.unwrapped.num_blocks
            ))
            if len(self._exploration_paths) > 0:
                statistics.update(eval_util.get_generic_path_information(
                    self._exploration_paths, stat_prefix="Exploration"
                ))
        else:
            statistics.update(eval_util.get_generic_path_information(
                test_paths, stat_prefix="Test", num_blocks=None
            ))
            if len(self._exploration_paths) > 0:
                statistics.update(eval_util.get_generic_path_information(
                    self._exploration_paths, stat_prefix="Exploration", num_blocks=None
                ))
        if hasattr(self.env, "log_diagnostics"):
            self.env.log_diagnostics(test_paths, logger=logger)
        if hasattr(self.env, "get_diagnostics"):
            statistics.update(self.env.get_diagnostics(test_paths))

        average_returns = eval_util.get_average_returns(test_paths)
        statistics['AverageReturn'] = average_returns
        for key, value in statistics.items():
            logger.record_tabular(key, value)
        self.need_to_update_eval_statistics = True

    def get_eval_paths(self):
        return self.eval_sampler.obtain_samples()

    @abc.abstractmethod
    def _do_training(self):
        """
        Perform some update, e.g. perform one gradient step.
        :return:
        """
        pass
Пример #18
0
class RLAlgorithm(metaclass=abc.ABCMeta):
    def __init__(
        self,
        env,
        exploration_policy: ExplorationPolicy,
        training_env=None,
        num_epochs=100,
        num_steps_per_epoch=10000,
        num_steps_per_eval=1000,
        num_updates_per_env_step=1,
        num_updates_per_epoch=None,
        batch_size=1024,
        max_path_length=1000,
        discount=0.99,
        replay_buffer_size=1000000,
        reward_scale=1,
        min_num_steps_before_training=None,
        render=False,
        save_replay_buffer=False,
        save_algorithm=False,
        save_environment=True,
        eval_sampler=None,
        eval_policy=None,
        replay_buffer=None,
        collection_mode='online',
    ):
        """
        Base class for RL Algorithms

        :param env: Environment used to evaluate.
        :param exploration_policy: Policy used to explore
        :param training_env: Environment used by the algorithm. By default, a
        copy of `env` will be made for training, so that training and
        evaluation are completely independent.
        :param num_epochs:
        :param num_steps_per_epoch:
        :param num_steps_per_eval:
        :param num_updates_per_env_step: Used by online training mode.
        :param num_updates_per_epoch: Used by batch training mode.
        :param batch_size:
        :param max_path_length:
        :param discount:
        :param replay_buffer_size:
        :param reward_scale:
        :param min_num_steps_before_training:
        :param render:
        :param save_replay_buffer:
        :param save_algorithm:
        :param save_environment:
        :param eval_sampler:
        :param eval_policy: Policy to evaluate with.
        :param replay_buffer:
        :param collection_mode: String determining how training happens
         - 'online': Train after every step taken in the environment.
         - 'batch': Train after every epoch.
        """
        assert collection_mode in ['online', 'batch']
        if collection_mode == 'batch':
            assert num_updates_per_epoch is not None

        self.training_env = training_env  #or pickle.loads(pickle.dumps(env))
        self.exploration_policy = exploration_policy
        self.num_epochs = num_epochs
        self.num_env_steps_per_epoch = num_steps_per_epoch
        self.num_steps_per_eval = num_steps_per_eval
        if collection_mode == 'online':
            self.num_updates_per_train_call = num_updates_per_env_step
        else:
            self.num_updates_per_train_call = num_updates_per_epoch
        self.batch_size = batch_size
        self.max_path_length = max_path_length
        self.discount = discount
        self.replay_buffer_size = replay_buffer_size
        self.reward_scale = reward_scale
        self.render = render
        self.collection_mode = collection_mode
        self.save_replay_buffer = save_replay_buffer
        self.save_algorithm = save_algorithm
        self.save_environment = save_environment
        if min_num_steps_before_training is None:
            min_num_steps_before_training = self.num_env_steps_per_epoch
        self.min_num_steps_before_training = min_num_steps_before_training
        if eval_sampler is None:
            if eval_policy is None:
                eval_policy = exploration_policy
            eval_sampler = InPlacePathSampler(
                env=env,
                policy=eval_policy,
                max_samples=self.num_steps_per_eval + self.max_path_length,
                max_path_length=self.max_path_length,
            )
        self.eval_policy = eval_policy
        self.eval_sampler = eval_sampler
        self.eval_statistics = OrderedDict()
        self.need_to_update_eval_statistics = True

        self.action_space = env.action_space
        self.obs_space = env.observation_space
        self.env = env
        if replay_buffer is None:
            replay_buffer = EnvReplayBuffer(
                self.replay_buffer_size,
                self.env,
            )
        self.replay_buffer = replay_buffer

        self._n_env_steps_total = 0
        self._n_train_steps_total = 0
        self._n_rollouts_total = 0
        self._do_train_time = 0
        self._epoch_start_time = None
        self._algo_start_time = None
        self._old_table_keys = None
        self._current_path_builder = PathBuilder()
        self._exploration_paths = []
        self.post_epoch_funcs = []

    def train(self, start_epoch=0):
        self.pretrain()
        if start_epoch == 0:
            params = self.get_epoch_snapshot(-1)
            logger.save_itr_params(-1, params)
        self.training_mode(False)
        self._n_env_steps_total = start_epoch * self.num_env_steps_per_epoch
        gt.reset()
        gt.set_def_unique(False)
        if self.collection_mode == 'online':
            self.train_online(start_epoch=start_epoch)
        elif self.collection_mode == 'batch':
            self.train_batch(start_epoch=start_epoch)
        else:
            raise TypeError("Invalid collection_mode: {}".format(
                self.collection_mode))

    def pretrain(self):
        pass

    def train_online(self, start_epoch=0):
        self._current_path_builder = PathBuilder()
        for epoch in gt.timed_for(
                range(start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            self._start_epoch(epoch)
            set_to_train_mode(self.training_env)
            observation = self._start_new_rollout()
            for _ in range(self.num_env_steps_per_epoch):
                observation = self._take_step_in_env(observation)
                gt.stamp('sample')

                self._try_to_train()
                gt.stamp('train')

            set_to_eval_mode(self.env)
            self._try_to_eval(epoch)
            gt.stamp('eval')
            self._end_epoch(epoch)

    def train_batch(self, start_epoch):
        self._current_path_builder = PathBuilder()
        for epoch in gt.timed_for(
                range(start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            self._start_epoch(epoch)
            set_to_train_mode(self.training_env)
            observation = self._start_new_rollout()
            # This implementation is rather naive. If you want to (e.g.)
            # parallelize data collection, this would be the place to do it.
            for _ in range(self.num_env_steps_per_epoch):
                observation = self._take_step_in_env(observation)
            gt.stamp('sample')

            self._try_to_train()
            gt.stamp('train')

            set_to_eval_mode(self.env)
            self._try_to_eval(epoch)
            gt.stamp('eval')
            self._end_epoch(epoch)

    def _take_step_in_env(self, observation):
        action, agent_info = self._get_action_and_info(observation, )
        if self.render:
            self.training_env.render()
        next_ob, raw_reward, terminal, env_info = (
            self.training_env.step(action))
        self._n_env_steps_total += 1
        reward = raw_reward * self.reward_scale
        terminal = np.array([terminal])
        reward = np.array([reward])
        self._handle_step(
            observation,
            action,
            reward,
            next_ob,
            terminal,
            agent_info=agent_info,
            env_info=env_info,
        )
        if terminal or len(self._current_path_builder) >= self.max_path_length:
            self._handle_rollout_ending()
            new_observation = self._start_new_rollout()
        else:
            new_observation = next_ob
        return new_observation

    def _try_to_train(self):
        if self._can_train():
            self.training_mode(True)
            for i in range(self.num_updates_per_train_call):
                self._do_training()
                self._n_train_steps_total += 1
            self.training_mode(False)

    def _try_to_eval(self, epoch, eval_paths=None):
        logger.save_extra_data(self.get_extra_data_to_save(epoch))
        if self._can_evaluate():
            self.evaluate(epoch, eval_paths=eval_paths)

            params = self.get_epoch_snapshot(epoch)
            logger.save_itr_params(epoch, params)
            table_keys = logger.get_table_key_set()
            if self._old_table_keys is not None:
                assert table_keys == self._old_table_keys, (
                    "Table keys cannot change from iteration to iteration.")
            self._old_table_keys = table_keys

            logger.record_tabular(
                "Number of train steps total",
                self._n_train_steps_total,
            )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            times_itrs = gt.get_times().stamps.itrs
            train_time = times_itrs['train'][-1]
            sample_time = times_itrs['sample'][-1]
            eval_time = times_itrs['eval'][-1] if epoch > 0 else 0
            epoch_time = train_time + sample_time + eval_time
            total_time = gt.get_times().total

            logger.record_tabular('Train Time (s)', train_time)
            logger.record_tabular('(Previous) Eval Time (s)', eval_time)
            logger.record_tabular('Sample Time (s)', sample_time)
            logger.record_tabular('Epoch Time (s)', epoch_time)
            logger.record_tabular('Total Train Time (s)', total_time)
            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")

    def _can_evaluate(self):
        """
        One annoying thing about the logger table is that the keys at each
        iteration need to be the exact same. So unless you can compute
        everything, skip evaluation.

        A common example for why you might want to skip evaluation is that at
        the beginning of training, you may not have enough data for a
        validation and training set.
        """
        return (len(self._exploration_paths) > 0
                and not self.need_to_update_eval_statistics)

    def _can_train(self):
        return (self.replay_buffer.num_steps_can_sample() >=
                self.min_num_steps_before_training)

    def _get_action_and_info(self, observation):
        """
        Get an action to take in the environment.
        :param observation:
        :return:
        """
        self.exploration_policy.set_num_steps_total(self._n_env_steps_total)
        return self.exploration_policy.get_action(observation, )

    def _start_epoch(self, epoch):
        self._epoch_start_time = time.time()
        self._exploration_paths = []
        self._do_train_time = 0
        logger.push_prefix('Iteration #%d | ' % epoch)

    def _end_epoch(self, epoch):
        logger.log("Epoch Duration: {0}".format(time.time() -
                                                self._epoch_start_time))
        logger.log("Started Training: {0}".format(self._can_train()))
        logger.pop_prefix()

        for post_epoch_func in self.post_epoch_funcs:
            post_epoch_func(self, epoch)

    def _start_new_rollout(self):
        self.exploration_policy.reset()
        return self.training_env.reset()

    def _handle_path(self, path):
        """
        Naive implementation: just loop through each transition.
        :param path:
        :return:
        """
        for (ob, action, reward, next_ob, terminal, agent_info,
             env_info) in zip(
                 path["observations"],
                 path["actions"],
                 path["rewards"],
                 path["next_observations"],
                 path["terminals"],
                 path["agent_infos"],
                 path["env_infos"],
             ):
            self._handle_step(
                ob,
                action,
                reward,
                next_ob,
                terminal,
                agent_info=agent_info,
                env_info=env_info,
            )
        self._handle_rollout_ending()

    def _handle_step(
        self,
        observation,
        action,
        reward,
        next_observation,
        terminal,
        agent_info,
        env_info,
    ):
        """
        Implement anything that needs to happen after every step
        :return:
        """
        self._current_path_builder.add_all(
            observations=observation,
            actions=action,
            rewards=reward,
            next_observations=next_observation,
            terminals=terminal,
            agent_infos=agent_info,
            env_infos=env_info,
        )
        self.replay_buffer.add_sample(
            observation=observation,
            action=action,
            reward=reward,
            terminal=terminal,
            next_observation=next_observation,
            agent_info=agent_info,
            env_info=env_info,
        )

    def _handle_rollout_ending(self):
        """
        Implement anything that needs to happen after every rollout.
        """
        self.replay_buffer.terminate_episode()
        self._n_rollouts_total += 1
        if len(self._current_path_builder) > 0:
            self._exploration_paths.append(
                self._current_path_builder.get_all_stacked())
            self._current_path_builder = PathBuilder()

    def get_epoch_snapshot(self, epoch):
        data_to_save = dict(
            epoch=epoch,
            exploration_policy=self.exploration_policy,
            eval_policy=self.eval_policy,
        )
        if self.save_environment:
            data_to_save['env'] = self.training_env
        return data_to_save

    def get_extra_data_to_save(self, epoch):
        """
        Save things that shouldn't be saved every snapshot but rather
        overwritten every time.
        :param epoch:
        :return:
        """
        if self.render:
            self.training_env.render(close=True)
        data_to_save = dict(epoch=epoch, )
        if self.save_environment:
            data_to_save['env'] = self.training_env
        if self.save_replay_buffer:
            data_to_save['replay_buffer'] = self.replay_buffer
        if self.save_algorithm:
            data_to_save['algorithm'] = self
        return data_to_save

    @abc.abstractmethod
    def training_mode(self, mode):
        """
        Set training mode to `mode`.
        :param mode: If True, training will happen (e.g. set the dropout
        probabilities to not all ones).
        """
        pass

    def evaluate(self, epoch, eval_paths=None):
        statistics = OrderedDict()
        statistics.update(self.eval_statistics)

        logger.log("Collecting samples for evaluation")
        if eval_paths:
            test_paths = eval_paths
        else:
            test_paths = self.get_eval_paths()
        statistics.update(
            eval_util.get_generic_path_information(
                test_paths,
                stat_prefix="Test",
            ))
        if len(self._exploration_paths) > 0:
            statistics.update(
                eval_util.get_generic_path_information(
                    self._exploration_paths,
                    stat_prefix="Exploration",
                ))
        if hasattr(self.env, "log_diagnostics"):
            self.env.log_diagnostics(test_paths, logger=logger)
        if hasattr(self.env, "get_diagnostics"):
            statistics.update(self.env.get_diagnostics(test_paths))

        average_returns = eval_util.get_average_returns(test_paths)
        statistics['AverageReturn'] = average_returns
        for key, value in statistics.items():
            logger.record_tabular(key, value)
        self.need_to_update_eval_statistics = True

    def get_eval_paths(self):
        return self.eval_sampler.obtain_samples()

    @abc.abstractmethod
    def _do_training(self):
        """
        Perform some update, e.g. perform one gradient step.
        :return:
        """
        pass
Пример #19
0
class MetaIRLAlgorithm(metaclass=abc.ABCMeta):
    '''
        While True:
            generate trajectories for a batch of different task settings
            update the models
    '''
    def __init__(
            self,
            env,
            train_context_expert_replay_buffer,
            train_test_expert_replay_buffer,
            test_context_expert_replay_buffer,
            test_test_expert_replay_buffer,
            train_task_params_sampler,
            test_task_params_sampler,
            training_env=None,
            num_epochs=100,
            num_rollouts_per_epoch=10,
            num_rollouts_between_updates=10,
            num_initial_rollouts_for_all_train_tasks=0,
            min_rollouts_before_training=10,
            max_path_length=1000,
            discount=0.99,
            replay_buffer_size_per_task=20000,
            render=False,
            save_replay_buffer=False,
            save_algorithm=False,
            save_environment=False,
            replay_buffer=None,
            policy_uses_pixels=False,
            wrap_absorbing=False,
            freq_saving=1,
            do_not_train=False,
            do_not_eval=False,
            # some environment like halfcheetah_v2 have a timelimit that defines the terminal
            # this is used as a minor hack to turn off time limits
            no_terminal=False,
            save_best=False,
            save_best_after_epoch=0,
            custom_save_epoch=[],
            use_env_getter=False,
            training_env_getter=None,
            test_env_getter=None,
            get_full_obs_dict=False,
            **kwargs):
        self.use_env_getter = use_env_getter
        self.training_env_getter = training_env_getter
        self.test_env_getter = test_env_getter
        self.get_full_obs_dict = get_full_obs_dict
        if self.use_env_getter:
            cur_task_params, cur_obs_task_params = train_task_params_sampler.sample(
            )
            self.training_env = self.training_env_getter(cur_obs_task_params)
        else:
            self.training_env = training_env or pickle.loads(pickle.dumps(env))
        # self.training_env = training_env or deepcopy(env)
        self.train_context_expert_replay_buffer = train_context_expert_replay_buffer
        self.train_test_expert_replay_buffer = train_test_expert_replay_buffer
        self.test_context_expert_replay_buffer = test_context_expert_replay_buffer
        self.test_test_expert_replay_buffer = test_test_expert_replay_buffer
        self.num_epochs = num_epochs
        self.num_rollouts_per_epoch = num_rollouts_per_epoch
        self.num_rollouts_between_updates = num_rollouts_between_updates
        self.num_initial_rollouts_for_all_train_tasks = num_initial_rollouts_for_all_train_tasks
        self.min_rollouts_before_training = min_rollouts_before_training
        self.max_path_length = max_path_length
        self.discount = discount
        self.replay_buffer_size_per_task = replay_buffer_size_per_task
        self.render = render
        self.save_replay_buffer = save_replay_buffer
        self.save_algorithm = save_algorithm
        self.save_environment = save_environment
        self.policy_uses_pixels = policy_uses_pixels

        if self.use_env_getter:
            cur_task_params, cur_obs_task_params = test_task_params_sampler.sample(
            )
            self.env = test_env_getter(cur_obs_task_params)
        else:
            self.env = env
        self.action_space = self.env.action_space
        self.obs_space = self.env.observation_space
        if replay_buffer is None:
            replay_buffer = MetaEnvReplayBuffer(
                self.replay_buffer_size_per_task,
                self.training_env,
                policy_uses_pixels=self.policy_uses_pixels,
            )
        self.replay_buffer = replay_buffer

        self._n_env_steps_total = 0
        self._n_train_steps_total = 0
        self._n_rollouts_total = 0
        self._do_train_time = 0
        self._epoch_start_time = None
        self._algo_start_time = None
        self._old_table_keys = None
        self._current_path_builder = PathBuilder()
        self._exploration_paths = []
        self.wrap_absorbing = wrap_absorbing
        if self.wrap_absorbing:
            assert isinstance(env, WrappedAbsorbingEnv), 'Env is not wrapped!'
        self.freq_saving = freq_saving
        self.no_terminal = no_terminal
        if self.no_terminal:
            print('\n\nDOING NO TERMINAL!\n\n')

        self.train_task_params_sampler = train_task_params_sampler
        self.test_task_params_sampler = test_task_params_sampler
        self.do_not_train = do_not_train
        self.do_not_eval = do_not_eval
        self.best_meta_test = np.float('-inf')
        self.save_best = save_best
        self.save_best_after_epoch = save_best_after_epoch
        self.custom_save_epoch = custom_save_epoch

    def train(self, start_epoch=0):
        self.pretrain()
        if start_epoch == 0:
            params = self.get_epoch_snapshot(-1)
            logger.save_itr_params(-1, params)
        self.training_mode(False)
        # self._n_env_steps_total = start_epoch * self.num_env_steps_per_epoch
        gt.reset()
        gt.set_def_unique(False)
        self.train_online(start_epoch=start_epoch)

    def pretrain(self):
        """
        Do anything before the main training phase.
        """
        if self.num_initial_rollouts_for_all_train_tasks > 0:
            self.generate_rollouts_for_all_train_tasks(
                self.num_initial_rollouts_for_all_train_tasks)
            print('\nGenerated Initial Task Rollouts\n')
            gt.stamp('initial_task_rollouts')

    def generate_rollouts_for_all_train_tasks(self, num_rollouts_per_task):
        '''
        This is a simple work-around for a problem that arises when sampling
        batches for NP-AIRL because you need to be able to sample a minimum
        number of trajectories per train task.
        I will try to replace this with a better fix later.
        '''
        i = 0
        for task_params, obs_task_params in self.train_task_params_sampler:
            print('rollouts for task %d' % i)
            # print('new task rollout')
            for _ in range(num_rollouts_per_task):
                self.generate_exploration_rollout(
                    task_params=task_params, obs_task_params=obs_task_params)
            i += 1
        # exploration paths maintains the exploration paths in one epoch
        # so that we can analyze certain properties of the trajs if we
        # wanted. we don't want these trajs to count towards that really.
        self._exploration_paths = []

    def generate_exploration_rollout(self,
                                     task_params=None,
                                     obs_task_params=None):
        observation, task_identifier = self._start_new_rollout(
            task_params=task_params, obs_task_params=obs_task_params)

        # _current_path_builder is initialized to a new one everytime
        # you call handle rollout ending
        # When you start a new rollout, self.exploration_policy
        # is set to the one for the current task
        terminal = False
        while (not terminal) and len(
                self._current_path_builder) < self.max_path_length:
            if isinstance(self.obs_space, Dict):
                if self.get_full_obs_dict:
                    agent_obs = observation
                else:
                    if self.policy_uses_pixels:
                        agent_obs = observation['pixels']
                    else:
                        agent_obs = observation['obs']
            else:
                agent_obs = observation

            action, agent_info = self._get_action_and_info(agent_obs)
            if self.render:
                self.training_env.render()

            next_ob, raw_reward, terminal, env_info = (
                self.training_env.step(action))
            if self.no_terminal:
                terminal = False

            self._n_env_steps_total += 1
            reward = raw_reward
            terminal = np.array([terminal])
            reward = np.array([reward])
            self._handle_step(
                observation,
                action,
                reward,
                next_ob,
                np.array([False]) if self.wrap_absorbing else terminal,
                task_identifier,
                agent_info=agent_info,
                env_info=env_info,
            )
            observation = next_ob

        if terminal and self.wrap_absorbing:
            raise NotImplementedError("I think they used 0 actions for this")
            # next_ob is the absorbing state
            # for now just using the action from the previous timesteps
            # as well as agent info and env info
            self._handle_step(
                next_ob,
                action,
                # the reward doesn't matter cause it will be
                # overwritten by the model that defines the reward
                # e.g. the discriminator in GAIL
                reward,
                next_ob,
                terminal,
                task_identifier,
                agent_info=agent_info,
                env_info=env_info)

        self._handle_rollout_ending(task_identifier)

    def train_online(self, start_epoch=0):
        # No need for training mode to be True when generating trajectories
        # training mode is automatically set to True
        # in _try_to_train and before exiting
        # it that function it reverts it to False
        self.training_mode(False)
        self._current_path_builder = PathBuilder()
        self._n_rollouts_total = 0

        for epoch in gt.timed_for(
                range(start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            self._start_epoch(epoch)
            print('EPOCH STARTED')
            # print('epoch')
            for _ in range(self.num_rollouts_per_epoch):
                # print('rollout')
                task_params, obs_task_params = self.train_task_params_sampler.sample(
                )
                self.generate_exploration_rollout(
                    task_params=task_params, obs_task_params=obs_task_params)

                # print(self._n_rollouts_total)
                if self._n_rollouts_total % self.num_rollouts_between_updates == 0:
                    gt.stamp('sample')
                    # print('train')
                    if not self.do_not_train: self._try_to_train(epoch)
                    gt.stamp('train')

            if not self.do_not_eval:
                self._try_to_eval(epoch)
                gt.stamp('eval')

            self._end_epoch()

    def _try_to_train(self, epoch):
        if self._can_train():
            self.training_mode(True)
            self._do_training(epoch)
            self._n_train_steps_total += 1
            self.training_mode(False)

    def _try_to_eval(self, epoch):
        if epoch % self.freq_saving == 0:
            logger.save_extra_data(self.get_extra_data_to_save(epoch))
        if self._can_evaluate():
            self.evaluate(epoch)

            if epoch % self.freq_saving == 0:
                params = self.get_epoch_snapshot(epoch)
                logger.save_itr_params(epoch, params)
            table_keys = logger.get_table_key_set()

            # logger.record_tabular(
            #     "Number of train steps total",
            #     self._n_policy_train_steps_total,
            # )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            times_itrs = gt.get_times().stamps.itrs
            train_time = times_itrs['train'][-1]
            sample_time = times_itrs['sample'][-1]
            eval_time = times_itrs['eval'][-1] if epoch > 0 else 0
            epoch_time = train_time + sample_time + eval_time
            total_time = gt.get_times().total

            logger.record_tabular('Train Time (s)', train_time)
            logger.record_tabular('(Previous) Eval Time (s)', eval_time)
            logger.record_tabular('Sample Time (s)', sample_time)
            logger.record_tabular('Epoch Time (s)', epoch_time)
            logger.record_tabular('Total Train Time (s)', total_time)

            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")

    def _can_evaluate(self):
        """
        One annoying thing about the logger table is that the keys at each
        iteration need to be the exact same. So unless you can compute
        everything, skip evaluation.

        A common example for why you might want to skip evaluation is that at
        the beginning of training, you may not have enough data for a
        validation and training set.

        :return:
        """
        return (len(self._exploration_paths) > 0 and
                self._n_rollouts_total >= self.min_rollouts_before_training)

    def _can_train(self):
        return self._n_rollouts_total >= self.min_rollouts_before_training

    def _get_action_and_info(self, observation):
        """
        Get an action to take in the environment.
        :param observation:
        :return:
        """
        self.exploration_policy.set_num_steps_total(self._n_env_steps_total)
        return self.exploration_policy.get_action(observation, )

    def _start_epoch(self, epoch):
        self._epoch_start_time = time.time()
        self._exploration_paths = []
        self._do_train_time = 0
        logger.push_prefix('Iteration #%d | ' % epoch)

    def _end_epoch(self):
        logger.log("Epoch Duration: {0}".format(time.time() -
                                                self._epoch_start_time))
        logger.log("Started Training: {0}".format(self._can_train()))
        logger.pop_prefix()

    def _start_new_rollout(self, task_params=None, obs_task_params=None):
        if self.use_env_getter:
            self.training_env = self.training_env_getter(obs_task_params)
            obs_from_reset = self.training_env.reset()
            observation = self.training_env._get_obs()
        else:
            if task_params is None:
                task_params, obs_task_params = self.train_task_params_sampler.sample(
                )
            observation = self.training_env.reset(
                task_params=task_params, obs_task_params=obs_task_params)
        task_id = self.training_env.task_identifier

        self.exploration_policy = self.get_exploration_policy(task_id)
        self.exploration_policy.reset()

        return observation, task_id

    def _handle_path(self, path, task_identifier):
        """
        Naive implementation: just loop through each transition.
        :param path:
        :return:
        """
        for (ob, action, reward, next_ob, terminal, agent_info,
             env_info) in zip(
                 path["observations"],
                 path["actions"],
                 path["rewards"],
                 path["next_observations"],
                 path["terminals"],
                 path["agent_infos"],
                 path["env_infos"],
             ):
            self._handle_step(
                ob,
                action,
                reward,
                next_ob,
                terminal,
                task_identifier,
                agent_info=agent_info,
                env_info=env_info,
            )
        self._handle_rollout_ending(task_identifier)

    def _handle_step(
        self,
        observation,
        action,
        reward,
        next_observation,
        terminal,
        task_identifier,
        agent_info,
        env_info,
    ):
        """
        Implement anything that needs to happen after every step
        :return:
        """
        self._current_path_builder.add_all(observations=observation,
                                           actions=action,
                                           rewards=reward,
                                           next_observations=next_observation,
                                           terminals=terminal,
                                           agent_infos=agent_info,
                                           env_infos=env_info,
                                           task_identifiers=task_identifier)
        self.replay_buffer.add_sample(
            observation=observation,
            action=action,
            reward=reward,
            terminal=terminal,
            next_observation=next_observation,
            task_identifier=task_identifier,
            agent_info=agent_info,
            env_info=env_info,
        )

    def _handle_rollout_ending(self, task_identifier):
        """
        Implement anything that needs to happen after every rollout.
        """
        self.replay_buffer.terminate_episode(task_identifier)
        self._n_rollouts_total += 1
        if len(self._current_path_builder) > 0:
            self._exploration_paths.append(
                self._current_path_builder.get_all_stacked())
            self._current_path_builder = PathBuilder()

    def get_epoch_snapshot(self, epoch):
        data_to_save = dict(epoch=epoch, )
        if self.save_environment:
            data_to_save['env'] = self.training_env
        return data_to_save

    def get_extra_data_to_save(self, epoch):
        """
        Save things that shouldn't be saved every snapshot but rather
        overwritten every time.
        :param epoch:
        :return:
        """
        if self.render:
            self.training_env.render(close=True)
        data_to_save = dict(epoch=epoch, )
        if self.save_environment:
            data_to_save['env'] = self.training_env
        if self.save_replay_buffer:
            data_to_save['replay_buffer'] = self.replay_buffer
        if self.save_algorithm:
            data_to_save['algorithm'] = self
        return data_to_save

    @abc.abstractmethod
    def get_exploration_policy(self, task_identifier):
        '''
            Since for each task a meta-irl algorithm needs to somehow
            use some expert demonstrations, this is a convenience method
            to get a version of the policy that is handling this stuff internally.

            Example:
            In the neural process meta-irl method, for a given task we need to,
            take some demonstrations, infer the posterior, sample from the posterior,
            then conidtion the policy by concatenating the sample to any observations
            that are passed to the policy. So internally, in np_bc and np_airl, when
            we call get_exploration_policy we set the latent sample for a
            PostCondReparamTanhMultivariateGaussianPolicy and return that. From then on,
            whenever we call get_action on the policy, it internally concatenates the
            latent to the observation passed to it.
        '''
        pass

    @abc.abstractmethod
    def get_eval_policy(self, task_identifier):
        '''
            Since for each task a meta-irl algorithm needs to somehow
            use some expert demonstrations, this is a convenience method
            to get a version of the policy that is handling this stuff internally.

            Example:
            In the neural process meta-irl method, for a given task we need to,
            take some demonstrations, infer the posterior, sample from the posterior,
            then conidtion the policy by concatenating the sample to any observations
            that are passed to the policy. So internally, in np_bc and np_airl, when
            we call get_exploration_policy we set the latent sample for a
            PostCondReparamTanhMultivariateGaussianPolicy and return that. From then on,
            whenever we call get_action on the policy, it internally concatenates the
            latent to the observation passed to it.
        '''
        pass

    @abc.abstractmethod
    def obtain_eval_samples(self, epoch):
        pass

    @abc.abstractmethod
    def training_mode(self, mode):
        """
        Set training mode to `mode`.
        :param mode: If True, training will happen (e.g. set the dropout
        probabilities to not all ones).
        """
        pass

    @abc.abstractmethod
    def cuda(self):
        """
        Turn cuda on.
        :return:
        """
        pass

    @abc.abstractmethod
    def cpu(self):
        """
        Turn cuda off.
        :return:
        """
        pass

    @abc.abstractmethod
    def evaluate(self, epoch):
        """
        Evaluate the policy, e.g. save/print progress.
        :param epoch:
        :return:
        """
        pass

    @abc.abstractmethod
    def _do_training(self):
        """
        Perform some update, e.g. perform one gradient step.
        :return:
        """
        pass
Пример #20
0
    def obtain_eval_samples(self, epoch, mode='meta_train'):
        self.training_mode(False)
        self.policy.eval()

        
        if mode == 'meta_train':
            params_samples = self.train_task_params_sampler.sample_unique(self.num_tasks_per_eval)
        else:
            params_samples = self.test_task_params_sampler.sample_unique(self.num_tasks_per_eval)
        all_eval_tasks_paths = []
        eval_task_num = -1
        for task_params, obs_task_params in params_samples:
            eval_task_num += 1
            saved_task_gif = False
            cur_eval_task_paths = []
            if mode == 'meta_train':
                self.env = self.training_env_getter(obs_task_params)
            else:
                self.env = self.test_env_getter(obs_task_params)
            self.env.reset()
            task_identifier = self.env.task_identifier

            for _ in range(self.num_diff_context_per_eval_task):
                eval_policy, context = self.get_eval_policy(task_identifier, mode=mode, return_context=True)

                for _ in range(self.num_eval_trajs_per_post_sample):
                    cur_eval_path_builder = PathBuilder()
                    observation = self.env.reset()
                    # from scipy.misc import imsave
                    # imsave('plots/junk_vis/val_check_obtain_eval.png', observation['image'].transpose(1,2,0))
                    # if mode == 'meta_test':
                    #     1/0
                    terminal = False

                    while (not terminal) and len(cur_eval_path_builder) < self.max_path_length:
                        agent_obs = observation
                        action, agent_info = self._get_action_and_info(agent_obs)

                        # print(self.env)
                        # print(action)
                        next_ob, raw_reward, terminal, env_info = (self.env.step(action))
                        if self.no_terminal:
                            terminal = False
                        
                        reward = raw_reward
                        terminal = np.array([terminal])
                        reward = np.array([reward])
                        cur_eval_path_builder.add_all(
                            observations=observation,
                            actions=action,
                            rewards=reward,
                            next_observations=next_ob,
                            terminals=terminal,
                            agent_infos=agent_info,
                            env_infos=env_info,
                            task_identifiers=task_identifier
                        )
                        observation = next_ob

                    if terminal and self.wrap_absorbing:
                        raise NotImplementedError("I think they used 0 actions for this")
                        cur_eval_path_builder.add_all(
                            observations=next_ob,
                            actions=action,
                            rewards=reward,
                            next_observations=next_ob,
                            terminals=terminal,
                            agent_infos=agent_info,
                            env_infos=env_info,
                            task_identifiers=task_identifier
                        )
                    
                    if len(cur_eval_path_builder) > 0:
                        cur_eval_task_paths.append(
                            cur_eval_path_builder.get_all_stacked()
                        )
                        if not saved_task_gif:
                            saved_task_gif = True
                            if eval_task_num < 2:
                                path = cur_eval_task_paths[-1]
                                gif_frames = [d["image"]for d in path["observations"]]
                                for frame_num, frame in enumerate(gif_frames):
                                    if frame_num % 4 == 3:
                                        imsave(osp.join(self.log_dir, mode+'task_%d_frame_%d.png'%(eval_task_num, frame_num)), frame.transpose(1,2,0))
                                # print(gif_frames)
                                # for img in gif_frames:
                                #     print(np.max(img), np.min(img))
                                # write_gif(gif_frames, osp.join(self.log_dir, mode+'_%d.gif'%eval_task_num) , fps=20)
                                if self.easy_context or self.last_image_is_context:
                                    context_img = ptu.get_numpy(context)[0].transpose(1,2,0)
                                    imsave(osp.join(self.log_dir, mode+'task_%d_context_%d.png'%(eval_task_num, eval_task_num)), context_img)
                                if self.using_all_context:
                                    context_img = ptu.get_numpy(context['image'][0,-1]).transpose(1,2,0)
                                    imsave(osp.join(self.log_dir, mode+'task_%d_context_%d.png'%(eval_task_num, eval_task_num)), context_img)
                                print('Saved the gifs')
            all_eval_tasks_paths.extend(cur_eval_task_paths)
        
        # flatten the list of lists
        self.policy.train()
        return all_eval_tasks_paths
Пример #21
0
class NPMetaRLAlgorithm(metaclass=abc.ABCMeta):
    def __init__(
            self,
            env_sampler,
            exploration_policy: ExplorationPolicy,
            neural_process,
            train_neural_process=False,
            latent_repr_mode='concat_params',  # OR concat_samples
            num_latent_samples=5,
            num_epochs=100,
            num_steps_per_epoch=10000,
            num_steps_per_eval=1000,
            num_updates_per_env_step=1,
            batch_size=1024,
            max_path_length=1000,
            discount=0.99,
            replay_buffer_size=1000000,
            reward_scale=1,
            render=False,
            save_replay_buffer=False,
            save_algorithm=False,
            save_environment=False,
            eval_sampler=None,
            eval_policy=None,
            replay_buffer=None,
            epoch_to_start_training=0):
        """
        Base class for RL Algorithms
        :param env: Environment used to evaluate.
        :param exploration_policy: Policy used to explore
        :param training_env: Environment used by the algorithm. By default, a
        copy of `env` will be made.
        :param num_epochs:
        :param num_steps_per_epoch:
        :param num_steps_per_eval:
        :param num_updates_per_env_step: Used by online training mode.
        :param num_updates_per_epoch: Used by batch training mode.
        :param batch_size:
        :param max_path_length:
        :param discount:
        :param replay_buffer_size:
        :param reward_scale:
        :param render:
        :param save_replay_buffer:
        :param save_algorithm:
        :param save_environment:
        :param eval_sampler:
        :param eval_policy: Policy to evaluate with.
        :param replay_buffer:
        """
        assert not train_neural_process, 'Have not implemented it yet! Remember to set it to train mode when training'
        self.neural_process = neural_process
        self.neural_process.set_mode('eval')
        self.latent_repr_mode = latent_repr_mode
        self.num_latent_samples = num_latent_samples
        self.env_sampler = env_sampler
        env, env_specs = env_sampler()
        self.training_env, _ = env_sampler(env_specs)
        # self.training_env = training_env or pickle.loads(pickle.dumps(env))
        # self.training_env = training_env or deepcopy(env)
        self.exploration_policy = exploration_policy
        self.num_epochs = num_epochs
        self.num_env_steps_per_epoch = num_steps_per_epoch
        self.num_steps_per_eval = num_steps_per_eval
        self.num_updates_per_train_call = num_updates_per_env_step
        self.batch_size = batch_size
        self.max_path_length = max_path_length
        self.discount = discount
        self.replay_buffer_size = replay_buffer_size
        self.reward_scale = reward_scale
        self.render = render
        self.save_replay_buffer = save_replay_buffer
        self.save_algorithm = save_algorithm
        self.save_environment = save_environment
        self.epoch_to_start_training = epoch_to_start_training

        if self.latent_repr_mode == 'concat_params':

            def get_latent_repr(posterior_state):
                z_mean, z_cov = self.neural_process.get_posterior_params(
                    posterior_state)
                return np.concatenate([z_mean, z_cov])

            self.extra_obs_dim = 2 * self.neural_process.z_dim
        else:

            def get_latent_repr(posterior_state):
                z_mean, z_cov = self.neural_process.get_posterior_params(
                    posterior_state)
                samples = np.random.multivariate_normal(
                    z_mean, np.diag(z_cov), self.num_latent_samples)
                samples = samples.flatten()
                return samples

            self.extra_obs_dim = self.num_latent_samples * self.neural_process.z_dim
        self.get_latent_repr = get_latent_repr

        if eval_sampler is None:
            if eval_policy is None:
                eval_policy = exploration_policy
            eval_sampler = InPlacePathSampler(
                env=env,
                policy=eval_policy,
                max_samples=self.num_steps_per_eval + self.max_path_length,
                max_path_length=self.max_path_length,
                neural_process=neural_process,
                latent_repr_fn=get_latent_repr,
                reward_scale=reward_scale)
        self.eval_policy = eval_policy
        self.eval_sampler = eval_sampler

        self.action_space = env.action_space
        self.obs_space = env.observation_space

        self.env = env
        obs_space_dim = gym_get_dim(self.obs_space)
        act_space_dim = gym_get_dim(self.action_space)
        if replay_buffer is None:
            replay_buffer = SimpleReplayBuffer(
                self.replay_buffer_size,
                obs_space_dim + self.extra_obs_dim,
                act_space_dim,
                discrete_action_dim=isinstance(self.action_space, Discrete))
        self.replay_buffer = replay_buffer

        self._n_env_steps_total = 0
        self._n_train_steps_total = 0
        self._n_rollouts_total = 0
        self._do_train_time = 0
        self._epoch_start_time = None
        self._algo_start_time = None
        self._old_table_keys = None
        self._current_path_builder = PathBuilder()
        self._exploration_paths = []

    def train(self, start_epoch=0):
        self.pretrain()
        if start_epoch == 0:
            params = self.get_epoch_snapshot(-1)
            logger.save_itr_params(-1, params)
        self.training_mode(False)
        self._n_env_steps_total = start_epoch * self.num_env_steps_per_epoch
        gt.reset()
        gt.set_def_unique(False)
        self.train_online(start_epoch=start_epoch)

    def pretrain(self):
        """
        Do anything before the main training phase.
        """
        pass

    def train_online(self, start_epoch=0):
        self._current_path_builder = PathBuilder()
        observation = self._start_new_rollout()
        for epoch in gt.timed_for(
                range(start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            self._start_epoch(epoch)
            for _ in range(self.num_env_steps_per_epoch):
                action, agent_info = self._get_action_and_info(observation, )
                if self.render:
                    self.training_env.render()
                next_ob, raw_reward, terminal, env_info = (
                    self.training_env.step(action))
                self._n_env_steps_total += 1
                reward = raw_reward * self.reward_scale
                terminal = np.array([terminal])
                reward = np.array([reward])

                self.posterior_state = self.neural_process.update_posterior_state(
                    self.posterior_state, observation[self.extra_obs_dim:],
                    action, reward, next_ob)
                next_ob = np.concatenate(
                    [self.get_latent_repr(self.posterior_state), next_ob])

                self._handle_step(
                    observation,
                    action,
                    reward,
                    next_ob,
                    terminal,
                    agent_info=agent_info,
                    env_info=env_info,
                )
                if terminal or len(
                        self._current_path_builder) >= self.max_path_length:
                    self._handle_rollout_ending()
                    observation = self._start_new_rollout()
                else:
                    observation = next_ob

                gt.stamp('sample')
                if epoch >= self.epoch_to_start_training:
                    self._try_to_train()
                gt.stamp('train')

            if epoch >= self.epoch_to_start_training:
                self._try_to_eval(epoch)
            gt.stamp('eval')
            self._end_epoch()

    def _try_to_train(self):
        if self._can_train():
            self.training_mode(True)
            for i in range(self.num_updates_per_train_call):
                self._do_training()
                self._n_train_steps_total += 1
            self.training_mode(False)

    def _try_to_eval(self, epoch):
        logger.save_extra_data(self.get_extra_data_to_save(epoch))
        if self._can_evaluate():
            self.evaluate(epoch)

            params = self.get_epoch_snapshot(epoch)
            logger.save_itr_params(epoch, params)
            table_keys = logger.get_table_key_set()
            if self._old_table_keys is not None:
                print('$$$$$$$$$$$$$$$')
                print(table_keys)
                print('\n' * 4)
                print(self._old_table_keys)
                print('$$$$$$$$$$$$$$$')
                print(set(table_keys) - set(self._old_table_keys))
                print(set(self._old_table_keys) - set(table_keys))
                assert table_keys == self._old_table_keys, (
                    "Table keys cannot change from iteration to iteration.")
            self._old_table_keys = table_keys

            logger.record_tabular(
                "Number of train steps total",
                self._n_train_steps_total,
            )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            times_itrs = gt.get_times().stamps.itrs
            train_time = times_itrs['train'][-1]
            sample_time = times_itrs['sample'][-1]
            eval_time = times_itrs['eval'][-1] if epoch > 0 else 0
            epoch_time = train_time + sample_time + eval_time
            total_time = gt.get_times().total

            logger.record_tabular('Train Time (s)', train_time)
            logger.record_tabular('(Previous) Eval Time (s)', eval_time)
            logger.record_tabular('Sample Time (s)', sample_time)
            logger.record_tabular('Epoch Time (s)', epoch_time)
            logger.record_tabular('Total Train Time (s)', total_time)

            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")

    def _can_evaluate(self):
        """
        One annoying thing about the logger table is that the keys at each
        iteration need to be the exact same. So unless you can compute
        everything, skip evaluation.

        A common example for why you might want to skip evaluation is that at
        the beginning of training, you may not have enough data for a
        validation and training set.

        :return:
        """
        return (len(self._exploration_paths) > 0 and
                self.replay_buffer.num_steps_can_sample() >= self.batch_size)

    def _can_train(self):
        return self.replay_buffer.num_steps_can_sample() >= self.batch_size

    def _get_action_and_info(self, observation):
        """
        Get an action to take in the environment.
        :param observation:
        :return:
        """
        self.exploration_policy.set_num_steps_total(self._n_env_steps_total)
        return self.exploration_policy.get_action(observation, )

    def _start_epoch(self, epoch):
        self._epoch_start_time = time.time()
        self._exploration_paths = []
        self._do_train_time = 0
        logger.push_prefix('Iteration #%d | ' % epoch)

    def _end_epoch(self):
        logger.log("Epoch Duration: {0}".format(time.time() -
                                                self._epoch_start_time))
        logger.log("Started Training: {0}".format(self._can_train()))
        logger.pop_prefix()

    def _start_new_rollout(self):
        self.exploration_policy.reset()
        self.env, env_specs = self.env_sampler()
        self.training_env, _ = self.env_sampler(env_specs)

        obs = self.training_env.reset()

        self.posterior_state = self.neural_process.reset_posterior_state()
        latent_repr = self.get_latent_repr(self.posterior_state)

        obs = np.concatenate([latent_repr, obs])
        return obs

    def _handle_path(self, path):
        """
        Naive implementation: just loop through each transition.
        :param path:
        :return:
        """
        for (ob, action, reward, next_ob, terminal, agent_info,
             env_info) in zip(
                 path["observations"],
                 path["actions"],
                 path["rewards"],
                 path["next_observations"],
                 path["terminals"],
                 path["agent_infos"],
                 path["env_infos"],
             ):
            self._handle_step(
                ob,
                action,
                reward,
                next_ob,
                terminal,
                agent_info=agent_info,
                env_info=env_info,
            )
        self._handle_rollout_ending()

    def _handle_step(
        self,
        observation,
        action,
        reward,
        next_observation,
        terminal,
        agent_info,
        env_info,
    ):
        """
        Implement anything that needs to happen after every step
        :return:
        """
        self._current_path_builder.add_all(
            observations=observation,
            actions=action,
            rewards=reward,
            next_observations=next_observation,
            terminals=terminal,
            agent_infos=agent_info,
            env_infos=env_info,
        )
        self.replay_buffer.add_sample(
            observation=observation,
            action=action,
            reward=reward,
            terminal=terminal,
            next_observation=next_observation,
            agent_info=agent_info,
            env_info=env_info,
        )

    def _handle_rollout_ending(self):
        """
        Implement anything that needs to happen after every rollout.
        """
        self.replay_buffer.terminate_episode()
        self._n_rollouts_total += 1
        if len(self._current_path_builder) > 0:
            self._exploration_paths.append(
                self._current_path_builder.get_all_stacked())
            self._current_path_builder = PathBuilder()

    def get_epoch_snapshot(self, epoch):
        data_to_save = dict(
            epoch=epoch,
            exploration_policy=self.exploration_policy,
        )
        if self.save_environment:
            data_to_save['env'] = self.training_env
        return data_to_save

    def get_extra_data_to_save(self, epoch):
        """
        Save things that shouldn't be saved every snapshot but rather
        overwritten every time.
        :param epoch:
        :return:
        """
        if self.render:
            self.training_env.render(close=True)
        data_to_save = dict(epoch=epoch, )
        if self.save_environment:
            data_to_save['env'] = self.training_env
        if self.save_replay_buffer:
            data_to_save['replay_buffer'] = self.replay_buffer
        if self.save_algorithm:
            data_to_save['algorithm'] = self
        return data_to_save

    @abc.abstractmethod
    def training_mode(self, mode):
        """
        Set training mode to `mode`.
        :param mode: If True, training will happen (e.g. set the dropout
        probabilities to not all ones).
        """
        pass

    @abc.abstractmethod
    def cuda(self):
        """
        Turn cuda on.
        :return:
        """
        pass

    @abc.abstractmethod
    def evaluate(self, epoch):
        """
        Evaluate the policy, e.g. save/print progress.
        :param epoch:
        :return:
        """
        pass

    @abc.abstractmethod
    def _do_training(self):
        """
        Perform some update, e.g. perform one gradient step.
        :return:
        """
        pass
Пример #22
0
class MetaRLAlgorithm(metaclass=abc.ABCMeta):
    def __init__(
        self,
        env,
        policy,
        train_tasks,
        eval_tasks,
        meta_batch=64,
        num_iterations=100,
        num_train_steps_per_itr=1000,
        num_tasks_sample=100,
        num_steps_per_task=100,
        num_evals=10,
        num_steps_per_eval=1000,
        batch_size=1024,
        embedding_batch_size=1024,
        embedding_mini_batch_size=1024,
        max_path_length=1000,
        discount=0.99,
        replay_buffer_size=1000000,
        reward_scale=1,
        train_embedding_source='posterior_only',
        eval_embedding_source='initial_pool',
        eval_deterministic=True,
        render=False,
        save_replay_buffer=False,
        save_algorithm=False,
        save_environment=False,
    ):
        """
        Base class for Meta RL Algorithms
        :param env: training env
        :param policy: policy that is conditioned on a latent variable z that rl_algorithm is responsible for feeding in
        :param train_tasks: list of tasks used for training
        :param eval_tasks: list of tasks used for eval
        :param meta_batch: number of tasks used for meta-update
        :param num_iterations: number of meta-updates taken
        :param num_train_steps_per_itr: number of meta-updates performed per iteration
        :param num_tasks_sample: number of train tasks to sample to collect data for
        :param num_steps_per_task: number of transitions to collect per task
        :param num_evals: number of independent evaluation runs, with separate task encodings
        :param num_steps_per_eval: number of transitions to sample for evaluation
        :param batch_size: size of batches used to compute RL update
        :param embedding_batch_size: size of batches used to compute embedding
        :param embedding_mini_batch_size: size of batch used for encoder update
        :param max_path_length: max episode length
        :param discount:
        :param replay_buffer_size: max replay buffer size
        :param reward_scale:
        :param render:
        :param save_replay_buffer:
        :param save_algorithm:
        :param save_environment:
        """
        self.env = env
        self.policy = policy
        self.exploration_policy = policy  # Can potentially use a different policy purely for exploration rather than also solving tasks, currently not being used
        self.train_tasks = train_tasks
        self.eval_tasks = eval_tasks
        self.meta_batch = meta_batch
        self.num_iterations = num_iterations
        self.num_train_steps_per_itr = num_train_steps_per_itr
        self.num_tasks_sample = num_tasks_sample
        self.num_steps_per_task = num_steps_per_task
        self.num_evals = num_evals
        self.num_steps_per_eval = num_steps_per_eval
        self.batch_size = batch_size
        self.embedding_batch_size = embedding_batch_size
        self.embedding_mini_batch_size = embedding_mini_batch_size
        self.max_path_length = max_path_length
        self.discount = discount
        self.replay_buffer_size = replay_buffer_size
        self.reward_scale = reward_scale
        self.train_embedding_source = train_embedding_source
        self.eval_embedding_source = eval_embedding_source  # TODO: add options for computing embeddings on train tasks too
        self.eval_deterministic = eval_deterministic
        self.render = render
        self.save_replay_buffer = save_replay_buffer
        self.save_algorithm = save_algorithm
        self.save_environment = save_environment

        self.eval_sampler = InPlacePathSampler(
            env=env,
            policy=policy,
            max_samples=self.num_steps_per_eval,
            max_path_length=self.max_path_length,
        )

        # separate replay buffers for
        # - training RL update
        # - training encoder update
        # - testing encoder
        self.replay_buffer = MultiTaskReplayBuffer(
            self.replay_buffer_size,
            env,
            self.train_tasks,
        )

        self.enc_replay_buffer = MultiTaskReplayBuffer(
            self.replay_buffer_size,
            env,
            self.train_tasks,
        )
        self.eval_enc_replay_buffer = MultiTaskReplayBuffer(
            self.replay_buffer_size, env, self.eval_tasks)

        self._n_env_steps_total = 0
        self._n_train_steps_total = 0
        self._n_rollouts_total = 0
        self._do_train_time = 0
        self._epoch_start_time = None
        self._algo_start_time = None
        self._old_table_keys = None
        self._current_path_builder = PathBuilder()
        self._exploration_paths = []

    def make_exploration_policy(self, policy):
        return policy

    def make_eval_policy(self, policy):
        return policy

    def sample_task(self, is_eval=False):
        '''
        sample task randomly
        '''
        if is_eval:
            idx = np.random.randint(len(self.eval_tasks))
        else:
            idx = np.random.randint(len(self.train_tasks))
        return idx

    def train(self):
        '''
        meta-training loop
        '''
        self.pretrain()
        params = self.get_epoch_snapshot(-1)
        logger.save_itr_params(-1, params)
        gt.reset()
        gt.set_def_unique(False)
        self._current_path_builder = PathBuilder()
        self.train_obs = self._start_new_rollout()

        # at each iteration, we first collect data from tasks, perform meta-updates, then try to evaluate
        for it_ in gt.timed_for(
                range(self.num_iterations),
                save_itrs=True,
        ):
            self._start_epoch(it_)
            self.training_mode(True)
            if it_ == 0:
                print('collecting initial pool of data for train and eval')
                # temp for evaluating
                for idx in self.train_tasks:
                    self.task_idx = idx
                    self.env.reset_task(idx)
                    self.collect_data_sampling_from_prior(
                        num_samples=self.max_path_length * 10,
                        resample_z_every_n=self.max_path_length,
                        eval_task=False)
                """
                for idx in self.eval_tasks:
                    self.task_idx = idx
                    self.env.reset_task(idx)
                    # TODO: make number of initial trajectories a parameter
                    self.collect_data_sampling_from_prior(num_samples=self.max_path_length * 20,
                                                          resample_z_every_n=self.max_path_length,
                                                          eval_task=True)
                """

            # Sample data from train tasks.
            for i in range(self.num_tasks_sample):
                idx = np.random.randint(len(self.train_tasks))
                self.task_idx = idx
                self.env.reset_task(idx)

                # TODO: there may be more permutations of sampling/adding to encoding buffer we may wish to try
                if self.train_embedding_source == 'initial_pool':
                    # embeddings are computed using only the initial pool of data
                    # sample data from posterior to train RL algorithm
                    self.collect_data_from_task_posterior(
                        idx=idx,
                        num_samples=self.num_steps_per_task,
                        add_to_enc_buffer=False)
                elif self.train_embedding_source == 'posterior_only':
                    self.collect_data_from_task_posterior(
                        idx=idx,
                        num_samples=self.num_steps_per_task,
                        eval_task=False,
                        add_to_enc_buffer=True)
                elif self.train_embedding_source == 'online_exploration_trajectories':
                    # embeddings are computed using only data collected using the prior
                    # sample data from posterior to train RL algorithm
                    self.enc_replay_buffer.task_buffers[idx].clear()
                    # resamples using current policy, conditioned on prior
                    self.collect_data_sampling_from_prior(
                        num_samples=self.num_steps_per_task,
                        resample_z_every_n=self.max_path_length,
                        add_to_enc_buffer=True)

                    self.collect_data_from_task_posterior(
                        idx=idx,
                        num_samples=self.num_steps_per_task,
                        add_to_enc_buffer=False)
                elif self.train_embedding_source == 'online_on_policy_trajectories':
                    # sample from prior, then sample more from the posterior
                    # embeddings computed from both prior and posterior data
                    self.enc_replay_buffer.task_buffers[idx].clear()
                    self.collect_data_online(
                        idx=idx,
                        num_samples=self.num_steps_per_task,
                        add_to_enc_buffer=True)
                else:
                    raise Exception(
                        "Invalid option for computing train embedding {}".
                        format(self.train_embedding_source))

            # Sample train tasks and compute gradient updates on parameters.
            for train_step in range(self.num_train_steps_per_itr):
                indices = np.random.choice(self.train_tasks, self.meta_batch)
                self._do_training(indices)
                self._n_train_steps_total += 1
            gt.stamp('train')

            #self.training_mode(False)

            # eval
            self._try_to_eval(it_)
            gt.stamp('eval')

            self._end_epoch()

    def pretrain(self):
        """
        Do anything before the main training phase.
        """
        pass

    def sample_z_from_prior(self):
        """
        Samples z from the prior distribution, which can be either a delta function at 0 or a standard Gaussian
        depending on whether we use the information bottleneck.
        :return: latent z as a Numpy array
        """
        pass

    def sample_z_from_posterior(self, idx, eval_task):
        """
        Samples z from the posterior distribution given data from task idx, where data comes from the encoding buffer
        :param idx: task idx from which to compute the posterior from
        :param eval_task: whether or not the task is an eval task
        :return: latent z as a Numpy array
        """
        pass

    # TODO: maybe find a better name for resample_z_every_n?
    def collect_data_sampling_from_prior(self,
                                         num_samples=1,
                                         resample_z_every_n=None,
                                         eval_task=False,
                                         add_to_enc_buffer=True):
        # do not resample z if resample_z_every_n is None
        if resample_z_every_n is None:
            self.policy.clear_z()
            self.collect_data(self.policy,
                              num_samples=num_samples,
                              eval_task=eval_task,
                              add_to_enc_buffer=add_to_enc_buffer)
        else:
            # collects more data in batches of resample_z_every_n until done
            while num_samples > 0:
                self.collect_data_sampling_from_prior(
                    num_samples=min(resample_z_every_n, num_samples),
                    resample_z_every_n=None,
                    eval_task=eval_task,
                    add_to_enc_buffer=add_to_enc_buffer)
                num_samples -= resample_z_every_n

    def collect_data_from_task_posterior(self,
                                         idx,
                                         num_samples=1,
                                         resample_z_every_n=None,
                                         eval_task=False,
                                         add_to_enc_buffer=True):
        # do not resample z if resample_z_every_n is None
        if resample_z_every_n is None:
            self.sample_z_from_posterior(idx, eval_task=eval_task)
            self.collect_data(self.policy,
                              num_samples=num_samples,
                              eval_task=eval_task,
                              add_to_enc_buffer=add_to_enc_buffer)
        else:
            # collects more data in batches of resample_z_every_n until done
            while num_samples > 0:
                self.collect_data_from_task_posterior(
                    idx=idx,
                    num_samples=min(resample_z_every_n, num_samples),
                    resample_z_every_n=None,
                    eval_task=eval_task,
                    add_to_enc_buffer=add_to_enc_buffer)
                num_samples -= resample_z_every_n

    # split number of prior and posterior samples
    def collect_data_online(self,
                            idx,
                            num_samples,
                            eval_task=False,
                            add_to_enc_buffer=True):
        self.collect_data_sampling_from_prior(
            num_samples=num_samples,
            resample_z_every_n=self.max_path_length,
            eval_task=eval_task,
            add_to_enc_buffer=True)
        self.collect_data_from_task_posterior(
            idx=idx,
            num_samples=num_samples,
            resample_z_every_n=self.max_path_length,
            eval_task=eval_task,
            add_to_enc_buffer=add_to_enc_buffer)

    # TODO: since switching tasks now resets the environment, we are not correctly handling episodes terminating
    # correctly. We also aren't using the episodes anywhere, but we should probably change this to make it gather paths
    # until we have more samples than num_samples, to make sure every episode cleanly terminates when intended.
    def collect_data(self,
                     agent,
                     num_samples=1,
                     eval_task=False,
                     add_to_enc_buffer=True):
        '''
        collect data from current env in batch mode
        with given policy
        '''
        for _ in range(num_samples):
            action, agent_info = self._get_action_and_info(
                agent, self.train_obs)
            if self.render:
                self.env.render()
            next_ob, raw_reward, terminal, env_info = (self.env.step(action))
            reward = raw_reward
            terminal = np.array([terminal])
            reward = np.array([reward])
            self._handle_step(
                self.task_idx,
                self.train_obs,
                action,
                reward,
                next_ob,
                terminal,
                eval_task=eval_task,
                add_to_enc_buffer=add_to_enc_buffer,
                agent_info=agent_info,
                env_info=env_info,
            )
            if terminal or len(
                    self._current_path_builder) >= self.max_path_length:
                self._handle_rollout_ending(eval_task=eval_task)
                self.train_obs = self._start_new_rollout()
            else:
                self.train_obs = next_ob

        if not eval_task:
            self._n_env_steps_total += num_samples
            gt.stamp('sample')

    def _try_to_eval(self, epoch):
        logger.save_extra_data(self.get_extra_data_to_save(epoch))
        if self._can_evaluate():
            self.evaluate(epoch)

            params = self.get_epoch_snapshot(epoch)
            logger.save_itr_params(epoch, params)
            table_keys = logger.get_table_key_set()
            if self._old_table_keys is not None:
                assert table_keys == self._old_table_keys, (
                    "Table keys cannot change from iteration to iteration.")
            self._old_table_keys = table_keys

            logger.record_tabular(
                "Number of train steps total",
                self._n_train_steps_total,
            )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            times_itrs = gt.get_times().stamps.itrs
            train_time = times_itrs['train'][-1]
            sample_time = times_itrs['sample'][-1]
            eval_time = times_itrs['eval'][-1] if epoch > 0 else 0
            epoch_time = train_time + sample_time + eval_time
            total_time = gt.get_times().total

            logger.record_tabular('Train Time (s)', train_time)
            logger.record_tabular('(Previous) Eval Time (s)', eval_time)
            logger.record_tabular('Sample Time (s)', sample_time)
            logger.record_tabular('Epoch Time (s)', epoch_time)
            logger.record_tabular('Total Train Time (s)', total_time)

            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")

    def _can_evaluate(self):
        """
        One annoying thing about the logger table is that the keys at each
        iteration need to be the exact same. So unless you can compute
        everything, skip evaluation.

        A common example for why you might want to skip evaluation is that at
        the beginning of training, you may not have enough data for a
        validation and training set.

        :return:
        """
        return (len(self._exploration_paths) > 0
                and self.replay_buffer.num_steps_can_sample(
                    self.task_idx) >= self.batch_size)

    def _can_train(self):
        return all([
            self.replay_buffer.num_steps_can_sample(idx) >= self.batch_size
            for idx in self.train_tasks
        ])

    def _get_action_and_info(self, agent, observation):
        """
        Get an action to take in the environment.
        :param observation:
        :return:
        """
        agent.set_num_steps_total(self._n_env_steps_total)
        return agent.get_action(observation, )

    def _start_epoch(self, epoch):
        self._epoch_start_time = time.time()
        self._exploration_paths = []
        self._do_train_time = 0
        logger.push_prefix('Iteration #%d | ' % epoch)

    def _end_epoch(self):
        logger.log("Epoch Duration: {0}".format(time.time() -
                                                self._epoch_start_time))
        logger.log("Started Training: {0}".format(self._can_train()))
        logger.pop_prefix()

    def _start_new_rollout(self):
        return self.env.reset()

    # not used
    def _handle_path(self, path):
        """
        Naive implementation: just loop through each transition.
        :param path:
        :return:
        """
        for (ob, action, reward, next_ob, terminal, agent_info,
             env_info) in zip(
                 path["observations"],
                 path["actions"],
                 path["rewards"],
                 path["next_observations"],
                 path["terminals"],
                 path["agent_infos"],
                 path["env_infos"],
             ):
            self._handle_step(
                ob,
                action,
                reward,
                next_ob,
                terminal,
                agent_info=agent_info,
                env_info=env_info,
            )
        self._handle_rollout_ending()

    def _handle_step(
        self,
        task_idx,
        observation,
        action,
        reward,
        next_observation,
        terminal,
        agent_info,
        env_info,
        eval_task=False,
        add_to_enc_buffer=True,
    ):
        """
        Implement anything that needs to happen after every step
        :return:
        """
        self._current_path_builder.add_all(
            task=task_idx,
            observations=observation,
            actions=action,
            rewards=reward,
            next_observations=next_observation,
            terminals=terminal,
            agent_infos=agent_info,
            env_infos=env_info,
        )
        if eval_task:
            self.eval_enc_replay_buffer.add_sample(
                task=task_idx,
                observation=observation,
                action=action,
                reward=reward,
                terminal=terminal,
                next_observation=next_observation,
                agent_info=agent_info,
                env_info=env_info,
            )
        else:
            self.replay_buffer.add_sample(
                task=task_idx,
                observation=observation,
                action=action,
                reward=reward,
                terminal=terminal,
                next_observation=next_observation,
                agent_info=agent_info,
                env_info=env_info,
            )
            if add_to_enc_buffer:
                self.enc_replay_buffer.add_sample(
                    task=task_idx,
                    observation=observation,
                    action=action,
                    reward=reward,
                    terminal=terminal,
                    next_observation=next_observation,
                    agent_info=agent_info,
                    env_info=env_info,
                )

    def _handle_rollout_ending(self, eval_task=False):
        """
        Implement anything that needs to happen after every rollout.
        """
        if eval_task:
            self.eval_enc_replay_buffer.terminate_episode(self.task_idx)
        else:
            self.replay_buffer.terminate_episode(self.task_idx)
            self.enc_replay_buffer.terminate_episode(self.task_idx)

        self._n_rollouts_total += 1
        if len(self._current_path_builder) > 0:
            self._exploration_paths.append(
                self._current_path_builder.get_all_stacked())
            self._current_path_builder = PathBuilder()

    def get_epoch_snapshot(self, epoch):
        data_to_save = dict(
            epoch=epoch,
            exploration_policy=self.exploration_policy,
        )
        if self.save_environment:
            data_to_save['env'] = self.training_env
        return data_to_save

    def get_extra_data_to_save(self, epoch):
        """
        Save things that shouldn't be saved every snapshot but rather
        overwritten every time.
        :param epoch:
        :return:
        """
        if self.render:
            self.training_env.render(close=True)
        data_to_save = dict(epoch=epoch, )
        if self.save_environment:
            data_to_save['env'] = self.training_env
        if self.save_replay_buffer:
            data_to_save['replay_buffer'] = self.replay_buffer
        if self.save_algorithm:
            data_to_save['algorithm'] = self
        return data_to_save

    @abc.abstractmethod
    def training_mode(self, mode):
        """
        Set training mode to `mode`.
        :param mode: If True, training will happen (e.g. set the dropout
        probabilities to not all ones).
        """
        pass

    @abc.abstractmethod
    def evaluate(self, epoch):
        """
        Evaluate the policy, e.g. save/print progress.
        :param epoch:
        :return:
        """
        pass

    @abc.abstractmethod
    def _do_training(self):
        """
        Perform some update, e.g. perform one gradient step.
        :return:
        """
        pass
Пример #23
0
    def obtain_eval_samples(self, epoch, mode='meta_train'):
        self.training_mode(False)

        if mode == 'meta_train':
            params_samples = self.train_task_params_sampler.sample_unique(
                self.num_tasks_per_eval)
        else:
            params_samples = self.test_task_params_sampler.sample_unique(
                self.num_tasks_per_eval)
        all_eval_tasks_paths = []
        for task_params, obs_task_params in params_samples:
            cur_eval_task_paths = []
            self.env.reset(task_params=task_params,
                           obs_task_params=obs_task_params)
            task_identifier = self.env.task_identifier

            for _ in range(self.num_diff_context_per_eval_task):
                eval_policy = self.get_eval_policy(task_identifier, mode=mode)

                for _ in range(self.num_eval_trajs_per_post_sample):
                    cur_eval_path_builder = PathBuilder()
                    observation = self.env.reset(
                        task_params=task_params,
                        obs_task_params=obs_task_params)
                    terminal = False

                    while (not terminal) and len(
                            cur_eval_path_builder) < self.max_path_length:
                        if isinstance(self.obs_space, Dict):
                            if self.policy_uses_pixels:
                                agent_obs = observation['pixels']
                            else:
                                agent_obs = observation['obs']
                        else:
                            agent_obs = observation
                        action, agent_info = eval_policy.get_action(agent_obs)

                        next_ob, raw_reward, terminal, env_info = (
                            self.env.step(action))
                        if self.no_terminal:
                            terminal = False

                        reward = raw_reward
                        terminal = np.array([terminal])
                        reward = np.array([reward])
                        cur_eval_path_builder.add_all(
                            observations=observation,
                            actions=action,
                            rewards=reward,
                            next_observations=next_ob,
                            terminals=terminal,
                            agent_infos=agent_info,
                            env_infos=env_info,
                            task_identifiers=task_identifier)
                        observation = next_ob

                    if terminal and self.wrap_absorbing:
                        raise NotImplementedError(
                            "I think they used 0 actions for this")
                        cur_eval_path_builder.add_all(
                            observations=next_ob,
                            actions=action,
                            rewards=reward,
                            next_observations=next_ob,
                            terminals=terminal,
                            agent_infos=agent_info,
                            env_infos=env_info,
                            task_identifiers=task_identifier)

                    if len(cur_eval_path_builder) > 0:
                        cur_eval_task_paths.append(
                            cur_eval_path_builder.get_all_stacked())
            all_eval_tasks_paths.extend(cur_eval_task_paths)

        # flatten the list of lists
        return all_eval_tasks_paths
Пример #24
0
class MdpStepCollector(StepCollector):
    def __init__(
        self,
        env,
        policy,
        max_num_epoch_paths_saved=None,
        render=False,
        render_kwargs=None,
    ):
        if render_kwargs is None:
            render_kwargs = {}
        self._env = env
        self._policy = policy
        self._max_num_epoch_paths_saved = max_num_epoch_paths_saved
        self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved)
        self._render = render
        self._render_kwargs = render_kwargs

        self._num_steps_total = 0
        self._num_paths_total = 0
        self._obs = None  # cache variable

    def get_epoch_paths(self):
        return self._epoch_paths

    def end_epoch(self, epoch):
        self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved)
        self._obs = None

    def get_diagnostics(self):
        path_lens = [len(path['actions']) for path in self._epoch_paths]
        stats = OrderedDict([
            ('num steps total', self._num_steps_total),
            ('num paths total', self._num_paths_total),
        ])
        stats.update(
            create_stats_ordered_dict(
                "path length",
                path_lens,
                always_show_all_stats=True,
            ))
        return stats

    def get_snapshot(self):
        return dict(
            env=self._env,
            policy=self._policy,
        )

    def collect_new_steps(
        self,
        max_path_length,
        num_steps,
        discard_incomplete_paths,
    ):
        for _ in range(num_steps):
            self.collect_one_step(max_path_length, discard_incomplete_paths)

    def collect_one_step(
        self,
        max_path_length,
        discard_incomplete_paths,
    ):
        if self._obs is None:
            self._start_new_rollout()

        action, agent_info = self._policy.get_action(self._obs)
        next_ob, reward, terminal, env_info = (self._env.step(action))
        if self._render:
            self._env.render(**self._render_kwargs)
        terminal = np.array([terminal])
        reward = np.array([reward])
        # store path obs
        self._current_path_builder.add_all(
            observations=self._obs,
            actions=action,
            rewards=reward,
            next_observations=next_ob,
            terminals=terminal,
            agent_infos=agent_info,
            env_infos=env_info,
        )
        if terminal or len(self._current_path_builder) >= max_path_length:
            self._handle_rollout_ending(max_path_length,
                                        discard_incomplete_paths)
            self._start_new_rollout()
        else:
            self._obs = next_ob

    def _start_new_rollout(self):
        self._current_path_builder = PathBuilder()
        self._obs = self._env.reset()

    def _handle_rollout_ending(self, max_path_length,
                               discard_incomplete_paths):
        if len(self._current_path_builder) > 0:
            path = self._current_path_builder.get_all_stacked()
            path_len = len(path['actions'])
            if (path_len != max_path_length and not path['terminals'][-1]
                    and discard_incomplete_paths):
                return
            self._epoch_paths.append(path)
            self._num_paths_total += 1
            self._num_steps_total += path_len
Пример #25
0
class RLAlgorithm(metaclass=abc.ABCMeta):
    def __init__(self,
                 env,
                 exploration_policy: ExplorationPolicy,
                 training_env=None,
                 num_epochs=100,
                 num_steps_per_epoch=10000,
                 num_steps_per_eval=1000,
                 num_updates_per_env_step=1,
                 batch_size=1024,
                 max_path_length=1000,
                 discount=0.99,
                 replay_buffer_size=1000000,
                 reward_scale=1,
                 render=False,
                 save_replay_buffer=False,
                 save_algorithm=False,
                 save_environment=True,
                 eval_sampler=None,
                 eval_policy=None,
                 replay_buffer=None,
                 environment_farming=False,
                 farmlist_base=None):
        """
        Base class for RL Algorithms
        :param env: Environment used to evaluate.
        :param exploration_policy: Policy used to explore
        :param training_env: Environment used by the algorithm. By default, a
        copy of `env` will be made.
        :param num_epochs:
        :param num_steps_per_epoch:
        :param num_steps_per_eval:
        :param num_updates_per_env_step: Used by online training mode.
        :param num_updates_per_epoch: Used by batch training mode.
        :param batch_size:
        :param max_path_length:
        :param discount:
        :param replay_buffer_size:
        :param reward_scale:
        :param render:
        :param save_replay_buffer:
        :param save_algorithm:
        :param save_environment:
        :param eval_sampler:
        :param eval_policy: Policy to evaluate with.
        :param replay_buffer:
        """
        self.training_env = training_env or pickle.loads(pickle.dumps(env))
        self.exploration_policy = exploration_policy
        self.num_epochs = num_epochs
        self.num_env_steps_per_epoch = num_steps_per_epoch
        self.num_steps_per_eval = num_steps_per_eval
        self.num_updates_per_train_call = num_updates_per_env_step
        self.batch_size = batch_size
        self.max_path_length = max_path_length
        self.discount = discount
        self.replay_buffer_size = replay_buffer_size
        self.reward_scale = reward_scale
        self.render = render
        self.save_replay_buffer = save_replay_buffer
        self.save_algorithm = save_algorithm
        self.save_environment = save_environment
        self.environment_farming = environment_farming
        if self.environment_farming:
            if farmlist_base == None:
                raise 'RLAlgorithm: environment_farming option should be used with farmlist_base option!'
            self.farmlist_base = farmlist_base
            self.farmer = Farmer(self.farmlist_base)

        if eval_sampler is None:
            if eval_policy is None:
                eval_policy = exploration_policy
            if not self.environment_farming:
                eval_sampler = InPlacePathSampler(
                    env=env,
                    policy=eval_policy,
                    max_samples=self.num_steps_per_eval + self.max_path_length,
                    max_path_length=self.max_path_length,
                )
            # For environment_farming environments managed dinamically. Therefore, eval_sampler should be created at time of sampling.

        self.eval_policy = eval_policy
        self.eval_sampler = eval_sampler

        self.action_space = env.action_space
        self.obs_space = env.observation_space
        self.env = env
        if replay_buffer is None:
            replay_buffer = EnvReplayBuffer(
                self.replay_buffer_size,
                self.env,
            )
        self.replay_buffer = replay_buffer

        self._n_env_steps_total = 0
        self._n_train_steps_total = 0
        self._n_rollouts_total = 0
        self._do_train_time = 0
        self._epoch_start_time = None
        self._algo_start_time = None
        self._old_table_keys = None
        self._current_path_builder = PathBuilder()
        self._exploration_paths = []

    def refarm(self):
        if self.environment_farming:
            del self.farmer
            self.farmer = Farmer(self.farmlist_base)

    def train(self, start_epoch=0):
        self.pretrain()
        if start_epoch == 0:
            params = self.get_epoch_snapshot(-1)
            logger.save_itr_params(-1, params)
        self.training_mode(False)
        self._n_env_steps_total = start_epoch * self.num_env_steps_per_epoch
        gt.reset()
        gt.set_def_unique(False)
        self.train_online(start_epoch=start_epoch)

    def pretrain(self):
        """
        Do anything before the main training phase.
        """
        pass

    def play_one_step(self, observation=None, env=None):
        if self.environment_farming:
            observation = env.get_last_observation()

        if env == None:
            env = self.training_env

        action, agent_info = self._get_action_and_info(observation)

        if self.render and not self.environment_farming:
            env.render()

        next_ob, raw_reward, terminal, env_info = (env.step(action))

        self._n_env_steps_total += 1
        reward = raw_reward * self.reward_scale
        terminal = np.array([terminal])
        reward = np.array([reward])
        self._handle_step(observation,
                          action,
                          reward,
                          next_ob,
                          terminal,
                          agent_info=agent_info,
                          env_info=env_info,
                          env=env)

        if not self.environment_farming:
            current_path_builder = self._current_path_builder
        else:
            current_path_builder = env.get_current_path_builder()

        if terminal or len(current_path_builder) >= self.max_path_length:
            self._handle_rollout_ending(env)
            observation = self._start_new_rollout(env)
        else:
            observation = next_ob

        if self.environment_farming:
            self.farmer.add_free_env(env)

        gt.stamp('sample')

        return observation

    def play_ignore(self, env):
        print("Number of active threads: " + str(th.active_count()))
        t = th.Thread(target=self.play_one_step,
                      args=(
                          None,
                          env,
                      ),
                      daemon=True)
        t.start()
        # ignore and return, let the thread run for itself.

    def train_online(self, start_epoch=0):
        if not self.environment_farming:
            observation = self._start_new_rollout()
        self._current_path_builder = PathBuilder()
        for epoch in gt.timed_for(
                range(start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            self._start_epoch(epoch)

            for _ in range(self.num_env_steps_per_epoch):
                if not self.environment_farming:
                    observation = self.play_one_step(observation)
                else:
                    # acquire a remote environment
                    remote_env = self.farmer.force_acq_env()
                    self.play_ignore(remote_env)

                # Training out of threads
                self._try_to_train()
                gt.stamp('train')

            if epoch % 10 == 0:
                self._try_to_eval(epoch)
            gt.stamp('eval')
            self._end_epoch()

    def _try_to_train(self):
        if self._can_train():
            self.training_mode(True)
            for i in range(self.num_updates_per_train_call):
                self._do_training()
                self._n_train_steps_total += 1
            self.training_mode(False)

    def _try_to_eval(self, epoch):
        logger.save_extra_data(self.get_extra_data_to_save(epoch))
        if self._can_evaluate():
            if self.environment_farming:
                # Create new new eval_sampler each evaluation time in order to avoid relesed environment problem
                env_for_eval_sampler = self.farmer.force_acq_env()
                print(env_for_eval_sampler)
                self.eval_sampler = InPlacePathSampler(
                    env=env_for_eval_sampler,
                    policy=self.eval_policy,
                    max_samples=self.num_steps_per_eval + self.max_path_length,
                    max_path_length=self.max_path_length,
                )

            self.evaluate(epoch)

            # Adding env back to free_env list
            self.farmer.add_free_env(env_for_eval_sampler)

            params = self.get_epoch_snapshot(epoch)
            logger.save_itr_params(epoch, params)
            table_keys = logger.get_table_key_set()
            if self._old_table_keys is not None:
                assert table_keys == self._old_table_keys, (
                    "Table keys cannot change from iteration to iteration.")
            self._old_table_keys = table_keys

            logger.record_tabular(
                "Number of train steps total",
                self._n_train_steps_total,
            )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            times_itrs = gt.get_times().stamps.itrs
            train_time = times_itrs['train'][-1]
            sample_time = times_itrs['sample'][-1]
            eval_time = times_itrs['eval'][-1] if epoch > 0 else 0
            epoch_time = train_time + sample_time + eval_time
            total_time = gt.get_times().total

            logger.record_tabular('Train Time (s)', train_time)
            logger.record_tabular('(Previous) Eval Time (s)', eval_time)
            logger.record_tabular('Sample Time (s)', sample_time)
            logger.record_tabular('Epoch Time (s)', epoch_time)
            logger.record_tabular('Total Train Time (s)', total_time)

            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")

    def _can_evaluate(self):
        """
        One annoying thing about the logger table is that the keys at each
        iteration need to be the exact same. So unless you can compute
        everything, skip evaluation.
        A common example for why you might want to skip evaluation is that at
        the beginning of training, you may not have enough data for a
        validation and training set.
        :return:
        """
        return (len(self._exploration_paths) > 0 and
                self.replay_buffer.num_steps_can_sample() >= self.batch_size)

    def _can_train(self):
        return self.replay_buffer.num_steps_can_sample() >= self.batch_size

    def _get_action_and_info(self, observation):
        """
        Get an action to take in the environment.
        :param observation:
        :return:
        """
        self.exploration_policy.set_num_steps_total(self._n_env_steps_total)
        return self.exploration_policy.get_action(observation, )

    def _start_epoch(self, epoch):
        self._epoch_start_time = time.time()
        self._exploration_paths = []
        self._do_train_time = 0
        logger.push_prefix('Iteration #%d | ' % epoch)

    def _end_epoch(self):
        logger.log("Epoch Duration: {0}".format(time.time() -
                                                self._epoch_start_time))
        logger.log("Started Training: {0}".format(self._can_train()))
        logger.pop_prefix()

    def _start_new_rollout(self, env=None):
        # WARNING exploration_policy.reset() does NOTHING for most of the time. So it is not modified for farming.
        self.exploration_policy.reset()
        if not self.environment_farming:
            return self.training_env.reset()
        elif env:
            return env.reset()
        else:
            raise '_start_new_rollout: Environment should be given in farming mode!'

    def _handle_path(self, path):
        """
        Naive implementation: just loop through each transition.
        :param path:
        :return:
        """
        for (ob, action, reward, next_ob, terminal, agent_info,
             env_info) in zip(
                 path["observations"],
                 path["actions"],
                 path["rewards"],
                 path["next_observations"],
                 path["terminals"],
                 path["agent_infos"],
                 path["env_infos"],
             ):
            self._handle_step(
                ob,
                action,
                reward,
                next_ob,
                terminal,
                agent_info=agent_info,
                env_info=env_info,
            )
        self._handle_rollout_ending()

    def _handle_step(self,
                     observation,
                     action,
                     reward,
                     next_observation,
                     terminal,
                     agent_info,
                     env_info,
                     env=None):
        """
        Implement anything that needs to happen after every step
        :return:
        """
        if not self.environment_farming:
            self._current_path_builder.add_all(
                observations=observation,
                actions=action,
                rewards=reward,
                next_observations=next_observation,
                terminals=terminal,
                agent_infos=agent_info,
                env_infos=env_info,
            )
        elif not env == None:
            _current_path_builder = env.get_current_path_builder()
            if _current_path_builder == None:
                raise '_handle_step: env object should have current_path_builder field!'
            _current_path_builder.add_all(
                observations=observation,
                actions=action,
                rewards=reward,
                next_observations=next_observation,
                terminals=terminal,
                agent_infos=agent_info,
                env_infos=env_info,
            )
        else:
            raise '_handle_step: env object should given to the fnc in farming mode!'

        self.replay_buffer.add_sample(
            observation=observation,
            action=action,
            reward=reward,
            terminal=terminal,
            next_observation=next_observation,
            agent_info=agent_info,
            env_info=env_info,
        )

    def _handle_rollout_ending(self, env=None):
        """
        Implement anything that needs to happen after every rollout.
        """
        #WARNING terminate_episode does NOTHING so it isn't adopted to farming
        self.replay_buffer.terminate_episode()
        self._n_rollouts_total += 1
        if not self.environment_farming:
            if len(self._current_path_builder) > 0:
                self._exploration_paths.append(
                    self._current_path_builder.get_all_stacked())
                self._current_path_builder = PathBuilder()
        elif env:
            _current_path_builder = env.get_current_path_builder()
            if _current_path_builder == None:
                raise '_handle_rollout_ending: env object should have current_path_builder field!'
            self._exploration_paths.append(
                _current_path_builder.get_all_stacked())
            env.newPathBuilder()
        else:
            raise '_handle_rollout_ending: env object should given to the fnc in farming mode!'

    def get_epoch_snapshot(self, epoch):
        if self.render:
            self.training_env.render(close=True)
        data_to_save = dict(
            epoch=epoch,
            exploration_policy=self.exploration_policy,
        )
        if self.save_environment:
            data_to_save['env'] = self.training_env
        return data_to_save

    def get_extra_data_to_save(self, epoch):
        """
        Save things that shouldn't be saved every snapshot but rather
        overwritten every time.
        :param epoch:
        :return:
        """
        if self.render:
            self.training_env.render(close=True)
        data_to_save = dict(epoch=epoch, )
        if self.save_environment:
            data_to_save['env'] = self.training_env
        if self.save_replay_buffer:
            data_to_save['replay_buffer'] = self.replay_buffer
        if self.save_algorithm:
            data_to_save['algorithm'] = self
        return data_to_save

    @abc.abstractmethod
    def training_mode(self, mode):
        """
        Set training mode to `mode`.
        :param mode: If True, training will happen (e.g. set the dropout
        probabilities to not all ones).
        """
        pass

    @abc.abstractmethod
    def cuda(self):
        """
        Turn cuda on.
        :return:
        """
        pass

    @abc.abstractmethod
    def evaluate(self, epoch):
        """
        Evaluate the policy, e.g. save/print progress.
        :param epoch:
        :return:
        """
        pass

    @abc.abstractmethod
    def _do_training(self):
        """
        Perform some update, e.g. perform one gradient step.
        :return:
        """
        pass