Exemple #1
0
 def __init__(
     self,
     data_source: Sized,
     *,
     batch_size: int,
     training_mode: TrainingMode | str = TrainingMode.step,
     shuffle: bool = True,
     drop_last: bool = False,
     generator: torch.Generator | None = None,
 ) -> None:
     self.data_source = data_source
     self.batch_size = batch_size
     self.shuffle = shuffle
     self.drop_last = drop_last
     self.generator = generator
     if isinstance(training_mode, str):
         training_mode = str_to_enum(str_=training_mode, enum=TrainingMode)
     self.training_mode = training_mode
     if self.training_mode is TrainingMode.epoch:
         epoch_length = num_batches_per_epoch(
             num_samples=len(self.data_source),
             batch_size=self.batch_size,
             drop_last=self.drop_last,
         )
     else:
         epoch_length = None
     super().__init__(epoch_length=epoch_length)
    def __init__(
        self,
        lambda_sampler: LS,
        *,
        mode: MixUpMode | str = MixUpMode.linear,
        p: float = 1.0,
        num_classes: int | None = None,
        featurewise: bool = False,
        inplace: bool = False,
    ) -> None:
        """
        :param lambda_sampler: The distribution from which to sample lambda (the mixup interpolation
            parameter).

        :param mode: Which mode to use to mix up samples: geometric or linear.

        .. note::
            The (weighted) geometric mean, enabled by ``mode=geometric``, is only valid for positive
            inputs.

        :param p: The probability with which the transform will be applied to a given sample.
        :param num_classes: The total number of classes in the dataset that needs to be specified if
            wanting to mix up targets that are label-enoded. Passing label-encoded targets without
            specifying ``num_classes`` will result in a RuntimeError.

        :param featurewise: Whether to sample sample feature-wise instead of sample-wise.

        .. note::
            If the ``lambda_sampler`` is a BernoulliDistribution, then featurewise sampling will
            always be enabled.

        :param inplace: Whether the transform should be performed in-place.

        :raises ValueError: if ``p`` is not in the range [0, 1] or ``num_classes < 1``.
        """
        super().__init__()
        self.lambda_sampler = lambda_sampler
        if not 0 <= p <= 1:
            raise ValueError("'p' must be in the range [0, 1].")
        self.p = p
        if isinstance(mode, str):
            mode = str_to_enum(str_=mode, enum=MixUpMode)
        self.mode = mode
        if (num_classes is not None) and num_classes < 1:
            raise ValueError(f"{ num_classes } must be greater than 1.")
        self.num_classes = num_classes
        self.featurewise = featurewise or isinstance(lambda_sampler,
                                                     td.Bernoulli)
        self.inplace = inplace
    def __init__(
        self,
        root: Union[str, Path],
        *,
        download: bool = True,
        transform: Optional[ImageTform] = None,
        split: Optional[Union[Camelyon17Split, str]] = None,
        split_scheme: Union[Camelyon17SplitScheme,
                            str] = Camelyon17SplitScheme.official,
        superclass: Union[Camelyon17Attr, str] = Camelyon17Attr.tumor,
        subclass: Union[Camelyon17Attr, str] = Camelyon17Attr.center,
    ) -> None:

        self.superclass = str_to_enum(str_=superclass, enum=Camelyon17Attr)
        self.subclass = str_to_enum(str_=subclass, enum=Camelyon17Attr)

        self.split = (str_to_enum(str_=split, enum=Camelyon17Split)
                      if isinstance(split, str) else split)
        self.split_scheme = (str_to_enum(
            str_=split_scheme, enum=Camelyon17SplitScheme) if isinstance(
                split_scheme, str) else split_scheme)
        self.root = Path(root)
        self._base_dir = self.root / "camelyon17_v1.0"
        self.download = download
        if self.download:
            download_from_url(
                file_info=self._FILE_INFO,
                root=self.root,
                logger=self.logger,
                remove_finished=True,
            )
        else:
            raise FileNotFoundError(
                f"Data not found at location {self._base_dir.resolve()}. Have you downloaded it?"
            )

        # Read in metadata
        # Note: metadata is one-indexed.
        self.metadata = pd.read_csv(self._base_dir / 'metadata.csv',
                                    index_col=0,
                                    dtype={"patient": "str"})
        if self.split_scheme is Camelyon17SplitScheme.mixed_to_test:
            # For the mixed-to-test setting,
            # we move slide 23 (corresponding to patient 042, node 3 in the original dataset)
            # from the test set to the training set
            slide_mask = self.metadata["slide"] == 23
            self.metadata.loc[slide_mask,
                              "split"] = Camelyon17Split.train.value
        # Use an official split of the data, if 'split' is specified, else just use all
        # of the data
        val_center_mask = self.metadata["center"] == self._VAL_CENTER
        test_center_mask = self.metadata["center"] == self._TEST_CENTER
        self.metadata.loc[val_center_mask, "split"] = Camelyon17Split.val.value
        self.metadata.loc[test_center_mask,
                          "split"] = Camelyon17Split.test.value

        if self.split is not None:
            split_indices = self.metadata["split"] == self.split.value
            self.metadata = cast(pd.DataFrame, self.metadata[split_indices])

        # Construct filepaths from metadata
        def build_fp(row: pd.DataFrame) -> str:
            return "patches/patient_{0}_node_{1}/patch_patient_{0}_node_{1}_x_{2}_y_{3}.png".format(
                *row)

        x = (self.metadata[["patient", "node", "x_coord",
                            "y_coord"]].apply(build_fp, axis=1).to_numpy())
        # Extract superclass labels
        y = torch.as_tensor(self.metadata[str(self.superclass)].to_numpy(),
                            dtype=torch.long)
        # Extract subclass labels
        s = torch.as_tensor(self.metadata[str(self.subclass)].to_numpy(),
                            dtype=torch.long)

        super().__init__(x=x,
                         y=y,
                         s=s,
                         transform=transform,
                         image_dir=self._base_dir)
Exemple #4
0
class DINO(MomentumTeacherModel):
    _ft_clf: DINOLinearClassifier

    @parsable
    def __init__(
        self,
        *,
        backbone: Union[nn.Module, vit.VitArch, vit.VisionTransformer,
                        ResNetArch, str] = vit.VitArch.small,
        out_dim: int = 65_536,
        lr: float = 5.0e-4,
        warmup_iters: int = 10,
        weight_decay: float = 4.0e-2,
        min_lr: float = 1.0e-6,
        weight_decay_final: float = 0.4,
        freeze_last_layer: int = 1,
        patch_size: int = 16,
        drop_path_rate: float = 0.1,
        norm_last_layer: bool = True,
        use_bn_in_head: bool = False,
        momentum_teacher: float = 0.996,
        momentum_center: float = 0.9,
        teacher_temp: float = 0.04,
        warmup_teacher_temp: float = 0.04,
        student_temp: float = 0.1,
        warmup_teacher_temp_iters: int = 30,
        num_eval_blocks: int = 1,
        lr_eval: float = 1.0e-4,
        global_crop_size: Optional[Union[Tuple[int, int], int]] = None,
        local_crop_size: Union[Tuple[float, float], float] = 0.43,
        global_crops_scale: Tuple[float, float] = (0.4, 1.0),
        local_crops_scale: Tuple[float, float] = (0.05, 0.4),
        local_crops_number: int = 8,
        batch_transforms: Optional[BatchTransform] = None,
        eval_epochs: int = 100,
        eval_batch_size: Optional[int] = None,
    ) -> None:
        super().__init__(
            lr=lr,
            weight_decay=weight_decay,
            eval_epochs=eval_epochs,
            eval_batch_size=eval_batch_size,
            batch_transforms=batch_transforms,
            global_crop_size=global_crop_size,
            local_crop_size=local_crop_size,
            global_crops_scale=global_crops_scale,
            local_crops_scale=local_crops_scale,
            local_crops_number=local_crops_number,
        )
        if isinstance(backbone, str):
            backbone = str_to_enum(str_=backbone, enum=vit.VitArch)
        self.backbone = backbone
        self.num_eval_blocks = num_eval_blocks
        self.warmup_iters = warmup_iters
        self.min_weight_decay = weight_decay_final
        self.min_lr = min_lr
        self.freeze_last_layer = freeze_last_layer

        self.out_dim = out_dim
        self.norm_last_layer = norm_last_layer
        self.use_bn_in_head = use_bn_in_head
        self.momentum_teacher = momentum_teacher
        self.momentum_center = momentum_center
        self.teacher_temp = teacher_temp
        self.warmup_teacher_temp = warmup_teacher_temp
        self.warmup_teacher_temp_iters = warmup_teacher_temp_iters
        self.student_temp = student_temp
        self.eval_lr = lr_eval

        # ViT-specific arguments
        self.patch_size = patch_size
        self.drop_path_rate = drop_path_rate

        self._loss_fn = DINOLoss(
            student_temp=student_temp,
            teacher_temp=teacher_temp,
            warmup_teacher_temp=warmup_teacher_temp,
            warmup_teacher_temp_iters=warmup_teacher_temp_iters,
        )
Exemple #5
0
    def __init__(
        self,
        root: Union[str, Path],
        *,
        transform: Optional[AudioTform] = None,
        download: bool = True,
        target_attrs: Union[Union[SoundscapeAttr, str],
                            List[Union[SoundscapeAttr,
                                       str]]] = SoundscapeAttr.habitat,
        segment_len: Optional[float] = 15,
        preprocess: bool = False,
    ) -> None:

        self.root = Path(root).expanduser()
        self.download = download
        self.base_dir = self.root / self.__class__.__name__
        self.labels_dir = self.base_dir / self.INDICES_DIR
        self.segment_len = segment_len
        self.preprocess = preprocess
        self._metadata_path = self.base_dir / self.METADATA_FILENAME
        self.ec_labels_path = self.labels_dir / self._EC_LABELS_FILENAME
        self.uk_labels_path = self.labels_dir / self._UK_LABELS_FILENAME

        if not isinstance(target_attrs, list):
            target_attrs = [target_attrs]
        self.target_attrs = [
            str(str_to_enum(str_=elem, enum=SoundscapeAttr))
            for elem in target_attrs
        ]

        if self.download:
            self._download_files()
        self._check_files()

        # Extract labels from indices files.
        if not self._metadata_path.exists():
            self._extract_metadata()

        self.metadata = pd.read_csv(self.base_dir / self.METADATA_FILENAME)

        if self.preprocess and (self.num_frames_in_segment is not None):
            # data directory depends on the segment length
            processed_audio_dir = self.base_dir / f"segment_len={self.segment_len}"
            if not (processed_audio_dir / "filepaths.csv").exists():
                self._preprocess_files()

            filepaths = pd.read_csv(processed_audio_dir / "filepaths.csv")
            self.metadata.drop("filePath", inplace=True, axis=1)
            self.metadata = filepaths.merge(self.metadata,
                                            how='left',
                                            on='fileName')

        x = self.metadata["filePath"].to_numpy()
        y = torch.as_tensor(
            self._label_encode(self.metadata[self.target_attrs],
                               inplace=True).to_numpy())

        super().__init__(x=x,
                         y=y,
                         transform=transform,
                         audio_dir=self.base_dir)
Exemple #6
0
class MoCoV2(MomentumTeacherModel):
    _ft_clf: FineTuner
    use_ddp: bool

    @parsable
    def __init__(
        self,
        *,
        backbone: Union[nn.Module, ResNetArch, str] = ResNetArch.resnet18,
        head: Optional[nn.Module] = None,
        out_dim: int = 128,
        num_negatives: int = 65_536,
        momentum_teacher: float = 0.999,
        temp: float = 0.07,
        lr: float = 0.03,
        momentum_sgd: float = 0.9,
        weight_decay: float = 1.0e-4,
        use_mlp: bool = False,
        instance_transforms: Optional[MultiCropTransform] = None,
        batch_transforms: Optional[BatchTransform] = None,
        global_crop_size: Optional[Union[Tuple[int, int], int]] = None,
        local_crop_size: Union[Tuple[float, float], float] = 0.43,
        global_crops_scale: Tuple[float, float] = (0.4, 1.0),
        local_crops_scale: Tuple[float, float] = (0.05, 0.4),
        local_crops_number: int = 0,
        eval_epochs: int = 100,
        eval_batch_size: Optional[int] = None,
    ) -> None:
        """
        PyTorch Lightning implementation of `MoCo <https://arxiv.org/abs/2003.04297>`_
        Paper authors: Xinlei Chen, Haoqi Fan, Ross Girshick, Kaiming He.

        :param backbone: Backbone of the encoder. Can be any nn.Module with a unary forward method
        (accepting a single input Tensor and returning a single output Tensor), in which case embed_dim
        will be inferred by passing a dummy input through the backone, or a 'ResNetArch' instance
        whose value is a resnet builder function which will be called with num_classes=out_dim.
        emb_dim: Feature dimension of the ResNet model: 128); only applicable if backbone is
        an 'ResNetArch' instance.

        :param out_dim: Output size of the encoder; only applicable when backbone is a 'ResNetArch' enum.

        :param num_negatives: queue size; number of negative keys.
        :param momentum_teacher: Momentum (what fraction of the previous iterates parameters to interpolate with)
        for the teacher update.

        :param temp: Softmax temperature.
        :param lr: Learning rate for the student model.
        :param sgd_momentum: Optimizer momentum.
        :param weight_decay: Optimizer weight decay.
        :param use_mlp: Whether to add an MLP head to the decoders (instead of a single linear layer).

        :param instance_transforms: Instance-wise image-transforms to use to generate the positive pairs
        for instance-discrimination.

        :param batch_transforms: Batch-wise image-transforms to use to generate the positive pairs for
        instance-discrimination.

        :param multicrop: Whether to use a multi-crop augmentation policy wherein the same image is
        randomly cropped to get a pair of high resolution (global) images and along with multiple
        lower resolution (generally covering less than 50% of the image) images of number 'local_crops_number'.

        :param global_crops_scale: Scale range of the cropped image before resizing, relative to the origin image.
        Used for large global view cropping. Only applies when 'multicrop=True'.

        :param local_crops_number: Number of small local views to generate.
        :param global_crops_scale: Scale range of the cropped image before resizing, relative to the origin image.
        Used for small, local cropping. Only applies when 'multicrop=True'.

        :param eval_epochs: Number of epochs to train the post-hoc classifier for during validation/testing.
        :param eval_batch_size: Batch size to use when training the post-hoc classifier during validation/testing.
        """
        super().__init__(
            lr=lr,
            weight_decay=weight_decay,
            eval_epochs=eval_epochs,
            eval_batch_size=eval_batch_size,
            instance_transforms=instance_transforms,
            batch_transforms=batch_transforms,
            global_crop_size=global_crop_size,
            local_crop_size=local_crop_size,
            global_crops_scale=global_crops_scale,
            local_crops_scale=global_crops_scale,
            local_crops_number=local_crops_number,
        )
        if isinstance(backbone, str):
            backbone = str_to_enum(str_=backbone, enum=ResNetArch)
        self.backbone = backbone
        self.head = head

        self.out_dim = out_dim
        self.temp = temp
        self.lr = lr
        self.weight_decay = weight_decay
        self.momentum_teacher = momentum_teacher
        self.momentum_sgd = momentum_sgd
        self.num_negatives = num_negatives
        self.use_mlp = use_mlp

        # View-generation settings
        self._global_crop_size = global_crop_size
        self._local_crop_size = local_crop_size
        self.local_crops_number = local_crops_number
        self.local_crops_scale = local_crops_scale
        self.global_crops_scale = global_crops_scale

        # create the queue
        self.mb = MemoryBank(dim=out_dim, capacity=num_negatives)
        self._loss_fn = CrossEntropyLoss(reduction=ReductionType.mean)
    def mixup_data(
        *,
        batch: TernarySample,
        device: torch.device,
        mix_lambda: Optional[float],
        alpha: float,
        fairness: Union[FairnessType, str],
    ) -> Mixed:
        '''Returns mixed inputs, pairs of targets, and lambda'''
        assert isinstance(batch.x, Tensor)
        fairness = str_to_enum(str_=fairness, enum=FairnessType)
        lam = (
            Beta(
                torch.tensor([alpha]).to(device),
                torch.tensor([alpha]).to(device)
                # Potentially change alpha from a=1.0 to account for class imbalance?
            ).sample()
            if mix_lambda is None
            else torch.tensor([mix_lambda]).to(device)
        )

        batches = {
            "x_s0": batch.x[batch.s.view(-1) == 0].to(device),
            "x_s1": batch.x[batch.s.view(-1) == 1].to(device),
            "y_s0": batch.y[batch.s.view(-1) == 0].to(device),
            "y_s1": batch.y[batch.s.view(-1) == 1].to(device),
            "x_s0_y0": batch.x[(batch.s.view(-1) == 0) & (batch.y.view(-1) == 0)].to(device),
            "x_s1_y0": batch.x[(batch.s.view(-1) == 1) & (batch.y.view(-1) == 0)].to(device),
            "x_s0_y1": batch.x[(batch.s.view(-1) == 0) & (batch.y.view(-1) == 1)].to(device),
            "x_s1_y1": batch.x[(batch.s.view(-1) == 1) & (batch.y.view(-1) == 1)].to(device),
            "s_s0_y0": batch.s[(batch.s.view(-1) == 0) & (batch.y.view(-1) == 0)].to(device),
            "s_s1_y0": batch.s[(batch.s.view(-1) == 1) & (batch.y.view(-1) == 0)].to(device),
            "s_s0_y1": batch.s[(batch.s.view(-1) == 0) & (batch.y.view(-1) == 1)].to(device),
            "s_s1_y1": batch.s[(batch.s.view(-1) == 1) & (batch.y.view(-1) == 1)].to(device),
        }
        xal = []
        xbl = []
        sal = []
        sbl = []
        yal = []
        ybl = []

        for x_a, s_a, y_a in zip(batch.x, batch.s, batch.y):
            xal.append(x_a)
            sal.append(s_a.unsqueeze(-1).float())
            yal.append(y_a.unsqueeze(-1).float())
            if (fairness is FairnessType.EqOp and y_a == 0) or fairness is FairnessType.No:
                xbl.append(x_a)
                sbl.append(s_a.unsqueeze(-1))
                ybl.append(y_a.unsqueeze(-1))
            elif fairness is FairnessType.EqOp:
                idx = torch.randint(batches[f"x_s{1 - int(s_a)}_y1"].size(0), (1,))
                x_b = batches[f"x_s{1 - int(s_a)}_y1"][idx, :].squeeze(0)
                xbl.append(x_b)
                sbl.append((torch.ones_like(s_a) * (1 - s_a)).unsqueeze(-1).float())
                y_b = torch.ones_like(y_a).unsqueeze(-1)
                ybl.append(y_b)
            elif fairness is FairnessType.DP:
                idx = torch.randint(batches[f"x_s{1-int(s_a)}"].size(0), (1,))
                x_b = batches[f"x_s{1-int(s_a)}"][idx, :].squeeze(0)
                xbl.append(x_b)
                sbl.append((torch.ones_like(s_a) * (1 - s_a)).unsqueeze(-1).float())
                y_b = batches[f"y_s{1-int(s_a)}"][idx].float()
                ybl.append(y_b)
            elif fairness is FairnessType.EO:
                idx = torch.randint(batches[f"x_s{1-int(s_a)}_y{int(y_a)}"].size(0), (1,))
                x_b = batches[f"x_s{1-int(s_a)}_y{int(y_a)}"][idx, :].squeeze(0)
                xbl.append(x_b)
                sbl.append((torch.ones_like(s_a) * (1 - s_a)).unsqueeze(-1).float())
                y_b = (torch.ones_like(y_a) * y_a).unsqueeze(-1)
                ybl.append(y_b)
        x_a = torch.stack(xal, dim=0).to(device)
        x_b = torch.stack(xbl, dim=0).to(device)

        s_a = torch.stack(sal, dim=0).to(device)
        s_b = torch.stack(sbl, dim=0).to(device)

        y_a = torch.stack(yal, dim=0).to(device)
        y_b = torch.stack(ybl, dim=0).to(device)

        mix_stats = {
            "batch_stats/S0=sS0": sum(a == b for a, b in zip(s_a, s_b)) / batch.s.size(0),
            "batch_stats/S0!=sS0": sum(a != b for a, b in zip(s_a, s_b)) / batch.s.size(0),
            "batch_stats/all_s0": sum(a + b == 0 for a, b in zip(s_a, s_b)) / batch.s.size(0),
            "batch_stats/all_s1": sum(a + b == 2 for a, b in zip(s_a, s_b)) / batch.s.size(0),
        }

        mixed_x = lam * x_a + (1 - lam) * x_b
        return Mixed(
            x=mixed_x.requires_grad_(True),
            xa=x_a,
            xb=x_b,
            sa=s_a,
            sb=s_b,
            ya=y_a,
            yb=y_b,
            lam=lam,
            stats=mix_stats,
        )
    def __init__(
        self,
        root: Union[str, Path],
        *,
        download: bool = True,
        transform: Optional[ImageTform] = None,
        label_map: Optional[Dict[str, int]] = None,
        colors: Optional[List[int]] = None,
        num_colors: int = 10,
        scale: float = 0.2,
        correlation: Optional[float] = None,
        binarize: bool = False,
        greyscale: bool = False,
        background: bool = False,
        black: bool = True,
        split: Optional[Union[ColoredMNISTSplit, str]] = None,
        seed: Optional[int] = 42,
    ) -> None:
        self.split = (str_to_enum(str_=split, enum=ColoredMNISTSplit)
                      if isinstance(split, str) else split)
        self.label_map = label_map
        self.scale = scale
        self.num_colors = num_colors
        self.colors = colors
        self.background = background
        self.binarize = binarize
        self.black = black
        self.greyscale = greyscale
        self.seed = seed
        # Note: a correlation coefficient of '1' corresponds to perfect correlation between
        # digit and class while a correlation coefficient of '-1' corresponds to perfect
        # anti-correlation.
        if correlation is None:
            correlation = 1.0 if split is ColoredMNISTSplit.train else 0.5
        if not 0 <= correlation <= 1:
            raise ValueError(
                "Strength of correlation between colour and targets must be between 0 and 1."
            )
        self.correlation = correlation

        if self.split is None:
            x_ls, y_ls = [], []
            for _split in ColoredMNISTSplit:
                base_dataset = MNIST(root=str(root),
                                     download=download,
                                     train=_split is ColoredMNISTSplit.train)
                x_ls.append(base_dataset.data)
                y_ls.append(base_dataset.targets)
            x = torch.cat(x_ls, dim=0)
            y = torch.cat(y_ls, dim=0)
        else:
            base_dataset = MNIST(root=str(root),
                                 download=download,
                                 train=self.split is ColoredMNISTSplit.train)
            x = base_dataset.data
            y = base_dataset.targets

        if self.label_map is not None:
            x, y = _filter_data_by_labels(data=x,
                                          targets=y,
                                          label_map=self.label_map)
        s = y % self.num_colors
        s_unique, s_unique_inv = s.unique(return_inverse=True)

        generator = (torch.default_generator if self.seed is None else
                     torch.Generator().manual_seed(self.seed))
        inv_card_s = 1 / len(s_unique)
        if self.correlation < 1:
            flip_prop = self.correlation * (1.0 - inv_card_s) + inv_card_s
            # Change the values of randomly-selected labels to values other than their original ones
            num_to_flip = round((1 - flip_prop) * len(s))
            to_flip = torch.randperm(len(s), generator=generator)[:num_to_flip]
            s_unique_inv[to_flip] += torch.randint(low=1,
                                                   high=len(s_unique),
                                                   size=(num_to_flip, ))
            # s labels live inside the Z/(num_colors * Z) ring
            s_unique_inv[to_flip] %= len(s_unique)
            s = s_unique[s_unique_inv]

        # Convert the greyscale iamges of shape ( H, W ) into 'colour' images of shape ( C, H, W )
        colorizer = MNISTColorizer(
            scale=self.scale,
            background=self.background,
            black=self.black,
            binarize=self.binarize,
            greyscale=self.greyscale,
            color_indices=self.colors,
            seed=self.seed,
        )
        x_colorized = colorizer(images=x, labels=s)
        # Convert to HWC format for compatibility with transforms
        x_colorized = x_colorized.movedim(1, -1).numpy().astype(np.uint8)

        super().__init__(x=x_colorized,
                         y=y,
                         s=s,
                         transform=transform,
                         image_dir=root)
Exemple #9
0
    def __init__(
        self,
        group_ids: Sequence[int],
        *,
        num_samples_per_group: int,
        multipliers: dict[int, int] | None = None,
        base_sampler: BaseSampler | str = BaseSampler.sequential,
        training_mode: TrainingMode | str = TrainingMode.step,
        replacement: bool = True,
        shuffle: bool = False,
        drop_last: bool = True,
        generator: torch.Generator | None = None,
    ) -> None:
        if (not isinstance(num_samples_per_group, int)
                or isinstance(num_samples_per_group, bool)
                or num_samples_per_group <= 0):
            raise ValueError(
                f"num_samples_per_group should be a positive integer; got {num_samples_per_group}"
            )
        if not isinstance(replacement, bool):
            raise ValueError(
                f"replacement should be a boolean value, but got replacement={replacement}"
            )
        if isinstance(base_sampler, str):
            base_sampler = str_to_enum(str_=base_sampler, enum=BaseSampler)
        if isinstance(training_mode, str):
            training_mode = str_to_enum(str_=training_mode, enum=TrainingMode)

        self.num_samples_per_group = num_samples_per_group
        multipliers_ = {} if multipliers is None else multipliers

        group_ids_t = torch.as_tensor(group_ids, dtype=torch.int64)
        # find all unique IDs
        groups: list[int] = group_ids_t.unique().tolist()

        # get the indexes for each group separately and compute the effective number of groups
        groupwise_idxs: list[tuple[Tensor, int]] = []
        num_groups_effective = 0
        for group in groups:
            # Idxs needs to be 1 dimensional
            idxs = (group_ids_t == group).nonzero(as_tuple=False).view(-1)
            multiplier = multipliers_.get(group, 1)
            assert isinstance(
                multiplier,
                int) and multiplier >= 0, "multiplier has to be >= 0"
            groupwise_idxs.append((idxs, multiplier))
            num_groups_effective += multiplier

            if not replacement and len(
                    idxs) < num_samples_per_group * multiplier:
                raise ValueError(
                    f"Not enough samples in group {group} to sample {num_samples_per_group}."
                )

        self.groupwise_idxs = groupwise_idxs
        self.num_groups_effective = num_groups_effective
        self.batch_size = self.num_groups_effective * self.num_samples_per_group
        self.sampler = base_sampler
        self.replacement = replacement
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.generator = generator
        self.training_mode = training_mode

        if self.training_mode is TrainingMode.epoch:
            # We define the length of the sampler to be the maximum number of steps
            # needed to do a complete pass of a group's data
            groupwise_epoch_length = [
                num_batches_per_epoch(
                    num_samples=len(idxs),
                    batch_size=mult * num_samples_per_group,
                    drop_last=self.drop_last,
                ) for idxs, mult in self.groupwise_idxs
            ]
            # Sort the groupwise-idxs by their associated epoch-length
            sorted_idxs_desc = np.argsort(groupwise_epoch_length)[::-1]
            self.groupwise_idxs = [
                self.groupwise_idxs[idx] for idx in sorted_idxs_desc
            ]
            max_epoch_length = groupwise_epoch_length[sorted_idxs_desc[0]]
        else:
            max_epoch_length = None

        super().__init__(epoch_length=max_epoch_length)