Exemple #1
0
class ListDataset(data.Dataset):
    img_size = 300

    def __init__(self, root, list_file, train, transform):
        '''
        Args:
          root: (str) ditectory to images.
          list_file: (str) path to index file.
          train: (boolean) train or test.
          transform: ([transforms]) image transforms.
        '''
        self.root = root
        self.train = train
        self.transform = transform

        self.fnames = []
        self.boxes = []
        self.labels = []

        self.data_encoder = DataEncoder()

        with open(list_file) as f:
            lines = f.readlines()
            self.num_samples = len(lines)

        for line in lines:
            splited = line.strip().split()
            self.fnames.append(splited[0])

            num_objs = int(splited[1])
            box = []
            label = []
            for i in range(num_objs):
                xmin = splited[2 + 5 * i]
                ymin = splited[3 + 5 * i]
                xmax = splited[4 + 5 * i]
                ymax = splited[5 + 5 * i]
                c = splited[6 + 5 * i]
                box.append(
                    [float(xmin),
                     float(ymin),
                     float(xmax),
                     float(ymax)])
                label.append(int(c))
            self.boxes.append(torch.Tensor(box))
            self.labels.append(torch.LongTensor(label))

    def __getitem__(self, idx):
        '''Load a image, and encode its bbox locations and class labels.

        Args:
          idx: (int) image index.

        Returns:
          img: (tensor) image tensor.
          loc_target: (tensor) location targets, sized [8732,4].
          conf_target: (tensor) label targets, sized [8732,].
        '''
        # Load image and bbox locations.
        fname = self.fnames[idx]
        img = Image.open(os.path.join(self.root, fname))
        boxes = self.boxes[idx].clone()
        labels = self.labels[idx]

        # Data augmentation while training.
        if self.train:
            img, boxes = self.random_flip(img, boxes)
            img, boxes, labels = self.random_crop(img, boxes, labels)

        # Scale bbox locaitons to [0,1].
        w, h = img.size
        boxes /= torch.Tensor([w, h, w, h]).expand_as(boxes)

        img = img.resize((self.img_size, self.img_size))
        img = self.transform(img)

        # Encode loc & conf targets.
        loc_target, conf_target = self.data_encoder.encode(boxes, labels)
        return img, loc_target, conf_target

    def random_flip(self, img, boxes):
        '''Randomly flip the image and adjust the bbox locations.

        For bbox (xmin, ymin, xmax, ymax), the flipped bbox is:
        (w-xmax, ymin, w-xmin, ymax).

        Args:
          img: (PIL.Image) image.
          boxes: (tensor) bbox locations, sized [#obj, 4].

        Returns:
          img: (PIL.Image) randomly flipped image.
          boxes: (tensor) randomly flipped bbox locations, sized [#obj, 4].
        '''
        if random.random() < 0.5:
            img = img.transpose(Image.FLIP_LEFT_RIGHT)
            w = img.width
            xmin = w - boxes[:, 2]
            xmax = w - boxes[:, 0]
            boxes[:, 0] = xmin
            boxes[:, 2] = xmax
        return img, boxes

    def random_crop(self, img, boxes, labels):
        '''Randomly crop the image and adjust the bbox locations.

        For more details, see 'Chapter2.2: Data augmentation' of the paper.

        Args:
          img: (PIL.Image) image.
          boxes: (tensor) bbox locations, sized [#obj, 4].
          labels: (tensor) bbox labels, sized [#obj,].

        Returns:
          img: (PIL.Image) cropped image.
          selected_boxes: (tensor) selected bbox locations.
          labels: (tensor) selected bbox labels.
        '''
        imw, imh = img.size
        while True:
            min_iou = random.choice([None, 0.1, 0.3, 0.5, 0.7, 0.9])
            if min_iou is None:
                return img, boxes, labels

            for _ in range(100):
                w = random.randrange(int(0.1 * imw), imw)
                h = random.randrange(int(0.1 * imh), imh)

                if h > 2 * w or w > 2 * h:
                    continue

                x = random.randrange(imw - w)
                y = random.randrange(imh - h)
                roi = torch.Tensor([[x, y, x + w, y + h]])

                center = (boxes[:, :2] + boxes[:, 2:]) / 2  # [N,2]
                roi2 = roi.expand(len(center), 4)  # [N,4]
                mask = (center > roi2[:, :2]) & (center < roi2[:, 2:])  # [N,2]
                mask = mask[:, 0] & mask[:, 1]  # [N,]
                if not mask.any():
                    continue

                selected_boxes = boxes.index_select(0,
                                                    mask.nonzero().squeeze(1))

                iou = self.data_encoder.iou(selected_boxes, roi)
                if iou.min() < min_iou:
                    continue

                img = img.crop((x, y, x + w, y + h))
                selected_boxes[:, 0].add_(-x).clamp_(min=0, max=w)
                selected_boxes[:, 1].add_(-y).clamp_(min=0, max=h)
                selected_boxes[:, 2].add_(-x).clamp_(min=0, max=w)
                selected_boxes[:, 3].add_(-y).clamp_(min=0, max=h)
                return img, selected_boxes, labels[mask]

    def __len__(self):
        return self.num_samples
Exemple #2
0
class BBoxDataset(torch.utils.data.Dataset):
    """
    define the labels that are used in object detection model like RetinaNet
    """
    def __init__(self, denoise=True):
        self.signals = []
        self.boxes = []
        self.labels = []
        self.peaks = []

        self.num_samples = 0
        
        self.raw_dataset_path = config["General"]["LUDB_path"]

        self.encoder = DataEncoder()
        self.get_signal_annotations(leads_seperate=True, normalize=True, denoise=denoise, gaussian_noise_sigma=wandb.config.augmentation_gaussian_noise_sigma, data_augmentation=wandb.config.data_augmentation)
        

    def get_signal_annotations(self, leads_seperate=True, normalize=True, denoise=True, gaussian_noise_sigma=0.1, data_augmentation=True):
        """
        compute and save the bbox result in dataset
        
        Args:
            leads_seperate: (bool) seperate leads or not
            normalize: (bool) normalize the signal or not
            denoise: (bool) denoise the signal using wavelet thresholding or not
            gaussian_noise_sigma: (float) the noise sigma add to data augmentation, if equals 0, then there will be no data augmentation
            data_augmentation: (bool) use data augmentation that scale the signal on different segments or not. 
        
        signals:    (Tensor) with sized [#signal, signal_length]
        boxes:      (list) with sized [#signal, #objs, 2]
        labels:     (list) with sized [#signal, #objs, ]
        peaks:      (list) with sized [#signal, #objs, ]
        """
        if denoise and path.exists(config["RetinaNet"]["output_path"]+"LUDB_preprocessed_data_denoise.pt"):
            self.signals, self.boxes, self.labels, self.peaks, self.num_samples = torch.load(config["RetinaNet"]["output_path"]+"LUDB_preprocessed_data_denoise.pt")
        elif not denoise and path.exists(config["RetinaNet"]["output_path"]+"LUDB_preprocessed_data.pt"):
            self.signals, self.boxes, self.labels, self.peaks, self.num_samples = torch.load(config["RetinaNet"]["output_path"]+"LUDB_preprocessed_data.pt")
        else:
            signals, bboxes, labels, peaks = load_raw_dataset_and_bbox_labels(self.raw_dataset_path)
            signals_, bboxes_, labels_, peaks_ = load_raw_dataset_and_bbox_labels_CAL()

            signals.extend(signals_)
            bboxes.extend(bboxes_)
            labels.extend(labels_)
            peaks.extend(peaks_)

            # with sized [#subjects, #leads, signal_length] [#subjects, #leads, #objs, 2] [#subjects, #leads, #objs]
            if gaussian_noise_sigma != 0.0:
                signals = signal_augmentation(signals)
                bboxes_aug = bboxes.copy()
                labels_aug = labels.copy()
                peaks_aug = peaks.copy()
                bboxes = [*bboxes, *bboxes_aug]
                labels = [*labels, *labels_aug]
                peaks = [*peaks, *peaks_aug]
                
            if leads_seperate == True:
                num_subjects = len(signals)
                for i in range(num_subjects):
                    num_leads = len(signals[i])
                    if denoise:
                        d = ekg_denoise(signals[i])
                    for j in range(num_leads):
                        self.signals.append(torch.Tensor(signals[i][j])) if not denoise else self.signals.append(torch.Tensor(d[j]))
                        self.boxes.append(torch.Tensor(bboxes[i][j]))
                        self.labels.append(torch.Tensor(labels[i][j]))
                        self.peaks.append(torch.Tensor(peaks[i][j]))

                        self.num_samples += 1
            if normalize:
                self.signals = Normalize(torch.stack(self.signals), instance=True)
                
            if denoise:
                torch.save((self.signals, self.boxes, self.labels, self.peaks, self.num_samples), "./data/LUDB_preprocessed_data_denoise.pt")
            else:
                torch.save((self.signals, self.boxes, self.labels, self.peaks, self.num_samples), "./data/LUDB_preprocessed_data.pt")
        
        if data_augmentation:
            for i in range(self.signals):
                x = self.signal[i].copy()
                for j in range(len(self.boxes[i])):
                    if self.labels[i][j] == 0: # p duration
                        for k in range(self.boxes[i][j][0], self.boxes[i][j][1]):
                            x[k] *= 0.8
                self.signal.append(x)
                self.boxes.append(self.boxes[i].copy())
                self.labels.append(self.labels[i].copy())
                self.peaks.append(self.peaks[i].copy())
                self.num_samples += 1

    def __getitem__(self,idx):
        """
        Load signal

        Args:
            idx: (int) signal index
        
        Returns:
            sig:         (Tensor) signal tensor
            loc_targets: (Tensor) location targets
            cls_targets: (Tensor) class label targets
        """
        sig = self.signals[idx]
        boxes = self.boxes[idx]
        labels = self.labels[idx]
        peaks = self.peaks[idx]

        return sig, boxes, labels, peaks

    def collate_fn(self, batch):
        """
        Encode targets
        
        Args:
            batch: (list) of signals, loc_targets, cls_targets
        """
        sigs = [x[0] for x in batch]
        boxes = [x[1] for x in batch]
        labels = [x[2] for x in batch]
        peaks = [x[3] for x in batch]

        input_size = 3968 # data length

        num_sigs = len(sigs)

        inputs = torch.zeros(num_sigs, input_size)

        loc_targets = []
        cls_targets = []
        for i in range(num_sigs):
            inputs[i] = sigs[i]
            
            loc_target, cls_target = self.encoder.encode(boxes[i], labels[i], input_size)

            loc_targets.append(loc_target)
            cls_targets.append(cls_target)
        
        return inputs, torch.stack(loc_targets), torch.stack(cls_targets), boxes, labels, peaks

    def __len__(self):
        return self.num_samples