def compute_gradients(self, loss, var_list, **kwargs): grads_and_vars = tf.train.AdamOptimizer.compute_gradients( self, loss, var_list, **kwargs) grads_and_vars = [(g, v) for g, v in grads_and_vars if g is not None] flat_grad = tf.concat( [tf.reshape(g, (-1, )) for g, v in grads_and_vars], axis=0) if Config.is_test_rank(): flat_grad = tf.zeros_like(flat_grad) shapes = [v.shape.as_list() for g, v in grads_and_vars] sizes = [int(np.prod(s)) for s in shapes] num_tasks = self.comm.Get_size() buf = np.zeros(sum(sizes), np.float32) def _collect_grads(flat_grad): self.comm.Allreduce(flat_grad, buf, op=MPI.SUM) np.divide(buf, float(num_tasks) * self.train_frac, out=buf) return buf avg_flat_grad = tf.compat.v1.py_func(_collect_grads, [flat_grad], tf.float32) avg_flat_grad.set_shape(flat_grad.shape) avg_grads = tf.split(avg_flat_grad, sizes, axis=0) avg_grads_and_vars = [(tf.reshape(g, v.shape), v) for g, (_, v) in zip(avg_grads, grads_and_vars)] return avg_grads_and_vars
def restore_file(restore_id, base_name=None, overlap_config=None, load_key='default'): """overlap config means you can modify the config in savefile, e.g. test seed""" if restore_id is not None: load_file = Config.get_load_filename(restore_id=restore_id, base_name=base_name) filepath = file_to_path(load_file) assert os.path.exists(filepath), "don't exist" load_data = joblib.load(filepath) Config.set_load_data(load_data, load_key=load_key) restored_args = load_data['args'] sub_dict = {} res_keys = Config.RES_KEYS for key in res_keys: if key in restored_args: sub_dict[key] = restored_args[key] else: print('warning key %s not restored' % key) Config.parse_args_dict(sub_dict) if overlap_config is not None: Config.parse_args_dict(overlap_config) print(Config.SET_SEED, Config.NUM_LEVELS) print("Init coinrun env threads and env args") init_args_and_threads(4) if restore_id == None: return None else: return load_file
def load_model(sess, base_name=None): filename = Config.get_save_file(base_name=base_name) print(filename) utils.load_params_for_scope(sess, 'model', load_path=filename, load_key='default') datapoints = utils.load_datapoints(load_path=filename) return datapoints
def save_model(sess, datapoints=None, base_name=None): base_dict = {} if datapoints is not None: base_dict['datapoints'] = datapoints # sess, scopes, filename, base_dict=None utils.save_params_in_scopes(sess, ['model'], Config.get_save_file(base_name=base_name), base_dict)
def load_args(load_key='default'): """get train args of retore id""" load_data = Config.get_load_data(load_key) if load_data is None: return False args_dict = load_data['args'] #Config.parse_args_dict(args_dict) return args_dict
def main(): setup_utils.setup_and_load() wandb_log = True if wandb_log: wandb.init(project="coinrun", name=Config.RESTORE_ID + 'test', config=Config.get_args_dict()) with tf.Session() as sess: for i in range(0, 256, 8): enjoy_env_sess(sess, i, wandb_log)
def load_model(sess, base_name=None): filename = Config.get_save_file(base_name=base_name) is_loaded = utils.load_params_for_scope(sess, 'model', load_path=filename, load_key='default') #datapoints = utils.load_datapoints(load_path=filename) if is_loaded: return filename else: return is_loaded
def load_datapoints(load_path=None, load_key=None): if load_path is None: load_data = Config.get_load_data(load_key) else: load_path = file_to_path(load_path) if os.path.exists(load_path): load_data = joblib.load(load_path) print('Load file', load_path) if load_data is None: return False return load_data['datapoints']
def restore_file_back(restore_id, load_key='default'): if restore_id is not None: load_file = Config.get_load_filename(restore_id=restore_id) filepath = file_to_path(load_file) load_data = joblib.load(filepath) Config.set_load_data(load_data, load_key=load_key) restored_args = load_data['args'] sub_dict = {} res_keys = Config.RES_KEYS for key in res_keys: if key in restored_args: sub_dict[key] = restored_args[key] else: print('warning key %s not restored' % key) Config.parse_args_dict(sub_dict) from coinrun.coinrunenv import init_args_and_threads init_args_and_threads(4)
def setup_and_load(use_cmd_line_args=True, **kwargs): """ Initialize the global config using command line options, defaulting to the values in `config.py`. `use_cmd_line_args`: set to False to ignore command line arguments passed to the program `**kwargs`: override the defaults from `config.py` with these values """ args = Config.initialize_args(use_cmd_line_args=use_cmd_line_args, **kwargs) load_for_setup_if_necessary() return args
def save_params_in_scopes(sess, scopes, filename, base_dict=None): data_dict = {} if base_dict is not None: data_dict.update(base_dict) save_path = file_to_path(filename) data_dict['args'] = Config.get_args_dict() param_dict = {} for scope in scopes: params = tf.trainable_variables(scope) if len(params) > 0: print('saving scope', scope, filename) ps = sess.run(params) param_dict[scope] = ps data_dict['params'] = param_dict joblib.dump(data_dict, save_path)
def load_params_for_scope(sess, scope, load_key='default', load_path=None): if load_path is None: load_data = Config.get_load_data(load_key) else: load_path = file_to_path(load_path) print('Load file', load_path) if os.path.exists(load_path): load_data = joblib.load(load_path) print('Load file', load_path) else: raise ValueError if load_data is None: return False params_dict = load_data['params'] if scope in params_dict: print('Loading saved file for scope', scope) loaded_params = params_dict[scope] loaded_params, params = get_savable_params(loaded_params, scope, keep_heads=True) restore_params(sess, loaded_params, params) return True
def __init__(self, comm, **kwargs): self.comm = comm self.train_frac = 1.0 - Config.get_test_frac() tf.train.AdamOptimizer.__init__(self, **kwargs)
def setup_and_load(use_cmd=True, **kwargs): args = Config.initialize_args(use_cmd=True, **kwargs) init_args_and_threads(4) return args
def __init__(self, sess): comm = MPI.COMM_WORLD rank = comm.Get_rank() clean_tb_dir() tb_writer = tf.summary.FileWriter( Config.TB_DIR + '/' + Config.RUN_ID + '_' + str(rank), sess.graph) total_steps = [0] should_log = (rank == 0 or Config.LOG_ALL_MPI) if should_log: hyperparams = np.array(Config.get_arg_text()) hyperparams_tensor = tf.constant(hyperparams) summary_op = tf.summary.text("hyperparameters info", hyperparams_tensor) summary = sess.run(summary_op) tb_writer.add_summary(summary) def add_summary(_merged, interval=1): if should_log: total_steps[0] += 1 if total_steps[0] % interval == 0: tb_writer.add_summary(_merged, total_steps[0]) tb_writer.flush() tuples = [] def make_scalar_graph(name): scalar_ph = tf.placeholder(name='scalar_' + name, dtype=tf.float32) scalar_summary = tf.compat.v1.summary.scalar(name, scalar_ph) merged = tf.compat.v1.summary.merge([scalar_summary]) tuples.append((scalar_ph, merged)) name_dict = {} curr_name_idx = [0] def log_scalar(x, name, step=-1): if not name in name_dict: name_dict[name] = curr_name_idx[0] tf_name = (name + '_' + Config.RUN_ID) if curr_name_idx[0] == 0 else name make_scalar_graph(tf_name) curr_name_idx[0] += 1 idx = name_dict[name] scalar_ph, merged = tuples[idx] if should_log: if step == -1: step = total_steps[0] total_steps[0] += 1 _merged = sess.run(merged, {scalar_ph: x}) tb_writer.add_summary(_merged, step) tb_writer.flush() self.add_summary = add_summary self.log_scalar = log_scalar
def main(): args = setup_and_load() comm = MPI.COMM_WORLD rank = comm.Get_rank() size = comm.Get_size() seed = int(time.time()) % 10000 utils.mpi_print(seed * 100 + rank) set_global_seeds(seed * 100 + rank) # For wandb package to visualize results curves config = Config.get_args_dict() config['global_seed'] = seed wandb.init(name=config["run_id"], project="coinrun", notes=" GARL generate seed", tags=["try"], config=config) utils.setup_mpi_gpus() utils.mpi_print('Set up gpu', args) config = tf.ConfigProto() config.gpu_options.allow_growth = True # pylint: disable=E1101 eval_limit = Config.EVAL_STEP * 10**6 phase_eval_limit = int(eval_limit // Config.TRAIN_ITER) total_timesteps = int(Config.TOTAL_STEP * 10**6) phase_timesteps = int((total_timesteps - eval_limit) // Config.TRAIN_ITER) with tf.Session(config=config): sess = tf.get_default_session() # init env nenv = Config.NUM_ENVS env = make_general_env(nenv, rand_seed=seed) utils.mpi_print('Set up env') policy = policies_back.get_policy() utils.mpi_print('Set up policy') optimizer = SeedOptimizer(env=env, logdir=Config.LOGDIR, spare_size=Config.SPA_LEVELS, ini_size=Config.INI_LEVELS, eval_limit=phase_eval_limit, train_set_limit=Config.NUM_LEVELS, load_seed=Config.LOAD_SEED, rand_seed=seed, rep=1, log=True) step_elapsed = 0 t = 0 if args.restore_id is not None: datapoints = Config.get_load_data('default')['datapoints'] step_elapsed = datapoints[-1][0] optimizer.load() seed = optimizer.hist[-1] env.set_seed(seed) t = 16 print('loadrestore') Config.RESTORE_ID = Config.get_load_data( 'default')['args']['run_id'] Config.RUN_ID = Config.get_load_data( 'default')['args']['run_id'].replace('-', '_') while (step_elapsed < (Config.TOTAL_STEP - 1) * 10**6): # ============ GARL ================= # optimize policy mean_rewards, datapoints = learn_func( sess=sess, policy=policy, env=env, log_interval=args.log_interval, save_interval=args.save_interval, nsteps=Config.NUM_STEPS, nminibatches=Config.NUM_MINIBATCHES, lam=Config.GAE_LAMBDA, gamma=Config.GAMMA, noptepochs=Config.PPO_EPOCHS, ent_coef=Config.ENTROPY_COEFF, vf_coef=Config.VF_COEFF, max_grad_norm=Config.MAX_GRAD_NORM, lr=lambda f: f * Config.LEARNING_RATE, cliprange=lambda f: f * Config.CLIP_RANGE, start_timesteps=step_elapsed, total_timesteps=phase_timesteps, index=t) # test catestrophic forgetting if 'Forget' in Config.RUN_ID: last_set = list(env.get_seed_set()) if t > 0: curr_set = list(env.get_seed_set()) last_scores, _ = eval_test(sess, nenv, last_set, train=True, idx=None, rep_count=len(last_set)) curr_scores, _ = eval_test(sess, nenv, curr_set, train=True, idx=None, rep_count=len(curr_set)) tmp = set(curr_set).difference(set(last_set)) mpi_print("Forgetting Exp") mpi_print("Last setsize", len(last_set)) mpi_print("Last scores", np.mean(last_scores), "Curr scores", np.mean(curr_scores)) mpi_print("Replace count", len(tmp)) # optimize env step_elapsed = datapoints[-1][0] if t < Config.TRAIN_ITER: best_rew_mean = max(mean_rewards) env, step_elapsed = optimizer.run(sess, env, step_elapsed, best_rew_mean) t += 1 save_final_test = True if save_final_test: final_test = {} final_test['step_elapsed'] = step_elapsed train_set = env.get_seed() final_test['train_set_size'] = len(train_set) eval_log = eval_test(sess, nenv, train_set, train=True, is_high=False, rep_count=1000, log=True) final_test['Train_set'] = eval_log eval_log = final_test(sess, nenv, None, train=False, is_high=True, rep_count=1000, log=True) final_test['Test_set'] = eval_log joblib.dump(final_test, setup_utils.file_to_path('final_test')) env.close()