def _find_samples_in_subfolders(self, dir):
     """
     Finds the class folders in a dataset.
     Args:
         dir (string): Root directory path.
     Returns:
         tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
     Ensures:
         No class is a subdirectory of another.
     """
     if sys.version_info >= (3, 5):
         # Faster and available in Python 3.5 and above
         classes = [d.name for d in os.scandir(dir) if d.is_dir()]
     else:
         classes = [
             d for d in os.listdir(dir)
             if os.path.isdir(os.path.join(dir, d))
         ]
     classes.sort()
     class_to_idx = {classes[i]: i for i in range(len(classes))}
     samples = []
     for target in sorted(class_to_idx.keys()):
         d = os.path.join(dir, target)
         if not os.path.isdir(d):
             continue
         for root, _, fnames in sorted(os.walk(d)):
             for fname in sorted(fnames):
                 if is_image_file(fname):
                     path = os.path.join(root, fname)
                     # item = (path, class_to_idx[target])
                     # samples.append(item)
                     samples.append(path)
     return samples
Пример #2
0
 def __init__(self, TYPE='bottle', isTrain='train'):
     self.gt_path = '/root/AFS/Corn/AEGAN/MVTec/'+TYPE+'/ground_truth_resize/all'
     self.train_path = '/root/AFS/Corn/AEGAN/MVTec/'+TYPE+'/train_resize/train'
     self.val_path = '/root/AFS/Corn/AEGAN/MVTec/'+TYPE+'/train_resize/validation'
     self.test_path = '/root/AFS/Corn/AEGAN/MVTec/'+TYPE+'/test_resize/all'
     
     self.data_path = self.train_path if isTrain=='train' else self.val_path if isTrain=='val' else self.test_path
     self.samples = [x for x in os.listdir(self.data_path) if is_image_file(x)]
     self.isTrain = isTrain
Пример #3
0
    def __init__(self, TYPE='bottle', isTrain=True):
        self.gt_path = '../MVTec/' + TYPE + '/ground_truth_resize/all'
        self.train_path = '../MVTec/' + TYPE + '/train_resize/'
        self.test_path = '../MVTec/' + TYPE + '/test_resize/all'

        self.data_path = self.train_path if isTrain else self.test_path
        self.samples = [
            x for x in os.listdir(self.data_path) if is_image_file(x)
        ]
        self.isTrain = isTrain
Пример #4
0
 def __init__(self, data_path, image_shape, with_subfolder=False, random_crop=True, return_name=False):
     super(Dataset, self).__init__()
     if with_subfolder:
         self.samples = self._find_samples_in_subfolders(data_path)
     else:
         self.samples = [x for x in listdir(data_path) if is_image_file(x)]
     self.with_subfolder = with_subfolder
     self.data_path = data_path
     self.image_shape = image_shape[:-1]
     self.random_crop = random_crop
     self.return_name = return_name
    def __init__(self,
                 data_path,
                 image_shape,
                 with_subfolder=False,
                 random_crop=True,
                 return_name=False):
        super(Dataset, self).__init__()
        self.data_path = data_path
        if with_subfolder:
            self.samples = self._find_samples_in_subfolders(self.data_path)
        else:
            self.samples = [
                os.path.join(self.data_path, x) for x in listdir(data_path)
                if is_image_file(x)
            ]
            # this is just the image file names not images, filtering out all non image files

        self.image_shape = image_shape[:-1]
        self.random_crop = random_crop
        self.return_name = return_name
Пример #6
0
    def __init__(self,
                 data_path,
                 image_shape,
                 with_subfolder=False,
                 random_crop=True,
                 return_name=False):
        super(Dataset, self).__init__()
        print(f"data_path: {data_path}")
        print(f"with_subfolder: {with_subfolder}")

        if with_subfolder:
            self.samples = self._find_samples_in_subfolders(data_path)
        else:
            self.samples = [x for x in listdir(data_path) if is_image_file(x)]

        print(f"Found files: {len(self.samples)}")

        self.data_path = data_path
        self.image_shape = image_shape[:-1]
        self.random_crop = random_crop
        self.return_name = return_name
Пример #7
0
    def __init__(self,
                 data_path,
                 gt_path,
                 image_shape,
                 with_subfolder=False,
                 random_crop=True,
                 return_name=False):
        super(Parse_Dataset, self).__init__()
        if with_subfolder:
            #self.samples = self._find_samples_in_subfolders(data_path)
            self.samples = self._find_samples_in_subfolders(gt_path)
        else:
            #self.samples = [x for x in listdir(data_path) if is_image_file(x)]
            self.samples = [x for x in listdir(gt_path) if is_image_file(x)]

        self.data_path = data_path
        self.gt_path = gt_path  # because gt has fewer images, we have to fix the number same with gt images
        self.image_shape = image_shape[:-1]
        self.random_crop = random_crop
        self.return_name = return_name
        print(str(len(self.samples)) + "  items found")

        #--------- SEGMENTATION STATS
        self.n_classes = 17
Пример #8
0
def generate(img, img_mask_path, model_path):
    with torch.no_grad():   # enter no grad context
        if img_mask_path and is_image_file(img_mask_path):
            # Test a single masked image with a given mask
            x = Image.fromarray(img)
            mask = default_loader(img_mask_path)
            x = transforms.Resize(config['image_shape'][:-1])(x)
            x = transforms.CenterCrop(config['image_shape'][:-1])(x)
            mask = transforms.Resize(config['image_shape'][:-1])(mask)
            mask = transforms.CenterCrop(config['image_shape'][:-1])(mask)
            x = transforms.ToTensor()(x)
            mask = transforms.ToTensor()(mask)[0].unsqueeze(dim=0)
            x = normalize(x)
            x = x * (1. - mask)
            x = x.unsqueeze(dim=0)
            mask = mask.unsqueeze(dim=0)
        elif img_mask_path:
            raise TypeError("{} is not an image file.".format(img_mask_path))
        else:
            # Test a single ground-truth image with a random mask
            #ground_truth = default_loader(img_path)
            ground_truth = img
            ground_truth = transforms.Resize(config['image_shape'][:-1])(ground_truth)
            ground_truth = transforms.CenterCrop(config['image_shape'][:-1])(ground_truth)
            ground_truth = transforms.ToTensor()(ground_truth)
            ground_truth = normalize(ground_truth)
            ground_truth = ground_truth.unsqueeze(dim=0)
            bboxes = random_bbox(config, batch_size=ground_truth.size(0))
            x, mask = mask_image(ground_truth, bboxes, config)

        # Set checkpoint path
        if not model_path:
            checkpoint_path = os.path.join('checkpoints',
                                           config['dataset_name'],
                                           config['mask_type'] + '_' + config['expname'])
        else:
            checkpoint_path = model_path

        # Define the trainer
        netG = Generator(config['netG'], cuda, device_ids)
        # Resume weight
        last_model_name = get_model_list(checkpoint_path, "gen", iteration=0)
        
        if cuda:
            netG.load_state_dict(torch.load(last_model_name))
        else:
            netG.load_state_dict(torch.load(last_model_name, map_location='cpu'))
                                 
        model_iteration = int(last_model_name[-11:-3])
        print("Resume from {} at iteration {}".format(checkpoint_path, model_iteration))

        if cuda:
            netG = nn.parallel.DataParallel(netG, device_ids=device_ids)
            x = x.cuda()
            mask = mask.cuda()

        # Inference
        x1, x2, offset_flow = netG(x, mask)
        inpainted_result = x2 * mask + x * (1. - mask)
        inpainted_result =  from_torch_img_to_numpy(inpainted_result, 'output.png', padding=0, normalize=True)

        return inpainted_result
def main():
    args = parser.parse_args()
    config = get_config(args.config)

    # CUDA configuration
    cuda = config['cuda']
    device_ids = config['gpu_ids']
    if cuda:
        os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
            str(i) for i in device_ids)
        device_ids = list(range(len(device_ids)))
        config['gpu_ids'] = device_ids
        cudnn.benchmark = True

    print("Arguments: {}".format(args))

    # Set random seed
    if args.seed is None:
        args.seed = random.randint(1, 10000)
    print("Random seed: {}".format(args.seed))
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    if cuda:
        torch.cuda.manual_seed_all(args.seed)

    print("Configuration: {}".format(config))

    try:  # for unexpected error logging
        with torch.no_grad():   # enter no grad context
            if is_image_file(args.image):
                if args.mask and is_image_file(args.mask):
                    # Test a single masked image with a given mask
                    x = default_loader(args.image)
                    mask = default_loader(args.mask)
                    x = transforms.Resize(config['image_shape'][:-1])(x)
                    x = transforms.CenterCrop(config['image_shape'][:-1])(x)
                    mask = transforms.Resize(config['image_shape'][:-1])(mask)
                    mask = transforms.CenterCrop(
                        config['image_shape'][:-1])(mask)
                    x = transforms.ToTensor()(x)
                    mask = transforms.ToTensor()(mask)[0].unsqueeze(dim=0)
                    x = normalize(x)
                    x = x * (1. - mask)
                    x = x.unsqueeze(dim=0)
                    mask = mask.unsqueeze(dim=0)
                elif args.mask:
                    raise TypeError(
                        "{} is not an image file.".format(args.mask))
                else:
                    # Test a single ground-truth image with a random mask
                    ground_truth = default_loader(args.image)
                    ground_truth = transforms.Resize(
                        config['image_shape'][:-1])(ground_truth)
                    ground_truth = transforms.CenterCrop(
                        config['image_shape'][:-1])(ground_truth)
                    ground_truth = transforms.ToTensor()(ground_truth)
                    ground_truth = normalize(ground_truth)
                    ground_truth = ground_truth.unsqueeze(dim=0)
                    bboxes = random_bbox(
                        config, batch_size=ground_truth.size(0))
                    x, mask = mask_image(ground_truth, bboxes, config)

                # Set checkpoint path
                if not args.checkpoint_path:
                    checkpoint_path = os.path.join('checkpoints',
                                                   config['dataset_name'],
                                                   config['mask_type'] + '_' + config['expname'])
                else:
                    checkpoint_path = args.checkpoint_path

                # Define the trainer
                netG = Generator(config['netG'], cuda, device_ids)
                # Resume weight
                last_model_name = get_model_list(
                    checkpoint_path, "gen", iteration=args.iter)
                netG.load_state_dict(torch.load(last_model_name))
                model_iteration = int(last_model_name[-11:-3])
                print("Resume from {} at iteration {}".format(
                    checkpoint_path, model_iteration))

                if cuda:
                    netG = nn.parallel.DataParallel(
                        netG, device_ids=device_ids)
                    x = x.cuda()
                    mask = mask.cuda()

                # Inference
                x1, x2, offset_flow = netG(x, mask)
                inpainted_result = x2 * mask + x * (1. - mask)

                vutils.save_image(inpainted_result, args.output,
                                  padding=0, normalize=True)
                print("Saved the inpainted result to {}".format(args.output))
                if args.flow:
                    vutils.save_image(offset_flow, args.flow,
                                      padding=0, normalize=True)
                    print("Saved offset flow to {}".format(args.flow))
            else:
                raise TypeError("{} is not an image file.".format)
        # exit no grad context
    except Exception as e:  # for unexpected error logging
        print("Error: {}".format(e))
        raise e
Пример #10
0
def main():
    args = parser.parse_args()
    config = get_config(args.config)

    # CUDA configuration
    cuda = config['cuda']
    device_ids = config['gpu_ids']
    if cuda:
        os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
            str(i) for i in device_ids)
        device_ids = list(range(len(device_ids)))
        config['gpu_ids'] = device_ids
        cudnn.benchmark = True

    # Set random seed
    if args.seed is None:
        args.seed = random.randint(1, 10000)
    print("Random seed: {}".format(args.seed))
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    if cuda:
        torch.cuda.manual_seed_all(args.seed)

    chunker = ImageChunker(config['image_shape'][0], config['image_shape'][1],
                           args.overlap)
    try:  # for unexpected error logging
        with torch.no_grad():  # enter no grad context
            if is_image_file(args.image):
                print("Loading image...")
                imgs, masks = [], []
                img_ori = default_loader(args.image)
                img_w, img_h = img_ori.size
                # Load mask txt file
                fname = args.image.replace('.jpg', '.txt')
                bboxes, _ = load_bbox_txt(fname, img_w, img_h)
                mask_ori = create_mask(bboxes, img_w, img_h)
                chunked_images = chunker.dimension_preprocess(
                    np.array(deepcopy(img_ori)))
                chunked_masks = chunker.dimension_preprocess(
                    np.array(deepcopy(mask_ori)))
                for (x, msk) in zip(chunked_images, chunked_masks):
                    x = transforms.ToTensor()(x)
                    mask = transforms.ToTensor()(msk)[0].unsqueeze(dim=0)
                    # x = normalize(x)
                    x = x * (1. - mask)
                    x = x.unsqueeze(dim=0)
                    mask = mask.unsqueeze(dim=0)
                    imgs.append(x)
                    masks.append(mask)

                # Set checkpoint path
                if not args.checkpoint_path:
                    checkpoint_path = os.path.join(
                        'checkpoints', config['dataset_name'],
                        config['mask_type'] + '_' + config['expname'])
                else:
                    checkpoint_path = args.checkpoint_path

                # Define the trainer
                netG = Generator(config['netG'], cuda, device_ids)
                # Resume weight
                last_model_name = get_model_list(checkpoint_path,
                                                 "gen",
                                                 iteration=args.iter)
                netG.load_state_dict(torch.load(last_model_name))
                model_iteration = int(last_model_name[-11:-3])
                print("Resume from {} at iteration {}".format(
                    checkpoint_path, model_iteration))

                pred_imgs = []
                for (x, mask) in zip(imgs, masks):
                    if torch.max(mask) == 1:
                        if cuda:
                            netG = nn.parallel.DataParallel(
                                netG, device_ids=device_ids)
                            x = x.cuda()
                            mask = mask.cuda()

                        # Inference
                        x1, x2, offset_flow = netG(x, mask)
                        inpainted_result = x2 * mask + x * (1. - mask)
                        inpainted_result = inpainted_result.squeeze(
                            dim=0).permute(1, 2, 0).cpu()
                        pred_imgs.append(inpainted_result.numpy())
                    else:
                        pred_imgs.append(
                            x.squeeze(dim=0).permute(1, 2, 0).numpy())

                pred_imgs = np.asarray(pred_imgs, dtype=np.float32)
                reconstructed_image = chunker.dimension_postprocess(
                    pred_imgs, np.array(img_ori))
                # plt.imshow(reconstructed_image); plt.show()
                reconstructed_image = torch.tensor(
                    reconstructed_image).permute(2, 0, 1).unsqueeze(dim=0)
                vutils.save_image(reconstructed_image,
                                  args.output,
                                  padding=0,
                                  normalize=True)
                print("Saved the inpainted result to {}".format(args.output))
                if args.flow:
                    vutils.save_image(offset_flow,
                                      args.flow,
                                      padding=0,
                                      normalize=True)
                    print("Saved offset flow to {}".format(args.flow))
            else:
                raise TypeError("{} is not an image file.".format)
        # exit no grad context
    except Exception as e:  # for unexpected error logging
        print("Error: {}".format(e))
        raise e
Пример #11
0
def generateInpaintedImage(args, netG, imagePath):
    config = get_config(args.g_config)
    occlusions = []

    # CUDA configuration
    cuda = config['cuda']
    device_ids = config['gpu_ids']
    if cuda:
        os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
            str(i) for i in device_ids)
        device_ids = list(range(len(device_ids)))
        config['gpu_ids'] = device_ids
        cudnn.benchmark = True

    print("Arguments: {}".format(args))

    # Set random seed
    if args.seed is None:
        args.seed = random.randint(1, 10000)
    print("Random seed: {}".format(args.seed))
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    if cuda:
        torch.cuda.manual_seed_all(args.seed)

    try:  # for unexpected error logging
        with torch.no_grad():  # enter no grad context
            if is_image_file(imagePath):
                if args.mask and is_image_file(args.mask):
                    # Test a multiple masked image with a given mask
                    x = default_loader(imagePath)
                    x = transforms.Resize([512, 1024])(x)

                    mask = default_loader(args.mask)
                    mask = transforms.Resize(config['image_shape'][:-1])(mask)
                    mask = transforms.CenterCrop(
                        config['image_shape'][:-1])(mask)
                    mask = transforms.ToTensor()(mask)[0].unsqueeze(dim=0)
                    mask = mask.unsqueeze(dim=0)

                    w, h = x.size
                    first = x.crop((0, 0, w // 3, h))
                    second = x.crop((w // 3, 0, ((w // 3) * 2) + 2, h))
                    third = x.crop(((w // 3) * 2, 0, w, h))

                    for y in [first, second, third]:
                        y = transforms.CenterCrop(
                            config['image_shape'][:-1])(y)
                        y = transforms.ToTensor()(y)
                        y = normalize(y)
                        y = y * (1. - mask)
                        occlusions.append(y)

                elif args.mask:
                    raise TypeError("{} is not an image file.".format(
                        args.mask))

                default_image = default_loader(imagePath)
                di_w, di_h = default_image.size

                for idx, occlusion in enumerate(occlusions):
                    if cuda:
                        occlusion = occlusion.cuda()
                        mask = mask.cuda()

                    # Inference
                    x1, x2, offset_flow = netG(occlusion, mask)
                    inpainted_result = x2 * mask + occlusion * (1. - mask)

                    inp_hw = config['image_shape'][1]

                    if idx == 0:
                        offset = ((di_w // 3 - inp_hw) // 2,
                                  (di_h - inp_hw) // 2)
                    elif idx == 1:
                        offset = ((di_w - inp_hw) // 2, (di_h - inp_hw) // 2)
                    elif idx == 2:
                        offset = ((((di_w - inp_hw) // 2) + (di_w // 3)),
                                  (di_h - inp_hw) // 2)

                    grid = vutils.make_grid(inpainted_result, normalize=True)

                    # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
                    ndarr = grid.mul_(255).add_(0.5).clamp_(0, 255).permute(
                        1, 2, 0).to('cpu', torch.uint8).numpy()
                    im = Image.fromarray(ndarr)

                    im = transforms.CenterCrop(config['mask_shape'])(im)
                    im = transforms.Resize(config['image_shape'][:-1])(im)
                    default_image.paste(im, offset)

                return default_image
            else:
                raise TypeError("{} is not an image file.".format)
        # exit no grad context
    except Exception as e:  # for unexpected error logging
        print("Error: {}".format(e))
        raise e
def main():
    args = parser.parse_args()
    config = get_config(args.config)

    # CUDA configuration
    cuda = config['cuda']
    device_ids = config['gpu_ids']
    if cuda:
        os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
            str(i) for i in device_ids)
        device_ids = list(range(len(device_ids)))
        config['gpu_ids'] = device_ids
        cudnn.benchmark = True

    print("Arguments: {}".format(args))

    # Set random seed
    if args.seed is None:
        args.seed = random.randint(1, 10000)
    print("Random seed: {}".format(args.seed))
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    if cuda:
        torch.cuda.manual_seed_all(args.seed)

    print("Configuration: {}".format(config))

    try:  # for unexpected error logging
        with torch.no_grad():  # enter no grad context
            file = dataset_files(args.test_root, "*.jpg")
            mask_file = dataset_files(args.mask_dir, "*.png")
            for j in range(len(mask_file)):
                for i in range(len(file)):
                    if is_image_file(file[i]):
                        if mask_file and is_image_file(mask_file[j]):
                            # Test a single masked image with a given mask
                            x = default_loader(file[i])
                            mask = default_loader(mask_file[j])
                            # x = cv2.cvtColor(cv2.imread(file[i]), cv2.COLOR_BGR2RGB)
                            # mask = cv2.cvtColor(cv2.imread(mask_file[j]), cv2.COLOR_BGR2RGB)
                            # x = cv2.resize(x, (config['image_shape'][0], config['image_shape'][1]))
                            # mask = cv2.resize(mask, (config['image_shape'][0], config['image_shape'][1]))
                            x = transforms.Resize(
                                config['image_shape'][:-1])(x)
                            x = transforms.CenterCrop(
                                config['image_shape'][:-1])(x)
                            # mask = transforms.Resize(config['image_shape'][:-1])(mask)
                            # mask = transforms.CenterCrop(config['image_shape'][:-1])(mask)
                            x = transforms.ToTensor()(x)
                            mask = transforms.ToTensor()(mask)[0].unsqueeze(
                                dim=0)
                            x = normalize(x)
                            x = x * (1. - mask)
                            x = x.unsqueeze(dim=0)
                            # x_raw = x
                            mask = mask.unsqueeze(dim=0)
                        elif mask_file[j]:
                            raise TypeError("{} is not an image file.".format(
                                mask_file[j]))
                        else:
                            # Test a single ground-truth image with a random mask
                            ground_truth = default_loader(file[i])
                            ground_truth = transforms.Resize(
                                config['image_shape'][:-1])(ground_truth)
                            ground_truth = transforms.CenterCrop(
                                config['image_shape'][:-1])(ground_truth)
                            ground_truth = transforms.ToTensor()(ground_truth)
                            ground_truth = normalize(ground_truth)
                            ground_truth = ground_truth.unsqueeze(dim=0)
                            bboxes = test_bbox(config,
                                               batch_size=ground_truth.size(0),
                                               t=50,
                                               l=50)
                            x, mask = mask_image(ground_truth, bboxes, config)

                        # Set checkpoint path
                        if not args.checkpoint_path:
                            checkpoint_path = os.path.join(
                                'checkpoints', config['dataset_name'],
                                config['mask_type'] + '_' + config['expname'])
                        else:
                            checkpoint_path = args.checkpoint_path

                        # Define the trainer
                        netG = Generator(config['netG'], cuda, device_ids)
                        # Resume weight
                        g_checkpoint = torch.load(f'{checkpoint_path}/gen.pt')
                        netG.load_state_dict(g_checkpoint)
                        # model_iteration = int(last_model_name[-11:-3])
                        print("Model Resumed".format(checkpoint_path))

                        if cuda:
                            netG = nn.parallel.DataParallel(
                                netG, device_ids=device_ids)
                            x = x.cuda()
                            mask = mask.cuda()

                        # Inference
                        x1, x2 = netG(x, mask)
                        inpainted_result = x2 * mask + x * (1. - mask)
                        inpainted_result_cpu = torch.Tensor.cpu(
                            inpainted_result).detach().permute(0, 2, 3, 1)
                        inpainted_result_cpu = np.asarray(
                            inpainted_result_cpu[0])
                        inpainted_result_cpu = cv2.normalize(
                            inpainted_result_cpu, inpainted_result_cpu, 0, 255,
                            cv2.NORM_MINMAX)

                        # cat_result = torch.cat([x, inpainted_result, ground_truth], dim=3).cuda()

                        vutils.save_image(inpainted_result,
                                          args.output_dir +
                                          'output_{}/'.format(j + 1) +
                                          'output_{}.png'.format(i),
                                          padding=0,
                                          normalize=True)
                        # cv2.imwrite(args.output_dir+ 'output_{}/'.format(j+1) + 'output_{}.png'.format(i), inpainted_result_cpu)
                        #             cv2.cvtColor(inpainted_result_cpu, cv2.COLOR_BGR2RGB))
                        print("{}th image saved".format(i))
                    else:
                        raise TypeError("{} is not an image file.".format)
            # exit no grad context
    except Exception as e:  # for unexpected error logging
        print("Error: {}".format(e))
        raise e
Пример #13
0
print("Arguments: {}".format(args))

# Set random seed
if args.seed is None:
    args.seed = random.randint(1, 10000)
print("Random seed: {}".format(args.seed))
random.seed(args.seed)
torch.manual_seed(args.seed)
if cuda:
    torch.cuda.manual_seed_all(args.seed)

print("Configuration: {}".format(config))

try:  # for unexpected error logging
    with torch.no_grad():  # enter no grad context
        if is_image_file(args.image):
            if args.mask and is_image_file(args.mask):
                # Test a single masked image with a given mask
                x = tif_loader(args.image)
                ## center crop
                x = x[110:366, 110:366, :]
                x = torch.from_numpy(x)
                mask = default_loader(args.mask)  ## 476 --> 256
                #x = x[110:366, 110:366,:]
                #x = transforms.Resize(config['image_shape'][:-1])(x)
                #x = transforms.CenterCrop(config['image_shape'][:-1])(x)
                mask = transforms.Resize(config['image_shape'][:-1])(mask)
                mask = transforms.CenterCrop(config['image_shape'][:-1])(mask)
                #x = transforms.ToTensor()(x)
                mask = transforms.ToTensor()(mask)[0].unsqueeze(dim=0)
                x = normalize(x)