def train_agt(self, iters, cliprange, lr, obs_n, actions_n, mb_rewards, values_n, advs_n, returns_n, dones_n): # some constants eps = 1e-8 A = self.world.A adv_edges = A[np.unique(np.nonzero(A[:, :self.world.n_adv])[0])] agt_edges = A[np.unique(np.nonzero(A[:, self.world.n_adv:])[0])] # Set old parameter values to new parameter values self.assign_old_eq_new() # Prepare data advs_n = [(advs_n[i] - advs_n[i].mean()) / (advs_n[i].std() + eps) for i in range(self.world.n)] args_n = [(obs_n[i], actions_n[i], advs_n[i], cliprange) for i in range(self.world.n)] # Train good agents loss_bf, logit_bf = list( zip(*[ self.losses_n[i](*args_n[i]) for i in range(self.world.n_adv, self.world.n) ])) for itr in range(self.admm_iter): edge = adv_edges[np.random.choice(range(len(adv_edges)))] q = np.where(edge != 0)[0] k, j = q[0], q[-1] nk, nj = k - self.world.n_adv, k - self.world.n_adv # print('The {}th iteration of adversaries!'.format(itr)) self.pg_train_n[k](*args_n[k], edge[k], nj, lr) self.pg_train_n[j](*args_n[j], edge[j], nk, lr) if len(q) > 1: a_k, p_k = self.to_exchange_n[k](obs_n[k], nj) a_j, p_j = self.to_exchange_n[j](obs_n[j], nk) self.exchange_n[k](j, a_j, p_j, edge[j], obs_n[k], edge[k]) self.exchange_n[j](k, a_k, p_k, edge[k], obs_n[j], edge[j]) # Train value function for _ in range(self.vf_iters): argvs_k = (obs_n[k], returns_n[k], values_n[k]) for (mbob, mbret, mbvl) in dataset.iterbatches( argvs_k, include_final_partial_batch=False, batch_size=64): self.vf_train_n[k](mbob, mbret, mbvl, cliprange, lr) argvs_j = (obs_n[j], returns_n[j], values_n[j]) for (mbob, mbret, mbvl) in dataset.iterbatches( argvs_j, include_final_partial_batch=False, batch_size=64): self.vf_train_n[j](mbob, mbret, mbvl, cliprange, lr) loss_itr = list( zip(*[ self.losses_n[i](*args_n[i]) for i in range(self.world.n_adv, self.world.n) ])) imp = np.array(loss_bf).ravel() - np.array(loss_itr).ravel() print(' Inner iteration {}: {}'.format(itr, imp))
def train(self, ob, ac, batch_size=32, lr=0.001, iter=200): logger.info("Training RND Critic") for _ in range(iter): for data in iterbatches([ob, ac], batch_size=batch_size, include_final_partial_batch=True): self._train(*data, lr)
def train(self, ob, ac, batch_size=32, lr=0.0001, iter=200): logger.info("Training RND Critic") # indices = np.arange(len(ob)) # np.random.shuffle(indices) # inspection_set = [ob[indices[:1000]], ac[indices[:1000]]] # out_of_dist_set = [ob[indices[:1000]], np.random.random(size=(inspection_set[1].shape))] # logger.info("iter, in_dist_loss, out_of_dist_loss") # in_dist_loss = self.get_feature_loss(*inspection_set) # out_of_dist_loss = self.get_feature_loss(*out_of_dist_set) # logger.info("%d,%f,%f"%(0,in_dist_loss,out_of_dist_loss)) for i in tqdm(range(iter)): # for i in range(iter): for data in iterbatches([ob, ac], batch_size=batch_size, include_final_partial_batch=True): self._train(*data, lr)
def update(self, obs, actions, atarg, returns, vpredbefore, nb): # Prepare data obs = tf.constant(obs) actions = tf.constant(actions) atarg = tf.constant(atarg) returns = tf.constant(returns) estimates = tf.constant(self.estimates[nb]) multipliers = tf.constant(self.multipliers[nb]) comm = self.comm_matrix[self.comm_matrix[:, nb] != 0][0, self.agent.id] args, synargs = (obs, actions, atarg), (estimates, multipliers, comm) # Sampling every 5 fvpargs = [arr[::1] for arr in args] def hvp(p): fvp = self.compute_fvp(p, *fvpargs).numpy() jjvp = self.compute_jjvp(p, *fvpargs).numpy() return self.allmean(fvp + jjvp) + self.cg_damping * p self.assign_new_eq_old( ) # set old parameter values to new parameter values with self.timed("computegrad"): g = self.compute_vjp(*args, *synargs).numpy() g = self.allmean(g) lossbefore = self.allmean( np.array(self.compute_losses(*args, *synargs))) if np.allclose(g, 0): logger.log("Got zero gradient. not updating") else: with self.timed("cg"): stepdir = cg(hvp, g, cg_iters=self.cg_iters) assert np.isfinite(stepdir).all() shs = 0.5 * g.dot(stepdir) # shs = .5*stepdir.dot(fvp(stepdir)) lm = np.sqrt(shs / self.max_kl) logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g)) fullstep = stepdir / lm expectedimprove = g.dot(fullstep) lagrangebefore, surrbefore, syncbefore, klbefore, entbonusbefore, meanentbefore = lossbefore stepsize = 1.0 thbefore = self.get_flat() for _ in range(10): thnew = thbefore + fullstep * stepsize self.set_from_flat(thnew) meanlosses = lagrange, surr, syncloss, kl, entbonus, meanent = self.allmean( np.array(self.compute_losses(*args, *synargs))) improve = lagrangebefore - lagrange performance_improve = surr - surrbefore sync_improve = syncbefore - syncloss print(lagrangebefore, surrbefore, syncbefore, meanentbefore) print(lagrange, surr, syncloss, meanent) logger.log("Expected: %.3f Actual: %.3f" % (expectedimprove, improve)) if not np.isfinite(meanlosses).all(): logger.log("Got non-finite value of losses -- bad!") elif kl > self.max_kl * 1.5: logger.log("violated KL constraint. shrinking step.") elif improve < 0: logger.log("surrogate didn't improve. shrinking step.") else: logger.log("Stepsize OK!") break stepsize *= .5 else: logger.log("couldn't compute a good step") self.set_from_flat(thbefore) # with self.timed("vf"): for _ in range(self.vf_iters): for (mbob, mbret) in dataset.iterbatches( (obs, returns), include_final_partial_batch=False, batch_size=64): vg = self.allmean( self.compute_vflossandgrad(mbob, mbret).numpy()) self.vfadam.update(vg, self.vf_stepsize) logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, returns))
def learn(env, policy_func, *, timesteps_per_batch, # what to train on max_kl, cg_iters, gamma, lam, # advantage estimation entcoeff=0.0, cg_damping=1e-2, vf_stepsize=3e-4, vf_iters =3, max_timesteps=0, max_episodes=0, max_iters=0, # time constraint callback=None ): nworkers = MPI.COMM_WORLD.Get_size() rank = MPI.COMM_WORLD.Get_rank() np.set_printoptions(precision=3) # Setup losses and stuff # ---------------------------------------- ob_space = env.observation_space ac_space = env.action_space pi = policy_func("pi", ob_space, ac_space) oldpi = policy_func("oldpi", ob_space, ac_space) atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable) ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return ob = U.get_placeholder_cached(name="ob") ac = pi.pdtype.sample_placeholder([None]) kloldnew = oldpi.pd.kl(pi.pd) ent = pi.pd.entropy() meankl = U.mean(kloldnew) meanent = U.mean(ent) entbonus = entcoeff * meanent vferr = U.mean(tf.square(pi.vpred - ret)) ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold surrgain = U.mean(ratio * atarg) optimgain = surrgain + entbonus losses = [optimgain, meankl, entbonus, surrgain, meanent] loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"] dist = meankl all_var_list = pi.get_trainable_variables() var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("pol")] vf_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("vf")] vfadam = MpiAdam(vf_var_list) get_flat = U.GetFlat(var_list) set_from_flat = U.SetFromFlat(var_list) klgrads = tf.gradients(dist, var_list) flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan") shapes = [var.get_shape().as_list() for var in var_list] start = 0 tangents = [] for shape in shapes: sz = U.intprod(shape) tangents.append(tf.reshape(flat_tangent[start:start+sz], shape)) start += sz gvp = tf.add_n([U.sum(g*tangent) for (g, tangent) in zipsame(klgrads, tangents)]) #pylint: disable=E1111 fvp = U.flatgrad(gvp, var_list) assign_old_eq_new = U.function([],[], updates=[tf.assign(oldv, newv) for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())]) compute_losses = U.function([ob, ac, atarg], losses) compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)]) compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp) compute_vflossandgrad = U.function([ob, ret], U.flatgrad(vferr, vf_var_list)) @contextmanager def timed(msg): if rank == 0: print(colorize(msg, color='magenta')) tstart = time.time() yield print(colorize("done in %.3f seconds"%(time.time() - tstart), color='magenta')) else: yield def allmean(x): assert isinstance(x, np.ndarray) out = np.empty_like(x) MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM) out /= nworkers return out U.initialize() th_init = get_flat() MPI.COMM_WORLD.Bcast(th_init, root=0) set_from_flat(th_init) vfadam.sync() print("Init param sum", th_init.sum(), flush=True) # Prepare for rollouts # ---------------------------------------- seg_gen = traj_segment_generator(pi, env, timesteps_per_batch, stochastic=True) episodes_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 tstart = time.time() lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards assert sum([max_iters>0, max_timesteps>0, max_episodes>0])==1 while True: if callback: callback(locals(), globals()) if max_timesteps and timesteps_so_far >= max_timesteps: break elif max_episodes and episodes_so_far >= max_episodes: break elif max_iters and iters_so_far >= max_iters: break logger.log("********** Iteration %i ************"%iters_so_far) with timed("sampling"): seg = seg_gen.__next__() add_vtarg_and_adv(seg, gamma, lam) # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets)) ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"] vpredbefore = seg["vpred"] # predicted value function before udpate atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret) if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy args = seg["ob"], seg["ac"], atarg fvpargs = [arr[::5] for arr in args] def fisher_vector_product(p): return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p assign_old_eq_new() # set old parameter values to new parameter values with timed("computegrad"): *lossbefore, g = compute_lossandgrad(*args) lossbefore = allmean(np.array(lossbefore)) g = allmean(g) if np.allclose(g, 0): logger.log("Got zero gradient. not updating") else: with timed("cg"): stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank==0) assert np.isfinite(stepdir).all() shs = .5*stepdir.dot(fisher_vector_product(stepdir)) lm = np.sqrt(shs / max_kl) # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g)) fullstep = stepdir / lm expectedimprove = g.dot(fullstep) surrbefore = lossbefore[0] stepsize = 1.0 thbefore = get_flat() for _ in range(10): thnew = thbefore + fullstep * stepsize set_from_flat(thnew) meanlosses = surr, kl, *_ = allmean(np.array(compute_losses(*args))) improve = surr - surrbefore logger.log("Expected: %.3f Actual: %.3f"%(expectedimprove, improve)) if not np.isfinite(meanlosses).all(): logger.log("Got non-finite value of losses -- bad!") elif kl > max_kl * 1.5: logger.log("violated KL constraint. shrinking step.") elif improve < 0: logger.log("surrogate didn't improve. shrinking step.") else: logger.log("Stepsize OK!") break stepsize *= .5 else: logger.log("couldn't compute a good step") set_from_flat(thbefore) if nworkers > 1 and iters_so_far % 20 == 0: paramsums = MPI.COMM_WORLD.allgather((thnew.sum(), vfadam.getflat().sum())) # list of tuples assert all(np.allclose(ps, paramsums[0]) for ps in paramsums[1:]) for (lossname, lossval) in zip(loss_names, meanlosses): logger.record_tabular(lossname, lossval) with timed("vf"): for _ in range(vf_iters): for (mbob, mbret) in dataset.iterbatches((seg["ob"], seg["tdlamret"]), include_final_partial_batch=False, batch_size=64): g = allmean(compute_vflossandgrad(mbob, mbret)) vfadam.update(g, vf_stepsize) logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret)) lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples lens, rews = map(flatten_lists, zip(*listoflrpairs)) lenbuffer.extend(lens) rewbuffer.extend(rews) logger.record_tabular("EpLenMean", np.mean(lenbuffer)) logger.record_tabular("EpRewMean", np.mean(rewbuffer)) logger.record_tabular("EpThisIter", len(lens)) episodes_so_far += len(lens) timesteps_so_far += sum(lens) iters_so_far += 1 logger.record_tabular("EpisodesSoFar", episodes_so_far) logger.record_tabular("TimestepsSoFar", timesteps_so_far) logger.record_tabular("TimeElapsed", time.time() - tstart) if rank==0: logger.dump_tabular()
def learn( env, policy_fn, *, timesteps_per_batch, # what to train on max_kl, cg_iters, gamma, lam, # advantage estimation entc=0.5, cg_damping=1e-2, vf_stepsize=3e-4, vf_iters=3, max_timesteps=0, max_episodes=0, max_iters=0, # time constraint callback=None, i_trial): nworkers = MPI.COMM_WORLD.Get_size() rank = MPI.COMM_WORLD.Get_rank() np.set_printoptions(precision=3) # Setup losses and stuff # ---------------------------------------- ob_space = env.observation_space ac_space = env.action_space pi = policy_fn("pi", ob_space, ac_space) oldpi = policy_fn("oldpi", ob_space, ac_space) atarg = tf.placeholder( dtype=tf.float32, shape=[None]) # Target advantage function (if applicable) ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return ob = U.get_placeholder_cached(name="ob") ac = pi.pdtype.sample_placeholder([None]) entp = tf.placeholder(dtype=tf.float32, shape=[]) kloldnew = oldpi.pd.kl(pi.pd) ent = pi.pd.entropy() meankl = tf.reduce_mean(kloldnew) meanent = tf.reduce_mean(ent) entbonus = entp * meanent vferr = tf.reduce_mean(tf.square(pi.vpred - ret)) ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold surrgain = tf.reduce_mean(ratio * atarg) optimgain = surrgain + entbonus losses = [optimgain, meankl, entbonus, surrgain, meanent] loss_names = ["optimgain", "meankl", "entloss", "surrgain", "loss_ent"] dist = meankl all_var_list = pi.get_trainable_variables() var_list = [ v for v in all_var_list if v.name.split("/")[1].startswith("pol") ] vf_var_list = [ v for v in all_var_list if v.name.split("/")[1].startswith("vf") ] vfadam = MpiAdam(vf_var_list) get_flat = U.GetFlat(var_list) set_from_flat = U.SetFromFlat(var_list) klgrads = tf.gradients(dist, var_list) flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan") shapes = [var.get_shape().as_list() for var in var_list] start = 0 tangents = [] for shape in shapes: sz = U.intprod(shape) tangents.append(tf.reshape(flat_tangent[start:start + sz], shape)) start += sz gvp = tf.add_n([ tf.reduce_sum(g * tangent) for (g, tangent) in zipsame(klgrads, tangents) ]) #pylint: disable=E1111 fvp = U.flatgrad(gvp, var_list) assign_old_eq_new = U.function( [], [], updates=[ tf.assign(oldv, newv) for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables()) ]) compute_losses = U.function([ob, ac, atarg, entp], losses) compute_lossandgrad = U.function([ob, ac, atarg, entp], losses + [U.flatgrad(optimgain, var_list)]) compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp) compute_vflossandgrad = U.function([ob, ret], U.flatgrad(vferr, vf_var_list)) @contextmanager def timed(msg): if rank == 0: print(colorize(msg, color='magenta')) tstart = time.time() yield print( colorize("done in %.3f seconds" % (time.time() - tstart), color='magenta')) else: yield def allmean(x): assert isinstance(x, np.ndarray) out = np.empty_like(x) MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM) out /= nworkers return out U.initialize() tf.global_variables_initializer() th_init = get_flat() MPI.COMM_WORLD.Bcast(th_init, root=0) set_from_flat(th_init) vfadam.sync() print("Init param sum", th_init.sum(), flush=True) # Prepare for rollouts # ---------------------------------------- seg_gen = traj_segment_generator(pi, env, timesteps_per_batch, stochastic=True, gamma=gamma) episodes_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 tstart = time.time() lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards drwdsbuffer = deque(maxlen=40) assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1 while True: if callback: callback(locals(), globals()) if max_timesteps and timesteps_so_far >= max_timesteps: break elif max_episodes and episodes_so_far >= max_episodes: break elif max_iters and iters_so_far >= max_iters: break logger.log("********** Iteration %i ************" % iters_so_far) with timed("sampling"): seg = seg_gen.__next__() add_vtarg_and_adv(seg, gamma, lam) # entcoeff = max(entc - float(iters_so_far) / float(max_iters), 0.01) entcoeff = 0.0 # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets)) ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[ "tdlamret"] vpredbefore = seg["vpred"] # predicted value function before udpate atarg = (atarg - atarg.mean() ) / atarg.std() # standardized advantage function estimate if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret) if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy args = seg["ob"], seg["ac"], atarg fvpargs = [arr[::5] for arr in args] def fisher_vector_product(p): return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p assign_old_eq_new() # set old parameter values to new parameter values with timed("computegrad"): *lossbefore, g = compute_lossandgrad(*args, entcoeff) lossbefore = allmean(np.array(lossbefore)) g = allmean(g) if np.allclose(g, 0): print("Got zero gradient. not updating") else: with timed("cg"): stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank == 0) assert np.isfinite(stepdir).all() shs = .5 * stepdir.dot(fisher_vector_product(stepdir)) lm = np.sqrt(shs / max_kl) # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g)) fullstep = stepdir / lm expectedimprove = g.dot(fullstep) surrbefore = lossbefore[0] stepsize = 1.0 thbefore = get_flat() for _ in range(10): thnew = thbefore + fullstep * stepsize set_from_flat(thnew) meanlosses = surr, kl, *_ = allmean( np.array(compute_losses(*args, entcoeff))) improve = surr - surrbefore print("Expected: %.3f Actual: %.3f" % (expectedimprove, improve)) if not np.isfinite(meanlosses).all(): print("Got non-finite value of losses -- bad!") elif kl > max_kl * 1.5: print("violated KL constraint. shrinking step.") elif improve < 0: print("surrogate didn't improve. shrinking step.") else: print("Stepsize OK!") break stepsize *= .5 else: print("couldn't compute a good step") set_from_flat(thbefore) if nworkers > 1 and iters_so_far % 20 == 0: paramsums = MPI.COMM_WORLD.allgather( (thnew.sum(), vfadam.getflat().sum())) # list of tuples assert all( np.allclose(ps, paramsums[0]) for ps in paramsums[1:]) for (lossname, lossval) in zip(loss_names, meanlosses): logger.logkv(lossname, lossval) with timed("vf"): for _ in range(vf_iters): for (mbob, mbret) in dataset.iterbatches( (seg["ob"], seg["tdlamret"]), include_final_partial_batch=False, batch_size=64): g = allmean(compute_vflossandgrad(mbob, mbret)) vfadam.update(g, vf_stepsize) logger.logkv("ev_tdlam_before", explained_variance(vpredbefore, tdlamret)) lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ep_drwds"] ) # local values listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples lens, rews, drwds = map(flatten_lists, zip(*listoflrpairs)) lenbuffer.extend(lens) rewbuffer.extend(rews) drwdsbuffer.extend(drwds) logger.logkv("EpLenMean", np.mean(lenbuffer)) logger.logkv("EpRewMean", np.mean(rewbuffer)) logger.logkv("EpThisIter", len(lens)) logger.logkv("EpDRewMean", np.mean(drwdsbuffer)) episodes_so_far += len(lens) timesteps_so_far += sum(lens) iters_so_far += 1 logger.logkv("EpisodesSoFar", episodes_so_far) logger.logkv("TimestepsSoFar", timesteps_so_far) logger.logkv("TimeElapsed", time.time() - tstart) logger.logkv('trial', i_trial) logger.logkv("Iteration", iters_so_far) logger.logkv("Name", 'TRPO') if rank == 0: logger.dump_tabular()
def learn( env, policy_func, reward_giver, expert_dataset, rank, pretrained, pretrained_weight, *, # 0 g_step, d_step, entcoeff, save_per_iter, # 1024 ckpt_dir, log_dir, timesteps_per_batch, task_name, robot_name, gamma, lam, max_kl, cg_iters, cg_damping=1e-2, vf_stepsize=3e-4, d_stepsize=3e-4, vf_iters=3, max_timesteps=0, max_episodes=0, max_iters=0, callback=None): nworkers = MPI.COMM_WORLD.Get_size() rank = MPI.COMM_WORLD.Get_rank() np.set_printoptions(precision=3) # Setup losses and stuff # ---------------------------------------- ob_space = env.observation_space ac_space = env.action_space pi = policy_func("pi", ob_space, ac_space, reuse=(pretrained_weight != None)) oldpi = policy_func("oldpi", ob_space, ac_space) atarg = tf.placeholder( dtype=tf.float32, shape=[None]) # Target advantage function (if applicable) ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return ob = U.get_placeholder_cached(name="ob") ac = pi.pdtype.sample_placeholder([None]) kloldnew = oldpi.pd.kl(pi.pd) ent = pi.pd.entropy() meankl = tf.reduce_mean(kloldnew) meanent = tf.reduce_mean(ent) entbonus = entcoeff * meanent vferr = tf.reduce_mean(tf.square(pi.vpred - ret)) ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold surrgain = tf.reduce_mean(ratio * atarg) optimgain = surrgain + entbonus losses = [optimgain, meankl, entbonus, surrgain, meanent] loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"] dist = meankl all_var_list = pi.get_trainable_variables() var_list = [ v for v in all_var_list if v.name.startswith("pi/pol") or v.name.startswith("pi/logstd") ] vf_var_list = [v for v in all_var_list if v.name.startswith("pi/vff")] assert len(var_list) == len(vf_var_list) + 1 d_adam = MpiAdam(reward_giver.get_trainable_variables()) vfadam = MpiAdam(vf_var_list) get_flat = U.GetFlat(var_list) set_from_flat = U.SetFromFlat(var_list) klgrads = tf.gradients(dist, var_list) flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan") shapes = [var.get_shape().as_list() for var in var_list] start = 0 tangents = [] for shape in shapes: sz = U.intprod(shape) tangents.append(tf.reshape(flat_tangent[start:start + sz], shape)) start += sz gvp = tf.add_n([ tf.reduce_sum(g * tangent) for (g, tangent) in zipsame(klgrads, tangents) ]) # pylint: disable=E1111 fvp = U.flatgrad(gvp, var_list) assign_old_eq_new = U.function( [], [], updates=[ tf.assign(oldv, newv) for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables()) ]) compute_losses = U.function([ob, ac, atarg], losses) compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)]) compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp) compute_vflossandgrad = U.function([ob, ret], U.flatgrad(vferr, vf_var_list)) @contextmanager def timed(msg): if rank == 0: print(colorize(msg, color='magenta')) tstart = time.time() yield print( colorize("done in %.3f seconds" % (time.time() - tstart), color='magenta')) else: yield def allmean(x): assert isinstance(x, np.ndarray) out = np.empty_like(x) MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM) out /= nworkers return out U.initialize() th_init = get_flat() MPI.COMM_WORLD.Bcast(th_init, root=0) set_from_flat(th_init) d_adam.sync() vfadam.sync() if rank == 0: print("Init param sum", th_init.sum(), flush=True) # Prepare for rollouts # ---------------------------------------- seg_gen = traj_segment_generator(pi, env, reward_giver, timesteps_per_batch, stochastic=True) episodes_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 tstart = time.time() lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards true_rewbuffer = deque(maxlen=40) assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1 # if provide pretrained weight if pretrained_weight is not None: U.load_state(pretrained_weight, var_list=pi.get_variables()) if robot_name == 'scara': summary_writer = tf.summary.FileWriter( '/home/yue/gym-gazebo/Tensorboard/scara', graph=tf.get_default_graph()) elif robot_name == 'mara': # summary_writer=tf.summary.FileWriter('/home/yue/gym-gazebo/Tensorboard/mara/down-home_position',graph=tf.get_default_graph()) summary_writer = tf.summary.FileWriter( '/home/yue/gym-gazebo/Tensorboard/mara/collisions_model/', graph=tf.get_default_graph()) while True: if callback: callback(locals(), globals()) if max_timesteps and timesteps_so_far >= max_timesteps: break elif max_episodes and episodes_so_far >= max_episodes: break elif max_iters and iters_so_far >= max_iters: break logger.log("********** Iteration %i ************" % iters_so_far) def fisher_vector_product(p): return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p # ------------------ Update G ------------------ logger.log("Optimizing Policy...") for _ in range(g_step): with timed("sampling"): seg = seg_gen.__next__() add_vtarg_and_adv(seg, gamma, lam) # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets)) ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[ "tdlamret"] vpredbefore = seg[ "vpred"] # predicted value function before udpate atarg = (atarg - atarg.mean()) / atarg.std( ) # standardized advantage function estimate if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy args = seg["ob"], seg["ac"], atarg fvpargs = [arr[::5] for arr in args] assign_old_eq_new( ) # set old parameter values to new parameter values with timed("computegrad"): *lossbefore, g = compute_lossandgrad(*args) lossbefore = allmean(np.array(lossbefore)) g = allmean(g) if np.allclose(g, 0): logger.log("Got zero gradient. not updating") else: with timed("cg"): stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank == 0) assert np.isfinite(stepdir).all() shs = .5 * stepdir.dot(fisher_vector_product(stepdir)) lm = np.sqrt(shs / max_kl) # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g)) fullstep = stepdir / lm expectedimprove = g.dot(fullstep) surrbefore = lossbefore[0] stepsize = 1.0 thbefore = get_flat() for _ in range(10): thnew = thbefore + fullstep * stepsize set_from_flat(thnew) meanlosses = surr, kl, *_ = allmean( np.array(compute_losses(*args))) improve = surr - surrbefore logger.log("Expected: %.3f Actual: %.3f" % (expectedimprove, improve)) if not np.isfinite(meanlosses).all(): logger.log("Got non-finite value of losses -- bad!") elif kl > max_kl * 1.5: logger.log("violated KL constraint. shrinking step.") elif improve < 0: logger.log("surrogate didn't improve. shrinking step.") else: logger.log("Stepsize OK!") break stepsize *= .5 else: logger.log("couldn't compute a good step") set_from_flat(thbefore) if nworkers > 1 and iters_so_far % 20 == 0: paramsums = MPI.COMM_WORLD.allgather( (thnew.sum(), vfadam.getflat().sum())) # list of tuples assert all( np.allclose(ps, paramsums[0]) for ps in paramsums[1:]) with timed("vf"): for _ in range(vf_iters): for (mbob, mbret) in dataset.iterbatches( (seg["ob"], seg["tdlamret"]), include_final_partial_batch=False, batch_size=128): if hasattr(pi, "ob_rms"): pi.ob_rms.update( mbob) # update running mean/std for policy if nworkers != 1: g = allmean(compute_vflossandgrad(mbob, mbret)) else: g = compute_vflossandgrad(mbob, mbret) vfadam.update(g, vf_stepsize) for (lossname, lossval) in zip(loss_names, meanlosses): logger.record_tabular(lossname, lossval) logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret)) # ------------------ Update D ------------------ logger.log("Optimizing Discriminator...") logger.log(fmt_row(13, reward_giver.loss_name)) ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob)) batch_size = len(ob) // d_step d_losses = [ ] # list of tuples, each of which gives the loss for a minibatch for ob_batch, ac_batch in dataset.iterbatches( (ob, ac), include_final_partial_batch=False, batch_size=batch_size): ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob_batch)) # update running mean/std for reward_giver if hasattr(reward_giver, "obs_rms"): reward_giver.obs_rms.update( np.concatenate((ob_batch, ob_expert), 0)) *newlosses, g = reward_giver.lossandgrad(ob_batch, ac_batch, ob_expert, ac_expert) if nworkers != 1: d_adam.update(allmean(g), d_stepsize) else: d_adam.update(g, d_stepsize) d_losses.append(newlosses) logger.log(fmt_row(13, np.mean(d_losses, axis=0))) g_loss_summary = tf.Summary(value=[ tf.Summary.Value(tag="g_loss", simple_value=np.mean(d_losses[0][0])) ]) summary_writer.add_summary(g_loss_summary, timesteps_so_far) d_loss_summary = tf.Summary(value=[ tf.Summary.Value(tag="d_loss", simple_value=np.mean(d_losses[0][1])) ]) summary_writer.add_summary(d_loss_summary, timesteps_so_far) entropy_summary = tf.Summary(value=[ tf.Summary.Value(tag="entropy", simple_value=np.mean(d_losses[0][2])) ]) summary_writer.add_summary(entropy_summary, timesteps_so_far) entropy_loss_summary = tf.Summary(value=[ tf.Summary.Value(tag="entropy_loss", simple_value=np.mean(d_losses[0][3])) ]) summary_writer.add_summary(entropy_loss_summary, timesteps_so_far) g_acc_summary = tf.Summary(value=[ tf.Summary.Value(tag="g_acc", simple_value=np.mean(d_losses[0][4])) ]) summary_writer.add_summary(g_acc_summary, timesteps_so_far) expert_acc_summary = tf.Summary(value=[ tf.Summary.Value(tag="expert_acc", simple_value=np.mean(d_losses[0][5])) ]) summary_writer.add_summary(expert_acc_summary, timesteps_so_far) lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ep_true_rets"] ) # local values listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples lens, rews, true_rets = map(flatten_lists, zip(*listoflrpairs)) true_rewbuffer.extend(true_rets) lenbuffer.extend(lens) rewbuffer.extend(rews) summary = tf.Summary(value=[ tf.Summary.Value(tag="MeanDiscriminator", simple_value=np.mean(rewbuffer)) ]) summary_writer.add_summary(summary, timesteps_so_far) truesummary = tf.Summary(value=[ tf.Summary.Value(tag="MeanGenerator", simple_value=np.mean(true_rewbuffer)) ]) summary_writer.add_summary(truesummary, timesteps_so_far) true_rets_summary = tf.Summary(value=[ tf.Summary.Value(tag="Generator", simple_value=np.mean(true_rets)) ]) summary_writer.add_summary(true_rets_summary, timesteps_so_far) len_summary = tf.Summary(value=[ tf.Summary.Value(tag="Length", simple_value=np.mean(lenbuffer)) ]) summary_writer.add_summary(len_summary, timesteps_so_far) optimgain_summary = tf.Summary(value=[ tf.Summary.Value(tag="Optimgain", simple_value=np.mean(meanlosses[0])) ]) summary_writer.add_summary(optimgain_summary, timesteps_so_far) meankl_summary = tf.Summary(value=[ tf.Summary.Value(tag="Meankl", simple_value=np.mean(meanlosses[1])) ]) summary_writer.add_summary(meankl_summary, timesteps_so_far) entloss_summary = tf.Summary(value=[ tf.Summary.Value(tag="Entloss", simple_value=np.mean(meanlosses[2])) ]) summary_writer.add_summary(entloss_summary, timesteps_so_far) surrgain_summary = tf.Summary(value=[ tf.Summary.Value(tag="Surrgain", simple_value=np.mean(meanlosses[3])) ]) summary_writer.add_summary(surrgain_summary, timesteps_so_far) entropy_summary = tf.Summary(value=[ tf.Summary.Value(tag="Entropy", simple_value=np.mean(meanlosses[4])) ]) summary_writer.add_summary(entropy_summary, timesteps_so_far) epThisIter_summary = tf.Summary(value=[ tf.Summary.Value(tag="EpThisIter", simple_value=np.mean(len(lens))) ]) summary_writer.add_summary(epThisIter_summary, timesteps_so_far) logger.record_tabular("EpLenMean", np.mean(lenbuffer)) logger.record_tabular("MeanDiscriminator", np.mean(rewbuffer)) # Save model if robot_name == 'scara': if iters_so_far % save_per_iter == 0: if np.mean(rewbuffer) <= 200 or np.mean( true_rewbuffer) >= -100: task_name = str(iters_so_far) fname = os.path.join(ckpt_dir, task_name) os.makedirs(os.path.dirname(fname), exist_ok=True) saver = tf.train.Saver() saver.save(tf.get_default_session(), fname) if iters_so_far == 2000: break elif robot_name == 'mara': if iters_so_far % save_per_iter == 0: # if np.mean(rewbuffer) <= 300 or np.mean(true_rewbuffer) >= -400: task_name = str(iters_so_far) fname = os.path.join(ckpt_dir, task_name) os.makedirs(os.path.dirname(fname), exist_ok=True) saver = tf.train.Saver() saver.save(tf.get_default_session(), fname) if iters_so_far == 5000: break logger.record_tabular("MeanGenerator", np.mean(true_rewbuffer)) logger.record_tabular("EpThisIter", len(lens)) episodes_so_far += len(lens) timesteps_so_far += sum(lens) iters_so_far += 1 logger.record_tabular("EpisodesSoFar", episodes_so_far) logger.record_tabular("TimestepsSoFar", timesteps_so_far) logger.record_tabular("TimeElapsed", time.time() - tstart) if rank == 0: logger.dump_tabular()
def learn( *, network, env, seed=None, beta, total_timesteps, sil_update, sil_loss, timesteps_per_batch=2048, # what to train on epsilon=0.01, cg_iters=10, gamma=0.99, lam=0.98, # advantage estimation entcoeff=0.0, lr=3e-4, cg_damping=0.1, vf_stepsize=1e-3, vf_iters=5, sil_value=0.01, sil_alpha=0.6, sil_beta=0.1, max_episodes=0, max_iters=0, # time constraint callback=None, save_interval=0, load_path=None, model_fn=None, update_fn=None, init_fn=None, mpi_rank_weight=1, comm=None, vf_coef=0.5, max_grad_norm=0.5, log_interval=1, nminibatches=4, noptepochs=4, cliprange=0.2, TRPO=False, **network_kwargs): set_global_seeds(seed) if isinstance(lr, float): lr = constfn(lr) else: assert callable(lr) if isinstance(cliprange, float): cliprange = constfn(cliprange) else: assert callable(cliprange) nworkers = MPI.COMM_WORLD.Get_size() rank = MPI.COMM_WORLD.Get_rank() policy = build_policy(env, network, value_network='copy', copos=True, **network_kwargs) nenvs = env.num_envs np.set_printoptions(precision=3) # Setup losses and stuff # ---------------------------------------- ob_space = env.observation_space ac_space = env.action_space nbatch = nenvs * timesteps_per_batch nbatch_train = nbatch // nminibatches is_mpi_root = (MPI is None or MPI.COMM_WORLD.Get_rank() == 0) if model_fn is None: model_fn = Model discrete_ac_space = isinstance(ac_space, gym.spaces.Discrete) ob = observation_placeholder(ob_space) with tf.variable_scope("pi", reuse=tf.AUTO_REUSE): pi = policy(observ_placeholder=ob) #sil_model=policy(None, None, sess=get_session) make_model = lambda: Model( policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train, nsteps=timesteps_per_batch, ent_coef=entcoeff, vf_coef=vf_coef, max_grad_norm=max_grad_norm, sil_update=sil_update, sil_value=sil_value, sil_alpha=sil_alpha, sil_beta=sil_beta, sil_loss=sil_loss, # fn_reward=env.process_reward, fn_reward=None, # fn_obs=env.process_obs, fn_obs=None, ppo=False, prev_pi='pi', silm=pi) model = make_model() if load_path is not None: model.load(load_path) with tf.variable_scope("oldpi", reuse=tf.AUTO_REUSE): oldpi = policy(observ_placeholder=ob) make_old_model = lambda: Model( policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train, nsteps=timesteps_per_batch, ent_coef=entcoeff, vf_coef=vf_coef, max_grad_norm=max_grad_norm, sil_update=sil_update, sil_value=sil_value, sil_alpha=sil_alpha, sil_beta=sil_beta, sil_loss=sil_loss, # fn_reward=env.process_reward, fn_reward=None, # fn_obs=env.process_obs, fn_obs=None, ppo=False, prev_pi='oldpi', silm=oldpi) old_model = make_old_model() atarg = tf.placeholder( dtype=tf.float32, shape=[None]) # Target advantage function (if applicable) ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return ac = pi.pdtype.sample_placeholder([None]) kloldnew = oldpi.pd.kl(pi.pd) ent = pi.pd.entropy() old_entropy = oldpi.pd.entropy() meankl = tf.reduce_mean(kloldnew) meanent = tf.reduce_mean(ent) entbonus = entcoeff * meanent vferr = tf.reduce_mean(tf.square(pi.vf - ret)) ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold surrgain = tf.reduce_mean(ratio * atarg) optimgain = surrgain + entbonus losses = [optimgain, meankl, entbonus, surrgain, meanent] loss_names = ["optimgain", "meankl", "entloss", "surrgain", "Entropy"] dist = meankl #all_var_list = pi.get_trainable_variables() #all_var_list = [v for v in all_var_list if v.name.split("/")[0].startswith("pi")] #var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("pol")] #vf_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("vf")] all_var_list = get_trainable_variables("pi") var_list = get_pi_trainable_variables("pi") vf_var_list = get_vf_trainable_variables("pi") vfadam = MpiAdam(vf_var_list) get_flat = U.GetFlat(var_list) set_from_flat = U.SetFromFlat(var_list) klgrads = tf.gradients(dist, var_list) flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan") shapes = [var.get_shape().as_list() for var in var_list] start = 0 tangents = [] for shape in shapes: sz = U.intprod(shape) tangents.append(tf.reshape(flat_tangent[start:start + sz], shape)) start += sz #????gvp and fvp??? gvp = tf.add_n([ tf.reduce_sum(g * tangent) for (g, tangent) in zipsame(klgrads, tangents) ]) #pylint: disable=E1111 fvp = U.flatgrad(gvp, var_list) assign_old_eq_new = U.function( [], [], updates=[ tf.assign(oldv, newv) for (oldv, newv) in zipsame(get_variables("oldpi"), get_variables("pi")) ]) compute_losses = U.function([ob, ac, atarg], losses) compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)]) compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp) compute_vflossandgrad = U.function([ob, ret], U.flatgrad(vferr, vf_var_list)) @contextmanager def timed(msg): if rank == 0: print(colorize(msg, color='magenta')) tstart = time.time() yield print( colorize("done in %.3f seconds" % (time.time() - tstart), color='magenta')) else: yield def allmean(x): assert isinstance(x, np.ndarray) out = np.empty_like(x) MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM) out /= nworkers return out U.initialize() if load_path is not None: pi.load(load_path) th_init = get_flat() if MPI is not None: MPI.COMM_WORLD.Bcast(th_init, root=0) set_from_flat(th_init) vfadam.sync() print("Init param sum", th_init.sum(), flush=True) # Initialize eta, omega optimizer if discrete_ac_space: init_eta = 1 init_omega = 0.5 eta_omega_optimizer = EtaOmegaOptimizerDiscrete( beta, epsilon, init_eta, init_omega) else: init_eta = 0.5 init_omega = 2.0 #????eta_omega_optimizer details????? eta_omega_optimizer = EtaOmegaOptimizer(beta, epsilon, init_eta, init_omega) # Prepare for rollouts # ---------------------------------------- seg_gen = traj_segment_generator(env, timesteps_per_batch, model, stochastic=True) episodes_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 tstart = time.time() lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards assert sum([max_iters > 0, total_timesteps > 0, max_episodes > 0]) == 1 while True: if callback: callback(locals(), globals()) if total_timesteps and timesteps_so_far >= total_timesteps: break elif max_episodes and episodes_so_far >= max_episodes: break elif max_iters and iters_so_far >= max_iters: break logger.log("********** Iteration %i ************" % iters_so_far) with timed("sampling"): seg = seg_gen.__next__() add_vtarg_and_adv(seg, gamma, lam) ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[ "tdlamret"] vpredbefore = seg["vpred"] # predicted value function before udpate model = seg["model"] atarg = (atarg - atarg.mean() ) / atarg.std() # standardized advantage function estimate #print(ob[:20]) #print(ac[:20]) if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret) if hasattr(pi, "rms"): pi.rms.update(ob) # update running mean/std for policy args = seg["ob"], seg["ac"], atarg fvpargs = [arr[::5] for arr in args] def fisher_vector_product(p): return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p assign_old_eq_new() # set old parameter values to new parameter values with timed("computegrad"): *lossbefore, g = compute_lossandgrad(*args) lossbefore = allmean(np.array(lossbefore)) g = allmean(g) if np.allclose(g, 0): logger.log("Got zero gradient. not updating") else: with timed("cg"): stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank == 0) assert np.isfinite(stepdir).all() if TRPO: # # TRPO specific code. # Find correct step size using line search # shs = .5 * stepdir.dot(fisher_vector_product(stepdir)) lm = np.sqrt(shs / epsilon) # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g)) fullstep = stepdir / lm expectedimprove = g.dot(fullstep) surrbefore = lossbefore[0] stepsize = 1.0 thbefore = get_flat() for _ in range(10): thnew = thbefore + fullstep * stepsize set_from_flat(thnew) meanlosses = surr, kl, *_ = allmean( np.array(compute_losses(*args))) improve = surr - surrbefore logger.log("Expected: %.3f Actual: %.3f" % (expectedimprove, improve)) if not np.isfinite(meanlosses).all(): logger.log("Got non-finite value of losses -- bad!") elif kl > epsilon * 1.5: logger.log("violated KL constraint. shrinking step.") elif improve < 0: logger.log("surrogate didn't improve. shrinking step.") else: logger.log("Stepsize OK!") break stepsize *= .5 else: logger.log("couldn't compute a good step") set_from_flat(thbefore) else: # # COPOS specific implementation. # copos_update_dir = stepdir # Split direction into log-linear 'w_theta' and non-linear 'w_beta' parts w_theta, w_beta = pi.split_w(copos_update_dir) tmp_ob = np.zeros( (1, ) + env.observation_space.shape ) # We assume that entropy does not depend on the NN # Optimize eta and omega if discrete_ac_space: entropy = lossbefore[4] #entropy = - 1/timesteps_per_batch * np.sum(np.sum(pi.get_action_prob(ob) * pi.get_log_action_prob(ob), axis=1)) eta, omega = eta_omega_optimizer.optimize( pi.compute_F_w(ob, copos_update_dir), pi.get_log_action_prob(ob), timesteps_per_batch, entropy) else: Waa, Wsa = pi.w2W(w_theta) wa = pi.get_wa(ob, w_beta) varphis = pi.get_varphis(ob) #old_ent = old_entropy.eval({oldpi.ob: tmp_ob})[0] old_ent = lossbefore[4] eta, omega = eta_omega_optimizer.optimize( w_theta, Waa, Wsa, wa, varphis, pi.get_kt(), pi.get_prec_matrix(), pi.is_new_policy_valid, old_ent) logger.log("Initial eta: " + str(eta) + " and omega: " + str(omega)) current_theta_beta = get_flat() prev_theta, prev_beta = pi.all_to_theta_beta( current_theta_beta) if discrete_ac_space: # Do a line search for both theta and beta parameters by adjusting only eta eta = eta_search(w_theta, w_beta, eta, omega, allmean, compute_losses, get_flat, set_from_flat, pi, epsilon, args, discrete_ac_space) logger.log("Updated eta, eta: " + str(eta)) set_from_flat(pi.theta_beta_to_all(prev_theta, prev_beta)) # Find proper omega for new eta. Use old policy parameters first. eta, omega = eta_omega_optimizer.optimize( pi.compute_F_w(ob, copos_update_dir), pi.get_log_action_prob(ob), timesteps_per_batch, entropy, eta) logger.log("Updated omega, eta: " + str(eta) + " and omega: " + str(omega)) # do line search for ratio for non-linear "beta" parameter values #ratio = beta_ratio_line_search(w_theta, w_beta, eta, omega, allmean, compute_losses, get_flat, set_from_flat, pi, # epsilon, beta, args) # set ratio to 1 if we do not use beta ratio line search ratio = 1 #print("ratio from line search: " + str(ratio)) cur_theta = (eta * prev_theta + w_theta.reshape(-1, )) / (eta + omega) cur_beta = prev_beta + ratio * w_beta.reshape(-1, ) / eta else: for i in range(2): # Do a line search for both theta and beta parameters by adjusting only eta eta = eta_search(w_theta, w_beta, eta, omega, allmean, compute_losses, get_flat, set_from_flat, pi, epsilon, args) logger.log("Updated eta, eta: " + str(eta) + " and omega: " + str(omega)) # Find proper omega for new eta. Use old policy parameters first. set_from_flat( pi.theta_beta_to_all(prev_theta, prev_beta)) eta, omega = \ eta_omega_optimizer.optimize(w_theta, Waa, Wsa, wa, varphis, pi.get_kt(), pi.get_prec_matrix(), pi.is_new_policy_valid, old_ent, eta) logger.log("Updated omega, eta: " + str(eta) + " and omega: " + str(omega)) # Use final policy logger.log("Final eta: " + str(eta) + " and omega: " + str(omega)) cur_theta = (eta * prev_theta + w_theta.reshape(-1, )) / (eta + omega) cur_beta = prev_beta + w_beta.reshape(-1, ) / eta set_from_flat(pi.theta_beta_to_all(cur_theta, cur_beta)) meanlosses = surr, kl, *_ = allmean( np.array(compute_losses(*args))) ##copos specific over if nworkers > 1 and iters_so_far % 20 == 0: paramsums = MPI.COMM_WORLD.allgather( (thnew.sum(), vfadam.getflat().sum())) # list of tuples assert all( np.allclose(ps, paramsums[0]) for ps in paramsums[1:]) #cg over for (lossname, lossval) in zip(loss_names, meanlosses): logger.record_tabular(lossname, lossval) #policy update over with timed("vf"): for _ in range(vf_iters): for (mbob, mbret) in dataset.iterbatches( (seg["ob"], seg["tdlamret"]), include_final_partial_batch=False, batch_size=64): g = allmean(compute_vflossandgrad(mbob, mbret)) vfadam.update(g, vf_stepsize) with timed('SIL'): lrnow = lr(1.0 - timesteps_so_far / total_timesteps) l_loss, sil_adv, sil_samples, sil_nlogp = model.sil_train(lrnow) logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret)) lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples lens, rews = map(flatten_lists, zip(*listoflrpairs)) lenbuffer.extend(lens) rewbuffer.extend(rews) print("Reward max: " + str(max(rewbuffer))) print("Reward min: " + str(min(rewbuffer))) logger.record_tabular( "EpLenMean", np.mean(lenbuffer) if np.sum(lenbuffer) != 0.0 else 0.0) logger.record_tabular( "EpRewMean", np.mean(rewbuffer) if np.sum(rewbuffer) != 0.0 else 0.0) logger.record_tabular( "AverageReturn", np.mean(rewbuffer) if np.sum(rewbuffer) != 0.0 else 0.0) logger.record_tabular("EpThisIter", len(lens)) episodes_so_far += len(lens) timesteps_so_far += sum(lens) iters_so_far += 1 logger.record_tabular("EpisodesSoFar", episodes_so_far) logger.record_tabular("TimestepsSoFar", timesteps_so_far) logger.record_tabular("TimeElapsed", time.time() - tstart) if sil_update > 0: logger.record_tabular("SilSamples", sil_samples) if rank == 0: logger.dump_tabular()
def learn( env, agent, reward_giver, expert_dataset, g_step, d_step, d_stepsize=3e-4, timesteps_per_batch=1024, nb_train_steps=50, max_timesteps=0, max_iters=0, # TODO: max_episodes callback=None, d_adam=None, sess=None, saver=None): nworkers = MPI.COMM_WORLD.Get_size() rank = MPI.COMM_WORLD.Get_rank() np.set_printoptions(precision=3) # Prepare for rollouts # ---------------------------------------- timesteps_so_far = 0 iters_so_far = 0 assert sum([max_iters > 0, max_timesteps > 0]) == 1 # TODO: implicit policy does not admit pretraining? # set up record policy_losses_record = {} discriminator_losses_record = {} while True: if max_timesteps and timesteps_so_far >= max_timesteps: break #elif max_episodes and episodes_so_far >= max_episodes: # break elif max_iters and iters_so_far >= max_iters: break logger.log("********** Iteration %i ************" % iters_so_far) logger.log("********** Steps %i ************" % timesteps_so_far) # ------------------ Update G ------------------ logger.log("Optimizing Policy...") ob_policy, ac_policy, losses_record = train_one_batch( env, agent, reward_giver, timesteps_per_batch, nb_train_steps, g_step) assert len(ob_policy) == len(ac_policy) == timesteps_per_batch * g_step # ------------------ Update D ------------------ logger.log("Optimizing Discriminator...") logger.log(fmt_row(13, reward_giver.loss_name)) ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob_policy)) batch_size = len(ob_policy) // d_step d_losses = [ ] # list of tuples, each of which gives the loss for a minibatch for ob_batch, ac_batch in dataset.iterbatches( (ob_policy, ac_policy), include_final_partial_batch=False, batch_size=batch_size): ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob_batch)) # update running mean/std for reward_giver if hasattr(reward_giver, "obs_rms"): reward_giver.obs_rms.update( np.concatenate((ob_batch, ob_expert), 0)) *newlosses, g = reward_giver.lossandgrad(ob_batch, ac_batch, ob_expert, ac_expert) d_adam.update(allmean(g, nworkers), d_stepsize) d_losses.append(newlosses) logger.log(fmt_row(13, np.mean(d_losses, axis=0))) timesteps_so_far += timesteps_per_batch * g_step iters_so_far += 1 # record for k, v in losses_record.items(): if k in policy_losses_record.keys(): policy_losses_record[k] += v else: policy_losses_record[k] = v for idx, k in enumerate(reward_giver.loss_name): if k in discriminator_losses_record.keys(): discriminator_losses_record[k] += [ np.mean(d_losses, axis=0)[idx] ] else: discriminator_losses_record[k] = [ np.mean(d_losses, axis=0)[idx] ] # logging logger.record_tabular("Epoch Actor Losses", np.mean(losses_record['actor_loss'])) logger.record_tabular("Epoch Critic Losses", np.mean(losses_record['critic_loss'])) logger.record_tabular("Epoch Classifier Losses", np.mean(losses_record['classifier_loss'])) logger.record_tabular("Epoch Entropy", np.mean(losses_record['entropy'])) if rank == 0: logger.dump_tabular() # Call callback if callback is not None: callback(locals(), globals())
def learn( env, policy_func, reward_giver, semi_dataset, rank, pretrained_weight, *, g_step, d_step, entcoeff, save_per_iter, ckpt_dir, log_dir, timesteps_per_batch, task_name, gamma, lam, max_kl, cg_iters, cg_damping=1e-2, vf_stepsize=3e-4, d_stepsize=3e-4, vf_iters=3, max_timesteps=0, max_episodes=0, max_iters=0, vf_batchsize=128, callback=None, freeze_g=False, freeze_d=False, pretrained_il=None, pretrained_semi=None, semi_loss=False, expert_reward_threshold=None, # filter experts based on reward expert_label=get_semi_prefix(), sparse_reward=False # filter experts based on success flag (sparse reward) ): semi_loss = semi_loss and semi_dataset is not None l2_w = 0.1 nworkers = MPI.COMM_WORLD.Get_size() rank = MPI.COMM_WORLD.Get_rank() np.set_printoptions(precision=3) if rank == 0: writer = U.file_writer(log_dir) # print all the hyperparameters in the log... log_dict = { # "expert trajectories": expert_dataset.num_traj, "expert model": pretrained_semi, "algo": "trpo", "threads": nworkers, "timesteps_per_batch": timesteps_per_batch, "timesteps_per_thread": -(-timesteps_per_batch // nworkers), "entcoeff": entcoeff, "vf_iters": vf_iters, "vf_batchsize": vf_batchsize, "vf_stepsize": vf_stepsize, "d_stepsize": d_stepsize, "g_step": g_step, "d_step": d_step, "max_kl": max_kl, "gamma": gamma, "lam": lam, } if semi_dataset is not None: log_dict["semi trajectories"] = semi_dataset.num_traj if hasattr(semi_dataset, 'info'): log_dict["semi_dataset_info"] = semi_dataset.info if expert_reward_threshold is not None: log_dict["expert reward threshold"] = expert_reward_threshold log_dict["sparse reward"] = sparse_reward # print them all together for csv logger.log(",".join([str(elem) for elem in log_dict])) logger.log(",".join([str(elem) for elem in log_dict.values()])) # also print them separately for easy reading: for elem in log_dict: logger.log(str(elem) + ": " + str(log_dict[elem])) # divide the timesteps to the threads timesteps_per_batch = -(-timesteps_per_batch // nworkers ) # get ceil of division # Setup losses and stuff # ---------------------------------------- ob_space = OrderedDict([(label, env[label].observation_space) for label in env]) if semi_dataset and get_semi_prefix() in env: # semi ob space is different semi_obs_space = semi_ob_space(env[get_semi_prefix()], semi_size=semi_dataset.semi_size) ob_space[get_semi_prefix()] = semi_obs_space else: print("no semi dataset") # raise RuntimeError vf_stepsize = {label: vf_stepsize for label in env} ac_space = {label: env[label].action_space for label in ob_space} pi = { label: policy_func("pi", ob_space=ob_space[label], ac_space=ac_space[label], prefix=label) for label in ob_space } oldpi = { label: policy_func("oldpi", ob_space=ob_space[label], ac_space=ac_space[label], prefix=label) for label in ob_space } atarg = { label: tf.placeholder(dtype=tf.float32, shape=[None]) for label in ob_space } # Target advantage function (if applicable) ret = { label: tf.placeholder(dtype=tf.float32, shape=[None]) for label in ob_space } # Empirical return ob = { label: U.get_placeholder_cached(name=label + "ob") for label in ob_space } ac = { label: pi[label].pdtype.sample_placeholder([None]) for label in ob_space } kloldnew = {label: oldpi[label].pd.kl(pi[label].pd) for label in ob_space} ent = {label: pi[label].pd.entropy() for label in ob_space} meankl = {label: tf.reduce_mean(kloldnew[label]) for label in ob_space} meanent = {label: tf.reduce_mean(ent[label]) for label in ob_space} entbonus = {label: entcoeff * meanent[label] for label in ob_space} vferr = { label: tf.reduce_mean(tf.square(pi[label].vpred - ret[label])) for label in ob_space } ratio = { label: tf.exp(pi[label].pd.logp(ac[label]) - oldpi[label].pd.logp(ac[label])) for label in ob_space } # advantage * pnew / pold surrgain = { label: tf.reduce_mean(ratio[label] * atarg[label]) for label in ob_space } optimgain = { label: surrgain[label] + entbonus[label] for label in ob_space } losses = { label: [ optimgain[label], meankl[label], entbonus[label], surrgain[label], meanent[label] ] for label in ob_space } loss_names = { label: [ label + name for name in ["optimgain", "meankl", "entloss", "surrgain", "entropy"] ] for label in ob_space } vf_losses = {label: [vferr[label]] for label in ob_space} vf_loss_names = {label: [label + "vf_loss"] for label in ob_space} dist = {label: meankl[label] for label in ob_space} all_var_list = { label: pi[label].get_trainable_variables() for label in ob_space } var_list = { label: [ v for v in all_var_list[label] if "pol" in v.name or "logstd" in v.name ] for label in ob_space } vf_var_list = { label: [v for v in all_var_list[label] if "vf" in v.name] for label in ob_space } for label in ob_space: assert len(var_list[label]) == len(vf_var_list[label]) + 1 get_flat = {label: U.GetFlat(var_list[label]) for label in ob_space} set_from_flat = { label: U.SetFromFlat(var_list[label]) for label in ob_space } klgrads = { label: tf.gradients(dist[label], var_list[label]) for label in ob_space } flat_tangent = { label: tf.placeholder(dtype=tf.float32, shape=[None], name=label + "flat_tan") for label in ob_space } fvp = {} for label in ob_space: shapes = [var.get_shape().as_list() for var in var_list[label]] start = 0 tangents = [] for shape in shapes: sz = U.intprod(shape) tangents.append( tf.reshape(flat_tangent[label][start:start + sz], shape)) start += sz gvp = tf.add_n([ tf.reduce_sum(g * tangent) for (g, tangent) in zipsame(klgrads[label], tangents) ]) # pylint: disable=E1111 fvp[label] = U.flatgrad(gvp, var_list[label]) assign_old_eq_new = { label: U.function([], [], updates=[ tf.assign(oldv, newv) for (oldv, newv) in zipsame(oldpi[label].get_variables(), pi[label].get_variables()) ]) for label in ob_space } compute_losses = { label: U.function([ob[label], ac[label], atarg[label]], losses[label]) for label in ob_space } compute_vf_losses = { label: U.function([ob[label], ac[label], atarg[label], ret[label]], losses[label] + vf_losses[label]) for label in ob_space } compute_lossandgrad = { label: U.function([ob[label], ac[label], atarg[label]], losses[label] + [U.flatgrad(optimgain[label], var_list[label])]) for label in ob_space } compute_fvp = { label: U.function([flat_tangent[label], ob[label], ac[label], atarg[label]], fvp[label]) for label in ob_space } compute_vflossandgrad = { label: U.function([ob[label], ret[label]], vf_losses[label] + [U.flatgrad(vferr[label], vf_var_list[label])]) for label in ob_space } @contextmanager def timed(msg): if rank == 0: print(colorize(msg, color='magenta')) tstart = time.time() yield print( colorize("done in %.3f seconds" % (time.time() - tstart), color='magenta')) else: yield def allmean(x): assert isinstance(x, np.ndarray) out = np.empty_like(x) MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM) out /= nworkers return out episodes_so_far = {label: 0 for label in ob_space} timesteps_so_far = {label: 0 for label in ob_space} iters_so_far = 0 tstart = time.time() lenbuffer = {label: deque(maxlen=40) for label in ob_space} # rolling buffer for episode lengths rewbuffer = {label: deque(maxlen=40) for label in ob_space} # rolling buffer for episode rewards true_rewbuffer = {label: deque(maxlen=40) for label in ob_space} success_buffer = {label: deque(maxlen=40) for label in ob_space} # L2 only for semi network l2_rewbuffer = deque( maxlen=40) if semi_loss and semi_dataset is not None else None total_rewbuffer = deque( maxlen=40) if semi_loss and semi_dataset is not None else None not_update = 1 if not freeze_d else 0 # do not update G before D the first time loaded = False # if provide pretrained weight if not U.load_checkpoint_variables(pretrained_weight, include_no_prefix_vars=True): # if no general checkpoint available, check sub-checkpoints for both networks if U.load_checkpoint_variables(pretrained_il, prefix=get_il_prefix(), include_no_prefix_vars=False): if rank == 0: logger.log("loaded checkpoint variables from " + pretrained_il) loaded = True elif expert_label == get_il_prefix(): logger.log("ERROR no available cat_dauggi expert model in ", pretrained_il) exit(1) if U.load_checkpoint_variables(pretrained_semi, prefix=get_semi_prefix(), include_no_prefix_vars=False): if rank == 0: logger.log("loaded checkpoint variables from " + pretrained_semi) loaded = True elif expert_label == get_semi_prefix(): if rank == 0: logger.log("ERROR no available semi expert model in ", pretrained_semi) exit(1) else: loaded = True if rank == 0: logger.log("loaded checkpoint variables from " + pretrained_weight) if loaded: not_update = 0 if any( [x.op.name.find("adversary") != -1 for x in U.ALREADY_INITIALIZED]) else 1 if pretrained_weight and pretrained_weight.rfind("iter_") and \ pretrained_weight[pretrained_weight.rfind("iter_") + len("iter_"):].isdigit(): curr_iter = int( pretrained_weight[pretrained_weight.rfind("iter_") + len("iter_"):]) + 1 if rank == 0: print("loaded checkpoint at iteration: " + str(curr_iter)) iters_so_far = curr_iter for label in timesteps_so_far: timesteps_so_far[label] = iters_so_far * timesteps_per_batch d_adam = MpiAdam(reward_giver.get_trainable_variables()) vfadam = {label: MpiAdam(vf_var_list[label]) for label in ob_space} U.initialize() d_adam.sync() for label in ob_space: th_init = get_flat[label]() MPI.COMM_WORLD.Bcast(th_init, root=0) set_from_flat[label](th_init) vfadam[label].sync() if rank == 0: print(label + "Init param sum", th_init.sum(), flush=True) # Prepare for rollouts # ---------------------------------------- seg_gen = { label: traj_segment_generator( pi[label], env[label], reward_giver, timesteps_per_batch, stochastic=True, semi_dataset=semi_dataset if label == get_semi_prefix() else None, semi_loss=semi_loss, reward_threshold=expert_reward_threshold if label == expert_label else None, sparse_reward=sparse_reward if label == expert_label else False) for label in ob_space } g_losses = {} assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1 g_loss_stats = { label: stats(loss_names[label] + vf_loss_names[label]) for label in ob_space if label != expert_label } d_loss_stats = stats(reward_giver.loss_name) ep_names = ["True_rewards", "Rewards", "Episode_length", "Success"] ep_stats = {label: None for label in ob_space} # cat_dauggi network stats if get_il_prefix() in ep_stats: ep_stats[get_il_prefix()] = stats([name for name in ep_names]) # semi network stats if get_semi_prefix() in ep_stats: if semi_loss and semi_dataset is not None: ep_names.append("L2_loss") ep_names.append("total_rewards") ep_stats[get_semi_prefix()] = stats( [get_semi_prefix() + name for name in ep_names]) if rank == 0: start_time = time.time() ch_count = 0 env_type = env[expert_label].env.env.__class__.__name__ while True: if callback: callback(locals(), globals()) if max_timesteps and any( [timesteps_so_far[label] >= max_timesteps for label in ob_space]): break elif max_episodes and any( [episodes_so_far[label] >= max_episodes for label in ob_space]): break elif max_iters and iters_so_far >= max_iters: break # Save model if rank == 0 and iters_so_far % save_per_iter == 0 and ckpt_dir is not None: fname = os.path.join(ckpt_dir, task_name) if env_type.find("Pendulum") != -1 or save_per_iter != 1: fname = os.path.join(ckpt_dir, 'iter_' + str(iters_so_far), 'iter_' + str(iters_so_far)) os.makedirs(os.path.dirname(fname), exist_ok=True) saver = tf.train.Saver() saver.save(tf.get_default_session(), fname, write_meta_graph=False) if rank == 0 and time.time( ) - start_time >= 3600 * ch_count: # save a different checkpoint every hour fname = os.path.join(ckpt_dir, 'hour' + str(ch_count).zfill(3)) fname = os.path.join(fname, 'iter_' + str(iters_so_far)) os.makedirs(os.path.dirname(fname), exist_ok=True) saver = tf.train.Saver() saver.save(tf.get_default_session(), fname, write_meta_graph=False) ch_count += 1 logger.log("********** Iteration %i ************" % iters_so_far) def fisher_func_builder(label): def fisher_vector_product(p): return allmean(compute_fvp[label](p, * fvpargs)) + cg_damping * p return fisher_vector_product # ------------------ Update G ------------------ d = {label: None for label in ob_space} segs = {label: None for label in ob_space} logger.log("Optimizing Policy...") for curr_step in range(g_step): for label in ob_space: if curr_step and label == expert_label: # get expert trajectories only for one g_step which is same as d_step continue logger.log("Optimizing Policy " + label + "...") with timed("sampling"): segs[label] = seg = seg_gen[label].__next__() seg["rew"] = seg["rew"] - seg["l2_loss"] * l2_w add_vtarg_and_adv(seg, gamma, lam) ob, ac, atarg, tdlamret, full_ob = seg["ob"], seg["ac"], seg[ "adv"], seg["tdlamret"], seg["full_ob"] vpredbefore = seg[ "vpred"] # predicted value function before udpate atarg = (atarg - atarg.mean()) / atarg.std( ) # standardized advantage function estimate d[label] = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret), shuffle=True) if not_update or label == expert_label: continue # stop G from updating if hasattr(pi[label], "ob_rms"): pi[label].ob_rms.update( full_ob) # update running mean/std for policy args = seg["full_ob"], seg["ac"], atarg fvpargs = [arr[::5] for arr in args] assign_old_eq_new[label]( ) # set old parameter values to new parameter values with timed("computegrad"): *lossbefore, g = compute_lossandgrad[label](*args) lossbefore = allmean(np.array(lossbefore)) g = allmean(g) if np.allclose(g, 0): logger.log("Got zero gradient. not updating") else: with timed("cg"): stepdir = cg(fisher_func_builder(label), g, cg_iters=cg_iters, verbose=rank == 0) assert np.isfinite(stepdir).all() shs = .5 * stepdir.dot(fisher_func_builder(label)(stepdir)) lm = np.sqrt(shs / max_kl) # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g)) fullstep = stepdir / lm expectedimprove = g.dot(fullstep) surrbefore = lossbefore[0] stepsize = 1.0 thbefore = get_flat[label]() for _ in range(10): thnew = thbefore + fullstep * stepsize set_from_flat[label](thnew) meanlosses = surr, kl, *_ = allmean( np.array(compute_losses[label](*args))) if rank == 0: print("Generator entropy " + str(meanlosses[4]) + ", loss " + str(meanlosses[2])) improve = surr - surrbefore logger.log("Expected: %.3f Actual: %.3f" % (expectedimprove, improve)) if not np.isfinite(meanlosses).all(): logger.log( "Got non-finite value of losses -- bad!") elif kl > max_kl * 1.5: logger.log( "violated KL constraint. shrinking step.") elif improve < 0: logger.log( "surrogate didn't improve. shrinking step.") else: logger.log("Stepsize OK!") break stepsize *= .5 else: logger.log("couldn't compute a good step") set_from_flat[label](thbefore) if nworkers > 1 and iters_so_far % 20 == 0: paramsums = MPI.COMM_WORLD.allgather( (thnew.sum(), vfadam[label].getflat().sum())) # list of tuples assert all( np.allclose(ps, paramsums[0]) for ps in paramsums[1:]) expert_dataset = d[expert_label] if not_update: break for label in ob_space: if label == expert_label: continue with timed("vf"): logger.log(fmt_row(13, vf_loss_names[label])) for _ in range(vf_iters): vf_b_losses = [] for batch in d[label].iterate_once(vf_batchsize): mbob = batch["ob"] mbret = batch["vtarg"] *newlosses, g = compute_vflossandgrad[label](mbob, mbret) g = allmean(g) newlosses = allmean(np.array(newlosses)) vfadam[label].update(g, vf_stepsize[label]) vf_b_losses.append(newlosses) logger.log(fmt_row(13, np.mean(vf_b_losses, axis=0))) logger.log("Evaluating losses...") losses = [] for batch in d[label].iterate_once(vf_batchsize): newlosses = compute_vf_losses[label](batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"]) losses.append(newlosses) g_losses[label], _, _ = mpi_moments(losses, axis=0) ######################### for ob_batch, ac_batch, full_ob_batch in dataset.iterbatches( (segs[label]["ob"], segs[label]["ac"], segs[label]["full_ob"]), include_final_partial_batch=False, batch_size=len(ob)): expert_batch = expert_dataset.next_batch(len(ob)) ob_expert, ac_expert = expert_batch["ob"], expert_batch[ "ac"] exp_rew = 0 exp_rews = None for obs, acs in zip(ob_expert, ac_expert): curr_rew = reward_giver.get_reward(obs, acs)[0][0] \ if not hasattr(reward_giver, '_labels') else \ reward_giver.get_reward(obs, acs, label) if isinstance(curr_rew, tuple): curr_rew, curr_rews = curr_rew exp_rews = 1 - np.exp( -curr_rews ) if exp_rews is None else exp_rews + 1 - np.exp( -curr_rews) exp_rew += 1 - np.exp(-curr_rew) mean_exp_rew = exp_rew / len(ob_expert) mean_exp_rews = exp_rews / len( ob_expert) if exp_rews is not None else None gen_rew = 0 gen_rews = None for obs, acs, full_obs in zip(ob_batch, ac_batch, full_ob_batch): curr_rew = reward_giver.get_reward(obs, acs)[0][0] \ if not hasattr(reward_giver, '_labels') else \ reward_giver.get_reward(obs, acs, label) if isinstance(curr_rew, tuple): curr_rew, curr_rews = curr_rew gen_rews = 1 - np.exp( -curr_rews ) if gen_rews is None else gen_rews + 1 - np.exp( -curr_rews) gen_rew += 1 - np.exp(-curr_rew) mean_gen_rew = gen_rew / len(ob_batch) mean_gen_rews = gen_rews / len( ob_batch) if gen_rews is not None else None if rank == 0: msg = "Network " + label + \ " Generator step " + str(curr_step) + ": Dicriminator reward of expert traj " \ + str(mean_exp_rew) + " vs gen traj " + str(mean_gen_rew) if mean_exp_rews is not None and mean_gen_rews is not None: msg += "\nDiscriminator multi rewards of expert " + str(mean_exp_rews) + " vs gen " \ + str(mean_gen_rews) logger.log(msg) ######################### if not not_update: for label in g_losses: for (lossname, lossval) in zip(loss_names[label] + vf_loss_names[label], g_losses[label]): logger.record_tabular(lossname, lossval) logger.record_tabular( label + "ev_tdlam_before", explained_variance(segs[label]["vpred"], segs[label]["tdlamret"])) # ------------------ Update D ------------------ if not freeze_d: logger.log("Optimizing Discriminator...") batch_size = len(list(segs.values())[0]['ob']) // d_step expert_dataset = d[expert_label] batch_gen = { label: dataset.iterbatches( (segs[label]["ob"], segs[label]["ac"]), include_final_partial_batch=False, batch_size=batch_size) for label in segs if label != expert_label } d_losses = [ ] # list of tuples, each of which gives the loss for a minibatch for step in range(d_step): g_ob = {} g_ac = {} for label in batch_gen: # get batches for different gens g_ob[label], g_ac[label] = batch_gen[label].__next__() expert_batch = expert_dataset.next_batch(batch_size) ob_expert, ac_expert = expert_batch["ob"], expert_batch["ac"] for label in g_ob: ######################### exp_rew = 0 exp_rews = None for obs, acs in zip(ob_expert, ac_expert): curr_rew = reward_giver.get_reward(obs, acs)[0][0] \ if not hasattr(reward_giver, '_labels') else \ reward_giver.get_reward(obs, acs, label) if isinstance(curr_rew, tuple): curr_rew, curr_rews = curr_rew exp_rews = 1 - np.exp( -curr_rews ) if exp_rews is None else exp_rews + 1 - np.exp( -curr_rews) exp_rew += 1 - np.exp(-curr_rew) mean_exp_rew = exp_rew / len(ob_expert) mean_exp_rews = exp_rews / len( ob_expert) if exp_rews is not None else None gen_rew = 0 gen_rews = None for obs, acs in zip(g_ob[label], g_ac[label]): curr_rew = reward_giver.get_reward(obs, acs)[0][0] \ if not hasattr(reward_giver, '_labels') else \ reward_giver.get_reward(obs, acs, label) if isinstance(curr_rew, tuple): curr_rew, curr_rews = curr_rew gen_rews = 1 - np.exp( -curr_rews ) if gen_rews is None else gen_rews + 1 - np.exp( -curr_rews) gen_rew += 1 - np.exp(-curr_rew) mean_gen_rew = gen_rew / len(g_ob[label]) mean_gen_rews = gen_rews / len( g_ob[label]) if gen_rews is not None else None if rank == 0: msg = "Dicriminator reward of expert traj " + str(mean_exp_rew) + " vs " + label + \ "gen traj " + str(mean_gen_rew) if mean_exp_rews is not None and mean_gen_rews is not None: msg += "\nDiscriminator multi expert rewards " + str(mean_exp_rews) + " vs " + label + \ "gen " + str(mean_gen_rews) logger.log(msg) ######################### # update running mean/std for reward_giver if hasattr(reward_giver, "obs_rms"): reward_giver.obs_rms.update( np.concatenate(list(g_ob.values()) + [ob_expert], 0)) *newlosses, g = reward_giver.lossandgrad( *(list(g_ob.values()) + list(g_ac.values()) + [ob_expert] + [ac_expert])) d_adam.update(allmean(g), d_stepsize) d_losses.append(newlosses) logger.log(fmt_row(13, reward_giver.loss_name)) logger.log(fmt_row(13, np.mean(d_losses, axis=0))) for label in ob_space: lrlocal = (segs[label]["ep_lens"], segs[label]["ep_rets"], segs[label]["ep_true_rets"], segs[label]["ep_success"], segs[label]["ep_semi_loss"]) # local values listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples lens, rews, true_rets, success, semi_losses = map( flatten_lists, zip(*listoflrpairs)) # success success = [ float(elem) for elem in success if isinstance(elem, (int, float, bool)) ] # remove potential None types if not success: success = [-1] # set success to -1 if env has no success flag success_buffer[label].extend(success) true_rewbuffer[label].extend(true_rets) lenbuffer[label].extend(lens) rewbuffer[label].extend(rews) if semi_loss and semi_dataset is not None and label == get_semi_prefix( ): semi_losses = [elem * l2_w for elem in semi_losses] total_rewards = rews total_rewards = [ re_elem - l2_elem for re_elem, l2_elem in zip(total_rewards, semi_losses) ] l2_rewbuffer.extend(semi_losses) total_rewbuffer.extend(total_rewards) logger.record_tabular(label + "EpLenMean", np.mean(lenbuffer[label])) logger.record_tabular(label + "EpRewMean", np.mean(rewbuffer[label])) logger.record_tabular(label + "EpTrueRewMean", np.mean(true_rewbuffer[label])) logger.record_tabular(label + "EpSuccess", np.mean(success_buffer[label])) if semi_loss and semi_dataset is not None and label == get_semi_prefix( ): logger.record_tabular(label + "EpSemiLoss", np.mean(l2_rewbuffer)) logger.record_tabular(label + "EpTotalLoss", np.mean(total_rewbuffer)) logger.record_tabular(label + "EpThisIter", len(lens)) episodes_so_far[label] += len(lens) timesteps_so_far[label] += sum(lens) logger.record_tabular(label + "EpisodesSoFar", episodes_so_far[label]) logger.record_tabular(label + "TimestepsSoFar", timesteps_so_far[label]) logger.record_tabular("TimeElapsed", time.time() - tstart) iters_so_far += 1 logger.record_tabular("ItersSoFar", iters_so_far) if rank == 0: logger.dump_tabular() if not not_update: for label in g_loss_stats: g_loss_stats[label].add_all_summary( writer, g_losses[label], iters_so_far) if not freeze_d: d_loss_stats.add_all_summary(writer, np.mean(d_losses, axis=0), iters_so_far) for label in ob_space: # default buffers ep_buffers = [ np.mean(true_rewbuffer[label]), np.mean(rewbuffer[label]), np.mean(lenbuffer[label]), np.mean(success_buffer[label]) ] if semi_loss and semi_dataset is not None and label == get_semi_prefix( ): ep_buffers.append(np.mean(l2_rewbuffer)) ep_buffers.append(np.mean(total_rewbuffer)) ep_stats[label].add_all_summary(writer, ep_buffers, iters_so_far) if not_update and not freeze_g: not_update -= 1
def learn( env, policy_fn, *, timesteps_per_batch, # what to train on epsilon, beta, cg_iters, gamma, lam, # advantage estimation trial, method, entcoeff=0.0, cg_damping=1e-2, vf_stepsize=3e-4, vf_iters=3, max_timesteps=0, max_episodes=0, max_iters=0, # time constraint callback=None, TRPO=False): nworkers = MPI.COMM_WORLD.Get_size() rank = MPI.COMM_WORLD.Get_rank() np.set_printoptions(precision=3) # Setup losses and stuff # ---------------------------------------- ob_space = env.observation_space ac_space = env.action_space pi = policy_fn("pi", ob_space, ac_space) oldpi = policy_fn("oldpi", ob_space, ac_space) atarg = tf.placeholder( dtype=tf.float32, shape=[None]) # Target advantage function (if applicable) ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return ob = U.get_placeholder_cached(name="ob") ac = pi.pdtype.sample_placeholder([None]) kloldnew = oldpi.pd.kl(pi.pd) ent = pi.pd.entropy() old_entropy = oldpi.pd.entropy() meankl = tf.reduce_mean(kloldnew) meanent = tf.reduce_mean(ent) entbonus = entcoeff * meanent vferr = tf.reduce_mean(tf.square(pi.vpred - ret)) # ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold # surrgain = tf.reduce_mean(ratio * atarg) surrgain = tf.reduce_mean(pi.pd.logp(ac) * atarg) optimgain = surrgain + entbonus losses = [optimgain, meankl, entbonus, surrgain, meanent] loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"] dist = meankl all_var_list = pi.get_trainable_variables() all_var_list = [ v for v in all_var_list if v.name.split("/")[0].startswith("pi") ] var_list = [ v for v in all_var_list if v.name.split("/")[1].startswith("pol") ] vf_var_list = [ v for v in all_var_list if v.name.split("/")[1].startswith("vf") ] vfadam = MpiAdam(vf_var_list) get_flat = U.GetFlat(var_list) set_from_flat = U.SetFromFlat(var_list) klgrads = tf.gradients(dist, var_list) flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan") shapes = [var.get_shape().as_list() for var in var_list] start = 0 tangents = [] for shape in shapes: sz = U.intprod(shape) tangents.append(tf.reshape(flat_tangent[start:start + sz], shape)) start += sz gvp = tf.add_n([ tf.reduce_sum(g * tangent) for (g, tangent) in zipsame(klgrads, tangents) ]) #pylint: disable=E1111 fvp = U.flatgrad(gvp, var_list) assign_old_eq_new = U.function( [], [], updates=[ tf.assign(oldv, newv) for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables()) ]) compute_losses = U.function([ob, ac, atarg], losses) compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)]) compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp) compute_vflossandgrad = U.function([ob, ret], U.flatgrad(vferr, vf_var_list)) @contextmanager def timed(msg): if rank == 0: print(colorize(msg, color='magenta')) tstart = time.time() yield print( colorize("done in %.3f seconds" % (time.time() - tstart), color='magenta')) else: yield def allmean(x): assert isinstance(x, np.ndarray) out = np.empty_like(x) MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM) out /= nworkers return out U.initialize() th_init = get_flat() MPI.COMM_WORLD.Bcast(th_init, root=0) set_from_flat(th_init) vfadam.sync() print("Init param sum", th_init.sum(), flush=True) # Initialize eta, omega optimizer init_eta = 0.5 init_omega = 2.0 eta_omega_optimizer = EtaOmegaOptimizer(beta, epsilon, init_eta, init_omega) # Prepare for rollouts # ---------------------------------------- seg_gen = traj_segment_generator(pi, env, timesteps_per_batch, stochastic=True) episodes_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 tstart = time.time() lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1 while True: if callback: callback(locals(), globals()) if max_timesteps and timesteps_so_far >= max_timesteps: break elif max_episodes and episodes_so_far >= max_episodes: break elif max_iters and iters_so_far >= max_iters: break logger.log("********** Iteration %i ************" % iters_so_far) with timed("sampling"): seg = seg_gen.__next__() add_vtarg_and_adv(seg, gamma, lam) # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets)) ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[ "tdlamret"] vpredbefore = seg["vpred"] # predicted value function before udpate atarg = (atarg - atarg.mean() ) / atarg.std() # standardized advantage function estimate if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret) if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy args = seg["ob"], seg["ac"], atarg fvpargs = [arr[::5] for arr in args] def fisher_vector_product(p): return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p assign_old_eq_new() # set old parameter values to new parameter values with timed("computegrad"): *lossbefore, g = compute_lossandgrad(*args) lossbefore = allmean(np.array(lossbefore)) g = allmean(g) if np.allclose(g, 0): logger.log("Got zero gradient. not updating") else: with timed("cg"): stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank == 0) assert np.isfinite(stepdir).all() if TRPO: # # TRPO specific code. # Find correct step size using line search # shs = .5 * stepdir.dot(fisher_vector_product(stepdir)) lm = np.sqrt(shs / epsilon) # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g)) fullstep = stepdir / lm expectedimprove = g.dot(fullstep) surrbefore = lossbefore[0] stepsize = 1.0 thbefore = get_flat() for _ in range(10): thnew = thbefore + fullstep * stepsize set_from_flat(thnew) meanlosses = surr, kl, *_ = allmean( np.array(compute_losses(*args))) improve = surr - surrbefore logger.log("Expected: %.3f Actual: %.3f" % (expectedimprove, improve)) if not np.isfinite(meanlosses).all(): logger.log("Got non-finite value of losses -- bad!") elif kl > epsilon * 1.5: logger.log("violated KL constraint. shrinking step.") elif improve < 0: logger.log("surrogate didn't improve. shrinking step.") else: logger.log("Stepsize OK!") break stepsize *= .5 else: logger.log("couldn't compute a good step") set_from_flat(thbefore) else: # # COPOS specific implementation. # copos_update_dir = stepdir # Split direction into log-linear 'w_theta' and non-linear 'w_beta' parts w_theta, w_beta = pi.split_w(copos_update_dir) # q_beta(s,a) = \grad_beta \log \pi(a|s) * w_beta # = features_beta(s) * K^T * Prec * a # q_beta = self.target.get_q_beta(features_beta, actions) Waa, Wsa = pi.w2W(w_theta) wa = pi.get_wa(ob, w_beta) varphis = pi.get_varphis(ob) # Optimize eta and omega tmp_ob = np.zeros( (1, ) + env.observation_space.shape ) # We assume that entropy does not depend on the NN old_ent = old_entropy.eval({oldpi.ob: tmp_ob})[0] eta, omega = eta_omega_optimizer.optimize( w_theta, Waa, Wsa, wa, varphis, pi.get_kt(), pi.get_prec_matrix(), pi.is_new_policy_valid, old_ent) logger.log("Initial eta: " + str(eta) + " and omega: " + str(omega)) current_theta_beta = get_flat() prev_theta, prev_beta = pi.all_to_theta_beta( current_theta_beta) for i in range(2): # Do a line search for both theta and beta parameters by adjusting only eta eta = eta_search(w_theta, w_beta, eta, omega, allmean, compute_losses, get_flat, set_from_flat, pi, epsilon, args) logger.log("Updated eta, eta: " + str(eta) + " and omega: " + str(omega)) # Find proper omega for new eta. Use old policy parameters first. set_from_flat(pi.theta_beta_to_all(prev_theta, prev_beta)) eta, omega = \ eta_omega_optimizer.optimize(w_theta, Waa, Wsa, wa, varphis, pi.get_kt(), pi.get_prec_matrix(), pi.is_new_policy_valid, old_ent, eta) logger.log("Updated omega, eta: " + str(eta) + " and omega: " + str(omega)) # Use final policy logger.log("Final eta: " + str(eta) + " and omega: " + str(omega)) cur_theta = (eta * prev_theta + w_theta.reshape(-1, )) / (eta + omega) cur_beta = prev_beta + w_beta.reshape(-1, ) / eta set_from_flat(pi.theta_beta_to_all(cur_theta, cur_beta)) meanlosses = surr, kl, *_ = allmean( np.array(compute_losses(*args))) if nworkers > 1 and iters_so_far % 20 == 0: paramsums = MPI.COMM_WORLD.allgather( (thnew.sum(), vfadam.getflat().sum())) # list of tuples assert all( np.allclose(ps, paramsums[0]) for ps in paramsums[1:]) for (lossname, lossval) in zip(loss_names, meanlosses): logger.record_tabular(lossname, lossval) with timed("vf"): for _ in range(vf_iters): for (mbob, mbret) in dataset.iterbatches( (seg["ob"], seg["tdlamret"]), include_final_partial_batch=False, batch_size=64): g = allmean(compute_vflossandgrad(mbob, mbret)) vfadam.update(g, vf_stepsize) logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret)) lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples lens, rews = map(flatten_lists, zip(*listoflrpairs)) lenbuffer.extend(lens) rewbuffer.extend(rews) logger.record_tabular("EpLenMean", np.mean(lenbuffer)) logger.record_tabular("EpRewMean", np.mean(rewbuffer)) logger.record_tabular("EpThisIter", len(lens)) episodes_so_far += len(lens) timesteps_so_far += sum(lens) iters_so_far += 1 logger.record_tabular("EpisodesSoFar", episodes_so_far) logger.record_tabular("TimestepsSoFar", timesteps_so_far) logger.record_tabular("TimeElapsed", time.time() - tstart) logger.record_tabular("Name", method) logger.record_tabular("Iteration", iters_so_far) logger.record_tabular("trial", trial) if rank == 0: logger.dump_tabular()
def learn(env, last_ob, last_jpos, run_reach, policy_func, reward_giver, expert_dataset, rank, pretrained, pretrained_weight, *, g_step, d_step, entcoeff, save_per_iter, ckpt_dir, log_dir, timesteps_per_batch, task_name, gamma, lam, max_kl, cg_iters, cg_damping=1e-2, vf_stepsize=3e-4, d_stepsize=3e-4, vf_iters=3, max_timesteps=0, max_episodes=0, max_iters=0, callback=None ): nworkers = MPI.COMM_WORLD.Get_size() rank = MPI.COMM_WORLD.Get_rank() np.set_printoptions(precision=3) # Setup losses and stuff # ---------------------------------------- ob_space = env.observation_space ac_space = env.action_space pi = policy_func("pi_grasp", ob_space, ac_space, reuse=(pretrained_weight != None)) oldpi = policy_func("oldpi", ob_space, ac_space) atarg = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable) ret = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None]) # Empirical return ob = U.get_placeholder_cached(name="ob") ac = pi.pdtype.sample_placeholder([None]) # Changes are made in order to use tensorboard # ------------------------------------------- #train_writer = tf.compat.v1.summary.FileWriter('../../logs/trpo_mpi') # sets log dir to GailPart folder #sess = tf.compat.v1.Session() # create a session?? # ------------------------------------------- kloldnew = oldpi.pd.kl(pi.pd) ent = pi.pd.entropy() meankl = tf.reduce_mean(kloldnew) meanent = tf.reduce_mean(ent) entbonus = entcoeff * meanent vferr = tf.reduce_mean(tf.square(pi.vpred - ret)) ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold surrgain = tf.reduce_mean(ratio * atarg) optimgain = surrgain + entbonus losses = [optimgain, meankl, entbonus, surrgain, meanent] loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"] dist = meankl all_var_list = pi.get_trainable_variables() var_list = [v for v in all_var_list if v.name.startswith("pi_grasp/pol") or v.name.startswith("pi_grasp/logstd")] vf_var_list = [v for v in all_var_list if v.name.startswith("pi_grasp/vff")] assert len(var_list) == len(vf_var_list) + 1 d_adam = MpiAdam(reward_giver.get_trainable_variables()) vfadam = MpiAdam(vf_var_list) get_flat = U.GetFlat(var_list) set_from_flat = U.SetFromFlat(var_list) klgrads = tf.gradients(dist, var_list) flat_tangent = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None], name="flat_tan") shapes = [var.get_shape().as_list() for var in var_list] start = 0 tangents = [] for shape in shapes: sz = U.intprod(shape) tangents.append(tf.reshape(flat_tangent[start:start+sz], shape)) start += sz gvp = tf.add_n([tf.reduce_sum(g*tangent) for (g, tangent) in zipsame(klgrads, tangents)]) # pylint: disable=E1111 fvp = U.flatgrad(gvp, var_list) assign_old_eq_new = U.function([], [], updates=[tf.compat.v1.assign(oldv, newv) for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())]) compute_losses = U.function([ob, ac, atarg], losses) compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)]) compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp) compute_vflossandgrad = U.function([ob, ret], U.flatgrad(vferr, vf_var_list)) @contextmanager def timed(msg): if rank == 0: print(colorize(msg, color='magenta')) tstart = time.time() yield print(colorize("done in %.3f seconds" % (time.time() - tstart), color='magenta')) else: yield def allmean(x): assert isinstance(x, np.ndarray) out = np.empty_like(x) MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM) out /= nworkers return out U.initialize() th_init = get_flat() MPI.COMM_WORLD.Bcast(th_init, root=0) set_from_flat(th_init) d_adam.sync() vfadam.sync() if rank == 0: print("Init param sum", th_init.sum(), flush=True) # Prepare for rollouts # ---------------------------------------- seg_gen = traj_segment_generator(pi, last_ob, last_jpos, run_reach, policy_func, env, reward_giver, timesteps_per_batch, stochastic=True) episodes_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 tstart = time.time() lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards true_rewbuffer = deque(maxlen=40) assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1 g_loss_stats = stats(loss_names) d_loss_stats = stats(reward_giver.loss_name) ep_stats = stats(["True_rewards", "Rewards", "Episode_length"]) # if provide pretrained weight if pretrained_weight is not None: U.load_state(pretrained_weight, var_list=pi.get_variables()) while True: if callback: callback(locals(), globals()) if max_timesteps and timesteps_so_far >= max_timesteps: break elif max_episodes and episodes_so_far >= max_episodes: break elif max_iters and iters_so_far >= max_iters: break # Save model if rank == 0 and iters_so_far % save_per_iter == 0 and ckpt_dir is not None: t_name = task_name + "_" + str(iters_so_far) fname = os.path.join(ckpt_dir, t_name) # changed from task_name os.makedirs(os.path.dirname(fname), exist_ok=True) saver = tf.compat.v1.train.Saver() saver.save(tf.compat.v1.get_default_session(), fname) logger.log("********** Iteration %i ************" % iters_so_far) def fisher_vector_product(p): return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p # ------------------ Update G ------------------ logger.log("Optimizing Policy...") for _ in range(g_step): with timed("sampling"): seg = seg_gen.__next__() #print("trpo_mpi, seg = seg_gen.__next__() call output: ", seg ) add_vtarg_and_adv(seg, gamma, lam) # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets)) ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"] vpredbefore = seg["vpred"] # predicted value function before udpate atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy args = seg["ob"], seg["ac"], atarg fvpargs = [arr[::5] for arr in args] assign_old_eq_new() # set old parameter values to new parameter values with timed("computegrad"): *lossbefore, g = compute_lossandgrad(*args) lossbefore = allmean(np.array(lossbefore)) g = allmean(g) if np.allclose(g, 0): logger.log("Got zero gradient. not updating") else: with timed("cg"): stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank == 0) assert np.isfinite(stepdir).all() shs = .5*stepdir.dot(fisher_vector_product(stepdir)) lm = np.sqrt(shs / max_kl) # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g)) fullstep = stepdir / lm expectedimprove = g.dot(fullstep) surrbefore = lossbefore[0] stepsize = 1.0 thbefore = get_flat() for _ in range(10): thnew = thbefore + fullstep * stepsize set_from_flat(thnew) meanlosses = surr, kl, *_ = allmean(np.array(compute_losses(*args))) #logger.log("trpo_mpi.py, what should be logged with loss names ie. meanlosses:_", meanlosses) improve = surr - surrbefore logger.log("Expected: %.3f Actual: %.3f" % (expectedimprove, improve)) if not np.isfinite(meanlosses).all(): logger.log("Got non-finite value of losses -- bad!") elif kl > max_kl * 1.5: logger.log("violated KL constraint. shrinking step.") elif improve < 0: logger.log("surrogate didn't improve. shrinking step.") else: logger.log("Stepsize OK!") break stepsize *= .5 else: logger.log("couldn't compute a good step") set_from_flat(thbefore) if nworkers > 1 and iters_so_far % 20 == 0: paramsums = MPI.COMM_WORLD.allgather((thnew.sum(), vfadam.getflat().sum())) # list of tuples assert all(np.allclose(ps, paramsums[0]) for ps in paramsums[1:]) with timed("vf"): for _ in range(vf_iters): for (mbob, mbret) in dataset.iterbatches((seg["ob"], seg["tdlamret"]), include_final_partial_batch=False, batch_size=128): if hasattr(pi, "ob_rms"): pi.ob_rms.update(mbob) # update running mean/std for policy g = allmean(compute_vflossandgrad(mbob, mbret)) vfadam.update(g, vf_stepsize) g_losses = meanlosses #logger.log("trpo_mpi.py, mean losses before logging wiht loss names: \n") #logger.log(meanlosses) # This is where the nan values are tabulated for some of the entries #logger.log("trpo_mpi.py, view whats being printed with (loss_names, lossvalues)") for (lossname, lossval) in zip(loss_names, meanlosses): logger.record_tabular(lossname, lossval) logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret)) # ------------------ Update D ------------------ logger.log("Optimizing Discriminator...") logger.log(fmt_row(13, reward_giver.loss_name)) ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob)) batch_size = len(ob) // d_step d_losses = [] # list of tuples, each of which gives the loss for a minibatch for ob_batch, ac_batch in dataset.iterbatches((ob, ac), include_final_partial_batch=False, batch_size=batch_size): ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob_batch)) # update running mean/std for reward_giver if hasattr(reward_giver, "obs_rms"): reward_giver.obs_rms.update(np.concatenate((ob_batch, ob_expert), 0)) *newlosses, g = reward_giver.lossandgrad(ob_batch, ac_batch, ob_expert, ac_expert) d_adam.update(allmean(g), d_stepsize) d_losses.append(newlosses) # This is to see what the d_losses are #logger.log("trpo_mpi.py, see what is being logged in d_losses") #logger.log("trpo_mpi.py, d_losses") #logger.log(d_losses) logger.log(fmt_row(13, np.mean(d_losses, axis=0))) # For Tensorboard Logging # --------------------------- #tf.compat.v1.summary.scalar("Generator Accuracy", tf.convert_to_tensor( np.mean(d_losses, axis=0)[4] ) ) # 5 position #tf.compat.v1.summary.scalar("Expert Accuracy", tf.convert_to_tensor( np.mean(d_losses, axis=0)[5] ) ) # 6 position #tf.compat.v1.summary.scalar("Entropy Loss", tf.convert_to_tensor( np.mean(d_losses, axis=0)[3] ) ) # 4 position #merge = tf.compat.v1.summary.merge_all() # merge summaries #summary = sess.run([merge]) #train_writer.add_summary(summary, iters_so_far) # Is there a need to reset metric after every epoch? I dont think so? # --------------------------- #logger.log("trpo_mpi.py, after logging, but before recordeing timesteps so far") lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ep_true_rets"]) # local values, truly confirmed is empty after call listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples lens, rews, true_rets = map(flatten_lists, zip(*listoflrpairs)) true_rewbuffer.extend(true_rets) lenbuffer.extend(lens) rewbuffer.extend(rews) # Could it be that the seg locals for lens and rets are ommitted since has no use in gail algorithm? # Probably dont have to worry about it, check the scalar part logger.record_tabular("EpLenMean", np.mean(lenbuffer)) # This has nan values logger.record_tabular("EpRewMean", np.mean(rewbuffer)) # This has nan values logger.record_tabular("EpTrueRewMean", np.mean(true_rewbuffer)) # This has nan values logger.record_tabular("EpThisIter", len(lens)) episodes_so_far += len(lens) #timesteps_so_far += sum(lens) timesteps_so_far += seg["steps"] # changed to match setup with no finishing condition iters_so_far += 1 #env.reset() #reset the environment after a new iteration, therefore in traj generator check ob logger.record_tabular("EpisodesSoFar", episodes_so_far) # This is 0 ? if lens which is the number of entries for episode length doesnt exist, doesnt make sense for it to have a return. logger.record_tabular("TimestepsSoFar", timesteps_so_far) logger.record_tabular("TimeElapsed", time.time() - tstart) # I think the entloss, entrpoy, ev_.... and the useful ones arent from the environment called using the trpo if rank == 0: logger.dump_tabular()
def learn(env, policy_fn, *, batch_size, # what to train on task_horizon, max_kl, cg_iters, gamma, lam, # advantage estimation entcoeff=0.0, cg_damping=1e-2, vf_stepsize=3e-4, vf_iters =3, max_timesteps=0, max_episodes=0, max_iters=0, # time constraint callback=None, weights_dir='.', per_decision = True, normalize = False, truncate_at = np.infty ): nworkers = MPI.COMM_WORLD.Get_size() rank = MPI.COMM_WORLD.Get_rank() np.set_printoptions(precision=3) timesteps_per_batch = batch_size * task_horizon # Setup losses and stuff # ---------------------------------------- ob_space = env.observation_space ac_space = env.action_space pi = policy_fn("pi", ob_space, ac_space) oldpi = policy_fn("oldpi", ob_space, ac_space) atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable) ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return ob = U.get_placeholder_cached(name="ob") ac = pi.pdtype.sample_placeholder([None]) kloldnew = oldpi.pd.kl(pi.pd) ent = pi.pd.entropy() meankl = tf.reduce_mean(kloldnew) meanent = tf.reduce_mean(ent) entbonus = entcoeff * meanent vferr = tf.reduce_mean(tf.square(pi.vpred - ret)) ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold surrgain = tf.reduce_mean(ratio * atarg) optimgain = surrgain + entbonus losses = [optimgain, meankl, entbonus, surrgain, meanent] loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"] dist = meankl all_var_list = pi.get_trainable_variables() var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("pol")] vf_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("vf")] vfadam = MpiAdam(vf_var_list) get_flat = U.GetFlat(var_list) set_from_flat = U.SetFromFlat(var_list) klgrads = tf.gradients(dist, var_list) flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan") shapes = [var.get_shape().as_list() for var in var_list] start = 0 tangents = [] for shape in shapes: sz = U.intprod(shape) tangents.append(tf.reshape(flat_tangent[start:start+sz], shape)) start += sz gvp = tf.add_n([tf.reduce_sum(g*tangent) for (g, tangent) in zipsame(klgrads, tangents)]) #pylint: disable=E1111 fvp = U.flatgrad(gvp, var_list) assign_old_eq_new = U.function([],[], updates=[tf.assign(oldv, newv) for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())]) compute_losses = U.function([ob, ac, atarg], losses) compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)]) compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp) compute_vflossandgrad = U.function([ob, ret], U.flatgrad(vferr, vf_var_list)) @contextmanager def timed(msg): if rank == 0: print(colorize(msg, color='magenta')) tstart = time.time() yield print(colorize("done in %.3f seconds"%(time.time() - tstart), color='magenta')) else: yield def allmean(x): assert isinstance(x, np.ndarray) out = np.empty_like(x) MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM) out /= nworkers return out U.initialize() th_init = get_flat() MPI.COMM_WORLD.Bcast(th_init, root=0) set_from_flat(th_init) vfadam.sync() print("Init param sum", th_init.sum(), flush=True) # Prepare for rollouts # ---------------------------------------- seg_gen = traj_segment_generator(pi, env, task_horizon, timesteps_per_batch, stochastic=True) episodes_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 tstart = time.time() lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards assert sum([max_iters>0, max_timesteps>0, max_episodes>0])==1 while True: if callback: callback(locals(), globals()) if max_timesteps and timesteps_so_far >= max_timesteps: break elif max_episodes and episodes_so_far >= max_episodes: break elif max_iters and iters_so_far >= max_iters: break logger.log("********** Iteration %i ************"%iters_so_far) with timed("sampling"): seg = seg_gen.__next__() add_vtarg_and_adv(seg, gamma, lam) #Params #""" params = pi.eval_param() #print(params) np.save(weights_dir+'/weights_'+str(iters_so_far), params) #""" # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets)) ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"] vpredbefore = seg["vpred"] # predicted value function before udpate atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret) if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy args = seg["ob"], seg["ac"], atarg fvpargs = [arr[::5] for arr in args] def fisher_vector_product(p): return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p assign_old_eq_new() # set old parameter values to new parameter values with timed("computegrad"): *lossbefore, g = compute_lossandgrad(*args) lossbefore = allmean(np.array(lossbefore)) g = allmean(g) if np.allclose(g, 0): logger.log("Got zero gradient. not updating") else: with timed("cg"): stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank==0) assert np.isfinite(stepdir).all() shs = .5*stepdir.dot(fisher_vector_product(stepdir)) lm = np.sqrt(shs / max_kl) print('DOT: %s' % np.dot(stepdir, g)) # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g)) fullstep = stepdir / lm expectedimprove = g.dot(fullstep) surrbefore = lossbefore[0] stepsize = 1.0 thbefore = get_flat() for _ in range(10): thnew = thbefore + fullstep * stepsize set_from_flat(thnew) meanlosses = surr, kl, *_ = allmean(np.array(compute_losses(*args))) improve = surr - surrbefore logger.log("Expected: %.3f Actual: %.3f"%(expectedimprove, improve)) if not np.isfinite(meanlosses).all(): logger.log("Got non-finite value of losses -- bad!") elif kl > max_kl * 1.5: logger.log("violated KL constraint. shrinking step.") elif improve < 0: logger.log("surrogate didn't improve. shrinking step.") else: logger.log("Stepsize OK!") break stepsize *= .5 else: logger.log("couldn't compute a good step") set_from_flat(thbefore) if nworkers > 1 and iters_so_far % 20 == 0: paramsums = MPI.COMM_WORLD.allgather((thnew.sum(), vfadam.getflat().sum())) # list of tuples assert all(np.allclose(ps, paramsums[0]) for ps in paramsums[1:]) for (lossname, lossval) in zip(loss_names, meanlosses): logger.record_tabular(lossname, lossval) with timed("vf"): for _ in range(vf_iters): for (mbob, mbret) in dataset.iterbatches((seg["ob"], seg["tdlamret"]), include_final_partial_batch=False, batch_size=64): g = allmean(compute_vflossandgrad(mbob, mbret)) vfadam.update(g, vf_stepsize) logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret)) lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ob"], seg["ac"],seg["rew"]) # local values listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples lens, rews, states, actions, rewards = map(flatten_lists, zip(*listoflrpairs)) disc_rews = [] start = 0 for ep_len in lens: end = start + ep_len disc = gamma + np.zeros(ep_len) disc[0] = 1 disc = np.cumprod(disc) disc_rewards = np.array(rewards[start:end]) * disc disc_rews.append(np.sum(disc_rewards)) start = end #Save importance weights simple_iw = pi.eval_simple_iw(states, actions, rewards, lens, gamma=gamma, behavioral=oldpi) np.save(weights_dir+'/iws_'+str(iters_so_far), simple_iw) #print(len(simple_iw), simple_iw) #Save returns ep_rets = np.array(disc_rews) np.save(weights_dir+'/rets_'+str(iters_so_far), ep_rets) #print(len(ep_rets), ep_rets) #lenbuffer.extend(lens) #rewbuffer.extend(rews) #Renyi """ renyi_4 = np.mean(pi.eval_renyi(states, oldpi, 4)) #print('Renyi:', renyi) #""" #Importance weights stats """ avg_iw, var_iw, max_iw, ess = pi.eval_iw_stats(states, actions, rewards, lens, gamma=gamma, behavioral=oldpi, per_decision=per_decision, normalize=normalize, truncate_at=truncate_at) #""" #Returns stats """ avg_ret, var_ret, max_ret = pi.eval_ret_stats(states, actions, rewards, lens, gamma=gamma, behavioral=oldpi, per_decision=per_decision, normalize=normalize, truncate_at=truncate_at) #""" #Performance #""" bound_delta = .2 batch_size = len(lens) J = pi.eval_J(states, actions, rewards, lens, gamma=gamma, behavioral=oldpi, per_decision=per_decision, normalize=normalize, truncate_at=truncate_at) var_J = pi.eval_var_J(states, actions, rewards, lens, gamma=gamma, behavioral=oldpi, per_decision=per_decision, normalize=normalize, truncate_at=truncate_at) """ bound = pi.eval_bound(states, actions, rewards, lens, gamma=gamma, behavioral=oldpi, per_decision=per_decision, normalize=normalize, truncate_at=truncate_at, delta=bound_delta, use_ess=True) #""" #Sample Renyi d2s = pi.eval_renyi(states, oldpi, 2) d2s_by_episode = [] start = 0 for ep_len in lens: end = start + ep_len d2s_by_episode = np.sum(d2s[start:end]) start = end sample_d2 = np.mean(np.exp(d2s_by_episode)) """ grad_bound = pi.eval_grad_bound(states, actions, rewards, lens, gamma=gamma, behavioral=oldpi, per_decision=per_decision, normalize=normalize, truncate_at=truncate_at, delta=bound_delta, use_ess=True) print(grad_bound) #print('Target performance', J, '+-', np.sqrt(var_J/len(lens))) #""" #Gradients """ grad_J = pi.eval_grad_J(states, actions, rewards, lens, behavioral=oldpi, per_decision=True) grad_var_J = pi.eval_grad_var_J(states, actions, rewards, lens, behavioral=oldpi, per_decision=True) print('Target performance grads', grad_J, grad_var_J) #""" #Student-t bound """ bound = pi.eval_bound(states, actions, rewards, lens, behavioral=oldpi, per_decision=True) #print('Bound comp. time', time.time() - checkpoint) print("StudentTBound", bound) #""" #Student-t bound grad """ bound_grad = pi.eval_bound_grad(states, actions, rewards, lens, behavioral=oldpi, per_decision=True) print("StudentTBound grad", bound_grad) #""" #Fisher """ checkpoint = time.time() fisher = oldpi.eval_fisher(states, actions, lens, behavioral=None) #print(fisher) assert np.array_equal(fisher, fisher.T) print('Fisher comp. time', time.time() - checkpoint) checkpoint = time.time() natural = np.linalg.solve(fisher + 1e-12*np.eye(fisher.shape[0]), grad_J) print(natural) #print('Fisher vector product time:', time.time() - checkpoint) #""" #Logging logger.record_tabular("Step_size", stepsize) #logger.record_tabular("Our_bound", bound) #logger.record_tabular("Reny_4", renyi_4) logger.record_tabular("SampleRenyi2", sample_d2) #logger.record_tabular("Max_iw", max_iw) #logger.record_tabular("Ess", ess) #logger.record_tabular("Avg_iw", avg_iw) #logger.record_tabular("Var_iw", var_iw) #logger.record_tabular("Max_ret", max_ret) #logger.record_tabular("Avg_ret", avg_ret) #logger.record_tabular("Var_ret", var_ret) logger.record_tabular("EpLenMean", np.mean(lens)) logger.record_tabular("DiscEpRewMean", np.mean(disc_rews)) logger.record_tabular("EpRewMean", np.mean(rews)) logger.record_tabular("EpThisIter", len(lens)) logger.record_tabular("J_hat", J) logger.record_tabular("Var_J", var_J) episodes_so_far += len(lens) timesteps_so_far += sum(lens) iters_so_far += 1 logger.record_tabular("EpisodesSoFar", episodes_so_far) logger.record_tabular("TimestepsSoFar", timesteps_so_far) logger.record_tabular("TimeElapsed", time.time() - tstart) if rank==0: logger.dump_tabular()
def learn(*, network, env, total_timesteps, timesteps_per_batch=1024, # what to train on max_kl=0.001, cg_iters=10, gamma=0.99, lam=1.0, # advantage estimation seed=None, ent_coef=0.0, cg_damping=1e-2, vf_stepsize=3e-4, vf_iters =3, max_episodes=0, max_iters=0, # time constraint callback=None, load_path=None, **network_kwargs ): ''' learn a policy function with TRPO algorithm Parameters: ---------- network neural network to learn. Can be either string ('mlp', 'cnn', 'lstm', 'lnlstm' for basic types) or function that takes input placeholder and returns tuple (output, None) for feedforward nets or (output, (state_placeholder, state_output, mask_placeholder)) for recurrent nets env environment (one of the gym environments or wrapped via baselines.common.vec_env.VecEnv-type class timesteps_per_batch timesteps per gradient estimation batch max_kl max KL divergence between old policy and new policy ( KL(pi_old || pi) ) ent_coef coefficient of policy entropy term in the optimization objective cg_iters number of iterations of conjugate gradient algorithm cg_damping conjugate gradient damping vf_stepsize learning rate for adam optimizer used to optimie value function loss vf_iters number of iterations of value function optimization iterations per each policy optimization step total_timesteps max number of timesteps max_episodes max number of episodes max_iters maximum number of policy optimization iterations callback function to be called with (locals(), globals()) each policy optimization step load_path str, path to load the model from (default: None, i.e. no model is loaded) **network_kwargs keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network Returns: ------- learnt model ''' if MPI is not None: nworkers = MPI.COMM_WORLD.Get_size() rank = MPI.COMM_WORLD.Get_rank() else: nworkers = 1 rank = 0 cpus_per_worker = 1 U.get_session(config=tf.ConfigProto( allow_soft_placement=True, inter_op_parallelism_threads=cpus_per_worker, intra_op_parallelism_threads=cpus_per_worker )) policy = build_policy(env, network, value_network='copy', **network_kwargs) set_global_seeds(seed) np.set_printoptions(precision=3) # Setup losses and stuff # ---------------------------------------- ob_space = env.observation_space ac_space = env.action_space ob = observation_placeholder(ob_space) with tf.variable_scope("pi"): pi = policy(observ_placeholder=ob) with tf.variable_scope("oldpi"): oldpi = policy(observ_placeholder=ob) atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable) ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return ac = pi.pdtype.sample_placeholder([None]) kloldnew = oldpi.pd.kl(pi.pd) ent = pi.pd.entropy() meankl = tf.reduce_mean(kloldnew) meanent = tf.reduce_mean(ent) entbonus = ent_coef * meanent vferr = tf.reduce_mean(tf.square(pi.vf - ret)) ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold surrgain = tf.reduce_mean(ratio * atarg) optimgain = surrgain + entbonus losses = [optimgain, meankl, entbonus, surrgain, meanent] loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"] dist = meankl all_var_list = get_trainable_variables("pi") # var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("pol")] # vf_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("vf")] var_list = get_pi_trainable_variables("pi") vf_var_list = get_vf_trainable_variables("pi") vfadam = MpiAdam(vf_var_list) get_flat = U.GetFlat(var_list) set_from_flat = U.SetFromFlat(var_list) klgrads = tf.gradients(dist, var_list) flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan") shapes = [var.get_shape().as_list() for var in var_list] start = 0 tangents = [] for shape in shapes: sz = U.intprod(shape) tangents.append(tf.reshape(flat_tangent[start:start+sz], shape)) start += sz gvp = tf.add_n([tf.reduce_sum(g*tangent) for (g, tangent) in zipsame(klgrads, tangents)]) #pylint: disable=E1111 fvp = U.flatgrad(gvp, var_list) assign_old_eq_new = U.function([],[], updates=[tf.assign(oldv, newv) for (oldv, newv) in zipsame(get_variables("oldpi"), get_variables("pi"))]) compute_losses = U.function([ob, ac, atarg], losses) compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)]) compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp) compute_vflossandgrad = U.function([ob, ret], U.flatgrad(vferr, vf_var_list)) @contextmanager def timed(msg): if rank == 0: print(colorize(msg, color='magenta')) tstart = time.time() yield print(colorize("done in %.3f seconds"%(time.time() - tstart), color='magenta')) else: yield def allmean(x): assert isinstance(x, np.ndarray) if MPI is not None: out = np.empty_like(x) MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM) out /= nworkers else: out = np.copy(x) return out U.initialize() if load_path is not None: pi.load(load_path) th_init = get_flat() if MPI is not None: MPI.COMM_WORLD.Bcast(th_init, root=0) set_from_flat(th_init) vfadam.sync() print("Init param sum", th_init.sum(), flush=True) # Prepare for rollouts # ---------------------------------------- seg_gen = traj_segment_generator(pi, env, timesteps_per_batch, stochastic=True) episodes_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 tstart = time.time() lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards if sum([max_iters>0, total_timesteps>0, max_episodes>0])==0: # noththing to be done return pi assert sum([max_iters>0, total_timesteps>0, max_episodes>0]) < 2, \ 'out of max_iters, total_timesteps, and max_episodes only one should be specified' while True: if callback: callback(locals(), globals()) if total_timesteps and timesteps_so_far >= total_timesteps: break elif max_episodes and episodes_so_far >= max_episodes: break elif max_iters and iters_so_far >= max_iters: break logger.log("********** Iteration %i ************"%iters_so_far) with timed("sampling"): seg = seg_gen.__next__() add_vtarg_and_adv(seg, gamma, lam) # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets)) ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"] vpredbefore = seg["vpred"] # predicted value function before udpate atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret) if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy args = seg["ob"], seg["ac"], atarg fvpargs = [arr[::5] for arr in args] def fisher_vector_product(p): return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p assign_old_eq_new() # set old parameter values to new parameter values with timed("computegrad"): *lossbefore, g = compute_lossandgrad(*args) lossbefore = allmean(np.array(lossbefore)) g = allmean(g) if np.allclose(g, 0): logger.log("Got zero gradient. not updating") else: with timed("cg"): stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank==0) assert np.isfinite(stepdir).all() shs = .5*stepdir.dot(fisher_vector_product(stepdir)) lm = np.sqrt(shs / max_kl) # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g)) fullstep = stepdir / lm expectedimprove = g.dot(fullstep) surrbefore = lossbefore[0] stepsize = 1.0 thbefore = get_flat() for _ in range(10): thnew = thbefore + fullstep * stepsize set_from_flat(thnew) meanlosses = surr, kl, *_ = allmean(np.array(compute_losses(*args))) improve = surr - surrbefore logger.log("Expected: %.3f Actual: %.3f"%(expectedimprove, improve)) if not np.isfinite(meanlosses).all(): logger.log("Got non-finite value of losses -- bad!") elif kl > max_kl * 1.5: logger.log("violated KL constraint. shrinking step.") elif improve < 0: logger.log("surrogate didn't improve. shrinking step.") else: logger.log("Stepsize OK!") break stepsize *= .5 else: logger.log("couldn't compute a good step") set_from_flat(thbefore) if nworkers > 1 and iters_so_far % 20 == 0: paramsums = MPI.COMM_WORLD.allgather((thnew.sum(), vfadam.getflat().sum())) # list of tuples assert all(np.allclose(ps, paramsums[0]) for ps in paramsums[1:]) for (lossname, lossval) in zip(loss_names, meanlosses): logger.record_tabular(lossname, lossval) with timed("vf"): for _ in range(vf_iters): for (mbob, mbret) in dataset.iterbatches((seg["ob"], seg["tdlamret"]), include_final_partial_batch=False, batch_size=64): g = allmean(compute_vflossandgrad(mbob, mbret)) vfadam.update(g, vf_stepsize) logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret)) lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values if MPI is not None: listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples else: listoflrpairs = [lrlocal] lens, rews = map(flatten_lists, zip(*listoflrpairs)) lenbuffer.extend(lens) rewbuffer.extend(rews) logger.record_tabular("EpLenMean", np.mean(lenbuffer)) logger.record_tabular("EpRewMean", np.mean(rewbuffer)) logger.record_tabular("EpThisIter", len(lens)) episodes_so_far += len(lens) timesteps_so_far += sum(lens) iters_so_far += 1 logger.record_tabular("EpisodesSoFar", episodes_so_far) logger.record_tabular("TimestepsSoFar", timesteps_so_far) logger.record_tabular("TimeElapsed", time.time() - tstart) if rank==0: logger.dump_tabular() return pi
def learn(env, policy_func, reward_giver, expert_dataset, rank, g_step, d_step, entcoeff, save_per_iter, timesteps_per_batch, ckpt_dir, log_dir, task_name, gamma, lam, max_kl, cg_iters, cg_damping=1e-2, vf_stepsize=3e-4, d_stepsize=3e-4, vf_iters=3, max_timesteps=0, max_episodes=0, max_iters=0, callback=None): nworkers = MPI.COMM_WORLD.Get_size() rank = MPI.COMM_WORLD.Get_rank() np.set_printoptions(precision=3) # Setup losses and stuff # ---------------------------------------- ob_space = env.observation_space ac_space = env.action_space pi = policy_func("pi", ob_space, ac_space) saver = tf.train.Saver( var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='pi')) saver.restore(tf.get_default_session(), U_.getPath() + '/model/bc.ckpt') oldpi = policy_func("oldpi", ob_space, ac_space) atarg = tf.placeholder( dtype=tf.float32, shape=[None]) # Target advantage function (if applicable) ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return ob = U.get_placeholder_cached(name="ob") ac = pi.pdtype.sample_placeholder([None]) kloldnew = oldpi.pd.kl(pi.pd) ent = pi.pd.entropy() meankl = tf.reduce_mean(kloldnew) meanent = tf.reduce_mean(ent) entbonus = entcoeff * meanent vferr = tf.reduce_mean(tf.square(pi.vpred - ret)) ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold surrgain = tf.reduce_mean(ratio * atarg) optimgain = surrgain + entbonus losses = [optimgain, meankl, entbonus, surrgain, meanent] loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"] dist = meankl all_var_list = pi.get_trainable_variables() var_list = [ v for v in all_var_list if v.name.startswith("pi/pol") or v.name.startswith("pi/logstd") ] vf_var_list = [v for v in all_var_list if v.name.startswith("pi/vff")] assert len(var_list) == len(vf_var_list) + 1 d_adam = MpiAdam(reward_giver.get_trainable_variables()) vfadam = MpiAdam(vf_var_list) get_flat = U.GetFlat(var_list) set_from_flat = U.SetFromFlat(var_list) klgrads = tf.gradients(dist, var_list) flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan") shapes = [var.get_shape().as_list() for var in var_list] start = 0 tangents = [] for shape in shapes: sz = U.intprod(shape) tangents.append(tf.reshape(flat_tangent[start:start + sz], shape)) start += sz gvp = tf.add_n([ tf.reduce_sum(g * tangent) for (g, tangent) in zipsame(klgrads, tangents) ]) # pylint: disable=E1111 fvp = U.flatgrad(gvp, var_list) assign_old_eq_new = U.function( [], [], updates=[ tf.assign(oldv, newv) for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables()) ]) compute_losses = U.function([ob, ac, atarg], losses) compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)]) compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp) compute_vflossandgrad = U.function([ob, ret], U.flatgrad(vferr, vf_var_list)) @contextmanager def timed(msg): if rank == 0: print(colorize(msg, color='magenta')) tstart = time.time() yield print( colorize("done in %.3f seconds" % (time.time() - tstart), color='magenta')) else: yield def allmean(x): assert isinstance(x, np.ndarray) out = np.empty_like(x) MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM) out /= nworkers return out U.initialize() th_init = get_flat() MPI.COMM_WORLD.Bcast(th_init, root=0) set_from_flat(th_init) d_adam.sync() vfadam.sync() if rank == 0: print("Init param sum", th_init.sum()) # Prepare for rollouts # ---------------------------------------- seg_gen = traj_segment_generator(pi, env, reward_giver, timesteps_per_batch, stochastic=True) episodes_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 tstart = time.time() lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards true_rewbuffer = deque(maxlen=40) assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1 g_loss_stats = stats(loss_names) d_loss_stats = stats(reward_giver.loss_name) ep_stats = stats(["True_rewards", "Rewards", "Episode_length"]) # if provide pretrained weight while True: if callback: callback(locals(), globals()) if max_timesteps and timesteps_so_far >= max_timesteps: break elif max_episodes and episodes_so_far >= max_episodes: break elif max_iters and iters_so_far >= max_iters: break # Save model if rank == 0 and iters_so_far % save_per_iter == 0 and ckpt_dir is not None: fname = os.path.join(ckpt_dir, task_name) print('save model as ', fname) try: os.makedirs(os.path.dirname(fname)) except OSError: # folder already exists pass saver = tf.train.Saver() saver.save(tf.get_default_session(), fname) print("********** Iteration %i ************" % iters_so_far) def fisher_vector_product(p): return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p # ------------------ Update G ------------------ print("Optimizing Policy...") for _ in range(g_step): with timed("sampling"): seg = seg_gen.next() add_vtarg_and_adv(seg, gamma, lam) # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets)) ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[ "tdlamret"] vpredbefore = seg[ "vpred"] # predicted value function before udpate atarg = (atarg - atarg.mean()) / atarg.std( ) # standardized advantage function estimate if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy args = seg["ob"], seg["ac"], atarg fvpargs = [arr[::5] for arr in args] assign_old_eq_new( ) # set old parameter values to new parameter values with timed("computegrad"): tmp_result = compute_lossandgrad(seg["ob"], seg["ac"], atarg) lossbefore = tmp_result[:-1] g = tmp_result[-1] lossbefore = allmean(np.array(lossbefore)) g = allmean(g) if np.allclose(g, 0): print("Got zero gradient. not updating") else: with timed("cg"): stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank == 0) assert np.isfinite(stepdir).all() shs = .5 * stepdir.dot(fisher_vector_product(stepdir)) lm = np.sqrt(shs / max_kl) # print("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g)) fullstep = stepdir / lm expectedimprove = g.dot(fullstep) surrbefore = lossbefore[0] stepsize = 1.0 thbefore = get_flat() for _ in range(10): thnew = thbefore + fullstep * stepsize set_from_flat(thnew) meanlosses = allmean( np.array(compute_losses(seg["ob"], seg["ac"], atarg))) surr = meanlosses[0] kl = meanlosses[1] improve = surr - surrbefore print("Expected: %.3f Actual: %.3f" % (expectedimprove, improve)) if not np.isfinite(meanlosses).all(): print("Got non-finite value of losses -- bad!") elif kl > max_kl * 1.5: print("violated KL constraint. shrinking step.") elif improve < 0: print("surrogate didn't improve. shrinking step.") else: print("Stepsize OK!") break stepsize *= .5 else: print("couldn't compute a good step") set_from_flat(thbefore) if nworkers > 1 and iters_so_far % 20 == 0: paramsums = MPI.COMM_WORLD.allgather( (thnew.sum(), vfadam.getflat().sum())) # list of tuples assert all( np.allclose(ps, paramsums[0]) for ps in paramsums[1:]) with timed("vf"): for _ in range(vf_iters): for (mbob, mbret) in dataset.iterbatches( (seg["ob"], seg["tdlamret"]), include_final_partial_batch=False, batch_size=128): if hasattr(pi, "ob_rms"): pi.ob_rms.update( mbob) # update running mean/std for policy g = allmean(compute_vflossandgrad(mbob, mbret)) vfadam.update(g, vf_stepsize) print("ev_tdlam_before", explained_variance(vpredbefore, tdlamret)) # ------------------ Update D ------------------ print("Optimizing Discriminator...") print(fmt_row(13, reward_giver.loss_name)) ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob)) batch_size = len(ob) // d_step d_losses = [ ] # list of tuples, each of which gives the loss for a minibatch for ob_batch, ac_batch in tqdm( dataset.iterbatches((ob, ac), include_final_partial_batch=False, batch_size=batch_size)): ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob_batch)) # update running mean/std for reward_giver if hasattr(reward_giver, "obs_rms"): reward_giver.obs_rms.update( np.concatenate((ob_batch, ob_expert), 0)) tmp_result = reward_giver.lossandgrad(ob_batch, ac_batch, ob_expert, ac_expert) newlosses = tmp_result[:-1] g = tmp_result[-1] d_adam.update(allmean(g), d_stepsize) d_losses.append(newlosses) print(fmt_row(13, np.mean(d_losses, axis=0))) timesteps_so_far += len(seg['ob']) iters_so_far += 1 print("EpisodesSoFar", episodes_so_far) print("TimestepsSoFar", timesteps_so_far) print("TimeElapsed", time.time() - tstart)
def learn(*, network, env, reward_giver, expert_dataset, g_step, d_step, d_stepsize=3e-4, total_timesteps, eval_env=None, seed=None, nsteps=2048, ent_coef=0.0, lr=3e-4, vf_coef=0.5, max_grad_norm=0.5, gamma=0.99, lam=0.95, log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2, save_interval=0, load_path=None, model_fn=None, update_fn=None, init_fn=None, mpi_rank_weight=1, comm=None, **network_kwargs): # from PPO learn set_global_seeds(seed) if isinstance(lr, float): lr = constfn(lr) else: assert callable(lr) if isinstance(cliprange, float): cliprange = constfn(cliprange) else: assert callable(cliprange) total_timesteps = int(total_timesteps) policy = build_policy(env, network, **network_kwargs) # nenvs = env.num_envs nenvs = 1 ob_space = env.observation_space ac_space = env.action_space nbatch = nenvs * nsteps nbatch_train = nbatch // nminibatches is_mpi_root = (MPI is None or MPI.COMM_WORLD.Get_rank() == 0) if model_fn is None: from baselines.ppo2.model import Model model_fn = Model model = model_fn(policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train, nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, comm=comm, mpi_rank_weight=mpi_rank_weight) if load_path is not None: model.load(load_path) runner = Runner(env=env, model=model, nsteps=nsteps, gamma=gamma, lam=lam, reward_giver=reward_giver) if eval_env is not None: eval_runner = Runner(env=eval_env, model=model, nsteps=nsteps, gamma=gamma, lam=lam) epinfobuf = deque(maxlen=100) if eval_env is not None: eval_epinfobuf = deque(maxlen=100) if init_fn is not None: init_fn() tfirststart = time.perf_counter() nupdates = total_timesteps // nbatch # from TRPO MPI nworkers = MPI.COMM_WORLD.Get_size() ob = model.act_model.X ac = model.A d_adam = MpiAdam(reward_giver.get_trainable_variables()) def allmean(x): assert isinstance(x, np.ndarray) out = np.empty_like(x) MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM) out /= nworkers return out # from PPO for update in range(1, nupdates + 1): assert nbatch % nminibatches == 0 tstart = time.perf_counter() frac = 1.0 - (update - 1.0) / nupdates lrnow = lr(frac) cliprangenow = cliprange(frac) logger.log("Optimizing Policy...") for _ in range(g_step): if update % log_interval == 0 and is_mpi_root: logger.info('Stepping environment...') obs, returns, masks, actions, values, neglogpacs, states, epinfos = runner.run( ) if eval_env is not None: eval_obs, eval_returns, eval_masks, eval_actions, eval_values, eval_neglogpacs, eval_states, eval_epinfos = eval_runner.run( ) if update % log_interval == 0 and is_mpi_root: logger.info('Done.') epinfobuf.extend(epinfos) if eval_env is not None: eval_epinfobuf.extend(eval_epinfos) mblossvals = [] if states is None: inds = np.arange(nbatch) for _ in range(noptepochs): np.random.shuffle(inds) for start in range(0, nbatch, nbatch_train): end = start + nbatch_train mbinds = inds[start:end] slices = (arr[mbinds] for arr in (obs, returns, masks, actions, values, neglogpacs)) mblossvals.append( model.train(lrnow, cliprangenow, *slices)) else: assert False # make sure we're not going here, so any bugs aren't from here assert nenvs % nminibatches == 0 envsperbatch = nenvs // nminibatches envinds = np.arange(nenvs) flatinds = np.arange(nenvs * nsteps).reshape(nenvs, nsteps) for _ in range(noptepochs): np.random.shuffle(envinds) for start in range(0, nenvs, envsperbatch): end = start + envsperbatch mbenvinds = envinds[start:end] mbflatinds = flatinds[mbenvinds].ravel() slices = (arr[mbflatinds] for arr in (obs, returns, masks, actions, values, neglogpacs)) mbstates = states[mbenvinds] mblossvals.append( model.train(lrnow, cliprangenow, *slices, mbstates)) lossvals = np.mean(mblossvals, axis=0) tnow = time.perf_counter() fps = int(nbatch / (tnow - tstart)) # TRPO MPI logger.log("Optimizing Disciminator...") logger.log(fmt_row(13, reward_giver.loss_name)) ob_expert, ac_expert = expert_dataset.get_next_batch(len(obs)) batch_size = len(obs) // d_step d_losses = [] for ob_batch, ac_batch in dataset.iterbatches( (obs, actions), include_final_partial_batch=False, batch_size=batch_size): ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob_batch)) if hasattr(reward_giver, "obs_rms"): reward_giver.obs_rms.update( np.concatenate((ob_batch, ob_expert), 0)) *newlosses, g = reward_giver.lossandgrad(ob_batch, ac_batch, ob_expert, ac_expert) d_adam.update(allmean(g), d_stepsize) d_losses.append(newlosses) logger.log(fmt_row(13, np.mean(d_losses, axis=0))) if update_fn is not None: update_fn(update) if update % log_interval == 0 or update == 1: ev = explained_variance(values, returns) logger.logkv("misc/serial_timesteps", update * nsteps) logger.logkv("misc/nupdates", update) logger.logkv("misc/total_timesteps", update * nbatch) logger.logkv("fps", fps) logger.logkv("misc/explained_variance", float(ev)) logger.logkv("eprewmean", safemean([epinfo['r'] for epinfo in epinfobuf])) logger.logkv("eplenmean", safemean([epinfo['l'] for epinfo in epinfobuf])) if eval_env is not None: logger.logkv( "eval_eprewmean", safemean([epinfo['r'] for epinfo in eval_epinfobuf])) logger.logkv( "eval_eplenmean", safemean([epinfo['l'] for epinfo in eval_epinfobuf])) logger.logkv("misc/time_elapsed", tnow - tfirststart) for (lossval, lossname) in zip(lossvals, model.loss_names): logger.logkv("loss/" + lossname, lossval) logger.dumpkvs() if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir() and is_mpi_root: checkdir = osp.join(logger.get_dir(), 'checkpoints') os.makedirs(checkdir, exist_ok=True) savepath = osp.join(checkdir, '%.5i' % update) print("Saving to", savepath) model.save(savepath) return model
def learn(env, policy_func, reward_giver, expert_dataset, rank, pretrained, pretrained_weight, *, clip_param, g_step, d_step, entcoeff, save_per_iter, optim_epochs, optim_stepsize, optim_batchsize,# optimization hypers ckpt_dir, log_dir, timesteps_per_batch, task_name, gamma, lam, d_stepsize=3e-4, adam_epsilon=1e-5, max_timesteps=0, max_episodes=0, max_iters=0, mix_reward=False, r_lambda=0.44, callback=None, schedule='constant', # annealing for stepsize parameters (epsilon and adam), frame_stack=1 ): nworkers = MPI.COMM_WORLD.Get_size() rank = MPI.COMM_WORLD.Get_rank() np.set_printoptions(precision=3) # Setup losses and stuff # ---------------------------------------- ob_space = env.observation_space ob_space.shape = (ob_space.shape[0] * frame_stack,) print(ob_space) ac_space = env.action_space pi = policy_func("pi", ob_space, ac_space, reuse=(pretrained_weight != None)) oldpi = policy_func("oldpi", ob_space, ac_space) atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable) ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return lrmult = tf.placeholder(name='lrmult', dtype=tf.float32, shape=[]) # learning rate multiplier, updated with schedule clip_param = clip_param * lrmult # Annealed cliping parameter epislon ob = U.get_placeholder_cached(name="ob") ac = pi.pdtype.sample_placeholder([None]) kloldnew = oldpi.pd.kl(pi.pd) ent = pi.pd.entropy() meankl = tf.reduce_mean(kloldnew) meanent = tf.reduce_mean(ent) pol_entpen = (-entcoeff) * meanent # kloldnew = oldpi.pd.kl(pi.pd) # ent = pi.pd.entropy() # meankl = tf.reduce_mean(kloldnew) # meanent = tf.reduce_mean(ent) # entbonus = entcoeff * meanent ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # pnew / pold surr1 = ratio * atarg # surrogate from conservative policy iteration surr2 = tf.clip_by_value(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg # pol_surr = - tf.reduce_mean(tf.minimum(surr1, surr2)) # PPO's pessimistic surrogate (L^CLIP) vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret)) total_loss = pol_surr + pol_entpen + vf_loss losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent] loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"] # vferr = tf.reduce_mean(tf.square(pi.vpred - ret)) # ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold # surrgain = tf.reduce_mean(ratio * atarg) # optimgain = surrgain + entbonus # losses = [optimgain, meankl, entbonus, surrgain, meanent] # loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"] var_list = pi.get_trainable_variables() lossandgrad = U.function([ob, ac, atarg, ret, lrmult], losses + [U.flatgrad(total_loss, var_list)]) adam = MpiAdam(var_list, epsilon=adam_epsilon) d_adam = MpiAdam(reward_giver.get_trainable_variables()) assign_old_eq_new = U.function([],[], updates=[tf.assign(oldv, newv) for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())]) compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses) # dist = meankl # all_var_list = pi.get_trainable_variables() # var_list = [v for v in all_var_list if v.name.startswith("pi/pol") or v.name.startswith("pi/logstd")] # vf_var_list = [v for v in all_var_list if v.name.startswith("pi/vff")] # assert len(var_list) == len(vf_var_list) + 1 # d_adam = MpiAdam(reward_giver.get_trainable_variables()) # vfadam = MpiAdam(vf_var_list) # get_flat = U.GetFlat(var_list) # set_from_flat = U.SetFromFlat(var_list) # klgrads = tf.gradients(dist, var_list) # flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan") # shapes = [var.get_shape().as_list() for var in var_list] # start = 0 # tangents = [] # for shape in shapes: # sz = U.intprod(shape) # tangents.append(tf.reshape(flat_tangent[start:start+sz], shape)) # start += sz # gvp = tf.add_n([tf.reduce_sum(g*tangent) for (g, tangent) in zipsame(klgrads, tangents)]) # pylint: disable=E1111 # fvp = U.flatgrad(gvp, var_list) # assign_old_eq_new = U.function([], [], updates=[tf.assign(oldv, newv) # for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())]) # compute_losses = U.function([ob, ac, atarg], losses) # compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)]) # compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp) # compute_vflossandgrad = U.function([ob, ret], U.flatgrad(vferr, vf_var_list)) if rank == 0: generator_loss = tf.placeholder(tf.float32, [], name='generator_loss') expert_loss = tf.placeholder(tf.float32, [], name='expert_loss') entropy = tf.placeholder(tf.float32, [], name='entropy') entropy_loss = tf.placeholder(tf.float32, [], name='entropy_loss') generator_acc = tf.placeholder(tf.float32, [], name='genrator_acc') expert_acc = tf.placeholder(tf.float32, [], name='expert_acc') eplenmean = tf.placeholder(tf.int32, [], name='eplenmean') eprewmean = tf.placeholder(tf.float32, [], name='eprewmean') eptruerewmean = tf.placeholder(tf.float32, [], name='eptruerewmean') # _meankl = tf.placeholder(tf.float32, [], name='meankl') # _optimgain = tf.placeholder(tf.float32, [], name='optimgain') # _surrgain = tf.placeholder(tf.float32, [], name='surrgain') _ops_to_merge = [generator_loss, expert_loss, entropy, entropy_loss, generator_acc, expert_acc, eplenmean, eprewmean, eptruerewmean] ops_to_merge = [ tf.summary.scalar(op.name, op) for op in _ops_to_merge] merged = tf.summary.merge(ops_to_merge) ### TODO: report these stats ### # generator_loss = tf.placeholder(tf.float32, [], name='generator_loss') # expert_loss = tf.placeholder(tf.float32, [], name='expert_loss') # generator_acc = tf.placeholder(tf.float32, [], name='genrator_acc') # expert_acc = tf.placeholder(tf.float32, [], name='expert_acc') # eplenmean = tf.placeholder(tf.int32, [], name='eplenmean') # eprewmean = tf.placeholder(tf.float32, [], name='eprewmean') # eptruerewmean = tf.placeholder(tf.float32, [], name='eptruerewmean') @contextmanager def timed(msg): if rank == 0: print(colorize(msg, color='magenta')) tstart = time.time() yield print(colorize("done in %.3f seconds" % (time.time() - tstart), color='magenta')) else: yield def allmean(x): assert isinstance(x, np.ndarray) out = np.empty_like(x) MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM) out /= nworkers return out U.initialize() adam.sync() d_adam.sync() # Prepare for rollouts # ---------------------------------------- seg_gen = traj_segment_generator(pi, env, reward_giver, timesteps_per_batch, mix_reward, r_lambda, stochastic=True, frame_stack=frame_stack) episodes_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 tstart = time.time() lenbuffer = deque(maxlen=100) # rolling buffer for episode lengths rewbuffer = deque(maxlen=100) # rolling buffer for episode rewards true_rewbuffer = deque(maxlen=100) assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1 g_loss_stats = stats(loss_names) d_loss_stats = stats(reward_giver.loss_name) ep_stats = stats(["True_rewards", "Rewards", "Episode_length"]) # if provide pretrained weight if pretrained_weight is not None: U.load_state(pretrained_weight, var_list=pi.get_variables()) if rank == 0: filenames = [f for f in os.listdir(log_dir) if 'logs' in f] writer = tf.summary.FileWriter('{}/logs-{}'.format(log_dir, len(filenames))) while True: if callback: callback(locals(), globals()) if max_timesteps and timesteps_so_far >= max_timesteps: break elif max_episodes and episodes_so_far >= max_episodes: break elif max_iters and iters_so_far >= max_iters: break # Save model if rank == 0 and iters_so_far % save_per_iter == 0 and ckpt_dir is not None: fname = os.path.join(ckpt_dir, task_name) os.makedirs(os.path.dirname(fname), exist_ok=True) from tensorflow.core.protobuf import saver_pb2 saver = tf.train.Saver(write_version=saver_pb2.SaverDef.V1) saver.save(tf.get_default_session(), fname) # U.save_state(fname) if schedule == 'constant': cur_lrmult = 1.0 elif schedule == 'linear': cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0) else: raise NotImplementedError logger.log("********** Iteration %i ************" % iters_so_far) # ------------------ Update G ------------------ logger.log("Optimizing Policy...") for _ in range(g_step): with timed("sampling"): seg = seg_gen.__next__() add_vtarg_and_adv(seg, gamma, lam) # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets)) ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"] vpredbefore = seg["vpred"] # predicted value function before udpate atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret), shuffle=not pi.recurrent) optim_batchsize = optim_batchsize or ob.shape[0] # # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets)) # ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"] # vpredbefore = seg["vpred"] # predicted value function before udpate # atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy # args = seg["ob"], seg["ac"], atarg # fvpargs = [arr[::5] for arr in args] assign_old_eq_new() # set old parameter values to new parameter values with timed("policy optimization"): logger.log("Optimizing...") logger.log(fmt_row(13, loss_names)) # Here we do a bunch of optimization epochs over the data for _ in range(optim_epochs): losses = [] # list of tuples, each of which gives the loss for a minibatch for batch in d.iterate_once(optim_batchsize): *newlosses, g = lossandgrad(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult) adam.update(g, optim_stepsize * cur_lrmult) losses.append(newlosses) logger.log(fmt_row(13, np.mean(losses, axis=0))) logger.log("Evaluating losses...") losses = [] for batch in d.iterate_once(optim_batchsize): newlosses = compute_losses(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult) losses.append(newlosses) meanlosses,_,_ = mpi_moments(losses, axis=0) logger.log(fmt_row(13, meanlosses)) for (lossval, name) in zipsame(meanlosses, loss_names): logger.record_tabular("loss_"+name, lossval) logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret)) # g_losses = meanlosses # for (lossname, lossval) in zip(loss_names, meanlosses): # logger.record_tabular(lossname, lossval) # logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret)) # ------------------ Update D ------------------ logger.log("Optimizing Discriminator...") logger.log(fmt_row(13, reward_giver.loss_name)) ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob)) batch_size = len(ob) // d_step d_losses = [] # list of tuples, each of which gives the loss for a minibatch for _ in range(optim_epochs // 10): for ob_batch, ac_batch in dataset.iterbatches((ob, ac), include_final_partial_batch=False, batch_size=batch_size): ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob_batch)) # update running mean/std for reward_giver ob_batch = ob_batch[:, -ob_expert.shape[1]:][:-30] if hasattr(reward_giver, "obs_rms"): reward_giver.obs_rms.update(np.concatenate((ob_batch, ob_expert), 0)[:, :-30]) # *newlosses, g = reward_giver.lossandgrad(ob_batch, ac_batch, ob_expert, ac_expert) *newlosses, g = reward_giver.lossandgrad(ob_batch[:, :-30], ob_expert[:, :-30]) d_adam.update(allmean(g), d_stepsize) d_losses.append(newlosses) logger.log(fmt_row(13, np.mean(d_losses, axis=0))) lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ep_true_rets"]) # local values listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples lens, rews, true_rets = map(flatten_lists, zip(*listoflrpairs)) true_rewbuffer.extend(true_rets) lenbuffer.extend(lens) rewbuffer.extend(rews) logger.record_tabular("EpLenMean", np.mean(lenbuffer)) logger.record_tabular("EpRewMean", np.mean(rewbuffer)) logger.record_tabular("EpTrueRewMean", np.mean(true_rewbuffer)) logger.record_tabular("EpThisIter", len(lens)) episodes_so_far += len(lens) timesteps_so_far += sum(lens) iters_so_far += 1 logger.record_tabular("EpisodesSoFar", episodes_so_far) logger.record_tabular("TimestepsSoFar", timesteps_so_far) logger.record_tabular("TimeElapsed", time.time() - tstart) if rank == 0 and iters_so_far % 10 == 0: disc_losses = np.mean(d_losses, axis=0) res = tf.get_default_session().run(merged, feed_dict={ generator_loss: disc_losses[0], expert_loss: disc_losses[1], entropy: disc_losses[2], entropy_loss: disc_losses[3], generator_acc: disc_losses[4], expert_acc: disc_losses[5], eplenmean: np.mean(lenbuffer), eprewmean: np.mean(rewbuffer), eptruerewmean: np.mean(true_rewbuffer), }) writer.add_summary(res, iters_so_far) writer.flush() if rank == 0: logger.dump_tabular()
def learn(env, policy_func, reward_giver, expert_dataset, rank, pretrained, pretrained_weight, *, g_step, d_step, entcoeff, ckpt_dir, timesteps_per_batch, task_name, gamma, lam, max_kl, cg_iters, cg_damping=1e-2, vf_stepsize=3e-4, d_stepsize=3e-4, vf_iters=3, max_timesteps=0, max_episodes=0, max_iters=0, rnd_iter=200, callback=None, dyn_norm=False, mmd=False): nworkers = MPI.COMM_WORLD.Get_size() rank = MPI.COMM_WORLD.Get_rank() np.set_printoptions(precision=3) # Setup losses and stuff # ---------------------------------------- ob_space = env.observation_space ac_space = env.action_space pi = policy_func("pi", ob_space, ac_space) oldpi = policy_func("oldpi", ob_space, ac_space) atarg = tf.placeholder( dtype=tf.float32, shape=[None]) # Target advantage function (if applicable) ob = U.get_placeholder_cached(name="ob") ac = pi.pdtype.sample_placeholder([None]) kloldnew = oldpi.pd.kl(pi.pd) ent = pi.pd.entropy() meankl = tf.reduce_mean(kloldnew) meanent = tf.reduce_mean(ent) entbonus = entcoeff * meanent ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold surrgain = tf.reduce_mean(ratio * atarg) optimgain = surrgain + entbonus losses = [optimgain, meankl, entbonus, surrgain, meanent] loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"] dist = meankl all_var_list = pi.get_trainable_variables() var_list = [ v for v in all_var_list if v.name.startswith("pi/pol") or v.name.startswith("pi/logstd") ] vf_var_list = [v for v in all_var_list if v.name.startswith("pi/vff")] vfadam = MpiAdam(vf_var_list) get_flat = U.GetFlat(var_list) set_from_flat = U.SetFromFlat(var_list) klgrads = tf.gradients(dist, var_list) flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan") shapes = [var.get_shape().as_list() for var in var_list] start = 0 tangents = [] for shape in shapes: sz = U.intprod(shape) tangents.append(tf.reshape(flat_tangent[start:start + sz], shape)) start += sz gvp = tf.add_n([ tf.reduce_sum(g * tangent) for (g, tangent) in zipsame(klgrads, tangents) ]) # pylint: disable=E1111 fvp = U.flatgrad(gvp, var_list) assign_old_eq_new = U.function( [], [], updates=[ tf.assign(oldv, newv) for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables()) ]) compute_losses = U.function([ob, ac, atarg], losses) compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)]) compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp) compute_vflossandgrad = pi.vlossandgrad def allmean(x): assert isinstance(x, np.ndarray) out = np.empty_like(x) MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM) out /= nworkers return out U.initialize() th_init = get_flat() MPI.COMM_WORLD.Bcast(th_init, root=0) set_from_flat(th_init) vfadam.sync() if rank == 0: print("Init param sum", th_init.sum(), flush=True) # Prepare for rollouts # ---------------------------------------- seg_gen = traj_segment_generator(pi, env, reward_giver, timesteps_per_batch, stochastic=True) episodes_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 tstart = time.time() lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards true_rewbuffer = deque(maxlen=40) assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1 ep_stats = stats(["True_rewards", "Rewards", "Episode_length"]) # if provide pretrained weight if pretrained_weight is not None: U.load_variables(pretrained_weight, variables=pi.get_variables()) else: if not dyn_norm: pi.ob_rms.update(expert_dataset[0]) if not mmd: reward_giver.train(*expert_dataset, iter=rnd_iter) #inspect the reward learned # for batch in iterbatches(expert_dataset, batch_size=32): # print(reward_giver.get_reward(*batch)) best = -2000 save_ind = 0 max_save = 3 while True: if callback: callback(locals(), globals()) if max_timesteps and timesteps_so_far >= max_timesteps: break elif max_episodes and episodes_so_far >= max_episodes: break elif max_iters and iters_so_far >= max_iters: break logger.log("********** Iteration %i ************" % iters_so_far) def fisher_vector_product(p): return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p # ------------------ Update G ------------------ # logger.log("Optimizing Policy...") for _ in range(g_step): seg = seg_gen.__next__() #mmd reward if mmd: reward_giver.set_b2(seg["ob"], seg["ac"]) seg["rew"] = reward_giver.get_reward(seg["ob"], seg["ac"]) #report stats and save policy if any good lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ep_true_rets"] ) # local values listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples lens, rews, true_rets = map(flatten_lists, zip(*listoflrpairs)) true_rewbuffer.extend(true_rets) lenbuffer.extend(lens) rewbuffer.extend(rews) true_rew_avg = np.mean(true_rewbuffer) logger.record_tabular("EpLenMean", np.mean(lenbuffer)) logger.record_tabular("EpRewMean", np.mean(rewbuffer)) logger.record_tabular("EpTrueRewMean", true_rew_avg) logger.record_tabular("EpThisIter", len(lens)) episodes_so_far += len(lens) timesteps_so_far += sum(lens) iters_so_far += 1 logger.record_tabular("EpisodesSoFar", episodes_so_far) logger.record_tabular("TimestepsSoFar", timesteps_so_far) logger.record_tabular("TimeElapsed", time.time() - tstart) logger.record_tabular("Best so far", best) # Save model if ckpt_dir is not None and true_rew_avg >= best: best = true_rew_avg fname = os.path.join(ckpt_dir, task_name) os.makedirs(os.path.dirname(fname), exist_ok=True) pi.save_policy(fname + "_" + str(save_ind)) save_ind = (save_ind + 1) % max_save #compute gradient towards next policy add_vtarg_and_adv(seg, gamma, lam) ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[ "tdlamret"] vpredbefore = seg[ "vpred"] # predicted value function before udpate atarg = (atarg - atarg.mean()) / atarg.std( ) # standardized advantage function estimate if hasattr(pi, "ob_rms") and dyn_norm: pi.ob_rms.update(ob) # update running mean/std for policy args = seg["ob"], seg["ac"], atarg fvpargs = [arr[::5] for arr in args] assign_old_eq_new( ) # set old parameter values to new parameter values *lossbefore, g = compute_lossandgrad(*args) lossbefore = allmean(np.array(lossbefore)) g = allmean(g) if np.allclose(g, 0): logger.log("Got zero gradient. not updating") else: stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=False) assert np.isfinite(stepdir).all() shs = .5 * stepdir.dot(fisher_vector_product(stepdir)) lm = np.sqrt(shs / max_kl) fullstep = stepdir / lm expectedimprove = g.dot(fullstep) surrbefore = lossbefore[0] stepsize = 1.0 thbefore = get_flat() for _ in range(10): thnew = thbefore + fullstep * stepsize set_from_flat(thnew) meanlosses = surr, kl, *_ = allmean( np.array(compute_losses(*args))) improve = surr - surrbefore logger.log("Expected: %.3f Actual: %.3f" % (expectedimprove, improve)) if not np.isfinite(meanlosses).all(): logger.log("Got non-finite value of losses -- bad!") elif kl > max_kl * 1.5: logger.log("violated KL constraint. shrinking step.") elif improve < 0: logger.log("surrogate didn't improve. shrinking step.") else: logger.log("Stepsize OK!") break stepsize *= .5 else: logger.log("couldn't compute a good step") set_from_flat(thbefore) if nworkers > 1 and iters_so_far % 20 == 0: paramsums = MPI.COMM_WORLD.allgather( (thnew.sum(), vfadam.getflat().sum())) # list of tuples assert all( np.allclose(ps, paramsums[0]) for ps in paramsums[1:]) if pi.use_popart: pi.update_popart(tdlamret) for _ in range(vf_iters): for (mbob, mbret) in dataset.iterbatches( (seg["ob"], seg["tdlamret"]), include_final_partial_batch=False, batch_size=128): if hasattr(pi, "ob_rms") and dyn_norm: pi.ob_rms.update( mbob) # update running mean/std for policy vfadam.update(allmean(compute_vflossandgrad(mbob, mbret)), vf_stepsize) g_losses = meanlosses for (lossname, lossval) in zip(loss_names, meanlosses): logger.record_tabular(lossname, lossval) logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret)) if rank == 0: logger.dump_tabular()
def hybrid_learn(env, policy_func, reward_giver, rank, *, policy_step, boundary_step, entcoeff, save_per_iter, ckpt_dir, log_dir, timesteps_per_batch, task_name_1, task_name_2, gamma, lam, max_kl, cg_iters, cg_damping=1e-2, vf_stepsize=3e-4, d_stepsize=3e-4, vf_iters=3, max_timesteps=0, max_episodes=0, max_iters=0, callback=None ): nworkers = MPI.COMM_WORLD.Get_size() rank = MPI.COMM_WORLD.Get_rank() np.set_printoptions(precision=3) # Setup losses and stuff # ---------------------------------------- ob_space = env.observation_space ac_space = env.action_space pi_task1 = policy_func("pi_task1", ob_space, ac_space) oldpi_task1 = policy_func("oldpi_task1", ob_space, ac_space) pi_task2 = policy_func("pi_task2", ob_space, ac_space) oldpi_task2 = policy_func("oldpi_task2", ob_space, ac_space) atarg_task1 = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable) ret_task1 = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return atarg_task2 = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable) ret_task2 = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return ob_task1 = U.get_placeholder_cached(name="ob_task1") ac_task1 = pi_task1.pdtype.sample_placeholder([None]) ob_task2 = U.get_placeholder_cached(name="ob_task2") ac_task2 = pi_task2.pdtype.sample_placeholder([None]) kloldnew_task1 = oldpi_task1.pd.kl(pi_task1.pd) ent_task1 = pi_task1.pd.entropy() meankl_task1 = tf.reduce_mean(kloldnew_task1) meanent_task1 = tf.reduce_mean(ent_task1) entbonus_task1 = entcoeff_task1 * meanent_task1 kloldnew_task2 = oldpi_task2.pd.kl(pi_task2.pd) ent_task2 = pi_task2.pd.entropy() meankl_task2 = tf.reduce_mean(kloldnew_task2) meanent_task2 = tf.reduce_mean(ent_task2) entbonus_task2 = entcoeff_task2 * meanent_task2 vferr_task1 = tf.reduce_mean(tf.square(pi_task1.vpred - ret_task1)) vferr_task2 = tf.reduce_mean(tf.square(pi_task2.vpred - ret_task2)) ratio_task1 = tf.exp(pi_task1.pd.logp(ac) - oldpi_task1.pd.logp(ac)) # advantage * pnew / pold ratio_task2 = tf.exp(pi_task2.pd.logp(ac) - oldpi_task2.pd.logp(ac)) # advantage * pnew / pold surrgain_task1 = tf.reduce_mean(ratio_task1 * atarg_task1) surrgain_task2 = tf.reduce_mean(ratio_task2 * atarg_task2) optimgain_task1 = surrgain_task1 + entbonus_task1 optimgain_task2 = surrgain_task2 + entbonus_task2 optimgain = optimgain_task1 + optimgain_task2 meankl = meankl_task1 + meankl_task2 entbonus = entbonus_task1 + entbonus_task2 surrgain = surrgain_task1 + surrgain_task2 meanent = meanent_task1 + meanent_task2 losses_task1 = [optimgain_task1, meankl_task1, entbonus_task1, surrgain_task1, meanent_task1] losses_task2 = [optimgain_task2, meankl_task2, entbonus_task2, surrgain_task2, meanent_task2] # losses = [optimgain_task1, optimgain_task2, meankl_task1, meankl_task2, # entbonus_task1, entbonus_task2, surrgain_task1, surrgain_task2, # meanent_task1, meanent_task2] loss_names = ['optimgain_task1', 'optimgain_task2', 'meankl_task1', 'meankl_task2', 'entbonus_task1', 'entbonus_task2', 'surrgain_task1', 'surrgain_task2', 'meanent_task1', 'meanent_task2'] dist_task1 = meankl_task1 dist_task2 = meankl_task2 all_var_list_task1 = pi_task1.get_trainable_variables() var_list_task1 = [v for v in all_var_list_task1 if v.name.startswith("pi/pol") or v.name.startswith("pi/logstd")] vf_var_list_task1 = [v for v in all_var_list_task1 if v.name.startswith("pi/vff")] assert len(var_list_task1) == len(vf_var_list_task1) + 1 # d_adam = MpiAdam(reward_giver.get_trainable_variables()) vfadam_task1 = MpiAdam(vf_var_list_task1) all_var_list_task2 = pi_task2.get_trainable_variables() var_list_task2 = [v for v in all_var_list_task2 if v.name.startswith("pi/pol") or v.name.startswith("pi/logstd")] vf_var_list_task2 = [v for v in all_var_list_task2 if v.name.startswith("pi/vff")] assert len(var_list_task2) == len(vf_var_list_task2) + 1 # d_adam = MpiAdam(reward_giver.get_trainable_variables()) vfadam_task2 = MpiAdam(vf_var_list_task2) get_flat_task1 = U.GetFlat(var_list_task1) set_from_flat_task1 = U.SetFromFlat(var_list_task1) klgrads_task1 = tf.gradients(dist_task1, var_list_task1) flat_tangent_task1 = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan_task1") shapes_task1 = [var.get_shape().as_list() for var in var_list_task1] start = 0 tangents_task1 = [] for shape in shapes_task1: sz = U.intprod(shape) tangents_task1.append(tf.reshape(flat_tangent_task1[start:start+sz], shape)) start += sz gvp_task1 = tf.add_n([tf.reduce_sum(g*tangent) for (g, tangent) in zipsame(klgrads_task1, tangents_task1)]) # pylint: disable=E1111 fvp_task1 = U.flatgrad(gvp_task1, var_list_task1) get_flat_task2 = U.GetFlat(var_list_task2) set_from_flat_task2 = U.SetFromFlat(var_list_task2) klgrads_task2 = tf.gradients(dist_task2, var_list_task2) flat_tangent_task2 = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan_task2") shapes_task2 = [var.get_shape().as_list() for var in var_list_task2] start = 0 tangents_task2 = [] for shape in shapes_task2: sz = U.intprod(shape) tangents_task2.append(tf.reshape(flat_tangent_task2[start:start+sz], shape)) start += sz gvp_task2 = tf.add_n([tf.reduce_sum(g*tangent) for (g, tangent) in zipsame(klgrads_task2, tangents_task2)]) # pylint: disable=E1111 fvp_task2 = U.flatgrad(gvp_task2, var_list)_task2 assign_old_eq_new_task1 = U.function([], [], updates=[tf.assign(oldv, newv) for (oldv, newv) in zipsame(oldpi_task1.get_variables(), pi_task1.get_variables())]) assign_old_eq_new_task2 = U.function([], [], updates=[tf.assign(oldv, newv) for (oldv, newv) in zipsame(oldpi_task2.get_variables(), pi_task2.get_variables())]) compute_losses_task1 = U.function([ob, ac, atarg_task1], losses_task1) compute_lossandgrad_task1 = U.function([ob, ac, atarg_task1], losses_task1 + [U.flatgrad(optimgain_task1, var_list_task1)]) compute_fvp_task1 = U.function([flat_tangent_task1, ob, ac, atarg_task1], fvp_task1) compute_vflossandgrad_task1 = U.function([ob, ret_task1], U.flatgrad(vferr_task1, vf_var_list_task1)) compute_losses_task2 = U.function([ob, ac, atarg_task2], losses_task2) compute_lossandgrad_task2 = U.function([ob, ac, atarg_task2], losses_task2 + [U.flatgrad(optimgain_task2, var_list_task2)]) compute_fvp_task2 = U.function([flat_tangent_task2, ob, ac, atarg_task2], fvp_task2) compute_vflossandgrad_task2 = U.function([ob, ret_task2], U.flatgrad(vferr_task2, vf_var_list_task2)) @contextmanager def timed(msg): if rank == 0: print(colorize(msg, color='magenta')) tstart = time.time() yield print(colorize("done in %.3f seconds" % (time.time() - tstart), color='magenta')) else: yield def allmean(x): assert isinstance(x, np.ndarray) out = np.empty_like(x) MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM) out /= nworkers return out U.initialize() th_init_task1 = get_flat_task1() MPI.COMM_WORLD.Bcast(th_init_task1, root=0) set_from_flat_task1(th_init_task1) # d_adam.sync() vfadam_task1.sync() th_init_task2 = get_flat_task2() MPI.COMM_WORLD.Bcast(th_init_task2, root=0) set_from_flat_task2(th_init_task2) # d_adam.sync() vfadam_task2.sync() if rank == 0: print("Init param sum", th_init_task1.sum(), flush=True) print("Init param sum", th_init_task2.sum(), flush=True) # Prepare for rollouts # ---------------------------------------- seg_gen = traj_segment_generator_composed(pi_task1, pi_task2, env, boundary_condition, timesteps_per_batch, stochastic=True) episodes_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 tstart = time.time() lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards true_rewbuffer = deque(maxlen=40) assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1 g_loss_stats = stats(loss_names) # d_loss_stats = stats(reward_giver.loss_name) ep_stats = stats(["True_rewards", "Rewards", "Episode_length"]) fname = os.path.join(ckpt_dir, task_name) weight_file = tf.train.latest_checkpoint(ckpt_dir) print("fname: {} weight_file: {}".format(fname, weight_file)) if weight_file is not None: U.load_state(weight_file)#, var_list=pi.get_variables()) tf.logging.info('%s loaded' % weight_file) else: print("from scratch") tf.logging.info('Training from the scratch (no pre-trained weight_filets)..') while True: if callback: callback(locals(), globals()) if max_timesteps and timesteps_so_far >= max_timesteps: break elif max_episodes and episodes_so_far >= max_episodes: break elif max_iters and iters_so_far >= max_iters: break # Save model if rank == 0 and iters_so_far % save_per_iter == 0 and ckpt_dir is not None: fname = os.path.join(ckpt_dir, task_name) os.makedirs(os.path.dirname(fname), exist_ok=True) saver = tf.train.Saver() saver.save(tf.get_default_session(), fname) print("============= SAVE ===============") logger.log("********** Iteration %i ************" % iters_so_far) def fisher_vector_product(p): return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p # ------------------ Update G ------------------ logger.log("Optimizing Policy...") # First, we optimize each policy. and then next step we find optimized boundary for given policy. for _ in range(policy_step): with timed("sampling"): seg = seg_gen.__next__() add_vtarg_and_adv(seg, gamma, lam) # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets)) ob_task1, ac_task1, atarg_task1, tdlamret_task1 = seg["ob_task1"], seg["ac_task1"], seg["adv_task1"], seg["tdlamret_task1"] ob_task2, ac_task2, atarg_task2, tdlamret_task2 = seg["ob_task2"], seg["ac_task2"], seg["adv_task2"], seg["tdlamret_task2"] vpredbefore_task1 = seg["vpred_task1"] # predicted value function before udpate vpredbefore_task2 = seg["vpred_task2"] # predicted value function before udpate atarg_task1 = (atarg_task1 - atarg_task1.mean()) / atarg_task1.std() # standardized advantage function estimate atarg_task2 = (atarg_task2 - atarg_task2.mean()) / atarg_task2.std() # standardized advantage function estimate if hasattr(pi_task1, "ob_rms_task1"): pi_task1.ob_rms.update(ob_task1) # update running mean/std for policy if hasattr(pi_task2, "ob_rms_task2"): pi_task2.ob_rms.update(ob_task2) # update running mean/std for policy args_task1 = seg["ob_task1"], seg["ac_task1"], atarg_task1 fvpargs_task1 = [arr[::5] for arr in args_task1] args_task2 = seg["ob_task2"], seg["ac_task2"], atarg_task2 fvpargs_task2 = [arr[::5] for arr in args_task2] assign_old_eq_new_task1() # set old parameter values to new parameter values assign_old_eq_new_task2() # set old parameter values to new parameter values with timed("computegrad"): *lossbefore, g = compute_lossandgrad(*args) lossbefore = allmean(np.array(lossbefore)) g = allmean(g) if np.allclose(g, 0): logger.log("Got zero gradient. not updating") else: with timed("cg"): stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank == 0) assert np.isfinite(stepdir).all() shs = .5*stepdir.dot(fisher_vector_product(stepdir)) lm = np.sqrt(shs / max_kl) # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g)) fullstep = stepdir / lm expectedimprove = g.dot(fullstep) surrbefore = lossbefore[0] stepsize = 1.0 thbefore = get_flat() for _ in range(10): thnew = thbefore + fullstep * stepsize set_from_flat(thnew) meanlosses = surr, kl, *_ = allmean(np.array(compute_losses(*args))) improve = surr - surrbefore logger.log("Expected: %.3f Actual: %.3f" % (expectedimprove, improve)) if not np.isfinite(meanlosses).all(): logger.log("Got non-finite value of losses -- bad!") elif kl > max_kl * 1.5: logger.log("violated KL constraint. shrinking step.") elif improve < 0: logger.log("surrogate didn't improve. shrinking step.") else: logger.log("Stepsize OK!") break stepsize *= .5 else: logger.log("couldn't compute a good step") set_from_flat(thbefore) if nworkers > 1 and iters_so_far % 20 == 0: paramsums = MPI.COMM_WORLD.allgather((thnew.sum(), vfadam.getflat().sum())) # list of tuples assert all(np.allclose(ps, paramsums[0]) for ps in paramsums[1:]) with timed("vf"): for _ in range(vf_iters): for (mbob, mbret) in dataset.iterbatches((seg["ob"], seg["tdlamret"]), include_final_partial_batch=False, batch_size=128): if hasattr(pi, "ob_rms"): pi.ob_rms.update(mbob) # update running mean/std for policy g = allmean(compute_vflossandgrad(mbob, mbret)) vfadam.update(g, vf_stepsize) g_losses = meanlosses for (lossname, lossval) in zip(loss_names, meanlosses): logger.record_tabular(lossname, lossval) logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret)) # ------------------ Update D ------------------ logger.log("Optimizing Discriminator...") logger.log(fmt_row(13, reward_giver.loss_name)) ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob)) batch_size = len(ob) // d_step d_losses = [] # list of tuples, each of which gives the loss for a minibatch for ob_batch, ac_batch in dataset.iterbatches((ob, ac), include_final_partial_batch=False, batch_size=batch_size): ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob_batch)) # update running mean/std for reward_giver if hasattr(reward_giver, "obs_rms"): reward_giver.obs_rms.update(np.concatenate((ob_batch, ob_expert), 0)) *newlosses, g = reward_giver.lossandgrad(ob_batch, ac_batch, ob_expert, ac_expert) d_adam.update(allmean(g), d_stepsize) d_losses.append(newlosses) logger.log(fmt_row(13, np.mean(d_losses, axis=0))) lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ep_true_rets"]) # local values listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples lens, rews, true_rets = map(flatten_lists, zip(*listoflrpairs)) true_rewbuffer.extend(true_rets) lenbuffer.extend(lens) rewbuffer.extend(rews) logger.record_tabular("EpLenMean", np.mean(lenbuffer)) logger.record_tabular("EpRewMean", np.mean(rewbuffer)) logger.record_tabular("EpTrueRewMean", np.mean(true_rewbuffer)) logger.record_tabular("EpThisIter", len(lens)) episodes_so_far += len(lens) timesteps_so_far += sum(lens) iters_so_far += 1 logger.record_tabular("EpisodesSoFar", episodes_so_far) logger.record_tabular("TimestepsSoFar", timesteps_so_far) logger.record_tabular("TimeElapsed", time.time() - tstart) if rank == 0: logger.dump_tabular()
def learn(env, policy_func, med_func, expert_dataset, pretrained, pretrained_weight, g_step, m_step, e_step, inner_iters, save_per_iter, ckpt_dir, log_dir, timesteps_per_batch, task_name, max_kl=0.01, max_timesteps=0, max_episodes=0, max_iters=0, batch_size=64, med_stepsize=1e-3, pi_stepsize=1e-3, callback=None, writer=None): nworkers = MPI.COMM_WORLD.Get_size() rank = MPI.COMM_WORLD.Get_rank() np.set_printoptions(precision=3) ob_space = env.observation_space ac_space = env.action_space pi = policy_func("pi", ob_space, ac_space, reuse=(pretrained_weight != None)) oldpi = policy_func("oldpi", ob_space, ac_space) med = med_func("mediator", ob_space, ac_space) pi_var_list = pi.get_trainable_variables() med_var_list = med.get_trainable_variables() g_ob = U.get_placeholder(name="g_ob", dtype=tf.float32, shape=[None] + list(ob_space.shape)) g_ac = U.get_placeholder(name='g_ac', dtype=tf.float32, shape=[None] + list(ac_space.shape)) e_ob = U.get_placeholder(name='e_ob', dtype=tf.float32, shape=[None] + list(ob_space.shape)) e_ac = U.get_placeholder(name='e_ac', dtype=tf.float32, shape=[None] + list(ac_space.shape)) med_loss = -tf.reduce_mean(med.g_pd.logp(g_ac) + med.e_pd.logp(e_ac)) * 0.5 #pi_loss = -0.5 * (tf.reduce_mean(pi.pd.logp(ac) - med.pd.logp(ac))) g_pdf = tfd.MultivariateNormalDiag(loc=pi.pd.mean, scale_diag=pi.pd.std) m_pdf = tfd.MultivariateNormalDiag(loc=med.g_pd.mean, scale_diag=med.g_pd.std) pi_loss = tf.reduce_mean(g_pdf.cross_entropy(m_pdf) - g_pdf.entropy()) # tf.reduce_mean(pi.pd.kl(med.pd)) kloldnew = oldpi.pd.kl(pi.pd) meankl = tf.reduce_mean(kloldnew) dist = meankl expert_loss = -tf.reduce_mean(pi.pd.logp(e_ac)) assign_old_eq_new = U.function([], [], updates=[tf.assign(oldv, newv) for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())]) compute_med_loss = U.function([g_ob, g_ac, e_ob, e_ac], med_loss) compute_pi_loss = U.function([g_ob], pi_loss) compute_exp_loss = U.function([e_ob, e_ac], expert_loss) # compute_kl_loss = U.function([ob], dist) # compute_fvp = U.function([flat_tangent, ob, ac], fvp) compute_med_lossandgrad = U.function([g_ob, g_ac, e_ob, e_ac], [med_loss, U.flatgrad(med_loss, med_var_list)]) compute_pi_lossandgrad = U.function([g_ob], [pi_loss, U.flatgrad(pi_loss, pi_var_list)]) compute_exp_lossandgrad = U.function([e_ob, e_ac], [expert_loss, U.flatgrad(expert_loss, pi_var_list)]) get_flat = U.GetFlat(pi_var_list) set_from_flat = U.SetFromFlat(pi_var_list) med_adam = MpiAdam(med_var_list) pi_adam = MpiAdam(pi_var_list) def allmean(x): assert isinstance(x, np.ndarray) out = np.empty_like(x) MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM) out /= nworkers return out U.initialize() # th_init = get_flat() # MPI.COMM_WORLD.Bcast(th_init, root=0) # set_from_flat(th_init) med_adam.sync() pi_adam.sync() # if rank == 0: # print("Init pi param sum %d, init med param sum %d." % (th_pi_init.sum(), th_med_init.sum()), flush=True) seg_gen = traj_segment_generator(pi, env, timesteps_per_batch, stochastic=True) episodes_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 tstart = time.time() lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths true_rewbuffer = deque(maxlen=40) assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1 loss_stats = stats(["med_loss", "pi_loss"]) ep_stats = stats(["True_rewards", "Episode_length"]) if pretrained_weight is not None: U.load_state(pretrained_weight, var_list=pi_var_list) med_losses = [] pi_losses = [] while True: if callback: callback(locals(), globals()) if max_timesteps and timesteps_so_far >= max_timesteps: break elif max_episodes and episodes_so_far >= max_episodes: break elif max_iters and iters_so_far >= max_iters: break # Save model if rank == 0 and iters_so_far % save_per_iter == 0 and ckpt_dir is not None: fname = os.path.join(ckpt_dir, task_name) os.makedirs(os.path.dirname(fname), exist_ok=True) saver = tf.train.Saver() saver.save(tf.get_default_session(), fname) logger.log("********** Iteration %i ************" % iters_so_far) # ======= Optimize Mediator========= seg = seg_gen.__next__() g_ob, g_ac = seg['ob'], seg['ac'] #assign_old_eq_new() #stepsize = 3e-4 # thbefore = get_flat() d = dataset.Dataset(dict(ob=g_ob, ac=g_ac)) optim_batchsize = min(batch_size, len(g_ob)) g_loss = [] logger.log("Optimizing Generator...") for _ in range(1): g_batch = d.next_batch(optim_batchsize) g_batch_ob, g_batch_ac = g_batch['ob'], g_batch['ac'] if hasattr(pi, "obs_rms"): pi.obs_rms.update(g_batch_ob) pi_loss, g = compute_pi_lossandgrad(g_batch_ob) # kl = compute_kl_loss(g_ob) # if kl > max_kl * 1.5: # logger.log("violated KL constraint. Shrinking step.") # # stepsize *= 0.1 # break # else: # logger.log("Stepsize OK!") pi_adam.update(allmean(g), pi_stepsize) g_loss.append(pi_loss) pi_losses.append(np.mean(np.array(g_loss))) med_loss = [] logger.log("Optimizing Mediator...") for g_ob_batch, g_ac_batch in dataset.iterbatches((seg['ob'], seg['ac']), include_final_partial_batch=False, batch_size=batch_size): # g_batch = d.next_batch(optim_batchsize) # g_ob_batch, g_ac_batch = g_batch['ob'], g_batch['ac'] e_ob_batch, e_ac_batch = expert_dataset.get_next_batch(optim_batchsize) if hasattr(med, "obs_rms"): med.obs_rms.update(np.concatenate((g_ob_batch, e_ob_batch), 0)) newlosses, g = compute_med_lossandgrad(g_ob_batch, g_ac_batch, e_ob_batch, e_ac_batch) med_adam.update(allmean(g), med_stepsize) med_loss.append(newlosses) med_losses.append(np.mean(np.array(med_loss))) #logger.record_tabular("med_loss_each_iter", np.mean(np.array(med_losses))) #logger.record_tabular("gen_loss_each_iter", np.mean(np.array(pi_losses))) #logger.record_tabular("expert_loss_each_iter", np.mean(np.array(exp_losses))) logger.record_tabular("med_loss_each_iter", np.mean(np.array(med_losses))) logger.record_tabular("gen_loss_each_iter", np.mean(np.array(pi_losses))) lrlocal = (seg["ep_lens"], seg["ep_true_rets"]) listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) lens, true_rets = map(flatten_lists, zip(*listoflrpairs)) true_rewbuffer.extend(true_rets) lenbuffer.extend(lens) logger.record_tabular("EpLenMean", np.mean(lenbuffer)) logger.record_tabular("EpTrueRewMean", np.mean(true_rewbuffer)) logger.record_tabular("EpThisIter", len(lens)) episodes_so_far += len(lens) timesteps_so_far += sum(lens) iters_so_far += 1 logger.record_tabular("EpisodesSoFar", episodes_so_far) logger.record_tabular("TimestepsSoFar", timesteps_so_far) logger.record_tabular("TimeElapsed", time.time() - tstart) if writer is not None: loss_stats.add_all_summary(writer, [np.mean(np.array(med_losses)), np.mean(np.array(pi_losses))], episodes_so_far) ep_stats.add_all_summary(writer, [np.mean(true_rewbuffer), np.mean(lenbuffer)], episodes_so_far) if rank == 0: logger.dump_tabular()
def update(self, obs, actions, atarg, returns, vpredbefore, nb): # lossbefore = self.compute_losses(obs, actions, atarg, nb) # g = self.compute_vjp(obs, actions, atarg, nb) # lossbefore, g = self.allmean(np.array(lossbefore)), self.allmean(np.array(g)) obs = tf.constant(obs) actions = tf.constant(actions) atarg = tf.constant(atarg) returns = tf.constant(returns) estimates = tf.constant(self.estimates[nb]) multipliers = tf.constant(self.multipliers[nb]) args = obs, actions, atarg, estimates, multipliers # Sampling every 5 fvpargs = [arr[::1] for arr in (obs, actions)] vjp = lambda p: self.allmean(self.compute_mpl_vjp(p, *fvpargs).numpy()) hvp = lambda p: self.allmean(self.compute_hvp(p, *fvpargs).numpy() ) + self.cg_damping * p fvp = lambda p: self.allmean( self.my_compute_fvp(self.reshape_from_flat(p), *fvpargs).numpy() ) + self.cg_damping * p self.assign_new_eq_old( ) # set old parameter values to new parameter values # with self.timed("computegrad"): lossbefore = self.compute_losses(*args, nb) g = self.compute_vjp(*args, nb) lossbefore = self.allmean(np.array(lossbefore)) g = g.numpy() g = self.allmean(g) if np.allclose(g, 0): logger.log("Got zero gradient. not updating") else: # with self.timed("cg"): stepdir = cg(hvp, g, cg_iters=self.cg_iters, verbose=self.rank == 0) assert np.isfinite(stepdir).all() shs = .5 * stepdir.dot(hvp(stepdir)) lm = np.sqrt(shs / self.max_kl) # comm = self.comm_matrix[self.comm_matrix[:,nb]!=0][0,self.agent.id] # left = np.array(vjp(-atarg+comm*self.multipliers[nb])+.5*stepdir) # denom = left.dot(hvp(stepdir)) - tf.reduce_sum(self.multipliers[nb]*self.estimates[nb]) # numer = - vjp(self.rho*comm).dot(hvp(stepdir)) # lm = numer / denom logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g)) fullstep = stepdir / lm expectedimprove = g.dot(fullstep) lagrangebefore, surrbefore, syncbefore, *_ = lossbefore stepsize = 1.0 thbefore = self.get_flat() for _ in range(10): thnew = thbefore + fullstep * stepsize self.set_from_flat(thnew) meanlosses = lagrange, surr, syncloss, kl, *_ = self.allmean( np.array(self.compute_losses(*args, nb))) improve = lagrangebefore - lagrange performance_improve = surr - surrbefore sync_improve = syncbefore - syncloss print(lagrangebefore, surrbefore, syncbefore) print(lagrange, surr, syncloss) # input() logger.log("Expected: %.3f Actual: %.3f" % (expectedimprove, improve)) if not np.isfinite(meanlosses).all(): logger.log("Got non-finite value of losses -- bad!") elif kl > self.max_kl * 1.5: logger.log("violated KL constraint. shrinking step.") elif improve < 0: logger.log("surrogate didn't improve. shrinking step.") else: logger.log("Stepsize OK!") break stepsize *= .5 else: logger.log("couldn't compute a good step") self.set_from_flat(thbefore) # with self.timed("vf"): for _ in range(self.vf_iters): for (mbob, mbret) in dataset.iterbatches( (obs, returns), include_final_partial_batch=False, batch_size=64): vg = self.allmean( self.compute_vflossandgrad(mbob, mbret).numpy()) self.vfadam.update(vg, self.vf_stepsize) logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, returns))
def learn( *, network, env, eval_env, make_eval_env, env_id, total_timesteps, timesteps_per_batch, sil_update, sil_loss, # what to train on max_kl=0.001, cg_iters=10, gamma=0.99, lam=1.0, # advantage estimation seed=None, ent_coef=0.0, lr=3e-4, cg_damping=1e-2, vf_stepsize=3e-4, vf_iters=5, sil_value=0.01, sil_alpha=0.6, sil_beta=0.1, max_episodes=0, max_iters=0, # time constraint callback=None, save_interval=0, load_path=None, # MBL # For train mbl mbl_train_freq=5, # For eval num_eval_episodes=5, eval_freq=5, vis_eval=False, eval_targs=('mbmf', ), #eval_targs=('mf',), quant=2, # For mbl.step #num_samples=(1500,), num_samples=(1, ), horizon=(2, ), #horizon=(2,1), #num_elites=(10,), num_elites=(1, ), mbl_lamb=(1.0, ), mbl_gamma=0.99, #mbl_sh=1, # Number of step for stochastic sampling mbl_sh=10000, #vf_lookahead=-1, #use_max_vf=False, reset_per_step=(0, ), # For get_model num_fc=2, num_fwd_hidden=500, use_layer_norm=False, # For MBL num_warm_start=int(1e4), init_epochs=10, update_epochs=5, batch_size=512, update_with_validation=False, use_mean_elites=1, use_ent_adjust=0, adj_std_scale=0.5, # For data loading validation_set_path=None, # For data collect collect_val_data=False, # For traj collect traj_collect='mf', # For profile measure_time=True, eval_val_err=False, measure_rew=True, model_fn=None, update_fn=None, init_fn=None, mpi_rank_weight=1, comm=None, vf_coef=0.5, max_grad_norm=0.5, log_interval=1, nminibatches=4, noptepochs=4, cliprange=0.2, **network_kwargs): ''' learn a policy function with TRPO algorithm Parameters: ---------- network neural network to learn. Can be either string ('mlp', 'cnn', 'lstm', 'lnlstm' for basic types) or function that takes input placeholder and returns tuple (output, None) for feedforward nets or (output, (state_placeholder, state_output, mask_placeholder)) for recurrent nets env environment (one of the gym environments or wrapped via baselines.common.vec_env.VecEnv-type class timesteps_per_batch timesteps per gradient estimation batch max_kl max KL divergence between old policy and new policy ( KL(pi_old || pi) ) ent_coef coefficient of policy entropy term in the optimization objective cg_iters number of iterations of conjugate gradient algorithm cg_damping conjugate gradient damping vf_stepsize learning rate for adam optimizer used to optimie value function loss vf_iters number of iterations of value function optimization iterations per each policy optimization step total_timesteps max number of timesteps max_episodes max number of episodes max_iters maximum number of policy optimization iterations callback function to be called with (locals(), globals()) each policy optimization step load_path str, path to load the model from (default: None, i.e. no model is loaded) **network_kwargs keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network Returns: ------- learnt model ''' if not isinstance(num_samples, tuple): num_samples = (num_samples, ) if not isinstance(horizon, tuple): horizon = (horizon, ) if not isinstance(num_elites, tuple): num_elites = (num_elites, ) if not isinstance(mbl_lamb, tuple): mbl_lamb = (mbl_lamb, ) if not isinstance(reset_per_step, tuple): reset_per_step = (reset_per_step, ) if validation_set_path is None: if collect_val_data: validation_set_path = os.path.join(logger.get_dir(), 'val.pkl') else: validation_set_path = os.path.join('dataset', '{}-val.pkl'.format(env_id)) if eval_val_err: eval_val_err_path = os.path.join('dataset', '{}-combine-val.pkl'.format(env_id)) logger.log(locals()) logger.log('MBL_SH', mbl_sh) logger.log('Traj_collect', traj_collect) if MPI is not None: nworkers = MPI.COMM_WORLD.Get_size() rank = MPI.COMM_WORLD.Get_rank() else: nworkers = 1 rank = 0 cpus_per_worker = 1 U.get_session( config=tf.ConfigProto(allow_soft_placement=True, inter_op_parallelism_threads=cpus_per_worker, intra_op_parallelism_threads=cpus_per_worker)) set_global_seeds(seed) if isinstance(lr, float): lr = constfn(lr) else: assert callable(lr) if isinstance(cliprange, float): cliprange = constfn(cliprange) else: assert callable(cliprange) policy = build_policy(env, network, value_network='copy', **network_kwargs) nenvs = env.num_envs np.set_printoptions(precision=3) # Setup losses and stuff # ---------------------------------------- ob_space = env.observation_space ac_space = env.action_space nbatch = nenvs * timesteps_per_batch nbatch_train = nbatch // nminibatches is_mpi_root = (MPI is None or MPI.COMM_WORLD.Get_rank() == 0) ob = observation_placeholder(ob_space) with tf.variable_scope("pi"): pi = policy(observ_placeholder=ob) make_model = lambda: Model( policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train, nsteps=timesteps_per_batch, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, sil_update=sil_update, sil_value=sil_value, sil_alpha=sil_alpha, sil_beta=sil_beta, sil_loss=sil_loss, # fn_reward=env.process_reward, fn_reward=None, # fn_obs=env.process_obs, fn_obs=None, ppo=False, prev_pi='pi', silm=pi) model = make_model() with tf.variable_scope("oldpi"): oldpi = policy(observ_placeholder=ob) make_old_model = lambda: Model( policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train, nsteps=timesteps_per_batch, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, sil_update=sil_update, sil_value=sil_value, sil_alpha=sil_alpha, sil_beta=sil_beta, sil_loss=sil_loss, # fn_reward=env.process_reward, fn_reward=None, # fn_obs=env.process_obs, fn_obs=None, ppo=False, prev_pi='oldpi', silm=oldpi) old_model = make_old_model() # MBL # --------------------------------------- #viz = Visdom(env=env_id) win = None eval_targs = list(eval_targs) logger.log(eval_targs) make_model_f = get_make_mlp_model(num_fc=num_fc, num_fwd_hidden=num_fwd_hidden, layer_norm=use_layer_norm) mbl = MBL(env=eval_env, env_id=env_id, make_model=make_model_f, num_warm_start=num_warm_start, init_epochs=init_epochs, update_epochs=update_epochs, batch_size=batch_size, **network_kwargs) val_dataset = {'ob': None, 'ac': None, 'ob_next': None} if update_with_validation: logger.log('Update with validation') val_dataset = load_val_data(validation_set_path) if eval_val_err: logger.log('Log val error') eval_val_dataset = load_val_data(eval_val_err_path) if collect_val_data: logger.log('Collect validation data') val_dataset_collect = [] def _mf_pi(ob, t=None): stochastic = True ac, vpred, _, _ = pi.step(ob, stochastic=stochastic) return ac, vpred def _mf_det_pi(ob, t=None): #ac, vpred, _, _ = pi.step(ob, stochastic=False) ac, vpred = pi._evaluate([pi.pd.mode(), pi.vf], ob) return ac, vpred def _mf_ent_pi(ob, t=None): mean, std, vpred = pi._evaluate([pi.pd.mode(), pi.pd.std, pi.vf], ob) ac = np.random.normal(mean, std * adj_std_scale, size=mean.shape) return ac, vpred ################### use_ent_adjust======> adj_std_scale????????pi action sample def _mbmf_inner_pi(ob, t=0): if use_ent_adjust: return _mf_ent_pi(ob) else: #return _mf_pi(ob) if t < mbl_sh: return _mf_pi(ob) else: return _mf_det_pi(ob) # --------------------------------------- # Run multiple configuration once all_eval_descs = [] def make_mbmf_pi(n, h, e, l): def _mbmf_pi(ob): ac, rew = mbl.step(ob=ob, pi=_mbmf_inner_pi, horizon=h, num_samples=n, num_elites=e, gamma=mbl_gamma, lamb=l, use_mean_elites=use_mean_elites) return ac[None], rew return Policy(step=_mbmf_pi, reset=None) for n in num_samples: for h in horizon: for l in mbl_lamb: for e in num_elites: if 'mbmf' in eval_targs: all_eval_descs.append(('MeanRew', 'MBL_TRPO_SIL', make_mbmf_pi(n, h, e, l))) #if 'mbmf' in eval_targs: all_eval_descs.append(('MeanRew-n-{}-h-{}-e-{}-l-{}-sh-{}-me-{}'.format(n, h, e, l, mbl_sh, use_mean_elites), 'MBL_TRPO-n-{}-h-{}-e-{}-l-{}-sh-{}-me-{}'.format(n, h, e, l, mbl_sh, use_mean_elites), make_mbmf_pi(n, h, e, l))) if 'mf' in eval_targs: all_eval_descs.append( ('MeanRew', 'TRPO_SIL', Policy(step=_mf_pi, reset=None))) logger.log('List of evaluation targets') for it in all_eval_descs: logger.log(it[0]) pool = Pool(mp.cpu_count()) warm_start_done = False # ---------------------------------------- atarg = tf.placeholder( dtype=tf.float32, shape=[None]) # Target advantage function (if applicable) ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return ac = pi.pdtype.sample_placeholder([None]) kloldnew = oldpi.pd.kl(pi.pd) ent = pi.pd.entropy() meankl = tf.reduce_mean(kloldnew) meanent = tf.reduce_mean(ent) entbonus = ent_coef * meanent vferr = tf.reduce_mean(tf.square(pi.vf - ret)) ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold surrgain = tf.reduce_mean(ratio * atarg) optimgain = surrgain + entbonus losses = [optimgain, meankl, entbonus, surrgain, meanent] loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"] dist = meankl all_var_list = get_trainable_variables("pi") # var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("pol")] # vf_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("vf")] var_list = get_pi_trainable_variables("pi") vf_var_list = get_vf_trainable_variables("pi") vfadam = MpiAdam(vf_var_list) get_flat = U.GetFlat(var_list) set_from_flat = U.SetFromFlat(var_list) klgrads = tf.gradients(dist, var_list) flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan") shapes = [var.get_shape().as_list() for var in var_list] start = 0 tangents = [] for shape in shapes: sz = U.intprod(shape) tangents.append(tf.reshape(flat_tangent[start:start + sz], shape)) start += sz gvp = tf.add_n([ tf.reduce_sum(g * tangent) for (g, tangent) in zipsame(klgrads, tangents) ]) #pylint: disable=E1111 fvp = U.flatgrad(gvp, var_list) assign_old_eq_new = U.function( [], [], updates=[ tf.assign(oldv, newv) for (oldv, newv) in zipsame(get_variables("oldpi"), get_variables("pi")) ]) compute_losses = U.function([ob, ac, atarg], losses) compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)]) compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp) compute_vflossandgrad = U.function([ob, ret], U.flatgrad(vferr, vf_var_list)) @contextmanager def timed(msg): if rank == 0: print(colorize(msg, color='magenta')) tstart = time.time() yield print( colorize("done in %.3f seconds" % (time.time() - tstart), color='magenta')) else: yield def allmean(x): assert isinstance(x, np.ndarray) out = np.empty_like(x) MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM) out /= nworkers return out U.initialize() if load_path is not None: pi.load(load_path) th_init = get_flat() MPI.COMM_WORLD.Bcast(th_init, root=0) set_from_flat(th_init) vfadam.sync() print("Init param sum", th_init.sum(), flush=True) # Prepare for rollouts # ---------------------------------------- if traj_collect == 'mf': seg_gen = traj_segment_generator(env, timesteps_per_batch, model, stochastic=True) episodes_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 tstart = time.time() lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards if sum([max_iters > 0, total_timesteps > 0, max_episodes > 0]) == 0: # noththing to be done return pi assert sum([max_iters>0, total_timesteps>0, max_episodes>0]) < 2, \ 'out of max_iters, total_timesteps, and max_episodes only one should be specified' while True: if callback: callback(locals(), globals()) if total_timesteps and timesteps_so_far >= total_timesteps: break elif max_episodes and episodes_so_far >= max_episodes: break elif max_iters and iters_so_far >= max_iters: break logger.log("********** Iteration %i ************" % iters_so_far) with timed("sampling"): seg = seg_gen.__next__() if traj_collect == 'mf-random' or traj_collect == 'mf-mb': seg_mbl = seg_gen_mbl.__next__() else: seg_mbl = seg add_vtarg_and_adv(seg, gamma, lam) # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets)) ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[ "tdlamret"] # Val data collection if collect_val_data: for ob_, ac_, ob_next_ in zip(ob[:-1, 0, ...], ac[:-1, ...], ob[1:, 0, ...]): val_dataset_collect.append( (copy.copy(ob_), copy.copy(ac_), copy.copy(ob_next_))) # ----------------------------- # MBL update else: ob_mbl, ac_mbl = seg_mbl["ob"], seg_mbl["ac"] mbl.add_data_batch(ob_mbl[:-1, 0, ...], ac_mbl[:-1, ...], ob_mbl[1:, 0, ...]) mbl.update_forward_dynamic(require_update=iters_so_far % mbl_train_freq == 0, ob_val=val_dataset['ob'], ac_val=val_dataset['ac'], ob_next_val=val_dataset['ob_next']) # ----------------------------- if traj_collect == 'mf': #if traj_collect == 'mf' or traj_collect == 'mf-random' or traj_collect == 'mf-mb': vpredbefore = seg[ "vpred"] # predicted value function before udpate model = seg["model"] atarg = (atarg - atarg.mean()) / atarg.std( ) # standardized advantage function estimate if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret) if hasattr(pi, "rms"): pi.rms.update(ob) # update running mean/std for policy args = seg["ob"], seg["ac"], atarg fvpargs = [arr[::5] for arr in args] def fisher_vector_product(p): return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p assign_old_eq_new( ) # set old parameter values to new parameter values with timed("computegrad"): *lossbefore, g = compute_lossandgrad(*args) lossbefore = allmean(np.array(lossbefore)) g = allmean(g) if np.allclose(g, 0): logger.log("Got zero gradient. not updating") else: with timed("cg"): stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank == 0) assert np.isfinite(stepdir).all() shs = .5 * stepdir.dot(fisher_vector_product(stepdir)) lm = np.sqrt(shs / max_kl) # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g)) fullstep = stepdir / lm expectedimprove = g.dot(fullstep) surrbefore = lossbefore[0] stepsize = 1.0 thbefore = get_flat() for _ in range(10): thnew = thbefore + fullstep * stepsize set_from_flat(thnew) meanlosses = surr, kl, *_ = allmean( np.array(compute_losses(*args))) improve = surr - surrbefore logger.log("Expected: %.3f Actual: %.3f" % (expectedimprove, improve)) if not np.isfinite(meanlosses).all(): logger.log("Got non-finite value of losses -- bad!") elif kl > max_kl * 1.5: logger.log("violated KL constraint. shrinking step.") elif improve < 0: logger.log("surrogate didn't improve. shrinking step.") else: logger.log("Stepsize OK!") break stepsize *= .5 else: logger.log("couldn't compute a good step") set_from_flat(thbefore) if nworkers > 1 and iters_so_far % 20 == 0: paramsums = MPI.COMM_WORLD.allgather( (thnew.sum(), vfadam.getflat().sum())) # list of tuples assert all( np.allclose(ps, paramsums[0]) for ps in paramsums[1:]) for (lossname, lossval) in zip(loss_names, meanlosses): logger.record_tabular(lossname, lossval) with timed("vf"): for _ in range(vf_iters): for (mbob, mbret) in dataset.iterbatches( (seg["ob"], seg["tdlamret"]), include_final_partial_batch=False, batch_size=64): g = allmean(compute_vflossandgrad(mbob, mbret)) vfadam.update(g, vf_stepsize) with timed("SIL"): lrnow = lr(1.0 - timesteps_so_far / total_timesteps) l_loss, sil_adv, sil_samples, sil_nlogp = model.sil_train( lrnow) logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret)) lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values if MPI is not None: listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples else: listoflrpairs = [lrlocal] lens, rews = map(flatten_lists, zip(*listoflrpairs)) lenbuffer.extend(lens) rewbuffer.extend(rews) logger.record_tabular("EpLenMean", np.mean(lenbuffer)) logger.record_tabular("EpRewMean", np.mean(rewbuffer)) logger.record_tabular("EpThisIter", len(lens)) episodes_so_far += len(lens) timesteps_so_far += sum(lens) iters_so_far += 1 logger.record_tabular("EpisodesSoFar", episodes_so_far) logger.record_tabular("TimestepsSoFar", timesteps_so_far) logger.record_tabular("TimeElapsed", time.time() - tstart) if sil_update > 0: logger.record_tabular("SilSamples", sil_samples) if rank == 0: # MBL evaluation if not collect_val_data: #set_global_seeds(seed) default_sess = tf.get_default_session() def multithread_eval_policy(env_, pi_, num_episodes_, vis_eval_, seed): with default_sess.as_default(): if hasattr(env, 'ob_rms') and hasattr(env_, 'ob_rms'): env_.ob_rms = env.ob_rms res = eval_policy(env_, pi_, num_episodes_, vis_eval_, seed, measure_time, measure_rew) try: env_.close() except: pass return res if mbl.is_warm_start_done() and iters_so_far % eval_freq == 0: warm_start_done = mbl.is_warm_start_done() if num_eval_episodes > 0: targs_names = {} with timed('eval'): num_descs = len(all_eval_descs) list_field_names = [e[0] for e in all_eval_descs] list_legend_names = [e[1] for e in all_eval_descs] list_pis = [e[2] for e in all_eval_descs] list_eval_envs = [ make_eval_env() for _ in range(num_descs) ] list_seed = [seed for _ in range(num_descs)] list_num_eval_episodes = [ num_eval_episodes for _ in range(num_descs) ] print(list_field_names) print(list_legend_names) list_vis_eval = [ vis_eval for _ in range(num_descs) ] for i in range(num_descs): field_name, legend_name = list_field_names[ i], list_legend_names[i], res = multithread_eval_policy( list_eval_envs[i], list_pis[i], list_num_eval_episodes[i], list_vis_eval[i], seed) #eval_results = pool.starmap(multithread_eval_policy, zip(list_eval_envs, list_pis, list_num_eval_episodes, list_vis_eval,list_seed)) #for field_name, legend_name, res in zip(list_field_names, list_legend_names, eval_results): perf, elapsed_time, eval_rew = res logger.record_tabular(field_name, perf) if measure_time: logger.record_tabular( 'Time-%s' % (field_name), elapsed_time) if measure_rew: logger.record_tabular( 'SimRew-%s' % (field_name), eval_rew) targs_names[field_name] = legend_name if eval_val_err: fwd_dynamics_err = mbl.eval_forward_dynamic( obs=eval_val_dataset['ob'], acs=eval_val_dataset['ac'], obs_next=eval_val_dataset['ob_next']) logger.record_tabular('FwdValError', fwd_dynamics_err) logger.dump_tabular() #print(logger.get_dir()) #print(targs_names) #if num_eval_episodes > 0: # win = plot(viz, win, logger.get_dir(), targs_names=targs_names, quant=quant, opt='best') # ----------- #logger.dump_tabular() yield pi if collect_val_data: with open(validation_set_path, 'wb') as f: pickle.dump(val_dataset_collect, f) logger.log('Save {} validation data'.format(len(val_dataset_collect)))
def learn( *, network, env, total_timesteps, timesteps_per_batch=1024, # what to train on max_kl=0.002, cg_iters=10, gamma=0.99, lam=1.0, # advantage estimation seed=None, ent_coef=0.00, cg_damping=1e-2, vf_stepsize=3e-4, vf_iters=3, max_episodes=0, max_iters=0, # time constraint callback=None, load_path=None, num_reward=1, **network_kwargs): ''' learn a policy function with TRPO algorithm Parameters: ---------- network neural network to learn. Can be either string ('mlp', 'cnn', 'lstm', 'lnlstm' for basic types) or function that takes input placeholder and returns tuple (output, None) for feedforward nets or (output, (state_placeholder, state_output, mask_placeholder)) for recurrent nets env environment (one of the gym environments or wrapped via baselines.common.vec_env.VecEnv-type class timesteps_per_batch timesteps per gradient estimation batch max_kl max KL divergence between old policy and new policy ( KL(pi_old || pi) ) ent_coef coefficient of policy entropy term in the optimization objective cg_iters number of iterations of conjugate gradient algorithm cg_damping conjugate gradient damping vf_stepsize learning rate for adam optimizer used to optimie value function loss vf_iters number of iterations of value function optimization iterations per each policy optimization step total_timesteps max number of timesteps max_episodes max number of episodes max_iters maximum number of policy optimization iterations callback function to be called with (locals(), globals()) each policy optimization step load_path str, path to load the model from (default: None, i.e. no model is loaded) **network_kwargs keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network Returns: ------- learnt model ''' if MPI is not None: nworkers = MPI.COMM_WORLD.Get_size() rank = MPI.COMM_WORLD.Get_rank() else: nworkers = 1 rank = 0 cpus_per_worker = 1 U.get_session( config=tf.ConfigProto(allow_soft_placement=True, inter_op_parallelism_threads=cpus_per_worker, intra_op_parallelism_threads=cpus_per_worker)) set_global_seeds(seed) # 创建policy policy = build_policy(env, network, value_network='copy', num_reward=num_reward, **network_kwargs) process_dir = logger.get_dir() save_dir = process_dir.split( 'Data')[-2] + 'log/l1/seed' + process_dir[-1] + '/' os.makedirs(save_dir, exist_ok=True) coe_save = [] impro_save = [] grad_save = [] adj_save = [] coe = np.ones((num_reward)) / num_reward np.set_printoptions(precision=3) # Setup losses and stuff # ---------------------------------------- ob_space = env.observation_space ac_space = env.action_space ################################################################# # ob ac ret atarg 都是 placeholder # ret atarg 此处应该是向量形式 ob = observation_placeholder(ob_space) # 创建pi和oldpi with tf.variable_scope("pi"): pi = policy(observ_placeholder=ob) with tf.variable_scope("oldpi"): oldpi = policy(observ_placeholder=ob) # 每个reward都可以算一个atarg atarg = tf.placeholder( dtype=tf.float32, shape=[None]) # Target advantage function (if applicable) ret = tf.placeholder(dtype=tf.float32, shape=[None, num_reward]) # Empirical return ac = pi.pdtype.sample_placeholder([None]) #此处的KL div和entropy与reward无关 ################################## kloldnew = oldpi.pd.kl(pi.pd) ent = pi.pd.entropy() meankl = tf.reduce_mean(kloldnew) meanent = tf.reduce_mean(ent) # entbonus 是entropy loss entbonus = ent_coef * meanent ################################# ########################################################### # vferr 用来更新 v 网络 vferr = tf.reduce_mean(tf.square(pi.vf - ret)) ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold surrgain = tf.reduce_mean(ratio * atarg) # optimgain 用来更新 policy 网络, 应该每个reward有一个 optimgain = surrgain + entbonus losses = [optimgain, meankl, entbonus, surrgain, meanent] loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"] ########################################################### dist = meankl # 定义要优化的变量和 V 网络 adam 优化器 all_var_list = get_trainable_variables("pi") # var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("pol")] # vf_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("vf")] var_list = get_pi_trainable_variables("pi") vf_var_list = get_vf_trainable_variables("pi") vfadam = MpiAdam(vf_var_list) # 把变量展开成一个向量的类 get_flat = U.GetFlat(var_list) # 这个类可以把一个向量分片赋值给var_list里的变量 set_from_flat = U.SetFromFlat(var_list) # kl散度的梯度 klgrads = tf.gradients(dist, var_list) #################################################################### # 拉直的向量 flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan") # 把拉直的向量重新分成很多向量 shapes = [var.get_shape().as_list() for var in var_list] start = 0 tangents = [] for shape in shapes: sz = U.intprod(shape) tangents.append(tf.reshape(flat_tangent[start:start + sz], shape)) start += sz #################################################################### #################################################################### # 把kl散度梯度与变量乘积相加 gvp = tf.add_n([ tf.reduce_sum(g * tangent) for (g, tangent) in zipsame(klgrads, tangents) ]) #pylint: disable=E1111 # 把gvp的梯度展成向量 fvp = U.flatgrad(gvp, var_list) #################################################################### # 用学习后的策略更新old策略 assign_old_eq_new = U.function( [], [], updates=[ tf.assign(oldv, newv) for (oldv, newv) in zipsame(get_variables("oldpi"), get_variables("pi")) ]) # 计算loss compute_losses = U.function([ob, ac, atarg], losses) # 计算loss和梯度 compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)]) # 计算fvp compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp) # 计算值网络的梯度 compute_vflossandgrad = U.function([ob, ret], U.flatgrad(vferr, vf_var_list)) @contextmanager def timed(msg): if rank == 0: print(colorize(msg, color='magenta')) tstart = time.time() yield print( colorize("done in %.3f seconds" % (time.time() - tstart), color='magenta')) else: yield def allmean(x): assert isinstance(x, np.ndarray) if MPI is not None: out = np.empty_like(x) MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM) out /= nworkers else: out = np.copy(x) return out # 初始化variable U.initialize() if load_path is not None: pi.load(load_path) # 得到初始化的参数向量 th_init = get_flat() if MPI is not None: MPI.COMM_WORLD.Bcast(th_init, root=0) # 把向量the_init的值分片赋值给var_list set_from_flat(th_init) #同步 vfadam.sync() print("Init param sum", th_init.sum(), flush=True) # Prepare for rollouts # ---------------------------------------- # 这是一个生成数据的迭代器 seg_gen = traj_segment_generator(pi, env, timesteps_per_batch, stochastic=True, num_reward=num_reward) episodes_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 tstart = time.time() # 双端队列 lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards if sum([max_iters > 0, total_timesteps > 0, max_episodes > 0]) == 0: # noththing to be done return pi assert sum([max_iters>0, total_timesteps>0, max_episodes>0]) < 2, \ 'out of max_iters, total_timesteps, and max_episodes only one should be specified' while True: if callback: callback(locals(), globals()) if total_timesteps and timesteps_so_far >= total_timesteps: break elif max_episodes and episodes_so_far >= max_episodes: break elif max_iters and iters_so_far >= max_iters: break logger.log("********** Iteration %i ************" % iters_so_far) with timed("sampling"): seg = seg_gen.__next__() # 计算累积回报 add_vtarg_and_adv(seg, gamma, lam, num_reward=num_reward) ###########$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$ToDo # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets)) # ob, ac, atarg, tdlamret 的类型都是ndarray #ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"] _, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[ "tdlamret"] #print(seg['ob'].shape,type(seg['ob'])) #print(seg['ac'],type(seg['ac'])) #print(seg['adv'],type(seg['adv'])) #print(seg["tdlamret"].shape,type(seg['tdlamret'])) vpredbefore = seg["vpred"] # predicted value function before udpate # 标准化 #print("============================== atarg =========================================================") #print(atarg) atarg = (atarg - np.mean(atarg, axis=0)) / np.std( atarg, axis=0) # standardized advantage function estimate #atarg = (atarg) / np.max(np.abs(atarg),axis=0) #print('======================================= standardized atarg ====================================') #print(atarg) if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret) if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy ## set old parameter values to new parameter values assign_old_eq_new() G = None S = None mr_lossbefore = np.zeros((num_reward, len(loss_names))) grad_norm = np.zeros((num_reward + 1)) for i in range(num_reward): args = seg["ob"], seg["ac"], atarg[:, i] #print(atarg[:,i]) # 算是args的一个sample,每隔5个取出一个 fvpargs = [arr[::5] for arr in args] # 这个函数计算fisher matrix 与向量 p 的 乘积 def fisher_vector_product(p): return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p with timed("computegrad of " + str(i + 1) + ".th reward"): *lossbefore, g = compute_lossandgrad(*args) lossbefore = allmean(np.array(lossbefore)) mr_lossbefore[i] = lossbefore g = allmean(g) #print("***************************************************************") #print(g) #print('==================='+str(i+1)+"=====================",np.linalg.norm(g)) #print(atarg[:,i]) if isinstance(G, np.ndarray): G = np.vstack((G, g)) else: G = g # g是目标函数的梯度 # 利用共轭梯度获得更新方向 if np.allclose(g, 0): logger.log("Got zero gradient. not updating") else: with timed("cg of " + str(i + 1) + ".th reward"): # stepdir 是更新方向 stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank == 0) shs = .5 * stepdir.dot(fisher_vector_product(stepdir)) lm = np.sqrt(shs / max_kl) # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g)) fullstep = stepdir / lm #print(np.linalg.norm(fullstep)) grad_norm[i] = np.linalg.norm(fullstep) assert np.isfinite(stepdir).all() if isinstance(S, np.ndarray): S = np.vstack((S, stepdir)) else: S = stepdir #print('======================================= G ====================================') #print(G) #print('======================================= S ====================================') #print(S) new_coe = get_coefficient(G, S) #coe = 0.99 * coe + 0.01 * new_coe coe = new_coe coe_save.append(coe) #根据梯度的夹角调整参数 try: GG = np.dot(S, S.T) D = np.sqrt(np.diag(1 / np.diag(GG))) GG = np.dot(np.dot(D, GG), D) #print('======================================= inner product ====================================') #print(GG) adj = np.sum(GG) / (num_reward**2) except: adj = 1 #print('======================================= adj ====================================') #print(adj) try: adj = 1 adj_save.append(adj) adj_max_kl = adj * max_kl ################################################################# grad_norm = grad_norm * np.sqrt(adj) stepdir = np.dot(coe, S) g = np.dot(coe, G) lossbefore = np.dot(coe, mr_lossbefore) ################################################################# shs = .5 * stepdir.dot(fisher_vector_product(stepdir)) lm = np.sqrt(shs / adj_max_kl) # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g)) fullstep = stepdir / lm grad_norm[num_reward] = np.linalg.norm(fullstep) grad_save.append(grad_norm) expectedimprove = g.dot(fullstep) surrbefore = lossbefore[0] stepsize = 1.0 thbefore = get_flat() def compute_mr_losses(): mr_losses = np.zeros((num_reward, len(loss_names))) for i in range(num_reward): args = seg["ob"], seg["ac"], atarg[:, i] one_reward_loss = allmean(np.array(compute_losses(*args))) mr_losses[i] = one_reward_loss mr_loss = np.dot(coe, mr_losses) return mr_loss, mr_losses # 做10次搜索 for _ in range(10): thnew = thbefore + fullstep * stepsize set_from_flat(thnew) mr_loss_new, mr_losses_new = compute_mr_losses() mr_impro = mr_losses_new - mr_lossbefore meanlosses = surr, kl, *_ = allmean(np.array(mr_loss_new)) improve = surr - surrbefore logger.log("Expected: %.3f Actual: %.3f" % (expectedimprove, improve)) if not np.isfinite(meanlosses).all(): logger.log("Got non-finite value of losses -- bad!") elif kl > adj_max_kl * 1.5: logger.log("violated KL constraint. shrinking step.") elif improve < 0: logger.log("surrogate didn't improve. shrinking step.") else: logger.log("Stepsize OK!") impro_save.append(np.hstack((mr_impro[:, 0], improve))) break stepsize *= .5 else: logger.log("couldn't compute a good step") set_from_flat(thbefore) if nworkers > 1 and iters_so_far % 20 == 0: paramsums = MPI.COMM_WORLD.allgather( (thnew.sum(), vfadam.getflat().sum())) # list of tuples assert all( np.allclose(ps, paramsums[0]) for ps in paramsums[1:]) for (lossname, lossval) in zip(loss_names, meanlosses): logger.record_tabular(lossname, lossval) with timed("vf"): #print('======================================= tdlamret ====================================') #print(seg["tdlamret"]) for _ in range(vf_iters): for (mbob, mbret) in dataset.iterbatches( (seg["ob"], seg["tdlamret"]), include_final_partial_batch=False, batch_size=64): #with tf.Session() as sess: # sess.run(tf.global_variables_initializer()) # aaa = sess.run(pi.vf,feed_dict={ob:mbob,ret:mbret}) # print("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") # print(aaa.shape) # print(mbret.shape) g = allmean(compute_vflossandgrad(mbob, mbret)) vfadam.update(g, vf_stepsize) #print(mbob,mbret) except: print('error') logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret)) lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values if MPI is not None: listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples else: listoflrpairs = [lrlocal] lens, rews = map(flatten_lists, zip(*listoflrpairs)) lenbuffer.extend(lens) rewbuffer.extend(rews) logger.record_tabular("EpLenMean", np.mean(lenbuffer)) logger.record_tabular("EpRewMean", np.mean(rewbuffer)) logger.record_tabular("EpThisIter", len(lens)) episodes_so_far += len(lens) timesteps_so_far += sum(lens) iters_so_far += 1 logger.record_tabular("EpisodesSoFar", episodes_so_far) logger.record_tabular("TimestepsSoFar", timesteps_so_far) logger.record_tabular("TimeElapsed", time.time() - tstart) if rank == 0: logger.dump_tabular() #pdb.set_trace() np.save(save_dir + 'coe.npy', coe_save) np.save(save_dir + 'grad.npy', grad_save) np.save(save_dir + 'improve.npy', impro_save) np.save(save_dir + 'adj.npy', adj_save) return pi
def learn(env, policy_func, reward_giver, expert_dataset, rank, pretrained, pretrained_weight, *, g_step, d_step, entcoeff, save_per_iter, ckpt_dir, log_dir, timesteps_per_batch, task_name, gamma, lam, max_kl, cg_iters, cg_damping=1e-2, vf_stepsize=3e-4, d_stepsize=3e-4, vf_iters=3, max_timesteps=0, max_episodes=0, max_iters=0, vf_batchsize=128, callback=None, freeze_g=False, freeze_d=False, semi_dataset=None, semi_loss=False): semi_loss = semi_loss and semi_dataset is not None l2_w = 0.1 nworkers = MPI.COMM_WORLD.Get_size() rank = MPI.COMM_WORLD.Get_rank() np.set_printoptions(precision=3) if rank == 0: writer = U.file_writer(log_dir) # print all the hyperparameters in the log... log_dict = { "expert trajectories": expert_dataset.num_traj, "algo": "trpo", "threads": nworkers, "timesteps_per_batch": timesteps_per_batch, "timesteps_per_thread": -(-timesteps_per_batch // nworkers), "entcoeff": entcoeff, "vf_iters": vf_iters, "vf_batchsize": vf_batchsize, "vf_stepsize": vf_stepsize, "d_stepsize": d_stepsize, "g_step": g_step, "d_step": d_step, "max_kl": max_kl, "gamma": gamma, "lam": lam, "l2_weight": l2_w } if semi_dataset is not None: log_dict["semi trajectories"] = semi_dataset.num_traj if hasattr(semi_dataset, 'info'): log_dict["semi_dataset_info"] = semi_dataset.info # print them all together for csv logger.log(",".join([str(elem) for elem in log_dict])) logger.log(",".join([str(elem) for elem in log_dict.values()])) # also print them separately for easy reading: for elem in log_dict: logger.log(str(elem) + ": " + str(log_dict[elem])) # divide the timesteps to the threads timesteps_per_batch = -(-timesteps_per_batch // nworkers ) # get ceil of division # Setup losses and stuff # ---------------------------------------- if semi_dataset: ob_space = semi_ob_space(env, semi_size=semi_dataset.semi_size) else: ob_space = env.observation_space ac_space = env.action_space pi = policy_func("pi", ob_space=ob_space, ac_space=ac_space, reuse=(pretrained_weight is not None)) oldpi = policy_func("oldpi", ob_space=ob_space, ac_space=ac_space) atarg = tf.placeholder( dtype=tf.float32, shape=[None]) # Target advantage function (if applicable) ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return ob = U.get_placeholder_cached(name="ob") ac = pi.pdtype.sample_placeholder([None]) kloldnew = oldpi.pd.kl(pi.pd) ent = pi.pd.entropy() meankl = tf.reduce_mean(kloldnew) meanent = tf.reduce_mean(ent) entbonus = entcoeff * meanent vferr = tf.reduce_mean(tf.square(pi.vpred - ret)) ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold surrgain = tf.reduce_mean(ratio * atarg) optimgain = surrgain + entbonus losses = [optimgain, meankl, entbonus, surrgain, meanent] loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"] vf_losses = [vferr] vf_loss_names = ["vf_loss"] dist = meankl all_var_list = pi.get_trainable_variables() var_list = [ v for v in all_var_list if v.name.startswith("pi/pol") or v.name.startswith("pi/logstd") ] vf_var_list = [v for v in all_var_list if v.name.startswith("pi/vf")] assert len(var_list) == len(vf_var_list) + 1 get_flat = U.GetFlat(var_list) set_from_flat = U.SetFromFlat(var_list) klgrads = tf.gradients(dist, var_list) flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan") shapes = [var.get_shape().as_list() for var in var_list] start = 0 tangents = [] for shape in shapes: sz = U.intprod(shape) tangents.append(tf.reshape(flat_tangent[start:start + sz], shape)) start += sz gvp = tf.add_n([ tf.reduce_sum(g * tangent) for (g, tangent) in zipsame(klgrads, tangents) ]) # pylint: disable=E1111 fvp = U.flatgrad(gvp, var_list) assign_old_eq_new = U.function( [], [], updates=[ tf.assign(oldv, newv) for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables()) ]) compute_losses = U.function([ob, ac, atarg], losses) compute_vf_losses = U.function([ob, ac, atarg, ret], losses + vf_losses) compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)]) compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp) compute_vflossandgrad = U.function([ob, ret], vf_losses + [U.flatgrad(vferr, vf_var_list)]) @contextmanager def timed(msg): if rank == 0: print(colorize(msg, color='magenta')) tstart = time.time() yield print( colorize("done in %.3f seconds" % (time.time() - tstart), color='magenta')) else: yield def allmean(x): assert isinstance(x, np.ndarray) out = np.empty_like(x) MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM) out /= nworkers return out episodes_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 tstart = time.time() lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards true_rewbuffer = deque(maxlen=40) success_buffer = deque(maxlen=40) l2_rewbuffer = deque( maxlen=40) if semi_loss and semi_dataset is not None else None total_rewbuffer = deque( maxlen=40) if semi_loss and semi_dataset is not None else None assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1 not_update = 1 if not freeze_d else 0 # do not update G before D the first time # if provide pretrained weight loaded = False if not U.load_checkpoint_variables(pretrained_weight): if U.load_checkpoint_variables(pretrained_weight, check_prefix=get_il_prefix()): if rank == 0: logger.log("loaded checkpoint variables from " + pretrained_weight) loaded = True else: loaded = True if loaded: not_update = 0 if any( [x.op.name.find("adversary") != -1 for x in U.ALREADY_INITIALIZED]) else 1 if pretrained_weight and pretrained_weight.rfind("iter_") and \ pretrained_weight[pretrained_weight.rfind("iter_") + len("iter_"):].isdigit(): curr_iter = int( pretrained_weight[pretrained_weight.rfind("iter_") + len("iter_"):]) + 1 print("loaded checkpoint at iteration: " + str(curr_iter)) iters_so_far = curr_iter timesteps_so_far = iters_so_far * timesteps_per_batch d_adam = MpiAdam(reward_giver.get_trainable_variables()) vfadam = MpiAdam(vf_var_list) U.initialize() th_init = get_flat() MPI.COMM_WORLD.Bcast(th_init, root=0) set_from_flat(th_init) d_adam.sync() vfadam.sync() if rank == 0: print("Init param sum", th_init.sum(), flush=True) # Prepare for rollouts # ---------------------------------------- seg_gen = traj_segment_generator( pi, env, reward_giver, timesteps_per_batch, stochastic=True, semi_dataset=semi_dataset, semi_loss=semi_loss) # ADD L2 loss to semi trajectories g_loss_stats = stats(loss_names + vf_loss_names) d_loss_stats = stats(reward_giver.loss_name) ep_names = ["True_rewards", "Rewards", "Episode_length", "Success"] if semi_loss and semi_dataset is not None: ep_names.append("L2_loss") ep_names.append("total_rewards") ep_stats = stats(ep_names) if rank == 0: start_time = time.time() ch_count = 0 env_type = env.env.env.__class__.__name__ while True: if callback: callback(locals(), globals()) if max_timesteps and timesteps_so_far >= max_timesteps: break elif max_episodes and episodes_so_far >= max_episodes: break elif max_iters and iters_so_far >= max_iters: break # Save model if rank == 0 and iters_so_far % save_per_iter == 0 and ckpt_dir is not None: fname = os.path.join(ckpt_dir, task_name) if env_type.find( "Pendulum" ) != -1 or save_per_iter != 1: # allow pendulum to save all iterations fname = os.path.join(ckpt_dir, 'iter_' + str(iters_so_far), 'iter_' + str(iters_so_far)) os.makedirs(os.path.dirname(fname), exist_ok=True) saver = tf.train.Saver() saver.save(tf.get_default_session(), fname, write_meta_graph=False) if rank == 0 and time.time( ) - start_time >= 3600 * ch_count: # save a different checkpoint every hour fname = os.path.join(ckpt_dir, 'hour' + str(ch_count).zfill(3)) fname = os.path.join(fname, 'iter_' + str(iters_so_far)) os.makedirs(os.path.dirname(fname), exist_ok=True) saver = tf.train.Saver() saver.save(tf.get_default_session(), fname, write_meta_graph=False) ch_count += 1 logger.log("********** Iteration %i ************" % iters_so_far) def fisher_vector_product(p): return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p # ------------------ Update G ------------------ logger.log("Optimizing Policy...") for curr_step in range(g_step): with timed("sampling"): seg = seg_gen.__next__() seg["rew"] = seg["rew"] - seg["l2_loss"] * l2_w add_vtarg_and_adv(seg, gamma, lam) # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets)) ob, ac, atarg, tdlamret, full_ob = seg["ob"], seg["ac"], seg[ "adv"], seg["tdlamret"], seg["full_ob"] vpredbefore = seg[ "vpred"] # predicted value function before udpate atarg = (atarg - atarg.mean()) / atarg.std( ) # standardized advantage function estimate d = Dataset(dict(ob=full_ob, ac=ac, atarg=atarg, vtarg=tdlamret), shuffle=True) if not_update: break # stop G from updating if hasattr(pi, "ob_rms"): pi.ob_rms.update(full_ob) # update running mean/std for policy args = seg["full_ob"], seg["ac"], atarg fvpargs = [arr[::5] for arr in args] assign_old_eq_new( ) # set old parameter values to new parameter values with timed("computegrad"): *lossbefore, g = compute_lossandgrad(*args) lossbefore = allmean(np.array(lossbefore)) g = allmean(g) if np.allclose(g, 0): logger.log("Got zero gradient. not updating") else: with timed("cg"): stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank == 0) assert np.isfinite(stepdir).all() shs = .5 * stepdir.dot(fisher_vector_product(stepdir)) lm = np.sqrt(shs / max_kl) # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g)) fullstep = stepdir / lm expectedimprove = g.dot(fullstep) surrbefore = lossbefore[0] stepsize = 1.0 thbefore = get_flat() for _ in range(10): thnew = thbefore + fullstep * stepsize set_from_flat(thnew) meanlosses = surr, kl, *_ = allmean( np.array(compute_losses(*args))) if rank == 0: print("Generator entropy " + str(meanlosses[4]) + ", loss " + str(meanlosses[2])) improve = surr - surrbefore logger.log("Expected: %.3f Actual: %.3f" % (expectedimprove, improve)) if not np.isfinite(meanlosses).all(): logger.log("Got non-finite value of losses -- bad!") elif kl > max_kl * 1.5: logger.log("violated KL constraint. shrinking step.") elif improve < 0: logger.log("surrogate didn't improve. shrinking step.") else: logger.log("Stepsize OK!") break stepsize *= .5 else: logger.log("couldn't compute a good step") set_from_flat(thbefore) if nworkers > 1 and iters_so_far % 20 == 0: paramsums = MPI.COMM_WORLD.allgather( (thnew.sum(), vfadam.getflat().sum())) # list of tuples assert all( np.allclose(ps, paramsums[0]) for ps in paramsums[1:]) with timed("vf"): logger.log(fmt_row(13, vf_loss_names)) for _ in range(vf_iters): vf_b_losses = [] for batch in d.iterate_once(vf_batchsize): mbob = batch["ob"] mbret = batch["vtarg"] if hasattr(pi, "ob_rms"): pi.ob_rms.update( mbob) # update running mean/std for policy *newlosses, g = compute_vflossandgrad(mbob, mbret) g = allmean(g) newlosses = allmean(np.array(newlosses)) vfadam.update(g, vf_stepsize) vf_b_losses.append(newlosses) logger.log(fmt_row(13, np.mean(vf_b_losses, axis=0))) logger.log("Evaluating losses...") losses = [] for batch in d.iterate_once(vf_batchsize): newlosses = compute_vf_losses(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"]) losses.append(newlosses) meanlosses, _, _ = mpi_moments(losses, axis=0) ######################### ''' For evaluation during training. Can be commented out for faster training... ''' for ob_batch, ac_batch, full_ob_batch in dataset.iterbatches( (ob, ac, full_ob), include_final_partial_batch=False, batch_size=len(ob)): ob_expert, ac_expert = expert_dataset.get_next_batch( len(ob_batch)) exp_rew = 0 for obs, acs in zip(ob_expert, ac_expert): exp_rew += 1 - np.exp( -reward_giver.get_reward(obs, acs)[0][0]) mean_exp_rew = exp_rew / len(ob_expert) gen_rew = 0 for obs, acs, full_obs in zip(ob_batch, ac_batch, full_ob_batch): gen_rew += 1 - np.exp( -reward_giver.get_reward(obs, acs)[0][0]) mean_gen_rew = gen_rew / len(ob_batch) if rank == 0: logger.log("Generator step " + str(curr_step) + ": Dicriminator reward of expert traj " + str(mean_exp_rew) + " vs gen traj " + str(mean_gen_rew)) ######################### if not not_update: g_losses = meanlosses for (lossname, lossval) in zip(loss_names + vf_loss_names, meanlosses): logger.record_tabular(lossname, lossval) logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret)) # ------------------ Update D ------------------ if not freeze_d: logger.log("Optimizing Discriminator...") batch_size = len(ob) // d_step d_losses = [ ] # list of tuples, each of which gives the loss for a minibatch for ob_batch, ac_batch, full_ob_batch in dataset.iterbatches( (ob, ac, full_ob), include_final_partial_batch=False, batch_size=batch_size): ob_expert, ac_expert = expert_dataset.get_next_batch( len(ob_batch)) ######################### ''' For evaluation during training. Can be commented out for faster training... ''' exp_rew = 0 for obs, acs in zip(ob_expert, ac_expert): exp_rew += 1 - np.exp( -reward_giver.get_reward(obs, acs)[0][0]) mean_exp_rew = exp_rew / len(ob_expert) gen_rew = 0 for obs, acs in zip(ob_batch, ac_batch): gen_rew += 1 - np.exp( -reward_giver.get_reward(obs, acs)[0][0]) mean_gen_rew = gen_rew / len(ob_batch) if rank == 0: logger.log("Dicriminator reward of expert traj " + str(mean_exp_rew) + " vs gen traj " + str(mean_gen_rew)) ######################### # update running mean/std for reward_giver if hasattr(reward_giver, "obs_rms"): reward_giver.obs_rms.update( np.concatenate((ob_batch, ob_expert), 0)) loss_input = (ob_batch, ac_batch, ob_expert, ac_expert) *newlosses, g = reward_giver.lossandgrad(*loss_input) d_adam.update(allmean(g), d_stepsize) d_losses.append(newlosses) logger.log(fmt_row(13, reward_giver.loss_name)) logger.log(fmt_row(13, np.mean(d_losses, axis=0))) lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ep_true_rets"], seg["ep_success"], seg["ep_semi_loss"]) # local values listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples lens, rews, true_rets, success, semi_losses = map( flatten_lists, zip(*listoflrpairs)) # success success = [ float(elem) for elem in success if isinstance(elem, (int, float, bool)) ] # remove potential None types if not success: success = [-1] # set success to -1 if env has no success flag success_buffer.extend(success) true_rewbuffer.extend(true_rets) lenbuffer.extend(lens) rewbuffer.extend(rews) if semi_loss and semi_dataset is not None: semi_losses = [elem * l2_w for elem in semi_losses] total_rewards = rews total_rewards = [ re_elem - l2_elem for re_elem, l2_elem in zip(total_rewards, semi_losses) ] l2_rewbuffer.extend(semi_losses) total_rewbuffer.extend(total_rewards) logger.record_tabular("EpLenMean", np.mean(lenbuffer)) logger.record_tabular("EpRewMean", np.mean(rewbuffer)) logger.record_tabular("EpTrueRewMean", np.mean(true_rewbuffer)) logger.record_tabular("EpSuccess", np.mean(success_buffer)) if semi_loss and semi_dataset is not None: logger.record_tabular("EpSemiLoss", np.mean(l2_rewbuffer)) logger.record_tabular("EpTotalReward", np.mean(total_rewbuffer)) logger.record_tabular("EpThisIter", len(lens)) episodes_so_far += len(lens) timesteps_so_far += sum(lens) iters_so_far += 1 logger.record_tabular("EpisodesSoFar", episodes_so_far) logger.record_tabular("TimestepsSoFar", timesteps_so_far) logger.record_tabular("TimeElapsed", time.time() - tstart) logger.record_tabular("ItersSoFar", iters_so_far) if rank == 0: logger.dump_tabular() if not not_update: g_loss_stats.add_all_summary(writer, g_losses, iters_so_far) if not freeze_d: d_loss_stats.add_all_summary(writer, np.mean(d_losses, axis=0), iters_so_far) # default buffers ep_buffers = [ np.mean(true_rewbuffer), np.mean(rewbuffer), np.mean(lenbuffer), np.mean(success_buffer) ] if semi_loss and semi_dataset is not None: ep_buffers.append(np.mean(l2_rewbuffer)) ep_buffers.append(np.mean(total_rewbuffer)) ep_stats.add_all_summary(writer, ep_buffers, iters_so_far) if not_update and not freeze_g: not_update -= 1
def learn( env, policy_fn, reward_giver, expert_dataset, *, timesteps_per_actorbatch, # timesteps per actor per update clip_param, entcoeff, # clipping parameter epsilon, entropy coeff optim_epochs, optim_stepsize, optim_batchsize, # optimization hypers gamma, lam, # advantage estimation max_timesteps=0, max_episodes=0, max_iters=0, max_seconds=0, # time constraint callback=None, # you can do anything in the callback, since it takes locals(), globals() adam_epsilon=1e-5, schedule='constant' # annealing for stepsize parameters (epsilon and adam) ): # Setup losses and stuff # ---------------------------------------- d_stepsize = 3e-4 ob_space = env.observation_space ac_space = env.action_space pi = policy_fn("pi", ob_space, ac_space) # Construct network for new policy oldpi = policy_fn("oldpi", ob_space, ac_space) # Network for old policy atarg = tf.placeholder( dtype=tf.float32, shape=[None]) # Target advantage function (if applicable) ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return lrmult = tf.placeholder( name='lrmult', dtype=tf.float32, shape=[]) # learning rate multiplier, updated with schedule clip_param = clip_param * lrmult # Annealed cliping parameter epislon ob = U.get_placeholder_cached(name="ob") ac = pi.pdtype.sample_placeholder([None]) kloldnew = oldpi.pd.kl(pi.pd) ent = pi.pd.entropy() meankl = tf.reduce_mean(kloldnew) meanent = tf.reduce_mean(ent) pol_entpen = (-entcoeff) * meanent ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # pnew / pold surr1 = ratio * atarg # surrogate from conservative policy iteration surr2 = tf.clip_by_value(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg # pol_surr = -tf.reduce_mean(tf.minimum( surr1, surr2)) # PPO's pessimistic surrogate (L^CLIP) vf_loss = tf.reduce_mean(tf.square(pi.vpred - ret)) total_loss = pol_surr + pol_entpen + vf_loss losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent] loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"] var_list = pi.get_trainable_variables() lossandgrad = U.function([ob, ac, atarg, ret, lrmult], losses + [U.flatgrad(total_loss, var_list)]) adam = MpiAdam(var_list, epsilon=adam_epsilon) d_adam = MpiAdam(reward_giver.get_trainable_variables()) assign_old_eq_new = U.function( [], [], updates=[ tf.assign(oldv, newv) for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables()) ]) compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses) U.initialize() adam.sync() d_adam.sync() # Prepare for rollouts # ---------------------------------------- viewer = mujoco_py.MjViewer(env.sim) seg_gen = traj_segment_generator(pi, env, viewer, reward_giver, timesteps_per_actorbatch, stochastic=True) episodes_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 tstart = time.time() lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards true_rewbuffer = deque(maxlen=40) assert sum( [max_iters > 0, max_timesteps > 0, max_episodes > 0, max_seconds > 0]) == 1, "Only one time constraint permitted" while True: if callback: callback(locals(), globals()) if max_timesteps and timesteps_so_far >= max_timesteps: break elif max_episodes and episodes_so_far >= max_episodes: break elif max_iters and iters_so_far >= max_iters: break elif max_seconds and time.time() - tstart >= max_seconds: break if schedule == 'constant': cur_lrmult = 1.0 elif schedule == 'linear': cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0) else: raise NotImplementedError logger.log("********** Iteration %i ************" % iters_so_far) seg = seg_gen.__next__() add_vtarg_and_adv(seg, gamma, lam) # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets)) ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[ "tdlamret"] vpredbefore = seg["vpred"] # predicted value function before udpate atarg = (atarg - atarg.mean() ) / atarg.std() # standardized advantage function estimate d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret), shuffle=not pi.recurrent) optim_batchsize = optim_batchsize or ob.shape[0] if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy assign_old_eq_new() # set old parameter values to new parameter values logger.log("Optimizing...") logger.log(fmt_row(13, loss_names)) # Here we do a bunch of optimization epochs over the data for _ in range(optim_epochs): losses = [ ] # list of tuples, each of which gives the loss for a minibatch for batch in d.iterate_once(optim_batchsize): *newlosses, g = lossandgrad(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult) adam.update(g, optim_stepsize * cur_lrmult) losses.append(newlosses) logger.log(fmt_row(13, np.mean(losses, axis=0))) # ------------------ Update D ------------------ logger.log("Optimizing Discriminator...") logger.log(fmt_row(13, reward_giver.loss_name)) ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob)) batch_size = len(ob) d_losses = [ ] # list of tuples, each of which gives the loss for a minibatch for ob_batch, ac_batch in dataset.iterbatches( (ob, ac), include_final_partial_batch=False, batch_size=batch_size): ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob_batch)) # update running mean/std for reward_giver if hasattr(reward_giver, "obs_rms"): reward_giver.obs_rms.update( np.concatenate((ob_batch, ob_expert), 0)) *newlosses, g = reward_giver.lossandgrad(ob_batch, ac_batch, ob_expert, ac_expert) d_adam.update(g, d_stepsize) d_losses.append(newlosses) logger.log("Evaluating losses...") losses = [] for batch in d.iterate_once(optim_batchsize): newlosses = compute_losses(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult) losses.append(newlosses) meanlosses, _, _ = mpi_moments(losses, axis=0) logger.log(fmt_row(13, meanlosses)) for (lossval, name) in zipsame(meanlosses, loss_names): logger.record_tabular("loss_" + name, lossval) logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret)) lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples lens, rews = map(flatten_lists, zip(*listoflrpairs)) lenbuffer.extend(lens) rewbuffer.extend(rews) logger.record_tabular("EpLenMean", np.mean(lenbuffer)) logger.record_tabular("EpRewMean", np.mean(rewbuffer)) logger.record_tabular("EpThisIter", len(lens)) episodes_so_far += len(lens) timesteps_so_far += sum(lens) iters_so_far += 1 logger.record_tabular("EpisodesSoFar", episodes_so_far) logger.record_tabular("TimestepsSoFar", timesteps_so_far) logger.record_tabular("TimeElapsed", time.time() - tstart) if MPI.COMM_WORLD.Get_rank() == 0: logger.dump_tabular() return pi
def run_hoof_no_lamgam( network, env, total_timesteps, timesteps_per_batch, # what to train on kl_range, gamma_range, lam_range, # advantage estimation num_kl, num_gamma_lam, cg_iters=10, seed=None, ent_coef=0.0, cg_damping=1e-2, vf_stepsize=3e-4, vf_iters=3, max_episodes=0, max_iters=0, # time constraint callback=None, load_path=None, **network_kwargs): ''' learn a policy function with TRPO algorithm Parameters: ---------- network neural network to learn. Can be either string ('mlp', 'cnn', 'lstm', 'lnlstm' for basic types) or function that takes input placeholder and returns tuple (output, None) for feedforward nets or (output, (state_placeholder, state_output, mask_placeholder)) for recurrent nets env environment (one of the gym environments or wrapped via baselines.common.vec_env.VecEnv-type class timesteps_per_batch timesteps per gradient estimation batch max_kl max KL divergence between old policy and new policy ( KL(pi_old || pi) ) ent_coef coefficient of policy entropy term in the optimization objective cg_iters number of iterations of conjugate gradient algorithm cg_damping conjugate gradient damping vf_stepsize learning rate for adam optimizer used to optimie value function loss vf_iters number of iterations of value function optimization iterations per each policy optimization step total_timesteps max number of timesteps max_episodes max number of episodes max_iters maximum number of policy optimization iterations callback function to be called with (locals(), globals()) each policy optimization step load_path str, path to load the model from (default: None, i.e. no model is loaded) **network_kwargs keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network Returns: ------- learnt model ''' MPI = None nworkers = 1 rank = 0 cpus_per_worker = 1 U.get_session( config=tf.ConfigProto(allow_soft_placement=True, inter_op_parallelism_threads=cpus_per_worker, intra_op_parallelism_threads=cpus_per_worker)) policy = build_policy(env, network, value_network='copy', **network_kwargs) set_global_seeds(seed) np.set_printoptions(precision=3) # Setup losses and stuff # ---------------------------------------- ob_space = env.observation_space ac_space = env.action_space # +2 for gamma, lambda ob = tf.placeholder(shape=(None, env.observation_space.shape[0] + 2), dtype=env.observation_space.dtype, name='Ob') with tf.variable_scope("pi"): pi = policy(observ_placeholder=ob) with tf.variable_scope("oldpi"): oldpi = policy(observ_placeholder=ob) atarg = tf.placeholder( dtype=tf.float32, shape=[None]) # Target advantage function (if applicable) ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return ac = pi.pdtype.sample_placeholder([None]) kloldnew = oldpi.pd.kl(pi.pd) ent = pi.pd.entropy() meankl = tf.reduce_mean(kloldnew) meanent = tf.reduce_mean(ent) entbonus = ent_coef * meanent vferr = tf.reduce_mean(tf.square(pi.vf - ret)) ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold surrgain = tf.reduce_mean(ratio * atarg) optimgain = surrgain + entbonus losses = [optimgain, meankl, entbonus, surrgain, meanent] loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"] dist = meankl all_var_list = get_trainable_variables("pi") var_list = get_pi_trainable_variables("pi") vf_var_list = get_vf_trainable_variables("pi") vfadam = MpiAdam(vf_var_list) get_flat = U.GetFlat(var_list) set_from_flat = U.SetFromFlat(var_list) klgrads = tf.gradients(dist, var_list) flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan") shapes = [var.get_shape().as_list() for var in var_list] start = 0 tangents = [] for shape in shapes: sz = U.intprod(shape) tangents.append(tf.reshape(flat_tangent[start:start + sz], shape)) start += sz gvp = tf.add_n([ tf.reduce_sum(g * tangent) for (g, tangent) in zipsame(klgrads, tangents) ]) #pylint: disable=E1111 fvp = U.flatgrad(gvp, var_list) assign_old_eq_new = U.function( [], [], updates=[ tf.assign(oldv, newv) for (oldv, newv) in zipsame(get_variables("oldpi"), get_variables("pi")) ]) compute_ratio = U.function( [ob, ac, atarg], ratio) # IS ratio - used for computing IS weights compute_losses = U.function([ob, ac, atarg], losses) compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)]) compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp) compute_vflossandgrad = U.function([ob, ret], U.flatgrad(vferr, vf_var_list)) @contextmanager def timed(msg): if rank == 0: print(colorize(msg, color='magenta')) tstart = time.time() yield print( colorize("done in %.3f seconds" % (time.time() - tstart), color='magenta')) else: yield def allmean(x): assert isinstance(x, np.ndarray) if MPI is not None: out = np.empty_like(x) MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM) out /= nworkers else: out = np.copy(x) return out U.initialize() if load_path is not None: pi.load(load_path) th_init = get_flat() if MPI is not None: MPI.COMM_WORLD.Bcast(th_init, root=0) set_from_flat(th_init) vfadam.sync() print("Init param sum", th_init.sum(), flush=True) # Prepare for rollouts # ---------------------------------------- seg_gen = traj_segment_generator_with_gl(pi, env, timesteps_per_batch, stochastic=True) episodes_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 tstart = time.time() lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards if sum([max_iters > 0, total_timesteps > 0, max_episodes > 0]) == 0: # noththing to be done return pi assert sum([max_iters>0, total_timesteps>0, max_episodes>0]) < 2, \ 'out of max_iters, total_timesteps, and max_episodes only one should be specified' kl_range = np.atleast_1d(kl_range) gamma_range = np.atleast_1d(gamma_range) lam_range = np.atleast_1d(lam_range) while True: if callback: callback(locals(), globals()) if total_timesteps and timesteps_so_far >= total_timesteps: break elif max_episodes and episodes_so_far >= max_episodes: break elif max_iters and iters_so_far >= max_iters: break logger.log("********** Iteration %i ************" % iters_so_far) with timed("sampling"): seg = seg_gen.__next__() thbefore = get_flat() rand_gamma = gamma_range[0] + ( gamma_range[-1] - gamma_range[0]) * np.random.rand(num_gamma_lam) rand_lam = lam_range[0] + ( lam_range[-1] - lam_range[0]) * np.random.rand(num_gamma_lam) rand_kl = kl_range[0] + (kl_range[-1] - kl_range[0]) * np.random.rand(num_kl) opt_polval = -10**8 est_polval = np.zeros((num_gamma_lam, num_kl)) ob_lam_gam = [] tdlamret = [] vpred = [] for gl in range(num_gamma_lam): oblg, vpredbefore, atarg, tdlr = add_vtarg_and_adv_without_gl( pi, seg, rand_gamma[gl], rand_lam[gl]) ob_lam_gam += [oblg] tdlamret += [tdlr] vpred += [vpredbefore] atarg = (atarg - atarg.mean()) / atarg.std( ) # standardized advantage function estimate pol_ob = np.concatenate( (seg['ob'], np.zeros(seg['ob'].shape[:-1] + (2, ))), axis=-1) args = pol_ob, seg["ac"], atarg fvpargs = [arr[::5] for arr in args] def fisher_vector_product(p): return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p assign_old_eq_new( ) # set old parameter values to new parameter values with timed("computegrad"): *lossbefore, g = compute_lossandgrad(*args) lossbefore = allmean(np.array(lossbefore)) g = allmean(g) if np.allclose(g, 0): logger.log("Got zero gradient. not updating") else: with timed("cg"): stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=False) assert np.isfinite(stepdir).all() shs = .5 * stepdir.dot(fisher_vector_product(stepdir)) surrbefore = lossbefore[0] for m, kl in enumerate(rand_kl): lm = np.sqrt(shs / kl) fullstep = stepdir / lm thnew = thbefore + fullstep set_from_flat(thnew) # compute the IS estimates lik_ratio = compute_ratio(*args) est_polval[gl, m] = wis_estimate(seg, lik_ratio) # update best policy found so far if est_polval[gl, m] > opt_polval: opt_polval = est_polval[gl, m] opt_th = thnew opt_kl = kl opt_gamma = rand_gamma[gl] opt_lam = rand_lam[gl] opt_vpredbefore = vpredbefore opt_tdlr = tdlr meanlosses = surr, kl, *_ = allmean( np.array(compute_losses(*args))) improve = surr - surrbefore expectedimprove = g.dot(fullstep) set_from_flat(thbefore) logger.log("Expected: %.3f Actual: %.3f" % (expectedimprove, improve)) set_from_flat(opt_th) for (lossname, lossval) in zip(loss_names, meanlosses): logger.record_tabular(lossname, lossval) ob_lam_gam = np.concatenate(ob_lam_gam, axis=0) tdlamret = np.concatenate(tdlamret, axis=0) vpred = np.concatenate(vpred, axis=0) with timed("vf"): for _ in range(vf_iters): for (mbob, mbret) in dataset.iterbatches( (ob_lam_gam, tdlamret), include_final_partial_batch=False, batch_size=num_gamma_lam * 64): g = allmean(compute_vflossandgrad(mbob, mbret)) vfadam.update(g, vf_stepsize) logger.record_tabular("ev_tdlam_before", explained_variance(vpred, tdlamret)) lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values if MPI is not None: listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples else: listoflrpairs = [lrlocal] lens, rews = map(flatten_lists, zip(*listoflrpairs)) lenbuffer.extend(lens) rewbuffer.extend(rews) logger.record_tabular("EpLenMean", np.mean(lenbuffer)) logger.record_tabular("EpRewMean", np.mean(rewbuffer)) logger.record_tabular("EpThisIter", len(lens)) episodes_so_far += len(lens) timesteps_so_far += sum(lens) iters_so_far += 1 logger.record_tabular("EpisodesSoFar", episodes_so_far) logger.record_tabular("TimestepsSoFar", timesteps_so_far) logger.record_tabular("TimeElapsed", time.time() - tstart) logger.record_tabular("Opt_KL", opt_kl) logger.record_tabular("gamma", opt_gamma) logger.record_tabular("lam", opt_lam) if rank == 0: logger.dump_tabular() return pi
def train(self,ob, ac, atarg, tdlamret): # 标准化 atarg = (atarg - np.mean(atarg,axis = 0)) / np.std(atarg,axis=0) # standardized advantage function estimate if hasattr(self.pi, "ret_rms"): self.pi.ret_rms.update(tdlamret) if hasattr(self.pi, "ob_rms"): self.pi.ob_rms.update(ob) # update running mean/std for policy ## set old parameter values to new parameter values self.assign_old_eq_new() args = ob, ac, atarg # 算是args的一个sample,每隔5个取出一个 fvpargs = [arr[::5] for arr in args] # 这个函数计算fisher matrix 与向量 p 的 乘积 def fisher_vector_product(p): return allmean(self.compute_fvp(p, *fvpargs)) + self.cg_damping * p with self.timed("computegrad of " + str(self.index+1) +".th reward"): *lossbefore, g = self.compute_lossandgrad(*args) lossbefore = allmean(np.array(lossbefore)) g = allmean(g) # g是目标函数的梯度 # 利用共轭梯度获得更新方向 try: if np.allclose(g, 0): logger.log("Got zero gradient. not updating") else: with self.timed("cg of " + str(self.index+1) +".th reward"): # stepdir 是更新方向 stepdir = cg(fisher_vector_product, g, cg_iters=self.cg_iters, verbose=self.rank==0) assert np.isfinite(stepdir).all() shs = .5*stepdir.dot(fisher_vector_product(stepdir)) lm = np.sqrt(shs / self.max_kl) # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g)) fullstep = stepdir / lm expectedimprove = g.dot(fullstep) surrbefore = lossbefore[0] stepsize = 1.0 thbefore = self.get_flat() # 做10次搜索 for _ in range(10): thnew = thbefore + fullstep * stepsize self.set_from_flat(thnew) meanlosses = surr, kl, *_ = allmean(np.array(self.compute_losses(*args))) improve = surr - surrbefore logger.log("Expected: %.3f Actual: %.3f"%(expectedimprove, improve)) if not np.isfinite(meanlosses).all(): logger.log("Got non-finite value of losses -- bad!") elif kl > self.max_kl * 1.5: logger.log("violated KL constraint. shrinking step.") elif improve < 0: logger.log("surrogate didn't improve. shrinking step.") else: logger.log("Stepsize OK!") break stepsize *= .5 else: logger.log("couldn't compute a good step") self.set_from_flat(thbefore) with self.timed("vf"): for _ in range(self.vf_iters): for (mbob, mbret) in dataset.iterbatches((ob , tdlamret), include_final_partial_batch=False, batch_size=64): #with tf.Session() as sess: # sess.run(tf.global_variables_initializer()) # print(sess.run(vferr,feed_dict={ob:mbob,ret:mbret})) g = allmean(self.compute_vflossandgrad(mbob, mbret)) self.vfadam.update(g, self.vf_stepsize) except: print("can't learn")
def learn(env, policy_func, rank, pretrained, pretrained_weight, *, g_step, d_step, entcoeff, save_per_iter, ckpt_dir, log_dir, timesteps_per_batch, task_name, gamma, lam, max_kl, cg_iters, cg_damping=1e-2, vf_stepsize=1e-4, vf_iters=3, max_timesteps=0, max_episodes=0, max_iters=0, callback=None): nworkers = MPI.COMM_WORLD.Get_size() rank = MPI.COMM_WORLD.Get_rank() np.set_printoptions(precision=3) # Setup losses and stuff # ---------------------------------------- ob_space = env.observation_space ac_space = env.action_space pi = policy_func( "pi", ob_space, ac_space, ) # reuse=(pretrained_weight != None) oldpi = policy_func("oldpi", ob_space, ac_space) atarg = tf.placeholder( dtype=tf.float32, shape=[None]) # Target advantage function (if applicable) ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return #ob = U.get_placeholder_cached(name="ob") ob_config = U.get_placeholder_cached(name="ob") ob_target = U.get_placeholder_cached(name="goal") obs_pos = U.get_placeholder_cached(name="obs_pos") #obs_pos2 = U.get_placeholder_cached(name="obs_pos2") ac = pi.pdtype.sample_placeholder([None]) kloldnew = oldpi.pd.kl(pi.pd) ent = pi.pd.entropy() meankl = tf.reduce_mean(kloldnew) meanent = tf.reduce_mean(ent) entbonus = entcoeff * meanent vferr = tf.reduce_mean(tf.square(pi.vpred - ret)) ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold surrgain = tf.reduce_mean(ratio * atarg) optimgain = surrgain + entbonus losses = [optimgain, meankl, entbonus, surrgain, meanent] loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"] dist = meankl all_var_list = pi.get_trainable_variables() var_list = [ v for v in all_var_list if v.name.startswith("pi/pol") or v.name.startswith("pi/logstd") or v.name.startswith("pi/obs") ] vf_var_list = [ v for v in all_var_list if v.name.startswith("pi/vf") or v.name.startswith("pi/obs") ] vfadam = MpiAdam(vf_var_list) get_flat = U.GetFlat(var_list) set_from_flat = U.SetFromFlat(var_list) klgrads = tf.gradients(dist, var_list) flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan") shapes = [var.get_shape().as_list() for var in var_list] start = 0 tangents = [] for shape in shapes: sz = U.intprod(shape) tangents.append(tf.reshape(flat_tangent[start:start + sz], shape)) start += sz gvp = tf.add_n([ tf.reduce_sum(g * tangent) for (g, tangent) in zipsame(klgrads, tangents) ]) # pylint: disable=E1111 fvp = U.flatgrad(gvp, var_list) assign_old_eq_new = U.function( [], [], updates=[ tf.assign(oldv, newv) for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables()) ]) compute_losses = U.function([ob_config, ob_target, obs_pos, ac, atarg], losses) compute_lossandgrad = U.function( [ob_config, ob_target, obs_pos, ac, atarg], losses + [U.flatgrad(optimgain, var_list)]) compute_fvp = U.function( [flat_tangent, ob_config, ob_target, obs_pos, ac, atarg], fvp) compute_vflossandgrad = U.function([ob_config, ob_target, obs_pos, ret], U.flatgrad(vferr, vf_var_list)) @contextmanager def timed(msg): if rank == 0: print(colorize(msg, color='magenta')) tstart = time.time() yield print( colorize("done in %.3f seconds" % (time.time() - tstart), color='magenta')) else: yield def allmean(x): assert isinstance(x, np.ndarray) out = np.empty_like(x) MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM) out /= nworkers return out U.initialize() th_init = get_flat() MPI.COMM_WORLD.Bcast(th_init, root=0) set_from_flat(th_init) vfadam.sync() if rank == 0: print("Init param sum", th_init.sum(), flush=True) # Prepare for rollouts # ---------------------------------------- seg_gen = traj_segment_generator(pi, env, timesteps_per_batch, stochastic=True) episodes_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 tstart = time.time() lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards true_rewbuffer = deque(maxlen=40) max_trm = -5 true_reward_mean = 0 assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1 g_loss_stats = stats(loss_names) ep_stats = stats(["True_rewards", "Rewards", "Episode_length"]) # if provide pretrained weight if pretrained_weight is not None: #U.load_variables(pretrained_weight, variables=pi.get_variables()) saver = tf.train.Saver() saver.restore(tf.get_default_session(), pretrained_weight) while True: if callback: callback(locals(), globals()) if max_timesteps and timesteps_so_far >= max_timesteps: break elif max_episodes and episodes_so_far >= max_episodes: break elif max_iters and iters_so_far >= max_iters: break # Save model if rank == 0 and ckpt_dir is not None and true_reward_mean > max_trm: fname = os.path.join(ckpt_dir, task_name) os.makedirs(os.path.dirname(fname), exist_ok=True) saver = tf.train.Saver() saver.save(tf.get_default_session(), fname) max_trm = true_reward_mean logger.log("********** Iteration %i ************" % iters_so_far) def fisher_vector_product(p): v1 = allmean(compute_fvp(p, *fvpargs)) # print("norm(v1):%.2e, norm(p):%.2e, cg_damping:%.2e"%(np.linalg.norm(v1), np.linalg.norm(p), cg_damping)) return v1 + cg_damping * p # ------------------ Update G ------------------ logger.log("Optimizing Policy...") for _ in range(g_step): with timed("sampling"): seg = seg_gen.__next__() add_vtarg_and_adv(seg, gamma, lam) # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets)) ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[ "tdlamret"] vpredbefore = seg[ "vpred"] # predicted value function before udpate atarg = (atarg - atarg.mean()) / atarg.std( ) # standardized advantage function estimate if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy config, goal, obstacle_pos = [], [], [] for o in seg["ob"]: config.append(o["joint"]) goal.append(o["target"]) obstacle_pos.append(o["obstacle_pos1"]) #obstacle_pos2.append(o["obstacle_pos2"]) config, goal, obstacle_pos = map(np.array, [config, goal, obstacle_pos]) args = config, goal, obstacle_pos, seg["ac"], atarg fvpargs = [arr[::5] for arr in args] assign_old_eq_new( ) # set old parameter values to new parameter values with timed("computegrad"): *lossbefore, g = compute_lossandgrad(*args) lossbefore = allmean(np.array(lossbefore)) g = allmean(g) if np.allclose(g, 0): logger.log("Got zero gradient. not updating") else: with timed("cg"): stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank == 0) logger.log( 'iter:{:d}, norm of g: {:.4f}, error of cg: {:.4f}'. format( cg_iters, np.linalg.norm(g), np.linalg.norm(g - compute_fvp(stepdir, *fvpargs)))) assert np.isfinite(stepdir).all() shs = .5 * stepdir.dot(fisher_vector_product(stepdir)) lm = np.sqrt(shs / max_kl) # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g)) fullstep = stepdir / lm expectedimprove = g.dot(fullstep) surrbefore = lossbefore[0] stepsize = 1.0 thbefore = get_flat() for _ in range(10): thnew = thbefore + fullstep * stepsize set_from_flat(thnew) meanlosses = surr, kl, *_ = allmean( np.array(compute_losses(*args))) improve = surr - surrbefore logger.log("Expected: %.3f Actual: %.3f" % (expectedimprove, improve)) if not np.isfinite(meanlosses).all(): logger.log("Got non-finite value of losses -- bad!") elif kl > max_kl * 1.5: logger.log("violated KL constraint. shrinking step.") elif improve < 0: logger.log("surrogate didn't improve. shrinking step.") else: logger.log("Stepsize OK!") break stepsize *= .5 else: logger.log("couldn't compute a good step") set_from_flat(thbefore) if nworkers > 1 and iters_so_far % 20 == 0: paramsums = MPI.COMM_WORLD.allgather( (thnew.sum(), vfadam.getflat().sum())) # list of tuples assert all( np.allclose(ps, paramsums[0]) for ps in paramsums[1:]) with timed("vf"): for _ in range(vf_iters): for (mbob, mbg, mbop, mbret) in dataset.iterbatches( (config, goal, obstacle_pos, seg["tdlamret"]), include_final_partial_batch=False, batch_size=128): if hasattr(pi, "ob_rms"): pi.ob_rms.update( mbob) # update running mean/std for policy g = allmean( compute_vflossandgrad(mbob, mbg, mbop, mbret)) vfadam.update(g, vf_stepsize) g_losses = meanlosses for (lossname, lossval) in zip(loss_names, meanlosses): logger.record_tabular(lossname, lossval) logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret)) lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ep_true_rets"] ) # local values listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples lens, rews, true_rets = map(flatten_lists, zip(*listoflrpairs)) true_rewbuffer.extend(true_rets) lenbuffer.extend(lens) rewbuffer.extend(rews) logger.record_tabular("EpLenMean", np.mean(lenbuffer)) logger.record_tabular("EpRewMean", np.mean(rewbuffer)) logger.record_tabular("EpTrueRewMean", np.mean(true_rewbuffer)) true_reward_mean = np.mean(true_rewbuffer) logger.record_tabular("EpThisIter", len(lens)) episodes_so_far += len(lens) timesteps_so_far += sum(lens) iters_so_far += 1 logger.record_tabular("EpisodesSoFar", episodes_so_far) logger.record_tabular("TimestepsSoFar", timesteps_so_far) logger.record_tabular("TimeElapsed", time.time() - tstart) if rank == 0: logger.dump_tabular()
def learn(self): # Prepare for rollouts # ---------------------------------------- seg_gen = traj_segment_generator(self.pi, self.env, self.timesteps_per_batch, stochastic=True) episodes_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 tstart = time.time() lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards assert sum([ self.max_iters > 0, self.max_timesteps > 0, self.max_episodes > 0 ]) == 1 while True: if self.max_timesteps and timesteps_so_far >= self.max_timesteps: break elif self.max_episodes and episodes_so_far >= self.max_episodes: break elif self.max_iters and iters_so_far >= self.max_iters: break logger.log("********** Iteration %i ************" % iters_so_far) with self.timed("sampling"): seg = seg_gen.__next__() add_vtarg_and_adv(seg, self.gamma, self.lam) # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets)) self.ob, self.ac, self.atarg, self.tdlamret = seg["ob"], seg[ "ac"], seg["adv"], seg["tdlamret"] self.vpredbefore = seg[ "vpred"] # predicted value function before udpate self.atarg = (self.atarg - self.atarg.mean()) / self.atarg.std( ) # standardized advantage function estimate if hasattr(self.pi, "ret_rms"): self.pi.ret_rms.update(self.tdlamret) if hasattr(self.pi, "ob_rms"): self.pi.ob_rms.update( self.ob) # update running mean/std for policy args = seg["ob"], seg["ac"], self.atarg self.fvpargs = [arr[::5] for arr in args] self.assign_old_eq_new( ) # set old parameter values to new parameter values with self.timed("computegrad"): *lossbefore, g = self.compute_lossandgrad(*args) lossbefore = self.allmean(np.array(lossbefore)) g = self.allmean(g) if np.allclose(g, 0): logger.log("Got zero gradient. not updating") else: with self.timed("cg"): stepdir = cg(self.fisher_vector_product, g, cg_iters=self.cg_iters, verbose=self.rank == 0) assert np.isfinite(stepdir).all() shs = .5 * stepdir.dot(self.fisher_vector_product(stepdir)) lm = np.sqrt(shs / self.max_kl) # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g)) fullstep = stepdir / lm expectedimprove = g.dot(fullstep) surrbefore = lossbefore[0] stepsize = 1.0 thbefore = self.get_flat() for _ in range(10): thnew = thbefore + fullstep * stepsize self.set_from_flat(thnew) meanlosses = surr, kl, *_ = self.allmean( np.array(self.compute_losses(*args))) improve = surr - surrbefore logger.log("Expected: %.3f Actual: %.3f" % (expectedimprove, improve)) if not np.isfinite(meanlosses).all(): logger.log("Got non-finite value of losses -- bad!") elif kl > self.max_kl * 1.5: logger.log("violated KL constraint. shrinking step.") elif improve < 0: logger.log("surrogate didn't improve. shrinking step.") else: logger.log("Stepsize OK!") break stepsize *= .5 else: logger.log("couldn't compute a good step") self.set_from_flat(thbefore) if self.nworkers > 1 and iters_so_far % 20 == 0: paramsums = MPI.COMM_WORLD.allgather( (thnew.sum(), self.vfadam.getflat().sum())) # list of tuples assert all( np.allclose(ps, paramsums[0]) for ps in paramsums[1:]) for (lossname, lossval) in zip(self.loss_names, meanlosses): logger.record_tabular(lossname, lossval) with self.timed("vf"): for _ in range(self.vf_iters): for (mbob, mbret) in dataset.iterbatches( (seg["ob"], seg["tdlamret"]), include_final_partial_batch=False, batch_size=64): g = self.allmean( self.compute_vflossandgrad(mbob, mbret)) self.vfadam.update(g, self.vf_stepsize) logger.record_tabular( "ev_tdlam_before", explained_variance(self.vpredbefore, self.tdlamret)) lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples lens, rews = map(flatten_lists, zip(*listoflrpairs)) lenbuffer.extend(lens) rewbuffer.extend(rews) logger.record_tabular("EpLenMean", np.mean(lenbuffer)) logger.record_tabular("EpRewMean", np.mean(rewbuffer)) logger.record_tabular("EpThisIter", len(lens)) episodes_so_far += len(lens) timesteps_so_far += sum(lens) iters_so_far += 1 logger.record_tabular("EpisodesSoFar", episodes_so_far) logger.record_tabular("TimestepsSoFar", timesteps_so_far) logger.record_tabular("TimeElapsed", time.time() - tstart) if self.rank == 0: logger.dump_tabular()
def learn( *, network, env, eval_env, total_timesteps, timesteps_per_batch=1024, # what to train on max_kl=0.001, cg_iters=10, gamma=0.99, lam=1.0, # advantage estimation seed=None, ent_coef=0.0, cg_damping=1e-2, vf_stepsize=3e-4, vf_iters=3, log_path=None, max_episodes=0, max_iters=0, # time constraint callback=None, load_path=None, **network_kwargs): ''' learn a policy function with TRPO algorithm Parameters: ---------- network neural network to learn. Can be either string ('mlp', 'cnn', 'lstm', 'lnlstm' for basic types) or function that takes input placeholder and returns tuple (output, None) for feedforward nets or (output, (state_placeholder, state_output, mask_placeholder)) for recurrent nets env environment (one of the gym environments or wrapped via baselines.common.vec_env.VecEnv-type class timesteps_per_batch timesteps per gradient estimation batch max_kl max KL divergence between old policy and new policy ( KL(pi_old || pi) ) ent_coef coefficient of policy entropy term in the optimization objective cg_iters number of iterations of conjugate gradient algorithm cg_damping conjugate gradient damping vf_stepsize learning rate for adam optimizer used to optimie value function loss vf_iters number of iterations of value function optimization iterations per each policy optimization step total_timesteps max number of timesteps max_episodes max number of episodes max_iters maximum number of policy optimization iterations callback function to be called with (locals(), globals()) each policy optimization step load_path str, path to load the model from (default: None, i.e. no model is loaded) **network_kwargs keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network Returns: ------- learnt model ''' if MPI is not None: nworkers = MPI.COMM_WORLD.Get_size() rank = MPI.COMM_WORLD.Get_rank() else: nworkers = 1 rank = 0 set_global_seeds(seed) np.set_printoptions(precision=3) # Setup losses and stuff # ---------------------------------------- ob_space = env.observation_space ac_space = env.action_space if isinstance(network, str): network = get_network_builder(network)(**network_kwargs) with tf.name_scope("pi"): pi_policy_network = network(ob_space.shape) pi_value_network = network(ob_space.shape) pi = PolicyWithValue(ac_space, pi_policy_network, pi_value_network) with tf.name_scope("oldpi"): old_pi_policy_network = network(ob_space.shape) old_pi_value_network = network(ob_space.shape) oldpi = PolicyWithValue(ac_space, old_pi_policy_network, old_pi_value_network) pi_var_list = pi_policy_network.trainable_variables + list( pi.pdtype.trainable_variables) old_pi_var_list = old_pi_policy_network.trainable_variables + list( oldpi.pdtype.trainable_variables) vf_var_list = pi_value_network.trainable_variables + pi.value_fc.trainable_variables old_vf_var_list = old_pi_value_network.trainable_variables + oldpi.value_fc.trainable_variables if load_path is not None: load_path = osp.expanduser(load_path) ckpt = tf.train.Checkpoint(model=pi) manager = tf.train.CheckpointManager(ckpt, load_path, max_to_keep=None) ckpt.restore(manager.latest_checkpoint) vfadam = MpiAdam(vf_var_list) get_flat = U.GetFlat(pi_var_list) set_from_flat = U.SetFromFlat(pi_var_list) loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"] shapes = [var.get_shape().as_list() for var in pi_var_list] def assign_old_eq_new(): for pi_var, old_pi_var in zip(pi_var_list, old_pi_var_list): old_pi_var.assign(pi_var) for vf_var, old_vf_var in zip(vf_var_list, old_vf_var_list): old_vf_var.assign(vf_var) @tf.function def compute_lossandgrad(ob, ac, atarg): with tf.GradientTape() as tape: old_policy_latent = oldpi.policy_network(ob) old_pd, _ = oldpi.pdtype.pdfromlatent(old_policy_latent) policy_latent = pi.policy_network(ob) pd, _ = pi.pdtype.pdfromlatent(policy_latent) kloldnew = old_pd.kl(pd) ent = pd.entropy() meankl = tf.reduce_mean(kloldnew) meanent = tf.reduce_mean(ent) entbonus = ent_coef * meanent ratio = tf.exp(pd.logp(ac) - old_pd.logp(ac)) surrgain = tf.reduce_mean(ratio * atarg) optimgain = surrgain + entbonus losses = [optimgain, meankl, entbonus, surrgain, meanent] gradients = tape.gradient(optimgain, pi_var_list) return losses + [U.flatgrad(gradients, pi_var_list)] @tf.function def compute_losses(ob, ac, atarg): old_policy_latent = oldpi.policy_network(ob) old_pd, _ = oldpi.pdtype.pdfromlatent(old_policy_latent) policy_latent = pi.policy_network(ob) pd, _ = pi.pdtype.pdfromlatent(policy_latent) kloldnew = old_pd.kl(pd) ent = pd.entropy() meankl = tf.reduce_mean(kloldnew) meanent = tf.reduce_mean(ent) entbonus = ent_coef * meanent ratio = tf.exp(pd.logp(ac) - old_pd.logp(ac)) surrgain = tf.reduce_mean(ratio * atarg) optimgain = surrgain + entbonus losses = [optimgain, meankl, entbonus, surrgain, meanent] return losses #ob shape should be [batch_size, ob_dim], merged nenv #ret shape should be [batch_size] @tf.function def compute_vflossandgrad(ob, ret): with tf.GradientTape() as tape: pi_vf = pi.value(ob) vferr = tf.reduce_mean(tf.square(pi_vf - ret)) return U.flatgrad(tape.gradient(vferr, vf_var_list), vf_var_list) @tf.function def compute_fvp(flat_tangent, ob, ac, atarg): with tf.GradientTape() as outter_tape: with tf.GradientTape() as inner_tape: old_policy_latent = oldpi.policy_network(ob) old_pd, _ = oldpi.pdtype.pdfromlatent(old_policy_latent) policy_latent = pi.policy_network(ob) pd, _ = pi.pdtype.pdfromlatent(policy_latent) kloldnew = old_pd.kl(pd) meankl = tf.reduce_mean(kloldnew) klgrads = inner_tape.gradient(meankl, pi_var_list) start = 0 tangents = [] for shape in shapes: sz = U.intprod(shape) tangents.append( tf.reshape(flat_tangent[start:start + sz], shape)) start += sz gvp = tf.add_n([ tf.reduce_sum(g * tangent) for (g, tangent) in zipsame(klgrads, tangents) ]) hessians_products = outter_tape.gradient(gvp, pi_var_list) fvp = U.flatgrad(hessians_products, pi_var_list) return fvp @contextmanager def timed(msg): if rank == 0: print(colorize(msg, color='magenta')) tstart = time.time() yield print( colorize("done in %.3f seconds" % (time.time() - tstart), color='magenta')) else: yield def allmean(x): assert isinstance(x, np.ndarray) if MPI is not None: out = np.empty_like(x) MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM) out /= nworkers else: out = np.copy(x) return out th_init = get_flat() if MPI is not None: MPI.COMM_WORLD.Bcast(th_init, root=0) set_from_flat(th_init) vfadam.sync() print("Init param sum", th_init.sum(), flush=True) # Prepare for rollouts # ---------------------------------------- seg_gen = traj_segment_generator(pi, env, timesteps_per_batch) episodes_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 tstart = time.time() lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards logdir = log_path + '/evaluator' modeldir = log_path + '/models' if not os.path.exists(logdir): os.makedirs(logdir) if not os.path.exists(modeldir): os.makedirs(modeldir) evaluator = Evaluator(env=eval_env, model=pi, logdir=logdir) max_inner_iter = 500000 if env.spec.id == 'InvertedDoublePendulum-v2' else 3000000 epoch = vf_iters batch_size = timesteps_per_batch mb_size = 256 inner_iter_per_iter = epoch * int(batch_size / mb_size) max_iter = int(max_inner_iter / inner_iter_per_iter) eval_num = 150 eval_interval = save_interval = int( int(max_inner_iter / eval_num) / inner_iter_per_iter) if sum([max_iters > 0, total_timesteps > 0, max_episodes > 0]) == 0: # noththing to be done return pi assert sum([max_iters>0, total_timesteps>0, max_episodes>0]) < 2, \ 'out of max_iters, total_timesteps, and max_episodes only one should be specified' for update in range(1, max_iter + 1): if callback: callback(locals(), globals()) # if total_timesteps and timesteps_so_far >= total_timesteps: # break # elif max_episodes and episodes_so_far >= max_episodes: # break # elif max_iters and iters_so_far >= max_iters: # break logger.log("********** Iteration %i ************" % iters_so_far) if (update - 1) % eval_interval == 0: evaluator.run_evaluation(update - 1) if (update - 1) % save_interval == 0: ckpt = tf.train.Checkpoint(model=pi) ckpt.save(modeldir + '/ckpt_ite' + str((update - 1))) with timed("sampling"): seg = seg_gen.__next__() add_vtarg_and_adv(seg, gamma, lam) # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets)) ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[ "tdlamret"] ob = sf01(ob) vpredbefore = seg["vpred"] # predicted value function before udpate atarg = (atarg - atarg.mean() ) / atarg.std() # standardized advantage function estimate if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret) if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy args = ob, ac, atarg fvpargs = [arr[::5] for arr in args] def fisher_vector_product(p): return allmean(compute_fvp(p, *fvpargs).numpy()) + cg_damping * p assign_old_eq_new() # set old parameter values to new parameter values with timed("computegrad"): *lossbefore, g = compute_lossandgrad(*args) lossbefore = allmean(np.array(lossbefore)) g = g.numpy() g = allmean(g) if np.allclose(g, 0): logger.log("Got zero gradient. not updating") else: with timed("cg"): stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank == 0) assert np.isfinite(stepdir).all() shs = .5 * stepdir.dot(fisher_vector_product(stepdir)) lm = np.sqrt(shs / max_kl) # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g)) fullstep = stepdir / lm expectedimprove = g.dot(fullstep) surrbefore = lossbefore[0] stepsize = 1.0 thbefore = get_flat() for _ in range(10): thnew = thbefore + fullstep * stepsize set_from_flat(thnew) meanlosses = surr, kl, *_ = allmean( np.array(compute_losses(*args))) improve = surr - surrbefore logger.log("Expected: %.3f Actual: %.3f" % (expectedimprove, improve)) if not np.isfinite(meanlosses).all(): logger.log("Got non-finite value of losses -- bad!") elif kl > max_kl * 1.5: logger.log("violated KL constraint. shrinking step.") elif improve < 0: logger.log("surrogate didn't improve. shrinking step.") else: logger.log("Stepsize OK!") break stepsize *= .5 else: logger.log("couldn't compute a good step") set_from_flat(thbefore) if nworkers > 1 and iters_so_far % 20 == 0: paramsums = MPI.COMM_WORLD.allgather( (thnew.sum(), vfadam.getflat().sum())) # list of tuples assert all( np.allclose(ps, paramsums[0]) for ps in paramsums[1:]) for (lossname, lossval) in zip(loss_names, meanlosses): logger.record_tabular(lossname, lossval) with timed("vf"): for _ in range(vf_iters): for (mbob, mbret) in dataset.iterbatches( (seg["ob"], seg["tdlamret"]), include_final_partial_batch=False, batch_size=mb_size): mbob = sf01(mbob) g = allmean(compute_vflossandgrad(mbob, mbret).numpy()) vfadam.update(g, vf_stepsize) logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret)) lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values if MPI is not None: listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples else: listoflrpairs = [lrlocal] lens, rews = map(flatten_lists, zip(*listoflrpairs)) lenbuffer.extend(lens) rewbuffer.extend(rews) logger.record_tabular("EpLenMean", np.mean(lenbuffer)) logger.record_tabular("EpRewMean", np.mean(rewbuffer)) logger.record_tabular("EpThisIter", len(lens)) episodes_so_far += len(lens) timesteps_so_far += sum(lens) iters_so_far += 1 logger.record_tabular("EpisodesSoFar", episodes_so_far) logger.record_tabular("TimestepsSoFar", timesteps_so_far) logger.record_tabular("TimeElapsed", time.time() - tstart) if rank == 0: logger.dump_tabular() return pi
def learn( *, network, env, total_timesteps, timesteps_per_batch=1024, # what to train on max_kl=0.001, cg_iters=10, gamma=0.99, lam=1.0, # advantage estimation seed=None, ent_coef=0.0, cg_damping=1e-2, vf_stepsize=3e-4, vf_iters=3, max_episodes=0, max_iters=0, # time constraint callback=None, load_path=None, **network_kwargs): ''' learn a policy function with TRPO algorithm Parameters: ---------- network neural network to learn. Can be either string ('mlp', 'cnn', 'lstm', 'lnlstm' for basic types) or function that takes input placeholder and returns tuple (output, None) for feedforward nets or (output, (state_placeholder, state_output, mask_placeholder)) for recurrent nets env environment (one of the gym environments or wrapped via baselines.common.vec_env.VecEnv-type class timesteps_per_batch timesteps per gradient estimation batch max_kl max KL divergence between old policy and new policy ( KL(pi_old || pi) ) ent_coef coefficient of policy entropy term in the optimization objective cg_iters number of iterations of conjugate gradient algorithm cg_damping conjugate gradient damping vf_stepsize learning rate for adam optimizer used to optimie value function loss vf_iters number of iterations of value function optimization iterations per each policy optimization step total_timesteps max number of timesteps max_episodes max number of episodes max_iters maximum number of policy optimization iterations callback function to be called with (locals(), globals()) each policy optimization step load_path str, path to load the model from (default: None, i.e. no model is loaded) **network_kwargs keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network Returns: ------- learnt model ''' if MPI is not None: nworkers = MPI.COMM_WORLD.Get_size() rank = MPI.COMM_WORLD.Get_rank() else: nworkers = 1 rank = 0 cpus_per_worker = 1 U.get_session( config=tf.ConfigProto(allow_soft_placement=True, inter_op_parallelism_threads=cpus_per_worker, intra_op_parallelism_threads=cpus_per_worker)) policy = build_policy(env, network, value_network='copy', **network_kwargs) set_global_seeds(seed) np.set_printoptions(precision=3) # Setup losses and stuff # ---------------------------------------- ob_space = env.observation_space ac_space = env.action_space ob = observation_placeholder(ob_space) with tf.variable_scope("pi"): pi = policy(observ_placeholder=ob) with tf.variable_scope("oldpi"): oldpi = policy(observ_placeholder=ob) atarg = tf.placeholder( dtype=tf.float32, shape=[None]) # Target advantage function (if applicable) ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return ac = pi.pdtype.sample_placeholder([None]) kloldnew = oldpi.pd.kl(pi.pd) ent = pi.pd.entropy() meankl = tf.reduce_mean(kloldnew) meanent = tf.reduce_mean(ent) entbonus = ent_coef * meanent vferr = tf.reduce_mean(tf.square(pi.vf - ret)) ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold surrgain = tf.reduce_mean(ratio * atarg) optimgain = surrgain + entbonus losses = [optimgain, meankl, entbonus, surrgain, meanent] loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"] dist = meankl all_var_list = get_trainable_variables("pi") # var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("pol")] # vf_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("vf")] var_list = get_pi_trainable_variables("pi") vf_var_list = get_vf_trainable_variables("pi") vfadam = MpiAdam(vf_var_list) get_flat = U.GetFlat(var_list) set_from_flat = U.SetFromFlat(var_list) klgrads = tf.gradients(dist, var_list) flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan") shapes = [var.get_shape().as_list() for var in var_list] start = 0 tangents = [] for shape in shapes: sz = U.intprod(shape) tangents.append(tf.reshape(flat_tangent[start:start + sz], shape)) start += sz gvp = tf.add_n([ tf.reduce_sum(g * tangent) for (g, tangent) in zipsame(klgrads, tangents) ]) #pylint: disable=E1111 fvp = U.flatgrad(gvp, var_list) assign_old_eq_new = U.function( [], [], updates=[ tf.assign(oldv, newv) for (oldv, newv) in zipsame(get_variables("oldpi"), get_variables("pi")) ]) compute_losses = U.function([ob, ac, atarg], losses) compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)]) compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp) compute_vflossandgrad = U.function([ob, ret], U.flatgrad(vferr, vf_var_list)) @contextmanager def timed(msg): if rank == 0: print(colorize(msg, color='magenta')) tstart = time.time() yield print( colorize("done in %.3f seconds" % (time.time() - tstart), color='magenta')) else: yield def allmean(x): assert isinstance(x, np.ndarray) if MPI is not None: out = np.empty_like(x) MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM) out /= nworkers else: out = np.copy(x) return out U.initialize() if load_path is not None: pi.load(load_path) th_init = get_flat() if MPI is not None: MPI.COMM_WORLD.Bcast(th_init, root=0) set_from_flat(th_init) vfadam.sync() print("Init param sum", th_init.sum(), flush=True) # Prepare for rollouts # ---------------------------------------- seg_gen = traj_segment_generator(pi, env, timesteps_per_batch, stochastic=True) episodes_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 tstart = time.time() lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards if sum([max_iters > 0, total_timesteps > 0, max_episodes > 0]) == 0: # noththing to be done return pi assert sum([max_iters>0, total_timesteps>0, max_episodes>0]) < 2, \ 'out of max_iters, total_timesteps, and max_episodes only one should be specified' while True: if callback: callback(locals(), globals()) if total_timesteps and timesteps_so_far >= total_timesteps: break elif max_episodes and episodes_so_far >= max_episodes: break elif max_iters and iters_so_far >= max_iters: break logger.log("********** Iteration %i ************" % iters_so_far) with timed("sampling"): seg = seg_gen.__next__() add_vtarg_and_adv(seg, gamma, lam) # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets)) ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[ "tdlamret"] vpredbefore = seg["vpred"] # predicted value function before udpate atarg = (atarg - atarg.mean() ) / atarg.std() # standardized advantage function estimate if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret) if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy args = seg["ob"], seg["ac"], atarg fvpargs = [arr[::5] for arr in args] def fisher_vector_product(p): return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p assign_old_eq_new() # set old parameter values to new parameter values with timed("computegrad"): *lossbefore, g = compute_lossandgrad(*args) lossbefore = allmean(np.array(lossbefore)) g = allmean(g) if np.allclose(g, 0): logger.log("Got zero gradient. not updating") else: with timed("cg"): stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank == 0) assert np.isfinite(stepdir).all() shs = .5 * stepdir.dot(fisher_vector_product(stepdir)) lm = np.sqrt(shs / max_kl) # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g)) fullstep = stepdir / lm expectedimprove = g.dot(fullstep) surrbefore = lossbefore[0] stepsize = 1.0 thbefore = get_flat() for _ in range(10): thnew = thbefore + fullstep * stepsize set_from_flat(thnew) meanlosses = surr, kl, *_ = allmean( np.array(compute_losses(*args))) improve = surr - surrbefore logger.log("Expected: %.3f Actual: %.3f" % (expectedimprove, improve)) if not np.isfinite(meanlosses).all(): logger.log("Got non-finite value of losses -- bad!") elif kl > max_kl * 1.5: logger.log("violated KL constraint. shrinking step.") elif improve < 0: logger.log("surrogate didn't improve. shrinking step.") else: logger.log("Stepsize OK!") break stepsize *= .5 else: logger.log("couldn't compute a good step") set_from_flat(thbefore) if nworkers > 1 and iters_so_far % 20 == 0: paramsums = MPI.COMM_WORLD.allgather( (thnew.sum(), vfadam.getflat().sum())) # list of tuples assert all( np.allclose(ps, paramsums[0]) for ps in paramsums[1:]) for (lossname, lossval) in zip(loss_names, meanlosses): logger.record_tabular(lossname, lossval) with timed("vf"): for _ in range(vf_iters): for (mbob, mbret) in dataset.iterbatches( (seg["ob"], seg["tdlamret"]), include_final_partial_batch=False, batch_size=64): g = allmean(compute_vflossandgrad(mbob, mbret)) vfadam.update(g, vf_stepsize) logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret)) lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values if MPI is not None: listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples else: listoflrpairs = [lrlocal] lens, rews = map(flatten_lists, zip(*listoflrpairs)) lenbuffer.extend(lens) rewbuffer.extend(rews) logger.record_tabular("EpLenMean", np.mean(lenbuffer)) logger.record_tabular("EpRewMean", np.mean(rewbuffer)) logger.record_tabular("EpThisIter", len(lens)) episodes_so_far += len(lens) timesteps_so_far += sum(lens) iters_so_far += 1 logger.record_tabular("EpisodesSoFar", episodes_so_far) logger.record_tabular("TimestepsSoFar", timesteps_so_far) logger.record_tabular("TimeElapsed", time.time() - tstart) if rank == 0: logger.dump_tabular() return pi
def learn( env, policy_func, *, timesteps_per_batch, max_kl, cg_iters, gamma, lam, entcoeff=0.0, cg_damping=1e-2, vf_stepsize=3e-4, vf_iters=3, max_timesteps=0, max_episodes=0, max_iters=0, callback=None, # GAIL Params pretrained_weight=None, reward_giver=None, expert_dataset=None, rank=0, save_per_iter=1, ckpt_dir="/tmp/gail/ckpt/", g_step=1, d_step=1, task_name="task_name", d_stepsize=3e-4, using_gail=True): """ learns a GAIL policy using the given environment :param env: (Gym Environment) the environment :param policy_func: (function (str, Gym Space, Gym Space, bool): MLPPolicy) policy generator :param timesteps_per_batch: (int) the number of timesteps to run per batch (horizon) :param max_kl: (float) the kullback leiber loss threashold :param cg_iters: (int) the number of iterations for the conjugate gradient calculation :param gamma: (float) the discount value :param lam: (float) GAE factor :param entcoeff: (float) the weight for the entropy loss :param cg_damping: (float) the compute gradient dampening factor :param vf_stepsize: (float) the value function stepsize :param vf_iters: (int) the value function's number iterations for learning :param max_timesteps: (int) the maximum number of timesteps before halting :param max_episodes: (int) the maximum number of episodes before halting :param max_iters: (int) the maximum number of training iterations before halting :param callback: (function (dict, dict)) the call back function, takes the local and global attribute dictionary :param pretrained_weight: (str) the save location for the pretrained weights :param reward_giver: (TransitionClassifier) the reward predicter from obsevation and action :param expert_dataset: (MujocoDset) the dataset manager :param rank: (int) the rank of the mpi thread :param save_per_iter: (int) the number of iterations before saving :param ckpt_dir: (str) the location for saving checkpoints :param g_step: (int) number of steps to train policy in each epoch :param d_step: (int) number of steps to train discriminator in each epoch :param task_name: (str) the name of the task (can be None) :param d_stepsize: (float) the reward giver stepsize :param using_gail: (bool) using the GAIL model """ nworkers = MPI.COMM_WORLD.Get_size() rank = MPI.COMM_WORLD.Get_rank() np.set_printoptions(precision=3) sess = tf_util.single_threaded_session() # Setup losses and stuff # ---------------------------------------- ob_space = env.observation_space ac_space = env.action_space policy = policy_func("pi", ob_space, ac_space, sess=sess) old_policy = policy_func("oldpi", ob_space, ac_space, sess=sess, placeholders={ "obs": policy.obs_ph, "stochastic": policy.stochastic_ph }) atarg = tf.placeholder( dtype=tf.float32, shape=[None]) # Target advantage function (if applicable) ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return observation = policy.obs_ph action = policy.pdtype.sample_placeholder([None]) kloldnew = old_policy.proba_distribution.kl(policy.proba_distribution) ent = policy.proba_distribution.entropy() meankl = tf.reduce_mean(kloldnew) meanent = tf.reduce_mean(ent) entbonus = entcoeff * meanent vferr = tf.reduce_mean(tf.square(policy.vpred - ret)) # advantage * pnew / pold ratio = tf.exp( policy.proba_distribution.logp(action) - old_policy.proba_distribution.logp(action)) surrgain = tf.reduce_mean(ratio * atarg) optimgain = surrgain + entbonus losses = [optimgain, meankl, entbonus, surrgain, meanent] loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"] dist = meankl all_var_list = policy.get_trainable_variables() if using_gail: var_list = [ v for v in all_var_list if v.name.startswith("pi/pol") or v.name.startswith("pi/logstd") ] vf_var_list = [v for v in all_var_list if v.name.startswith("pi/vff")] assert len(var_list) == len(vf_var_list) + 1 d_adam = MpiAdam(reward_giver.get_trainable_variables()) else: var_list = [ v for v in all_var_list if v.name.split("/")[1].startswith("pol") ] vf_var_list = [ v for v in all_var_list if v.name.split("/")[1].startswith("vf") ] vfadam = MpiAdam(vf_var_list, sess=sess) get_flat = tf_util.GetFlat(var_list, sess=sess) set_from_flat = tf_util.SetFromFlat(var_list, sess=sess) klgrads = tf.gradients(dist, var_list) flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan") shapes = [var.get_shape().as_list() for var in var_list] start = 0 tangents = [] for shape in shapes: var_size = tf_util.intprod(shape) tangents.append(tf.reshape(flat_tangent[start:start + var_size], shape)) start += var_size gvp = tf.add_n([ tf.reduce_sum(grad * tangent) for (grad, tangent) in zipsame(klgrads, tangents) ]) # pylint: disable=E1111 fvp = tf_util.flatgrad(gvp, var_list) assign_old_eq_new = tf_util.function( [], [], updates=[ tf.assign(oldv, newv) for (oldv, newv) in zipsame( old_policy.get_variables(), policy.get_variables()) ]) compute_losses = tf_util.function([observation, action, atarg], losses) compute_lossandgrad = tf_util.function( [observation, action, atarg], losses + [tf_util.flatgrad(optimgain, var_list)]) compute_fvp = tf_util.function([flat_tangent, observation, action, atarg], fvp) compute_vflossandgrad = tf_util.function([observation, ret], tf_util.flatgrad( vferr, vf_var_list)) @contextmanager def timed(msg): if rank == 0: print(colorize(msg, color='magenta')) start_time = time.time() yield print( colorize("done in %.3f seconds" % (time.time() - start_time), color='magenta')) else: yield def allmean(arr): assert isinstance(arr, np.ndarray) out = np.empty_like(arr) MPI.COMM_WORLD.Allreduce(arr, out, op=MPI.SUM) out /= nworkers return out tf_util.initialize(sess=sess) th_init = get_flat() MPI.COMM_WORLD.Bcast(th_init, root=0) set_from_flat(th_init) if using_gail: d_adam.sync() vfadam.sync() if rank == 0: print("Init param sum", th_init.sum(), flush=True) # Prepare for rollouts # ---------------------------------------- if using_gail: seg_gen = traj_segment_generator(policy, env, timesteps_per_batch, stochastic=True, reward_giver=reward_giver, gail=True) else: seg_gen = traj_segment_generator(policy, env, timesteps_per_batch, stochastic=True) episodes_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 t_start = time.time() lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1 if using_gail: true_rewbuffer = deque(maxlen=40) # Stats not used for now # g_loss_stats = Stats(loss_names) # d_loss_stats = Stats(reward_giver.loss_name) # ep_stats = Stats(["True_rewards", "Rewards", "Episode_length"]) # if provide pretrained weight if pretrained_weight is not None: tf_util.load_state(pretrained_weight, var_list=policy.get_variables()) while True: if callback: callback(locals(), globals()) if max_timesteps and timesteps_so_far >= max_timesteps: break elif max_episodes and episodes_so_far >= max_episodes: break elif max_iters and iters_so_far >= max_iters: break # Save model if using_gail and rank == 0 and iters_so_far % save_per_iter == 0 and ckpt_dir is not None: fname = os.path.join(ckpt_dir, task_name) os.makedirs(os.path.dirname(fname), exist_ok=True) saver = tf.train.Saver() saver.save(sess, fname) logger.log("********** Iteration %i ************" % iters_so_far) def fisher_vector_product(vec): return allmean(compute_fvp(vec, *fvpargs, sess=sess)) + cg_damping * vec # ------------------ Update G ------------------ logger.log("Optimizing Policy...") # g_step = 1 when not using GAIL for _ in range(g_step): with timed("sampling"): seg = seg_gen.__next__() add_vtarg_and_adv(seg, gamma, lam) # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets)) observation, action, atarg, tdlamret = seg["ob"], seg["ac"], seg[ "adv"], seg["tdlamret"] vpredbefore = seg[ "vpred"] # predicted value function before udpate atarg = (atarg - atarg.mean()) / atarg.std( ) # standardized advantage function estimate if hasattr(policy, "ret_rms"): policy.ret_rms.update(tdlamret) if hasattr(policy, "ob_rms"): policy.ob_rms.update( observation) # update running mean/std for policy args = seg["ob"], seg["ac"], atarg fvpargs = [arr[::5] for arr in args] assign_old_eq_new(sess=sess) with timed("computegrad"): *lossbefore, grad = compute_lossandgrad(*args, sess=sess) lossbefore = allmean(np.array(lossbefore)) grad = allmean(grad) if np.allclose(grad, 0): logger.log("Got zero gradient. not updating") else: with timed("cg"): stepdir = conjugate_gradient(fisher_vector_product, grad, cg_iters=cg_iters, verbose=rank == 0) assert np.isfinite(stepdir).all() shs = .5 * stepdir.dot(fisher_vector_product(stepdir)) # abs(shs) to avoid taking square root of negative values lagrange_multiplier = np.sqrt(abs(shs) / max_kl) # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g)) fullstep = stepdir / lagrange_multiplier expectedimprove = grad.dot(fullstep) surrbefore = lossbefore[0] stepsize = 1.0 thbefore = get_flat() for _ in range(10): thnew = thbefore + fullstep * stepsize set_from_flat(thnew) mean_losses = surr, kl_loss, *_ = allmean( np.array(compute_losses(*args, sess=sess))) improve = surr - surrbefore logger.log("Expected: %.3f Actual: %.3f" % (expectedimprove, improve)) if not np.isfinite(mean_losses).all(): logger.log("Got non-finite value of losses -- bad!") elif kl_loss > max_kl * 1.5: logger.log("violated KL constraint. shrinking step.") elif improve < 0: logger.log("surrogate didn't improve. shrinking step.") else: logger.log("Stepsize OK!") break stepsize *= .5 else: logger.log("couldn't compute a good step") set_from_flat(thbefore) if nworkers > 1 and iters_so_far % 20 == 0: paramsums = MPI.COMM_WORLD.allgather( (thnew.sum(), vfadam.getflat().sum())) # list of tuples assert all( np.allclose(ps, paramsums[0]) for ps in paramsums[1:]) with timed("vf"): for _ in range(vf_iters): for (mbob, mbret) in dataset.iterbatches( (seg["ob"], seg["tdlamret"]), include_final_partial_batch=False, batch_size=128): if hasattr(policy, "ob_rms"): policy.ob_rms.update( mbob) # update running mean/std for policy grad = allmean( compute_vflossandgrad(mbob, mbret, sess=sess)) vfadam.update(grad, vf_stepsize) for (loss_name, loss_val) in zip(loss_names, mean_losses): logger.record_tabular(loss_name, loss_val) logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret)) if using_gail: # ------------------ Update D ------------------ logger.log("Optimizing Discriminator...") logger.log(fmt_row(13, reward_giver.loss_name)) ob_expert, ac_expert = expert_dataset.get_next_batch( len(observation)) batch_size = len(observation) // d_step d_losses = [ ] # list of tuples, each of which gives the loss for a minibatch for ob_batch, ac_batch in dataset.iterbatches( (observation, action), include_final_partial_batch=False, batch_size=batch_size): ob_expert, ac_expert = expert_dataset.get_next_batch( len(ob_batch)) # update running mean/std for reward_giver if hasattr(reward_giver, "obs_rms"): reward_giver.obs_rms.update( np.concatenate((ob_batch, ob_expert), 0)) *newlosses, grad = reward_giver.lossandgrad( ob_batch, ac_batch, ob_expert, ac_expert) d_adam.update(allmean(grad), d_stepsize) d_losses.append(newlosses) logger.log(fmt_row(13, np.mean(d_losses, axis=0))) lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ep_true_rets"] ) # local values listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples lens, rews, true_rets = map(flatten_lists, zip(*listoflrpairs)) true_rewbuffer.extend(true_rets) else: lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples lens, rews = map(flatten_lists, zip(*listoflrpairs)) lenbuffer.extend(lens) rewbuffer.extend(rews) logger.record_tabular("EpLenMean", np.mean(lenbuffer)) logger.record_tabular("EpRewMean", np.mean(rewbuffer)) if using_gail: logger.record_tabular("EpTrueRewMean", np.mean(true_rewbuffer)) logger.record_tabular("EpThisIter", len(lens)) episodes_so_far += len(lens) timesteps_so_far += sum(lens) iters_so_far += 1 logger.record_tabular("EpisodesSoFar", episodes_so_far) logger.record_tabular("TimestepsSoFar", timesteps_so_far) logger.record_tabular("TimeElapsed", time.time() - t_start) if rank == 0: logger.dump_tabular()