示例#1
0
def test_training_scheduler_calls_more_than_one_train():
    """
    Ensure the training scheduler trains each component in a multi-component agent.
    """
    schedule = {
        MultiComponentAgent.COMPONENT_1:
        TrainingDefinition(
            1,
            IdentifiableComponentTrainer(
                MultiComponentAgent.COMPONENT_1.name)),
        MultiComponentAgent.COMPONENT_2:
        TrainingDefinition(
            1,
            IdentifiableComponentTrainer(
                MultiComponentAgent.COMPONENT_2.name)),
        MultiComponentAgent.COMPONENT_3:
        TrainingDefinition(
            1,
            IdentifiableComponentTrainer(
                MultiComponentAgent.COMPONENT_3.name)),
    }

    training_scheduler = TFTrainingScheduler(schedule)
    loss_dictionary = training_scheduler.maybe_train(
        tf.ones(tuple(), dtype=tf.int64))

    assert (loss_dictionary[MultiComponentAgent.COMPONENT_1].extra ==
            MultiComponentAgent.COMPONENT_1.name)
    assert (loss_dictionary[MultiComponentAgent.COMPONENT_2].extra ==
            MultiComponentAgent.COMPONENT_2.name)
    assert (loss_dictionary[MultiComponentAgent.COMPONENT_3].extra ==
            MultiComponentAgent.COMPONENT_3.name)
示例#2
0
 def create_training_scheduler(
     self, agent: TFAgent, replay_buffer: ReplayBuffer
 ) -> TFTrainingScheduler:
     schedule = {
         MultiComponentAgent.COMPONENT_1: TrainingDefinition(
             1, IdentifiableComponentTrainer(MultiComponentAgent.COMPONENT_1.name)
         ),
         MultiComponentAgent.COMPONENT_2: TrainingDefinition(
             2, IdentifiableComponentTrainer(MultiComponentAgent.COMPONENT_2.name)
         ),
         MultiComponentAgent.COMPONENT_3: TrainingDefinition(
             3, IdentifiableComponentTrainer(MultiComponentAgent.COMPONENT_3.name)
         ),
     }
     return TFTrainingScheduler(cast(Dict[Enum, TrainingDefinition], schedule))
示例#3
0
    def create_training_scheduler(
            self, agent: TFAgent,
            replay_buffer: ReplayBuffer) -> TFTrainingScheduler:
        assert isinstance(agent,
                          (DdpgAgent, Td3Agent,
                           SacAgent)), "Expect an off-policy model free agent."

        def _not_boundary(trajectories, _):
            return ~trajectories.is_boundary()[0]

        dataset = (replay_buffer.as_dataset(
            num_steps=2).filter(_not_boundary).batch(
                self._training_data_batch_size))
        iterator = iter(dataset)

        def train_step():
            if (tf.data.experimental.cardinality(dataset).numpy() >=
                    self._training_data_batch_size):
                experience, _ = next(iterator)
                return agent.train(experience)
            return LossInfo(None, None)

        schedule = {
            ModelFreeAgentComponent.MODEL_FREE_AGENT:
            TrainingDefinition(self._steps_per_policy_update, train_step)
        }
        return TFTrainingScheduler(
            cast(Dict[Enum, TrainingDefinition], schedule))
示例#4
0
def test_training_scheduler_resets_step_counter():
    """
    Ensure the training scheduler trains a single component agent at the specified schedule.
    """
    schedule = {
        SingleComponentAgent.COMPONENT:
        TrainingDefinition(
            2,
            IdentifiableComponentTrainer(SingleComponentAgent.COMPONENT.name))
    }

    training_scheduler = TFTrainingScheduler(schedule)
    loss_dictionary_1 = training_scheduler.maybe_train(
        1 * tf.ones(tuple(), dtype=tf.int64))
    loss_dictionary_2 = training_scheduler.maybe_train(
        2 * tf.ones(tuple(), dtype=tf.int64))
    loss_dictionary_3 = training_scheduler.maybe_train(
        3 * tf.ones(tuple(), dtype=tf.int64))
    loss_dictionary_4 = training_scheduler.maybe_train(
        4 * tf.ones(tuple(), dtype=tf.int64))

    assert not loss_dictionary_1
    assert not loss_dictionary_3
    assert (loss_dictionary_2[SingleComponentAgent.COMPONENT].extra ==
            SingleComponentAgent.COMPONENT.name)
    assert (loss_dictionary_4[SingleComponentAgent.COMPONENT].extra ==
            SingleComponentAgent.COMPONENT.name)
示例#5
0
def test_training_scheduler_resets_one_step_counter_of_several():
    """
    Ensure that the training scheduler trains each component of a multi-component agent at the
    specified schedule.
    """
    schedule = {
        MultiComponentAgent.COMPONENT_1:
        TrainingDefinition(
            1,
            IdentifiableComponentTrainer(
                MultiComponentAgent.COMPONENT_1.name)),
        MultiComponentAgent.COMPONENT_2:
        TrainingDefinition(
            2,
            IdentifiableComponentTrainer(
                MultiComponentAgent.COMPONENT_2.name)),
        MultiComponentAgent.COMPONENT_3:
        TrainingDefinition(
            3,
            IdentifiableComponentTrainer(
                MultiComponentAgent.COMPONENT_3.name)),
    }

    training_scheduler = TFTrainingScheduler(schedule)
    loss_dictionary_1 = training_scheduler.maybe_train(
        1 * tf.ones(tuple(), dtype=tf.int64))
    loss_dictionary_2 = training_scheduler.maybe_train(
        2 * tf.ones(tuple(), dtype=tf.int64))
    loss_dictionary_3 = training_scheduler.maybe_train(
        3 * tf.ones(tuple(), dtype=tf.int64))

    assert len(loss_dictionary_1) == 1
    assert (loss_dictionary_1[MultiComponentAgent.COMPONENT_1].extra ==
            MultiComponentAgent.COMPONENT_1.name)

    assert len(loss_dictionary_2) == 2
    assert (loss_dictionary_2[MultiComponentAgent.COMPONENT_1].extra ==
            MultiComponentAgent.COMPONENT_1.name)
    assert (loss_dictionary_2[MultiComponentAgent.COMPONENT_2].extra ==
            MultiComponentAgent.COMPONENT_2.name)

    assert len(loss_dictionary_3) == 2
    assert (loss_dictionary_3[MultiComponentAgent.COMPONENT_1].extra ==
            MultiComponentAgent.COMPONENT_1.name)
    assert (loss_dictionary_3[MultiComponentAgent.COMPONENT_3].extra ==
            MultiComponentAgent.COMPONENT_3.name)
示例#6
0
    def create_training_scheduler(
            self, agent: TFAgent,
            replay_buffer: ReplayBuffer) -> TFTrainingScheduler:
        assert isinstance(
            agent,
            BackgroundPlanningAgent), "Expect a `BackgroundPlanningAgent`."

        # specify the train step for training the transition model
        train_transition_model_kwargs_dict = {
            TRAIN_ARGSPEC_COMPONENT_ID:
            EnvironmentModelComponents.TRANSITION.value
        }

        def train_transition_model_step() -> LossInfo:
            self._has_transition_model_been_trained = True
            trajectory = replay_buffer.gather_all()
            return agent.train(trajectory,
                               **train_transition_model_kwargs_dict)

        # specify the train step for training the model-free agent
        train_model_free_agent_kwargs_dict = {
            TRAIN_ARGSPEC_COMPONENT_ID:
            ModelFreeAgentComponent.MODEL_FREE_AGENT.value
        }

        def train_model_free_agent_step() -> LossInfo:
            if not self._has_transition_model_been_trained:
                return LossInfo(None, None)
            trajectory = replay_buffer.gather_all()
            return agent.train(trajectory,
                               **train_model_free_agent_kwargs_dict)

        # create scheduler
        schedule = {
            EnvironmentModelComponents.TRANSITION:
            TrainingDefinition(self._steps_per_transition_model_update,
                               train_transition_model_step),
            ModelFreeAgentComponent.MODEL_FREE_AGENT:
            TrainingDefinition(self._steps_per_model_free_agent_update,
                               train_model_free_agent_step),
        }
        return TFTrainingScheduler(
            cast(Dict[Enum, TrainingDefinition], schedule))
示例#7
0
 def create_training_scheduler(
     self, agent: TFAgent, replay_buffer: ReplayBuffer
 ) -> TFTrainingScheduler:
     schedule = {
         SingleComponentAgent.COMPONENT: TrainingDefinition(
             self._interval,
             IdentifiableComponentTrainer(SingleComponentAgent.COMPONENT.name),
         )
     }
     return TFTrainingScheduler(cast(Dict[Enum, TrainingDefinition], schedule))
示例#8
0
def test_training_info_does_not_contain_none_losses():
    def _none_returning_train_step():
        return LossInfo(loss=None, extra=None)

    schedule = {
        SingleComponentAgent.COMPONENT:
        TrainingDefinition(1, _none_returning_train_step)
    }
    training_scheduler = TFTrainingScheduler(schedule)

    loss_dictionary = training_scheduler.maybe_train(
        tf.ones(tuple(), dtype=tf.int64))
    assert not loss_dictionary  # loss_dictionary should be empty
示例#9
0
def test_training_scheduler_calls_train():
    """
    Ensure that the training scheduler trains a single component agent.
    """
    schedule = {
        SingleComponentAgent.COMPONENT:
        TrainingDefinition(
            1,
            IdentifiableComponentTrainer(SingleComponentAgent.COMPONENT.name))
    }

    training_scheduler = TFTrainingScheduler(schedule)
    loss_dictionary = training_scheduler.maybe_train(
        tf.ones(tuple(), dtype=tf.int64))

    assert (loss_dictionary[SingleComponentAgent.COMPONENT].extra ==
            SingleComponentAgent.COMPONENT.name)
示例#10
0
    def create_training_scheduler(
            self, agent: TFAgent,
            replay_buffer: ReplayBuffer) -> TFTrainingScheduler:
        assert isinstance(
            agent,
            (TRPOAgent, PPOAgent)), "Expect an on-policy model free agent."

        def train_step() -> LossInfo:
            trajectory = replay_buffer.gather_all()
            replay_buffer.clear()
            return agent.train(trajectory)

        schedule = {
            ModelFreeAgentComponent.MODEL_FREE_AGENT:
            TrainingDefinition(self._steps_per_policy_update, train_step)
        }
        return TFTrainingScheduler(
            cast(Dict[Enum, TrainingDefinition], schedule))
    def create_training_scheduler(
            self, agent: TFAgent,
            replay_buffer: ReplayBuffer) -> TFTrainingScheduler:
        assert isinstance(
            agent,
            DecisionTimePlanningAgent), "Expect a `DecisionTimePlanningAgent`."

        train_kwargs_dict = {
            TRAIN_ARGSPEC_COMPONENT_ID:
            EnvironmentModelComponents.TRANSITION.value
        }

        def train_step() -> LossInfo:
            trajectory = replay_buffer.gather_all()
            return agent.train(trajectory, **train_kwargs_dict)

        schedule = {
            EnvironmentModelComponents.TRANSITION:
            TrainingDefinition(self._steps_per_transition_model_update,
                               train_step)
        }
        return TFTrainingScheduler(
            cast(Dict[Enum, TrainingDefinition], schedule))