def main(checkpoint_filename, input_image, output_image): device = torch.device( 'cuda') if torch.cuda.is_available() else torch.device('cpu') # Define network model = DeepLab(num_classes=3, backbone='resnet', output_stride=16, sync_bn=False, freeze_bn=False) checkpoint = torch.load(checkpoint_filename, map_location=device) state_dict = checkpoint['state_dict'] # because model was saved with DataParallel, stored checkpoint contains "module" prefix that we want to strip state_dict = { key[7:] if key.startswith('module.') else key: val for key, val in state_dict.items() } model.load_state_dict(state_dict) model.eval() image = Image.open(input_image).convert('RGB') mask = predict(model, image) mask.save(output_image)
def load_model(model_path, num_classes=14, backbone='resnet', output_stride=16): print(f"Loading model from {model_path}") model = DeepLab(num_classes=num_classes, backbone=backbone, output_stride=output_stride) pretrained_dict = torch.load(model_path, map_location=lambda storage, loc: storage) model_dict = model.state_dict() # 1. filter out unnecessary keys and mismatching sizes pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in model_dict) and (model_dict[k].shape == pretrained_dict[k].shape)} # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict model.load_state_dict(model_dict) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if torch.cuda.device_count() > 1: print("Load model in ", torch.cuda.device_count(), " GPUs!") model = nn.DataParallel(model) model.to(device) model.eval() return model
class RAN(): def __init__(self, weight, gpu_ids): self.model = DeepLab(num_classes=2, backbone='mobilenet', output_stride=16) torch.cuda.set_device(gpu_ids) self.model = self.model.cuda() assert weight is not None if not os.path.isfile(weight): raise RuntimeError("=> no checkpoint found at '{}'".format(weight)) checkpoint = torch.load(weight) self.model.load_state_dict(checkpoint['state_dict']) self.model.eval() self.mean = (0.485, 0.456, 0.406) self.std = (0.229, 0.224, 0.225) def inference(self, img): # normalize img = cv2.resize(img, (480, 480)) img = img.astype(np.float32) img /= 255.0 img -= self.mean img /= self.std img = img.transpose((2, 0, 1)) img = img[np.newaxis, :, :, :] # to tensor img = torch.from_numpy(img).float().cuda() with torch.no_grad(): output = self.model(img) return output
def inference_A_sample_image(img_path, model_path, num_classes, backbone, output_stride, sync_bn, freeze_bn): # read image image = cv2.imread(img_path) # print(image.shape) image = np.array(image).astype(np.float32) # Normalize pascal image (mean and std is from pascal.py) mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) image /= 255 image -= mean image /= std # swap color axis because # numpy image: H x W x C # torch image: C X H X W image = image.transpose((2, 0, 1)) # to 4D, N=1 image = image.reshape(1, image.shape[0], image.shape[1], image.shape[2]) image = torch.from_numpy(image) #.float() model = DeepLab(num_classes=num_classes, backbone=backbone, output_stride=output_stride, sync_bn=sync_bn, freeze_bn=freeze_bn, pretrained=True) # False if torch.cuda.is_available() is False: device = torch.device('cpu') else: device = None # need added # checkpoint = torch.load(model_path,map_location=device) # model.load_state_dict(checkpoint['state_dict']) checkpoint = torch.load('resnet101-5d3b4d8f.pth', map_location=device) model.load_state_dict(checkpoint['state_dict']) # for set dropout and batch normalization layers to evaluation mode before running inference. # Failing to do this will yield inconsistent inference results. model.eval() with torch.no_grad(): output = model(image) out_np = output.cpu().data.numpy() pred = np.argmax(out_np, axis=1) pred = pred.reshape(pred.shape[1], pred.shape[2]) # save result cv2.imwrite('output.jpg', pred) test = 1
def main(): args = arguments() seed(args) model = DeepLab(backbone='mobilenet', output_stride=16, num_classes=21, sync_bn=False) model.eval() from aimet_torch import batch_norm_fold from aimet_torch import utils args.input_shape = (1, 3, 513, 513) batch_norm_fold.fold_all_batch_norms(model, args.input_shape) utils.replace_modules_of_type1_with_type2(model, torch.nn.ReLU6, torch.nn.ReLU) if args.checkpoint_path: model.load_state_dict(torch.load(args.checkpoint_path)) else: raise ValueError('checkpoint path {} must be specified'.format( args.checkpoint_path)) data_loader_kwargs = {'worker_init_fn': work_init, 'num_workers': 0} train_loader, val_loader, test_loader, num_class = make_data_loader( args, **data_loader_kwargs) eval_func_quant = model_eval(args, val_loader) eval_func = model_eval(args, val_loader) from aimet_common.defs import QuantScheme from aimet_torch.quantsim import QuantizationSimModel if hasattr(args, 'quant_scheme'): if args.quant_scheme == 'range_learning_tf': quant_scheme = QuantScheme.training_range_learning_with_tf_init elif args.quant_scheme == 'range_learning_tfe': quant_scheme = QuantScheme.training_range_learning_with_tf_enhanced_init elif args.quant_scheme == 'tf': quant_scheme = QuantScheme.post_training_tf elif args.quant_scheme == 'tf_enhanced': quant_scheme = QuantScheme.post_training_tf_enhanced else: raise ValueError("Got unrecognized quant_scheme: " + args.quant_scheme) kwargs = { 'quant_scheme': quant_scheme, 'default_param_bw': args.default_param_bw, 'default_output_bw': args.default_output_bw, 'config_file': args.config_file } print(kwargs) sim = QuantizationSimModel(model.cpu(), input_shapes=args.input_shape, **kwargs) sim.compute_encodings(eval_func_quant, (1024, True)) post_quant_top1 = eval_func(sim.model.cuda(), (99999999, True)) print("Post Quant mIoU :", post_quant_top1)
def main(args): vali_dataset = MRIBrainSegmentation(root_folder=args.root_folder, image_label=args.data_label, is_train=False) vali_loader = torch.utils.data.DataLoader(vali_dataset, batch_size=16, shuffle=False, num_workers=4, drop_last=False) # Init and load model model = DeepLab(num_classes=1, backbone='resnet', output_stride=8, sync_bn=None, freeze_bn=False) checkpoint = torch.load(args.checkpoint) state_dict = checkpoint['state_dict'] model.load_state_dict(state_dict) device = torch.device(args.device if torch.cuda.is_available() else 'cpu') model = model.to(device) model.eval() with torch.no_grad(): for i, sample in enumerate(vali_loader): print(i) data = sample['image'] target = sample['mask'] data, target = data.to(device), target.to(device) output = model(data) target = target.data.cpu().numpy() data = data.data.cpu().numpy() output = output.data.cpu().numpy() pred = np.zeros_like(output) pred[output > 0.5] = 1 pred = pred[:, 0] for j in range(len(target)): output_image = pred[j] * 255 target_image = target[j] * 255 cv2.imwrite("{}/{:06d}_{:06d}_predict.png".format(args.output_folder, i, j), output_image.astype(np.uint8)) cv2.imwrite("{}/{:06d}_{:06d}_target.png".format(args.output_folder, i, j), target_image.astype(np.uint8)) img = data[j].transpose([1, 2, 0]) img *= (0.229, 0.224, 0.225) img += (0.485, 0.456, 0.406) img *= 255.0 cv2.imwrite( "{}}/{:06d}_{:06d}_origin.png".format(args.output_folder, i, j), img.astype(np.uint8))
def test(args): kwargs = {'num_workers': 1, 'pin_memory': True} train_loader, val_loader, test_loader, nclass = make_data_loader(args, **kwargs) model = DeepLab(num_classes=nclass, backbone=args.backbone, output_stride=args.out_stride, sync_bn=False) model.load_state_dict(torch.load(args.pretrained, map_location=device)['state_dict']) model.eval() tbar = tqdm(test_loader) ## train test dev for i, sample in enumerate(tbar): image, target = sample['image'], sample['label'] # original_image = image if args.use_mixup: image, targets_a, targets_b, lam = mixup_data(image, target, args.mixup_alpha, use_cuda=False) # mixed_image = image # image = norm(image.permute(0,2,3,1)).permute(0,3,1,2) output = model(image)
def test(args): kwargs = {'num_workers': 1, 'pin_memory': True} _, val_loader, _, nclass = make_data_loader(args, **kwargs) checkpoint = torch.load(args.ckpt) if checkpoint is None: raise ValueError device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model = DeepLab(num_classes=nclass, backbone='resnet', output_stride=16, sync_bn=True, freeze_bn=False) model.load_state_dict(checkpoint['state_dict']) model.eval() model.to(device) torch.set_grad_enabled(False) tbar = tqdm(val_loader) num_img_tr = len(val_loader) for i, sample in enumerate(tbar): x1, x2, y1, y2 = [ int(item) for item in sample['img_meta']['bbox_coord'] ] # bbox coord w, h = x2 - x1, y2 - y1 img = sample['img_meta']['image'].squeeze().cpu().numpy() img_w, img_h = img.shape[:2] inputs = sample['image'].cuda() output = model(inputs).squeeze().cpu().numpy() pred = np.argmax(output, axis=0) result = decode_segmap(pred, dataset=args.dataset, plot=False) result = imresize(result, (w, h)) result_padding = np.zeros(img.shape, dtype=np.uint8) result_padding[y1:y2, x1:x2] = result result = img // 2 + result_padding * 127 result[result > 255] = 255 plt.imsave( os.path.join('run', args.dataset, 'deeplab-resnet', 'output', str(i)), result)
def main(): parser = argparse.ArgumentParser( description="PyTorch DeeplabV3Plus Training") parser.add_argument('--backbone', type=str, default='resnet', choices=['resnet', 'xception', 'drn', 'mobilenet'], help='backbone name (default: resnet)') parser.add_argument('--out_stride', type=int, default=16, help='network output stride (default: 8)') parser.add_argument('--rip_mode', type=str, default='patches-level2') parser.add_argument('--use_sbd', action='store_true', default=True, help='whether to use SBD dataset (default: True)') parser.add_argument('--workers', type=int, default=8, metavar='N', help='dataloader threads') parser.add_argument('--base_size', type=int, default=800, help='base image size') parser.add_argument('--crop_size', type=int, default=800, help='crop image size') parser.add_argument('--sync_bn', type=bool, default=None, help='whether to use sync bn (default: auto)') parser.add_argument( '--freeze_bn', type=bool, default=False, help='whether to freeze bn parameters (default: False)') # cuda, seed and logging parser.add_argument('--gpus', type=int, default=1, help='how many gpus to use (default=1)') parser.add_argument('--seed', type=int, default=123, metavar='S', help='random seed (default: 1)') # checking point parser.add_argument('--resume', type=str, default=None, help='put the path to resuming file if needed') parser.add_argument('--checkname', type=str, default=None, help='set the checkpoint name') parser.add_argument('--exp_root', type=str, default='') args = parser.parse_args() args.device, args.cuda = get_available_device(args.gpus) nclass = 3 model = DeepLab(num_classes=nclass, backbone=args.backbone, output_stride=args.out_stride, sync_bn=args.sync_bn, freeze_bn=args.freeze_bn) args.checkname = '/data2/data2/zewei/exp/RipData/DeepLabV3/patches/level2/CV5-1/model_best.pth.tar' ckpt = torch.load(args.checkname) model.load_state_dict(ckpt['state_dict']) model.eval() model = model.to(args.device) img_files = ['doc/tests/img_cv.png'] out_file = 'doc/tests/img_seg.png' transforms = Compose([ ToTensor(), Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) ]) color_map = get_rip_labels() img_cv = cv2.imread(img_files[0]) pred = process_single_large_image(model, img_cv) mask = gen_mask(pred, nclass, color_map) out_img = composite_image(img_cv, mask, alpha=0.2) save_image(mask, out_file.split('.')[0] + f'_mask.png') save_image(out_img, out_file.split('.')[0] + f'_com.png') print(f'saved image {out_file}') with torch.no_grad(): for img_file in img_files: name, ext = img_file.split('.') img_cv = cv2.imread(img_file) patches = decompose_image(img_cv, None, (800, 800), (300, 700)) print(f'Decompose input image into {len(patches)} patches.') for i, patch in patches.items(): img = transforms(patch.image) img = torch.stack([img], dim=0).cuda() output = model(img) output = output.data.cpu().numpy() pred = np.argmax(output, axis=1) expanded_pred = torch.zeros() # out_img = output[0].cpu().permute((1, 2, 0)).numpy() # out_img = (out_img * 255).astype(np.uint8) mask = gen_mask(pred[0], nclass, color_map) out_img = composite_image(patch.image, mask, alpha=0.2) save_image(mask, name + f'_patch{i:02d}_seg.' + ext) save_image(out_img, name + f'_patch{i:02d}_seg_img.' + ext) print(f'saved image {out_file}')
from modeling.deeplab import DeepLab from dataloaders.utils import decode_segmap device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') checkpoint = torch.load("./run/pascal/deeplab-resnet/model_best.pth") model = DeepLab(num_classes=21, backbone='resnet', output_stride=16, sync_bn=True, freeze_bn=False) model.load_state_dict(checkpoint['state_dict_G']) model.eval() model.to(device) def transform(image): return tr.Compose([ tr.Resize(513), tr.CenterCrop(513), tr.ToTensor(), tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) ])(image) torch.set_grad_enabled(False) image = Image.open('sample_image.jpg')
class Trainer: def __init__(self, data_train, data_valid, image_base_dir, instructions): """ :param data_train: :param data_valid: :param image_base_dir: :param instructions: """ self.image_base_dir = image_base_dir self.data_valid = data_valid self.instructions = instructions # specify model save dir self.model_name = instructions[STR.MODEL_NAME] # now = time.localtime() # start_time = "{}-{}-{}T{}:{}:{}".format(now.tm_year, now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min, # now.tm_sec) experiment_folder_path = os.path.join(paths.MODELS_FOLDER_PATH, self.model_name) if os.path.exists(experiment_folder_path): Warning( "Experiment folder exists already. Files might be overwritten") os.makedirs(experiment_folder_path, exist_ok=True) # define saver and save instructions self.saver = Saver(folder_path=experiment_folder_path, instructions=instructions) self.saver.save_instructions() # define Tensorboard Summary self.writer = SummaryWriter(log_dir=experiment_folder_path) nn_input_size = instructions[STR.NN_INPUT_SIZE] state_dict_file_path = instructions.get(STR.STATE_DICT_FILE_PATH, None) self.colour_mapping = mapping.get_colour_mapping() # define transformers for training crops_per_image = instructions.get(STR.CROPS_PER_IMAGE, 10) apply_random_cropping = (STR.CROPS_PER_IMAGE in instructions.keys()) and \ (STR.IMAGES_PER_BATCH in instructions.keys()) print("{}applying random cropping".format( "" if apply_random_cropping else "_NOT_ ")) t = [Normalize()] if apply_random_cropping: t.append( RandomCrop(min_size=instructions.get(STR.CROP_SIZE_MIN, 400), max_size=instructions.get(STR.CROP_SIZE_MAX, 1000), crop_count=crops_per_image)) t += [ Resize(nn_input_size), Flip(p_vertical=0.2, p_horizontal=0.5), ToTensor() ] transformations_train = transforms.Compose(t) # define transformers for validation transformations_valid = transforms.Compose( [Normalize(), Resize(nn_input_size), ToTensor()]) # set up data loaders dataset_train = DictArrayDataSet(image_base_dir=image_base_dir, data=data_train, num_classes=len( self.colour_mapping.keys()), transformation=transformations_train) # define batch sizes self.batch_size = instructions[STR.BATCH_SIZE] if apply_random_cropping: self.data_loader_train = DataLoader( dataset=dataset_train, batch_size=instructions[STR.IMAGES_PER_BATCH], shuffle=True, collate_fn=custom_collate) else: self.data_loader_train = DataLoader(dataset=dataset_train, batch_size=self.batch_size, shuffle=True, collate_fn=custom_collate) dataset_valid = DictArrayDataSet(image_base_dir=image_base_dir, data=data_valid, num_classes=len( self.colour_mapping.keys()), transformation=transformations_valid) self.data_loader_valid = DataLoader(dataset=dataset_valid, batch_size=self.batch_size, shuffle=False, collate_fn=custom_collate) self.num_classes = dataset_train.num_classes() # define model print("Building model") self.model = DeepLab(num_classes=self.num_classes, backbone=instructions.get(STR.BACKBONE, "resnet"), output_stride=instructions.get( STR.DEEPLAB_OUTPUT_STRIDE, 16)) # load weights if state_dict_file_path is not None: print("loading state_dict from:") print(state_dict_file_path) load_state_dict(self.model, state_dict_file_path) learning_rate = instructions.get(STR.LEARNING_RATE, 1e-5) train_params = [{ 'params': self.model.get_1x_lr_params(), 'lr': learning_rate }, { 'params': self.model.get_10x_lr_params(), 'lr': learning_rate }] # choose gpu or cpu self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") if instructions.get(STR.MULTI_GPU, False): if torch.cuda.device_count() > 1: print("Using ", torch.cuda.device_count(), " GPUs!") self.model = nn.DataParallel(self.model) self.model.to(self.device) # Define Optimizer self.optimizer = torch.optim.SGD(train_params, momentum=0.9, weight_decay=5e-4, nesterov=False) # calculate class weights if instructions.get(STR.CLASS_STATS_FILE_PATH, None): class_weights = calculate_class_weights( instructions[STR.CLASS_STATS_FILE_PATH], self.colour_mapping, modifier=instructions.get(STR.LOSS_WEIGHT_MODIFIER, 1.01)) class_weights = torch.from_numpy(class_weights.astype(np.float32)) else: class_weights = None self.criterion = SegmentationLosses( weight=class_weights, cuda=self.device.type != "cpu").build_loss() # Define Evaluator self.evaluator = Evaluator(self.num_classes) # Define lr scheduler self.scheduler = None if instructions.get(STR.USE_LR_SCHEDULER, True): self.scheduler = LR_Scheduler(mode="cos", base_lr=learning_rate, num_epochs=instructions[STR.EPOCHS], iters_per_epoch=len( self.data_loader_train)) # print information before training start print("-" * 60) print("instructions") pprint(instructions) model_parameters = sum([p.nelement() for p in self.model.parameters()]) print("Model parameters: {:.2E}".format(model_parameters)) self.best_prediction = 0.0 def train(self, epoch): self.model.train() train_loss = 0.0 # create a progress bar pbar = tqdm(self.data_loader_train) num_batches_train = len(self.data_loader_train) # go through each item in the training data for i, sample in enumerate(pbar): # set input and target nn_input = sample[STR.NN_INPUT].to(self.device) nn_target = sample[STR.NN_TARGET].to(self.device, dtype=torch.float) if self.scheduler: self.scheduler(self.optimizer, i, epoch, self.best_prediction) # run model output = self.model(nn_input) # calc losses loss = self.criterion(output, nn_target) # # save step losses # combined_loss_steps.append(float(loss)) # regression_loss_steps.append(float(regression_loss)) # classification_loss_steps.append(float(classification_loss)) train_loss += loss.item() pbar.set_description('Train loss: %.3f' % (train_loss / (i + 1))) self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_batches_train * epoch) # calculate gradient and update model weights loss.backward() # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) self.optimizer.step() self.optimizer.zero_grad() self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch) print("[Epoch: {}, num images/crops: {}]".format( epoch, num_batches_train * self.batch_size)) print("Loss: {:.2f}".format(train_loss)) def validation(self, epoch): self.model.eval() self.evaluator.reset() test_loss = 0.0 pbar = tqdm(self.data_loader_valid, desc='\r') num_batches_val = len(self.data_loader_valid) for i, sample in enumerate(pbar): # set input and target nn_input = sample[STR.NN_INPUT].to(self.device) nn_target = sample[STR.NN_TARGET].to(self.device, dtype=torch.float) with torch.no_grad(): output = self.model(nn_input) loss = self.criterion(output, nn_target) test_loss += loss.item() pbar.set_description('Test loss: %.3f' % (test_loss / (i + 1))) pred = output.data.cpu().numpy() pred = np.argmax(pred, axis=1) nn_target = nn_target.cpu().numpy() # Add batch sample into evaluator self.evaluator.add_batch(nn_target, pred) # Fast test during the training Acc = self.evaluator.Pixel_Accuracy() Acc_class = self.evaluator.Pixel_Accuracy_Class() mIoU = self.evaluator.Mean_Intersection_over_Union() FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union() self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch) self.writer.add_scalar('val/mIoU', mIoU, epoch) self.writer.add_scalar('val/Acc', Acc, epoch) self.writer.add_scalar('val/Acc_class', Acc_class, epoch) self.writer.add_scalar('val/fwIoU', FWIoU, epoch) print('Validation:') print("[Epoch: {}, num crops: {}]".format( epoch, num_batches_val * self.batch_size)) print( "Acc:{:.2f}, Acc_class:{:.2f}, mIoU:{:.2f}, fwIoU: {:.2f}".format( Acc, Acc_class, mIoU, FWIoU)) print("Loss: {:.2f}".format(test_loss)) new_pred = mIoU is_best = new_pred > self.best_prediction if is_best: self.best_prediction = new_pred self.saver.save_checkpoint(self.model, is_best, epoch)
from modeling.deeplab import DeepLab import kornia from PIL import Image import torch import torchvision.transforms.functional as TF import numpy as np # this example only uses 1 image, so cpu is fine device = torch.device("cpu") # load pre-trained weights, set network to inference mode network = DeepLab(num_classes=18) network.load_state_dict( torch.load("segmentation-model/epoch-14", map_location="cpu")) network.eval() network.to(device) # load example image. the image is resized because DeepLab uses # a lot of dilated convolutions and doesn't work very well for # low resolution images. image = Image.open("nate.jpg") scaled_image = image.resize((418, 512), resample=Image.LANCZOS) image_tensor = TF.to_tensor(scaled_image) # send the input through the network. unsqueeze is used to # add a batch dimension, because torch always expects a batch # but in this case it's just one image # I then use Kornia to resize the mask back to 218x178 then # squeeze to remove the batch channel again (kornia also # always expects a batch dimension) with torch.no_grad():
default='/Users/yulian/Downloads/mixup_model_best.pth.tar', help='pretrained model') parser.add_argument('--color', type=str, default='purple', choices=['purple', 'green', 'blue', 'red'], help='Color your hair (default: purple)') args = parser.parse_args() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") net = DeepLab(backbone=args.backbone, output_stride=16, num_classes=2, sync_bn=False).to(device) net.load_state_dict( torch.load(args.pretrained, map_location=device)['state_dict']) net.eval() cam = cv2.VideoCapture(0) if not cam.isOpened(): raise Exception("webcam is not detected") while (True): # ret : frame capture(boolean) # frame : Capture frame ret, image = cam.read() if (ret): image, mask = get_image_mask(image, net) # print(image.shape, mask.shape) add = color_image(image, mask, args.color) cv2.imshow('frame', add) if cv2.waitKey(1) & 0xFF == ord(chr(27)):
class DeeplabRos: def __init__(self): #GPU assignment os.environ["CUDA_VISIBLE_DEVICES"] = "0" self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu') #Load checkpoint self.checkpoint = torch.load( os.path.join("./src/deeplab_ros/data/model_best.pth.tar")) #Load Model self.model = DeepLab(num_classes=4, backbone='mobilenet', output_stride=16, sync_bn=True, freeze_bn=False) self.model.load_state_dict(self.checkpoint['state_dict']) self.model = self.model.to(self.device) #ROS init self.bridge = CvBridge() self.image_sub = rospy.Subscriber("/cam2/pylon_camera_node/image_raw", ImageMsg, self.callback, queue_size=1, buff_size=2**24) self.image_pub = rospy.Publisher("segmentation_image", ImageMsg, queue_size=1) def callback(self, data): cv_image = self.bridge.imgmsg_to_cv2(data, "bgr8") start_time = time.time() self.model.eval() torch.set_grad_enabled(False) tfms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) inputs = tfms(cv_image).to(self.device) output = self.model(inputs.unsqueeze(0)).squeeze().cpu().numpy() pred = np.argmax(output, axis=0) pred_img = self.label_to_color_image(pred) msg = self.bridge.cv2_to_imgmsg(pred_img, "bgr8") inference_time = time.time() - start_time print("inference time: ", inference_time) self.image_pub.publish(msg) def label_to_color_image(self, pred, class_num=4): label_colors = np.array([(0, 0, 0), (0, 0, 128), (0, 128, 0), (128, 0, 0)]) #bgr # Unlabeled, Building, Lane-marking, Fence r = np.zeros_like(pred).astype(np.uint8) g = np.zeros_like(pred).astype(np.uint8) b = np.zeros_like(pred).astype(np.uint8) for i in range(0, class_num): idx = pred == i r[idx] = label_colors[i, 0] g[idx] = label_colors[i, 1] b[idx] = label_colors[i, 2] rgb = np.stack([r, g, b], axis=2) return rgb