def __init__( self, data_source: Sized, *, batch_size: int, training_mode: TrainingMode | str = TrainingMode.step, shuffle: bool = True, drop_last: bool = False, generator: torch.Generator | None = None, ) -> None: self.data_source = data_source self.batch_size = batch_size self.shuffle = shuffle self.drop_last = drop_last self.generator = generator if isinstance(training_mode, str): training_mode = str_to_enum(str_=training_mode, enum=TrainingMode) self.training_mode = training_mode if self.training_mode is TrainingMode.epoch: epoch_length = num_batches_per_epoch( num_samples=len(self.data_source), batch_size=self.batch_size, drop_last=self.drop_last, ) else: epoch_length = None super().__init__(epoch_length=epoch_length)
def __init__( self, lambda_sampler: LS, *, mode: MixUpMode | str = MixUpMode.linear, p: float = 1.0, num_classes: int | None = None, featurewise: bool = False, inplace: bool = False, ) -> None: """ :param lambda_sampler: The distribution from which to sample lambda (the mixup interpolation parameter). :param mode: Which mode to use to mix up samples: geometric or linear. .. note:: The (weighted) geometric mean, enabled by ``mode=geometric``, is only valid for positive inputs. :param p: The probability with which the transform will be applied to a given sample. :param num_classes: The total number of classes in the dataset that needs to be specified if wanting to mix up targets that are label-enoded. Passing label-encoded targets without specifying ``num_classes`` will result in a RuntimeError. :param featurewise: Whether to sample sample feature-wise instead of sample-wise. .. note:: If the ``lambda_sampler`` is a BernoulliDistribution, then featurewise sampling will always be enabled. :param inplace: Whether the transform should be performed in-place. :raises ValueError: if ``p`` is not in the range [0, 1] or ``num_classes < 1``. """ super().__init__() self.lambda_sampler = lambda_sampler if not 0 <= p <= 1: raise ValueError("'p' must be in the range [0, 1].") self.p = p if isinstance(mode, str): mode = str_to_enum(str_=mode, enum=MixUpMode) self.mode = mode if (num_classes is not None) and num_classes < 1: raise ValueError(f"{ num_classes } must be greater than 1.") self.num_classes = num_classes self.featurewise = featurewise or isinstance(lambda_sampler, td.Bernoulli) self.inplace = inplace
def __init__( self, root: Union[str, Path], *, download: bool = True, transform: Optional[ImageTform] = None, split: Optional[Union[Camelyon17Split, str]] = None, split_scheme: Union[Camelyon17SplitScheme, str] = Camelyon17SplitScheme.official, superclass: Union[Camelyon17Attr, str] = Camelyon17Attr.tumor, subclass: Union[Camelyon17Attr, str] = Camelyon17Attr.center, ) -> None: self.superclass = str_to_enum(str_=superclass, enum=Camelyon17Attr) self.subclass = str_to_enum(str_=subclass, enum=Camelyon17Attr) self.split = (str_to_enum(str_=split, enum=Camelyon17Split) if isinstance(split, str) else split) self.split_scheme = (str_to_enum( str_=split_scheme, enum=Camelyon17SplitScheme) if isinstance( split_scheme, str) else split_scheme) self.root = Path(root) self._base_dir = self.root / "camelyon17_v1.0" self.download = download if self.download: download_from_url( file_info=self._FILE_INFO, root=self.root, logger=self.logger, remove_finished=True, ) else: raise FileNotFoundError( f"Data not found at location {self._base_dir.resolve()}. Have you downloaded it?" ) # Read in metadata # Note: metadata is one-indexed. self.metadata = pd.read_csv(self._base_dir / 'metadata.csv', index_col=0, dtype={"patient": "str"}) if self.split_scheme is Camelyon17SplitScheme.mixed_to_test: # For the mixed-to-test setting, # we move slide 23 (corresponding to patient 042, node 3 in the original dataset) # from the test set to the training set slide_mask = self.metadata["slide"] == 23 self.metadata.loc[slide_mask, "split"] = Camelyon17Split.train.value # Use an official split of the data, if 'split' is specified, else just use all # of the data val_center_mask = self.metadata["center"] == self._VAL_CENTER test_center_mask = self.metadata["center"] == self._TEST_CENTER self.metadata.loc[val_center_mask, "split"] = Camelyon17Split.val.value self.metadata.loc[test_center_mask, "split"] = Camelyon17Split.test.value if self.split is not None: split_indices = self.metadata["split"] == self.split.value self.metadata = cast(pd.DataFrame, self.metadata[split_indices]) # Construct filepaths from metadata def build_fp(row: pd.DataFrame) -> str: return "patches/patient_{0}_node_{1}/patch_patient_{0}_node_{1}_x_{2}_y_{3}.png".format( *row) x = (self.metadata[["patient", "node", "x_coord", "y_coord"]].apply(build_fp, axis=1).to_numpy()) # Extract superclass labels y = torch.as_tensor(self.metadata[str(self.superclass)].to_numpy(), dtype=torch.long) # Extract subclass labels s = torch.as_tensor(self.metadata[str(self.subclass)].to_numpy(), dtype=torch.long) super().__init__(x=x, y=y, s=s, transform=transform, image_dir=self._base_dir)
class DINO(MomentumTeacherModel): _ft_clf: DINOLinearClassifier @parsable def __init__( self, *, backbone: Union[nn.Module, vit.VitArch, vit.VisionTransformer, ResNetArch, str] = vit.VitArch.small, out_dim: int = 65_536, lr: float = 5.0e-4, warmup_iters: int = 10, weight_decay: float = 4.0e-2, min_lr: float = 1.0e-6, weight_decay_final: float = 0.4, freeze_last_layer: int = 1, patch_size: int = 16, drop_path_rate: float = 0.1, norm_last_layer: bool = True, use_bn_in_head: bool = False, momentum_teacher: float = 0.996, momentum_center: float = 0.9, teacher_temp: float = 0.04, warmup_teacher_temp: float = 0.04, student_temp: float = 0.1, warmup_teacher_temp_iters: int = 30, num_eval_blocks: int = 1, lr_eval: float = 1.0e-4, global_crop_size: Optional[Union[Tuple[int, int], int]] = None, local_crop_size: Union[Tuple[float, float], float] = 0.43, global_crops_scale: Tuple[float, float] = (0.4, 1.0), local_crops_scale: Tuple[float, float] = (0.05, 0.4), local_crops_number: int = 8, batch_transforms: Optional[BatchTransform] = None, eval_epochs: int = 100, eval_batch_size: Optional[int] = None, ) -> None: super().__init__( lr=lr, weight_decay=weight_decay, eval_epochs=eval_epochs, eval_batch_size=eval_batch_size, batch_transforms=batch_transforms, global_crop_size=global_crop_size, local_crop_size=local_crop_size, global_crops_scale=global_crops_scale, local_crops_scale=local_crops_scale, local_crops_number=local_crops_number, ) if isinstance(backbone, str): backbone = str_to_enum(str_=backbone, enum=vit.VitArch) self.backbone = backbone self.num_eval_blocks = num_eval_blocks self.warmup_iters = warmup_iters self.min_weight_decay = weight_decay_final self.min_lr = min_lr self.freeze_last_layer = freeze_last_layer self.out_dim = out_dim self.norm_last_layer = norm_last_layer self.use_bn_in_head = use_bn_in_head self.momentum_teacher = momentum_teacher self.momentum_center = momentum_center self.teacher_temp = teacher_temp self.warmup_teacher_temp = warmup_teacher_temp self.warmup_teacher_temp_iters = warmup_teacher_temp_iters self.student_temp = student_temp self.eval_lr = lr_eval # ViT-specific arguments self.patch_size = patch_size self.drop_path_rate = drop_path_rate self._loss_fn = DINOLoss( student_temp=student_temp, teacher_temp=teacher_temp, warmup_teacher_temp=warmup_teacher_temp, warmup_teacher_temp_iters=warmup_teacher_temp_iters, )
def __init__( self, root: Union[str, Path], *, transform: Optional[AudioTform] = None, download: bool = True, target_attrs: Union[Union[SoundscapeAttr, str], List[Union[SoundscapeAttr, str]]] = SoundscapeAttr.habitat, segment_len: Optional[float] = 15, preprocess: bool = False, ) -> None: self.root = Path(root).expanduser() self.download = download self.base_dir = self.root / self.__class__.__name__ self.labels_dir = self.base_dir / self.INDICES_DIR self.segment_len = segment_len self.preprocess = preprocess self._metadata_path = self.base_dir / self.METADATA_FILENAME self.ec_labels_path = self.labels_dir / self._EC_LABELS_FILENAME self.uk_labels_path = self.labels_dir / self._UK_LABELS_FILENAME if not isinstance(target_attrs, list): target_attrs = [target_attrs] self.target_attrs = [ str(str_to_enum(str_=elem, enum=SoundscapeAttr)) for elem in target_attrs ] if self.download: self._download_files() self._check_files() # Extract labels from indices files. if not self._metadata_path.exists(): self._extract_metadata() self.metadata = pd.read_csv(self.base_dir / self.METADATA_FILENAME) if self.preprocess and (self.num_frames_in_segment is not None): # data directory depends on the segment length processed_audio_dir = self.base_dir / f"segment_len={self.segment_len}" if not (processed_audio_dir / "filepaths.csv").exists(): self._preprocess_files() filepaths = pd.read_csv(processed_audio_dir / "filepaths.csv") self.metadata.drop("filePath", inplace=True, axis=1) self.metadata = filepaths.merge(self.metadata, how='left', on='fileName') x = self.metadata["filePath"].to_numpy() y = torch.as_tensor( self._label_encode(self.metadata[self.target_attrs], inplace=True).to_numpy()) super().__init__(x=x, y=y, transform=transform, audio_dir=self.base_dir)
class MoCoV2(MomentumTeacherModel): _ft_clf: FineTuner use_ddp: bool @parsable def __init__( self, *, backbone: Union[nn.Module, ResNetArch, str] = ResNetArch.resnet18, head: Optional[nn.Module] = None, out_dim: int = 128, num_negatives: int = 65_536, momentum_teacher: float = 0.999, temp: float = 0.07, lr: float = 0.03, momentum_sgd: float = 0.9, weight_decay: float = 1.0e-4, use_mlp: bool = False, instance_transforms: Optional[MultiCropTransform] = None, batch_transforms: Optional[BatchTransform] = None, global_crop_size: Optional[Union[Tuple[int, int], int]] = None, local_crop_size: Union[Tuple[float, float], float] = 0.43, global_crops_scale: Tuple[float, float] = (0.4, 1.0), local_crops_scale: Tuple[float, float] = (0.05, 0.4), local_crops_number: int = 0, eval_epochs: int = 100, eval_batch_size: Optional[int] = None, ) -> None: """ PyTorch Lightning implementation of `MoCo <https://arxiv.org/abs/2003.04297>`_ Paper authors: Xinlei Chen, Haoqi Fan, Ross Girshick, Kaiming He. :param backbone: Backbone of the encoder. Can be any nn.Module with a unary forward method (accepting a single input Tensor and returning a single output Tensor), in which case embed_dim will be inferred by passing a dummy input through the backone, or a 'ResNetArch' instance whose value is a resnet builder function which will be called with num_classes=out_dim. emb_dim: Feature dimension of the ResNet model: 128); only applicable if backbone is an 'ResNetArch' instance. :param out_dim: Output size of the encoder; only applicable when backbone is a 'ResNetArch' enum. :param num_negatives: queue size; number of negative keys. :param momentum_teacher: Momentum (what fraction of the previous iterates parameters to interpolate with) for the teacher update. :param temp: Softmax temperature. :param lr: Learning rate for the student model. :param sgd_momentum: Optimizer momentum. :param weight_decay: Optimizer weight decay. :param use_mlp: Whether to add an MLP head to the decoders (instead of a single linear layer). :param instance_transforms: Instance-wise image-transforms to use to generate the positive pairs for instance-discrimination. :param batch_transforms: Batch-wise image-transforms to use to generate the positive pairs for instance-discrimination. :param multicrop: Whether to use a multi-crop augmentation policy wherein the same image is randomly cropped to get a pair of high resolution (global) images and along with multiple lower resolution (generally covering less than 50% of the image) images of number 'local_crops_number'. :param global_crops_scale: Scale range of the cropped image before resizing, relative to the origin image. Used for large global view cropping. Only applies when 'multicrop=True'. :param local_crops_number: Number of small local views to generate. :param global_crops_scale: Scale range of the cropped image before resizing, relative to the origin image. Used for small, local cropping. Only applies when 'multicrop=True'. :param eval_epochs: Number of epochs to train the post-hoc classifier for during validation/testing. :param eval_batch_size: Batch size to use when training the post-hoc classifier during validation/testing. """ super().__init__( lr=lr, weight_decay=weight_decay, eval_epochs=eval_epochs, eval_batch_size=eval_batch_size, instance_transforms=instance_transforms, batch_transforms=batch_transforms, global_crop_size=global_crop_size, local_crop_size=local_crop_size, global_crops_scale=global_crops_scale, local_crops_scale=global_crops_scale, local_crops_number=local_crops_number, ) if isinstance(backbone, str): backbone = str_to_enum(str_=backbone, enum=ResNetArch) self.backbone = backbone self.head = head self.out_dim = out_dim self.temp = temp self.lr = lr self.weight_decay = weight_decay self.momentum_teacher = momentum_teacher self.momentum_sgd = momentum_sgd self.num_negatives = num_negatives self.use_mlp = use_mlp # View-generation settings self._global_crop_size = global_crop_size self._local_crop_size = local_crop_size self.local_crops_number = local_crops_number self.local_crops_scale = local_crops_scale self.global_crops_scale = global_crops_scale # create the queue self.mb = MemoryBank(dim=out_dim, capacity=num_negatives) self._loss_fn = CrossEntropyLoss(reduction=ReductionType.mean)
def mixup_data( *, batch: TernarySample, device: torch.device, mix_lambda: Optional[float], alpha: float, fairness: Union[FairnessType, str], ) -> Mixed: '''Returns mixed inputs, pairs of targets, and lambda''' assert isinstance(batch.x, Tensor) fairness = str_to_enum(str_=fairness, enum=FairnessType) lam = ( Beta( torch.tensor([alpha]).to(device), torch.tensor([alpha]).to(device) # Potentially change alpha from a=1.0 to account for class imbalance? ).sample() if mix_lambda is None else torch.tensor([mix_lambda]).to(device) ) batches = { "x_s0": batch.x[batch.s.view(-1) == 0].to(device), "x_s1": batch.x[batch.s.view(-1) == 1].to(device), "y_s0": batch.y[batch.s.view(-1) == 0].to(device), "y_s1": batch.y[batch.s.view(-1) == 1].to(device), "x_s0_y0": batch.x[(batch.s.view(-1) == 0) & (batch.y.view(-1) == 0)].to(device), "x_s1_y0": batch.x[(batch.s.view(-1) == 1) & (batch.y.view(-1) == 0)].to(device), "x_s0_y1": batch.x[(batch.s.view(-1) == 0) & (batch.y.view(-1) == 1)].to(device), "x_s1_y1": batch.x[(batch.s.view(-1) == 1) & (batch.y.view(-1) == 1)].to(device), "s_s0_y0": batch.s[(batch.s.view(-1) == 0) & (batch.y.view(-1) == 0)].to(device), "s_s1_y0": batch.s[(batch.s.view(-1) == 1) & (batch.y.view(-1) == 0)].to(device), "s_s0_y1": batch.s[(batch.s.view(-1) == 0) & (batch.y.view(-1) == 1)].to(device), "s_s1_y1": batch.s[(batch.s.view(-1) == 1) & (batch.y.view(-1) == 1)].to(device), } xal = [] xbl = [] sal = [] sbl = [] yal = [] ybl = [] for x_a, s_a, y_a in zip(batch.x, batch.s, batch.y): xal.append(x_a) sal.append(s_a.unsqueeze(-1).float()) yal.append(y_a.unsqueeze(-1).float()) if (fairness is FairnessType.EqOp and y_a == 0) or fairness is FairnessType.No: xbl.append(x_a) sbl.append(s_a.unsqueeze(-1)) ybl.append(y_a.unsqueeze(-1)) elif fairness is FairnessType.EqOp: idx = torch.randint(batches[f"x_s{1 - int(s_a)}_y1"].size(0), (1,)) x_b = batches[f"x_s{1 - int(s_a)}_y1"][idx, :].squeeze(0) xbl.append(x_b) sbl.append((torch.ones_like(s_a) * (1 - s_a)).unsqueeze(-1).float()) y_b = torch.ones_like(y_a).unsqueeze(-1) ybl.append(y_b) elif fairness is FairnessType.DP: idx = torch.randint(batches[f"x_s{1-int(s_a)}"].size(0), (1,)) x_b = batches[f"x_s{1-int(s_a)}"][idx, :].squeeze(0) xbl.append(x_b) sbl.append((torch.ones_like(s_a) * (1 - s_a)).unsqueeze(-1).float()) y_b = batches[f"y_s{1-int(s_a)}"][idx].float() ybl.append(y_b) elif fairness is FairnessType.EO: idx = torch.randint(batches[f"x_s{1-int(s_a)}_y{int(y_a)}"].size(0), (1,)) x_b = batches[f"x_s{1-int(s_a)}_y{int(y_a)}"][idx, :].squeeze(0) xbl.append(x_b) sbl.append((torch.ones_like(s_a) * (1 - s_a)).unsqueeze(-1).float()) y_b = (torch.ones_like(y_a) * y_a).unsqueeze(-1) ybl.append(y_b) x_a = torch.stack(xal, dim=0).to(device) x_b = torch.stack(xbl, dim=0).to(device) s_a = torch.stack(sal, dim=0).to(device) s_b = torch.stack(sbl, dim=0).to(device) y_a = torch.stack(yal, dim=0).to(device) y_b = torch.stack(ybl, dim=0).to(device) mix_stats = { "batch_stats/S0=sS0": sum(a == b for a, b in zip(s_a, s_b)) / batch.s.size(0), "batch_stats/S0!=sS0": sum(a != b for a, b in zip(s_a, s_b)) / batch.s.size(0), "batch_stats/all_s0": sum(a + b == 0 for a, b in zip(s_a, s_b)) / batch.s.size(0), "batch_stats/all_s1": sum(a + b == 2 for a, b in zip(s_a, s_b)) / batch.s.size(0), } mixed_x = lam * x_a + (1 - lam) * x_b return Mixed( x=mixed_x.requires_grad_(True), xa=x_a, xb=x_b, sa=s_a, sb=s_b, ya=y_a, yb=y_b, lam=lam, stats=mix_stats, )
def __init__( self, root: Union[str, Path], *, download: bool = True, transform: Optional[ImageTform] = None, label_map: Optional[Dict[str, int]] = None, colors: Optional[List[int]] = None, num_colors: int = 10, scale: float = 0.2, correlation: Optional[float] = None, binarize: bool = False, greyscale: bool = False, background: bool = False, black: bool = True, split: Optional[Union[ColoredMNISTSplit, str]] = None, seed: Optional[int] = 42, ) -> None: self.split = (str_to_enum(str_=split, enum=ColoredMNISTSplit) if isinstance(split, str) else split) self.label_map = label_map self.scale = scale self.num_colors = num_colors self.colors = colors self.background = background self.binarize = binarize self.black = black self.greyscale = greyscale self.seed = seed # Note: a correlation coefficient of '1' corresponds to perfect correlation between # digit and class while a correlation coefficient of '-1' corresponds to perfect # anti-correlation. if correlation is None: correlation = 1.0 if split is ColoredMNISTSplit.train else 0.5 if not 0 <= correlation <= 1: raise ValueError( "Strength of correlation between colour and targets must be between 0 and 1." ) self.correlation = correlation if self.split is None: x_ls, y_ls = [], [] for _split in ColoredMNISTSplit: base_dataset = MNIST(root=str(root), download=download, train=_split is ColoredMNISTSplit.train) x_ls.append(base_dataset.data) y_ls.append(base_dataset.targets) x = torch.cat(x_ls, dim=0) y = torch.cat(y_ls, dim=0) else: base_dataset = MNIST(root=str(root), download=download, train=self.split is ColoredMNISTSplit.train) x = base_dataset.data y = base_dataset.targets if self.label_map is not None: x, y = _filter_data_by_labels(data=x, targets=y, label_map=self.label_map) s = y % self.num_colors s_unique, s_unique_inv = s.unique(return_inverse=True) generator = (torch.default_generator if self.seed is None else torch.Generator().manual_seed(self.seed)) inv_card_s = 1 / len(s_unique) if self.correlation < 1: flip_prop = self.correlation * (1.0 - inv_card_s) + inv_card_s # Change the values of randomly-selected labels to values other than their original ones num_to_flip = round((1 - flip_prop) * len(s)) to_flip = torch.randperm(len(s), generator=generator)[:num_to_flip] s_unique_inv[to_flip] += torch.randint(low=1, high=len(s_unique), size=(num_to_flip, )) # s labels live inside the Z/(num_colors * Z) ring s_unique_inv[to_flip] %= len(s_unique) s = s_unique[s_unique_inv] # Convert the greyscale iamges of shape ( H, W ) into 'colour' images of shape ( C, H, W ) colorizer = MNISTColorizer( scale=self.scale, background=self.background, black=self.black, binarize=self.binarize, greyscale=self.greyscale, color_indices=self.colors, seed=self.seed, ) x_colorized = colorizer(images=x, labels=s) # Convert to HWC format for compatibility with transforms x_colorized = x_colorized.movedim(1, -1).numpy().astype(np.uint8) super().__init__(x=x_colorized, y=y, s=s, transform=transform, image_dir=root)
def __init__( self, group_ids: Sequence[int], *, num_samples_per_group: int, multipliers: dict[int, int] | None = None, base_sampler: BaseSampler | str = BaseSampler.sequential, training_mode: TrainingMode | str = TrainingMode.step, replacement: bool = True, shuffle: bool = False, drop_last: bool = True, generator: torch.Generator | None = None, ) -> None: if (not isinstance(num_samples_per_group, int) or isinstance(num_samples_per_group, bool) or num_samples_per_group <= 0): raise ValueError( f"num_samples_per_group should be a positive integer; got {num_samples_per_group}" ) if not isinstance(replacement, bool): raise ValueError( f"replacement should be a boolean value, but got replacement={replacement}" ) if isinstance(base_sampler, str): base_sampler = str_to_enum(str_=base_sampler, enum=BaseSampler) if isinstance(training_mode, str): training_mode = str_to_enum(str_=training_mode, enum=TrainingMode) self.num_samples_per_group = num_samples_per_group multipliers_ = {} if multipliers is None else multipliers group_ids_t = torch.as_tensor(group_ids, dtype=torch.int64) # find all unique IDs groups: list[int] = group_ids_t.unique().tolist() # get the indexes for each group separately and compute the effective number of groups groupwise_idxs: list[tuple[Tensor, int]] = [] num_groups_effective = 0 for group in groups: # Idxs needs to be 1 dimensional idxs = (group_ids_t == group).nonzero(as_tuple=False).view(-1) multiplier = multipliers_.get(group, 1) assert isinstance( multiplier, int) and multiplier >= 0, "multiplier has to be >= 0" groupwise_idxs.append((idxs, multiplier)) num_groups_effective += multiplier if not replacement and len( idxs) < num_samples_per_group * multiplier: raise ValueError( f"Not enough samples in group {group} to sample {num_samples_per_group}." ) self.groupwise_idxs = groupwise_idxs self.num_groups_effective = num_groups_effective self.batch_size = self.num_groups_effective * self.num_samples_per_group self.sampler = base_sampler self.replacement = replacement self.shuffle = shuffle self.drop_last = drop_last self.generator = generator self.training_mode = training_mode if self.training_mode is TrainingMode.epoch: # We define the length of the sampler to be the maximum number of steps # needed to do a complete pass of a group's data groupwise_epoch_length = [ num_batches_per_epoch( num_samples=len(idxs), batch_size=mult * num_samples_per_group, drop_last=self.drop_last, ) for idxs, mult in self.groupwise_idxs ] # Sort the groupwise-idxs by their associated epoch-length sorted_idxs_desc = np.argsort(groupwise_epoch_length)[::-1] self.groupwise_idxs = [ self.groupwise_idxs[idx] for idx in sorted_idxs_desc ] max_epoch_length = groupwise_epoch_length[sorted_idxs_desc[0]] else: max_epoch_length = None super().__init__(epoch_length=max_epoch_length)