예제 #1
0
def get_one(sample, new_size, args):
    if len(sample) == 4:
        # guide image is both for appearance and location guidance
        guide_image = Image.open(sample[0])
        guide_label = Image.open(sample[1])
        image = Image.open(sample[2])
        label = Image.open(sample[3])
        ref_label = guide_label
    else:
        # guide image is only for appearance guidance, ref label is only for location guidance
        guide_image = Image.open(sample[0])
        guide_label = Image.open(sample[1])
        #guide_image = Image.open(sample[2])
        ref_label = Image.open(sample[2])
        image = Image.open(sample[3])
        label = Image.open(sample[4])
    if len(sample) > 5:
        label_id = sample[5]
    else:
        label_id = 0
    image = image.resize(new_size, Image.BILINEAR)
    label = label.resize(new_size, Image.NEAREST)
    ref_label = ref_label.resize(new_size, Image.NEAREST)
    guide_label = guide_label.resize(guide_image.size, Image.NEAREST)
    if label_id > 0:
        guide_label = _get_obj_mask(guide_label, label_id)
        ref_label = _get_obj_mask(ref_label, label_id)
        label = _get_obj_mask(label, label_id)
    guide_label_data = np.array(guide_label)
    bbox = get_mask_bbox(guide_label_data)
    guide_image = guide_image.crop(bbox)
    guide_label = guide_label.crop(bbox)
    guide_image, guide_label = data_augmentation(
        guide_image,
        guide_label,
        args.guide_size,
        data_aug_flip=args.data_aug_flip,
        keep_aspect_ratio=args.vg_keep_aspect_ratio,
        random_crop_ratio=args.vg_random_crop_ratio,
        random_rotate_angle=args.vg_random_rotate_angle,
        color_aug=args.vg_color_aug)
    if not args.use_original_mask:
        gb_image = get_gb_image(np.array(ref_label),
                                center_perturb=args.sg_center_perturb_ratio,
                                std_perturb=args.sg_std_perturb_ratio)
    else:
        gb_image = perturb_mask(np.array(ref_label))
        gb_image = ndimage.morphology.binary_dilation(
            gb_image, structure=args.dilate_structure) * 255
    image_data = np.array(image, dtype=np.float32)
    label_data = np.array(label, dtype=np.uint8) > 0
    image_data = to_bgr(image_data)
    image_data = (image_data - args.mean_value) * args.scale_value
    guide_label_data = np.array(guide_label, dtype=np.uint8)
    guide_image_data = np.array(guide_image, dtype=np.float32)
    guide_image_data = to_bgr(guide_image_data)
    guide_image_data = (guide_image_data - args.mean_value) * args.scale_value
    guide_image_data = mask_image(guide_image_data, guide_label_data)
    return guide_image_data, gb_image, image_data, label_data
예제 #2
0
    def next_batch(self, batch_size, phase):
        """Get next batch of image (path) and labels
        Args:
        batch_size: Size of the batch
        phase: Possible options:'train' or 'test'
        Returns in training:
        images: Numpy arrays of the images
        labels: Numpy arrays of the labels
        Returns in testing:
        images: Numpy array of the images
        path: List of image paths
        """
        if phase == 'train':
            if self.train_ptr + batch_size <= self.train_size:
                idx = np.array(self.train_idx[self.train_ptr:self.train_ptr +
                                              batch_size])
                self.train_ptr += batch_size
            else:
                np.random.shuffle(self.train_idx)
                new_ptr = batch_size
                idx = np.array(self.train_idx[:new_ptr])
                self.train_ptr = new_ptr
            guide_images = []
            gb_images = []
            images = []
            labels = []
            if self.data_aug_scales:
                scale = random.choice(self.data_aug_scales)
                new_size = (int(self.size[0] * scale),
                            int(self.size[1] * scale))
            for i in idx:
                sample = self.train_list[i]
                if len(sample) == 4:
                    # guide image is both for appearance and location guidance
                    guide_image = Image.open(sample[0])
                    guide_label = Image.open(sample[1])
                    image = Image.open(sample[2])
                    label = Image.open(sample[3])
                    ref_label = guide_label
                else:
                    # guide image is only for appearance guidance, ref label is only for location guidance
                    guide_image = Image.open(sample[0])
                    guide_label = Image.open(sample[1])
                    #guide_image = Image.open(sample[2])
                    ref_label = Image.open(sample[2])
                    image = Image.open(sample[3])
                    label = Image.open(sample[4])
                image = image.resize(new_size, Image.BILINEAR)
                label = label.resize(new_size, Image.NEAREST)
                ref_label = ref_label.resize(new_size, Image.NEAREST)
                guide_label = guide_label.resize(guide_image.size,
                                                 Image.NEAREST)
                bbox = get_mask_bbox(np.array(guide_label))
                guide_image = guide_image.crop(bbox)
                guide_label = guide_label.crop(bbox)
                guide_image, guide_label = data_augmentation(
                    guide_image,
                    guide_label,
                    self.guide_size,
                    data_aug_flip=self.data_aug_flip,
                    keep_aspect_ratio=self.vg_keep_aspect_ratio,
                    random_crop_ratio=self.vg_random_crop_ratio,
                    random_rotate_angle=self.vg_random_rotate_angle,
                    color_aug=self.vg_color_aug)
                if not self.use_original_mask:
                    gb_image = get_gb_image(
                        np.array(ref_label),
                        center_perturb=self.sg_center_perturb_ratio,
                        std_perturb=self.sg_std_perturb_ratio)
                else:
                    gb_image = perturb_mask(np.array(ref_label))
                    gb_image = ndimage.morphology.binary_dilation(
                        gb_image, structure=self.dilate_structure) * 255
                image_data = np.array(image, dtype=np.float32)
                label_data = np.array(label, dtype=np.uint8) > 0
                image_data = to_bgr(image_data)
                image_data -= self.mean_value
                guide_label_data = np.array(guide_label, dtype=np.uint8)
                guide_image_data = np.array(guide_image, dtype=np.float32)
                guide_image_data = to_bgr(guide_image_data)
                guide_image_data -= self.mean_value
                if not self.bbox_sup:
                    guide_image_data = mask_image(guide_image_data,
                                                  guide_label_data)
                guide_images.append(guide_image_data)
                gb_images.append(gb_image)
                images.append(image_data)
                labels.append(label_data)
            images = np.array(images)
            gb_images = np.array(gb_images)[..., np.newaxis]
            labels = np.array(labels)[..., np.newaxis]
            guide_images = np.array(guide_images)
            return guide_images, gb_images, images, labels
        elif phase == 'test':
            guide_images = []
            gb_images = []
            images = []
            image_paths = []
            self.crop_boxes = []
            self.images = []
            assert batch_size == 1, "Only allow batch size = 1 for testing"
            if self.test_ptr + batch_size < self.test_size:
                idx = np.array(self.test_idx[self.test_ptr:self.test_ptr +
                                             batch_size])
                self.test_ptr += batch_size
            else:
                new_ptr = (self.test_ptr + batch_size) % self.test_size
                idx = np.hstack(
                    (self.test_idx[self.test_ptr:], self.test_idx[:new_ptr]))
                self.test_ptr = new_ptr
            i = idx[0]
            sample = self.test_list[i]
            if sample[0] == None:
                # visual guide image / mask is none, only read spatial guide and input image
                first_frame = False
                ref_label = Image.open(sample[2])
                image = Image.open(sample[3])
                frame_name = sample[3].split('/')[-1].split('.')[0] + '.png'
                if self.multiclass:
                    # seq_name/label_id/frame_name
                    ref_name = os.path.join(*(sample[2].split('/')[-3:-1] +
                                              [frame_name]))
                else:
                    # seq_name/frame_name
                    ref_name = os.path.join(sample[2].split('/')[-2],
                                            frame_name)
            else:
                # only process visual guide image / mask
                first_frame = True
                guide_image = Image.open(sample[0])
                guide_label = Image.open(sample[1])
                if self.multiclass:
                    # seq_name/label_id/frame_name
                    ref_name = os.path.join(*(sample[1].split('/')[-3:]))
                else:
                    # seq_name/frame_name
                    ref_name = os.path.join(*(sample[1].split('/')[-2:]))
            if not first_frame:
                if len(self.size) == 2:
                    self.new_size = self.size
                else:
                    # resize short size of image to self.size[0]
                    resize_ratio = max(
                        float(self.size[0]) / image.size[0],
                        float(self.size[0]) / image.size[1])
                    self.new_size = (int(resize_ratio * image.size[0]),
                                     int(resize_ratio * image.size[1]))
                ref_label = ref_label.resize(self.new_size, Image.NEAREST)
                ref_label_data = np.array(ref_label) / 255
                gb_image = get_gb_image(ref_label_data,
                                        center_perturb=0,
                                        std_perturb=0)
                image_ref_crf = image.resize(self.new_size, Image.BILINEAR)
                self.images.append(np.array(image_ref_crf))
                image = image.resize(self.new_size, Image.BILINEAR)
                if self.use_original_mask:
                    gb_image = ndimage.morphology.binary_dilation(
                        ref_label_data, structure=self.dilate_structure) * 255
                image_data = np.array(image, dtype=np.float32)
                image_data = to_bgr(image_data)
                image_data -= self.mean_value
                gb_images.append(gb_image)
                images.append(image_data)
                images = np.array(images)
                gb_images = np.array(gb_images)[..., np.newaxis]
                guide_images = None
            else:
                # process visual guide images
                guide_label = guide_label.resize(guide_image.size,
                                                 Image.NEAREST)
                bbox = get_mask_bbox(np.array(guide_label))
                guide_image = guide_image.crop(bbox)
                guide_label = guide_label.crop(bbox)
                guide_image, guide_label = data_augmentation(
                    guide_image,
                    guide_label,
                    self.guide_size,
                    data_aug_flip=False,
                    pad_ratio=self.vg_pad_ratio,
                    keep_aspect_ratio=self.vg_keep_aspect_ratio)

                #guide_image = guide_image.resize(self.guide_size, Image.BILINEAR)
                #guide_label = guide_label.resize(self.guide_size, Image.NEAREST)
                guide_image_data = np.array(guide_image, dtype=np.float32)
                guide_image_data = to_bgr(guide_image_data)
                guide_image_data -= self.mean_value
                guide_label_data = np.array(guide_label, dtype=np.uint8)
                if not self.bbox_sup:
                    guide_image_data = mask_image(guide_image_data,
                                                  guide_label_data)
                guide_images.append(guide_image_data)
                guide_images = np.array(guide_images)
                images = None
                gb_images = None
            image_paths.append(ref_name)
            return guide_images, gb_images, images, image_paths
        else:
            return None, None, None, None
예제 #3
0
    def next_batch(self, batch_size, phase):
        """Get next batch of image (path) and labels
        Args:
        batch_size: Size of the batch
        phase: Possible options:'train' or 'test'
        Returns in training:
        images: Numpy arrays of the images
        labels: Numpy arrays of the labels
        Returns in testing:
        images: Numpy array of the images
        path: List of image paths
        """
        if phase == 'train':
            if self.train_ptr + batch_size <= self.train_size:
                idx = np.array(self.train_idx[self.train_ptr:self.train_ptr +
                                              batch_size])
                self.train_ptr += batch_size
            else:
                np.random.shuffle(self.train_idx)
                new_ptr = batch_size
                idx = np.array(self.train_idx[:new_ptr])
                self.train_ptr = new_ptr
            guide_images = []
            gb_images = []
            images = []
            labels = []
            if self.data_aug_scales:
                scale = random.choice(self.data_aug_scales)
                new_size = (int(self.size[0] * scale),
                            int(self.size[1] * scale))
            if self.args.num_loader == 1:
                batch = [
                    get_one(self.train_list[i], new_size, self.args)
                    for i in idx
                ]
            else:
                batch = [
                    self.pool.apply(get_one,
                                    args=(self.train_list[i], new_size,
                                          self.args)) for i in idx
                ]
            for guide_image_data, gb_image, image_data, label_data in batch:

                guide_images.append(guide_image_data)
                gb_images.append(gb_image)
                images.append(image_data)
                labels.append(label_data)
            images = np.array(images)
            gb_images = np.array(gb_images)[..., np.newaxis]
            labels = np.array(labels)[..., np.newaxis]
            guide_images = np.array(guide_images)
            return guide_images, gb_images, images, labels
        elif phase == 'test':
            guide_images = []
            gb_images = []
            images = []
            image_paths = []
            self.crop_boxes = []
            self.images = []
            assert batch_size == 1, "Only allow batch size = 1 for testing"
            if self.test_ptr + batch_size < self.test_size:
                idx = np.array(self.test_idx[self.test_ptr:self.test_ptr +
                                             batch_size])
                self.test_ptr += batch_size
            else:
                new_ptr = (self.test_ptr + batch_size) % self.test_size
                idx = np.hstack(
                    (self.test_idx[self.test_ptr:], self.test_idx[:new_ptr]))
                self.test_ptr = new_ptr
            i = idx[0]
            sample = self.test_list[i]
            if len(sample) > 4:
                label_id = sample[4]
            else:
                label_id = 0

            if sample[0] == None:
                # visual guide image / mask is none, only read spatial guide and input image
                first_frame = False
                ref_label = Image.open(sample[2])
                image = Image.open(sample[3])
                frame_name = sample[3].split('/')[-1].split('.')[0] + '.png'
                if len(sample) > 5:
                    # vid_path/label_id/frame_name
                    ref_name = os.path.join(sample[5], frame_name)
                elif self.multiclass:
                    # seq_name/label_id/frame_name
                    ref_name = os.path.join(*(sample[2].split('/')[-3:-1] +
                                              [frame_name]))
                else:
                    # seq_name/frame_name
                    ref_name = os.path.join(sample[2].split('/')[-2],
                                            frame_name)
            else:
                # only process visual guide image / mask
                first_frame = True
                guide_image = Image.open(sample[0])
                guide_label = Image.open(sample[1])
                if len(sample) > 5:
                    # vid_path/label_id/frame_name
                    ref_name = os.path.join(sample[5],
                                            sample[1].split('/')[-1])
                elif self.multiclass:
                    # seq_name/label_id/frame_name
                    ref_name = os.path.join(*(sample[1].split('/')[-3:]))
                else:
                    # seq_name/frame_name
                    ref_name = os.path.join(*(sample[1].split('/')[-2:]))
            if not first_frame:
                if len(self.size) == 2:
                    self.new_size = self.size
                else:
                    # resize short size of image to self.size[0]
                    resize_ratio = max(
                        float(self.size[0]) / image.size[0],
                        float(self.size[0]) / image.size[1])
                    self.new_size = (int(resize_ratio * image.size[0]),
                                     int(resize_ratio * image.size[1]))
                ref_label = ref_label.resize(self.new_size, Image.NEAREST)
                if label_id > 0:
                    ref_label = _get_obj_mask(ref_label, label_id)
                ref_label_data = np.array(ref_label)
                image_ref_crf = image.resize(self.new_size, Image.BILINEAR)
                self.images.append(np.array(image_ref_crf))
                image = image.resize(self.new_size, Image.BILINEAR)
                if self.use_original_mask:
                    gb_image = ndimage.morphology.binary_dilation(
                        ref_label_data,
                        structure=self.args.dilate_structure) * 255
                else:
                    gb_image = get_gb_image(ref_label_data,
                                            center_perturb=0,
                                            std_perturb=0)
                image_data = np.array(image, dtype=np.float32)
                image_data = to_bgr(image_data)
                image_data = (image_data - self.mean_value) * self.scale_value
                gb_images.append(gb_image)
                images.append(image_data)
                images = np.array(images)
                gb_images = np.array(gb_images)[..., np.newaxis]
                guide_images = None
            else:
                # process visual guide images
                # resize to same size of guide_image first, in case of full resolution input
                guide_label = guide_label.resize(guide_image.size,
                                                 Image.NEAREST)
                if label_id > 0:
                    guide_label = _get_obj_mask(guide_label, label_id)
                bbox = get_mask_bbox(np.array(guide_label))
                guide_image = guide_image.crop(bbox)
                guide_label = guide_label.crop(bbox)
                guide_image, guide_label = data_augmentation(
                    guide_image,
                    guide_label,
                    self.args.guide_size,
                    data_aug_flip=False,
                    pad_ratio=self.vg_pad_ratio,
                    keep_aspect_ratio=self.vg_keep_aspect_ratio)

                guide_image_data = np.array(guide_image, dtype=np.float32)
                guide_image_data = to_bgr(guide_image_data)
                guide_image_data = (guide_image_data -
                                    self.mean_value) * self.scale_value
                guide_label_data = np.array(guide_label, dtype=np.uint8)
                if not self.bbox_sup:
                    guide_image_data = mask_image(guide_image_data,
                                                  guide_label_data)
                guide_images.append(guide_image_data)
                guide_images = np.array(guide_images)
                images = None
                gb_images = None
            image_paths.append(ref_name)
            return guide_images, gb_images, images, image_paths
        else:
            return None, None, None, None
예제 #4
0
    def next_batch(self, batch_size, phase):
        """Get next batch of image (path) and labels
        Args:
        batch_size: Size of the batch
        phase: Possible options:'train' or 'test'
        Returns in training:
        images: List of images paths if store_memory=False, List of Numpy arrays of the images if store_memory=True
        labels: List of labels paths if store_memory=False, List of Numpy arrays of the labels if store_memory=True
        Returns in testing:
        images: None if store_memory=False, Numpy array of the image if store_memory=True
        path: List of image paths
        """
        if phase == 'train':
            if self.train_ptr + batch_size < self.train_size:
                idx = np.array(self.train_idx[self.train_ptr:self.train_ptr +
                                              batch_size])
                self.train_ptr += batch_size
            else:
                np.random.shuffle(self.train_idx)
                new_ptr = batch_size
                idx = np.array(self.train_idx[:new_ptr])
                self.train_ptr = new_ptr
            images = []
            labels = []
            guide_images = []
            gb_images = []

            if self.data_aug_scales:
                scale = random.choice(self.data_aug_scales)
                new_size = (int(self.size[0] * scale),
                            int(self.size[1] * scale))
            else:
                new_size = self.size
            for i in idx:
                anno = self.train_annos[i]
                image_path = self.train_image_path.format(anno['image_id'])
                image = Image.open(image_path)
                label_data = self.train_data.annToMask(anno).astype(np.uint8)
                label = Image.fromarray(label_data)

                guide_image = image.crop(anno['bbox'])
                guide_label = label.crop(anno['bbox'])
                guide_image, guide_label = data_augmentation(
                    guide_image,
                    guide_label,
                    self.guide_size,
                    data_aug_flip=self.data_aug_flip,
                    keep_aspect_ratio=self.vg_keep_aspect_ratio,
                    random_crop_ratio=self.vg_random_crop_ratio,
                    random_rotate_angle=self.vg_random_rotate_angle,
                    color_aug=self.vg_color_aug)

                image, label = data_augmentation(
                    image,
                    label,
                    new_size,
                    data_aug_flip=self.data_aug_flip,
                    random_crop_ratio=self.random_crop_ratio)
                image_data = np.array(image, dtype=np.float32)
                label_data = np.array(label, dtype=np.float32)
                guide_image_data = np.array(guide_image, dtype=np.float32)
                guide_label_data = np.array(guide_label, dtype=np.uint8)
                if self.use_original_mask:
                    gb_image = perturb_mask(label_data)
                    gb_image = ndimage.morphology.binary_dilation(
                        gb_image, structure=self.dilate_structure) * 255
                else:
                    gb_image = get_gb_image(
                        label_data,
                        center_perturb=self.sg_center_perturb_ratio,
                        std_perturb=self.sg_std_perturb_ratio)
                image_data = to_bgr(image_data)
                guide_image_data = to_bgr(guide_image_data)
                image_data -= self.mean_value
                guide_image_data -= self.mean_value
                # masking
                if not self.bbox_sup:
                    guide_image_data = mask_image(guide_image_data,
                                                  guide_label_data)
                images.append(image_data)
                labels.append(label_data)
                guide_images.append(guide_image_data)
                gb_images.append(gb_image)
            images = np.array(images)
            labels = np.array(labels)
            gb_images = np.array(gb_images)
            labels = labels[..., np.newaxis]
            gb_images = gb_images[..., np.newaxis]
            guide_images = np.array(guide_images)
            return guide_images, gb_images, images, labels
        elif phase == 'test':
            guide_images = []
            gb_images = []
            images = []
            image_paths = []
            if self.test_ptr + batch_size < self.test_size:
                idx = np.array(self.test_idx[self.test_ptr:self.test_ptr +
                                             batch_size])
                self.test_ptr += batch_size
            else:
                new_ptr = (self.test_ptr + batch_size) % self.test_size
                idx = np.hstack(
                    (self.test_idx[self.test_ptr:], self.test_idx[:new_ptr]))
                self.test_ptr = new_ptr
            for i in idx:
                anno = self.test_annos[i]
                image_path = self.test_image_path.format(anno['image_id'])
                image = Image.open(image_path)
                label_data = self.test_data.annToMask(anno).astype(np.uint8)
                label = Image.fromarray(label_data)

                guide_image = image.crop(anno['bbox'])
                guide_label = label.crop(anno['bbox'])
                guide_image, guide_label = data_augmentation(
                    guide_image,
                    guide_label,
                    self.guide_size,
                    keep_aspect_ratio=self.vg_keep_aspect_ratio)

                image, label = data_augmentation(image,
                                                 label,
                                                 self.size,
                                                 data_aug_flip=False)
                image_data = np.array(image, dtype=np.float32)
                guide_image_data = np.array(guide_image, dtype=np.float32)
                image_data = to_bgr(image_data)
                guide_image_data = to_bgr(guide_image_data)
                image_data -= self.mean_value
                guide_image_data -= self.mean_value
                label_data = np.array(label, dtype=np.uint8)
                if self.use_original_mask:
                    gb_image = ndimage.morphology.binary_dilation(
                        label_data, structure=self.dilate_structure) * 255
                else:
                    gb_image = get_gb_image(label_data,
                                            center_perturb=0,
                                            std_perturb=0)
                guide_label_data = np.array(guide_label, dtype=np.uint8)
                # masking
                if not self.bbox_sup:
                    guide_image_data = mask_image(guide_image_data,
                                                  guide_label_data)
                images.append(image_data)
                gb_images.append(gb_image)
                # only need file name for result saving
                image_paths.append('%06d.png' % i)
                guide_images.append(guide_image_data)
            images = np.array(images)
            gb_images = np.array(gb_images)
            gb_images = gb_images[..., np.newaxis]
            guide_images = np.array(guide_images)

            return guide_images, gb_images, images, image_paths
        else:
            return None, None, None, None