def __init__(self, image_topic, device, pretrained): self.image_pub = rospy.Publisher("semantic_img", Image, queue_size=10) self.bridge = CvBridge() self.image_sub = rospy.Subscriber(image_topic, Image, self.callback, queue_size=1) self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) self.device = device self.model = LEDNet(19).to(device) self.model.load_state_dict(torch.load(pretrained)) self.model.eval()
class semantic: def __init__(self, image_topic, device, pretrained): self.image_pub = rospy.Publisher("semantic_img", Image, queue_size=10) self.bridge = CvBridge() self.image_sub = rospy.Subscriber(image_topic, Image, self.callback, queue_size=1) self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) self.device = device self.model = LEDNet(19).to(device) self.model.load_state_dict(torch.load(pretrained)) self.model.eval() def callback(self, data): try: cv_image = self.bridge.imgmsg_to_cv2(data, "bgr8") except CvBridgeError as e: print(e) pilImg = cv2PIL(cv_image, cv2.COLOR_BGR2RGB) img = self.transform(pilImg).unsqueeze(0).to(self.device) with torch.no_grad(): output = self.model(img) predict = torch.argmax(output, 1).squeeze(0).cpu().numpy() mask = ptutil.get_color_pallete(predict, 'citys') mask.save(os.path.join(cur_path, 'png/output.png')) mmask = cv2.imread(os.path.join(cur_path, 'png/output.png')) # plt.imshow(mmask) # plt.show() # cv2.imshow("OpenCV",mmask) # cv2.waitKey(1) try: self.image_pub.publish(self.bridge.cv2_to_imgmsg(mmask, "bgr8")) except CvBridgeError as e: print(e)
def get_model(name): if name == 'hlnet': model = HLNet(input_shape=(IMG_SIZE, IMG_SIZE, 3), cls_num=CLS_NUM) elif name == 'fastscnn': model = Fast_SCNN(num_classes=CLS_NUM, input_shape=(IMG_SIZE, IMG_SIZE, 3)).model() elif name == 'lednet': model = LEDNet(groups=2, classes=CLS_NUM, input_shape=(IMG_SIZE, IMG_SIZE, 3)).model() elif name == 'dfanet': model = DFANet(input_shape=(IMG_SIZE, IMG_SIZE, 3), cls_num=CLS_NUM, size_factor=2) elif name == 'enet': model = ENet(input_shape=(IMG_SIZE, IMG_SIZE, 3), cls_num=CLS_NUM) elif name == 'mobilenet': model = MobileNet(input_shape=(IMG_SIZE, IMG_SIZE, 3), cls_num=CLS_NUM) else: raise NameError("No corresponding model...") return model
def __init__(self, args): self.device = torch.device(args.device) # image transform input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225]), ]) # dataset and dataloader data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size } trainset = get_segmentation_dataset(args.dataset, split=args.train_split, mode='train', **data_kwargs) args.per_iter = len(trainset) // (args.num_gpus * args.batch_size) args.max_iter = args.epochs * args.per_iter if args.distributed: sampler = data.DistributedSampler(trainset) else: sampler = data.RandomSampler(trainset) train_sampler = data.sampler.BatchSampler(sampler, args.batch_size, True) train_sampler = IterationBasedBatchSampler( train_sampler, num_iterations=args.max_iter) self.train_loader = data.DataLoader(trainset, batch_sampler=train_sampler, pin_memory=True, num_workers=args.workers) if not args.skip_eval or 0 < args.eval_epochs < args.epochs: valset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs) val_sampler = make_data_sampler(valset, False, args.distributed) val_batch_sampler = data.sampler.BatchSampler( val_sampler, args.test_batch_size, False) self.valid_loader = data.DataLoader( valset, batch_sampler=val_batch_sampler, num_workers=args.workers, pin_memory=True) # create network self.net = LEDNet(trainset.NUM_CLASS) if args.distributed: self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.net) self.net.to(self.device) # resume checkpoint if needed if args.resume is not None: if os.path.isfile(args.resume): self.net.load_state_dict(torch.load(args.resume)) else: raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) # create criterion if args.ohem: min_kept = args.batch_size * args.crop_size**2 // 16 self.criterion = OHEMSoftmaxCrossEntropyLoss(thresh=0.7, min_kept=min_kept, use_weight=False) else: self.criterion = MixSoftmaxCrossEntropyLoss() # optimizer and lr scheduling self.optimizer = optim.SGD(self.net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) self.scheduler = WarmupPolyLR(self.optimizer, T_max=args.max_iter, warmup_factor=args.warmup_factor, warmup_iters=args.warmup_iters, power=0.9) if args.distributed: self.net = torch.nn.parallel.DistributedDataParallel( self.net, device_ids=[args.local_rank], output_device=args.local_rank) # evaluation metrics self.metric = SegmentationMetric(trainset.num_class) self.args = args
class Trainer(object): def __init__(self, args): self.device = torch.device(args.device) # image transform input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225]), ]) # dataset and dataloader data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size } trainset = get_segmentation_dataset(args.dataset, split=args.train_split, mode='train', **data_kwargs) args.per_iter = len(trainset) // (args.num_gpus * args.batch_size) args.max_iter = args.epochs * args.per_iter if args.distributed: sampler = data.DistributedSampler(trainset) else: sampler = data.RandomSampler(trainset) train_sampler = data.sampler.BatchSampler(sampler, args.batch_size, True) train_sampler = IterationBasedBatchSampler( train_sampler, num_iterations=args.max_iter) self.train_loader = data.DataLoader(trainset, batch_sampler=train_sampler, pin_memory=True, num_workers=args.workers) if not args.skip_eval or 0 < args.eval_epochs < args.epochs: valset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs) val_sampler = make_data_sampler(valset, False, args.distributed) val_batch_sampler = data.sampler.BatchSampler( val_sampler, args.test_batch_size, False) self.valid_loader = data.DataLoader( valset, batch_sampler=val_batch_sampler, num_workers=args.workers, pin_memory=True) # create network self.net = LEDNet(trainset.NUM_CLASS) if args.distributed: self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.net) self.net.to(self.device) # resume checkpoint if needed if args.resume is not None: if os.path.isfile(args.resume): self.net.load_state_dict(torch.load(args.resume)) else: raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) # create criterion if args.ohem: min_kept = args.batch_size * args.crop_size**2 // 16 self.criterion = OHEMSoftmaxCrossEntropyLoss(thresh=0.7, min_kept=min_kept, use_weight=False) else: self.criterion = MixSoftmaxCrossEntropyLoss() # optimizer and lr scheduling self.optimizer = optim.SGD(self.net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) self.scheduler = WarmupPolyLR(self.optimizer, T_max=args.max_iter, warmup_factor=args.warmup_factor, warmup_iters=args.warmup_iters, power=0.9) if args.distributed: self.net = torch.nn.parallel.DistributedDataParallel( self.net, device_ids=[args.local_rank], output_device=args.local_rank) # evaluation metrics self.metric = SegmentationMetric(trainset.num_class) self.args = args def training(self): self.net.train() save_to_disk = ptutil.get_rank() == 0 start_training_time = time.time() trained_time = 0 tic = time.time() end = time.time() iteration, max_iter = 0, self.args.max_iter save_iter, eval_iter = self.args.per_iter * self.args.save_epoch, self.args.per_iter * self.args.eval_epochs # save_iter, eval_iter = 10, 10 logger.info( "Start training, total epochs {:3d} = total iteration: {:6d}". format(self.args.epochs, max_iter)) for i, (image, target) in enumerate(self.train_loader): iteration += 1 self.scheduler.step() self.optimizer.zero_grad() image, target = image.to(self.device), target.to(self.device) outputs = self.net(image) loss_dict = self.criterion(outputs, target) # reduce losses over all GPUs for logging purposes loss_dict_reduced = ptutil.reduce_loss_dict(loss_dict) losses_reduced = sum(loss for loss in loss_dict_reduced.values()) loss = sum(loss for loss in loss_dict.values()) loss.backward() self.optimizer.step() trained_time += time.time() - end end = time.time() if iteration % args.log_step == 0: eta_seconds = int( (trained_time / iteration) * (max_iter - iteration)) log_str = [ "Iteration {:06d} , Lr: {:.5f}, Cost: {:.2f}s, Eta: {}". format(iteration, self.optimizer.param_groups[0]['lr'], time.time() - tic, str(datetime.timedelta(seconds=eta_seconds))), "total_loss: {:.3f}".format(losses_reduced.item()) ] log_str = ', '.join(log_str) logger.info(log_str) tic = time.time() if save_to_disk and iteration % save_iter == 0: model_path = os.path.join( self.args.save_dir, "{}_iter_{:06d}.pth".format('LEDNet', iteration)) self.save_model(model_path) # Do eval when training, to trace the mAP changes and see performance improved whether or nor if args.eval_epochs > 0 and iteration % eval_iter == 0 and not iteration == max_iter: metrics = self.validate() ptutil.synchronize() pixAcc, mIoU = ptutil.accumulate_metric(metrics) if pixAcc is not None: logger.info('pixAcc: {:.4f}, mIoU: {:.4f}'.format( pixAcc, mIoU)) self.net.train() if save_to_disk: model_path = os.path.join( self.args.save_dir, "{}_iter_{:06d}.pth".format('LEDNet', max_iter)) self.save_model(model_path) # compute training time total_training_time = int(time.time() - start_training_time) total_time_str = str(datetime.timedelta(seconds=total_training_time)) logger.info("Total training time: {} ({:.4f} s / it)".format( total_time_str, total_training_time / max_iter)) # eval after training if not self.args.skip_eval: metrics = self.validate() ptutil.synchronize() pixAcc, mIoU = ptutil.accumulate_metric(metrics) if pixAcc is not None: logger.info( 'After training, pixAcc: {:.4f}, mIoU: {:.4f}'.format( pixAcc, mIoU)) def validate(self): # total_inter, total_union, total_correct, total_label = 0, 0, 0, 0 self.metric.reset() torch.cuda.empty_cache() if isinstance(self.net, torch.nn.parallel.DistributedDataParallel): model = self.net.module else: model = self.net model.eval() tbar = tqdm(self.valid_loader) for i, (image, target) in enumerate(tbar): # if i == 10: break image, target = image.to(self.device), target.to(self.device) with torch.no_grad(): outputs = model(image) self.metric.update(target, outputs) return self.metric def save_model(self, model_path): if isinstance(self.net, torch.nn.parallel.DistributedDataParallel): model = self.net.module else: model = self.net torch.save(model.state_dict(), model_path) logger.info("Saved checkpoint to {}".format(model_path))
num_gpus = int( os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 distributed = num_gpus > 1 if args.cuda and torch.cuda.is_available(): torch.backends.cudnn.benchmark = False if args.mode == 'testval' else True device = torch.device('cuda') else: distributed = False if distributed: torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend="nccl", init_method=args.init_method) # Load Model model = LEDNet(19) model.load_state_dict(torch.load(args.pretrained)) model.keep_shape = True if args.mode == 'testval' else False model.to(device) # testing data input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225]), ]) data_kwargs = { 'base_size': args.base_size, 'crop_size': args.crop_size, 'transform': input_transform }
parser.add_argument('--cuda', type=ptutil.str2bool, default='true', help='demo with GPU') opt = parser.parse_args() return opt if __name__ == '__main__': args = parse_args() device = torch.device('cpu') if args.cuda: device = torch.device('cuda') # Load Model model = LEDNet(19).to(device) model.load_state_dict(torch.load(args.pretrained)) model.eval() # Load Images img = Image.open(args.input_pic) # Transform transform_fn = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) img = transform_fn(img).unsqueeze(0).to(device) with torch.no_grad(): output = model(img)
parser = argparse.ArgumentParser() parser.add_argument("--image_size", help="size of image", type=int, default=256) parser.add_argument("--model_path", help="the path of model", type=str, default='./weights/celebhair/exper/fastscnn/model.h5') args = parser.parse_args() IMG_SIZE = args.image_size MODEL_PATH = args.model_path if MODEL_PATH.split('/')[-2] == 'lednet': from model.lednet import LEDNet model = LEDNet(2, 3, (256, 256, 3)).model() model.load_weights(MODEL_PATH) else: model = load_model(MODEL_PATH, custom_objects={'mean_accuracy': mean_accuracy, 'mean_iou': mean_iou, 'frequency_weighted_iou': frequency_weighted_iou, 'pixel_accuracy': pixel_accuracy, 'categorical_crossentropy_plus_dice_loss': cce_dice_loss, 'resize_image': resize_image}) data_name = MODEL_PATH.split('/')[2] for img_path in glob.glob(os.path.join("./demo", data_name, "*.jpg")): img_basename = os.path.basename(img_path) name = os.path.splitext(img_basename)[0]