def learn(policy, env, nsteps, total_timesteps, ent_coef, lr, vf_coef=0.5, max_grad_norm=0.5, gamma=0.99, lam=0.95, log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2, save_interval=0, load_path=None, train_callback=None, eval_callback=None, cloud_sync_callback=None, cloud_sync_interval=1000, workdir='', use_curiosity=False, curiosity_strength=0.01, forward_inverse_ratio=0.2, curiosity_loss_strength=10, random_state_predictor=False): if isinstance(lr, float): lr = constfn(lr) else: assert callable(lr) if isinstance(cliprange, float): cliprange = constfn(cliprange) else: assert callable(cliprange) total_timesteps = int(total_timesteps) nenvs = env.num_envs ob_space = env.observation_space ac_space = env.action_space nbatch = nenvs * nsteps nbatch_train = nbatch // nminibatches # pylint: disable=g-long-lambda make_model = lambda: Model(policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train, nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, use_curiosity=use_curiosity, curiosity_strength=curiosity_strength, forward_inverse_ratio=forward_inverse_ratio, curiosity_loss_strength=curiosity_loss_strength, random_state_predictor=random_state_predictor) # pylint: enable=g-long-lambda if save_interval and workdir: with tf.gfile.Open(osp.join(workdir, 'make_model.pkl'), 'wb') as fh: fh.write(dill.dumps(make_model)) model = make_model() if load_path is not None: model.load(load_path) runner = Runner(env=env, model=model, nsteps=nsteps, gamma=gamma, lam=lam, eval_callback=eval_callback) epinfobuf = deque(maxlen=100) tfirststart = time.time() nupdates = total_timesteps // nbatch for update in range(1, nupdates + 1): assert nbatch % nminibatches == 0 nbatch_train = nbatch // nminibatches tstart = time.time() frac = 1.0 - (update - 1.0) / nupdates lrnow = lr(frac) cliprangenow = cliprange(frac) (obs, next_obs, returns, masks, actions, values, neglogpacs), states, epinfos = runner.run() epinfobuf.extend(epinfos) mblossvals = [] if states is None: # nonrecurrent version inds = np.arange(nbatch) for _ in range(noptepochs): np.random.shuffle(inds) for start in range(0, nbatch, nbatch_train): end = start + nbatch_train mbinds = inds[start:end] slices = [ arr[mbinds] for arr in (obs, returns, masks, actions, values, neglogpacs, next_obs) ] mblossvals.append( model.train(lrnow, cliprangenow, slices[0], slices[6], slices[1], slices[2], slices[3], slices[4], slices[5])) else: # recurrent version assert nenvs % nminibatches == 0 envsperbatch = nenvs // nminibatches envinds = np.arange(nenvs) flatinds = np.arange(nenvs * nsteps).reshape(nenvs, nsteps) envsperbatch = nbatch_train // nsteps for _ in range(noptepochs): np.random.shuffle(envinds) for start in range(0, nenvs, envsperbatch): end = start + envsperbatch mbenvinds = envinds[start:end] mbflatinds = flatinds[mbenvinds].ravel() slices = [ arr[mbflatinds] for arr in (obs, returns, masks, actions, values, neglogpacs, next_obs) ] mbstates = states[mbenvinds] mblossvals.append( model.train(lrnow, cliprangenow, slices[0], slices[6], slices[1], slices[2], slices[3], slices[4], slices[5], mbstates)) lossvals = np.mean(mblossvals, axis=0) tnow = time.time() fps = int(nbatch / (tnow - tstart)) if update % log_interval == 0 or update == 1: ev = explained_variance(values, returns) logger.logkv('serial_timesteps', update * nsteps) logger.logkv('nupdates', update) logger.logkv('total_timesteps', update * nbatch) logger.logkv('fps', fps) logger.logkv('explained_variance', float(ev)) logger.logkv('eprewmean', safemean([epinfo['r'] for epinfo in epinfobuf])) logger.logkv('eplenmean', safemean([epinfo['l'] for epinfo in epinfobuf])) if train_callback: train_callback(safemean([epinfo['l'] for epinfo in epinfobuf]), safemean([epinfo['r'] for epinfo in epinfobuf]), update * nbatch) logger.logkv('time_elapsed', tnow - tfirststart) for (lossval, lossname) in zip(lossvals, model.loss_names): logger.logkv(lossname, lossval) logger.dumpkvs() if (save_interval and (update % save_interval == 0 or update == 1) and workdir): checkdir = osp.join(workdir, 'checkpoints') if not tf.gfile.Exists(checkdir): tf.gfile.MakeDirs(checkdir) savepath = osp.join(checkdir, '%.5i' % update) print('Saving to', savepath) model.save(savepath) if (cloud_sync_interval and update % cloud_sync_interval == 0 and cloud_sync_callback): cloud_sync_callback() env.close() return model
def eval_callback_on_test(eprewmean, global_step_val): if test_measurements: test_measurements.create_measurement( objective_value=eprewmean, step=global_step_val) logger.logkv('eprewmean_test', eprewmean)
def eval_callback_on_valid(eprewmean, global_step_val): if valid_measurements: valid_measurements.create_measurement( objective_value=eprewmean, step=global_step_val) logger.logkv('eprewmean_valid', eprewmean)
def measurement_callback(unused_eplenmean, eprewmean, global_step_val): if train_measurements: train_measurements.create_measurement( objective_value=eprewmean, step=global_step_val) logger.logkv('eprewmean_train', eprewmean)
def learn(policy, env, nsteps, total_timesteps, ent_coef, lr, vf_coef=0.5, max_grad_norm=0.5, gamma=0.99, lam=0.95, log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2, save_interval=0, load_path=None, train_callback=None, eval_callback=None, cloud_sync_callback=None, cloud_sync_interval=1000, workdir='', use_curiosity=False, curiosity_strength=0.01, forward_inverse_ratio=0.2, curiosity_loss_strength=10, random_state_predictor=False, use_rlb=False, checkpoint_path_for_debugging=None): if isinstance(lr, float): lr = constfn(lr) else: assert callable(lr) if isinstance(cliprange, float): cliprange = constfn(cliprange) else: assert callable(cliprange) total_timesteps = int(total_timesteps) nenvs = env.num_envs ob_space = env.observation_space ac_space = env.action_space nbatch = nenvs * nsteps nbatch_train = nbatch // nminibatches # pylint: disable=g-long-lambda make_model = lambda: Model(policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train, nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, use_curiosity=use_curiosity, curiosity_strength=curiosity_strength, forward_inverse_ratio=forward_inverse_ratio, curiosity_loss_strength=curiosity_loss_strength, random_state_predictor=random_state_predictor, use_rlb=use_rlb) # pylint: enable=g-long-lambda if save_interval and workdir: with tf.gfile.Open(osp.join(workdir, 'make_model.pkl'), 'wb') as fh: fh.write(dill.dumps(make_model)) saver = tf.train.Saver(max_to_keep=10000000) def save_state(fname): if not osp.exists(osp.dirname(fname)): os.makedirs(osp.dirname(fname)) saver.save(tf.get_default_session(), fname) with tf.device('/gpu:0'): model = make_model() if load_path is not None: model.load(load_path) runner = Runner(env=env, model=model, nsteps=nsteps, gamma=gamma, lam=lam, eval_callback=eval_callback) if checkpoint_path_for_debugging is not None: tf_util.load_state(checkpoint_path_for_debugging, var_list=tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, scope='rlb_model')) epinfobuf = deque(maxlen=100) tfirststart = time.time() nupdates = total_timesteps // nbatch for update in range(1, nupdates + 1): assert nbatch % nminibatches == 0 nbatch_train = nbatch // nminibatches tstart = time.time() frac = 1.0 - (update - 1.0) / nupdates lrnow = lr(frac) cliprangenow = cliprange(frac) (obs, next_obs, returns, masks, actions, values, neglogpacs), states, epinfos, (rewards, rewards_ext, rewards_int, rewards_int_raw, selected_infos, dones) = runner.run() epinfobuf.extend(epinfos) mblossvals = [] mbhistos = [] mbscs = [] #if model.all_rlb_args.debug_args['debug_tf_timeline'] and update % 5 == 0: if model.all_rlb_args.debug_args[ 'debug_tf_timeline'] and update % 1 == 0: debug_timeliner = logger.TimeLiner() else: debug_timeliner = None if states is None: # nonrecurrent version inds = np.arange(nbatch) for oe in range(noptepochs): gather_histo = (oe == noptepochs - 1) np.random.shuffle(inds) for start in range(0, nbatch, nbatch_train): gather_sc = ((oe == noptepochs - 1) and (start + nbatch_train >= nbatch)) end = start + nbatch_train mbinds = inds[start:end] slices = [ arr[mbinds] for arr in (obs, returns, masks, actions, values, neglogpacs, next_obs) ] with logger.ProfileKV('train'): fetches = model.train(lrnow, cliprangenow, slices[0], slices[6], slices[1], slices[2], slices[3], slices[4], slices[5], gather_histo=gather_histo, gather_sc=gather_sc, debug_timeliner=debug_timeliner) mblossvals.append(fetches['losses']) if gather_histo: mbhistos.append(fetches['stats_histo']) if gather_sc: mbscs.append(fetches['stats_sc']) else: # recurrent version assert nenvs % nminibatches == 0 envsperbatch = nenvs // nminibatches envinds = np.arange(nenvs) flatinds = np.arange(nenvs * nsteps).reshape(nenvs, nsteps) envsperbatch = nbatch_train // nsteps for oe in range(noptepochs): gather_histo = (oe == noptepochs - 1) np.random.shuffle(envinds) for start in range(0, nenvs, envsperbatch): gather_sc = ((oe == noptepochs - 1) and (start + nbatch_train >= nbatch)) end = start + envsperbatch mbenvinds = envinds[start:end] mbflatinds = flatinds[mbenvinds].ravel() slices = [ arr[mbflatinds] for arr in (obs, returns, masks, actions, values, neglogpacs, next_obs) ] mbstates = states[mbenvinds] fetches = model.train(lrnow, cliprangenow, slices[0], slices[6], slices[1], slices[2], slices[3], slices[4], slices[5], mbstates, gather_histo=gather_histo, gather_sc=gather_sc, debug_timeliner=debug_timeliner) mblossvals.append(fetches['losses']) if gather_histo: mbhistos.append(fetches['stats_histo']) if gather_sc: mbscs.append(fetches['stats_sc']) if debug_timeliner is not None: with logger.ProfileKV("save_timeline_json"): debug_timeliner.save( osp.join(workdir, 'timeline_{}.json'.format(update))) lossvals = np.mean(mblossvals, axis=0) assert len(mbscs) == 1 scalars = mbscs[0] histograms = { n: np.concatenate([f[n] for f in mbhistos], axis=0) for n in model.stats_histo_names } logger.info('Histograms: {}'.format([(n, histograms[n].shape) for n in histograms.keys()])) #for v in histograms.values(): # assert len(v) == nbatch tnow = time.time() fps = int(nbatch / (tnow - tstart)) if update % log_interval == 0 or update == 1: fps_total = int((update * nbatch) / (tnow - tfirststart)) #tf_op_names = [i.name for i in tf.get_default_graph().get_operations()] #logger.info('#################### tf_op_names: {}'.format(tf_op_names)) tf_num_ops = len(tf.get_default_graph().get_operations()) logger.info( '#################### tf_num_ops: {}'.format(tf_num_ops)) logger.logkv('tf_num_ops', tf_num_ops) ev = explained_variance(values, returns) logger.logkv('serial_timesteps', update * nsteps) logger.logkv('nupdates', update) logger.logkv('total_timesteps', update * nbatch) logger.logkv('fps', fps) logger.logkv('fps_total', fps_total) logger.logkv( 'remaining_time', float(tnow - tfirststart) / float(update) * float(nupdates - update)) logger.logkv('explained_variance', float(ev)) logger.logkv('eprewmean', safemean([epinfo['r'] for epinfo in epinfobuf])) logger.logkv('eplenmean', safemean([epinfo['l'] for epinfo in epinfobuf])) if train_callback: train_callback(safemean([epinfo['l'] for epinfo in epinfobuf]), safemean([epinfo['r'] for epinfo in epinfobuf]), update * nbatch) logger.logkv('time_elapsed', tnow - tfirststart) for (lossval, lossname) in zip(lossvals, model.loss_names): logger.logkv(lossname, lossval) for n, v in scalars.items(): logger.logkv(n, v) for n, v in histograms.items(): logger.logkv(n, v) logger.logkv('mean_' + n, np.mean(v)) logger.logkv('std_' + n, np.std(v)) logger.logkv('max_' + n, np.max(v)) logger.logkv('min_' + n, np.min(v)) for n, v in locals().items(): if n in ['rewards_int', 'rewards_int_raw']: logger.logkv(n, v) if n in [ 'rewards', 'rewards_ext', 'rewards_int', 'rewards_int_raw' ]: logger.logkv('mean_' + n, np.mean(v)) logger.logkv('std_' + n, np.std(v)) logger.logkv('max_' + n, np.max(v)) logger.logkv('min_' + n, np.min(v)) if model.rlb_model: if model.all_rlb_args.outer_args['rlb_normalize_ir']: logger.logkv('rlb_ir_running_mean', runner.irff_rms.mean) logger.logkv('rlb_ir_running_std', np.sqrt(runner.irff_rms.var)) logger.dumpkvs() if (save_interval and (update % save_interval == 0 or update == 1) and workdir): checkdir = osp.join(workdir, 'checkpoints') if not tf.gfile.Exists(checkdir): tf.gfile.MakeDirs(checkdir) savepath = osp.join(checkdir, '%.5i' % update) print('Saving to', savepath) model.save(savepath) checkdir = osp.join(workdir, 'full_checkpoints') if not tf.gfile.Exists(checkdir): tf.gfile.MakeDirs(checkdir) savepath = osp.join(checkdir, '%.5i' % update) print('Saving to', savepath) save_state(savepath) if (cloud_sync_interval and update % cloud_sync_interval == 0 and cloud_sync_callback): cloud_sync_callback() env.close() return model
def train(self, batch_gen, steps_per_epoch, num_epochs): mblossvals = [] mbhistos = [] mbscs = [] mbascs = [] for epoch in range(num_epochs): gather_histo = (epoch == num_epochs - 1) for step in range(steps_per_epoch): gather_sc = ((epoch == num_epochs - 1) and (step == steps_per_epoch - 1)) obs, obs_next, acs = next(batch_gen) with logger.ProfileKV('train_ot_inner'): fetches = self._train( obs, obs_next, acs, gather_histo=gather_histo, gather_sc=gather_sc) mblossvals.append(fetches['losses']) if gather_histo: mbhistos.append(fetches['stats_histo']) if gather_sc: mbscs.append(fetches['stats_sc']) mbascs.append(fetches['additional_sc']) lossvals = np.mean(mblossvals, axis=0) assert len(mbscs) == 1 assert len(mbascs) == 1 scalars = mbscs[0] additional_scalars = mbascs[0] histograms = { n: np.concatenate([f[n] for f in mbhistos], axis=0) for n in self._stats_histo_names } logger.info('RLBModelWrapper.train histograms: {}'.format([(n, histograms[n].shape) for n in histograms.keys()])) for (lossval, lossname) in zip(lossvals, self._loss_names): logger.logkv(lossname, lossval) for n, v in scalars.items(): logger.logkv(n, v) for n, v in additional_scalars.items(): logger.logkv(n, v) for n, v in histograms.items(): logger.logkv(n, v) logger.logkv('mean_' + n, np.mean(v)) logger.logkv('std_' + n, np.std(v)) logger.logkv('max_' + n, np.max(v)) logger.logkv('min_' + n, np.min(v))