def on_batch_start(self, state: RunnerState):
        if not self.is_needed:
            return

        if self.alpha > 0:
            self.lam = np.random.beta(self.alpha, self.alpha)
        else:
            self.lam = 1

        classes_in_batch = torch.unique(state.input[self.input_key]).cpu().numpy()
        batch_idx = np.arange(state.input[self.fields[0]].shape[0])
        # now make permutations per each class

        for idx, image_class in enumerate(classes_in_batch):
            # images with this class
            class_mask = (state.input[self.input_key].cpu().numpy() == image_class).squeeze()
            class_idx = batch_idx[class_mask]
            index = np.random.permutation(class_idx)
            index = torch.tensor(index, dtype=torch.long)
            class_idx = torch.tensor(class_idx, dtype=torch.long)
            index.to(state.device)
            class_idx.to(state.device)
            for f in self.fields:
                state.input[f][class_idx] = self.lam * state.input[f][class_idx] + (1 - self.lam) * state.input[f][
                    index]
Ejemplo n.º 2
0
    def on_batch_start(self, state: RunnerState):
        if not self.is_needed:
            return

        if self.alpha > 0:
            self.lam = np.random.beta(self.alpha, self.alpha)
        else:
            self.lam = 1

        self.index = torch.randperm(state.input[self.fields[0]].shape[0])
        self.index.to(state.device)

        for f in self.fields:
            state.input[f] = self.lam * state.input[f] + \
                (1 - self.lam) * state.input[f][self.index]
Ejemplo n.º 3
0
    def on_batch_start(self, state: RunnerState):
        if not self.is_needed:
            return

        if self.alpha > 0:
            self.lam = np.random.beta(self.alpha, self.alpha)
        else:
            self.lam = 1

        self.index = torch.randperm(state.input[self.fields[0]].shape[0])
        self.index.to(state.device)

        bbx1, bby1, bbx2, bby2 = rand_bbox(
            state.input[self.input_key].size(), self.lam)
        for f in self.fields:
            state.input[f][:, :, bbx1:bbx2,
                           bby1:bby2] = state.input[f][self.index, :, bbx1:bbx2, bby1:bby2]