def test_update_policy_state_with_trajectories_that_do_not_terminate(
        observation_space, action_space, horizon):
    policy_state_updater = CrossEntropyMethodPolicyStateUpdater(
        num_elites=1, learning_rate=0.1)
    policy_state = None
    trajectory = generate_dummy_trajectories(observation_space,
                                             action_space,
                                             batch_size=1,
                                             trajectory_length=2)
    with pytest.raises(AssertionError) as excinfo:
        policy_state_updater.update(policy_state, trajectory, 1)

    assert "must end in a terminal state" in str(excinfo)
def test_update_policy_state_with_trajectories_that_reset_mid_way(
        observation_space, action_space, horizon):
    policy_state_updater = CrossEntropyMethodPolicyStateUpdater(
        num_elites=1, learning_rate=0.1)
    policy_state = None
    trajectory = generate_dummy_trajectories(observation_space,
                                             action_space,
                                             batch_size=1,
                                             trajectory_length=2)
    trajectory = trajectory.replace(
        step_type=tf.constant([[StepType.LAST, StepType.FIRST]]))
    with pytest.raises(AssertionError) as excinfo:
        policy_state_updater.update(policy_state, trajectory, 1)

    assert "contain a terminal state" in str(excinfo)
def test_train_method_increments_counter_for_generic_background_planning(
        mocker, agent_class):
    """
    The docstring for the `_train` method of a TFAgent requires that the implementation increments
    the `train_step_counter`.
    """
    population_size = 1
    horizon = 10
    model_free_training_iterations = 1

    mf_agent = create_mock_model_free_agent(mocker, TIMESTEP_SPEC, ACTION_SPEC,
                                            agent_class)
    network = LinearTransitionNetwork(OBSERVATION_SPEC)
    transition_model = KerasTransitionModel([network], OBSERVATION_SPEC,
                                            ACTION_SPEC)
    reward_model = ConstantReward(OBSERVATION_SPEC, ACTION_SPEC)
    initial_state_model = create_uniform_initial_state_distribution(
        OBSERVATION_SPEC)

    train_step_counter = common.create_variable("train_step_counter",
                                                shape=(),
                                                dtype=tf.float64)
    model_based_agent = BackgroundPlanningAgent(
        (transition_model, TransitionModelTrainingSpec(1, 1)),
        reward_model,
        initial_state_model,
        mf_agent,
        population_size,
        horizon,
        model_free_training_iterations,
        train_step_counter=train_step_counter,
    )

    dummy_trajectories = generate_dummy_trajectories(
        OBSERVATION_SPEC,
        ACTION_SPEC,
        batch_size=population_size,
        trajectory_length=horizon)
    train_kwargs = {
        TRAIN_ARGSPEC_COMPONENT_ID: EnvironmentModelComponents.TRANSITION.value
    }
    model_based_agent.train(dummy_trajectories, **train_kwargs)

    assert train_step_counter.value() == 1
def test_train_method_increments_counter_for_model_free_supported_agents(
    mocker, agent_class, train_component
):
    """
    The docstring for the `_train` method of a TFAgent requires that the implementation increments
    the `train_step_counter`.
    """
    population_size = 1
    number_of_particles = 1
    horizon = 10

    mf_agent = create_mock_model_free_agent(mocker, TIMESTEP_SPEC, ACTION_SPEC, agent_class)
    trajectory_optimiser = random_shooting_trajectory_optimisation(
        TIMESTEP_SPEC, ACTION_SPEC, horizon, population_size, number_of_particles
    )
    network = LinearTransitionNetwork(OBSERVATION_SPEC)
    transition_model = KerasTransitionModel([network], OBSERVATION_SPEC, ACTION_SPEC)
    reward_model = ConstantReward(OBSERVATION_SPEC, ACTION_SPEC)
    initial_state_model = create_uniform_initial_state_distribution(OBSERVATION_SPEC)

    train_step_counter = common.create_variable(
        "train_step_counter", shape=(), dtype=tf.float64
    )
    agent = ModelFreeSupportedDecisionTimePlanningAgent(
        TIMESTEP_SPEC,
        ACTION_SPEC,
        (transition_model, TransitionModelTrainingSpec(1, 1)),
        reward_model,
        initial_state_model,
        trajectory_optimiser,
        mf_agent,
        train_step_counter=train_step_counter,
    )

    dummy_trajectories = generate_dummy_trajectories(
        OBSERVATION_SPEC, ACTION_SPEC, batch_size=population_size, trajectory_length=horizon
    )
    train_kwargs = {TRAIN_ARGSPEC_COMPONENT_ID: train_component.value}
    agent.train(dummy_trajectories, **train_kwargs)

    assert train_step_counter.value() == 1
Esempio n. 5
0
def test_extract_transitions_from_trajectories(observation_space, action_space,
                                               batch_size, trajectory_length,
                                               predict_state_difference):
    trajectories = generate_dummy_trajectories(observation_space, action_space,
                                               batch_size, trajectory_length)
    transitions = extract_transitions_from_trajectories(
        trajectories, observation_space, action_space,
        predict_state_difference)

    observation = transitions.observation
    action = transitions.action
    reward = transitions.reward
    next_observation = transitions.next_observation

    assert is_batched_nested_tensors(
        tensors=[observation, action, reward, next_observation],
        specs=[observation_space, action_space, RewardSpec, observation_space],
    )

    assert (observation.shape[0] == action.shape[0] == reward.shape[0] ==
            next_observation.shape[0] ==
            (batch_size * (trajectory_length - 1)))
def test_train_oracle_transition_model():
    """
    Ensure that a non-trainable oracle transition model does not cause the agent `train` method to
    fail.
    """
    population_size = 1
    number_of_particles = 1
    horizon = 10

    trajectory_optimiser = random_shooting_trajectory_optimisation(
        TIMESTEP_SPEC, ACTION_SPEC, horizon, population_size, number_of_particles
    )
    transition_model = StubTransitionModel(OBSERVATION_SPEC, ACTION_SPEC)
    reward_model = ConstantReward(OBSERVATION_SPEC, ACTION_SPEC)
    initial_state_model = create_uniform_initial_state_distribution(OBSERVATION_SPEC)

    train_step_counter = common.create_variable(
        "train_step_counter", shape=(), dtype=tf.float64
    )
    with pytest.warns(RuntimeWarning):
        agent = DecisionTimePlanningAgent(
            TIMESTEP_SPEC,
            ACTION_SPEC,
            transition_model,
            reward_model,
            initial_state_model,
            trajectory_optimiser,
            train_step_counter=train_step_counter,
        )

    dummy_trajectories = generate_dummy_trajectories(
        OBSERVATION_SPEC, ACTION_SPEC, batch_size=population_size, trajectory_length=horizon
    )
    train_kwargs = {TRAIN_ARGSPEC_COMPONENT_ID: EnvironmentModelComponents.TRANSITION.value}
    loss_info = agent.train(dummy_trajectories, **train_kwargs)

    assert loss_info.loss is None
    assert loss_info.extra is None