示例#1
0
def test_bad_config():
    brain_params = make_brain_parameters(discrete_action=False,
                                         visual_inputs=0,
                                         vec_obs_size=6)
    # Test that we throw an error if we have sequence length greater than batch size
    with pytest.raises(TrainerConfigError):
        TrainerSettings(
            network_settings=NetworkSettings(
                memory=NetworkSettings.MemorySettings(sequence_length=64)),
            hyperparameters=PPOSettings(batch_size=32),
        )
        _ = PPOTrainer(brain_params, 0, dummy_config, True, False, 0, "0")
    EnvironmentParametersChannel,
)
from mlagents_envs.communicator_objects.demonstration_meta_pb2 import (
    DemonstrationMetaProto,
)
from mlagents_envs.communicator_objects.brain_parameters_pb2 import BrainParametersProto
from mlagents_envs.communicator_objects.space_type_pb2 import discrete, continuous

BRAIN_NAME = "1D"


PPO_CONFIG = TrainerSettings(
    trainer_type=TrainerType.PPO,
    hyperparameters=PPOSettings(
        learning_rate=5.0e-3,
        learning_rate_schedule=ScheduleType.CONSTANT,
        batch_size=16,
        buffer_size=64,
    ),
    network_settings=NetworkSettings(num_layers=1, hidden_units=32),
    summary_freq=500,
    max_steps=3000,
    threaded=False,
)

SAC_CONFIG = TrainerSettings(
    trainer_type=TrainerType.SAC,
    hyperparameters=SACSettings(
        learning_rate=5.0e-3,
        learning_rate_schedule=ScheduleType.CONSTANT,
        batch_size=8,
        buffer_init_steps=100,