def eval_performance(policy, env, period, max_path_length, num_rollouts, seed=0): # import ipdb; ipdb.set_trace() # change the policy period # do the rollouts and aggregate the performances ext.set_seed(seed) returns = [] if isinstance(policy, HierarchicalPolicyRandomTime): with policy.fix_period(period): for _ in trange(num_rollouts): returns.append( np.sum( rollout(env, policy, max_path_length=max_path_length)['rewards'])) # policy.curr_period = period # policy.random_period = False # with policy.manager.set_std_to_0(): # for _ in trange(num_rollouts): # returns.append(np.sum(rollout(env, policy, max_path_length=max_path_length)['rewards'])) else: policy.period = period # with policy.manager.set_std_to_0(): for _ in trange(num_rollouts): returns.append( np.sum( rollout(env, policy, max_path_length=max_path_length)['rewards'])) return returns
def get_velocities(policy, env, max_path_length, num_rollouts, seed=0): ext.set_seed(seed) angles = [] for _ in trange(num_rollouts): rollout_result = rollout(env, policy, max_path_length=max_path_length) angles.append(rollout_result['env_infos']['joint_angles']) return angles
def create_rllab_env(env_name, init_seed): """ create the rllab env """ env = eval(env_name)() ext.set_seed(init_seed) return env
def __init__(self, env, args): self.env = env self.args = args # Parallel setup parallel_sampler.initialize(n_parallel=args.n_parallel) if args.seed is not None: set_seed(args.seed) parallel_sampler.set_seed(args.seed)
def set_seed(seed, env, framework): if framework == 'gym': env.unwrapped.seed(seed) elif framework == 'rllab': ext.set_seed(seed) else: raise("framework not supported") return env
def initialize_worker(group, rank, seed, cpu): log_str = "MbSampler rank: {} initialized".format(rank) try: p = psutil.Process() p.cpu_affinity([cpu]) log_str += ", CPU Affinity: {}".format(p.cpu_affinity()) except AttributeError: pass if seed is not None: ext.set_seed(seed) time.sleep(0.3) # (so the printing from set_seed is not intermixed) log_str += ", Seed: {}".format(seed) logger.log(log_str)
def get_latent_info(policy, env, period, max_path_length, num_rollouts): # change the policy period #do the rollouts and aggregate the performances policy.period = period ext.set_seed(0) latents = [] for i in trange(num_rollouts): latent_infos = rollout( env, policy, max_path_length=max_path_length)['agent_infos']['latents'] latents.append(latent_infos[np.array(range(0, len(latent_infos), 10), dtype=np.uint32)]) return latents
def init_rank(self, rank): self.rank = rank if self.set_cpu_affinity: self._set_affinity(rank) self.baseline.init_rank(rank) self.optimizer.init_rank(rank) if self.exemplar is not None: self.exemplar.init_rank(rank) seed = ext.get_seed() if seed is None: # NOTE: Not sure if this is a good source for seed? seed = int(1e6 * np.random.rand()) ext.set_seed(seed + rank)
def eval_performance(policy, env, max_path_length, num_rollouts): #do the rollouts and aggregate the performances ext.set_seed(0) returns = [] with policy.manager.set_std_to_0(): for i in trange(num_rollouts): returns.append( np.sum( rollout(env, policy, max_path_length=max_path_length)['rewards'])) # if i%50 == 0: # print(np.mean(np.array(returns))) return returns
def __init__(self, env, args): self.env = env self.args = args # Parallel setup parallel_sampler.initialize(n_parallel=args.n_parallel) if args.seed is not None: set_seed(args.seed) parallel_sampler.set_seed(args.seed) index = 0 env, policy = self.parse_env_args(env, args) self.algo = self.setup(env, policy, start_itr=index)
def eval_performance(policy, env, to_remove, max_path_length=5000, num_rollouts=1000): policy.manager.to_remove = to_remove ext.set_seed(0) returns = [] for i in trange(num_rollouts): returns.append( np.sum( rollout(env, policy, max_path_length=max_path_length)['rewards'])) return returns
def setup(seed, n_parallel, log_dir): if seed is not None: set_seed(seed) if n_parallel > 0: from rllab.sampler import parallel_sampler parallel_sampler.initialize(n_parallel=n_parallel) if seed is not None: parallel_sampler.set_seed(seed) if os.path.isdir(log_dir) == False: os.makedirs(log_dir, exist_ok=True) logger.set_snapshot_dir(log_dir) logger.add_tabular_output(log_dir + '/progress.csv')
def eval_p(policy, env, max_path_length, num_rollouts, seed): # change the policy period #do the rollouts and aggregate the performances ext.set_seed(seed) returns = [] # with policy.manager.set_std_to_0(): # for i in trange(num_rollouts): # returns.append(np.sum(rollout(env, policy, max_path_length=max_path_length)['rewards'])) # return returns for i in trange(num_rollouts, desc="Rollouts", ncols=80): returns.append( np.sum( rollout(env, policy, max_path_length=max_path_length)['rewards'])) return returns
def run_experiment(algo, n_parallel=0, seed=0, plot=False, log_dir=None, exp_name=None, snapshot_mode='last', snapshot_gap=1, exp_prefix='experiment', log_tabular_only=False): default_log_dir = config.LOG_DIR + "/local/" + exp_prefix set_seed(seed) if exp_name is None: now = datetime.datetime.now(dateutil.tz.tzlocal()) timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z') exp_name = 'experiment_%s' % (timestamp) if n_parallel > 0: from rllab.sampler import parallel_sampler parallel_sampler.initialize(n_parallel=n_parallel) parallel_sampler.set_seed(seed) if plot: from rllab.plotter import plotter plotter.init_worker() if log_dir is None: log_dir = osp.join(default_log_dir, exp_name) tabular_log_file = osp.join(log_dir, 'progress.csv') text_log_file = osp.join(log_dir, 'debug.log') #params_log_file = osp.join(log_dir, 'params.json') #logger.log_parameters_lite(params_log_file, args) logger.add_text_output(text_log_file) logger.add_tabular_output(tabular_log_file) prev_snapshot_dir = logger.get_snapshot_dir() prev_mode = logger.get_snapshot_mode() logger.set_snapshot_dir(log_dir) logger.set_snapshot_mode(snapshot_mode) logger.set_snapshot_gap(snapshot_gap) logger.set_log_tabular_only(log_tabular_only) logger.push_prefix("[%s] " % exp_name) algo.train() logger.set_snapshot_mode(prev_mode) logger.set_snapshot_dir(prev_snapshot_dir) logger.remove_tabular_output(tabular_log_file) logger.remove_text_output(text_log_file) logger.pop_prefix()
def eval_policy(self, itr, gpu_device=None, gpu_frac=None): if itr == -1: itr = 0 while os.path.exists(self._itr_file(itr)): itr += 1 itr -= 1 if self.params['seed'] is not None: set_seed(self.params['seed']) if gpu_device is None: gpu_device = self.params['policy']['gpu_device'] if gpu_frac is None: gpu_frac = self.params['policy']['gpu_frac'] sess, graph = MACPolicy.create_session_and_graph(gpu_device=gpu_device, gpu_frac=gpu_frac) with graph.as_default(), sess.as_default(): policy = self._load_itr_policy(itr) logger.log('Evaluating policy for itr {0}'.format(itr)) n_envs = 1 if 'max_path_length' in self.params['alg']: max_path_length = self.params['alg']['max_path_length'] else: max_path_length = self.env.horizon sampler = RNNCriticSampler( policy=policy, env=self.env, n_envs=n_envs, replay_pool_size=int(1e4), max_path_length=max_path_length, save_rollouts=True, sampling_method=self.params['alg']['replay_pool_sampling']) rollouts = [] step = 0 logger.log('Starting rollout {0}'.format(len(rollouts))) while len(rollouts) < self._num_rollouts: sampler.step(step) step += n_envs new_rollouts = sampler.get_recent_paths() if len(new_rollouts) > 0: rollouts += new_rollouts logger.log('Starting rollout {0}'.format(len(rollouts))) self.save_eval_rollouts(itr, rollouts)
def create_env(env_str, is_normalize=True, seed=None): from rllab.envs.gym_env import GymEnv, FixedIntervalVideoSchedule from sandbox.gkahn.gcg.envs.rccar.square_env import SquareEnv from sandbox.gkahn.gcg.envs.rccar.square_cluttered_env import SquareClutteredEnv from sandbox.gkahn.gcg.envs.rccar.cylinder_env import CylinderEnv inner_env = eval(env_str) if is_normalize: inner_env = normalize(inner_env) env = TfEnv(inner_env) # set seed if seed is not None: set_seed(seed) if isinstance(inner_env, GymEnv): inner_env.env.seed(seed) return env
def startup(self, master=True): if self.seed is None: self.seed = make_seed() set_seed(self.seed) env_spec, sample_size, horizon, mid_batch_reset = self.sampler.initialize( seed=self.seed + 1, affinities=self.affinities, discount=getattr(self.algo, "discount", None), need_extra_obs=self.algo.need_extra_obs, ) self.init_policy(env_spec) self.algo.initialize( policy=self.policy, env_spec=env_spec, sample_size=sample_size, horizon=horizon, mid_batch_reset=mid_batch_reset, ) self.sampler.policy_init(self.policy) if master: n_itr = self.get_n_itr(sample_size) self.algo.set_n_itr(n_itr) self.init_logging() return n_itr
def _worker_set_seed(_, seed): logger.log("Setting seed to %d" % seed) ext.set_seed(seed)
# parser.add_argument("type", help="Type of DDPG to run: ['unified-decaying', 'unified-gated-decaying', 'unified', 'unified-gated', 'regular']") parser.add_argument("env", help="The environment name from OpenAIGym environments") parser.add_argument("--num_epochs", default=100, type=int) parser.add_argument("--data_dir", default="./data/") parser.add_argument("--use_ec2", action="store_true", help="Use your ec2 instances if configured") parser.add_argument( "--dont_terminate_machine", action="store_false", help="Whether to terminate your spot instance or not. Be careful.") args = parser.parse_args() stub(globals()) ext.set_seed(1) supported_gym_envs = [ "MountainCarContinuous-v0", "InvertedPendulum-v1", "InvertedDoublePendulum-v1", "Hopper-v1", "Walker2d-v1", "Humanoid-v1", "Reacher-v1", "HalfCheetah-v1", "Swimmer-v1", "HumanoidStandup-v1" ] other_env_class_map = {"Cartpole": CartpoleEnv} if args.env in supported_gym_envs: gymenv = GymEnv(args.env, force_reset=True, record_video=False, record_log=False) # gymenv.env.seed(1)
def _worker_set_seed(_, seed): logger.log("Setting seed to %d" % seed) logger.log("Setting seed to %d" % seed) ext.set_seed(seed) logger.log("Done Setting seed to %d" % seed)
parser.add_argument("-is", '--init_state', type=str, help='vector of init_state') parser.add_argument( "-cf", '--collection_file', type=str, help='path to the pkl file with start positions Collection') args = parser.parse_args() policy = None env = None if args.seed >= 0: set_seed(args.seed) if args.collection_file: all_feasible_starts = pickle.load(open(args.collection_file, 'rb')) with tf.Session() as sess: data = joblib.load(args.file) if "algo" in data: policy = data["algo"].policy env = data["algo"].env else: policy = data['policy'] env = data['env'] # easiest to hardest init_pos = [[0, 0]] # init_pos = [[0, 0],
def run_experiment(argv): # e2crawfo: These imports, in this order, were necessary for fixing issues on cedar. import rllab.mujoco_py.mjlib import tensorflow default_log_dir = config.LOG_DIR now = datetime.datetime.now(dateutil.tz.tzlocal()) # avoid name clashes when running distributed jobs rand_id = str(uuid.uuid4())[:5] timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z') default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id) parser = argparse.ArgumentParser() parser.add_argument( '--n_parallel', type=int, default=1, help= 'Number of parallel workers to perform rollouts. 0 => don\'t start any workers' ) parser.add_argument('--exp_name', type=str, default=default_exp_name, help='Name of the experiment.') parser.add_argument('--log_dir', type=str, default=None, help='Path to save the log and iteration snapshot.') parser.add_argument('--snapshot_mode', type=str, default='all', help='Mode to save the snapshot. Can be either "all" ' '(all iterations will be saved), "last" (only ' 'the last iteration will be saved), or "none" ' '(do not save snapshots)') parser.add_argument('--snapshot_gap', type=int, default=1, help='Gap between snapshot iterations.') parser.add_argument('--tabular_log_file', type=str, default='progress.csv', help='Name of the tabular log file (in csv).') parser.add_argument('--text_log_file', type=str, default='debug.log', help='Name of the text log file (in pure text).') parser.add_argument('--params_log_file', type=str, default='params.json', help='Name of the parameter log file (in json).') parser.add_argument('--variant_log_file', type=str, default='variant.json', help='Name of the variant log file (in json).') parser.add_argument( '--resume_from', type=str, default=None, help='Name of the pickle file to resume experiment from.') parser.add_argument('--plot', type=ast.literal_eval, default=False, help='Whether to plot the iteration results') parser.add_argument( '--log_tabular_only', type=ast.literal_eval, default=False, help= 'Whether to only print the tabular log information (in a horizontal format)' ) parser.add_argument('--seed', type=int, help='Random seed for numpy') parser.add_argument('--args_data', type=str, help='Pickled data for stub objects') parser.add_argument('--variant_data', type=str, help='Pickled data for variant configuration') parser.add_argument('--use_cloudpickle', type=ast.literal_eval, default=False) args = parser.parse_args(argv[1:]) if args.seed is not None: set_seed(args.seed) if args.n_parallel > 0: from rllab.sampler import parallel_sampler parallel_sampler.initialize(n_parallel=args.n_parallel) if args.seed is not None: parallel_sampler.set_seed(args.seed) if args.plot: from rllab.plotter import plotter plotter.init_worker() if args.log_dir is None: log_dir = osp.join(default_log_dir, args.exp_name) else: log_dir = args.log_dir tabular_log_file = osp.join(log_dir, args.tabular_log_file) text_log_file = osp.join(log_dir, args.text_log_file) params_log_file = osp.join(log_dir, args.params_log_file) if args.variant_data is not None: variant_data = pickle.loads(base64.b64decode(args.variant_data)) variant_log_file = osp.join(log_dir, args.variant_log_file) logger.log_variant(variant_log_file, variant_data) else: variant_data = None if not args.use_cloudpickle: logger.log_parameters_lite(params_log_file, args) logger.add_text_output(text_log_file) logger.add_tabular_output(tabular_log_file) prev_snapshot_dir = logger.get_snapshot_dir() prev_mode = logger.get_snapshot_mode() logger.set_snapshot_dir(log_dir) logger.set_tf_summary_dir(osp.join(log_dir, "tf_summary")) logger.set_snapshot_mode(args.snapshot_mode) logger.set_snapshot_gap(args.snapshot_gap) logger.set_log_tabular_only(args.log_tabular_only) logger.push_prefix("[%s] " % args.exp_name) if args.resume_from is not None: data = joblib.load(args.resume_from) assert 'algo' in data algo = data['algo'] maybe_iter = algo.train() if is_iterable(maybe_iter): for _ in maybe_iter: pass else: # read from stdin if args.use_cloudpickle: import cloudpickle method_call = cloudpickle.loads(base64.b64decode(args.args_data)) method_call(variant_data) else: data = pickle.loads(base64.b64decode(args.args_data)) maybe_iter = concretize(data) if is_iterable(maybe_iter): for _ in maybe_iter: pass logger.set_snapshot_mode(prev_mode) logger.set_snapshot_dir(prev_snapshot_dir) logger.remove_tabular_output(tabular_log_file) logger.remove_text_output(text_log_file) logger.pop_prefix()
def _worker_set_seed(_, seed): ext.set_seed(seed)
def _graph_setup(self): ### create session and graph tf_sess = tf.get_default_session() if tf_sess is None: tf_sess, tf_graph = MACPolicy.create_session_and_graph( gpu_device=self._gpu_device, gpu_frac=self._gpu_frac) tf_graph = tf_sess.graph with tf_sess.as_default(), tf_graph.as_default(): if ext.get_seed() is not None: ext.set_seed(ext.get_seed()) ### create input output placeholders tf_obs_ph, tf_actions_ph, tf_dones_ph, tf_rewards_ph, tf_obs_target_ph, \ tf_test_es_ph_dict, tf_episode_timesteps_ph = self._graph_input_output_placeholders() self.global_step = tf.Variable(0, trainable=False, name='global_step') ### policy policy_scope = 'policy' with tf.variable_scope(policy_scope): ### create preprocess placeholders tf_preprocess = self._graph_preprocess_placeholders() ### process obs to lowd tf_obs_lowd = self._graph_obs_to_lowd(tf_obs_ph, tf_preprocess, is_training=True) ### create training policy tf_train_values, tf_train_values_softmax, _, _ = \ self._graph_inference(tf_obs_lowd, tf_actions_ph[:, :self._H, :], self._values_softmax, tf_preprocess, is_training=True) with tf.variable_scope(policy_scope, reuse=True): tf_train_values_test, tf_train_values_softmax_test, _, _ = \ self._graph_inference(tf_obs_lowd, tf_actions_ph[:, :self._get_action_test['H'], :], self._values_softmax, tf_preprocess, is_training=False) tf_get_value = tf.reduce_sum(tf_train_values_softmax_test * tf_train_values_test, reduction_indices=1) ### action selection tf_get_action, tf_get_action_value, tf_get_action_reset_ops = \ self._graph_get_action(tf_obs_ph, self._get_action_test, policy_scope, True, policy_scope, True, tf_episode_timesteps_ph) ### exploration strategy and logprob tf_get_action_explore = self._graph_get_action_explore( tf_get_action, tf_test_es_ph_dict) ### get policy variables tf_policy_vars = sorted(tf.get_collection( xplatform.global_variables_collection_name(), scope=policy_scope), key=lambda v: v.name) tf_trainable_policy_vars = sorted(tf.get_collection( xplatform.trainable_variables_collection_name(), scope=policy_scope), key=lambda v: v.name) ### create target network if self._use_target: target_scope = 'target' if self._separate_target_params else 'policy' ### action selection tf_obs_target_ph_packed = xplatform.concat([ tf_obs_target_ph[:, h - self._obs_history_len:h, :] for h in range(self._obs_history_len, self._obs_history_len + self._N + 1) ], 0) tf_target_get_action, tf_target_get_action_values, _ = self._graph_get_action( tf_obs_target_ph_packed, self._get_action_target, scope_select=policy_scope, reuse_select=True, scope_eval=target_scope, reuse_eval=(target_scope == policy_scope), tf_episode_timesteps_ph=None) # TODO would need to fill in tf_target_get_action_values = tf.transpose( tf.reshape(tf_target_get_action_values, (self._N + 1, -1)))[:, 1:] else: tf_target_get_action_values = tf.zeros( [tf.shape(tf_train_values)[0], self._N]) ### update target network if self._use_target and self._separate_target_params: tf_policy_vars_nobatchnorm = list( filter( lambda v: 'biased' not in v.name and 'local_step' not in v.name, tf_policy_vars)) tf_target_vars = sorted(tf.get_collection( xplatform.global_variables_collection_name(), scope=target_scope), key=lambda v: v.name) assert (len(tf_policy_vars_nobatchnorm) == len(tf_target_vars)) tf_update_target_fn = [] for var, var_target in zip(tf_policy_vars_nobatchnorm, tf_target_vars): assert (var.name.replace(policy_scope, '') == var_target.name.replace( target_scope, '')) tf_update_target_fn.append(var_target.assign(var)) tf_update_target_fn = tf.group(*tf_update_target_fn) else: tf_target_vars = None tf_update_target_fn = None ### optimization tf_cost, tf_mse = self._graph_cost(tf_train_values, tf_train_values_softmax, tf_rewards_ph, tf_dones_ph, tf_target_get_action_values) tf_opt, tf_lr_ph = self._graph_optimize(tf_cost, tf_trainable_policy_vars) ### initialize self._graph_init_vars(tf_sess) ### what to return return { 'sess': tf_sess, 'graph': tf_graph, 'obs_ph': tf_obs_ph, 'actions_ph': tf_actions_ph, 'dones_ph': tf_dones_ph, 'rewards_ph': tf_rewards_ph, 'obs_target_ph': tf_obs_target_ph, 'test_es_ph_dict': tf_test_es_ph_dict, 'episode_timesteps_ph': tf_episode_timesteps_ph, 'preprocess': tf_preprocess, 'get_value': tf_get_value, 'get_action': tf_get_action, 'get_action_explore': tf_get_action_explore, 'get_action_value': tf_get_action_value, 'get_action_reset_ops': tf_get_action_reset_ops, 'update_target_fn': tf_update_target_fn, 'cost': tf_cost, 'mse': tf_mse, 'opt': tf_opt, 'lr_ph': tf_lr_ph, 'policy_vars': tf_policy_vars, 'target_vars': tf_target_vars }
def main(): now = datetime.datetime.now(dateutil.tz.tzlocal()) rand_id = str(uuid.uuid4())[:5] timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z') default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id) parser = argparse.ArgumentParser() parser.add_argument('--exp_name', type=str, default=default_exp_name, help='Name of the experiment.') parser.add_argument('--discount', type=float, default=0.99) parser.add_argument('--gae_lambda', type=float, default=1.0) parser.add_argument('--reward_scale', type=float, default=1.0) parser.add_argument('--n_iter', type=int, default=250) parser.add_argument('--sampler_workers', type=int, default=1) parser.add_argument('--max_traj_len', type=int, default=250) parser.add_argument('--update_curriculum', action='store_true', default=False) parser.add_argument('--n_timesteps', type=int, default=8000) parser.add_argument('--control', type=str, default='centralized') parser.add_argument('--rectangle', type=str, default='10,10') parser.add_argument('--map_type', type=str, default='rectangle') parser.add_argument('--n_evaders', type=int, default=5) parser.add_argument('--n_pursuers', type=int, default=2) parser.add_argument('--obs_range', type=int, default=3) parser.add_argument('--n_catch', type=int, default=2) parser.add_argument('--urgency', type=float, default=0.0) parser.add_argument('--pursuit', dest='train_pursuit', action='store_true') parser.add_argument('--evade', dest='train_pursuit', action='store_false') parser.set_defaults(train_pursuit=True) parser.add_argument('--surround', action='store_true', default=False) parser.add_argument('--constraint_window', type=float, default=1.0) parser.add_argument('--sample_maps', action='store_true', default=False) parser.add_argument('--map_file', type=str, default='../maps/map_pool.npy') parser.add_argument('--flatten', action='store_true', default=False) parser.add_argument('--reward_mech', type=str, default='global') parser.add_argument('--catchr', type=float, default=0.1) parser.add_argument('--term_pursuit', type=float, default=5.0) parser.add_argument('--recurrent', type=str, default=None) parser.add_argument('--policy_hidden_sizes', type=str, default='128,128') parser.add_argument('--baselin_hidden_sizes', type=str, default='128,128') parser.add_argument('--baseline_type', type=str, default='linear') parser.add_argument('--conv', action='store_true', default=False) parser.add_argument('--max_kl', type=float, default=0.01) parser.add_argument('--checkpoint', type=str, default=None) parser.add_argument('--log_dir', type=str, required=False) parser.add_argument('--tabular_log_file', type=str, default='progress.csv', help='Name of the tabular log file (in csv).') parser.add_argument('--text_log_file', type=str, default='debug.log', help='Name of the text log file (in pure text).') parser.add_argument('--params_log_file', type=str, default='params.json', help='Name of the parameter log file (in json).') parser.add_argument('--seed', type=int, help='Random seed for numpy') parser.add_argument('--args_data', type=str, help='Pickled data for stub objects') parser.add_argument('--snapshot_mode', type=str, default='all', help='Mode to save the snapshot. Can be either "all" ' '(all iterations will be saved), "last" (only ' 'the last iteration will be saved), or "none" ' '(do not save snapshots)') parser.add_argument( '--log_tabular_only', type=ast.literal_eval, default=False, help= 'Whether to only print the tabular log information (in a horizontal format)' ) args = parser.parse_args() parallel_sampler.initialize(n_parallel=args.sampler_workers) if args.seed is not None: set_seed(args.seed) parallel_sampler.set_seed(args.seed) args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(','))) if args.checkpoint: with tf.Session() as sess: data = joblib.load(args.checkpoint) policy = data['policy'] env = data['env'] else: if args.sample_maps: map_pool = np.load(args.map_file) else: if args.map_type == 'rectangle': env_map = TwoDMaps.rectangle_map( *map(int, args.rectangle.split(','))) elif args.map_type == 'complex': env_map = TwoDMaps.complex_map( *map(int, args.rectangle.split(','))) else: raise NotImplementedError() map_pool = [env_map] env = PursuitEvade(map_pool, n_evaders=args.n_evaders, n_pursuers=args.n_pursuers, obs_range=args.obs_range, n_catch=args.n_catch, train_pursuit=args.train_pursuit, urgency_reward=args.urgency, surround=args.surround, sample_maps=args.sample_maps, constraint_window=args.constraint_window, flatten=args.flatten, reward_mech=args.reward_mech, catchr=args.catchr, term_pursuit=args.term_pursuit) env = TfEnv( RLLabEnv(StandardizedEnv(env, scale_reward=args.reward_scale, enable_obsnorm=False), mode=args.control)) if args.recurrent: if args.conv: feature_network = ConvNetwork( name='feature_net', input_shape=emv.spec.observation_space.shape, output_dim=5, conv_filters=(16, 32, 32), conv_filter_sizes=(3, 3, 3), conv_strides=(1, 1, 1), conv_pads=('VALID', 'VALID', 'VALID'), hidden_sizes=(64, ), hidden_nonlinearity=tf.nn.relu, output_nonlinearity=tf.nn.softmax) else: feature_network = MLP( name='feature_net', input_shape=(env.spec.observation_space.flat_dim + env.spec.action_space.flat_dim, ), output_dim=5, hidden_sizes=(256, 128, 64), hidden_nonlinearity=tf.nn.tanh, output_nonlinearity=None) if args.recurrent == 'gru': policy = CategoricalGRUPolicy(env_spec=env.spec, feature_network=feature_network, hidden_dim=int( args.policy_hidden_sizes), name='policy') elif args.recurrent == 'lstm': policy = CategoricalLSTMPolicy(env_spec=env.spec, feature_network=feature_network, hidden_dim=int( args.policy_hidden_sizes), name='policy') elif args.conv: feature_network = ConvNetwork( name='feature_net', input_shape=env.spec.observation_space.shape, output_dim=5, conv_filters=(8, 16), conv_filter_sizes=(3, 3), conv_strides=(2, 1), conv_pads=('VALID', 'VALID'), hidden_sizes=(32, ), hidden_nonlinearity=tf.nn.relu, output_nonlinearity=tf.nn.softmax) policy = CategoricalMLPPolicy(name='policy', env_spec=env.spec, prob_network=feature_network) else: policy = CategoricalMLPPolicy(name='policy', env_spec=env.spec, hidden_sizes=args.hidden_sizes) if args.baseline_type == 'linear': baseline = LinearFeatureBaseline(env_spec=env.spec) else: baseline = ZeroBaseline(env_spec=env.spec) # logger default_log_dir = config.LOG_DIR if args.log_dir is None: log_dir = osp.join(default_log_dir, args.exp_name) else: log_dir = args.log_dir tabular_log_file = osp.join(log_dir, args.tabular_log_file) text_log_file = osp.join(log_dir, args.text_log_file) params_log_file = osp.join(log_dir, args.params_log_file) logger.log_parameters_lite(params_log_file, args) logger.add_text_output(text_log_file) logger.add_tabular_output(tabular_log_file) prev_snapshot_dir = logger.get_snapshot_dir() prev_mode = logger.get_snapshot_mode() logger.set_snapshot_dir(log_dir) logger.set_snapshot_mode(args.snapshot_mode) logger.set_log_tabular_only(args.log_tabular_only) logger.push_prefix("[%s] " % args.exp_name) algo = TRPO( env=env, policy=policy, baseline=baseline, batch_size=args.n_timesteps, max_path_length=args.max_traj_len, n_itr=args.n_iter, discount=args.discount, gae_lambda=args.gae_lambda, step_size=args.max_kl, optimizer=ConjugateGradientOptimizer(hvp_approach=FiniteDifferenceHvp( base_eps=1e-5)) if args.recurrent else None, mode=args.control, ) algo.train()
def run_experiment(argv): default_log_dir = config.LOG_DIR now = datetime.datetime.now(dateutil.tz.tzlocal()) # avoid name clashes when running distributed jobs rand_id = str(uuid.uuid4())[:5] timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z') default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id) parser = argparse.ArgumentParser() parser.add_argument('--n_parallel', type=int, default=1, help='Number of parallel workers to perform rollouts. 0 => don\'t start any workers') parser.add_argument( '--exp_name', type=str, default=default_exp_name, help='Name of the experiment.') parser.add_argument('--log_dir', type=str, default=None, help='Path to save the log and iteration snapshot.') parser.add_argument('--snapshot_mode', type=str, default='all', help='Mode to save the snapshot. Can be either "all" ' '(all iterations will be saved), "last" (only ' 'the last iteration will be saved), "gap" (every' '`snapshot_gap` iterations are saved), or "none" ' '(do not save snapshots)') parser.add_argument('--snapshot_gap', type=int, default=1, help='Gap between snapshot iterations.') parser.add_argument('--tabular_log_file', type=str, default='progress.csv', help='Name of the tabular log file (in csv).') parser.add_argument('--text_log_file', type=str, default='debug.log', help='Name of the text log file (in pure text).') parser.add_argument('--params_log_file', type=str, default='params.json', help='Name of the parameter log file (in json).') parser.add_argument('--variant_log_file', type=str, default='variant.json', help='Name of the variant log file (in json).') parser.add_argument('--resume_from', type=str, default=None, help='Name of the pickle file to resume experiment from.') parser.add_argument('--plot', type=ast.literal_eval, default=False, help='Whether to plot the iteration results') parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False, help='Whether to only print the tabular log information (in a horizontal format)') parser.add_argument('--seed', type=int, help='Random seed for numpy') parser.add_argument('--args_data', type=str, help='Pickled data for stub objects') parser.add_argument('--variant_data', type=str, help='Pickled data for variant configuration') parser.add_argument('--use_cloudpickle', type=ast.literal_eval, default=False) args = parser.parse_args(argv[1:]) if args.seed is not None: set_seed(args.seed) if args.n_parallel > 0: from rllab.sampler import parallel_sampler parallel_sampler.initialize(n_parallel=args.n_parallel) if args.seed is not None: parallel_sampler.set_seed(args.seed) if args.plot: from rllab.plotter import plotter plotter.init_worker() if args.log_dir is None: log_dir = osp.join(default_log_dir, args.exp_name) else: log_dir = args.log_dir tabular_log_file = osp.join(log_dir, args.tabular_log_file) text_log_file = osp.join(log_dir, args.text_log_file) params_log_file = osp.join(log_dir, args.params_log_file) if args.variant_data is not None: variant_data = pickle.loads(base64.b64decode(args.variant_data)) variant_log_file = osp.join(log_dir, args.variant_log_file) logger.log_variant(variant_log_file, variant_data) else: variant_data = None if not args.use_cloudpickle: logger.log_parameters_lite(params_log_file, args) logger.add_text_output(text_log_file) logger.add_tabular_output(tabular_log_file) prev_snapshot_dir = logger.get_snapshot_dir() prev_mode = logger.get_snapshot_mode() logger.set_snapshot_dir(log_dir) logger.set_snapshot_mode(args.snapshot_mode) logger.set_snapshot_gap(args.snapshot_gap) logger.set_log_tabular_only(args.log_tabular_only) logger.push_prefix("[%s] " % args.exp_name) if args.resume_from is not None: data = joblib.load(args.resume_from) assert 'algo' in data algo = data['algo'] algo.train() else: # read from stdin if args.use_cloudpickle: import cloudpickle method_call = cloudpickle.loads(base64.b64decode(args.args_data)) method_call(variant_data) else: data = pickle.loads(base64.b64decode(args.args_data)) maybe_iter = concretize(data) if is_iterable(maybe_iter): for _ in maybe_iter: pass logger.set_snapshot_mode(prev_mode) logger.set_snapshot_dir(prev_snapshot_dir) logger.remove_tabular_output(tabular_log_file) logger.remove_text_output(text_log_file) logger.pop_prefix()
def main(): now = datetime.datetime.now(dateutil.tz.tzlocal()) rand_id = str(uuid.uuid4())[:5] timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z') default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id) parser = argparse.ArgumentParser() parser.add_argument('--exp_name', type=str, default=default_exp_name, help='Name of the experiment.') parser.add_argument('--discount', type=float, default=0.95) parser.add_argument('--gae_lambda', type=float, default=0.99) parser.add_argument('--reward_scale', type=float, default=1.0) parser.add_argument('--enable_obsnorm', action='store_true', default=False) parser.add_argument('--chunked', action='store_true', default=False) parser.add_argument('--n_iter', type=int, default=250) parser.add_argument('--sampler_workers', type=int, default=1) parser.add_argument('--max_traj_len', type=int, default=250) parser.add_argument('--update_curriculum', action='store_true', default=False) parser.add_argument('--anneal_step_size', type=int, default=0) parser.add_argument('--n_timesteps', type=int, default=8000) parser.add_argument('--control', type=str, default='centralized') parser.add_argument('--buffer_size', type=int, default=1) parser.add_argument('--radius', type=float, default=0.015) parser.add_argument('--n_evaders', type=int, default=10) parser.add_argument('--n_pursuers', type=int, default=8) parser.add_argument('--n_poison', type=int, default=10) parser.add_argument('--n_coop', type=int, default=4) parser.add_argument('--n_sensors', type=int, default=30) parser.add_argument('--sensor_range', type=str, default='0.2') parser.add_argument('--food_reward', type=float, default=5) parser.add_argument('--poison_reward', type=float, default=-1) parser.add_argument('--encounter_reward', type=float, default=0.05) parser.add_argument('--reward_mech', type=str, default='local') parser.add_argument('--recurrent', type=str, default=None) parser.add_argument('--baseline_type', type=str, default='linear') parser.add_argument('--policy_hidden_sizes', type=str, default='128,128') parser.add_argument('--baseline_hidden_sizes', type=str, default='128,128') parser.add_argument('--max_kl', type=float, default=0.01) parser.add_argument('--log_dir', type=str, required=False) parser.add_argument('--tabular_log_file', type=str, default='progress.csv', help='Name of the tabular log file (in csv).') parser.add_argument('--text_log_file', type=str, default='debug.log', help='Name of the text log file (in pure text).') parser.add_argument('--params_log_file', type=str, default='params.json', help='Name of the parameter log file (in json).') parser.add_argument('--seed', type=int, help='Random seed for numpy') parser.add_argument('--args_data', type=str, help='Pickled data for stub objects') parser.add_argument('--snapshot_mode', type=str, default='all', help='Mode to save the snapshot. Can be either "all" ' '(all iterations will be saved), "last" (only ' 'the last iteration will be saved), or "none" ' '(do not save snapshots)') parser.add_argument( '--log_tabular_only', type=ast.literal_eval, default=False, help= 'Whether to only print the tabular log information (in a horizontal format)' ) args = parser.parse_args() parallel_sampler.initialize(n_parallel=args.sampler_workers) if args.seed is not None: set_seed(args.seed) parallel_sampler.set_seed(args.seed) args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(','))) centralized = True if args.control == 'centralized' else False sensor_range = np.array(map(float, args.sensor_range.split(','))) if len(sensor_range) == 1: sensor_range = sensor_range[0] else: assert sensor_range.shape == (args.n_pursuers, ) env = MAWaterWorld(args.n_pursuers, args.n_evaders, args.n_coop, args.n_poison, radius=args.radius, n_sensors=args.n_sensors, food_reward=args.food_reward, poison_reward=args.poison_reward, encounter_reward=args.encounter_reward, reward_mech=args.reward_mech, sensor_range=sensor_range, obstacle_loc=None) env = TfEnv( RLLabEnv(StandardizedEnv(env, scale_reward=args.reward_scale, enable_obsnorm=args.enable_obsnorm), mode=args.control)) if args.buffer_size > 1: env = ObservationBuffer(env, args.buffer_size) if args.recurrent: feature_network = MLP( name='feature_net', input_shape=(env.spec.observation_space.flat_dim + env.spec.action_space.flat_dim, ), output_dim=16, hidden_sizes=(128, 64, 32), hidden_nonlinearity=tf.nn.tanh, output_nonlinearity=None) if args.recurrent == 'gru': policy = GaussianGRUPolicy(env_spec=env.spec, feature_network=feature_network, hidden_dim=int( args.policy_hidden_sizes), name='policy') elif args.recurrent == 'lstm': policy = GaussianLSTMPolicy(env_spec=env.spec, feature_network=feature_network, hidden_dim=int( args.policy_hidden_sizes), name='policy') else: policy = GaussianMLPPolicy( name='policy', env_spec=env.spec, hidden_sizes=tuple(map(int, args.policy_hidden_sizes.split(','))), min_std=10e-5) if args.baseline_type == 'linear': baseline = LinearFeatureBaseline(env_spec=env.spec) elif args.baseline_type == 'mlp': raise NotImplementedError() # baseline = GaussianMLPBaseline( # env_spec=env.spec, hidden_sizes=tuple(map(int, args.baseline_hidden_sizes.split(',')))) else: baseline = ZeroBaseline(env_spec=env.spec) # logger default_log_dir = config.LOG_DIR if args.log_dir is None: log_dir = osp.join(default_log_dir, args.exp_name) else: log_dir = args.log_dir tabular_log_file = osp.join(log_dir, args.tabular_log_file) text_log_file = osp.join(log_dir, args.text_log_file) params_log_file = osp.join(log_dir, args.params_log_file) logger.log_parameters_lite(params_log_file, args) logger.add_text_output(text_log_file) logger.add_tabular_output(tabular_log_file) prev_snapshot_dir = logger.get_snapshot_dir() prev_mode = logger.get_snapshot_mode() logger.set_snapshot_dir(log_dir) logger.set_snapshot_mode(args.snapshot_mode) logger.set_log_tabular_only(args.log_tabular_only) logger.push_prefix("[%s] " % args.exp_name) algo = TRPO( env=env, policy=policy, baseline=baseline, batch_size=args.n_timesteps, max_path_length=args.max_traj_len, #max_path_length_limit=args.max_path_length_limit, update_max_path_length=args.update_curriculum, anneal_step_size=args.anneal_step_size, n_itr=args.n_iter, discount=args.discount, gae_lambda=args.gae_lambda, step_size=args.max_kl, optimizer=ConjugateGradientOptimizer(hvp_approach=FiniteDifferenceHvp( base_eps=1e-5)) if args.recurrent else None, mode=args.control if not args.chunked else 'chunk_{}'.format(args.control), ) algo.train()
def run_experiment(argv): default_log_dir = config.LOG_DIR now = datetime.datetime.now(dateutil.tz.tzlocal()) # avoid name clashes when running distributed jobs rand_id = str(uuid.uuid4())[:5] timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z') default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id) parser = argparse.ArgumentParser() parser.add_argument( '--n_parallel', type=int, default=1, help= 'Number of parallel workers to perform rollouts. 0 => don\'t start any workers' ) parser.add_argument('--exp_name', type=str, default=default_exp_name, help='Name of the experiment.') parser.add_argument('--log_dir', type=str, default=None, help='Path to save the log and iteration snapshot.') parser.add_argument('--snapshot_mode', type=str, default='all', help='Mode to save the snapshot. Can be either "all" ' '(all iterations will be saved), "last" (only ' 'the last iteration will be saved), "gap" (every' '`snapshot_gap` iterations are saved), or "none" ' '(do not save snapshots)') parser.add_argument('--snapshot_gap', type=int, default=1, help='Gap between snapshot iterations.') parser.add_argument('--tabular_log_file', type=str, default='progress.csv', help='Name of the tabular log file (in csv).') parser.add_argument('--text_log_file', type=str, default='debug.log', help='Name of the text log file (in pure text).') parser.add_argument('--tensorboard_log_dir', type=str, default='tb', help='Name of the folder for tensorboard_summary.') parser.add_argument( '--tensorboard_step_key', type=str, default=None, help= 'Name of the step key in log data which shows the step in tensorboard_summary.' ) parser.add_argument('--params_log_file', type=str, default='params.json', help='Name of the parameter log file (in json).') parser.add_argument('--variant_log_file', type=str, default='variant.json', help='Name of the variant log file (in json).') parser.add_argument( '--resume_from', type=str, default=None, help='Name of the pickle file to resume experiment from.') parser.add_argument('--plot', type=ast.literal_eval, default=False, help='Whether to plot the iteration results') parser.add_argument( '--log_tabular_only', type=ast.literal_eval, default=False, help= 'Whether to only print the tabular log information (in a horizontal format)' ) parser.add_argument('--seed', type=int, help='Random seed for numpy') parser.add_argument('--args_data', type=str, help='Pickled data for stub objects') parser.add_argument('--variant_data', type=str, help='Pickled data for variant configuration') parser.add_argument('--use_cloudpickle', type=ast.literal_eval, default=False) parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', help='Name of the folder for checkpoints.') parser.add_argument('--obs_dir', type=str, default='obs', help='Name of the folder for original observations.') args = parser.parse_args(argv[1:]) if args.seed is not None: set_seed(args.seed) if args.n_parallel > 0: from rllab.sampler import parallel_sampler parallel_sampler.initialize(n_parallel=args.n_parallel) if args.seed is not None: parallel_sampler.set_seed(args.seed) if args.plot: from rllab.plotter import plotter plotter.init_worker() if args.log_dir is None: log_dir = osp.join(default_log_dir, args.exp_name) else: log_dir = args.log_dir tabular_log_file = osp.join(log_dir, args.tabular_log_file) text_log_file = osp.join(log_dir, args.text_log_file) params_log_file = osp.join(log_dir, args.params_log_file) tensorboard_log_dir = osp.join(log_dir, args.tensorboard_log_dir) checkpoint_dir = osp.join(log_dir, args.checkpoint_dir) obs_dir = osp.join(log_dir, args.obs_dir) if args.variant_data is not None: variant_data = pickle.loads(base64.b64decode(args.variant_data)) variant_log_file = osp.join(log_dir, args.variant_log_file) logger.log_variant(variant_log_file, variant_data) else: variant_data = None if not args.use_cloudpickle: logger.log_parameters_lite(params_log_file, args) logger.add_text_output(text_log_file) logger.add_tabular_output(tabular_log_file) logger.set_tensorboard_dir(tensorboard_log_dir) logger.set_checkpoint_dir(checkpoint_dir) logger.set_obs_dir(obs_dir) prev_snapshot_dir = logger.get_snapshot_dir() prev_mode = logger.get_snapshot_mode() logger.set_snapshot_dir(log_dir) logger.set_snapshot_mode(args.snapshot_mode) logger.set_snapshot_gap(args.snapshot_gap) logger.set_log_tabular_only(args.log_tabular_only) logger.set_tensorboard_step_key(args.tensorboard_step_key) logger.push_prefix("[%s] " % args.exp_name) git_commit = get_git_commit_hash() logger.log('Git commit: {}'.format(git_commit)) git_diff_file_path = osp.join(log_dir, 'git_diff_{}.patch'.format(git_commit)) save_git_diff_to_file(git_diff_file_path) logger.log('hostname: {}, pid: {}, tmux session: {}'.format( socket.gethostname(), os.getpid(), get_tmux_session_name())) if args.resume_from is not None: data = joblib.load(args.resume_from) assert 'algo' in data algo = data['algo'] algo.train() else: # read from stdin if args.use_cloudpickle: import cloudpickle method_call = cloudpickle.loads(base64.b64decode(args.args_data)) method_call(variant_data) else: data = pickle.loads(base64.b64decode(args.args_data)) maybe_iter = concretize(data) if is_iterable(maybe_iter): for _ in maybe_iter: pass logger.set_snapshot_mode(prev_mode) logger.set_snapshot_dir(prev_snapshot_dir) logger.remove_tabular_output(tabular_log_file) logger.remove_text_output(text_log_file) logger.pop_prefix()
if pickled_mode: run_experiment_lite( algo.train_seek(), exp_prefix="trpo-expl", n_parallel=2, snapshot_mode="last", seed=seed, mode="local", script="rllab/run_experiment_lite.py", ) else: from sandbox.vime.sampler import parallel_sampler_expl as parallel_sampler parallel_sampler.initialize(n_parallel=n_parallel) if seed is not None: set_seed(seed) parallel_sampler.set_seed(seed) if plot: from rllab.plotter import plotter plotter.init_worker() tabular_log_file_fullpath = osp.join(log_dir, tabular_log_file) text_log_file_fullpath = osp.join(log_dir, text_log_file) # params_log_file_fullpath = osp.join(log_dir, params_log_file) params_all_log_file_fullpath = osp.join(log_dir, params_all_log_file) # logger.log_parameters_lite(params_log_file, args) logger.add_text_output(text_log_file_fullpath) logger.add_tabular_output(tabular_log_file_fullpath) prev_snapshot_dir = logger.get_snapshot_dir()
def __init__(self, env, args): self.args = args # Parallel setup parallel_sampler.initialize(n_parallel=args.n_parallel) if args.seed is not None: set_seed(args.seed) parallel_sampler.set_seed(args.seed) env, policy = rllab_envpolicy_parser(env, args) if not args.algo == 'thddpg': # Baseline if args.baseline_type == 'linear': baseline = LinearFeatureBaseline(env_spec=env.spec) elif args.baseline_type == 'zero': baseline = ZeroBaseline(env_spec=env.spec) else: raise NotImplementedError(args.baseline_type) # Logger default_log_dir = config.LOG_DIR if args.log_dir is None: log_dir = osp.join(default_log_dir, args.exp_name) else: log_dir = args.log_dir tabular_log_file = osp.join(log_dir, args.tabular_log_file) text_log_file = osp.join(log_dir, args.text_log_file) params_log_file = osp.join(log_dir, args.params_log_file) logger.log_parameters_lite(params_log_file, args) logger.add_text_output(text_log_file) logger.add_tabular_output(tabular_log_file) prev_snapshot_dir = logger.get_snapshot_dir() prev_mode = logger.get_snapshot_mode() logger.set_snapshot_dir(log_dir) logger.set_snapshot_mode(args.snapshot_mode) logger.set_log_tabular_only(args.log_tabular_only) logger.push_prefix("[%s] " % args.exp_name) if args.algo == 'tftrpo': self.algo = TRPO( env=env, policy=policy, baseline=baseline, batch_size=args.batch_size, max_path_length=args.max_path_length, n_itr=args.n_iter, discount=args.discount, gae_lambda=args.gae_lambda, step_size=args.step_size, optimizer=ConjugateGradientOptimizer( hvp_approach=FiniteDifferenceHvp( base_eps=1e-5)) if args.recurrent else None, mode=args.control) elif args.algo == 'thddpg': qfunc = thContinuousMLPQFunction(env_spec=env.spec) if args.exp_strategy == 'ou': es = OUStrategy(env_spec=env.spec) elif args.exp_strategy == 'gauss': es = GaussianStrategy(env_spec=env.spec) else: raise NotImplementedError() self.algo = thDDPG(env=env, policy=policy, qf=qfunc, es=es, batch_size=args.batch_size, max_path_length=args.max_path_length, epoch_length=args.epoch_length, min_pool_size=args.min_pool_size, replay_pool_size=args.replay_pool_size, n_epochs=args.n_iter, discount=args.discount, scale_reward=0.01, qf_learning_rate=args.qfunc_lr, policy_learning_rate=args.policy_lr, eval_samples=args.eval_samples, mode=args.control)
def main(): now = datetime.datetime.now(dateutil.tz.tzlocal()) rand_id = str(uuid.uuid4())[:5] timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z') default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id) parser = argparse.ArgumentParser() parser.add_argument( '--exp_name', type=str, default=default_exp_name, help='Name of the experiment.') parser.add_argument('--discount', type=float, default=0.99) parser.add_argument('--gae_lambda', type=float, default=1.0) parser.add_argument('--reward_scale', type=float, default=1.0) parser.add_argument('--n_iter', type=int, default=250) parser.add_argument('--sampler_workers', type=int, default=1) parser.add_argument('--max_traj_len', type=int, default=250) parser.add_argument('--update_curriculum', action='store_true', default=False) parser.add_argument('--n_timesteps', type=int, default=8000) parser.add_argument('--control', type=str, default='centralized') parser.add_argument('--rectangle', type=str, default='10,10') parser.add_argument('--map_type', type=str, default='rectangle') parser.add_argument('--n_evaders', type=int, default=5) parser.add_argument('--n_pursuers', type=int, default=2) parser.add_argument('--obs_range', type=int, default=3) parser.add_argument('--n_catch', type=int, default=2) parser.add_argument('--urgency', type=float, default=0.0) parser.add_argument('--pursuit', dest='train_pursuit', action='store_true') parser.add_argument('--evade', dest='train_pursuit', action='store_false') parser.set_defaults(train_pursuit=True) parser.add_argument('--surround', action='store_true', default=False) parser.add_argument('--constraint_window', type=float, default=1.0) parser.add_argument('--sample_maps', action='store_true', default=False) parser.add_argument('--map_file', type=str, default='../maps/map_pool.npy') parser.add_argument('--flatten', action='store_true', default=False) parser.add_argument('--reward_mech', type=str, default='global') parser.add_argument('--catchr', type=float, default=0.1) parser.add_argument('--term_pursuit', type=float, default=5.0) parser.add_argument('--recurrent', type=str, default=None) parser.add_argument('--policy_hidden_sizes', type=str, default='128,128') parser.add_argument('--baselin_hidden_sizes', type=str, default='128,128') parser.add_argument('--baseline_type', type=str, default='linear') parser.add_argument('--conv', action='store_true', default=False) parser.add_argument('--max_kl', type=float, default=0.01) parser.add_argument('--log_dir', type=str, required=False) parser.add_argument('--tabular_log_file', type=str, default='progress.csv', help='Name of the tabular log file (in csv).') parser.add_argument('--text_log_file', type=str, default='debug.log', help='Name of the text log file (in pure text).') parser.add_argument('--params_log_file', type=str, default='params.json', help='Name of the parameter log file (in json).') parser.add_argument('--seed', type=int, help='Random seed for numpy') parser.add_argument('--args_data', type=str, help='Pickled data for stub objects') parser.add_argument('--snapshot_mode', type=str, default='all', help='Mode to save the snapshot. Can be either "all" ' '(all iterations will be saved), "last" (only ' 'the last iteration will be saved), or "none" ' '(do not save snapshots)') parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False, help='Whether to only print the tabular log information (in a horizontal format)') args = parser.parse_args() parallel_sampler.initialize(n_parallel=args.sampler_workers) if args.seed is not None: set_seed(args.seed) parallel_sampler.set_seed(args.seed) args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(','))) if args.sample_maps: map_pool = np.load(args.map_file) else: if args.map_type == 'rectangle': env_map = TwoDMaps.rectangle_map(*map(int, args.rectangle.split(','))) elif args.map_type == 'complex': env_map = TwoDMaps.complex_map(*map(int, args.rectangle.split(','))) else: raise NotImplementedError() map_pool = [env_map] env = PursuitEvade(map_pool, n_evaders=args.n_evaders, n_pursuers=args.n_pursuers, obs_range=args.obs_range, n_catch=args.n_catch, train_pursuit=args.train_pursuit, urgency_reward=args.urgency, surround=args.surround, sample_maps=args.sample_maps, constraint_window=args.constraint_window, flatten=args.flatten, reward_mech=args.reward_mech, catchr=args.catchr, term_pursuit=args.term_pursuit) env = RLLabEnv( StandardizedEnv(env, scale_reward=args.reward_scale, enable_obsnorm=False), mode=args.control) if args.recurrent: if args.conv: feature_network = ConvNetwork( input_shape=emv.spec.observation_space.shape, output_dim=5, conv_filters=(8,16,16), conv_filter_sizes=(3,3,3), conv_strides=(1,1,1), conv_pads=('VALID','VALID','VALID'), hidden_sizes=(64,), hidden_nonlinearity=NL.rectify, output_nonlinearity=NL.softmax) else: feature_network = MLP( input_shape=(env.spec.observation_space.flat_dim + env.spec.action_space.flat_dim,), output_dim=5, hidden_sizes=(128,128,128), hidden_nonlinearity=NL.tanh, output_nonlinearity=None) if args.recurrent == 'gru': policy = CategoricalGRUPolicy(env_spec=env.spec, feature_network=feature_network, hidden_dim=int(args.policy_hidden_sizes)) elif args.conv: feature_network = ConvNetwork( input_shape=env.spec.observation_space.shape, output_dim=5, conv_filters=(8,16,16), conv_filter_sizes=(3,3,3), conv_strides=(1,1,1), conv_pads=('valid','valid','valid'), hidden_sizes=(64,), hidden_nonlinearity=NL.rectify, output_nonlinearity=NL.softmax) policy = CategoricalMLPPolicy(env_spec=env.spec, prob_network=feature_network) else: policy = CategoricalMLPPolicy(env_spec=env.spec, hidden_sizes=args.hidden_sizes) if args.baseline_type == 'linear': baseline = LinearFeatureBaseline(env_spec=env.spec) else: baseline = ZeroBaseline(obsfeat_space) # logger default_log_dir = config.LOG_DIR if args.log_dir is None: log_dir = osp.join(default_log_dir, args.exp_name) else: log_dir = args.log_dir tabular_log_file = osp.join(log_dir, args.tabular_log_file) text_log_file = osp.join(log_dir, args.text_log_file) params_log_file = osp.join(log_dir, args.params_log_file) logger.log_parameters_lite(params_log_file, args) logger.add_text_output(text_log_file) logger.add_tabular_output(tabular_log_file) prev_snapshot_dir = logger.get_snapshot_dir() prev_mode = logger.get_snapshot_mode() logger.set_snapshot_dir(log_dir) logger.set_snapshot_mode(args.snapshot_mode) logger.set_log_tabular_only(args.log_tabular_only) logger.push_prefix("[%s] " % args.exp_name) algo = TRPO( env=env, policy=policy, baseline=baseline, batch_size=args.n_timesteps, max_path_length=args.max_traj_len, n_itr=args.n_iter, discount=args.discount, gae_lambda=args.gae_lambda, step_size=args.max_kl, mode=args.control,) algo.train()
from sandbox.rocky.tf.policies.deterministic_mlp_policy import \ DeterministicMLPPolicy from sandbox.rocky.tf.q_functions.continuous_mlp_q_function import \ ContinuousMLPQFunction parser = argparse.ArgumentParser() parser.add_argument("env", help="The environment name from OpenAIGym environments") parser.add_argument("--num_epochs", default=250, type=int) parser.add_argument("--data_dir", default="./data_ddpg/") parser.add_argument("--reward_scale", default=1.0, type=float) parser.add_argument("--use_ec2", action="store_true", help="Use your ec2 instances if configured") parser.add_argument("--dont_terminate_machine", action="store_false", help="Whether to terminate your spot instance or not. Be careful.") args = parser.parse_args() stub(globals()) ext.set_seed(1) gymenv = GymEnv(args.env, force_reset=True, record_video=True, record_log=True) env = TfEnv(normalize(gymenv)) policy = DeterministicMLPPolicy( env_spec=env.spec, name="policy", # The neural network policy should have two hidden layers, each with 32 hidden units. hidden_sizes=(100, 50, 25), hidden_nonlinearity=tf.nn.relu, ) es = OUStrategy(env_spec=env.spec)
def run_experiment(argv): default_log_dir = config.LOG_DIR now = datetime.datetime.now(dateutil.tz.tzlocal()) # avoid name clashes when running distributed jobs rand_id = str(uuid.uuid4())[:5] timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z') default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id) parser = argparse.ArgumentParser() parser.add_argument('--n_parallel', type=int, default=1, help='Number of parallel workers to perform rollouts.') parser.add_argument('--exp_name', type=str, default=default_exp_name, help='Name of the experiment.') parser.add_argument('--log_dir', type=str, default=default_log_dir, help='Path to save the log and iteration snapshot.') parser.add_argument('--snapshot_mode', type=str, default='all', help='Mode to save the snapshot. Can be either "all" ' '(all iterations will be saved), "last" (only ' 'the last iteration will be saved), or "none" ' '(do not save snapshots)') parser.add_argument('--snapshot_gap', type=int, default=1, help='Gap between snapshot iterations.') parser.add_argument('--tabular_log_file', type=str, default='progress.csv', help='Name of the tabular log file (in csv).') parser.add_argument('--text_log_file', type=str, default='debug.log', help='Name of the text log file (in pure text).') parser.add_argument('--params_log_file', type=str, default='params.json', help='Name of the parameter log file (in json).') parser.add_argument('--plot', type=ast.literal_eval, default=False, help='Whether to plot the iteration results') parser.add_argument( '--log_tabular_only', type=ast.literal_eval, default=False, help= 'Whether to only print the tabular log information (in a horizontal format)' ) parser.add_argument('--seed', type=int, help='Random seed for numpy') parser.add_argument('--args_data', type=str, help='Pickled data for stub objects') parser.add_argument('--use_cloudpickle', type=ast.literal_eval, default=False, help='Whether to plot the iteration results') args = parser.parse_args(argv[1:]) if args.seed is not None: set_seed(args.seed) if args.n_parallel > 0: from sandbox.vase.sampler import parallel_sampler_expl as parallel_sampler parallel_sampler.initialize(n_parallel=args.n_parallel) if args.seed is not None: set_seed(args.seed) parallel_sampler.set_seed(args.seed) if args.plot: from rllab.plotter import plotter plotter.init_worker() # read from stdin data = pickle.loads(base64.b64decode(args.args_data)) log_dir = args.log_dir # exp_dir = osp.join(log_dir, args.exp_name) tabular_log_file = osp.join(log_dir, args.tabular_log_file) text_log_file = osp.join(log_dir, args.text_log_file) params_log_file = osp.join(log_dir, args.params_log_file) logger.log_parameters_lite(params_log_file, args) logger.add_text_output(text_log_file) logger.add_tabular_output(tabular_log_file) prev_snapshot_dir = logger.get_snapshot_dir() prev_mode = logger.get_snapshot_mode() logger.set_snapshot_gap(args.snapshot_gap) logger.set_snapshot_dir(log_dir) logger.set_snapshot_mode(args.snapshot_mode) logger.set_log_tabular_only(args.log_tabular_only) logger.push_prefix("[%s] " % args.exp_name) maybe_iter = concretize(data) if is_iterable(maybe_iter): for _ in maybe_iter: pass logger.set_snapshot_mode(prev_mode) logger.set_snapshot_dir(prev_snapshot_dir) logger.remove_tabular_output(tabular_log_file) logger.remove_text_output(text_log_file) logger.pop_prefix()
def main(): now = datetime.datetime.now(dateutil.tz.tzlocal()) rand_id = str(uuid.uuid4())[:5] timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z') default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id) parser = argparse.ArgumentParser() parser.add_argument('--exp_name', type=str, default=default_exp_name, help='Name of the experiment.') parser.add_argument('--discount', type=float, default=0.95) parser.add_argument('--gae_lambda', type=float, default=0.99) parser.add_argument('--reward_scale', type=float, default=1.0) parser.add_argument('--enable_obsnorm', action='store_true', default=False) parser.add_argument('--chunked', action='store_true', default=False) parser.add_argument('--n_iter', type=int, default=250) parser.add_argument('--sampler_workers', type=int, default=1) parser.add_argument('--max_traj_len', type=int, default=250) parser.add_argument('--update_curriculum', action='store_true', default=False) parser.add_argument('--anneal_step_size', type=int, default=0) parser.add_argument('--n_timesteps', type=int, default=8000) parser.add_argument('--control', type=str, default='centralized') parser.add_argument('--buffer_size', type=int, default=1) parser.add_argument('--radius', type=float, default=0.015) parser.add_argument('--n_evaders', type=int, default=10) parser.add_argument('--n_pursuers', type=int, default=8) parser.add_argument('--n_poison', type=int, default=10) parser.add_argument('--n_coop', type=int, default=4) parser.add_argument('--n_sensors', type=int, default=30) parser.add_argument('--sensor_range', type=str, default='0.2') parser.add_argument('--food_reward', type=float, default=5) parser.add_argument('--poison_reward', type=float, default=-1) parser.add_argument('--encounter_reward', type=float, default=0.05) parser.add_argument('--reward_mech', type=str, default='local') parser.add_argument('--recurrent', type=str, default=None) parser.add_argument('--baseline_type', type=str, default='linear') parser.add_argument('--policy_hidden_sizes', type=str, default='128,128') parser.add_argument('--baseline_hidden_sizes', type=str, default='128,128') parser.add_argument('--max_kl', type=float, default=0.01) parser.add_argument('--log_dir', type=str, required=False) parser.add_argument('--tabular_log_file', type=str, default='progress.csv', help='Name of the tabular log file (in csv).') parser.add_argument('--text_log_file', type=str, default='debug.log', help='Name of the text log file (in pure text).') parser.add_argument('--params_log_file', type=str, default='params.json', help='Name of the parameter log file (in json).') parser.add_argument('--seed', type=int, help='Random seed for numpy') parser.add_argument('--args_data', type=str, help='Pickled data for stub objects') parser.add_argument('--snapshot_mode', type=str, default='all', help='Mode to save the snapshot. Can be either "all" ' '(all iterations will be saved), "last" (only ' 'the last iteration will be saved), or "none" ' '(do not save snapshots)') parser.add_argument( '--log_tabular_only', type=ast.literal_eval, default=False, help='Whether to only print the tabular log information (in a horizontal format)') args = parser.parse_args() parallel_sampler.initialize(n_parallel=args.sampler_workers) if args.seed is not None: set_seed(args.seed) parallel_sampler.set_seed(args.seed) args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(','))) centralized = True if args.control == 'centralized' else False sensor_range = np.array(map(float, args.sensor_range.split(','))) if len(sensor_range) == 1: sensor_range = sensor_range[0] else: assert sensor_range.shape == (args.n_pursuers,) env = MAWaterWorld(args.n_pursuers, args.n_evaders, args.n_coop, args.n_poison, radius=args.radius, n_sensors=args.n_sensors, food_reward=args.food_reward, poison_reward=args.poison_reward, encounter_reward=args.encounter_reward, reward_mech=args.reward_mech, sensor_range=sensor_range, obstacle_loc=None) env = TfEnv( RLLabEnv( StandardizedEnv(env, scale_reward=args.reward_scale, enable_obsnorm=args.enable_obsnorm), mode=args.control)) if args.buffer_size > 1: env = ObservationBuffer(env, args.buffer_size) if args.recurrent: feature_network = MLP( name='feature_net', input_shape=(env.spec.observation_space.flat_dim + env.spec.action_space.flat_dim,), output_dim=16, hidden_sizes=(128, 64, 32), hidden_nonlinearity=tf.nn.tanh, output_nonlinearity=None) if args.recurrent == 'gru': policy = GaussianGRUPolicy(env_spec=env.spec, feature_network=feature_network, hidden_dim=int(args.policy_hidden_sizes), name='policy') elif args.recurrent == 'lstm': policy = GaussianLSTMPolicy(env_spec=env.spec, feature_network=feature_network, hidden_dim=int(args.policy_hidden_sizes), name='policy') else: policy = GaussianMLPPolicy( name='policy', env_spec=env.spec, hidden_sizes=tuple(map(int, args.policy_hidden_sizes.split(','))), min_std=10e-5) if args.baseline_type == 'linear': baseline = LinearFeatureBaseline(env_spec=env.spec) elif args.baseline_type == 'mlp': raise NotImplementedError() # baseline = GaussianMLPBaseline( # env_spec=env.spec, hidden_sizes=tuple(map(int, args.baseline_hidden_sizes.split(',')))) else: baseline = ZeroBaseline(env_spec=env.spec) # logger default_log_dir = config.LOG_DIR if args.log_dir is None: log_dir = osp.join(default_log_dir, args.exp_name) else: log_dir = args.log_dir tabular_log_file = osp.join(log_dir, args.tabular_log_file) text_log_file = osp.join(log_dir, args.text_log_file) params_log_file = osp.join(log_dir, args.params_log_file) logger.log_parameters_lite(params_log_file, args) logger.add_text_output(text_log_file) logger.add_tabular_output(tabular_log_file) prev_snapshot_dir = logger.get_snapshot_dir() prev_mode = logger.get_snapshot_mode() logger.set_snapshot_dir(log_dir) logger.set_snapshot_mode(args.snapshot_mode) logger.set_log_tabular_only(args.log_tabular_only) logger.push_prefix("[%s] " % args.exp_name) algo = TRPO( env=env, policy=policy, baseline=baseline, batch_size=args.n_timesteps, max_path_length=args.max_traj_len, #max_path_length_limit=args.max_path_length_limit, update_max_path_length=args.update_curriculum, anneal_step_size=args.anneal_step_size, n_itr=args.n_iter, discount=args.discount, gae_lambda=args.gae_lambda, step_size=args.max_kl, optimizer=ConjugateGradientOptimizer(hvp_approach=FiniteDifferenceHvp(base_eps=1e-5)) if args.recurrent else None, mode=args.control if not args.chunked else 'chunk_{}'.format(args.control),) algo.train()
def run_experiment(argv): default_log_dir = config.LOG_DIR now = datetime.datetime.now(dateutil.tz.tzlocal()) # avoid name clashes when running distributed jobs rand_id = str(uuid.uuid4())[:5] timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z') default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id) parser = argparse.ArgumentParser() parser.add_argument('--n_parallel', type=int, default=1, help='Number of parallel workers to perform rollouts.') parser.add_argument( '--exp_name', type=str, default=default_exp_name, help='Name of the experiment.') parser.add_argument('--log_dir', type=str, default=default_log_dir, help='Path to save the log and iteration snapshot.') parser.add_argument('--snapshot_mode', type=str, default='all', help='Mode to save the snapshot. Can be either "all" ' '(all iterations will be saved), "last" (only ' 'the last iteration will be saved), or "none" ' '(do not save snapshots)') parser.add_argument('--tabular_log_file', type=str, default='progress.csv', help='Name of the tabular log file (in csv).') parser.add_argument('--text_log_file', type=str, default='debug.log', help='Name of the text log file (in pure text).') parser.add_argument('--params_log_file', type=str, default='params.json', help='Name of the parameter log file (in json).') parser.add_argument('--plot', type=ast.literal_eval, default=False, help='Whether to plot the iteration results') parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False, help='Whether to only print the tabular log information (in a horizontal format)') parser.add_argument('--seed', type=int, help='Random seed for numpy') parser.add_argument('--args_data', type=str, help='Pickled data for stub objects') args = parser.parse_args(argv[1:]) from sandbox.vime.sampler import parallel_sampler_expl as parallel_sampler parallel_sampler.initialize(n_parallel=args.n_parallel) if args.seed is not None: set_seed(args.seed) parallel_sampler.set_seed(args.seed) if args.plot: from rllab.plotter import plotter plotter.init_worker() # read from stdin data = pickle.loads(base64.b64decode(args.args_data)) log_dir = args.log_dir # exp_dir = osp.join(log_dir, args.exp_name) tabular_log_file = osp.join(log_dir, args.tabular_log_file) text_log_file = osp.join(log_dir, args.text_log_file) params_log_file = osp.join(log_dir, args.params_log_file) logger.log_parameters_lite(params_log_file, args) logger.add_text_output(text_log_file) logger.add_tabular_output(tabular_log_file) prev_snapshot_dir = logger.get_snapshot_dir() prev_mode = logger.get_snapshot_mode() logger.set_snapshot_dir(log_dir) logger.set_snapshot_mode(args.snapshot_mode) logger.set_log_tabular_only(args.log_tabular_only) logger.push_prefix("[%s] " % args.exp_name) maybe_iter = concretize(data) if is_iterable(maybe_iter): for _ in maybe_iter: pass logger.set_snapshot_mode(prev_mode) logger.set_snapshot_dir(prev_snapshot_dir) logger.remove_tabular_output(tabular_log_file) logger.remove_text_output(text_log_file) logger.pop_prefix()
def run_experiment(argv): default_log_dir = config.LOG_DIR now = datetime.datetime.now(dateutil.tz.tzlocal()) # avoid name clashes when running distributed jobs rand_id = str(uuid.uuid4())[:5] timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z') default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id) parser = argparse.ArgumentParser() parser.add_argument( '--n_parallel', type=int, default=1, help= 'Number of parallel workers to perform rollouts. 0 => don\'t start any workers' ) parser.add_argument('--exp_name', type=str, default=default_exp_name, help='Name of the experiment.') parser.add_argument('--log_dir', type=str, default=None, help='Path to save the log and iteration snapshot.') parser.add_argument('--snapshot_mode', type=str, default='all', help='Mode to save the snapshot. Can be either "all" ' '(all iterations will be saved), "last" (only ' 'the last iteration will be saved), "gap" (every' '`snapshot_gap` iterations are saved), or "none" ' '(do not save snapshots)') parser.add_argument('--snapshot_gap', type=int, default=1, help='Gap between snapshot iterations.') parser.add_argument('--tabular_log_file', type=str, default='progress.csv', help='Name of the tabular log file (in csv).') parser.add_argument('--text_log_file', type=str, default='debug.log', help='Name of the text log file (in pure text).') parser.add_argument('--params_log_file', type=str, default='params.json', help='Name of the parameter log file (in json).') parser.add_argument('--variant_log_file', type=str, default='variant.json', help='Name of the variant log file (in json).') parser.add_argument( '--resume_from', type=str, default=None, help='Name of the pickle file to resume experiment from.') parser.add_argument('--plot', type=ast.literal_eval, default=False, help='Whether to plot the iteration results') parser.add_argument( '--log_tabular_only', type=ast.literal_eval, default=False, help= 'Whether to only print the tabular log information (in a horizontal format)' ) parser.add_argument('--seed', type=int, help='Random seed for numpy') parser.add_argument('--args_data', type=str, help='Pickled data for stub objects') parser.add_argument('--variant_data', type=str, help='Pickled data for variant configuration') parser.add_argument('--use_cloudpickle', type=ast.literal_eval, default=False) args = parser.parse_args(argv[1:]) if args.seed is not None: set_seed(args.seed) if args.n_parallel > 0: from rllab.sampler import parallel_sampler parallel_sampler.initialize(n_parallel=args.n_parallel) if args.seed is not None: parallel_sampler.set_seed(args.seed) if args.plot: from rllab.plotter import plotter plotter.init_worker() if args.log_dir is None: log_dir = osp.join(default_log_dir, args.exp_name) else: log_dir = args.log_dir tabular_log_file = osp.join(log_dir, args.tabular_log_file) text_log_file = osp.join(log_dir, args.text_log_file) params_log_file = osp.join(log_dir, args.params_log_file) if args.variant_data is not None: variant_data = pickle.loads(base64.b64decode(args.variant_data)) variant_log_file = osp.join(log_dir, args.variant_log_file) logger.log_variant(variant_log_file, variant_data) else: variant_data = None if not args.use_cloudpickle: logger.log_parameters_lite(params_log_file, args) logger.add_text_output(text_log_file) logger.add_tabular_output(tabular_log_file) prev_snapshot_dir = logger.get_snapshot_dir() prev_mode = logger.get_snapshot_mode() logger.set_snapshot_dir(log_dir) logger.set_snapshot_mode(args.snapshot_mode) logger.set_snapshot_gap(args.snapshot_gap) logger.set_log_tabular_only(args.log_tabular_only) logger.push_prefix("[%s] " % args.exp_name) #variant_data is the variant dictionary sent from trpoTests_ExpLite if (args.resume_from is not None) and ( '&|&' in args.resume_from ): #separate string on &|& to get iters and file location vals = args.resume_from.split( '&|&') #dirRes | numItrs to go | new batchSize dirRes = vals[0] numItrs = int(vals[1]) if (len(vals) > 2): batchSize = int(vals[2]) print("resuming from :{}".format(dirRes)) data = joblib.load(dirRes) #data is dict : 'baseline', 'algo', 'itr', 'policy', 'env' assert 'algo' in data algo = data['algo'] assert 'policy' in data pol = data['policy'] bl = data['baseline'] oldBatchSize = algo.batch_size algo.n_itr = numItrs if (len(vals) > 2): algo.batch_size = batchSize print( 'algo iters : {} cur iter :{} oldBatchSize : {} newBatchSize : {}' .format(algo.n_itr, algo.current_itr, oldBatchSize, algo.batch_size)) else: print('algo iters : {} cur iter :{} '.format( algo.n_itr, algo.current_itr)) algo.train() else: print('Not resuming - building new exp') # read from stdin if args.use_cloudpickle: #set to use cloudpickle import cloudpickle method_call = cloudpickle.loads(base64.b64decode(args.args_data)) method_call(variant_data) else: print('not use cloud pickle') data = pickle.loads(base64.b64decode(args.args_data)) maybe_iter = concretize(data) if is_iterable(maybe_iter): for _ in maybe_iter: pass logger.set_snapshot_mode(prev_mode) logger.set_snapshot_dir(prev_snapshot_dir) logger.remove_tabular_output(tabular_log_file) logger.remove_text_output(text_log_file) logger.pop_prefix()