Ejemplo n.º 1
0
    def rotate(self,X,Y):
        rotate_nr = random.randint(0, 3)  # Number of times to rotate
        rotate_dir = random.sample([1, 2, 3], k=2)  # Direction of axis in which to rotate
        X = torch.rot90(X, rotate_nr, rotate_dir)
        Y = torch.rot90(Y, rotate_nr, rotate_dir)

        return X, Y
Ejemplo n.º 2
0
def visulization(render_norm, render_tex=None):
    if render_norm is None and render_tex is None:
        return None, None, None

    render_size = 256

    if render_norm is not None:
        render_norm = render_norm.detach() * 255.0
        render_norm = torch.rot90(render_norm, 1,
                                  [0, 1]).permute(2, 0, 1).unsqueeze(0)
        render_norm = F.interpolate(render_norm,
                                    size=(render_size, render_size))
        render_norm = render_norm[0].cpu().numpy().transpose(1, 2, 0)
        reference = render_norm

    if render_tex is not None:
        render_tex = render_tex.detach() * 255.0
        render_tex = torch.rot90(render_tex, 1, [0, 1]).permute(2, 0,
                                                                1).unsqueeze(0)
        render_tex = F.interpolate(render_tex, size=(render_size, render_size))
        render_tex = render_tex[0].cpu().numpy().transpose(1, 2, 0)
        reference = render_tex

    bg = np.logical_and(
        np.logical_and(reference[:, :, 0] == 255, reference[:, :, 1] == 255),
        reference[:, :, 2] == 255,
    ).reshape(render_size, render_size, 1)
    mask = ~bg

    return render_norm, render_tex, mask
Ejemplo n.º 3
0
 def test_step(self, batch, batch_idx):
     # OPTIONAL
     data = batch
     t1 = time.time()
     if (self.hparams.self_ensemble == True):
         image = data['input']
         B, C, H, W = image.shape
         new_image = torch.zeros((B * 4, C, H, W),
                                 dtype=image.dtype,
                                 device=image.device)
         new_image[0:B] = image
         new_image[B:B * 2] = torch.fliplr(image)
         image = torch.rot90(image, 2, [2, 3])
         new_image[B * 2:B * 3] = image
         new_image[B * 3:B * 4] = torch.fliplr(image)
         new_image2 = torch.rot90(new_image, 1, [2, 3])
         data['input'] = new_image
         out1 = self.forward(data)['image']
         data['input'] = new_image2
         out2 = self.forward(data)['image']
         out2 = torch.rot90(out2, 3, [2, 3])
         tempout = (out1 + out2) / 2
         tempout[B:B * 2] = torch.fliplr(tempout[B:B * 2])
         tempout[B * 2:B * 3] = torch.rot90(tempout[B * 2:B * 3], 2, [2, 3])
         tempout[B * 3:B * 4] = torch.rot90(
             torch.fliplr(tempout[B * 3:B * 4]), 2, [2, 3])
         for i in range(B):
             image[i] = torch.mean(tempout[i::B], 0)
         out = image
     else:
         out = self.forward(data)['image']
     t2 = time.time()
     return {'image': out, 'name': data['name'], 't': t2 - t1}
def get_augmented(volume):
    # perform 16 Augmentations as mentioned in Kisuks thesis
    vol0 = volume
    vol90 = torch.rot90(vol0, 1, [3, 4])
    vol180 = torch.rot90(vol90, 1, [3, 4])
    vol270 = torch.rot90(vol180, 1, [3, 4])

    vol0f = torch.flip(vol0, [3])
    vol90f = torch.flip(vol90, [3])
    vol180f = torch.flip(vol180, [3])
    vol270f = torch.flip(vol270, [3])

    vol0z = torch.flip(vol0, [2])
    vol90z = torch.flip(vol90, [2])
    vol180z = torch.flip(vol180, [2])
    vol270z = torch.flip(vol270, [2])

    vol0fz = torch.flip(vol0f, [2])
    vol90fz = torch.flip(vol90f, [2])
    vol180fz = torch.flip(vol180f, [2])
    vol270fz = torch.flip(vol270f, [2])

    augmented_volumes = [
        vol0, vol90, vol180, vol270, vol0f, vol90f, vol180f, vol270f, vol0z,
        vol90z, vol180z, vol270z, vol0fz, vol90fz, vol180fz, vol270fz
    ]

    return augmented_volumes
Ejemplo n.º 5
0
    def __call__(self, images, masks=None):
        """ Rotate 90 * k degree for image and mask

        Args:
            images: 3-D tensor of shape [height, width, channel]
            masks: 2-D tensor of shape [height, width]

        Returns:
            images_tensor
            masks_tensor
        """
        ret = list()
        k = int(np.random.choice([0, 1, 2, 3], 1)[0]) if self.k is None else self.k
        if k == 0:
            ret.append(images)
            if masks is not None:
                ret.append(masks)
            return tuple(ret) if len(ret) > 1 else ret[0]

        images_tensor = torch.rot90(images, k, [0, 1])
        ret.append(images_tensor)
        if masks is not None:
            masks_tensor = torch.rot90(masks, k, [0, 1])
            ret.append(masks_tensor)

        return tuple(ret) if len(ret) > 1 else ret[0]
Ejemplo n.º 6
0
def argument_image_rotation_and_fake_mix(X, X_f, ridx=None):

    n = X.size()[0]

    l_0 = torch.cuda.LongTensor([0]).repeat(n)

    X_90 = torch.rot90(X, 1, [2, 3])
    l_90 = torch.cuda.LongTensor([1]).repeat(n)

    X_180 = torch.rot90(X, 2, [2, 3])
    l_180 = torch.cuda.LongTensor([2]).repeat(n)

    X_270 = torch.rot90(X, 3, [2, 3])
    l_270 = torch.cuda.LongTensor([3]).repeat(n)

    l_fake = torch.cuda.LongTensor([4]).repeat(n)

    Xarg = torch.cat([X, X_90, X_180, X_270, X_f])
    larg = torch.cat([l_0, l_90, l_180, l_270, l_fake])

    if ridx is None:
        ridx = np.arange(5 * n)
        np.random.shuffle(ridx)
        ridx = ridx[0:n]

    X_out = []
    l_out = []

    for index in ridx:
        X_out.append(Xarg[index])
        l_out.append(larg[index])

    rot_labels = one_hot(torch.stack(l_out), 5)

    return torch.stack(X_out), rot_labels.double(), ridx
Ejemplo n.º 7
0
def randomly_flip_and_rotate_images(images):
    r"""
    Info:
        Randomly perform horizontal flipping and/or rotation of 90/180/270 degree.
    """
    num_imgs = len(images)
    rint = np.random.randint(low=0, high=2)
    if rint:
        # Horizontal flip
        for idx in range(num_imgs):
            images[idx] = torch.flip(images[idx], dims=[2])

    rint = np.random.randint(low=0, high=3)
    if rint == 0:
        # Rotate 90 degree
        for idx in range(num_imgs):
            images[idx] = torch.rot90(images[idx], k=1, dims=[1, 2])
    elif rint == 1:
        # Rotate 180 degree
        for idx in range(num_imgs):
            images[idx] = torch.rot90(images[idx], k=2, dims=[1, 2])
    elif rint == 2:
        # Rotate 270 degree
        for idx in range(num_imgs):
            images[idx] = torch.rot90(images[idx], k=3, dims=[1, 2])
    return images
Ejemplo n.º 8
0
def symmetry(x, mode="real"):
    center = (x.shape[1]) // 2
    u = torch.arange(center)
    v = torch.arange(center)

    diag1 = torch.arange(center, x.shape[1])
    diag2 = torch.arange(center, x.shape[1])
    diag_indices = torch.stack((diag1, diag2))
    grid = torch.tril_indices(x.shape[1], x.shape[1], -1)

    x_sym = torch.cat(
        (grid[0].reshape(-1, 1), diag_indices[0].reshape(-1, 1)),
    )
    y_sym = torch.cat(
        (grid[1].reshape(-1, 1), diag_indices[1].reshape(-1, 1)),
    )
    x = torch.rot90(x, 1, dims=(1, 2))
    i = center + (center - x_sym)
    j = center + (center - y_sym)
    u = center - (center - x_sym)
    v = center - (center - y_sym)
    if mode == "real":
        x[:, i, j] = x[:, u, v]
    if mode == "imag":
        x[:, i, j] = -x[:, u, v]
    return torch.rot90(x, 3, dims=(1, 2))
def rearrange_features(features, img1_2_flip, img1_2_rot, img2_2_flip,
                       img2_2_rot, args):
    '''
    Reorder the features to align featrues of the corresponding similar pixels.
    '''
    for instance_index, (if_flip,
                         rot_index) in enumerate(zip(img1_2_flip, img1_2_rot)):
        if if_flip:
            features[2 * args['batch_size'] + instance_index] = torch.flip(
                features[2 * args['batch_size'] + instance_index], dims=(1, 2))
        if rot_index:
            features[2 * args['batch_size'] + instance_index] = torch.rot90(
                features[2 * args['batch_size'] + instance_index],
                k=-rot_index,
                dims=(1, 2))

    for instance_index, (if_flip,
                         rot_index) in enumerate(zip(img2_2_flip, img2_2_rot)):
        if if_flip:
            features[3 * args['batch_size'] + instance_index] = torch.flip(
                features[3 * args['batch_size'] + instance_index], dims=(1, 2))
        if rot_index:
            features[3 * args['batch_size'] + instance_index] = torch.rot90(
                features[3 * args['batch_size'] + instance_index],
                k=-rot_index,
                dims=(1, 2))
    return features
Ejemplo n.º 10
0
def dump_grasp_tuple(left_finger, grasp_object, right_finger, path):
    left_finger = rot90(left_finger, 2, (3, 4))
    right_finger = rot90(right_finger, 2, (2, 4))
    input_tsdf = cat([left_finger, grasp_object, right_finger], dim=2)
    TSDFHelper.to_mesh(tsdf=input_tsdf[0, 0, :].cpu().numpy(),
                       voxel_size=0.015,
                       path=path)
Ejemplo n.º 11
0
def process_data(supp, query, train=True, config=gconfig):
    if train:
        # return [supp, query]
        # load train data
        way, number = len(supp[0]), len(query[0]) // len(supp[0]) + 1
        others = supp[0].size()[1:]
        if config['pretrain_shot'] == 1:
            x = supp[0]
            y = supp[2]
        else:
            x = torch.cat([supp[0], query[0]])
            y = torch.cat([supp[2], query[2]])
            y, slices = y.sort()
            x = x[slices].reshape(way, number, *others)
            y = y.reshape(way, number)
            randidx = torch.randperm(number)[:config['pretrain_shot']]
            x, y = x[:, randidx, :].reshape(way * config['pretrain_shot'],
                                            *others), y[:, randidx].reshape(-1)

        if config['rotation']:
            # x.shape in CIFAR100: [64=32way*2shot, 3, 28, 28]
            x90 = torch.rot90(x, 1, [2, 3])
            x180 = torch.rot90(x90, 1, [2, 3])
            x270 = torch.rot90(x180, 1, [2, 3])
            x_ = torch.cat((x, x90, x180, x270), 0)
            y_ = torch.cat((y, y, y, y), 0)
            x = x_
            y = y_
        return [x, y]
    else:
        # load valid data
        return [supp, query]
Ejemplo n.º 12
0
    def forward(self, x):

        x_compress1 = self.compress1(x)
        x_compress2 = self.compress2(x)
        x_compress3 = self.compress3(x)

        x_rot1 = torch.cat(
            [torch.rot90(x_compress1, i, dims=[2, 3]) for i in range(4)],
            dim=1)  # 4 * 256 * 256
        x_out1 = self.spatial1(x_rot1)

        x_rot2 = torch.cat(
            [torch.rot90(x_compress2, i, dims=[2, 3]) for i in range(4)],
            dim=1)  # 8 * 256 * 256
        x_out2 = self.spatial2(x_rot2)

        x_rot3 = torch.cat(
            [torch.rot90(x_compress3, i, dims=[2, 3]) for i in range(4)],
            dim=1)  # 16 * 256 * 256
        x_out3 = self.spatial3(x_rot3)

        x_out = x_out1 + x_out2 + x_out3
        scale = torch.sigmoid(x_out)  # broadcasting
        scale = scale.repeat(1, int(self.channel / 8), 1, 1)
        x_scale = x * scale

        return x_scale
Ejemplo n.º 13
0
def rand_90(img: torch.Tensor,
            img2: torch.Tensor = None,
            prob: float = 0.5) -> torch.Tensor:
    """ Randomly rotate the given the Image Tensor 90 degrees clockwise or 
      counterclockwise (random).
    Args: 
        img: Image Tensor to be rotated, in the form [C, H, W].
        img2: Second image Tensor to be rotated, in the form [C, H, W].
          (optional)
        prob (float): Probabilty of rotation. C-W and counter C-W have same
          probability of happening.
    Returns:
        Tensor: Rotated image Tensor.
        (Careful if image dimensions are not square)
    """
    #if not _is_tensor_a_torch_image(img):
    #raise TypeError('tensor is not a torch image.')

    if np.random.random() < prob / 2.:
        img = torch.rot90(img, 1, dims=[2, 3])
        #img = img.transpose(2, 3).flip(2)
        if img2 is not None:
            img2 = torch.rot90(img2, 1, dims=[2, 3])
    elif np.random.random() < prob:
        img = torch.rot90(img, -1, dims=[2, 3])
        #img = img.transpose(2, 3).flip(3)
        if img2 is not None:
            img2 = torch.rot90(img2, -1, dims=[2, 3])
    if img2 is not None:
        return img, img2
    else:
        return img
Ejemplo n.º 14
0
def TTA(net, image, mode='cls'):
    """
    Do test time augmentations on single image for classification or segmentation.
    Note: Only suport single image per time!
    Args:
        image: [N, C, H, W] tensor of image (have transformed)
        mode: 'cls' for classification, 'seg' for segmentation.

    """
    # predict a complete image
    aug_imgs = []
    for i in range(4):
        aug_imgs.append(torch.rot90(image.clone(), i, dims=(3, 2)))
    aug_imgs.append(torch.flip(image.clone(), [2]))  # filp H
    aug_imgs.append(torch.flip(image.clone(), [3]))  # filp W
    aug_imgs = torch.cat(aug_imgs, dim=0)
    outputs = net(aug_imgs)
    if mode == 'cls':
        # outputs: [NC]
        predict = outputs.mean(dim=0, keepdim=True)
    elif mode == 'seg':
        # outputs: [NCHW]
        predict = torch.flip(outputs[5, None].clone(), [3])
        predict += torch.flip(outputs[4, None].clone(), [2])
        for i in range(4):
            predict += torch.rot90(outputs[i, None].clone(), i, dims=(2, 3))

    return predict
Ejemplo n.º 15
0
    def __getitem__(self, index):
        index_ = index % self.sizex
        ps = self.ps

        inp_path = self.inp_filenames[index_]
        tar_path = self.tar_filenames[index_]

        inp_img = Image.open(inp_path)
        tar_img = Image.open(tar_path)

        w, h = tar_img.size
        padw = ps - w if w < ps else 0
        padh = ps - h if h < ps else 0

        # Reflect Pad in case image is smaller than patch_size
        if padw != 0 or padh != 0:
            inp_img = TF.pad(inp_img, (0, 0, padw, padh),
                             padding_mode='reflect')
            tar_img = TF.pad(tar_img, (0, 0, padw, padh),
                             padding_mode='reflect')

        inp_img = TF.to_tensor(inp_img)
        tar_img = TF.to_tensor(tar_img)

        hh, ww = tar_img.shape[1], tar_img.shape[2]

        rr = random.randint(0, hh - ps)
        cc = random.randint(0, ww - ps)
        aug = random.randint(0, 8)

        # Crop patch
        inp_img = inp_img[:, rr:rr + ps, cc:cc + ps]
        tar_img = tar_img[:, rr:rr + ps, cc:cc + ps]

        # Data Augmentations
        if aug == 1:
            inp_img = inp_img.flip(1)
            tar_img = tar_img.flip(1)
        elif aug == 2:
            inp_img = inp_img.flip(2)
            tar_img = tar_img.flip(2)
        elif aug == 3:
            inp_img = torch.rot90(inp_img, dims=(1, 2))
            tar_img = torch.rot90(tar_img, dims=(1, 2))
        elif aug == 4:
            inp_img = torch.rot90(inp_img, dims=(1, 2), k=2)
            tar_img = torch.rot90(tar_img, dims=(1, 2), k=2)
        elif aug == 5:
            inp_img = torch.rot90(inp_img, dims=(1, 2), k=3)
            tar_img = torch.rot90(tar_img, dims=(1, 2), k=3)
        elif aug == 6:
            inp_img = torch.rot90(inp_img.flip(1), dims=(1, 2))
            tar_img = torch.rot90(tar_img.flip(1), dims=(1, 2))
        elif aug == 7:
            inp_img = torch.rot90(inp_img.flip(2), dims=(1, 2))
            tar_img = torch.rot90(tar_img.flip(2), dims=(1, 2))

        filename = os.path.splitext(os.path.split(tar_path)[-1])[0]

        return tar_img, inp_img, filename
Ejemplo n.º 16
0
    def forward(self, input):
        weight = self.weight.reshape(self.out_features, self.grid_with,
                                     self.grid_with, self.in_features)
        weight1 = torch.rot90(weight, 1, [1, 2])
        weight1 = weight1.reshape(
            self.out_features,
            self.grid_with * self.grid_with * self.in_features)
        x1 = F.linear(input, weight1, self.bias)

        weight2 = torch.rot90(weight, 2, [1, 2])
        weight2 = weight2.reshape(
            self.out_features,
            self.grid_with * self.grid_with * self.in_features)
        x2 = F.linear(input, weight2, self.bias)

        weight3 = torch.rot90(weight, 3, [1, 2])
        weight3 = weight3.reshape(
            self.out_features,
            self.grid_with * self.grid_with * self.in_features)
        x3 = F.linear(input, weight3, self.bias)

        weight0 = weight.reshape(
            self.out_features,
            self.grid_with * self.grid_with * self.in_features)
        x0 = F.linear(input, weight0, self.bias)
        m = nn.MaxPool1d(4, stride=4)
        x = m(torch.stack([x0, x1, x2, x3]).permute(1, 2, 0)).squeeze()
        return x
def combine_augmented(outputs):
    assert len(outputs) == 16
    for i in range(8, 16):
        outputs[i] = torch.flip(outputs[i], [2])
    for i in range(4, 8):
        outputs[i] = torch.flip(outputs[i], [3])
    for i in range(12, 16):
        outputs[i] = torch.flip(outputs[i], [3])
    for i in range(1, 16, 4):
        outputs[i] = torch.rot90(outputs[i], -1, [3, 4])
    for i in range(2, 16, 4):
        outputs[i] = torch.rot90(outputs[i], -1, [3, 4])
        outputs[i] = torch.rot90(outputs[i], -1, [3, 4])
    for i in range(3, 16, 4):
        outputs[i] = torch.rot90(outputs[i], 1, [3, 4])

    # output = torch.zeros_like(outputs[0], dtype=torch.float64)
    # for i in range(len(outputs)):
    #     output += outputs[i].double()
    # output = output / 16.0
    for i in range(len(outputs)):
        outputs[i] = outputs[i].unsqueeze(0)
    output = torch.min(torch.cat(outputs, 0), 0)[0]

    return output, outputs
Ejemplo n.º 18
0
def augment_data(x, y, n=None):
    """
    Generate an augmented training dataset with random reflections
    and 90 degree rotations
    x, y : Image sets of shape (Samples, Width, Height, Channels)
        training images and next images
    n : number of training examples
    """
    n_data = x.shape[0]

    if not n:
        n = n_data
    x_out, y_out = list(), list()

    for i in range(n):
        r = random.randint(0, n_data)
        x_r, y_r = x[r], y[r]

        if random.random() < 0.5:
            x_r = torch.fliplr(x_r)
            y_r = torch.fliplr(y_r)
        if random.random() < 0.5:
            x_r = torch.flipud(x_r)
            y_r = torch.flipud(y_r)

        num_rots = random.randint(0, 4)
        x_r = torch.rot90(x_r, k=num_rots)
        y_r = torch.rot90(y_r, k=num_rots)

        x_out.append(x_r), y_out.append(y_r)
    return torch.stack(x_out), torch.stack(y_out)
Ejemplo n.º 19
0
 def forward(self, input):
     weight_0 = self.weight.data
     weight_90 = torch.rot90(weight_0, 1, [2, 3])
     weight_180 = torch.rot90(weight_90, 1, [2, 3])
     weight_270 = torch.rot90(weight_180, 1, [2, 3])
     weight_all = torch.cat((weight_0, weight_90, weight_180, weight_270), dim=0)
     bias_all = torch.cat((self.bias, self.bias, self.bias, self.bias))
     return F.conv2d(input, weight=weight_all, bias=bias_all, padding=self.padding)
Ejemplo n.º 20
0
def rotate(x, angle):
    if angle == 0:
        return x
    elif angle == 90:
        return torch.rot90(x, k=1, dims=(3, 2))
    elif angle == 180:
        return torch.rot90(x, k=2, dims=(3, 2))
    elif angle == 270:
        return torch.rot90(x, k=3, dims=(3, 2))
Ejemplo n.º 21
0
    def __getitem__(self, idex):
        # Randomly choose two images
        choosen_index = random.randint(0, len(self.data_list) - 1)
        img1 = Image.open(
            os.path.join(self.data_path,
                         self.data_list[choosen_index])).convert('L')
        choosen_index = random.randint(0, len(self.data_list) - 1)
        img2 = Image.open(
            os.path.join(self.data_path,
                         self.data_list[choosen_index])).convert('L')
        img1 = self.transform_list[0](img1)
        img2 = self.transform_list[0](img2)

        # Generate similar pairs
        img1_1 = self.transform_list[1](img1)
        img2_1 = self.transform_list[1](img2)
        img1_2 = self.transform_list[1](img1)
        img2_2 = self.transform_list[1](img2)

        img1_2_flip = random.randint(0, 1)
        img1_2_rot = random.randint(0, 3)
        if self.if_flip:
            if img1_2_flip:
                img1_2 = torch.flip(img1_2, dims=(1, 2))
        if self.if_rot_90:
            if img1_2_rot:
                img1_2 = torch.rot90(img1_2, k=img1_2_rot, dims=(1, 2))

        img2_2_flip = random.randint(0, 1)
        img2_2_rot = random.randint(0, 3)
        if self.if_flip:
            if img2_2_flip:
                img2_2 = torch.flip(img2_2, dims=(1, 2))
        if self.if_rot_90:
            if img2_2_rot:
                img2_2 = torch.rot90(img2_2, k=img2_2_rot, dims=(1, 2))

        if self.if_mixup:
            # IF if_mixup is True, mixup images will be created.
            img1_weight = random.random() * 0.6 + 0.2
            img3 = img1_weight * img1_1 + (1 - img1_weight) * img2_1
            img3 = self.transform_list[2](img3)
            img1_weight = torch.tensor(img1_weight)

            img1_1 = self.transform_list[2](img1_1)
            img1_2 = self.transform_list[2](img1_2)
            img2_1 = self.transform_list[2](img2_1)
            img2_2 = self.transform_list[2](img2_2)

            return img1_1, img1_2, img2_1, img2_2, img3, img1_weight, img1_2_flip, img1_2_rot, img2_2_flip, img2_2_rot
        else:
            img1_1 = self.transform_list[2](img1_1)
            img1_2 = self.transform_list[2](img1_2)
            img2_1 = self.transform_list[2](img2_1)
            img2_2 = self.transform_list[2](img2_2)

            return img1_1, img1_2, img2_1, img2_2, img1_2_flip, img1_2_rot, img2_2_flip, img2_2_rot
Ejemplo n.º 22
0
    def data_augmentation(image, mask):
        image = torch.Tensor(image)
        mask = torch.Tensor(mask)
        mask = mask.unsqueeze(0)

        if random.random() < 0.5:
            # flip left right
            image = torch.fliplr(image)
            mask = torch.fliplr(mask)

        rot = np.random.choice([0, 1, 2, 3])
        image = torch.rot90(image, rot, [1, 2])
        mask = torch.rot90(mask, rot, [1, 2])

        if random.random() < 0.5:
            # flip up-down
            image = torch.flipud(image)
            mask = torch.flipud(mask)

        if intensity >= 1:

            # random crop
            cropsize = image.shape[2] // 2
            image, mask = random_crop(image, mask, cropsize=cropsize)

            std_noise = 1 * image.std()
            if random.random() < 0.5:
                # add noise per pixel and per channel
                pixel_noise = torch.rand(image.shape[1], image.shape[2])
                pixel_noise = torch.repeat_interleave(pixel_noise.unsqueeze(0),
                                                      image.size(0),
                                                      dim=0)
                image = image + pixel_noise * std_noise

            if random.random() < 0.5:
                channel_noise = torch.rand(
                    image.shape[0]).unsqueeze(1).unsqueeze(2)
                channel_noise = torch.repeat_interleave(
                    torch.repeat_interleave(channel_noise, image.shape[1], 1),
                    image.shape[2], 2)
                image = image + channel_noise * std_noise

            if random.random() < 0.5:
                # add noise
                noise = torch.rand(image.shape[0], image.shape[1],
                                   image.shape[2]) * std_noise
                image = image + noise

        if intensity >= 2:
            # channel shuffle
            if random.random() < 0.5:
                idxs = np.arange(image.shape[0])
                np.random.shuffle(idxs)  # random band indixes
                image = image[idxs]

        mask = mask.squeeze(0)
        return image, mask
Ejemplo n.º 23
0
    def __call__(self, image):
        angle = random.randint(0, 3)
        if type(image) is not list:
            image = torch.rot90(image, angle, [1, 2])
        else:
            for i in range(len(image)):
                image[i] = torch.rot90(image[i], angle, [1, 2])

        return image
def revert_tta_factory(flip, rot):
    if flip and rot:
        return lambda x: torch.rot90(x.flip(flip), rot, dims=(3, 4))
    elif flip:
        return lambda x: x.flip(flip)
    elif rot:
        return lambda x: torch.rot90(x, rot, dims=(3, 4))
    else:
        raise
Ejemplo n.º 25
0
def val_pred(MaskCN, EnhanceSN, image, coarsemask):

    rot_90 = torch.rot90(image, 1, [2, 3])
    rot_180 = torch.rot90(image, 2, [2, 3])
    rot_270 = torch.rot90(image, 3, [2, 3])
    hor_flip = torch.flip(image, [-1])
    ver_flip = torch.flip(image, [-2])
    image = torch.cat([image, rot_90, rot_180, rot_270, hor_flip, ver_flip],
                      dim=0)

    rot_90_cm = torch.rot90(coarsemask, 1, [2, 3])
    rot_180_cm = torch.rot90(coarsemask, 2, [2, 3])
    rot_270_cm = torch.rot90(coarsemask, 3, [2, 3])
    hor_flip_cm = torch.flip(coarsemask, [-1])
    ver_flip_cm = torch.flip(coarsemask, [-2])
    coarsemask = torch.cat([
        coarsemask, rot_90_cm, rot_180_cm, rot_270_cm, hor_flip_cm, ver_flip_cm
    ],
                           dim=0)

    EnhanceSN.eval()
    with torch.no_grad():
        data_cla = torch.cat((image, coarsemask), dim=1)
        cla_cam = cam(MaskCN, data_cla)
        cla_cam = torch.from_numpy(np.stack(cla_cam)).unsqueeze(1).cuda()
        pred = EnhanceSN(image, cla_cam)

    pred = pred[0:1] + torch.rot90(pred[1:2], 3, [2, 3]) + torch.rot90(
        pred[2:3], 2, [2, 3]) + torch.rot90(pred[3:4], 1, [2, 3]) + torch.flip(
            pred[4:5], [-1]) + torch.flip(pred[5:6], [-2])

    return pred
Ejemplo n.º 26
0
    def rotx4_forward(self, lq):
        """Flip testing with rotation self ensemble, i.e., normal,90, 180, 270
        Args:
            model (PyTorch model)
            inp (Tensor): inputs defined by the model

        Returns:
            output (Tensor): outputs of the model. float
        """
        # normal
        output_r = self.generator(lq)
        lq_90 = torch.rot90(lq, 1, [-1, -2])
        lq_180 = torch.rot90(lq_90, 1, [-1, -2])
        lq_270 = torch.rot90(lq_180, 1, [-1, -2])
        # counter-clockwise 90
        output = self.generator(lq_90)
        output_r = output_r + torch.rot90(output, 1, [-2, -1])
        # counter-clockwise 180
        output = self.generator(lq_180)
        output_r = output_r + torch.rot90(torch.rot90(output, 1, [-2, -1]), 1,
                                          [-2, -1])
        # counter-clockwise 270
        output = self.generator(lq_270)
        output_r = output_r + torch.rot90(
            torch.rot90(torch.rot90(output, 1, [-2, -1]), 1, [-2, -1]), 1,
            [-2, -1])
        return output_r / 4
Ejemplo n.º 27
0
    def forward(self, x, idx_scale):
        self.idx_scale = idx_scale
        if hasattr(self.model, 'set_scale'):
            self.model.set_scale(idx_scale)

        if self.training:
            if self.n_GPUs > 1:
                return P.data_parallel(self.model, x, range(self.n_GPUs))
            else:
                return self.model(x)
        else:
            if self.chop:
                forward_function = self.forward_chop
            else:
                forward_function = self.model.forward

            y = None
            for i in range(len(self.chop_patch_size)):
                if self.self_ensemble:
                    if y is None:
                        y = self.forward_x8(x,
                                            forward_function=forward_function,
                                            chop=self.chop,
                                            min_size=self.chop_patch_size[i] *
                                            self.chop_patch_size[i],
                                            shave_size=self.shave_size[i])
                    else:
                        y += self.forward_x8(x,
                                             forward_function=forward_function,
                                             chop=self.chop,
                                             min_size=self.chop_patch_size[i] *
                                             self.chop_patch_size[i],
                                             shave_size=self.shave_size[i])
                else:
                    rot_x = torch.rot90(x, i, [2, 3])
                    if y is None:
                        if self.chop:
                            y = forward_function(
                                rot_x,
                                shave=self.shave_size[i],
                                min_size=self.chop_patch_size[i] *
                                self.chop_patch_size[i])
                        else:
                            y = forward_function(rot_x)
                    else:
                        if self.chop:
                            rot_y = forward_function(
                                rot_x,
                                shave=self.shave_size[i],
                                min_size=self.chop_patch_size[i] *
                                self.chop_patch_size[i])
                        else:
                            rot_y = forward_function(rot_x)
                        y += torch.rot90(rot_y, -i, [2, 3])
            return y.div(float(len(self.chop_patch_size)))
Ejemplo n.º 28
0
def apply_tta(input):
    inputs = []
    inputs.append(input)
    inputs.append(torch.flip(input, dims=[2]))
    inputs.append(torch.flip(input, dims=[3]))
    inputs.append(torch.rot90(input, k=1, dims=[2, 3]))
    inputs.append(torch.rot90(input, k=2, dims=[2, 3]))
    inputs.append(torch.rot90(input, k=3, dims=[2, 3]))
    inputs.append(torch.rot90(torch.flip(input, dims=[2]), k=1, dims=[2, 3]))
    inputs.append(torch.rot90(torch.flip(input, dims=[2]), k=3, dims=[2, 3]))
    return inputs
Ejemplo n.º 29
0
    def forward(self, input):
        assert len(input.shape) == 5
        grasp_object = input[:, 0, :, :, :].unsqueeze(dim=1)
        left_finger = input[:, 1, :, :, :].unsqueeze(dim=1)
        right_finger = input[:, 2, :, :, :].unsqueeze(dim=1)

        left_finger = rot90(left_finger, 2, (3, 4))
        right_finger = rot90(right_finger, 2, (2, 4))
        input_tsdf = cat([left_finger, grasp_object, right_finger], dim=2)

        return self.net(input_tsdf)
Ejemplo n.º 30
0
 def before_batch(self):
     x = self.xb[0].clone()
     y = self.yb[0].clone()
     randint = np.random.randint(0, 4, x.shape[0])
     for i in range(x.shape[0]):
         x[i, 0] = torch.rot90(x[i, 0], int(randint[i]))
         x[i, 1] = torch.rot90(x[i, 1], int(randint[i]))
         y[i, 0] = torch.rot90(y[i, 0], int(randint[i]))
         y[i, 1] = torch.rot90(y[i, 1], int(randint[i]))
     self.learn.xb = [x]
     self.learn.yb = [y]