def columnize(names, tuples, widths, indent=2): """Generate and return the content of table (w/o logging or printing anything) Args: width (int): Width of each cell in the table indent (int): Indentation spacing prepended to every row in the table """ indent_space = indent * ' ' # Add row containing the names table = indent_space + " | ".join(cell(name, width) for name, width in zipsame(names, widths)) table_width = len(table) # Add header hline table += '\n' + indent_space + ('-' * table_width) for tuple_ in tuples: # Add a new row table += '\n' + indent_space table += " | ".join(cell(value, width) for value, width in zipsame(tuple_, widths)) # Add closing hline table += '\n' + indent_space + ('-' * table_width) return table
def log_module_info(self, *components): assert len(components) > 0, "components list is empty" for component in components: logger.info("logging {}/{} specs".format(self.name, component.name)) names = [var.name for var in component.trainable_vars] shapes = [U.var_shape(var) for var in component.trainable_vars] num_paramss = [U.numel(var) for var in component.trainable_vars] zipped_info = zipsame(names, shapes, num_paramss) logger.info( columnize(names=['name', 'shape', 'num_params'], tuples=zipped_info, widths=[36, 16, 10])) logger.info(" total num params: {}".format(sum(num_paramss)))
def flatgrad(loss, var_list, clip_norm=None): """Returns a list of sum(dy/dx) for each x in `var_list` Clipping is done by global norm (paper: https://arxiv.org/abs/1211.5063) """ grads = tf.gradients(loss, var_list) if clip_norm is not None: grads, _ = tf.clip_by_global_norm(grads, clip_norm=clip_norm) vars_and_grads = zipsame(var_list, grads) # zip with extra security for index, (var, grad) in enumerate(vars_and_grads): # If the gradient gets stopped for some obsure reason, set the grad as zero vector _grad = grad if grad is not None else tf.zeros_like(var) # Reshape the grad into a vector grads[index] = tf.reshape(_grad, [numel(var)]) # return tf.concat(grads, axis=0) return tf.concat(grads, axis=0)
def learn(env, policy_func, reward_giver, reward_guidance, expert_dataset, rank, pretrained, pretrained_weight, *, g_step, d_step, entcoeff, save_per_iter, ckpt_dir, log_dir, timesteps_per_batch, task_name, gamma, lam, algo, max_kl, cg_iters, cg_damping=1e-2, vf_stepsize=3e-4, d_stepsize=1e-4, vf_iters=3, max_timesteps=0, max_episodes=0, max_iters=0, loss_percent=0.0, callback=None): nworkers = MPI.COMM_WORLD.Get_size() rank = MPI.COMM_WORLD.Get_rank() np.set_printoptions(precision=3) # Setup losses and stuff # ---------------------------------------- ob_space = env.observation_space ac_space = env.action_space policy = build_policy(env, 'mlp', value_network='copy') ob = observation_placeholder(ob_space) with tf.variable_scope('pi'): pi = policy(observ_placeholder=ob) with tf.variable_scope('oldpi'): oldpi = policy(observ_placeholder=ob) atarg = tf.placeholder( dtype=tf.float32, shape=[None]) # Target advantage function (if applicable) ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return ac = pi.pdtype.sample_placeholder([None]) kloldnew = oldpi.pd.kl(pi.pd) ent = pi.pd.entropy() meankl = tf.reduce_mean(kloldnew) meanent = tf.reduce_mean(ent) entbonus = entcoeff * meanent vferr = tf.reduce_mean(tf.square(pi.vf - ret)) ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold surrgain = tf.reduce_mean(ratio * atarg) optimgain = surrgain + entbonus losses = [optimgain, meankl, entbonus, surrgain, meanent] loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"] dist = meankl all_var_list = get_trainable_variables('pi') # var_list = [v for v in all_var_list if v.name.startswith("pi/pol") or v.name.startswith("pi/logstd")] # vf_var_list = [v for v in all_var_list if v.name.startswith("pi/vff")] var_list = get_pi_trainable_variables("pi") vf_var_list = get_vf_trainable_variables("pi") # assert len(var_list) == len(vf_var_list) + 1 d_adam = MpiAdam(reward_giver.get_trainable_variables()) guidance_adam = MpiAdam(reward_guidance.get_trainable_variables()) vfadam = MpiAdam(vf_var_list) get_flat = U.GetFlat(var_list) set_from_flat = U.SetFromFlat(var_list) klgrads = tf.gradients(dist, var_list) flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan") shapes = [var.get_shape().as_list() for var in var_list] start = 0 tangents = [] for shape in shapes: sz = U.intprod(shape) tangents.append(tf.reshape(flat_tangent[start:start + sz], shape)) start += sz gvp = tf.add_n([ tf.reduce_sum(g * tangent) for (g, tangent) in zipsame(klgrads, tangents) ]) # pylint: disable=E1111 fvp = U.flatgrad(gvp, var_list) assign_old_eq_new = U.function( [], [], updates=[ tf.assign(oldv, newv) for (oldv, newv) in zipsame(get_variables('oldpi'), get_variables('pi')) ]) compute_losses = U.function([ob, ac, atarg], losses) compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)]) compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp) compute_vflossandgrad = U.function([ob, ret], U.flatgrad(vferr, vf_var_list)) @contextmanager def timed(msg): if rank == 0: print(colorize(msg, color='magenta')) tstart = time.time() yield print( colorize("done in %.3f seconds" % (time.time() - tstart), color='magenta')) else: yield def allmean(x): assert isinstance(x, np.ndarray) out = np.empty_like(x) MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM) out /= nworkers return out U.initialize() th_init = get_flat() MPI.COMM_WORLD.Bcast(th_init, root=0) set_from_flat(th_init) d_adam.sync() guidance_adam.sync() vfadam.sync() if rank == 0: print("Init param sum", th_init.sum(), flush=True) # Prepare for rollouts # ---------------------------------------- seg_gen = traj_segment_generator(pi, env, reward_giver, reward_guidance, timesteps_per_batch, stochastic=True, algo=algo, loss_percent=loss_percent) episodes_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 tstart = time.time() lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards true_rewbuffer = deque(maxlen=40) assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1 g_loss_stats = stats(loss_names) d_loss_stats = stats(reward_giver.loss_name) ep_stats = stats(["True_rewards", "Rewards", "Episode_length"]) # if provide pretrained weight if pretrained_weight is not None: U.load_state(pretrained_weight, var_list=pi.get_variables()) while True: if callback: callback(locals(), globals()) if max_timesteps and timesteps_so_far >= max_timesteps: break elif max_episodes and episodes_so_far >= max_episodes: break elif max_iters and iters_so_far >= max_iters: break # Save model if rank == 0 and iters_so_far % save_per_iter == 0 and ckpt_dir is not None: fname = os.path.join(ckpt_dir, task_name) os.makedirs(os.path.dirname(fname), exist_ok=True) saver = tf.train.Saver() saver.save(tf.get_default_session(), fname) logger.log("********** Iteration %i ************" % iters_so_far) # global flag_render # if iters_so_far > 0 and iters_so_far % 10 ==0: # flag_render = True # else: # flag_render = False def fisher_vector_product(p): return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p # ------------------ Update G ------------------ logger.log("Optimizing Policy...") for _ in range(g_step): with timed("sampling"): seg = seg_gen.__next__() print('rewards', seg['rew']) add_vtarg_and_adv(seg, gamma, lam) # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets)) ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[ "tdlamret"] vpredbefore = seg[ "vpred"] # predicted value function before udpate atarg = (atarg - atarg.mean()) / atarg.std( ) # standardized advantage function estimate if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy args = seg["ob"], seg["ac"], atarg fvpargs = [arr[::5] for arr in args] assign_old_eq_new( ) # set old parameter values to new parameter values with timed("computegrad"): *lossbefore, g = compute_lossandgrad(*args) lossbefore = allmean(np.array(lossbefore)) g = allmean(g) if np.allclose(g, 0): logger.log("Got zero gradient. not updating") else: with timed("cg"): stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank == 0) assert np.isfinite(stepdir).all() shs = .5 * stepdir.dot(fisher_vector_product(stepdir)) lm = np.sqrt(shs / max_kl) # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g)) fullstep = stepdir / lm expectedimprove = g.dot(fullstep) surrbefore = lossbefore[0] stepsize = 1.0 thbefore = get_flat() for _ in range(10): thnew = thbefore + fullstep * stepsize set_from_flat(thnew) meanlosses = surr, kl, *_ = allmean( np.array(compute_losses(*args))) improve = surr - surrbefore logger.log("Expected: %.3f Actual: %.3f" % (expectedimprove, improve)) if not np.isfinite(meanlosses).all(): logger.log("Got non-finite value of losses -- bad!") elif kl > max_kl * 1.5: logger.log("violated KL constraint. shrinking step.") elif improve < 0: logger.log("surrogate didn't improve. shrinking step.") else: logger.log("Stepsize OK!") break stepsize *= .5 else: logger.log("couldn't compute a good step") set_from_flat(thbefore) if nworkers > 1 and iters_so_far % 20 == 0: paramsums = MPI.COMM_WORLD.allgather( (thnew.sum(), vfadam.getflat().sum())) # list of tuples assert all( np.allclose(ps, paramsums[0]) for ps in paramsums[1:]) with timed("vf"): for _ in range(vf_iters): for (mbob, mbret) in dataset.iterbatches( (seg["ob"], seg["tdlamret"]), include_final_partial_batch=False, batch_size=128): if hasattr(pi, "ob_rms"): pi.ob_rms.update( mbob) # update running mean/std for policy g = allmean(compute_vflossandgrad(mbob, mbret)) vfadam.update(g, vf_stepsize) g_losses = meanlosses for (lossname, lossval) in zip(loss_names, meanlosses): logger.record_tabular(lossname, lossval) logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret)) # ------------------ Update D ------------------ logger.log("Optimizing Discriminator...") logger.log(fmt_row(13, reward_giver.loss_name)) ob_expert, ac_expert = expert_dataset.get_next_batch( batch_size=len(ob)) batch_size = 128 d_losses = [ ] # list of tuples, each of which gives the loss for a minibatch with timed("Discriminator"): for (ob_batch, ac_batch) in dataset.iterbatches( (ob, ac), include_final_partial_batch=False, batch_size=batch_size): ob_expert, ac_expert = expert_dataset.get_next_batch( batch_size=batch_size) # update running mean/std for reward_giver if hasattr(reward_giver, "obs_rms"): reward_giver.obs_rms.update( np.concatenate((ob_batch, ob_expert), 0)) *newlosses, g = reward_giver.lossandgrad(ob_batch, ob_expert) d_adam.update(allmean(g), d_stepsize) d_losses.append(newlosses) logger.log(fmt_row(13, np.mean(d_losses, axis=0))) # ------------------ Update Guidance ------------ logger.log("Optimizing Guidance...") logger.log(fmt_row(13, reward_guidance.loss_name)) batch_size = 128 guidance_losses = [ ] # list of tuples, each of which gives the loss for a minibatch with timed("Guidance"): for ob_batch, ac_batch in dataset.iterbatches( (ob, ac), include_final_partial_batch=False, batch_size=batch_size): ob_expert, ac_expert = expert_dataset.get_next_batch( batch_size=batch_size) idx_condition = process_expert(ob_expert, ac_expert) pick_idx = (idx_condition >= loss_percent) # pick_idx = idx_condition ob_expert_p = ob_expert[pick_idx] ac_expert_p = ac_expert[pick_idx] ac_batch_p = [] for each_ob in ob_expert_p: tmp_ac, _, _, _ = pi.step(each_ob, stochastic=True) ac_batch_p.append(tmp_ac) # update running mean/std for reward_giver if hasattr(reward_guidance, "obs_rms"): reward_guidance.obs_rms.update(ob_expert_p) # reward_guidance.train(expert_s=ob_batch_p, agent_a=ac_batch_p, expert_a=ac_expert_p) *newlosses, g = reward_guidance.lossandgrad( ob_expert_p, ac_batch_p, ac_expert_p) guidance_adam.update(allmean(g), d_stepsize) guidance_losses.append(newlosses) logger.log(fmt_row(13, np.mean(guidance_losses, axis=0))) lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ep_true_rets"] ) # local values listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples lens, rews, true_rets = map(flatten_lists, zip(*listoflrpairs)) true_rewbuffer.extend(true_rets) lenbuffer.extend(lens) rewbuffer.extend(rews) logger.record_tabular("EpLenMean", np.mean(lenbuffer)) logger.record_tabular("EpRewMean", np.mean(rewbuffer)) logger.record_tabular("EpTrueRewMean", np.mean(true_rewbuffer)) logger.record_tabular("EpThisIter", len(lens)) episodes_so_far += len(lens) timesteps_so_far += sum(lens) * g_step iters_so_far += 1 logger.record_tabular("EpisodesSoFar", episodes_so_far) logger.record_tabular("TimestepsSoFar", timesteps_so_far) logger.record_tabular("TimeElapsed", time.time() - tstart) if rank == 0: logger.dump_tabular()