def __init__(self, configer):
        self.configer = configer

        self.seg_vis = SegVisualizer(configer)
        self.seg_model_manager = SegModelManager(configer)
        self.module_utilizer = ModuleUtilizer(configer)
        self.seg_net = None
    def __init__(self, configer):
        self.configer = configer
        self.blob_helper = BlobHelper(configer)
        self.seg_visualizer = SegVisualizer(configer)
        self.seg_parser = SegParser(configer)
        self.seg_model_manager = SegModelManager(configer)
        self.seg_data_loader = DataLoader(configer)
        self.device = torch.device('cpu' if self.configer.get('gpu') is None else 'cuda')
        self.seg_net = None

        self._init_model()
Exemple #3
0
    def __init__(self, configer):
        self.configer = configer

        self.seg_visualizer = SegVisualizer(configer)
        self.seg_parser = SegParser(configer)
        self.seg_model_manager = SegModelManager(configer)
        self.seg_data_loader = SegDataLoader(configer)
        self.module_utilizer = ModuleUtilizer(configer)
        self.device = torch.device(
            'cpu' if self.configer.get('gpu') is None else 'cuda')
        self.seg_net = None
    def __init__(self, configer):
        self.configer = configer

        self.seg_vis = SegVisualizer(configer)
        self.seg_model_manager = SegModelManager(configer)
        self.module_utilizer = ModuleUtilizer(configer)
        self.seg_net = None
Exemple #5
0
    def __init__(self, configer):
        self.configer = configer
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.train_losses = AverageMeter()
        self.val_losses = AverageMeter()
        self.seg_visualizer = SegVisualizer(configer)
        self.seg_loss_manager = SegLossManager(configer)
        self.module_utilizer = ModuleUtilizer(configer)
        self.seg_model_manager = SegModelManager(configer)
        self.seg_data_loader = SegDataLoader(configer)

        self.seg_net = None
        self.train_loader = None
        self.val_loader = None
        self.optimizer = None
        self.lr = None
        self.iters = None
    def __init__(self, configer):
        self.configer = configer
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.train_losses = AverageMeter()
        self.val_losses = AverageMeter()
        self.seg_running_score = SegRunningScore(configer)
        self.seg_visualizer = SegVisualizer(configer)
        self.seg_loss_manager = LossManager(configer)
        self.seg_model_manager = SegModelManager(configer)
        self.seg_data_loader = DataLoader(configer)

        self.seg_net = None
        self.train_loader = None
        self.val_loader = None
        self.optimizer = None
        self.scheduler = None
        self.runner_state = dict()

        self._init_model()
Exemple #7
0
    def __init__(self, configer):
        self.configer = configer
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.train_losses = AverageMeter()
        self.val_losses = AverageMeter()
        self.seg_running_score = SegRunningScore(configer)
        self.seg_visualizer = SegVisualizer(configer)
        self.seg_loss_manager = SegLossManager(configer)
        self.module_utilizer = ModuleUtilizer(configer)
        self.data_transformer = DataTransformer(configer)
        self.seg_model_manager = SegModelManager(configer)
        self.seg_data_loader = SegDataLoader(configer)
        self.optim_scheduler = OptimScheduler(configer)

        self.seg_net = None
        self.train_loader = None
        self.val_loader = None
        self.optimizer = None
        self.scheduler = None

        self._init_model()
    def __init__(self, configer):
        self.configer = configer
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.train_losses = AverageMeter()
        self.val_losses = AverageMeter()
        self.seg_visualizer = SegVisualizer(configer)
        self.seg_loss_manager = SegLossManager(configer)
        self.module_utilizer = ModuleUtilizer(configer)
        self.seg_model_manager = SegModelManager(configer)
        self.seg_data_loader = SegDataLoader(configer)

        self.seg_net = None
        self.train_loader = None
        self.val_loader = None
        self.optimizer = None
        self.lr = None
        self.iters = None
class FCNSegmentorTest(object):
    def __init__(self, configer):
        self.configer = configer

        self.seg_vis = SegVisualizer(configer)
        self.seg_model_manager = SegModelManager(configer)
        self.module_utilizer = ModuleUtilizer(configer)
        self.seg_net = None

    def init_model(self):
        self.seg_net = self.seg_model_manager.seg_net()
        self.seg_net, _ = self.module_utilizer.load_net(self.seg_net)
        self.seg_net.eval()

    def forward(self, image_path):
        image = Image.open(image_path).convert('RGB')
        image = RandomResize(size=self.configer.get('data', 'input_size'), is_base=False)(image)
        image = ToTensor()(image)
        image = Normalize(mean=[128.0, 128.0, 128.0], std=[256.0, 256.0, 256.0])(image)
        inputs = Variable(image.unsqueeze(0).cuda(), volatile=True)
        results = self.seg_net.forward(inputs)
        return results.data.cpu().numpy().argmax(axis=1)[0].squeeze()

    def __test_img(self, image_path, save_path):
        if self.configer.get('dataset') == 'cityscape':
            self.__test_cityscape_img(image_path, save_path)
        elif self.configer.get('dataset') == 'laneline':
            self.__test_laneline_img(image_path, save_path)
        else:
            Log.error('Dataset: {} is not valid.'.format(self.configer.get('dataset')))
            exit(1)

    def __test_cityscape_img(self, img_path, save_path):
        color_list = [(128, 64, 128), (244, 35, 232), (70, 70, 70), (102, 102, 156), (190, 153, 153),
                      (153, 153, 153), (250, 170, 30), (220, 220, 0), (107, 142, 35), (152, 251, 152),
                      (70, 130, 180), (220, 20, 60), (255, 0, 0), (0, 0, 142), (0, 0, 70), (0, 60, 100),
                      (0, 80, 100), (0, 0, 230), (119, 11, 32)]

        result = self.forward(img_path)
        width = self.configer.get('data', 'input_size')[0] // self.configer.get('network', 'stride')
        height = self.configer.get('data', 'input_size')[1] // self.configer.get('network', 'stride')
        color_dst = np.zeros((height, width, 3), dtype=np.uint8)
        for i in range(self.configer.get('data', 'num_classes')):
            color_dst[result == i] = color_list[i]

        color_img = np.array(color_dst, dtype=np.uint8)
        color_img = Image.fromarray(color_img, 'RGB')
        color_img.save(save_path)

    def __test_laneline_img(self, img_path, save_path):
        pass

    def test(self):
        base_dir = os.path.join(self.configer.get('project_dir'),
                                'val/results/seg', self.configer.get('dataset'), 'test')
        if not os.path.exists(base_dir):
            os.makedirs(base_dir)

        test_img = self.configer.get('test_img')
        test_dir = self.configer.get('test_dir')
        if test_img is None and test_dir is None:
            Log.error('test_img & test_dir not exists.')
            exit(1)

        if test_img is not None and test_dir is not None:
            Log.error('Either test_img or test_dir.')
            exit(1)

        if test_img is not None:
            filename = test_img.rstrip().split('/')[-1]
            save_path = os.path.join(base_dir, filename)
            self.__test_img(test_img, save_path)

        else:
            for filename in self.__list_dir(test_dir):
                image_path = os.path.join(test_dir, filename)
                save_path = os.path.join(base_dir, filename)
                self.__test_img(image_path, save_path)

    def __create_cityscape_submission(self, test_dir=None, base_dir=None):
        label_list = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33]

        for filename in self.__list_dir(test_dir):
            image_path = os.path.join(test_dir, filename)
            save_path = os.path.join(base_dir, filename)
            result = self.forward(image_path)
            width = self.configer.get('data', 'input_size')[0] // self.configer.get('network', 'stride')
            height = self.configer.get('data', 'input_size')[1] // self.configer.get('network', 'stride')
            label_dst = np.ones((height, width), dtype=np.uint8) * 255
            for i in range(self.configer.get('data', 'num_classes')):
                label_dst[result == i] = label_list[i]

            label_img = np.array(label_dst, dtype=np.uint8)
            label_img = Image.fromarray(label_img, 'P')
            label_img.save(save_path)

    def create_submission(self, test_dir=None):
        base_dir = os.path.join(self.configer.get('project_dir'),
                                'val/results/seg', self.configer.get('dataset'), 'submission')
        if not os.path.exists(base_dir):
            os.makedirs(base_dir)



        if self.configer.get('dataset') == 'cityscape':
            self.__create_cityscape_submission(test_dir, base_dir)

        else:
            Log.error('Dataset: {} is not valid.'.format(self.configer.get('dataset')))
            exit(1)

    def __list_dir(self, dir_name):
        filename_list = list()
        for item in os.listdir(dir_name):
            if os.path.isdir(item):
                for filename in os.listdir(os.path.join(dir_name, item)):
                    filename_list.append('{}/{}'.format(item, filename))

            else:
                filename_list.append(item)

        return filename_list
Exemple #10
0
class FCNSegmentorTest(object):
    def __init__(self, configer):
        self.configer = configer

        self.seg_visualizer = SegVisualizer(configer)
        self.seg_parser = SegParser(configer)
        self.seg_model_manager = SegModelManager(configer)
        self.seg_data_loader = SegDataLoader(configer)
        self.module_utilizer = ModuleUtilizer(configer)
        self.device = torch.device(
            'cpu' if self.configer.get('gpu') is None else 'cuda')
        self.seg_net = None

    def init_model(self):
        self.seg_net = self.seg_model_manager.semantic_segmentor()
        self.seg_net = self.module_utilizer.load_net(self.seg_net)
        self.seg_net.eval()

    def __test_img(self, image_path, save_path):
        image = ImageHelper.pil_open_rgb(image_path)
        ori_width, ori_height = image.size
        image = Scale(size=self.configer.get('data', 'input_size'))(image)
        image = ToTensor()(image)
        image = Normalize(mean=self.configer.get('trans_params', 'mean'),
                          std=self.configer.get('trans_params', 'std'))(image)
        with torch.no_grad():
            inputs = image.unsqueeze(0).to(self.device)
            results = self.seg_net.forward(inputs)

            label_map = results.data.cpu().numpy().argmax(axis=1)[0].squeeze()

            label_img = np.array(label_map, dtype=np.uint8)
            if not self.configer.is_empty('details', 'label_list'):
                label_img = self.__relabel(label_img)

            label_img = Image.fromarray(label_img, 'P')
            label_img = label_img.resize((ori_width, ori_height),
                                         Image.NEAREST)
            label_img.save(save_path)

    def __relabel(self, label_map):
        height, width = label_map.shape
        label_dst = np.zeros((height, width), dtype=np.uint8)
        for i in range(self.configer.get('data', 'num_classes')):
            label_dst[label_map == i] = self.configer.get(
                'details', 'label_list')[i]

        label_dst = np.array(label_dst, dtype=np.uint8)

        return label_dst

    def test(self):
        base_dir = os.path.join(self.configer.get('output_dir'),
                                'val/results/seg',
                                self.configer.get('dataset'))

        test_img = self.configer.get('test_img')
        test_dir = self.configer.get('test_dir')
        if test_img is None and test_dir is None:
            Log.error('test_img & test_dir not exists.')
            exit(1)

        if test_img is not None and test_dir is not None:
            Log.error('Either test_img or test_dir.')
            exit(1)

        if test_img is not None:
            base_dir = os.path.join(base_dir, 'test_img')
            if not os.path.exists(base_dir):
                os.makedirs(base_dir)

            filename = test_img.rstrip().split('/')[-1]
            save_path = os.path.join(base_dir, filename)
            self.__test_img(test_img, save_path)

        else:
            base_dir = os.path.join(base_dir, 'test_dir',
                                    test_dir.rstrip('/').split('/')[-1])
            if not os.path.exists(base_dir):
                os.makedirs(base_dir)

            for filename in FileHelper.list_dir(test_dir):
                image_path = os.path.join(test_dir, filename)
                save_path = os.path.join(base_dir, filename)
                if not os.path.exists(os.path.dirname(save_path)):
                    os.makedirs(os.path.dirname(save_path))

                self.__test_img(image_path, save_path)

    def debug(self):
        base_dir = os.path.join(self.configer.get('project_dir'),
                                'vis/results/seg',
                                self.configer.get('dataset'), 'debug')

        if not os.path.exists(base_dir):
            os.makedirs(base_dir)

        val_data_loader = self.seg_data_loader.get_valloader()

        count = 0
        for i, (inputs, targets) in enumerate(val_data_loader):
            for j in range(inputs.size(0)):
                count = count + 1
                if count > 20:
                    exit(1)

                ori_img = DeNormalize(
                    mean=self.configer.get('trans_params', 'mean'),
                    std=self.configer.get('trans_params', 'std'))(inputs[j])
                ori_img = ori_img.numpy().transpose(1, 2, 0).astype(np.uint8)

                image_bgr = cv2.cvtColor(ori_img, cv2.COLOR_RGB2BGR)
                label_map = targets[j].numpy()
                image_canvas = self.seg_parser.colorize(label_map,
                                                        image_canvas=image_bgr)
                cv2.imwrite(
                    os.path.join(base_dir, '{}_{}_vis.png'.format(i, j)),
                    image_canvas)
                cv2.imshow('main', image_canvas)
                cv2.waitKey()
class FCNSegmentor(object):
    """
      The class for Pose Estimation. Include train, val, val & predict.
    """
    def __init__(self, configer):
        self.configer = configer
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.train_losses = AverageMeter()
        self.val_losses = AverageMeter()
        self.seg_running_score = SegRunningScore(configer)
        self.seg_visualizer = SegVisualizer(configer)
        self.seg_loss_manager = SegLossManager(configer)
        self.module_utilizer = ModuleUtilizer(configer)
        self.seg_model_manager = SegModelManager(configer)
        self.seg_data_loader = SegDataLoader(configer)
        self.optim_scheduler = OptimScheduler(configer)

        self.seg_net = None
        self.train_loader = None
        self.val_loader = None
        self.optimizer = None
        self.scheduler = None

    def init_model(self):
        self.seg_net = self.seg_model_manager.semantic_segmentor()
        self.seg_net = self.module_utilizer.load_net(self.seg_net)

        self.optimizer, self.scheduler = self.optim_scheduler.init_optimizer(
            self._get_parameters())

        self.train_loader = self.seg_data_loader.get_trainloader()
        self.val_loader = self.seg_data_loader.get_valloader()

        self.pixel_loss = self.seg_loss_manager.get_seg_loss(
            'cross_entropy_loss')

    def _get_parameters(self):

        return self.seg_net.parameters()

    def __train(self):
        """
          Train function of every epoch during train phase.
        """
        if self.configer.get(
                'network',
                'resume') is not None and self.configer.get('iters') == 0:
            self.__val()

        self.seg_net.train()
        start_time = time.time()
        # Adjust the learning rate after every epoch.
        self.configer.plus_one('epoch')
        self.scheduler.step(self.configer.get('epoch'))

        for i, (inputs, targets) in enumerate(self.train_loader):
            self.data_time.update(time.time() - start_time)
            # Change the data type.

            inputs, targets = self.module_utilizer.to_device(inputs, targets)

            # Forward pass.
            outputs = self.seg_net(inputs)

            # Compute the loss of the train batch & backward.
            loss = self.pixel_loss(outputs, targets)
            self.train_losses.update(loss.item(), inputs.size(0))
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # Update the vars of the train phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()
            self.configer.plus_one('iters')

            # Print the log info & reset the states.
            if self.configer.get('iters') % self.configer.get(
                    'solver', 'display_iter') == 0:
                Log.info(
                    'Train Epoch: {0}\tTrain Iteration: {1}\t'
                    'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t'
                    'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n'
                    'Learning rate = {3}\tLoss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'
                    .format(self.configer.get('epoch'),
                            self.configer.get('iters'),
                            self.configer.get('solver', 'display_iter'),
                            self.scheduler.get_lr(),
                            batch_time=self.batch_time,
                            data_time=self.data_time,
                            loss=self.train_losses))
                self.batch_time.reset()
                self.data_time.reset()
                self.train_losses.reset()

            # Check to val the current model.
            if self.val_loader is not None and \
               self.configer.get('iters') % self.configer.get('solver', 'test_interval') == 0:
                self.__val()

    def __val(self):
        """
          Validation function during the train phase.
        """
        self.seg_net.eval()
        start_time = time.time()
        with torch.no_grad():
            for j, (inputs, targets) in enumerate(self.val_loader):
                # Change the data type.
                inputs, targets = self.module_utilizer.to_device(
                    inputs, targets)
                # Forward pass.
                outputs = self.seg_net(inputs)
                # Compute the loss of the val batch.
                loss = self.pixel_loss(outputs, targets)

                self.val_losses.update(loss.item(), inputs.size(0))
                self.seg_running_score.update(
                    outputs.max(1)[1].unsqueeze(1).data, targets.data)

                # Update the vars of the val phase.
                self.batch_time.update(time.time() - start_time)
                start_time = time.time()

            self.configer.update_value(['performace'],
                                       self.seg_running_score.get_mean_iou())
            self.configer.update_value(['val_loss'], self.val_losses.avg)
            self.module_utilizer.save_net(self.seg_net, metric='performance')
            self.module_utilizer.save_net(self.seg_net, metric='val_loss')

            # Print the log info & reset the states.
            Log.info(
                'Test Time {batch_time.sum:.3f}s, ({batch_time.avg:.3f})\t'
                'Loss {loss.avg:.8f}\n'.format(batch_time=self.batch_time,
                                               loss=self.val_losses))
            Log.info('Mean IOU: {}'.format(
                self.seg_running_score.get_mean_iou()))
            self.batch_time.reset()
            self.val_losses.reset()
            self.seg_running_score.reset()
            self.seg_net.train()

    def train(self):
        cudnn.benchmark = True
        while self.configer.get('epoch') < self.configer.get(
                'solver', 'max_epoch'):
            self.__train()
            if self.configer.get('epoch') == self.configer.get(
                    'solver', 'max_epoch'):
                break
class FCNSegmentorTest(object):
    def __init__(self, configer):
        self.configer = configer
        self.blob_helper = BlobHelper(configer)
        self.seg_visualizer = SegVisualizer(configer)
        self.seg_parser = SegParser(configer)
        self.seg_model_manager = SegModelManager(configer)
        self.seg_data_loader = DataLoader(configer)
        self.device = torch.device('cpu' if self.configer.get('gpu') is None else 'cuda')
        self.seg_net = None

        self._init_model()

    def _init_model(self):
        self.seg_net = self.seg_model_manager.semantic_segmentor()
        self.seg_net = RunnerHelper.load_net(self, self.seg_net)
        self.seg_net.eval()

    def _get_blob(self, ori_image, scale=None):
        assert scale is not None
        image = None
        if self.configer.exists('test', 'input_size'):
            image = self.blob_helper.make_input(image=ori_image,
                                                input_size=self.configer.get('test', 'input_size'),
                                                scale=scale)

        elif self.configer.exists('test', 'min_side_length') and not self.configer.exists('test', 'max_side_length'):
            image = self.blob_helper.make_input(image=ori_image,
                                                min_side_length=self.configer.get('test', 'min_side_length'),
                                                scale=scale)

        elif not self.configer.exists('test', 'min_side_length') and self.configer.exists('test', 'max_side_length'):
            image = self.blob_helper.make_input(image=ori_image,
                                                max_side_length=self.configer.get('test', 'max_side_length'),
                                                scale=scale)

        elif self.configer.exists('test', 'min_side_length') and self.configer.exists('test', 'max_side_length'):
            image = self.blob_helper.make_input(image=ori_image,
                                                min_side_length=self.configer.get('test', 'min_side_length'),
                                                max_side_length=self.configer.get('test', 'max_side_length'),
                                                scale=scale)

        else:
            Log.error('Test setting error')
            exit(1)

        b, c, h, w = image.size()
        border_hw = [h, w]
        if self.configer.exists('test', 'fit_stride'):
            stride = self.configer.get('test', 'fit_stride')

            pad_w = 0 if (w % stride == 0) else stride - (w % stride)  # right
            pad_h = 0 if (h % stride == 0) else stride - (h % stride)  # down

            expand_image = torch.zeros((b, c, h + pad_h, w + pad_w)).to(image.device)
            expand_image[:, :, 0:h, 0:w] = image
            image = expand_image

        return image, border_hw

    def test_img(self, image_path, label_path, vis_path, raw_path):
        Log.info('Image Path: {}'.format(image_path))
        ori_image = ImageHelper.read_image(image_path,
                                           tool=self.configer.get('data', 'image_tool'),
                                           mode=self.configer.get('data', 'input_mode'))
        total_logits = None
        if self.configer.get('test', 'mode') == 'ss_test':
            total_logits = self.ss_test(ori_image)

        elif self.configer.get('test', 'mode') == 'sscrop_test':
            total_logits = self.sscrop_test(ori_image)

        elif self.configer.get('test', 'mode') == 'ms_test':
            total_logits = self.ms_test(ori_image)

        elif self.configer.get('test', 'mode') == 'mscrop_test':
            total_logits = self.mscrop_test(ori_image)

        else:
            Log.error('Invalid test mode:{}'.format(self.configer.get('test', 'mode')))
            exit(1)

        label_map = np.argmax(total_logits, axis=-1)
        label_img = np.array(label_map, dtype=np.uint8)
        ori_img_bgr = ImageHelper.get_cv2_bgr(ori_image, mode=self.configer.get('data', 'input_mode'))
        image_canvas = self.seg_parser.colorize(label_img, image_canvas=ori_img_bgr)
        ImageHelper.save(image_canvas, save_path=vis_path)
        ImageHelper.save(ori_image, save_path=raw_path)

        if self.configer.exists('data', 'label_list'):
            label_img = self.__relabel(label_img)

        if self.configer.exists('data', 'reduce_zero_label') and self.configer.get('data', 'reduce_zero_label'):
            label_img = label_img + 1
            label_img = label_img.astype(np.uint8)

        label_img = Image.fromarray(label_img, 'P')
        Log.info('Label Path: {}'.format(label_path))
        ImageHelper.save(label_img, label_path)

    def ss_test(self, ori_image):
        ori_width, ori_height = ImageHelper.get_size(ori_image)
        total_logits = np.zeros((ori_height, ori_width, self.configer.get('data', 'num_classes')), np.float32)
        image, border_hw = self._get_blob(ori_image, scale=1.0)
        results = self._predict(image)
        results = cv2.resize(results[:border_hw[0], :border_hw[1]],
                             (ori_width, ori_height), interpolation=cv2.INTER_CUBIC)
        total_logits += results
        return total_logits

    def sscrop_test(self, ori_image):
        ori_width, ori_height = ImageHelper.get_size(ori_image)
        total_logits = np.zeros((ori_height, ori_width, self.configer.get('data', 'num_classes')), np.float32)
        image, _ = self._get_blob(ori_image, scale=1.0)
        crop_size = self.configer.get('test', 'crop_size')
        if image.size()[3] > crop_size[0] and image.size()[2] > crop_size[1]:
            results = self._crop_predict(image, crop_size)
        else:
            results = self._predict(image)

        results = cv2.resize(results, (ori_width, ori_height), interpolation=cv2.INTER_CUBIC)
        total_logits += results
        return total_logits

    def mscrop_test(self, ori_image):
        ori_width, ori_height = ImageHelper.get_size(ori_image)
        total_logits = np.zeros((ori_height, ori_width, self.configer.get('data', 'num_classes')), np.float32)
        for scale in self.configer.get('test', 'scale_search'):
            image, _ = self._get_blob(ori_image, scale=scale)
            crop_size = self.configer.get('test', 'crop_size')
            if image.size()[3] > crop_size[0] and image.size()[2] > crop_size[1]:
                results = self._crop_predict(image, crop_size)
            else:
                results = self._predict(image)

            results = cv2.resize(results, (ori_width, ori_height), interpolation=cv2.INTER_CUBIC)
            total_logits += results

        return total_logits

    def ms_test(self, ori_image):
        ori_width, ori_height = ImageHelper.get_size(ori_image)
        total_logits = np.zeros((ori_height, ori_width, self.configer.get('data', 'num_classes')), np.float32)
        for scale in self.configer.get('test', 'scale_search'):
            image, border_hw = self._get_blob(ori_image, scale=scale)
            results = self._predict(image)
            results = cv2.resize(results[:border_hw[0], :border_hw[1]],
                                 (ori_width, ori_height), interpolation=cv2.INTER_CUBIC)
            total_logits += results

        if self.configer.get('data', 'image_tool') == 'cv2':
            mirror_image = cv2.flip(ori_image, 1)
        else:
            mirror_image = ori_image.transpose(Image.FLIP_LEFT_RIGHT)

        image, border_hw = self._get_blob(mirror_image, scale=1.0)
        results = self._predict(image)
        results = results[:border_hw[0], :border_hw[1]]
        results = cv2.resize(results[:, ::-1], (ori_width, ori_height), interpolation=cv2.INTER_CUBIC)
        total_logits += results
        return total_logits

    def _crop_predict(self, image, crop_size):
        height, width = image.size()[2:]
        np_image = image.squeeze(0).permute(1, 2, 0).cpu().numpy()
        height_starts = self._decide_intersection(height, crop_size[1])
        width_starts = self._decide_intersection(width, crop_size[0])
        split_crops = []
        for height in height_starts:
            for width in width_starts:
                image_crop = np_image[height:height + crop_size[1], width:width + crop_size[0]]
                split_crops.append(image_crop[np.newaxis, :])

        split_crops = np.concatenate(split_crops, axis=0)  # (n, crop_image_size, crop_image_size, 3)
        inputs = torch.from_numpy(split_crops).permute(0, 3, 1, 2).to(self.device)
        with torch.no_grad():
            results = self.seg_net.forward(inputs)
            results = results[-1].permute(0, 2, 3, 1).cpu().numpy()

        reassemble = np.zeros((np_image.shape[0], np_image.shape[1], results.shape[-1]), np.float32)
        index = 0
        for height in height_starts:
            for width in width_starts:
                reassemble[height:height+crop_size[1], width:width+crop_size[0]] += results[index]
                index += 1

        return reassemble

    def _decide_intersection(self, total_length, crop_length):
        stride = int(crop_length * self.configer.get('test', 'crop_stride_ratio'))            # set the stride as the paper do
        times = (total_length - crop_length) // stride + 1
        cropped_starting = []
        for i in range(times):
            cropped_starting.append(stride*i)

        if total_length - cropped_starting[-1] > crop_length:
            cropped_starting.append(total_length - crop_length)  # must cover the total image

        return cropped_starting

    def _predict(self, inputs):
        with torch.no_grad():
            results = self.seg_net.forward(inputs)
            results = results[-1].squeeze(0).permute(1, 2, 0).cpu().numpy()

        return results

    def __relabel(self, label_map):
        height, width = label_map.shape
        label_dst = np.zeros((height, width), dtype=np.uint8)
        for i in range(self.configer.get('data', 'num_classes')):
            label_dst[label_map == i] = self.configer.get('data', 'label_list')[i]

        label_dst = np.array(label_dst, dtype=np.uint8)

        return label_dst

    def debug(self, vis_dir):
        count = 0
        for i, data_dict in enumerate(self.seg_data_loader.get_trainloader()):
            inputs = data_dict['img']
            targets = data_dict['labelmap']
            for j in range(inputs.size(0)):
                count = count + 1
                if count > 20:
                    exit(1)

                image_bgr = self.blob_helper.tensor2bgr(inputs[j])
                label_map = targets[j].numpy()
                image_canvas = self.seg_parser.colorize(label_map, image_canvas=image_bgr)
                cv2.imwrite(os.path.join(vis_dir, '{}_{}_vis.png'.format(i, j)), image_canvas)
                cv2.imshow('main', image_canvas)
                cv2.waitKey()
class FCNSegmentor(object):
    """
      The class for Pose Estimation. Include train, val, val & predict.
    """
    def __init__(self, configer):
        self.configer = configer
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.train_losses = AverageMeter()
        self.val_losses = AverageMeter()
        self.seg_visualizer = SegVisualizer(configer)
        self.seg_loss_manager = SegLossManager(configer)
        self.module_utilizer = ModuleUtilizer(configer)
        self.seg_model_manager = SegModelManager(configer)
        self.seg_data_loader = SegDataLoader(configer)

        self.seg_net = None
        self.train_loader = None
        self.val_loader = None
        self.optimizer = None
        self.lr = None
        self.iters = None

    def init_model(self):
        self.seg_net = self.seg_model_manager.seg_net()
        self.iters = 0
        self.seg_net, _ = self.module_utilizer.load_net(self.seg_net)

        self.optimizer, self.lr = self.module_utilizer.update_optimizer(self.seg_net, self.iters)

        if self.configer.get('dataset') == 'cityscape':
            self.train_loader = self.seg_data_loader.get_trainloader(FSCityScapeLoader)
            self.val_loader = self.seg_data_loader.get_valloader(FSCityScapeLoader)

        else:
            Log.error('Dataset: {} is not valid!'.format(self.configer.get('dataset')))
            exit(1)

        self.pixel_loss = self.seg_loss_manager.get_seg_loss('cross_entropy_loss')

    def __train(self):
        """
          Train function of every epoch during train phase.
        """
        self.seg_net.train()
        start_time = time.time()

        # data_tuple: (inputs, heatmap, maskmap, tagmap, num_objects)
        for i, data_tuple in enumerate(self.train_loader):
            self.data_time.update(time.time() - start_time)
            # Change the data type.
            if len(data_tuple) < 2:
                Log.error('Train Loader Error!')
                exit(0)

            inputs = Variable(data_tuple[0].cuda(async=True))
            targets = Variable(data_tuple[1].cuda(async=True))

            # Forward pass.
            outputs = self.seg_net(inputs)

            # Compute the loss of the train batch & backward.
            loss_pixel = self.pixel_loss(outputs, targets)
            loss = loss_pixel
            self.train_losses.update(loss.data[0], inputs.size(0))
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # Update the vars of the train phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()
            self.iters += 1

            # Print the log info & reset the states.
            if self.iters % self.configer.get('solver', 'display_iter') == 0:
                Log.info('Train Iteration: {0}\t'
                         'Time {batch_time.sum:.3f}s / {1}iters, ({batch_time.avg:.3f})\t'
                         'Data load {data_time.sum:.3f}s / {1}iters, ({data_time.avg:3f})\n'
                         'Learning rate = {2}\n'
                         'Loss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'.format(
                         self.iters, self.configer.get('solver', 'display_iter'),
                         self.lr, batch_time=self.batch_time,
                         data_time=self.data_time, loss=self.train_losses))
                self.batch_time.reset()
                self.data_time.reset()
                self.train_losses.reset()

            # Check to val the current model.
            if self.val_loader is not None and \
               self.iters % self.configer.get('solver', 'test_interval') == 0:
                self.__val()

            self.optimizer, self.lr = self.module_utilizer.update_optimizer(self.seg_net, self.iters)

    def __val(self):
        """
          Validation function during the train phase.
        """
        self.seg_net.eval()
        start_time = time.time()

        for j, data_tuple in enumerate(self.val_loader):
            # Change the data type.
            inputs = Variable(data_tuple[0].cuda(async=True), volatile=True)
            targets = Variable(data_tuple[1].cuda(async=True), volatile=True)
            # Forward pass.
            outputs = self.seg_net(inputs)
            # Compute the loss of the val batch.
            loss_pixel = self.pixel_loss(outputs, targets)
            loss = loss_pixel

            self.val_losses.update(loss.data[0], inputs.size(0))

            # Update the vars of the val phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()

        self.module_utilizer.save_net(self.seg_net, self.iters)
        # Print the log info & reset the states.
        Log.info(
            'Test Time {batch_time.sum:.3f}s, ({batch_time.avg:.3f})\t'
            'Loss {loss.avg:.8f}\n'.format(
            batch_time=self.batch_time, loss=self.val_losses))
        self.batch_time.reset()
        self.val_losses.reset()
        self.seg_net.train()

    def train(self):
        cudnn.benchmark = True
        while self.iters < self.configer.get('solver', 'max_iter'):
            self.__train()
            if self.iters == self.configer.get('solver', 'max_iter'):
                break
Exemple #14
0
class FCNSegmentor(object):
    """
      The class for Pose Estimation. Include train, val, val & predict.
    """
    def __init__(self, configer):
        self.configer = configer
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.train_losses = AverageMeter()
        self.val_losses = AverageMeter()
        self.seg_running_score = SegRunningScore(configer)
        self.seg_visualizer = SegVisualizer(configer)
        self.seg_loss_manager = SegLossManager(configer)
        self.module_utilizer = ModuleUtilizer(configer)
        self.data_transformer = DataTransformer(configer)
        self.seg_model_manager = SegModelManager(configer)
        self.seg_data_loader = SegDataLoader(configer)
        self.optim_scheduler = OptimScheduler(configer)

        self.seg_net = None
        self.train_loader = None
        self.val_loader = None
        self.optimizer = None
        self.scheduler = None

        self._init_model()

    def _init_model(self):
        self.seg_net = self.seg_model_manager.semantic_segmentor()
        self.seg_net = self.module_utilizer.load_net(self.seg_net)

        self.optimizer, self.scheduler = self.optim_scheduler.init_optimizer(
            self._get_parameters())

        self.train_loader = self.seg_data_loader.get_trainloader()
        self.val_loader = self.seg_data_loader.get_valloader()

        self.pixel_loss = self.seg_loss_manager.get_seg_loss('fcn_seg_loss')

        if self.configer.get('network', 'bn_type') == 'syncbn':
            self.pixel_loss = DataParallelCriterion(self.pixel_loss).cuda()

    def _get_parameters(self):
        lr_1 = []
        lr_10 = []
        params_dict = dict(self.seg_net.named_parameters())
        for key, value in params_dict.items():
            if 'backbone.' not in key:
                lr_10.append(value)
            else:
                lr_1.append(value)

        params = [{
            'params': lr_1,
            'lr': self.configer.get('lr', 'base_lr')
        }, {
            'params': lr_10,
            'lr': self.configer.get('lr', 'base_lr') * 1.0
        }]
        return params

    def __train(self):
        """
          Train function of every epoch during train phase.
        """
        self.seg_net.train()
        start_time = time.time()
        # Adjust the learning rate after every epoch.

        self.scheduler.step(self.configer.get('epoch'))

        for i, data_dict in enumerate(self.train_loader):
            inputs = data_dict['img']
            targets = data_dict['labelmap']
            self.data_time.update(time.time() - start_time)
            # Change the data type.

            inputs, targets = self.module_utilizer.to_device(inputs, targets)

            # Forward pass.
            outputs = self.seg_net(inputs)

            # Compute the loss of the train batch & backward.
            loss = self.pixel_loss(outputs, targets)
            self.train_losses.update(loss.item(), inputs.size(0))
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # Update the vars of the train phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()
            self.configer.plus_one('iters')

            # Print the log info & reset the states.
            if self.configer.get('iters') % self.configer.get(
                    'solver', 'display_iter') == 0:
                Log.info(
                    'Train Epoch: {0}\tTrain Iteration: {1}\t'
                    'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t'
                    'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n'
                    'Learning rate = {3}\tLoss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'
                    .format(self.configer.get('epoch'),
                            self.configer.get('iters'),
                            self.configer.get('solver', 'display_iter'),
                            self.scheduler.get_lr(),
                            batch_time=self.batch_time,
                            data_time=self.data_time,
                            loss=self.train_losses))
                self.batch_time.reset()
                self.data_time.reset()
                self.train_losses.reset()

            # Check to val the current model.
            if self.val_loader is not None and \
               self.configer.get('iters') % self.configer.get('solver', 'test_interval') == 0:
                self.__val()

        self.configer.plus_one('epoch')

    def __val(self):
        """
          Validation function during the train phase.
        """
        self.seg_net.eval()
        start_time = time.time()

        for j, data_dict in enumerate(self.val_loader):
            inputs = data_dict['img']
            targets = data_dict['labelmap']

            with torch.no_grad():
                # Change the data type.
                inputs, targets = self.module_utilizer.to_device(
                    inputs, targets)
                # Forward pass.
                outputs = self.seg_net(inputs)
                # Compute the loss of the val batch.
                loss = self.pixel_loss(outputs, targets)

                outputs = self.module_utilizer.gather(outputs)
                pred = outputs[0]

            self.val_losses.update(loss.item(), inputs.size(0))
            self.seg_running_score.update(
                pred.max(1)[1].cpu().numpy(),
                targets.cpu().numpy())

            # Update the vars of the val phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()

        self.configer.update_value(['performance'],
                                   self.seg_running_score.get_mean_iou())
        self.configer.update_value(['val_loss'], self.val_losses.avg)
        self.module_utilizer.save_net(self.seg_net, save_mode='performance')
        self.module_utilizer.save_net(self.seg_net, save_mode='val_loss')

        # Print the log info & reset the states.
        Log.info('Test Time {batch_time.sum:.3f}s, ({batch_time.avg:.3f})\t'
                 'Loss {loss.avg:.8f}\n'.format(batch_time=self.batch_time,
                                                loss=self.val_losses))
        Log.info('Mean IOU: {}\n'.format(
            self.seg_running_score.get_mean_iou()))
        self.batch_time.reset()
        self.val_losses.reset()
        self.seg_running_score.reset()
        self.seg_net.train()

    def train(self):
        cudnn.benchmark = True
        if self.configer.get('network',
                             'resume') is not None and self.configer.get(
                                 'network', 'resume_val'):
            self.__val()

        while self.configer.get('epoch') < self.configer.get(
                'solver', 'max_epoch'):
            self.__train()
            if self.configer.get('epoch') == self.configer.get(
                    'solver', 'max_epoch'):
                break
class FCNSegmentorTest(object):
    def __init__(self, configer):
        self.configer = configer
        self.blob_helper = BlobHelper(configer)
        self.seg_visualizer = SegVisualizer(configer)
        self.seg_parser = SegParser(configer)
        self.seg_model_manager = SegModelManager(configer)
        self.seg_data_loader = SegDataLoader(configer)
        self.module_utilizer = ModuleUtilizer(configer)
        self.data_transformer = DataTransformer(configer)
        self.device = torch.device('cpu' if self.configer.get('gpu') is None else 'cuda')
        self.seg_net = None

        self._init_model()

    def _init_model(self):
        self.seg_net = self.seg_model_manager.semantic_segmentor()
        self.seg_net = self.module_utilizer.load_net(self.seg_net)
        self.seg_net.eval()

    def __test_img(self, image_path, label_path, vis_path, raw_path):
        Log.info('Image Path: {}'.format(image_path))
        ori_image = ImageHelper.read_image(image_path,
                                           tool=self.configer.get('data', 'image_tool'),
                                           mode=self.configer.get('data', 'input_mode'))

        ori_width, ori_height = ImageHelper.get_size(ori_image)

        total_logits = np.zeros((ori_height, ori_width, self.configer.get('data', 'num_classes')), np.float32)
        for scale in self.configer.get('test', 'scale_search'):
            image = self.blob_helper.make_input(image=ori_image,
                                                input_size=self.configer.get('test', 'input_size'),
                                                scale=scale)
            if self.configer.get('test', 'crop_test'):
                crop_size = self.configer.get('test', 'crop_size')
                if image.size()[3] > crop_size[0] and image.size()[2] > crop_size[1]:
                    results = self._crop_predict(image, crop_size)
                else:
                    results = self._predict(image)
            else:
                results = self._predict(image)

            results = cv2.resize(results, (ori_width, ori_height), interpolation=cv2.INTER_LINEAR)
            total_logits += results

        if self.configer.get('test', 'mirror'):
            if self.configer.get('data', 'image_tool') == 'cv2':
                image = cv2.flip(ori_image, 1)
            else:
                image = ori_image.transpose(Image.FLIP_LEFT_RIGHT)

            image = self.blob_helper.make_input(image, input_size=self.configer.get('test', 'input_size'), scale=1.0)
            if self.configer.get('test', 'crop_test'):
                crop_size = self.configer.get('test', 'crop_size')
                if image.size()[3] > crop_size[0] and image.size()[2] > crop_size[1]:
                    results = self._crop_predict(image, crop_size)
                else:
                    results = self._predict(image)
            else:
                results = self._predict(image)

            results = cv2.resize(results[:, ::-1], (ori_width, ori_height), interpolation=cv2.INTER_LINEAR)
            total_logits += results

        label_map = np.argmax(total_logits, axis=-1)
        label_img = np.array(label_map, dtype=np.uint8)
        image_bgr = cv2.cvtColor(np.array(ori_image), cv2.COLOR_RGB2BGR)
        image_canvas = self.seg_parser.colorize(label_img, image_canvas=image_bgr)
        ImageHelper.save(image_canvas, save_path=vis_path)
        ImageHelper.save(ori_image, save_path=raw_path)

        if not self.configer.is_empty('data', 'label_list'):
            label_img = self.__relabel(label_img)

        label_img = Image.fromarray(label_img, 'P')
        Log.info('Label Path: {}'.format(label_path))
        ImageHelper.save(label_img, label_path)

    def _crop_predict(self, image, crop_size):
        height, width = image.size()[2:]
        np_image = image.squeeze(0).permute(1, 2, 0).cpu().numpy()
        height_starts = self._decide_intersection(height, crop_size[1])
        width_starts = self._decide_intersection(width, crop_size[0])
        split_crops = []
        for height in height_starts:
            for width in width_starts:
                image_crop = np_image[height:height + crop_size[1], width:width + crop_size[0]]
                split_crops.append(image_crop[np.newaxis, :])

        split_crops = np.concatenate(split_crops, axis=0)  # (n, crop_image_size, crop_image_size, 3)
        inputs = torch.from_numpy(split_crops).permute(0, 3, 1, 2).to(self.device)
        with torch.no_grad():
            results = self.seg_net.forward(inputs)
            results = results[0].permute(0, 2, 3, 1).cpu().numpy()

        reassemble = np.zeros((np_image.shape[0], np_image.shape[1], results.shape[-1]), np.float32)
        index = 0
        for height in height_starts:
            for width in width_starts:
                reassemble[height:height+crop_size[1], width:width+crop_size[0]] += results[index]
                index += 1

        return reassemble

    def _decide_intersection(self, total_length, crop_length):
        stride = int(crop_length * self.configer.get('test', 'crop_stride_ratio'))            # set the stride as the paper do
        times = (total_length - crop_length) // stride + 1
        cropped_starting = []
        for i in range(times):
            cropped_starting.append(stride*i)

        if total_length - cropped_starting[-1] > crop_length:
            cropped_starting.append(total_length - crop_length)  # must cover the total image

        return cropped_starting

    def _predict(self, inputs):
        with torch.no_grad():
            results = self.seg_net.forward(inputs)
            results = results[0].squeeze().permute(1, 2, 0).cpu().numpy()

        return results

    def __relabel(self, label_map):
        height, width = label_map.shape
        label_dst = np.zeros((height, width), dtype=np.uint8)
        for i in range(self.configer.get('data', 'num_classes')):
            label_dst[label_map == i] = self.configer.get('data', 'label_list')[i]

        label_dst = np.array(label_dst, dtype=np.uint8)

        return label_dst

    def test(self):
        base_dir = os.path.join(self.configer.get('project_dir'),
                                'val/results/seg', self.configer.get('dataset'))

        test_img = self.configer.get('test_img')
        test_dir = self.configer.get('test_dir')
        if test_img is None and test_dir is None:
            Log.error('test_img & test_dir not exists.')
            exit(1)

        if test_img is not None and test_dir is not None:
            Log.error('Either test_img or test_dir.')
            exit(1)

        if test_img is not None:
            base_dir = os.path.join(base_dir, 'test_img')
            filename = test_img.rstrip().split('/')[-1]
            label_path = os.path.join(base_dir, 'label', '{}.png'.format('.'.join(filename.split('.')[:-1])))
            raw_path = os.path.join(base_dir, 'raw', filename)
            vis_path = os.path.join(base_dir, 'vis', '{}_vis.png'.format('.'.join(filename.split('.')[:-1])))
            FileHelper.make_dirs(label_path, is_file=True)
            FileHelper.make_dirs(raw_path, is_file=True)
            FileHelper.make_dirs(vis_path, is_file=True)

            self.__test_img(test_img, label_path, vis_path, raw_path)

        else:
            base_dir = os.path.join(base_dir, 'test_dir', test_dir.rstrip('/').split('/')[-1])
            FileHelper.make_dirs(base_dir)

            for filename in FileHelper.list_dir(test_dir):
                image_path = os.path.join(test_dir, filename)
                label_path = os.path.join(base_dir, 'label', '{}.png'.format('.'.join(filename.split('.')[:-1])))
                raw_path = os.path.join(base_dir, 'raw', filename)
                vis_path = os.path.join(base_dir, 'vis', '{}_vis.png'.format('.'.join(filename.split('.')[:-1])))
                FileHelper.make_dirs(label_path, is_file=True)
                FileHelper.make_dirs(raw_path, is_file=True)
                FileHelper.make_dirs(vis_path, is_file=True)

                self.__test_img(image_path, label_path, vis_path, raw_path)

    def debug(self):
        base_dir = os.path.join(self.configer.get('project_dir'),
                                'vis/results/seg', self.configer.get('dataset'), 'debug')

        if not os.path.exists(base_dir):
            os.makedirs(base_dir)

        count = 0
        for i, data_dict in enumerate(self.seg_data_loader.get_trainloader()):
            inputs = data_dict['img']
            targets = data_dict['labelmap']
            for j in range(inputs.size(0)):
                count = count + 1
                if count > 20:
                    exit(1)

                image_bgr = self.blob_helper.tensor2bgr(inputs[j])
                label_map = targets[j].numpy()
                image_canvas = self.seg_parser.colorize(label_map, image_canvas=image_bgr)
                cv2.imwrite(os.path.join(base_dir, '{}_{}_vis.png'.format(i, j)), image_canvas)
                cv2.imshow('main', image_canvas)
                cv2.waitKey()
class FCNSegmentor(object):
    """
      The class for Pose Estimation. Include train, val, val & predict.
    """
    def __init__(self, configer):
        self.configer = configer
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.train_losses = AverageMeter()
        self.val_losses = AverageMeter()
        self.seg_running_score = SegRunningScore(configer)
        self.seg_visualizer = SegVisualizer(configer)
        self.seg_loss_manager = LossManager(configer)
        self.seg_model_manager = SegModelManager(configer)
        self.seg_data_loader = DataLoader(configer)

        self.seg_net = None
        self.train_loader = None
        self.val_loader = None
        self.optimizer = None
        self.scheduler = None
        self.runner_state = dict()

        self._init_model()

    def _init_model(self):
        self.seg_net = self.seg_model_manager.semantic_segmentor()
        self.seg_net = RunnerHelper.load_net(self, self.seg_net)

        self.optimizer, self.scheduler = Trainer.init(self,
                                                      self._get_parameters())

        self.train_loader = self.seg_data_loader.get_trainloader()
        self.val_loader = self.seg_data_loader.get_valloader()

        self.pixel_loss = self.seg_loss_manager.get_seg_loss()

    def _get_parameters(self):
        lr_1 = []
        lr_10 = []
        params_dict = dict(self.seg_net.named_parameters())
        for key, value in params_dict.items():
            if 'backbone' not in key:
                lr_10.append(value)
            else:
                lr_1.append(value)

        params = [{
            'params': lr_1,
            'lr': self.configer.get('lr', 'base_lr')
        }, {
            'params': lr_10,
            'lr': self.configer.get('lr', 'base_lr') * 1.0
        }]
        return params

    def train(self):
        """
          Train function of every epoch during train phase.
        """
        self.seg_net.train()
        start_time = time.time()
        # Adjust the learning rate after every epoch.

        for i, data_dict in enumerate(self.train_loader):
            Trainer.update(self, backbone_list=(0, ))
            inputs = data_dict['img']
            targets = data_dict['labelmap']
            self.data_time.update(time.time() - start_time)
            # Change the data type.

            inputs, targets = RunnerHelper.to_device(self, inputs, targets)

            # Forward pass.
            outputs = self.seg_net(inputs)
            # outputs = self.module_utilizer.gather(outputs)
            # Compute the loss of the train batch & backward.
            loss = self.pixel_loss(outputs,
                                   targets,
                                   gathered=self.configer.get(
                                       'network', 'gathered'))
            self.train_losses.update(loss.item(), inputs.size(0))
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # Update the vars of the train phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()
            self.runner_state['iters'] += 1

            # Print the log info & reset the states.
            if self.configer.get('iters') % self.configer.get(
                    'solver', 'display_iter') == 0:
                Log.info(
                    'Train Epoch: {0}\tTrain Iteration: {1}\t'
                    'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t'
                    'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n'
                    'Learning rate = {3}\tLoss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'
                    .format(self.runner_state['epoch'],
                            self.runner_state['iters'],
                            self.configer.get('solver', 'display_iter'),
                            RunnerHelper.get_lr(self.optimizer),
                            batch_time=self.batch_time,
                            data_time=self.data_time,
                            loss=self.train_losses))
                self.batch_time.reset()
                self.data_time.reset()
                self.train_losses.reset()

            if self.configer.get('lr', 'metric') == 'iters' \
                    and self.runner_state['iters'] == self.configer.get('solver', 'max_iters'):
                break

            # Check to val the current model.
            if self.runner_state['iters'] % self.configer.get(
                    'solver', 'test_interval') == 0:
                self.val()

        self.runner_state['epoch'] += 1

    def val(self, data_loader=None):
        """
          Validation function during the train phase.
        """
        self.seg_net.eval()
        start_time = time.time()

        data_loader = self.val_loader if data_loader is None else data_loader
        for j, data_dict in enumerate(data_loader):
            inputs = data_dict['img']
            targets = data_dict['labelmap']

            with torch.no_grad():
                # Change the data type.
                inputs, targets = RunnerHelper.to_device(self, inputs, targets)
                # Forward pass.
                outputs = self.seg_net(inputs)
                # Compute the loss of the val batch.
                loss = self.pixel_loss(outputs,
                                       targets,
                                       gathered=self.configer.get(
                                           'network', 'gathered'))
                outputs = RunnerHelper.gather(self, outputs)

            self.val_losses.update(loss.item(), inputs.size(0))
            self._update_running_score(outputs[-1], data_dict['meta'])

            # Update the vars of the val phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()

        self.runner_state['performance'] = self.seg_running_score.get_mean_iou(
        )
        self.runner_state['val_loss'] = self.val_losses.avg
        RunnerHelper.save_net(
            self,
            self.seg_net,
            performance=self.seg_running_score.get_mean_iou(),
            val_loss=self.val_losses.avg)

        # Print the log info & reset the states.
        Log.info('Test Time {batch_time.sum:.3f}s, ({batch_time.avg:.3f})\t'
                 'Loss {loss.avg:.8f}\n'.format(batch_time=self.batch_time,
                                                loss=self.val_losses))
        Log.info('Mean IOU: {}\n'.format(
            self.seg_running_score.get_mean_iou()))
        Log.info('Pixel ACC: {}\n'.format(
            self.seg_running_score.get_pixel_acc()))
        self.batch_time.reset()
        self.val_losses.reset()
        self.seg_running_score.reset()
        self.seg_net.train()

    def _update_running_score(self, pred, metas):
        pred = pred.permute(0, 2, 3, 1)
        for i in range(pred.size(0)):
            ori_img_size = metas[i]['ori_img_size']
            border_size = metas[i]['border_size']
            ori_target = metas[i]['ori_target']
            total_logits = cv2.resize(
                pred[i, :border_size[1], :border_size[0]].cpu().numpy(),
                tuple(ori_img_size),
                interpolation=cv2.INTER_CUBIC)
            labelmap = np.argmax(total_logits, axis=-1)
            self.seg_running_score.update(labelmap[None], ori_target[None])
Exemple #17
0
class FCNSegmentor(object):
    """
      The class for Pose Estimation. Include train, val, val & predict.
    """
    def __init__(self, configer):
        self.configer = configer
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.train_losses = AverageMeter()
        self.val_losses = AverageMeter()
        self.seg_visualizer = SegVisualizer(configer)
        self.seg_loss_manager = SegLossManager(configer)
        self.module_utilizer = ModuleUtilizer(configer)
        self.seg_model_manager = SegModelManager(configer)
        self.seg_data_loader = SegDataLoader(configer)

        self.seg_net = None
        self.train_loader = None
        self.val_loader = None
        self.optimizer = None
        self.lr = None
        self.iters = None

    def init_model(self):
        self.seg_net = self.seg_model_manager.seg_net()
        self.iters = 0
        self.seg_net, _ = self.module_utilizer.load_net(self.seg_net)

        self.optimizer, self.lr = self.module_utilizer.update_optimizer(
            self.seg_net, self.iters)

        if self.configer.get('dataset') == 'cityscape':
            self.train_loader = self.seg_data_loader.get_trainloader(
                FSCityScapeLoader)
            self.val_loader = self.seg_data_loader.get_valloader(
                FSCityScapeLoader)

        else:
            Log.error('Dataset: {} is not valid!'.format(
                self.configer.get('dataset')))
            exit(1)

        self.pixel_loss = self.seg_loss_manager.get_seg_loss(
            'cross_entropy_loss')

    def __train(self):
        """
          Train function of every epoch during train phase.
        """
        self.seg_net.train()
        start_time = time.time()

        # data_tuple: (inputs, heatmap, maskmap, tagmap, num_objects)
        for i, data_tuple in enumerate(self.train_loader):
            self.data_time.update(time.time() - start_time)
            # Change the data type.
            if len(data_tuple) < 2:
                Log.error('Train Loader Error!')
                exit(0)

            inputs = Variable(data_tuple[0].cuda(async=True))
            targets = Variable(data_tuple[1].cuda(async=True))

            # Forward pass.
            outputs = self.seg_net(inputs)

            # Compute the loss of the train batch & backward.
            loss_pixel = self.pixel_loss(outputs, targets)
            loss = loss_pixel
            self.train_losses.update(loss.data[0], inputs.size(0))
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # Update the vars of the train phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()
            self.iters += 1

            # Print the log info & reset the states.
            if self.iters % self.configer.get('solver', 'display_iter') == 0:
                Log.info(
                    'Train Iteration: {0}\t'
                    'Time {batch_time.sum:.3f}s / {1}iters, ({batch_time.avg:.3f})\t'
                    'Data load {data_time.sum:.3f}s / {1}iters, ({data_time.avg:3f})\n'
                    'Learning rate = {2}\n'
                    'Loss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'.format(
                        self.iters,
                        self.configer.get('solver', 'display_iter'),
                        self.lr,
                        batch_time=self.batch_time,
                        data_time=self.data_time,
                        loss=self.train_losses))
                self.batch_time.reset()
                self.data_time.reset()
                self.train_losses.reset()

            # Check to val the current model.
            if self.val_loader is not None and \
               self.iters % self.configer.get('solver', 'test_interval') == 0:
                self.__val()

            self.optimizer, self.lr = self.module_utilizer.update_optimizer(
                self.seg_net, self.iters)

    def __val(self):
        """
          Validation function during the train phase.
        """
        self.seg_net.eval()
        start_time = time.time()

        for j, data_tuple in enumerate(self.val_loader):
            # Change the data type.
            inputs = Variable(data_tuple[0].cuda(async=True), volatile=True)
            targets = Variable(data_tuple[1].cuda(async=True), volatile=True)
            # Forward pass.
            outputs = self.seg_net(inputs)
            # Compute the loss of the val batch.
            loss_pixel = self.pixel_loss(outputs, targets)
            loss = loss_pixel

            self.val_losses.update(loss.data[0], inputs.size(0))

            # Update the vars of the val phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()

        self.module_utilizer.save_net(self.seg_net, self.iters)
        # Print the log info & reset the states.
        Log.info('Test Time {batch_time.sum:.3f}s, ({batch_time.avg:.3f})\t'
                 'Loss {loss.avg:.8f}\n'.format(batch_time=self.batch_time,
                                                loss=self.val_losses))
        self.batch_time.reset()
        self.val_losses.reset()
        self.seg_net.train()

    def train(self):
        cudnn.benchmark = True
        while self.iters < self.configer.get('solver', 'max_iter'):
            self.__train()
            if self.iters == self.configer.get('solver', 'max_iter'):
                break
Exemple #18
0
class FCNSegmentorTest(object):
    def __init__(self, configer):
        self.configer = configer

        self.seg_vis = SegVisualizer(configer)
        self.seg_model_manager = SegModelManager(configer)
        self.module_utilizer = ModuleUtilizer(configer)
        self.seg_net = None

    def init_model(self):
        self.seg_net = self.seg_model_manager.seg_net()
        self.seg_net, _ = self.module_utilizer.load_net(self.seg_net)
        self.seg_net.eval()

    def forward(self, image_path):
        image = Image.open(image_path).convert('RGB')
        image = RandomResize(size=self.configer.get('data', 'input_size'),
                             is_base=False)(image)
        image = ToTensor()(image)
        image = Normalize(mean=[128.0, 128.0, 128.0],
                          std=[256.0, 256.0, 256.0])(image)
        inputs = Variable(image.unsqueeze(0).cuda(), volatile=True)
        results = self.seg_net.forward(inputs)
        return results.data.cpu().numpy().argmax(axis=1)[0].squeeze()

    def __test_img(self, image_path, save_path):
        if self.configer.get('dataset') == 'cityscape':
            self.__test_cityscape_img(image_path, save_path)
        elif self.configer.get('dataset') == 'laneline':
            self.__test_laneline_img(image_path, save_path)
        else:
            Log.error('Dataset: {} is not valid.'.format(
                self.configer.get('dataset')))
            exit(1)

    def __test_cityscape_img(self, img_path, save_path):
        color_list = [(128, 64, 128), (244, 35, 232), (70, 70, 70),
                      (102, 102, 156), (190, 153, 153), (153, 153, 153),
                      (250, 170, 30), (220, 220, 0), (107, 142, 35),
                      (152, 251, 152), (70, 130, 180), (220, 20, 60),
                      (255, 0, 0), (0, 0, 142), (0, 0, 70), (0, 60, 100),
                      (0, 80, 100), (0, 0, 230), (119, 11, 32)]

        result = self.forward(img_path)
        width = self.configer.get(
            'data', 'input_size')[0] // self.configer.get('network', 'stride')
        height = self.configer.get(
            'data', 'input_size')[1] // self.configer.get('network', 'stride')
        color_dst = np.zeros((height, width, 3), dtype=np.uint8)
        for i in range(self.configer.get('data', 'num_classes')):
            color_dst[result == i] = color_list[i]

        color_img = np.array(color_dst, dtype=np.uint8)
        color_img = Image.fromarray(color_img, 'RGB')
        color_img.save(save_path)

    def __test_laneline_img(self, img_path, save_path):
        pass

    def test(self):
        base_dir = os.path.join(self.configer.get('project_dir'),
                                'val/results/seg',
                                self.configer.get('dataset'), 'test')
        if not os.path.exists(base_dir):
            os.makedirs(base_dir)

        test_img = self.configer.get('test_img')
        test_dir = self.configer.get('test_dir')
        if test_img is None and test_dir is None:
            Log.error('test_img & test_dir not exists.')
            exit(1)

        if test_img is not None and test_dir is not None:
            Log.error('Either test_img or test_dir.')
            exit(1)

        if test_img is not None:
            filename = test_img.rstrip().split('/')[-1]
            save_path = os.path.join(base_dir, filename)
            self.__test_img(test_img, save_path)

        else:
            for filename in self.__list_dir(test_dir):
                image_path = os.path.join(test_dir, filename)
                save_path = os.path.join(base_dir, filename)
                self.__test_img(image_path, save_path)

    def __create_cityscape_submission(self, test_dir=None, base_dir=None):
        label_list = [
            7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31,
            32, 33
        ]

        for filename in self.__list_dir(test_dir):
            image_path = os.path.join(test_dir, filename)
            save_path = os.path.join(base_dir, filename)
            result = self.forward(image_path)
            width = self.configer.get('data',
                                      'input_size')[0] // self.configer.get(
                                          'network', 'stride')
            height = self.configer.get('data',
                                       'input_size')[1] // self.configer.get(
                                           'network', 'stride')
            label_dst = np.ones((height, width), dtype=np.uint8) * 255
            for i in range(self.configer.get('data', 'num_classes')):
                label_dst[result == i] = label_list[i]

            label_img = np.array(label_dst, dtype=np.uint8)
            label_img = Image.fromarray(label_img, 'P')
            label_img.save(save_path)

    def create_submission(self, test_dir=None):
        base_dir = os.path.join(self.configer.get('project_dir'),
                                'val/results/seg',
                                self.configer.get('dataset'), 'submission')
        if not os.path.exists(base_dir):
            os.makedirs(base_dir)

        if self.configer.get('dataset') == 'cityscape':
            self.__create_cityscape_submission(test_dir, base_dir)

        else:
            Log.error('Dataset: {} is not valid.'.format(
                self.configer.get('dataset')))
            exit(1)

    def __list_dir(self, dir_name):
        filename_list = list()
        for item in os.listdir(dir_name):
            if os.path.isdir(item):
                for filename in os.listdir(os.path.join(dir_name, item)):
                    filename_list.append('{}/{}'.format(item, filename))

            else:
                filename_list.append(item)

        return filename_list