def test_inference_override_configuration_values(grim_config):
    """Test for correct handling of the '--inference' argument override."""

    # We test the override handling of the '--inference' argument by itself as it will affect the result of the '--export-path' argument.

    grim_config['--export-path'] = 'UnitySDK/Assets/ML-Agents/Examples/3DBall/ImportedModels'

    override_args = Namespace(
        configuration_file='config/3DBall_grimagents.json',
        trainer_config=None,
        resume=False,
        env=None,
        sampler=None,
        lesson=None,
        run_id=None,
        base_port=None,
        num_envs=None,
        inference=True,
        graphics=None,
        no_graphics=None,
        timestamp=None,
        no_timestamp=None,
        multi_gpu=None,
        no_multi_gpu=None,
        args=[],
    )

    arguments = TrainingWrapperArguments(grim_config)
    arguments.apply_argument_overrides(override_args)

    result = arguments.get_arguments()

    # The absolute path to training_wrapper.py will differ based on the system running this test.
    result[3] = 'grimagents/training_wrapper.py'

    # The '--inference' argument should be present and '--export-path' should not.
    assert result == [
        'pipenv',
        'run',
        'python',
        'grimagents/training_wrapper.py',
        'config/3DBall.yaml',
        '--run-id',
        '3DBall',
        '--inference',
    ]
def test_resume_override_configuration_values(grim_config):
    """Test for correct handling of the '--resume' argument."""

    # We test handling of the '--resume' argument by itself as it will affect the result
    # of the timestamp and inference arguments.

    override_args = Namespace(
        configuration_file='config/3DBall_grimagents.json',
        trainer_config=None,
        resume=True,
        inference=True,
        timestamp=True,
        env=None,
        sampler=None,
        lesson=None,
        run_id=None,
        base_port=None,
        num_envs=None,
        graphics=None,
        no_graphics=None,
        no_timestamp=None,
        multi_gpu=None,
        no_multi_gpu=None,
        args=[],
    )

    arguments = TrainingWrapperArguments(grim_config)
    arguments.apply_argument_overrides(override_args)

    result = arguments.get_arguments()

    # The absolute path to training_wrapper.py will differ based on the system running this test.
    result[3] = 'grimagents/training_wrapper.py'

    # The inference argument should not be present and a timestamp should not be applied.
    assert result == [
        'pipenv',
        'run',
        'python',
        'grimagents/training_wrapper.py',
        'config/3DBall.yaml',
        '--run-id',
        '3DBall',
        '--resume',
    ]
def test_override_configuration_values(grim_config):
    """Test that TrainingWrapperArguments correctly applies argument overrides, including:
    --base-port
    --env
    --multi-gpu
    --no-graphics
    --num-envs
    --run-id
    --timestamp
    --trainer-config
    """

    grim_config['--base-port'] = 5010
    grim_config['--env'] = 'builds/3DBall/3DBall.exe'
    grim_config['--multi-gpu'] = True
    grim_config['--no-graphics'] = False
    grim_config['--num-envs'] = 2
    grim_config['--timestamp'] = True

    override_args = Namespace(
        args=[],
        base_port=6010,
        configuration_file='config/3DBall_grimagents.json',
        env='builds/PushBlock/PushBlock.exe',
        graphics=None,
        inference=False,
        multi_gpu=None,
        no_graphics=True,
        no_multi_gpu=True,
        no_timestamp=True,
        num_envs=4,
        resume=False,
        run_id='PushBlock',
        timestamp=None,
        trainer_config='config/PushBlock_grimagents.json',
    )

    arguments = TrainingWrapperArguments(grim_config)
    arguments.set_additional_arguments(override_args.args)
    arguments.apply_argument_overrides(override_args)

    result = arguments.get_arguments()

    # The absolute path to training_wrapper.py will differ based on the system running this test.
    result[3] = 'grimagents/training_wrapper.py'

    assert result == [
        'pipenv',
        'run',
        'python',
        'grimagents/training_wrapper.py',
        'config/PushBlock_grimagents.json',
        '--run-id',
        'PushBlock',
        '--env',
        'builds/PushBlock/PushBlock.exe',
        '--base-port',
        6010,
        '--num-envs',
        '4',
        '--no-graphics',
    ]