class Trainer(object): def __init__(self, args): self.args = args self.device = torch.device(args.device) # image transform input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225]), ]) # dataset and dataloader data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size } trainset = get_segmentation_dataset(args.dataset, split='train', mode='train', **data_kwargs) args.iters_per_epoch = len(trainset) // (args.num_gpus * args.batch_size) args.max_iters = args.epochs * args.iters_per_epoch train_sampler = make_data_sampler(trainset, shuffle=True, distributed=args.distributed) train_batch_sampler = make_batch_data_sampler(train_sampler, args.batch_size, args.max_iters) self.train_loader = data.DataLoader(dataset=trainset, batch_sampler=train_batch_sampler, num_workers=args.workers, pin_memory=True) if not args.skip_val: valset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs) val_sampler = make_data_sampler(valset, False, args.distributed) val_batch_sampler = make_batch_data_sampler( val_sampler, args.batch_size) self.val_loader = data.DataLoader(dataset=valset, batch_sampler=val_batch_sampler, num_workers=args.workers, pin_memory=True) # create network BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d self.model = get_segmentation_model(args.model, dataset=args.dataset, aux=args.aux, norm_layer=BatchNorm2d) if args.distributed: self.model = nn.parallel.DistributedDataParallel( self.model, device_ids=[args.local_rank], output_device=args.local_rank) self.model = self.model.to(args.device) # resume checkpoint if needed if args.resume: if os.path.isfile(args.resume): name, ext = os.path.splitext(args.resume) assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.' print('Resuming training, loading {}...'.format(args.resume)) self.model.load_state_dict( torch.load(args.resume, map_location=lambda storage, loc: storage)) # create criterion if args.ohem: min_kept = int(args.batch_size // args.num_gpus * args.crop_size**2 // 16) self.criterion = MixSoftmaxCrossEntropyOHEMLoss( args.aux, args.aux_weight, min_kept=min_kept, ignore_index=-1).to(self.device) else: self.criterion = MixSoftmaxCrossEntropyLoss(args.aux, args.aux_weight, ignore_index=-1).to( self.device) # optimizer self.optimizer = torch.optim.SGD(self.model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # lr scheduling self.lr_scheduler = WarmupPolyLR(self.optimizer, max_iters=args.max_iters, power=0.9, warmup_factor=args.warmup_factor, warmup_iters=args.warmup_iters, warmup_method=args.warmup_method) # evaluation metrics self.metric = SegmentationMetric(trainset.num_class) self.best_pred = 0.0 def train(self): save_to_disk = get_rank() == 0 epochs, max_iters = self.args.epochs, self.args.max_iters log_per_iters, val_per_iters = self.args.log_iter, self.args.val_epoch * self.args.iters_per_epoch save_per_iters = self.args.save_epoch * self.args.iters_per_epoch start_time = time.time() logger.info( 'Start training, Total Epochs: {:d} = Total Iterations {:d}'. format(epochs, max_iters)) self.model.train() for iteration, (images, targets) in enumerate(self.train_loader): iteration += 1 self.lr_scheduler.step() images = images.to(self.device) targets = targets.to(self.device) outputs = self.model(images) loss_dict = self.criterion(outputs, targets) losses = sum(loss for loss in loss_dict.values()) # reduce losses over all GPUs for logging purposes loss_dict_reduced = reduce_loss_dict(loss_dict) losses_reduced = sum(loss for loss in loss_dict_reduced.values()) self.optimizer.zero_grad() losses.backward() self.optimizer.step() eta_seconds = ((time.time() - start_time) / iteration) * (max_iters - iteration) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if iteration % log_per_iters == 0 and save_to_disk: logger.info( "Iters: {:d}/{:d} || Lr: {:.6f} || Loss: {:.4f} || Cost Time: {} || Estimated Time: {}" .format( iteration, max_iters, self.optimizer.param_groups[0]['lr'], losses_reduced.item(), str( datetime.timedelta(seconds=int(time.time() - start_time))), eta_string)) if iteration % save_per_iters == 0 and save_to_disk: save_checkpoint(self.model, self.args, is_best=False) if not self.args.skip_val and iteration % val_per_iters == 0: self.validation() self.model.train() save_checkpoint(self.model, self.args, is_best=False) total_training_time = time.time() - start_time total_training_str = str( datetime.timedelta(seconds=total_training_time)) logger.info("Total training time: {} ({:.4f}s / it)".format( total_training_str, total_training_time / max_iters)) def validation(self): # total_inter, total_union, total_correct, total_label = 0, 0, 0, 0 is_best = False self.metric.reset() if self.args.distributed: model = self.model.module else: model = self.model torch.cuda.empty_cache() # TODO check if it helps model.eval() for i, (image, target) in enumerate(self.val_loader): image = image.to(self.device) target = target.to(self.device) with torch.no_grad(): outputs = model(image) self.metric.update(outputs[0], target) pixAcc, mIoU = self.metric.get() logger.info( "Sample: {:d}, Validation pixAcc: {:.3f}, mIoU: {:.3f}".format( i + 1, pixAcc, mIoU)) new_pred = (pixAcc + mIoU) / 2 if new_pred > self.best_pred: is_best = True self.best_pred = new_pred save_checkpoint(self.model, self.args, is_best) synchronize()
class Trainer(object): def __init__(self, args): self.args = args self.device = torch.device(args.device) # image transform input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225]), ]) # dataset and dataloader train_data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size, 're_size': args.re_size, } trainset = get_segmentation_dataset(args.dataset, args=args, split='train', mode='train_onlyrs', **train_data_kwargs) args.iters_per_epoch = len(trainset) // (args.num_gpus * args.batch_size) args.max_iters = args.epochs * args.iters_per_epoch train_sampler = make_data_sampler(trainset, shuffle=True, distributed=args.distributed) train_batch_sampler = make_batch_data_sampler(train_sampler, args.batch_size, args.max_iters) self.train_loader = data.DataLoader(dataset=trainset, batch_sampler=train_batch_sampler, num_workers=args.workers, pin_memory=True) val60_data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size, 're_size': args.re_size, } valset = get_segmentation_dataset(args.dataset, args=args, split='val', mode='val_onlyrs', **val60_data_kwargs) val_sampler = make_data_sampler(valset, True, args.distributed) val_batch_sampler = make_batch_data_sampler(val_sampler, args.batch_size) self.val60_loader = data.DataLoader(dataset=valset, batch_sampler=val_batch_sampler, num_workers=args.workers, pin_memory=True) # create network BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d self.model = get_segmentation_model(args.model, dataset=args.dataset, args=self.args, norm_layer=BatchNorm2d).to( self.device) self.model = load_modules(args, self.model) self.model = fix_model(args, self.model) # optimizer self.optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.model.parameters()), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # create criterion if args.ohem: min_kept = int(args.batch_size // args.num_gpus * args.crop_size**2 // 16) self.criterion = MixSoftmaxCrossEntropyOHEMLoss( args.aux, args.aux_weight, min_kept=min_kept, ignore_index=-1).to(self.device) else: self.criterion = MixSoftmaxCrossEntropyLoss(args.aux, args.aux_weight, ignore_index=-1).to( self.device) # lr scheduling self.lr_scheduler = WarmupPolyLR(self.optimizer, max_iters=args.max_iters, power=0.9, warmup_factor=args.warmup_factor, warmup_iters=args.warmup_iters, warmup_method=args.warmup_method) if args.use_DataParallel: self.model = torch.nn.DataParallel(self.model, device_ids=range( torch.cuda.device_count())) elif args.distributed: self.model = nn.parallel.DistributedDataParallel( self.model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) # evaluation metrics self.metric_120 = SegmentationMetric(trainset.num_class) self.metric_60 = SegmentationMetric(trainset.num_class) self.best_pred = 0.0 def train(self, writer): save_to_disk = get_rank() == 0 epochs, max_iters = self.args.epochs, self.args.max_iters log_per_iters, val_per_iters = self.args.log_iter, self.args.val_epoch * self.args.iters_per_epoch save_per_iters = self.args.save_epoch * self.args.iters_per_epoch start_time = time.time() logger.info( 'Start training, Total Epochs: {:d} = Total Iterations {:d}'. format(epochs, max_iters)) self.model.train() for iteration, (images, targets, _) in enumerate(self.train_loader): iteration += self.args.start_step self.lr_scheduler.step() for index in range(len(images)): images[index] = images[index].to(self.device) for index in range(len(targets)): targets[index] = targets[index].to(self.device) outputs = self.model(images) loss_dict = self.criterion(outputs, targets) losses = sum(loss for loss in loss_dict.values()) # reduce losses over all GPUs for logging purposes loss_dict_reduced = reduce_loss_dict(loss_dict) losses_reduced = sum(loss for loss in loss_dict_reduced.values()) self.optimizer.zero_grad() losses.backward() self.optimizer.step() eta_seconds = ((time.time() - start_time) / iteration) * (max_iters - iteration) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) writer.add_scalar("learning_rate", self.optimizer.param_groups[0]['lr'], iteration) writer.add_scalar("Loss/train_loss", losses_reduced.item(), iteration) if iteration % log_per_iters == 0 and save_to_disk: logger.info( "Iters: {:d}/{:d} || Lr: {:.6f} || Loss: {:.4f} || Estimated Time: {}" .format(iteration, max_iters, self.optimizer.param_groups[0]['lr'], losses_reduced.item(), eta_string)) if iteration % save_per_iters == 0 and save_to_disk: print('saving......') save_checkpoint(self.model, self.args, iteration=iteration, is_best=False) print('save over!') if (iteration % val_per_iters == 0): print('evaluating...') self.validate(iteration, writer) self.model.train() print('eval over!') total_training_time = time.time() - start_time total_training_str = str( datetime.timedelta(seconds=total_training_time)) logger.info("Total training time: {} ({:.4f}s / it)".format( total_training_str, total_training_time / max_iters)) def validate(self, iteration, writer): # total_inter, total_union, total_correct, total_label = 0, 0, 0, 0 is_best = False self.metric_120.reset() self.metric_60.reset() if self.args.distributed: model = self.model.module else: model = self.model torch.cuda.empty_cache() # TODO check if it helps model.eval() loss = [[], []] for i, (image, target, _) in enumerate(self.val60_loader): for index in range(len(image)): image[index] = image[index].to(self.device) for index in range(len(target)): target[index] = target[index].to(self.device) with torch.no_grad(): outputs = model(image) self.metric_120.update(outputs[0][0], target[0]) self.metric_60.update(outputs[0][1], target[1]) loss_dict = self.criterion(outputs, target) loss_dict_120 = loss_dict['loss_120'] loss_dict_60 = loss_dict['loss_60'] loss_dict_reduced_120 = reduce_loss_dict(loss_dict_120) loss_dict_reduced_60 = reduce_loss_dict(loss_dict_60) loss[0].append(loss_dict_reduced_120) loss[1].append(loss_dict_reduced_60) pixAcc_120, mIoU_120, Iou_120 = self.metric_120.get() val_loss_120 = sum(loss[0]) / len(loss[0]) val_mIou_120 = mIoU_120 val_mpixAcc_120 = pixAcc_120 logger.info( "120 Loss: {:.3f}, Validation mpixAcc: {:.3f}, mIoU: {:.3f}". format(val_loss_120, val_mpixAcc_120, val_mIou_120)) writer.add_scalar("Loss/val120_loss", val_loss_120, iteration) writer.add_scalar("Result/val120_mIou", val_mIou_120, iteration) writer.add_scalar("Result/val120_Acc", val_mpixAcc_120, iteration) for i, j in enumerate(Iou_120): logger.info("class {:d} : {:.3f}".format(i, j)) writer.add_scalar("Class120/class_{}".format(i), Iou_120[i], iteration) pixAcc_60, mIoU_60, Iou_60 = self.metric_60.get() val_loss_60 = sum(loss[1]) / len(loss[1]) val_mIou_60 = mIoU_60 val_mpixAcc_60 = pixAcc_60 logger.info( "60 Loss: {:.3f}, Validation mpixAcc: {:.3f}, mIoU: {:.3f}". format(val_loss_60, val_mpixAcc_60, val_mIou_60)) writer.add_scalar("Loss/val60_loss", val_loss_60, iteration) writer.add_scalar("Result/val60_mIou", val_mIou_60, iteration) writer.add_scalar("Result/val60_Acc", val_mpixAcc_60, iteration) for i, j in enumerate(Iou_60): logger.info("class {:d} : {:.3f}".format(i, j)) writer.add_scalar("Class60/class_{}".format(i), Iou_60[i], iteration) new_pred = (val_mIou_60 + val_mIou_120) / 2.0 if new_pred > self.best_pred: is_best = True self.best_pred = new_pred save_checkpoint(self.model, self.args, iteration, is_best) synchronize()
class Evaluator(object): def __init__(self, args): self.args = args self.device = torch.device(args.device) # image transform input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225]), ]) # dataset and dataloader val_dataset = get_segmentation_dataset('eyes', split='val', mode='testval', transform=input_transform) val_sampler = make_data_sampler(val_dataset, False, args.distributed) val_batch_sampler = make_batch_data_sampler(val_sampler, images_per_batch=1) self.val_loader = data.DataLoader(dataset=val_dataset, batch_sampler=val_batch_sampler, num_workers=args.workers, pin_memory=True) # create network self.model = get_segmentation_model(model=args.model, dataset=args.dataset, aux=args.aux, pretrained=True, pretrained_base=False) if args.distributed: self.model = self.model.module self.model.to(self.device) self.metric = SegmentationMetric(val_dataset.num_class) def eval(self): self.metric.reset() self.model.eval() if self.args.distributed: model = self.model.module else: model = self.model logger.info("Start validation, Total sample: {:d}".format( len(self.val_loader))) for i, (image, target) in enumerate(self.val_loader): image = image.to(self.device) target = target.to(self.device) with torch.no_grad(): outputs = model(image) self.metric.update(outputs[0], target) pixAcc, mIoU = self.metric.get() logger.info( "Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}".format( i + 1, pixAcc * 100, mIoU * 100)) if True: pred = torch.argmax(outputs[0], 1) pred = pred.cpu().data.numpy() predict = pred.squeeze(0) # mask = get_color_pallete(predict, self.args.dataset) image = image.cpu().data.numpy().squeeze(0).transpose( (1, 2, 0)) image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) colors = np.array([[255, 0, 0], [0, 255, 0], [0, 0, 255]]) res = np.zeros((image.shape[0] * 3, image.shape[1], 3)) inp = ((image + 1) * 127.5).astype('int') msk = colors[predict] res[0:image.shape[0], :, :] = inp res[image.shape[0]:image.shape[0] * 2, :, :] = msk res[image.shape[0] * 2:, :, :] = cv2.addWeighted( inp, 0.5, msk, 0.5, 0) cv2.imwrite( f'/root/mitya/Lightweight-Segmentation/results/{i}.png', res) synchronize()
class Evaluator(object): def __init__(self, args): self.args = args self.device = torch.device(args.device) # image transform input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225]), ]) # dataset and dataloader val60_data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size, 're_size': args.re_size, } valset = get_segmentation_dataset(args.dataset, args=args, split='val', mode='val_onlyrs', **val60_data_kwargs) val_sampler = make_data_sampler(valset, True, args.distributed) val_batch_sampler = make_batch_data_sampler(val_sampler, args.batch_size) self.val60_loader = data.DataLoader(dataset=valset, batch_sampler=val_batch_sampler, num_workers=args.workers, pin_memory=True) # create network BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d self.model = get_segmentation_model(args.model, dataset=args.dataset, args=self.args, norm_layer=BatchNorm2d).to( self.device) self.model = load_model(args.resume, self.model) # evaluation metrics self.metric_120 = SegmentationMetric(valset.num_class) self.metric_60 = SegmentationMetric(valset.num_class) self.best_pred = 0.0 def evaluate(self): is_best = False self.metric_120.reset() self.metric_60.reset() if self.args.distributed: model = self.model.module else: model = self.model torch.cuda.empty_cache() # TODO check if it helps model.eval() loss = [[], []] for i, (image, target, _) in enumerate(self.val60_loader): for index in range(len(image)): image[index] = image[index].to(self.device) for index in range(len(target)): target[index] = target[index].to(self.device) with torch.no_grad(): outputs = model(image) self.metric_120.update(outputs[0][0], target[0]) self.metric_60.update(outputs[0][1], target[1]) if self.args.save_pre: self.save_pred(image, target, _, outputs) pixAcc_120, mIoU_120, Iou_120 = self.metric_120.get() val_mIou_120 = mIoU_120 val_mpixAcc_120 = pixAcc_120 logger.info("120 Validation: mpixAcc: {:.3f}, mIoU: {:.3f}".format( val_mpixAcc_120, val_mIou_120)) for i, j in enumerate(Iou_120): logger.info("class {:d} : {:.3f}".format(i, j)) pixAcc_60, mIoU_60, Iou_60 = self.metric_60.get() val_mIou_60 = mIoU_60 val_mpixAcc_60 = pixAcc_60 logger.info("60 Validation: mpixAcc: {:.3f}, mIoU: {:.3f}".format( val_mpixAcc_60, val_mIou_60)) for i, j in enumerate(Iou_60): logger.info("class {:d} : {:.3f}".format(i, j)) synchronize() def save_pred(self, image, target, image_name, outputs): def unnormlize(img, mean, std): mean = np.expand_dims(mean, axis=0) mean = np.repeat(mean, img.shape[1], axis=0) mean = np.expand_dims(mean, axis=0) mean = np.repeat(mean, img.shape[0], axis=0) std = np.expand_dims(std, axis=0) std = np.repeat(std, img.shape[1], axis=0) std = np.expand_dims(std, axis=0) std = np.repeat(std, img.shape[0], axis=0) img = (img * std + mean) * 255. return img mean = np.array([.485, .456, .406]) std = np.array([.229, .224, .225]) ################################### 120 ########################################## pred_120 = torch.argmax(outputs[0][0], 1) pred_120 = pred_120.cpu().data.numpy() predict_120 = pred_120.squeeze(0) mask_120 = get_color_pallete(predict_120, self.args.dataset) mask_120 = np.asarray(mask_120.convert('RGB')) misc.imsave( os.path.join( self.args.save_pre_path, str(image_name[1])[2:-2] + '_' + self.args.model_mode + '.png'), mask_120) if self.args.combined: image_120 = image[0] image_120 = image_120.cpu().data.numpy()[0].transpose(1, 2, 0) image_120 = np.array(unnormlize(image_120, mean, std), dtype=np.int32) target_120 = target[0].cpu().data.numpy() target_120 = target_120.squeeze(0) target_120 = get_color_pallete(target_120, self.args.dataset) target_120 = np.asarray(target_120.convert('RGB')) combine1 = np.concatenate( (image_120, image_120 * 0.5 + mask_120 * 0.5), axis=1) combine2 = np.concatenate((target_120, mask_120), axis=1) mask_120 = np.concatenate((combine1, combine2), axis=0) misc.imsave( os.path.join( self.args.save_pre_path, str(image_name[0])[2:-2] + '_' + self.args.model_mode + '_4.png'), mask_120) ################################### 60 ########################################## pred_60 = torch.argmax(outputs[0][1], 1) pred_60 = pred_60.cpu().data.numpy() predict_60 = pred_60.squeeze(0) mask_60 = get_color_pallete(predict_60, self.args.dataset) mask_60 = np.asarray(mask_60.convert('RGB')) misc.imsave( os.path.join( self.args.save_pre_path, str(image_name[1])[2:-2] + '_' + self.args.model_mode + '.png'), mask_60) if self.args.combined: image_60 = image[1] image_60 = image_60.cpu().data.numpy()[0].transpose(1, 2, 0) image_60 = np.array(unnormlize(image_60, mean, std), dtype=np.int32) target_60 = target[1].cpu().data.numpy() target_60 = target_60.squeeze(0) target_60 = get_color_pallete(target_60, self.args.dataset) target_60 = np.asarray(target_60.convert('RGB')) combine1 = np.concatenate( (image_60, image_60 * 0.5 + mask_60 * 0.5), axis=1) combine2 = np.concatenate((target_60, mask_60), axis=1) mask_60 = np.concatenate((combine1, combine2), axis=0) misc.imsave( os.path.join( self.args.save_pre_path, str(image_name[1])[2:-2] + '_' + self.args.model_mode + '_4.png'), mask_60)
class Evaluator(object): def __init__(self, args): self.args = args self.device = torch.device(args.device) self.nb_classes = 3 valid_img_dir = '/home/wangjialei/teeth_dataset/new_data_20190621/valid_new/images' valid_mask_dir = '/home/wangjialei/teeth_dataset/new_data_20190621/valid_new/masks' # valid_transform=transforms.Compose([ # # transforms.ToTensor(), # # transforms.Normalize([0.517446, 0.360147, 0.310427], [0.061526,0.049087, 0.041330])#R_var is 0.061526, G_var is 0.049087, B_var is 0.041330 # ]) # dataset and dataloader valid_set = SegmentationData(images_dir=valid_img_dir, masks_dir=valid_mask_dir, nb_classes=self.nb_classes, mode='valid', transform=None) valid_sampler = make_data_sampler(valid_set, False, args.distributed) valid_batch_sampler = make_batch_data_sampler(valid_sampler, images_per_batch=1) self.val_loader = data.DataLoader(dataset=valid_set, batch_sampler=valid_batch_sampler, num_workers=args.workers, pin_memory=True) # create network self.model = get_segmentation_model(model=args.model, dataset=args.dataset, aux=args.aux, pretrained=True, pretrained_base=False) if args.distributed: self.model = self.model.module self.model.to(self.device) self.metric = SegmentationMetric(valid_set.num_class) def eval(self): self.metric.reset() self.model.eval() if self.args.distributed: model = self.model.module else: model = self.model logger.info("Start validation, Total sample: {:d}".format( len(self.val_loader))) for i, (image, target) in enumerate(self.val_loader): img = data_process(image) img = img.to(self.device) target = target.to(self.device) with torch.no_grad(): outputs = model(img) self.metric.update(outputs, target) pixAcc, mIoU = self.metric.get() logger.info( "Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}".format( i + 1, pixAcc * 100, mIoU * 100)) if self.args.save_pred: pred = torch.argmax(outputs[0], 1) pred = pred.cpu().data.numpy() predict = pred.squeeze(0) img_show = image[0].numpy() img_show = img_show.astype('uint8') plt.subplot(1, 3, 1) plt.title('image') plt.imshow(img_show) mask = target.cpu().data.numpy() mask = mask.reshape(mask.shape[1], mask.shape[2]) mask = mask_to_image(mask) plt.subplot(1, 3, 2) plt.title('mask') plt.imshow(mask) predict = mask_to_image(predict) plt.subplot(1, 3, 3) plt.title('pred') plt.imshow(predict) save_file = "save_fig_val" os.makedirs(save_file, exist_ok=True) plt.savefig(os.path.join(save_file, str(i) + '.png')) synchronize() def test(self): self.model.eval() if self.args.distributed: model = self.model.module else: model = self.model test_img_dir = '/home/wangjialei/projects/teeth_bad_case/' img_folder = os.listdir(test_img_dir) for iter, img_file in enumerate(img_folder): img_name = test_img_dir + img_file image = Image.open(img_name) print(type(image)) img = data_process(image) img = img.to(self.device) with torch.no_grad(): outputs = model(img) if self.args.save_pred: pred = torch.argmax(outputs[0], 1) pred = pred.cpu().data.numpy() predict = pred.squeeze(0) img_show = image plt.subplot(1, 2, 1) plt.title('image') plt.imshow(img_show) predict = mask_to_image(predict) plt.subplot(1, 2, 2) plt.title('pred') plt.imshow(predict) save_file = "save_fig_test" os.makedirs(save_file, exist_ok=True) plt.savefig(os.path.join(save_file, str(iter) + '.png'))
class Evaluator(object): def __init__(self, args): self.args = args self.device = torch.device(args.device) # image transform input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225]), ]) # dataset and dataloader val_dataset = get_segmentation_dataset(args.dataset, split='val', mode='testval', transform=input_transform) val_sampler = make_data_sampler(val_dataset, False, args.distributed) val_batch_sampler = make_batch_data_sampler(val_sampler, images_per_batch=1) self.val_loader = data.DataLoader(dataset=val_dataset, batch_sampler=val_batch_sampler, num_workers=args.workers, pin_memory=True) # create network self.model = get_segmentation_model(model=args.model, dataset=args.dataset, aux=args.aux, pretrained=True, pretrained_base=False) if args.distributed: self.model = self.model.module self.model.to(self.device) self.metric = SegmentationMetric(val_dataset.num_class) def eval(self): self.metric.reset() self.model.eval() if self.args.distributed: model = self.model.module else: model = self.model logger.info("Start validation, Total sample: {:d}".format( len(self.val_loader))) for i, (image, target) in enumerate(self.val_loader): image = image.to(self.device) target = target.to(self.device) with torch.no_grad(): outputs = model(image) self.metric.update(outputs[0], target) pixAcc, mIoU = self.metric.get() logger.info( "Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}".format( i + 1, pixAcc * 100, mIoU * 100)) if self.args.save_pred: pred = torch.argmax(outputs[0], 1) pred = pred.cpu().data.numpy() predict = pred.squeeze(0) mask = get_color_pallete(predict, self.args.dataset) # mask.save(os.path.join(outdir, os.path.splitext(filename[0])[0] + '.png')) synchronize()