def test_modify_values_yaml_without_pod_count(mocker): open_mock = mocker.patch("builtins.open", new_callable=mock.mock_open, read_data=TEST_YAML_FILE_WITHOUT_POD_COUNT) sh_move_mock = mocker.patch("shutil.move") yaml_dump_mock = mocker.patch("yaml.safe_dump") tf_training.modify_values_yaml(experiment_folder=EXPERIMENT_FOLDER, script_location=SCRIPT_LOCATION, script_parameters=SCRIPT_PARAMETERS, pack_params=PACK_PARAMETERS, experiment_name='test-experiment', pack_type=EXAMPLE_PACK_TYPE, cluster_registry_port=1111, env_variables=None, username='******') assert sh_move_mock.call_count == 1, "job yaml file wasn't moved." output = yaml_dump_mock.call_args[0][0] compare_yaml(output["commandline"]["args"], SCRIPT_LOCATION) assert 'key1' and 'key2' in output assert output['key1'] == 'val1' assert output['key2'] == ["a", "b"] assert yaml_dump_mock.call_count == 1, "job yaml wasn't modified" assert open_mock.call_count == 2, "files weren't read/written" assert all(EXAMPLE_PACK_TYPE in call[0][0] for call in open_mock.call_args_list) assert output['podCount'] == 3 or int(output['podCount']) == 3 assert output['workersCount'] == 2 or int(output['workersCount']) == 2 assert output['pServersCount'] == 1 or int(output['pServersCount']) == 1
def test_modify_values_yaml_raise_error_if_bad_argument(mocker): open_mock = mocker.patch("builtins.open", new_callable=mock.mock_open, read_data=TEST_YAML_FILE) sh_move_mock = mocker.patch("shutil.move") yaml_dump_mock = mocker.patch("yaml.dump") wrong_pack_params = [("key1", "{ bad list")] with pytest.raises(AttributeError): tf_training.modify_values_yaml(experiment_folder=EXPERIMENT_FOLDER, script_location=SCRIPT_LOCATION, script_parameters=SCRIPT_PARAMETERS, pack_params=wrong_pack_params, experiment_name='test-experiment', username='******', pack_type=EXAMPLE_PACK_TYPE, cluster_registry_port=1111, env_variables=None) assert sh_move_mock.call_count == 0, "job yaml should not be moved." assert yaml_dump_mock.call_count == 0, "yaml should not be modified." assert all(EXAMPLE_PACK_TYPE in call[0][0] for call in open_mock.call_args_list)