def test_list_predictors(valid_simple_ml_predictor_data, valid_expression_predictor_data, basic_predictor_report_data): # Given session = FakeSession() collection = PredictorCollection(uuid.uuid4(), session) session.set_responses( { 'entries': [valid_simple_ml_predictor_data, valid_expression_predictor_data], 'next': '' }, basic_predictor_report_data, basic_predictor_report_data) # When predictors = list(collection.list(per_page=20)) # Then expected_call = FakeCall(method='GET', path='/projects/{}/modules'.format( collection.project_id), params={ 'per_page': 20, 'module_type': 'PREDICTOR' }) assert 3 == session.num_calls, session.calls # This is a little strange, the report is fetched eagerly assert expected_call == session.calls[0] assert len(predictors) == 2
def test_create_default(predictor_evaluation_workflow_dict: dict, workflow: PredictorEvaluationWorkflow): session = FakeSession() session.set_response(predictor_evaluation_workflow_dict) collection = PredictorEvaluationWorkflowCollection(project_id=uuid.uuid4(), session=session) default_workflow = collection.create_default(uuid.uuid4()) assert default_workflow.dump() == workflow.dump()
def test_create_default(valid_product_design_space_data, valid_product_design_space): # The instance field isn't renamed to config in objects returned from this route # This renames the config key to instance to match the data we get from the API data_with_instance = deepcopy(valid_product_design_space_data) data_with_instance['instance'] = data_with_instance.pop('config') session = FakeSession() session.set_response(data_with_instance) collection = DesignSpaceCollection(project_id=uuid.uuid4(), session=session) default_design_space = collection.create_default(uuid.uuid4()) assert default_design_space.dump() == valid_product_design_space.dump()
def test_check_update_none(): """Test that check-for-updates makes the expected calls, parses output for no update.""" # Given session = FakeSession() session.set_response({"updatable": False}) pc = PredictorCollection(uuid.uuid4(), session) predictor_id = uuid.uuid4() # when update_check = pc.check_for_update(predictor_id) # then assert update_check is None expected_call = FakeCall(method='GET', path='/projects/{}/predictors/{}/check-for-update'.format(pc.project_id, predictor_id)) assert session.calls[0] == expected_call
def dataset(): dataset = DatasetFactory(name='Test Dataset') dataset.project_id = UUID('6b608f78-e341-422c-8076-35adc8828545') dataset.uid = UUID("503d7bf6-8e2d-4d29-88af-257af0d4fe4a") dataset.session = FakeSession() return dataset
def test_unexpected_pattern(): """Check that unexpected patterns result in a value error""" # Given session = FakeSession() pc = PredictorCollection(uuid.uuid4(), session) # Then with pytest.raises(ValueError): pc.auto_configure(GemTableDataSource(uuid.uuid4(), 0), "yogurt")
def test_mark_predictor_invalid(valid_simple_ml_predictor_data, valid_predictor_report_data): # Given session = FakeSession() collection = PredictorCollection(uuid.uuid4(), session) predictor = SimpleMLPredictor.build(valid_simple_ml_predictor_data) session.set_responses(valid_simple_ml_predictor_data, valid_predictor_report_data) # When predictor.archived = False collection.update(predictor) # Then assert 1 == session.num_calls, session.calls first_call = session.calls[0] # First call is the update assert first_call.method == 'PUT' assert first_call.path == '/projects/{}/modules/{}'.format(collection.project_id, predictor.uid) assert not first_call.json['archived']
def test_list_workflows(basic_design_workflow_data, basic_performance_workflow_data): #Given session = FakeSession() workflow_collection = WorkflowCollection(project_id=uuid.uuid4(), session=session) session.set_responses( { 'entries': [basic_design_workflow_data], 'next': '' }, { 'entries': [basic_performance_workflow_data], 'next': '' }, ) # When workflows = list(workflow_collection.list(per_page=20)) # Then expected_design_call = FakeCall(method='GET', path='/projects/{}/modules'.format( workflow_collection.project_id), params={ 'per_page': 20, 'module_type': 'DESIGN_WORKFLOW' }) expected_performance_call = FakeCall( method='GET', path='/projects/{}/modules'.format(workflow_collection.project_id), params={ 'per_page': 20, 'module_type': 'PERFORMANCE_WORKFLOW' }) assert 2 == session.num_calls assert expected_design_call == session.calls[0] assert expected_performance_call == session.calls[1] assert len(workflows) == 2 assert isinstance(workflows[0], DesignWorkflow) assert isinstance(workflows[1], PerformanceWorkflow)
def test_from_predictor_responses(): session = FakeSession() col = 'smiles' response_json = { 'responses': [ # shortened sample response { 'category': 'Real', 'descriptor_key': 'khs.sNH3 KierHallSmarts for {}'.format(col), 'units': '', 'lower_bound': 0, 'upper_bound': 1000000000 }, { 'category': 'Real', 'descriptor_key': 'khs.dsN KierHallSmarts for {}'.format(col), 'units': '', 'lower_bound': 0, 'upper_bound': 1000000000 }, ] } session.set_response(response_json) descriptors = DescriptorMethods(uuid4(), session) featurizer = MolecularStructureFeaturizer( name="Molecule featurizer", description="description", descriptor=MolecularStructureDescriptor(col), features=["all"], excludes=["standard"]) results = descriptors.from_predictor_responses( featurizer, [MolecularStructureDescriptor(col)]) assert results == [ RealDescriptor( key=r['descriptor_key'], lower_bound=r['lower_bound'], upper_bound=r['upper_bound'], ) for r in response_json['responses'] ] assert session.last_call.path == '/projects/{}/material-descriptors/predictor-responses'\ .format(descriptors.project_id) assert session.last_call.method == 'POST'
def test_returned_predictor(valid_graph_predictor_data): """Check that auto_configure works on the happy path.""" # Given session = FakeSession() # Setup a response that includes instance instead of config response = deepcopy(valid_graph_predictor_data) response["instance"] = response["config"] del response["config"] session.set_response(response) pc = PredictorCollection(uuid.uuid4(), session) # When result = pc.auto_configure(GemTableDataSource(uuid.uuid4(), 0), "PLAIN") # Then the response is parsed in a predictor assert result.name == valid_graph_predictor_data["display_name"] assert isinstance(result, GraphPredictor) # including nested predictors assert len(result.predictors) == 2 assert isinstance(result.predictors[0], uuid.UUID) assert isinstance(result.predictors[1], DeprecatedExpressionPredictor)
def test_descriptors_from_data_source(): session = FakeSession() col = 'smiles' response_json = { 'descriptors': [ # shortened sample response { 'category': 'Real', 'descriptor_key': 'khs.sNH3 KierHallSmarts for {}'.format(col), 'units': '', 'lower_bound': 0, 'upper_bound': 1000000000 }, { 'category': 'Real', 'descriptor_key': 'khs.dsN KierHallSmarts for {}'.format(col), 'units': '', 'lower_bound': 0, 'upper_bound': 1000000000 }, ] } session.set_response(response_json) descriptors = DescriptorMethods(uuid4(), session) data_source = GemTableDataSource('43357a66-3644-4959-8115-77b2630aca45', 123) results = descriptors.descriptors_from_data_source(data_source) assert results == [ RealDescriptor(key=r['descriptor_key'], lower_bound=r['lower_bound'], upper_bound=r['upper_bound'], units=r['units']) for r in response_json['descriptors'] ] assert session.last_call.path == '/projects/{}/material-descriptors/from-data-source'\ .format(descriptors.project_id) assert session.last_call.method == 'POST'
def test_check_update_some(): """Test the update check correctly builds a module.""" # given session = FakeSession() desc = RealDescriptor("spam", 0, 1, "kg") response = { "type": "AnalyticExpression", "name": "foo", "description": "bar", "expression": "2 * x", "output": RealDescriptor("spam", 0, 1, "kg").dump(), "aliases": {} } session.set_response({"updatable": True, "update": response}) pc = PredictorCollection(uuid.uuid4(), session) predictor_id = uuid.uuid4() # when update_check = pc.check_for_update(predictor_id) # then expected = ExpressionPredictor("foo", "bar", "2 * x", desc, {}) assert update_check.dump() == expected.dump() assert update_check.uid == predictor_id
def test_design_space_limits(): """Test that the validation logic is triggered before post/put-ing enumerated design spaces.""" # Given session = FakeSession() collection = DesignSpaceCollection(uuid.uuid4(), session) too_big = EnumeratedDesignSpace( "foo", "bar", descriptors=[ RealDescriptor("R-{}".format(i), 0, 1, "") for i in range(128) ], data=[{"R-{}".format(i): random() for i in range(128)} for _ in range(2001)]) just_right = EnumeratedDesignSpace( "foo", "bar", descriptors=[ RealDescriptor("R-{}".format(i), 0, 1, "") for i in range(128) ], data=[{"R-{}".format(i): random() for i in range(128)} for _ in range(2000)]) # create mock post response by setting the status mock_response = just_right.dump() mock_response["status"] = "READY" session.responses.append(mock_response) # Then with pytest.raises(ValueError) as excinfo: collection.register(too_big) assert "only supports" in str(excinfo.value) # test register collection.register(just_right) # add back the response for the next test session.responses.append(mock_response) # test update collection.update(just_right)
def session() -> FakeSession: return FakeSession()
def test_user_collection_creation(): session = FakeSession() assert session == UserCollection(session).session
def __init__(self, num_properties): self.project_id = uuid4() self.session = FakeSession() self.num_properties = num_properties