def test_training_scheduler_calls_more_than_one_train(): """ Ensure the training scheduler trains each component in a multi-component agent. """ schedule = { MultiComponentAgent.COMPONENT_1: TrainingDefinition( 1, IdentifiableComponentTrainer( MultiComponentAgent.COMPONENT_1.name)), MultiComponentAgent.COMPONENT_2: TrainingDefinition( 1, IdentifiableComponentTrainer( MultiComponentAgent.COMPONENT_2.name)), MultiComponentAgent.COMPONENT_3: TrainingDefinition( 1, IdentifiableComponentTrainer( MultiComponentAgent.COMPONENT_3.name)), } training_scheduler = TFTrainingScheduler(schedule) loss_dictionary = training_scheduler.maybe_train( tf.ones(tuple(), dtype=tf.int64)) assert (loss_dictionary[MultiComponentAgent.COMPONENT_1].extra == MultiComponentAgent.COMPONENT_1.name) assert (loss_dictionary[MultiComponentAgent.COMPONENT_2].extra == MultiComponentAgent.COMPONENT_2.name) assert (loss_dictionary[MultiComponentAgent.COMPONENT_3].extra == MultiComponentAgent.COMPONENT_3.name)
def create_training_scheduler( self, agent: TFAgent, replay_buffer: ReplayBuffer ) -> TFTrainingScheduler: schedule = { MultiComponentAgent.COMPONENT_1: TrainingDefinition( 1, IdentifiableComponentTrainer(MultiComponentAgent.COMPONENT_1.name) ), MultiComponentAgent.COMPONENT_2: TrainingDefinition( 2, IdentifiableComponentTrainer(MultiComponentAgent.COMPONENT_2.name) ), MultiComponentAgent.COMPONENT_3: TrainingDefinition( 3, IdentifiableComponentTrainer(MultiComponentAgent.COMPONENT_3.name) ), } return TFTrainingScheduler(cast(Dict[Enum, TrainingDefinition], schedule))
def create_training_scheduler( self, agent: TFAgent, replay_buffer: ReplayBuffer) -> TFTrainingScheduler: assert isinstance(agent, (DdpgAgent, Td3Agent, SacAgent)), "Expect an off-policy model free agent." def _not_boundary(trajectories, _): return ~trajectories.is_boundary()[0] dataset = (replay_buffer.as_dataset( num_steps=2).filter(_not_boundary).batch( self._training_data_batch_size)) iterator = iter(dataset) def train_step(): if (tf.data.experimental.cardinality(dataset).numpy() >= self._training_data_batch_size): experience, _ = next(iterator) return agent.train(experience) return LossInfo(None, None) schedule = { ModelFreeAgentComponent.MODEL_FREE_AGENT: TrainingDefinition(self._steps_per_policy_update, train_step) } return TFTrainingScheduler( cast(Dict[Enum, TrainingDefinition], schedule))
def test_training_scheduler_resets_step_counter(): """ Ensure the training scheduler trains a single component agent at the specified schedule. """ schedule = { SingleComponentAgent.COMPONENT: TrainingDefinition( 2, IdentifiableComponentTrainer(SingleComponentAgent.COMPONENT.name)) } training_scheduler = TFTrainingScheduler(schedule) loss_dictionary_1 = training_scheduler.maybe_train( 1 * tf.ones(tuple(), dtype=tf.int64)) loss_dictionary_2 = training_scheduler.maybe_train( 2 * tf.ones(tuple(), dtype=tf.int64)) loss_dictionary_3 = training_scheduler.maybe_train( 3 * tf.ones(tuple(), dtype=tf.int64)) loss_dictionary_4 = training_scheduler.maybe_train( 4 * tf.ones(tuple(), dtype=tf.int64)) assert not loss_dictionary_1 assert not loss_dictionary_3 assert (loss_dictionary_2[SingleComponentAgent.COMPONENT].extra == SingleComponentAgent.COMPONENT.name) assert (loss_dictionary_4[SingleComponentAgent.COMPONENT].extra == SingleComponentAgent.COMPONENT.name)
def test_training_scheduler_resets_one_step_counter_of_several(): """ Ensure that the training scheduler trains each component of a multi-component agent at the specified schedule. """ schedule = { MultiComponentAgent.COMPONENT_1: TrainingDefinition( 1, IdentifiableComponentTrainer( MultiComponentAgent.COMPONENT_1.name)), MultiComponentAgent.COMPONENT_2: TrainingDefinition( 2, IdentifiableComponentTrainer( MultiComponentAgent.COMPONENT_2.name)), MultiComponentAgent.COMPONENT_3: TrainingDefinition( 3, IdentifiableComponentTrainer( MultiComponentAgent.COMPONENT_3.name)), } training_scheduler = TFTrainingScheduler(schedule) loss_dictionary_1 = training_scheduler.maybe_train( 1 * tf.ones(tuple(), dtype=tf.int64)) loss_dictionary_2 = training_scheduler.maybe_train( 2 * tf.ones(tuple(), dtype=tf.int64)) loss_dictionary_3 = training_scheduler.maybe_train( 3 * tf.ones(tuple(), dtype=tf.int64)) assert len(loss_dictionary_1) == 1 assert (loss_dictionary_1[MultiComponentAgent.COMPONENT_1].extra == MultiComponentAgent.COMPONENT_1.name) assert len(loss_dictionary_2) == 2 assert (loss_dictionary_2[MultiComponentAgent.COMPONENT_1].extra == MultiComponentAgent.COMPONENT_1.name) assert (loss_dictionary_2[MultiComponentAgent.COMPONENT_2].extra == MultiComponentAgent.COMPONENT_2.name) assert len(loss_dictionary_3) == 2 assert (loss_dictionary_3[MultiComponentAgent.COMPONENT_1].extra == MultiComponentAgent.COMPONENT_1.name) assert (loss_dictionary_3[MultiComponentAgent.COMPONENT_3].extra == MultiComponentAgent.COMPONENT_3.name)
def create_training_scheduler( self, agent: TFAgent, replay_buffer: ReplayBuffer) -> TFTrainingScheduler: assert isinstance( agent, BackgroundPlanningAgent), "Expect a `BackgroundPlanningAgent`." # specify the train step for training the transition model train_transition_model_kwargs_dict = { TRAIN_ARGSPEC_COMPONENT_ID: EnvironmentModelComponents.TRANSITION.value } def train_transition_model_step() -> LossInfo: self._has_transition_model_been_trained = True trajectory = replay_buffer.gather_all() return agent.train(trajectory, **train_transition_model_kwargs_dict) # specify the train step for training the model-free agent train_model_free_agent_kwargs_dict = { TRAIN_ARGSPEC_COMPONENT_ID: ModelFreeAgentComponent.MODEL_FREE_AGENT.value } def train_model_free_agent_step() -> LossInfo: if not self._has_transition_model_been_trained: return LossInfo(None, None) trajectory = replay_buffer.gather_all() return agent.train(trajectory, **train_model_free_agent_kwargs_dict) # create scheduler schedule = { EnvironmentModelComponents.TRANSITION: TrainingDefinition(self._steps_per_transition_model_update, train_transition_model_step), ModelFreeAgentComponent.MODEL_FREE_AGENT: TrainingDefinition(self._steps_per_model_free_agent_update, train_model_free_agent_step), } return TFTrainingScheduler( cast(Dict[Enum, TrainingDefinition], schedule))
def create_training_scheduler( self, agent: TFAgent, replay_buffer: ReplayBuffer ) -> TFTrainingScheduler: schedule = { SingleComponentAgent.COMPONENT: TrainingDefinition( self._interval, IdentifiableComponentTrainer(SingleComponentAgent.COMPONENT.name), ) } return TFTrainingScheduler(cast(Dict[Enum, TrainingDefinition], schedule))
def test_training_info_does_not_contain_none_losses(): def _none_returning_train_step(): return LossInfo(loss=None, extra=None) schedule = { SingleComponentAgent.COMPONENT: TrainingDefinition(1, _none_returning_train_step) } training_scheduler = TFTrainingScheduler(schedule) loss_dictionary = training_scheduler.maybe_train( tf.ones(tuple(), dtype=tf.int64)) assert not loss_dictionary # loss_dictionary should be empty
def test_training_scheduler_calls_train(): """ Ensure that the training scheduler trains a single component agent. """ schedule = { SingleComponentAgent.COMPONENT: TrainingDefinition( 1, IdentifiableComponentTrainer(SingleComponentAgent.COMPONENT.name)) } training_scheduler = TFTrainingScheduler(schedule) loss_dictionary = training_scheduler.maybe_train( tf.ones(tuple(), dtype=tf.int64)) assert (loss_dictionary[SingleComponentAgent.COMPONENT].extra == SingleComponentAgent.COMPONENT.name)
def create_training_scheduler( self, agent: TFAgent, replay_buffer: ReplayBuffer) -> TFTrainingScheduler: assert isinstance( agent, (TRPOAgent, PPOAgent)), "Expect an on-policy model free agent." def train_step() -> LossInfo: trajectory = replay_buffer.gather_all() replay_buffer.clear() return agent.train(trajectory) schedule = { ModelFreeAgentComponent.MODEL_FREE_AGENT: TrainingDefinition(self._steps_per_policy_update, train_step) } return TFTrainingScheduler( cast(Dict[Enum, TrainingDefinition], schedule))
def create_training_scheduler( self, agent: TFAgent, replay_buffer: ReplayBuffer) -> TFTrainingScheduler: assert isinstance( agent, DecisionTimePlanningAgent), "Expect a `DecisionTimePlanningAgent`." train_kwargs_dict = { TRAIN_ARGSPEC_COMPONENT_ID: EnvironmentModelComponents.TRANSITION.value } def train_step() -> LossInfo: trajectory = replay_buffer.gather_all() return agent.train(trajectory, **train_kwargs_dict) schedule = { EnvironmentModelComponents.TRANSITION: TrainingDefinition(self._steps_per_transition_model_update, train_step) } return TFTrainingScheduler( cast(Dict[Enum, TrainingDefinition], schedule))