예제 #1
0
def _create_dqn_agent(time_step_spec, action_spec, policy_network):
    """Creates a dqn_agent."""
    layers = tf.nest.map_structure(
        feature_ops.get_observation_processing_layer_creator(),
        time_step_spec.observation)

    network = policy_network(time_step_spec.observation,
                             action_spec,
                             preprocessing_layers=layers,
                             name='QNetwork')

    return dqn_agent.DqnAgent(time_step_spec, action_spec, q_network=network)
예제 #2
0
def _create_behavioral_cloning_agent(time_step_spec, action_spec,
                                     policy_network):
    """Creates a behavioral_cloning_agent."""
    layers = tf.nest.map_structure(
        feature_ops.get_observation_processing_layer_creator(),
        time_step_spec.observation)

    network = policy_network(time_step_spec.observation,
                             action_spec,
                             preprocessing_layers=layers,
                             name='QNetwork')

    return behavioral_cloning_agent.BehavioralCloningAgent(
        time_step_spec, action_spec, cloning_network=network, num_outer_dims=2)
예제 #3
0
def _create_ppo_agent(time_step_spec, action_spec, policy_network):
    """Creates a ppo_agent."""
    layers = tf.nest.map_structure(
        feature_ops.get_observation_processing_layer_creator(),
        time_step_spec.observation)

    actor_network = policy_network(time_step_spec.observation,
                                   action_spec,
                                   preprocessing_layers=layers,
                                   name='ActorDistributionNetwork')

    critic_network = constant_value_network.ConstantValueNetwork(
        time_step_spec.observation, name='ConstantValueNetwork')

    return ppo_agent.PPOAgent(time_step_spec,
                              action_spec,
                              actor_net=actor_network,
                              value_net=critic_network)