コード例 #1
0
ファイル: agent_base.py プロジェクト: nicolascastanet/mrl
class Agent():
    """
  The base agent class. Important: Agents should almost always be generated from a config_dict
  using mrl.util.config_to_agent(config_dict). See configs folder for default configs / examples.
  
  Agent is a flat collection of mrl.Module, which may include:
    - environments (train/eval)
    - replay buffer(s)
    - new task function
    - action function  (exploratory + greedy) 
    - loss function
    - intrinsic curiosity module 
    - value / policy networks and other models (e.g. goal generation)
    - planner (e.g., MCTS)
    - logger
    - anything else you want (image tagger, human interface, etc.)

  Agent has some lifecycle methods (process_experience, optimize, save, load) that call the 
  corresponding lifecycle hooks on modules that declare them.

  Modules have a reference to the Agent so that they can access each other via the Agent. Actually,
  modules use __getattr__ to access the agent directly (via self.*), so they are effectively agent
  methods that are defined in separate files / have their own initialize/save/load functions.

  Modules are registered and saved/restored individually. This lets you swap out / tweak individual
  agent methods without subclassing the agent. Individual saves let you swap out saved modules via
  the filesystem (good for, e.g., BatchRL), avoid pickling problems from non-picklable modules.
  """
    def __init__(
        self,
        module_list: Iterable,  # list of mrl.Modules (possibly nested)
        config: AttrDict):  # hyperparameters and module settings

        self.config = config
        parent_folder = config.parent_folder
        assert parent_folder, "Setting the agent's parent folder is required!"
        self.agent_name = config.get(
            'agent_name') or 'agent_' + short_timestamp()
        self.agent_folder = os.path.join(parent_folder, self.agent_name)
        load_agent = False
        if os.path.exists(self.agent_folder):
            print('Detected existing agent! Loading agent from checkpoint...')
            load_agent = True
        else:
            os.makedirs(self.agent_folder, exist_ok=True)

        self._process_experience_registry = [
        ]  # set of modules which define _process_experience
        self._optimize_registry = []  # set of modules which define _optimize
        self.config.env_steps = 0
        self.config.opt_steps = 0

        module_list = flatten_modules(module_list)
        self.module_dict = AttrDict()
        for module in module_list:
            assert module.module_name
            setattr(self, module.module_name, module)
            self.module_dict[module.module_name] = module
        for module in module_list:
            self._register_module(module)

        self.training = True

        if load_agent:
            self.load()
            print('Successfully loaded saved agent!')
        else:
            self.save()

    def train_mode(self):
        """Set agent to train mode; exploration / use dropout / etc. As in Pytorch."""
        self.training = True

    def eval_mode(self):
        """Set agent to eval mode; act deterministically / don't use dropout / etc."""
        self.training = False

    def process_experience(self, experience: AttrDict):
        """Calls the _process_experience function of each relevant module
    (typically, these will include a replay buffer and one or more logging modules)"""
        self.config.env_steps += self.env.num_envs if hasattr(self,
                                                              'env') else 1
        for module in self._process_experience_registry:
            module._process_experience(experience)

    def optimize(self):
        """Calls the _optimize function of each relevant module
    (typically, this will be the main algorithm; but may include others)"""
        self.config.opt_steps += 1
        for module in self._optimize_registry:
            module._optimize()

    def _register_module(self, module):
        """
    Provides module with a reference to agent so that modules can interact; e.g., 
    allows agent's policy to reference the value function.

    Then, calls each module's _setup and verify methods to _setup the module and
    verify that agent has all required modules.
    """
        self.module_dict[module.module_name] = module

        module.agent = self
        module.verify_agent_compatibility()
        module._setup()
        module.new_task()
        if hasattr(module, '_process_experience'):
            self._process_experience_registry.append(module)
        if hasattr(module, '_optimize'):
            self._optimize_registry.append(module)

    def set_module(self, module_name, module):
        """
    Sets a module (can be used to switch environments / policies)
    """
        setattr(self, module_name, module)
        self._register_module(module)

    def save(self, subfolder: Optional[str] = None):
        """
    The state of all stateful modules is saved to the agent's folder.
    The agent itself is NOT saved, and should be (1) rebuilt, and (2) restored using self.load().
    Subfolder can be used to save various checkpoints of same agent.
    """
        save_folder = self.agent_folder
        subfolder = subfolder or 'checkpoint'
        save_folder = os.path.join(save_folder, subfolder)

        if not os.path.exists(save_folder):
            os.makedirs(save_folder)

        for module in self.module_dict.values():
            module.save(save_folder)

        with open(os.path.join(save_folder, 'config.pickle'), 'wb') as f:
            pickle.dump(self.config, f)

    def load(self, subfolder: Optional[str] = None):
        """
    Restores state of stateful modules from the agent's folder[/subfolder].
    """
        save_folder = self.agent_folder
        subfolder = subfolder or 'checkpoint'
        save_folder = os.path.join(save_folder, subfolder)

        assert os.path.exists(save_folder), "load path does not exist!"

        with open(os.path.join(save_folder, 'config.pickle'), 'rb') as f:
            self.config = pickle.load(f)

        for module in self.module_dict.values():
            print("Loading module {}".format(module.module_name))
            module.load(save_folder)

    def save_checkpoint(self, checkpoint_dir):
        """
    Saves agent together with its buffer regardless of save buffer.
    Keeps 2 saves in the in folder in case the job is killed and last
    checkpoint is corrupted.

    NOTE: You should call agent.save to save to the main folder BEFORE calling this.
    """
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

        with open(os.path.join(checkpoint_dir, 'INITIALIZED'), 'w') as f:
            f.write('INITIALIZED')

        subfolder1 = os.path.join(checkpoint_dir, '1')
        subfolder2 = os.path.join(checkpoint_dir, '2')

        os.makedirs(os.path.join(subfolder1, 'checkpoint'), exist_ok=True)
        os.makedirs(os.path.join(subfolder2, 'checkpoint'), exist_ok=True)

        done1 = os.path.join(subfolder1, 'DONE')
        done2 = os.path.join(subfolder2, 'DONE')

        if not os.path.exists(done1):
            savedir = subfolder1
            done_file = done1
        elif not os.path.exists(done2):
            savedir = subfolder2
            done_file = done2
        else:
            modtime1 = os.path.getmtime(done1)
            modtime2 = os.path.getmtime(done2)
            if modtime1 < modtime2:
                savedir = subfolder1
                done_file = done1
            else:
                savedir = subfolder2
                done_file = done2

            os.remove(done_file)

        savedir_checkpoint = os.path.join(savedir, 'checkpoint')
        # First save all modules, including replay buffer
        old_save_replay_buf = self.config.save_replay_buf
        self.config.save_replay_buf = True
        for module in self.module_dict.values():
            module.save(savedir_checkpoint)
        self.config.save_replay_buf = old_save_replay_buf

        # Now save the config also
        with open(os.path.join(savedir_checkpoint, 'config.pickle'),
                  'wb') as f:
            pickle.dump(self.config, f)

        # Now copy over the config and results files from the agent_folder
        files_and_folders = glob.glob(os.path.join(self.agent_folder, '*'))
        for file_or_folder in files_and_folders:
            if os.path.isfile(file_or_folder):
                shutil.copy(file_or_folder, savedir)

        # Finally, print the DONE file.
        with open(done_file, 'w') as f:
            f.write('DONE')

    def load_from_checkpoint(self, checkpoint_dir):
        """
    This loads an agent from a checkpoint_dir to which it was saved using the `save_checkpoint` method.
    """
        subfolder1 = os.path.join(checkpoint_dir, '1')
        subfolder2 = os.path.join(checkpoint_dir, '2')
        done1 = os.path.join(subfolder1, 'DONE')
        done2 = os.path.join(subfolder2, 'DONE')

        if not os.path.exists(done1):
            assert os.path.exists(done2)
            savedir = subfolder2
        elif not os.path.exists(done2):
            savedir = subfolder1
        else:
            modtime1 = os.path.getmtime(done1)
            modtime2 = os.path.getmtime(done2)
            if modtime1 > modtime2:
                savedir = subfolder1
            else:
                savedir = subfolder2

        savedir_checkpoint = os.path.join(savedir, 'checkpoint')

        # First load the agent
        with open(os.path.join(savedir_checkpoint, 'config.pickle'),
                  'rb') as f:
            self.config = pickle.load(f)

        for module in self.module_dict.values():
            print("Loading module {}".format(module.module_name))
            module.load(savedir_checkpoint)

        # Then copy over the config and results file to the agent_folder
        files_and_folders = glob.glob(os.path.join(savedir, '*'))
        for file_or_folder in files_and_folders:
            if os.path.isfile(file_or_folder):
                shutil.copy(file_or_folder, self.agent_folder)

    def torch(self, x):
        if isinstance(x, torch.Tensor): return x
        return torch.FloatTensor(x).to(self.config.device)

    def numpy(self, x):
        return x.cpu().detach().numpy()
コード例 #2
0
class EntropyPrioritizedOnlineHERBuffer(mrl.Module):

  def __init__(
      self,
      module_name='prioritized_replay',
      rank_method='dense',
      temperature=1.0
  ):
    """
    Buffer that stores entropy of trajectories for prioritized replay
    """

    super().__init__(module_name, required_agent_modules=['env','replay_buffer'], locals=locals())

    self.goal_space = None
    self.buffer = None
    self.rank_method = rank_method
    self.temperature = temperature
    self.traj_len = None

  def _setup(self):
    self.ag_buffer = self.replay_buffer.buffer.BUFF.buffer_ag

    env = self.env
    assert type(env.observation_space) == gym.spaces.Dict
    self.goal_space = env.observation_space.spaces["desired_goal"]

    # Note: for now we apply entropy estimation on the achieved goal (ag) space
    # Define the buffers to store for prioritization
    items = [("entropy", (1,)), ("priority", (1,))]
    self.buffer = AttrDict()
    for name, shape in items:
      self.buffer['buffer_' + name] = RingBuffer(self.ag_buffer.maxlen, shape=shape)

    self._subbuffers = [[] for _ in range(self.env.num_envs)]
    self.n_envs = self.env.num_envs

    # Define the placeholder for mixture model to estimate trajectory
    self.clf = 0

  def fit_density_model(self):
    ag = self.ag_buffer.data[0:self.size].copy()
    X_train = ag.reshape(-1, self.traj_len * ag.shape[-1]) # [num_episodes, episode_len * goal_dim]

    self.clf = mixture.BayesianGaussianMixture(weight_concentration_prior_type="dirichlet_distribution", n_components=3)
    self.clf.fit(X_train)
    pred = -self.clf.score_samples(X_train)

    self.pred_min = pred.min()
    pred = pred - self.pred_min
    pred = np.clip(pred, 0, None)
    self.pred_sum = pred.sum()
    pred = pred / self.pred_sum
    self.pred_avg = (1 / pred.shape[0])
    pred = np.repeat(pred, self.traj_len, axis=0)

    self.buffer.buffer_entropy.data[:self.size] = pred.reshape(-1,1).copy()

  def _process_experience(self, exp):
    # Compute the entropy 
    # TODO: Include previous achieved goal too? or use that instead of ag?
    achieved = exp.next_state['achieved_goal']
    for i in range(self.n_envs):
      self._subbuffers[i].append([achieved[i]])
    
    for i in range(self.n_envs):
      if exp.trajectory_over[i]:
        # TODO: Compute the entropy of the trajectory
        traj_len = len(self._subbuffers[i])
        if self.traj_len is None:
          self.traj_len = traj_len
        else:
          # Current implementation assumes the same length for all trajectories
          assert(traj_len == self.traj_len)

        if not isinstance(self.clf, int):
          ag = [np.stack(a) for a in zip(*self._subbuffers[i])][0] # [episode_len, goal_dim]
          X = ag.reshape(-1, ag.shape[0]*ag.shape[1])
          pred = -self.clf.score_samples(X)

          pred = pred - self.pred_min
          pred = np.clip(pred, 0, None)
          pred = pred / self.pred_sum # Shape (1,)

          entropy = np.ones((traj_len,1)) * pred
        else:
          # Not enough data to train mixture density yet, set entropy to be zero
          entropy = np.zeros((traj_len, 1))
        
        priority = np.zeros((traj_len,1))
        trajectory = [entropy, priority]
        
        # TODO: Update the trajectory with entropy
        self.add_trajectory(*trajectory)

        self._subbuffers[i] = []

        # TODO: Update the rank here before adding it to the trajectory?
        self.update_priority()

  def add_trajectory(self, *items):
    """
    Append a trajectory of transitions to the buffer.

    :param items: a list of batched transition values to append to the replay buffer,
        in the item order that we initialized the ReplayBuffer with.
    """
    for buffer, batched_values in zip(self.buffer.values(), items):
      buffer.append_batch(batched_values)

  def update_priority(self):
    """
    After adding a trajectory to the replay buffer, update the ranking of transitions
    """
    # Note: 'dense' assigns the next highest element with the rank immediately 
    # after those assigned to the tied elements.
    entropy_transition_total = self.buffer.buffer_entropy.data[:self.size]
    entropy_rank = rankdata(entropy_transition_total, method=self.rank_method)
    entropy_rank = (entropy_rank - 1).reshape(-1, 1)
    self.buffer.buffer_priority.data[:self.size] = entropy_rank

  def __call__(self, batch_size):
    """
    Samples batch_size number of indices from main replay_buffer.

    Args:
      batch_size (int): size of the batch to sample
    
    Returns:
      batch_idxs: a 1-D numpy array of length batch_size containing indices
                  sampled in prioritized manner
    """
    if self.rank_method == 'none':
      entropy_trajectory = self.buffer.buffer_entropy.data[:self.size]
    else:
      entropy_trajectory = self.buffer.buffer_priority.data[:self.size]
    
    # Factorize out sampling into sampling trajectory according to priority/entropy
    # then sample time uniformly independently
    entropy_trajectory = entropy_trajectory.reshape(-1, self.traj_len)[:,0]
    p_trajectory = np.power(entropy_trajectory, 1/(self.temperature+1e-2))
    p_trajectory = p_trajectory / p_trajectory.sum()
    
    num_trajectories = p_trajectory.shape[0]
    batch_tidx = np.random.choice(num_trajectories, size=batch_size, p=p_trajectory)
    batch_idxs = self.traj_len * batch_tidx + np.random.choice(self.traj_len, size=batch_size)

    return batch_idxs

  @property
  def size(self):
    return len(self.ag_buffer)

  def save(self, save_folder):
    if self.config.save_replay_buf:
      state = self.buffer._get_state()
      with open(os.path.join(save_folder, "{}.pickle".format(self.module_name)), 'wb') as f:
        pickle.dump(state, f)

  def load(self, save_folder):
    load_path = os.path.join(save_folder, "{}.pickle".format(self.module_name))
    if os.path.exists(load_path):
      with open(load_path, 'rb') as f:
        state = pickle.load(f)
      self.buffer._set_state(state)
    else:
      self.logger.log_color('###############################################################', '', color='red')
      self.logger.log_color('WARNING', 'Replay buffer is not being loaded / was not saved.', color='cyan')
      self.logger.log_color('WARNING', 'Replay buffer is not being loaded / was not saved.', color='red')
      self.logger.log_color('WARNING', 'Replay buffer is not being loaded / was not saved.', color='yellow')
      self.logger.log_color('###############################################################', '', color='red')