Beispiel #1
0
    def get_transform_val(self, size):
        if self.crop == 'five' or self.crop == 'multi':
            transform_val = [
                transforms.Resize(int(size[0] * (1.14))),
                transforms.FiveCrop(size)
            ]
            transform_val.append(
                transforms.Lambda(lambda crops: torch.stack(
                    [transforms.ToTensor()(crop) for crop in crops])))
            transform_val.append(
                transforms.Lambda(lambda crops: torch.stack([
                    transforms.Normalize([0.485, 0.456, 0.406],
                                         [0.229, 0.224, 0.225])(crop)
                    for crop in crops
                ])))
        else:
            transform_val = [
                transforms.Resize(int(size[0] * (1.14))),
                transforms.CenterCrop(size)
            ]
            transform_val.append(transforms.ToTensor())
            transform_val.append(
                transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225]))

        return transforms.Compose(transform_val)
Beispiel #2
0
def train(args):
    h5_file = h5py.File(args.output_path, 'w')

    lr_group = h5_file.create_group('lr')
    hr_group = h5_file.create_group('hr')

    image_list = sorted(glob.glob('{}/*'.format(args.images_dir)))
    patch_idx = 0

    for i, image_path in enumerate(image_list):
        hr = pil_image.open(image_path).convert('RGB')

        for hr in transforms.FiveCrop(size=(hr.height // 2,
                                            hr.width // 2))(hr):
            hr = hr.resize(((hr.width // args.scale) * args.scale,
                            (hr.height // args.scale) * args.scale),
                           resample=pil_image.BICUBIC)
            lr = hr.resize((hr.width // args.scale, hr.height // args.scale),
                           resample=pil_image.BICUBIC)

            hr = np.array(hr)
            lr = np.array(lr)

            lr_group.create_dataset(str(patch_idx), data=lr)
            hr_group.create_dataset(str(patch_idx), data=hr)

            patch_idx += 1

        print(i, patch_idx, image_path)

    h5_file.close()
Beispiel #3
0
def train(args):
    h5_file = h5py.File(args.output_path, 'w')

    lr_group = h5_file.create_group('lr')
    hr_group = h5_file.create_group('hr')

    image_list = sorted(glob.glob('{}/*'.format(args.images_dir + "//HR")))
    patch_idx = 0

    for i, image_path in enumerate(image_list):
        hr = pil_image.open(image_path).convert('RGB')
        lr_path = image_path[:-13] + "//LR" + image_path[-9:]
        lr = pil_image.open(lr_path).convert('RGB')

        hr = np.array(hr)
        lr = np.array(lr)

        for hr in transforms.FiveCrop(size=(hr.height // 2,
                                            hr.width // 2))(hr):
            hr = np.array(hr)
            lr = np.array(lr)

            lr_group.create_dataset(str(patch_idx), data=lr)
            hr_group.create_dataset(str(patch_idx), data=hr)

            patch_idx += 1

        print(i, patch_idx, image_path)

    h5_file.close()
Beispiel #4
0
    def __call__(self, img):
        c, h, w = img.size()
        if h < w:
            s = h
        else:
            s = w
        scales = tuple(map(lambda x: x / s, self.sizes))
        img = img.repeat(3, 1, 1, 1)
        cc = []

        for i in range(len(self.sizes)):
            c = ()
            im = transforms.Resize(
                tuple(map(lambda x: int(x * scales[i] * (256 / 224)),
                          (h, w))))(img[i, :, :, :])  # c, h, w, 3
            crops = transforms.FiveCrop(
                tuple(map(lambda x: int(x * scales[i]),
                          (h, w))))(im.unsqueeze(0))
            c += crops
            cc.append(torch.cat(c))

        return cc
Beispiel #5
0
    def __init__(
        self,
        epoch,
        dataset_path='./drive/My Drive/datasets/car classification/train_dataset',
        val_path='./drive/My Drive/datasets/car classification/val_data',
        batch_size=128,
        model_name='tf_efficientnet_b0_ns',
        ckpt_path='./drive/My Drive/ckpt/190.pth',
        test_number=5000,
        pseudo_test=True,
        crop='five',
        csv_path='',
        mode='fix',
        sizes=(680, 600, 528)):
        self.epoch = epoch
        self.dataset_path = dataset_path
        self.val_path = val_path
        self.batch_size = batch_size
        self.model_name = model_name
        self.ckpt_path = ckpt_path
        self.test_number = test_number
        self.pseudo_test = pseudo_test
        self.crop = crop
        self.csv_path = csv_path
        self.mode = mode
        self.sizes = sizes

        if model_name == 'tf_efficientnet_b0_ns':
            self.input_size = (224, 224)
        elif model_name == 'tf_efficientnet_b3_ns':
            self.input_size = (300, 300)
        elif model_name == 'tf_efficientnet_b4_ns':
            self.input_size = (480, 480)
        elif model_name == 'tf_efficientnet_b6_ns':
            self.input_size = (680, 680)  # 528
        else:
            raise Exception('non-valid model name')

        # Compose transforms
        transform = []
        fill = lambda i: transforms.Resize((i.size[1] * (2**torch.ceil(
            torch.log2(torch.tensor(self.input_size[1] / i.size[1]))
        )), i.size[0] * (2**torch.ceil(
            torch.log2(torch.tensor(self.input_size[1] / i.size[1]))))))(
                i) if i.size[0] < self.input_size[0] or i.size[
                    1] < self.input_size[1] else i
        if crop == 'center':
            transform.append(transforms.CenterCrop(self.input_size[0]))
            transform.append(transforms.ToTensor())
            transform.append(
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]))
        elif crop == 'five':
            transform.append(transforms.Lambda(fill))
            transform.append(transforms.FiveCrop(self.input_size[0]))
            transform.append(
                transforms.Lambda(lambda crops: torch.stack(
                    [transforms.ToTensor()(crop) for crop in crops])))
            transform.append(
                transforms.Lambda(lambda crops: torch.stack([
                    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
                    (crop) for crop in crops
                ])))
        self.transform = transforms.Compose(transform)

        if self.pseudo_test:
            if crop == 'multi':
                self.transform_val = []
                self.dataset = []
                self.dataloader = []
                for i in range(len(self.sizes)):
                    self.transform_val.append(
                        self.get_transform_val((self.sizes[i], self.sizes[i])))
                    self.dataset.append(
                        ImageFolder(self.dataset_path,
                                    transform=self.transform_val[i]))
                    self.dataloader.append(
                        DataLoader(self.dataset[i],
                                   batch_size=self.batch_size,
                                   num_workers=1,
                                   shuffle=False))
            else:
                self.dataset = ImageFolder(self.dataset_path,
                                           transform=self.transform_val)
                self.dataloader = DataLoader(self.dataset,
                                             batch_size=self.batch_size,
                                             num_workers=1,
                                             shuffle=False)

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.model = create_model(model_name, num_classes=196).to(self.device)
        if self.mode == 'fix':
            ckpt = torch.load(self.ckpt_path)
            self.model.load_state_dict(ckpt['model'])
        else:
            ckpt = torch.load(self.ckpt_path)
            self.model.load_state_dict(ckpt['model_state_dict'])
        self.start_epoch = 0

        l = [d.name for d in os.scandir(self.val_path) if d.is_dir()]
        l.sort()
        l[l.index('Ram CV Cargo Van Minivan 2012'
                  )] = 'Ram C/V Cargo Van Minivan 2012'
        self.label_texts = l
Beispiel #6
0
 def FiveCrop(self, **args):
     return self._add(transforms.FiveCrop(**args))
Beispiel #7
0
    def __init__(self, gpu=False):
        models_directory = os.path.dirname(os.path.abspath(__file__))
        # DENSENET
        self.N_CLASSES = 14
        self.CLASS_NAMES = [
            'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass',
            'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema',
            'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia'
        ]
        if gpu:
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            device = torch.device("cuda:0")
        else:
            device = torch.device("cpu")

        # initialize and load the model
        model_dense = DenseNet121(self.N_CLASSES).to(device).eval()
        if gpu:
            model_dense = torch.nn.DataParallel(model_dense).to(device).eval()
            checkpoint = torch.load(
                os.path.join(models_directory, "gpu_weight.pth"))
        else:
            checkpoint = torch.load(os.path.join(models_directory,
                                                 "cpu_weight.pth"),
                                    map_location=device)

        model_dense.load_state_dict(checkpoint)

        self.normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                              [0.229, 0.224, 0.225])
        self.transform_dense = transforms.Compose([
            transforms.Resize(256),
            transforms.FiveCrop(224),
            transforms.Lambda(lambda crops: torch.stack(
                [transforms.ToTensor()(crop) for crop in crops])),
            transforms.Lambda(lambda crops: torch.stack(
                [self.normalize(crop) for crop in crops]))
        ])

        self.model_dense = model_dense.to(device).eval()
        self.device = device

        # EFFNET
        model_eff = EfficientNet.from_name(model_name="efficientnet-b0",
                                           params=[1.0, 1.0, 224, 0.2],
                                           override_params={'num_classes': 2})
        state_dict = torch.load(os.path.join(models_directory,
                                             "effnet_weight.pth"),
                                map_location=device)
        model_eff.load_state_dict(state_dict)

        self.model_eff = model_eff.to(device).eval()

        self.transform_eff = transforms.Compose([
            transforms.Resize(224),
            transforms.Grayscale(3),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        #Resnet for detecting xray out of random images
        resnet_state_dict = torch.load(os.path.join(models_directory,
                                                    "nonxray.pth"),
                                       map_location=device)
        model_resnet = resnet18()
        model_resnet.fc = torch.nn.Linear(model_resnet.fc.in_features, 2)
        model_resnet.load_state_dict(resnet_state_dict)

        self.model_resnet = model_resnet.to(device).eval()

        self.transform_resnet = transforms.Compose([
            transforms.Resize(100),
            transforms.CenterCrop(64),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])