示例#1
0
def test_list_datasets_infinite_loop_detect(paginated_collection,
                                            paginated_session):
    # Given
    batch_size = 100
    datasets_data = DatasetDataFactory.create_batch(batch_size)
    # duplicate the data, this simulates an API that keeps returning the first page
    datasets_data.extend(datasets_data)
    paginated_session.set_response(datasets_data)

    # When
    datasets = list(paginated_collection.list())

    # Then
    assert 2 == paginated_session.num_calls  # duplicate UID detected on the second call
    expected_first_call = FakeCall(method='GET',
                                   path='projects/{}/datasets'.format(
                                       paginated_collection.project_id),
                                   params={'per_page': batch_size})
    expected_last_call = FakeCall(method='GET',
                                  path='projects/{}/datasets'.format(
                                      paginated_collection.project_id),
                                  params={
                                      'page': 2,
                                      'per_page': batch_size
                                  })
    assert expected_first_call == paginated_session.calls[0]
    assert expected_last_call == paginated_session.last_call
    assert len(datasets) == batch_size

    expected_uids = [d['id'] for d in datasets_data[0:batch_size]]
    dataset_ids = [str(d.uid) for d in datasets]
    assert dataset_ids == expected_uids
示例#2
0
def test_search_projects_with_pagination(paginated_collection, paginated_session):
    # Given
    common_name = "same name"

    same_name_projects_data = ProjectDataFactory.create_batch(35, name=common_name)
    more_data = ProjectDataFactory.create_batch(35, name="some other name")


    per_page = 10

    paginated_session.set_response({ 'projects': same_name_projects_data })

    search_params = {'status': {
        'value': common_name,
        'search_method': 'EXACT'}}


    # When
    projects = list(paginated_collection.search(per_page=per_page, search_params=search_params))

    # Then
    assert 4 == paginated_session.num_calls
    expected_first_call = FakeCall(method='POST', path='/projects/search', 
                                        params={'per_page': per_page}, json={'search_params': search_params} )
    expected_last_call = FakeCall(method='POST', path='/projects/search', 
                                        params={'page': 4, 'per_page': per_page}, json={'search_params': search_params})

    assert expected_first_call == paginated_session.calls[0]
    assert expected_last_call == paginated_session.last_call

    project_ids = [str(p.uid) for p in projects]
    expected_ids = [p['id'] for p in same_name_projects_data]

    assert project_ids == expected_ids
示例#3
0
def test_list_datasets(paginated_collection, paginated_session):
    # Given
    datasets_data = DatasetDataFactory.create_batch(50)
    paginated_session.set_response(datasets_data)

    # When
    datasets = list(paginated_collection.list(per_page=20))

    # Then
    assert 3 == paginated_session.num_calls
    expected_first_call = FakeCall(method='GET',
                                   path='projects/{}/datasets'.format(
                                       paginated_collection.project_id),
                                   params={'per_page': 20})
    expected_last_call = FakeCall(method='GET',
                                  path='projects/{}/datasets'.format(
                                      paginated_collection.project_id),
                                  params={
                                      'page': 3,
                                      'per_page': 20
                                  })
    assert expected_first_call == paginated_session.calls[0]
    assert expected_last_call == paginated_session.last_call
    assert 50 == len(datasets)

    expected_uids = [d['id'] for d in datasets_data]
    dataset_ids = [str(d.uid) for d in datasets]
    assert dataset_ids == expected_uids
示例#4
0
def test_default_for_material(collection: AraDefinitionCollection, session):
    """Test that default for material hits the right route"""
    # Given
    project_id = '6b608f78-e341-422c-8076-35adc8828545'
    dummy_resp = {
        'config': TableConfig(
            name='foo',
            description='foo',
            variables=[],
            columns=[],
            rows=[],
            datasets=[]
        ).dump(),
        'ambiguous': [
            [
                RootIdentifier(name='foo', headers=['foo'], scope='id').dump(),
                IdentityColumn(data_source='foo').dump(),
            ]
        ],
    }
    session.responses.append(dummy_resp)
    collection.default_for_material(
        material='my_id',
        scope='my_scope',
        name='my_name',
        description='my_description',
    )

    assert 1 == session.num_calls
    assert session.last_call == FakeCall(
        method="GET",
        path="projects/{}/table-configs/default".format(project_id),
        params={
            'id': 'my_id',
            'scope': 'my_scope',
            'name': 'my_name',
            'description': 'my_description'
        }
    )
    session.calls.clear()
    session.responses.append(dummy_resp)
    collection.default_for_material(
        material=MaterialRun('foo', uids={'scope': 'id'}),
        scope='ignored',
        name='my_name',
        description='my_description',
    )
    assert 1 == session.num_calls
    assert session.last_call == FakeCall(
        method="GET",
        path="projects/{}/table-configs/default".format(project_id),
        params={
            'id': 'id',
            'scope': 'scope',
            'name': 'my_name',
            'description': 'my_description'
        }
    )
示例#5
0
def test_get_table_metadata(collection, session):
    # Given
    project_id = '6b608f78-e341-422c-8076-35adc8828545'
    gem_table = GemTableDataFactory()
    session.set_response(gem_table)

    # When
    retrieved_table: GemTable = collection.get(gem_table["id"],
                                               gem_table["version"])

    # Then
    assert 1 == session.num_calls
    expect_call = FakeCall(
        method="GET",
        path="projects/{}/display-tables/{}/versions/{}".format(
            project_id, gem_table["id"], gem_table["version"]))
    assert session.last_call == expect_call
    assert str(retrieved_table.uid) == gem_table["id"]
    assert retrieved_table.version == gem_table["version"]
    assert retrieved_table.download_url == gem_table["signed_download_url"]

    # Given
    gem_tables = ListGemTableVersionsDataFactory()
    session.set_response(gem_tables)
    version_number = max([table["version"] for table in gem_tables["tables"]])

    # When
    retrieved_table = collection.get(gem_table["id"])

    # Then
    assert retrieved_table.version == version_number
示例#6
0
def run_noop_gemd_relation_search_test(search_for,
                                       search_with,
                                       collection,
                                       search_fn,
                                       per_page=100):
    """Test that relation searches hit the correct endpoint."""
    collection.session.set_response({'contents': []})
    test_id = 'foo-id'
    test_scope = 'foo-scope'
    result = search_fn(test_id, scope=test_scope)
    if isinstance(result, Iterator):
        # evaluate iterator to make calls happen
        list(result)
    assert collection.session.num_calls == 1
    assert collection.session.last_call == FakeCall(
        method="GET",
        path="projects/{}/{}/{}/{}/{}".format(collection.project_id,
                                              search_with, test_scope, test_id,
                                              search_for),
        params={
            "dataset_id": str(collection.dataset_id),
            "forward": True,
            "ascending": True,
            "per_page": per_page
        })
def test_list(collection: PredictorEvaluationExecutionCollection, session):
    session.set_response({
        "page": 2,
        "per_page": 4,
        "next": "foo",
        "response": []
    })
    predictor_id = uuid.uuid4()
    lst = list(collection.list(2, 4, predictor_id=predictor_id))
    assert len(lst) == 0

    expected_path = '/projects/{}/predictor-evaluation-executions'.format(
        collection.project_id)
    assert session.last_call == FakeCall(method='GET',
                                         path=expected_path,
                                         params={
                                             "page":
                                             2,
                                             "per_page":
                                             4,
                                             "predictor_id":
                                             str(predictor_id),
                                             "workflow_id":
                                             str(collection.workflow_id)
                                         })
def test_file_download(mock_write_file_locally, collection, session):
    """
    Test that downloading a file works as expected.

    It should make the full file path if only a directory is given, make the directory if
    it does not exist, make a call to get the pre-signed URL, and another to download.
    """
    # Given
    filename = 'diagram.pdf'
    url = "http://citrine.com/api/files/123/versions/456"
    file = FileLink.build(FileLinkDataFactory(url=url, filename=filename))
    pre_signed_url = "http://files.citrine.io/secret-codes/jiifema987pjfsda"  # arbitrary
    session.set_response({
        'pre_signed_read_link': pre_signed_url,
    })
    local_path = 'Users/me/some/new/directory/'

    with requests_mock.mock() as mock_get:
        mock_get.get(pre_signed_url, text='0101001')

        # When
        collection.download(file, local_path)

        # When
        assert mock_get.call_count == 1
        expected_call = FakeCall(method='GET', path=url + '/content-link')
        assert expected_call == session.last_call
        assert mock_write_file_locally.call_count == 1
        assert mock_write_file_locally.call_args == call(
            b'0101001', local_path + file.filename)
示例#9
0
def test_restore(workflow, collection):
    collection.restore(workflow.uid)
    expected_path = '/projects/{}/design-workflows/{}/restore'.format(
        collection.project_id, workflow.uid)
    assert collection.session.last_call == FakeCall(method='PUT',
                                                    path=expected_path,
                                                    json={})
def test_delete_contents_ok(dataset):

    job_resp = {'job_id': '1234'}

    failed_job_resp = {
        'job_type': 'batch_delete',
        'status': 'Success',
        'tasks': [],
        'output': {
            # Keep in mind this is a stringified JSON value. Eww.
            'failures': '[]'
        }
    }

    session = dataset.session
    session.set_responses(job_resp, failed_job_resp)

    # When
    del_resp = dataset.delete_contents()

    # Then
    assert len(del_resp) == 0

    # Ensure we made the expected delete call
    expected_call = FakeCall(method='DELETE',
                             path='projects/{}/datasets/{}/contents'.format(
                                 dataset.project_id, dataset.uid))
    assert len(session.calls) == 2
    assert session.calls[0] == expected_call
示例#11
0
def test_filter_by_tags(collection, session):
    # Given
    sample_run = MaterialRunDataFactory()
    session.set_response({'contents': [sample_run]})

    # When
    runs = collection.filter_by_tags(tags=["color"], page=1, per_page=10)

    # Then
    assert 1 == session.num_calls
    expected_call = FakeCall(method='GET',
                             path='projects/{}/material-runs'.format(
                                 collection.project_id),
                             params={
                                 'dataset_id': str(collection.dataset_id),
                                 'tags': ["color"],
                                 'page': 1,
                                 'per_page': 10
                             })
    assert expected_call == session.last_call
    assert 1 == len(runs)
    assert sample_run['uids'] == runs[0].uids

    # When user gives a single string for tags, it should still work.
    session.set_response({'contents': [sample_run]})
    collection.filter_by_tags(tags="color", page=1, per_page=10)

    # Then
    assert session.num_calls == 2
    assert session.last_call == expected_call

    # When user gives multiple tags, should raise NotImplemented Error
    with pytest.raises(NotImplementedError):
        collection.filter_by_tags(tags=["color", "shape"])
示例#12
0
def test_filter_by_spec(collection: MaterialRunCollection, session):
    """
    Test that MaterialRunCollection.filter_by_spec() hits the expected endpoint
    """
    # Given
    project_id = '6b608f78-e341-422c-8076-35adc8828545'
    material_spec = MaterialSpecFactory()
    test_scope = 'id'
    test_id = material_spec.uids[test_scope]
    sample_run = MaterialRunDataFactory(spec=material_spec)
    session.set_response({'contents': [sample_run]})

    # When
    runs = [run for run in collection.filter_by_spec(test_id, per_page=20)]

    # Then
    assert 1 == session.num_calls
    expected_call = FakeCall(
        method="GET",
        path="projects/{}/material-specs/{}/{}/material-runs".format(
            project_id, test_scope, test_id),
        # per_page will be ignored
        params={
            "dataset_id": str(collection.dataset_id),
            "forward": True,
            "ascending": True,
            "per_page": 100
        })
    assert session.last_call == expected_call
    assert runs == [collection.build(sample_run)]
示例#13
0
def test_user_registration(collection, session):
    # given
    user = UserDataFactory()

    session.set_response({'user': user})

    # When
    created_user = collection.register(screen_name=user["screen_name"],
                                       email=user["email"],
                                       position=user["position"],
                                       is_admin=user["is_admin"])

    # Then
    assert 1 == session.num_calls
    expected_call = FakeCall(method='POST',
                             path='/users',
                             json={
                                 'screen_name': user["screen_name"],
                                 'position': user["position"],
                                 'email': user["email"],
                                 'is_admin': user["is_admin"],
                             })

    assert expected_call.json['screen_name'] == created_user.screen_name
    assert expected_call.json['email'] == created_user.email
    assert expected_call.json['position'] == created_user.position
    assert expected_call.json['is_admin'] == created_user.is_admin
def test_register_dataset_with_idempotent_put(collection, session):
    # Given
    name = 'Test Dataset'
    summary = 'testing summary'
    description = 'testing description'
    unique_name = 'foo'
    session.set_response(
        DatasetDataFactory(name=name,
                           summary=summary,
                           description=description,
                           unique_name=unique_name))

    # When
    session.use_idempotent_dataset_put = True
    dataset = collection.register(
        DatasetFactory(name=name,
                       summary=summary,
                       description=description,
                       unique_name=unique_name))

    expected_call = FakeCall(method='PUT',
                             path='projects/{}/datasets'.format(
                                 collection.project_id),
                             json={
                                 'name': name,
                                 'summary': summary,
                                 'description': description,
                                 'unique_name': unique_name
                             })
    assert session.num_calls == 1
    assert expected_call == session.last_call
    assert name == dataset.name
示例#15
0
def test_validate_templates_successful_all_params(collection, session):
    """
    Test that DataObjectCollection.validate_templates() handles a successful return value when
    passing in all params
    """

    # Given
    project_id = '6b608f78-e341-422c-8076-35adc8828545'
    run = MaterialRunFactory(name="validate_templates_successful")
    template = MaterialTemplateFactory()
    unused_process_template = ProcessTemplateFactory()

    # When
    session.set_response("")
    errors = collection.validate_templates(run, template,
                                           unused_process_template)

    # Then
    assert 1 == session.num_calls
    expected_call = FakeCall(
        method="PUT",
        path="projects/{}/material-runs/validate-templates".format(project_id),
        json={
            "dataObject": scrub_none(run.dump()),
            "objectTemplate": scrub_none(template.dump()),
            "ingredientProcessTemplate":
            scrub_none(unused_process_template.dump())
        })
    assert session.last_call == expected_call
    assert errors == []
def test_register_dataset_with_existing_id(collection, session):
    # Given
    name = 'Test Dataset'
    summary = 'testing summary'
    description = 'testing description'
    session.set_response(
        DatasetDataFactory(name=name, summary=summary,
                           description=description))

    # When
    dataset = DatasetFactory(name=name,
                             summary=summary,
                             description=description)

    ds_uid = UUID('cafebeef-e341-422c-8076-35adc8828545')
    dataset.uid = ds_uid
    dataset = collection.register(dataset)

    expected_call = FakeCall(method='PUT',
                             path='projects/{}/datasets/{}'.format(
                                 collection.project_id, ds_uid),
                             json={
                                 'name': name,
                                 'summary': summary,
                                 'description': description,
                                 'id': str(ds_uid)
                             })
    assert session.num_calls == 1
    assert expected_call == session.last_call
    assert name == dataset.name
示例#17
0
def test_validate_templates_errors(collection, session):
    """
    Test that DataObjectCollection.validate_templates() handles validation errors
    """
    # Given
    project_id = '6b608f78-e341-422c-8076-35adc8828545'
    run = MaterialRunFactory(name="")

    # When
    validation_error = ValidationError(failure_message="you failed",
                                       failure_id="failure_id")
    session.set_response(
        BadRequest(
            "path",
            FakeRequestResponseApiError(400, "Bad Request",
                                        [validation_error])))
    errors = collection.validate_templates(run)

    # Then
    assert 1 == session.num_calls
    expected_call = FakeCall(
        method="PUT",
        path="projects/{}/material-runs/validate-templates".format(project_id),
        json={"dataObject": scrub_none(run.dump())})
    assert session.last_call == expected_call
    assert errors == [validation_error]
示例#18
0
def test_project_registration(collection, session):
    # Given
    create_time = parse('2019-09-10T00:00:00+00:00')
    project_data = ProjectDataFactory(
        name='testing',
        description='A sample project',
        created_at=int(
            create_time.timestamp() *
            1000)  # The lib expects ms since epoch, which is really odd
    )
    session.set_response({'project': project_data})

    # When
    created_project = collection.register('testing')

    # Then
    assert 1 == session.num_calls
    expected_call = FakeCall(method='POST',
                             path='/projects',
                             json={
                                 'name': 'testing',
                                 'description': None,
                                 'id': None,
                                 'status': None,
                                 'created_at': None,
                             })
    assert expected_call == session.last_call

    assert 'A sample project' == created_project.description
    assert 'CREATED' == created_project.status
    assert create_time == created_project.created_at
示例#19
0
def test_search_projects(collection, session):
    # Given
    projects_data = ProjectDataFactory.create_batch(2)

    project_name_to_match = projects_data[0]['name']

    expected_response = list(
        filter(lambda p: p["name"] == project_name_to_match, projects_data))

    session.set_response({'projects': expected_response})

    search_params = {
        'name': {
            'value': project_name_to_match,
            'search_method': 'EXACT'
        }
    }

    # When
    projects = list(collection.search(search_params=search_params))

    # Then
    assert 1 == session.num_calls
    expected_call = FakeCall(method='POST',
                             path='/projects/search',
                             params={'per_page': 100},
                             json={'search_params': search_params})
    assert expected_call == session.last_call
    assert len(expected_response) == len(projects)
示例#20
0
def test_filter_by_attribute_bounds(collection, session):
    # Given
    sample_run = MaterialRunDataFactory()
    session.set_response({'contents': [sample_run]})
    link = LinkByUIDFactory()
    bounds = {link: IntegerBounds(1, 5)}

    # When
    runs = collection.filter_by_attribute_bounds(bounds, page=1, per_page=10)

    # Then
    assert 1 == session.num_calls
    expected_call = FakeCall(
        method='POST',
        path='projects/{}/material-runs/filter-by-attribute-bounds'.format(
            collection.project_id),
        params={
            "page": 1,
            "per_page": 10,
            "dataset_id": str(collection.dataset_id)
        },
        json={
            'attribute_bounds': {
                link.id: {
                    'lower_bound': 1,
                    'upper_bound': 5,
                    'type': 'integer_bounds'
                }
            }
        })
    assert expected_call == session.last_call
    assert 1 == len(runs)
    assert sample_run['uids'] == runs[0].uids
示例#21
0
def test_list_predictors(valid_simple_ml_predictor_data,
                         valid_expression_predictor_data,
                         basic_predictor_report_data):
    # Given
    session = FakeSession()
    collection = PredictorCollection(uuid.uuid4(), session)
    session.set_responses(
        {
            'entries':
            [valid_simple_ml_predictor_data, valid_expression_predictor_data],
            'next':
            ''
        }, basic_predictor_report_data, basic_predictor_report_data)

    # When
    predictors = list(collection.list(per_page=20))

    # Then
    expected_call = FakeCall(method='GET',
                             path='/projects/{}/modules'.format(
                                 collection.project_id),
                             params={
                                 'per_page': 20,
                                 'module_type': 'PREDICTOR'
                             })
    assert 3 == session.num_calls, session.calls  # This is a little strange, the report is fetched eagerly
    assert expected_call == session.calls[0]
    assert len(predictors) == 2
def test_archive(workflow, collection):
    collection.archive(workflow.uid)
    expected_path = '/projects/{}/predictor-evaluation-workflows/archive'.format(
        collection.project_id)
    assert collection.session.last_call == FakeCall(
        method='PUT',
        path=expected_path,
        json={"module_uid": str(workflow.uid)})
def test_restore(workflow_execution, collection):
    collection.restore(workflow_execution.uid)
    expected_path = '/projects/{}/predictor-evaluation-executions/restore'.format(
        collection.project_id)
    assert collection.session.last_call == FakeCall(
        method='PUT',
        path=expected_path,
        json={"module_uid": str(workflow_execution.uid)})
示例#24
0
def test_get_table_config(collection, session):
    """Get table config, with or without version"""

    # Given
    project_id = '6b608f78-e341-422c-8076-35adc8828545'
    table_config_response = TableConfigResponseDataFactory()
    session.set_response(table_config_response)
    defn_id = table_config_response["definition"]["id"]
    ver_number = table_config_response["version"]["version_number"]

    # When
    retrieved_table_config: TableConfig = collection.get(defn_id, ver_number)

    # Then
    assert 1 == session.num_calls
    expect_call = FakeCall(
        method="GET",
        path="projects/{}/ara-definitions/{}/versions/{}".format(
            project_id, defn_id, ver_number))
    assert session.last_call == expect_call
    assert str(retrieved_table_config.config_uid) == defn_id
    assert retrieved_table_config.version_number == ver_number

    # Given
    table_configs_response = ListTableConfigResponseDataFactory()
    defn_id = table_configs_response["definition"]["id"]
    version_number = max([
        version_dict["version_number"]
        for version_dict in table_configs_response["versions"]
    ])
    session.set_response(table_configs_response)

    # When
    retrieved_table_config: TableConfig = collection.get(defn_id)

    # Then
    assert 2 == session.num_calls
    expect_call = FakeCall(method="GET",
                           path="projects/{}/ara-definitions/{}".format(
                               project_id, defn_id))
    assert session.last_call == expect_call
    assert str(retrieved_table_config.config_uid) == defn_id
    assert retrieved_table_config.version_number == version_number
示例#25
0
def test_delete_project(collection, session):
    # Given
    uid = '151199ec-e9aa-49a1-ac8e-da722aaf74c4'

    # When
    resp = collection.delete(uid)

    # Then
    assert 1 == session.num_calls
    expected_call = FakeCall(method='DELETE', path='/projects/{}'.format(uid))
    assert expected_call == session.last_call
def test_list_workflows(basic_design_workflow_data,
                        basic_performance_workflow_data):
    #Given
    session = FakeSession()
    workflow_collection = WorkflowCollection(project_id=uuid.uuid4(),
                                             session=session)
    session.set_responses(
        {
            'entries': [basic_design_workflow_data],
            'next': ''
        },
        {
            'entries': [basic_performance_workflow_data],
            'next': ''
        },
    )

    # When
    workflows = list(workflow_collection.list(per_page=20))

    # Then
    expected_design_call = FakeCall(method='GET',
                                    path='/projects/{}/modules'.format(
                                        workflow_collection.project_id),
                                    params={
                                        'per_page': 20,
                                        'module_type': 'DESIGN_WORKFLOW'
                                    })
    expected_performance_call = FakeCall(
        method='GET',
        path='/projects/{}/modules'.format(workflow_collection.project_id),
        params={
            'per_page': 20,
            'module_type': 'PERFORMANCE_WORKFLOW'
        })
    assert 2 == session.num_calls
    assert expected_design_call == session.calls[0]
    assert expected_performance_call == session.calls[1]
    assert len(workflows) == 2
    assert isinstance(workflows[0], DesignWorkflow)
    assert isinstance(workflows[1], PerformanceWorkflow)
示例#27
0
def test_list_projects_with_page_params(collection, session):
    # Given
    project_data = ProjectDataFactory()
    session.set_response({'projects': [project_data]})

    # When
    list(collection.list(page=3, per_page=10))

    # Then
    assert 1 == session.num_calls
    expected_call = FakeCall(method='GET', path='/projects', params={'page': 3, 'per_page': 10})
    assert expected_call == session.last_call
示例#28
0
def test_list_members(project, session):
    # Given
    user = UserDataFactory()
    session.set_response({'users': [user]})

    # When
    project.list_members()

    # Then
    assert 1 == session.num_calls
    expect_call = FakeCall(method='GET', path='/projects/{}/users'.format(project.uid))
    assert expect_call == session.last_call
示例#29
0
def test_list_projects(collection, session):
    # Given
    projects_data = ProjectDataFactory.create_batch(5)
    session.set_response({'projects': projects_data})

    # When
    projects = list(collection.list())

    # Then
    assert 1 == session.num_calls
    expected_call = FakeCall(method='GET', path='/projects', params={'per_page': 1000})
    assert expected_call == session.last_call
    assert 5 == len(projects)
示例#30
0
def test_creator(project, session):
    # Given
    email = '*****@*****.**'
    session.set_response({'email': email})

    # When
    creator = project.creator()

    # Then
    assert 1 == session.num_calls
    expect_call = FakeCall(method='GET', path='/projects/{}/creator'.format(project.uid))
    assert expect_call == session.last_call
    assert creator == email