def test_responsibleai_adult_with_ill_defined_cohorts( self, create_rai_insights_object_classification): ri = create_rai_insights_object_classification cohort_filter_continuous_1 = CohortFilter( method=CohortFilterMethods.METHOD_LESS, arg=[65], column='Age') cohort_filter_continuous_2 = CohortFilter( method=CohortFilterMethods.METHOD_GREATER, arg=[40], column='Hours per week') user_cohort_continuous = Cohort(name='Cohort Continuous') user_cohort_continuous.add_cohort_filter(cohort_filter_continuous_1) user_cohort_continuous.add_cohort_filter(cohort_filter_continuous_2) with pytest.raises( UserConfigValidationException, match="cohort_list parameter should be a list."): ResponsibleAIDashboard(ri, cohort_list={}) with pytest.raises( UserConfigValidationException, match="All entries in cohort_list should be of type Cohort."): ResponsibleAIDashboard( ri, cohort_list=[user_cohort_continuous, {}])
def test_cohort_serialization_single_value_method(self, method): cohort_filter_1 = CohortFilter(method=method, arg=[65], column='age') cohort_1 = Cohort(name="Cohort New") cohort_1.add_cohort_filter(cohort_filter_1) json_str = cohort_1.to_json() assert 'Cohort New' in json_str assert method in json_str assert '[65]' in json_str assert 'age' in json_str
def test_cohort_configuration_validations(self): with pytest.raises( UserConfigValidationException, match="Got unexpected type <class 'int'> for cohort name. " "Expected string type."): Cohort(name=1) with pytest.raises( UserConfigValidationException, match="Got unexpected type <class 'list'> for cohort filter. " "Expected CohortFilter type"): cohort = Cohort(name="Cohort New") cohort.add_cohort_filter(cohort_filter=[])
def test_responsibleai_adult_duplicate_cohort_names( self, create_rai_insights_object_classification): ri = create_rai_insights_object_classification cohort_filter_continuous_1 = CohortFilter( method=CohortFilterMethods.METHOD_LESS, arg=[65], column='Age') cohort_filter_continuous_2 = CohortFilter( method=CohortFilterMethods.METHOD_GREATER, arg=[40], column='Hours per week') user_cohort_continuous = Cohort(name='Cohort Continuous') user_cohort_continuous.add_cohort_filter(cohort_filter_continuous_1) user_cohort_continuous.add_cohort_filter(cohort_filter_continuous_2) with pytest.raises( UserConfigValidationException, match="Found cohorts with duplicate names. " "All pre-defined cohorts need to have distinct names."): ResponsibleAIDashboard( ri, cohort_list=[user_cohort_continuous, user_cohort_continuous])
def test_cohort_list_serialization(self): cohort_filter_1 = CohortFilter( method=CohortFilterMethods.METHOD_LESS, arg=[65], column='age') cohort_1 = Cohort(name="Cohort New") cohort_1.add_cohort_filter(cohort_filter_1) cohort_2 = Cohort(name="Cohort Old") cohort_2.add_cohort_filter(cohort_filter_1) cohort_list = [cohort_1, cohort_2] json_str = json.dumps(cohort_list, default=cohort_filter_json_converter) assert 'Cohort Old' in json_str assert 'Cohort New' in json_str assert CohortFilterMethods.METHOD_LESS in json_str assert '[65]' in json_str assert 'age' in json_str
def test_cohort_serialization_deserialization_in_range_method(self): cohort_filter_1 = CohortFilter( method=CohortFilterMethods.METHOD_RANGE, arg=[65.0, 70.0], column='age') cohort_1 = Cohort(name="Cohort New") cohort_1.add_cohort_filter(cohort_filter_1) json_str = cohort_1.to_json() assert 'Cohort New' in json_str assert CohortFilterMethods.METHOD_RANGE in json_str assert '65.0' in json_str assert '70.0' in json_str assert 'age' in json_str cohort_1_new = Cohort.from_json(json_str) assert cohort_1_new.name == cohort_1.name assert len(cohort_1_new.cohort_filter_list) == \ len(cohort_1.cohort_filter_list) assert cohort_1_new.cohort_filter_list[0].method == \ cohort_1.cohort_filter_list[0].method
def test_responsibleai_housing_with_pre_defined_cohorts( self, create_rai_insights_object_regression): ri = create_rai_insights_object_regression cohort_filter_continuous_1 = CohortFilter( method=CohortFilterMethods.METHOD_LESS, arg=[30.5], column='HouseAge') cohort_filter_continuous_2 = CohortFilter( method=CohortFilterMethods.METHOD_GREATER, arg=[3.0], column='AveRooms') user_cohort_continuous = Cohort(name='Cohort Continuous') user_cohort_continuous.add_cohort_filter(cohort_filter_continuous_1) user_cohort_continuous.add_cohort_filter(cohort_filter_continuous_2) cohort_filter_index = CohortFilter( method=CohortFilterMethods.METHOD_LESS, arg=[20], column='Index') user_cohort_index = Cohort(name='Cohort Index') user_cohort_index.add_cohort_filter(cohort_filter_index) cohort_filter_predicted_y = CohortFilter( method=CohortFilterMethods.METHOD_LESS, arg=[5.0], column='Predicted Y') user_cohort_predicted_y = Cohort(name='Cohort Predicted Y') user_cohort_predicted_y.add_cohort_filter(cohort_filter_predicted_y) cohort_filter_true_y = CohortFilter( method=CohortFilterMethods.METHOD_GREATER, arg=[1.0], column='True Y') user_cohort_true_y = Cohort(name='Cohort True Y') user_cohort_true_y.add_cohort_filter(cohort_filter_true_y) widget = ResponsibleAIDashboard( ri, cohort_list=[user_cohort_continuous, user_cohort_index, user_cohort_predicted_y, user_cohort_true_y]) self.validate_rai_dashboard_data(widget)
def test_responsibleai_adult_with_pre_defined_cohorts( self, create_rai_insights_object_classification): ri = create_rai_insights_object_classification cohort_filter_continuous_1 = CohortFilter( method=CohortFilterMethods.METHOD_LESS, arg=[65], column='Age') cohort_filter_continuous_2 = CohortFilter( method=CohortFilterMethods.METHOD_GREATER, arg=[40], column='Hours per week') user_cohort_continuous = Cohort(name='Cohort Continuous') user_cohort_continuous.add_cohort_filter(cohort_filter_continuous_1) user_cohort_continuous.add_cohort_filter(cohort_filter_continuous_2) cohort_filter_categorical = CohortFilter( method=CohortFilterMethods.METHOD_INCLUDES, arg=[2, 6, 4], column='Marital Status') user_cohort_categorical = Cohort(name='Cohort Categorical') user_cohort_categorical.add_cohort_filter(cohort_filter_categorical) cohort_filter_index = CohortFilter( method=CohortFilterMethods.METHOD_LESS, arg=[20], column='Index') user_cohort_index = Cohort(name='Cohort Index') user_cohort_index.add_cohort_filter(cohort_filter_index) widget = ResponsibleAIDashboard( ri, cohort_list=[user_cohort_continuous, user_cohort_categorical, user_cohort_index]) self.validate_rai_dashboard_data(widget)
def test_cohort_deserialization_error_conditions(self): test_dict = {} with pytest.raises( UserConfigValidationException, match="No name field found for cohort deserialization"): Cohort.from_json(json.dumps(test_dict)) test_dict = {'name': 'Cohort New'} with pytest.raises( UserConfigValidationException, match="No cohort_filter_list field found for " "cohort deserialization"): Cohort.from_json(json.dumps(test_dict)) test_dict = {'name': 'Cohort New', 'cohort_filter_list': {}} with pytest.raises(UserConfigValidationException, match="Field cohort_filter_list not of type list " "for cohort deserialization"): Cohort.from_json(json.dumps(test_dict)) test_dict = {'name': 'Cohort New', 'cohort_filter_list': [{}]} with pytest.raises( UserConfigValidationException, match="No method field found for cohort filter " "deserialization"): Cohort.from_json(json.dumps(test_dict)) test_dict = { 'name': 'Cohort New', 'cohort_filter_list': [{"method": "fake_method"}]} with pytest.raises( UserConfigValidationException, match="No arg field found for cohort filter deserialization"): Cohort.from_json(json.dumps(test_dict)) test_dict = { 'name': 'Cohort New', 'cohort_filter_list': [{"method": "fake_method", "arg": []}]} with pytest.raises( UserConfigValidationException, match="No column field found for cohort filter " "deserialization"): Cohort.from_json(json.dumps(test_dict))
def test_cohort_serialization_deserialization_include_exclude_methods( self, method): cohort_filter_str = CohortFilter(method=method, arg=['val1', 'val2', 'val3'], column='age') cohort_str = Cohort(name="Cohort New Str") cohort_str.add_cohort_filter(cohort_filter_str) json_str = cohort_str.to_json() assert method in json_str assert 'val1' in json_str assert 'val2' in json_str assert 'val3' in json_str assert 'age' in json_str cohort_str_new = Cohort.from_json(json_str) assert cohort_str == cohort_str_new cohort_filter_int = CohortFilter(method=method, arg=[1, 2, 3], column='age') cohort_int = Cohort(name="Cohort New Int") cohort_int.add_cohort_filter(cohort_filter_int) json_str = cohort_int.to_json() assert method in json_str assert '1' in json_str assert '2' in json_str assert '3' in json_str assert 'age' in json_str cohort_int_new = Cohort.from_json(json_str) assert cohort_int == cohort_int_new
def test_cohort_validate_with_test_data(self): cohort_filter_1 = CohortFilter( method=CohortFilterMethods.METHOD_LESS, arg=[65], column='age') cohort_1 = Cohort(name="Cohort New") cohort_1.add_cohort_filter(cohort_filter_1) test_data = get_toy_binary_classification_dataset() with pytest.raises( UserConfigValidationException, match="The test_data should be a pandas DataFrame"): cohort_1._validate_with_test_data( test_data=[], target_column='target', categorical_features=[]) with pytest.raises( UserConfigValidationException, match="The target_column should be string."): cohort_1._validate_with_test_data( test_data=test_data, target_column=1, categorical_features=[]) with pytest.raises( UserConfigValidationException, match="The target_column fake_target " "was not found in test_data."): cohort_1._validate_with_test_data( test_data=test_data, target_column="fake_target", categorical_features=[]) with pytest.raises( UserConfigValidationException, match="Expected a list type for " "categorical columns."): cohort_1._validate_with_test_data( test_data=test_data, target_column="target", categorical_features={}) with pytest.raises( UserConfigValidationException, match="Feature 1 in categorical_features need to be of " "string type."): cohort_1._validate_with_test_data( test_data=test_data, target_column="target", categorical_features=[1, 2]) with pytest.raises( UserConfigValidationException, match="Found categorical feature hours-per-week which is not" " present in test data."): cohort_1._validate_with_test_data( test_data=test_data, target_column="target", categorical_features=["hours-per-week"])