def stack_conv2d_layers(self, embedding): """Stack the Conv2D layers""" conv2dz = zipsame(self.hps.nums_filters, self.hps.filter_shapes, self.hps.stride_shapes) for conv2d_layer_index, zipped_conv2d in enumerate(conv2dz, start=1): conv2d_layer_id = "conv2d{}".format(conv2d_layer_index) num_filters, filter_shape, stride_shape = zipped_conv2d # unpack # Add cond2d hidden layer and non-linearity embedding = tf.layers.conv2d( inputs=embedding, filters=num_filters, kernel_size=filter_shape, strides=stride_shape, padding='valid', data_format='channels_last', dilation_rate=(1, 1), activation=parse_nonlin(self.hps.hid_nonlin), use_bias=True, kernel_initializer=self.hid_initializers['w'], bias_initializer=self.hid_initializers['b'], kernel_regularizer=self.hid_regularizers['w'], bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, trainable=True, name=conv2d_layer_id, reuse=tf.AUTO_REUSE) # Flatten between conv2d layers and fully-connected layers embedding = tf.layers.flatten(inputs=embedding, name='flatten') return embedding
def zipsame_prune(l1, l2): out = [] for a, b in zipsame(l1, l2): if b is None: continue out.append((a, b)) return out
def update_priorities(self, idxs, priorities): # Update priorities via the legacy method, used w/ ranked approach idxs, priorities = super().update_priorities(idxs, priorities) # Register whether a transition is b/g in the UNREAL-specific sum trees for idx, priority in zipsame(idxs, priorities): # Decide whether the transition to be added is good or bad # Get the rank from the priority # Note: UnrealRB inherits from PER w/ 'ranked' set to True if idx < self.num_demos: # When the transition is from the demos, always set it as 'good' regardless self.b_sum_st[idx] = 0 self.g_sum_st[idx] = 1 else: rank = (1. / priority) - 1 thres = floor(.5 * self.num_entries) is_g = rank < thres is_g *= 1 # HAXX: multiply by 1 to cast the bool into an int # Fill the good and bad sum segment trees w/ the obtained value self.b_sum_st[idx] = 1 - is_g self.g_sum_st[idx] = is_g if debug: # Verify updates # Compute the cardinalities of virtual sub-buffers b_num_entries = self.b_sum_st.sum(end=self.num_entries) g_num_entries = self.g_sum_st.sum(end=self.num_entries) print("[num entries] b: {} | g: {}".format(b_num_entries, g_num_entries)) print("total num entries: {}".format(self.num_entries))
def run(args): """Spawn jobs""" # Create meta-experiment identifier meta = rand_id() # Define experiment type if args.rand: type_exp = 'hpsearch' else: type_exp = 'sweep' # Get hyperparameter configurations if args.rand: # Get a number of random hyperparameter configurations hpmaps = [get_rand_hps(args, meta, args.num_seeds) for _ in range(args.num_rand_trials)] # Flatten into a 1-dim list hpmaps = flatten_lists(hpmaps) else: # Get the deterministic spectrum of specified hyperparameters hpmaps = get_spectrum_hps(args, meta, args.num_seeds) # Create associated task strings exp_strs = [format_exp_str(args, hpmap) for hpmap in hpmaps] if not len(exp_strs) == len(set(exp_strs)): # Terminate in case of duplicate experiment (extremely unlikely though) raise ValueError("bad luck, there are dupes -> Try again :)") # Create the job maps job_maps = [get_job_map(args, meta, i, hpmap['env_id'], hpmap['seed'], '0' if args.task == 'ppo' else hpmap['num_demos'], type_exp) for i, hpmap in enumerate(hpmaps)] # Finally get all the required job strings job_strs = [format_job_str(args, jm, es) for jm, es in zipsame(job_maps, exp_strs)] # Spawn the jobs for i, (jm, js) in enumerate(zipsame(job_maps, job_strs)): print('-' * 10 + "> job #{} launcher content:".format(i)) print(js + "\n") job_name = "{}.sh".format(jm['job-name']) with open(job_name, 'w') as f: f.write(js) if args.call: # Spawn the job! call(["sbatch", "./{}".format(job_name)]) # Summarize the number of jobs spawned print("total num job (successfully) spawned: {}".format(len(job_strs)))
def __init__(self, limit, ob_shape, ac_shape): self.limit = limit self.ob_shape = ob_shape self.ac_shape = ac_shape self.num_demos = 0 self.atom_names = ['obs0', 'acs', 'rews', 'dones1', 'obs1'] self.atom_shapes = [self.ob_shape, self.ac_shape, (1,), (1,), self.ob_shape] # Create one `RingBuffer` object for every atom in a transition self.ring_buffers = {atom_name: RingBuffer(self.limit, atom_shape) for atom_name, atom_shape in zipsame(self.atom_names, self.atom_shapes)}
def update_priorities(self, idxs, priorities): """Update priorities according to the PER paper, i.e. by updating only the priority of sampled transitions. A priority priorities[i] is assigned to the transition at index indices[i]. Note: not in use in the vanilla setting, but here if needed in extensions. """ global debug if self.ranked: # Override the priorities to be 1 / (rank(priority) + 1) # Add new index, priority pairs to the list self.i_p.update({i: p for i, p in zipsame(idxs, priorities)}) # Rank the indices by priorities i_sorted_by_p = sorted(self.i_p.items(), key=lambda t: t[1], reverse=True) # Create the index, rank dict i_r = {i: i_sorted_by_p.index((i, p)) for i, p in self.i_p.items()} # Unpack indices and ranks _idxs, ranks = zipsame(*i_r.items()) # Override the indices and priorities idxs = list(_idxs) priorities = [1. / (rank + 1) for rank in ranks] # start ranks at 1 if debug: # Verify that the priorities have been properly overridden for idx, priority in zipsame(idxs, priorities): print("index: {} | priority: {}".format(idx, priority)) assert len(idxs) == len(priorities), "the two arrays must be the same length" for idx, priority in zipsame(idxs, priorities): assert priority > 0, "priorities must be positive" assert 0 <= idx < self.num_entries, "no element in buffer associated w/ index" if idx < self.num_demos: # Add a priority bonus when replaying a demo priority += self.demos_eps self.sum_st[idx] = priority ** self.alpha self.min_st[idx] = priority ** self.alpha # Update max priority currently in the buffer self.max_priority = max(priority, self.max_priority) if self.ranked: # Return indices and associated overriden priorities # Note: returned values are only used in the UNREAL priority update function return idxs, priorities
def log_module_info(logger, name, *components): assert len(components) > 0, "components list is empty" for component in components: logger.info("logging {}/{} specs".format(name, component.name)) names = [var.name for var in component.trainable_vars] shapes = [var_shape(var) for var in component.trainable_vars] num_paramss = [numel(var) for var in component.trainable_vars] zipped_info = zipsame(names, shapes, num_paramss) logger.info(columnize(names=['name', 'shape', 'num_params'], tuples=zipped_info, widths=[40, 16, 10])) logger.info(" total num params: {}".format(sum(num_paramss)))
def columnize(names, tuples, widths, indent=2): """Generate and return the content of table (w/o logging or printing anything) Args: width (int): Width of each cell in the table indent (int): Indentation spacing prepended to every row in the table """ indent_space = indent * ' ' # Add row containing the names table = indent_space + " | ".join(cell(name, width) for name, width in zipsame(names, widths)) table_width = len(table) # Add header hline table += '\n' + indent_space + ('-' * table_width) for tuple_ in tuples: # Add a new row table += '\n' + indent_space table += " | ".join(cell(value, width) for value, width in zipsame(tuple_, widths)) # Add closing hline table += '\n' + indent_space + ('-' * table_width) return table
def add_demo_transitions_to_mem(self, dset): """Add transitions from expert demonstration trajectories to memory""" # Ensure the replay buffer is empty as demos need to be first assert self.num_entries == 0 and self.num_demos == 0 logger.info("adding demonstrations to memory") # Zip transition atoms transitions = zipsame(dset.obs0, dset.acs, dset.env_rews, dset.obs1, dset.dones1) # Note: careful w/ the order, it should correspond to the order in `append` signature for transition in transitions: self.append(*transition, is_demo=True) self.num_demos += 1 assert self.num_demos == self.num_entries logger.info(" num entries in memory after addition: {}".format(self.num_entries))
def flatgrad(loss, var_list, clip_norm=None): """Returns a list of sum(dy/dx) for each x in `var_list` Clipping is done by global norm (paper: https://arxiv.org/abs/1211.5063) """ grads = tf.gradients(loss, var_list) if clip_norm is not None: grads, _ = tf.clip_by_global_norm(grads, clip_norm=clip_norm) vars_and_grads = zipsame(var_list, grads) # zip with extra security for index, (var, grad) in enumerate(vars_and_grads): # If the gradient gets stopped for some obsure reason, set the grad as zero vector _grad = grad if grad is not None else tf.zeros_like(var) # Reshape the grad into a vector grads[index] = tf.reshape(_grad, [numel(var)]) # return tf.concat(grads, axis=0) return tf.concat(grads, axis=0)
def get_target_updates(vars_, targ_vars, polyak): """Return assignment ops for target network updates. Hard updates are used for initialization only, while soft updates are used throughout the training process, at every iteration. Note that DQN uses hard updates while training, but those updates are not performed every iteration (only once every XX iterations). """ logger.info("setting up target updates") hard_updates = [] soft_updates = [] assert len(vars_) == len(targ_vars) for var_, targ_var in zipsame(vars_, targ_vars): logger.info(' {} <- {}'.format(targ_var.name, var_.name)) hard_updates.append(tf.assign(targ_var, var_)) soft_updates.append( tf.assign(targ_var, (1. - polyak) * targ_var + polyak * var_)) assert len(hard_updates) == len(vars_) assert len(soft_updates) == len(vars_) return tf.group(*hard_updates), tf.group( *soft_updates) # ops that group ops
def get_p_actor_updates(actor, perturbed_actor, pn_std): """Return assignment ops for actor parameters noise perturbations. The perturbations consist in applying additive gaussian noise the the perturbable actor variables, while simply leaving the non-perturbable ones untouched. """ assert len(actor.vars) == len(perturbed_actor.vars) assert len(actor.perturbable_vars) == len(perturbed_actor.perturbable_vars) updates = [] for var_, perturbed_var in zipsame(actor.vars, perturbed_actor.vars): if var_ in actor.perturbable_vars: logger.info(" {} <- {} + noise".format(perturbed_var.name, var_.name)) noised_up_var = var_ + tf.random_normal( tf.shape(var_), mean=0., stddev=pn_std) updates.append(tf.assign(perturbed_var, noised_up_var)) else: logger.info(" {} <- {}".format(perturbed_var.name, var_.name)) updates.append(tf.assign(perturbed_var, var_)) assert len(updates) == len(actor.vars) return tf.group(*updates)
def learn(comm, env, xpo_agent_wrapper, sample_or_mode, gamma, max_kl, save_frequency, ckpt_dir, summary_dir, timesteps_per_batch, batch_size, experiment_name, ent_reg_scale, gae_lambda, cg_iters, cg_damping, vf_iters, vf_lr, max_iters): rank = comm.Get_rank() # Create policies pi = xpo_agent_wrapper('pi') old_pi = xpo_agent_wrapper('old_pi') # Create and retrieve already-existing placeholders ob = get_placeholder_cached(name='ob') ac = pi.pd_type.sample_placeholder([None]) adv = tf.placeholder(name='adv', dtype=tf.float32, shape=[None]) ret = tf.placeholder(name='ret', dtype=tf.float32, shape=[None]) flat_tangent = tf.placeholder(name='flat_tan', dtype=tf.float32, shape=[None]) # Build graphs kl_mean = tf.reduce_mean(old_pi.pd_pred.kl(pi.pd_pred)) ent_mean = tf.reduce_mean(pi.pd_pred.entropy()) ent_bonus = ent_reg_scale * ent_mean vf_err = tf.reduce_mean(tf.square(pi.v_pred - ret)) # MC error # The surrogate objective is defined as: advantage * pnew / pold ratio = tf.exp(pi.pd_pred.logp(ac) - old_pi.pd_pred.logp(ac)) # IS surr_gain = tf.reduce_mean(ratio * adv) # surrogate objective (CPI) # Add entropy bonus optim_gain = surr_gain + ent_bonus losses = OrderedDict() # Add losses losses.update({ 'pol_kl_mean': kl_mean, 'pol_ent_mean': ent_mean, 'pol_ent_bonus': ent_bonus, 'pol_surr_gain': surr_gain, 'pol_optim_gain': optim_gain, 'pol_vf_err': vf_err }) # Build natural gradient material get_flat = GetFlat(pi.pol_trainable_vars) set_from_flat = SetFromFlat(pi.pol_trainable_vars) kl_grads = tf.gradients(kl_mean, pi.pol_trainable_vars) shapes = [var.get_shape().as_list() for var in pi.pol_trainable_vars] start = 0 tangents = [] for shape in shapes: sz = intprod(shape) tangents.append(tf.reshape(flat_tangent[start:start + sz], shape)) start += sz # Create the gradient vector product gvp = tf.add_n([ tf.reduce_sum(g * tangent) for (g, tangent) in zipsame(kl_grads, tangents) ]) # Create the Fisher vector product fvp = flatgrad(gvp, pi.pol_trainable_vars) # Make the current `pi` become the next `old_pi` zipped = zipsame(old_pi.vars, pi.vars) updates_op = [] for k, v in zipped: # Populate list of assignment operations logger.info("assignment: {} <- {}".format(k, v)) assign_op = tf.assign(k, v) updates_op.append(assign_op) assert len(updates_op) == len(pi.vars) # Create mpi adam optimizer for the value function vf_optimizer = MpiAdamOptimizer(comm=comm, clip_norm=5.0, learning_rate=vf_lr, name='vf_adam') optimize_vf = vf_optimizer.minimize(loss=vf_err, var_list=pi.vf_trainable_vars) # Create gradients grads = flatgrad(optim_gain, pi.pol_trainable_vars) # Create callable objects assign_old_eq_new = TheanoFunction(inputs=[], outputs=updates_op) compute_losses = TheanoFunction(inputs=[ob, ac, adv, ret], outputs=list(losses.values())) compute_losses_grads = TheanoFunction(inputs=[ob, ac, adv, ret], outputs=list(losses.values()) + [grads]) compute_fvp = TheanoFunction(inputs=[flat_tangent, ob, ac, adv], outputs=fvp) optimize_vf = TheanoFunction(inputs=[ob, ret], outputs=optimize_vf) # Initialise variables initialize() # Sync params of all processes with the params of the root process theta_init = get_flat() comm.Bcast(theta_init, root=0) set_from_flat(theta_init) vf_optimizer.sync_from_root(pi.vf_trainable_vars) # Create context manager that records the time taken by encapsulated ops timed = timed_cm_wrapper(comm, logger) if rank == 0: # Create summary writer summary_writer = tf.summary.FileWriterCache.get(summary_dir) # Create segment generator seg_gen = traj_segment_generator(env, pi, timesteps_per_batch, sample_or_mode) eps_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 tstart = time.time() # Define rolling buffers for recent stats aggregation maxlen = 100 len_buffer = deque(maxlen=maxlen) env_ret_buffer = deque(maxlen=maxlen) pol_losses_buffer = deque(maxlen=maxlen) while iters_so_far <= max_iters: pretty_iter(logger, iters_so_far) pretty_elapsed(logger, tstart) # Verify that the processes are still in sync if iters_so_far > 0 and iters_so_far % 10 == 0: vf_optimizer.check_synced(pi.vf_trainable_vars) logger.info("vf params still in sync across processes") # Save the model if rank == 0 and iters_so_far % save_frequency == 0 and ckpt_dir is not None: model_path = osp.join(ckpt_dir, experiment_name) save_state(model_path, iters_so_far=iters_so_far) logger.info("saving model") logger.info(" @: {}".format(model_path)) with timed("sampling mini-batch"): seg = seg_gen.__next__() augment_segment_gae_stats(seg, gamma, gae_lambda, rew_key="env_rews") # Standardize advantage function estimate seg['advs'] = (seg['advs'] - seg['advs'].mean()) / (seg['advs'].std() + 1e-8) # Update running mean and std if hasattr(pi, 'obs_rms'): with timed("normalizing obs via rms"): pi.obs_rms.update(seg['obs'], comm) def fisher_vector_product(p): computed_fvp = compute_fvp({ flat_tangent: p, ob: seg['obs'], ac: seg['acs'], adv: seg['advs'] }) return mpi_mean_like(computed_fvp, comm) + cg_damping * p assign_old_eq_new({}) # Compute gradients with timed("computing gradients"): *loss_before, g = compute_losses_grads({ ob: seg['obs'], ac: seg['acs'], adv: seg['advs'], ret: seg['td_lam_rets'] }) loss_before = mpi_mean_like(loss_before, comm) g = mpi_mean_like(g, comm) if np.allclose(g, 0): logger.info("got zero gradient -> not updating") else: with timed("performing conjugate gradient procedure"): step_direction = conjugate_gradient(f_Ax=fisher_vector_product, b=g, cg_iters=cg_iters, verbose=(rank == 0)) assert np.isfinite(step_direction).all() shs = 0.5 * step_direction.dot( fisher_vector_product(step_direction)) # shs is (1/2)*s^T*A*s in the paper lm = np.sqrt(shs / max_kl) # lm is 1/beta in the paper (max_kl is user-specified delta) full_step = step_direction / lm # beta*s expected_improve = g.dot(full_step) # project s on g surr_before = loss_before[4] # 5-th in loss list step_size = 1.0 theta_before = get_flat() with timed("updating policy"): for _ in range( 10): # trying (10 times max) until the stepsize is OK # Update the policy parameters theta_new = theta_before + full_step * step_size set_from_flat(theta_new) pol_losses = compute_losses({ ob: seg['obs'], ac: seg['acs'], adv: seg['advs'], ret: seg['td_lam_rets'] }) pol_losses_buffer.append(pol_losses) pol_losses_mpi_mean = mpi_mean_like(pol_losses, comm) surr = pol_losses_mpi_mean[4] kl = pol_losses_mpi_mean[0] actual_improve = surr - surr_before logger.info(" expected: {:.3f} | actual: {:.3f}".format( expected_improve, actual_improve)) if not np.isfinite(pol_losses_mpi_mean).all(): logger.info(" got non-finite value of losses :(") elif kl > max_kl * 1.5: logger.info( " violated KL constraint -> shrinking step.") elif actual_improve < 0: logger.info( " surrogate didn't improve -> shrinking step.") else: logger.info(" stepsize fine :)") break step_size *= 0.5 # backtracking when the step size is deemed inappropriate else: logger.info(" couldn't compute a good step") set_from_flat(theta_before) # Create Feeder object to iterate over (ob, ret) pairs feeder = Feeder(data_map={ 'obs': seg['obs'], 'td_lam_rets': seg['td_lam_rets'] }, enable_shuffle=True) # Update state-value function with timed("updating value function"): for _ in range(vf_iters): for minibatch in feeder.get_feed(batch_size=batch_size): optimize_vf({ ob: minibatch['obs'], ret: minibatch['td_lam_rets'] }) # Log policy update statistics logger.info("logging pol training losses (log)") pol_losses_np_mean = np.mean(pol_losses_buffer, axis=0) pol_losses_mpi_mean = mpi_mean_reduce(pol_losses_buffer, comm, axis=0) zipped_pol_losses = zipsame(list(losses.keys()), pol_losses_np_mean, pol_losses_mpi_mean) logger.info( columnize(names=['name', 'local', 'global'], tuples=zipped_pol_losses, widths=[20, 16, 16])) # Log statistics logger.info("logging misc training stats (log + csv)") # Gather statistics across workers local_lens_rets = (seg['ep_lens'], seg['ep_env_rets']) gathered_lens_rets = comm.allgather(local_lens_rets) lens, env_rets = map(flatten_lists, zip(*gathered_lens_rets)) # Extend the deques of recorded statistics len_buffer.extend(lens) env_ret_buffer.extend(env_rets) ep_len_mpi_mean = np.mean(len_buffer) ep_env_ret_mpi_mean = np.mean(env_ret_buffer) logger.record_tabular('ep_len_mpi_mean', ep_len_mpi_mean) logger.record_tabular('ep_env_ret_mpi_mean', ep_env_ret_mpi_mean) eps_this_iter = len(lens) timesteps_this_iter = sum(lens) eps_so_far += eps_this_iter timesteps_so_far += timesteps_this_iter eps_this_iter_mpi_mean = mpi_mean_like(eps_this_iter, comm) timesteps_this_iter_mpi_mean = mpi_mean_like(timesteps_this_iter, comm) eps_so_far_mpi_mean = mpi_mean_like(eps_so_far, comm) timesteps_so_far_mpi_mean = mpi_mean_like(timesteps_so_far, comm) logger.record_tabular('eps_this_iter_mpi_mean', eps_this_iter_mpi_mean) logger.record_tabular('timesteps_this_iter_mpi_mean', timesteps_this_iter_mpi_mean) logger.record_tabular('eps_so_far_mpi_mean', eps_so_far_mpi_mean) logger.record_tabular('timesteps_so_far_mpi_mean', timesteps_so_far_mpi_mean) logger.record_tabular('elapsed time', prettify_time(time.time() - tstart)) # no mpi mean logger.record_tabular( 'ev_td_lam_before', explained_variance(seg['vs'], seg['td_lam_rets'])) iters_so_far += 1 if rank == 0: logger.dump_tabular() if rank == 0: # Add summaries summary = tf.summary.Summary() tab = 'trpo' # Episode stats summary.value.add(tag="{}/{}".format(tab, 'mean_ep_len'), simple_value=ep_len_mpi_mean) summary.value.add(tag="{}/{}".format(tab, 'mean_ep_env_ret'), simple_value=ep_env_ret_mpi_mean) # Losses for name, loss in zipsame(list(losses.keys()), pol_losses_mpi_mean): summary.value.add(tag="{}/{}".format(tab, name), simple_value=loss) summary_writer.add_summary(summary, iters_so_far)
def learn(comm, env, xpo_agent_wrapper, sample_or_mode, gamma, save_frequency, ckpt_dir, summary_dir, timesteps_per_batch, batch_size, optim_epochs_per_iter, lr, experiment_name, ent_reg_scale, clipping_eps, gae_lambda, schedule, max_iters): rank = comm.Get_rank() # Create policies pi = xpo_agent_wrapper('pi') old_pi = xpo_agent_wrapper('old_pi') # Create and retrieve already-existing placeholders ob = get_placeholder_cached(name='ob') ac = pi.pd_type.sample_placeholder([None]) adv = tf.placeholder(name='adv', dtype=tf.float32, shape=[None]) ret = tf.placeholder(name='ret', dtype=tf.float32, shape=[None]) # Adaptive learning rate multiplier, updated with schedule lr_mult = tf.placeholder(name='lr_mult', dtype=tf.float32, shape=[]) # Build graphs kl_mean = tf.reduce_mean(old_pi.pd_pred.kl(pi.pd_pred)) ent_mean = tf.reduce_mean(pi.pd_pred.entropy()) ent_pen = (-ent_reg_scale) * ent_mean vf_err = tf.reduce_mean(tf.square(pi.v_pred - ret)) # MC error # The surrogate objective is defined as: advantage * pnew / pold ratio = tf.exp(pi.pd_pred.logp(ac) - old_pi.pd_pred.logp(ac)) # IS surr_gain = ratio * adv # surrogate objective (CPI) # Annealed clipping parameter epsilon clipping_eps = clipping_eps * lr_mult surr_gain_w_clipping = tf.clip_by_value(ratio, 1.0 - clipping_eps, 1.0 + clipping_eps) * adv # PPO's pessimistic surrogate (L^CLIP in paper) surr_loss = -tf.reduce_mean(tf.minimum(surr_gain, surr_gain_w_clipping)) # Assemble losses (including the value function loss) loss = surr_loss + ent_pen + vf_err losses = OrderedDict() # Add losses losses.update({'pol_kl_mean': kl_mean, 'pol_ent_mean': ent_mean, 'pol_ent_pen': ent_pen, 'pol_surr_loss': surr_loss, 'pol_vf_err': vf_err, 'pol_total_loss': loss}) # Make the current `pi` become the next `old_pi` zipped = zipsame(old_pi.vars, pi.vars) updates_op = [] for k, v in zipped: # Populate list of assignment operations logger.info("assignment: {} <- {}".format(k, v)) assign_op = tf.assign(k, v) updates_op.append(assign_op) assert len(updates_op) == len(pi.vars) # Create mpi adam optimizer optimizer = MpiAdamOptimizer(comm=comm, clip_norm=5.0, learning_rate=lr * lr_mult, name='adam') optimize = optimizer.minimize(loss=loss, var_list=pi.trainable_vars) # Create callable objects assign_old_eq_new = TheanoFunction(inputs=[], outputs=updates_op) compute_losses = TheanoFunction(inputs=[ob, ac, adv, ret, lr_mult], outputs=list(losses.values())) optimize = TheanoFunction(inputs=[ob, ac, adv, ret, lr_mult], outputs=optimize) # Initialise variables initialize() # Sync params of all processes with the params of the root process optimizer.sync_from_root(pi.trainable_vars) # Create context manager that records the time taken by encapsulated ops timed = timed_cm_wrapper(comm, logger) if rank == 0: # Create summary writer summary_writer = tf.summary.FileWriterCache.get(summary_dir) seg_gen = traj_segment_generator(env, pi, timesteps_per_batch, sample_or_mode) eps_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 tstart = time.time() # Define rolling buffers for recent stats aggregation maxlen = 100 len_buffer = deque(maxlen=maxlen) env_ret_buffer = deque(maxlen=maxlen) pol_losses_buffer = deque(maxlen=maxlen) while iters_so_far <= max_iters: pretty_iter(logger, iters_so_far) pretty_elapsed(logger, tstart) # Verify that the processes are still in sync if iters_so_far > 0 and iters_so_far % 10 == 0: optimizer.check_synced(pi.trainable_vars) logger.info("params still in sync across processes") # Manage lr multiplier schedule if schedule == 'constant': curr_lr_mult = 1.0 elif schedule == 'linear': curr_lr_mult = max(1.0 - float(iters_so_far * timesteps_per_batch) / max_iters * timesteps_per_batch, 0) else: raise NotImplementedError # Save the model if rank == 0 and iters_so_far % save_frequency == 0 and ckpt_dir is not None: model_path = osp.join(ckpt_dir, experiment_name) save_state(model_path, iters_so_far=iters_so_far) logger.info("saving model") logger.info(" @: {}".format(model_path)) with timed("sampling mini-batch"): seg = seg_gen.__next__() augment_segment_gae_stats(seg, gamma, gae_lambda, rew_key="env_rews") # Standardize advantage function estimate seg['advs'] = (seg['advs'] - seg['advs'].mean()) / (seg['advs'].std() + 1e-8) # Update running mean and std if hasattr(pi, 'obs_rms'): with timed("normalizing obs via rms"): pi.obs_rms.update(seg['obs'], comm) assign_old_eq_new({}) # Create Feeder object to iterate over (ob, ac, adv, td_lam_ret) tuples data_map = {'obs': seg['obs'], 'acs': seg['acs'], 'advs': seg['advs'], 'td_lam_rets': seg['td_lam_rets']} feeder = Feeder(data_map=data_map, enable_shuffle=True) # Update policy and state-value function with timed("updating policy and value function"): for _ in range(optim_epochs_per_iter): for minibatch in feeder.get_feed(batch_size=batch_size): feeds = {ob: minibatch['obs'], ac: minibatch['acs'], adv: minibatch['advs'], ret: minibatch['td_lam_rets'], lr_mult: curr_lr_mult} # Compute losses pol_losses = compute_losses(feeds) # Update the policy and value function optimize(feeds) # Store the losses pol_losses_buffer.append(pol_losses) # Log policy update statistics logger.info("logging training losses (log)") pol_losses_np_mean = np.mean(pol_losses_buffer, axis=0) pol_losses_mpi_mean = mpi_mean_reduce(pol_losses_buffer, comm, axis=0) zipped_pol_losses = zipsame(list(losses.keys()), pol_losses_np_mean, pol_losses_mpi_mean) logger.info(columnize(names=['name', 'local', 'global'], tuples=zipped_pol_losses, widths=[20, 16, 16])) # Log statistics logger.info("logging misc training stats (log + csv)") # Gather statistics across workers local_lens_rets = (seg['ep_lens'], seg['ep_env_rets']) gathered_lens_rets = comm.allgather(local_lens_rets) lens, env_rets = map(flatten_lists, zip(*gathered_lens_rets)) # Extend the deques of recorded statistics len_buffer.extend(lens) env_ret_buffer.extend(env_rets) ep_len_mpi_mean = np.mean(len_buffer) ep_env_ret_mpi_mean = np.mean(env_ret_buffer) logger.record_tabular('ep_len_mpi_mean', ep_len_mpi_mean) logger.record_tabular('ep_env_ret_mpi_mean', ep_env_ret_mpi_mean) eps_this_iter = len(lens) timesteps_this_iter = sum(lens) eps_so_far += eps_this_iter timesteps_so_far += timesteps_this_iter eps_this_iter_mpi_mean = mpi_mean_like(eps_this_iter, comm) timesteps_this_iter_mpi_mean = mpi_mean_like(timesteps_this_iter, comm) eps_so_far_mpi_mean = mpi_mean_like(eps_so_far, comm) timesteps_so_far_mpi_mean = mpi_mean_like(timesteps_so_far, comm) logger.record_tabular('eps_this_iter_mpi_mean', eps_this_iter_mpi_mean) logger.record_tabular('timesteps_this_iter_mpi_mean', timesteps_this_iter_mpi_mean) logger.record_tabular('eps_so_far_mpi_mean', eps_so_far_mpi_mean) logger.record_tabular('timesteps_so_far_mpi_mean', timesteps_so_far_mpi_mean) logger.record_tabular('elapsed time', prettify_time(time.time() - tstart)) # no mpi mean logger.record_tabular('ev_td_lam_before', explained_variance(seg['vs'], seg['td_lam_rets'])) iters_so_far += 1 if rank == 0: logger.dump_tabular() if rank == 0: # Add summaries summary = tf.summary.Summary() tab = 'ppo' # Episode stats summary.value.add(tag="{}/{}".format(tab, 'mean_ep_len'), simple_value=ep_len_mpi_mean) summary.value.add(tag="{}/{}".format(tab, 'mean_ep_env_ret'), simple_value=ep_env_ret_mpi_mean) # Losses for name, loss in zipsame(list(losses.keys()), pol_losses_mpi_mean): summary.value.add(tag="{}/{}".format(tab, name), simple_value=loss) summary_writer.add_summary(summary, iters_so_far)