Ejemplo n.º 1
0
    def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
        """
        Args:
            batch (Tensor): Float tensor of size (B, C, H, W)
            target (Tensor): Integer tensor of size (B, )

        Returns:
            Tensor: Randomly transformed batch.
        """
        if batch.ndim != 4:
            raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
        if target.ndim != 1:
            raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
        if not batch.is_floating_point():
            raise TypeError(
                f"Batch dtype should be a float tensor. Got {batch.dtype}.")
        if target.dtype != torch.int64:
            raise TypeError(
                f"Target dtype should be torch.int64. Got {target.dtype}")

        if not self.inplace:
            batch = batch.clone()
            target = target.clone()

        if target.ndim == 1:
            target = torch.nn.functional.one_hot(
                target, num_classes=self.num_classes).to(dtype=batch.dtype)

        if torch.rand(1).item() >= self.p:
            return batch, target

        # It's faster to roll the batch by one instead of shuffling it to create image pairs
        batch_rolled = batch.roll(1, 0)
        target_rolled = target.roll(1, 0)

        # Implemented as on cutmix paper, page 12 (with minor corrections on typos).
        lambda_param = float(
            torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
        W, H = F.get_image_size(batch)

        r_x = torch.randint(W, (1, ))
        r_y = torch.randint(H, (1, ))

        r = 0.5 * math.sqrt(1.0 - lambda_param)
        r_w_half = int(r * W)
        r_h_half = int(r * H)

        x1 = int(torch.clamp(r_x - r_w_half, min=0))
        y1 = int(torch.clamp(r_y - r_h_half, min=0))
        x2 = int(torch.clamp(r_x + r_w_half, max=W))
        y2 = int(torch.clamp(r_y + r_h_half, max=H))

        batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
        lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))

        target_rolled.mul_(1.0 - lambda_param)
        target.mul_(lambda_param).add_(target_rolled)

        return batch, target
Ejemplo n.º 2
0
    def forward(self, state):
        concentration = self._get_concentration(state)
        if self.training:
            # PyTorch can't backwards pass _sample_dirichlet
            action = Dirichlet(concentration).rsample()
        else:
            # ONNX can't export Dirichlet()
            action = torch._sample_dirichlet(concentration)

        log_prob = Dirichlet(concentration).log_prob(action)
        return rlt.ActorOutput(action=action,
                               log_prob=log_prob.unsqueeze(dim=1))
Ejemplo n.º 3
0
    def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
        """
        Args:
            batch (Tensor): Float tensor of size (B, C, H, W)
            target (Tensor): Integer tensor of size (B, )

        Returns:
            Tensor: Randomly transformed batch.
        """
        if batch.ndim != 4:
            raise ValueError("Batch ndim should be 4. Got {}".format(
                batch.ndim))
        elif target.ndim != 1:
            raise ValueError("Target ndim should be 1. Got {}".format(
                target.ndim))
        elif not batch.is_floating_point():
            raise TypeError(
                'Batch dtype should be a float tensor. Got {}.'.format(
                    batch.dtype))
        elif target.dtype != torch.int64:
            raise TypeError(
                "Target dtype should be torch.int64. Got {}".format(
                    target.dtype))

        if not self.inplace:
            batch = batch.clone()
            target = target.clone()

        if target.ndim == 1:
            target = torch.nn.functional.one_hot(
                target, num_classes=self.num_classes).to(dtype=batch.dtype)

        if torch.rand(1).item() >= self.p:
            return batch, target

        # It's faster to roll the batch by one instead of shuffling it to create image pairs
        batch_rolled = batch.roll(1, 0)
        target_rolled = target.roll(1, 0)

        # Implemented as on mixup paper, page 3.
        lambda_param = float(
            torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
        batch_rolled.mul_(1.0 - lambda_param)
        batch.mul_(lambda_param).add_(batch_rolled)

        target_rolled.mul_(1.0 - lambda_param)
        target.mul_(lambda_param).add_(target_rolled)

        return batch, target
Ejemplo n.º 4
0
def sample_mat_default(row_size=RANK,
                       col_size=RANK,
                       size=SINGLE_SAMPLE_SIZE,
                       device=DEVICE):
    '''
    default sampling function.
    return an array of matrices of shape row_size*col_size,
    '''
    return torch._sample_dirichlet(
        torch.tensor([
            [
                1.,
            ] * row_size,
        ] * (col_size * size),
                     dtype=torch.float64,
                     device=device)).reshape(-1, col_size,
                                             row_size).permute(0, 2, 1)
Ejemplo n.º 5
0
 def forward(ctx, concentration):
     x = torch._sample_dirichlet(concentration)
     ctx.save_for_backward(x, concentration)
     return x
Ejemplo n.º 6
0
 def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
     # Must be on a separate method so that we can overwrite it in tests.
     return torch._sample_dirichlet(params)