Exemplo n.º 1
0
def test(net, img, hyperparams):
    """
    Test a model on a specific image
    """
    net.eval()
    patch_size = hyperparams["patch_size"]
    center_pixel = hyperparams["center_pixel"]
    batch_size, device = hyperparams["batch_size"], hyperparams["device"]
    n_classes = hyperparams["n_classes"]

    kwargs = {
        "step": hyperparams["test_stride"],
        "window_size": (patch_size, patch_size),
    }
    probs = np.zeros(img.shape[:2] + (n_classes,))

    iterations = count_sliding_window(img, **kwargs) // batch_size
    for batch in tqdm(
        grouper(batch_size, sliding_window(img, **kwargs)),
        total=(iterations),
        desc="Inference on the image",
    ):
        with torch.no_grad():
            if patch_size == 1:
                data = [b[0][0, 0] for b in batch]
                data = np.copy(data)
                data = torch.from_numpy(data)
            else:
                data = [b[0] for b in batch]
                data = np.copy(data)
                data = data.transpose(0, 3, 1, 2)
                data = torch.from_numpy(data)
                data = data.unsqueeze(1)

            indices = [b[1:] for b in batch]
            data = data.to(device)
            output = net(data)
            if isinstance(output, tuple):
                output = output[0]
            output = output.to("cpu")

            if patch_size == 1 or center_pixel:
                output = output.numpy()
            else:
                output = np.transpose(output.numpy(), (0, 2, 3, 1))
            for (x, y, w, h), out in zip(indices, output):
                if center_pixel:
                    probs[x + w // 2, y + h // 2] += out
                else:
                    probs[x : x + w, y : y + h] += out
    return probs
Exemplo n.º 2
0
def test(net, img, hyperparams):
    """
    Test a model on a specific image
    """
    net.eval()
    patch_size = hyperparams['patch_size']
    center_pixel = hyperparams['center_pixel']
    batch_size, device = hyperparams['batch_size'], hyperparams['device']
    n_classes = hyperparams['n_classes']

    kwargs = {
        'step': hyperparams['test_stride'],
        'window_size': (patch_size, patch_size)
    }
    probs = np.zeros(img.shape[:2] + (n_classes, ))

    iterations = count_sliding_window(img, **kwargs) // batch_size
    for batch in tqdm(grouper(batch_size, sliding_window(img, **kwargs)),
                      total=(iterations),
                      desc="Inference on the image"):
        with torch.no_grad():
            if patch_size == 1:
                data = [b[0][0, 0] for b in batch]
                data = np.copy(data)
                data = torch.from_numpy(data)
            else:
                data = [b[0] for b in batch]
                data = np.copy(data)
                data = data.transpose(0, 3, 1, 2)
                data = torch.from_numpy(data)
                # data = data.unsqueeze(1)              # 3DConv时执行

            indices = [b[1:] for b in batch]
            data = data.to(device)
            output = net(data)
            if isinstance(output, tuple):
                output = output[0]
            output = output.to('cpu')  # 将cpu 改为 cuda

            if patch_size == 1 or center_pixel:
                output = output.numpy()
            else:
                output = np.transpose(output.numpy(), (0, 2, 3, 1))
            for (x, y, w, h), out in zip(indices, output):
                if center_pixel:
                    # probs[x, y] += out
                    probs[x + w // 2, y + h // 2] += out
                    # probs[x:x + w, y:y + h] += out
                else:
                    probs[x:x + w, y:y + h] += out
    return probs
Exemplo n.º 3
0
def test(net, img, args):
    """
    Test a model on a specific image
    """
    net.eval()
    patch_size = args.patch_size
    center_pixel = args.center_pixel
    batch_size, device = args.batch_size, torch.device(args.device)
    n_classes = args.n_classes

    kwargs = {
        'step': args.test_stride,
        'window_size': (patch_size, patch_size)
    }
    probs = np.zeros(img.shape[:2] + (n_classes, ))

    iterations = utils.count_sliding_window(img, **kwargs) // batch_size
    for batch in tqdm(utils.grouper(batch_size,
                                    utils.sliding_window(img, **kwargs)),
                      total=(iterations),
                      desc="Inference on the image"):
        with torch.no_grad():
            if patch_size == 1:
                data = [b[0][0, 0] for b in batch]
                data = np.copy(data)
                data = torch.from_numpy(data)
            else:
                data = [b[0] for b in batch]
                data = np.copy(data)
                data = data.transpose(0, 3, 1, 2)
                data = torch.from_numpy(data)
                data = data.unsqueeze(1)

            indices = [b[1:] for b in batch]
            data = data.to(device)
            output = net(data)
            if isinstance(output, tuple):
                output = output[0]
            output = output.to('cpu')

            if patch_size == 1 or center_pixel:
                output = output.numpy()
            else:
                output = np.transpose(output.numpy(), (0, 2, 3, 1))
            for (x, y, w, h), out in zip(indices, output):
                if center_pixel:
                    probs[x + w // 2, y + h // 2] += out
                else:
                    probs[x:x + w, y:y + h] += out
    return probs
Exemplo n.º 4
0
def test(net,
         test_ids,
         all=False,
         stride=None,
         batch_size=None,
         window_size=None):
    # Default params
    if stride is None:
        stride = cfg.WINDOW_WIDTH

    if batch_size is None:
        batch_size = cfg.BATCH_SIZE

    if window_size is None:
        window_size = cfg.WINDOW_SIZE

    # Use the network on the test set
    test_images = (
        1 / 255.0 *
        np.asarray(io.imread(cfg.DATA_FOLDER.format(id)), dtype='float32')
        for id in test_ids)
    test_labels = (np.asarray(io.imread(cfg.LABEL_FOLDER.format(id)),
                              dtype='uint8') for id in test_ids)
    eroded_labels = (convert_from_color(io.imread(
        cfg.ERODED_FOLDER.format(id))) for id in test_ids)

    all_preds = []
    all_gts = []

    # Switch the network to inference mode
    net.eval()

    # Start a loop to get image, ground truth and eroded ground truth
    for img, gt, gt_e in tqdm(zip(test_images, test_labels, eroded_labels),
                              total=len(test_ids),
                              leave=False):
        # container for predection
        pred = np.zeros(img.shape[:2] + (cfg.N_CLASSES, ))

        total = count_sliding_window(img, step=stride,
                                     window_size=window_size) // batch_size
        print "Total windows in image: {}".format(total)
        for i, coords in enumerate(
                tqdm(grouper(
                    batch_size,
                    sliding_window(img, step=stride, window_size=window_size)),
                     total=total,
                     leave=False)):
            # Display in progress results
            print "{} of {} done....".format(i, total)
            """if i > 0 and total > 10 and i % int(10 * total / 100) == 0:
					_pred = np.argmax(pred, axis=-1)
					fig = plt.figure()
					fig.add_subplot(1,3,1)
					plt.imshow(np.asarray(255 * img, dtype='uint8'))
					fig.add_subplot(1,3,2)
					plt.imshow(convert_to_color(_pred))
					fig.add_subplot(1,3,3)
					plt.imshow(gt)
					# clear_output()
					plt.show()"""

            # Build the tensor
            image_patches = [
                np.copy(img[x:x + w, y:y + h]).transpose((2, 0, 1))
                for x, y, w, h in coords
            ]
            image_patches = np.asarray(image_patches)
            image_patches = Variable(torch.from_numpy(image_patches).cuda(),
                                     volatile=True)

            # Do the inference
            outs = net(image_patches)
            outs = outs.data.cpu().numpy()

            # Fill in the results array
            for out, (x, y, w, h) in zip(outs, coords):
                out = out.transpose((1, 2, 0))
                pred[x:x + w, y:y + h] += out
            del (outs)

        pred = np.argmax(pred, axis=-1)

        # Display the result
        # clear_output()
        # fig = plt.figure()
        # fig.add_subplot(1,3,1)
        # plt.imshow(np.asarray(255 * img, dtype='uint8'))
        # fig.add_subplot(1,3,2)
        # plt.imshow(convert_to_color(pred))
        # fig.add_subplot(1,3,3)
        # plt.imshow(gt)
        # plt.show()

        all_preds.append(pred)
        all_gts.append(gt_e)

        # clear_output()
        # Compute some metrics
        metrics(pred.ravel(), gt_e.ravel())
        accuracy = metrics(
            np.concatenate([p.ravel() for p in all_preds]),
            np.concatenate([p.ravel() for p in all_gts]).ravel())
    if all:
        return accuracy, all_preds, all_gts
    else:
        return accuracy
Exemplo n.º 5
0
def pred_and_display(net,
                     test_ids,
                     stride=None,
                     batch_size=None,
                     window_size=None):
    # Default params
    if stride is None:
        stride = cfg.WINDOW_WIDTH

    if batch_size is None:
        batch_size = cfg.BATCH_SIZE

    if window_size is None:
        window_size = cfg.WINDOW_SIZE

    test_images = (
        1 / 255.0 *
        np.asarray(io.imread(cfg.DATA_FOLDER.format(id)), dtype='float32')
        for id in test_ids)

    all_preds = []
    net.eval()

    for img in tqdm(test_images, total=len(test_ids), leave=False):
        pred = np.zeros(img.shape[:2] + (cfg.N_CLASSES, ))

        total = count_sliding_window(img, step=stride,
                                     window_size=window_size) // batch_size
        # print total
        # exit(0)
        for i, coords in enumerate(
                tqdm(grouper(
                    batch_size,
                    sliding_window(img, step=stride, window_size=window_size)),
                     total=total,
                     leave=False)):
            image_patches = [
                np.copy(img[x:x + w, y:y + h]).transpose((2, 0, 1))
                for x, y, w, h in coords
            ]
            image_patches = np.asarray(image_patches)
            image_patches = Variable(torch.from_numpy(image_patches).cuda(),
                                     volatile=True)

            # Do the inference
            outs = net(image_patches)
            outs = outs.data.cpu().numpy()

            for out, (x, y, w, h) in zip(outs, coords):
                out = out.transpose((1, 2, 0))
                pred[x:x + w, y:y + h] += out
            del (outs)

        pred = np.argmax(pred, axis=-1)

        # Display the result
        # clear_output()
        fig = plt.figure()
        fig.add_subplot(1, 2, 1)
        plt.imshow(np.asarray(255 * img, dtype='uint8'))
        fig.add_subplot(1, 2, 2)
        plt.imshow(convert_to_color(pred))
        plt.show()

        all_preds.append(pred)
    return all_preds