Пример #1
0
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5):
    net.eval()

    img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))

    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)

        if net.num_classes > 1:
            probs = F.softmax(output, dim=1)
        else:
            probs = torch.sigmoid(output)

        probs = probs.squeeze(0)

        tf = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(full_img.size[1]),
            transforms.ToTensor()
        ])

        probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()

    return full_mask > out_threshold
Пример #2
0
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5):
    net.eval()

    img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))

    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)

        print(output.shape)
        #if net.n_classes > 1:
        #    probs = F.softmax(output, dim=1)
        #else:
        #    probs = torch.sigmoid(output)
        probs_0 = torch.sigmoid(output[:, 0, :, :])
        probs_1 = torch.sigmoid(output[:, 1, :, :])

        probs_0 = probs_0.squeeze(0)
        probs_1 = probs_1.squeeze(0)

        tf = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(full_img.width),  #size[1]),
            transforms.ToTensor()
        ])

        probs_0 = tf(probs_0.cpu())
        probs_1 = tf(probs_1.cpu())
        mask_0 = probs_0.squeeze().cpu().numpy()
        mask_1 = probs_1.squeeze().cpu().numpy()
        full_mask = np.array([mask_0, mask_1])  #probs.squeeze().cpu().numpy()

    return full_mask  # > out_threshold
Пример #3
0
def predict_img(net, full_img, device, scale_factor=1, use_dense_crf=False):
    net.eval()
    img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))

    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)
        #output,aux = net(img)
        if net.n_classes > 1:
            probs = F.softmax(output, dim=1)
        else:
            probs = torch.sigmoid(output)

        probs = probs.squeeze(0)

        tf = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(full_img.shape[1]),
            transforms.ToTensor()
        ])

        probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()

    if use_dense_crf:
        full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask)

    # return full_mask > out_threshold
    return np.argmax(full_mask, axis=0)
Пример #4
0
def predict_img(net, full_img, device, scale_factor=0.267, out_threshold=0.6):
    net.eval()

    img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))
    img_pro = (img).squeeze(0).numpy()
    img_show = sitk.GetImageFromArray(img_pro)
    #sitk.WriteImage(img_show, './data/pred/scale0.2_002input.nii')
    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)
        max_out = output.cpu().numpy().max()

        if net.n_classes > 1:
            probs = F.softmax(output, dim=1)
        else:
            probs = torch.sigmoid(output)
            max0 = probs.cpu().numpy().max()

        probs = output.squeeze()

        tf = transforms.Compose([
            transforms.ToPILImage(),
            #transforms.Resize((481,481,481)),
            transforms.ToTensor()
        ])

        #probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()
    return full_mask
Пример #5
0
def predict_img(net,
                full_img,
                device,
                scale_factor=1,
                out_threshold=0.5,
                use_dense_crf=False):
    net.eval()

    ds = BasicDataset('data/training/img/',
                      'data/training/full_mask/',
                      scale=scale_factor)
    img = ds.preprocess(full_img)
    img = torch.from_numpy(img)
    img = torch.unsqueeze(img, 0)
    img = img.to(device=device, dtype=torch.float32)
    with torch.no_grad():
        output = net(img)
        probs = torch.sigmoid(output)
        probs = probs.squeeze(0)

        tf = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(full_img.size[1]),
            transforms.ToTensor()
        ])

        probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()

    if use_dense_crf:
        full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask)

    return full_mask > out_threshold
Пример #6
0
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5):
    net.eval()

    img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))

    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        logging.info(f"Input: {img.shape}")
        output = net(img)
        logging.info(f"Output: {output.shape}")
        
        if net.n_classes > 1:
            probs = F.softmax(output, dim=1)
        else:
            probs = torch.sigmoid(output)

        probs = probs.squeeze(0)

        tf = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(full_img.size[1]),
            transforms.ToTensor()]
        )

        probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()
        result = full_mask > out_threshold
        logging.info(f"Result shape: {result.shape}")
        # logging.info("Result: {result}")
        k = result[True].sum()
        k_str = str(k)
        logging.info(f"Forentground: {k_str}")
    return result
Пример #7
0
def predict_img(net, full_img, device, resizing=572, out_threshold=0.5):

    net.eval()  # evaluation mode. parameters aren't updated

    img = torch.from_numpy(BasicDataset.preprocess(full_img, resizing))

    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)

        if net.n_classes > 1:
            probs = F.softmax(output, dim=1)
        else:
            probs = torch.sigmoid(output)

        probs = probs.squeeze(0)

        tf = transforms.Compose(
            [transforms.ToPILImage(),
             transforms.ToTensor()])

        probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()

    return full_mask > out_threshold
Пример #8
0
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5):
    net.eval()

    #对图像进行预处理,将其转换成张量
    img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))

    #unsqueeze对数据维度进行扩充
    img = img.unsqueeze(0)
    #使用选择的设备将数据类型转换成float32
    img = img.to(device=device, dtype=torch.float32)

    #使用pytorh时,默认要进行计算图构建,使用with时,强制之后的内容不进行计算图构建
    with torch.no_grad():
        output = net(img)

        if net.n_classes > 1:
            probs = F.softmax(output, dim=1)
        else:
            probs = torch.sigmoid(output)

        #squeeze对数据维度进行缩减,删除一维空间
        probs = probs.squeeze(0)

        #对PILImage进行变换
        tf = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(full_img.size[1]),
            transforms.ToTensor()
        ])

        probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()

    return full_mask > out_threshold
Пример #9
0
def predict_img(net,
                full_img,
                device,
                scale_factor=1,
                out_threshold=0.5):
    net.eval()

    img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))

    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)
        if net.n_classes > 1:
            probs = F.softmax(output, dim=1)
        else:
            probs = torch.sigmoid(output)
        probs = probs.squeeze(0)
        full_mask = probs.cpu().numpy()
        _, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(1, 5, sharey=True)
        ax1.imshow(full_mask[0,:,:].squeeze())
        ax2.imshow(full_mask[1,:,:].squeeze())
        ax3.imshow(full_mask[2,:,:].squeeze())
        ax4.imshow(full_mask[3,:,:].squeeze())
        ax5.imshow(full_mask[4,:,:].squeeze())
        plt.show()
        full_mask = np.argmax(full_mask, 0)
    return full_mask
Пример #10
0
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5):
    net.eval()

    data_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.6976, 0.7220, 0.7588],
                             std=[0.1394, 0.1779, 0.2141])
    ])

    img = BasicDataset.preprocess(full_img, scale_factor, 'img')
    img = data_transforms(img)

    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)

        if net.n_classes > 1:
            probs = F.softmax(output, dim=1)
        else:
            probs = torch.sigmoid(output)

        probs = probs.squeeze(0)

        tf = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(full_img.size[1]),
            transforms.ToTensor()
        ])

        probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()

    return full_mask > out_threshold
Пример #11
0
def predict_img(net,
                full_img,
                device,
                scale_factor=1,
                out_threshold=0.5,
                use_dense_crf=False):
    net.eval()

    ds = BasicDataset('', '', scale=scale_factor)
    img = torch.from_numpy(ds.preprocess(full_img))
    print('28 img', img.shape)
    img = img.unsqueeze(0)
    print('30 img', img.shape)
    img = img.to(device=device, dtype=torch.float32)

    # traced_script_module = torch.jit.trace(net, img)
    # 保存模型
    # traced_script_module.save("torch_script_eval.pth")

    with torch.no_grad():
        print('38 input into net', img.shape)
        output = net(img)
        print('40 output shape', output.shape)

        if net.module.n_classes > 1:
            probs = F.softmax(output, dim=1)
        else:
            probs = torch.sigmoid(output)

        probs = probs.squeeze(0)
        print('48 probs shape', probs.shape)

        tf = transforms.Compose([
            transforms.ToPILImage(),
            # transforms.Resize(np.array(full_img).shape[1]),
            transforms.ToTensor()
        ])

        probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()

    # if use_dense_crf:
    #     full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask)

    full_mask[0][full_mask[0] > out_threshold] = 1.0
    full_mask[0][full_mask[0] <= out_threshold] = 0.0
    full_mask[1][full_mask[1] > out_threshold] = 1.0
    full_mask[1][full_mask[1] <= out_threshold] = 0.0
    full_mask[2][full_mask[2] > out_threshold] = 1.0
    full_mask[2][full_mask[2] <= out_threshold] = 0.0
    full_mask[3][full_mask[3] > out_threshold] = 1.0
    full_mask[3][full_mask[3] <= out_threshold] = 0.0
    return full_mask
Пример #12
0
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5):
    net.eval()

    img = torch.from_numpy(
        BasicDataset.preprocess(full_img, scale_factor, False))

    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)
        output_seg = output.max(dim=1)[1].unsqueeze(1)
        output_seg = output_seg.data.cpu().numpy()
        return output_seg[0, 0, :, :]
Пример #13
0
def plot_imgs_pred():
    """
    Funzione
    :return:
    """
    args = get_plot_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # img = Image.open(args.dir_img)
    img = tiff.imread(args.dir_img)
    img = BasicDataset.preprocess(img, scale=args.scale)
    img = torch.from_numpy(img).type(torch.FloatTensor)

    if args.model_arch == 'unet':
        net = UNet(n_channels=4, n_classes=4, bilinear=True)

    elif args.model_arch == 'icnet':
        net = ICNet(n_channels=4, n_classes=4, pretrained_base=False)

    net.load_state_dict(torch.load(args.checkpoint_net, map_location=device))
    net.to(device=device)
    net.eval()

    img = img.to(device=device, dtype=torch.float32)
    img = img.unsqueeze(0)

    with torch.no_grad():
        if args.model_arch == 'icnet':
            mask_pred, pred_sub4, pred_sub8, pred_sub16 = net(img)
        else:
            mask_pred = net(img)

    plt.imshow(img[0][0])
    plt.colorbar()
    plt.savefig(args.dir_output + "original_img.png")
    plt.clf()

    for i, c in enumerate(mask_pred):
        n_classes = c.size(0)
        classes = range(n_classes)
        c = torch.sigmoid(c)
        max_index = torch.max(c, 0).indices
        for class_index in classes:
            # Vediamo la predizione
            jaccard_input = (max_index == class_index).float()
            plt.imshow(jaccard_input)
            plt.colorbar()
            plt.savefig(args.dir_output + f"pred_cls_{class_index}.png")
            plt.clf()
Пример #14
0
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5):
    net.eval()

    img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))

    # img_input = transforms.ToPILImage()(img.type(torch.float32)).convert('RGB')
    # img_input.save('input.jpg')
    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)

        if net.n_classes > 1:
            probs = F.softmax(output, dim=1)
        else:
            probs = torch.sigmoid(output)

        probs = probs.squeeze(0)

        _, probs = torch.max(probs, dim=0)
        print(f'probs.max():{probs.max()}')
        probs = probs.unsqueeze(0) * 10
        probs = probs.type(torch.float32)

        # tf = transforms.Compose(
        #     [
        #         transforms.ToPILImage(),
        #         transforms.Resize(full_img.size[1]),
        #         transforms.ToTensor()
        #     ]
        # )
        tf = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(full_img.size[1]),
            transforms.ToTensor()
        ])

        probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()

    return full_mask > out_threshold
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.0):
    net.eval()

    img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))

    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)
        if net.n_classes > 1:
            probs = F.softmax(output, dim=1)
        else:
            probs = torch.sigmoid(output)
        probs = probs.squeeze(0)
        print(probs)
        full_mask = probs.cpu().numpy()
        print(type(full_mask))
        print("*******************************************************")
        """
        _, (ax1, ax2, ax3, ax4, ax5,ax6,ax7,ax8,ax9,ax10,ax11,ax12,ax13,ax14) = plt.subplots(1, 14, sharey=True)
        ax1.imshow(full_mask[0,:,:].squeeze())
        ax2.imshow(full_mask[1,:,:].squeeze())
        ax3.imshow(full_mask[2,:,:].squeeze())
        ax4.imshow(full_mask[3,:,:].squeeze())
        ax5.imshow(full_mask[4,:,:].squeeze())
        ax6.imshow(full_mask[5, :, :].squeeze())
        ax7.imshow(full_mask[6, :, :].squeeze())
        ax8.imshow(full_mask[7, :, :].squeeze())
        ax9.imshow(full_mask[8, :, :].squeeze())
        ax10.imshow(full_mask[9, :, :].squeeze())
        ax11.imshow(full_mask[10, :, :].squeeze())
        ax12.imshow(full_mask[11, :, :].squeeze())
        ax13.imshow(full_mask[12, :, :].squeeze())
        ax14.imshow(full_mask[13, :, :].squeeze())
        """

        print(full_mask)
        full_mask = np.argmax(full_mask, axis=0)
        print("--***********************************************")
        print(full_mask.shape)
    return full_mask
Пример #16
0
def inference_one(net, image, device):
    net.eval()

    img = torch.from_numpy(BasicDataset.preprocess(image, cfg.scale))

    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)
        if cfg.deepsupervision:
            output = output[-1]

        if cfg.n_classes > 1:
            probs = F.softmax(output, dim=1)
        else:
            probs = torch.sigmoid(output)

        probs = probs.squeeze(0)        # C x H x W

        tf = transforms.Compose(
                [
                    transforms.ToPILImage(),
                    transforms.Resize((image.size[1], image.size[0])),
                    transforms.ToTensor()
                ]
        )

        if cfg.n_classes == 1:
            probs = tf(probs.cpu())
            mask = probs.squeeze().cpu().numpy()
            return mask > cfg.out_threshold
        else:
            masks = []
            for prob in probs:
                prob = tf(prob.cpu())
                mask = prob.squeeze().cpu().numpy()
                mask = mask > cfg.out_threshold
                masks.append(mask)
            return masks
Пример #17
0
def predict_img(net,
                full_img,
                device,
                scale_factor=1,
                out_threshold=0.5):
    net.eval()

    img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))

    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        # output = net(img)
        centerlines, points = net(img)

        if net.n_classes > 1:
            probs1 = F.softmax(centerlines, dim=1)
            probs2 = F.softmax(points, dim=1)
        else:
            probs1 = torch.sigmoid(centerlines)
            probs2 = torch.sigmoid(points)

        probs1 = probs1.squeeze(0)
        probs2 = probs2.squeeze(0)

        tf = transforms.Compose(
            [
                transforms.ToPILImage(),
                transforms.Resize(full_img.size[1]),
                transforms.ToTensor()
            ]
        )

        probs1 = tf(probs1.cpu())
        probs2 = tf(probs2.cpu())
        full_centerlines = probs1.squeeze().cpu().numpy()
        full_points = probs2.squeeze().cpu().numpy()

    return full_centerlines > out_threshold, full_points > out_threshold
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.0):
    net.eval()

    img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))
    print(img.shape)
    img = img.unsqueeze(0)
    print(img.shape)
    img = img.to(device=device, dtype=torch.float32)
    with torch.no_grad():
        output = net(img)
        print(output)
        if net.n_classes > 1:
            probs = F.softmax(output, dim=1)
            print(probs.shape)
        else:
            probs = torch.sigmoid(output)

        probs = probs.squeeze(0)

        print("probsss", probs)

        tf = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(full_img.size[1]),
            transforms.ToTensor()
        ])

        probs = tf(probs.cpu())

        sum1 = 0
        full_mask = probs.squeeze().cpu().numpy()
        """
        for i in range(0,375):
            sum1 += (full_mask[0][0][i])
        print(sum1)
        """
        print("*******", full_mask)
        print(out_threshold)
    return np.argmax(full_mask, axis=0)
Пример #19
0
def predict_img(net,
                full_img,
                device,
                scale_factor=1,
                out_threshold=0.5):
    net.eval()

    img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))

    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)

        # if net.n_classes > 1:
        #     probs = F.softmax(output, dim=1)
        # else:
        #     probs = torch.sigmoid(output)
        probs = torch.sigmoid(output)
        probs = probs.squeeze(0)

        # tf = transforms.Compose(
        #     [
        #         transforms.ToPILImage(),
        #         transforms.Resize(full_img.size[1]),
        #         transforms.ToTensor()
        #     ]
        # )
        #
        # probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()

        import matplotlib.pyplot as plt
        plt.figure()
        plt.imshow(full_mask[2:,:,:].argmax(0)), plt.colorbar()
        plt.show()
    return full_mask #> out_threshold
def prediction(net, imgs, device):
    net.eval()
    ds = BasicDataset('data/training/img',
                      'data/training/full_mask',
                      scale=0.5)
    img = ds.preprocess(imgs)
    img = torch.from_numpy(img)
    img = torch.unsqueeze(img, 0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)
        probs = torch.sigmoid(output)
        probs = probs.squeeze(0)

        tf = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(imgs.size[1]),
            transforms.ToTensor()
        ])
        probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()
    return full_mask > 0.5
Пример #21
0
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.3):
    net.eval()

    img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))

    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)

        probs = output.squeeze(0)

        tf = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(size if resize else full_img.size[1]),
            # transforms.Resize(full_img.size[1]),
            transforms.ToTensor()
        ])
        probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()

    return full_mask > out_threshold
Пример #22
0
def predict_img(net,
                full_img,
                device,
                file_name,
                scale_factor=1,
                out_threshold=0.5):
    net.eval()

    img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))

    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)['out']

        if net.num_classes > 1:
            probs = F.softmax(output, dim=1)
        else:
            probs = torch.sigmoid(output)

        probs = probs.squeeze(0)
        save_array = probs.cpu().numpy()
        mat_prob = np.reshape(save_array, [300, 300])
        save_fn = 'D:/users/otis/MedicalImage_Project02_Segmentation/MedicalImage_Project02_Segmentation/private_data_10/private_data_10/Results' + file_name[:
                                                                                                                                                              -4] + '_prob.mat'
        sio.savemat(save_fn, {'array': mat_prob})
        tf = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(full_img.size[1]),
            transforms.ToTensor()
        ])

        probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()

    return full_mask > out_threshold
Пример #23
0
def predict_img(net,
                full_img,
                device,
                scale_factor=1,
                out_threshold=0.5,
                outf=None):
    net.eval()
    transform = transforms.Compose([transforms.Resize((128, 128))])
    img = torch.from_numpy(
        BasicDataset.preprocess(full_img, scale_factor, transform))
    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)

        if net.n_classes > 1:
            probs = F.softmax(output, dim=1)
        else:
            probs = torch.sigmoid(output)

        probs = probs.squeeze(0)

        tf = transforms.Compose([
            transforms.ToPILImage(),
            #transforms.Resize(full_img.size[1]),
            transforms.ToTensor()
        ])

        probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()

    outfn = outf.split('/')
    outtrain_fn = "./data/predicted/train128_" + outfn[3]
    save_image(img, outtrain_fn)

    return full_mask > out_threshold
Пример #24
0
        logging.info(f'\tPredicting on image {idx+1} (or {img_idx}) ...')

        image_addr = config.TEST_RGB_DIR_PATH + img_idx + config.RGB_IMG_EXT
        img = Image.open(image_addr)

        depth_addr = config.TEST_DEPTH_DIR_PATH + img_idx + config.DEPTH_IMG_EXT
        depth = Image.open(depth_addr)
        depth = Image.fromarray(skimage.color.gray2rgb(np.array(depth)))

        gt_mask_addr = config.TEST_MASKS_DIR_PATH + img_idx + config.GT_MASK_EXT
        gt = Image.open(gt_mask_addr)

        # preprocess
        img = BasicDataset.crop(img, config.CROP_H, config.CROP_H, is_img=True)
        img = BasicDataset.preprocess(Image.fromarray(img),
                                      args.scale,
                                      is_img=True)
        depth = BasicDataset.crop(depth,
                                  config.CROP_H,
                                  config.CROP_H,
                                  is_img=True)
        depth = BasicDataset.preprocess(Image.fromarray(depth),
                                        args.scale,
                                        is_img=True)
        gt = BasicDataset.crop(gt, config.CROP_H, config.CROP_H, is_img=False)
        gt = Image.fromarray(gt)

        pred_gt_mask_addr = config.TEST_PRED_DIR_PATH + img_idx + config.GT_SAVE_MASK_EXT
        gt.save(pred_gt_mask_addr)

        pred_mask_addr = config.TEST_PRED_DIR_PATH + img_idx + config.PRED_MASK_EXT
            hsv[..., 2] = mag
            int_mask = cv2.cvtColor(hsv, cv2.COLOR_BGR2RGB)
            int_mask = Image.fromarray(np.uint8(int_mask))

            # predict mask using flowUNet
            mask_pred, mask = predict_img(net=net,
                               int_mask=int_mask,
                               org_img=org_img,
                               scale_factor=args.scale,
                               out_threshold=args.mask_threshold,
                               device=device)
            mask_pred = torch.from_numpy(mask_pred).type(torch.FloatTensor)
            mask_pred = mask_pred.to(device=device, dtype=torch.float32)
            pred = torch.sigmoid(mask_pred)
            pred = (pred > 0.5).float()
            true_mask = BasicDataset.preprocess(true_mask, 1)
            true_mask = torch.from_numpy(true_mask).type(torch.FloatTensor)
            true_mask = true_mask.to(device=device, dtype=torch.float32)

            end_time = time.time()
            total_time += end_time - start_time

            print("Time taking for predicting:", str(end_time-start_time))

            if not args.no_eval:
                dc = dice_coeff(pred.unsqueeze(0).unsqueeze(0), true_mask.unsqueeze(0)).item()
                tot += dc
                print("Dice Coefficient: " + str(dc))

            if not args.no_save:
                output_file = args.output + 'output_' + img_files[i]