def test_initialize_trainers(mock_communicator, mock_launcher, dummy_config,
                             dummy_offline_bc_config, dummy_online_bc_config,
                             dummy_bad_config):
    open_name = 'mlagents.trainers.trainer_controller' + '.open'
    with mock.patch('yaml.load') as mock_load:
        with mock.patch(open_name, create=True) as _:
            mock_communicator.return_value = MockCommunicator(
                discrete_action=True, visual_inputs=1)
            tc = TrainerController(' ', ' ', 1, None, True, False, False, 1, 1,
                                   1, 1, '', "tests/test_mlagents.trainers.py",
                                   False)

            # Test for PPO trainer
            mock_load.return_value = dummy_config
            config = tc._load_config()
            tf.reset_default_graph()
            tc._initialize_trainers(config)
            assert (len(tc.trainers) == 1)
            assert (isinstance(tc.trainers['RealFakeBrain'], PPOTrainer))

            # Test for Online Behavior Cloning Trainer
            mock_load.return_value = dummy_online_bc_config
            config = tc._load_config()
            tf.reset_default_graph()
            tc._initialize_trainers(config)
            assert (isinstance(tc.trainers['RealFakeBrain'], OnlineBCTrainer))

            # Test for proper exception when trainer name is incorrect
            mock_load.return_value = dummy_bad_config
            config = tc._load_config()
            tf.reset_default_graph()
            with pytest.raises(UnityEnvironmentException):
                tc._initialize_trainers(config)
def test_load_config(mock_communicator, mock_launcher, dummy_config):
    open_name = 'mlagents.trainers.trainer_controller' + '.open'
    with mock.patch('yaml.load') as mock_load:
        with mock.patch(open_name, create=True) as _:
            mock_load.return_value = dummy_config
            mock_communicator.return_value = MockCommunicator(
                discrete_action=True, visual_inputs=1)
            mock_load.return_value = dummy_config
            tc = TrainerController(' ', ' ', 1, None, True, True, False, 1, 1,
                                   1, 1, '', '', False)
            config = tc._load_config()
            assert (len(config) == 1)
            assert (config['default']['trainer'] == "ppo")
def test_initialize_offline_trainers(mock_communicator, mock_launcher,
                                     dummy_config, dummy_offline_bc_config,
                                     dummy_online_bc_config, dummy_bad_config):
    open_name = 'mlagents.trainers.trainer_controller' + '.open'
    with mock.patch('yaml.load') as mock_load:
        with mock.patch(open_name, create=True) as _:
            mock_communicator.return_value = MockCommunicator(
                discrete_action=False,
                stack=False,
                visual_inputs=0,
                brain_name="Ball3DBrain",
                vec_obs_size=8)
            tc = TrainerController(' ', ' ', 1, None, True, False, False, 1, 1,
                                   1, 1, '', "tests/test_mlagents.trainers.py",
                                   False)

            # Test for Offline Behavior Cloning Trainer
            mock_load.return_value = dummy_offline_bc_config
            config = tc._load_config()
            tf.reset_default_graph()
            tc._initialize_trainers(config)
            assert (isinstance(tc.trainers['Ball3DBrain'], OfflineBCTrainer))