Ejemplo n.º 1
0
def save_results_to_disk(p, val_loader, model, crf_postprocess=False):
    print('Save results to disk ...')
    model.eval()

    # CRF
    if crf_postprocess:
        from utils.crf import dense_crf

    counter = 0
    for i, batch in enumerate(val_loader):
        output = model(batch['image'].cuda(non_blocking=True))
        meta = batch['meta']
        for jj in range(output.shape[0]):
            counter += 1
            image_file = meta['image_file'][jj]

            # CRF post-process
            if crf_postprocess:
                probs = dense_crf(meta['image_file'][jj], output[jj])
                pred = np.argmax(probs, axis=0).astype(np.uint8)
            
            # Regular
            else:
                pred = torch.argmax(output[jj], dim=0).cpu().numpy().astype(np.uint8)

            result = cv2.resize(pred, dsize=(meta['im_size'][1][jj], meta['im_size'][0][jj]), 
                                        interpolation=cv2.INTER_NEAREST)
            imageio.imwrite(os.path.join(p['save_dir'], meta['image'][jj] + '.png'), result)
   
        if counter % 250 == 0:
            print('Saving results: {} of {} objects'.format(counter, len(val_loader.dataset)))
Ejemplo n.º 2
0
def predict_img(net, full_img, gpu=False):
    img = resize(full_img)
    img = np.array(img)
    img = torch.FloatTensor(img)

    x = img.permute(2, 0, 1).contiguous()  # transform to (C x H x W)

    x = x.view(1, 3, 256, 255)  # image (N x C x H x W)

    if gpu:
        with torch.no_grad():
            x = Variable(x).cuda()
    else:
        with torch.no_grad():
            x = Variable(x)

    x = normalize(x)  # normalize values to [0, 1]

    x = net(x)  # feed into the net

    x = F.sigmoid(x)
    x = F.upsample_bilinear(x, scale_factor=2).data[0][0].cpu().numpy(
    )  # rescale the image to full size

    yy = dense_crf(np.array(full_img).astype(np.uint8), x)

    return yy > 0.5
Ejemplo n.º 3
0
def predict_img(net,
                full_img,
                device,
                scale_factor=1,
                out_threshold=0.5,
                use_dense_crf=False):
    net.eval()

    ds = BasicDataset('data/training/img/',
                      'data/training/full_mask/',
                      scale=scale_factor)
    img = ds.preprocess(full_img)
    img = torch.from_numpy(img)
    img = torch.unsqueeze(img, 0)
    img = img.to(device=device, dtype=torch.float32)
    with torch.no_grad():
        output = net(img)
        probs = torch.sigmoid(output)
        probs = probs.squeeze(0)

        tf = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(full_img.size[1]),
            transforms.ToTensor()
        ])

        probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()

    if use_dense_crf:
        full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask)

    return full_mask > out_threshold
Ejemplo n.º 4
0
def predict_img(net,
                full_img,
                device,
                scale_factor=1,
                out_threshold=0.5,
                use_dense_crf=False):
    net.eval()

    img = full_img.resize((320, 240))
    img = np.array(img)
    img = (img < 1200) * img
    img = img / img.max()
    img = np.expand_dims(img, axis=2)
    img = img.transpose((2, 0, 1))
    img = torch.from_numpy(img.astype(np.float32))
    # img = preprocess(full_img, scale_factor)

    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)

        if net.n_classes > 1:
            probs = F.softmax(output, dim=1)
        else:
            probs = torch.sigmoid(output)

        probs = probs.squeeze(0)

        tf = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((240, 320)),
            transforms.ToTensor()
        ])

        probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()

    if use_dense_crf:
        full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask)

    return full_mask > out_threshold
Ejemplo n.º 5
0
def predict_img(net,
                full_img,
                device,
                scale_factor=1,
                out_threshold=0.5,
                use_dense_crf=False):
    net.eval()

    img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))

    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)

        if net.n_classes > 1:
            probs = F.softmax(output, dim=1)
        else:
            probs = torch.sigmoid(output)

        probs = probs.squeeze(0)

        tf = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(full_img.size[1]),
            transforms.ToTensor()
        ])

        probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()

    if use_dense_crf:
        full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask)

    return full_mask > out_threshold
Ejemplo n.º 6
0
def main():
    args = parser.parse_args()
    model = WNet.WNet(args.squeeze)

    model.load_state_dict(
        torch.load(args.model, map_location=torch.device('cpu')))
    model.eval()

    transform = transforms.Compose(
        [transforms.Resize((64, 64)),
         transforms.ToTensor()])

    img = Image.open("data2/images/train/1head.png").convert('RGB')
    x = transform(img)[None, :, :, :]

    enc, dec = model(x)
    show_image(x[0])
    # TODO: torch sum/ stack?
    show_image(enc[0, :3, :, :].detach())
    # show_image(torch.argmax(enc[:,:,:,:], dim=1))
    # show_image(dec[0, :, :, :].detach())
    # now put enc in crf
    segment = enc[0, :, :, :].detach()
    # put in tensor here?

    orimg = imread("data2/images/train/1head.png")
    img = resize(orimg, (64, 64))
    Q = dense_crf(img, segment.numpy())

    print(type(Q))
    Q = np.argmax(Q, axis=0)
    print(len(Q))

    print(np.unique(Q))
    plt.imshow(Q)
    plt.show()
Ejemplo n.º 7
0
          (best_iou, (best_iou * 2) / (best_iou + 1)))

else:
    raise ValueError("can't find model")

crf = True

with torch.no_grad():
    img, mask = img.to(device), mask.to(device)
    output = model(img)  #[1, 9, 256, 256]
    probs = F.softmax(output, dim=1)
    if crf:
        pred_crf = probs.cpu().data[0].numpy()
        # crf
        img = img.cpu().data[0].numpy()
        pred_crf = dense_crf(img * 255, pred_crf)
        pred_crf = np.asarray(pred_crf, dtype=np.int)
        # 合并特征
        pred_crf = merge_classes(pred_crf)
    _, pred = torch.max(probs, dim=1)
    pred = pred.cpu().data[0].numpy()
    label = mask.cpu().data[0].numpy()
    pred = np.asarray(pred, dtype=np.int)
    label = np.asarray(label, dtype=np.int)
    pred = merge_classes(pred)
    label = merge_classes(label)
cv2.namedWindow("image", 0)
cv2.imshow("image", image)
cv2.namedWindow("mask", 0)
cv2.imshow("mask", encode(label, color_test))
cv2.namedWindow("pred", 0)
Ejemplo n.º 8
0
          (best_iou, (best_iou * 2) / (best_iou + 1)))
else:
    raise ValueError("can't find model")

print(">>>Test After Dense CRF: ")
model.eval()
running_metrics.reset()
with torch.no_grad():
    for i, (img, mask) in tqdm(enumerate(val_loader)):
        img = img.to(device)
        output = model(img)  #[-1, 9, 256, 256]
        probs = F.softmax(output, dim=1)
        pred = probs.cpu().data[0].numpy()
        label = mask.cpu().data[0].numpy()
        # crf
        img = img.cpu().data[0].numpy()
        pred = dense_crf(img * 255, pred)
        # print(pred.shape)
        # _, pred = torch.max(torch.tensor(pred), dim=-1)
        pred = np.asarray(pred, dtype=np.int)
        label = np.asarray(label, dtype=np.int)
        # 合并特征
        pred = merge_classes(pred)
        label = merge_classes(label)
        # print(pred.shape,label.shape)
        running_metrics.update(label, pred)

score, class_iou = running_metrics.get_scores()
for k, v in score.items():
    print(k, ':', v)
print(i, class_iou)
Ejemplo n.º 9
0
def save_inference_results_on_disk(loader, network, name):
    config = loader.config
    pack_volume = config['pack_volume']
    path = os.path.join(config['temp_folder'], name, '')
    print('path ', path)
    network.eval()
    network = network.cuda()
    all_outputs = torch.cuda.FloatTensor()
    i = 1
    print('Inference is in progress')
    print('loader ', loader.batch_sampler.sampler)
    for data in tqdm(loader):
        images, true_masks = data

        images = images.cuda()

        images_themselves = images[:, :3]
        if config['with_depth']:
            depths = images[:, 3]
        else:
            depths = None

        size_101 = config['101']

        if config['resize_128']:
            outputs = network(images_themselves, depths).detach()


        else:
            size_patch = config['patch_size']
            size_37 = size_101 - size_patch
            outputs_1 = network(images_themselves[:, :, :size_patch, :size_patch], depths).detach()
            outputs_2 = network(images_themselves[:, :, size_37:, :size_patch], depths).detach()
            outputs_3 = network(images_themselves[:, :, :size_patch, size_37:], depths).detach()
            outputs_4 = network(images_themselves[:, :, size_37:, size_37:], depths).detach()

            outputs = torch.from_numpy(np.zeros((outputs_1.shape[0], outputs_1.shape[1], size_101, size_101))).float().cuda()

            outputs[:, :, :size_patch,:size_patch] += outputs_1
            outputs[:, :, size_37:,:size_patch] += outputs_2
            outputs[:, :, :size_patch,size_37:] += outputs_3
            outputs[:, :, size_37:,size_37:] += outputs_4

            outputs[:, :, size_37:size_patch, :size_37] /= 2.0
            outputs[:, :, size_37:size_patch, size_patch:] /= 2.0
            outputs[:, :, :size_37, size_37:size_patch] /= 2.0
            outputs[:, :, size_patch:, size_37:size_patch] /= 2.0
            outputs[:, :, size_37:size_patch, size_37:size_patch] /= 4.0

        outputs = F.sigmoid(outputs)

        # something like smoothing with conditional random fields
        if config['crf']:
            for j, (output, image) in enumerate(zip(outputs, images_themselves)):
                output = output.squeeze(dim=0)
                # print('output before ', output.shape)
                image = torch.transpose(image, dim0=0, dim1=2)
                # print('image ', image.shape)
                output = dense_crf(image.data.cpu().numpy().astype(np.uint8), output.data.cpu().numpy())
                # print('output after', output)
                outputs[j] = torch.from_numpy(output).float()

        if config['resize_128']:
            resized_outputs = np.zeros((outputs.shape[0], outputs.shape[1], size_101, size_101))
            for j, output in enumerate(outputs):
                output_as_array = output.data.cpu().numpy()

                resized_outputs[j] = output_as_array[:, 27:,
                                     14:-13]  # resize_image(output_as_array, (size_101, size_101))

            outputs = torch.from_numpy(resized_outputs).cuda().float()

            # outputs_for_plot = outputs.cpu().numpy()[0][0]
            # print('outputs_for_plot ', outputs_for_plot, outputs_for_plot.shape)
            # print('true_masks ', true_masks.shape)
            # import matplotlib.pyplot as plt
            # plt.imshow(outputs_for_plot)
            # plt.show()
            # plt.imshow(true_masks[0], cmap='gray')
            # plt.show()
            # input()

        all_outputs = torch.cat((all_outputs, outputs.data), dim=0)

        if i % pack_volume == 0:
            torch.save(all_outputs, '%sall_outputs_%d' % (path, i))
            all_outputs = torch.cuda.FloatTensor()
            torch.cuda.empty_cache()
        i += 1
    batches_number = len(loader) // pack_volume
    print('batches_number = ', batches_number)
    all_outputs = None
    torch.cuda.empty_cache()
    return batches_number
Ejemplo n.º 10
0
segment2 = torch.load('segment2.pt')
segment3 = torch.load('segment3.pt')
segment4 = torch.load('segment4.pt')
# segment = segment2

print(segment1)
segment = torch.stack([segment1, segment2, segment3, segment4])

# segment = torch.load('segment1.pt')

# sns.heatmap(segment, cmap="binary")
# plt.show()
# segment = torch.squeeze(segment)
# print(type(segment))
# segment = -torch.log(segment)
# segment_normalize = torch.round(torch.sigmoid(segment))
# segment_normalize = torch.nn.functional.softmax(segment3).data
orimg = imread("data2/images/train/8049.jpg")
img = resize(orimg, (224, 224))
Q = dense_crf(img, segment.numpy())

print(Q)

sns.heatmap(Q[0], cmap="cubehelix")
plt.show()
sns.heatmap(Q[1], cmap="cubehelix")
plt.show()
sns.heatmap(Q[2], cmap="cubehelix")
plt.show()
sns.heatmap(Q[3], cmap="cubehelix")
plt.show()