Beispiel #1
0
 def get_nn_model(self, config, num_classes, mode='erm'):
     cl_config = config['classification_config']
     if mode == 'erm':
         mode_config = cl_config[f'erm_config']
     else:
         mode_config = cl_config[f'gdro_config']
     cl_config = merge_dicts(cl_config, mode_config)
     if cl_config['bit_pretrained']:
         model_cls = BiTResNet
     else:
         models = {
             'lenet4': LeNet4,
             'resnet50': PyTorchResNet,
             'shallow_cnn': ShallowCNN,
             "mp64x64": Mp64x64Net,
         }
         try:
             model_cls = models[cl_config['model']]
         except KeyError:
             raise ValueError('Unsupported model architecture')
     out_dim = num_classes if num_classes > 2 else 1
     model = model_cls(num_classes=out_dim)
     if self.use_cuda:
         model = torch.nn.DataParallel(model).cuda()
     self.logger.info('Model:')
     self.logger.info(str(model))
     return model
 def get_nn_model(self, config, num_classes, mode="erm"):
     cl_config = config["classification_config"]
     if mode == "erm":
         mode_config = cl_config[f"erm_config"]
     else:
         mode_config = cl_config[f"gdro_config"]
     cl_config = merge_dicts(cl_config, mode_config)
     if cl_config["bit_pretrained"]:
         model_cls = BiTResNet
     else:
         models = {
             "lenet4": LeNet4,
             "resnet50": PyTorchResNet,
             "shallow_cnn": ShallowCNN,
         }
         try:
             model_cls = models[cl_config["model"]]
         except KeyError:
             raise ValueError("Unsupported model architecture")
     model = model_cls(num_classes=num_classes)
     if self.use_cuda:
         model = torch.nn.DataParallel(model).cuda()
     self.logger.info("Model:")
     self.logger.info(str(model))
     return model
Beispiel #3
0
    def get_dataloaders(
        self,
        config: Dict[str, Any],
        data_config: BaseConfig,
        use_cuda: bool,
        mode="erm",
        subclass_labels: Optional[str] = None,
    ):
        train_config = config["classification_config"]
        device = (torch.device(f"cuda:{torch.cuda.current_device()}")
                  if use_cuda else torch.device("cpu"))
        if mode == "erm":
            mode_config = train_config["erm_config"]
        else:
            mode_config = train_config["gdro_config"]
        train_config = merge_dicts(train_config, mode_config)

        # dataset_name = config['dataset']
        seed = config['seed']
        config = config['classification_config']
        if mode == 'george':
            # file path to subclass labels specified
            assert isinstance(subclass_labels,
                              str) and '.pt' in subclass_labels
        elif mode != 'erm':
            assert subclass_labels is None
            subclass_labels = mode.rstrip('_gdro')

        if subclass_labels is None:
            # subclass labels default to superclass labels if none are given
            subclass_labels = 'superclass'

        if '.pt' in subclass_labels:  # file path specified
            subclass_labels = torch.load(subclass_labels)
        else:  # keyword specified
            kw = subclass_labels
            subclass_labels = defaultdict(lambda: kw)

        if mode == 'erm':
            mode_config = config[f'erm_config']
        else:
            mode_config = config[f'gdro_config']
        config = merge_dicts(config, mode_config)

        dataset_triplet = load_dataset(cfg=data_config)
        dataset_class = FdmDatasetWrapper
        batch_size = config['batch_size']

        context_labels = None
        # Check whether to use the ground-truth labels for the context-set or whether the labels
        # need to be predicted
        if data_config.predict_context:
            # trained on the training set
            # 1. step: train classifier on triplet.train
            # 2. step: predict labels on triplet.context
            # 3. step: combine datasets
            train_data, test_data = dataset_triplet.train, dataset_triplet.context
            y_dim = dataset_triplet.y_dim
            input_shape = next(iter(train_data))[0].shape

            train_loader = DataLoader(
                train_data,
                batch_size=data_config.batch_size,
                pin_memory=True,
                shuffle=True,
                num_workers=data_config.num_workers,
            )
            test_loader = DataLoader(
                test_data,
                batch_size=data_config.test_batch_size,
                shuffle=False,
                pin_memory=True,
                num_workers=data_config.num_workers,
            )

            clf: Classifier = fit_classifier(
                data_config,
                input_shape,
                train_data=train_loader,
                train_on_recon=False,
                pred_s=False,
                test_data=test_loader,
                target_dim=y_dim,
                device=device,
            )

            context_labels, actual, _ = clf.predict_dataset(test_loader,
                                                            device=device)
            wandb.log(
                {"context.acc": (context_labels == actual).float().mean()})

        dataloaders = {}
        for split in DATA_SPLITS:
            key = 'train' if 'train' in split else split
            split_subclass_labels = subclass_labels[key]
            shared_dl_args = {
                'batch_size': batch_size,
                'num_workers': config['workers']
            }

            if split in ("train", "train_clean", "val"):
                # The ground-truth labels are being used for the context-set
                if context_labels is None:
                    dataset = ConcatDataset(
                        [dataset_triplet.train, dataset_triplet.context])
                # The predicted labels are being used for the context set
                else:
                    dataset = TrainContextWrapper(
                        dataset_triplet.train,
                        context=dataset_triplet.context,
                        new_context_labels=context_labels,
                    )
            else:
                dataset = dataset_triplet.test
            if split == 'train':
                dataset = dataset_class(
                    cfg=data_config,
                    split=split,
                    dataset=dataset,
                    device=device,
                )
                dataset.add_subclass_labels(split_subclass_labels, seed=seed)
                if config.get('uniform_group_sampling', False):
                    sampler, group_weights = self._get_uniform_group_sampler(
                        dataset)
                    self.logger.info(
                        f'Resampling training data with subclass weights:\n{group_weights}'
                    )
                    dataloaders[split] = DataLoader(dataset,
                                                    **shared_dl_args,
                                                    shuffle=False,
                                                    sampler=sampler)
                else:
                    dataloaders[split] = DataLoader(dataset,
                                                    **shared_dl_args,
                                                    shuffle=True)
            else:
                # Evaluation dataloaders (including for the training set) are "clean" - no data augmentation or shuffling
                dataset = dataset_class(
                    cfg=data_config,
                    dataset=dataset,
                    split=key,
                    device=device,
                )
                dataset.add_subclass_labels(split_subclass_labels, seed=seed)
                dataloaders[split] = DataLoader(dataset,
                                                **shared_dl_args,
                                                shuffle=False)

            self.logger.info(f'{split} split:')
            # log class counts for each label type
            for label_type, labels in dataset.Y_dict.items():
                self.logger.info(
                    f'{label_type.capitalize()} counts: {np.bincount(labels)}')

        return dataloaders
Beispiel #4
0
    def classify(self, classification_config, model, dataloaders, mode):
        """Runs the initial representation learning stage of the GEORGE pipeline.

        Note:
            This function handles much of the pre- and post-processing needed to transition
            from stage to stage (i.e. modifying datasets with subclass labels and formatting
            the outputs in a manner that is compatible with GEORGEHarness.cluster).
            For more direct interaction with the classification procedure, see the
            GEORGEClassification class in classification.george_classification.

        Args:
            classification_config(dict): Contains args for the criterion, optimizer,
                scheduler, metrics. Optional nested `{mode}_config` dictionaries can
                add and/or replace arguments in classification config.
            model(nn.Module): A PyTorch model.
            dataloaders(Dict[str, DataLoader]): a dictionary mapping a data split
                to its given DataLoader (note that all data splits in DATA_SPLITS
                must be specified). More information can be found in
                classification.datasets.
            mode(str): The type of optimization to run. `erm` trains with vanilla
                Cross Entropy Loss. 'george' trains DRO using given cluster labels.
                `random_gdro` trains DRO with random cluster labels. `superclass_gdro`
                trains DRO using the given superclass labels. Implementation of DRO
                from Sagawa et al. (2020).
            clusters_path(str, optional): The path leading to clusters.pt file
                produced by GEORGEHarness.cluster. Only needed if mode == 'george'.

        Returns:
            save_dir(str): The subdirectory within `exp_dir` that contains model
                checkpoints, best model outputs, and best model metrics.
        """
        # overwrite args in classification_config with mode specific params
        if mode == 'erm':
            mode_config = classification_config[f'erm_config']
        else:
            mode_config = classification_config[f'gdro_config']
        classification_config = merge_dicts(classification_config, mode_config)

        if classification_config['eval_only'] or classification_config[
                'save_act_only']:
            save_dir = self.exp_dir
        else:
            save_dir = os.path.join(self.exp_dir, f'{mode}_{get_unique_str()}')
            self._save_config(save_dir, classification_config)
        robust = self._get_robust_status(mode)

        # (1) train
        trainer = GEORGEClassification(
            classification_config,
            save_dir=save_dir,
            use_cuda=self.use_cuda,
            log_format=self.log_format,
            has_estimated_subclasses=mode not in ['erm', 'true_subclass_gdro'],
        )
        if not (classification_config['eval_only']
                or classification_config['save_act_only']
                or classification_config['bit_pretrained']):
            trainer.train(model,
                          dataloaders['train'],
                          dataloaders['val'],
                          robust=robust)

        # (2) evaluate
        split_to_outputs = {}
        split_to_metrics = {}
        for split, dataloader in dataloaders.items():
            if split == 'train':
                continue
            key = 'train' if split == 'train_clean' else split
            if classification_config['eval_only'] and key != 'test':
                continue
            self.logger.basic_info(f'Evaluating on {key} split...')
            metrics, outputs = trainer.evaluate(
                model,
                dataloaders,
                split,
                robust=robust,
                save_activations=True,
                bit_pretrained=classification_config['bit_pretrained'],
                adv_metrics=classification_config['eval_only'],
                ban_reweight=classification_config['ban_reweight'],
            )
            split_to_metrics[key] = metrics
            split_to_outputs[key] = outputs

        # (3) save everything
        if not classification_config['eval_only']:
            self._save_json(os.path.join(save_dir, 'metrics.json'),
                            split_to_metrics)
            self._save_torch(os.path.join(save_dir, 'outputs.pt'),
                             split_to_outputs)
            wandb.log(split_to_metrics)
        return save_dir
    def get_dataloaders(
        self, config, mode="erm", transforms=None, subclass_labels=None
    ):
        dataset_name = config["dataset"]
        seed = config["seed"]
        config = config["classification_config"]
        if mode == "george":
            assert ".pt" in subclass_labels  # file path to subclass labels specified
        elif mode != "erm":
            assert subclass_labels is None
            subclass_labels = mode.rstrip("_gdro")

        if subclass_labels is None:
            # subclass labels default to superclass labels if none are given
            subclass_labels = "superclass"

        if ".pt" in subclass_labels:  # file path specified
            subclass_labels = torch.load(subclass_labels)
        else:  # keyword specified
            kw = subclass_labels
            subclass_labels = defaultdict(lambda: kw)

        if mode == "erm":
            mode_config = config[f"erm_config"]
        else:
            mode_config = config[f"gdro_config"]
        config = merge_dicts(config, mode_config)

        dataset_name = dataset_name.lower()
        d = {
            "celeba": CelebADataset,
            "isic": ISICDataset,
            "mnist": MNISTDataset,
            "waterbirds": WaterbirdsDataset,
            "pmx": PmxDataset,
        }
        dataset_class = d[dataset_name]
        batch_size = config["batch_size"]

        dataloaders = {}
        for split in DATA_SPLITS:
            key = "train" if "train" in split else split
            split_subclass_labels = subclass_labels[key]
            shared_dl_args = {
                "batch_size": batch_size,
                "num_workers": config["workers"],
            }
            if split == "train":
                dataset = dataset_class(
                    root="./data",
                    split=split,
                    download=True,
                    augment=True,
                    **config["dataset_config"],
                )
                dataset.add_subclass_labels(split_subclass_labels, seed=seed)
                if config.get("uniform_group_sampling", False):
                    sampler, group_weights = self._get_uniform_group_sampler(dataset)
                    self.logger.info(
                        f"Resampling training data with subclass weights:\n{group_weights}"
                    )
                    dataloaders[split] = DataLoader(
                        dataset, **shared_dl_args, shuffle=False, sampler=sampler
                    )
                else:
                    dataloaders[split] = DataLoader(
                        dataset, **shared_dl_args, shuffle=True
                    )
            else:
                # Evaluation dataloaders (including for the training set) are "clean" - no data augmentation or shuffling
                dataset = dataset_class(
                    root="./data", split=key, **config["dataset_config"]
                )
                dataset.add_subclass_labels(split_subclass_labels, seed=seed)
                dataloaders[split] = DataLoader(
                    dataset, **shared_dl_args, shuffle=False
                )

            self.logger.info(f"{split} split:")
            # log class counts for each label type
            for label_type, labels in dataset.Y_dict.items():
                self.logger.info(
                    f"{label_type.capitalize()} counts: {np.bincount(labels)}"
                )

        return dataloaders
Beispiel #6
0
    def get_dataloaders(self, config, mode='erm', transforms=None, subclass_labels=None):
        dataset_name = config['dataset']
        seed = config['seed']
        config = config['classification_config']
        if mode == 'george':
            assert ('.pt' in subclass_labels)  # file path to subclass labels specified
        elif mode != 'erm':
            assert (subclass_labels is None)
            subclass_labels = mode.rstrip('_gdro')

        if subclass_labels is None:
            # subclass labels default to superclass labels if none are given
            subclass_labels = 'superclass'

        if '.pt' in subclass_labels:  # file path specified
            subclass_labels = torch.load(subclass_labels)
        else:  # keyword specified
            kw = subclass_labels
            subclass_labels = defaultdict(lambda: kw)

        if mode == 'erm':
            mode_config = config[f'erm_config']
        else:
            mode_config = config[f'gdro_config']
        config = merge_dicts(config, mode_config)

        dataset_name = dataset_name.lower()
        d = {
            'celeba': CelebADataset,
            'isic': ISICDataset,
            'mnist': MNISTDataset,
            'waterbirds': WaterbirdsDataset
        }
        dataset_class = d[dataset_name]
        batch_size = config['batch_size']

        dataloaders = {}
        for split in DATA_SPLITS:
            key = 'train' if 'train' in split else split
            split_subclass_labels = subclass_labels[key]
            shared_dl_args = {'batch_size': batch_size, 'num_workers': config['workers']}
            if split == 'train':
                dataset = dataset_class(root='./data', split=split, download=True, augment=True,
                                        **config['dataset_config'])
                dataset.add_subclass_labels(split_subclass_labels, seed=seed)
                if config.get('uniform_group_sampling', False):
                    sampler, group_weights = self._get_uniform_group_sampler(dataset)
                    self.logger.info(
                        f'Resampling training data with subclass weights:\n{group_weights}')
                    dataloaders[split] = DataLoader(dataset, **shared_dl_args, shuffle=False,
                                                    sampler=sampler)
                else:
                    dataloaders[split] = DataLoader(dataset, **shared_dl_args, shuffle=True)
            else:
                # Evaluation dataloaders (including for the training set) are "clean" - no data augmentation or shuffling
                dataset = dataset_class(root='./data', split=key, **config['dataset_config'])
                dataset.add_subclass_labels(split_subclass_labels, seed=seed)
                dataloaders[split] = DataLoader(dataset, **shared_dl_args, shuffle=False)

            self.logger.info(f'{split} split:')
            # log class counts for each label type
            for label_type, labels in dataset.Y_dict.items():
                self.logger.info(f'{label_type.capitalize()} counts: {np.bincount(labels)}')

        return dataloaders