예제 #1
0
def get_policy(observations, hparams, action_space):
    """Get a policy network.

  Args:
    observations: observations
    hparams: parameters
    action_space: action space

  Returns:
    Tuple (action logits, value).
  """
    if not isinstance(action_space, gym.spaces.Discrete):
        raise ValueError("Expecting discrete action space.")

    policy_problem = DummyPolicyProblem(action_space)
    trainer_lib.add_problem_hparams(hparams, policy_problem)
    hparams.force_full_predict = True
    model = registry.model(hparams.policy_network)(hparams,
                                                   tf.estimator.ModeKeys.TRAIN)
    obs_shape = common_layers.shape_list(observations)
    features = {
        "inputs": observations,
        "input_action": tf.zeros(obs_shape[:2] + [1], dtype=tf.int32),
        "input_reward": tf.zeros(obs_shape[:2] + [1], dtype=tf.int32),
        "targets": tf.zeros(obs_shape[:1] + [1] + obs_shape[2:]),
        "target_action": tf.zeros(obs_shape[:1] + [1, 1], dtype=tf.int32),
        "target_reward": tf.zeros(obs_shape[:1] + [1, 1], dtype=tf.int32),
        "target_policy": tf.zeros(obs_shape[:1] + [1] + [action_space.n]),
        "target_value": tf.zeros(obs_shape[:1] + [1])
    }
    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
        t2t_model.create_dummy_vars()
        (targets, _) = model(features)
    return (targets["target_policy"], targets["target_value"])
예제 #2
0
def get_policy(observations, hparams, action_space):
    """Get a policy network.

  Args:
    observations: observations
    hparams: parameters
    action_space: action space

  Returns:
    Tuple (action logits, value).
  """
    if not isinstance(action_space, gym.spaces.Discrete):
        raise ValueError("Expecting discrete action space.")

    obs_shape = common_layers.shape_list(observations)
    (frame_height, frame_width) = obs_shape[2:4]

    # TODO(afrozm): We have these dummy problems mainly for hparams, so cleanup
    # when possible and do this properly.
    if hparams.policy_problem_name == "dummy_policy_problem_ttt":
        tf.logging.info("Using DummyPolicyProblemTTT for the policy.")
        policy_problem = tic_tac_toe_env.DummyPolicyProblemTTT()
    else:
        tf.logging.info("Using DummyPolicyProblem for the policy.")
        policy_problem = DummyPolicyProblem(action_space, frame_height,
                                            frame_width)

    trainer_lib.add_problem_hparams(hparams, policy_problem)
    hparams.force_full_predict = True
    model = registry.model(hparams.policy_network)(hparams,
                                                   tf.estimator.ModeKeys.TRAIN)
    try:
        num_target_frames = hparams.video_num_target_frames
    except AttributeError:
        num_target_frames = 1
    features = {
        "inputs":
        observations,
        "input_action":
        tf.zeros(obs_shape[:2] + [1], dtype=tf.int32),
        "input_reward":
        tf.zeros(obs_shape[:2] + [1], dtype=tf.int32),
        "targets":
        tf.zeros(obs_shape[:1] + [num_target_frames] + obs_shape[2:]),
        "target_action":
        tf.zeros(obs_shape[:1] + [num_target_frames, 1], dtype=tf.int32),
        "target_reward":
        tf.zeros(obs_shape[:1] + [num_target_frames, 1], dtype=tf.int32),
        "target_policy":
        tf.zeros(obs_shape[:1] + [num_target_frames] + [action_space.n]),
        "target_value":
        tf.zeros(obs_shape[:1] + [num_target_frames])
    }
    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
        t2t_model.create_dummy_vars()
        (targets, _) = model(features)
    return (targets["target_policy"][:, 0, :], targets["target_value"][:, 0])