def test_aum_finalize(tmp_path, aum_data): inputs, outputs = aum_data save_dir = tmp_path.as_posix() aum_calculator = AUMCalculator(save_dir=save_dir, compressed=False) for data in inputs: aum_calculator.update(data['logits'], data['targets'], data['sample_ids']) aum_calculator.finalize() final_vals = pd.read_csv(os.path.join(save_dir, 'aum_values.csv')) detailed_vals = pd.read_csv(os.path.join(save_dir, 'full_aum_records.csv')) # Lets first verify detailed vals records = [] for output in outputs: records.extend(output.values()) expected_detailed_vals = pd.DataFrame([ asdict(record) for record in records ]).sort_values(by=['sample_id', 'num_measurements']).reset_index(drop=True) assert detailed_vals.equals(expected_detailed_vals) # Now lets verfiy the final vals final_dict = {record.sample_id: record.aum for record in records} expected_final_vals = [] for key, val in final_dict.items(): expected_final_vals.append({'sample_id': key, 'aum': val}) expected_final_vals = pd.DataFrame(expected_final_vals).sort_values( by='aum', ascending=False).reset_index(drop=True) assert final_vals.equals(expected_final_vals)
def test_aum_update(aum_data): inputs, outputs = aum_data aum_calculator = AUMCalculator(save_dir=None) expected_results = aum_calculator.update(inputs[0]['logits'], inputs[0]['targets'], inputs[0]['sample_ids']) assert expected_results == outputs[0] expected_results = aum_calculator.update(inputs[1]['logits'], inputs[1]['targets'], inputs[1]['sample_ids']) assert expected_results == outputs[1]