def eval(self, y_pred, y_true, metadata, prediction_fn=None): """ Computes all evaluation metrics. Args: - y_pred (Tensor): Predictions from a model - y_true (LongTensor): Ground-truth values - metadata (Tensor): Metadata - prediction_fn (function): Only None supported Output: - results (dictionary): Dictionary of evaluation metrics - results_str (str): String summarizing the evaluation metrics """ assert prediction_fn is None, "PovertyMapDataset.eval() does not support prediction_fn" metrics = [MSE(), PearsonCorrelation()] all_results = {} all_results_str = '' for metric in metrics: results, results_str = self.standard_group_eval( metric, self._eval_grouper, y_pred, y_true, metadata) all_results.update(results) all_results_str += results_str return all_results, all_results_str
import torch.nn as nn import torch import sys, os # metrics from wilds.common.metrics.loss import ElementwiseLoss, Loss, MultiTaskLoss from wilds.common.metrics.all_metrics import Accuracy, MultiTaskAccuracy, MSE, multiclass_logits_to_pred, binary_logits_to_pred losses = { 'cross_entropy': ElementwiseLoss(loss_fn=nn.CrossEntropyLoss(reduction='none')), 'lm_cross_entropy': MultiTaskLoss(loss_fn=nn.CrossEntropyLoss(reduction='none')), 'mse': MSE(name='loss'), 'multitask_bce': MultiTaskLoss(loss_fn=nn.BCEWithLogitsLoss(reduction='none')), } algo_log_metrics = { 'accuracy': Accuracy(prediction_fn=multiclass_logits_to_pred), 'mse': MSE(), 'multitask_accuracy': MultiTaskAccuracy(prediction_fn=multiclass_logits_to_pred), 'multitask_binary_accuracy': MultiTaskAccuracy(prediction_fn=binary_logits_to_pred), None: None, }
'camelyon17': Camelyon17Dataset, 'celebA': CelebADataset, 'civilcomments': CivilCommentsDataset, 'iwildcam': IWildCamDataset, 'waterbirds': WaterbirdsDataset, 'yelp': YelpDataset, 'ogb-molpcba': OGBPCBADataset, 'poverty': PovertyMapDataset, 'fmow': FMoWDataset, 'bdd100k': BDD100KDataset, } losses = { 'cross_entropy': ElementwiseLoss(loss_fn=nn.CrossEntropyLoss(reduction='none')), 'mse': MSE(name='loss'), 'multitask_bce': MultiTaskLoss(loss_fn=nn.BCEWithLogitsLoss(reduction='none')), } algo_log_metrics = { 'accuracy': Accuracy(), 'mse': MSE(), 'multitask_accuracy': MultiTaskAccuracy(), None: None, } # see initialize_*() functions for correspondence transforms = [ 'bert', 'image_base', 'image_resize_and_center_crop', 'poverty_train' ]
def __init__(self, root_dir='data', download=False, split_scheme='official', no_nl=True, fold='A', oracle_training_set=False, use_ood_val=False): self._compressed_size = 18_630_656_000 self._data_dir = self.initialize_data_dir(root_dir, download) self._split_dict = { 'train': 0, 'id_val': 1, 'id_test': 2, 'val': 3, 'test': 4 } self._split_names = { 'train': 'Train', 'id_val': 'ID Val', 'id_test': 'ID Test', 'val': 'OOD Val', 'test': 'OOD Test' } if split_scheme == 'official': split_scheme = 'countries' self._split_scheme = split_scheme if self._split_scheme != 'countries': raise ValueError("Split scheme not recognized") self.oracle_training_set = oracle_training_set self.no_nl = no_nl if fold not in {'A', 'B', 'C', 'D', 'E'}: raise ValueError("Fold must be A, B, C, D, or E") self.root = Path(self._data_dir) self.metadata = pd.read_csv(self.root / 'dhs_metadata.csv') # country folds, split off OOD country_folds = SURVEY_NAMES[f'2009-17{fold}'] self._split_array = -1 * np.ones(len(self.metadata)) incountry_folds_split = np.arange(len(self.metadata)) # take the test countries to be ood idxs_id, idxs_ood_test = split_by_countries(incountry_folds_split, country_folds['test'], self.metadata) # also create a validation OOD set idxs_id, idxs_ood_val = split_by_countries(idxs_id, country_folds['val'], self.metadata) for split in ['test', 'val', 'id_test', 'id_val', 'train']: # keep ood for test, otherwise throw away ood data if split == 'test': idxs = idxs_ood_test elif split == 'val': idxs = idxs_ood_val else: idxs = idxs_id num_eval = 2000 # if oracle, do 50-50 split between OOD and ID if split == 'train' and self.oracle_training_set: idxs = subsample_idxs(incountry_folds_split, num=len(idxs_id), seed=ord(fold))[num_eval:] elif split != 'train' and self.oracle_training_set: eval_idxs = subsample_idxs(incountry_folds_split, num=len(idxs_id), seed=ord(fold))[:num_eval] elif split == 'train': idxs = subsample_idxs(idxs, take_rest=True, num=num_eval, seed=ord(fold)) else: eval_idxs = subsample_idxs(idxs, take_rest=False, num=num_eval, seed=ord(fold)) if split != 'train': if split == 'id_val': idxs = eval_idxs[:num_eval // 2] else: idxs = eval_idxs[num_eval // 2:] self._split_array[idxs] = self._split_dict[split] if not use_ood_val: self._split_dict = { 'train': 0, 'val': 1, 'id_test': 2, 'ood_val': 3, 'test': 4 } self._split_names = { 'train': 'Train', 'val': 'ID Val', 'id_test': 'ID Test', 'ood_val': 'OOD Val', 'test': 'OOD Test' } self.imgs = np.load(self.root / 'landsat_poverty_imgs.npy', mmap_mode='r') self.imgs = self.imgs.transpose((0, 3, 1, 2)) self._y_array = torch.from_numpy( np.asarray(self.metadata['wealthpooled'])[:, np.newaxis]).float() self._y_size = 1 # add country group field country_to_idx = { country: i for i, country in enumerate(DHS_COUNTRIES) } self.metadata['country'] = [ country_to_idx[country] for country in self.metadata['country'].tolist() ] self._metadata_map = {'country': DHS_COUNTRIES} self._metadata_array = torch.from_numpy( self.metadata[['urban', 'wealthpooled', 'country']].astype(float).to_numpy()) # rename wealthpooled to y self._metadata_fields = ['urban', 'y', 'country'] self._eval_grouper = CombinatorialGrouper(dataset=self, groupby_fields=['urban']) self._metrics = [MSE(), PearsonCorrelation()] self.cache_counter = 0 super().__init__(root_dir, download, split_scheme)