Beispiel #1
0
    def dump_tabular(self):
        """
        Write all of the diagnostics from the current iteration.

        Writes both to stdout, and to the output file.
        """
        if proc_id() == 0:
            vals = []
            key_lens = [len(key) for key in self.log_headers]
            max_key_len = max(15, max(key_lens))
            keystr = '%' + '%d' % max_key_len
            fmt = "| " + keystr + "s | %15s |"
            n_slashes = 22 + max_key_len
            print("-" * n_slashes)
            for key in self.log_headers:
                val = self.log_current_row.get(key, "")
                valstr = "%8.3g" % val if hasattr(val, "__float__") else val
                print(fmt % (key, valstr))
                vals.append(val)
            print("-" * n_slashes)
            if self.output_file is not None:
                if self.first_row:
                    self.output_file.write("\t".join(self.log_headers) + "\n")
                self.output_file.write("\t".join(map(str, vals)) + "\n")
                self.output_file.flush()
        self.log_current_row.clear()
        self.first_row = False
Beispiel #2
0
    def save_state(self, state_dict, model, itr=None):
        """
        Saves the state of an experiment.

        To be clear: this is about saving *state*, not logging diagnostics.
        All diagnostic logging is separate from this function. This function
        will save whatever is in ``state_dict``---usually just a copy of the
        environment---and the most recent copy of the model via ``model``.

        Call with any frequency you prefer. If you only want to maintain a
        single state and overwrite it at each call with the most recent
        version, leave ``itr=None``. If you want to keep all of the states you
        save, provide unique (increasing) values for 'itr'.

        Args:
            state_dict (dict): Dictionary containing essential elements to
                describe the current state of training.
            model (nn.Module): A model which contains the policy.
            itr: An int, or None. Current iteration of training.
        """
        if proc_id() == 0:
            fname = 'vars.pkl' if itr is None else 'vars%d.pkl' % itr
            try:
                joblib.dump(state_dict, osp.join(self.output_dir, fname))
            except:
                self.log('Warning: could not pickle state_dict.', color='red')
            self._torch_save(model, itr)
Beispiel #3
0
    def save_config(self, config):
        """
        Log an experiment configuration.

        Call this once at the top of your experiment, passing in all important
        config vars as a dict. This will serialize the config to JSON, while
        handling anything which can't be serialized in a graceful way (writing
        as informative a string as possible).

        Example use:

        .. code-block:: python

            logger = EpochLogger(**logger_kwargs)
            logger.save_config(locals())
        """
        config_json = convert_json(config)
        if self.exp_name is not None:
            config_json['exp_name'] = self.exp_name
        if proc_id() == 0:
            output = json.dumps(config_json,
                                separators=(',', ':\t'),
                                indent=4,
                                sort_keys=True)
            print(colorize('Saving config:\n', color='cyan', bold=True))
            print(output)
            with open(osp.join(self.output_dir, "config.json"), 'w') as out:
                out.write(output)
Beispiel #4
0
    def __init__(self,
                 output_dir=None,
                 output_fname='progress.txt',
                 exp_name=None):
        """
        Initialize a Logger.

        Args:
            output_dir (string): A directory for saving results to. If
                ``None``, defaults to a temp directory of the form
                ``/tmp/experiments/somerandomnumber``.

            output_fname (string): Name for the tab-separated-value file
                containing metrics logged throughout a training run.
                Defaults to ``progress.txt``.

            exp_name (string): Experiment name. If you run multiple training
                runs and give them all the same ``exp_name``, the plotter
                will know to group them. (Use case: if you run the same
                hyperparameter configuration with multiple random seeds, you
                should give them all the same ``exp_name``.)
        """
        if proc_id() == 0:
            self.output_dir = output_dir or "/tmp/experiments/%i" % int(
                time.time())
            if osp.exists(self.output_dir):
                print(
                    "Warning: Log dir %s already exists! Storing info there anyway."
                    % self.output_dir)
            else:
                os.makedirs(self.output_dir)
            self.output_file = open(osp.join(self.output_dir, output_fname),
                                    'w')
            atexit.register(self.output_file.close)
            print(
                colorize("Logging data to %s" % self.output_file.name,
                         'green',
                         bold=True))
        else:
            self.output_dir = None
            self.output_file = None
        self.first_row = True
        self.log_headers = []
        self.log_current_row = {}
        self.exp_name = exp_name
Beispiel #5
0
 def log(self, msg, color='green'):
     """Print a colorized message to stdout."""
     if proc_id() == 0:
         print(colorize(msg, color, bold=True))
Beispiel #6
0
 def _torch_save(self, model, itr=None):
     if proc_id() == 0:
         fname = 'torch_save.pt' if itr is None else 'torch_save%d.pt' % itr
         torch.save(model, osp.join(self.output_dir, fname))
Beispiel #7
0
def ppo(env_fn: Callable,
        actor_critic: Callable = ActorCriticWM,
        ac_kwargs: dict = dict(),
        preprocess_kwargs: dict = dict(),
        seed: int = 0,
        steps_per_epoch: int = 4000,
        epochs: int = 50,
        gamma: float = 0.99,
        clip_ratio: float = 0.2,
        pi_lr: float = 3e-4,
        vf_lr: float = 1e-3,
        train_pi_iters: int = 80,
        train_v_iters: int = 80,
        lam: float = 0.97,
        max_ep_len: int = 1000,
        target_kl: float = 0.01,
        logger_kwargs: dict = dict(),
        save_freq: int = 10):
    """
    Proximal Policy Optimization (by clipping) with early stopping based on approximate KL.
    Parameters
    ----------
    env_fn
        A function which creates a copy of the environment.
        The environment must satisfy the OpenAI Gym API.
    actor_critic
        The agent's main model which is composed of
        the policy and value function model, where the policy takes
        some state, ``x`` and action ``a``, and value function takes
        the state ``x``. The model returns a tuple of:
        ===========  ================  ======================================
        Symbol       Shape             Description
        ===========  ================  ======================================
        ``pi``       (batch, act_dim)  | Samples actions from policy given
                                       | states.
        ``logp``     (batch,)          | Gives log probability, according to
                                       | the policy, of taking actions ``a``
                                       | in states ``x``.
        ``logp_pi``  (batch,)          | Gives log probability, according to
                                       | the policy, of the action sampled by
                                       | ``pi``.
        ``v``        (batch,)          | Gives the value estimate for states
                                       | in ``x``. (Critical: make sure
                                       | to flatten this via .squeeze()!)
        ===========  ================  ======================================
    ac_kwargs
        Any kwargs appropriate for the actor_critic class you provided to PPO.
    preprocess_kwargs
        Any kwargs appropriate for the observation preprocessing function in utils.
    seed
        Seed for random number generators.
    steps_per_epoch
        Number of steps of interaction (state-action pairs) for the agent and the environment in each epoch.
    epochs
        Number of epochs of interaction (equivalent to number of policy updates) to perform.
    gamma
        Discount factor. (Always between 0 and 1.)
    clip_ratio
        Hyperparameter for clipping in the policy objective.
        Roughly: how far can the new policy go from the old policy while
        still profiting (improving the objective function)? The new policy
        can still go farther than the clip_ratio says, but it doesn't help
        on the objective anymore. (Usually small, 0.1 to 0.3.)
    pi_lr
        Learning rate for policy optimizer.
    vf_lr
        Learning rate for value function optimizer.
    train_pi_iters
        Maximum number of gradient descent steps to take on policy loss per epoch.
        (Early stopping may cause optimizer to take fewer than this.)
    train_v_iters
        Number of gradient descent steps to take on value function per epoch.
    lam
        Lambda for GAE-Lambda. (Always between 0 and 1, close to 1.)
    max_ep_len
        Maximum length of trajectory / episode / rollout.
    target_kl
        Roughly what KL divergence we think is appropriate between new and old policies
        after an update. This will get used for early stopping. (Usually small, 0.01 or 0.05.)
    logger_kwargs
        Keyword args for EpochLogger.
    save_freq
        How often (in terms of gap between epochs) to save the current policy and value function.
    """

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    seed += 10000 * proc_id()
    torch.manual_seed(seed)
    np.random.seed(seed)

    env = env_fn()
    try:  # (3, 64, 64)
        obs_dim = (
            env.observation_space.shape[-1], ) + preprocess_kwargs['resize']
    except KeyError:  # (3, 96, 96)
        obs_dim = (env.observation_space.shape[-1],
                   ) + env.observation_space.shape[:-1]
    act_dim = env.action_space.shape  # (3,)

    # Share information about action space with policy architecture
    ac_kwargs['action_space'] = env.action_space  # Box(3,)

    # Main model
    actor_critic = actor_critic(**ac_kwargs)

    # Experience buffer
    local_steps_per_epoch = int(steps_per_epoch / num_procs())
    buf = PPOBuffer(obs_dim, act_dim, local_steps_per_epoch, gamma, lam)

    # Count variables
    var_counts = tuple(
        count_vars(module)
        for module in [actor_critic.policy, actor_critic.value_function])
    logger.log('\nNumber of parameters: \t pi: %d, \t v: %d\n' % var_counts)

    # Optimizers
    train_pi = torch.optim.Adam(actor_critic.policy.parameters(), lr=pi_lr)
    train_v = torch.optim.Adam(actor_critic.value_function.parameters(),
                               lr=vf_lr)

    # Sync params across processes
    sync_all_params(actor_critic.parameters())

    def update():
        obs, act, adv, ret, logp_old = [torch.Tensor(x) for x in buf.get()]

        # Training policy
        _, logp, _ = actor_critic.policy(obs, act)
        ratio = (logp - logp_old).exp()
        min_adv = torch.where(adv > 0, (1 + clip_ratio) * adv,
                              (1 - clip_ratio) * adv)
        pi_l_old = -(torch.min(ratio * adv, min_adv)).mean()
        ent = (-logp).mean()  # a sample estimate for entropy

        for i in range(train_pi_iters):
            # Output from policy function graph
            _, logp, _ = actor_critic.policy(obs, act)
            # PPO policy objective
            ratio = (logp - logp_old).exp()
            min_adv = torch.where(adv > 0, (1 + clip_ratio) * adv,
                                  (1 - clip_ratio) * adv)
            pi_loss = -(torch.min(ratio * adv, min_adv)).mean()

            # Policy gradient step
            train_pi.zero_grad()
            pi_loss.backward()
            average_gradients(train_pi.param_groups)
            train_pi.step()

            _, logp, _ = actor_critic.policy(obs, act)
            kl = (logp_old - logp).mean()
            kl = mpi_avg(kl.item())
            if kl > 1.5 * target_kl:
                logger.log(
                    'Early stopping at step %d due to reaching max kl.' % i)
                break
        logger.store(StopIter=i)

        # Training value function
        v = actor_critic.value_function(obs, act)
        v_l_old = F.mse_loss(v, ret)
        for _ in range(train_v_iters):
            # Output from value function graph
            v = actor_critic.value_function(obs, act)
            # PPO value function objective
            v_loss = F.mse_loss(v, ret)

            # Value function gradient step
            train_v.zero_grad()
            v_loss.backward()
            average_gradients(train_v.param_groups)
            train_v.step()

        # Log changes from update
        _, logp, _, v = actor_critic(obs, act)
        ratio = (logp - logp_old).exp()
        min_adv = torch.where(adv > 0, (1 + clip_ratio) * adv,
                              (1 - clip_ratio) * adv)
        pi_l_new = -(torch.min(ratio * adv, min_adv)).mean()
        v_l_new = F.mse_loss(v, ret)
        kl = (logp_old - logp).mean()  # a sample estimate for KL-divergence
        clipped = (ratio > (1 + clip_ratio)) | (ratio < (1 - clip_ratio))
        cf = (clipped.float()).mean()
        logger.store(LossPi=pi_l_old,
                     LossV=v_l_old,
                     KL=kl,
                     Entropy=ent,
                     ClipFrac=cf,
                     DeltaLossPi=(pi_l_new - pi_l_old),
                     DeltaLossV=(v_l_new - v_l_old))

    start_time = time.time()
    o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
    o = preprocess_obs(o, **preprocess_kwargs)
    o = np.transpose(o[np.newaxis, ...], axes=(0, 3, 1, 2))

    # Main loop: collect experience in env and update/log each epoch
    for epoch in range(epochs):
        a = env.action_space.sample()
        a = a.reshape(1, a.shape[0])
        actor_critic.eval()
        for t in range(local_steps_per_epoch):

            a, _, logp_t, v_t = actor_critic(torch.Tensor(o), torch.Tensor(a))
            a = a.detach().numpy()

            # save and log
            buf.store(o, a, r, v_t.item(), logp_t.detach().numpy())
            logger.store(VVals=v_t)

            o, r, d, _ = env.step(a[0])
            o = preprocess_obs(o, **preprocess_kwargs)
            o = np.transpose(o[np.newaxis, ...], axes=(0, 3, 1, 2))

            ep_ret += r
            ep_len += 1

            terminal = d or (ep_len == max_ep_len)
            if terminal or (t == local_steps_per_epoch - 1):
                if not terminal:
                    print('Warning: trajectory cut off by epoch at %d steps.' %
                          ep_len)
                # if trajectory didn't reach terminal state, bootstrap value target
                last_val = r if d else actor_critic.value_function(
                    torch.Tensor(o), torch.Tensor(a)).item()
                buf.finish_path(last_val)
                if terminal:  # only save EpRet / EpLen if trajectory finished
                    logger.store(EpRet=ep_ret, EpLen=ep_len)
                o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
                o = preprocess_obs(o, **preprocess_kwargs)
                o = np.transpose(o[np.newaxis, ...], axes=(0, 3, 1, 2))

        # Save model
        if (epoch % save_freq == 0) or (epoch == epochs - 1):
            logger.save_state({'env': env}, actor_critic, None)

        # Perform PPO update!
        actor_critic.train()
        update()

        # Log info about epoch
        logger.log_tabular('Epoch', epoch)
        logger.log_tabular('EpRet', with_min_and_max=True)
        logger.log_tabular('EpLen', average_only=True)
        logger.log_tabular('VVals', with_min_and_max=True)
        logger.log_tabular('TotalEnvInteracts', (epoch + 1) * steps_per_epoch)
        logger.log_tabular('LossPi', average_only=True)
        logger.log_tabular('LossV', average_only=True)
        logger.log_tabular('DeltaLossPi', average_only=True)
        logger.log_tabular('DeltaLossV', average_only=True)
        logger.log_tabular('Entropy', average_only=True)
        logger.log_tabular('KL', average_only=True)
        logger.log_tabular('ClipFrac', average_only=True)
        logger.log_tabular('StopIter', average_only=True)
        logger.log_tabular('Time', time.time() - start_time)
        logger.dump_tabular()