def rotate(self,X,Y): rotate_nr = random.randint(0, 3) # Number of times to rotate rotate_dir = random.sample([1, 2, 3], k=2) # Direction of axis in which to rotate X = torch.rot90(X, rotate_nr, rotate_dir) Y = torch.rot90(Y, rotate_nr, rotate_dir) return X, Y
def visulization(render_norm, render_tex=None): if render_norm is None and render_tex is None: return None, None, None render_size = 256 if render_norm is not None: render_norm = render_norm.detach() * 255.0 render_norm = torch.rot90(render_norm, 1, [0, 1]).permute(2, 0, 1).unsqueeze(0) render_norm = F.interpolate(render_norm, size=(render_size, render_size)) render_norm = render_norm[0].cpu().numpy().transpose(1, 2, 0) reference = render_norm if render_tex is not None: render_tex = render_tex.detach() * 255.0 render_tex = torch.rot90(render_tex, 1, [0, 1]).permute(2, 0, 1).unsqueeze(0) render_tex = F.interpolate(render_tex, size=(render_size, render_size)) render_tex = render_tex[0].cpu().numpy().transpose(1, 2, 0) reference = render_tex bg = np.logical_and( np.logical_and(reference[:, :, 0] == 255, reference[:, :, 1] == 255), reference[:, :, 2] == 255, ).reshape(render_size, render_size, 1) mask = ~bg return render_norm, render_tex, mask
def test_step(self, batch, batch_idx): # OPTIONAL data = batch t1 = time.time() if (self.hparams.self_ensemble == True): image = data['input'] B, C, H, W = image.shape new_image = torch.zeros((B * 4, C, H, W), dtype=image.dtype, device=image.device) new_image[0:B] = image new_image[B:B * 2] = torch.fliplr(image) image = torch.rot90(image, 2, [2, 3]) new_image[B * 2:B * 3] = image new_image[B * 3:B * 4] = torch.fliplr(image) new_image2 = torch.rot90(new_image, 1, [2, 3]) data['input'] = new_image out1 = self.forward(data)['image'] data['input'] = new_image2 out2 = self.forward(data)['image'] out2 = torch.rot90(out2, 3, [2, 3]) tempout = (out1 + out2) / 2 tempout[B:B * 2] = torch.fliplr(tempout[B:B * 2]) tempout[B * 2:B * 3] = torch.rot90(tempout[B * 2:B * 3], 2, [2, 3]) tempout[B * 3:B * 4] = torch.rot90( torch.fliplr(tempout[B * 3:B * 4]), 2, [2, 3]) for i in range(B): image[i] = torch.mean(tempout[i::B], 0) out = image else: out = self.forward(data)['image'] t2 = time.time() return {'image': out, 'name': data['name'], 't': t2 - t1}
def get_augmented(volume): # perform 16 Augmentations as mentioned in Kisuks thesis vol0 = volume vol90 = torch.rot90(vol0, 1, [3, 4]) vol180 = torch.rot90(vol90, 1, [3, 4]) vol270 = torch.rot90(vol180, 1, [3, 4]) vol0f = torch.flip(vol0, [3]) vol90f = torch.flip(vol90, [3]) vol180f = torch.flip(vol180, [3]) vol270f = torch.flip(vol270, [3]) vol0z = torch.flip(vol0, [2]) vol90z = torch.flip(vol90, [2]) vol180z = torch.flip(vol180, [2]) vol270z = torch.flip(vol270, [2]) vol0fz = torch.flip(vol0f, [2]) vol90fz = torch.flip(vol90f, [2]) vol180fz = torch.flip(vol180f, [2]) vol270fz = torch.flip(vol270f, [2]) augmented_volumes = [ vol0, vol90, vol180, vol270, vol0f, vol90f, vol180f, vol270f, vol0z, vol90z, vol180z, vol270z, vol0fz, vol90fz, vol180fz, vol270fz ] return augmented_volumes
def __call__(self, images, masks=None): """ Rotate 90 * k degree for image and mask Args: images: 3-D tensor of shape [height, width, channel] masks: 2-D tensor of shape [height, width] Returns: images_tensor masks_tensor """ ret = list() k = int(np.random.choice([0, 1, 2, 3], 1)[0]) if self.k is None else self.k if k == 0: ret.append(images) if masks is not None: ret.append(masks) return tuple(ret) if len(ret) > 1 else ret[0] images_tensor = torch.rot90(images, k, [0, 1]) ret.append(images_tensor) if masks is not None: masks_tensor = torch.rot90(masks, k, [0, 1]) ret.append(masks_tensor) return tuple(ret) if len(ret) > 1 else ret[0]
def argument_image_rotation_and_fake_mix(X, X_f, ridx=None): n = X.size()[0] l_0 = torch.cuda.LongTensor([0]).repeat(n) X_90 = torch.rot90(X, 1, [2, 3]) l_90 = torch.cuda.LongTensor([1]).repeat(n) X_180 = torch.rot90(X, 2, [2, 3]) l_180 = torch.cuda.LongTensor([2]).repeat(n) X_270 = torch.rot90(X, 3, [2, 3]) l_270 = torch.cuda.LongTensor([3]).repeat(n) l_fake = torch.cuda.LongTensor([4]).repeat(n) Xarg = torch.cat([X, X_90, X_180, X_270, X_f]) larg = torch.cat([l_0, l_90, l_180, l_270, l_fake]) if ridx is None: ridx = np.arange(5 * n) np.random.shuffle(ridx) ridx = ridx[0:n] X_out = [] l_out = [] for index in ridx: X_out.append(Xarg[index]) l_out.append(larg[index]) rot_labels = one_hot(torch.stack(l_out), 5) return torch.stack(X_out), rot_labels.double(), ridx
def randomly_flip_and_rotate_images(images): r""" Info: Randomly perform horizontal flipping and/or rotation of 90/180/270 degree. """ num_imgs = len(images) rint = np.random.randint(low=0, high=2) if rint: # Horizontal flip for idx in range(num_imgs): images[idx] = torch.flip(images[idx], dims=[2]) rint = np.random.randint(low=0, high=3) if rint == 0: # Rotate 90 degree for idx in range(num_imgs): images[idx] = torch.rot90(images[idx], k=1, dims=[1, 2]) elif rint == 1: # Rotate 180 degree for idx in range(num_imgs): images[idx] = torch.rot90(images[idx], k=2, dims=[1, 2]) elif rint == 2: # Rotate 270 degree for idx in range(num_imgs): images[idx] = torch.rot90(images[idx], k=3, dims=[1, 2]) return images
def symmetry(x, mode="real"): center = (x.shape[1]) // 2 u = torch.arange(center) v = torch.arange(center) diag1 = torch.arange(center, x.shape[1]) diag2 = torch.arange(center, x.shape[1]) diag_indices = torch.stack((diag1, diag2)) grid = torch.tril_indices(x.shape[1], x.shape[1], -1) x_sym = torch.cat( (grid[0].reshape(-1, 1), diag_indices[0].reshape(-1, 1)), ) y_sym = torch.cat( (grid[1].reshape(-1, 1), diag_indices[1].reshape(-1, 1)), ) x = torch.rot90(x, 1, dims=(1, 2)) i = center + (center - x_sym) j = center + (center - y_sym) u = center - (center - x_sym) v = center - (center - y_sym) if mode == "real": x[:, i, j] = x[:, u, v] if mode == "imag": x[:, i, j] = -x[:, u, v] return torch.rot90(x, 3, dims=(1, 2))
def rearrange_features(features, img1_2_flip, img1_2_rot, img2_2_flip, img2_2_rot, args): ''' Reorder the features to align featrues of the corresponding similar pixels. ''' for instance_index, (if_flip, rot_index) in enumerate(zip(img1_2_flip, img1_2_rot)): if if_flip: features[2 * args['batch_size'] + instance_index] = torch.flip( features[2 * args['batch_size'] + instance_index], dims=(1, 2)) if rot_index: features[2 * args['batch_size'] + instance_index] = torch.rot90( features[2 * args['batch_size'] + instance_index], k=-rot_index, dims=(1, 2)) for instance_index, (if_flip, rot_index) in enumerate(zip(img2_2_flip, img2_2_rot)): if if_flip: features[3 * args['batch_size'] + instance_index] = torch.flip( features[3 * args['batch_size'] + instance_index], dims=(1, 2)) if rot_index: features[3 * args['batch_size'] + instance_index] = torch.rot90( features[3 * args['batch_size'] + instance_index], k=-rot_index, dims=(1, 2)) return features
def dump_grasp_tuple(left_finger, grasp_object, right_finger, path): left_finger = rot90(left_finger, 2, (3, 4)) right_finger = rot90(right_finger, 2, (2, 4)) input_tsdf = cat([left_finger, grasp_object, right_finger], dim=2) TSDFHelper.to_mesh(tsdf=input_tsdf[0, 0, :].cpu().numpy(), voxel_size=0.015, path=path)
def process_data(supp, query, train=True, config=gconfig): if train: # return [supp, query] # load train data way, number = len(supp[0]), len(query[0]) // len(supp[0]) + 1 others = supp[0].size()[1:] if config['pretrain_shot'] == 1: x = supp[0] y = supp[2] else: x = torch.cat([supp[0], query[0]]) y = torch.cat([supp[2], query[2]]) y, slices = y.sort() x = x[slices].reshape(way, number, *others) y = y.reshape(way, number) randidx = torch.randperm(number)[:config['pretrain_shot']] x, y = x[:, randidx, :].reshape(way * config['pretrain_shot'], *others), y[:, randidx].reshape(-1) if config['rotation']: # x.shape in CIFAR100: [64=32way*2shot, 3, 28, 28] x90 = torch.rot90(x, 1, [2, 3]) x180 = torch.rot90(x90, 1, [2, 3]) x270 = torch.rot90(x180, 1, [2, 3]) x_ = torch.cat((x, x90, x180, x270), 0) y_ = torch.cat((y, y, y, y), 0) x = x_ y = y_ return [x, y] else: # load valid data return [supp, query]
def forward(self, x): x_compress1 = self.compress1(x) x_compress2 = self.compress2(x) x_compress3 = self.compress3(x) x_rot1 = torch.cat( [torch.rot90(x_compress1, i, dims=[2, 3]) for i in range(4)], dim=1) # 4 * 256 * 256 x_out1 = self.spatial1(x_rot1) x_rot2 = torch.cat( [torch.rot90(x_compress2, i, dims=[2, 3]) for i in range(4)], dim=1) # 8 * 256 * 256 x_out2 = self.spatial2(x_rot2) x_rot3 = torch.cat( [torch.rot90(x_compress3, i, dims=[2, 3]) for i in range(4)], dim=1) # 16 * 256 * 256 x_out3 = self.spatial3(x_rot3) x_out = x_out1 + x_out2 + x_out3 scale = torch.sigmoid(x_out) # broadcasting scale = scale.repeat(1, int(self.channel / 8), 1, 1) x_scale = x * scale return x_scale
def rand_90(img: torch.Tensor, img2: torch.Tensor = None, prob: float = 0.5) -> torch.Tensor: """ Randomly rotate the given the Image Tensor 90 degrees clockwise or counterclockwise (random). Args: img: Image Tensor to be rotated, in the form [C, H, W]. img2: Second image Tensor to be rotated, in the form [C, H, W]. (optional) prob (float): Probabilty of rotation. C-W and counter C-W have same probability of happening. Returns: Tensor: Rotated image Tensor. (Careful if image dimensions are not square) """ #if not _is_tensor_a_torch_image(img): #raise TypeError('tensor is not a torch image.') if np.random.random() < prob / 2.: img = torch.rot90(img, 1, dims=[2, 3]) #img = img.transpose(2, 3).flip(2) if img2 is not None: img2 = torch.rot90(img2, 1, dims=[2, 3]) elif np.random.random() < prob: img = torch.rot90(img, -1, dims=[2, 3]) #img = img.transpose(2, 3).flip(3) if img2 is not None: img2 = torch.rot90(img2, -1, dims=[2, 3]) if img2 is not None: return img, img2 else: return img
def TTA(net, image, mode='cls'): """ Do test time augmentations on single image for classification or segmentation. Note: Only suport single image per time! Args: image: [N, C, H, W] tensor of image (have transformed) mode: 'cls' for classification, 'seg' for segmentation. """ # predict a complete image aug_imgs = [] for i in range(4): aug_imgs.append(torch.rot90(image.clone(), i, dims=(3, 2))) aug_imgs.append(torch.flip(image.clone(), [2])) # filp H aug_imgs.append(torch.flip(image.clone(), [3])) # filp W aug_imgs = torch.cat(aug_imgs, dim=0) outputs = net(aug_imgs) if mode == 'cls': # outputs: [NC] predict = outputs.mean(dim=0, keepdim=True) elif mode == 'seg': # outputs: [NCHW] predict = torch.flip(outputs[5, None].clone(), [3]) predict += torch.flip(outputs[4, None].clone(), [2]) for i in range(4): predict += torch.rot90(outputs[i, None].clone(), i, dims=(2, 3)) return predict
def __getitem__(self, index): index_ = index % self.sizex ps = self.ps inp_path = self.inp_filenames[index_] tar_path = self.tar_filenames[index_] inp_img = Image.open(inp_path) tar_img = Image.open(tar_path) w, h = tar_img.size padw = ps - w if w < ps else 0 padh = ps - h if h < ps else 0 # Reflect Pad in case image is smaller than patch_size if padw != 0 or padh != 0: inp_img = TF.pad(inp_img, (0, 0, padw, padh), padding_mode='reflect') tar_img = TF.pad(tar_img, (0, 0, padw, padh), padding_mode='reflect') inp_img = TF.to_tensor(inp_img) tar_img = TF.to_tensor(tar_img) hh, ww = tar_img.shape[1], tar_img.shape[2] rr = random.randint(0, hh - ps) cc = random.randint(0, ww - ps) aug = random.randint(0, 8) # Crop patch inp_img = inp_img[:, rr:rr + ps, cc:cc + ps] tar_img = tar_img[:, rr:rr + ps, cc:cc + ps] # Data Augmentations if aug == 1: inp_img = inp_img.flip(1) tar_img = tar_img.flip(1) elif aug == 2: inp_img = inp_img.flip(2) tar_img = tar_img.flip(2) elif aug == 3: inp_img = torch.rot90(inp_img, dims=(1, 2)) tar_img = torch.rot90(tar_img, dims=(1, 2)) elif aug == 4: inp_img = torch.rot90(inp_img, dims=(1, 2), k=2) tar_img = torch.rot90(tar_img, dims=(1, 2), k=2) elif aug == 5: inp_img = torch.rot90(inp_img, dims=(1, 2), k=3) tar_img = torch.rot90(tar_img, dims=(1, 2), k=3) elif aug == 6: inp_img = torch.rot90(inp_img.flip(1), dims=(1, 2)) tar_img = torch.rot90(tar_img.flip(1), dims=(1, 2)) elif aug == 7: inp_img = torch.rot90(inp_img.flip(2), dims=(1, 2)) tar_img = torch.rot90(tar_img.flip(2), dims=(1, 2)) filename = os.path.splitext(os.path.split(tar_path)[-1])[0] return tar_img, inp_img, filename
def forward(self, input): weight = self.weight.reshape(self.out_features, self.grid_with, self.grid_with, self.in_features) weight1 = torch.rot90(weight, 1, [1, 2]) weight1 = weight1.reshape( self.out_features, self.grid_with * self.grid_with * self.in_features) x1 = F.linear(input, weight1, self.bias) weight2 = torch.rot90(weight, 2, [1, 2]) weight2 = weight2.reshape( self.out_features, self.grid_with * self.grid_with * self.in_features) x2 = F.linear(input, weight2, self.bias) weight3 = torch.rot90(weight, 3, [1, 2]) weight3 = weight3.reshape( self.out_features, self.grid_with * self.grid_with * self.in_features) x3 = F.linear(input, weight3, self.bias) weight0 = weight.reshape( self.out_features, self.grid_with * self.grid_with * self.in_features) x0 = F.linear(input, weight0, self.bias) m = nn.MaxPool1d(4, stride=4) x = m(torch.stack([x0, x1, x2, x3]).permute(1, 2, 0)).squeeze() return x
def combine_augmented(outputs): assert len(outputs) == 16 for i in range(8, 16): outputs[i] = torch.flip(outputs[i], [2]) for i in range(4, 8): outputs[i] = torch.flip(outputs[i], [3]) for i in range(12, 16): outputs[i] = torch.flip(outputs[i], [3]) for i in range(1, 16, 4): outputs[i] = torch.rot90(outputs[i], -1, [3, 4]) for i in range(2, 16, 4): outputs[i] = torch.rot90(outputs[i], -1, [3, 4]) outputs[i] = torch.rot90(outputs[i], -1, [3, 4]) for i in range(3, 16, 4): outputs[i] = torch.rot90(outputs[i], 1, [3, 4]) # output = torch.zeros_like(outputs[0], dtype=torch.float64) # for i in range(len(outputs)): # output += outputs[i].double() # output = output / 16.0 for i in range(len(outputs)): outputs[i] = outputs[i].unsqueeze(0) output = torch.min(torch.cat(outputs, 0), 0)[0] return output, outputs
def augment_data(x, y, n=None): """ Generate an augmented training dataset with random reflections and 90 degree rotations x, y : Image sets of shape (Samples, Width, Height, Channels) training images and next images n : number of training examples """ n_data = x.shape[0] if not n: n = n_data x_out, y_out = list(), list() for i in range(n): r = random.randint(0, n_data) x_r, y_r = x[r], y[r] if random.random() < 0.5: x_r = torch.fliplr(x_r) y_r = torch.fliplr(y_r) if random.random() < 0.5: x_r = torch.flipud(x_r) y_r = torch.flipud(y_r) num_rots = random.randint(0, 4) x_r = torch.rot90(x_r, k=num_rots) y_r = torch.rot90(y_r, k=num_rots) x_out.append(x_r), y_out.append(y_r) return torch.stack(x_out), torch.stack(y_out)
def forward(self, input): weight_0 = self.weight.data weight_90 = torch.rot90(weight_0, 1, [2, 3]) weight_180 = torch.rot90(weight_90, 1, [2, 3]) weight_270 = torch.rot90(weight_180, 1, [2, 3]) weight_all = torch.cat((weight_0, weight_90, weight_180, weight_270), dim=0) bias_all = torch.cat((self.bias, self.bias, self.bias, self.bias)) return F.conv2d(input, weight=weight_all, bias=bias_all, padding=self.padding)
def rotate(x, angle): if angle == 0: return x elif angle == 90: return torch.rot90(x, k=1, dims=(3, 2)) elif angle == 180: return torch.rot90(x, k=2, dims=(3, 2)) elif angle == 270: return torch.rot90(x, k=3, dims=(3, 2))
def __getitem__(self, idex): # Randomly choose two images choosen_index = random.randint(0, len(self.data_list) - 1) img1 = Image.open( os.path.join(self.data_path, self.data_list[choosen_index])).convert('L') choosen_index = random.randint(0, len(self.data_list) - 1) img2 = Image.open( os.path.join(self.data_path, self.data_list[choosen_index])).convert('L') img1 = self.transform_list[0](img1) img2 = self.transform_list[0](img2) # Generate similar pairs img1_1 = self.transform_list[1](img1) img2_1 = self.transform_list[1](img2) img1_2 = self.transform_list[1](img1) img2_2 = self.transform_list[1](img2) img1_2_flip = random.randint(0, 1) img1_2_rot = random.randint(0, 3) if self.if_flip: if img1_2_flip: img1_2 = torch.flip(img1_2, dims=(1, 2)) if self.if_rot_90: if img1_2_rot: img1_2 = torch.rot90(img1_2, k=img1_2_rot, dims=(1, 2)) img2_2_flip = random.randint(0, 1) img2_2_rot = random.randint(0, 3) if self.if_flip: if img2_2_flip: img2_2 = torch.flip(img2_2, dims=(1, 2)) if self.if_rot_90: if img2_2_rot: img2_2 = torch.rot90(img2_2, k=img2_2_rot, dims=(1, 2)) if self.if_mixup: # IF if_mixup is True, mixup images will be created. img1_weight = random.random() * 0.6 + 0.2 img3 = img1_weight * img1_1 + (1 - img1_weight) * img2_1 img3 = self.transform_list[2](img3) img1_weight = torch.tensor(img1_weight) img1_1 = self.transform_list[2](img1_1) img1_2 = self.transform_list[2](img1_2) img2_1 = self.transform_list[2](img2_1) img2_2 = self.transform_list[2](img2_2) return img1_1, img1_2, img2_1, img2_2, img3, img1_weight, img1_2_flip, img1_2_rot, img2_2_flip, img2_2_rot else: img1_1 = self.transform_list[2](img1_1) img1_2 = self.transform_list[2](img1_2) img2_1 = self.transform_list[2](img2_1) img2_2 = self.transform_list[2](img2_2) return img1_1, img1_2, img2_1, img2_2, img1_2_flip, img1_2_rot, img2_2_flip, img2_2_rot
def data_augmentation(image, mask): image = torch.Tensor(image) mask = torch.Tensor(mask) mask = mask.unsqueeze(0) if random.random() < 0.5: # flip left right image = torch.fliplr(image) mask = torch.fliplr(mask) rot = np.random.choice([0, 1, 2, 3]) image = torch.rot90(image, rot, [1, 2]) mask = torch.rot90(mask, rot, [1, 2]) if random.random() < 0.5: # flip up-down image = torch.flipud(image) mask = torch.flipud(mask) if intensity >= 1: # random crop cropsize = image.shape[2] // 2 image, mask = random_crop(image, mask, cropsize=cropsize) std_noise = 1 * image.std() if random.random() < 0.5: # add noise per pixel and per channel pixel_noise = torch.rand(image.shape[1], image.shape[2]) pixel_noise = torch.repeat_interleave(pixel_noise.unsqueeze(0), image.size(0), dim=0) image = image + pixel_noise * std_noise if random.random() < 0.5: channel_noise = torch.rand( image.shape[0]).unsqueeze(1).unsqueeze(2) channel_noise = torch.repeat_interleave( torch.repeat_interleave(channel_noise, image.shape[1], 1), image.shape[2], 2) image = image + channel_noise * std_noise if random.random() < 0.5: # add noise noise = torch.rand(image.shape[0], image.shape[1], image.shape[2]) * std_noise image = image + noise if intensity >= 2: # channel shuffle if random.random() < 0.5: idxs = np.arange(image.shape[0]) np.random.shuffle(idxs) # random band indixes image = image[idxs] mask = mask.squeeze(0) return image, mask
def __call__(self, image): angle = random.randint(0, 3) if type(image) is not list: image = torch.rot90(image, angle, [1, 2]) else: for i in range(len(image)): image[i] = torch.rot90(image[i], angle, [1, 2]) return image
def revert_tta_factory(flip, rot): if flip and rot: return lambda x: torch.rot90(x.flip(flip), rot, dims=(3, 4)) elif flip: return lambda x: x.flip(flip) elif rot: return lambda x: torch.rot90(x, rot, dims=(3, 4)) else: raise
def val_pred(MaskCN, EnhanceSN, image, coarsemask): rot_90 = torch.rot90(image, 1, [2, 3]) rot_180 = torch.rot90(image, 2, [2, 3]) rot_270 = torch.rot90(image, 3, [2, 3]) hor_flip = torch.flip(image, [-1]) ver_flip = torch.flip(image, [-2]) image = torch.cat([image, rot_90, rot_180, rot_270, hor_flip, ver_flip], dim=0) rot_90_cm = torch.rot90(coarsemask, 1, [2, 3]) rot_180_cm = torch.rot90(coarsemask, 2, [2, 3]) rot_270_cm = torch.rot90(coarsemask, 3, [2, 3]) hor_flip_cm = torch.flip(coarsemask, [-1]) ver_flip_cm = torch.flip(coarsemask, [-2]) coarsemask = torch.cat([ coarsemask, rot_90_cm, rot_180_cm, rot_270_cm, hor_flip_cm, ver_flip_cm ], dim=0) EnhanceSN.eval() with torch.no_grad(): data_cla = torch.cat((image, coarsemask), dim=1) cla_cam = cam(MaskCN, data_cla) cla_cam = torch.from_numpy(np.stack(cla_cam)).unsqueeze(1).cuda() pred = EnhanceSN(image, cla_cam) pred = pred[0:1] + torch.rot90(pred[1:2], 3, [2, 3]) + torch.rot90( pred[2:3], 2, [2, 3]) + torch.rot90(pred[3:4], 1, [2, 3]) + torch.flip( pred[4:5], [-1]) + torch.flip(pred[5:6], [-2]) return pred
def rotx4_forward(self, lq): """Flip testing with rotation self ensemble, i.e., normal,90, 180, 270 Args: model (PyTorch model) inp (Tensor): inputs defined by the model Returns: output (Tensor): outputs of the model. float """ # normal output_r = self.generator(lq) lq_90 = torch.rot90(lq, 1, [-1, -2]) lq_180 = torch.rot90(lq_90, 1, [-1, -2]) lq_270 = torch.rot90(lq_180, 1, [-1, -2]) # counter-clockwise 90 output = self.generator(lq_90) output_r = output_r + torch.rot90(output, 1, [-2, -1]) # counter-clockwise 180 output = self.generator(lq_180) output_r = output_r + torch.rot90(torch.rot90(output, 1, [-2, -1]), 1, [-2, -1]) # counter-clockwise 270 output = self.generator(lq_270) output_r = output_r + torch.rot90( torch.rot90(torch.rot90(output, 1, [-2, -1]), 1, [-2, -1]), 1, [-2, -1]) return output_r / 4
def forward(self, x, idx_scale): self.idx_scale = idx_scale if hasattr(self.model, 'set_scale'): self.model.set_scale(idx_scale) if self.training: if self.n_GPUs > 1: return P.data_parallel(self.model, x, range(self.n_GPUs)) else: return self.model(x) else: if self.chop: forward_function = self.forward_chop else: forward_function = self.model.forward y = None for i in range(len(self.chop_patch_size)): if self.self_ensemble: if y is None: y = self.forward_x8(x, forward_function=forward_function, chop=self.chop, min_size=self.chop_patch_size[i] * self.chop_patch_size[i], shave_size=self.shave_size[i]) else: y += self.forward_x8(x, forward_function=forward_function, chop=self.chop, min_size=self.chop_patch_size[i] * self.chop_patch_size[i], shave_size=self.shave_size[i]) else: rot_x = torch.rot90(x, i, [2, 3]) if y is None: if self.chop: y = forward_function( rot_x, shave=self.shave_size[i], min_size=self.chop_patch_size[i] * self.chop_patch_size[i]) else: y = forward_function(rot_x) else: if self.chop: rot_y = forward_function( rot_x, shave=self.shave_size[i], min_size=self.chop_patch_size[i] * self.chop_patch_size[i]) else: rot_y = forward_function(rot_x) y += torch.rot90(rot_y, -i, [2, 3]) return y.div(float(len(self.chop_patch_size)))
def apply_tta(input): inputs = [] inputs.append(input) inputs.append(torch.flip(input, dims=[2])) inputs.append(torch.flip(input, dims=[3])) inputs.append(torch.rot90(input, k=1, dims=[2, 3])) inputs.append(torch.rot90(input, k=2, dims=[2, 3])) inputs.append(torch.rot90(input, k=3, dims=[2, 3])) inputs.append(torch.rot90(torch.flip(input, dims=[2]), k=1, dims=[2, 3])) inputs.append(torch.rot90(torch.flip(input, dims=[2]), k=3, dims=[2, 3])) return inputs
def forward(self, input): assert len(input.shape) == 5 grasp_object = input[:, 0, :, :, :].unsqueeze(dim=1) left_finger = input[:, 1, :, :, :].unsqueeze(dim=1) right_finger = input[:, 2, :, :, :].unsqueeze(dim=1) left_finger = rot90(left_finger, 2, (3, 4)) right_finger = rot90(right_finger, 2, (2, 4)) input_tsdf = cat([left_finger, grasp_object, right_finger], dim=2) return self.net(input_tsdf)
def before_batch(self): x = self.xb[0].clone() y = self.yb[0].clone() randint = np.random.randint(0, 4, x.shape[0]) for i in range(x.shape[0]): x[i, 0] = torch.rot90(x[i, 0], int(randint[i])) x[i, 1] = torch.rot90(x[i, 1], int(randint[i])) y[i, 0] = torch.rot90(y[i, 0], int(randint[i])) y[i, 1] = torch.rot90(y[i, 1], int(randint[i])) self.learn.xb = [x] self.learn.yb = [y]