Пример #1
0
    def eval(self, y_pred, y_true, metadata, prediction_fn=None):
        """
        Computes all evaluation metrics.
        Args:
            - y_pred (Tensor): Predictions from a model
            - y_true (LongTensor): Ground-truth values
            - metadata (Tensor): Metadata
            - prediction_fn (function): Only None supported
        Output:
            - results (dictionary): Dictionary of evaluation metrics
            - results_str (str): String summarizing the evaluation metrics
        """
        assert prediction_fn is None, "PovertyMapDataset.eval() does not support prediction_fn"

        metrics = [MSE(), PearsonCorrelation()]

        all_results = {}
        all_results_str = ''
        for metric in metrics:
            results, results_str = self.standard_group_eval(
                metric,
                self._eval_grouper,
                y_pred, y_true, metadata)
            all_results.update(results)
            all_results_str += results_str
        return all_results, all_results_str
Пример #2
0
import torch.nn as nn
import torch
import sys, os

# metrics
from wilds.common.metrics.loss import ElementwiseLoss, Loss, MultiTaskLoss
from wilds.common.metrics.all_metrics import Accuracy, MultiTaskAccuracy, MSE, multiclass_logits_to_pred, binary_logits_to_pred

losses = {
    'cross_entropy':
    ElementwiseLoss(loss_fn=nn.CrossEntropyLoss(reduction='none')),
    'lm_cross_entropy':
    MultiTaskLoss(loss_fn=nn.CrossEntropyLoss(reduction='none')),
    'mse':
    MSE(name='loss'),
    'multitask_bce':
    MultiTaskLoss(loss_fn=nn.BCEWithLogitsLoss(reduction='none')),
}

algo_log_metrics = {
    'accuracy':
    Accuracy(prediction_fn=multiclass_logits_to_pred),
    'mse':
    MSE(),
    'multitask_accuracy':
    MultiTaskAccuracy(prediction_fn=multiclass_logits_to_pred),
    'multitask_binary_accuracy':
    MultiTaskAccuracy(prediction_fn=binary_logits_to_pred),
    None:
    None,
}
Пример #3
0
    'camelyon17': Camelyon17Dataset,
    'celebA': CelebADataset,
    'civilcomments': CivilCommentsDataset,
    'iwildcam': IWildCamDataset,
    'waterbirds': WaterbirdsDataset,
    'yelp': YelpDataset,
    'ogb-molpcba': OGBPCBADataset,
    'poverty': PovertyMapDataset,
    'fmow': FMoWDataset,
    'bdd100k': BDD100KDataset,
}

losses = {
    'cross_entropy':
    ElementwiseLoss(loss_fn=nn.CrossEntropyLoss(reduction='none')),
    'mse': MSE(name='loss'),
    'multitask_bce':
    MultiTaskLoss(loss_fn=nn.BCEWithLogitsLoss(reduction='none')),
}

algo_log_metrics = {
    'accuracy': Accuracy(),
    'mse': MSE(),
    'multitask_accuracy': MultiTaskAccuracy(),
    None: None,
}

# see initialize_*() functions for correspondence
transforms = [
    'bert', 'image_base', 'image_resize_and_center_crop', 'poverty_train'
]
Пример #4
0
    def __init__(self,
                 root_dir='data',
                 download=False,
                 split_scheme='official',
                 no_nl=True,
                 fold='A',
                 oracle_training_set=False,
                 use_ood_val=False):

        self._compressed_size = 18_630_656_000
        self._data_dir = self.initialize_data_dir(root_dir, download)

        self._split_dict = {
            'train': 0,
            'id_val': 1,
            'id_test': 2,
            'val': 3,
            'test': 4
        }
        self._split_names = {
            'train': 'Train',
            'id_val': 'ID Val',
            'id_test': 'ID Test',
            'val': 'OOD Val',
            'test': 'OOD Test'
        }

        if split_scheme == 'official':
            split_scheme = 'countries'
        self._split_scheme = split_scheme
        if self._split_scheme != 'countries':
            raise ValueError("Split scheme not recognized")

        self.oracle_training_set = oracle_training_set

        self.no_nl = no_nl
        if fold not in {'A', 'B', 'C', 'D', 'E'}:
            raise ValueError("Fold must be A, B, C, D, or E")

        self.root = Path(self._data_dir)
        self.metadata = pd.read_csv(self.root / 'dhs_metadata.csv')
        # country folds, split off OOD
        country_folds = SURVEY_NAMES[f'2009-17{fold}']

        self._split_array = -1 * np.ones(len(self.metadata))

        incountry_folds_split = np.arange(len(self.metadata))
        # take the test countries to be ood
        idxs_id, idxs_ood_test = split_by_countries(incountry_folds_split,
                                                    country_folds['test'],
                                                    self.metadata)
        # also create a validation OOD set
        idxs_id, idxs_ood_val = split_by_countries(idxs_id,
                                                   country_folds['val'],
                                                   self.metadata)
        for split in ['test', 'val', 'id_test', 'id_val', 'train']:
            # keep ood for test, otherwise throw away ood data
            if split == 'test':
                idxs = idxs_ood_test
            elif split == 'val':
                idxs = idxs_ood_val
            else:
                idxs = idxs_id
                num_eval = 2000
                # if oracle, do 50-50 split between OOD and ID
                if split == 'train' and self.oracle_training_set:
                    idxs = subsample_idxs(incountry_folds_split,
                                          num=len(idxs_id),
                                          seed=ord(fold))[num_eval:]
                elif split != 'train' and self.oracle_training_set:
                    eval_idxs = subsample_idxs(incountry_folds_split,
                                               num=len(idxs_id),
                                               seed=ord(fold))[:num_eval]
                elif split == 'train':
                    idxs = subsample_idxs(idxs,
                                          take_rest=True,
                                          num=num_eval,
                                          seed=ord(fold))
                else:
                    eval_idxs = subsample_idxs(idxs,
                                               take_rest=False,
                                               num=num_eval,
                                               seed=ord(fold))

                if split != 'train':
                    if split == 'id_val':
                        idxs = eval_idxs[:num_eval // 2]
                    else:
                        idxs = eval_idxs[num_eval // 2:]
            self._split_array[idxs] = self._split_dict[split]

        if not use_ood_val:
            self._split_dict = {
                'train': 0,
                'val': 1,
                'id_test': 2,
                'ood_val': 3,
                'test': 4
            }
            self._split_names = {
                'train': 'Train',
                'val': 'ID Val',
                'id_test': 'ID Test',
                'ood_val': 'OOD Val',
                'test': 'OOD Test'
            }

        self.imgs = np.load(self.root / 'landsat_poverty_imgs.npy',
                            mmap_mode='r')

        self.imgs = self.imgs.transpose((0, 3, 1, 2))
        self._y_array = torch.from_numpy(
            np.asarray(self.metadata['wealthpooled'])[:, np.newaxis]).float()
        self._y_size = 1

        # add country group field
        country_to_idx = {
            country: i
            for i, country in enumerate(DHS_COUNTRIES)
        }
        self.metadata['country'] = [
            country_to_idx[country]
            for country in self.metadata['country'].tolist()
        ]
        self._metadata_map = {'country': DHS_COUNTRIES}
        self._metadata_array = torch.from_numpy(
            self.metadata[['urban', 'wealthpooled',
                           'country']].astype(float).to_numpy())
        # rename wealthpooled to y
        self._metadata_fields = ['urban', 'y', 'country']

        self._eval_grouper = CombinatorialGrouper(dataset=self,
                                                  groupby_fields=['urban'])

        self._metrics = [MSE(), PearsonCorrelation()]
        self.cache_counter = 0

        super().__init__(root_dir, download, split_scheme)