コード例 #1
0
ファイル: checkpoints.py プロジェクト: wrzadkow/flax
def save_checkpoint(ckpt_dir, target, step, prefix='checkpoint_', keep=1):
    """Save a checkpoint of the model.

  Attempts to be pre-emption safe by writing to temporary before
  a final rename and cleanup of past files.

  Args:
    ckpt_dir: str: path to store checkpoint files in.
    target: serializable flax object, usually a flax optimizer.
    step: int or float: training step number or other metric number.
    prefix: str: checkpoint file name prefix.
    keep: number of past checkpoint files to keep.

  Returns:
    Filename of saved checkpoint.
  """
    # Write temporary checkpoint file.
    logging.info('Saving checkpoint at step: %s', step)
    ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix)
    ckpt_path = _checkpoint_path(ckpt_dir, step, prefix)
    gfile.makedirs(os.path.dirname(ckpt_path))
    with gfile.GFile(ckpt_tmp_path, 'wb') as fp:
        fp.write(serialization.to_bytes(target))

    # Rename once serialization and writing finished.
    gfile.rename(ckpt_tmp_path, ckpt_path)
    logging.info('Saved checkpoint at %s', ckpt_path)

    # Remove old checkpoint files.
    base_path = os.path.join(ckpt_dir, f'{prefix}')
    checkpoint_files = natural_sort(gfile.glob(base_path + '*'))
    if len(checkpoint_files) > keep:
        old_ckpts = checkpoint_files[:-keep]
        for path in old_ckpts:
            logging.info('Removing checkpoint at %s', path)
            gfile.remove(path)

    return ckpt_path
コード例 #2
0
def log_video(writer,
              video,
              tb_key,
              name,
              step,
              work_unit_dir,
              save_raw=False,
              scale=False):
  """Save video frames to tensorboard and a file."""
  video_raw = video
  if scale:
    video = scale_depth(video)

  if writer is not None:
    logging.info('Logging video frames')
    writer.write_images(step, {f'{tb_key}/{name}': make_image_grid(video)})

  filename = f'{tb_key}_{name}_{step:05d}.mp4'
  local_path = os.path.join('/tmp', filename)
  logging.info('Writing video to %s', local_path)
  media.write_video(local_path, video, fps=30)

  wu_path = os.path.join(work_unit_dir, filename)
  logging.info('Copying video to %s', wu_path)
  gfile.copy(local_path, wu_path, overwrite=True)
  gfile.remove(local_path)

  if save_raw:
    # save raw floating point values to scale depth properly
    raw_filename = f'{tb_key}_{name}_{step:05d}.npy'
    raw_path = os.path.join(work_unit_dir, raw_filename)
    logging.info('Saving raw video to %s', raw_path)
    with gfile.GFile(raw_path, 'wb') as raw_f:
      onp.save(raw_f, video_raw)

  logging.info('Done logging video.')
コード例 #3
0
def training_loop(
    env=None,
    epochs=EPOCHS,
    policy_and_value_net_fn=None,
    policy_and_value_optimizer_fn=None,
    batch_size=BATCH_TRAJECTORIES,
    n_optimizer_steps=N_OPTIMIZER_STEPS,
    print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP,
    target_kl=0.01,
    boundary=20,
    max_timestep=None,
    max_timestep_eval=20000,
    random_seed=None,
    gamma=GAMMA,
    lambda_=LAMBDA,
    epsilon=EPSILON,
    c1=1.0,
    c2=0.01,
    output_dir=None,
    eval_every_n=1000,
    eval_env=None,
    done_frac_for_policy_save=0.5,
    enable_early_stopping=True,
    env_name=None,
    n_evals=1,
    len_history_for_policy=4,
):
    """Runs the training loop for PPO, with fixed policy and value nets."""
    assert env
    assert output_dir
    assert env_name

    gfile.makedirs(output_dir)

    # Create summary writers and history.
    train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train"))
    timing_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "timing"))
    eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval"))

    train_sw.text("env_name", env_name)
    timing_sw.text("env_name", env_name)
    eval_sw.text("env_name", env_name)

    jax_rng_key = trax.get_random_number_generator_and_set_seed(random_seed)

    # Batch Observations Shape = [1, 1] + OBS, because we will eventually call
    # policy and value networks on shape [B, T] +_OBS
    batch_observations_shape = (1, 1) + env.observation_space.shape
    observations_dtype = env.observation_space.dtype

    assert isinstance(env.action_space, gym.spaces.Discrete)
    n_actions = env.action_space.n

    jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2)

    # Initialize the policy and value network.
    policy_and_value_net_params, policy_and_value_net_apply = (
        policy_and_value_net_fn(key1, batch_observations_shape,
                                observations_dtype, n_actions))

    # Maybe restore the policy params. If there is nothing to restore, then
    # iteration = 0 and policy_and_value_net_params are returned as is.
    restore, policy_and_value_net_params, iteration = (maybe_restore_params(
        output_dir, policy_and_value_net_params))

    if restore:
        logging.info("Restored parameters from iteration [%d]", iteration)
        # We should start from the next iteration.
        iteration += 1

    policy_and_value_net_apply = jit(policy_and_value_net_apply)

    # Initialize the optimizers.
    policy_and_value_optimizer = (
        policy_and_value_optimizer_fn(policy_and_value_net_params))
    (policy_and_value_opt_state, policy_and_value_opt_update,
     policy_and_value_get_params) = policy_and_value_optimizer

    n_trajectories_done = 0
    last_saved_at = 0

    logging.info("Starting the PPO training loop.")
    for i in range(iteration, epochs):
        epoch_start_time = time.time()

        # Params we'll use to collect the trajectories.
        policy_and_value_net_params = policy_and_value_get_params(
            policy_and_value_opt_state)

        # A function to get the policy and value predictions.
        def get_predictions(observations, rng=None):
            """Returns log-probs, value predictions and key back."""
            key, key1 = jax_random.split(rng, num=2)

            log_probs, value_preds = policy_and_value_net_apply(
                observations, policy_and_value_net_params, rng=key1)

            return log_probs, value_preds, key

        # Evaluate the policy.
        policy_eval_start_time = time.time()
        if ((i + 1) % eval_every_n == 0) or (i == epochs - 1):
            jax_rng_key, key = jax_random.split(jax_rng_key, num=2)

            logging.vlog(1, "Epoch [% 6d] evaluating policy.", i)

            avg_reward, avg_reward_unclipped = evaluate_policy(
                eval_env,
                get_predictions,
                max_timestep=max_timestep_eval,
                n_evals=n_evals,
                len_history_for_policy=len_history_for_policy,
                rng=key)
            for k, v in avg_reward.items():
                eval_sw.scalar("eval/mean_reward/%s" % k, v, step=i)
                logging.info(
                    "Epoch [% 6d] Policy Evaluation (clipped) [%s] = %10.2f",
                    i, k, v)
            for k, v in avg_reward_unclipped.items():
                eval_sw.scalar("eval/mean_reward_unclipped/%s" % k, v, step=i)
                logging.info(
                    "Epoch [% 6d] Policy Evaluation (unclipped) [%s] = %10.2f",
                    i, k, v)
        policy_eval_time = get_time(policy_eval_start_time)

        trajectory_collection_start_time = time.time()
        logging.vlog(1, "Epoch [% 6d] collecting trajectories.", i)
        jax_rng_key, key = jax_random.split(jax_rng_key)
        trajs, n_done, timing_info = collect_trajectories(
            env,
            policy_fn=get_predictions,
            n_trajectories=batch_size,
            max_timestep=max_timestep,
            rng=key,
            len_history_for_policy=len_history_for_policy,
            reset=(i == 0) or restore,
            epsilon=(10.0 / (i + 10.0)))  # this is a different epsilon.
        trajectory_collection_time = get_time(trajectory_collection_start_time)

        logging.vlog(1, "Collecting trajectories took %0.2f msec.",
                     trajectory_collection_time)

        avg_reward = float(sum(np.sum(traj[2]) for traj in trajs)) / len(trajs)
        max_reward = max(np.sum(traj[2]) for traj in trajs)
        min_reward = min(np.sum(traj[2]) for traj in trajs)

        train_sw.scalar("train/mean_reward", avg_reward, step=i)

        logging.vlog(1,
                     "Rewards avg=[%0.2f], max=[%0.2f], min=[%0.2f], all=%s",
                     avg_reward, max_reward, min_reward,
                     [float(np.sum(traj[2])) for traj in trajs])

        logging.vlog(
            1, "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]",
            float(sum(len(traj[0]) for traj in trajs)) / len(trajs),
            max(len(traj[0]) for traj in trajs),
            min(len(traj[0]) for traj in trajs))
        logging.vlog(2, "Trajectory Lengths: %s",
                     [len(traj[0]) for traj in trajs])

        padding_start_time = time.time()
        (_, reward_mask, padded_observations, padded_actions,
         padded_rewards) = pad_trajectories(trajs, boundary=boundary)
        padding_time = get_time(padding_start_time)

        logging.vlog(1, "Padding trajectories took %0.2f msec.",
                     get_time(padding_start_time))
        logging.vlog(1, "Padded Observations' shape [%s]",
                     str(padded_observations.shape))
        logging.vlog(1, "Padded Actions' shape [%s]",
                     str(padded_actions.shape))
        logging.vlog(1, "Padded Rewards' shape [%s]",
                     str(padded_rewards.shape))

        # Calculate log-probabilities and value predictions of the trajectories.
        # We'll pass these to the loss functions so as to not get recomputed.

        # NOTE:
        # There is a slight problem here, if the policy network contains
        # stochasticity in the log-probabilities (ex: dropout), then calculating
        # these again here is not going to be correct and should be done in the
        # collect function.

        log_prob_recompute_start_time = time.time()
        jax_rng_key, key = jax_random.split(jax_rng_key)
        log_probabs_traj, value_predictions_traj, _ = get_predictions(
            padded_observations, rng=key)
        log_prob_recompute_time = get_time(log_prob_recompute_start_time)

        # Some assertions.
        B, T = padded_actions.shape  # pylint: disable=invalid-name
        assert (B, T) == padded_rewards.shape
        assert (B, T) == reward_mask.shape
        assert (B, T + 1) == padded_observations.shape[:2]
        assert (B, T +
                1) + env.observation_space.shape == padded_observations.shape

        # Linear annealing from 0.1 to 0.0
        # epsilon_schedule = epsilon if epochs == 1 else epsilon * (1.0 -
        #                                                           (i /
        #                                                            (epochs - 1)))

        # Constant epsilon.
        epsilon_schedule = epsilon

        # Compute value and ppo losses.
        jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2)
        logging.vlog(2, "Starting to compute P&V loss.")
        loss_compute_start_time = time.time()
        cur_combined_loss, cur_ppo_loss, cur_value_loss, entropy_bonus = (
            combined_loss(policy_and_value_net_params,
                          log_probabs_traj,
                          value_predictions_traj,
                          policy_and_value_net_apply,
                          padded_observations,
                          padded_actions,
                          padded_rewards,
                          reward_mask,
                          gamma=gamma,
                          lambda_=lambda_,
                          epsilon=epsilon_schedule,
                          c1=c1,
                          c2=c2,
                          rng=key1))
        loss_compute_time = get_time(loss_compute_start_time)
        logging.vlog(
            1,
            "Calculating P&V loss [%10.2f(%10.2f, %10.2f, %10.2f)] took %0.2f msec.",
            cur_combined_loss, cur_value_loss, cur_ppo_loss, entropy_bonus,
            get_time(loss_compute_start_time))

        jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2)
        logging.vlog(1, "Policy and Value Optimization")
        optimization_start_time = time.time()
        keys = jax_random.split(key1, num=n_optimizer_steps)
        for j in range(n_optimizer_steps):
            k1, k2, k3 = jax_random.split(keys[j], num=3)
            t = time.time()
            # Update the optimizer state.
            policy_and_value_opt_state = policy_and_value_opt_step(
                j,
                policy_and_value_opt_state,
                policy_and_value_opt_update,
                policy_and_value_get_params,
                policy_and_value_net_apply,
                log_probabs_traj,
                value_predictions_traj,
                padded_observations,
                padded_actions,
                padded_rewards,
                reward_mask,
                c1=c1,
                c2=c2,
                gamma=gamma,
                lambda_=lambda_,
                epsilon=epsilon_schedule,
                rng=k1)

            # Compute the approx KL for early stopping.
            new_policy_and_value_net_params = policy_and_value_get_params(
                policy_and_value_opt_state)

            log_probab_actions_new, _ = policy_and_value_net_apply(
                padded_observations, new_policy_and_value_net_params, rng=k2)

            approx_kl = approximate_kl(log_probab_actions_new,
                                       log_probabs_traj, reward_mask)

            early_stopping = enable_early_stopping and approx_kl > 1.5 * target_kl
            if early_stopping:
                logging.vlog(
                    1,
                    "Early stopping policy and value optimization at iter: %d, "
                    "with approx_kl: %0.2f", j, approx_kl)
                # We don't return right-away, we want the below to execute on the last
                # iteration.

            t2 = time.time()
            if (((j + 1) % print_every_optimizer_steps == 0)
                    or (j == n_optimizer_steps - 1) or early_stopping):
                # Compute and log the loss.
                (loss_combined, loss_ppo, loss_value,
                 entropy_bonus) = (combined_loss(
                     new_policy_and_value_net_params,
                     log_probabs_traj,
                     value_predictions_traj,
                     policy_and_value_net_apply,
                     padded_observations,
                     padded_actions,
                     padded_rewards,
                     reward_mask,
                     gamma=gamma,
                     lambda_=lambda_,
                     epsilon=epsilon_schedule,
                     c1=c1,
                     c2=c2,
                     rng=k3))
                logging.vlog(
                    1, "One Policy and Value grad desc took: %0.2f msec",
                    get_time(t, t2))
                logging.vlog(
                    1, "Combined Loss(value, ppo, entropy_bonus) [%10.2f] ->"
                    " [%10.2f(%10.2f,%10.2f,%10.2f)]", cur_combined_loss,
                    loss_combined, loss_value, loss_ppo, entropy_bonus)

            if early_stopping:
                break

        optimization_time = get_time(optimization_start_time)

        logging.vlog(
            1, "Total Combined Loss reduction [%0.2f]%%",
            (100 *
             (cur_combined_loss - loss_combined) / np.abs(cur_combined_loss)))

        # Save parameters every time we see the end of at least a fraction of batch
        # number of trajectories that are done (not completed -- completed includes
        # truncated and done).
        # Also don't save too frequently, enforce a minimum gap.
        # Or if this is the last iteration.
        policy_save_start_time = time.time()
        n_trajectories_done += n_done
        # TODO(afrozm): Refactor to trax.save_state.
        if (((n_trajectories_done >= done_frac_for_policy_save * batch_size)
             and (i - last_saved_at > eval_every_n) and
             (((i + 1) % eval_every_n == 0))) or (i == epochs - 1)):
            logging.vlog(1, "Epoch [% 6d] saving model.", i)
            old_model_files = gfile.glob(
                os.path.join(output_dir, "model-??????.pkl"))
            params_file = os.path.join(output_dir, "model-%06d.pkl" % i)
            with gfile.GFile(params_file, "wb") as f:
                pickle.dump(policy_and_value_net_params, f)
            # Remove the old model files.
            for path in old_model_files:
                gfile.remove(path)
            # Reset this number.
            n_trajectories_done = 0
            last_saved_at = i
        policy_save_time = get_time(policy_save_start_time)

        epoch_time = get_time(epoch_start_time)

        logging.info(
            "Epoch [% 6d], Reward[min, max, avg] [%5.2f,%5.2f,%5.2f], Combined"
            " Loss(value, ppo, entropy) [%2.5f(%2.5f,%2.5f,%2.5f)]", i,
            min_reward, max_reward, avg_reward, loss_combined, loss_value,
            loss_ppo, entropy_bonus)

        timing_dict = {
            "epoch": epoch_time,
            "policy_eval": policy_eval_time,
            "trajectory_collection": trajectory_collection_time,
            "padding": padding_time,
            "log_prob_recompute": log_prob_recompute_time,
            "loss_compute": loss_compute_time,
            "optimization": optimization_time,
            "policy_save": policy_save_time,
        }

        timing_dict.update(timing_info)

        for k, v in timing_dict.items():
            timing_sw.scalar("timing/%s" % k, v, step=i)

        max_key_len = max(len(k) for k in timing_dict)
        timing_info_list = [
            "%s : % 10.2f" % (k.rjust(max_key_len + 1), v)
            for k, v in sorted(timing_dict.items())
        ]
        logging.info("Epoch [% 6d], Timings: \n%s", i,
                     "\n".join(timing_info_list))

        # Reset restore.
        restore = False

        # Flush summary writers once in a while.
        if (i + 1) % 1000 == 0 or i == epochs - 1:
            train_sw.flush()
            timing_sw.flush()
            eval_sw.flush()
コード例 #4
0
ファイル: checkpoints.py プロジェクト: hmph/flax
def save_checkpoint(ckpt_dir: Union[str, os.PathLike],
                    target,
                    step,
                    prefix='checkpoint_',
                    keep=1,
                    overwrite=False):
    """Save a checkpoint of the model.

  Attempts to be pre-emption safe by writing to temporary before
  a final rename and cleanup of past files.

  Args:
    ckpt_dir: str or pathlib-like path to store checkpoint files in.
    target: serializable flax object, usually a flax optimizer.
    step: int or float: training step number or other metric number.
    prefix: str: checkpoint file name prefix.
    keep: number of past checkpoint files to keep.
    overwrite: overwrite existing checkpoint files if a checkpoint
      at the current or a later step already exits (default: False).
  Returns:
    Filename of saved checkpoint.
  """
    ckpt_dir = os.fspath(ckpt_dir)  # Pathlib -> str
    # Write temporary checkpoint file.
    logging.info('Saving checkpoint at step: %s', step)
    if ckpt_dir.startswith('./'):
        ckpt_dir = ckpt_dir[2:]  # gfile.glob() can remove leading './'
    ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix)
    ckpt_path = _checkpoint_path(ckpt_dir, step, prefix)
    gfile.makedirs(os.path.dirname(ckpt_path))
    base_path = os.path.join(ckpt_dir, prefix)
    checkpoint_files = gfile.glob(base_path + '*')

    if ckpt_path in checkpoint_files:
        if not overwrite:
            raise errors.InvalidCheckpointError(ckpt_path, step)
    else:
        checkpoint_files.append(ckpt_path)

    checkpoint_files = natural_sort(checkpoint_files)
    if checkpoint_files[-1] == ckpt_tmp_path:
        checkpoint_files.pop(-1)
    if ckpt_path != checkpoint_files[-1]:
        if not overwrite:
            raise errors.InvalidCheckpointError(ckpt_path, step)

    with gfile.GFile(ckpt_tmp_path, 'wb') as fp:
        fp.write(serialization.to_bytes(target))

    # Rename once serialization and writing finished.
    gfile.rename(ckpt_tmp_path, ckpt_path, overwrite=overwrite)
    logging.info('Saved checkpoint at %s', ckpt_path)
    print(ckpt_path)

    # Remove newer checkpoints
    if overwrite:
        ind = checkpoint_files.index(ckpt_path) + 1
        newer_ckpts = checkpoint_files[ind:]
        checkpoint_files = checkpoint_files[:ind]
        for path in newer_ckpts:
            logging.info('Removing checkpoint at %s', path)
            gfile.remove(path)

    # Remove old checkpoint files.
    if len(checkpoint_files) > keep:
        old_ckpts = checkpoint_files[:-keep]
        for path in old_ckpts:
            logging.info('Removing checkpoint at %s', path)
            gfile.remove(path)

    return ckpt_path
コード例 #5
0
def run(workdir,
        data,
        strategy,
        architecture,
        n_layers,
        n_hiddens,
        activation,
        dropout_rate,
        l2_penalty,
        w_init_name,
        b_init_name,
        optimizer_name,
        learning_rate,
        n_epochs,
        epochs_between_checkpoints,
        init_stddev,
        cnn_stride,
        reduce_learningrate=False,
        verbosity=0):
    """Runs the whole training procedure."""
    data_tr, data_te, dataset_info = data
    n_outputs = dataset_info['num_classes']

    with strategy.scope():
        optimizer = tf.keras.optimizers.get(optimizer_name)
        optimizer.learning_rate = learning_rate
        w_init = tf.keras.initializers.get(w_init_name)
        if w_init_name.lower() in ['truncatednormal', 'randomnormal']:
            w_init.stddev = init_stddev
        b_init = tf.keras.initializers.get(b_init_name)
        if b_init_name.lower() in ['truncatednormal', 'randomnormal']:
            b_init.stddev = init_stddev
        w_reg = tf.keras.regularizers.l2(
            l2_penalty) if l2_penalty > 0 else None

        if architecture == 'cnn' or architecture == 'cnnbn':
            model = build_cnn(n_layers, n_hiddens, n_outputs, dropout_rate,
                              activation, cnn_stride, w_reg, w_init, b_init,
                              architecture == 'cnnbn')
        elif architecture == 'fcn':
            model = build_fcn(n_layers, n_hiddens, n_outputs, dropout_rate,
                              activation, w_reg, w_init, b_init, False)
        else:
            assert False, 'Unknown architecture: ' % architecture

        model.compile(
            optimizer=optimizer,
            loss=tf.keras.losses.SparseCategoricalCrossentropy(
                from_logits=True),
            metrics=['accuracy', 'mse', 'sparse_categorical_crossentropy'])

    # force the model to set input shapes and init weights
    for x, _ in data_tr:
        model.predict(x)
        if verbosity:
            model.summary()
        break

    ckpt = tf.train.Checkpoint(step=optimizer.iterations,
                               optimizer=optimizer,
                               model=model)
    ckpt_dir = os.path.join(workdir, 'temporary-ckpt')
    ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
    if ckpt_manager.latest_checkpoint:
        logging.info('restoring checkpoint: %s',
                     ckpt_manager.latest_checkpoint)
        print('restoring from %s' % ckpt_manager.latest_checkpoint)
        with strategy.scope():
            ckpt.restore(ckpt_manager.latest_checkpoint)
        info = restore_results(
            os.path.join(workdir, '.intermediate-results.json'))
        print(info, flush=True)
    else:
        info = {
            'steps': 0,
            'start_time': time.time(),
            'train_loss': dict(),
            'train_accuracy': dict(),
            'test_loss': dict(),
            'test_accuracy': dict(),
        }
        info.update(_get_workunit_params())  # Add command line parameters.

    logger = None
    starting_epoch = len(info['train_loss'])
    cur_epoch = starting_epoch
    for cur_epoch in range(starting_epoch, n_epochs):
        if reduce_learningrate and cur_epoch == n_epochs - (n_epochs // 10):
            optimizer.learning_rate = learning_rate / 10
        elif reduce_learningrate and cur_epoch == n_epochs - 2:
            optimizer.learning_rate = learning_rate / 100

        # Train until we reach the criterion or get NaNs
        try:
            # always keep checkpoints for the first few epochs
            # we evaluate first and train afterwards so we have the at-init data
            if cur_epoch < 4 or (cur_epoch % epochs_between_checkpoints) == 0:
                eval_model(model, data_tr, data_te, info, logger, cur_epoch,
                           workdir)

            model.fit(data_tr, epochs=1, verbose=verbosity)
            ckpt_manager.save()
            store_results(info,
                          os.path.join(workdir, '.intermediate-results.json'))

            dt = time.time() - info['start_time']
            logging.info('epoch %d (%3.2fs)', cur_epoch, dt)

        except tf.errors.InvalidArgumentError as e:
            # We got NaN in the loss, most likely gradients resulted in NaNs
            logging.info(str(e))
            info['status'] = 'NaN'
            logging.info('Stop training because NaNs encountered')
            break

    eval_model(model, data_tr, data_te, info, logger, cur_epoch + 1, workdir)
    store_results(info, os.path.join(workdir, 'results.json'))

    # we don't need the temporary checkpoints anymore
    gfile.rmtree(os.path.join(workdir, 'temporary-ckpt'))
    gfile.remove(os.path.join(workdir, '.intermediate-results.json'))
コード例 #6
0
ファイル: ppo.py プロジェクト: hubayirp/fabric-vsf
def training_loop(
        env,
        eval_env,
        env_name,
        policy_and_value_net_fn,
        policy_and_value_optimizer_fn,
        output_dir,
        epochs=EPOCHS,
        n_optimizer_steps=N_OPTIMIZER_STEPS,
        print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP,
        target_kl=0.01,
        boundary=20,
        max_timestep=None,
        max_timestep_eval=20000,
        random_seed=None,
        gamma=GAMMA,
        lambda_=LAMBDA,
        epsilon=EPSILON,
        c1=1.0,
        c2=0.01,
        eval_every_n=1000,
        done_frac_for_policy_save=0.5,
        enable_early_stopping=True,
        n_evals=1,
        len_history_for_policy=4,
        eval_temperatures=(1.0, 0.5),
):
    """Runs the training loop for PPO, with fixed policy and value nets.

  Args:
    env: gym.Env to use for training.
    eval_env: gym.Env to use for evaluation.
    env_name: Name of the environment.
    policy_and_value_net_fn: Function defining the policy and value network.
    policy_and_value_optimizer_fn: Function defining the optimizer.
    output_dir: Output dir.
    epochs: Number of epochs to run for.
    n_optimizer_steps: Number of optimizer steps.
    print_every_optimizer_steps: How often to log during the policy optimization
      process.
    target_kl: Policy iteration early stopping.
    boundary: We pad trajectories at integer multiples of this number.
    max_timestep: If set to an integer, maximum number of time-steps in
      a trajectory. Used in the collect procedure.
    max_timestep_eval: If set to an integer, maximum number of time-steps in an
      evaluation trajectory. Used in the collect procedure.
    random_seed: Random seed.
    gamma: Reward discount factor.
    lambda_: N-step TD-error discount factor in GAE.
    epsilon: Random action probability in epsilon-greedy sampling.
    c1: Value loss coefficient.
    c2: Entropy loss coefficient.
    eval_every_n: How frequently to eval the policy.
    done_frac_for_policy_save: Fraction of the trajectories that should be done
      to checkpoint the policy.
    enable_early_stopping: Whether to enable early stopping.
    n_evals: Number of times to evaluate.
    len_history_for_policy: How much of history to give to the policy.
    eval_temperatures: Sequence of temperatures to try for categorical sampling
      during evaluation.
  """
    gfile.makedirs(output_dir)

    # Create summary writers and history.
    train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train"))
    timing_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "timing"))
    eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval"))

    train_sw.text("env_name", env_name)
    timing_sw.text("env_name", env_name)
    eval_sw.text("env_name", env_name)

    jax_rng_key = trax.get_random_number_generator_and_set_seed(random_seed)

    # Batch Observations Shape = [1, 1] + OBS, because we will eventually call
    # policy and value networks on shape [B, T] +_OBS
    batch_observations_shape = (1, 1) + env.observation_space.shape
    observations_dtype = env.observation_space.dtype

    assert isinstance(env.action_space, gym.spaces.Discrete)
    n_actions = env.action_space.n

    jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2)

    # Initialize the policy and value network.
    policy_and_value_net_params, policy_and_value_net_apply = (
        policy_and_value_net_fn(key1, batch_observations_shape,
                                observations_dtype, n_actions))

    # Maybe restore the policy params. If there is nothing to restore, then
    # iteration = 0 and policy_and_value_net_params are returned as is.
    restore, policy_and_value_net_params, iteration = (maybe_restore_params(
        output_dir, policy_and_value_net_params))

    if restore:
        logging.info("Restored parameters from iteration [%d]", iteration)
        # We should start from the next iteration.
        iteration += 1

    policy_and_value_net_apply = jit(policy_and_value_net_apply)

    # Initialize the optimizers.
    policy_and_value_optimizer = (
        policy_and_value_optimizer_fn(policy_and_value_net_params))
    (policy_and_value_opt_state, policy_and_value_opt_update,
     policy_and_value_get_params) = policy_and_value_optimizer

    n_trajectories_done = 0
    last_saved_at = 0

    logging.info("Starting the PPO training loop.")
    for i in range(iteration, epochs):
        epoch_start_time = time.time()

        # Params we'll use to collect the trajectories.
        policy_and_value_net_params = policy_and_value_get_params(
            policy_and_value_opt_state)

        # A function to get the policy and value predictions.
        def get_predictions(observations, rng=None):
            """Returns log-probs, value predictions and key back."""
            key, key1 = jax_random.split(rng, num=2)

            log_probs, value_preds = policy_and_value_net_apply(
                observations, policy_and_value_net_params, rng=key1)

            return log_probs, value_preds, key

        # Evaluate the policy.
        policy_eval_start_time = time.time()
        if ((i + 1) % eval_every_n == 0) or (i == epochs - 1):
            jax_rng_key, key = jax_random.split(jax_rng_key, num=2)

            logging.vlog(1, "Epoch [% 6d] evaluating policy.", i)

            reward_stats = evaluate_policy(
                eval_env,
                get_predictions,
                temperatures=eval_temperatures,
                max_timestep=max_timestep_eval,
                n_evals=n_evals,
                len_history_for_policy=len_history_for_policy,
                rng=key)
            write_eval_reward_summaries(reward_stats, eval_sw, epoch=i)
        policy_eval_time = get_time(policy_eval_start_time)

        trajectory_collection_start_time = time.time()
        logging.vlog(1, "Epoch [% 6d] collecting trajectories.", i)
        jax_rng_key, key = jax_random.split(jax_rng_key)
        trajs, n_done, timing_info = collect_trajectories(
            env,
            policy_fn=get_predictions,
            n_trajectories=env.batch_size,
            max_timestep=max_timestep,
            rng=key,
            len_history_for_policy=len_history_for_policy,
            reset=(i == 0) or restore,
            epsilon=(10.0 / (i + 10.0)))  # this is a different epsilon.
        trajectory_collection_time = get_time(trajectory_collection_start_time)

        logging.vlog(1, "Collecting trajectories took %0.2f msec.",
                     trajectory_collection_time)

        avg_reward = float(sum(np.sum(traj[2]) for traj in trajs)) / len(trajs)
        max_reward = max(np.sum(traj[2]) for traj in trajs)
        min_reward = min(np.sum(traj[2]) for traj in trajs)

        train_sw.scalar("train/reward_mean_truncated", avg_reward, step=i)

        logging.vlog(1,
                     "Rewards avg=[%0.2f], max=[%0.2f], min=[%0.2f], all=%s",
                     avg_reward, max_reward, min_reward,
                     [float(np.sum(traj[2])) for traj in trajs])

        logging.vlog(
            1, "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]",
            float(sum(len(traj[0]) for traj in trajs)) / len(trajs),
            max(len(traj[0]) for traj in trajs),
            min(len(traj[0]) for traj in trajs))
        logging.vlog(2, "Trajectory Lengths: %s",
                     [len(traj[0]) for traj in trajs])

        padding_start_time = time.time()
        (_, reward_mask, padded_observations, padded_actions, padded_rewards,
         padded_infos) = pad_trajectories(trajs, boundary=boundary)
        padding_time = get_time(padding_start_time)

        logging.vlog(1, "Padding trajectories took %0.2f msec.",
                     get_time(padding_start_time))
        logging.vlog(1, "Padded Observations' shape [%s]",
                     str(padded_observations.shape))
        logging.vlog(1, "Padded Actions' shape [%s]",
                     str(padded_actions.shape))
        logging.vlog(1, "Padded Rewards' shape [%s]",
                     str(padded_rewards.shape))

        # Some assertions.
        B, T = padded_actions.shape  # pylint: disable=invalid-name
        assert (B, T) == padded_rewards.shape
        assert (B, T) == reward_mask.shape
        assert (B, T + 1) == padded_observations.shape[:2]
        assert (B, T +
                1) + env.observation_space.shape == padded_observations.shape

        log_prob_recompute_start_time = time.time()
        assert ("log_prob_actions" in padded_infos
                and "value_predictions" in padded_infos)
        # These are the actual log-probabs and value predictions seen while picking
        # the actions.
        actual_log_probabs_traj = padded_infos["log_prob_actions"]
        actual_value_predictions_traj = padded_infos["value_predictions"]

        assert (B, T) == actual_log_probabs_traj.shape[:2]
        A = actual_log_probabs_traj.shape[2]  # pylint: disable=invalid-name
        assert (B, T, 1) == actual_value_predictions_traj.shape

        # TODO(afrozm): log-probabs doesn't need to be (B, T+1, A) it can do with
        # (B, T, A), so make that change throughout.

        # NOTE: We don't have the log-probabs and value-predictions for the last
        # observation, so we re-calculate for everything, but use the original ones
        # for all but the last time-step.
        jax_rng_key, key = jax_random.split(jax_rng_key)
        log_probabs_traj, value_predictions_traj, _ = get_predictions(
            padded_observations, rng=key)

        assert (B, T + 1, A) == log_probabs_traj.shape
        assert (B, T + 1, 1) == value_predictions_traj.shape

        # Concatenate the last time-step's log-probabs and value predictions to the
        # actual log-probabs and value predictions and use those going forward.
        log_probabs_traj = np.concatenate(
            (actual_log_probabs_traj, log_probabs_traj[:, -1:, :]), axis=1)
        value_predictions_traj = np.concatenate(
            (actual_value_predictions_traj, value_predictions_traj[:, -1:, :]),
            axis=1)

        log_prob_recompute_time = get_time(log_prob_recompute_start_time)

        # Linear annealing from 0.1 to 0.0
        # epsilon_schedule = epsilon if epochs == 1 else epsilon * (1.0 -
        #                                                           (i /
        #                                                            (epochs - 1)))

        # Constant epsilon.
        epsilon_schedule = epsilon

        # Compute value and ppo losses.
        jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2)
        logging.vlog(2, "Starting to compute P&V loss.")
        loss_compute_start_time = time.time()
        cur_combined_loss, cur_ppo_loss, cur_value_loss, entropy_bonus = (
            combined_loss(policy_and_value_net_params,
                          log_probabs_traj,
                          value_predictions_traj,
                          policy_and_value_net_apply,
                          padded_observations,
                          padded_actions,
                          padded_rewards,
                          reward_mask,
                          gamma=gamma,
                          lambda_=lambda_,
                          epsilon=epsilon_schedule,
                          c1=c1,
                          c2=c2,
                          rng=key1))
        loss_compute_time = get_time(loss_compute_start_time)
        logging.vlog(
            1,
            "Calculating P&V loss [%10.2f(%10.2f, %10.2f, %10.2f)] took %0.2f msec.",
            cur_combined_loss, cur_value_loss, cur_ppo_loss, entropy_bonus,
            get_time(loss_compute_start_time))

        jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2)
        logging.vlog(1, "Policy and Value Optimization")
        optimization_start_time = time.time()
        keys = jax_random.split(key1, num=n_optimizer_steps)
        for j in range(n_optimizer_steps):
            k1, k2, k3 = jax_random.split(keys[j], num=3)
            t = time.time()
            # Update the optimizer state.
            policy_and_value_opt_state = policy_and_value_opt_step(
                j,
                policy_and_value_opt_state,
                policy_and_value_opt_update,
                policy_and_value_get_params,
                policy_and_value_net_apply,
                log_probabs_traj,
                value_predictions_traj,
                padded_observations,
                padded_actions,
                padded_rewards,
                reward_mask,
                c1=c1,
                c2=c2,
                gamma=gamma,
                lambda_=lambda_,
                epsilon=epsilon_schedule,
                rng=k1)

            # Compute the approx KL for early stopping.
            new_policy_and_value_net_params = policy_and_value_get_params(
                policy_and_value_opt_state)

            log_probab_actions_new, _ = policy_and_value_net_apply(
                padded_observations, new_policy_and_value_net_params, rng=k2)

            approx_kl = approximate_kl(log_probab_actions_new,
                                       log_probabs_traj, reward_mask)

            early_stopping = enable_early_stopping and approx_kl > 1.5 * target_kl
            if early_stopping:
                logging.vlog(
                    1,
                    "Early stopping policy and value optimization at iter: %d, "
                    "with approx_kl: %0.2f", j, approx_kl)
                # We don't return right-away, we want the below to execute on the last
                # iteration.

            t2 = time.time()
            if (((j + 1) % print_every_optimizer_steps == 0)
                    or (j == n_optimizer_steps - 1) or early_stopping):
                # Compute and log the loss.
                (loss_combined, loss_ppo, loss_value,
                 entropy_bonus) = (combined_loss(
                     new_policy_and_value_net_params,
                     log_probabs_traj,
                     value_predictions_traj,
                     policy_and_value_net_apply,
                     padded_observations,
                     padded_actions,
                     padded_rewards,
                     reward_mask,
                     gamma=gamma,
                     lambda_=lambda_,
                     epsilon=epsilon_schedule,
                     c1=c1,
                     c2=c2,
                     rng=k3))
                logging.vlog(
                    1, "One Policy and Value grad desc took: %0.2f msec",
                    get_time(t, t2))
                logging.vlog(
                    1, "Combined Loss(value, ppo, entropy_bonus) [%10.2f] ->"
                    " [%10.2f(%10.2f,%10.2f,%10.2f)]", cur_combined_loss,
                    loss_combined, loss_value, loss_ppo, entropy_bonus)

            if early_stopping:
                break

        optimization_time = get_time(optimization_start_time)

        logging.vlog(
            1, "Total Combined Loss reduction [%0.2f]%%",
            (100 *
             (cur_combined_loss - loss_combined) / np.abs(cur_combined_loss)))

        # Save parameters every time we see the end of at least a fraction of batch
        # number of trajectories that are done (not completed -- completed includes
        # truncated and done).
        # Also don't save too frequently, enforce a minimum gap.
        # Or if this is the last iteration.
        policy_save_start_time = time.time()
        n_trajectories_done += n_done
        # TODO(afrozm): Refactor to trax.save_state.
        if ((
            (n_trajectories_done >= done_frac_for_policy_save * env.batch_size)
                and (i - last_saved_at > eval_every_n) and
            (((i + 1) % eval_every_n == 0))) or (i == epochs - 1)):
            logging.vlog(1, "Epoch [% 6d] saving model.", i)
            old_model_files = gfile.glob(
                os.path.join(output_dir, "model-??????.pkl"))
            params_file = os.path.join(output_dir, "model-%06d.pkl" % i)
            with gfile.GFile(params_file, "wb") as f:
                pickle.dump(policy_and_value_net_params, f)
            # Remove the old model files.
            for path in old_model_files:
                gfile.remove(path)
            # Reset this number.
            n_trajectories_done = 0
            last_saved_at = i
        policy_save_time = get_time(policy_save_start_time)

        epoch_time = get_time(epoch_start_time)

        logging.info(
            "Epoch [% 6d], Reward[min, max, avg] [%5.2f,%5.2f,%5.2f], Combined"
            " Loss(value, ppo, entropy) [%2.5f(%2.5f,%2.5f,%2.5f)]", i,
            min_reward, max_reward, avg_reward, loss_combined, loss_value,
            loss_ppo, entropy_bonus)

        timing_dict = {
            "epoch": epoch_time,
            "policy_eval": policy_eval_time,
            "trajectory_collection": trajectory_collection_time,
            "padding": padding_time,
            "log_prob_recompute": log_prob_recompute_time,
            "loss_compute": loss_compute_time,
            "optimization": optimization_time,
            "policy_save": policy_save_time,
        }

        timing_dict.update(timing_info)

        for k, v in timing_dict.items():
            timing_sw.scalar("timing/%s" % k, v, step=i)

        max_key_len = max(len(k) for k in timing_dict)
        timing_info_list = [
            "%s : % 10.2f" % (k.rjust(max_key_len + 1), v)
            for k, v in sorted(timing_dict.items())
        ]
        logging.info("Epoch [% 6d], Timings: \n%s", i,
                     "\n".join(timing_info_list))

        # Reset restore.
        restore = False

        # Flush summary writers once in a while.
        if (i + 1) % 1000 == 0 or i == epochs - 1:
            train_sw.flush()
            timing_sw.flush()
            eval_sw.flush()
コード例 #7
0
ファイル: io.py プロジェクト: cybertrust1/DeepSpeech-1
def remove_remote(filename):
    """
    Wrapper that can remove local and remote files like `gs://...`
    """
    # Conditional import
    return gfile.remove(filename)