def _create_dqn_agent(time_step_spec, action_spec, policy_network): """Creates a dqn_agent.""" layers = tf.nest.map_structure( feature_ops.get_observation_processing_layer_creator(), time_step_spec.observation) network = policy_network(time_step_spec.observation, action_spec, preprocessing_layers=layers, name='QNetwork') return dqn_agent.DqnAgent(time_step_spec, action_spec, q_network=network)
def _create_behavioral_cloning_agent(time_step_spec, action_spec, policy_network): """Creates a behavioral_cloning_agent.""" layers = tf.nest.map_structure( feature_ops.get_observation_processing_layer_creator(), time_step_spec.observation) network = policy_network(time_step_spec.observation, action_spec, preprocessing_layers=layers, name='QNetwork') return behavioral_cloning_agent.BehavioralCloningAgent( time_step_spec, action_spec, cloning_network=network, num_outer_dims=2)
def _create_ppo_agent(time_step_spec, action_spec, policy_network): """Creates a ppo_agent.""" layers = tf.nest.map_structure( feature_ops.get_observation_processing_layer_creator(), time_step_spec.observation) actor_network = policy_network(time_step_spec.observation, action_spec, preprocessing_layers=layers, name='ActorDistributionNetwork') critic_network = constant_value_network.ConstantValueNetwork( time_step_spec.observation, name='ConstantValueNetwork') return ppo_agent.PPOAgent(time_step_spec, action_spec, actor_net=actor_network, value_net=critic_network)