class FCClassifierTest(object): def __init__(self, configer): self.configer = configer self.blob_helper = BlobHelper(configer) self.cls_model_manager = ClsModelManager(configer) self.cls_data_loader = ClsDataLoader(configer) self.module_utilizer = ModuleUtilizer(configer) self.cls_parser = ClsParser(configer) self.device = torch.device( 'cpu' if self.configer.get('gpu') is None else 'cuda') self.cls_net = None if self.configer.get('dataset') == 'imagenet': with open( os.path.join( self.configer.get('project_dir'), 'datasets/cls/imagenet/imagenet_class_index.json') ) as json_stream: name_dict = json.load(json_stream) name_seq = [ name_dict[str(i)][1] for i in range(self.configer.get('data', 'num_classes')) ] self.configer.add_key_value(['details', 'name_seq'], name_seq) self._init_model() def _init_model(self): self.cls_net = self.cls_model_manager.image_classifier() self.cls_net = self.module_utilizer.load_net(self.cls_net) self.cls_net.eval() def __test_img(self, image_path, json_path, raw_path, vis_path): Log.info('Image Path: {}'.format(image_path)) img = ImageHelper.read_image( image_path, tool=self.configer.get('data', 'image_tool'), mode=self.configer.get('data', 'input_mode')) trans = None if self.configer.get('dataset') == 'imagenet': if self.configer.get('data', 'image_tool') == 'cv2': img = Image.fromarray(img) trans = transforms.Compose([ transforms.Scale(256), transforms.CenterCrop(224), ]) assert trans is not None img = trans(img) ori_img_bgr = ImageHelper.get_cv2_bgr(img, mode=self.configer.get( 'data', 'input_mode')) inputs = self.blob_helper.make_input(img, input_size=self.configer.get( 'test', 'input_size'), scale=1.0) with torch.no_grad(): outputs = self.cls_net(inputs) json_dict = self.__get_info_tree(outputs, image_path) image_canvas = self.cls_parser.draw_label(ori_img_bgr.copy(), json_dict['label']) cv2.imwrite(vis_path, image_canvas) cv2.imwrite(raw_path, ori_img_bgr) Log.info('Json Path: {}'.format(json_path)) JsonHelper.save_file(json_dict, json_path) return json_dict def __get_info_tree(self, outputs, image_path=None): json_dict = dict() if image_path is not None: json_dict['image_path'] = image_path topk = (1, 3, 5) maxk = max(topk) _, pred = outputs.topk(maxk, 1, True, True) pred = pred.t() for k in topk: if k == 1: json_dict['label'] = pred[0][0] else: json_dict['label_top{}'.format(k)] = pred[0][:k] return json_dict def test(self): base_dir = os.path.join(self.configer.get('project_dir'), 'val/results/cls', self.configer.get('dataset')) test_img = self.configer.get('test_img') test_dir = self.configer.get('test_dir') if test_img is None and test_dir is None: Log.error('test_img & test_dir not exists.') exit(1) if test_img is not None and test_dir is not None: Log.error('Either test_img or test_dir.') exit(1) if test_img is not None: base_dir = os.path.join(base_dir, 'test_img') filename = test_img.rstrip().split('/')[-1] json_path = os.path.join( base_dir, 'json', '{}.json'.format('.'.join(filename.split('.')[:-1]))) raw_path = os.path.join(base_dir, 'raw', filename) vis_path = os.path.join( base_dir, 'vis', '{}_vis.png'.format('.'.join(filename.split('.')[:-1]))) FileHelper.make_dirs(json_path, is_file=True) FileHelper.make_dirs(raw_path, is_file=True) FileHelper.make_dirs(vis_path, is_file=True) self.__test_img(test_img, json_path, raw_path, vis_path) else: base_dir = os.path.join(base_dir, 'test_dir', test_dir.rstrip('/').split('/')[-1]) FileHelper.make_dirs(base_dir) for filename in FileHelper.list_dir(test_dir): image_path = os.path.join(test_dir, filename) json_path = os.path.join( base_dir, 'json', '{}.json'.format('.'.join(filename.split('.')[:-1]))) raw_path = os.path.join(base_dir, 'raw', filename) vis_path = os.path.join( base_dir, 'vis', '{}_vis.png'.format('.'.join(filename.split('.')[:-1]))) FileHelper.make_dirs(json_path, is_file=True) FileHelper.make_dirs(raw_path, is_file=True) FileHelper.make_dirs(vis_path, is_file=True) self.__test_img(image_path, json_path, raw_path, vis_path) def debug(self): base_dir = os.path.join(self.configer.get('project_dir'), 'vis/results/cls', self.configer.get('dataset'), 'debug') if not os.path.exists(base_dir): os.makedirs(base_dir) count = 0 for i, data_dict in enumerate(self.cls_data_loader.get_trainloader()): inputs = data_dict['img'] labels = data_dict['label'] eye_matrix = torch.eye(self.configer.get('data', 'num_classes')) labels_target = eye_matrix[labels.view(-1)].view( inputs.size(0), self.configer.get('data', 'num_classes')) for j in range(inputs.size(0)): count = count + 1 if count > 20: exit(1) ori_img_bgr = self.blob_helper.tensor2bgr(inputs[j]) json_dict = self.__get_info_tree(labels_target) image_canvas = self.cls_parser.draw_label( ori_img_bgr.copy(), json_dict['label']) cv2.imwrite( os.path.join(base_dir, '{}_{}_vis.png'.format(i, j)), image_canvas) cv2.imshow('main', image_canvas) cv2.waitKey()
class FCClassifier(object): """ The class for the training phase of Image classification. """ def __init__(self, configer): self.configer = configer self.batch_time = AverageMeter() self.data_time = AverageMeter() self.train_losses = AverageMeter() self.val_losses = AverageMeter() self.cls_loss_manager = ClsLossManager(configer) self.cls_model_manager = ClsModelManager(configer) self.cls_data_loader = ClsDataLoader(configer) self.module_utilizer = ModuleUtilizer(configer) self.optim_scheduler = OptimScheduler(configer) self.cls_running_score = ClsRunningScore(configer) self.cls_net = None self.train_loader = None self.val_loader = None self.optimizer = None self.scheduler = None self._init_model() def _init_model(self): self.cls_net = self.cls_model_manager.image_classifier() self.cls_net = self.module_utilizer.load_net(self.cls_net) self.optimizer, self.scheduler = self.optim_scheduler.init_optimizer( self._get_parameters()) self.train_loader = self.cls_data_loader.get_trainloader() self.val_loader = self.cls_data_loader.get_valloader() self.ce_loss = self.cls_loss_manager.get_cls_loss('cross_entropy_loss') def _get_parameters(self): return self.cls_net.parameters() def __train(self): """ Train function of every epoch during train phase. """ self.cls_net.train() start_time = time.time() # Adjust the learning rate after every epoch. self.configer.plus_one('epoch') self.scheduler.step(self.configer.get('epoch')) for i, data_dict in enumerate(self.train_loader): inputs = data_dict['img'] labels = data_dict['label'] self.data_time.update(time.time() - start_time) # Change the data type. inputs, labels = self.module_utilizer.to_device(inputs, labels) # Forward pass. outputs = self.cls_net(inputs) outputs = self.module_utilizer.gather(outputs) # Compute the loss of the train batch & backward. loss = self.ce_loss(outputs, labels) self.train_losses.update(loss.item(), inputs.size(0)) self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Update the vars of the train phase. self.batch_time.update(time.time() - start_time) start_time = time.time() self.configer.plus_one('iters') # Print the log info & reset the states. if self.configer.get('iters') % self.configer.get( 'solver', 'display_iter') == 0: Log.info( 'Train Epoch: {0}\tTrain Iteration: {1}\t' 'Time {batch_time.sum:.3f}s / {2}iters, ({batch_time.avg:.3f})\t' 'Data load {data_time.sum:.3f}s / {2}iters, ({data_time.avg:3f})\n' 'Learning rate = {3}\tLoss = {loss.val:.8f} (ave = {loss.avg:.8f})\n' .format(self.configer.get('epoch'), self.configer.get('iters'), self.configer.get('solver', 'display_iter'), self.scheduler.get_lr(), batch_time=self.batch_time, data_time=self.data_time, loss=self.train_losses)) self.batch_time.reset() self.data_time.reset() self.train_losses.reset() # Check to val the current model. if self.val_loader is not None and \ self.configer.get('iters') % self.configer.get('solver', 'test_interval') == 0: self.__val() def __val(self): """ Validation function during the train phase. """ self.cls_net.eval() start_time = time.time() with torch.no_grad(): for j, data_dict in enumerate(self.val_loader): inputs = data_dict['img'] labels = data_dict['label'] # Change the data type. inputs, labels = self.module_utilizer.to_device(inputs, labels) # Forward pass. outputs = self.cls_net(inputs) outputs = self.module_utilizer.gather(outputs) # Compute the loss of the val batch. loss = self.ce_loss(outputs, labels) self.cls_running_score.update(outputs, labels) self.val_losses.update(loss.item(), inputs.size(0)) # Update the vars of the val phase. self.batch_time.update(time.time() - start_time) start_time = time.time() self.module_utilizer.save_net(self.cls_net, save_mode='iters') # Print the log info & reset the states. Log.info('Test Time {batch_time.sum:.3f}s'.format( batch_time=self.batch_time)) Log.info('TestLoss = {loss.avg:.8f}'.format(loss=self.val_losses)) Log.info('Top1 ACC = {}'.format( self.cls_running_score.get_top1_acc())) Log.info('Top5 ACC = {}'.format( self.cls_running_score.get_top5_acc())) self.batch_time.reset() self.val_losses.reset() self.cls_running_score.reset() self.cls_net.train() def train(self): cudnn.benchmark = True if self.configer.get('network', 'resume') is not None and self.configer.get( 'network', 'resume_val'): self.__val() while self.configer.get('epoch') < self.configer.get( 'solver', 'max_epoch'): self.__train() if self.configer.get('epoch') == self.configer.get( 'solver', 'max_epoch'): break
class FCClassifierTest(object): def __init__(self, configer): self.configer = configer self.blob_helper = BlobHelper(configer) self.cls_model_manager = ClsModelManager(configer) self.cls_data_loader = DataLoader(configer) self.cls_parser = ClsParser(configer) self.device = torch.device( 'cpu' if self.configer.get('gpu') is None else 'cuda') self.cls_net = None if self.configer.get('dataset') == 'imagenet': with open( os.path.join( self.configer.get('project_dir'), 'datasets/cls/imagenet/imagenet_class_index.json') ) as json_stream: name_dict = json.load(json_stream) name_seq = [ name_dict[str(i)][1] for i in range(self.configer.get('data', 'num_classes')) ] self.configer.add(['details', 'name_seq'], name_seq) self._init_model() def _init_model(self): self.cls_net = self.cls_model_manager.image_classifier() self.cls_net = RunnerHelper.load_net(self, self.cls_net) self.cls_net.eval() def __test_img(self, image_path, json_path, raw_path, vis_path): Log.info('Image Path: {}'.format(image_path)) img = ImageHelper.read_image( image_path, tool=self.configer.get('data', 'image_tool'), mode=self.configer.get('data', 'input_mode')) trans = None if self.configer.get('dataset') == 'imagenet': if self.configer.get('data', 'image_tool') == 'cv2': img = Image.fromarray(img) trans = transforms.Compose([ transforms.Scale(256), transforms.CenterCrop(224), ]) assert trans is not None img = trans(img) ori_img_bgr = ImageHelper.get_cv2_bgr(img, mode=self.configer.get( 'data', 'input_mode')) inputs = self.blob_helper.make_input(img, input_size=self.configer.get( 'test', 'input_size'), scale=1.0) with torch.no_grad(): outputs = self.cls_net(inputs) json_dict = self.__get_info_tree(outputs, image_path) image_canvas = self.cls_parser.draw_label(ori_img_bgr.copy(), json_dict['label']) cv2.imwrite(vis_path, image_canvas) cv2.imwrite(raw_path, ori_img_bgr) Log.info('Json Path: {}'.format(json_path)) JsonHelper.save_file(json_dict, json_path) return json_dict def __get_info_tree(self, outputs, image_path=None): json_dict = dict() if image_path is not None: json_dict['image_path'] = image_path topk = (1, 3, 5) maxk = max(topk) _, pred = outputs.topk(maxk, 0, True, True) for k in topk: if k == 1: json_dict['label'] = pred[0] else: json_dict['label_top{}'.format(k)] = pred[:k] return json_dict def debug(self, vis_dir): count = 0 for i, data_dict in enumerate(self.cls_data_loader.get_trainloader()): inputs = data_dict['img'] labels = data_dict['label'] eye_matrix = torch.eye(self.configer.get('data', 'num_classes')) labels_target = eye_matrix[labels.view(-1)].view( inputs.size(0), self.configer.get('data', 'num_classes')) for j in range(inputs.size(0)): count = count + 1 if count > 20: exit(1) ori_img_bgr = self.blob_helper.tensor2bgr(inputs[j]) json_dict = self.__get_info_tree(labels_target[j]) image_canvas = self.cls_parser.draw_label( ori_img_bgr.copy(), json_dict['label']) cv2.imwrite( os.path.join(vis_dir, '{}_{}_vis.png'.format(i, j)), image_canvas) cv2.imshow('main', image_canvas) cv2.waitKey()