コード例 #1
0
def test_list_predictors(valid_simple_ml_predictor_data,
                         valid_expression_predictor_data,
                         basic_predictor_report_data):
    # Given
    session = FakeSession()
    collection = PredictorCollection(uuid.uuid4(), session)
    session.set_responses(
        {
            'entries':
            [valid_simple_ml_predictor_data, valid_expression_predictor_data],
            'next':
            ''
        }, basic_predictor_report_data, basic_predictor_report_data)

    # When
    predictors = list(collection.list(per_page=20))

    # Then
    expected_call = FakeCall(method='GET',
                             path='/projects/{}/modules'.format(
                                 collection.project_id),
                             params={
                                 'per_page': 20,
                                 'module_type': 'PREDICTOR'
                             })
    assert 3 == session.num_calls, session.calls  # This is a little strange, the report is fetched eagerly
    assert expected_call == session.calls[0]
    assert len(predictors) == 2
def test_create_default(predictor_evaluation_workflow_dict: dict,
                        workflow: PredictorEvaluationWorkflow):
    session = FakeSession()
    session.set_response(predictor_evaluation_workflow_dict)
    collection = PredictorEvaluationWorkflowCollection(project_id=uuid.uuid4(),
                                                       session=session)
    default_workflow = collection.create_default(uuid.uuid4())
    assert default_workflow.dump() == workflow.dump()
コード例 #3
0
def test_create_default(valid_product_design_space_data,
                        valid_product_design_space):
    # The instance field isn't renamed to config in objects returned from this route
    # This renames the config key to instance to match the data we get from the API
    data_with_instance = deepcopy(valid_product_design_space_data)
    data_with_instance['instance'] = data_with_instance.pop('config')

    session = FakeSession()
    session.set_response(data_with_instance)
    collection = DesignSpaceCollection(project_id=uuid.uuid4(),
                                       session=session)
    default_design_space = collection.create_default(uuid.uuid4())
    assert default_design_space.dump() == valid_product_design_space.dump()
コード例 #4
0
def test_check_update_none():
    """Test that check-for-updates makes the expected calls, parses output for no update."""
    # Given
    session = FakeSession()
    session.set_response({"updatable": False})
    pc = PredictorCollection(uuid.uuid4(), session)
    predictor_id = uuid.uuid4()

    # when
    update_check = pc.check_for_update(predictor_id)

    # then
    assert update_check is None
    expected_call = FakeCall(method='GET', path='/projects/{}/predictors/{}/check-for-update'.format(pc.project_id, predictor_id))
    assert session.calls[0] == expected_call
コード例 #5
0
def dataset():
    dataset = DatasetFactory(name='Test Dataset')
    dataset.project_id = UUID('6b608f78-e341-422c-8076-35adc8828545')
    dataset.uid = UUID("503d7bf6-8e2d-4d29-88af-257af0d4fe4a")
    dataset.session = FakeSession()

    return dataset
コード例 #6
0
def test_unexpected_pattern():
    """Check that unexpected patterns result in a value error"""
    # Given
    session = FakeSession()
    pc = PredictorCollection(uuid.uuid4(), session)

    # Then
    with pytest.raises(ValueError):
        pc.auto_configure(GemTableDataSource(uuid.uuid4(), 0), "yogurt")
コード例 #7
0
def test_mark_predictor_invalid(valid_simple_ml_predictor_data, valid_predictor_report_data):
    # Given
    session = FakeSession()
    collection = PredictorCollection(uuid.uuid4(), session)
    predictor = SimpleMLPredictor.build(valid_simple_ml_predictor_data)
    session.set_responses(valid_simple_ml_predictor_data, valid_predictor_report_data)

    # When
    predictor.archived = False
    collection.update(predictor)

    # Then
    assert 1 == session.num_calls, session.calls

    first_call = session.calls[0]  # First call is the update
    assert first_call.method == 'PUT'
    assert first_call.path == '/projects/{}/modules/{}'.format(collection.project_id, predictor.uid)
    assert not first_call.json['archived']
コード例 #8
0
def test_list_workflows(basic_design_workflow_data,
                        basic_performance_workflow_data):
    #Given
    session = FakeSession()
    workflow_collection = WorkflowCollection(project_id=uuid.uuid4(),
                                             session=session)
    session.set_responses(
        {
            'entries': [basic_design_workflow_data],
            'next': ''
        },
        {
            'entries': [basic_performance_workflow_data],
            'next': ''
        },
    )

    # When
    workflows = list(workflow_collection.list(per_page=20))

    # Then
    expected_design_call = FakeCall(method='GET',
                                    path='/projects/{}/modules'.format(
                                        workflow_collection.project_id),
                                    params={
                                        'per_page': 20,
                                        'module_type': 'DESIGN_WORKFLOW'
                                    })
    expected_performance_call = FakeCall(
        method='GET',
        path='/projects/{}/modules'.format(workflow_collection.project_id),
        params={
            'per_page': 20,
            'module_type': 'PERFORMANCE_WORKFLOW'
        })
    assert 2 == session.num_calls
    assert expected_design_call == session.calls[0]
    assert expected_performance_call == session.calls[1]
    assert len(workflows) == 2
    assert isinstance(workflows[0], DesignWorkflow)
    assert isinstance(workflows[1], PerformanceWorkflow)
コード例 #9
0
def test_from_predictor_responses():
    session = FakeSession()
    col = 'smiles'
    response_json = {
        'responses': [  # shortened sample response
            {
                'category': 'Real',
                'descriptor_key': 'khs.sNH3 KierHallSmarts for {}'.format(col),
                'units': '',
                'lower_bound': 0,
                'upper_bound': 1000000000
            },
            {
                'category': 'Real',
                'descriptor_key': 'khs.dsN KierHallSmarts for {}'.format(col),
                'units': '',
                'lower_bound': 0,
                'upper_bound': 1000000000
            },
        ]
    }
    session.set_response(response_json)
    descriptors = DescriptorMethods(uuid4(), session)
    featurizer = MolecularStructureFeaturizer(
        name="Molecule featurizer",
        description="description",
        descriptor=MolecularStructureDescriptor(col),
        features=["all"],
        excludes=["standard"])
    results = descriptors.from_predictor_responses(
        featurizer, [MolecularStructureDescriptor(col)])
    assert results == [
        RealDescriptor(
            key=r['descriptor_key'],
            lower_bound=r['lower_bound'],
            upper_bound=r['upper_bound'],
        ) for r in response_json['responses']
    ]
    assert session.last_call.path == '/projects/{}/material-descriptors/predictor-responses'\
        .format(descriptors.project_id)
    assert session.last_call.method == 'POST'
コード例 #10
0
def test_returned_predictor(valid_graph_predictor_data):
    """Check that auto_configure works on the happy path."""
    # Given
    session = FakeSession()

    # Setup a response that includes instance instead of config
    response = deepcopy(valid_graph_predictor_data)
    response["instance"] = response["config"]
    del response["config"]

    session.set_response(response)
    pc = PredictorCollection(uuid.uuid4(), session)

    # When
    result = pc.auto_configure(GemTableDataSource(uuid.uuid4(), 0), "PLAIN")

    # Then the response is parsed in a predictor
    assert result.name == valid_graph_predictor_data["display_name"]
    assert isinstance(result, GraphPredictor)
    # including nested predictors
    assert len(result.predictors) == 2
    assert isinstance(result.predictors[0], uuid.UUID)
    assert isinstance(result.predictors[1], DeprecatedExpressionPredictor)
コード例 #11
0
def test_descriptors_from_data_source():
    session = FakeSession()
    col = 'smiles'
    response_json = {
        'descriptors': [  # shortened sample response
            {
                'category': 'Real',
                'descriptor_key': 'khs.sNH3 KierHallSmarts for {}'.format(col),
                'units': '',
                'lower_bound': 0,
                'upper_bound': 1000000000
            },
            {
                'category': 'Real',
                'descriptor_key': 'khs.dsN KierHallSmarts for {}'.format(col),
                'units': '',
                'lower_bound': 0,
                'upper_bound': 1000000000
            },
        ]
    }
    session.set_response(response_json)
    descriptors = DescriptorMethods(uuid4(), session)
    data_source = GemTableDataSource('43357a66-3644-4959-8115-77b2630aca45',
                                     123)

    results = descriptors.descriptors_from_data_source(data_source)
    assert results == [
        RealDescriptor(key=r['descriptor_key'],
                       lower_bound=r['lower_bound'],
                       upper_bound=r['upper_bound'],
                       units=r['units']) for r in response_json['descriptors']
    ]
    assert session.last_call.path == '/projects/{}/material-descriptors/from-data-source'\
        .format(descriptors.project_id)
    assert session.last_call.method == 'POST'
コード例 #12
0
def test_check_update_some():
    """Test the update check correctly builds a module."""
    # given
    session = FakeSession()
    desc = RealDescriptor("spam", 0, 1, "kg")
    response = {
        "type": "AnalyticExpression",
        "name": "foo",
        "description": "bar",
        "expression": "2 * x",
        "output": RealDescriptor("spam", 0, 1, "kg").dump(),
        "aliases": {}
    }
    session.set_response({"updatable": True, "update": response})
    pc = PredictorCollection(uuid.uuid4(), session)
    predictor_id = uuid.uuid4()

    # when
    update_check = pc.check_for_update(predictor_id)

    # then
    expected = ExpressionPredictor("foo", "bar", "2 * x", desc, {})
    assert update_check.dump() == expected.dump()
    assert update_check.uid == predictor_id
コード例 #13
0
def test_design_space_limits():
    """Test that the validation logic is triggered before post/put-ing enumerated design spaces."""
    # Given
    session = FakeSession()
    collection = DesignSpaceCollection(uuid.uuid4(), session)

    too_big = EnumeratedDesignSpace(
        "foo",
        "bar",
        descriptors=[
            RealDescriptor("R-{}".format(i), 0, 1, "") for i in range(128)
        ],
        data=[{"R-{}".format(i): random()
               for i in range(128)} for _ in range(2001)])

    just_right = EnumeratedDesignSpace(
        "foo",
        "bar",
        descriptors=[
            RealDescriptor("R-{}".format(i), 0, 1, "") for i in range(128)
        ],
        data=[{"R-{}".format(i): random()
               for i in range(128)} for _ in range(2000)])

    # create mock post response by setting the status
    mock_response = just_right.dump()
    mock_response["status"] = "READY"
    session.responses.append(mock_response)

    # Then
    with pytest.raises(ValueError) as excinfo:
        collection.register(too_big)
    assert "only supports" in str(excinfo.value)

    # test register
    collection.register(just_right)

    # add back the response for the next test
    session.responses.append(mock_response)

    # test update
    collection.update(just_right)
コード例 #14
0
def session() -> FakeSession:
    return FakeSession()
コード例 #15
0
def test_user_collection_creation():
    session = FakeSession()

    assert session == UserCollection(session).session
コード例 #16
0
 def __init__(self, num_properties):
     self.project_id = uuid4()
     self.session = FakeSession()
     self.num_properties = num_properties