def main(): # load from restore file args_dict = utils.load_args() # train args of restore id test_args = setup_utils.setup_and_load() if 'NR' in Config.RESTORE_ID: Config.USE_LSTM = 2 if 'dropout' in Config.RESTORE_ID: Config.DROPOUT = 0 Config.USE_BATCH_NORM = 0 wandb.init(project="coinrun", notes="test", tags=["baseline", "test"], config=Config.get_args_dict()) config = tf.ConfigProto() config.gpu_options.allow_growth = True seed = np.random.randint(100000) Config.SET_SEED = seed overlap = { 'set_seed': Config.SET_SEED, 'rep': Config.REP, 'highd': Config.HIGH_DIFFICULTY, 'num_levels': Config.NUM_LEVELS, 'use_lstm': Config.USE_LSTM, 'dropout': Config.DROPOUT, 'use_batch_norm': Config.USE_BATCH_NORM } load_file = Config.get_load_filename(restore_id=Config.RESTORE_ID) mpi_print('load file name', load_file) mpi_print('seed', seed) mpi_print("---------------------------------------") for checkpoint in range(1, 33): with tf.Session() as sess: steps_elapsed = checkpoint * 8000000 mpi_print('steps_elapsed:', steps_elapsed) enjoy_env_sess(sess, checkpoint, overlap)
def restore_file_back(restore_id, load_key='default'): if restore_id is not None: load_file = Config.get_load_filename(restore_id=restore_id) filepath = file_to_path(load_file) load_data = joblib.load(filepath) Config.set_load_data(load_data, load_key=load_key) restored_args = load_data['args'] sub_dict = {} res_keys = Config.RES_KEYS for key in res_keys: if key in restored_args: sub_dict[key] = restored_args[key] else: print('warning key %s not restored' % key) Config.parse_args_dict(sub_dict) from coinrun.coinrunenv import init_args_and_threads init_args_and_threads(4)
def restore_file(restore_id, base_name=None, overlap_config=None, load_key='default'): """overlap config means you can modify the config in savefile, e.g. test seed""" if restore_id is not None: load_file = Config.get_load_filename(restore_id=restore_id, base_name=base_name) filepath = file_to_path(load_file) assert os.path.exists(filepath), "don't exist" load_data = joblib.load(filepath) Config.set_load_data(load_data, load_key=load_key) restored_args = load_data['args'] sub_dict = {} res_keys = Config.RES_KEYS for key in res_keys: if key in restored_args: sub_dict[key] = restored_args[key] else: print('warning key %s not restored' % key) Config.parse_args_dict(sub_dict) print("Load params") if overlap_config is not None: Config.parse_args_dict(overlap_config) from coinrun.coinrunenv import init_args_and_threads print("Init coinrun env threads and env args") init_args_and_threads(4) if restore_id == None: return None else: return load_file