예제 #1
0
def validate(cfg, model_path):
    assert model_path is not None, 'Not assert model path'
    use_cuda = False
    if cfg.get("cuda", None) is not None:
        if cfg.get("cuda", None) != "all":
            os.environ["CUDA_VISIBLE_DEVICES"] = cfg.get("cuda", None)
        use_cuda = torch.cuda.is_available()

    # Setup Dataloader
    train_loader, val_loader = get_loader(cfg)

    loss_fn = get_loss_fn(cfg)

    # Load Model
    model = get_model(cfg)
    if use_cuda:
        model.cuda()
        loss_fn.cuda()
        checkpoint = torch.load(model_path)
        if torch.cuda.device_count() > 1:  # multi gpus
            model = torch.nn.DataParallel(
                model, device_ids=list(range(torch.cuda.device_count())))
            state = checkpoint["state_dict"]
        else:  # 1 gpu
            state = convert_state_dict(checkpoint["state_dict"])
    else:  # cpu
        checkpoint = torch.load(model_path, map_location='cpu')
        state = convert_state_dict(checkpoint["state_dict"])
    model.load_state_dict(state)

    validate_epoch(val_loader, model, loss_fn, use_cuda)
예제 #2
0
파일: test.py 프로젝트: syt2/SKNet
def test(cfg, img_path, model_path):
    assert img_path is not None, 'Not assert img'
    assert model_path is not None, 'Not assert model path'
    use_cuda = False
    if cfg.get("cuda", None) is not None:
        if cfg.get("cuda", None) != "all":
            os.environ["CUDA_VISIBLE_DEVICES"] = cfg.get("cuda", None)
        use_cuda = torch.cuda.is_available()

    # Load Model
    model = get_model(cfg)
    if use_cuda:
        model.cuda()
        checkpoint = torch.load(model_path)
        if torch.cuda.device_count() > 1:  # multi gpus
            model = torch.nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
            state = checkpoint["state_dict"]
        else:  # 1 gpu
            state = convert_state_dict(checkpoint["state_dict"])
    else:  # cpu
        checkpoint = torch.load(model_path, map_location='cpu')
        state = convert_state_dict(checkpoint["state_dict"])
    model.load_state_dict(state)

    model.eval()
    input = get_img(img_path)  # read img
    with torch.no_grad():
        if use_cuda:
            input = input.cuda()
        output = model(input)  # model output
        idx = torch.argmax(output, 1)

    idx = int(idx.data)
    cls = idx2label[idx]
    print(idx, cls)
예제 #3
0
def test(img_path, model_path, show=False):
    imgorg = cv2.imread(img_path)
    imgorg = cv2.cvtColor(imgorg, cv2.COLOR_BGR2RGB)
    w, h = imgorg.shape[0], imgorg.shape[1]
    img = cv2.resize(imgorg, (256, 256))
    img = img[:, :, ::-1]
    img = img.astype(float) / 255
    img = img.transpose(2, 0, 1)
    img = np.expand_dims(img, 0)
    img = torch.from_numpy(img).float()

    htan = nn.Hardtanh(0.0, 1.0)
    model = get_model('unetnc', n_classes=3, in_channels=3)
    state = convert_state_dict(torch.load(model_path)['model_state'])

    model.load_state_dict(state)
    model.eval()
    model.cuda()
    images = Variable(img.cuda())

    with torch.no_grad():
        output = model(images)
        pred = htan(output)

    pred = pred.cpu().detach().numpy()[0]

    if show:
        pred = pred.transpose((1, 2, 0))
        pred = cv2.resize(pred, (h, w), interpolation=cv2.INTER_NEAREST)
        print(pred)
        _, axis = plt.subplots(1, 2)
        axis[0].imshow(imgorg)
        axis[1].imshow(pred)
        plt.show()
예제 #4
0
def main(args, logger):
    # ================ seed and device ===================
    np.random.seed(42)
    torch.manual_seed(42)
    # if args.cuda:
    if True:
        torch.cuda.manual_seed_all(42)
        device = 'cuda'
    else:
        device = 'cpu'
    # ================= data ====================
    mnist = Mnist(args.data_dir, mode='test')
    val_loader = torch.utils.data.DataLoader(mnist.test_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False)
    # ================== Load Model ===================
    model = models.get_model(name=args.model, n_class=mnist.n_class)
    model.to(device)
    best_model_path = pjoin(args.save_dir, args.model + '_best_model.pkl')
    state = convert_state_dict(torch.load(best_model_path)["model_state"])
    model.load_state_dict(state)
    # ================== Testing ======================
    model.eval()
    val_acc = 0
    with torch.no_grad():
        for idx, (img, lab) in enumerate(val_loader):
            img = img.to(device)
            lab = lab.to(device)
            out = model(img)
            pred = out.argmax(dim=1, keepdim=True)
            val_acc += pred.eq(lab.view_as(pred)).sum().item()
    val_acc /= len(val_loader.dataset)
    logger.write(f'Model {best_model_path}, Acc: {val_acc:.3f}')
 def init_model(self, model_path):
     n_classes = 19
     # Setup Model
     model = get_model({"arch": "hardnet"}, n_classes)
     state = convert_state_dict(
         torch.load(model_path, map_location=self.device)["model_state"])
     model.load_state_dict(state)
     model.eval()
     model.to(self.device)
     return model
예제 #6
0
def main(args):
    # ========= Setup device and seed ============
    np.random.seed(42)
    torch.manual_seed(42)
    if args.cuda:
        torch.cuda.manual_seed_all(42)
    device = 'cuda' if args.cuda else 'cpu'
    logger = Logger(pjoin(args.save_dir, args.model, 'test.log'))
    logger.write(f'\nTesting configs: {args}')

    # ================= Load processed data ===================
    val_dataset = VOC12(args.data_dir, img_size=args.img_size, split='test')
    val_loader = DataLoader(val_dataset, num_workers=8, batch_size=1)
    n_classes = val_dataset.n_classes

    # ================= Init model ====================
    model = models.get_model(name=args.model, n_classes=n_classes)
    model = model.to(device)
    state = convert_state_dict(torch.load(args.model_path)["model_state"])
    model.load_state_dict(state)
    model.eval()

    # ====================== Only one image ==========================
    if args.eval:
        with torch.no_grad():
            img = Image.open(args.img_path)
            origin = img.size
            if args.img_size:
                img = img.resize(
                    (val_dataset.img_size[0], val_dataset.img_size[1]))
            img = val_dataset.input_transform(img).unsqueeze(0).to(device)
            out = model(img)
            pred = np.squeeze(out.data.max(1)[1].cpu().numpy(), axis=0)
            decoded = val_dataset.decode_segmap(pred)
            img_out = ToPILImage()(decoded).resize(origin)
            img_out.save(
                pjoin(args.save_dir, args.model, f'eval_{args.img_size}.png'))
        return

    # ====================== Testing Many images ==============================
    with torch.no_grad():
        for idx, (name, img) in enumerate(val_loader):
            img = img.to(device)
            out = model(img)
            pred = out.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy()
            decoded = val_dataset.decode_segmap(pred)
            ToPILImage()(decoded).save(
                pjoin(args.save_dir, args.model,
                      f'{name[0]}_{args.img_size}.png'))
예제 #7
0
def main(args):
    # ================ seed and device ===================
    np.random.seed(42)
    torch.manual_seed(42)
    if args.cuda:
        torch.cuda.manual_seed_all(42)
        device = 'cuda'
    else:
        device = 'cpu'
    logger = Logger(pjoin(args.save_dir, args.model + '_test.log'))
    logger.write(f'\nConfig: {args}')
    # ================= data ====================
    mnist = Mnist(args.data_dir, mode='test')
    val_loader = torch.utils.data.DataLoader(mnist.test_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False)
    # ================== Load Model ===================
    model = models.get_model(name=args.model, n_class=mnist.n_class)
    model.to(device)
    best_model_path = args.model_path
    state = convert_state_dict(torch.load(best_model_path)["model_state"])
    model.load_state_dict(state)
    # ================== Testing ======================
    model.eval()
    val_acc = 0
    res = []
    with torch.no_grad():
        for idx, (img, lab) in enumerate(val_loader):
            img = img.to(device)
            lab = lab.to(device)
            out, tmp = model(img)
            res.append(tmp.cpu().numpy())
            pred = out.argmax(dim=1, keepdim=True)
            val_acc += pred.eq(lab.view_as(pred)).sum().item()
    val_acc /= len(val_loader.dataset)
    logger.write(f'Model {best_model_path}, Acc: {val_acc:.3f}')
    with open('test.npy', 'wb') as fout:
        np.save(fout, np.array(res))
예제 #8
0
파일: test.py 프로젝트: ycpan1597/MRI-AI
def test(args):

    data_loader = get_loader(args.dataset)
    data_path = get_data_path()
    loader = data_loader(data_path, is_transform=True)
    n_classes = loader.n_classes

    # Setup Model
    model = unet(n_classes=n_classes, in_channels=1)
    state = convert_state_dict(torch.load(args.model_path)['model_state'])
    model.load_state_dict(state)
    model.eval()
    model.cuda(0)

    # Shannon and Preston's way of processing test files
    test_dataset = MRI(args.test_root,
                       img_size=(args.img_rows, args.img_cols),
                       mode='test')
    testLoader = DataLoader(test_dataset, batch_size=1)

    for (img, gt) in testLoader:
        img = img.numpy()
        img = img.astype(np.float64)
        # img -= loader.mean
        img -= 128
        img = img.astype(float) / 255.0
        img = np.expand_dims(img, axis=2)
        # NHWC -> NCWH # what does this mean?
        # N = number of images in the batch, H = height of the image, W = width of the image, C = number of channels of the image
        # https://stackoverflow.com/questions/37689423/convert-between-nhwc-and-nchw-in-tensorflow
        img = img.transpose(2, 0, 1)
        img = np.expand_dims(img, 0)
        img = torch.from_numpy(img).float()

        images = Variable(img.cuda(0), volatile=True)

    print("files are read!")
예제 #9
0
def test(cfg):

    device = torch.device("cuda:{}".format(cfg["training"]["gpu_idx"])
                          if torch.cuda.is_available() else "cpu")

    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]
    v_loader = data_loader(data_path, split='val')

    n_classes = v_loader.n_classes
    n_val = len(v_loader.files['val'])
    valLoader = data.DataLoader(v_loader,
                                batch_size=1,
                                num_workers=cfg["training"]["n_workers"])

    model = get_model(cfg["model"], n_classes).to(device)
    state = convert_state_dict(
        torch.load(cfg["testing"]["trained_model"],
                   map_location=device)["model_state"])
    model.load_state_dict(state)
    model.eval()
    model.to(device)

    running_metrics_val = runningScore(n_classes, n_val)
    with torch.no_grad():
        for i_val, (images_val, labels_val,
                    img_name_val) in tqdm(enumerate(valLoader)):
            images_val = images_val.to(device)
            labels_val = labels_val.to(device)

            outputs = model(images_val)

            pred = np.squeeze(outputs.data.max(1)[1].cpu().numpy())
            gt = np.squeeze(labels_val.data.cpu().numpy())

            running_metrics_val.update(gt, pred, i_val)

            decoded = v_loader.decode_segmap(pred, plot=False)
            m.imsave(
                pjoin(cfg["testing"]["path"],
                      '{}.bmp'.format(img_name_val[0])), decoded)

    score = running_metrics_val.get_scores()
    acc_all, dsc_cls = running_metrics_val.get_list()
    for k, v in score[0].items():
        print(k, v)

    if cfg["testing"]["boxplot"] == True:
        sns.set_style("whitegrid")
        labels = ['CSF', 'Gray Matter', 'White Matter']
        fig1, ax1 = plt.subplots()
        ax1.set_title('Basic Plot')
        # ax1.boxplot(dsc_cls.transpose()[:,1:n_classes], showfliers=False, labels=labels)
        ax1 = sns.boxplot(data=dsc_cls.transpose()[:, 1:n_classes])

        # ax1.yaxis.grid(True)
        ax1.set_xlabel('Three separate samples')
        ax1.set_ylabel('Dice Score')

        # path to save boxplot
        plt.savefig('/home/jwliu/disk/kxie/CNN_LSTM/test_results/box.pdf')
예제 #10
0
파일: train.py 프로젝트: syt2/CRA
def train(cfg, writer, logger):
    # This statement must be declared before using pytorch
    use_cuda = False
    if cfg.get("cuda", None) is not None:
        if cfg.get("cuda", None) != "all":
            os.environ["CUDA_VISIBLE_DEVICES"] = cfg.get("cuda", None)
        use_cuda = torch.cuda.is_available()

    # Setup random seed
    seed = cfg["training"].get("seed", random.randint(1, 10000))
    torch.manual_seed(seed)
    if use_cuda:
        torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    # Setup Dataloader
    train_loader, val_loader = get_loader(cfg)

    # Setup Model
    model = get_model(cfg)
    # writer.add_graph(model, torch.rand([1, 3, 224, 224]))
    if use_cuda and torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(
                                          range(torch.cuda.device_count())))

    # Setup optimizer, lr_scheduler and loss function
    optimizer = get_optimizer(model.parameters(), cfg)
    scheduler = get_scheduler(optimizer, cfg)
    loss_fn = get_loss_fn(cfg)

    # Setup Metrics
    epochs = cfg["training"]["epochs"]
    recorder = RecorderMeter(epochs)
    start_epoch = 0

    # save model parameters every <n> epochs
    save_interval = cfg["training"]["save_interval"]

    if use_cuda:
        model.cuda()
        loss_fn.cuda()

    # Resume Trained Model
    resume_path = os.path.join(writer.file_writer.get_logdir(),
                               cfg["training"]["resume"])
    best_path = os.path.join(writer.file_writer.get_logdir(),
                             cfg["training"]["best_model"])

    if cfg["training"]["resume"] is not None:
        if os.path.isfile(resume_path):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    resume_path))
            checkpoint = torch.load(resume_path)
            state = checkpoint["state_dict"]
            if torch.cuda.device_count() <= 1:
                state = convert_state_dict(state)
            model.load_state_dict(state)
            optimizer.load_state_dict(checkpoint["optimizer"])
            scheduler.load_state_dict(checkpoint["scheduler"])
            start_epoch = checkpoint["epoch"]
            recorder = checkpoint['recorder']
            logger.info("Loaded checkpoint '{}' (epoch {})".format(
                resume_path, checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(resume_path))

    epoch_time = AverageMeter()
    for epoch in range(start_epoch, epochs):
        start_time = time.time()
        need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg *
                                                            (epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)
        logger.info(
            '\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:8.6f}]'.
            format(time_string(), epoch, epochs, need_time, optimizer.
                   param_groups[0]['lr']) +  # scheduler.get_last_lr() >=1.4
            ' [Best : Accuracy={:.2f}]'.format(recorder.max_accuracy(False)))
        train_acc, train_los = train_epoch(train_loader, model, loss_fn,
                                           optimizer, use_cuda, logger)
        val_acc, val_los = validate_epoch(val_loader, model, loss_fn, use_cuda,
                                          logger)
        scheduler.step()

        is_best = recorder.update(epoch, train_los, train_acc, val_los,
                                  val_acc)
        if is_best or epoch % save_interval == 0 or epoch == epochs - 1:  # save model (resume model and best model)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'recorder': recorder,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                }, is_best, best_path, resume_path)

            for name, param in model.named_parameters():  # save histogram
                writer.add_histogram(name,
                                     param.clone().cpu().data.numpy(), epoch)

        writer.add_scalar('Train/loss', train_los, epoch)  # save curves
        writer.add_scalar('Train/acc', train_acc, epoch)
        writer.add_scalar('Val/loss', val_los, epoch)
        writer.add_scalar('Val/acc', val_acc, epoch)

        epoch_time.update(time.time() - start_time)

    writer.close()
예제 #11
0
def test(args, img_path, fname):
    wc_model_file_name = os.path.split(args.wc_model_path)[1]
    wc_model_name = wc_model_file_name[:wc_model_file_name.find('_')]

    bm_model_file_name = os.path.split(args.bm_model_path)[1]
    bm_model_name = bm_model_file_name[:bm_model_file_name.find('_')]

    wc_n_classes = 3
    bm_n_classes = 2

    wc_img_size = (256, 256)
    bm_img_size = (128, 128)

    # Setup image
    print("Read Input Image from : {}".format(img_path))
    imgorg = m.imread(img_path, mode='RGB')
    img = m.imresize(imgorg, wc_img_size)
    img = img[:, :, ::-1]
    img = img.astype(float) / 255.0
    img = img.transpose(2, 0, 1)  # NHWC -> NCHW
    img = np.expand_dims(img, 0)
    img = torch.from_numpy(img).float()

    # Predict
    htan = nn.Hardtanh(0, 1.0)
    wc_model = get_model(wc_model_name, wc_n_classes, in_channels=3)
    wc_state = convert_state_dict(
        torch.load(args.wc_model_path)['model_state'])
    wc_model.load_state_dict(wc_state)
    wc_model.eval()
    bm_model = get_model(bm_model_name, bm_n_classes, in_channels=3)
    bm_state = convert_state_dict(
        torch.load(args.bm_model_path)['model_state'])
    bm_model.load_state_dict(bm_state)
    bm_model.eval()

    if torch.cuda.is_available():
        wc_model.cuda()
        bm_model.cuda()
        images = Variable(img.cuda())
    else:
        images = Variable(img)

    with torch.no_grad():
        wc_outputs = wc_model(images)
        pred_wc = htan(wc_outputs)
        bm_input = F.interpolate(pred_wc, bm_img_size)
        outputs_bm = bm_model(bm_input)

    # call unwarp
    uwpred = unwarp(imgorg, outputs_bm)

    if args.show:
        f1, axarr1 = plt.subplots(1, 2)
        axarr1[0].imshow(imgorg)
        axarr1[1].imshow(uwpred)
        plt.show()

    # Save the output
    outp = os.path.join(args.out_path, fname)
    cv2.imwrite(outp, uwpred[:, :, ::-1] * 255)
예제 #12
0
def test(args, img_path, fname):
    wc_model_file_name = os.path.split(args.wc_model_path)[1]
    wc_model_name = wc_model_file_name[:wc_model_file_name.find('_')]

    bm_model_file_name = os.path.split(args.bm_model_path)[1]
    bm_model_name = bm_model_file_name[:bm_model_file_name.find('_')]

    wc_n_classes = 3
    bm_n_classes = 2

    wc_img_size = (256, 256)
    bm_img_size = (128, 128)

    # Setup image
    print("Read Input Image from : {}".format(img_path))
    imgorg = cv2.imread(img_path)
    imgorg = cv2.cvtColor(imgorg, cv2.COLOR_BGR2RGB)
    img = cv2.resize(imgorg, wc_img_size)
    img = img[:, :, ::-1]
    img = img.astype(float) / 255.0
    img = img.transpose(2, 0, 1)  # NHWC -> NCHW
    img = np.expand_dims(img, 0)
    img = torch.from_numpy(img).float()

    # Predict
    htan = nn.Hardtanh(0, 1.0)
    wc_model = get_model(wc_model_name, wc_n_classes, in_channels=3)
    if DEVICE.type == 'cpu':
        wc_state = convert_state_dict(torch.load(
            args.wc_model_path, map_location='cpu')['model_state'])
    else:
        wc_state = convert_state_dict(
            torch.load(args.wc_model_path)['model_state'])
    wc_model.load_state_dict(wc_state)
    wc_model.eval()
    bm_model = get_model(bm_model_name, bm_n_classes, in_channels=3)
    if DEVICE.type == 'cpu':
        bm_state = convert_state_dict(torch.load(
            args.bm_model_path, map_location='cpu')['model_state'])
    else:
        bm_state = convert_state_dict(
            torch.load(args.bm_model_path)['model_state'])
    bm_model.load_state_dict(bm_state)
    bm_model.eval()

    if torch.cuda.is_available():
        wc_model.cuda()
        bm_model.cuda()
        images = Variable(img.cuda())
    else:
        images = Variable(img)

    with torch.no_grad():
        wc_outputs = wc_model(images)
        pred_wc = htan(wc_outputs)
        bm_input = F.interpolate(pred_wc, bm_img_size,
                                 mode='bilinear', align_corners=True)
        outputs_bm = bm_model(bm_input)

        # call unwarp
        wc = pred_wc[0].cpu().detach().numpy().transpose((1, 2, 0))
        wc = cv2.resize(wc, (imgorg.shape[1], imgorg.shape[0]))
        mask = (wc * 255).astype(np.uint8)
        _, binary_im = cv2.threshold(mask, 1, 255, cv2.THRESH_BINARY)
        # binary_im = binary_im.astype(np.float32) / 255.0
        out = cv2.bitwise_and(imgorg, binary_im)
        bm = outputs_bm[0].cpu().detach().numpy().transpose((1, 2, 0))
        bm = cv2.resize(bm, (imgorg.shape[1], imgorg.shape[0]))
        uwpred = unwarp(imgorg, outputs_bm)

    if args.show:
        f1, axarr1 = plt.subplots(1, 6)
        axarr1[0].imshow(imgorg)
        axarr1[1].imshow(wc)
        axarr1[2].imshow(bm[:, :, 0])
        axarr1[3].imshow(uwpred)
        axarr1[4].imshow(binary_im, cmap='gray')
        axarr1[5].imshow(out)
        plt.show()

    # Save the output
    mask_out = os.path.join(args.out_path, fname[:-4] + '___.png')
    img_out = os.path.join(args.out_path, fname)
    wc_out = os.path.join(args.out_path, fname[:-4] + '__.png')
    cv2.imwrite(img_out, uwpred[:, :, ::-1]*255)
    cv2.imwrite(wc_out, wc[:, :, ::-1]*255)
    cv2.imwrite(mask_out, out[:, :, ::-1])
예제 #13
0
    img = Image.fromarray(img, "YCbCr").convert("RGB")
    return img


opt = parser.parse_args()
cuda = opt.cuda

if cuda:
    print("=> use gpu id: '{}'".format(opt.gpus))
    os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus
    if not torch.cuda.is_available():
        raise Exception(
            "No GPU found or Wrong gpu id, please run without --cuda")

model = MemNet(1, 64, 6, 6)
state = convert_state_dict(torch.load(opt.model)['model'])
#Since using multiple card for training model, so we need employ convert_state_dict() to remove the module prefix when loading module
model.load_state_dict(state)

im_gt_ycbcr = imread("data/SuperResolution/Set5/" + opt.image + ".bmp",
                     mode="YCbCr")
im_b_ycbcr = imread("data/SuperResolution/Set5/" + opt.image + "_scale_" +
                    str(opt.scale) + ".bmp",
                    mode="YCbCr")

im_gt_y = im_gt_ycbcr[:, :, 0].astype(float)
im_b_y = im_b_ycbcr[:, :, 0].astype(float)

psnr_bicubic = PSNR(im_gt_y, im_b_y, shave_border=opt.scale)

im_input = im_b_y / 255.
예제 #14
0
if __name__ == "__main__":

    transformations = transforms.Compose([transforms.ToTensor()])

    custom_mnist_from_images =  \
        CustomDatasetFromImagesEval('../data/iris_segmentation/lists/eval.csv')

    mn_dataset_loader = torch.utils.data.DataLoader(
        dataset=custom_mnist_from_images,
        batch_size=1,
        shuffle=False,
        num_workers=8)

    model = segnet()
    ckpt = torch.load("./best_model_22_1200.pkl")
    state = convert_state_dict(ckpt['model_state'])
    model.load_state_dict(state)
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count() -
                                                   1))
    model.cuda()
    with torch.no_grad():
        for i, (images, labels, name) in enumerate(mn_dataset_loader):
            images = Variable(images.cuda())
            labels = Variable(labels.cuda())
            #images = Variable(images)
            #labels = Variable(labels)
            # Forward pass
            outputs = model(images)
            #print(outputs.size())
            img_array = outputs.cpu().numpy()
예제 #15
0
def test(args):

    # Setup image
    print("Read Input Image from : {}".format(args.img_path))
    img = misc.imread(args.img_path)

    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    loader = data_loader(data_path, is_transform=True)
    n_classes = loader.n_classes

    resized_img = misc.imresize(img, (loader.img_size[0], loader.img_size[1]),
                                interp='bicubic')

    img = img[:, :, ::-1]
    img = img.astype(np.float64)
    img -= loader.mean
    img = misc.imresize(img, (loader.img_size[0], loader.img_size[1]))
    img = img.astype(float) / 255.0
    # NHWC -> NCWH
    img = img.transpose(2, 0, 1)
    img = np.expand_dims(img, 0)
    img = torch.from_numpy(img).float()

    # Setup Model
    model = get_model('fcn8s', n_classes)
    state = convert_state_dict(torch.load(args.model_path)['model_state'])
    model.load_state_dict(state)
    model.eval()

    model.cuda(0)
    images = Variable(img.cuda(0), volatile=True)

    outputs = F.softmax(model(images), dim=1)

    if args.dcrf == "True":
        unary = outputs.data.cpu().numpy()
        unary = np.squeeze(unary, 0)
        unary = -np.log(unary)
        unary = unary.transpose(2, 1, 0)
        w, h, c = unary.shape
        unary = unary.transpose(2, 0, 1).reshape(loader.n_classes, -1)
        unary = np.ascontiguousarray(unary)

        resized_img = np.ascontiguousarray(resized_img)

        d = dcrf.DenseCRF2D(w, h, loader.n_classes)
        d.setUnaryEnergy(unary)
        d.addPairwiseBilateral(sxy=5, srgb=3, rgbim=resized_img, compat=1)

        q = d.inference(50)
        mask = np.argmax(q, axis=0).reshape(w, h).transpose(1, 0)
        decoded_crf = loader.decode_segmap(np.array(mask, dtype=np.uint8))
        dcrf_path = args.out_path[:-4] + '_drf.png'
        misc.imsave(dcrf_path, decoded_crf)
        print("Dense CRF Processed Mask Saved at: {}".format(dcrf_path))

    if torch.cuda.is_available():
        model.cuda(0)
        images = Variable(img.cuda(0), volatile=True)
    else:
        images = Variable(img, volatile=True)

    pred = np.squeeze(outputs.data.max(1)[1].cpu().numpy(), axis=0)
    decoded = loader.decode_segmap(pred)
    print('Classes found: ', np.unique(pred))
    misc.imsave(args.out_path, decoded)
    print("Segmentation Mask Saved at: {}".format(args.out_path))