def test_training_arguments_add_additional_args(grim_config):
    """Test that TrainingWrapperArguments correctly appends additional arguments."""

    arguments = TrainingWrapperArguments(grim_config)
    additional_args = ['--debug']
    arguments.set_additional_arguments(additional_args)

    # The absolute path to training_wrapper.py will differ based on the system running this test.
    result = arguments.get_arguments()
    result[3] = 'grimagents/training_wrapper.py'

    assert result == [
        'pipenv',
        'run',
        'python',
        'grimagents/training_wrapper.py',
        'config/3DBall.yaml',
        '--run-id',
        '3DBall',
        '--debug',
    ]
def test_override_configuration_values(grim_config):
    """Test that TrainingWrapperArguments correctly applies argument overrides, including:
    --base-port
    --env
    --multi-gpu
    --no-graphics
    --num-envs
    --run-id
    --timestamp
    --trainer-config
    """

    grim_config['--base-port'] = 5010
    grim_config['--env'] = 'builds/3DBall/3DBall.exe'
    grim_config['--multi-gpu'] = True
    grim_config['--no-graphics'] = False
    grim_config['--num-envs'] = 2
    grim_config['--timestamp'] = True

    override_args = Namespace(
        args=[],
        base_port=6010,
        configuration_file='config/3DBall_grimagents.json',
        env='builds/PushBlock/PushBlock.exe',
        graphics=None,
        inference=False,
        multi_gpu=None,
        no_graphics=True,
        no_multi_gpu=True,
        no_timestamp=True,
        num_envs=4,
        resume=False,
        run_id='PushBlock',
        timestamp=None,
        trainer_config='config/PushBlock_grimagents.json',
    )

    arguments = TrainingWrapperArguments(grim_config)
    arguments.set_additional_arguments(override_args.args)
    arguments.apply_argument_overrides(override_args)

    result = arguments.get_arguments()

    # The absolute path to training_wrapper.py will differ based on the system running this test.
    result[3] = 'grimagents/training_wrapper.py'

    assert result == [
        'pipenv',
        'run',
        'python',
        'grimagents/training_wrapper.py',
        'config/PushBlock_grimagents.json',
        '--run-id',
        'PushBlock',
        '--env',
        'builds/PushBlock/PushBlock.exe',
        '--base-port',
        6010,
        '--num-envs',
        '4',
        '--no-graphics',
    ]