Esempio n. 1
0
 def stack_conv2d_layers(self, embedding):
     """Stack the Conv2D layers"""
     conv2dz = zipsame(self.hps.nums_filters, self.hps.filter_shapes,
                       self.hps.stride_shapes)
     for conv2d_layer_index, zipped_conv2d in enumerate(conv2dz, start=1):
         conv2d_layer_id = "conv2d{}".format(conv2d_layer_index)
         num_filters, filter_shape, stride_shape = zipped_conv2d  # unpack
         # Add cond2d hidden layer and non-linearity
         embedding = tf.layers.conv2d(
             inputs=embedding,
             filters=num_filters,
             kernel_size=filter_shape,
             strides=stride_shape,
             padding='valid',
             data_format='channels_last',
             dilation_rate=(1, 1),
             activation=parse_nonlin(self.hps.hid_nonlin),
             use_bias=True,
             kernel_initializer=self.hid_initializers['w'],
             bias_initializer=self.hid_initializers['b'],
             kernel_regularizer=self.hid_regularizers['w'],
             bias_regularizer=None,
             activity_regularizer=None,
             kernel_constraint=None,
             bias_constraint=None,
             trainable=True,
             name=conv2d_layer_id,
             reuse=tf.AUTO_REUSE)
     # Flatten between conv2d layers and fully-connected layers
     embedding = tf.layers.flatten(inputs=embedding, name='flatten')
     return embedding
Esempio n. 2
0
def zipsame_prune(l1, l2):
    out = []
    for a, b in zipsame(l1, l2):
        if b is None:
            continue
        out.append((a, b))
    return out
Esempio n. 3
0
 def update_priorities(self, idxs, priorities):
     # Update priorities via the legacy method, used w/ ranked approach
     idxs, priorities = super().update_priorities(idxs, priorities)
     # Register whether a transition is b/g in the UNREAL-specific sum trees
     for idx, priority in zipsame(idxs, priorities):
         # Decide whether the transition to be added is good or bad
         # Get the rank from the priority
         # Note: UnrealRB inherits from PER w/ 'ranked' set to True
         if idx < self.num_demos:
             # When the transition is from the demos, always set it as 'good' regardless
             self.b_sum_st[idx] = 0
             self.g_sum_st[idx] = 1
         else:
             rank = (1. / priority) - 1
             thres = floor(.5 * self.num_entries)
             is_g = rank < thres
             is_g *= 1  # HAXX: multiply by 1 to cast the bool into an int
             # Fill the good and bad sum segment trees w/ the obtained value
             self.b_sum_st[idx] = 1 - is_g
             self.g_sum_st[idx] = is_g
     if debug:
         # Verify updates
         # Compute the cardinalities of virtual sub-buffers
         b_num_entries = self.b_sum_st.sum(end=self.num_entries)
         g_num_entries = self.g_sum_st.sum(end=self.num_entries)
         print("[num entries]    b: {}    | g: {}".format(b_num_entries, g_num_entries))
         print("total num entries: {}".format(self.num_entries))
Esempio n. 4
0
def run(args):
    """Spawn jobs"""
    # Create meta-experiment identifier
    meta = rand_id()
    # Define experiment type
    if args.rand:
        type_exp = 'hpsearch'
    else:
        type_exp = 'sweep'
    # Get hyperparameter configurations
    if args.rand:
        # Get a number of random hyperparameter configurations
        hpmaps = [get_rand_hps(args, meta, args.num_seeds) for _ in range(args.num_rand_trials)]
        # Flatten into a 1-dim list
        hpmaps = flatten_lists(hpmaps)
    else:
        # Get the deterministic spectrum of specified hyperparameters
        hpmaps = get_spectrum_hps(args, meta, args.num_seeds)
    # Create associated task strings
    exp_strs = [format_exp_str(args, hpmap) for hpmap in hpmaps]
    if not len(exp_strs) == len(set(exp_strs)):
        # Terminate in case of duplicate experiment (extremely unlikely though)
        raise ValueError("bad luck, there are dupes -> Try again :)")
    # Create the job maps
    job_maps = [get_job_map(args,
                            meta,
                            i,
                            hpmap['env_id'],
                            hpmap['seed'],
                            '0' if args.task == 'ppo' else hpmap['num_demos'],
                            type_exp)
                for i, hpmap in enumerate(hpmaps)]
    # Finally get all the required job strings
    job_strs = [format_job_str(args, jm, es) for jm, es in zipsame(job_maps, exp_strs)]
    # Spawn the jobs
    for i, (jm, js) in enumerate(zipsame(job_maps, job_strs)):
        print('-' * 10 + "> job #{} launcher content:".format(i))
        print(js + "\n")
        job_name = "{}.sh".format(jm['job-name'])
        with open(job_name, 'w') as f:
            f.write(js)
        if args.call:
            # Spawn the job!
            call(["sbatch", "./{}".format(job_name)])
    # Summarize the number of jobs spawned
    print("total num job (successfully) spawned: {}".format(len(job_strs)))
Esempio n. 5
0
 def __init__(self, limit, ob_shape, ac_shape):
     self.limit = limit
     self.ob_shape = ob_shape
     self.ac_shape = ac_shape
     self.num_demos = 0
     self.atom_names = ['obs0', 'acs', 'rews', 'dones1', 'obs1']
     self.atom_shapes = [self.ob_shape, self.ac_shape, (1,), (1,), self.ob_shape]
     # Create one `RingBuffer` object for every atom in a transition
     self.ring_buffers = {atom_name: RingBuffer(self.limit, atom_shape)
                          for atom_name, atom_shape in zipsame(self.atom_names,
                                                               self.atom_shapes)}
Esempio n. 6
0
    def update_priorities(self, idxs, priorities):
        """Update priorities according to the PER paper, i.e. by updating
        only the priority of sampled transitions. A priority priorities[i] is
        assigned to the transition at index indices[i].
        Note: not in use in the vanilla setting, but here if needed in extensions.
        """
        global debug
        if self.ranked:
            # Override the priorities to be 1 / (rank(priority) + 1)
            # Add new index, priority pairs to the list
            self.i_p.update({i: p for i, p in zipsame(idxs, priorities)})
            # Rank the indices by priorities
            i_sorted_by_p = sorted(self.i_p.items(), key=lambda t: t[1], reverse=True)
            # Create the index, rank dict
            i_r = {i: i_sorted_by_p.index((i, p)) for i, p in self.i_p.items()}
            # Unpack indices and ranks
            _idxs, ranks = zipsame(*i_r.items())
            # Override the indices and priorities
            idxs = list(_idxs)
            priorities = [1. / (rank + 1) for rank in ranks]  # start ranks at 1
            if debug:
                # Verify that the priorities have been properly overridden
                for idx, priority in zipsame(idxs, priorities):
                    print("index: {}    | priority: {}".format(idx, priority))

        assert len(idxs) == len(priorities), "the two arrays must be the same length"
        for idx, priority in zipsame(idxs, priorities):
            assert priority > 0, "priorities must be positive"
            assert 0 <= idx < self.num_entries, "no element in buffer associated w/ index"
            if idx < self.num_demos:
                # Add a priority bonus when replaying a demo
                priority += self.demos_eps
            self.sum_st[idx] = priority ** self.alpha
            self.min_st[idx] = priority ** self.alpha
            # Update max priority currently in the buffer
            self.max_priority = max(priority, self.max_priority)

        if self.ranked:
            # Return indices and associated overriden priorities
            # Note: returned values are only used in the UNREAL priority update function
            return idxs, priorities
Esempio n. 7
0
def log_module_info(logger, name, *components):
    assert len(components) > 0, "components list is empty"
    for component in components:
        logger.info("logging {}/{} specs".format(name, component.name))
        names = [var.name for var in component.trainable_vars]
        shapes = [var_shape(var) for var in component.trainable_vars]
        num_paramss = [numel(var) for var in component.trainable_vars]
        zipped_info = zipsame(names, shapes, num_paramss)
        logger.info(columnize(names=['name', 'shape', 'num_params'],
                              tuples=zipped_info,
                              widths=[40, 16, 10]))
        logger.info("  total num params: {}".format(sum(num_paramss)))
Esempio n. 8
0
def columnize(names, tuples, widths, indent=2):
    """Generate and return the content of table
    (w/o logging or printing anything)

    Args:
        width (int): Width of each cell in the table
        indent (int): Indentation spacing prepended to every row in the table
    """
    indent_space = indent * ' '
    # Add row containing the names
    table = indent_space + " | ".join(cell(name, width) for name, width in zipsame(names, widths))
    table_width = len(table)
    # Add header hline
    table += '\n' + indent_space + ('-' * table_width)
    for tuple_ in tuples:
        # Add a new row
        table += '\n' + indent_space
        table += " | ".join(cell(value, width) for value, width in zipsame(tuple_, widths))
    # Add closing hline
    table += '\n' + indent_space + ('-' * table_width)
    return table
Esempio n. 9
0
 def add_demo_transitions_to_mem(self, dset):
     """Add transitions from expert demonstration trajectories to memory"""
     # Ensure the replay buffer is empty as demos need to be first
     assert self.num_entries == 0 and self.num_demos == 0
     logger.info("adding demonstrations to memory")
     # Zip transition atoms
     transitions = zipsame(dset.obs0, dset.acs, dset.env_rews, dset.obs1, dset.dones1)
     # Note: careful w/ the order, it should correspond to the order in `append` signature
     for transition in transitions:
         self.append(*transition, is_demo=True)
         self.num_demos += 1
     assert self.num_demos == self.num_entries
     logger.info("  num entries in memory after addition: {}".format(self.num_entries))
Esempio n. 10
0
def flatgrad(loss, var_list, clip_norm=None):
    """Returns a list of sum(dy/dx) for each x in `var_list`
    Clipping is done by global norm (paper: https://arxiv.org/abs/1211.5063)
    """
    grads = tf.gradients(loss, var_list)
    if clip_norm is not None:
        grads, _ = tf.clip_by_global_norm(grads, clip_norm=clip_norm)
    vars_and_grads = zipsame(var_list, grads)  # zip with extra security
    for index, (var, grad) in enumerate(vars_and_grads):
        # If the gradient gets stopped for some obsure reason, set the grad as zero vector
        _grad = grad if grad is not None else tf.zeros_like(var)
        # Reshape the grad into a vector
        grads[index] = tf.reshape(_grad, [numel(var)])
    # return tf.concat(grads, axis=0)
    return tf.concat(grads, axis=0)
Esempio n. 11
0
def get_target_updates(vars_, targ_vars, polyak):
    """Return assignment ops for target network updates.
    Hard updates are used for initialization only, while soft updates are
    used throughout the training process, at every iteration.
    Note that DQN uses hard updates while training, but those updates
    are not performed every iteration (only once every XX iterations).
    """
    logger.info("setting up target updates")
    hard_updates = []
    soft_updates = []
    assert len(vars_) == len(targ_vars)
    for var_, targ_var in zipsame(vars_, targ_vars):
        logger.info('  {} <- {}'.format(targ_var.name, var_.name))
        hard_updates.append(tf.assign(targ_var, var_))
        soft_updates.append(
            tf.assign(targ_var, (1. - polyak) * targ_var + polyak * var_))
    assert len(hard_updates) == len(vars_)
    assert len(soft_updates) == len(vars_)
    return tf.group(*hard_updates), tf.group(
        *soft_updates)  # ops that group ops
Esempio n. 12
0
def get_p_actor_updates(actor, perturbed_actor, pn_std):
    """Return assignment ops for actor parameters noise perturbations.
    The perturbations consist in applying additive gaussian noise the the perturbable
    actor variables, while simply leaving the non-perturbable ones untouched.
    """
    assert len(actor.vars) == len(perturbed_actor.vars)
    assert len(actor.perturbable_vars) == len(perturbed_actor.perturbable_vars)

    updates = []
    for var_, perturbed_var in zipsame(actor.vars, perturbed_actor.vars):
        if var_ in actor.perturbable_vars:
            logger.info("  {} <- {} + noise".format(perturbed_var.name,
                                                    var_.name))
            noised_up_var = var_ + tf.random_normal(
                tf.shape(var_), mean=0., stddev=pn_std)
            updates.append(tf.assign(perturbed_var, noised_up_var))
        else:
            logger.info("  {} <- {}".format(perturbed_var.name, var_.name))
            updates.append(tf.assign(perturbed_var, var_))
    assert len(updates) == len(actor.vars)
    return tf.group(*updates)
Esempio n. 13
0
def learn(comm, env, xpo_agent_wrapper, sample_or_mode, gamma, max_kl,
          save_frequency, ckpt_dir, summary_dir, timesteps_per_batch,
          batch_size, experiment_name, ent_reg_scale, gae_lambda, cg_iters,
          cg_damping, vf_iters, vf_lr, max_iters):

    rank = comm.Get_rank()

    # Create policies
    pi = xpo_agent_wrapper('pi')
    old_pi = xpo_agent_wrapper('old_pi')

    # Create and retrieve already-existing placeholders
    ob = get_placeholder_cached(name='ob')
    ac = pi.pd_type.sample_placeholder([None])
    adv = tf.placeholder(name='adv', dtype=tf.float32, shape=[None])
    ret = tf.placeholder(name='ret', dtype=tf.float32, shape=[None])
    flat_tangent = tf.placeholder(name='flat_tan',
                                  dtype=tf.float32,
                                  shape=[None])

    # Build graphs
    kl_mean = tf.reduce_mean(old_pi.pd_pred.kl(pi.pd_pred))
    ent_mean = tf.reduce_mean(pi.pd_pred.entropy())
    ent_bonus = ent_reg_scale * ent_mean
    vf_err = tf.reduce_mean(tf.square(pi.v_pred - ret))  # MC error
    # The surrogate objective is defined as: advantage * pnew / pold
    ratio = tf.exp(pi.pd_pred.logp(ac) - old_pi.pd_pred.logp(ac))  # IS
    surr_gain = tf.reduce_mean(ratio * adv)  # surrogate objective (CPI)
    # Add entropy bonus
    optim_gain = surr_gain + ent_bonus

    losses = OrderedDict()

    # Add losses
    losses.update({
        'pol_kl_mean': kl_mean,
        'pol_ent_mean': ent_mean,
        'pol_ent_bonus': ent_bonus,
        'pol_surr_gain': surr_gain,
        'pol_optim_gain': optim_gain,
        'pol_vf_err': vf_err
    })

    # Build natural gradient material
    get_flat = GetFlat(pi.pol_trainable_vars)
    set_from_flat = SetFromFlat(pi.pol_trainable_vars)
    kl_grads = tf.gradients(kl_mean, pi.pol_trainable_vars)
    shapes = [var.get_shape().as_list() for var in pi.pol_trainable_vars]
    start = 0
    tangents = []
    for shape in shapes:
        sz = intprod(shape)
        tangents.append(tf.reshape(flat_tangent[start:start + sz], shape))
        start += sz
    # Create the gradient vector product
    gvp = tf.add_n([
        tf.reduce_sum(g * tangent)
        for (g, tangent) in zipsame(kl_grads, tangents)
    ])
    # Create the Fisher vector product
    fvp = flatgrad(gvp, pi.pol_trainable_vars)

    # Make the current `pi` become the next `old_pi`
    zipped = zipsame(old_pi.vars, pi.vars)
    updates_op = []
    for k, v in zipped:
        # Populate list of assignment operations
        logger.info("assignment: {} <- {}".format(k, v))
        assign_op = tf.assign(k, v)
        updates_op.append(assign_op)
    assert len(updates_op) == len(pi.vars)

    # Create mpi adam optimizer for the value function
    vf_optimizer = MpiAdamOptimizer(comm=comm,
                                    clip_norm=5.0,
                                    learning_rate=vf_lr,
                                    name='vf_adam')
    optimize_vf = vf_optimizer.minimize(loss=vf_err,
                                        var_list=pi.vf_trainable_vars)

    # Create gradients
    grads = flatgrad(optim_gain, pi.pol_trainable_vars)

    # Create callable objects
    assign_old_eq_new = TheanoFunction(inputs=[], outputs=updates_op)
    compute_losses = TheanoFunction(inputs=[ob, ac, adv, ret],
                                    outputs=list(losses.values()))
    compute_losses_grads = TheanoFunction(inputs=[ob, ac, adv, ret],
                                          outputs=list(losses.values()) +
                                          [grads])
    compute_fvp = TheanoFunction(inputs=[flat_tangent, ob, ac, adv],
                                 outputs=fvp)
    optimize_vf = TheanoFunction(inputs=[ob, ret], outputs=optimize_vf)

    # Initialise variables
    initialize()

    # Sync params of all processes with the params of the root process
    theta_init = get_flat()
    comm.Bcast(theta_init, root=0)
    set_from_flat(theta_init)

    vf_optimizer.sync_from_root(pi.vf_trainable_vars)

    # Create context manager that records the time taken by encapsulated ops
    timed = timed_cm_wrapper(comm, logger)

    if rank == 0:
        # Create summary writer
        summary_writer = tf.summary.FileWriterCache.get(summary_dir)

    # Create segment generator
    seg_gen = traj_segment_generator(env, pi, timesteps_per_batch,
                                     sample_or_mode)

    eps_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()

    # Define rolling buffers for recent stats aggregation
    maxlen = 100
    len_buffer = deque(maxlen=maxlen)
    env_ret_buffer = deque(maxlen=maxlen)
    pol_losses_buffer = deque(maxlen=maxlen)

    while iters_so_far <= max_iters:

        pretty_iter(logger, iters_so_far)
        pretty_elapsed(logger, tstart)

        # Verify that the processes are still in sync
        if iters_so_far > 0 and iters_so_far % 10 == 0:
            vf_optimizer.check_synced(pi.vf_trainable_vars)
            logger.info("vf params still in sync across processes")

        # Save the model
        if rank == 0 and iters_so_far % save_frequency == 0 and ckpt_dir is not None:
            model_path = osp.join(ckpt_dir, experiment_name)
            save_state(model_path, iters_so_far=iters_so_far)
            logger.info("saving model")
            logger.info("  @: {}".format(model_path))

        with timed("sampling mini-batch"):
            seg = seg_gen.__next__()

        augment_segment_gae_stats(seg, gamma, gae_lambda, rew_key="env_rews")

        # Standardize advantage function estimate
        seg['advs'] = (seg['advs'] - seg['advs'].mean()) / (seg['advs'].std() +
                                                            1e-8)

        # Update running mean and std
        if hasattr(pi, 'obs_rms'):
            with timed("normalizing obs via rms"):
                pi.obs_rms.update(seg['obs'], comm)

        def fisher_vector_product(p):
            computed_fvp = compute_fvp({
                flat_tangent: p,
                ob: seg['obs'],
                ac: seg['acs'],
                adv: seg['advs']
            })
            return mpi_mean_like(computed_fvp, comm) + cg_damping * p

        assign_old_eq_new({})

        # Compute gradients
        with timed("computing gradients"):
            *loss_before, g = compute_losses_grads({
                ob: seg['obs'],
                ac: seg['acs'],
                adv: seg['advs'],
                ret: seg['td_lam_rets']
            })

        loss_before = mpi_mean_like(loss_before, comm)

        g = mpi_mean_like(g, comm)

        if np.allclose(g, 0):
            logger.info("got zero gradient -> not updating")
        else:
            with timed("performing conjugate gradient procedure"):
                step_direction = conjugate_gradient(f_Ax=fisher_vector_product,
                                                    b=g,
                                                    cg_iters=cg_iters,
                                                    verbose=(rank == 0))
            assert np.isfinite(step_direction).all()
            shs = 0.5 * step_direction.dot(
                fisher_vector_product(step_direction))
            # shs is (1/2)*s^T*A*s in the paper
            lm = np.sqrt(shs / max_kl)
            # lm is 1/beta in the paper (max_kl is user-specified delta)
            full_step = step_direction / lm  # beta*s
            expected_improve = g.dot(full_step)  # project s on g
            surr_before = loss_before[4]  # 5-th in loss list
            step_size = 1.0
            theta_before = get_flat()

            with timed("updating policy"):
                for _ in range(
                        10):  # trying (10 times max) until the stepsize is OK
                    # Update the policy parameters
                    theta_new = theta_before + full_step * step_size
                    set_from_flat(theta_new)
                    pol_losses = compute_losses({
                        ob: seg['obs'],
                        ac: seg['acs'],
                        adv: seg['advs'],
                        ret: seg['td_lam_rets']
                    })

                    pol_losses_buffer.append(pol_losses)

                    pol_losses_mpi_mean = mpi_mean_like(pol_losses, comm)
                    surr = pol_losses_mpi_mean[4]
                    kl = pol_losses_mpi_mean[0]
                    actual_improve = surr - surr_before
                    logger.info("  expected: {:.3f} | actual: {:.3f}".format(
                        expected_improve, actual_improve))
                    if not np.isfinite(pol_losses_mpi_mean).all():
                        logger.info("  got non-finite value of losses :(")
                    elif kl > max_kl * 1.5:
                        logger.info(
                            "  violated KL constraint -> shrinking step.")
                    elif actual_improve < 0:
                        logger.info(
                            "  surrogate didn't improve -> shrinking step.")
                    else:
                        logger.info("  stepsize fine :)")
                        break
                    step_size *= 0.5  # backtracking when the step size is deemed inappropriate
                else:
                    logger.info("  couldn't compute a good step")
                    set_from_flat(theta_before)

        # Create Feeder object to iterate over (ob, ret) pairs
        feeder = Feeder(data_map={
            'obs': seg['obs'],
            'td_lam_rets': seg['td_lam_rets']
        },
                        enable_shuffle=True)

        # Update state-value function
        with timed("updating value function"):
            for _ in range(vf_iters):
                for minibatch in feeder.get_feed(batch_size=batch_size):
                    optimize_vf({
                        ob: minibatch['obs'],
                        ret: minibatch['td_lam_rets']
                    })

        # Log policy update statistics
        logger.info("logging pol training losses (log)")
        pol_losses_np_mean = np.mean(pol_losses_buffer, axis=0)
        pol_losses_mpi_mean = mpi_mean_reduce(pol_losses_buffer, comm, axis=0)
        zipped_pol_losses = zipsame(list(losses.keys()), pol_losses_np_mean,
                                    pol_losses_mpi_mean)
        logger.info(
            columnize(names=['name', 'local', 'global'],
                      tuples=zipped_pol_losses,
                      widths=[20, 16, 16]))

        # Log statistics

        logger.info("logging misc training stats (log + csv)")
        # Gather statistics across workers
        local_lens_rets = (seg['ep_lens'], seg['ep_env_rets'])
        gathered_lens_rets = comm.allgather(local_lens_rets)
        lens, env_rets = map(flatten_lists, zip(*gathered_lens_rets))
        # Extend the deques of recorded statistics
        len_buffer.extend(lens)
        env_ret_buffer.extend(env_rets)
        ep_len_mpi_mean = np.mean(len_buffer)
        ep_env_ret_mpi_mean = np.mean(env_ret_buffer)
        logger.record_tabular('ep_len_mpi_mean', ep_len_mpi_mean)
        logger.record_tabular('ep_env_ret_mpi_mean', ep_env_ret_mpi_mean)
        eps_this_iter = len(lens)
        timesteps_this_iter = sum(lens)
        eps_so_far += eps_this_iter
        timesteps_so_far += timesteps_this_iter
        eps_this_iter_mpi_mean = mpi_mean_like(eps_this_iter, comm)
        timesteps_this_iter_mpi_mean = mpi_mean_like(timesteps_this_iter, comm)
        eps_so_far_mpi_mean = mpi_mean_like(eps_so_far, comm)
        timesteps_so_far_mpi_mean = mpi_mean_like(timesteps_so_far, comm)
        logger.record_tabular('eps_this_iter_mpi_mean', eps_this_iter_mpi_mean)
        logger.record_tabular('timesteps_this_iter_mpi_mean',
                              timesteps_this_iter_mpi_mean)
        logger.record_tabular('eps_so_far_mpi_mean', eps_so_far_mpi_mean)
        logger.record_tabular('timesteps_so_far_mpi_mean',
                              timesteps_so_far_mpi_mean)
        logger.record_tabular('elapsed time',
                              prettify_time(time.time() -
                                            tstart))  # no mpi mean
        logger.record_tabular(
            'ev_td_lam_before',
            explained_variance(seg['vs'], seg['td_lam_rets']))
        iters_so_far += 1

        if rank == 0:
            logger.dump_tabular()

        if rank == 0:
            # Add summaries
            summary = tf.summary.Summary()
            tab = 'trpo'
            # Episode stats
            summary.value.add(tag="{}/{}".format(tab, 'mean_ep_len'),
                              simple_value=ep_len_mpi_mean)
            summary.value.add(tag="{}/{}".format(tab, 'mean_ep_env_ret'),
                              simple_value=ep_env_ret_mpi_mean)
            # Losses
            for name, loss in zipsame(list(losses.keys()),
                                      pol_losses_mpi_mean):
                summary.value.add(tag="{}/{}".format(tab, name),
                                  simple_value=loss)

            summary_writer.add_summary(summary, iters_so_far)
Esempio n. 14
0
def learn(comm,
          env,
          xpo_agent_wrapper,
          sample_or_mode,
          gamma,
          save_frequency,
          ckpt_dir,
          summary_dir,
          timesteps_per_batch,
          batch_size,
          optim_epochs_per_iter,
          lr,
          experiment_name,
          ent_reg_scale,
          clipping_eps,
          gae_lambda,
          schedule,
          max_iters):

    rank = comm.Get_rank()

    # Create policies
    pi = xpo_agent_wrapper('pi')
    old_pi = xpo_agent_wrapper('old_pi')

    # Create and retrieve already-existing placeholders
    ob = get_placeholder_cached(name='ob')
    ac = pi.pd_type.sample_placeholder([None])
    adv = tf.placeholder(name='adv', dtype=tf.float32, shape=[None])
    ret = tf.placeholder(name='ret', dtype=tf.float32, shape=[None])
    # Adaptive learning rate multiplier, updated with schedule
    lr_mult = tf.placeholder(name='lr_mult', dtype=tf.float32, shape=[])

    # Build graphs
    kl_mean = tf.reduce_mean(old_pi.pd_pred.kl(pi.pd_pred))
    ent_mean = tf.reduce_mean(pi.pd_pred.entropy())
    ent_pen = (-ent_reg_scale) * ent_mean
    vf_err = tf.reduce_mean(tf.square(pi.v_pred - ret))  # MC error
    # The surrogate objective is defined as: advantage * pnew / pold
    ratio = tf.exp(pi.pd_pred.logp(ac) - old_pi.pd_pred.logp(ac))  # IS
    surr_gain = ratio * adv  # surrogate objective (CPI)
    # Annealed clipping parameter epsilon
    clipping_eps = clipping_eps * lr_mult
    surr_gain_w_clipping = tf.clip_by_value(ratio,
                                            1.0 - clipping_eps,
                                            1.0 + clipping_eps) * adv
    # PPO's pessimistic surrogate (L^CLIP in paper)
    surr_loss = -tf.reduce_mean(tf.minimum(surr_gain, surr_gain_w_clipping))
    # Assemble losses (including the value function loss)
    loss = surr_loss + ent_pen + vf_err

    losses = OrderedDict()

    # Add losses
    losses.update({'pol_kl_mean': kl_mean,
                   'pol_ent_mean': ent_mean,
                   'pol_ent_pen': ent_pen,
                   'pol_surr_loss': surr_loss,
                   'pol_vf_err': vf_err,
                   'pol_total_loss': loss})

    # Make the current `pi` become the next `old_pi`
    zipped = zipsame(old_pi.vars, pi.vars)
    updates_op = []
    for k, v in zipped:
        # Populate list of assignment operations
        logger.info("assignment: {} <- {}".format(k, v))
        assign_op = tf.assign(k, v)
        updates_op.append(assign_op)
    assert len(updates_op) == len(pi.vars)

    # Create mpi adam optimizer
    optimizer = MpiAdamOptimizer(comm=comm,
                                 clip_norm=5.0,
                                 learning_rate=lr * lr_mult,
                                 name='adam')
    optimize = optimizer.minimize(loss=loss, var_list=pi.trainable_vars)

    # Create callable objects
    assign_old_eq_new = TheanoFunction(inputs=[], outputs=updates_op)
    compute_losses = TheanoFunction(inputs=[ob, ac, adv, ret, lr_mult],
                                    outputs=list(losses.values()))
    optimize = TheanoFunction(inputs=[ob, ac, adv, ret, lr_mult],
                              outputs=optimize)

    # Initialise variables
    initialize()

    # Sync params of all processes with the params of the root process
    optimizer.sync_from_root(pi.trainable_vars)

    # Create context manager that records the time taken by encapsulated ops
    timed = timed_cm_wrapper(comm, logger)

    if rank == 0:
        # Create summary writer
        summary_writer = tf.summary.FileWriterCache.get(summary_dir)

    seg_gen = traj_segment_generator(env, pi, timesteps_per_batch, sample_or_mode)

    eps_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()

    # Define rolling buffers for recent stats aggregation
    maxlen = 100
    len_buffer = deque(maxlen=maxlen)
    env_ret_buffer = deque(maxlen=maxlen)
    pol_losses_buffer = deque(maxlen=maxlen)

    while iters_so_far <= max_iters:

        pretty_iter(logger, iters_so_far)
        pretty_elapsed(logger, tstart)

        # Verify that the processes are still in sync
        if iters_so_far > 0 and iters_so_far % 10 == 0:
            optimizer.check_synced(pi.trainable_vars)
            logger.info("params still in sync across processes")

        # Manage lr multiplier schedule
        if schedule == 'constant':
            curr_lr_mult = 1.0
        elif schedule == 'linear':
            curr_lr_mult = max(1.0 - float(iters_so_far * timesteps_per_batch) /
                               max_iters * timesteps_per_batch, 0)
        else:
            raise NotImplementedError

        # Save the model
        if rank == 0 and iters_so_far % save_frequency == 0 and ckpt_dir is not None:
            model_path = osp.join(ckpt_dir, experiment_name)
            save_state(model_path, iters_so_far=iters_so_far)
            logger.info("saving model")
            logger.info("  @: {}".format(model_path))

        with timed("sampling mini-batch"):
            seg = seg_gen.__next__()

        augment_segment_gae_stats(seg, gamma, gae_lambda, rew_key="env_rews")

        # Standardize advantage function estimate
        seg['advs'] = (seg['advs'] - seg['advs'].mean()) / (seg['advs'].std() + 1e-8)

        # Update running mean and std
        if hasattr(pi, 'obs_rms'):
            with timed("normalizing obs via rms"):
                pi.obs_rms.update(seg['obs'], comm)

        assign_old_eq_new({})

        # Create Feeder object to iterate over (ob, ac, adv, td_lam_ret) tuples
        data_map = {'obs': seg['obs'],
                    'acs': seg['acs'],
                    'advs': seg['advs'],
                    'td_lam_rets': seg['td_lam_rets']}
        feeder = Feeder(data_map=data_map, enable_shuffle=True)

        # Update policy and state-value function
        with timed("updating policy and value function"):
            for _ in range(optim_epochs_per_iter):
                for minibatch in feeder.get_feed(batch_size=batch_size):

                    feeds = {ob: minibatch['obs'],
                             ac: minibatch['acs'],
                             adv: minibatch['advs'],
                             ret: minibatch['td_lam_rets'],
                             lr_mult: curr_lr_mult}

                    # Compute losses
                    pol_losses = compute_losses(feeds)

                    # Update the policy and value function
                    optimize(feeds)

                    # Store the losses
                    pol_losses_buffer.append(pol_losses)

        # Log policy update statistics
        logger.info("logging training losses (log)")
        pol_losses_np_mean = np.mean(pol_losses_buffer, axis=0)
        pol_losses_mpi_mean = mpi_mean_reduce(pol_losses_buffer, comm, axis=0)
        zipped_pol_losses = zipsame(list(losses.keys()), pol_losses_np_mean, pol_losses_mpi_mean)
        logger.info(columnize(names=['name', 'local', 'global'],
                              tuples=zipped_pol_losses,
                              widths=[20, 16, 16]))

        # Log statistics

        logger.info("logging misc training stats (log + csv)")
        # Gather statistics across workers
        local_lens_rets = (seg['ep_lens'], seg['ep_env_rets'])
        gathered_lens_rets = comm.allgather(local_lens_rets)
        lens, env_rets = map(flatten_lists, zip(*gathered_lens_rets))
        # Extend the deques of recorded statistics
        len_buffer.extend(lens)
        env_ret_buffer.extend(env_rets)
        ep_len_mpi_mean = np.mean(len_buffer)
        ep_env_ret_mpi_mean = np.mean(env_ret_buffer)
        logger.record_tabular('ep_len_mpi_mean', ep_len_mpi_mean)
        logger.record_tabular('ep_env_ret_mpi_mean', ep_env_ret_mpi_mean)
        eps_this_iter = len(lens)
        timesteps_this_iter = sum(lens)
        eps_so_far += eps_this_iter
        timesteps_so_far += timesteps_this_iter
        eps_this_iter_mpi_mean = mpi_mean_like(eps_this_iter, comm)
        timesteps_this_iter_mpi_mean = mpi_mean_like(timesteps_this_iter, comm)
        eps_so_far_mpi_mean = mpi_mean_like(eps_so_far, comm)
        timesteps_so_far_mpi_mean = mpi_mean_like(timesteps_so_far, comm)
        logger.record_tabular('eps_this_iter_mpi_mean', eps_this_iter_mpi_mean)
        logger.record_tabular('timesteps_this_iter_mpi_mean', timesteps_this_iter_mpi_mean)
        logger.record_tabular('eps_so_far_mpi_mean', eps_so_far_mpi_mean)
        logger.record_tabular('timesteps_so_far_mpi_mean', timesteps_so_far_mpi_mean)
        logger.record_tabular('elapsed time', prettify_time(time.time() - tstart))  # no mpi mean
        logger.record_tabular('ev_td_lam_before', explained_variance(seg['vs'],
                                                                     seg['td_lam_rets']))
        iters_so_far += 1

        if rank == 0:
            logger.dump_tabular()

        if rank == 0:
            # Add summaries
            summary = tf.summary.Summary()
            tab = 'ppo'
            # Episode stats
            summary.value.add(tag="{}/{}".format(tab, 'mean_ep_len'),
                              simple_value=ep_len_mpi_mean)
            summary.value.add(tag="{}/{}".format(tab, 'mean_ep_env_ret'),
                              simple_value=ep_env_ret_mpi_mean)
            # Losses
            for name, loss in zipsame(list(losses.keys()), pol_losses_mpi_mean):
                summary.value.add(tag="{}/{}".format(tab, name), simple_value=loss)

            summary_writer.add_summary(summary, iters_so_far)