Exemplo n.º 1
0
class FCClassifierTest(object):
    def __init__(self, configer):
        self.configer = configer
        self.blob_helper = BlobHelper(configer)
        self.cls_model_manager = ClsModelManager(configer)
        self.cls_data_loader = ClsDataLoader(configer)
        self.module_utilizer = ModuleUtilizer(configer)
        self.cls_parser = ClsParser(configer)
        self.device = torch.device(
            'cpu' if self.configer.get('gpu') is None else 'cuda')
        self.cls_net = None
        if self.configer.get('dataset') == 'imagenet':
            with open(
                    os.path.join(
                        self.configer.get('project_dir'),
                        'datasets/cls/imagenet/imagenet_class_index.json')
            ) as json_stream:
                name_dict = json.load(json_stream)
                name_seq = [
                    name_dict[str(i)][1]
                    for i in range(self.configer.get('data', 'num_classes'))
                ]
                self.configer.add_key_value(['details', 'name_seq'], name_seq)

        self._init_model()

    def _init_model(self):
        self.cls_net = self.cls_model_manager.image_classifier()
        self.cls_net = self.module_utilizer.load_net(self.cls_net)
        self.cls_net.eval()

    def __test_img(self, image_path, json_path, raw_path, vis_path):
        Log.info('Image Path: {}'.format(image_path))
        img = ImageHelper.read_image(
            image_path,
            tool=self.configer.get('data', 'image_tool'),
            mode=self.configer.get('data', 'input_mode'))

        trans = None
        if self.configer.get('dataset') == 'imagenet':
            if self.configer.get('data', 'image_tool') == 'cv2':
                img = Image.fromarray(img)

            trans = transforms.Compose([
                transforms.Scale(256),
                transforms.CenterCrop(224),
            ])

        assert trans is not None
        img = trans(img)

        ori_img_bgr = ImageHelper.get_cv2_bgr(img,
                                              mode=self.configer.get(
                                                  'data', 'input_mode'))

        inputs = self.blob_helper.make_input(img,
                                             input_size=self.configer.get(
                                                 'test', 'input_size'),
                                             scale=1.0)

        with torch.no_grad():
            outputs = self.cls_net(inputs)

        json_dict = self.__get_info_tree(outputs, image_path)

        image_canvas = self.cls_parser.draw_label(ori_img_bgr.copy(),
                                                  json_dict['label'])
        cv2.imwrite(vis_path, image_canvas)
        cv2.imwrite(raw_path, ori_img_bgr)

        Log.info('Json Path: {}'.format(json_path))
        JsonHelper.save_file(json_dict, json_path)
        return json_dict

    def __get_info_tree(self, outputs, image_path=None):
        json_dict = dict()
        if image_path is not None:
            json_dict['image_path'] = image_path

        topk = (1, 3, 5)
        maxk = max(topk)

        _, pred = outputs.topk(maxk, 1, True, True)
        pred = pred.t()
        for k in topk:
            if k == 1:
                json_dict['label'] = pred[0][0]

            else:
                json_dict['label_top{}'.format(k)] = pred[0][:k]

        return json_dict

    def test(self):
        base_dir = os.path.join(self.configer.get('project_dir'),
                                'val/results/cls',
                                self.configer.get('dataset'))

        test_img = self.configer.get('test_img')
        test_dir = self.configer.get('test_dir')
        if test_img is None and test_dir is None:
            Log.error('test_img & test_dir not exists.')
            exit(1)

        if test_img is not None and test_dir is not None:
            Log.error('Either test_img or test_dir.')
            exit(1)

        if test_img is not None:
            base_dir = os.path.join(base_dir, 'test_img')
            filename = test_img.rstrip().split('/')[-1]
            json_path = os.path.join(
                base_dir, 'json',
                '{}.json'.format('.'.join(filename.split('.')[:-1])))
            raw_path = os.path.join(base_dir, 'raw', filename)
            vis_path = os.path.join(
                base_dir, 'vis',
                '{}_vis.png'.format('.'.join(filename.split('.')[:-1])))
            FileHelper.make_dirs(json_path, is_file=True)
            FileHelper.make_dirs(raw_path, is_file=True)
            FileHelper.make_dirs(vis_path, is_file=True)

            self.__test_img(test_img, json_path, raw_path, vis_path)

        else:
            base_dir = os.path.join(base_dir, 'test_dir',
                                    test_dir.rstrip('/').split('/')[-1])
            FileHelper.make_dirs(base_dir)

            for filename in FileHelper.list_dir(test_dir):
                image_path = os.path.join(test_dir, filename)
                json_path = os.path.join(
                    base_dir, 'json',
                    '{}.json'.format('.'.join(filename.split('.')[:-1])))
                raw_path = os.path.join(base_dir, 'raw', filename)
                vis_path = os.path.join(
                    base_dir, 'vis',
                    '{}_vis.png'.format('.'.join(filename.split('.')[:-1])))
                FileHelper.make_dirs(json_path, is_file=True)
                FileHelper.make_dirs(raw_path, is_file=True)
                FileHelper.make_dirs(vis_path, is_file=True)

                self.__test_img(image_path, json_path, raw_path, vis_path)

    def debug(self):
        base_dir = os.path.join(self.configer.get('project_dir'),
                                'vis/results/cls',
                                self.configer.get('dataset'), 'debug')

        if not os.path.exists(base_dir):
            os.makedirs(base_dir)

        count = 0
        for i, data_dict in enumerate(self.cls_data_loader.get_trainloader()):
            inputs = data_dict['img']
            labels = data_dict['label']
            eye_matrix = torch.eye(self.configer.get('data', 'num_classes'))
            labels_target = eye_matrix[labels.view(-1)].view(
                inputs.size(0), self.configer.get('data', 'num_classes'))

            for j in range(inputs.size(0)):
                count = count + 1
                if count > 20:
                    exit(1)

                ori_img_bgr = self.blob_helper.tensor2bgr(inputs[j])

                json_dict = self.__get_info_tree(labels_target)
                image_canvas = self.cls_parser.draw_label(
                    ori_img_bgr.copy(), json_dict['label'])

                cv2.imwrite(
                    os.path.join(base_dir, '{}_{}_vis.png'.format(i, j)),
                    image_canvas)
                cv2.imshow('main', image_canvas)
                cv2.waitKey()
Exemplo n.º 2
0
class FCClassifier(object):
    """
      The class for the training phase of Image classification.
    """
    def __init__(self, configer):
        self.configer = configer
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.train_losses = AverageMeter()
        self.val_losses = AverageMeter()
        self.cls_loss_manager = ClsLossManager(configer)
        self.cls_model_manager = ClsModelManager(configer)
        self.cls_data_loader = ClsDataLoader(configer)
        self.module_utilizer = ModuleUtilizer(configer)
        self.optim_scheduler = OptimScheduler(configer)
        self.cls_running_score = ClsRunningScore(configer)

        self.cls_net = None
        self.train_loader = None
        self.val_loader = None
        self.optimizer = None
        self.scheduler = None

        self._init_model()

    def _init_model(self):
        self.cls_net = self.cls_model_manager.image_classifier()
        self.cls_net = self.module_utilizer.load_net(self.cls_net)
        self.optimizer, self.scheduler = self.optim_scheduler.init_optimizer(
            self._get_parameters())

        self.train_loader = self.cls_data_loader.get_trainloader()
        self.val_loader = self.cls_data_loader.get_valloader()

        self.ce_loss = self.cls_loss_manager.get_cls_loss('cross_entropy_loss')

    def _get_parameters(self):

        return self.cls_net.parameters()

    def __train(self):
        """
          Train function of every epoch during train phase.
        """
        self.cls_net.train()
        start_time = time.time()
        # Adjust the learning rate after every epoch.
        self.configer.plus_one('epoch')
        self.scheduler.step(self.configer.get('epoch'))

        for i, data_dict in enumerate(self.train_loader):
            inputs = data_dict['img']
            labels = data_dict['label']
            self.data_time.update(time.time() - start_time)
            # Change the data type.
            inputs, labels = self.module_utilizer.to_device(inputs, labels)
            # Forward pass.
            outputs = self.cls_net(inputs)
            outputs = self.module_utilizer.gather(outputs)
            # Compute the loss of the train batch & backward.

            loss = self.ce_loss(outputs, labels)

            self.train_losses.update(loss.item(), inputs.size(0))
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # Update the vars of the train phase.
            self.batch_time.update(time.time() - start_time)
            start_time = time.time()
            self.configer.plus_one('iters')

            # Print the log info & reset the states.
            if self.configer.get('iters') % self.configer.get(
                    'solver', 'display_iter') == 0:
                Log.info(
                    'Train Epoch: {0}\tTrain Iteration: {1}\t'
                    'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t'
                    'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n'
                    'Learning rate = {3}\tLoss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'
                    .format(self.configer.get('epoch'),
                            self.configer.get('iters'),
                            self.configer.get('solver', 'display_iter'),
                            self.scheduler.get_lr(),
                            batch_time=self.batch_time,
                            data_time=self.data_time,
                            loss=self.train_losses))

                self.batch_time.reset()
                self.data_time.reset()
                self.train_losses.reset()

            # Check to val the current model.
            if self.val_loader is not None and \
               self.configer.get('iters') % self.configer.get('solver', 'test_interval') == 0:
                self.__val()

    def __val(self):
        """
          Validation function during the train phase.
        """
        self.cls_net.eval()
        start_time = time.time()

        with torch.no_grad():
            for j, data_dict in enumerate(self.val_loader):
                inputs = data_dict['img']
                labels = data_dict['label']
                # Change the data type.
                inputs, labels = self.module_utilizer.to_device(inputs, labels)
                # Forward pass.
                outputs = self.cls_net(inputs)
                outputs = self.module_utilizer.gather(outputs)
                # Compute the loss of the val batch.
                loss = self.ce_loss(outputs, labels)
                self.cls_running_score.update(outputs, labels)
                self.val_losses.update(loss.item(), inputs.size(0))

                # Update the vars of the val phase.
                self.batch_time.update(time.time() - start_time)
                start_time = time.time()

            self.module_utilizer.save_net(self.cls_net, save_mode='iters')

            # Print the log info & reset the states.
            Log.info('Test Time {batch_time.sum:.3f}s'.format(
                batch_time=self.batch_time))
            Log.info('TestLoss = {loss.avg:.8f}'.format(loss=self.val_losses))
            Log.info('Top1 ACC = {}'.format(
                self.cls_running_score.get_top1_acc()))
            Log.info('Top5 ACC = {}'.format(
                self.cls_running_score.get_top5_acc()))
            self.batch_time.reset()
            self.val_losses.reset()
            self.cls_running_score.reset()
            self.cls_net.train()

    def train(self):
        cudnn.benchmark = True
        if self.configer.get('network',
                             'resume') is not None and self.configer.get(
                                 'network', 'resume_val'):
            self.__val()

        while self.configer.get('epoch') < self.configer.get(
                'solver', 'max_epoch'):
            self.__train()
            if self.configer.get('epoch') == self.configer.get(
                    'solver', 'max_epoch'):
                break
Exemplo n.º 3
0
class FCClassifierTest(object):
    def __init__(self, configer):
        self.configer = configer
        self.blob_helper = BlobHelper(configer)
        self.cls_model_manager = ClsModelManager(configer)
        self.cls_data_loader = DataLoader(configer)
        self.cls_parser = ClsParser(configer)
        self.device = torch.device(
            'cpu' if self.configer.get('gpu') is None else 'cuda')
        self.cls_net = None
        if self.configer.get('dataset') == 'imagenet':
            with open(
                    os.path.join(
                        self.configer.get('project_dir'),
                        'datasets/cls/imagenet/imagenet_class_index.json')
            ) as json_stream:
                name_dict = json.load(json_stream)
                name_seq = [
                    name_dict[str(i)][1]
                    for i in range(self.configer.get('data', 'num_classes'))
                ]
                self.configer.add(['details', 'name_seq'], name_seq)

        self._init_model()

    def _init_model(self):
        self.cls_net = self.cls_model_manager.image_classifier()
        self.cls_net = RunnerHelper.load_net(self, self.cls_net)
        self.cls_net.eval()

    def __test_img(self, image_path, json_path, raw_path, vis_path):
        Log.info('Image Path: {}'.format(image_path))
        img = ImageHelper.read_image(
            image_path,
            tool=self.configer.get('data', 'image_tool'),
            mode=self.configer.get('data', 'input_mode'))

        trans = None
        if self.configer.get('dataset') == 'imagenet':
            if self.configer.get('data', 'image_tool') == 'cv2':
                img = Image.fromarray(img)

            trans = transforms.Compose([
                transforms.Scale(256),
                transforms.CenterCrop(224),
            ])

        assert trans is not None
        img = trans(img)

        ori_img_bgr = ImageHelper.get_cv2_bgr(img,
                                              mode=self.configer.get(
                                                  'data', 'input_mode'))

        inputs = self.blob_helper.make_input(img,
                                             input_size=self.configer.get(
                                                 'test', 'input_size'),
                                             scale=1.0)

        with torch.no_grad():
            outputs = self.cls_net(inputs)

        json_dict = self.__get_info_tree(outputs, image_path)

        image_canvas = self.cls_parser.draw_label(ori_img_bgr.copy(),
                                                  json_dict['label'])
        cv2.imwrite(vis_path, image_canvas)
        cv2.imwrite(raw_path, ori_img_bgr)

        Log.info('Json Path: {}'.format(json_path))
        JsonHelper.save_file(json_dict, json_path)
        return json_dict

    def __get_info_tree(self, outputs, image_path=None):
        json_dict = dict()
        if image_path is not None:
            json_dict['image_path'] = image_path

        topk = (1, 3, 5)
        maxk = max(topk)

        _, pred = outputs.topk(maxk, 0, True, True)
        for k in topk:
            if k == 1:
                json_dict['label'] = pred[0]

            else:
                json_dict['label_top{}'.format(k)] = pred[:k]

        return json_dict

    def debug(self, vis_dir):
        count = 0
        for i, data_dict in enumerate(self.cls_data_loader.get_trainloader()):
            inputs = data_dict['img']
            labels = data_dict['label']
            eye_matrix = torch.eye(self.configer.get('data', 'num_classes'))
            labels_target = eye_matrix[labels.view(-1)].view(
                inputs.size(0), self.configer.get('data', 'num_classes'))

            for j in range(inputs.size(0)):
                count = count + 1
                if count > 20:
                    exit(1)

                ori_img_bgr = self.blob_helper.tensor2bgr(inputs[j])

                json_dict = self.__get_info_tree(labels_target[j])
                image_canvas = self.cls_parser.draw_label(
                    ori_img_bgr.copy(), json_dict['label'])

                cv2.imwrite(
                    os.path.join(vis_dir, '{}_{}_vis.png'.format(i, j)),
                    image_canvas)
                cv2.imshow('main', image_canvas)
                cv2.waitKey()