def get_input_pair(self, data_info_row):
        if len(self.channels) == 0:
            raise Exception('You have to specify at least one channel.')

        instance_name = '_'.join(
            [data_info_row['name'],
             str(data_info_row['position'])])
        image_path = get_filepath(self.dataset_path,
                                  data_info_row['name'],
                                  self.images_folder,
                                  instance_name,
                                  file_type=self.image_type)
        mask_path = get_filepath(self.dataset_path,
                                 data_info_row['name'],
                                 self.masks_folder,
                                 instance_name,
                                 file_type=self.mask_type)

        images_array = filter_by_channels(read_tensor(image_path),
                                          self.channels, self.neighbours)

        if images_array.ndim == 2:
            images_array = np.expand_dims(images_array, -1)

        masks_array = read_tensor(mask_path)

        aug = Compose([
            RandomRotate90(),
            Flip(),
            OneOf(
                [
                    RandomSizedCrop(min_max_height=(int(
                        self.image_size * 0.7), self.image_size),
                                    height=self.image_size,
                                    width=self.image_size),
                    RandomBrightnessContrast(brightness_limit=0.15,
                                             contrast_limit=0.15),
                    #MedianBlur(blur_limit=3, p=0.2),
                    MaskDropout(p=0.6),
                    ElasticTransform(alpha=15, sigma=5, alpha_affine=5),
                    GridDistortion(p=0.6)
                ],
                p=0.8),
            ToTensor()
        ])

        augmented = aug(image=images_array, mask=masks_array)
        augmented_images = augmented['image']
        augmented_masks = augmented['mask']
        if self.classification_head:
            masks_class = ((augmented_masks.sum() > 0) *
                           1).unsqueeze(-1).float()  #.type(torch.FloatTensor)
            return augmented_images, [augmented_masks, masks_class]
        else:
            return {'features': augmented_images, 'targets': augmented_masks}
Beispiel #2
0
def predict(data_path, model_weights_path, network, test_df_path, save_path,
            size, channels, neighbours, classification_head):
    model = get_model(network, classification_head)
    model.encoder.conv1 = nn.Conv2d(count_channels(channels) * neighbours,
                                    64,
                                    kernel_size=(7, 7),
                                    stride=(2, 2),
                                    padding=(3, 3),
                                    bias=False)

    model, device = UtilsFactory.prepare_model(model)

    if classification_head:
        model.load_state_dict(torch.load(model_weights_path))
    else:
        checkpoint = torch.load(model_weights_path, map_location='cpu')
        model.load_state_dict(checkpoint['model_state_dict'])

    test_df = pd.read_csv(test_df_path)

    predictions_path = os.path.join(save_path, "predictions")

    if not os.path.exists(predictions_path):
        os.makedirs(predictions_path, exist_ok=True)
        print("Prediction directory created.")

    for _, image_info in tqdm(test_df.iterrows()):
        filename = '_'.join([image_info['name'], image_info['position']])
        image_path = get_filepath(data_path,
                                  image_info['dataset_folder'],
                                  'images',
                                  filename,
                                  file_type='tiff')

        image_tensor = filter_by_channels(read_tensor(image_path), channels,
                                          neighbours)
        if image_tensor.ndim == 2:
            image_tensor = np.expand_dims(image_tensor, -1)

        image = transforms.ToTensor()(image_tensor)
        if classification_head:
            prediction, label = model.predict(
                image.view(1,
                           count_channels(channels) * neighbours, size,
                           size).to(device, dtype=torch.float))
        else:
            prediction = model.predict(
                image.view(1,
                           count_channels(channels) * neighbours, size,
                           size).to(device, dtype=torch.float))

        result = prediction.view(size, size).detach().cpu().numpy()

        cv.imwrite(get_filepath(predictions_path, filename, file_type='png'),
                   result * 255)
    def __getitem__(self, idx):
        if len(self.channels) < 2:
            raise Exception('You have to specify at least two channels.')

        data_info_row = self.df.iloc[idx]
        instance_name = '_'.join(
            [data_info_row['name'], data_info_row['position']])

        images_array, masks_array = [], []
        #for k in range(1,self.num_images+1):
        for k in range(self.num_images, 0, -1):
            image_path = get_filepath(self.dataset_path,
                                      data_info_row['dataset_folder'],
                                      self.images_folder,
                                      instance_name + f'_{k}',
                                      file_type=self.image_type)

            img = filter_by_channels(read_tensor(image_path), self.channels, 1)
            images_array.append(img)

            mask_path = get_filepath(self.dataset_path,
                                     data_info_row['dataset_folder'],
                                     self.masks_folder,
                                     instance_name + f'_{k}',
                                     file_type=self.mask_type)
            msk = read_tensor(mask_path)
            masks_array.append(np.expand_dims(msk, axis=-1))

        aug = Compose([
            RandomRotate90(),
            Flip(),
            OneOf([
                RandomSizedCrop(min_max_height=(int(
                    self.image_size * 0.7), self.image_size),
                                height=self.image_size,
                                width=self.image_size),
                RandomBrightnessContrast(brightness_limit=0.15,
                                         contrast_limit=0.15),
                ElasticTransform(alpha=15, sigma=5, alpha_affine=5),
                GridDistortion(p=0.6)
            ],
                  p=0.8),
            ToTensor()
        ])

        augmented = aug(image=np.concatenate(images_array, axis=-1),
                        mask=np.concatenate(masks_array, axis=-1))

        augmented_images = torch.stack([
            augmented['image'][num_img *
                               count_channels(self.channels):(num_img + 1) *
                               count_channels(self.channels), :, :]
            for num_img in range(self.num_images)
        ])
        if self.all_masks:
            augmented_masks = torch.stack([
                augmented['mask'][:, :, :, i]
                for i in range(augmented['mask'].shape[-1])
            ]).squeeze()
        else:
            augmented_masks = torch.stack([augmented['mask'][:, :, :, -1]])

        return {
            'features': augmented_images,
            'targets': augmented_masks,
            'name': data_info_row['name'],
            'position': data_info_row['position']
        }
    def __getitem__(self, idx):
        if len(self.channels) < 2:
            raise Exception('You have to specify at least two channels.')

        data_info_row = self.df.iloc[idx]
        instance_name = '_'.join(
            [data_info_row['name'], data_info_row['position']])
        images_array, masks_array = [], []
        for k in range(1, self.num_images + 1):
            image_path = get_filepath(self.dataset_path,
                                      data_info_row['dataset_folder'],
                                      self.images_folder,
                                      instance_name + f'_{k}',
                                      file_type=self.image_type)

            img = filter_by_channels(read_tensor(image_path), self.channels, 1)
            images_array.append(img)

        mask_path = get_filepath(self.dataset_path,
                                 data_info_row['dataset_folder'],
                                 self.masks_folder,
                                 instance_name,
                                 file_type=self.mask_type)
        masks_array = read_tensor(mask_path)

        if self.phase == 'train':
            aug = Compose([
                RandomRotate90(),
                Flip(),
                OneOf(
                    [
                        RandomSizedCrop(min_max_height=(int(
                            self.image_size * 0.7), self.image_size),
                                        height=self.image_size,
                                        width=self.image_size),
                        RandomBrightnessContrast(brightness_limit=0.15,
                                                 contrast_limit=0.15),
                        #MedianBlur(blur_limit=3, p=0.2),
                        MaskDropout(p=0.6),
                        ElasticTransform(alpha=15, sigma=5, alpha_affine=5),
                        GridDistortion(p=0.6)
                    ],
                    p=0.8),
                ToTensor()
            ])
        else:
            aug = ToTensor()
        '''
        keys = ['image']
        values = [images_array[0]]
        for k in range(self.num_images-1):
            keys.append(f'image{k}')
            values.append(images_array[k+1])
        
        keys.append('mask')
        values.append(masks_array)
        
        #{"image" : images_array[0], "image2" : images_array[1], ..., "mask": masks_array, ...}
        aug_input = { keys[i] : values[i] for i in range(len(keys)) }

        augmented = aug(**aug_input)

        augmented_images = [augmented['image']]
        for k in range(self.num_images-1):
            augmented_images.append(np.transpose(augmented[f'image{k}'], ( 2, 0, 1))/255)

        augmented_masks = [augmented['mask']]

        return {'features': augmented_images, 'targets': augmented_masks, 'name': data_info_row['name'], 'position': data_info_row['position']}
        '''

        augmented = aug(image=np.concatenate(
            (images_array[0], images_array[1]), axis=-1),
                        mask=masks_array)

        augmented_images = [
            augmented['image'][:count_channels(self.channels), :, :],
            augmented['image'][count_channels(self.channels):, :, :]
        ]
        augmented_masks = [augmented['mask']]

        return {
            'features': augmented_images,
            'targets': augmented_masks,
            'name': data_info_row['name'],
            'position': data_info_row['position']
        }