def main(checkpoint_filename, input_image, output_image):
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    # Define network
    model = DeepLab(num_classes=3,
                    backbone='resnet',
                    output_stride=16,
                    sync_bn=False,
                    freeze_bn=False)

    checkpoint = torch.load(checkpoint_filename, map_location=device)
    state_dict = checkpoint['state_dict']
    # because model was saved with DataParallel, stored checkpoint contains "module" prefix that we want to strip
    state_dict = {
        key[7:] if key.startswith('module.') else key: val
        for key, val in state_dict.items()
    }
    model.load_state_dict(state_dict)
    model.eval()

    image = Image.open(input_image).convert('RGB')
    mask = predict(model, image)

    mask.save(output_image)
예제 #2
0
def load_model(model_path, num_classes=14, backbone='resnet', output_stride=16):
    print(f"Loading model from {model_path}")
    model = DeepLab(num_classes=num_classes,
                    backbone=backbone,
                    output_stride=output_stride)

    pretrained_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
    model_dict = model.state_dict()
    # 1. filter out unnecessary keys and mismatching sizes
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if
                       (k in model_dict) and (model_dict[k].shape == pretrained_dict[k].shape)}

    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    model.load_state_dict(model_dict)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if torch.cuda.device_count() > 1:
        print("Load model in  ", torch.cuda.device_count(), " GPUs!")
        model = nn.DataParallel(model)
    model.to(device)
    model.eval()


    return model
예제 #3
0
class RAN():
    def __init__(self, weight, gpu_ids):
        self.model = DeepLab(num_classes=2,
                             backbone='mobilenet',
                             output_stride=16)

        torch.cuda.set_device(gpu_ids)
        self.model = self.model.cuda()

        assert weight is not None
        if not os.path.isfile(weight):
            raise RuntimeError("=> no checkpoint found at '{}'".format(weight))
        checkpoint = torch.load(weight)
        self.model.load_state_dict(checkpoint['state_dict'])
        self.model.eval()

        self.mean = (0.485, 0.456, 0.406)
        self.std = (0.229, 0.224, 0.225)

    def inference(self, img):
        # normalize
        img = cv2.resize(img, (480, 480))
        img = img.astype(np.float32)
        img /= 255.0
        img -= self.mean
        img /= self.std
        img = img.transpose((2, 0, 1))
        img = img[np.newaxis, :, :, :]
        # to tensor
        img = torch.from_numpy(img).float().cuda()

        with torch.no_grad():
            output = self.model(img)
        return output
예제 #4
0
def inference_A_sample_image(img_path, model_path, num_classes, backbone,
                             output_stride, sync_bn, freeze_bn):

    # read image
    image = cv2.imread(img_path)

    # print(image.shape)
    image = np.array(image).astype(np.float32)
    # Normalize pascal image (mean and std is from pascal.py)
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    image /= 255
    image -= mean
    image /= std

    # swap color axis because
    # numpy image: H x W x C
    # torch image: C X H X W
    image = image.transpose((2, 0, 1))

    # to 4D, N=1
    image = image.reshape(1, image.shape[0], image.shape[1], image.shape[2])
    image = torch.from_numpy(image)  #.float()

    model = DeepLab(num_classes=num_classes,
                    backbone=backbone,
                    output_stride=output_stride,
                    sync_bn=sync_bn,
                    freeze_bn=freeze_bn,
                    pretrained=True)  # False

    if torch.cuda.is_available() is False:
        device = torch.device('cpu')
    else:
        device = None  # need added

    # checkpoint = torch.load(model_path,map_location=device)
    # model.load_state_dict(checkpoint['state_dict'])
    checkpoint = torch.load('resnet101-5d3b4d8f.pth', map_location=device)
    model.load_state_dict(checkpoint['state_dict'])

    # for set dropout and batch normalization layers to evaluation mode before running inference.
    #  Failing to do this will yield inconsistent inference results.
    model.eval()

    with torch.no_grad():
        output = model(image)

        out_np = output.cpu().data.numpy()

        pred = np.argmax(out_np, axis=1)

        pred = pred.reshape(pred.shape[1], pred.shape[2])

        # save result
        cv2.imwrite('output.jpg', pred)

        test = 1
예제 #5
0
def main():
    args = arguments()
    seed(args)

    model = DeepLab(backbone='mobilenet',
                    output_stride=16,
                    num_classes=21,
                    sync_bn=False)
    model.eval()

    from aimet_torch import batch_norm_fold
    from aimet_torch import utils
    args.input_shape = (1, 3, 513, 513)
    batch_norm_fold.fold_all_batch_norms(model, args.input_shape)
    utils.replace_modules_of_type1_with_type2(model, torch.nn.ReLU6,
                                              torch.nn.ReLU)

    if args.checkpoint_path:
        model.load_state_dict(torch.load(args.checkpoint_path))
    else:
        raise ValueError('checkpoint path {} must be specified'.format(
            args.checkpoint_path))

    data_loader_kwargs = {'worker_init_fn': work_init, 'num_workers': 0}
    train_loader, val_loader, test_loader, num_class = make_data_loader(
        args, **data_loader_kwargs)
    eval_func_quant = model_eval(args, val_loader)
    eval_func = model_eval(args, val_loader)

    from aimet_common.defs import QuantScheme
    from aimet_torch.quantsim import QuantizationSimModel
    if hasattr(args, 'quant_scheme'):
        if args.quant_scheme == 'range_learning_tf':
            quant_scheme = QuantScheme.training_range_learning_with_tf_init
        elif args.quant_scheme == 'range_learning_tfe':
            quant_scheme = QuantScheme.training_range_learning_with_tf_enhanced_init
        elif args.quant_scheme == 'tf':
            quant_scheme = QuantScheme.post_training_tf
        elif args.quant_scheme == 'tf_enhanced':
            quant_scheme = QuantScheme.post_training_tf_enhanced
        else:
            raise ValueError("Got unrecognized quant_scheme: " +
                             args.quant_scheme)
        kwargs = {
            'quant_scheme': quant_scheme,
            'default_param_bw': args.default_param_bw,
            'default_output_bw': args.default_output_bw,
            'config_file': args.config_file
        }
    print(kwargs)
    sim = QuantizationSimModel(model.cpu(),
                               input_shapes=args.input_shape,
                               **kwargs)
    sim.compute_encodings(eval_func_quant, (1024, True))
    post_quant_top1 = eval_func(sim.model.cuda(), (99999999, True))
    print("Post Quant mIoU :", post_quant_top1)
예제 #6
0
def main(args):
    vali_dataset = MRIBrainSegmentation(root_folder=args.root_folder,
                                        image_label=args.data_label,
                                        is_train=False)
    vali_loader = torch.utils.data.DataLoader(vali_dataset, batch_size=16, shuffle=False,
                                              num_workers=4, drop_last=False)

    # Init and load model
    model = DeepLab(num_classes=1,
                    backbone='resnet',
                    output_stride=8,
                    sync_bn=None,
                    freeze_bn=False)

    checkpoint = torch.load(args.checkpoint)
    state_dict = checkpoint['state_dict']
    model.load_state_dict(state_dict)
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()

    with torch.no_grad():
        for i, sample in enumerate(vali_loader):
            print(i)
            data = sample['image']
            target = sample['mask']
            data, target = data.to(device), target.to(device)
            output = model(data)

            target = target.data.cpu().numpy()
            data = data.data.cpu().numpy()
            output = output.data.cpu().numpy()
            pred = np.zeros_like(output)
            pred[output > 0.5] = 1
            pred = pred[:, 0]
            for j in range(len(target)):
                output_image = pred[j] * 255
                target_image = target[j] * 255

                cv2.imwrite("{}/{:06d}_{:06d}_predict.png".format(args.output_folder, i, j), output_image.astype(np.uint8))
                cv2.imwrite("{}/{:06d}_{:06d}_target.png".format(args.output_folder, i, j), target_image.astype(np.uint8))
                img = data[j].transpose([1, 2, 0])
                img *= (0.229, 0.224, 0.225)
                img += (0.485, 0.456, 0.406)
                img *= 255.0
                cv2.imwrite(
                    "{}}/{:06d}_{:06d}_origin.png".format(args.output_folder,
                        i, j), img.astype(np.uint8))
예제 #7
0
def test(args):
    kwargs = {'num_workers': 1, 'pin_memory': True}
    train_loader, val_loader, test_loader, nclass = make_data_loader(args, **kwargs)
    model = DeepLab(num_classes=nclass,
                    backbone=args.backbone,
                    output_stride=args.out_stride,
                    sync_bn=False)
    model.load_state_dict(torch.load(args.pretrained, map_location=device)['state_dict'])
    model.eval()
    tbar = tqdm(test_loader) ## train test dev
    for i, sample in enumerate(tbar):
        image, target = sample['image'], sample['label']
        # original_image = image
        if args.use_mixup:
            image, targets_a, targets_b, lam = mixup_data(image, target,
                                                          args.mixup_alpha, use_cuda=False)
        # mixed_image = image
        # image = norm(image.permute(0,2,3,1)).permute(0,3,1,2)
        output = model(image)
예제 #8
0
def test(args):
    kwargs = {'num_workers': 1, 'pin_memory': True}
    _, val_loader, _, nclass = make_data_loader(args, **kwargs)

    checkpoint = torch.load(args.ckpt)
    if checkpoint is None:
        raise ValueError

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model = DeepLab(num_classes=nclass,
                    backbone='resnet',
                    output_stride=16,
                    sync_bn=True,
                    freeze_bn=False)
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    model.to(device)
    torch.set_grad_enabled(False)

    tbar = tqdm(val_loader)
    num_img_tr = len(val_loader)
    for i, sample in enumerate(tbar):
        x1, x2, y1, y2 = [
            int(item) for item in sample['img_meta']['bbox_coord']
        ]  # bbox coord
        w, h = x2 - x1, y2 - y1
        img = sample['img_meta']['image'].squeeze().cpu().numpy()
        img_w, img_h = img.shape[:2]

        inputs = sample['image'].cuda()
        output = model(inputs).squeeze().cpu().numpy()
        pred = np.argmax(output, axis=0)
        result = decode_segmap(pred, dataset=args.dataset, plot=False)

        result = imresize(result, (w, h))
        result_padding = np.zeros(img.shape, dtype=np.uint8)
        result_padding[y1:y2, x1:x2] = result
        result = img // 2 + result_padding * 127
        result[result > 255] = 255
        plt.imsave(
            os.path.join('run', args.dataset, 'deeplab-resnet', 'output',
                         str(i)), result)
def main():
    parser = argparse.ArgumentParser(
        description="PyTorch DeeplabV3Plus Training")
    parser.add_argument('--backbone',
                        type=str,
                        default='resnet',
                        choices=['resnet', 'xception', 'drn', 'mobilenet'],
                        help='backbone name (default: resnet)')
    parser.add_argument('--out_stride',
                        type=int,
                        default=16,
                        help='network output stride (default: 8)')
    parser.add_argument('--rip_mode', type=str, default='patches-level2')
    parser.add_argument('--use_sbd',
                        action='store_true',
                        default=True,
                        help='whether to use SBD dataset (default: True)')
    parser.add_argument('--workers',
                        type=int,
                        default=8,
                        metavar='N',
                        help='dataloader threads')
    parser.add_argument('--base_size',
                        type=int,
                        default=800,
                        help='base image size')
    parser.add_argument('--crop_size',
                        type=int,
                        default=800,
                        help='crop image size')
    parser.add_argument('--sync_bn',
                        type=bool,
                        default=None,
                        help='whether to use sync bn (default: auto)')
    parser.add_argument(
        '--freeze_bn',
        type=bool,
        default=False,
        help='whether to freeze bn parameters (default: False)')
    # cuda, seed and logging
    parser.add_argument('--gpus',
                        type=int,
                        default=1,
                        help='how many gpus to use (default=1)')
    parser.add_argument('--seed',
                        type=int,
                        default=123,
                        metavar='S',
                        help='random seed (default: 1)')
    # checking point
    parser.add_argument('--resume',
                        type=str,
                        default=None,
                        help='put the path to resuming file if needed')
    parser.add_argument('--checkname',
                        type=str,
                        default=None,
                        help='set the checkpoint name')

    parser.add_argument('--exp_root', type=str, default='')
    args = parser.parse_args()

    args.device, args.cuda = get_available_device(args.gpus)

    nclass = 3

    model = DeepLab(num_classes=nclass,
                    backbone=args.backbone,
                    output_stride=args.out_stride,
                    sync_bn=args.sync_bn,
                    freeze_bn=args.freeze_bn)

    args.checkname = '/data2/data2/zewei/exp/RipData/DeepLabV3/patches/level2/CV5-1/model_best.pth.tar'
    ckpt = torch.load(args.checkname)
    model.load_state_dict(ckpt['state_dict'])
    model.eval()
    model = model.to(args.device)

    img_files = ['doc/tests/img_cv.png']
    out_file = 'doc/tests/img_seg.png'

    transforms = Compose([
        ToTensor(),
        Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    ])
    color_map = get_rip_labels()
    img_cv = cv2.imread(img_files[0])
    pred = process_single_large_image(model, img_cv)
    mask = gen_mask(pred, nclass, color_map)
    out_img = composite_image(img_cv, mask, alpha=0.2)
    save_image(mask, out_file.split('.')[0] + f'_mask.png')
    save_image(out_img, out_file.split('.')[0] + f'_com.png')
    print(f'saved image {out_file}')

    with torch.no_grad():
        for img_file in img_files:
            name, ext = img_file.split('.')

            img_cv = cv2.imread(img_file)
            patches = decompose_image(img_cv, None, (800, 800), (300, 700))
            print(f'Decompose input image into {len(patches)} patches.')
            for i, patch in patches.items():
                img = transforms(patch.image)
                img = torch.stack([img], dim=0).cuda()

                output = model(img)
                output = output.data.cpu().numpy()
                pred = np.argmax(output, axis=1)

                expanded_pred = torch.zeros()

                # out_img = output[0].cpu().permute((1, 2, 0)).numpy()
                # out_img = (out_img * 255).astype(np.uint8)
                mask = gen_mask(pred[0], nclass, color_map)
                out_img = composite_image(patch.image, mask, alpha=0.2)
                save_image(mask, name + f'_patch{i:02d}_seg.' + ext)
                save_image(out_img, name + f'_patch{i:02d}_seg_img.' + ext)
                print(f'saved image {out_file}')
예제 #10
0
파일: inference.py 프로젝트: arunumd/test5
from modeling.deeplab import DeepLab
from dataloaders.utils import decode_segmap

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

checkpoint = torch.load("./run/pascal/deeplab-resnet/model_best.pth")

model = DeepLab(num_classes=21,
                backbone='resnet',
                output_stride=16,
                sync_bn=True,
                freeze_bn=False)

model.load_state_dict(checkpoint['state_dict_G'])
model.eval()
model.to(device)


def transform(image):
    return tr.Compose([
        tr.Resize(513),
        tr.CenterCrop(513),
        tr.ToTensor(),
        tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    ])(image)


torch.set_grad_enabled(False)

image = Image.open('sample_image.jpg')
예제 #11
0
class Trainer:
    def __init__(self, data_train, data_valid, image_base_dir, instructions):
        """

        :param data_train:
        :param data_valid:
        :param image_base_dir:
        :param instructions:
        """

        self.image_base_dir = image_base_dir
        self.data_valid = data_valid
        self.instructions = instructions

        # specify model save dir
        self.model_name = instructions[STR.MODEL_NAME]
        # now = time.localtime()
        # start_time = "{}-{}-{}T{}:{}:{}".format(now.tm_year, now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min,
        #                                         now.tm_sec)
        experiment_folder_path = os.path.join(paths.MODELS_FOLDER_PATH,
                                              self.model_name)

        if os.path.exists(experiment_folder_path):
            Warning(
                "Experiment folder exists already. Files might be overwritten")
        os.makedirs(experiment_folder_path, exist_ok=True)

        # define saver and save instructions
        self.saver = Saver(folder_path=experiment_folder_path,
                           instructions=instructions)
        self.saver.save_instructions()

        # define Tensorboard Summary
        self.writer = SummaryWriter(log_dir=experiment_folder_path)

        nn_input_size = instructions[STR.NN_INPUT_SIZE]
        state_dict_file_path = instructions.get(STR.STATE_DICT_FILE_PATH, None)

        self.colour_mapping = mapping.get_colour_mapping()

        # define transformers for training
        crops_per_image = instructions.get(STR.CROPS_PER_IMAGE, 10)

        apply_random_cropping = (STR.CROPS_PER_IMAGE in instructions.keys()) and \
                                (STR.IMAGES_PER_BATCH in instructions.keys())

        print("{}applying random cropping".format(
            "" if apply_random_cropping else "_NOT_ "))

        t = [Normalize()]
        if apply_random_cropping:
            t.append(
                RandomCrop(min_size=instructions.get(STR.CROP_SIZE_MIN, 400),
                           max_size=instructions.get(STR.CROP_SIZE_MAX, 1000),
                           crop_count=crops_per_image))
        t += [
            Resize(nn_input_size),
            Flip(p_vertical=0.2, p_horizontal=0.5),
            ToTensor()
        ]

        transformations_train = transforms.Compose(t)

        # define transformers for validation
        transformations_valid = transforms.Compose(
            [Normalize(), Resize(nn_input_size),
             ToTensor()])

        # set up data loaders
        dataset_train = DictArrayDataSet(image_base_dir=image_base_dir,
                                         data=data_train,
                                         num_classes=len(
                                             self.colour_mapping.keys()),
                                         transformation=transformations_train)

        # define batch sizes
        self.batch_size = instructions[STR.BATCH_SIZE]

        if apply_random_cropping:
            self.data_loader_train = DataLoader(
                dataset=dataset_train,
                batch_size=instructions[STR.IMAGES_PER_BATCH],
                shuffle=True,
                collate_fn=custom_collate)
        else:
            self.data_loader_train = DataLoader(dataset=dataset_train,
                                                batch_size=self.batch_size,
                                                shuffle=True,
                                                collate_fn=custom_collate)

        dataset_valid = DictArrayDataSet(image_base_dir=image_base_dir,
                                         data=data_valid,
                                         num_classes=len(
                                             self.colour_mapping.keys()),
                                         transformation=transformations_valid)

        self.data_loader_valid = DataLoader(dataset=dataset_valid,
                                            batch_size=self.batch_size,
                                            shuffle=False,
                                            collate_fn=custom_collate)

        self.num_classes = dataset_train.num_classes()

        # define model
        print("Building model")
        self.model = DeepLab(num_classes=self.num_classes,
                             backbone=instructions.get(STR.BACKBONE, "resnet"),
                             output_stride=instructions.get(
                                 STR.DEEPLAB_OUTPUT_STRIDE, 16))

        # load weights
        if state_dict_file_path is not None:
            print("loading state_dict from:")
            print(state_dict_file_path)
            load_state_dict(self.model, state_dict_file_path)

        learning_rate = instructions.get(STR.LEARNING_RATE, 1e-5)
        train_params = [{
            'params': self.model.get_1x_lr_params(),
            'lr': learning_rate
        }, {
            'params': self.model.get_10x_lr_params(),
            'lr': learning_rate
        }]

        # choose gpu or cpu
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        if instructions.get(STR.MULTI_GPU, False):
            if torch.cuda.device_count() > 1:
                print("Using ", torch.cuda.device_count(), " GPUs!")
                self.model = nn.DataParallel(self.model)

        self.model.to(self.device)

        # Define Optimizer
        self.optimizer = torch.optim.SGD(train_params,
                                         momentum=0.9,
                                         weight_decay=5e-4,
                                         nesterov=False)

        # calculate class weights
        if instructions.get(STR.CLASS_STATS_FILE_PATH, None):

            class_weights = calculate_class_weights(
                instructions[STR.CLASS_STATS_FILE_PATH],
                self.colour_mapping,
                modifier=instructions.get(STR.LOSS_WEIGHT_MODIFIER, 1.01))

            class_weights = torch.from_numpy(class_weights.astype(np.float32))
        else:
            class_weights = None
        self.criterion = SegmentationLosses(
            weight=class_weights, cuda=self.device.type != "cpu").build_loss()

        # Define Evaluator
        self.evaluator = Evaluator(self.num_classes)

        # Define lr scheduler
        self.scheduler = None
        if instructions.get(STR.USE_LR_SCHEDULER, True):
            self.scheduler = LR_Scheduler(mode="cos",
                                          base_lr=learning_rate,
                                          num_epochs=instructions[STR.EPOCHS],
                                          iters_per_epoch=len(
                                              self.data_loader_train))

        # print information before training start
        print("-" * 60)
        print("instructions")
        pprint(instructions)
        model_parameters = sum([p.nelement() for p in self.model.parameters()])
        print("Model parameters: {:.2E}".format(model_parameters))

        self.best_prediction = 0.0

    def train(self, epoch):
        self.model.train()
        train_loss = 0.0

        # create a progress bar
        pbar = tqdm(self.data_loader_train)
        num_batches_train = len(self.data_loader_train)

        # go through each item in the training data
        for i, sample in enumerate(pbar):
            # set input and target
            nn_input = sample[STR.NN_INPUT].to(self.device)
            nn_target = sample[STR.NN_TARGET].to(self.device,
                                                 dtype=torch.float)

            if self.scheduler:
                self.scheduler(self.optimizer, i, epoch, self.best_prediction)

            # run model
            output = self.model(nn_input)

            # calc losses
            loss = self.criterion(output, nn_target)
            # # save step losses
            # combined_loss_steps.append(float(loss))
            # regression_loss_steps.append(float(regression_loss))
            # classification_loss_steps.append(float(classification_loss))

            train_loss += loss.item()
            pbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_batches_train * epoch)

            # calculate gradient and update model weights
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
            self.optimizer.step()
            self.optimizer.zero_grad()

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print("[Epoch: {}, num images/crops: {}]".format(
            epoch, num_batches_train * self.batch_size))

        print("Loss: {:.2f}".format(train_loss))

    def validation(self, epoch):

        self.model.eval()
        self.evaluator.reset()
        test_loss = 0.0

        pbar = tqdm(self.data_loader_valid, desc='\r')
        num_batches_val = len(self.data_loader_valid)

        for i, sample in enumerate(pbar):
            # set input and target
            nn_input = sample[STR.NN_INPUT].to(self.device)
            nn_target = sample[STR.NN_TARGET].to(self.device,
                                                 dtype=torch.float)

            with torch.no_grad():
                output = self.model(nn_input)

            loss = self.criterion(output, nn_target)
            test_loss += loss.item()
            pbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            nn_target = nn_target.cpu().numpy()
            # Add batch sample into evaluator
            self.evaluator.add_batch(nn_target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print("[Epoch: {}, num crops: {}]".format(
            epoch, num_batches_val * self.batch_size))
        print(
            "Acc:{:.2f}, Acc_class:{:.2f}, mIoU:{:.2f}, fwIoU: {:.2f}".format(
                Acc, Acc_class, mIoU, FWIoU))
        print("Loss: {:.2f}".format(test_loss))

        new_pred = mIoU
        is_best = new_pred > self.best_prediction
        if is_best:
            self.best_prediction = new_pred
        self.saver.save_checkpoint(self.model, is_best, epoch)
from modeling.deeplab import DeepLab
import kornia
from PIL import Image
import torch
import torchvision.transforms.functional as TF
import numpy as np

# this example only uses 1 image, so cpu is fine
device = torch.device("cpu")

# load pre-trained weights, set network to inference mode
network = DeepLab(num_classes=18)
network.load_state_dict(
    torch.load("segmentation-model/epoch-14", map_location="cpu"))
network.eval()
network.to(device)

# load example image. the image is resized because DeepLab uses
# a lot of dilated convolutions and doesn't work very well for
# low resolution images.
image = Image.open("nate.jpg")
scaled_image = image.resize((418, 512), resample=Image.LANCZOS)
image_tensor = TF.to_tensor(scaled_image)

# send the input through the network. unsqueeze is used to
# add a batch dimension, because torch always expects a batch
# but in this case it's just one image
# I then use Kornia to resize the mask back to 218x178 then
# squeeze to remove the batch channel again (kornia also
# always expects a batch dimension)
with torch.no_grad():
예제 #13
0
        default='/Users/yulian/Downloads/mixup_model_best.pth.tar',
        help='pretrained model')
    parser.add_argument('--color',
                        type=str,
                        default='purple',
                        choices=['purple', 'green', 'blue', 'red'],
                        help='Color your hair (default: purple)')
    args = parser.parse_args()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net = DeepLab(backbone=args.backbone,
                  output_stride=16,
                  num_classes=2,
                  sync_bn=False).to(device)
    net.load_state_dict(
        torch.load(args.pretrained, map_location=device)['state_dict'])
    net.eval()
    cam = cv2.VideoCapture(0)
    if not cam.isOpened():
        raise Exception("webcam is not detected")

    while (True):
        # ret : frame capture(boolean)
        # frame : Capture frame
        ret, image = cam.read()

        if (ret):
            image, mask = get_image_mask(image, net)
            # print(image.shape, mask.shape)
            add = color_image(image, mask, args.color)
            cv2.imshow('frame', add)
            if cv2.waitKey(1) & 0xFF == ord(chr(27)):
class DeeplabRos:
    def __init__(self):

        #GPU assignment
        os.environ["CUDA_VISIBLE_DEVICES"] = "0"
        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')

        #Load checkpoint
        self.checkpoint = torch.load(
            os.path.join("./src/deeplab_ros/data/model_best.pth.tar"))

        #Load Model
        self.model = DeepLab(num_classes=4,
                             backbone='mobilenet',
                             output_stride=16,
                             sync_bn=True,
                             freeze_bn=False)

        self.model.load_state_dict(self.checkpoint['state_dict'])
        self.model = self.model.to(self.device)

        #ROS init
        self.bridge = CvBridge()
        self.image_sub = rospy.Subscriber("/cam2/pylon_camera_node/image_raw",
                                          ImageMsg,
                                          self.callback,
                                          queue_size=1,
                                          buff_size=2**24)
        self.image_pub = rospy.Publisher("segmentation_image",
                                         ImageMsg,
                                         queue_size=1)

    def callback(self, data):

        cv_image = self.bridge.imgmsg_to_cv2(data, "bgr8")
        start_time = time.time()

        self.model.eval()
        torch.set_grad_enabled(False)

        tfms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        inputs = tfms(cv_image).to(self.device)
        output = self.model(inputs.unsqueeze(0)).squeeze().cpu().numpy()
        pred = np.argmax(output, axis=0)
        pred_img = self.label_to_color_image(pred)

        msg = self.bridge.cv2_to_imgmsg(pred_img, "bgr8")

        inference_time = time.time() - start_time
        print("inference time: ", inference_time)

        self.image_pub.publish(msg)

    def label_to_color_image(self, pred, class_num=4):
        label_colors = np.array([(0, 0, 0), (0, 0, 128), (0, 128, 0),
                                 (128, 0, 0)])  #bgr
        # Unlabeled, Building, Lane-marking, Fence
        r = np.zeros_like(pred).astype(np.uint8)
        g = np.zeros_like(pred).astype(np.uint8)
        b = np.zeros_like(pred).astype(np.uint8)

        for i in range(0, class_num):
            idx = pred == i
            r[idx] = label_colors[i, 0]
            g[idx] = label_colors[i, 1]
            b[idx] = label_colors[i, 2]

        rgb = np.stack([r, g, b], axis=2)

        return rgb