def test_overall_metric_results(true_sentiment_key: str, 
                                include_metadata: bool,
                                strict_accuracy_metrics: bool):
    if true_sentiment_key is None:
        true_sentiment_key = 'target_sentiments'
    model_1_collection = passable_example_multiple_preds(true_sentiment_key, 'model_1')
    model_1_collection.add(TargetText(text='a', text_id='200', spans=[Span(0,1)], targets=['a'],
                                      model_1=[['pos'], ['neg']], **{f'{true_sentiment_key}': ['pos']}))
    model_1_collection.add(TargetText(text='a', text_id='201', spans=[Span(0,1)], targets=['a'],
                                      model_1=[['pos'], ['neg']], **{f'{true_sentiment_key}': ['neg']}))
    model_1_collection.add(TargetText(text='a', text_id='202', spans=[Span(0,1)], targets=['a'],
                                      model_1=[['pos'], ['neg']], **{f'{true_sentiment_key}': ['neu']}))
    print(true_sentiment_key)
    print(model_1_collection['1'])
    print(model_1_collection['200'])
    model_2_collection = passable_example_multiple_preds(true_sentiment_key, 'model_2')
    combined_collection = TargetTextCollection()
    
    standard_columns = ['Dataset', 'Macro F1', 'Accuracy', 'run number', 
                        'prediction key']
    if strict_accuracy_metrics:
        standard_columns = standard_columns + ['STAC', 'STAC 1', 'STAC Multi']
    if include_metadata:
        metadata = {'predicted_target_sentiment_key': {'model_1': {'CWR': True},
                                                       'model_2': {'CWR': False}}}
        combined_collection.name = 'name'
        combined_collection.metadata = metadata
        standard_columns.append('CWR')
    number_df_columns = len(standard_columns)

    for key, value in model_1_collection.items():
        if key in ['200', '201', '202']:
            combined_collection.add(value)
            combined_collection[key]['model_2'] = [['neg'], ['pos']]
            continue
        combined_collection.add(value)
        combined_collection[key]['model_2'] = model_2_collection[key]['model_2']
    if true_sentiment_key is None:
        result_df = util.overall_metric_results(combined_collection, 
                                                ['model_1', 'model_2'], 
                                                strict_accuracy_metrics=strict_accuracy_metrics)
    else:
        result_df = util.overall_metric_results(combined_collection, 
                                                ['model_1', 'model_2'],
                                                true_sentiment_key, 
                                                strict_accuracy_metrics=strict_accuracy_metrics)
    assert (4, number_df_columns) == result_df.shape
    assert set(standard_columns) == set(result_df.columns)
    if include_metadata:
        assert ['name'] * 4 == result_df['Dataset'].tolist()
    else:
        assert [''] * 4 == result_df['Dataset'].tolist()
    # Test the case where only one model is used
    if true_sentiment_key is None:
        result_df = util.overall_metric_results(combined_collection, 
                                                ['model_1'], 
                                                strict_accuracy_metrics=strict_accuracy_metrics)
    else:
        result_df = util.overall_metric_results(combined_collection, 
                                                ['model_1'],
                                                true_sentiment_key, 
                                                strict_accuracy_metrics=strict_accuracy_metrics)
    assert (2, number_df_columns) == result_df.shape
    # Test the case where the model names come from the metadata
    if include_metadata:
        result_df = util.overall_metric_results(combined_collection, 
                                                true_sentiment_key=true_sentiment_key, 
                                                strict_accuracy_metrics=strict_accuracy_metrics)
        assert (4, number_df_columns) == result_df.shape
    else:
        with pytest.raises(KeyError):
            util.overall_metric_results(combined_collection, 
                                        true_sentiment_key=true_sentiment_key, 
                                        strict_accuracy_metrics=strict_accuracy_metrics)
def test_add_metadata_to_df():
    model_1_collection = passable_example_multiple_preds('true_sentiments', 'model_1')
    model_2_collection = passable_example_multiple_preds('true_sentiments', 'model_2')
    combined_collection = TargetTextCollection()
    for key, value in model_1_collection.items():
        combined_collection.add(value)
        combined_collection[key]['model_2'] = model_2_collection[key]['model_2']
    # get test metric_df
    combined_collection.metadata = None
    metric_df = util.metric_df(combined_collection, sentiment_metrics.accuracy, 
                               'true_sentiments', 
                               predicted_sentiment_keys=['model_1', 'model_2'], 
                               average=False, array_scores=True, metric_name='metric')
    # Test the case where the TargetTextCollection has no metadata, should 
    # just return the dataframe as is without change
    test_df = util.add_metadata_to_df(metric_df.copy(deep=True), combined_collection, 
                                      'non-existing-key')
    assert metric_df.equals(test_df)
    assert combined_collection.metadata is None
    # Testing the case where the metadata is not None but does not contain the 
    # `metadata_prediction_key` = `non-existing-key`
    combined_collection.name = 'combined collection'
    assert combined_collection.metadata is not None
    test_df = util.add_metadata_to_df(metric_df.copy(deep=True), combined_collection, 
                                      'non-existing-key')
    assert metric_df.equals(test_df)

    # Test the normal cases where we are adding metadata
    key_metadata_normal = {'model_1': {'embedding': True, 'value': '10'}, 
                           'model_2': {'embedding': False, 'value': '5'}}
    key_metadata_alt = {'model_1': {'embedding': True, 'value': '10'}, 
                        'model_2': {'embedding': False, 'value': '5'},
                        'model_3': {'embedding': 'low', 'value': '12'}}
    key_metadata_diff = {'model_1': {'embedding': True, 'value': '10', 'special': 12}, 
                         'model_2': {'embedding': False, 'value': '5', 'diff': 30.0}}
    key_metadataer = [key_metadata_normal, key_metadata_alt, key_metadata_diff]
    for key_metadata in key_metadataer:
        combined_collection.metadata['non-existing-key'] = key_metadata
        if 'special' in key_metadata['model_1']:
            test_df = util.add_metadata_to_df(metric_df.copy(deep=True), combined_collection, 
                                              'non-existing-key', 
                                              metadata_keys=['embedding', 'value'])
        else:
            test_df = util.add_metadata_to_df(metric_df.copy(deep=True), combined_collection, 
                                              'non-existing-key')
        assert not metric_df.equals(test_df)
        assert (4, 4) == test_df.shape
        assert [True, True] == test_df.loc[test_df['prediction key']=='model_1']['embedding'].to_list()
        assert [False, False] == test_df.loc[test_df['prediction key']=='model_2']['embedding'].to_list()
        assert ['10', '10'] == test_df.loc[test_df['prediction key']=='model_1']['value'].to_list()
        assert ['5', '5'] == test_df.loc[test_df['prediction key']=='model_2']['value'].to_list()
    # Test the case where some of the metadata exists for some of the models in
    # the collection but not all of them
    combined_collection.metadata['non-existing-key'] = key_metadata_diff
    test_df = util.add_metadata_to_df(metric_df.copy(deep=True), combined_collection, 
                                      'non-existing-key')
    assert not metric_df.equals(test_df)
    assert (4, 6) == test_df.shape
    assert [True, True] == test_df.loc[test_df['prediction key']=='model_1']['embedding'].to_list()
    assert [False, False] == test_df.loc[test_df['prediction key']=='model_2']['embedding'].to_list()
    assert ['10', '10'] == test_df.loc[test_df['prediction key']=='model_1']['value'].to_list()
    assert ['5', '5'] == test_df.loc[test_df['prediction key']=='model_2']['value'].to_list()
    assert [12, 12] == test_df.loc[test_df['prediction key']=='model_1']['special'].to_list()
    nan_values = test_df.loc[test_df['prediction key']=='model_2']['special'].to_list()
    assert 2 == len(nan_values)
    for test_value in nan_values:
        assert math.isnan(test_value)
    nan_values = test_df.loc[test_df['prediction key']=='model_1']['diff'].to_list()
    assert 2 == len(nan_values)
    for test_value in nan_values:
        assert math.isnan(test_value)
    assert [30.0, 30.0] == test_df.loc[test_df['prediction key']=='model_2']['diff'].to_list()

    # Test the KeyError cases
    # Prediction keys that exist in the metric df but not in the collection
    metric_copy_df = metric_df.copy(deep=True)
    alt_metric_df = pd.DataFrame({'prediction key': ['model_3', 'model_3'], 
                                  'metric': [0.4, 0.5]})
    metric_copy_df = metric_copy_df.append(alt_metric_df)
    assert (6, 2) == metric_copy_df.shape
    with pytest.raises(KeyError):
        util.add_metadata_to_df(metric_copy_df, combined_collection, 'non-existing-key')
    # Prediction keys exist in the dataframe and target texts but not in the 
    # metadata
    combined_collection.metadata['non-existing-key'] = {'model_1': {'embedding': True, 'value': '10'}}
    with pytest.raises(KeyError):
        util.add_metadata_to_df(metric_df.copy(deep=True), combined_collection, 'non-existing-key')