Exemplo n.º 1
0
    def __init__(self, configer):
        self.configer = configer

        if self.configer.get('data', 'image_tool') == 'pil':
            self.aug_train_transform = pil_aug_trans.PILAugCompose(
                self.configer, split='train')
            self.aug_val_transform = pil_aug_trans.PILAugCompose(self.configer,
                                                                 split='val')
        elif self.configer.get('data', 'image_tool') == 'cv2':
            self.aug_train_transform = cv2_aug_trans.CV2AugCompose(
                self.configer, split='train')
            self.aug_val_transform = cv2_aug_trans.CV2AugCompose(self.configer,
                                                                 split='val')
        else:
            Log.error('Not support {} image tool.'.format(
                self.configer.get('data', 'image_tool')))
            exit(1)

        self.img_transform = trans.Compose([
            trans.ToTensor(),
            trans.Normalize(div_value=self.configer.get(
                'normalize', 'div_value'),
                            mean=self.configer.get('normalize', 'mean'),
                            std=self.configer.get('normalize', 'std')),
        ])

        self.label_transform = trans.Compose([
            trans.ToLabel(),
            trans.ReLabel(255, -1),
        ])
Exemplo n.º 2
0
    def __init__(self, configer):
        self.configer = configer

        from lib.datasets.tools import cv2_aug_transforms
        self.aug_train_transform = cv2_aug_transforms.CV2AugCompose(self.configer, split='train')
        self.aug_val_transform = cv2_aug_transforms.CV2AugCompose(self.configer, split='val')

        self.img_transform = trans.Compose([
            trans.ToTensor(),
            trans.Normalize(div_value=self.configer.get('normalize', 'div_value'),
                            mean=self.configer.get('normalize', 'mean'),
                            std=self.configer.get('normalize', 'std')), ])

        self.label_transform = trans.Compose([
            trans.ToLabel(),
            trans.ReLabel(255, -1), ])
Exemplo n.º 3
0
    def __init__(self, configer):
        self.configer = configer

        from lib.datasets.tools import cv2_aug_transforms
        self.aug_train_transform = cv2_aug_transforms.CV2AugCompose(self.configer, split='train')
        self.aug_val_transform = cv2_aug_transforms.CV2AugCompose(self.configer, split='val')

        self.torch_img_transform = torch_trans.ColorJitter(
            brightness=self.configer.get('train_trans', 'color_aug'),
            contrast=self.configer.get('train_trans', 'color_aug'),
            saturation=self.configer.get('train_trans', 'color_aug'),
            hue=self.configer.get('train_trans', 'color_aug'))

        self.img_transform = trans.Compose([
            trans.ToTensor(),
            trans.Normalize(div_value=self.configer.get('normalize', 'div_value'),
                            mean=self.configer.get('normalize', 'mean'),
                            std=self.configer.get('normalize', 'std'))])

        self.label_transform = trans.Compose([
            trans.ToLabel(),
            trans.ReLabel(255, -1)])
Exemplo n.º 4
0
    def test(self, img_path=None, output_dir=None, data_loader=None):
        """
          Validation function during the train phase.
        """
        print("test!!!")
        self.seg_net.eval()
        start_time = time.time()
        image_id = 0

        Log.info('save dir {}'.format(self.save_dir))
        FileHelper.make_dirs(self.save_dir, is_file=False)

        colors = get_ade_colors()

        # Reader.
        if img_path is not None:
            input_path = img_path
        else:
            input_path = self.configer.get('input_image')

        input_image = cv2.imread(input_path)

        transform = trans.Compose([
            trans.ToTensor(),
            trans.Normalize(div_value=self.configer.get('normalize', 'div_value'),
                            mean=self.configer.get('normalize', 'mean'),
                            std=self.configer.get('normalize', 'std')), ])

        aug_val_transform = cv2_aug_transforms.CV2AugCompose(self.configer, split='val')

        pre_vis_img = None
        pre_lines = None
        pre_target_img = None
        
        ori_img = input_image.copy()

        h, w, _ = input_image.shape
        ori_img_size = [w, h]

        # print(img.shape)
        input_image = aug_val_transform(input_image)
        input_image = input_image[0]
            
        h, w, _ = input_image.shape
        border_size = [w, h]

        input_image = transform(input_image)
        # print(img)
        # print(img.shape)

        # inputs = data_dict['img']
        # names = data_dict['name']
        # metas = data_dict['meta']
        
        # print(inputs)

        with torch.no_grad():
            # Forward pass.
            outputs = self.ss_test([input_image])

            if isinstance(outputs, torch.Tensor):
                outputs = outputs.permute(0, 2, 3, 1).cpu().numpy()
                n = outputs.shape[0]
            else:
                outputs = [output.permute(0, 2, 3, 1).cpu().numpy().squeeze() for output in outputs]
                n = len(outputs)

            logits = cv2.resize(outputs[0],
                                tuple(ori_img_size), interpolation=cv2.INTER_CUBIC)
            label_img = np.asarray(np.argmax(logits, axis=-1), dtype=np.uint8)
            if self.configer.exists('data', 'reduce_zero_label') and self.configer.get('data', 'reduce_zero_label'):
                label_img = label_img + 1
                label_img = label_img.astype(np.uint8)
            if self.configer.exists('data', 'label_list'):
                label_img_ = self.__relabel(label_img)
            else:
                label_img_ = label_img
            label_img_ = Image.fromarray(label_img_, 'P')

            input_name = '.'.join(os.path.basename(input_path).split('.')[:-1])
            if output_dir is None:
                label_path = os.path.join(self.save_dir, 'label_{}.png'.format(input_name))
            else:
                label_path = os.path.join(output_dir, 'label_{}.png'.format(input_name))
            FileHelper.make_dirs(label_path, is_file=True)
            # print(f"{label_path}")
            ImageHelper.save(label_img_, label_path)

        self.batch_time.update(time.time() - start_time)

        # Print the log info & reset the states.
        Log.info('Test Time {batch_time.sum:.3f}s'.format(batch_time=self.batch_time))
Exemplo n.º 5
0
    def test(self, data_loader=None):
        """
          Validation function during the train phase.
        """
        print("test!!!")
        self.seg_net.eval()
        start_time = time.time()
        image_id = 0

        Log.info('save dir {}'.format(self.save_dir))
        FileHelper.make_dirs(self.save_dir, is_file=False)

        colors = get_ade_colors()

        # Reader.
        input_path = self.configer.get('input_video')
        cap = cv2.VideoCapture(self.configer.get('input_video'))

        total_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
        v_w = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
        v_h = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
        total_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
        fps = cap.get(cv2.CAP_PROP_FPS)

        # Writer.
        output_name = '.'.join(os.path.basename(input_path).split('.')[:-1])
        output_name = output_name + '_out.avi'
        RATIO_IMG_W = 200
        fourcc = cv2.VideoWriter_fourcc(*'XVID')
        out = cv2.VideoWriter(output_name, fourcc, fps,
                              (int(v_w + RATIO_IMG_W), int(v_h)))

        transform = trans.Compose([
            trans.ToTensor(),
            trans.Normalize(div_value=self.configer.get(
                'normalize', 'div_value'),
                            mean=self.configer.get('normalize', 'mean'),
                            std=self.configer.get('normalize', 'std')),
        ])

        aug_val_transform = cv2_aug_transforms.CV2AugCompose(self.configer,
                                                             split='val')

        pre_vis_img = None
        pre_lines = None
        pre_target_img = None
        for i in tqdm(range(int(total_frames))):
            ret, img = cap.read()

            if not ret:
                break

            ori_img = img.copy()

            h, w, _ = img.shape
            ori_img_size = [w, h]

            # print(img.shape)
            img = aug_val_transform(img)
            img = img[0]

            h, w, _ = img.shape
            border_size = [w, h]

            img = transform(img)
            # print(img)
            # print(img.shape)

            # inputs = data_dict['img']
            # names = data_dict['name']
            # metas = data_dict['meta']

            # print(inputs)

            with torch.no_grad():
                # Forward pass.
                outputs = self.ss_test([img])

                if isinstance(outputs, torch.Tensor):
                    outputs = outputs.permute(0, 2, 3, 1).cpu().numpy()
                    n = outputs.shape[0]
                else:
                    outputs = [
                        output.permute(0, 2, 3, 1).cpu().numpy().squeeze()
                        for output in outputs
                    ]
                    n = len(outputs)

                image_id += 1
                logits = cv2.resize(outputs[0],
                                    tuple(ori_img_size),
                                    interpolation=cv2.INTER_CUBIC)
                label_img = np.asarray(np.argmax(logits, axis=-1),
                                       dtype=np.uint8)
                if self.configer.exists(
                        'data', 'reduce_zero_label') and self.configer.get(
                            'data', 'reduce_zero_label'):
                    label_img = label_img + 1
                    label_img = label_img.astype(np.uint8)
                if self.configer.exists('data', 'label_list'):
                    label_img_ = self.__relabel(label_img)
                else:
                    label_img_ = label_img
                # print(f"label_img_1 {label_img_}")
                lines = self.get_ratio_all(label_img_)
                # print(f"lines {lines}")
                vis_img = self.visualize(label_img_)
                # print(f"vis_img {vis_img.shape}")

                pre_vis_img = vis_img
                # # if pre_vis_img is None:
                #     pre_vis_img = vis_img

                # if i % fps == 0:
                #     pre_vis_img = vis_img

                alpha = 0.5
                cv2.addWeighted(pre_vis_img, alpha, ori_img, 1 - alpha, 0,
                                ori_img)

                pre_lines = lines
                # if pre_lines is None:
                #     pre_lines = lines

                # if i % fps == 0:
                #     pre_lines = lines
                ratio_img = self.visualize_ratio(pre_lines, (v_h, v_w),
                                                 RATIO_IMG_W)

                target_img = cv2.hconcat([ori_img, ratio_img])

                target_img = cv2.cvtColor(target_img, cv2.COLOR_RGB2BGR)

                # if pre_target_img is None:
                #     pre_target_img = target_img

                # if i % fps == 0:
                #     pre_target_img = target_img

                out.write(target_img)

                # label_img_ = Image.fromarray(label_img_, 'P')
                # Log.info('{:4d}/{:4d} label map generated'.format(image_id, self.test_size))
                # label_path = os.path.join(self.save_dir, "label/", '{:05d}.png'.format(i))
                # FileHelper.make_dirs(label_path, is_file=True)
                # ImageHelper.save(label_img_, label_path)
                # # colorize the label-map
                # color_img_ = Image.fromarray(target_img)
                # vis_path = os.path.join(self.save_dir, "vis/", '{:05d}.png'.format(i))
                # FileHelper.make_dirs(vis_path, is_file=True)
                # ImageHelper.save(color_img_, save_path=vis_path)

            self.batch_time.update(time.time() - start_time)
            start_time = time.time()

        # Print the log info & reset the states.
        Log.info('Test Time {batch_time.sum:.3f}s'.format(
            batch_time=self.batch_time))