def test_list_datasets_infinite_loop_detect(paginated_collection, paginated_session): # Given batch_size = 100 datasets_data = DatasetDataFactory.create_batch(batch_size) # duplicate the data, this simulates an API that keeps returning the first page datasets_data.extend(datasets_data) paginated_session.set_response(datasets_data) # When datasets = list(paginated_collection.list()) # Then assert 2 == paginated_session.num_calls # duplicate UID detected on the second call expected_first_call = FakeCall(method='GET', path='projects/{}/datasets'.format( paginated_collection.project_id), params={'per_page': batch_size}) expected_last_call = FakeCall(method='GET', path='projects/{}/datasets'.format( paginated_collection.project_id), params={ 'page': 2, 'per_page': batch_size }) assert expected_first_call == paginated_session.calls[0] assert expected_last_call == paginated_session.last_call assert len(datasets) == batch_size expected_uids = [d['id'] for d in datasets_data[0:batch_size]] dataset_ids = [str(d.uid) for d in datasets] assert dataset_ids == expected_uids
def test_search_projects_with_pagination(paginated_collection, paginated_session): # Given common_name = "same name" same_name_projects_data = ProjectDataFactory.create_batch(35, name=common_name) more_data = ProjectDataFactory.create_batch(35, name="some other name") per_page = 10 paginated_session.set_response({ 'projects': same_name_projects_data }) search_params = {'status': { 'value': common_name, 'search_method': 'EXACT'}} # When projects = list(paginated_collection.search(per_page=per_page, search_params=search_params)) # Then assert 4 == paginated_session.num_calls expected_first_call = FakeCall(method='POST', path='/projects/search', params={'per_page': per_page}, json={'search_params': search_params} ) expected_last_call = FakeCall(method='POST', path='/projects/search', params={'page': 4, 'per_page': per_page}, json={'search_params': search_params}) assert expected_first_call == paginated_session.calls[0] assert expected_last_call == paginated_session.last_call project_ids = [str(p.uid) for p in projects] expected_ids = [p['id'] for p in same_name_projects_data] assert project_ids == expected_ids
def test_list_datasets(paginated_collection, paginated_session): # Given datasets_data = DatasetDataFactory.create_batch(50) paginated_session.set_response(datasets_data) # When datasets = list(paginated_collection.list(per_page=20)) # Then assert 3 == paginated_session.num_calls expected_first_call = FakeCall(method='GET', path='projects/{}/datasets'.format( paginated_collection.project_id), params={'per_page': 20}) expected_last_call = FakeCall(method='GET', path='projects/{}/datasets'.format( paginated_collection.project_id), params={ 'page': 3, 'per_page': 20 }) assert expected_first_call == paginated_session.calls[0] assert expected_last_call == paginated_session.last_call assert 50 == len(datasets) expected_uids = [d['id'] for d in datasets_data] dataset_ids = [str(d.uid) for d in datasets] assert dataset_ids == expected_uids
def test_default_for_material(collection: AraDefinitionCollection, session): """Test that default for material hits the right route""" # Given project_id = '6b608f78-e341-422c-8076-35adc8828545' dummy_resp = { 'config': TableConfig( name='foo', description='foo', variables=[], columns=[], rows=[], datasets=[] ).dump(), 'ambiguous': [ [ RootIdentifier(name='foo', headers=['foo'], scope='id').dump(), IdentityColumn(data_source='foo').dump(), ] ], } session.responses.append(dummy_resp) collection.default_for_material( material='my_id', scope='my_scope', name='my_name', description='my_description', ) assert 1 == session.num_calls assert session.last_call == FakeCall( method="GET", path="projects/{}/table-configs/default".format(project_id), params={ 'id': 'my_id', 'scope': 'my_scope', 'name': 'my_name', 'description': 'my_description' } ) session.calls.clear() session.responses.append(dummy_resp) collection.default_for_material( material=MaterialRun('foo', uids={'scope': 'id'}), scope='ignored', name='my_name', description='my_description', ) assert 1 == session.num_calls assert session.last_call == FakeCall( method="GET", path="projects/{}/table-configs/default".format(project_id), params={ 'id': 'id', 'scope': 'scope', 'name': 'my_name', 'description': 'my_description' } )
def test_get_table_metadata(collection, session): # Given project_id = '6b608f78-e341-422c-8076-35adc8828545' gem_table = GemTableDataFactory() session.set_response(gem_table) # When retrieved_table: GemTable = collection.get(gem_table["id"], gem_table["version"]) # Then assert 1 == session.num_calls expect_call = FakeCall( method="GET", path="projects/{}/display-tables/{}/versions/{}".format( project_id, gem_table["id"], gem_table["version"])) assert session.last_call == expect_call assert str(retrieved_table.uid) == gem_table["id"] assert retrieved_table.version == gem_table["version"] assert retrieved_table.download_url == gem_table["signed_download_url"] # Given gem_tables = ListGemTableVersionsDataFactory() session.set_response(gem_tables) version_number = max([table["version"] for table in gem_tables["tables"]]) # When retrieved_table = collection.get(gem_table["id"]) # Then assert retrieved_table.version == version_number
def run_noop_gemd_relation_search_test(search_for, search_with, collection, search_fn, per_page=100): """Test that relation searches hit the correct endpoint.""" collection.session.set_response({'contents': []}) test_id = 'foo-id' test_scope = 'foo-scope' result = search_fn(test_id, scope=test_scope) if isinstance(result, Iterator): # evaluate iterator to make calls happen list(result) assert collection.session.num_calls == 1 assert collection.session.last_call == FakeCall( method="GET", path="projects/{}/{}/{}/{}/{}".format(collection.project_id, search_with, test_scope, test_id, search_for), params={ "dataset_id": str(collection.dataset_id), "forward": True, "ascending": True, "per_page": per_page })
def test_list(collection: PredictorEvaluationExecutionCollection, session): session.set_response({ "page": 2, "per_page": 4, "next": "foo", "response": [] }) predictor_id = uuid.uuid4() lst = list(collection.list(2, 4, predictor_id=predictor_id)) assert len(lst) == 0 expected_path = '/projects/{}/predictor-evaluation-executions'.format( collection.project_id) assert session.last_call == FakeCall(method='GET', path=expected_path, params={ "page": 2, "per_page": 4, "predictor_id": str(predictor_id), "workflow_id": str(collection.workflow_id) })
def test_file_download(mock_write_file_locally, collection, session): """ Test that downloading a file works as expected. It should make the full file path if only a directory is given, make the directory if it does not exist, make a call to get the pre-signed URL, and another to download. """ # Given filename = 'diagram.pdf' url = "http://citrine.com/api/files/123/versions/456" file = FileLink.build(FileLinkDataFactory(url=url, filename=filename)) pre_signed_url = "http://files.citrine.io/secret-codes/jiifema987pjfsda" # arbitrary session.set_response({ 'pre_signed_read_link': pre_signed_url, }) local_path = 'Users/me/some/new/directory/' with requests_mock.mock() as mock_get: mock_get.get(pre_signed_url, text='0101001') # When collection.download(file, local_path) # When assert mock_get.call_count == 1 expected_call = FakeCall(method='GET', path=url + '/content-link') assert expected_call == session.last_call assert mock_write_file_locally.call_count == 1 assert mock_write_file_locally.call_args == call( b'0101001', local_path + file.filename)
def test_restore(workflow, collection): collection.restore(workflow.uid) expected_path = '/projects/{}/design-workflows/{}/restore'.format( collection.project_id, workflow.uid) assert collection.session.last_call == FakeCall(method='PUT', path=expected_path, json={})
def test_delete_contents_ok(dataset): job_resp = {'job_id': '1234'} failed_job_resp = { 'job_type': 'batch_delete', 'status': 'Success', 'tasks': [], 'output': { # Keep in mind this is a stringified JSON value. Eww. 'failures': '[]' } } session = dataset.session session.set_responses(job_resp, failed_job_resp) # When del_resp = dataset.delete_contents() # Then assert len(del_resp) == 0 # Ensure we made the expected delete call expected_call = FakeCall(method='DELETE', path='projects/{}/datasets/{}/contents'.format( dataset.project_id, dataset.uid)) assert len(session.calls) == 2 assert session.calls[0] == expected_call
def test_filter_by_tags(collection, session): # Given sample_run = MaterialRunDataFactory() session.set_response({'contents': [sample_run]}) # When runs = collection.filter_by_tags(tags=["color"], page=1, per_page=10) # Then assert 1 == session.num_calls expected_call = FakeCall(method='GET', path='projects/{}/material-runs'.format( collection.project_id), params={ 'dataset_id': str(collection.dataset_id), 'tags': ["color"], 'page': 1, 'per_page': 10 }) assert expected_call == session.last_call assert 1 == len(runs) assert sample_run['uids'] == runs[0].uids # When user gives a single string for tags, it should still work. session.set_response({'contents': [sample_run]}) collection.filter_by_tags(tags="color", page=1, per_page=10) # Then assert session.num_calls == 2 assert session.last_call == expected_call # When user gives multiple tags, should raise NotImplemented Error with pytest.raises(NotImplementedError): collection.filter_by_tags(tags=["color", "shape"])
def test_filter_by_spec(collection: MaterialRunCollection, session): """ Test that MaterialRunCollection.filter_by_spec() hits the expected endpoint """ # Given project_id = '6b608f78-e341-422c-8076-35adc8828545' material_spec = MaterialSpecFactory() test_scope = 'id' test_id = material_spec.uids[test_scope] sample_run = MaterialRunDataFactory(spec=material_spec) session.set_response({'contents': [sample_run]}) # When runs = [run for run in collection.filter_by_spec(test_id, per_page=20)] # Then assert 1 == session.num_calls expected_call = FakeCall( method="GET", path="projects/{}/material-specs/{}/{}/material-runs".format( project_id, test_scope, test_id), # per_page will be ignored params={ "dataset_id": str(collection.dataset_id), "forward": True, "ascending": True, "per_page": 100 }) assert session.last_call == expected_call assert runs == [collection.build(sample_run)]
def test_user_registration(collection, session): # given user = UserDataFactory() session.set_response({'user': user}) # When created_user = collection.register(screen_name=user["screen_name"], email=user["email"], position=user["position"], is_admin=user["is_admin"]) # Then assert 1 == session.num_calls expected_call = FakeCall(method='POST', path='/users', json={ 'screen_name': user["screen_name"], 'position': user["position"], 'email': user["email"], 'is_admin': user["is_admin"], }) assert expected_call.json['screen_name'] == created_user.screen_name assert expected_call.json['email'] == created_user.email assert expected_call.json['position'] == created_user.position assert expected_call.json['is_admin'] == created_user.is_admin
def test_register_dataset_with_idempotent_put(collection, session): # Given name = 'Test Dataset' summary = 'testing summary' description = 'testing description' unique_name = 'foo' session.set_response( DatasetDataFactory(name=name, summary=summary, description=description, unique_name=unique_name)) # When session.use_idempotent_dataset_put = True dataset = collection.register( DatasetFactory(name=name, summary=summary, description=description, unique_name=unique_name)) expected_call = FakeCall(method='PUT', path='projects/{}/datasets'.format( collection.project_id), json={ 'name': name, 'summary': summary, 'description': description, 'unique_name': unique_name }) assert session.num_calls == 1 assert expected_call == session.last_call assert name == dataset.name
def test_validate_templates_successful_all_params(collection, session): """ Test that DataObjectCollection.validate_templates() handles a successful return value when passing in all params """ # Given project_id = '6b608f78-e341-422c-8076-35adc8828545' run = MaterialRunFactory(name="validate_templates_successful") template = MaterialTemplateFactory() unused_process_template = ProcessTemplateFactory() # When session.set_response("") errors = collection.validate_templates(run, template, unused_process_template) # Then assert 1 == session.num_calls expected_call = FakeCall( method="PUT", path="projects/{}/material-runs/validate-templates".format(project_id), json={ "dataObject": scrub_none(run.dump()), "objectTemplate": scrub_none(template.dump()), "ingredientProcessTemplate": scrub_none(unused_process_template.dump()) }) assert session.last_call == expected_call assert errors == []
def test_register_dataset_with_existing_id(collection, session): # Given name = 'Test Dataset' summary = 'testing summary' description = 'testing description' session.set_response( DatasetDataFactory(name=name, summary=summary, description=description)) # When dataset = DatasetFactory(name=name, summary=summary, description=description) ds_uid = UUID('cafebeef-e341-422c-8076-35adc8828545') dataset.uid = ds_uid dataset = collection.register(dataset) expected_call = FakeCall(method='PUT', path='projects/{}/datasets/{}'.format( collection.project_id, ds_uid), json={ 'name': name, 'summary': summary, 'description': description, 'id': str(ds_uid) }) assert session.num_calls == 1 assert expected_call == session.last_call assert name == dataset.name
def test_validate_templates_errors(collection, session): """ Test that DataObjectCollection.validate_templates() handles validation errors """ # Given project_id = '6b608f78-e341-422c-8076-35adc8828545' run = MaterialRunFactory(name="") # When validation_error = ValidationError(failure_message="you failed", failure_id="failure_id") session.set_response( BadRequest( "path", FakeRequestResponseApiError(400, "Bad Request", [validation_error]))) errors = collection.validate_templates(run) # Then assert 1 == session.num_calls expected_call = FakeCall( method="PUT", path="projects/{}/material-runs/validate-templates".format(project_id), json={"dataObject": scrub_none(run.dump())}) assert session.last_call == expected_call assert errors == [validation_error]
def test_project_registration(collection, session): # Given create_time = parse('2019-09-10T00:00:00+00:00') project_data = ProjectDataFactory( name='testing', description='A sample project', created_at=int( create_time.timestamp() * 1000) # The lib expects ms since epoch, which is really odd ) session.set_response({'project': project_data}) # When created_project = collection.register('testing') # Then assert 1 == session.num_calls expected_call = FakeCall(method='POST', path='/projects', json={ 'name': 'testing', 'description': None, 'id': None, 'status': None, 'created_at': None, }) assert expected_call == session.last_call assert 'A sample project' == created_project.description assert 'CREATED' == created_project.status assert create_time == created_project.created_at
def test_search_projects(collection, session): # Given projects_data = ProjectDataFactory.create_batch(2) project_name_to_match = projects_data[0]['name'] expected_response = list( filter(lambda p: p["name"] == project_name_to_match, projects_data)) session.set_response({'projects': expected_response}) search_params = { 'name': { 'value': project_name_to_match, 'search_method': 'EXACT' } } # When projects = list(collection.search(search_params=search_params)) # Then assert 1 == session.num_calls expected_call = FakeCall(method='POST', path='/projects/search', params={'per_page': 100}, json={'search_params': search_params}) assert expected_call == session.last_call assert len(expected_response) == len(projects)
def test_filter_by_attribute_bounds(collection, session): # Given sample_run = MaterialRunDataFactory() session.set_response({'contents': [sample_run]}) link = LinkByUIDFactory() bounds = {link: IntegerBounds(1, 5)} # When runs = collection.filter_by_attribute_bounds(bounds, page=1, per_page=10) # Then assert 1 == session.num_calls expected_call = FakeCall( method='POST', path='projects/{}/material-runs/filter-by-attribute-bounds'.format( collection.project_id), params={ "page": 1, "per_page": 10, "dataset_id": str(collection.dataset_id) }, json={ 'attribute_bounds': { link.id: { 'lower_bound': 1, 'upper_bound': 5, 'type': 'integer_bounds' } } }) assert expected_call == session.last_call assert 1 == len(runs) assert sample_run['uids'] == runs[0].uids
def test_list_predictors(valid_simple_ml_predictor_data, valid_expression_predictor_data, basic_predictor_report_data): # Given session = FakeSession() collection = PredictorCollection(uuid.uuid4(), session) session.set_responses( { 'entries': [valid_simple_ml_predictor_data, valid_expression_predictor_data], 'next': '' }, basic_predictor_report_data, basic_predictor_report_data) # When predictors = list(collection.list(per_page=20)) # Then expected_call = FakeCall(method='GET', path='/projects/{}/modules'.format( collection.project_id), params={ 'per_page': 20, 'module_type': 'PREDICTOR' }) assert 3 == session.num_calls, session.calls # This is a little strange, the report is fetched eagerly assert expected_call == session.calls[0] assert len(predictors) == 2
def test_archive(workflow, collection): collection.archive(workflow.uid) expected_path = '/projects/{}/predictor-evaluation-workflows/archive'.format( collection.project_id) assert collection.session.last_call == FakeCall( method='PUT', path=expected_path, json={"module_uid": str(workflow.uid)})
def test_restore(workflow_execution, collection): collection.restore(workflow_execution.uid) expected_path = '/projects/{}/predictor-evaluation-executions/restore'.format( collection.project_id) assert collection.session.last_call == FakeCall( method='PUT', path=expected_path, json={"module_uid": str(workflow_execution.uid)})
def test_get_table_config(collection, session): """Get table config, with or without version""" # Given project_id = '6b608f78-e341-422c-8076-35adc8828545' table_config_response = TableConfigResponseDataFactory() session.set_response(table_config_response) defn_id = table_config_response["definition"]["id"] ver_number = table_config_response["version"]["version_number"] # When retrieved_table_config: TableConfig = collection.get(defn_id, ver_number) # Then assert 1 == session.num_calls expect_call = FakeCall( method="GET", path="projects/{}/ara-definitions/{}/versions/{}".format( project_id, defn_id, ver_number)) assert session.last_call == expect_call assert str(retrieved_table_config.config_uid) == defn_id assert retrieved_table_config.version_number == ver_number # Given table_configs_response = ListTableConfigResponseDataFactory() defn_id = table_configs_response["definition"]["id"] version_number = max([ version_dict["version_number"] for version_dict in table_configs_response["versions"] ]) session.set_response(table_configs_response) # When retrieved_table_config: TableConfig = collection.get(defn_id) # Then assert 2 == session.num_calls expect_call = FakeCall(method="GET", path="projects/{}/ara-definitions/{}".format( project_id, defn_id)) assert session.last_call == expect_call assert str(retrieved_table_config.config_uid) == defn_id assert retrieved_table_config.version_number == version_number
def test_delete_project(collection, session): # Given uid = '151199ec-e9aa-49a1-ac8e-da722aaf74c4' # When resp = collection.delete(uid) # Then assert 1 == session.num_calls expected_call = FakeCall(method='DELETE', path='/projects/{}'.format(uid)) assert expected_call == session.last_call
def test_list_workflows(basic_design_workflow_data, basic_performance_workflow_data): #Given session = FakeSession() workflow_collection = WorkflowCollection(project_id=uuid.uuid4(), session=session) session.set_responses( { 'entries': [basic_design_workflow_data], 'next': '' }, { 'entries': [basic_performance_workflow_data], 'next': '' }, ) # When workflows = list(workflow_collection.list(per_page=20)) # Then expected_design_call = FakeCall(method='GET', path='/projects/{}/modules'.format( workflow_collection.project_id), params={ 'per_page': 20, 'module_type': 'DESIGN_WORKFLOW' }) expected_performance_call = FakeCall( method='GET', path='/projects/{}/modules'.format(workflow_collection.project_id), params={ 'per_page': 20, 'module_type': 'PERFORMANCE_WORKFLOW' }) assert 2 == session.num_calls assert expected_design_call == session.calls[0] assert expected_performance_call == session.calls[1] assert len(workflows) == 2 assert isinstance(workflows[0], DesignWorkflow) assert isinstance(workflows[1], PerformanceWorkflow)
def test_list_projects_with_page_params(collection, session): # Given project_data = ProjectDataFactory() session.set_response({'projects': [project_data]}) # When list(collection.list(page=3, per_page=10)) # Then assert 1 == session.num_calls expected_call = FakeCall(method='GET', path='/projects', params={'page': 3, 'per_page': 10}) assert expected_call == session.last_call
def test_list_members(project, session): # Given user = UserDataFactory() session.set_response({'users': [user]}) # When project.list_members() # Then assert 1 == session.num_calls expect_call = FakeCall(method='GET', path='/projects/{}/users'.format(project.uid)) assert expect_call == session.last_call
def test_list_projects(collection, session): # Given projects_data = ProjectDataFactory.create_batch(5) session.set_response({'projects': projects_data}) # When projects = list(collection.list()) # Then assert 1 == session.num_calls expected_call = FakeCall(method='GET', path='/projects', params={'per_page': 1000}) assert expected_call == session.last_call assert 5 == len(projects)
def test_creator(project, session): # Given email = '*****@*****.**' session.set_response({'email': email}) # When creator = project.creator() # Then assert 1 == session.num_calls expect_call = FakeCall(method='GET', path='/projects/{}/creator'.format(project.uid)) assert expect_call == session.last_call assert creator == email