def get_policy(observations, hparams, action_space): """Get a policy network. Args: observations: observations hparams: parameters action_space: action space Returns: Tuple (action logits, value). """ if not isinstance(action_space, gym.spaces.Discrete): raise ValueError("Expecting discrete action space.") policy_problem = DummyPolicyProblem(action_space) trainer_lib.add_problem_hparams(hparams, policy_problem) hparams.force_full_predict = True model = registry.model(hparams.policy_network)(hparams, tf.estimator.ModeKeys.TRAIN) obs_shape = common_layers.shape_list(observations) features = { "inputs": observations, "input_action": tf.zeros(obs_shape[:2] + [1], dtype=tf.int32), "input_reward": tf.zeros(obs_shape[:2] + [1], dtype=tf.int32), "targets": tf.zeros(obs_shape[:1] + [1] + obs_shape[2:]), "target_action": tf.zeros(obs_shape[:1] + [1, 1], dtype=tf.int32), "target_reward": tf.zeros(obs_shape[:1] + [1, 1], dtype=tf.int32), "target_policy": tf.zeros(obs_shape[:1] + [1] + [action_space.n]), "target_value": tf.zeros(obs_shape[:1] + [1]) } with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): t2t_model.create_dummy_vars() (targets, _) = model(features) return (targets["target_policy"], targets["target_value"])
def get_policy(observations, hparams, action_space): """Get a policy network. Args: observations: observations hparams: parameters action_space: action space Returns: Tuple (action logits, value). """ if not isinstance(action_space, gym.spaces.Discrete): raise ValueError("Expecting discrete action space.") obs_shape = common_layers.shape_list(observations) (frame_height, frame_width) = obs_shape[2:4] # TODO(afrozm): We have these dummy problems mainly for hparams, so cleanup # when possible and do this properly. if hparams.policy_problem_name == "dummy_policy_problem_ttt": tf.logging.info("Using DummyPolicyProblemTTT for the policy.") policy_problem = tic_tac_toe_env.DummyPolicyProblemTTT() else: tf.logging.info("Using DummyPolicyProblem for the policy.") policy_problem = DummyPolicyProblem(action_space, frame_height, frame_width) trainer_lib.add_problem_hparams(hparams, policy_problem) hparams.force_full_predict = True model = registry.model(hparams.policy_network)(hparams, tf.estimator.ModeKeys.TRAIN) try: num_target_frames = hparams.video_num_target_frames except AttributeError: num_target_frames = 1 features = { "inputs": observations, "input_action": tf.zeros(obs_shape[:2] + [1], dtype=tf.int32), "input_reward": tf.zeros(obs_shape[:2] + [1], dtype=tf.int32), "targets": tf.zeros(obs_shape[:1] + [num_target_frames] + obs_shape[2:]), "target_action": tf.zeros(obs_shape[:1] + [num_target_frames, 1], dtype=tf.int32), "target_reward": tf.zeros(obs_shape[:1] + [num_target_frames, 1], dtype=tf.int32), "target_policy": tf.zeros(obs_shape[:1] + [num_target_frames] + [action_space.n]), "target_value": tf.zeros(obs_shape[:1] + [num_target_frames]) } with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): t2t_model.create_dummy_vars() (targets, _) = model(features) return (targets["target_policy"][:, 0, :], targets["target_value"][:, 0])