def train(self): memory = ReplayMem( obs_dim=self.env.observation_space.flat_dim, act_dim=self.env.action_space.flat_dim, memory_size=self.memory_size) itr = 0 path_length = 0 path_return = 0 end = False obs = self.env.reset() for epoch in xrange(self.n_epochs): logger.push_prefix("epoch #%d | " % epoch) logger.log("Training started") for epoch_itr in pyprind.prog_bar(range(self.epoch_length)): # run the policy if end: # reset the environment and stretegy when an episode ends obs = self.env.reset() self.strategy.reset() # self.policy.reset() self.strategy_path_returns.append(path_return) path_length = 0 path_return = 0 # note action is sampled from the policy not the target policy act = self.strategy.get_action(obs, self.policy) nxt, rwd, end, _ = self.env.step(act) path_length += 1 path_return += rwd if not end and path_length >= self.max_path_length: end = True if self.include_horizon_terminal: memory.add_sample(obs, act, rwd, end) else: memory.add_sample(obs, act, rwd, end) obs = nxt if memory.size >= self.memory_start_size: for update_time in xrange(self.n_updates_per_sample): batch = memory.get_batch(self.batch_size) self.do_update(itr, batch) itr += 1 logger.log("Training finished") if memory.size >= self.memory_start_size: self.evaluate(epoch, memory) logger.dump_tabular(with_prefix=False) logger.pop_prefix()
def train(self): parallel_sampler.populate_task(self.env, self.policy) if self.plot: plotter.init_plot(self.env, self.policy) cur_std = self.init_std cur_mean = self.policy.get_param_values() # K = cur_mean.size n_best = max(1, int(self.n_samples * self.best_frac)) for itr in range(self.n_itr): # sample around the current distribution extra_var_mult = max(1.0 - itr / self.extra_decay_time, 0) sample_std = np.sqrt(np.square(cur_std) + np.square(self.extra_std) * extra_var_mult) if self.batch_size is None: criterion = 'paths' threshold = self.n_samples else: criterion = 'samples' threshold = self.batch_size infos = stateful_pool.singleton_pool.run_collect( _worker_rollout_policy, threshold=threshold, args=(dict(cur_mean=cur_mean, sample_std=sample_std, max_path_length=self.max_path_length, discount=self.discount, criterion=criterion),) ) xs = np.asarray([info[0] for info in infos]) paths = [info[1] for info in infos] fs = np.array([path['returns'][0] for path in paths]) print((xs.shape, fs.shape)) best_inds = (-fs).argsort()[:n_best] best_xs = xs[best_inds] cur_mean = best_xs.mean(axis=0) cur_std = best_xs.std(axis=0) best_x = best_xs[0] logger.push_prefix('itr #%d | ' % itr) logger.record_tabular('Iteration', itr) logger.record_tabular('CurStdMean', np.mean(cur_std)) undiscounted_returns = np.array([path['undiscounted_return'] for path in paths]) logger.record_tabular('AverageReturn', np.mean(undiscounted_returns)) logger.record_tabular('StdReturn', np.mean(undiscounted_returns)) logger.record_tabular('MaxReturn', np.max(undiscounted_returns)) logger.record_tabular('MinReturn', np.min(undiscounted_returns)) logger.record_tabular('AverageDiscountedReturn', np.mean(fs)) logger.record_tabular('AvgTrajLen', np.mean([len(path['returns']) for path in paths])) logger.record_tabular('NumTrajs', len(paths)) self.policy.set_param_values(best_x) self.env.log_diagnostics(paths) self.policy.log_diagnostics(paths) logger.save_itr_params(itr, dict( itr=itr, policy=self.policy, env=self.env, cur_mean=cur_mean, cur_std=cur_std, )) logger.dump_tabular(with_prefix=False) logger.pop_prefix() if self.plot: plotter.update_plot(self.policy, self.max_path_length) parallel_sampler.terminate_task()
def _train(self, env, policy, pool): """Perform RL training. Args: env (`rllab.Env`): Environment used for training policy (`Policy`): Policy used for training pool (`PoolBase`): Sample pool to add samples to """ self._init_training(env, policy, pool) with self._sess.as_default(): observation = env.reset() policy.reset() path_length = 0 path_return = 0 last_path_return = 0 max_path_return = -np.inf n_episodes = 0 gt.rename_root('RLAlgorithm') gt.reset() gt.set_def_unique(False) for epoch in gt.timed_for( range(self._n_epochs + 1), save_itrs=True): logger.push_prefix('Epoch #%d | ' % epoch) if self.iter_callback is not None: self.iter_callback(locals(), globals()) for t in range(self._epoch_length): iteration = t + epoch * self._epoch_length action, _ = policy.get_action(observation) next_ob, reward, terminal, info = env.step(action) path_length += 1 path_return += reward self.pool.add_sample( observation, action, reward, terminal, next_ob, ) if terminal or path_length >= self._max_path_length: observation = env.reset() policy.reset() path_length = 0 max_path_return = max(max_path_return, path_return) last_path_return = path_return path_return = 0 n_episodes += 1 else: observation = next_ob gt.stamp('sample') if self.pool.size >= self._min_pool_size: for i in range(self._n_train_repeat): batch = self.pool.random_batch(self._batch_size) self._do_training(iteration, batch) gt.stamp('train') self._evaluate(epoch) params = self.get_snapshot(epoch) logger.save_itr_params(epoch, params) times_itrs = gt.get_times().stamps.itrs eval_time = times_itrs['eval'][-1] if epoch > 1 else 0 total_time = gt.get_times().total logger.record_tabular('time-train', times_itrs['train'][-1]) logger.record_tabular('time-eval', eval_time) logger.record_tabular('time-sample', times_itrs['sample'][-1]) logger.record_tabular('time-total', total_time) logger.record_tabular('epoch', epoch) logger.record_tabular('episodes', n_episodes) logger.record_tabular('max-path-return', max_path_return) logger.record_tabular('last-path-return', last_path_return) logger.record_tabular('pool-size', self.pool.size) logger.dump_tabular(with_prefix=False) logger.pop_prefix() gt.stamp('eval') env.terminate()
algo = TRPO( env=env, policy=policy, baseline=baseline, # batch_size=4000, batch_size=1000, max_path_length=env.horizon, n_itr=500, discount=0.99, step_size=0.0025, # Uncomment both lines (this and the plot parameter below) to enable plotting # plot=True, ) algo.train() logger.remove_tabular_output(tabular_log_file) logger.remove_text_output(text_log_file) logger.pop_prefix() # run_experiment_lite( # algo.train(), # # Number of parallel workers for sampling # n_parallel=1, # # Only keep the snapshot parameters for the last iteration # snapshot_mode="last", # # Specifies the seed for the experiment. If this is not provided, a random seed # # will be used # seed=1, # # plot=True, # )
def train(self): pool = SimpleReplayPool( max_pool_size=self.replay_pool_size, observation_dim=self.env.observation_space.flat_dim, action_dim=self.env.action_space.flat_dim, ) self.start_worker() self.init_opt() itr = 0 path_length = 0 path_return = 0 terminal = False observation = self.env.reset() sample_policy = pickle.loads(pickle.dumps(self.policy)) #self.experiment_space = self.env.action_space for epoch in xrange(self.n_epochs): logger.push_prefix('epoch #%d | ' % epoch) logger.log("Training started") for epoch_itr in pyprind.prog_bar(xrange(self.epoch_length)): # Execute policy if terminal: observation = self.env.reset() self.es.reset() sample_policy.reset() self.es_path_returns.append(path_return) path_length = 0 path_return = 0 action = self.es.get_action(itr, observation, policy=sample_policy) # qf=qf) next_observation, reward, terminal, _ = self.env.step(action, observation) path_length += 1 path_return += reward if not terminal and path_length >= self.max_path_length: terminal = True if self.include_horizon_terminal_transitions: pool.add_sample( self.env.observation_space.flatten(observation), self.env.action_space.flatten(action), reward * self.scale_reward, terminal ) else: pool.add_sample( self.env.observation_space.flatten(observation), self.env.action_space.flatten(action), reward * self.scale_reward, terminal ) observation = next_observation if pool.size >= self.min_pool_size: for update_itr in xrange(self.n_updates_per_sample): # Train policy batch = pool.random_batch(self.batch_size) self.do_training(itr, batch) sample_policy.set_param_values(self.policy.get_param_values()) itr += 1 self.pool = pool logger.log("Training finished") if pool.size >= self.min_pool_size: self.evaluate(epoch, pool) params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) logger.dump_tabular(with_prefix=False) logger.pop_prefix() if self.plot: self.update_plot() if self.pause_for_plot: raw_input("Plotting evaluation run: Press Enter to " "continue...") self.env.terminate() self.policy.terminate()
def train(self): cur_std = self.sigma0 cur_mean = self.policy.get_param_values() es = cma_es_lib.CMAEvolutionStrategy( cur_mean, cur_std) parallel_sampler.populate_task(self.env, self.policy) if self.plot: plotter.init_plot(self.env, self.policy) cur_std = self.sigma0 cur_mean = self.policy.get_param_values() itr = 0 while itr < self.n_itr and not es.stop(): if self.batch_size is None: # Sample from multivariate normal distribution. xs = es.ask() xs = np.asarray(xs) # For each sample, do a rollout. infos = ( stateful_pool.singleton_pool.run_map(sample_return, [(x, self.max_path_length, self.discount) for x in xs])) else: cum_len = 0 infos = [] xss = [] done = False while not done: sbs = stateful_pool.singleton_pool.n_parallel * 2 # Sample from multivariate normal distribution. # You want to ask for sbs samples here. xs = es.ask(sbs) xs = np.asarray(xs) xss.append(xs) sinfos = stateful_pool.singleton_pool.run_map( sample_return, [(x, self.max_path_length, self.discount) for x in xs]) for info in sinfos: infos.append(info) cum_len += len(info['returns']) if cum_len >= self.batch_size: xs = np.concatenate(xss) done = True break # Evaluate fitness of samples (negative as it is minimization # problem). fs = - np.array([info['returns'][0] for info in infos]) # When batching, you could have generated too many samples compared # to the actual evaluations. So we cut it off in this case. xs = xs[:len(fs)] # Update CMA-ES params based on sample fitness. es.tell(xs, fs) logger.push_prefix('itr #%d | ' % itr) logger.record_tabular('Iteration', itr) logger.record_tabular('CurStdMean', np.mean(cur_std)) undiscounted_returns = np.array( [info['undiscounted_return'] for info in infos]) logger.record_tabular('AverageReturn', np.mean(undiscounted_returns)) logger.record_tabular('StdReturn', np.mean(undiscounted_returns)) logger.record_tabular('MaxReturn', np.max(undiscounted_returns)) logger.record_tabular('MinReturn', np.min(undiscounted_returns)) logger.record_tabular('AverageDiscountedReturn', np.mean(fs)) logger.record_tabular('AvgTrajLen', np.mean([len(info['returns']) for info in infos])) self.env.log_diagnostics(infos) self.policy.log_diagnostics(infos) logger.save_itr_params(itr, dict( itr=itr, policy=self.policy, env=self.env, )) logger.dump_tabular(with_prefix=False) if self.plot: plotter.update_plot(self.policy, self.max_path_length) logger.pop_prefix() # Update iteration. itr += 1 # Set final params. self.policy.set_param_values(es.result()[0]) parallel_sampler.terminate_task()
def close_logger(log_dir): logger.remove_tabular_output(osp.join(log_dir, 'progress.csv')) logger.remove_text_output(osp.join(log_dir, 'debug.log')) logger.pop_prefix()
def train(self): with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # This seems like a rather sequential method pool = SimpleReplayPool( max_pool_size=self.replay_pool_size, observation_dim=self.env.observation_space.flat_dim, action_dim=self.env.action_space.flat_dim, replacement_prob=self.replacement_prob, ) self.start_worker() self.init_opt() # This initializes the optimizer parameters sess.run(tf.global_variables_initializer()) itr = 0 path_length = 0 path_return = 0 terminal = False initial = False observation = self.env.reset() with tf.variable_scope("sample_policy"): sample_policy = Serializable.clone(self.policy) for epoch in range(self.n_epochs): logger.push_prefix('epoch #%d | ' % epoch) logger.log("Training started") train_qf_itr, train_policy_itr = 0, 0 #sample a policy function from the posterior at every episode #move in the entire episode with the sampled policy function? for epoch_itr in pyprind.prog_bar(range(self.epoch_length)): # Execute policy if terminal: # or path_length > self.max_path_length: # Note that if the last time step ends an episode, the very # last state and observation will be ignored and not added # to the replay pool observation = self.env.reset() self.es.reset() sample_policy.reset() self.es_path_returns.append(path_return) path_length = 0 path_return = 0 initial = True else: initial = False action = self.es.get_action(itr, observation, policy=sample_policy) # qf=qf) next_observation, reward, terminal, _ = self.env.step( action) path_length += 1 path_return += reward if not terminal and path_length >= self.max_path_length: terminal = True # only include the terminal transition in this case if the flag was set if self.include_horizon_terminal_transitions: pool.add_sample(observation, action, reward * self.scale_reward, terminal, initial) else: pool.add_sample(observation, action, reward * self.scale_reward, terminal, initial) observation = next_observation if pool.size >= self.min_pool_size: for update_itr in range(self.n_updates_per_sample): # Train policy batch = pool.random_batch(self.batch_size) itrs = self.do_training(itr, epoch, batch) train_qf_itr += itrs[0] train_policy_itr += itrs[1] sample_policy.set_param_values( self.policy.get_param_values()) itr += 1 logger.log("Training finished") logger.log("Trained qf %d steps, policy %d steps" % (train_qf_itr, train_policy_itr)) if pool.size >= self.min_pool_size: self.evaluate(epoch, pool) params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) logger.dump_tabular(with_prefix=False) logger.pop_prefix() if self.plot: self.update_plot() if self.pause_for_plot: input("Plotting evaluation run: Press Enter to " "continue...") self.env.terminate() self.policy.terminate()
def run_experiment(argv): default_log_dir = config.LOG_DIR now = datetime.datetime.now(dateutil.tz.tzlocal()) # avoid name clashes when running distributed jobs rand_id = str(uuid.uuid4())[:5] timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z') default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id) parser = argparse.ArgumentParser() parser.add_argument('--n_parallel', type=int, default=1, help='Number of parallel workers to perform rollouts.') parser.add_argument('--exp_name', type=str, default=default_exp_name, help='Name of the experiment.') parser.add_argument('--log_dir', type=str, default=default_log_dir, help='Path to save the log and iteration snapshot.') parser.add_argument('--snapshot_mode', type=str, default='all', help='Mode to save the snapshot. Can be either "all" ' '(all iterations will be saved), "last" (only ' 'the last iteration will be saved), or "none" ' '(do not save snapshots)') parser.add_argument('--tabular_log_file', type=str, default='progress.csv', help='Name of the tabular log file (in csv).') parser.add_argument('--text_log_file', type=str, default='debug.log', help='Name of the text log file (in pure text).') parser.add_argument('--params_log_file', type=str, default='params.json', help='Name of the parameter log file (in json).') parser.add_argument('--plot', type=ast.literal_eval, default=False, help='Whether to plot the iteration results') parser.add_argument( '--log_tabular_only', type=ast.literal_eval, default=False, help= 'Whether to only print the tabular log information (in a horizontal format)' ) parser.add_argument('--seed', type=int, help='Random seed for numpy') parser.add_argument('--args_data', type=str, help='Pickled data for stub objects') parser.add_argument('--use_cloudpickle', type=bool, help='NOT USED') args = parser.parse_args(argv[1:]) from sandbox.vime.sampler import parallel_sampler_expl as parallel_sampler parallel_sampler.initialize(n_parallel=args.n_parallel) if args.seed is not None: set_seed(args.seed) parallel_sampler.set_seed(args.seed) if args.plot: from rllab.plotter import plotter plotter.init_worker() # read from stdin data = pickle.loads(base64.b64decode(args.args_data)) log_dir = args.log_dir # exp_dir = osp.join(log_dir, args.exp_name) tabular_log_file = osp.join(log_dir, args.tabular_log_file) text_log_file = osp.join(log_dir, args.text_log_file) params_log_file = osp.join(log_dir, args.params_log_file) logger.log_parameters_lite(params_log_file, args) logger.add_text_output(text_log_file) logger.add_tabular_output(tabular_log_file) prev_snapshot_dir = logger.get_snapshot_dir() prev_mode = logger.get_snapshot_mode() logger.set_snapshot_dir(log_dir) logger.set_snapshot_mode(args.snapshot_mode) logger.set_log_tabular_only(args.log_tabular_only) logger.push_prefix("[%s] " % args.exp_name) maybe_iter = concretize(data) if is_iterable(maybe_iter): for _ in maybe_iter: pass logger.set_snapshot_mode(prev_mode) logger.set_snapshot_dir(prev_snapshot_dir) logger.remove_tabular_output(tabular_log_file) logger.remove_text_output(text_log_file) logger.pop_prefix()
def _train(self, env, policy, pool): """When training our policy expects an augmented observation.""" self._init_training(env, policy, pool) with self._sess.as_default(): observation = env.reset() policy.reset() log_p_z_episode = [] # Store log_p_z for this episode path_length = 0 path_return = 0 last_path_return = 0 max_path_return = -np.inf n_episodes = 0 if self._learn_p_z: log_p_z_list = [deque(maxlen=self._max_path_length) for _ in range(self._num_skills)] gt.rename_root('RLAlgorithm') gt.reset() gt.set_def_unique(False) for epoch in gt.timed_for(range(self._n_epochs + 1), save_itrs=True): logger.push_prefix('Epoch #%d | ' % epoch) path_length_list = [] z = self._sample_z() aug_obs = utils.concat_obs_z(observation, z, self._num_skills) for t in range(self._epoch_length): iteration = t + epoch * self._epoch_length action, _ = policy.get_action(aug_obs) if self._learn_p_z: (obs, _) = utils.split_aug_obs(aug_obs, self._num_skills) feed_dict = {self._discriminator._obs_pl: obs[None], self._discriminator._action_pl: action[None]} logits = tf_utils.get_default_session().run( self._discriminator._output_t, feed_dict)[0] log_p_z = np.log(utils._softmax(logits)[z]) if self._learn_p_z: log_p_z_list[z].append(log_p_z) next_ob, reward, terminal, info = env.step(action) aug_next_ob = utils.concat_obs_z(next_ob, z, self._num_skills) path_length += 1 path_return += reward self._pool.add_sample( aug_obs, action, reward, terminal, aug_next_ob, ) if terminal or path_length >= self._max_path_length: path_length_list.append(path_length) observation = env.reset() policy.reset() log_p_z_episode = [] path_length = 0 max_path_return = max(max_path_return, path_return) last_path_return = path_return path_return = 0 n_episodes += 1 else: aug_obs = aug_next_ob gt.stamp('sample') if self._pool.size >= self._min_pool_size: for i in range(self._n_train_repeat): batch = self._pool.random_batch(self._batch_size) self._do_training(iteration, batch) gt.stamp('train') if self._learn_p_z: print('learning p(z)') for z in range(self._num_skills): if log_p_z_list[z]: print('\t skill = %d, min=%.2f, max=%.2f, mean=%.2f, len=%d' % (z, np.min(log_p_z_list[z]), np.max(log_p_z_list[z]), np.mean(log_p_z_list[z]), len(log_p_z_list[z]))) log_p_z = [np.mean(log_p_z) if log_p_z else np.log(1.0 / self._num_skills) for log_p_z in log_p_z_list] print('log_p_z: %s' % log_p_z) self._p_z = utils._softmax(log_p_z) self._evaluate(epoch) params = self.get_snapshot(epoch) logger.save_itr_params(epoch, params) times_itrs = gt.get_times().stamps.itrs eval_time = times_itrs['eval'][-1] if epoch > 1 else 0 total_time = gt.get_times().total logger.record_tabular('time-train', times_itrs['train'][-1]) logger.record_tabular('time-eval', eval_time) logger.record_tabular('time-sample', times_itrs['sample'][-1]) logger.record_tabular('time-total', total_time) logger.record_tabular('epoch', epoch) logger.record_tabular('episodes', n_episodes) logger.record_tabular('max-path-return', max_path_return) logger.record_tabular('last-path-return', last_path_return) logger.record_tabular('pool-size', self._pool.size) logger.record_tabular('path-length', np.mean(path_length_list)) logger.dump_tabular(with_prefix=False) logger.pop_prefix() gt.stamp('eval') env.terminate()
logger.log( 'WARNING: Training with uniform sampling. Testing is done BEFORE the optimization !!!!' ) algo.train_brownian_with_uniform() else: algo.train_brownian() elif params['use_hide_alg'] == 2: if train_mode == 0: algo.train_brownian_with_goals() elif train_mode == 1: algo.train_brownian_reverse_repeat() elif train_mode == 2: algo.train_brownian_multiseed() elif train_mode == 3: algo.train_brownian_multiseed_swap_every_update_period() else: raise NotImplementedError else: #my version of the alg algo.train_hide_seek() else: algo.train_seek() logger.log('Experiment finished ...') logger.set_snapshot_mode(prev_mode) logger.set_snapshot_dir(prev_snapshot_dir) logger.remove_tabular_output(tabular_log_file_fullpath) logger.remove_text_output(text_log_file_fullpath) logger.pop_prefix() print('Tabular log file:', tabular_log_file_fullpath)
def train(self): # Bayesian neural network (BNN) initialization. # ------------------------------------------------ batch_size = 1 # Redundant n_batches = 5 # Hardcode or annealing scheme \pi_i. # MDP observation and action dimensions. obs_dim = np.prod(self.env.observation_space.shape) act_dim = np.prod(self.env.action_space.shape) logger.log("Building BNN model (eta={}) ...".format(self.eta)) start_time = time.time() self.bnn = bnn.BNN( n_in=(obs_dim + act_dim), n_hidden=self.unn_n_hidden, n_out=obs_dim, n_batches=n_batches, layers_type=self.unn_layers_type, trans_func=lasagne.nonlinearities.rectify, out_func=lasagne.nonlinearities.linear, batch_size=batch_size, n_samples=self.snn_n_samples, prior_sd=self.prior_sd, use_reverse_kl_reg=self.use_reverse_kl_reg, reverse_kl_reg_factor=self.reverse_kl_reg_factor, # stochastic_output=self.stochastic_output, second_order_update=self.second_order_update, learning_rate=self.unn_learning_rate, compression=self.compression, information_gain=self.information_gain ) logger.log( "Model built ({:.1f} sec).".format((time.time() - start_time))) if self.use_replay_pool: self.pool = SimpleReplayPool( max_pool_size=self.replay_pool_size, observation_shape=self.env.observation_space.shape, action_dim=act_dim ) # ------------------------------------------------ self.start_worker() self.init_opt() episode_rewards = [] episode_lengths = [] for itr in xrange(self.start_itr, self.n_itr): logger.push_prefix('itr #%d | ' % itr) paths = self.obtain_samples(itr) samples_data = self.process_samples(itr, paths) # Exploration code # ---------------- if self.use_replay_pool: # Fill replay pool. logger.log("Fitting dynamics model using replay pool ...") for path in samples_data['paths']: path_len = len(path['rewards']) for i in xrange(path_len): obs = path['observations'][i] act = path['actions'][i] rew = path['rewards'][i] term = (i == path_len - 1) self.pool.add_sample(obs, act, rew, term) # Now we train the dynamics model using the replay self.pool; only # if self.pool is large enough. if self.pool.size >= self.min_pool_size: obs_mean, obs_std, act_mean, act_std = self.pool.mean_obs_act() _inputss = [] _targetss = [] for _ in xrange(self.n_updates_per_sample): batch = self.pool.random_batch( self.pool_batch_size) obs = (batch['observations'] - obs_mean) / \ (obs_std + 1e-8) next_obs = ( batch['next_observations'] - obs_mean) / (obs_std + 1e-8) act = (batch['actions'] - act_mean) / \ (act_std + 1e-8) _inputs = np.hstack( [obs, act]) _targets = next_obs _inputss.append(_inputs) _targetss.append(_targets) old_acc = 0. for _inputs, _targets in zip(_inputss, _targetss): _out = self.bnn.pred_fn(_inputs) old_acc += np.mean(np.square(_out - _targets)) old_acc /= len(_inputss) for _inputs, _targets in zip(_inputss, _targetss): self.bnn.train_fn(_inputs, _targets) new_acc = 0. for _inputs, _targets in zip(_inputss, _targetss): _out = self.bnn.pred_fn(_inputs) new_acc += np.mean(np.square(_out - _targets)) new_acc /= len(_inputss) logger.record_tabular( 'BNN_DynModelSqLossBefore', old_acc) logger.record_tabular( 'BNN_DynModelSqLossAfter', new_acc) # ---------------- self.env.log_diagnostics(paths) self.policy.log_diagnostics(paths) self.baseline.log_diagnostics(paths) self.optimize_policy(itr, samples_data) logger.log("saving snapshot...") params = self.get_itr_snapshot(itr, samples_data) paths = samples_data["paths"] if self.store_paths: params["paths"] = paths episode_rewards.extend(sum(p["rewards"]) for p in paths) episode_lengths.extend(len(p["rewards"]) for p in paths) params["episode_rewards"] = np.array(episode_rewards) params["episode_lengths"] = np.array(episode_lengths) params["algo"] = self logger.save_itr_params(itr, params) logger.log("saved") logger.dump_tabular(with_prefix=False) logger.pop_prefix() if self.plot: self.update_plot() if self.pause_for_plot: raw_input("Plotting evaluation run: Press Enter to " "continue...") self.shutdown_worker()
def train(self): # This seems like a rather sequential method pool = SimpleReplayPool( max_pool_size=self.replay_pool_size, observation_dim=self.env.observation_space.flat_dim, action_dim=self.env.action_space.flat_dim, ) self.start_worker() self.init_opt() itr = 0 path_length = 0 path_return = 0 terminal = False observation = self.env.reset() sample_policy = pickle.loads(pickle.dumps(self.policy)) for epoch in range(self.n_epochs): logger.push_prefix('epoch #%d | ' % epoch) logger.log("Training started") for epoch_itr in pyprind.prog_bar(range(self.epoch_length)): # Execute policy if terminal: # or path_length > self.max_path_length: # Note that if the last time step ends an episode, the very # last state and observation will be ignored and not added # to the replay pool observation = self.env.reset() self.es.reset() sample_policy.reset() self.es_path_returns.append(path_return) path_length = 0 path_return = 0 action = self.es.get_action(itr, observation, policy=sample_policy) # qf=qf) next_observation, reward, terminal, _ = self.env.step(action) path_length += 1 path_return += reward if not terminal and path_length >= self.max_path_length: terminal = True # only include the terminal transition in this case if the # flag was set if self.include_horizon_terminal_transitions: pool.add_sample(observation, action, reward * self.scale_reward, terminal) else: pool.add_sample(observation, action, reward * self.scale_reward, terminal) observation = next_observation if pool.size >= self.min_pool_size: for update_itr in range(self.n_updates_per_sample): # Train policy batch = pool.random_batch(self.batch_size) self.do_training(itr, batch) sample_policy.set_param_values( self.policy.get_param_values()) itr += 1 logger.log("Training finished") if pool.size >= self.min_pool_size: self.evaluate(epoch, pool) params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) logger.dump_tabular(with_prefix=False) logger.pop_prefix() if self.plot: self.update_plot() if self.pause_for_plot: input("Plotting evaluation run: Press Enter to " "continue...") self.env.terminate() self.policy.terminate()
def train(self): pool = SimpleReplayPool( max_pool_size=self.replay_pool_size, observation_dim=self.env.observation_space.flat_dim, action_dim=self.env.action_space.flat_dim, ) self.start_worker() self.init_opt() itr = 0 path_length = 0 path_return = 0 terminal = False observation = self.env.reset() sample_policy = pickle.loads(pickle.dumps(self.policy)) """ Loop for the number of training epochs """ for epoch in range(self.n_epochs): print("Number of Training Epochs", epoch) logger.push_prefix('epoch #%d | ' % epoch) logger.log("Training started") print("Q Network Weights before LP", self.qf.get_param_values(regularizable=True)) updated_q_network, updated_policy_network, _, _, end_trajectory_action, end_trajectory_state = self.lp.lp_exploration( ) self.qf = updated_q_network self.policy = updated_policy_network observation = end_trajectory_state for epoch_itr in pyprind.prog_bar(range(self.epoch_length)): # Execute policy if terminal: # or path_length > self.max_path_length: # Note that if the last time step ends an episode, the very # last state and observation will be ignored and not added # to the replay pool observation = self.env.reset() self.es.reset() sample_policy.reset() self.es_path_returns.append(path_return) path_length = 0 path_return = 0 else: initial = False action = self.es.get_action(itr, observation, policy=sample_policy) # qf=qf) #next states, and reward based on the chosen action next_observation, reward, terminal, _ = self.env.step(action) path_length += 1 #accumulation of total rewards path_return += reward if not terminal and path_length >= self.max_path_length: terminal = True # only include the terminal transition in this case if the flag was set if self.include_horizon_terminal_transitions: pool.add_sample(observation, action, reward * self.scale_reward, terminal, initial) else: pool.add_sample(observation, action, reward * self.scale_reward, terminal, initial) observation = next_observation if pool.size >= self.min_pool_size: for update_itr in range(self.n_updates_per_sample): # print ("Replay Buffer, iterations", update_itr) # Train policy batch = pool.random_batch(self.batch_size) self.do_training(itr, batch) sample_policy.set_param_values( self.policy.get_param_values()) itr += 1 logger.log("Training finished") print("DDPG, Updated Q Network Weights", self.qf.get_param_values(regularizable=True)) if pool.size >= self.min_pool_size: self.evaluate(epoch, pool) params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) logger.dump_tabular(with_prefix=False) logger.pop_prefix() if self.plot: self.update_plot() if self.pause_for_plot: input("Plotting evaluation run: Press Enter to " "continue...") self.env.terminate() self.policy.terminate()
def train(self): # This seems like a rather sequential method pool = SimpleReplayPool( max_pool_size=self.replay_pool_size, observation_dim=self.env.observation_space.flat_dim, action_dim=self.env.action_space.flat_dim, ) self.start_worker() self.init_opt() itr = 0 path_length = 0 path_return = 0 terminal = False observation = self.env.reset() sample_policy = pickle.loads(pickle.dumps(self.policy)) for epoch in range(self.n_epochs): logger.push_prefix('epoch #%d | ' % epoch) logger.log("Training started") for epoch_itr in pyprind.prog_bar(range(self.epoch_length)): # Execute policy if terminal: # or path_length > self.max_path_length: # Note that if the last time step ends an episode, the very # last state and observation will be ignored and not added # to the replay pool observation = self.env.reset() self.es.reset() sample_policy.reset() self.es_path_returns.append(path_return) path_length = 0 path_return = 0 action = self.es.get_action(itr, observation, policy=sample_policy) # qf=qf) next_observation, reward, terminal, _ = self.env.step(action) path_length += 1 path_return += reward if not terminal and path_length >= self.max_path_length: terminal = True # only include the terminal transition in this case if the flag was set if self.include_horizon_terminal_transitions: pool.add_sample(observation, action, reward * self.scale_reward, terminal) else: pool.add_sample(observation, action, reward * self.scale_reward, terminal) observation = next_observation if pool.size >= self.min_pool_size: for update_itr in range(self.n_updates_per_sample): # Train policy batch = pool.random_batch(self.batch_size) self.do_training(itr, batch) sample_policy.set_param_values(self.policy.get_param_values()) itr += 1 logger.log("Training finished") if pool.size >= self.min_pool_size: self.evaluate(epoch, pool) params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) logger.dump_tabular(with_prefix=False) logger.pop_prefix() if self.plot: self.update_plot() if self.pause_for_plot: input("Plotting evaluation run: Press Enter to " "continue...") self.env.terminate() self.policy.terminate()
def train(self): # Bayesian neural network (BNN) initialization. # ------------------------------------------------ batch_size = 1 # Redundant n_batches = 5 # Hardcode or annealing scheme \pi_i. # MDP observation and action dimensions. obs_dim = self.env.observation_space.flat_dim act_dim = self.env.action_space.flat_dim logger.log("Building BNN model (eta={}) ...".format(self.eta)) start_time = time.time() self.bnn = bnn.BNN(n_in=(obs_dim + act_dim), n_hidden=self.unn_n_hidden, n_out=obs_dim, n_batches=n_batches, layers_type=self.unn_layers_type, trans_func=lasagne.nonlinearities.rectify, out_func=lasagne.nonlinearities.linear, batch_size=batch_size, n_samples=self.snn_n_samples, prior_sd=self.prior_sd, likelihood_sd=self.likelihood_sd, use_reverse_kl_reg=self.use_reverse_kl_reg, reverse_kl_reg_factor=self.reverse_kl_reg_factor, learning_rate=self.unn_learning_rate) logger.log("Model built ({:.1f} sec).".format( (time.time() - start_time))) if self.use_replay_pool: self.pool = SimpleReplayPool( max_pool_size=self.replay_pool_size, observation_shape=self.env.observation_space.shape, action_dim=act_dim) # ------------------------------------------------ self.start_worker() self.init_opt() episode_rewards = [] episode_lengths = [] for itr in range(self.start_itr, self.n_itr): logger.push_prefix('itr #%d | ' % itr) paths = self.obtain_samples(itr) samples_data = self.process_samples(itr, paths) # Exploration code # ---------------- if self.use_replay_pool: # Fill replay pool. logger.log("Fitting dynamics model using replay pool ...") for path in samples_data['paths']: path_len = len(path['rewards']) for i in range(path_len): obs = path['observations'][i] act = path['actions'][i] rew = path['rewards'][i] term = (i == path_len - 1) self.pool.add_sample(obs, act, rew, term) # Now we train the dynamics model using the replay self.pool; only # if self.pool is large enough. if self.pool.size >= self.min_pool_size: obs_mean, obs_std, act_mean, act_std = self.pool.mean_obs_act( ) _inputss = [] _targetss = [] for _ in range(self.n_updates_per_sample): batch = self.pool.random_batch(self.pool_batch_size) obs = (batch['observations'] - obs_mean) / \ (obs_std + 1e-8) next_obs = (batch['next_observations'] - obs_mean) / (obs_std + 1e-8) act = (batch['actions'] - act_mean) / \ (act_std + 1e-8) _inputs = np.hstack([obs, act]) _targets = next_obs _inputss.append(_inputs) _targetss.append(_targets) old_acc = 0. for _inputs, _targets in zip(_inputss, _targetss): _out = self.bnn.pred_fn(_inputs) old_acc += np.mean(np.square(_out - _targets)) old_acc /= len(_inputss) for _inputs, _targets in zip(_inputss, _targetss): self.bnn.train_fn(_inputs, _targets) new_acc = 0. for _inputs, _targets in zip(_inputss, _targetss): _out = self.bnn.pred_fn(_inputs) new_acc += np.mean(np.square(_out - _targets)) new_acc /= len(_inputss) logger.record_tabular('BNN_DynModelSqLossBefore', old_acc) logger.record_tabular('BNN_DynModelSqLossAfter', new_acc) # ---------------- self.env.log_diagnostics(paths) self.policy.log_diagnostics(paths) self.baseline.log_diagnostics(paths) self.optimize_policy(itr, samples_data) logger.log("saving snapshot...") params = self.get_itr_snapshot(itr, samples_data) paths = samples_data["paths"] if self.store_paths: params["paths"] = paths episode_rewards.extend(sum(p["rewards"]) for p in paths) episode_lengths.extend(len(p["rewards"]) for p in paths) params["episode_rewards"] = np.array(episode_rewards) params["episode_lengths"] = np.array(episode_lengths) params["algo"] = self logger.save_itr_params(itr, params) logger.log("saved") logger.dump_tabular(with_prefix=False) logger.pop_prefix() if self.plot: self.update_plot() if self.pause_for_plot: raw_input("Plotting evaluation run: Press Enter to " "continue...") self.shutdown_worker()
def run_experiment(argv): default_log_dir = config.LOG_DIR now = datetime.datetime.now(dateutil.tz.tzlocal()) # avoid name clashes when running distributed jobs rand_id = str(uuid.uuid4())[:5] timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z') default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id) parser = argparse.ArgumentParser() parser.add_argument('--n_parallel', type=int, default=1, help='Number of parallel workers to perform rollouts.') parser.add_argument( '--exp_name', type=str, default=default_exp_name, help='Name of the experiment.') parser.add_argument('--log_dir', type=str, default=default_log_dir, help='Path to save the log and iteration snapshot.') parser.add_argument('--snapshot_mode', type=str, default='all', help='Mode to save the snapshot. Can be either "all" ' '(all iterations will be saved), "last" (only ' 'the last iteration will be saved), or "none" ' '(do not save snapshots)') parser.add_argument('--tabular_log_file', type=str, default='progress.csv', help='Name of the tabular log file (in csv).') parser.add_argument('--text_log_file', type=str, default='debug.log', help='Name of the text log file (in pure text).') parser.add_argument('--params_log_file', type=str, default='params.json', help='Name of the parameter log file (in json).') parser.add_argument('--plot', type=ast.literal_eval, default=False, help='Whether to plot the iteration results') parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False, help='Whether to only print the tabular log information (in a horizontal format)') parser.add_argument('--seed', type=int, help='Random seed for numpy') parser.add_argument('--args_data', type=str, help='Pickled data for stub objects') args = parser.parse_args(argv[1:]) from sandbox.vime.sampler import parallel_sampler_expl as parallel_sampler parallel_sampler.initialize(n_parallel=args.n_parallel) if args.seed is not None: set_seed(args.seed) parallel_sampler.set_seed(args.seed) if args.plot: from rllab.plotter import plotter plotter.init_worker() # read from stdin data = pickle.loads(base64.b64decode(args.args_data)) log_dir = args.log_dir # exp_dir = osp.join(log_dir, args.exp_name) tabular_log_file = osp.join(log_dir, args.tabular_log_file) text_log_file = osp.join(log_dir, args.text_log_file) params_log_file = osp.join(log_dir, args.params_log_file) logger.log_parameters_lite(params_log_file, args) logger.add_text_output(text_log_file) logger.add_tabular_output(tabular_log_file) prev_snapshot_dir = logger.get_snapshot_dir() prev_mode = logger.get_snapshot_mode() logger.set_snapshot_dir(log_dir) logger.set_snapshot_mode(args.snapshot_mode) logger.set_log_tabular_only(args.log_tabular_only) logger.push_prefix("[%s] " % args.exp_name) maybe_iter = concretize(data) if is_iterable(maybe_iter): for _ in maybe_iter: pass logger.set_snapshot_mode(prev_mode) logger.set_snapshot_dir(prev_snapshot_dir) logger.remove_tabular_output(tabular_log_file) logger.remove_text_output(text_log_file) logger.pop_prefix()
def _train(self, env, policy, pool): """Perform RL training. Args: env (`rllab.Env`): Environment used for training policy (`Policy`): Policy used for training pool (`PoolBase`): Sample pool to add samples to """ self._init_training(env, policy, pool) with self._sess.as_default(): observation = env.reset() policy.reset() path_length = 0 path_return = 0 last_path_return = 0 max_path_return = -np.inf n_episodes = 0 gt.rename_root('RLAlgorithm') gt.reset() gt.set_def_unique(False) for epoch in gt.timed_for(range(self._n_epochs + 1), save_itrs=True): logger.push_prefix('Epoch #%d | ' % epoch) if self.iter_callback is not None: self.iter_callback(locals(), globals()) for t in range(self._epoch_length): iteration = t + epoch * self._epoch_length action, _ = policy.get_action(observation) next_ob, reward, terminal, info = env.step(action) path_length += 1 path_return += reward self.pool.add_sample( observation, action, reward, terminal, next_ob, ) if terminal or path_length >= self._max_path_length: observation = env.reset() policy.reset() path_length = 0 max_path_return = max(max_path_return, path_return) last_path_return = path_return path_return = 0 n_episodes += 1 else: observation = next_ob gt.stamp('sample') if self.pool.size >= self._min_pool_size: for i in range(self._n_train_repeat): batch = self.pool.random_batch(self._batch_size) self._do_training(iteration, batch) gt.stamp('train') self._evaluate(epoch) params = self.get_snapshot(epoch) logger.save_itr_params(epoch, params) times_itrs = gt.get_times().stamps.itrs eval_time = times_itrs['eval'][-1] if epoch > 1 else 0 total_time = gt.get_times().total logger.record_tabular('time-train', times_itrs['train'][-1]) logger.record_tabular('time-eval', eval_time) logger.record_tabular('time-sample', times_itrs['sample'][-1]) logger.record_tabular('time-total', total_time) logger.record_tabular('epoch', epoch) logger.record_tabular('episodes', n_episodes) logger.record_tabular('max-path-return', max_path_return) logger.record_tabular('last-path-return', last_path_return) logger.record_tabular('pool-size', self.pool.size) logger.dump_tabular(with_prefix=False) logger.pop_prefix() gt.stamp('eval') env.terminate()
def run_experiment(argv): default_log_dir = config.LOG_DIR now = datetime.datetime.now(dateutil.tz.tzlocal()) # avoid name clashes when running distributed jobs rand_id = str(uuid.uuid4())[:5] timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z') default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id) parser = argparse.ArgumentParser() parser.add_argument('--n_parallel', type=int, default=1, help='Number of parallel workers to perform rollouts. 0 => don\'t start any workers') parser.add_argument( '--exp_name', type=str, default=default_exp_name, help='Name of the experiment.') parser.add_argument('--log_dir', type=str, default=None, help='Path to save the log and iteration snapshot.') parser.add_argument('--snapshot_mode', type=str, default='all', help='Mode to save the snapshot. Can be either "all" ' '(all iterations will be saved), "last" (only ' 'the last iteration will be saved), "gap" (every' '`snapshot_gap` iterations are saved), or "none" ' '(do not save snapshots)') parser.add_argument('--snapshot_gap', type=int, default=1, help='Gap between snapshot iterations.') parser.add_argument('--tabular_log_file', type=str, default='progress.csv', help='Name of the tabular log file (in csv).') parser.add_argument('--text_log_file', type=str, default='debug.log', help='Name of the text log file (in pure text).') parser.add_argument('--params_log_file', type=str, default='params.json', help='Name of the parameter log file (in json).') parser.add_argument('--variant_log_file', type=str, default='variant.json', help='Name of the variant log file (in json).') parser.add_argument('--resume_from', type=str, default=None, help='Name of the pickle file to resume experiment from.') parser.add_argument('--plot', type=ast.literal_eval, default=False, help='Whether to plot the iteration results') parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False, help='Whether to only print the tabular log information (in a horizontal format)') parser.add_argument('--seed', type=int, help='Random seed for numpy') parser.add_argument('--args_data', type=str, help='Pickled data for stub objects') parser.add_argument('--variant_data', type=str, help='Pickled data for variant configuration') parser.add_argument('--use_cloudpickle', type=ast.literal_eval, default=False) args = parser.parse_args(argv[1:]) if args.seed is not None: set_seed(args.seed) if args.n_parallel > 0: from rllab.sampler import parallel_sampler parallel_sampler.initialize(n_parallel=args.n_parallel) if args.seed is not None: parallel_sampler.set_seed(args.seed) if args.plot: from rllab.plotter import plotter plotter.init_worker() if args.log_dir is None: log_dir = osp.join(default_log_dir, args.exp_name) else: log_dir = args.log_dir tabular_log_file = osp.join(log_dir, args.tabular_log_file) text_log_file = osp.join(log_dir, args.text_log_file) params_log_file = osp.join(log_dir, args.params_log_file) if args.variant_data is not None: variant_data = pickle.loads(base64.b64decode(args.variant_data)) variant_log_file = osp.join(log_dir, args.variant_log_file) logger.log_variant(variant_log_file, variant_data) else: variant_data = None if not args.use_cloudpickle: logger.log_parameters_lite(params_log_file, args) logger.add_text_output(text_log_file) logger.add_tabular_output(tabular_log_file) prev_snapshot_dir = logger.get_snapshot_dir() prev_mode = logger.get_snapshot_mode() logger.set_snapshot_dir(log_dir) logger.set_snapshot_mode(args.snapshot_mode) logger.set_snapshot_gap(args.snapshot_gap) logger.set_log_tabular_only(args.log_tabular_only) logger.push_prefix("[%s] " % args.exp_name) if args.resume_from is not None: data = joblib.load(args.resume_from) assert 'algo' in data algo = data['algo'] algo.train() else: # read from stdin if args.use_cloudpickle: import cloudpickle method_call = cloudpickle.loads(base64.b64decode(args.args_data)) method_call(variant_data) else: data = pickle.loads(base64.b64decode(args.args_data)) maybe_iter = concretize(data) if is_iterable(maybe_iter): for _ in maybe_iter: pass logger.set_snapshot_mode(prev_mode) logger.set_snapshot_dir(prev_snapshot_dir) logger.remove_tabular_output(tabular_log_file) logger.remove_text_output(text_log_file) logger.pop_prefix()
def run_experiment(argv): default_log_dir = config.LOG_DIR now = datetime.datetime.now(dateutil.tz.tzlocal()) # avoid name clashes when running distributed jobs rand_id = str(uuid.uuid4())[:5] timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z') default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id) parser = argparse.ArgumentParser() parser.add_argument( '--n_parallel', type=int, default=1, help= 'Number of parallel workers to perform rollouts. 0 => don\'t start any workers' ) parser.add_argument('--exp_name', type=str, default=default_exp_name, help='Name of the experiment.') parser.add_argument('--log_dir', type=str, default=None, help='Path to save the log and iteration snapshot.') parser.add_argument('--snapshot_mode', type=str, default='all', help='Mode to save the snapshot. Can be either "all" ' '(all iterations will be saved), "last" (only ' 'the last iteration will be saved), "gap" (every' '`snapshot_gap` iterations are saved), or "none" ' '(do not save snapshots)') parser.add_argument('--snapshot_gap', type=int, default=1, help='Gap between snapshot iterations.') parser.add_argument('--tabular_log_file', type=str, default='progress.csv', help='Name of the tabular log file (in csv).') parser.add_argument('--text_log_file', type=str, default='debug.log', help='Name of the text log file (in pure text).') parser.add_argument('--params_log_file', type=str, default='params.json', help='Name of the parameter log file (in json).') parser.add_argument('--variant_log_file', type=str, default='variant.json', help='Name of the variant log file (in json).') parser.add_argument( '--resume_from', type=str, default=None, help='Name of the pickle file to resume experiment from.') parser.add_argument('--plot', type=ast.literal_eval, default=False, help='Whether to plot the iteration results') parser.add_argument( '--log_tabular_only', type=ast.literal_eval, default=False, help= 'Whether to only print the tabular log information (in a horizontal format)' ) parser.add_argument('--seed', type=int, help='Random seed for numpy') parser.add_argument('--args_data', type=str, help='Pickled data for stub objects') parser.add_argument('--variant_data', type=str, help='Pickled data for variant configuration') parser.add_argument('--use_cloudpickle', type=ast.literal_eval, default=False) args = parser.parse_args(argv[1:]) if args.seed is not None: set_seed(args.seed) if args.n_parallel > 0: from rllab.sampler import parallel_sampler parallel_sampler.initialize(n_parallel=args.n_parallel) if args.seed is not None: parallel_sampler.set_seed(args.seed) if args.plot: from rllab.plotter import plotter plotter.init_worker() if args.log_dir is None: log_dir = osp.join(default_log_dir, args.exp_name) else: log_dir = args.log_dir tabular_log_file = osp.join(log_dir, args.tabular_log_file) text_log_file = osp.join(log_dir, args.text_log_file) params_log_file = osp.join(log_dir, args.params_log_file) if args.variant_data is not None: variant_data = pickle.loads(base64.b64decode(args.variant_data)) variant_log_file = osp.join(log_dir, args.variant_log_file) logger.log_variant(variant_log_file, variant_data) else: variant_data = None if not args.use_cloudpickle: logger.log_parameters_lite(params_log_file, args) logger.add_text_output(text_log_file) logger.add_tabular_output(tabular_log_file) prev_snapshot_dir = logger.get_snapshot_dir() prev_mode = logger.get_snapshot_mode() logger.set_snapshot_dir(log_dir) logger.set_snapshot_mode(args.snapshot_mode) logger.set_snapshot_gap(args.snapshot_gap) logger.set_log_tabular_only(args.log_tabular_only) logger.push_prefix("[%s] " % args.exp_name) # baselines logger baselines_logger.configure(dir=log_dir) if args.resume_from is not None: data = joblib.load(args.resume_from) assert 'algo' in data algo = data['algo'] algo.train() else: # read from stdin if args.use_cloudpickle: import cloudpickle method_call = cloudpickle.loads(base64.b64decode(args.args_data)) method_call(variant_data) else: data = pickle.loads(base64.b64decode(args.args_data)) maybe_iter = concretize(data) if is_iterable(maybe_iter): for _ in maybe_iter: pass logger.set_snapshot_mode(prev_mode) logger.set_snapshot_dir(prev_snapshot_dir) logger.remove_tabular_output(tabular_log_file) logger.remove_text_output(text_log_file) logger.pop_prefix()