def make_mbmf_pi(n, h, e, l): def _mbmf_pi(ob): ac, rew = mbl.step(ob=ob, pi=_mbmf_inner_pi, horizon=h, num_samples=n, num_elites=e, gamma=mbl_gamma, lamb=l, use_mean_elites=use_mean_elites) return ac[None], rew return Policy(step=_mbmf_pi, reset=None)
def learn( *, network, env, eval_env, make_eval_env, env_id, total_timesteps, timesteps_per_batch, sil_update, sil_loss, # what to train on max_kl=0.001, cg_iters=10, gamma=0.99, lam=1.0, # advantage estimation seed=None, ent_coef=0.0, lr=3e-4, cg_damping=1e-2, vf_stepsize=3e-4, vf_iters=5, sil_value=0.01, sil_alpha=0.6, sil_beta=0.1, max_episodes=0, max_iters=0, # time constraint callback=None, save_interval=0, load_path=None, # MBL # For train mbl mbl_train_freq=5, # For eval num_eval_episodes=5, eval_freq=5, vis_eval=False, eval_targs=('mbmf', ), #eval_targs=('mf',), quant=2, # For mbl.step #num_samples=(1500,), num_samples=(1, ), horizon=(2, ), #horizon=(2,1), #num_elites=(10,), num_elites=(1, ), mbl_lamb=(1.0, ), mbl_gamma=0.99, #mbl_sh=1, # Number of step for stochastic sampling mbl_sh=10000, #vf_lookahead=-1, #use_max_vf=False, reset_per_step=(0, ), # For get_model num_fc=2, num_fwd_hidden=500, use_layer_norm=False, # For MBL num_warm_start=int(1e4), init_epochs=10, update_epochs=5, batch_size=512, update_with_validation=False, use_mean_elites=1, use_ent_adjust=0, adj_std_scale=0.5, # For data loading validation_set_path=None, # For data collect collect_val_data=False, # For traj collect traj_collect='mf', # For profile measure_time=True, eval_val_err=False, measure_rew=True, model_fn=None, update_fn=None, init_fn=None, mpi_rank_weight=1, comm=None, vf_coef=0.5, max_grad_norm=0.5, log_interval=1, nminibatches=4, noptepochs=4, cliprange=0.2, **network_kwargs): ''' learn a policy function with TRPO algorithm Parameters: ---------- network neural network to learn. Can be either string ('mlp', 'cnn', 'lstm', 'lnlstm' for basic types) or function that takes input placeholder and returns tuple (output, None) for feedforward nets or (output, (state_placeholder, state_output, mask_placeholder)) for recurrent nets env environment (one of the gym environments or wrapped via baselines.common.vec_env.VecEnv-type class timesteps_per_batch timesteps per gradient estimation batch max_kl max KL divergence between old policy and new policy ( KL(pi_old || pi) ) ent_coef coefficient of policy entropy term in the optimization objective cg_iters number of iterations of conjugate gradient algorithm cg_damping conjugate gradient damping vf_stepsize learning rate for adam optimizer used to optimie value function loss vf_iters number of iterations of value function optimization iterations per each policy optimization step total_timesteps max number of timesteps max_episodes max number of episodes max_iters maximum number of policy optimization iterations callback function to be called with (locals(), globals()) each policy optimization step load_path str, path to load the model from (default: None, i.e. no model is loaded) **network_kwargs keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network Returns: ------- learnt model ''' if not isinstance(num_samples, tuple): num_samples = (num_samples, ) if not isinstance(horizon, tuple): horizon = (horizon, ) if not isinstance(num_elites, tuple): num_elites = (num_elites, ) if not isinstance(mbl_lamb, tuple): mbl_lamb = (mbl_lamb, ) if not isinstance(reset_per_step, tuple): reset_per_step = (reset_per_step, ) if validation_set_path is None: if collect_val_data: validation_set_path = os.path.join(logger.get_dir(), 'val.pkl') else: validation_set_path = os.path.join('dataset', '{}-val.pkl'.format(env_id)) if eval_val_err: eval_val_err_path = os.path.join('dataset', '{}-combine-val.pkl'.format(env_id)) logger.log(locals()) logger.log('MBL_SH', mbl_sh) logger.log('Traj_collect', traj_collect) if MPI is not None: nworkers = MPI.COMM_WORLD.Get_size() rank = MPI.COMM_WORLD.Get_rank() else: nworkers = 1 rank = 0 cpus_per_worker = 1 U.get_session( config=tf.ConfigProto(allow_soft_placement=True, inter_op_parallelism_threads=cpus_per_worker, intra_op_parallelism_threads=cpus_per_worker)) set_global_seeds(seed) if isinstance(lr, float): lr = constfn(lr) else: assert callable(lr) if isinstance(cliprange, float): cliprange = constfn(cliprange) else: assert callable(cliprange) policy = build_policy(env, network, value_network='copy', **network_kwargs) nenvs = env.num_envs np.set_printoptions(precision=3) # Setup losses and stuff # ---------------------------------------- ob_space = env.observation_space ac_space = env.action_space nbatch = nenvs * timesteps_per_batch nbatch_train = nbatch // nminibatches is_mpi_root = (MPI is None or MPI.COMM_WORLD.Get_rank() == 0) ob = observation_placeholder(ob_space) with tf.variable_scope("pi"): pi = policy(observ_placeholder=ob) make_model = lambda: Model( policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train, nsteps=timesteps_per_batch, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, sil_update=sil_update, sil_value=sil_value, sil_alpha=sil_alpha, sil_beta=sil_beta, sil_loss=sil_loss, # fn_reward=env.process_reward, fn_reward=None, # fn_obs=env.process_obs, fn_obs=None, ppo=False, prev_pi='pi', silm=pi) model = make_model() with tf.variable_scope("oldpi"): oldpi = policy(observ_placeholder=ob) make_old_model = lambda: Model( policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train, nsteps=timesteps_per_batch, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, sil_update=sil_update, sil_value=sil_value, sil_alpha=sil_alpha, sil_beta=sil_beta, sil_loss=sil_loss, # fn_reward=env.process_reward, fn_reward=None, # fn_obs=env.process_obs, fn_obs=None, ppo=False, prev_pi='oldpi', silm=oldpi) old_model = make_old_model() # MBL # --------------------------------------- #viz = Visdom(env=env_id) win = None eval_targs = list(eval_targs) logger.log(eval_targs) make_model_f = get_make_mlp_model(num_fc=num_fc, num_fwd_hidden=num_fwd_hidden, layer_norm=use_layer_norm) mbl = MBL(env=eval_env, env_id=env_id, make_model=make_model_f, num_warm_start=num_warm_start, init_epochs=init_epochs, update_epochs=update_epochs, batch_size=batch_size, **network_kwargs) val_dataset = {'ob': None, 'ac': None, 'ob_next': None} if update_with_validation: logger.log('Update with validation') val_dataset = load_val_data(validation_set_path) if eval_val_err: logger.log('Log val error') eval_val_dataset = load_val_data(eval_val_err_path) if collect_val_data: logger.log('Collect validation data') val_dataset_collect = [] def _mf_pi(ob, t=None): stochastic = True ac, vpred, _, _ = pi.step(ob, stochastic=stochastic) return ac, vpred def _mf_det_pi(ob, t=None): #ac, vpred, _, _ = pi.step(ob, stochastic=False) ac, vpred = pi._evaluate([pi.pd.mode(), pi.vf], ob) return ac, vpred def _mf_ent_pi(ob, t=None): mean, std, vpred = pi._evaluate([pi.pd.mode(), pi.pd.std, pi.vf], ob) ac = np.random.normal(mean, std * adj_std_scale, size=mean.shape) return ac, vpred ################### use_ent_adjust======> adj_std_scale????????pi action sample def _mbmf_inner_pi(ob, t=0): if use_ent_adjust: return _mf_ent_pi(ob) else: #return _mf_pi(ob) if t < mbl_sh: return _mf_pi(ob) else: return _mf_det_pi(ob) # --------------------------------------- # Run multiple configuration once all_eval_descs = [] def make_mbmf_pi(n, h, e, l): def _mbmf_pi(ob): ac, rew = mbl.step(ob=ob, pi=_mbmf_inner_pi, horizon=h, num_samples=n, num_elites=e, gamma=mbl_gamma, lamb=l, use_mean_elites=use_mean_elites) return ac[None], rew return Policy(step=_mbmf_pi, reset=None) for n in num_samples: for h in horizon: for l in mbl_lamb: for e in num_elites: if 'mbmf' in eval_targs: all_eval_descs.append(('MeanRew', 'MBL_TRPO_SIL', make_mbmf_pi(n, h, e, l))) #if 'mbmf' in eval_targs: all_eval_descs.append(('MeanRew-n-{}-h-{}-e-{}-l-{}-sh-{}-me-{}'.format(n, h, e, l, mbl_sh, use_mean_elites), 'MBL_TRPO-n-{}-h-{}-e-{}-l-{}-sh-{}-me-{}'.format(n, h, e, l, mbl_sh, use_mean_elites), make_mbmf_pi(n, h, e, l))) if 'mf' in eval_targs: all_eval_descs.append( ('MeanRew', 'TRPO_SIL', Policy(step=_mf_pi, reset=None))) logger.log('List of evaluation targets') for it in all_eval_descs: logger.log(it[0]) pool = Pool(mp.cpu_count()) warm_start_done = False # ---------------------------------------- atarg = tf.placeholder( dtype=tf.float32, shape=[None]) # Target advantage function (if applicable) ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return ac = pi.pdtype.sample_placeholder([None]) kloldnew = oldpi.pd.kl(pi.pd) ent = pi.pd.entropy() meankl = tf.reduce_mean(kloldnew) meanent = tf.reduce_mean(ent) entbonus = ent_coef * meanent vferr = tf.reduce_mean(tf.square(pi.vf - ret)) ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold surrgain = tf.reduce_mean(ratio * atarg) optimgain = surrgain + entbonus losses = [optimgain, meankl, entbonus, surrgain, meanent] loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"] dist = meankl all_var_list = get_trainable_variables("pi") # var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("pol")] # vf_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("vf")] var_list = get_pi_trainable_variables("pi") vf_var_list = get_vf_trainable_variables("pi") vfadam = MpiAdam(vf_var_list) get_flat = U.GetFlat(var_list) set_from_flat = U.SetFromFlat(var_list) klgrads = tf.gradients(dist, var_list) flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan") shapes = [var.get_shape().as_list() for var in var_list] start = 0 tangents = [] for shape in shapes: sz = U.intprod(shape) tangents.append(tf.reshape(flat_tangent[start:start + sz], shape)) start += sz gvp = tf.add_n([ tf.reduce_sum(g * tangent) for (g, tangent) in zipsame(klgrads, tangents) ]) #pylint: disable=E1111 fvp = U.flatgrad(gvp, var_list) assign_old_eq_new = U.function( [], [], updates=[ tf.assign(oldv, newv) for (oldv, newv) in zipsame(get_variables("oldpi"), get_variables("pi")) ]) compute_losses = U.function([ob, ac, atarg], losses) compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)]) compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp) compute_vflossandgrad = U.function([ob, ret], U.flatgrad(vferr, vf_var_list)) @contextmanager def timed(msg): if rank == 0: print(colorize(msg, color='magenta')) tstart = time.time() yield print( colorize("done in %.3f seconds" % (time.time() - tstart), color='magenta')) else: yield def allmean(x): assert isinstance(x, np.ndarray) out = np.empty_like(x) MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM) out /= nworkers return out U.initialize() if load_path is not None: pi.load(load_path) th_init = get_flat() MPI.COMM_WORLD.Bcast(th_init, root=0) set_from_flat(th_init) vfadam.sync() print("Init param sum", th_init.sum(), flush=True) # Prepare for rollouts # ---------------------------------------- if traj_collect == 'mf': seg_gen = traj_segment_generator(env, timesteps_per_batch, model, stochastic=True) episodes_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 tstart = time.time() lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards if sum([max_iters > 0, total_timesteps > 0, max_episodes > 0]) == 0: # noththing to be done return pi assert sum([max_iters>0, total_timesteps>0, max_episodes>0]) < 2, \ 'out of max_iters, total_timesteps, and max_episodes only one should be specified' while True: if callback: callback(locals(), globals()) if total_timesteps and timesteps_so_far >= total_timesteps: break elif max_episodes and episodes_so_far >= max_episodes: break elif max_iters and iters_so_far >= max_iters: break logger.log("********** Iteration %i ************" % iters_so_far) with timed("sampling"): seg = seg_gen.__next__() if traj_collect == 'mf-random' or traj_collect == 'mf-mb': seg_mbl = seg_gen_mbl.__next__() else: seg_mbl = seg add_vtarg_and_adv(seg, gamma, lam) # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets)) ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[ "tdlamret"] # Val data collection if collect_val_data: for ob_, ac_, ob_next_ in zip(ob[:-1, 0, ...], ac[:-1, ...], ob[1:, 0, ...]): val_dataset_collect.append( (copy.copy(ob_), copy.copy(ac_), copy.copy(ob_next_))) # ----------------------------- # MBL update else: ob_mbl, ac_mbl = seg_mbl["ob"], seg_mbl["ac"] mbl.add_data_batch(ob_mbl[:-1, 0, ...], ac_mbl[:-1, ...], ob_mbl[1:, 0, ...]) mbl.update_forward_dynamic(require_update=iters_so_far % mbl_train_freq == 0, ob_val=val_dataset['ob'], ac_val=val_dataset['ac'], ob_next_val=val_dataset['ob_next']) # ----------------------------- if traj_collect == 'mf': #if traj_collect == 'mf' or traj_collect == 'mf-random' or traj_collect == 'mf-mb': vpredbefore = seg[ "vpred"] # predicted value function before udpate model = seg["model"] atarg = (atarg - atarg.mean()) / atarg.std( ) # standardized advantage function estimate if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret) if hasattr(pi, "rms"): pi.rms.update(ob) # update running mean/std for policy args = seg["ob"], seg["ac"], atarg fvpargs = [arr[::5] for arr in args] def fisher_vector_product(p): return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p assign_old_eq_new( ) # set old parameter values to new parameter values with timed("computegrad"): *lossbefore, g = compute_lossandgrad(*args) lossbefore = allmean(np.array(lossbefore)) g = allmean(g) if np.allclose(g, 0): logger.log("Got zero gradient. not updating") else: with timed("cg"): stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank == 0) assert np.isfinite(stepdir).all() shs = .5 * stepdir.dot(fisher_vector_product(stepdir)) lm = np.sqrt(shs / max_kl) # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g)) fullstep = stepdir / lm expectedimprove = g.dot(fullstep) surrbefore = lossbefore[0] stepsize = 1.0 thbefore = get_flat() for _ in range(10): thnew = thbefore + fullstep * stepsize set_from_flat(thnew) meanlosses = surr, kl, *_ = allmean( np.array(compute_losses(*args))) improve = surr - surrbefore logger.log("Expected: %.3f Actual: %.3f" % (expectedimprove, improve)) if not np.isfinite(meanlosses).all(): logger.log("Got non-finite value of losses -- bad!") elif kl > max_kl * 1.5: logger.log("violated KL constraint. shrinking step.") elif improve < 0: logger.log("surrogate didn't improve. shrinking step.") else: logger.log("Stepsize OK!") break stepsize *= .5 else: logger.log("couldn't compute a good step") set_from_flat(thbefore) if nworkers > 1 and iters_so_far % 20 == 0: paramsums = MPI.COMM_WORLD.allgather( (thnew.sum(), vfadam.getflat().sum())) # list of tuples assert all( np.allclose(ps, paramsums[0]) for ps in paramsums[1:]) for (lossname, lossval) in zip(loss_names, meanlosses): logger.record_tabular(lossname, lossval) with timed("vf"): for _ in range(vf_iters): for (mbob, mbret) in dataset.iterbatches( (seg["ob"], seg["tdlamret"]), include_final_partial_batch=False, batch_size=64): g = allmean(compute_vflossandgrad(mbob, mbret)) vfadam.update(g, vf_stepsize) with timed("SIL"): lrnow = lr(1.0 - timesteps_so_far / total_timesteps) l_loss, sil_adv, sil_samples, sil_nlogp = model.sil_train( lrnow) logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret)) lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values if MPI is not None: listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples else: listoflrpairs = [lrlocal] lens, rews = map(flatten_lists, zip(*listoflrpairs)) lenbuffer.extend(lens) rewbuffer.extend(rews) logger.record_tabular("EpLenMean", np.mean(lenbuffer)) logger.record_tabular("EpRewMean", np.mean(rewbuffer)) logger.record_tabular("EpThisIter", len(lens)) episodes_so_far += len(lens) timesteps_so_far += sum(lens) iters_so_far += 1 logger.record_tabular("EpisodesSoFar", episodes_so_far) logger.record_tabular("TimestepsSoFar", timesteps_so_far) logger.record_tabular("TimeElapsed", time.time() - tstart) if sil_update > 0: logger.record_tabular("SilSamples", sil_samples) if rank == 0: # MBL evaluation if not collect_val_data: #set_global_seeds(seed) default_sess = tf.get_default_session() def multithread_eval_policy(env_, pi_, num_episodes_, vis_eval_, seed): with default_sess.as_default(): if hasattr(env, 'ob_rms') and hasattr(env_, 'ob_rms'): env_.ob_rms = env.ob_rms res = eval_policy(env_, pi_, num_episodes_, vis_eval_, seed, measure_time, measure_rew) try: env_.close() except: pass return res if mbl.is_warm_start_done() and iters_so_far % eval_freq == 0: warm_start_done = mbl.is_warm_start_done() if num_eval_episodes > 0: targs_names = {} with timed('eval'): num_descs = len(all_eval_descs) list_field_names = [e[0] for e in all_eval_descs] list_legend_names = [e[1] for e in all_eval_descs] list_pis = [e[2] for e in all_eval_descs] list_eval_envs = [ make_eval_env() for _ in range(num_descs) ] list_seed = [seed for _ in range(num_descs)] list_num_eval_episodes = [ num_eval_episodes for _ in range(num_descs) ] print(list_field_names) print(list_legend_names) list_vis_eval = [ vis_eval for _ in range(num_descs) ] for i in range(num_descs): field_name, legend_name = list_field_names[ i], list_legend_names[i], res = multithread_eval_policy( list_eval_envs[i], list_pis[i], list_num_eval_episodes[i], list_vis_eval[i], seed) #eval_results = pool.starmap(multithread_eval_policy, zip(list_eval_envs, list_pis, list_num_eval_episodes, list_vis_eval,list_seed)) #for field_name, legend_name, res in zip(list_field_names, list_legend_names, eval_results): perf, elapsed_time, eval_rew = res logger.record_tabular(field_name, perf) if measure_time: logger.record_tabular( 'Time-%s' % (field_name), elapsed_time) if measure_rew: logger.record_tabular( 'SimRew-%s' % (field_name), eval_rew) targs_names[field_name] = legend_name if eval_val_err: fwd_dynamics_err = mbl.eval_forward_dynamic( obs=eval_val_dataset['ob'], acs=eval_val_dataset['ac'], obs_next=eval_val_dataset['ob_next']) logger.record_tabular('FwdValError', fwd_dynamics_err) logger.dump_tabular() #print(logger.get_dir()) #print(targs_names) #if num_eval_episodes > 0: # win = plot(viz, win, logger.get_dir(), targs_names=targs_names, quant=quant, opt='best') # ----------- #logger.dump_tabular() yield pi if collect_val_data: with open(validation_set_path, 'wb') as f: pickle.dump(val_dataset_collect, f) logger.log('Save {} validation data'.format(len(val_dataset_collect)))
def learn( *, network, env, eval_env, make_eval_env, env_id, total_timesteps, seed=None, nsteps=2048, ent_coef=0.0, lr=3e-4, vf_coef=0.5, max_grad_norm=0.5, gamma=0.99, lam=0.95, log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2, # MBL # For train mbl mbl_train_freq=5, # For eval num_eval_episodes=5, eval_freq=5, vis_eval=False, #eval_targs=('mbmf',), eval_targs=('mf', ), quant=2, # For mbl.step #num_samples=(1500,), num_samples=(1, ), horizon=(2, ), #horizon=(2,1), #num_elites=(10,), num_elites=(1, ), mbl_lamb=(1.0, ), mbl_gamma=0.99, #mbl_sh=1, # Number of step for stochastic sampling mbl_sh=10000, #vf_lookahead=-1, #use_max_vf=False, reset_per_step=(0, ), # For get_model num_fc=2, num_fwd_hidden=500, use_layer_norm=False, # For MBL num_warm_start=int(1e4), init_epochs=10, update_epochs=5, batch_size=512, update_with_validation=False, use_mean_elites=1, use_ent_adjust=0, adj_std_scale=0.5, # For data loading validation_set_path=None, # For data collect collect_val_data=False, # For traj collect traj_collect='mf', # For profile measure_time=True, eval_val_err=False, measure_rew=True, save_interval=0, load_path=None, model_fn=None, update_fn=None, init_fn=None, mpi_rank_weight=1, comm=None, **network_kwargs): ''' Learn policy using PPO algorithm (https://arxiv.org/abs/1707.06347) Parameters: ---------- network: policy network architecture. Either string (mlp, lstm, lnlstm, cnn_lstm, cnn, cnn_small, conv_only - see baselines.common/models.py for full list) specifying the standard network architecture, or a function that takes tensorflow tensor as input and returns tuple (output_tensor, extra_feed) where output tensor is the last network layer output, extra_feed is None for feed-forward neural nets, and extra_feed is a dictionary describing how to feed state into the network for recurrent neural nets. See common/models.py/lstm for more details on using recurrent nets in policies env: baselines.common.vec_env.VecEnv environment. Needs to be vectorized for parallel environment simulation. The environments produced by gym.make can be wrapped using baselines.common.vec_env.DummyVecEnv class. nsteps: int number of steps of the vectorized environment per update (i.e. batch size is nsteps * nenv where nenv is number of environment copies simulated in parallel) total_timesteps: int number of timesteps (i.e. number of actions taken in the environment) ent_coef: float policy entropy coefficient in the optimization objective lr: float or function learning rate, constant or a schedule function [0,1] -> R+ where 1 is beginning of the training and 0 is the end of the training. vf_coef: float value function loss coefficient in the optimization objective max_grad_norm: float or None gradient norm clipping coefficient gamma: float discounting factor lam: float advantage estimation discounting factor (lambda in the paper) log_interval: int number of timesteps between logging events nminibatches: int number of training minibatches per update. For recurrent policies, should be smaller or equal than number of environments run in parallel. noptepochs: int number of training epochs per update cliprange: float or function clipping range, constant or schedule function [0,1] -> R+ where 1 is beginning of the training and 0 is the end of the training save_interval: int number of timesteps between saving events load_path: str path to load the model from **network_kwargs: keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network For instance, 'mlp' network architecture has arguments num_hidden and num_layers. ''' if not isinstance(num_samples, tuple): num_samples = (num_samples, ) if not isinstance(horizon, tuple): horizon = (horizon, ) if not isinstance(num_elites, tuple): num_elites = (num_elites, ) if not isinstance(mbl_lamb, tuple): mbl_lamb = (mbl_lamb, ) if not isinstance(reset_per_step, tuple): reset_per_step = (reset_per_step, ) if validation_set_path is None: if collect_val_data: validation_set_path = os.path.join(logger.get_dir(), 'val.pkl') else: validation_set_path = os.path.join('dataset', '{}-val.pkl'.format(env_id)) if eval_val_err: eval_val_err_path = os.path.join('dataset', '{}-combine-val.pkl'.format(env_id)) logger.log(locals()) logger.log('MBL_SH', mbl_sh) logger.log('Traj_collect', traj_collect) if MPI is not None: nworkers = MPI.COMM_WORLD.Get_size() rank = MPI.COMM_WORLD.Get_rank() else: nworkers = 1 rank = 0 cpus_per_worker = 1 U.get_session( config=tf.ConfigProto(allow_soft_placement=True, inter_op_parallelism_threads=cpus_per_worker, intra_op_parallelism_threads=cpus_per_worker)) set_global_seeds(seed) if isinstance(lr, float): lr = constfn(lr) else: assert callable(lr) if isinstance(cliprange, float): cliprange = constfn(cliprange) else: assert callable(cliprange) total_timesteps = int(total_timesteps) policy = build_policy(env, network, **network_kwargs) np.set_printoptions(precision=3) # Get the nb of env nenvs = env.num_envs # Get state_space and action_space ob_space = env.observation_space ac_space = env.action_space # Calculate the batch_size nbatch = nenvs * nsteps nbatch_train = nbatch // nminibatches is_mpi_root = (MPI is None or MPI.COMM_WORLD.Get_rank() == 0) # Instantiate the model object (that creates act_model and train_model) if model_fn is None: model_fn = Model model = model_fn(policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train, nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, comm=comm, mpi_rank_weight=mpi_rank_weight, ppo=True, prev_pi=None) pi = model.act_model if load_path is not None: model.load(load_path) # MBL # --------------------------------------- #viz = Visdom(env=env_id) win = None eval_targs = list(eval_targs) logger.log(eval_targs) make_model = get_make_mlp_model(num_fc=num_fc, num_fwd_hidden=num_fwd_hidden, layer_norm=use_layer_norm) mbl = MBL(env=eval_env, env_id=env_id, make_model=make_model, num_warm_start=num_warm_start, init_epochs=init_epochs, update_epochs=update_epochs, batch_size=batch_size, **network_kwargs) val_dataset = {'ob': None, 'ac': None, 'ob_next': None} if update_with_validation: logger.log('Update with validation') val_dataset = load_val_data(validation_set_path) if eval_val_err: logger.log('Log val error') eval_val_dataset = load_val_data(eval_val_err_path) if collect_val_data: logger.log('Collect validation data') val_dataset_collect = [] def _mf_pi(ob, t=None): stochastic = True ac, vpred, _, _ = pi.step(ob, stochastic=stochastic) return ac, vpred def _mf_det_pi(ob, t=None): #ac, vpred, _, _ = pi.step(ob, stochastic=False) ac, vpred = pi._evaluate([pi.pd.mode(), pi.vf], ob) return ac, vpred def _mf_ent_pi(ob, t=None): mean, std, vpred = pi._evaluate([pi.pd.mode(), pi.pd.std, pi.vf], ob) ac = np.random.normal(mean, std * adj_std_scale, size=mean.shape) return ac, vpred ################### use_ent_adjust======> adj_std_scale????????pi action sample def _mbmf_inner_pi(ob, t=0): if use_ent_adjust: return _mf_ent_pi(ob) else: #return _mf_pi(ob) if t < mbl_sh: return _mf_pi(ob) else: return _mf_det_pi(ob) # --------------------------------------- # Run multiple configuration once all_eval_descs = [] def make_mbmf_pi(n, h, e, l): def _mbmf_pi(ob): ac, rew = mbl.step(ob=ob, pi=_mbmf_inner_pi, horizon=h, num_samples=n, num_elites=e, gamma=mbl_gamma, lamb=l, use_mean_elites=use_mean_elites) return ac[None], rew return Policy(step=_mbmf_pi, reset=None) for n in num_samples: for h in horizon: for l in mbl_lamb: for e in num_elites: if 'mbmf' in eval_targs: all_eval_descs.append( ('MeanRew', 'MBL_PPO', make_mbmf_pi(n, h, e, l))) #if 'mbmf' in eval_targs: all_eval_descs.append(('MeanRew-n-{}-h-{}-e-{}-l-{}-sh-{}-me-{}'.format(n, h, e, l, mbl_sh, use_mean_elites), 'MBL_TRPO-n-{}-h-{}-e-{}-l-{}-sh-{}-me-{}'.format(n, h, e, l, mbl_sh, use_mean_elites), make_mbmf_pi(n, h, e, l))) if 'mf' in eval_targs: all_eval_descs.append(('MeanRew', 'PPO', Policy(step=_mf_pi, reset=None))) logger.log('List of evaluation targets') for it in all_eval_descs: logger.log(it[0]) @contextmanager def timed(msg): if rank == 0: print(colorize(msg, color='magenta')) tstart = time.time() yield print( colorize("done in %.3f seconds" % (time.time() - tstart), color='magenta')) else: yield pool = Pool(mp.cpu_count()) warm_start_done = False U.initialize() if load_path is not None: pi.load(load_path) # Instantiate the runner object runner = Runner(env=env, model=model, nsteps=nsteps, gamma=gamma, lam=lam) epinfobuf = deque(maxlen=40) if init_fn is not None: init_fn() if traj_collect == 'mf': obs = runner.run()[0] # Start total timer tfirststart = time.perf_counter() nupdates = total_timesteps // nbatch for update in range(1, nupdates + 1): assert nbatch % nminibatches == 0 # Start timer if hasattr(model.train_model, "ret_rms"): model.train_model.ret_rms.update(returns) if hasattr(model.train_model, "rms"): model.train_model.rms.update(obs) tstart = time.perf_counter() frac = 1.0 - (update - 1.0) / nupdates # Calculate the learning rate lrnow = lr(frac) # Calculate the cliprange cliprangenow = cliprange(frac) if update % log_interval == 0 and is_mpi_root: logger.info('Stepping environment...') # Get minibatch obs, returns, masks, actions, values, neglogpacs, states, epinfos = runner.run( ) #pylint: disable=E0632 # Val data collection if collect_val_data: for ob_, ac_, ob_next_ in zip(obs[:-1, 0, ...], actions[:-1, ...], obs[1:, 0, ...]): val_dataset_collect.append( (copy.copy(ob_), copy.copy(ac_), copy.copy(ob_next_))) # ----------------------------- # MBL update else: ob_mbl, ac_mbl = obs.copy(), actions.copy() mbl.add_data_batch(ob_mbl[:-1, ...], ac_mbl[:-1, ...], ob_mbl[1:, ...]) mbl.update_forward_dynamic(require_update=(update - 1) % mbl_train_freq == 0, ob_val=val_dataset['ob'], ac_val=val_dataset['ac'], ob_next_val=val_dataset['ob_next']) # ----------------------------- if update % log_interval == 0 and is_mpi_root: logger.info('Done.') epinfobuf.extend(epinfos) # Here what we're going to do is for each minibatch calculate the loss and append it. mblossvals = [] if states is None: # nonrecurrent version # Index of each element of batch_size # Create the indices array inds = np.arange(nbatch) for _ in range(noptepochs): # Randomize the indexes np.random.shuffle(inds) # 0 to batch_size with batch_train_size step for start in range(0, nbatch, nbatch_train): end = start + nbatch_train mbinds = inds[start:end] slices = (arr[mbinds] for arr in (obs, returns, masks, actions, values, neglogpacs)) mblossvals.append(model.train(lrnow, cliprangenow, *slices)) else: # recurrent version print("caole") assert nenvs % nminibatches == 0 envsperbatch = nenvs // nminibatches envinds = np.arange(nenvs) flatinds = np.arange(nenvs * nsteps).reshape(nenvs, nsteps) for _ in range(noptepochs): np.random.shuffle(envinds) for start in range(0, nenvs, envsperbatch): end = start + envsperbatch mbenvinds = envinds[start:end] mbflatinds = flatinds[mbenvinds].ravel() slices = (arr[mbflatinds] for arr in (obs, returns, masks, actions, values, neglogpacs)) mbstates = states[mbenvinds] mblossvals.append( model.train(lrnow, cliprangenow, *slices, mbstates)) # Feedforward --> get losses --> update lossvals = np.mean(mblossvals, axis=0) # End timer tnow = time.perf_counter() # Calculate the fps (frame per second) fps = int(nbatch / (tnow - tstart)) if update_fn is not None: update_fn(update) if update % log_interval == 0 or update == 1: # Calculates if value function is a good predicator of the returns (ev > 1) # or if it's just worse than predicting nothing (ev =< 0) ev = explained_variance(values, returns) logger.logkv("misc/serial_timesteps", update * nsteps) logger.logkv("misc/nupdates", update) logger.logkv("misc/total_timesteps", update * nbatch) logger.logkv("fps", fps) logger.logkv("misc/explained_variance", float(ev)) logger.logkv('eprewmean', safemean([epinfo['r'] for epinfo in epinfobuf])) logger.logkv("AverageReturn", safemean([epinfo['r'] for epinfo in epinfobuf])) logger.logkv('eplenmean', safemean([epinfo['l'] for epinfo in epinfobuf])) logger.logkv('misc/time_elapsed', tnow - tfirststart) for (lossval, lossname) in zip(lossvals, model.loss_names): logger.logkv('loss/' + lossname, lossval) if rank == 0: # MBL evaluation if not collect_val_data: #set_global_seeds(seed) default_sess = tf.get_default_session() def multithread_eval_policy(env_, pi_, num_episodes_, vis_eval_, seed): with default_sess.as_default(): if hasattr(env, 'ob_rms') and hasattr( env_, 'ob_rms'): env_.ob_rms = env.ob_rms res = eval_policy(env_, pi_, num_episodes_, vis_eval_, seed, measure_time, measure_rew) try: env_.close() except: pass return res #if mbl.forward_dynamic.memory.nb_entries >= mbl.num_warm_start and update % eval_freq == 0: if mbl.is_warm_start_done() and update % eval_freq == 0: warm_start_done = mbl.is_warm_start_done() if num_eval_episodes > 0: targs_names = {} with timed('eval'): num_descs = len(all_eval_descs) list_field_names = [ e[0] for e in all_eval_descs ] list_legend_names = [ e[1] for e in all_eval_descs ] list_pis = [e[2] for e in all_eval_descs] list_eval_envs = [ make_eval_env() for _ in range(num_descs) ] list_seed = [seed for _ in range(num_descs)] list_num_eval_episodes = [ num_eval_episodes for _ in range(num_descs) ] print(list_field_names) print(list_legend_names) list_vis_eval = [ vis_eval for _ in range(num_descs) ] for i in range(num_descs): field_name, legend_name = list_field_names[ i], list_legend_names[i], res = multithread_eval_policy( list_eval_envs[i], list_pis[i], list_num_eval_episodes[i], list_vis_eval[i], seed) #eval_results = pool.starmap(multithread_eval_policy, zip(list_eval_envs, list_pis, list_num_eval_episodes, list_vis_eval,list_seed)) #for field_name, legend_name, res in zip(list_field_names, list_legend_names, eval_results): perf, elapsed_time, eval_rew = res logger.logkv(field_name, perf) if measure_time: logger.logkv('Time-%s' % (field_name), elapsed_time) if measure_rew: logger.logkv( 'SimRew-%s' % (field_name), eval_rew) targs_names[field_name] = legend_name if eval_val_err: fwd_dynamics_err = mbl.eval_forward_dynamic( obs=eval_val_dataset['ob'], acs=eval_val_dataset['ac'], obs_next=eval_val_dataset['ob_next']) logger.logkv('FwdValError', fwd_dynamics_err) #logger.dump_tabular() logger.dumpkvs() #print(logger.get_dir()) #print(targs_names) #if num_eval_episodes > 0: # win = plot(viz, win, logger.get_dir(), targs_names=targs_names, quant=quant, opt='best') #else: logger.dumpkvs() # ----------- yield pi if collect_val_data: with open(validation_set_path, 'wb') as f: pickle.dump(val_dataset_collect, f) logger.log('Save {} validation data'.format( len(val_dataset_collect))) if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir() and is_mpi_root: checkdir = osp.join(logger.get_dir(), 'checkpoints') os.makedirs(checkdir, exist_ok=True) savepath = osp.join(checkdir, '%.5i' % update) print('Saving to', savepath) model.save(savepath) return model