예제 #1
0
 def initialize_eval_grouper(self):
     if self.split_scheme == 'user':
         self._eval_grouper = CombinatorialGrouper(dataset=self,
                                                   groupby_fields=['user'])
     elif self.split_scheme in ('time', 'time_baseline'):
         self._eval_grouper = CombinatorialGrouper(dataset=self,
                                                   groupby_fields=['year'])
     else:
         raise ValueError(
             f'Split scheme {self.split_scheme} not recognized')
예제 #2
0
 def initialize_eval_grouper(self):
     if 'black' in self.split_scheme or 'race' in self.split_scheme:
         eval_grouper = CombinatorialGrouper(
             dataset=self, groupby_fields=['suspect race'])
     elif 'bronx' in self.split_scheme or 'all_borough' == self.split_scheme:
         eval_grouper = CombinatorialGrouper(dataset=self,
                                             groupby_fields=['borough'])
     else:
         raise ValueError(
             f'Split scheme {self.split_scheme} not recognized')
     return eval_grouper
예제 #3
0
 def initialize_eval_grouper(self):
     if self.split_scheme == 'user':
         self._eval_grouper = CombinatorialGrouper(dataset=self,
                                                   groupby_fields=['user'])
     elif self.split_scheme.endswith(
             'generalization'
     ) or self.split_scheme == 'category_subpopulation':
         self._eval_grouper = CombinatorialGrouper(
             dataset=self, groupby_fields=['category'])
     elif self.split_scheme in ('time', 'time_baseline'):
         self._eval_grouper = CombinatorialGrouper(dataset=self,
                                                   groupby_fields=['year'])
     elif self.split_scheme.endswith('_baseline'):  # user baselines
         self._eval_grouper = CombinatorialGrouper(dataset=self,
                                                   groupby_fields=['user'])
     else:
         raise ValueError(
             f'Split scheme {self.split_scheme} not recognized')
    def __init__(self,
                 root_dir='data',
                 download=False,
                 split_scheme='official'):
        self._dataset_name = 'waterbirds'
        self._version = '1.0'
        self._download_url = 'https://worksheets.codalab.org/rest/bundles/0x505056d5cdea4e4eaa0e242cbfe2daa4/contents/blob/'
        self._data_dir = self.initialize_data_dir(root_dir, download)

        if not os.path.exists(self.data_dir):
            raise ValueError(
                f'{self.data_dir} does not exist yet. Please generate the dataset first.'
            )

        # Read in metadata
        # Note: metadata_df is one-indexed.
        metadata_df = pd.read_csv(os.path.join(self.data_dir, 'metadata.csv'))

        # Get the y values
        self._y_array = torch.LongTensor(metadata_df['y'].values)
        self._y_size = 1
        self._n_classes = 2

        self._metadata_array = torch.stack(
            (torch.LongTensor(metadata_df['place'].values), self._y_array),
            dim=1)
        self._metadata_fields = ['background', 'y']
        self._metadata_map = {
            'background': [' land', 'water'],  # Padding for str formatting
            'y': [' landbird', 'waterbird']
        }

        # Extract filenames
        self._input_array = metadata_df['img_filename'].values
        self._original_resolution = (224, 224)

        # Extract splits
        self._split_scheme = split_scheme
        if self._split_scheme != 'official':
            raise ValueError(
                f'Split scheme {self._split_scheme} not recognized')
        self._split_array = metadata_df['split'].values

        self._eval_grouper = CombinatorialGrouper(
            dataset=self, groupby_fields=(['background', 'y']))
        self._metric = Accuracy()

        super().__init__(root_dir, download, split_scheme)
예제 #5
0
    def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'):
        self._version = version
        self._data_dir = self.initialize_data_dir(root_dir, download)

        # Read in metadata
        self._metadata_df = pd.read_csv(
            os.path.join(self._data_dir, 'all_data_with_identities.csv'),
            index_col=0)

        # Get the y values
        self._y_array = torch.LongTensor(self._metadata_df['toxicity'].values >= 0.5)
        self._y_size = 1
        self._n_classes = 2

        # Extract text
        self._text_array = list(self._metadata_df['comment_text'])

        # Extract splits
        self._split_scheme = split_scheme
        if self._split_scheme != 'official':
            raise ValueError(f'Split scheme {self._split_scheme} not recognized')
        # metadata_df contains split names in strings, so convert them to ints
        for split in self.split_dict:
            split_indices = self._metadata_df['split'] == split
            self._metadata_df.loc[split_indices, 'split'] = self.split_dict[split]
        self._split_array = self._metadata_df['split'].values

        # Extract metadata
        self._identity_vars = [
            'male',
            'female',
            'LGBTQ',
            'christian',
            'muslim',
            'other_religions',
            'black',
            'white'
        ]
        self._auxiliary_vars = [
            'identity_any',
            'severe_toxicity',
            'obscene',
            'threat',
            'insult',
            'identity_attack',
            'sexual_explicit'
        ]

        self._metadata_array = torch.cat(
            (
                torch.LongTensor((self._metadata_df.loc[:, self._identity_vars] >= 0.5).values),
                torch.LongTensor((self._metadata_df.loc[:, self._auxiliary_vars] >= 0.5).values),
                self._y_array.reshape((-1, 1))
            ),
            dim=1
        )
        self._metadata_fields = self._identity_vars + self._auxiliary_vars + ['y']

        self._eval_groupers = [
            CombinatorialGrouper(
                dataset=self,
                groupby_fields=[identity_var, 'y'])
            for identity_var in self._identity_vars]

        super().__init__(root_dir, download, split_scheme)
예제 #6
0
    def __init__(self,
                 root_dir='data',
                 download=False,
                 split_scheme='official'):
        self._dataset_name = 'celebA'
        self._version = '1.0'
        self._download_url = ''
        self._data_dir = self.initialize_data_dir(root_dir, download)
        target_name = 'Blond_Hair'
        confounder_names = ['Male']

        # Read in attributes
        attrs_df = pd.read_csv(
            os.path.join(self.data_dir, 'list_attr_celeba.csv'))

        # Split out filenames and attribute names
        # Note: idx and filenames are off by one.
        self._input_array = attrs_df['image_id'].values
        self._original_resolution = (178, 218)
        attrs_df = attrs_df.drop(labels='image_id', axis='columns')
        attr_names = attrs_df.columns.copy()

        def attr_idx(attr_name):
            return attr_names.get_loc(attr_name)

        # Then cast attributes to numpy array and set them to 0 and 1
        # (originally, they're -1 and 1)
        attrs_df = attrs_df.values
        attrs_df[attrs_df == -1] = 0

        # Get the y values
        target_idx = attr_idx(target_name)
        self._y_array = torch.LongTensor(attrs_df[:, target_idx])
        self._y_size = 1
        self._n_classes = 2

        # Get metadata
        confounder_idx = [attr_idx(a) for a in confounder_names]
        confounders = attrs_df[:, confounder_idx]

        self._metadata_array = torch.cat(
            (torch.LongTensor(confounders), self._y_array.reshape((-1, 1))),
            dim=1)
        confounder_names = [s.lower() for s in confounder_names]
        self._metadata_fields = confounder_names + ['y']
        self._metadata_map = {
            'y': ['not blond', '    blond']  # Padding for str formatting
        }

        self._eval_grouper = CombinatorialGrouper(
            dataset=self, groupby_fields=(confounder_names + ['y']))
        self._metric = Accuracy()

        # Extract splits
        self._split_scheme = split_scheme
        if self._split_scheme != 'official':
            raise ValueError(
                f'Split scheme {self._split_scheme} not recognized')
        split_df = pd.read_csv(
            os.path.join(self.data_dir, 'list_eval_partition.csv'))
        self._split_array = split_df['partition'].values

        super().__init__(root_dir, download, split_scheme)
예제 #7
0
    def __init__(self,
                 version=None,
                 root_dir='data',
                 download=False,
                 split_scheme='official'):
        self._version = version
        self._data_dir = self.initialize_data_dir(root_dir, download)
        self._original_resolution = (96, 96)

        # Read in metadata
        self._metadata_df = pd.read_csv(os.path.join(self._data_dir,
                                                     'metadata.csv'),
                                        index_col=0,
                                        dtype={'patient': 'str'})

        # Get the y values
        self._y_array = torch.LongTensor(self._metadata_df['tumor'].values)
        self._y_size = 1
        self._n_classes = 2

        # Get filenames
        self._input_array = [
            f'patches/patient_{patient}_node_{node}/patch_patient_{patient}_node_{node}_x_{x}_y_{y}.png'
            for patient, node, x, y in self._metadata_df.
            loc[:, ['patient', 'node', 'x_coord', 'y_coord']].itertuples(
                index=False, name=None)
        ]

        # Extract splits
        # Note that the hospital numbering here is different from what's in the paper,
        # where to avoid confusing readers we used a 1-indexed scheme and just labeled the test hospital as 5.
        # Here, the numbers are 0-indexed.
        test_center = 2
        val_center = 1

        self._split_dict = {'train': 0, 'id_val': 1, 'test': 2, 'val': 3}
        self._split_names = {
            'train': 'Train',
            'id_val': 'Validation (ID)',
            'test': 'Test',
            'val': 'Validation (OOD)',
        }
        centers = self._metadata_df['center'].values.astype('long')
        num_centers = int(np.max(centers)) + 1
        val_center_mask = (self._metadata_df['center'] == val_center)
        test_center_mask = (self._metadata_df['center'] == test_center)
        self._metadata_df.loc[val_center_mask,
                              'split'] = self.split_dict['val']
        self._metadata_df.loc[test_center_mask,
                              'split'] = self.split_dict['test']

        self._split_scheme = split_scheme
        if self._split_scheme == 'official':
            pass
        elif self._split_scheme == 'in-dist':
            # For the in-distribution oracle,
            # we move slide 23 (corresponding to patient 042, node 3 in the original dataset)
            # from the test set to the training set
            slide_mask = (self._metadata_df['slide'] == 23)
            self._metadata_df.loc[slide_mask,
                                  'split'] = self.split_dict['train']
        else:
            raise ValueError(
                f'Split scheme {self._split_scheme} not recognized')
        self._split_array = self._metadata_df['split'].values

        self._metadata_array = torch.stack(
            (torch.LongTensor(centers),
             torch.LongTensor(
                 self._metadata_df['slide'].values), self._y_array),
            dim=1)
        self._metadata_fields = ['hospital', 'slide', 'y']

        self._eval_grouper = CombinatorialGrouper(dataset=self,
                                                  groupby_fields=['slide'])

        super().__init__(root_dir, download, split_scheme)
예제 #8
0
    def __init__(self,
                 version=None,
                 root_dir='data',
                 download=False,
                 split_scheme='official',
                 oracle_training_set=False,
                 seed=111,
                 use_ood_val=False):
        self._version = version
        self._data_dir = self.initialize_data_dir(root_dir, download)

        self._split_dict = {
            'train': 0,
            'id_val': 1,
            'id_test': 2,
            'val': 3,
            'test': 4
        }
        self._split_names = {
            'train': 'Train',
            'id_val': 'ID Val',
            'id_test': 'ID Test',
            'val': 'OOD Val',
            'test': 'OOD Test'
        }
        if split_scheme == 'official':
            split_scheme = 'time_after_2016'
        self._split_scheme = split_scheme
        self.oracle_training_set = oracle_training_set

        self.root = Path(self._data_dir)
        self.seed = int(seed)
        self._original_resolution = (224, 224)

        self.category_to_idx = {cat: i for i, cat in enumerate(categories)}

        self.metadata = pd.read_csv(self.root / 'rgb_metadata.csv')
        country_codes_df = pd.read_csv(self.root / 'country_code_mapping.csv')
        countrycode_to_region = {
            k: v
            for k, v in zip(country_codes_df['alpha-3'],
                            country_codes_df['region'])
        }
        regions = [
            countrycode_to_region.get(code, 'Other')
            for code in self.metadata['country_code'].to_list()
        ]
        self.metadata['region'] = regions
        all_countries = self.metadata['country_code']

        self.num_chunks = 101
        self.chunk_size = len(self.metadata) // (self.num_chunks - 1)

        if self._split_scheme.startswith('time_after'):
            year = int(self._split_scheme.split('_')[2])
            year_dt = datetime.datetime(year, 1, 1, tzinfo=pytz.UTC)
            self.test_ood_mask = np.asarray(
                pd.to_datetime(self.metadata['timestamp']) >= year_dt)
            # use 3 years of the training set as validation
            year_minus_3_dt = datetime.datetime(year - 3,
                                                1,
                                                1,
                                                tzinfo=pytz.UTC)
            self.val_ood_mask = np.asarray(
                pd.to_datetime(self.metadata['timestamp']) >= year_minus_3_dt
            ) & ~self.test_ood_mask
            self.ood_mask = self.test_ood_mask | self.val_ood_mask
        else:
            raise ValueError(
                f"Not supported: self._split_scheme = {self._split_scheme}")

        self._split_array = -1 * np.ones(len(self.metadata))
        for split in self._split_dict.keys():
            idxs = np.arange(len(self.metadata))
            if split == 'test':
                test_mask = np.asarray(self.metadata['split'] == 'test')
                idxs = idxs[self.test_ood_mask & test_mask]
            elif split == 'val':
                val_mask = np.asarray(self.metadata['split'] == 'val')
                idxs = idxs[self.val_ood_mask & val_mask]
            elif split == 'id_test':
                test_mask = np.asarray(self.metadata['split'] == 'test')
                idxs = idxs[~self.ood_mask & test_mask]
            elif split == 'id_val':
                val_mask = np.asarray(self.metadata['split'] == 'val')
                idxs = idxs[~self.ood_mask & val_mask]
            else:
                split_mask = np.asarray(self.metadata['split'] == split)
                idxs = idxs[~self.ood_mask & split_mask]

            if self.oracle_training_set and split == 'train':
                test_mask = np.asarray(self.metadata['split'] == 'test')
                unused_ood_idxs = np.arange(len(self.metadata))[self.ood_mask
                                                                & ~test_mask]
                subsample_unused_ood_idxs = subsample_idxs(unused_ood_idxs,
                                                           num=len(idxs) // 2,
                                                           seed=self.seed + 2)
                subsample_train_idxs = subsample_idxs(idxs.copy(),
                                                      num=len(idxs) // 2,
                                                      seed=self.seed + 3)
                idxs = np.concatenate(
                    [subsample_unused_ood_idxs, subsample_train_idxs])
            self._split_array[idxs] = self._split_dict[split]

        if not use_ood_val:
            self._split_dict = {
                'train': 0,
                'val': 1,
                'id_test': 2,
                'ood_val': 3,
                'test': 4
            }
            self._split_names = {
                'train': 'Train',
                'val': 'ID Val',
                'id_test': 'ID Test',
                'ood_val': 'OOD Val',
                'test': 'OOD Test'
            }

        # filter out sequestered images from full dataset
        seq_mask = np.asarray(self.metadata['split'] == 'seq')
        # take out the sequestered images
        self._split_array = self._split_array[~seq_mask]
        self.full_idxs = np.arange(len(self.metadata))[~seq_mask]

        self._y_array = np.asarray(
            [self.category_to_idx[y] for y in list(self.metadata['category'])])
        self.metadata['y'] = self._y_array
        self._y_array = torch.from_numpy(self._y_array).long()[~seq_mask]
        self._y_size = 1
        self._n_classes = 62

        # convert region to idxs
        all_regions = list(self.metadata['region'].unique())
        region_to_region_idx = {
            region: i
            for i, region in enumerate(all_regions)
        }
        self._metadata_map = {'region': all_regions}
        region_idxs = [
            region_to_region_idx[region]
            for region in self.metadata['region'].tolist()
        ]
        self.metadata['region'] = region_idxs

        # make a year column in metadata
        year_array = -1 * np.ones(len(self.metadata))
        ts = pd.to_datetime(self.metadata['timestamp'])
        for year in range(2002, 2018):
            year_mask = np.asarray(ts >= datetime.datetime(year, 1, 1, tzinfo=pytz.UTC)) \
                        & np.asarray(ts < datetime.datetime(year+1, 1, 1, tzinfo=pytz.UTC))
            year_array[year_mask] = year - 2002
        self.metadata['year'] = year_array
        self._metadata_map['year'] = list(range(2002, 2018))

        self._metadata_fields = ['region', 'year', 'y']
        self._metadata_array = torch.from_numpy(self.metadata[
            self._metadata_fields].astype(int).to_numpy()).long()[~seq_mask]

        self._eval_groupers = {
            'year':
            CombinatorialGrouper(dataset=self, groupby_fields=['year']),
            'region':
            CombinatorialGrouper(dataset=self, groupby_fields=['region']),
        }

        super().__init__(root_dir, download, split_scheme)
예제 #9
0
def main():
    ''' set default hyperparams in default_hyperparams.py '''
    parser = argparse.ArgumentParser()

    # Required arguments
    parser.add_argument('-d',
                        '--dataset',
                        choices=wilds.supported_datasets,
                        required=True)
    parser.add_argument('--algorithm',
                        required=True,
                        choices=supported.algorithms)
    parser.add_argument(
        '--root_dir',
        required=True,
        help=
        'The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).'
    )
    parser.add_argument('--pretrained_model_path',
                        default=None,
                        type=str,
                        help="Specify a path to a pretrained model's weights")

    # Dataset
    parser.add_argument(
        '--split_scheme',
        help=
        'Identifies how the train/val/test split is constructed. Choices are dataset-specific.'
    )
    parser.add_argument('--dataset_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    parser.add_argument(
        '--download',
        default=False,
        type=parse_bool,
        const=True,
        nargs='?',
        help=
        'If true, tries to download the dataset if it does not exist in root_dir.'
    )
    parser.add_argument(
        '--frac',
        type=float,
        default=1.0,
        help=
        'Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes. Note that this also scales the test set down, so the reported numbers are not comparable with the full test set.'
    )
    parser.add_argument('--version', default=None, type=str)

    # Loaders
    parser.add_argument('--loader_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    parser.add_argument('--train_loader', choices=['standard', 'group'])
    parser.add_argument('--uniform_over_groups',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--distinct_groups',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--n_groups_per_batch', type=int)
    parser.add_argument('--unlabeled_n_groups_per_batch', type=int)
    parser.add_argument('--batch_size', type=int)
    parser.add_argument('--unlabeled_batch_size', type=int)
    parser.add_argument('--eval_loader',
                        choices=['standard'],
                        default='standard')
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        'Number of batches to process before stepping optimizer and/or schedulers. If > 1, we simulate having a larger effective batch size (though batchnorm behaves differently).'
    )

    # Active Learning
    parser.add_argument('--active_learning',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument(
        '--target_split',
        default="test",
        type=str,
        help=
        'Split from which to sample labeled examples and use as unlabeled data for self-training.'
    )
    parser.add_argument(
        '--use_target_labeled',
        type=parse_bool,
        const=True,
        nargs='?',
        default=True,
        help=
        "If false, we sample target labels and remove them from the eval set, but don't actually train on them."
    )
    parser.add_argument(
        '--use_source_labeled',
        type=parse_bool,
        const=True,
        nargs='?',
        default=False,
        help=
        "Train on labeled source examples (perhaps in addition to labeled target examples.)"
    )
    parser.add_argument(
        '--upsample_target_labeled',
        type=parse_bool,
        const=True,
        nargs='?',
        default=False,
        help=
        "If concatenating source labels, upsample target labels s.t. our labeled batches are 1/2 src and 1/2 tgt."
    )
    parser.add_argument('--selection_function',
                        choices=supported.selection_functions)
    parser.add_argument(
        '--selection_function_kwargs',
        nargs='*',
        action=ParseKwargs,
        default={},
        help=
        "keyword arguments for selection fn passed as key1=value1 key2=value2")
    parser.add_argument(
        '--selectby_fields',
        nargs='+',
        help=
        "If set, acts like a grouper and n_shots are acquired per selection group (e.g. y x hospital selects K examples per y x hospital)."
    )
    parser.add_argument('--n_shots',
                        type=int,
                        help="number of shots (labels) to actively acquire")

    # Model
    parser.add_argument('--model', choices=supported.models)
    parser.add_argument(
        '--model_kwargs',
        nargs='*',
        action=ParseKwargs,
        default={},
        help=
        'keyword arguments for model initialization passed as key1=value1 key2=value2'
    )
    parser.add_argument('--freeze_featurizer',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        help="Only train classifier weights")
    parser.add_argument(
        '--teacher_model_path',
        type=str,
        help=
        'Path to teacher model weights. If this is defined, pseudolabels will first be computed for unlabeled data before anything else runs.'
    )
    parser.add_argument('--dropout_rate', type=float)

    # Transforms
    parser.add_argument('--transform', choices=supported.transforms)
    parser.add_argument('--additional_labeled_transform',
                        type=parse_none,
                        choices=supported.additional_transforms)
    parser.add_argument('--additional_unlabeled_transform',
                        type=parse_none,
                        nargs='+',
                        choices=supported.additional_transforms)
    parser.add_argument(
        '--target_resolution',
        nargs='+',
        type=int,
        help=
        'The input resolution that images will be resized to before being passed into the model. For example, use --target_resolution 224 224 for a standard ResNet.'
    )
    parser.add_argument('--resize_scale', type=float)
    parser.add_argument('--max_token_length', type=int)
    parser.add_argument(
        '--randaugment_n',
        type=int,
        help=
        'N parameter of RandAugment - the number of transformations to apply.')

    # Objective
    parser.add_argument('--loss_function', choices=supported.losses)

    # Algorithm
    parser.add_argument('--groupby_fields', nargs='+')
    parser.add_argument('--group_dro_step_size', type=float)
    parser.add_argument('--coral_penalty_weight', type=float)
    parser.add_argument('--irm_lambda', type=float)
    parser.add_argument('--irm_penalty_anneal_iters', type=int)
    parser.add_argument('--maml_first_order',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--metalearning_k', type=int)
    parser.add_argument('--metalearning_adapt_lr', type=float)
    parser.add_argument('--metalearning_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    parser.add_argument('--self_training_labeled_weight',
                        type=float,
                        help='Weight of labeled loss')
    parser.add_argument('--self_training_unlabeled_weight',
                        type=float,
                        help='Weight of unlabeled loss')
    parser.add_argument('--self_training_threshold', type=float)
    parser.add_argument(
        '--pseudolabel_T2',
        type=float,
        help=
        'Percentage of total iterations at which to end linear scheduling and hold unlabeled weight at the max value'
    )
    parser.add_argument('--soft_pseudolabels',
                        default=False,
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--algo_log_metric')

    # Model selection
    parser.add_argument('--val_metric')
    parser.add_argument('--val_metric_decreasing',
                        type=parse_bool,
                        const=True,
                        nargs='?')

    # Optimization
    parser.add_argument('--n_epochs', type=int)
    parser.add_argument('--optimizer', choices=supported.optimizers)
    parser.add_argument('--lr', type=float)
    parser.add_argument('--weight_decay', type=float)
    parser.add_argument('--max_grad_norm', type=float)
    parser.add_argument('--optimizer_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})

    # Scheduler
    parser.add_argument('--scheduler', choices=supported.schedulers)
    parser.add_argument('--scheduler_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    parser.add_argument('--scheduler_metric_split',
                        choices=['train', 'val'],
                        default='val')
    parser.add_argument('--scheduler_metric_name')

    # Evaluation
    parser.add_argument('--process_outputs_function',
                        choices=supported.process_outputs_functions)
    parser.add_argument('--evaluate_all_splits',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument('--eval_splits', nargs='+', default=['val', 'test'])
    parser.add_argument(
        '--save_splits',
        nargs='+',
        default=['test'],
        help=
        'If save_pred_step or save_pseudo_step are set, then this sets which splits to save pred / pseudos for. Must be a subset of eval_splits.'
    )
    parser.add_argument('--eval_additional_every',
                        default=1,
                        type=int,
                        help='Eval additional splits every _ training epochs.')
    parser.add_argument('--eval_only',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument(
        '--eval_epoch',
        default=None,
        type=int,
        help=
        'If eval_only is set, then eval_epoch allows you to specify evaluating at a particular epoch. By default, it evaluates the best epoch by validation performance.'
    )

    # Misc
    parser.add_argument('--device', type=int, nargs='+', default=[0])
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--log_dir', default='./logs')
    parser.add_argument('--log_every', default=50, type=int)
    parser.add_argument('--save_model_step', type=int)
    parser.add_argument('--save_pred_step', type=int)
    parser.add_argument('--save_pseudo_step', type=int)
    parser.add_argument('--save_best',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=True)
    parser.add_argument('--save_last',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=True)
    parser.add_argument('--no_group_logging',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--progress_bar',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument(
        '--resume',
        type=parse_bool,
        const=True,
        nargs='?',
        default=False,
        help=
        'Whether to resume from the most recent saved model in the current log_dir.'
    )

    # Weights & Biases
    parser.add_argument('--use_wandb',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument(
        '--wandb_api_key_path',
        type=str,
        help=
        "Path to Weights & Biases API Key. If use_wandb is set to True and this argument is not specified, user will be prompted to authenticate."
    )
    parser.add_argument('--wandb_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={},
                        help="Will be passed directly into wandb.init().")

    config = parser.parse_args()
    config = populate_defaults(config)

    # Set device
    if torch.cuda.is_available():
        device_count = torch.cuda.device_count()
        if len(config.device) > device_count:
            raise ValueError(
                f"Specified {len(config.device)} devices, but only {device_count} devices found."
            )
        config.use_data_parallel = len(config.device) > 1
        try:
            device_str = ",".join(map(str, config.device))
            config.device = torch.device(f"cuda:{device_str}")
        except RuntimeError as e:
            print(
                f"Failed to initialize CUDA. Using torch.device('cuda') instead. Error: {str(e)}"
            )
            config.device = torch.device("cuda")
    else:
        config.use_data_parallel = False
        config.device = torch.device("cpu")

    ## Initialize logs
    if os.path.exists(config.log_dir) and config.resume:
        resume = True
        config.mode = 'a'
    elif os.path.exists(config.log_dir) and config.eval_only:
        resume = False
        config.mode = 'a'
    else:
        resume = False
        config.mode = 'w'

    if not os.path.exists(config.log_dir):
        os.makedirs(config.log_dir)
    logger = Logger(os.path.join(config.log_dir, 'log.txt'), config.mode)

    # Record config
    log_config(config, logger)

    # Set random seed
    set_seed(config.seed)

    # Algorithms that use unlabeled data must be run in active learning mode,
    # because otherwise we have no unlabeled data.
    if config.algorithm in ["PseudoLabel", "FixMatch", "NoisyStudent"]:
        assert config.active_learning

    # Data
    full_dataset = wilds.get_dataset(dataset=config.dataset,
                                     version=config.version,
                                     root_dir=config.root_dir,
                                     download=config.download,
                                     split_scheme=config.split_scheme,
                                     **config.dataset_kwargs)

    # In this project, we sometimes train on batches of mixed splits, e.g. some train labeled examples and test labeled examples
    # Within each batch, we may want to sample uniformly across split, or log the train v. test label balance
    # To facilitate this, we'll hack the WILDS dataset to include each point's split in the metadata array
    add_split_to_wilds_dataset_metadata_array(full_dataset)

    # To modify data augmentation, modify the following code block.
    # If you want to use transforms that modify both `x` and `y`,
    # set `do_transform_y` to True when initializing the `WILDSSubset` below.
    train_transform = initialize_transform(
        transform_name=config.transform,
        config=config,
        dataset=full_dataset,
        additional_transform=config.additional_labeled_transform,
        is_training=True)
    eval_transform = initialize_transform(transform_name=config.transform,
                                          config=config,
                                          dataset=full_dataset,
                                          is_training=False)

    # Define any special transforms for the algorithms that use unlabeled data
    # if config.algorithm == "FixMatch":
    #     # For FixMatch, we need our loader to return batches in the form ((x_weak, x_strong), m)
    #     # We do this by initializing a special transform function
    #     unlabeled_train_transform = initialize_transform(
    #         config.transform, config, full_dataset, is_training=True, additional_transform="fixmatch"
    #     )
    # else:
    unlabeled_train_transform = initialize_transform(
        config.transform,
        config,
        full_dataset,
        is_training=True,
        additional_transform=config.additional_unlabeled_transform)

    train_grouper = CombinatorialGrouper(dataset=full_dataset,
                                         groupby_fields=config.groupby_fields)

    datasets = defaultdict(dict)
    for split in full_dataset.split_dict.keys():
        if split == 'train':
            transform = train_transform
            verbose = True
        elif split == 'val':
            transform = eval_transform
            verbose = True
        else:
            transform = eval_transform
            verbose = False

        data = full_dataset.get_subset(split,
                                       frac=config.frac,
                                       transform=transform)

        datasets[split] = configure_split_dict(
            data=data,
            split=split,
            split_name=full_dataset.split_names[split],
            get_train=(split == 'train'),
            get_eval=(split != 'train'),
            verbose=verbose,
            grouper=train_grouper,
            batch_size=config.batch_size,
            config=config)

        pseudolabels = None
        if config.algorithm == "NoisyStudent" and config.target_split == split:
            # Infer teacher outputs on unlabeled examples in sequential order
            # During forward pass, ensure we are not shuffling and not applying strong augs
            print(
                f"Inferring teacher pseudolabels on {config.target_split} for Noisy Student"
            )
            assert config.teacher_model_path is not None
            if not config.teacher_model_path.endswith(".pth"):
                # Use the best model
                config.teacher_model_path = os.path.join(
                    config.teacher_model_path,
                    f"{config.dataset}_seed:{config.seed}_epoch:best_model.pth"
                )
            teacher_model = initialize_model(
                config, infer_d_out(full_dataset)).to(config.device)
            load(teacher_model,
                 config.teacher_model_path,
                 device=config.device)
            # Infer teacher outputs on weakly augmented unlabeled examples in sequential order
            weak_transform = initialize_transform(
                transform_name=config.transform,
                config=config,
                dataset=full_dataset,
                is_training=True,
                additional_transform="weak")
            unlabeled_split_dataset = full_dataset.get_subset(
                split, transform=weak_transform, frac=config.frac)
            sequential_loader = get_eval_loader(
                loader=config.eval_loader,
                dataset=unlabeled_split_dataset,
                grouper=train_grouper,
                batch_size=config.unlabeled_batch_size,
                **config.loader_kwargs)
            pseudolabels = infer_predictions(teacher_model, sequential_loader,
                                             config)
            del teacher_model

        if config.active_learning and config.target_split == split:
            datasets[split]['label_manager'] = LabelManager(
                subset=data,
                train_transform=train_transform,
                eval_transform=eval_transform,
                unlabeled_train_transform=unlabeled_train_transform,
                pseudolabels=pseudolabels)

    if config.use_wandb:
        initialize_wandb(config)

    # Logging dataset info
    # Show class breakdown if feasible
    if config.no_group_logging and full_dataset.is_classification and full_dataset.y_size == 1 and full_dataset.n_classes <= 10:
        log_grouper = CombinatorialGrouper(dataset=full_dataset,
                                           groupby_fields=['y'])
    elif config.no_group_logging:
        log_grouper = None
    else:
        log_grouper = train_grouper
    log_group_data(datasets, log_grouper, logger)

    ## Initialize algorithm
    ## Schedulers are initialized as if we will iterate over "train" split batches.
    ## If we train on another split (e.g. labeled test), we'll re-initialize schedulers later using algorithm.change_n_train_steps()
    algorithm = initialize_algorithm(config=config,
                                     datasets=datasets,
                                     train_grouper=train_grouper)
    if config.freeze_featurizer: freeze_features(algorithm)

    if config.active_learning:
        select_grouper = CombinatorialGrouper(
            dataset=full_dataset, groupby_fields=config.selectby_fields)
        selection_fn = initialize_selection_function(
            config, algorithm, select_grouper, algo_grouper=train_grouper)

    # Resume from most recent model in log_dir
    model_prefix = get_model_prefix(datasets['train'], config)
    if not config.eval_only:
        ## If doing active learning, expects to load a model trained on source
        resume_success = False
        if config.resume:
            save_path = model_prefix + 'epoch:last_model.pth'
            if not os.path.exists(save_path):
                epochs = [
                    int(file.split('epoch:')[1].split('_')[0])
                    for file in os.listdir(config.log_dir)
                    if file.endswith('.pth')
                ]
                if len(epochs) > 0:
                    latest_epoch = max(epochs)
                    save_path = model_prefix + f'epoch:{latest_epoch}_model.pth'
            try:
                prev_epoch, best_val_metric = load(algorithm, save_path,
                                                   config.device)
                # also load previous selections

                epoch_offset = prev_epoch + 1
                config.selection_function_kwargs[
                    'load_selection_path'] = config.log_dir
                logger.write(
                    f'Resuming from epoch {epoch_offset} with best val metric {best_val_metric}\n'
                )
                resume_success = True
            except FileNotFoundError:
                pass

        if resume_success == False:
            epoch_offset = 0
            best_val_metric = None

        # Log effective batch size
        logger.write((
            f'\nUsing gradient_accumulation_steps {config.gradient_accumulation_steps} means that'
        ) + (
            f' the effective labeled batch size is {config.batch_size * config.gradient_accumulation_steps}'
        ) + (
            f' and the effective unlabeled batch size is {config.unlabeled_batch_size * config.gradient_accumulation_steps}'
            if config.unlabeled_batch_size else '') + (
                '. Updates behave as if torch loaders have drop_last=False\n'))

        if config.active_learning:
            # create new labeled/unlabeled test splits
            train_split, unlabeled_split = run_active_learning(
                selection_fn=selection_fn,
                datasets=datasets,
                grouper=train_grouper,
                config=config,
                general_logger=logger,
                full_dataset=full_dataset)
            # reset schedulers, which were originally initialized to schedule based on the 'train' split
            # one epoch = one pass over labeled data
            algorithm.change_n_train_steps(
                new_n_train_steps=infer_n_train_steps(
                    datasets[train_split]['train_loader'], config),
                config=config)
        else:
            train_split = "train"
            unlabeled_split = None

        train(algorithm=algorithm,
              datasets=datasets,
              train_split=train_split,
              val_split="val",
              unlabeled_split=unlabeled_split,
              general_logger=logger,
              config=config,
              epoch_offset=epoch_offset,
              best_val_metric=best_val_metric)

    else:
        if config.eval_epoch is None:
            eval_model_path = model_prefix + 'epoch:best_model.pth'
        else:
            eval_model_path = model_prefix + f'epoch:{config.eval_epoch}_model.pth'
        best_epoch, best_val_metric = load(algorithm, eval_model_path,
                                           config.device)
        if config.eval_epoch is None:
            epoch = best_epoch
        else:
            epoch = config.eval_epoch

        if config.active_learning:
            # create new labeled/unlabeled test splits
            config.selection_function_kwargs[
                'load_selection_path'] = config.log_dir
            run_active_learning(selection_fn=selection_fn,
                                datasets=datasets,
                                grouper=train_grouper,
                                config=config,
                                general_logger=logger,
                                full_dataset=full_dataset)

        evaluate(algorithm=algorithm,
                 datasets=datasets,
                 epoch=epoch,
                 general_logger=logger,
                 config=config)

    if config.use_wandb:
        wandb.finish()
    logger.close()
    for split in datasets:
        datasets[split]['eval_logger'].close()
        datasets[split]['algo_logger'].close()
예제 #10
0
    def __init__(self,
                 root_dir='data',
                 download=False,
                 split_scheme='official'):
        self._dataset_name = 'cmnist4'
        self._version = '1.0'
        self._data_dir = self.initialize_data_dir(root_dir, download)
        self._original_resolution = (28, 28)

        # Read in metadata
        self._metadata_df = pd.read_csv(
            os.path.join(self._data_dir, 'metadata.csv'),
            index_col=0,
            # dtype={'patient': 'str'}
        )

        # Get the y values
        self._y_array = torch.LongTensor(self._metadata_df['digit'].values)
        self._y_array = (self._y_array == 6) + (self._y_array == 9) * 2
        self._y_size = 3
        self._n_classes = 3

        # Get filenames
        self._input_array = [
            f'images/env_{env}/digit_{digit}/{image}.pt'
            for image, digit, env in self._metadata_df.
            loc[:,
                ['image', 'digit', 'env']].itertuples(index=False, name=None)
        ]

        # Extract splits
        # Note that the hospital numbering here is different from what's in the paper,
        # where to avoid confusing readers we used a 1-indexed scheme and just labeled the test hospital as 5.
        # Here, the numbers are 0-indexed.
        test_env = 4
        val_env = 3

        self._split_dict = {'train': 0, 'id_val': 1, 'test': 2, 'val': 3}
        self._split_names = {
            'train': 'Train',
            'id_val': 'Validation (ID)',
            'test': 'Test',
            'val': 'Validation (OOD)',
        }
        envs = self._metadata_df['env'].values.astype('long')
        val_env_mask = (self._metadata_df['env'] == val_env)
        test_env_mask = (self._metadata_df['env'] == test_env)
        self._metadata_df.loc[val_env_mask, 'split'] = self.split_dict['val']
        self._metadata_df.loc[test_env_mask, 'split'] = self.split_dict['test']

        self._split_scheme = split_scheme
        if self._split_scheme != 'official':
            raise ValueError(
                f'Split scheme {self._split_scheme} not recognized')

        self._split_array = self._metadata_df['split'].values

        self._metadata_array = torch.stack(
            [torch.LongTensor(envs), self._y_array], dim=1)
        self._metadata_fields = ['env', 'y']

        self._eval_grouper = CombinatorialGrouper(dataset=self,
                                                  groupby_fields=['env'])

        self._metric = Accuracy()

        super().__init__(root_dir, download, split_scheme)
예제 #11
0
def main():
    ''' set default hyperparams in default_hyperparams.py '''
    parser = argparse.ArgumentParser()

    # Required arguments
    parser.add_argument('-d',
                        '--dataset',
                        choices=wilds.supported_datasets,
                        required=True)
    parser.add_argument('--algorithm',
                        required=True,
                        choices=supported.algorithms)
    parser.add_argument(
        '--root_dir',
        required=True,
        help=
        'The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).'
    )

    # Dataset
    parser.add_argument(
        '--split_scheme',
        help=
        'Identifies how the train/val/test split is constructed. Choices are dataset-specific.'
    )
    parser.add_argument('--dataset_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    parser.add_argument(
        '--download',
        default=False,
        type=parse_bool,
        const=True,
        nargs='?',
        help=
        'If true, tries to downloads the dataset if it does not exist in root_dir.'
    )
    parser.add_argument(
        '--frac',
        type=float,
        default=1.0,
        help=
        'Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes. Note that this also scales the test set down, so the reported numbers are not comparable with the full test set.'
    )
    parser.add_argument('--version', default=None, type=str)

    # Loaders
    parser.add_argument('--loader_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    parser.add_argument('--train_loader', choices=['standard', 'group'])
    parser.add_argument('--uniform_over_groups',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--distinct_groups',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--n_groups_per_batch', type=int)
    parser.add_argument('--batch_size', type=int)
    parser.add_argument('--eval_loader',
                        choices=['standard'],
                        default='standard')

    # Model
    parser.add_argument('--model', choices=supported.models)
    parser.add_argument(
        '--model_kwargs',
        nargs='*',
        action=ParseKwargs,
        default={},
        help=
        'keyword arguments for model initialization passed as key1=value1 key2=value2'
    )

    # Transforms
    parser.add_argument('--train_transform', choices=supported.transforms)
    parser.add_argument('--eval_transform', choices=supported.transforms)
    parser.add_argument(
        '--target_resolution',
        nargs='+',
        type=int,
        help=
        'The input resolution that images will be resized to before being passed into the model. For example, use --target_resolution 224 224 for a standard ResNet.'
    )
    parser.add_argument('--resize_scale', type=float)
    parser.add_argument('--max_token_length', type=int)

    # Objective
    parser.add_argument('--loss_function', choices=supported.losses)

    # Algorithm
    parser.add_argument('--groupby_fields', nargs='+')
    parser.add_argument('--group_dro_step_size', type=float)
    parser.add_argument('--coral_penalty_weight', type=float)
    parser.add_argument('--dann_lambda', type=float)
    parser.add_argument('--dann_domain_layers', type=int,
                        default=1)  # hidden layers
    parser.add_argument('--dann_label_layers', type=int,
                        default=1)  # hidden layers
    parser.add_argument('--domain_loss_function', choices=supported.losses)
    parser.add_argument('--irm_lambda', type=float)
    parser.add_argument('--irm_penalty_anneal_iters', type=int)
    parser.add_argument('--algo_log_metric')

    # Model selection
    parser.add_argument('--val_metric')
    parser.add_argument('--val_metric_decreasing',
                        type=parse_bool,
                        const=True,
                        nargs='?')

    # Optimization
    parser.add_argument('--n_epochs', type=int)
    parser.add_argument('--optimizer', choices=supported.optimizers)
    parser.add_argument('--lr', type=float)
    parser.add_argument('--weight_decay', type=float)
    parser.add_argument('--max_grad_norm', type=float)
    parser.add_argument('--optimizer_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})

    # Scheduler
    parser.add_argument('--scheduler', choices=supported.schedulers)
    parser.add_argument('--scheduler_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    parser.add_argument('--scheduler_metric_split',
                        choices=['train', 'val'],
                        default='val')
    parser.add_argument('--scheduler_metric_name')

    # Evaluation
    parser.add_argument('--process_outputs_function',
                        choices=supported.process_outputs_functions)
    parser.add_argument('--evaluate_all_splits',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=True)
    parser.add_argument('--eval_splits', nargs='+', default=[])
    parser.add_argument('--eval_only',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument(
        '--eval_epoch',
        default=None,
        type=int,
        help=
        'If eval_only is set, then eval_epoch allows you to specify evaluating at a particular epoch. By default, it evaluates the best epoch by validation performance.'
    )

    # Misc
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--log_dir', default='./logs')
    parser.add_argument('--log_every', default=50, type=int)
    parser.add_argument('--save_step', type=int)
    parser.add_argument('--save_best',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=True)
    parser.add_argument('--save_last',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=True)
    parser.add_argument('--save_pred',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=True)
    parser.add_argument('--no_group_logging',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--use_wandb',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument('--progress_bar',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument('--resume',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)

    config = parser.parse_args()
    config = populate_defaults(config)

    # set device
    config.device = torch.device("cuda:" + str(
        config.device)) if torch.cuda.is_available() else torch.device("cpu")

    ## Initialize logs
    if os.path.exists(config.log_dir) and config.resume:
        resume = True
        mode = 'a'
    elif os.path.exists(config.log_dir) and config.eval_only:
        resume = False
        mode = 'a'
    else:
        resume = False
        mode = 'w'

    if not os.path.exists(config.log_dir):
        os.makedirs(config.log_dir)
    logger = Logger(os.path.join(config.log_dir, 'log.txt'), mode)

    # Record config
    log_config(config, logger)

    # Set random seed
    set_seed(config.seed)

    # Data
    full_dataset = wilds.get_dataset(dataset=config.dataset,
                                     version=config.version,
                                     root_dir=config.root_dir,
                                     download=config.download,
                                     split_scheme=config.split_scheme,
                                     **config.dataset_kwargs)

    # To implement data augmentation (i.e., have different transforms
    # at training time vs. test time), modify these two lines:
    train_transform = initialize_transform(
        transform_name=config.train_transform,
        config=config,
        dataset=full_dataset)
    eval_transform = initialize_transform(transform_name=config.eval_transform,
                                          config=config,
                                          dataset=full_dataset)

    train_grouper = CombinatorialGrouper(dataset=full_dataset,
                                         groupby_fields=config.groupby_fields)

    datasets = defaultdict(dict)
    for split in full_dataset.split_dict.keys():
        if split == 'train':
            transform = train_transform
            verbose = True
        elif split == 'val':
            transform = eval_transform
            verbose = True
        else:
            transform = eval_transform
            verbose = False
        # Get subset
        datasets[split]['dataset'] = full_dataset.get_subset(
            split, frac=config.frac, transform=transform)

        if split == 'train':
            datasets[split]['loader'] = get_train_loader(
                loader=config.train_loader,
                dataset=datasets[split]['dataset'],
                batch_size=config.batch_size,
                uniform_over_groups=config.uniform_over_groups,
                grouper=train_grouper,
                distinct_groups=config.distinct_groups,
                n_groups_per_batch=config.n_groups_per_batch,
                **config.loader_kwargs)
        else:
            datasets[split]['loader'] = get_eval_loader(
                loader=config.eval_loader,
                dataset=datasets[split]['dataset'],
                grouper=train_grouper,
                batch_size=config.batch_size,
                **config.loader_kwargs)

        # Set fields
        datasets[split]['split'] = split
        datasets[split]['name'] = full_dataset.split_names[split]
        datasets[split]['verbose'] = verbose

        # Loggers
        datasets[split]['eval_logger'] = BatchLogger(
            os.path.join(config.log_dir, f'{split}_eval.csv'),
            mode=mode,
            use_wandb=(config.use_wandb and verbose))
        datasets[split]['algo_logger'] = BatchLogger(
            os.path.join(config.log_dir, f'{split}_algo.csv'),
            mode=mode,
            use_wandb=(config.use_wandb and verbose))

        if config.use_wandb:
            initialize_wandb(config)

    # Logging dataset info
    # Show class breakdown if feasible
    if config.no_group_logging and full_dataset.is_classification and full_dataset.y_size == 1 and full_dataset.n_classes <= 10:
        log_grouper = CombinatorialGrouper(dataset=full_dataset,
                                           groupby_fields=['y'])
    elif config.no_group_logging:
        log_grouper = None
    else:
        log_grouper = train_grouper
    log_group_data(datasets, log_grouper, logger)

    ## Initialize algorithm
    algorithm = initialize_algorithm(config=config,
                                     datasets=datasets,
                                     train_grouper=train_grouper)

    model_prefix = get_model_prefix(datasets['train'], config)
    if not config.eval_only:
        ## Load saved results if resuming
        resume_success = False
        if resume:
            save_path = model_prefix + 'epoch:last_model.pth'
            if not os.path.exists(save_path):
                epochs = [
                    int(file.split('epoch:')[1].split('_')[0])
                    for file in os.listdir(config.log_dir)
                    if file.endswith('.pth')
                ]
                if len(epochs) > 0:
                    latest_epoch = max(epochs)
                    save_path = model_prefix + f'epoch:{latest_epoch}_model.pth'
            try:
                prev_epoch, best_val_metric = load(algorithm, save_path)
                epoch_offset = prev_epoch + 1
                logger.write(
                    f'Resuming from epoch {epoch_offset} with best val metric {best_val_metric}'
                )
                resume_success = True
            except FileNotFoundError:
                pass

        if resume_success == False:
            epoch_offset = 0
            best_val_metric = None

        train(algorithm=algorithm,
              datasets=datasets,
              general_logger=logger,
              config=config,
              epoch_offset=epoch_offset,
              best_val_metric=best_val_metric)
    else:
        if config.eval_epoch is None:
            eval_model_path = model_prefix + 'epoch:best_model.pth'
        else:
            eval_model_path = model_prefix + f'epoch:{config.eval_epoch}_model.pth'
        best_epoch, best_val_metric = load(algorithm, eval_model_path)
        if config.eval_epoch is None:
            epoch = best_epoch
        else:
            epoch = config.eval_epoch
        evaluate(algorithm=algorithm,
                 datasets=datasets,
                 epoch=epoch,
                 general_logger=logger,
                 config=config)

    logger.close()
    for split in datasets:
        datasets[split]['eval_logger'].close()
        datasets[split]['algo_logger'].close()
예제 #12
0
    def __init__(self,
                 root_dir='data',
                 download=False,
                 split_scheme='official',
                 no_nl=True,
                 fold='A',
                 oracle_training_set=False,
                 use_ood_val=False):

        self._compressed_size = 18_630_656_000
        self._data_dir = self.initialize_data_dir(root_dir, download)

        self._split_dict = {
            'train': 0,
            'id_val': 1,
            'id_test': 2,
            'val': 3,
            'test': 4
        }
        self._split_names = {
            'train': 'Train',
            'id_val': 'ID Val',
            'id_test': 'ID Test',
            'val': 'OOD Val',
            'test': 'OOD Test'
        }

        if split_scheme == 'official':
            split_scheme = 'countries'
        self._split_scheme = split_scheme
        if self._split_scheme != 'countries':
            raise ValueError("Split scheme not recognized")

        self.oracle_training_set = oracle_training_set

        self.no_nl = no_nl
        if fold not in {'A', 'B', 'C', 'D', 'E'}:
            raise ValueError("Fold must be A, B, C, D, or E")

        self.root = Path(self._data_dir)
        self.metadata = pd.read_csv(self.root / 'dhs_metadata.csv')
        # country folds, split off OOD
        country_folds = SURVEY_NAMES[f'2009-17{fold}']

        self._split_array = -1 * np.ones(len(self.metadata))

        incountry_folds_split = np.arange(len(self.metadata))
        # take the test countries to be ood
        idxs_id, idxs_ood_test = split_by_countries(incountry_folds_split,
                                                    country_folds['test'],
                                                    self.metadata)
        # also create a validation OOD set
        idxs_id, idxs_ood_val = split_by_countries(idxs_id,
                                                   country_folds['val'],
                                                   self.metadata)
        for split in ['test', 'val', 'id_test', 'id_val', 'train']:
            # keep ood for test, otherwise throw away ood data
            if split == 'test':
                idxs = idxs_ood_test
            elif split == 'val':
                idxs = idxs_ood_val
            else:
                idxs = idxs_id
                num_eval = 2000
                # if oracle, do 50-50 split between OOD and ID
                if split == 'train' and self.oracle_training_set:
                    idxs = subsample_idxs(incountry_folds_split,
                                          num=len(idxs_id),
                                          seed=ord(fold))[num_eval:]
                elif split != 'train' and self.oracle_training_set:
                    eval_idxs = subsample_idxs(incountry_folds_split,
                                               num=len(idxs_id),
                                               seed=ord(fold))[:num_eval]
                elif split == 'train':
                    idxs = subsample_idxs(idxs,
                                          take_rest=True,
                                          num=num_eval,
                                          seed=ord(fold))
                else:
                    eval_idxs = subsample_idxs(idxs,
                                               take_rest=False,
                                               num=num_eval,
                                               seed=ord(fold))

                if split != 'train':
                    if split == 'id_val':
                        idxs = eval_idxs[:num_eval // 2]
                    else:
                        idxs = eval_idxs[num_eval // 2:]
            self._split_array[idxs] = self._split_dict[split]

        if not use_ood_val:
            self._split_dict = {
                'train': 0,
                'val': 1,
                'id_test': 2,
                'ood_val': 3,
                'test': 4
            }
            self._split_names = {
                'train': 'Train',
                'val': 'ID Val',
                'id_test': 'ID Test',
                'ood_val': 'OOD Val',
                'test': 'OOD Test'
            }

        self.imgs = np.load(self.root / 'landsat_poverty_imgs.npy',
                            mmap_mode='r')

        self.imgs = self.imgs.transpose((0, 3, 1, 2))
        self._y_array = torch.from_numpy(
            np.asarray(self.metadata['wealthpooled'])[:, np.newaxis]).float()
        self._y_size = 1

        # add country group field
        country_to_idx = {
            country: i
            for i, country in enumerate(DHS_COUNTRIES)
        }
        self.metadata['country'] = [
            country_to_idx[country]
            for country in self.metadata['country'].tolist()
        ]
        self._metadata_map = {'country': DHS_COUNTRIES}
        self._metadata_array = torch.from_numpy(
            self.metadata[['urban', 'wealthpooled',
                           'country']].astype(float).to_numpy())
        # rename wealthpooled to y
        self._metadata_fields = ['urban', 'y', 'country']

        self._eval_grouper = CombinatorialGrouper(dataset=self,
                                                  groupby_fields=['urban'])

        self._metrics = [MSE(), PearsonCorrelation()]
        self.cache_counter = 0

        super().__init__(root_dir, download, split_scheme)
예제 #13
0
    def __init__(self, root_dir='data', download=False, split_scheme='official'):
        self._dataset_name = 'vlcs'
        self._version = '1.0'
        # self._download_url = 'https://drive.google.com/uc?id=1skwblH1_okBwxWxmRsp9_qi15hyPpxg8'
        self._data_dir = self.initialize_data_dir(root_dir, download)
        self._resolution = (224, 224)

        # Read in metadata
        self._metadata_df = pd.read_csv(
            os.path.join(self._data_dir, 'metadata.csv'),
            index_col=0
        )

        # Get the y values
        self._label_map = {
            'bird': 0,
            'car': 1,
            'chair': 2,
            'dog': 3,
            'person': 4
        }
        self._label_array = self._metadata_df['label'].values
        self._y_array = torch.LongTensor([self._label_map[y] for y in self._label_array])
        self._y_size = 1
        self._n_classes = 5

        # Get filenames
        self._input_array = [
            f'{env}/{label}/{image}'
            for image, label, env in
            self._metadata_df.loc[:, ['image', 'label', 'env']].itertuples(index=False, name=None)]

        test_env = ''  #'VOC2007'
        val_env = 'VOC2007'

        self._split_dict = {
            'train': 0,
            'id_val': 1,
            'test': 2,
            'val': 3
        }
        self._split_names = {
            'train': 'Train',
            'id_val': 'Validation (ID)',
            'test': 'Test',
            'val': 'Validation (OOD)',
        }

        env_map = {
            'SUN09': 0,
            'LabelMe': 1,
            'Caltech101': 2,
            'VOC2007': 3
        }
        env_names = self._metadata_df['env'].values
        envs = [env_map[name] for name in env_names]

        val_env_mask = (self._metadata_df['env'] == val_env)
        test_env_mask = (self._metadata_df['env'] == test_env)
        self._metadata_df.loc[val_env_mask, 'split'] = self.split_dict['val']
        self._metadata_df.loc[test_env_mask, 'split'] = self.split_dict['test']

        self._split_scheme = split_scheme
        if self._split_scheme != 'official':
            raise ValueError(f'Split scheme {self._split_scheme} not recognized')
        self._split_array = self._metadata_df['split'].values

        self._metadata_array = torch.stack(
            (torch.LongTensor(envs),
             self._y_array),
            dim=1)
        self._metadata_fields = ['env', 'y']

        self._eval_grouper = CombinatorialGrouper(
            dataset=self,
            groupby_fields=['env'])

        self._metric = Accuracy()

        super().__init__(root_dir, download, split_scheme)
예제 #14
0
def run_active_learning(selection_fn,
                        datasets,
                        grouper,
                        config,
                        general_logger,
                        full_dataset=None):
    label_manager = datasets[config.target_split]['label_manager']

    # First run selection function
    selection_fn.select_and_reveal(label_manager=label_manager,
                                   K=config.n_shots)
    general_logger.write(
        f"Total Labels Revealed: {label_manager.num_labeled}\n")

    # Concatenate labeled source examples to labeled target examples
    if config.use_source_labeled:
        assert full_dataset is not None
        # We allow optionally ignoring the target examples entirely
        if not config.use_target_labeled:
            indices = datasets['train']['dataset'].indices
        else:
            indices = np.concatenate(
                (label_manager.labeled_indices,
                 datasets['train']['dataset'].indices)).astype(
                     int)  # target points at front
        labeled_dataset = WILDSSubset(full_dataset, indices,
                                      label_manager.labeled_train_transform)
    else:
        labeled_dataset = label_manager.get_labeled_subset()

    if config.upsample_target_labeled:
        # upsample target labels (compared to src labels) using a weighted sampler
        # do this by grouping by split and then using --uniform_over_groups=True
        labeled_grouper = CombinatorialGrouper(dataset=full_dataset,
                                               groupby_fields=['split'])
        labeled_config = copy(config)
        labeled_config.uniform_over_groups = True
    else:
        labeled_config = config
        labeled_grouper = grouper

    # Dump unlabeled indices to file
    save_array(label_manager.unlabeled_indices,
               csv_path=f'{config.log_dir}/unlabeled_test_ids.csv')

    # Add new splits to datasets dict
    ## Training Splits
    ### Labeled test
    datasets[f'labeled_{config.target_split}'] = configure_split_dict(
        data=labeled_dataset,
        split=f'labeled_{config.target_split}',
        split_name=f'labeled_{config.target_split}',
        get_train=True,
        verbose=True,
        grouper=labeled_grouper,
        batch_size=config.batch_size,
        config=labeled_config)
    ### Unlabeled test
    datasets[
        f'unlabeled_{config.target_split}_augmented'] = configure_split_dict(
            data=label_manager.get_unlabeled_subset(train=True),
            split=f"unlabeled_{config.target_split}_augmented",
            split_name=f"unlabeled_{config.target_split}_augmented",
            get_train=True,
            get_eval=True,
            grouper=grouper,
            batch_size=config.unlabeled_batch_size,
            verbose=True,
            config=config)
    ## Eval Splits
    ### Unlabeled test, eval transform
    datasets[f'unlabeled_{config.target_split}'] = configure_split_dict(
        data=label_manager.get_unlabeled_subset(train=False,
                                                return_pseudolabels=False),
        split=f"unlabeled_{config.target_split}",
        split_name=f"unlabeled_{config.target_split}",
        get_eval=True,
        grouper=None,
        verbose=True,
        batch_size=config.unlabeled_batch_size,
        config=config)

    ## Special de-duplicated eval set for fmow
    if config.dataset == 'fmow':
        disjoint_unlabeled_indices = fmow_deduplicate_locations(
            negative_indices=label_manager.labeled_indices,
            superset_indices=label_manager.unlabeled_indices,
            config=config)
        save_array(disjoint_unlabeled_indices,
                   csv_path=f'{config.log_dir}/disjoint_ids.csv')
        # build disjoint split
        disjoint_eval_dataset = WILDSSubset(full_dataset,
                                            disjoint_unlabeled_indices,
                                            label_manager.eval_transform)
        datasets[
            f'unlabeled_{config.target_split}_disjoint'] = configure_split_dict(
                data=disjoint_eval_dataset,
                split=f'unlabeled_{config.target_split}_disjoint',
                split_name=f'unlabeled_{config.target_split}_disjoint',
                get_eval=True,
                grouper=None,
                verbose=True,
                batch_size=config.unlabeled_batch_size,
                config=config)

    # Save NoisyStudent pseudolabels initially
    if config.algorithm == 'NoisyStudent':
        save_pseudo_if_needed(label_manager.unlabeled_pseudolabel_array,
                              f'unlabeled_{config.target_split}',
                              datasets[f'unlabeled_{config.target_split}'],
                              None, config, None)
        if f'unlabeled_{config.target_split}_disjoint' in datasets:
            save_pseudo_if_needed(
                label_manager.unlabeled_pseudolabel_array[[
                    label_manager.unlabeled_indices.index(i)
                    for i in disjoint_unlabeled_indices
                ]], f'unlabeled_{config.target_split}_disjoint',
                datasets[f'unlabeled_{config.target_split}_disjoint'], None,
                config, None)

    # return names of train_split, unlabeled_split
    return f'labeled_{config.target_split}', f"unlabeled_{config.target_split}_augmented"
예제 #15
0
    def __init__(self,
                 version=None,
                 root_dir='data',
                 download=False,
                 split_scheme='official'):

        self._version = version
        self._split_scheme = split_scheme
        if self._split_scheme != 'official':
            raise ValueError(
                f'Split scheme {self._split_scheme} not recognized')

        # path
        self._data_dir = Path(self.initialize_data_dir(root_dir, download))

        # Load splits
        train_df = pd.read_csv(self._data_dir / 'train.csv')
        val_trans_df = pd.read_csv(self._data_dir / 'val_trans.csv')
        test_trans_df = pd.read_csv(self._data_dir / 'test_trans.csv')
        val_cis_df = pd.read_csv(self._data_dir / 'val_cis.csv')
        test_cis_df = pd.read_csv(self._data_dir / 'test_cis.csv')

        # Merge all dfs
        train_df['split'] = 'train'
        val_trans_df['split'] = 'val'
        test_trans_df['split'] = 'test'
        val_cis_df['split'] = 'id_val'
        test_cis_df['split'] = 'id_test'
        df = pd.concat(
            [train_df, val_trans_df, test_trans_df, test_cis_df, val_cis_df])

        # Splits
        data = {}
        self._split_dict = {
            'train': 0,
            'val': 1,
            'test': 2,
            'id_val': 3,
            'id_test': 4
        }
        self._split_names = {
            'train': 'Train',
            'val': 'Validation (OOD/Trans)',
            'test': 'Test (OOD/Trans)',
            'id_val': 'Validation (ID/Cis)',
            'id_test': 'Test (ID/Cis)'
        }

        df['split_id'] = df['split'].apply(lambda x: self._split_dict[x])
        self._split_array = df['split_id'].values

        # Filenames
        self._input_array = df['filename'].values

        # Labels
        unique_categories = np.unique(df['category_id'])
        self._n_classes = len(unique_categories)
        category_to_label = dict([
            (i, j) for i, j in zip(unique_categories, range(self._n_classes))
        ])
        label_to_category = dict([(v, k)
                                  for k, v in category_to_label.items()])
        self._y_array = torch.tensor(
            df['category_id'].apply(lambda x: category_to_label[x]).values)
        self._y_size = 1

        # Location/group info
        location_ids = df['location']
        locations = np.unique(location_ids)
        n_groups = len(locations)
        location_to_group_id = {locations[i]: i for i in range(n_groups)}
        df['group_id'] = df['location'].apply(
            lambda x: location_to_group_id[x])

        self._n_groups = n_groups

        # Extract datetime subcomponents and include in metadata
        df['datetime_obj'] = df['datetime'].apply(
            lambda x: datetime.strptime(x, '%Y-%m-%d %H:%M:%S.%f'))
        df['year'] = df['datetime_obj'].apply(lambda x: int(x.year))
        df['month'] = df['datetime_obj'].apply(lambda x: int(x.month))
        df['day'] = df['datetime_obj'].apply(lambda x: int(x.day))
        df['hour'] = df['datetime_obj'].apply(lambda x: int(x.hour))
        df['minute'] = df['datetime_obj'].apply(lambda x: int(x.minute))
        df['second'] = df['datetime_obj'].apply(lambda x: int(x.second))

        self._metadata_array = torch.tensor(
            np.stack([
                df['group_id'].values, df['year'].values, df['month'].values,
                df['day'].values, df['hour'].values, df['minute'].values,
                df['second'].values, self.y_array
            ],
                     axis=1))
        self._metadata_fields = [
            'location', 'year', 'month', 'day', 'hour', 'minute', 'second', 'y'
        ]
        # eval grouper
        self._eval_grouper = CombinatorialGrouper(dataset=self,
                                                  groupby_fields=(['location'
                                                                   ]))

        super().__init__(root_dir, download, split_scheme)
예제 #16
0
    def __init__(self,
                 version=None,
                 root_dir='data',
                 download=False,
                 split_scheme='official'):

        self._version = version
        self._split_scheme = split_scheme
        if self._split_scheme != 'official':
            raise ValueError(
                f'Split scheme {self._split_scheme} not recognized')

        # path
        self._data_dir = Path(self.initialize_data_dir(root_dir, download))

        # Load splits
        df = pd.read_csv(self._data_dir / 'metadata.csv')

        # Splits
        self._split_dict = {
            'train': 0,
            'val': 1,
            'test': 2,
            'id_val': 3,
            'id_test': 4
        }
        self._split_names = {
            'train': 'Train',
            'val': 'Validation (OOD/Trans)',
            'test': 'Test (OOD/Trans)',
            'id_val': 'Validation (ID/Cis)',
            'id_test': 'Test (ID/Cis)'
        }

        df['split_id'] = df['split'].apply(lambda x: self._split_dict[x])
        self._split_array = df['split_id'].values

        # Filenames
        self._input_array = df['filename'].values

        # Labels
        self._y_array = torch.tensor(df['y'].values)
        self._n_classes = max(df['y']) + 1
        self._y_size = 1
        assert len(np.unique(df['y'])) == self._n_classes

        # Location/group info
        n_groups = max(df['location_remapped']) + 1
        self._n_groups = n_groups
        assert len(np.unique(df['location_remapped'])) == self._n_groups

        # Sequence info
        n_sequences = max(df['sequence_remapped']) + 1
        self._n_sequences = n_sequences
        assert len(np.unique(df['sequence_remapped'])) == self._n_sequences

        # Extract datetime subcomponents and include in metadata
        df['datetime_obj'] = df['datetime'].apply(
            lambda x: datetime.strptime(x, '%Y-%m-%d %H:%M:%S.%f'))
        df['year'] = df['datetime_obj'].apply(lambda x: int(x.year))
        df['month'] = df['datetime_obj'].apply(lambda x: int(x.month))
        df['day'] = df['datetime_obj'].apply(lambda x: int(x.day))
        df['hour'] = df['datetime_obj'].apply(lambda x: int(x.hour))
        df['minute'] = df['datetime_obj'].apply(lambda x: int(x.minute))
        df['second'] = df['datetime_obj'].apply(lambda x: int(x.second))

        self._metadata_array = torch.tensor(
            np.stack([
                df['location_remapped'].values, df['sequence_remapped'].values,
                df['year'].values, df['month'].values, df['day'].values,
                df['hour'].values, df['minute'].values, df['second'].values,
                self.y_array
            ],
                     axis=1))
        self._metadata_fields = [
            'location', 'sequence', 'year', 'month', 'day', 'hour', 'minute',
            'second', 'y'
        ]

        # eval grouper
        self._eval_grouper = CombinatorialGrouper(dataset=self,
                                                  groupby_fields=(['location'
                                                                   ]))

        super().__init__(root_dir, download, split_scheme)
예제 #17
0
    def __init__(self,
                 root_dir='data',
                 download=False,
                 split_scheme='official'):

        self._dataset_name = 'iwildcam'
        self._version = '1.0'
        self._split_scheme = split_scheme
        if self._split_scheme != 'official':
            raise ValueError(
                f'Split scheme {self._split_scheme} not recognized')

        # path
        self._download_url = ''
        self._compressed_size = 90_094_666_806
        self._data_dir = Path(self.initialize_data_dir(root_dir, download))

        # Load splits
        train_df = pd.read_csv(self._data_dir / 'train.csv')
        val_trans_df = pd.read_csv(self._data_dir / 'val_trans.csv')
        test_trans_df = pd.read_csv(self._data_dir / 'test_trans.csv')
        val_cis_df = pd.read_csv(self._data_dir / 'val_cis.csv')
        test_cis_df = pd.read_csv(self._data_dir / 'test_cis.csv')

        # Merge all dfs
        train_df['split'] = 'train'
        val_trans_df['split'] = 'val'
        test_trans_df['split'] = 'test'
        val_cis_df['split'] = 'id_val'
        test_cis_df['split'] = 'id_test'
        df = pd.concat(
            [train_df, val_trans_df, test_trans_df, test_cis_df, val_cis_df])

        # Splits
        data = {}
        self._split_dict = {
            'train': 0,
            'val': 1,
            'test': 2,
            'id_val': 3,
            'id_test': 4
        }
        self._split_names = {
            'train': 'Train',
            'val': 'Validation (OOD/Trans)',
            'test': 'Test (OOD/Trans)',
            'id_val': 'Validation (ID/Cis)',
            'id_test': 'Test (ID/Cis)'
        }

        df['split_id'] = df['split'].apply(lambda x: self._split_dict[x])
        self._split_array = df['split_id'].values

        # Filenames
        self._input_array = df['filename'].values

        # Labels
        unique_categories = np.unique(df['category_id'])
        self._n_classes = len(unique_categories)
        category_to_label = dict([
            (i, j) for i, j in zip(unique_categories, range(self._n_classes))
        ])
        label_to_category = dict([(v, k)
                                  for k, v in category_to_label.items()])
        self._y_array = torch.tensor(
            df['category_id'].apply(lambda x: category_to_label[x]).values)
        self._y_size = 1

        # Location/group info
        location_ids = df['location']
        locations = np.unique(location_ids)
        n_groups = len(locations)
        location_to_group_id = {locations[i]: i for i in range(n_groups)}
        df['group_id'] = df['location'].apply(
            lambda x: location_to_group_id[x])

        self._n_groups = n_groups
        self._metadata_array = torch.tensor(
            np.stack([df['group_id'].values, self.y_array], axis=1))
        self._metadata_fields = ['location', 'y']
        # eval grouper
        self._eval_grouper = CombinatorialGrouper(dataset=self,
                                                  groupby_fields=(['location'
                                                                   ]))

        self._metrics = [
            Accuracy(),
            Recall(average='macro'),
            Recall(average='weighted'),
            F1(average='macro'),
            F1(average='weighted')
        ]
        super().__init__(root_dir, download, split_scheme)
예제 #18
0
def main():
    ''' set default hyperparams in default_hyperparams.py '''
    parser = argparse.ArgumentParser()

    # Required arguments
    parser.add_argument('-d',
                        '--dataset',
                        choices=supported.datasets,
                        required=True)
    parser.add_argument('--algorithm',
                        required=True,
                        choices=supported.algorithms)
    parser.add_argument(
        '--root_dir',
        required=True,
        help=
        'The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).'
    )
    parser.add_argument('--analyze_sample', default=1)

    # Dataset
    parser.add_argument(
        '--split_scheme',
        help=
        'Identifies how the train/val/test split is constructed. Choices are dataset-specific.'
    )
    parser.add_argument('--dataset_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    parser.add_argument(
        '--download',
        default=False,
        type=parse_bool,
        const=True,
        nargs='?',
        help=
        'If true, tries to downloads the dataset if it does not exist in root_dir.'
    )
    parser.add_argument(
        '--frac',
        type=float,
        default=1.0,
        help=
        'Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.'
    )

    # Loaders
    parser.add_argument('--loader_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    parser.add_argument('--train_loader', choices=['standard', 'group'])
    parser.add_argument('--uniform_over_groups',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--distinct_groups',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--n_groups_per_batch', type=int)
    parser.add_argument('--batch_size', type=int)
    parser.add_argument('--eval_loader',
                        choices=['standard'],
                        default='standard')

    # Model
    parser.add_argument('--model', choices=supported.models)
    parser.add_argument(
        '--model_kwargs',
        nargs='*',
        action=ParseKwargs,
        default={},
        help=
        'keyword arguments for model initialization passed as key1=value1 key2=value2'
    )

    # Transforms
    parser.add_argument('--train_transform', choices=supported.transforms)
    parser.add_argument('--eval_transform', choices=supported.transforms)
    parser.add_argument(
        '--target_resolution',
        nargs='+',
        type=int,
        help=
        'target resolution. for example --target_resolution 224 224 for standard resnet.'
    )
    parser.add_argument('--resize_scale', type=float)
    parser.add_argument('--max_token_length', type=int)

    # Objective
    parser.add_argument('--loss_function', choices=supported.losses)

    # Algorithm
    parser.add_argument('--groupby_fields', nargs='+')
    parser.add_argument('--group_dro_step_size', type=float)
    parser.add_argument('--coral_penalty_weight', type=float)
    parser.add_argument('--irm_lambda', type=float)
    parser.add_argument('--irm_penalty_anneal_iters', type=int)
    parser.add_argument('--algo_log_metric')
    parser.add_argument('--hsic_beta', type=float)
    parser.add_argument('--grad_penalty_lamb', type=float)
    parser.add_argument(
        '--params_regex',
        type=str,
        help='Regular expression specifying which gradients to penalize.')
    parser.add_argument('--label_cond',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument('--dann_lamb', type=float)
    parser.add_argument('--dann_dc_name', type=str)

    # Model selection
    parser.add_argument('--val_metric')
    parser.add_argument('--val_metric_decreasing',
                        type=parse_bool,
                        const=True,
                        nargs='?')

    # Optimization
    parser.add_argument('--n_epochs', type=int)
    parser.add_argument('--optimizer', choices=supported.optimizers)
    parser.add_argument('--lr', type=float)
    parser.add_argument('--weight_decay', type=float)
    parser.add_argument('--max_grad_norm', type=float)
    parser.add_argument('--optimizer_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})

    # Scheduler
    parser.add_argument('--scheduler', choices=supported.schedulers)
    parser.add_argument('--scheduler_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    parser.add_argument('--scheduler_metric_split',
                        choices=['train', 'val'],
                        default='val')
    parser.add_argument('--scheduler_metric_name')

    # Evaluation
    parser.add_argument('--evaluate_all_splits',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=True)
    parser.add_argument('--eval_splits', nargs='+', default=[])
    parser.add_argument('--eval_only',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument('--eval_epoch', default=None, type=int)
    parser.add_argument('--save_z',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)

    # Misc
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--log_dir', default='./logs')
    parser.add_argument('--log_every', default=50, type=int)
    parser.add_argument('--save_step', type=int)
    parser.add_argument('--save_best',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=True)
    parser.add_argument('--save_last',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=True)
    parser.add_argument('--no_group_logging',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--use_wandb',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument('--progress_bar',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument('--resume',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)

    config = parser.parse_args()
    config = populate_defaults(config)

    # set device
    config.device = torch.device("cuda:" + str(
        config.device)) if torch.cuda.is_available() else torch.device("cpu")

    ## Initialize logs
    if os.path.exists(config.log_dir) and config.resume:
        resume = True
        mode = 'a'
    elif os.path.exists(config.log_dir) and config.eval_only:
        resume = False
        mode = 'a'
    else:
        resume = False
        mode = 'w'

    if not os.path.exists(config.log_dir):
        os.makedirs(config.log_dir)
    logger = Logger(os.path.join(config.log_dir, 'log.txt'), mode)

    # Record config
    log_config(config, logger)

    # Set random seed
    set_seed(config.seed)

    # Data
    full_dataset = supported.datasets[config.dataset](
        root_dir=config.root_dir,
        download=config.download,
        split_scheme=config.split_scheme,
        **config.dataset_kwargs)

    # To implement data augmentation (i.e., have different transforms
    # at training time vs. test time), modify these two lines:
    train_transform = initialize_transform(
        transform_name=config.train_transform,
        config=config,
        dataset=full_dataset)
    eval_transform = initialize_transform(transform_name=config.eval_transform,
                                          config=config,
                                          dataset=full_dataset)

    train_grouper = CombinatorialGrouper(dataset=full_dataset,
                                         groupby_fields=config.groupby_fields)

    datasets = defaultdict(dict)
    for split in full_dataset.split_dict.keys():
        if split == 'train':
            transform = train_transform
            verbose = True
        elif split == 'val':
            transform = eval_transform
            verbose = True
        else:
            transform = eval_transform
            verbose = False
        # Get subset
        datasets[split]['dataset'] = full_dataset.get_subset(
            split, frac=config.frac, transform=transform)

        if split == 'train':
            datasets[split]['loader'] = get_train_loader(
                loader=config.train_loader,
                dataset=datasets[split]['dataset'],
                batch_size=config.batch_size,
                uniform_over_groups=config.uniform_over_groups,
                grouper=train_grouper,
                distinct_groups=config.distinct_groups,
                n_groups_per_batch=config.n_groups_per_batch,
                **config.loader_kwargs)
        else:
            datasets[split]['loader'] = get_eval_loader(
                loader=config.eval_loader,
                dataset=datasets[split]['dataset'],
                grouper=train_grouper,
                batch_size=config.batch_size,
                **config.loader_kwargs)

        # Set fields
        datasets[split]['split'] = split
        datasets[split]['name'] = full_dataset.split_names[split]
        datasets[split]['verbose'] = verbose

        # Loggers
        datasets[split]['eval_logger'] = BatchLogger(
            os.path.join(config.log_dir, f'{split}_eval.csv'),
            mode=mode,
            use_wandb=(config.use_wandb and verbose))
        datasets[split]['algo_logger'] = BatchLogger(
            os.path.join(config.log_dir, f'{split}_algo.csv'),
            mode=mode,
            use_wandb=(config.use_wandb and verbose))

        if config.use_wandb:
            initialize_wandb(config)

    # Logging dataset info
    if config.no_group_logging and full_dataset.is_classification and full_dataset.y_size == 1:
        log_grouper = CombinatorialGrouper(dataset=full_dataset,
                                           groupby_fields=['y'])
    elif config.no_group_logging:
        log_grouper = None
    else:
        log_grouper = train_grouper
    log_group_data(datasets, log_grouper, logger)

    ## Initialize algorithm
    algorithm = initialize_algorithm(config=config,
                                     datasets=datasets,
                                     train_grouper=train_grouper)

    if config.eval_epoch is None:
        eval_model_path = os.path.join(config.log_dir, 'best_model.pth')
    else:
        eval_model_path = os.path.join(config.log_dir,
                                       f'{config.eval_epoch}_model.pth')
    best_epoch, best_val_metric = load(algorithm, eval_model_path)
    if config.eval_epoch is None:
        epoch = best_epoch
    else:
        epoch = config.eval_epoch

    results, z_splits, y_splits, c_splits = evaluate(algorithm=algorithm,
                                                     datasets=datasets,
                                                     epoch=epoch,
                                                     general_logger=logger,
                                                     config=config)

    include_test = config.evaluate_all_splits or 'test' in config.eval_splits

    logistics = all_logistics(z_splits,
                              c_splits,
                              y_splits,
                              epoch=epoch,
                              sample=int(config.analyze_sample),
                              include_test=include_test)

    logistics['G0'] = results['id_val']['acc_avg']
    logistics['G1'] = logistics['val_on_val']
    logistics['G2'] = logistics['trainval_on_val']
    logistics['G3'] = results['val']['acc_avg']

    logistics['I0'] = logistics['c_train']
    logistics['I1'] = logistics['c_val']
    per_class = torch.tensor(list(logistics['c_perclass'].values()))
    logistics['I2'] = torch.mean(per_class).item()

    if include_test:
        logistics['G1_test'] = logistics['test_on_test']
        logistics['G2_test'] = logistics['traintest_on_test']
        logistics['G3_test'] = results['test']['acc_avg']

        logistics['I1_test'] = logistics['c_test']
        per_class = torch.tensor(list(logistics['c_perclass_test'].values()))
        logistics['I2_test'] = torch.mean(per_class).item()

    with (open(os.path.join(config.log_dir, f'tests_epoch_{epoch}.pkl'),
               "wb")) as f:
        pickle.dump(logistics, f)

    logger.close()
    for split in datasets:
        datasets[split]['eval_logger'].close()
        datasets[split]['algo_logger'].close()