def forward(self, img: Tensor) -> Tensor: """ img (PIL Image or Tensor): Image to be transformed. Returns: PIL Image or Tensor: Transformed image. """ fill = self.fill if isinstance(img, Tensor): if isinstance(fill, (int, float)): fill = [float(fill)] * F.get_image_num_channels(img) elif fill is not None: fill = [float(f) for f in fill] op_meta = self._augmentation_space(self.num_magnitude_bins) op_index = int(torch.randint(len(op_meta), (1, )).item()) op_name = list(op_meta.keys())[op_index] magnitudes, signed = op_meta[op_name] magnitude = float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item()) \ if magnitudes.ndim > 0 else 0.0 if signed and torch.randint(2, (1, )): magnitude *= -1.0 return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
def forward(self, img: Tensor, img2: Tensor) -> Tensor: """ img (PIL Image or Tensor): Image to be transformed. Returns: PIL Image or Tensor: AutoAugmented image. """ fill = self.fill if isinstance(img, Tensor): if isinstance(fill, (int, float)): fill = [float(fill)] * F.get_image_num_channels(img) elif fill is not None: fill = [float(f) for f in fill] fill = self.fill if isinstance(img2, Tensor): if isinstance(fill, (int, float)): fill = [float(fill)] * F.get_image_num_channels(img2) elif fill is not None: fill = [float(f) for f in fill] transform_id, probs, signs = self.get_params(len(self.policies)) for i, (op_name, p, magnitude_id) in enumerate(self.policies[transform_id]): if probs[i] <= p: # print(op_name) op_meta = self._augmentation_space(10, F.get_image_size(img)) magnitudes, signed = op_meta[op_name] # print(op_name, magnitudes, signed, magnitude_id) magnitude = float(magnitudes[magnitude_id].item() ) if magnitude_id is not None else 0.0 if signed and signs[i] == 0: magnitude *= -1.0 # print(op_name, magnitude) img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill, img2=img2) return img
def forward( self, image: Tensor, target: Optional[Dict[str, Tensor]] = None ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: if isinstance(image, torch.Tensor): if image.ndimension() not in {2, 3}: raise ValueError( f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions." ) elif image.ndimension() == 2: image = image.unsqueeze(0) r = torch.rand(7) if r[0] < self.p: image = self._brightness(image) contrast_before = r[1] < 0.5 if contrast_before: if r[2] < self.p: image = self._contrast(image) if r[3] < self.p: image = self._saturation(image) if r[4] < self.p: image = self._hue(image) if not contrast_before: if r[5] < self.p: image = self._contrast(image) if r[6] < self.p: channels = F.get_image_num_channels(image) permutation = torch.randperm(channels) is_pil = F._is_pil_image(image) if is_pil: image = F.pil_to_tensor(image) image = F.convert_image_dtype(image) image = image[..., permutation, :, :] if is_pil: image = F.to_pil_image(image) return image, target