def _test_ensemble_model_close_to_actuals(trajectories, tf_env, trajectory_sampling_strategy): keras_transition_networks = [ LinearTransitionNetwork(tf_env.observation_spec(), True) for _ in range(_ENSEMBLE_SIZE) ] model = KerasTransitionModel( keras_transition_networks, tf_env.observation_spec(), tf_env.action_spec(), predict_state_difference=False, trajectory_sampling_strategy=trajectory_sampling_strategy, ) training_spec = KerasTrainingSpec( epochs=1000, training_batch_size=256, callbacks=[ tf.keras.callbacks.EarlyStopping(monitor="loss", patience=20) ], ) model.train(trajectories, training_spec) assert_rollouts_are_close_to_actuals(model, max_steps=1)
def test_fit_mountain_car_data( mountain_car_data, transition_network, bootstrap_data, batch_size, ensemble_size ): tf_env, trajectories = mountain_car_data network_list = [ transition_network(tf_env.observation_spec(), bootstrap_data=bootstrap_data) for _ in range(ensemble_size) ] transition_model = KerasTransitionModel( network_list, tf_env.observation_spec(), tf_env.action_spec(), predict_state_difference=False, trajectory_sampling_strategy=OneStepTrajectorySampling(batch_size, ensemble_size), ) training_spec = KerasTrainingSpec( epochs=10, training_batch_size=256, callbacks=[], ) history = transition_model.train(trajectories, training_spec) assert history.history["loss"][-1] < history.history["loss"][0]
def test_step_call_shape( transition_network, observation_space, action_space, batch_size, ensemble_size, ): network_list = [ transition_network(observation_space, bootstrap_data=True) for _ in range(ensemble_size) ] transition_model = KerasTransitionModel( network_list, observation_space, action_space, predict_state_difference=True, trajectory_sampling_strategy=OneStepTrajectorySampling(batch_size, ensemble_size), ) observation_distribution = create_uniform_distribution_from_spec(observation_space) observations = observation_distribution.sample((batch_size,)) action_distribution = create_uniform_distribution_from_spec(action_space) actions = action_distribution.sample((batch_size,)) next_observations = transition_model.step(observations, actions) assert next_observations.shape == (batch_size,) + observation_space.shape assert observation_space.is_compatible_with(next_observations[0])
def test_incorrect_termination_model(): """ The generic model-based agent should only allow a ConstantFalseTermination model. """ # setup arguments for the model-based agent constructor py_env = suite_gym.load("MountainCarContinuous-v0") tf_env = TFPyEnvironment(py_env) time_step_spec = tf_env.time_step_spec() observation_spec = tf_env.observation_spec() action_spec = tf_env.action_spec() network = LinearTransitionNetwork(observation_spec) transition_model = KerasTransitionModel([network], observation_spec, action_spec) reward_model = MountainCarReward(observation_spec, action_spec) initial_state_distribution_model = MountainCarInitialState(observation_spec) termination_model = MountainCarTermination(observation_spec) policy = RandomTFPolicy(time_step_spec, action_spec) with pytest.raises(AssertionError) as excinfo: ModelBasedAgent( time_step_spec, action_spec, transition_model, reward_model, termination_model, initial_state_distribution_model, policy, policy, ) assert "Only constant false termination supported" in str(excinfo.value)
def test_batched_environment_model(observation_space, action_space, batch_size): transition_network = DummyEnsembleTransitionNetwork(observation_space) transition_model = KerasTransitionModel( [transition_network], observation_space, action_space, ) reward = ConstantReward(observation_space, action_space, 0.0) termination = ConstantFalseTermination(observation_space) initial_state_sampler = create_uniform_initial_state_distribution( observation_space) env_model = EnvironmentModel(transition_model, reward, termination, initial_state_sampler, batch_size) action_distr = create_uniform_distribution_from_spec(action_space) single_action = action_distr.sample() batch_actions = tf.convert_to_tensor( [single_action for _ in range(batch_size)]) first_step = env_model.reset() assert (first_step.step_type == [ StepType.FIRST for _ in range(batch_size) ]).numpy().all() assert first_step.observation.shape == [batch_size] + list( observation_space.shape) next_step = env_model.step(batch_actions) assert (next_step.step_type == [StepType.MID for _ in range(batch_size)]).numpy().all() assert next_step.observation.shape == [batch_size] + list( observation_space.shape) assert next_step.reward.shape == [batch_size]
def _create_env_model(observation_space, action_space): batch_size = 3 time_limit = 5 terminations = MutableBatchConstantTermination(observation_space, batch_size) observation = create_uniform_distribution_from_spec(observation_space).sample() network = DummyEnsembleTransitionNetwork(observation_space) model = KerasTransitionModel([network], observation_space, action_space) env_model = TFTimeLimit( EnvironmentModel( transition_model=model, reward_model=ConstantReward(observation_space, action_space, -1.0), termination_model=terminations, initial_state_distribution_model=DeterministicInitialStateModel(observation), batch_size=batch_size, ), duration=time_limit, ) actions = create_uniform_distribution_from_spec(action_space).sample((batch_size,)) # Initial time step env_model.reset() observations = np.squeeze( np.repeat(np.expand_dims(observation, axis=0), batch_size, axis=0) ) return terminations, observations, actions, env_model
def test_replay_actions_across_batches(observation_space, action_space, horizon, batch_size): transition_network = DummyEnsembleTransitionNetwork(observation_space) transition_model = KerasTransitionModel( [transition_network], observation_space, action_space, ) reward = ConstantReward(observation_space, action_space, 0.0) termination = ConstantFalseTermination(observation_space) initial_state_sampler = create_uniform_initial_state_distribution( observation_space) env_model = TFTimeLimit( EnvironmentModel(transition_model, reward, termination, initial_state_sampler, batch_size), horizon, ) actions_distribution = create_uniform_initial_state_distribution( observation_space) actions = actions_distribution.sample((horizon, )) trajectory = replay_actions_across_batch_transition_models( env_model, actions) assert (trajectory.observation.shape == ( batch_size, horizon, ) + observation_space.shape)
def get_optimiser_and_environment_model( time_step_space, observation_space, action_space, population_size, number_of_particles, horizon, optimiser_policy_trajectory_optimiser_factory, sample_shape=(), ): reward = ConstantReward(observation_space, action_space, -1.0) batched_transition_network = DummyEnsembleTransitionNetwork( observation_space) batched_transition_model = KerasTransitionModel( [batched_transition_network], observation_space, action_space, ) observation = create_uniform_distribution_from_spec( observation_space).sample(sample_shape=sample_shape) environment_model = EnvironmentModel( transition_model=batched_transition_model, reward_model=reward, termination_model=ConstantFalseTermination(observation_space), initial_state_distribution_model=DeterministicInitialStateModel( observation), batch_size=population_size, ) trajectory_optimiser = optimiser_policy_trajectory_optimiser_factory( time_step_space, action_space, horizon, population_size, number_of_particles) return trajectory_optimiser, environment_model
def test_generate_virtual_rollouts(observation_space, action_space, batch_size, horizon): observation = create_uniform_distribution_from_spec( observation_space).sample() network = DummyEnsembleTransitionNetwork(observation_space) model = KerasTransitionModel([network], observation_space, action_space) env_model = EnvironmentModel( transition_model=model, reward_model=ConstantReward(observation_space, action_space, -1.0), termination_model=ConstantFalseTermination(observation_space), initial_state_distribution_model=DeterministicInitialStateModel( observation), batch_size=batch_size, ) random_policy = RandomTFPolicy(time_step_spec(observation_space), action_space) replay_buffer, driver, wrapped_env_model = virtual_rollouts_buffer_and_driver( env_model, random_policy, horizon) driver.run(wrapped_env_model.reset()) trajectory = replay_buffer.gather_all() mid_steps = repeat(1, horizon - 1) expected_step_types = tf.constant(list(chain([0], mid_steps, [2]))) batched_step_types = replicate(expected_step_types, (batch_size, )) np.testing.assert_array_equal(batched_step_types, trajectory.step_type)
def _create_wrapped_environment(observation_space, action_space, reward): network = LinearTransitionNetwork(observation_space) model = KerasTransitionModel([network], observation_space, action_space) return EnvironmentModel( model, ConstantReward(observation_space, action_space, reward), ConstantFalseTermination(observation_space), create_uniform_initial_state_distribution(observation_space), )
def test_step_call_goal_state_transform( transition_network, observation_space_latent_obs, action_space_latent_obs, batch_size, ensemble_size, ): latent_observation_space_spec = BoundedTensorSpec( shape=observation_space_latent_obs.shape[:-1] + [observation_space_latent_obs.shape[-1] - 1], dtype=observation_space_latent_obs.dtype, minimum=observation_space_latent_obs.minimum, maximum=observation_space_latent_obs.maximum, name=observation_space_latent_obs.name, ) network_list = [ transition_network(latent_observation_space_spec, bootstrap_data=True) for _ in range(ensemble_size) ] observation_transformation = GoalStateObservationTransformation( latent_observation_space_spec=latent_observation_space_spec, goal_state_start_index=-1, ) transition_model = KerasTransitionModel( network_list, observation_space_latent_obs, action_space_latent_obs, predict_state_difference=True, trajectory_sampling_strategy=OneStepTrajectorySampling(batch_size, ensemble_size), observation_transformation=observation_transformation, ) observation_distribution = create_uniform_distribution_from_spec( observation_space_latent_obs ) observations = observation_distribution.sample((batch_size,)) action_distribution = create_uniform_distribution_from_spec(action_space_latent_obs) actions = action_distribution.sample((batch_size,)) next_observations = transition_model.step(observations, actions) assert next_observations.shape == (batch_size,) + observation_space_latent_obs.shape assert observation_space_latent_obs.is_compatible_with(next_observations[0]) tf.assert_equal(next_observations[..., -1], observations[..., -1])
def get_cross_entropy_policy(observation_space, action_space, horizon, batch_size): time_step_space = time_step_spec(observation_space) network = LinearTransitionNetwork(observation_space) model = KerasTransitionModel([network], observation_space, action_space) env_model = EnvironmentModel( model, ConstantReward(observation_space, action_space), ConstantFalseTermination(observation_space), create_uniform_initial_state_distribution(observation_space), batch_size, ) policy = CrossEntropyMethodPolicy(time_step_space, action_space, horizon, batch_size) return env_model, policy
def _transition_fixture(mountain_car_environment, batch_size): network = LinearTransitionNetwork( mountain_car_environment.observation_spec()) transition_model = KerasTransitionModel( [network], mountain_car_environment.observation_spec(), mountain_car_environment.action_spec(), ) reward_model = ConstantReward(mountain_car_environment.observation_spec(), mountain_car_environment.action_spec()) transition = sample_uniformly_distributed_transitions( transition_model, 2 * batch_size, reward_model) return mountain_car_environment, transition
def test_mismatch_ensemble_size(observation_space, action_space, trajectory_sampling_strategy_factory, batch_size): """ Ensure that the ensemble size specified in the trajectory sampling strategy is equal to the number of networks in the models. """ strategy = trajectory_sampling_strategy_factory(batch_size, 2) if isinstance(strategy, SingleFunction): pytest.skip("SingleFunction strategy is not an ensemble strategy.") with pytest.raises(AssertionError): KerasTransitionModel( [LinearTransitionNetwork(observation_space)], observation_space, action_space, trajectory_sampling_strategy=strategy, )
def test_train_method_increments_counter_for_generic_background_planning( mocker, agent_class): """ The docstring for the `_train` method of a TFAgent requires that the implementation increments the `train_step_counter`. """ population_size = 1 horizon = 10 model_free_training_iterations = 1 mf_agent = create_mock_model_free_agent(mocker, TIMESTEP_SPEC, ACTION_SPEC, agent_class) network = LinearTransitionNetwork(OBSERVATION_SPEC) transition_model = KerasTransitionModel([network], OBSERVATION_SPEC, ACTION_SPEC) reward_model = ConstantReward(OBSERVATION_SPEC, ACTION_SPEC) initial_state_model = create_uniform_initial_state_distribution( OBSERVATION_SPEC) train_step_counter = common.create_variable("train_step_counter", shape=(), dtype=tf.float64) model_based_agent = BackgroundPlanningAgent( (transition_model, TransitionModelTrainingSpec(1, 1)), reward_model, initial_state_model, mf_agent, population_size, horizon, model_free_training_iterations, train_step_counter=train_step_counter, ) dummy_trajectories = generate_dummy_trajectories( OBSERVATION_SPEC, ACTION_SPEC, batch_size=population_size, trajectory_length=horizon) train_kwargs = { TRAIN_ARGSPEC_COMPONENT_ID: EnvironmentModelComponents.TRANSITION.value } model_based_agent.train(dummy_trajectories, **train_kwargs) assert train_step_counter.value() == 1
def test_planning_policy_batch_environment_model(): """ Ensure that planning policy is operational. """ # number of trajectories for planning and planning horizon population_size = 3 planner_horizon = 5 number_of_particles = 1 # setup the environment and a model of it py_env = suite_gym.load("MountainCar-v0") tf_env = TFPyEnvironment(py_env) reward = MountainCarReward(tf_env.observation_spec(), tf_env.action_spec()) terminates = MountainCarTermination(tf_env.observation_spec()) network = LinearTransitionNetwork(tf_env.observation_spec()) transition_model = KerasTransitionModel( [network], tf_env.observation_spec(), tf_env.action_spec(), ) initial_state = MountainCarInitialState(tf_env.observation_spec()) environment_model = EnvironmentModel( transition_model=transition_model, reward_model=reward, termination_model=terminates, initial_state_distribution_model=initial_state, ) # setup the trajectory optimiser random_policy = RandomTFPolicy(tf_env.time_step_spec(), tf_env.action_spec()) trajectory_optimiser = PolicyTrajectoryOptimiser(random_policy, planner_horizon, population_size, number_of_particles) planning_policy = PlanningPolicy(environment_model, trajectory_optimiser) # test whether it runs collect_driver_planning_policy = DynamicEpisodeDriver(tf_env, planning_policy, num_episodes=1) time_step = tf_env.reset() collect_driver_planning_policy.run(time_step)
def _wrapped_environment_fixture(observation_space, action_space, batch_size): observation = create_uniform_distribution_from_spec( observation_space).sample() network = DummyEnsembleTransitionNetwork(observation_space) model = KerasTransitionModel([network], observation_space, action_space) env_model = EnvironmentModel( transition_model=model, reward_model=ConstantReward(observation_space, action_space, -1.0), termination_model=ConstantFalseTermination(observation_space), initial_state_distribution_model=DeterministicInitialStateModel( observation), batch_size=batch_size, ) wrapped_environment_model = TFTimeLimit(env_model, 2) action = create_uniform_distribution_from_spec(action_space).sample( (batch_size, )) return wrapped_environment_model, action
def test_random_shooting_with_dynamic_step_driver(observation_space, action_space): """ This test uses the environment wrapper as an adapter so that a driver from TF-Agents can be used to generate a rollout. This also serves as an example of how to construct "random shooting" rollouts from an environment model. The assertion in this test is that selected action has the expected log_prob value consistent with optimisers from a uniform distribution. All this is really checking is that the preceeding code has run successfully. """ network = LinearTransitionNetwork(observation_space) environment = KerasTransitionModel([network], observation_space, action_space) wrapped_environment = EnvironmentModel( environment, ConstantReward(observation_space, action_space, 0.0), ConstantFalseTermination(observation_space), create_uniform_initial_state_distribution(observation_space), ) random_policy = RandomTFPolicy( wrapped_environment.time_step_spec(), action_space, emit_log_probability=True ) transition_observer = _RecordLastLogProbTransitionObserver() driver = DynamicStepDriver( env=wrapped_environment, policy=random_policy, transition_observers=[transition_observer], ) driver.run() last_log_prob = transition_observer.last_log_probability uniform_distribution = create_uniform_distribution_from_spec(action_space) action_log_prob = uniform_distribution.log_prob(transition_observer.action) expected = np.sum(action_log_prob.numpy().astype(np.float32)) actual = np.sum(last_log_prob.numpy()) np.testing.assert_array_almost_equal(actual, expected, decimal=4)
def test_train_method_increments_counter_for_model_free_supported_agents( mocker, agent_class, train_component ): """ The docstring for the `_train` method of a TFAgent requires that the implementation increments the `train_step_counter`. """ population_size = 1 number_of_particles = 1 horizon = 10 mf_agent = create_mock_model_free_agent(mocker, TIMESTEP_SPEC, ACTION_SPEC, agent_class) trajectory_optimiser = random_shooting_trajectory_optimisation( TIMESTEP_SPEC, ACTION_SPEC, horizon, population_size, number_of_particles ) network = LinearTransitionNetwork(OBSERVATION_SPEC) transition_model = KerasTransitionModel([network], OBSERVATION_SPEC, ACTION_SPEC) reward_model = ConstantReward(OBSERVATION_SPEC, ACTION_SPEC) initial_state_model = create_uniform_initial_state_distribution(OBSERVATION_SPEC) train_step_counter = common.create_variable( "train_step_counter", shape=(), dtype=tf.float64 ) agent = ModelFreeSupportedDecisionTimePlanningAgent( TIMESTEP_SPEC, ACTION_SPEC, (transition_model, TransitionModelTrainingSpec(1, 1)), reward_model, initial_state_model, trajectory_optimiser, mf_agent, train_step_counter=train_step_counter, ) dummy_trajectories = generate_dummy_trajectories( OBSERVATION_SPEC, ACTION_SPEC, batch_size=population_size, trajectory_length=horizon ) train_kwargs = {TRAIN_ARGSPEC_COMPONENT_ID: train_component.value} agent.train(dummy_trajectories, **train_kwargs) assert train_step_counter.value() == 1
def test_sample_trajectory_for_mountain_car(): tf_env = tf_py_environment.TFPyEnvironment( suite_gym.load("MountainCar-v0")) network = LinearTransitionNetwork(tf_env.observation_spec()) model = KerasTransitionModel( [network], tf_env.observation_spec(), tf_env.action_spec(), ) reward = ConstantReward(tf_env.observation_spec(), tf_env.action_spec(), -1.0) terminates = MountainCarTermination(tf_env.observation_spec()) initial_state_sampler = MountainCarInitialState(tf_env.observation_spec()) environment = TFTimeLimit(EnvironmentModel(model, reward, terminates, initial_state_sampler), duration=200) collect_policy = RandomTFPolicy(tf_env.time_step_spec(), tf_env.action_spec()) replay_buffer_capacity = 1001 policy_training_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_policy.trajectory_spec, batch_size=1, max_length=replay_buffer_capacity) collect_episodes_per_iteration = 2 collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( environment, collect_policy, observers=[policy_training_buffer.add_batch], num_episodes=collect_episodes_per_iteration, ) collect_driver.run() trajectory = policy_training_buffer.gather_all() first_batch_step_type = trajectory.step_type[0, :] assert (first_batch_step_type[0] == StepType.FIRST and first_batch_step_type[-1] == StepType.LAST)
def test_tf_time_limit_wrapper_with_environment_model(observation_space, action_space, trajectory_length): """ This test checks that the environment wrapper can in turn be wrapped by the `TimeLimit` environment wrapper from TF-Agents. """ ts_spec = time_step_spec(observation_space) network = LinearTransitionNetwork(observation_space) environment = KerasTransitionModel([network], observation_space, action_space) wrapped_environment = TFTimeLimit( EnvironmentModel( environment, ConstantReward(observation_space, action_space, 0.0), ConstantFalseTermination(observation_space), create_uniform_initial_state_distribution(observation_space), ), trajectory_length, ) collect_policy = RandomTFPolicy(ts_spec, action_space) replay_buffer_capacity = 1001 policy_training_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_policy.trajectory_spec, batch_size=1, max_length=replay_buffer_capacity) collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( wrapped_environment, collect_policy, observers=[policy_training_buffer.add_batch], num_episodes=1, ) collect_driver.run() trajectories = policy_training_buffer.gather_all() assert trajectories.step_type.shape == (1, trajectory_length + 1)
def test_invalid_num_elites(observation_space, action_space, horizon): # some fixed parameters population_size = 10 number_of_particles = 1 # set up the environment model network = LinearTransitionNetwork(observation_space) model = KerasTransitionModel([network], observation_space, action_space) environment_model = EnvironmentModel( model, ConstantReward(observation_space, action_space), ConstantFalseTermination(observation_space), create_uniform_initial_state_distribution(observation_space), population_size, ) # set up the trajectory optimizer time_step_space = time_step_spec(observation_space) optimiser = cross_entropy_method_trajectory_optimisation( time_step_space, action_space, horizon=horizon, population_size=population_size, number_of_particles=number_of_particles, num_elites=population_size + 1, learning_rate=0.1, max_iterations=1, ) # remember the time step comes from the real environment with batch size 1 observation = create_uniform_distribution_from_spec( observation_space).sample(sample_shape=(1, )) initial_time_step = restart(observation, batch_size=1) # run with pytest.raises(AssertionError) as excinfo: optimiser.optimise(initial_time_step, environment_model) assert "num_elites" in str(excinfo)
""" # %% batch_size = 64 training_spec = KerasTrainingSpec( epochs=5000, training_batch_size=256, callbacks=[tf.keras.callbacks.EarlyStopping(monitor="loss", patience=3)], verbose=0, ) linear_transition_network = LinearTransitionNetwork(tf_env.observation_spec()) trajectory_sampling_strategy = InfiniteHorizonTrajectorySampling(batch_size, 1) transition_model = KerasTransitionModel( [linear_transition_network], tf_env.observation_spec(), tf_env.action_spec(), ) reward_model = ConstantReward(tf_env.observation_spec(), tf_env.action_spec()) sample_transitions = sample_uniformly_distributed_transitions( transition_model, 1000, reward_model ) # %% plot_mountain_car_transitions( sample_transitions.observation.numpy(), sample_transitions.action.numpy(), sample_transitions.next_observation.numpy(), ) # %% [markdown]
def build_transition_model_and_training_spec_from_type( observation_spec: types.NestedTensorSpec, action_spec: types.NestedTensorSpec, transition_model_type: TransitionModelType, num_hidden_layers: int, num_hidden_nodes: int, activation_function: Callable, ensemble_size: int, predict_state_difference: bool, epochs: int, training_batch_size: int, callbacks: List[tf.keras.callbacks.Callback], trajectory_sampler: TrajectorySamplingStrategy, observation_transformation: Optional[ObservationTransformation] = None, verbose: int = 0, ) -> Tuple[KerasTransitionModel, KerasTrainingSpec]: """ Custom function to build a keras transition model plus training spec from arguments. :param observation_spec: A nest of BoundedTensorSpec representing the observations. :param action_spec: A nest of BoundedTensorSpec representing the actions. :param transition_model_type: An indicator which of the available transition models should be used - list can be found in `TransitionModelType`. A component of the environment model that describes the transition dynamics. :param num_hidden_layers: A transition model parameter, used for constructing a neural network. A number of hidden layers in the neural network. :param num_hidden_nodes: A transition model parameter, used for constructing a neural network. A number of nodes in each hidden layer. Parameter is shared across all layers. :param activation_function: A transition model parameter, used for constructing a neural network. An activation function of the hidden nodes. :param ensemble_size: A transition model parameter, used for constructing a neural network. The number of networks in the ensemble. :param predict_state_difference: A transition model parameter, used for constructing a neural network. A boolean indicating whether transition model will be predicting a difference between current and a next state or the next state directly. :param epochs: A transition model parameter, used by Keras fit method. A number of epochs used for training the neural network. :param training_batch_size: A transition model parameter, used by Keras fit method. A batch size used for training the neural network. :param callbacks: A transition model parameter, used by Keras fit method. A list of Keras callbacks used for training the neural network. :param trajectory_sampler: Trajectory sampler determines how predictions from an ensemble of neural networks that model the transition dynamics are sampled. Works only with ensemble type of transition models. :param observation_transformation: To transform observations to latent observations that are used by the transition model, and back. None will internally create an identity transform. :param verbose: A transition model parameter, used by Keras fit method. A level of how detailed the output to the console/logger is during the training. :return: The keras transition model object and the corresponding training spec. """ networks: List[KerasTransitionNetwork] = None if transition_model_type == TransitionModelType.Deterministic: networks = [ MultilayerFcTransitionNetwork( observation_spec, num_hidden_layers, [num_hidden_nodes] * num_hidden_layers, [activation_function] * num_hidden_layers, ) ] transition_model = KerasTransitionModel( networks, observation_spec, action_spec, predict_state_difference=predict_state_difference, observation_transformation=observation_transformation, ) elif transition_model_type == TransitionModelType.DeterministicEnsemble: networks = [ MultilayerFcTransitionNetwork( observation_spec, num_hidden_layers, [num_hidden_nodes] * num_hidden_layers, [activation_function] * num_hidden_layers, bootstrap_data=True, ) for _ in range(ensemble_size) ] transition_model = KerasTransitionModel( networks, observation_spec, action_spec, predict_state_difference=predict_state_difference, observation_transformation=observation_transformation, trajectory_sampling_strategy=trajectory_sampler, ) elif transition_model_type == TransitionModelType.Probabilistic: networks = [ DiagonalGaussianTransitionNetwork( observation_spec, num_hidden_layers, [num_hidden_nodes] * num_hidden_layers, [activation_function] * num_hidden_layers, ) ] transition_model = KerasTransitionModel( networks, observation_spec, action_spec, predict_state_difference=predict_state_difference, observation_transformation=observation_transformation, ) elif transition_model_type == TransitionModelType.ProbabilisticEnsemble: networks = [ DiagonalGaussianTransitionNetwork( observation_spec, num_hidden_layers, [num_hidden_nodes] * num_hidden_layers, [activation_function] * num_hidden_layers, bootstrap_data=True, ) for _ in range(ensemble_size) ] transition_model = KerasTransitionModel( networks, observation_spec, action_spec, predict_state_difference=predict_state_difference, observation_transformation=observation_transformation, trajectory_sampling_strategy=trajectory_sampler, ) else: raise RuntimeError("Unknown transition model") training_spec = KerasTrainingSpec( epochs=epochs, training_batch_size=training_batch_size, callbacks=callbacks, verbose=verbose, ) return transition_model, training_spec