def setUpClass(cls): super(MonMiniGridEnvTest, cls).setUpClass() gin.parse_config_files_and_bindings([ os.path.join(mon_minigrid.GIN_FILES_PREFIX, '{}.gin'.format('classic_fourrooms')) ], bindings=[], skip_unknown=False) cls.env_id = mon_minigrid.register_environment()
def main(_): flags.mark_flags_as_required(['base_dir']) if FLAGS.custom_base_dir_from_hparams is not None: FLAGS.base_dir = os.path.join(FLAGS.base_dir, FLAGS.custom_base_dir_from_hparams) else: # Add Work unit to base directory path, if it exists. if 'xm_wid' in FLAGS and FLAGS.xm_wid > 0: FLAGS.base_dir = os.path.join(FLAGS.base_dir, str(FLAGS.xm_wid)) xm_parameters = (None if 'xm_parameters' not in FLAGS else FLAGS.xm_parameters) if xm_parameters: xm_params = json.loads(xm_parameters) if 'env_name' in xm_params: FLAGS.env_name = xm_params['env_name'] if FLAGS.env_name is None: base_dir = os.path.join( FLAGS.base_dir, '{}_{}'.format(FLAGS.num_states, FLAGS.num_actions)) else: base_dir = os.path.join(FLAGS.base_dir, FLAGS.env_name) base_dir = os.path.join(base_dir, 'PVF', FLAGS.estimator) if not tf.io.gfile.exists(base_dir): tf.io.gfile.makedirs(base_dir) if FLAGS.env_name is not None: gin.add_config_file_search_path(_ENV_CONFIG_PATH) gin.parse_config_files_and_bindings( config_files=[f'{FLAGS.env_name}.gin'], bindings=FLAGS.gin_bindings, skip_unknown=False) env_id = mon_minigrid.register_environment() env = gym.make(env_id) env = RGBImgObsWrapper(env) # Get pixel observations # Get tabular observation and drop the 'mission' field: env = mdp_wrapper.MDPWrapper(env, get_rgb=False) env = coloring_wrapper.ColoringWrapper(env) if FLAGS.env_name is None: env = random_mdp.RandomMDP(FLAGS.num_states, FLAGS.num_actions) # We add the discount factor to the environment. env.gamma = FLAGS.gamma logging.set_verbosity(logging.INFO) gin_files = [] gin_bindings = FLAGS.gin_bindings runner = TrainRunner(base_dir, env, FLAGS.epochs, FLAGS.lr, FLAGS.estimator, FLAGS.alpha, FLAGS.optimizer, FLAGS.use_l2_reg, FLAGS.reg_coeff, FLAGS.use_penalty, FLAGS.j, FLAGS.num_rows, jax.random.PRNGKey(0), FLAGS.epochs - 1, FLAGS.epochs - 1) runner.train()
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') gin.parse_config_files_and_bindings([ os.path.join(mon_minigrid.GIN_FILES_PREFIX, '{}.gin'.format( FLAGS.env_name)) ], bindings=FLAGS.gin_bindings, skip_unknown=False) env_id = mon_minigrid.register_environment() env = gym.make(env_id) env = RGBImgObsWrapper(env) # Get pixel observations # Get tabular observation and drop the 'mission' field: env = tabular_wrapper.TabularWrapper(env, get_rgb=True) env.reset() num_frames = 0 max_num_frames = 500 if not tf.io.gfile.exists(FLAGS.file_path): tf.io.gfile.makedirs(FLAGS.file_path) print('Available actions:') for a in ACTION_MAPPINGS: print('\t{}: "{}"'.format(ACTION_MAPPINGS[a], a)) print() undisc_return = 0 while num_frames < max_num_frames: draw_ascii_view(env) a = input('action: ') if a not in ACTION_MAPPINGS: print('Unrecognized action.') continue action = env.DirectionalActions[ACTION_MAPPINGS[a]].value obs, reward, done, _ = env.step(action) undisc_return += reward num_frames += 1 print('t:', num_frames, ' s:', obs['state']) # Draw environment frame just for simple visualization plt.imshow(obs['image']) path = os.path.join(FLAGS.file_path, 'obs_{}.png'.format(num_frames)) plt.savefig(path) plt.clf() if done: break print('Undiscounted return: %.2f' % undisc_return) env.close()
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') gin.parse_config_files_and_bindings([ os.path.join(mon_minigrid.GIN_FILES_PREFIX, '{}.gin'.format(FLAGS.env)) ], bindings=FLAGS.gin_bindings, skip_unknown=False) env_id = mon_minigrid.register_environment() env = gym.make(env_id) env = RGBImgObsWrapper(env) # Get pixel observations # Get tabular observation and drop the 'mission' field: env = mdp_wrapper.MDPWrapper(env) env = coloring_wrapper.ColoringWrapper(env) values = np.zeros(env.num_states) error = FLAGS.tolerance * 2 i = 0 while error > FLAGS.tolerance: new_values = np.copy(values) for s in range(env.num_states): max_value = 0. for a in range(env.num_actions): curr_value = (env.rewards[s, a] + FLAGS.gamma * np.matmul(env.transition_probs[s, a, :], values)) if curr_value > max_value: max_value = curr_value new_values[s] = max_value error = np.max(abs(new_values - values)) values = new_values i += 1 if i % 1000 == 0: print('Error after {} iterations: {}'.format(i, error)) print('Found V* in {} iterations'.format(i)) print(values) if FLAGS.values_image_file is not None: cmap = cm.get_cmap('plasma', 256) norm = colors.Normalize(vmin=min(values), vmax=max(values)) obs_image = env.render_custom_observation(env.reset(), values, cmap, boundary_values=[1.0, 4.5]) m = cm.ScalarMappable(cmap=cmap, norm=norm) m.set_array(obs_image) plt.imshow(obs_image) plt.colorbar(m) plt.savefig(FLAGS.values_image_file) plt.clf() env.close()
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') gin.parse_config_files_and_bindings( [os.path.join(mon_minigrid.GIN_FILES_PREFIX, 'classic_fourrooms.gin')], bindings=FLAGS.gin_bindings, skip_unknown=False) env_id = mon_minigrid.register_environment() env = gym.make(env_id) env = RGBImgObsWrapper(env) # Get pixel observations # Get tabular observation and drop the 'mission' field: env = tabular_wrapper.TabularWrapper(env, get_rgb=True) env.reset() num_frames = 0 max_num_frames = 500 if not tf.io.gfile.exists(FLAGS.file_path): tf.io.gfile.makedirs(FLAGS.file_path) undisc_return = 0 while num_frames < max_num_frames: # Act randomly obs, reward, done, _ = env.step(env.action_space.sample()) undisc_return += reward num_frames += 1 print('t:', num_frames, ' s:', obs['state']) # Draw environment frame just for simple visualization plt.imshow(obs['image']) path = FLAGS.file_path + str(num_frames) + '.png' plt.savefig(path) plt.clf() if done: break print('Undiscounted return: %.2f' % undisc_return) env.close()
def main(_): flags.mark_flags_as_required(['base_dir']) if FLAGS.custom_base_dir_from_hparams is not None: FLAGS.base_dir = os.path.join(FLAGS.base_dir, FLAGS.custom_base_dir_from_hparams) else: # Add Work unit to base directory path, if it exists. if 'xm_wid' in FLAGS and FLAGS.xm_wid > 0: FLAGS.base_dir = os.path.join(FLAGS.base_dir, str(FLAGS.xm_wid)) xm_parameters = (None if 'xm_parameters' not in FLAGS else FLAGS.xm_parameters) if xm_parameters: xm_params = json.loads(xm_parameters) if 'env_name' in xm_params: FLAGS.env_name = xm_params['env_name'] if FLAGS.env_name is None: base_dir = os.path.join( FLAGS.base_dir, '{}_{}'.format(FLAGS.num_states, FLAGS.num_actions)) else: base_dir = os.path.join(FLAGS.base_dir, FLAGS.env_name) base_dir = os.path.join(base_dir, 'PVF', FLAGS.estimator, f'lr_{FLAGS.lr}') if not tf.io.gfile.exists(base_dir): tf.io.gfile.makedirs(base_dir) if FLAGS.env_name is not None: gin.add_config_file_search_path(_ENV_CONFIG_PATH) gin.parse_config_files_and_bindings( config_files=[f'{FLAGS.env_name}.gin'], bindings=FLAGS.gin_bindings, skip_unknown=False) env_id = mon_minigrid.register_environment() env = gym.make(env_id) env = RGBImgObsWrapper(env) # Get pixel observations # Get tabular observation and drop the 'mission' field: env = mdp_wrapper.MDPWrapper(env, get_rgb=False) env = coloring_wrapper.ColoringWrapper(env) if FLAGS.env_name is None: env = random_mdp.RandomMDP(FLAGS.num_states, FLAGS.num_actions) # We add the discount factor to the environment. env.gamma = FLAGS.gamma P = utils.transition_matrix(env, rl_basics.policy_random(env)) # pylint: disable=invalid-name S = P.shape[0] # pylint: disable=invalid-name Psi = jnp.linalg.solve(jnp.eye(S) - env.gamma * P, jnp.eye(S)) # pylint: disable=invalid-name # Normalize tasks so that they have maximum value 1. max_task_value = np.max(Psi, axis=0) Psi /= max_task_value # pylint: disable=invalid-name left_vectors, _, _ = jnp.linalg.svd(Psi) # pylint: disable=invalid-names approx_error = utils.approx_error(left_vectors, FLAGS.d, Psi) # Initialization of Phi representation_init = jax.random.normal( # pylint: disable=invalid-names jax.random.PRNGKey(0), (S, FLAGS.d), # pylint: disable=invalid-name dtype=jnp.float64) representations, grads = train(representation_init, Psi, FLAGS.epochs, FLAGS.lr, jax.random.PRNGKey(0), FLAGS.estimator, FLAGS.alpha, FLAGS.optimizer, FLAGS.use_l2_reg, FLAGS.reg_coeff, FLAGS.use_penalty, FLAGS.j, FLAGS.num_rows, FLAGS.skipsize_train) gm_distances = calc_gm_distances(representations, left_vectors[:, :FLAGS.d], FLAGS.skipsize) x_len = len(gm_distances) frob_norms = calc_frob_norms(representations, Psi, FLAGS.skipsize) if FLAGS.d == 1: dot_products = calc_dot_products(representations, left_vectors[:, :FLAGS.d], FLAGS.skipsize) else: dot_products = np.zeros((x_len,)) grad_norms = calc_grad_norms(grads, FLAGS.skipsize) phi_norms = calc_Phi_norm(representations, FLAGS.skipsize) phi_ranks = calc_sranks(representations, FLAGS.skipsize) prefix = f'alpha{FLAGS.alpha}_j{FLAGS.j}_d{FLAGS.d}_regcoeff{FLAGS.reg_coeff}' with tf.io.gfile.GFile(osp.join(base_dir, f'{prefix}.npy'), 'wb') as f: np.save( f, { 'gm_distances': gm_distances, 'dot_products': dot_products, 'frob_norms': frob_norms, 'approx_error': approx_error, 'grad_norms': grad_norms, 'representations': representations, 'phi_norms': phi_norms, 'phi_ranks': phi_ranks }, allow_pickle=True)