コード例 #1
0
ファイル: demo.py プロジェクト: HeadReaper-hc/LEDNet_ros
    def __init__(self, image_topic, device, pretrained):
        self.image_pub = rospy.Publisher("semantic_img", Image, queue_size=10)

        self.bridge = CvBridge()
        self.image_sub = rospy.Subscriber(image_topic,
                                          Image,
                                          self.callback,
                                          queue_size=1)

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        self.device = device
        self.model = LEDNet(19).to(device)
        self.model.load_state_dict(torch.load(pretrained))
        self.model.eval()
コード例 #2
0
ファイル: demo.py プロジェクト: HeadReaper-hc/LEDNet_ros
class semantic:
    def __init__(self, image_topic, device, pretrained):
        self.image_pub = rospy.Publisher("semantic_img", Image, queue_size=10)

        self.bridge = CvBridge()
        self.image_sub = rospy.Subscriber(image_topic,
                                          Image,
                                          self.callback,
                                          queue_size=1)

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        self.device = device
        self.model = LEDNet(19).to(device)
        self.model.load_state_dict(torch.load(pretrained))
        self.model.eval()

    def callback(self, data):
        try:
            cv_image = self.bridge.imgmsg_to_cv2(data, "bgr8")
        except CvBridgeError as e:
            print(e)

        pilImg = cv2PIL(cv_image, cv2.COLOR_BGR2RGB)
        img = self.transform(pilImg).unsqueeze(0).to(self.device)
        with torch.no_grad():
            output = self.model(img)
        predict = torch.argmax(output, 1).squeeze(0).cpu().numpy()
        mask = ptutil.get_color_pallete(predict, 'citys')
        mask.save(os.path.join(cur_path, 'png/output.png'))
        mmask = cv2.imread(os.path.join(cur_path, 'png/output.png'))
        # plt.imshow(mmask)
        # plt.show()
        # cv2.imshow("OpenCV",mmask)
        # cv2.waitKey(1)

        try:
            self.image_pub.publish(self.bridge.cv2_to_imgmsg(mmask, "bgr8"))
        except CvBridgeError as e:
            print(e)
def get_model(name):
    if name == 'hlnet':
        model = HLNet(input_shape=(IMG_SIZE, IMG_SIZE, 3), cls_num=CLS_NUM)
    elif name == 'fastscnn':
        model = Fast_SCNN(num_classes=CLS_NUM,
                          input_shape=(IMG_SIZE, IMG_SIZE, 3)).model()
    elif name == 'lednet':
        model = LEDNet(groups=2,
                       classes=CLS_NUM,
                       input_shape=(IMG_SIZE, IMG_SIZE, 3)).model()
    elif name == 'dfanet':
        model = DFANet(input_shape=(IMG_SIZE, IMG_SIZE, 3),
                       cls_num=CLS_NUM,
                       size_factor=2)
    elif name == 'enet':
        model = ENet(input_shape=(IMG_SIZE, IMG_SIZE, 3), cls_num=CLS_NUM)
    elif name == 'mobilenet':
        model = MobileNet(input_shape=(IMG_SIZE, IMG_SIZE, 3), cls_num=CLS_NUM)
    else:
        raise NameError("No corresponding model...")

    return model
コード例 #4
0
ファイル: train.py プロジェクト: lvcat/LEDNet
    def __init__(self, args):
        self.device = torch.device(args.device)
        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])
        # dataset and dataloader
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size
        }
        trainset = get_segmentation_dataset(args.dataset,
                                            split=args.train_split,
                                            mode='train',
                                            **data_kwargs)
        args.per_iter = len(trainset) // (args.num_gpus * args.batch_size)
        args.max_iter = args.epochs * args.per_iter
        if args.distributed:
            sampler = data.DistributedSampler(trainset)
        else:
            sampler = data.RandomSampler(trainset)
        train_sampler = data.sampler.BatchSampler(sampler, args.batch_size,
                                                  True)
        train_sampler = IterationBasedBatchSampler(
            train_sampler, num_iterations=args.max_iter)
        self.train_loader = data.DataLoader(trainset,
                                            batch_sampler=train_sampler,
                                            pin_memory=True,
                                            num_workers=args.workers)
        if not args.skip_eval or 0 < args.eval_epochs < args.epochs:
            valset = get_segmentation_dataset(args.dataset,
                                              split='val',
                                              mode='val',
                                              **data_kwargs)
            val_sampler = make_data_sampler(valset, False, args.distributed)
            val_batch_sampler = data.sampler.BatchSampler(
                val_sampler, args.test_batch_size, False)
            self.valid_loader = data.DataLoader(
                valset,
                batch_sampler=val_batch_sampler,
                num_workers=args.workers,
                pin_memory=True)

        # create network
        self.net = LEDNet(trainset.NUM_CLASS)

        if args.distributed:
            self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.net)
        self.net.to(self.device)
        # resume checkpoint if needed
        if args.resume is not None:
            if os.path.isfile(args.resume):
                self.net.load_state_dict(torch.load(args.resume))
            else:
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))

        # create criterion
        if args.ohem:
            min_kept = args.batch_size * args.crop_size**2 // 16
            self.criterion = OHEMSoftmaxCrossEntropyLoss(thresh=0.7,
                                                         min_kept=min_kept,
                                                         use_weight=False)
        else:
            self.criterion = MixSoftmaxCrossEntropyLoss()

        # optimizer and lr scheduling
        self.optimizer = optim.SGD(self.net.parameters(),
                                   lr=args.lr,
                                   momentum=args.momentum,
                                   weight_decay=args.weight_decay)
        self.scheduler = WarmupPolyLR(self.optimizer,
                                      T_max=args.max_iter,
                                      warmup_factor=args.warmup_factor,
                                      warmup_iters=args.warmup_iters,
                                      power=0.9)

        if args.distributed:
            self.net = torch.nn.parallel.DistributedDataParallel(
                self.net,
                device_ids=[args.local_rank],
                output_device=args.local_rank)

        # evaluation metrics
        self.metric = SegmentationMetric(trainset.num_class)
        self.args = args
コード例 #5
0
ファイル: train.py プロジェクト: lvcat/LEDNet
class Trainer(object):
    def __init__(self, args):
        self.device = torch.device(args.device)
        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])
        # dataset and dataloader
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size
        }
        trainset = get_segmentation_dataset(args.dataset,
                                            split=args.train_split,
                                            mode='train',
                                            **data_kwargs)
        args.per_iter = len(trainset) // (args.num_gpus * args.batch_size)
        args.max_iter = args.epochs * args.per_iter
        if args.distributed:
            sampler = data.DistributedSampler(trainset)
        else:
            sampler = data.RandomSampler(trainset)
        train_sampler = data.sampler.BatchSampler(sampler, args.batch_size,
                                                  True)
        train_sampler = IterationBasedBatchSampler(
            train_sampler, num_iterations=args.max_iter)
        self.train_loader = data.DataLoader(trainset,
                                            batch_sampler=train_sampler,
                                            pin_memory=True,
                                            num_workers=args.workers)
        if not args.skip_eval or 0 < args.eval_epochs < args.epochs:
            valset = get_segmentation_dataset(args.dataset,
                                              split='val',
                                              mode='val',
                                              **data_kwargs)
            val_sampler = make_data_sampler(valset, False, args.distributed)
            val_batch_sampler = data.sampler.BatchSampler(
                val_sampler, args.test_batch_size, False)
            self.valid_loader = data.DataLoader(
                valset,
                batch_sampler=val_batch_sampler,
                num_workers=args.workers,
                pin_memory=True)

        # create network
        self.net = LEDNet(trainset.NUM_CLASS)

        if args.distributed:
            self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.net)
        self.net.to(self.device)
        # resume checkpoint if needed
        if args.resume is not None:
            if os.path.isfile(args.resume):
                self.net.load_state_dict(torch.load(args.resume))
            else:
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))

        # create criterion
        if args.ohem:
            min_kept = args.batch_size * args.crop_size**2 // 16
            self.criterion = OHEMSoftmaxCrossEntropyLoss(thresh=0.7,
                                                         min_kept=min_kept,
                                                         use_weight=False)
        else:
            self.criterion = MixSoftmaxCrossEntropyLoss()

        # optimizer and lr scheduling
        self.optimizer = optim.SGD(self.net.parameters(),
                                   lr=args.lr,
                                   momentum=args.momentum,
                                   weight_decay=args.weight_decay)
        self.scheduler = WarmupPolyLR(self.optimizer,
                                      T_max=args.max_iter,
                                      warmup_factor=args.warmup_factor,
                                      warmup_iters=args.warmup_iters,
                                      power=0.9)

        if args.distributed:
            self.net = torch.nn.parallel.DistributedDataParallel(
                self.net,
                device_ids=[args.local_rank],
                output_device=args.local_rank)

        # evaluation metrics
        self.metric = SegmentationMetric(trainset.num_class)
        self.args = args

    def training(self):
        self.net.train()
        save_to_disk = ptutil.get_rank() == 0
        start_training_time = time.time()
        trained_time = 0
        tic = time.time()
        end = time.time()
        iteration, max_iter = 0, self.args.max_iter
        save_iter, eval_iter = self.args.per_iter * self.args.save_epoch, self.args.per_iter * self.args.eval_epochs
        # save_iter, eval_iter = 10, 10

        logger.info(
            "Start training, total epochs {:3d} = total iteration: {:6d}".
            format(self.args.epochs, max_iter))

        for i, (image, target) in enumerate(self.train_loader):
            iteration += 1
            self.scheduler.step()
            self.optimizer.zero_grad()
            image, target = image.to(self.device), target.to(self.device)
            outputs = self.net(image)
            loss_dict = self.criterion(outputs, target)
            # reduce losses over all GPUs for logging purposes
            loss_dict_reduced = ptutil.reduce_loss_dict(loss_dict)
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())

            loss = sum(loss for loss in loss_dict.values())
            loss.backward()
            self.optimizer.step()
            trained_time += time.time() - end
            end = time.time()
            if iteration % args.log_step == 0:
                eta_seconds = int(
                    (trained_time / iteration) * (max_iter - iteration))
                log_str = [
                    "Iteration {:06d} , Lr: {:.5f}, Cost: {:.2f}s, Eta: {}".
                    format(iteration, self.optimizer.param_groups[0]['lr'],
                           time.time() - tic,
                           str(datetime.timedelta(seconds=eta_seconds))),
                    "total_loss: {:.3f}".format(losses_reduced.item())
                ]
                log_str = ', '.join(log_str)
                logger.info(log_str)
                tic = time.time()
            if save_to_disk and iteration % save_iter == 0:
                model_path = os.path.join(
                    self.args.save_dir,
                    "{}_iter_{:06d}.pth".format('LEDNet', iteration))
                self.save_model(model_path)
            # Do eval when training, to trace the mAP changes and see performance improved whether or nor
            if args.eval_epochs > 0 and iteration % eval_iter == 0 and not iteration == max_iter:
                metrics = self.validate()
                ptutil.synchronize()
                pixAcc, mIoU = ptutil.accumulate_metric(metrics)
                if pixAcc is not None:
                    logger.info('pixAcc: {:.4f}, mIoU: {:.4f}'.format(
                        pixAcc, mIoU))
                self.net.train()
        if save_to_disk:
            model_path = os.path.join(
                self.args.save_dir,
                "{}_iter_{:06d}.pth".format('LEDNet', max_iter))
            self.save_model(model_path)
        # compute training time
        total_training_time = int(time.time() - start_training_time)
        total_time_str = str(datetime.timedelta(seconds=total_training_time))
        logger.info("Total training time: {} ({:.4f} s / it)".format(
            total_time_str, total_training_time / max_iter))
        # eval after training
        if not self.args.skip_eval:
            metrics = self.validate()
            ptutil.synchronize()
            pixAcc, mIoU = ptutil.accumulate_metric(metrics)
            if pixAcc is not None:
                logger.info(
                    'After training, pixAcc: {:.4f}, mIoU: {:.4f}'.format(
                        pixAcc, mIoU))

    def validate(self):
        # total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
        self.metric.reset()
        torch.cuda.empty_cache()
        if isinstance(self.net, torch.nn.parallel.DistributedDataParallel):
            model = self.net.module
        else:
            model = self.net
        model.eval()
        tbar = tqdm(self.valid_loader)
        for i, (image, target) in enumerate(tbar):
            # if i == 10: break
            image, target = image.to(self.device), target.to(self.device)
            with torch.no_grad():
                outputs = model(image)
            self.metric.update(target, outputs)
        return self.metric

    def save_model(self, model_path):
        if isinstance(self.net, torch.nn.parallel.DistributedDataParallel):
            model = self.net.module
        else:
            model = self.net
        torch.save(model.state_dict(), model_path)
        logger.info("Saved checkpoint to {}".format(model_path))
コード例 #6
0
ファイル: eval.py プロジェクト: HeadReaper-hc/LEDNet_ros
    num_gpus = int(
        os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    distributed = num_gpus > 1
    if args.cuda and torch.cuda.is_available():
        torch.backends.cudnn.benchmark = False if args.mode == 'testval' else True
        device = torch.device('cuda')
    else:
        distributed = False

    if distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl",
                                             init_method=args.init_method)

    # Load Model
    model = LEDNet(19)
    model.load_state_dict(torch.load(args.pretrained))
    model.keep_shape = True if args.mode == 'testval' else False
    model.to(device)

    # testing data
    input_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
    ])

    data_kwargs = {
        'base_size': args.base_size,
        'crop_size': args.crop_size,
        'transform': input_transform
    }
コード例 #7
0
    parser.add_argument('--cuda',
                        type=ptutil.str2bool,
                        default='true',
                        help='demo with GPU')

    opt = parser.parse_args()
    return opt


if __name__ == '__main__':
    args = parse_args()
    device = torch.device('cpu')
    if args.cuda:
        device = torch.device('cuda')
    # Load Model
    model = LEDNet(19).to(device)
    model.load_state_dict(torch.load(args.pretrained))
    model.eval()

    # Load Images
    img = Image.open(args.input_pic)

    # Transform
    transform_fn = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    img = transform_fn(img).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(img)
コード例 #8
0
    parser = argparse.ArgumentParser()
    parser.add_argument("--image_size",
                        help="size of image", type=int, default=256)
    parser.add_argument("--model_path",
                        help="the path of model", type=str,
                        default='./weights/celebhair/exper/fastscnn/model.h5')
    args = parser.parse_args()

    IMG_SIZE = args.image_size
    MODEL_PATH = args.model_path

    if MODEL_PATH.split('/')[-2] == 'lednet':
        from model.lednet import LEDNet

        model = LEDNet(2, 3, (256, 256, 3)).model()
        model.load_weights(MODEL_PATH)

    else:
        model = load_model(MODEL_PATH, custom_objects={'mean_accuracy': mean_accuracy,
                                                       'mean_iou': mean_iou,
                                                       'frequency_weighted_iou': frequency_weighted_iou,
                                                       'pixel_accuracy': pixel_accuracy,
                                                       'categorical_crossentropy_plus_dice_loss': cce_dice_loss,
                                                       'resize_image': resize_image})

    data_name = MODEL_PATH.split('/')[2]

    for img_path in glob.glob(os.path.join("./demo", data_name, "*.jpg")):
        img_basename = os.path.basename(img_path)
        name = os.path.splitext(img_basename)[0]