Ejemplo n.º 1
0
    def generator(coco, path, batch_size, class_count):
        images_directory = os.path.join(os.path.dirname(path), "data")
        files = {
            int(f.split(".")[0]): os.path.join(images_directory, f)
            for f in os.listdir(images_directory)
        }

        while True:
            images = np.zeros((batch_size, 1024, 1024, 3), dtype=np.uint8)
            labels = np.zeros((batch_size, 1024 // 8, 1024 // 8, class_count),
                              dtype=np.uint8)
            count = 0
            for image_id, segmentation in coco.imgToAnns.items():
                image = Image.open(files[image_id])
                images[count, :, :, :] = resize_and_crop(image, 1024)
                for i, ann in enumerate(segmentation):
                    arr = coco.annToMask(ann)
                    image = Image.fromarray(arr)
                    resized = resize_and_crop(image, 1024 // 8, rgb=False)
                    labels[count, :, :, ann["category_id"] - 1] += resized
                labels[count, labels[count, :, :, :] > 0] = 1
                count += 1
                if count == batch_size:
                    yield images, labels
                    images = np.zeros((batch_size, 1024, 1024, 3),
                                      dtype=np.uint8)
                    labels = np.zeros(
                        (batch_size, 1024 // 8, 1024 // 8, class_count),
                        dtype=np.uint8)
                    count = 0
Ejemplo n.º 2
0
    def __getitem__(self, idx):

        img = readimage(self.dir_img + self.data[idx] + '.nii')
        img = resize_and_crop(img)
        img = np.expand_dims(img, axis=0)

        mask = readimage(self.dir_mask + self.data[idx] + '_mask.nii')
        mask = resize_and_crop(mask)
        mask = channalize_mask(mask)
        mask = np.transpose(mask, axes=[2, 0, 1])

        return img.astype(np.float32), mask.astype(np.float32)
Ejemplo n.º 3
0
 def generator(images_dir, labels_file, batch_size, class_count,
               picture_size):
     while True:
         images = np.zeros((batch_size, picture_size, picture_size, 3),
                           dtype=np.uint8)
         labels = np.zeros((batch_size, class_count), dtype=np.uint8)
         count = 0
         for image_filename in os.listdir(images_dir):
             centering = (round(random.random(),
                                1), round(random.random(), 1))
             if image_filename[-3:] not in ("jpg", "png"):
                 continue
             p = os.path.join(images_dir, image_filename)
             image = Image.open(p)
             images[count, :, :, :] = resize_and_crop(image,
                                                      picture_size,
                                                      centering=centering)
             labels[count, label_dict[image_filename]] = 1
             count += 1
             if count == batch_size:
                 yield images, labels
                 images = np.zeros(
                     (batch_size, picture_size, picture_size, 3),
                     dtype=np.uint8)
                 labels = np.zeros((batch_size, class_count))
                 count = 0
def prepare_image(image_path):

    img_in = scipy.misc.imread(image_path, mode='RGB')
    img = img_in.astype(np.float32)
    img = utils.resize_and_crop(img, CROP_SIZE)
    img = img.transpose(2, 0, 1)
    return img
Ejemplo n.º 5
0
def predict_images(paths, model, output_directory):
    images = np.zeros((len(paths), 512, 512, 3))

    orig_sizes = []
    for i, p in enumerate(paths):
        img = Image.open(p)
        orig_sizes.append(img.size)
        images[i, :, :, :] = resize_and_crop(img, 512)[np.newaxis, :, :, :]

    predictions = model.predict(images, batch_size=len(paths))
    predictions *= 255
    predictions = predictions.astype(np.uint8)

    for i in range(predictions.shape[0]):
        for j in range(predictions.shape[-1]):
            new_w, new_h = (
                (int(orig_sizes[i][0] * 512 / orig_sizes[i][1]), 512)
                if orig_sizes[i][0] > orig_sizes[i][1]
                else (512, int(orig_sizes[i][1] * 512 / orig_sizes[i][0]))
            )
            img = Image.new("RGB", (new_w, new_h))
            ret, th = cv2.threshold(
                predictions[i, 0 : images[0].shape[1], 0 : images[0].shape[0], j],
                0,
                255,
                cv2.THRESH_OTSU,
            )
            img.paste(
                Image.fromarray(th).resize((512, 512)),
                (int(img.size[0] / 2 - 256), int(img.size[1] / 2 - 256)),
            )
            img = img.resize(orig_sizes[i])
            p = ".".join(os.path.basename(paths[i]).split(".")[0:-1])
            img.save(f"{output_directory}/{p}_{j}_segmentation.png")
Ejemplo n.º 6
0
def prepare_img_data(data):
    #if not os.path.exists("/home/dogs/"):
    #    _download_dogs(DATASET_DIR)
    #train_img_names, train_labels, test_img_names,test_labels = utils.read_image_names_and_assign_labels(class_size,test_split_number,DATASET_IMAGES)
    train_data = utils.resize_and_crop(new_height, new_width, data)
    train_data_arr = substract_mean(train_data)
    print("method successful")
    return train_data_arr
Ejemplo n.º 7
0
def predict_img(net,
                full_img,
                scale_factor=0.5,
                out_threshold=0.5,
                use_dense_crf=True,
                use_gpu=False):

    img_height = full_img.size[1]
    img_width = full_img.size[0]

    img = resize_and_crop(full_img, scale=scale_factor)
    img = normalize(img)

    left_square, right_square = split_img_into_squares(img)

    left_square = hwc_to_chw(left_square)
    right_square = hwc_to_chw(right_square)

    X_left = torch.from_numpy(left_square).unsqueeze(0)
    X_right = torch.from_numpy(right_square).unsqueeze(0)

    if use_gpu:
        X_left = X_left.cuda()
        X_right = X_right.cuda()

    with torch.no_grad():
        output_left = net(X_left)
        output_right = net(X_right)

        left_probs = F.sigmoid(output_left).squeeze(0)
        right_probs = F.sigmoid(output_right).squeeze(0)
        '''
        tf = transforms.Compose(
                [
                    transforms.ToPILImage(),
                    transforms.Resize(img_height),
                    transforms.ToTensor()
                ]
            )
        '''
        tf = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Scale(img_height),
            transforms.ToTensor()
        ])

        left_probs = tf(left_probs.cpu())
        right_probs = tf(right_probs.cpu())

        left_mask_np = left_probs.squeeze().cpu().numpy()
        right_mask_np = right_probs.squeeze().cpu().numpy()

    full_mask = merge_masks(left_mask_np, right_mask_np, img_width)

    if use_dense_crf:
        full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask)

    return full_mask > out_threshold
Ejemplo n.º 8
0
 def classification_generator():
     while True:
         for image_filename in os.listdir(class_images_dir):
             centering = (round(random.random(),
                                1), round(random.random(), 1))
             if image_filename[-3:] not in ("jpg", "png"):
                 continue
             p = os.path.join(class_images_dir, image_filename)
             image = Image.open(p)
             yield (
                 resize_and_crop(image, 512, centering=centering),
                 label_dict[image_filename],
             )
Ejemplo n.º 9
0
    def segmentation_generator():
        while True:
            for image_filename in os.listdir(seg_images_dir):
                centering = (round(random.random(),
                                   1), round(random.random(), 1))
                if image_filename[-3:] not in ("jpg", "png"):
                    continue
                p = os.path.join(os.getcwd(), seg_images_dir, image_filename)
                image = Image.open(p)
                mask = Image.open(
                    os.path.join(seg_mask_dir,
                                 f"{image_filename[:-4]}_segmentation.png"))

                yield (
                    resize_and_crop(image, 512, centering=centering),
                    resize_and_crop(
                        mask,
                        512 // (8 if model_size == "large" else 16),
                        centering=centering,
                        rgb=False,
                    ),
                )
Ejemplo n.º 10
0
 def generator(samples_dir, batch_size, class_count):
     files = os.listdir(samples_dir)
     while True:
         images = np.zeros((batch_size, picture_size, picture_size, 3),
                           dtype=np.uint8)
         labels = np.zeros(
             (batch_size, picture_size // 8, picture_size // 8,
              class_count),
             dtype=np.uint8,
         )
         count = 0
         random.shuffle(files)
         for f in files:
             centering = (round(random.random(),
                                1), round(random.random(), 1))
             arr, masks = parse_labelme_file(os.path.join(samples_dir, f))
             images[count, :, :, :] = resize_and_crop(arr,
                                                      picture_size,
                                                      centering=centering)
             for l, mask in masks:
                 labels[count, :, :,
                        l] += resize_and_crop(mask,
                                              picture_size // 8,
                                              centering=centering,
                                              rgb=False)
             count += 1
             if count == batch_size:
                 labels[labels > 0] = 1
                 yield images, labels
                 images = np.zeros(
                     (batch_size, picture_size, picture_size, 3),
                     dtype=np.uint8)
                 labels = np.zeros(
                     (batch_size, picture_size // 8, picture_size // 8,
                      class_count),
                     dtype=np.uint8,
                 )
                 count = 0
Ejemplo n.º 11
0
def evaluate():
    """Runs inference with pretrained model.
    """

    args = parse_arguments()

    # Load model
    model = MobileNetV3LiteRASPP(shape=(512, 512, 3),
                                 n_class=args.class_count,
                                 task=args.task)

    label_dict = {}
    with open(args.labels_file, "r") as f:
        csvfile = csv.reader(f)
        # Skip column description
        next(csvfile)
        for row in csvfile:
            label_dict[row[0]] = row[1:].index("1.0")

    if args.model_size == "large":
        model = model.build_large()
    else:
        model = model.build_small()

    model.load_weights(args.save_path, by_name=True)

    images = np.zeros((len(os.listdir(args.input_dir)), 512, 512, 3))
    labels = np.zeros((len(os.listdir(args.input_dir)), args.class_count))
    for i, filename in enumerate(os.listdir(args.input_dir)):
        img = resize_and_crop(
            Image.open(os.path.join(args.input_dir, filename)), 512)
        images[i, :, :, :] = np.array(img)
        labels[i, label_dict[filename[:-4]]] = 1

    preds = model.predict(images)
    print(labels.argmax(axis=1))
    print(preds.argmax(axis=1))
    matrix = metrics.confusion_matrix(labels.argmax(axis=1),
                                      preds.argmax(axis=1))
    f1 = metrics.f1_score(labels.argmax(axis=1),
                          preds.argmax(axis=1),
                          average=None)
    acc = 0
    for i in range(matrix.shape[0]):
        acc += matrix[i, i] / sum(matrix[i, :])
    print(np.around(matrix / np.sum(matrix, axis=1)[:, None], decimals=2))
    print(acc / args.class_count)
    print(f1)
Ejemplo n.º 12
0
def predict_img(net,
                full_img,
                device,
                scale_factor=1,
                out_threshold=0.5,
                use_dense_crf=False):
    net.eval()
    img_height = full_img.size[1]

    img = resize_and_crop(full_img, scale=scale_factor)
    img = normalize(img)
    img = hwc_to_chw(img)

    X = torch.from_numpy(img).unsqueeze(0)

    X = X.to(device=device)

    with torch.no_grad():
        output = net(X)

        if net.n_classes > 1:
            probs = F.softmax(output, dim=1)
        else:
            probs = torch.sigmoid(output)

        probs = probs.squeeze(0)

        tf = transforms.Compose(
            [
                transforms.ToPILImage(),
                transforms.Resize(img_height),
                transforms.ToTensor()
            ]
        )

        probs = tf(probs.cpu())

        full_mask = probs.squeeze().cpu().numpy()

    if use_dense_crf:
        full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask)

    return full_mask > out_threshold
Ejemplo n.º 13
0
def predict_img(net,
                full_img,
                scale_factor=0.5,
                out_threshold=0.5,
                use_dense_crf=True,
                use_gpu=False):
    net.eval()


    img = resize_and_crop(full_img, scale=scale_factor)
    img = normalize(img)
    img = hwc_to_chw(img)
    X = torch.from_numpy(img).unsqueeze(0)

    if use_gpu:
        X = X.cuda()

    with torch.no_grad():
        output = net(X)
        probs = output.squeeze(0)

        tf = transforms.Compose(
            [
                transforms.ToPILImage(),
                transforms.ToTensor()
            ]
        )

        probs = tf(probs.cpu())
        print("Probs - ", probs.shape)
        mask = probs.squeeze().cpu().numpy()
        print("Mask - ", mask.shape)
        plt.imshow(mask)
        plt.show()

    if use_dense_crf:
        mask = dense_crf(np.array(full_img).astype(np.uint8), mask)

    return mask > out_threshold
Ejemplo n.º 14
0
def predict_img_batch(net,
                      full_img,
                      scale_factor=0.5,
                      out_threshold=0.5,
                      use_dense_crf=True,
                      use_gpu=True):
    """return fullmask with size (C, H, W)"""
    net.eval()
    img_height = full_img.size[1]
    img_width = full_img.size[0]
    print('imgheight', img_height)
    print('imgwidth', img_width)
    img = resize_and_crop(full_img, scale=scale_factor)
    img = normalize(img)

    if len(img.shape) == 2:
        img = img[..., np.newaxis]

    print('img.shape', img.shape)
    left_square, right_square = split_img_into_squares(img)

    left_square = hwc_to_chw(left_square)
    right_square = hwc_to_chw(right_square)

    X_left = torch.from_numpy(left_square).unsqueeze(0)
    X_right = torch.from_numpy(right_square).unsqueeze(0)

    if use_gpu:
        X_left = X_left.cuda()
        X_right = X_right.cuda()

    with torch.no_grad():
        output_left = net(X_left)
        output_right = net(X_right)

        print('output_left.shape', output_left.shape)
        print('output_right.shape', output_right.shape)
        left_probs = output_left.squeeze(0)
        right_probs = output_right.squeeze(0)
        #
        # if not scale_factor==1:
        #     tf = transforms.Compose(
        #         [
        #             transforms.ToPILImage(),
        #             transforms.Resize(img_height),
        #             transforms.ToTensor()
        #         ]
        #     )
        #
        #     left_probs = tf(left_probs.cpu())
        #     right_probs = tf(right_probs.cpu())
        # print('left_probs', left_probs.shape)
        # print('right_probs', right_probs.shape)
        left_mask_np = left_probs.squeeze().cpu().numpy()
        right_mask_np = right_probs.squeeze().cpu().numpy()
        left_mask_np = np.transpose(left_mask_np, axes=[1, 2, 0])
        right_mask_np = np.transpose(right_mask_np, axes=[1, 2, 0])
        if not scale_factor == 1:
            right_mask_np = resize_np(right_mask_np, 1 / scale_factor)
            left_mask_np = resize_np(left_mask_np, 1 / scale_factor)
        print('left_mask_np', left_mask_np.shape)
        print('right_mask_np', right_mask_np.shape)
    left_mask_np = np.transpose(left_mask_np, axes=[2, 0, 1])
    right_mask_np = np.transpose(right_mask_np, axes=[2, 0, 1])
    full_mask = merge_masks(left_mask_np, right_mask_np, img_width)

    # if use_dense_crf:
    if 0:
        full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask)

    return full_mask
Ejemplo n.º 15
0
def prepare_image(img_in, crop_size):
    img = utils.resize_and_crop(img_in, crop_size)
    img = img.astype(np.float32)
    img = img[None, ...]
    return img
Ejemplo n.º 16
0
def prepare_image(img_in, crop_size):
    img = utils.resize_and_crop(img_in, crop_size)
    img = img.astype(INPUT_DATA_TYPE)
    img = img.transpose(2, 0, 1)  # to CHW
    return img
Ejemplo n.º 17
0
def to_cropped_imgs(ids, dir, suffix):
    """From a list of tuples, returns the correct cropped img"""
    for id, pos in ids:
        im = resize_and_crop(Image.open(dir + id + suffix))
        yield get_square(im, pos)
Ejemplo n.º 18
0
    img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2GRAY)

    mask = cv2.imread(dir_mask + '/' + id + '.png')

    red_mask = np.array(mask[:, :, 2] == 128)
    green_mask = np.array(mask[:, :, 1] == 128)
    true_mask = np.stack(
        [red_mask.astype(np.float32),
         green_mask.astype(np.float32)])

    img[red_mask] = img[red_mask] * 0.8
    img[green_mask] = img[green_mask] * 1.2
    img[np.logical_and(mask[:, :, 2] != 128, mask[:, :, 1] != 128)] = 0.2 * \
                          img[np.logical_and(mask[:, :, 2] != 128, mask[:, :, 1] != 128)]

    img = resize_and_crop(img, scale=0.5)
    img = normalize(img)[None, None, :, :]
    img = torch.from_numpy(img).float()
    net = UNet(1, 2)
    net.eval()
    net.load_state_dict(
        torch.load(
            '/media/zhuzhu/0C5809B80C5809B8/draft/unet/checkpoint/unet_0.854608765.pth',
            map_location='cpu'))
    predict = net(img).squeeze(0)

    mask_predict = (predict > 0.5).float().numpy()
    mask_blue = np.zeros(mask_predict.shape[1:])[np.newaxis, :]
    mask_predict = np.concatenate([mask_predict, mask_blue], axis=0)

    mask_predict = (mask_predict * 128).astype(np.uint8).transpose([1, 2, 0])
Ejemplo n.º 19
0
    def generator(images_dir, masks_dir, batch_size, class_count):
        while True:
            images = np.zeros((batch_size, picture_size, picture_size, 3),
                              dtype=np.uint8)

            if model_size == "large":
                labels = np.zeros(
                    (batch_size, picture_size // 8, picture_size // 8,
                     class_count),
                    dtype=np.uint8,
                )
            else:
                labels = np.zeros(
                    (batch_size, picture_size // 16, picture_size // 16,
                     class_count),
                    dtype=np.uint8,
                )

            count = 0
            for image_filename in os.listdir(images_dir):
                centering = (round(random.random(),
                                   1), round(random.random(), 1))
                if image_filename[-3:] not in ("jpg", "png"):
                    continue
                p = os.path.join(os.getcwd(), images_dir, image_filename)
                image = Image.open(p)
                images[count, :, :, :] = resize_and_crop(image,
                                                         picture_size,
                                                         centering=centering)

                mask = Image.open(
                    os.path.join(masks_dir,
                                 f"{image_filename[:-4]}_segmentation.png"))

                if model_size == "large":
                    labels[count, :, :,
                           0] = resize_and_crop(mask,
                                                picture_size // 8,
                                                centering=centering,
                                                rgb=False)
                else:
                    labels[count, :, :,
                           0] = resize_and_crop(mask,
                                                picture_size // 16,
                                                centering=centering,
                                                rgb=False)

                labels[count, labels[count, :, :, :] > 0] = 1
                count += 1
                if count == batch_size:
                    yield images, labels
                    images = np.zeros(
                        (batch_size, picture_size, picture_size, 3),
                        dtype=np.uint8)

                    if model_size == "large":
                        labels = np.zeros((
                            batch_size,
                            picture_size // 8,
                            picture_size // 8,
                            class_count,
                        ))
                    else:
                        labels = np.zeros((
                            batch_size,
                            picture_size // 16,
                            picture_size // 16,
                            class_count,
                        ))

                    count = 0
Ejemplo n.º 20
0
def predict_img(net,
                full_img,
                scale_factor=0.5,
                out_threshold=0.5,
                use_gpu=False):
    pst = time.perf_counter()
    net.eval()
    #print(' 0 running time: %s seconds ' %(( time.clock() -pst)))

    img_height = full_img.size[1]
    # print(img_height)
    img_width = full_img.size[0]
    # print(full_img.size)  # 2048,1229
    img = resize_and_crop(full_img, scale=scale_factor)
    #pdb.set_trace()
    #print(' 1 running time: %s seconds ' %(( time.clock() -pst)))

    img = normalize(img)

    #print(' 2 running time: %s seconds ' %(( time.clock() -pst)))
    # print(img.shape)   # 614,1024,3
    left_square, right_square = split_img_into_squares(img)
    # print(right_square.shape)     # 614,614,3
    left_square = hwc_to_chw(left_square)
    right_square = hwc_to_chw(right_square)

    #print(' 3 running time: %s seconds ' %(( time.clock() -pst)))

    X_left = torch.from_numpy(left_square).unsqueeze(0)
    X_right = torch.from_numpy(right_square).unsqueeze(0)

    #print(' 4 running time: %s seconds ' %(( time.clock() -pst)))

    #outstart = time.clock()
    if use_gpu:
        X_left = X_left.cuda()
        X_right = X_right.cuda()

    #print(' 5 running time: %s seconds ' %(( time.clock() -pst)))

    with torch.no_grad():
        torch.cuda.synchronize()
        st = time.perf_counter()
        # print(X_left.shape)   # 1,3,614,614
        output_left = net(X_left)
        output_right = net(X_right)
        torch.cuda.synchronize()
        st1 = time.perf_counter()
        #outend = time.clock()
        print(' Unet++ --------------------> running time: %s seconds ' %
              (st1 - st))
        left_probs = output_left.squeeze(0)
        right_probs = output_right.squeeze(0)
        # print(' squeeze running time: %s seconds ' %((time.perf_counter()-st1)))

        if (left_probs.shape[1] != img_height):
            tf = transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize(img_height),
                transforms.ToTensor()
            ])
            left_probs = tf(left_probs.cpu())
            right_probs = tf(right_probs.cpu())
            print("Transform done!")
        #print(' 8  running time: %s seconds ' %(( time.clock() -pst)))
        #lstart = time.clock()

        #left_probs.cpu()
        #print(' transforms running time: %s seconds ' %(( time.time() -pst)))
        st = time.perf_counter()
        left_mask_np = left_probs.squeeze().cpu().numpy()
        end1 = time.perf_counter() - st
        #print(left_probs.shape)
        #pdb.set_trace()
        # print(' tonumpy1 running time: %s seconds ' %(end1))
        st = time.perf_counter()
        right_mask_np = right_probs.squeeze().cpu().numpy()
        end2 = time.perf_counter() - st
        # print(' tonumpy2 running time: %s seconds ' %(end2))

    full_mask = merge_masks(left_mask_np, right_mask_np, img_width)

    # print(' 9 running time: %s seconds ' %(( time.perf_counter() -pst)))

    #pdb.set_trace()
    full_mask[full_mask >= out_threshold] = 1
    full_mask[full_mask < out_threshold] = 0
    #-------------------------------------------------------------------------------

    #newmask = dense_crf(np.array(full_img).astype(np.uint8), full_mask)
    #lend = time.clock()
    #print(' running time: %s seconds ' %((lend-pst)))
    #pdb.set_trace()
    return full_mask > out_threshold
Ejemplo n.º 21
0
    def generator(images_dir, masks_dir, batch_size, class_count):
        while True:
            random.shuffle(samples)
            images = np.zeros((batch_size, picture_size, picture_size, 3),
                              dtype=np.uint8)
            if model_size == "large":
                labels = np.zeros(
                    (batch_size, picture_size // 8, picture_size // 8,
                     class_count),
                    dtype=np.uint8,
                )
            else:
                labels = np.zeros(
                    (batch_size, picture_size // 16, picture_size // 16,
                     class_count),
                    dtype=np.uint8,
                )
            count = 0

            for image_path, mask_path, label in samples:
                centering = (round(random.random(),
                                   1), round(random.random(), 1))

                try:
                    angle = random.choice([0, 45, 90, 270])
                    image = Image.open(image_path).rotate(angle)
                    images[count, :, :, :] = resize_and_crop(
                        image, picture_size, centering=centering)

                    #Image.fromarray(images[count, :, :, :]).save(f"out_debug/image_{count}.jpg")

                    mask = Image.open(mask_path).rotate(angle)

                    if model_size == "large":
                        labels[count, :, :,
                               label] = resize_and_crop(mask,
                                                        picture_size // 8,
                                                        centering=centering,
                                                        rgb=False)
                    else:
                        labels[count, :, :,
                               label] = resize_and_crop(mask,
                                                        picture_size // 16,
                                                        centering=centering,
                                                        rgb=False)
                    labels[count, labels[count, :, :, :] > 0] = 1

                    #Image.fromarray(labels[count, :, :, label] * 255).save(f"out_debug/mask_{count}.png")

                    count += 1
                except Exception as ex:
                    print(ex)

                if count == batch_size:
                    yield images, labels
                    images = np.zeros(
                        (batch_size, picture_size, picture_size, 3),
                        dtype=np.uint8)

                    if model_size == "large":
                        labels = np.zeros((
                            batch_size,
                            picture_size // 8,
                            picture_size // 8,
                            class_count,
                        ))
                    else:
                        labels = np.zeros((
                            batch_size,
                            picture_size // 16,
                            picture_size // 16,
                            class_count,
                        ))

                    count = 0
Ejemplo n.º 22
0
def predict_img(net,
                full_img,
                scale_factor=0.5,
                out_threshold=0.5,
                use_dense_crf=True,
                use_gpu=False):

    net.eval()
    img_height = full_img.size[1]
    img_width = full_img.size[0]

    img = resize_and_crop(full_img, scale=scale_factor)
    img = normalize(img)

    # left_square, right_square = split_img_into_squares(img)
    #
    # left_square = hwc_to_chw(left_square)
    # right_square = hwc_to_chw(right_square)
    img = hwc_to_chw(img)
    X = torch.from_numpy(img).unsqueeze(0)

    # X_left = torch.from_numpy(left_square).unsqueeze(0)
    # X_right = torch.from_numpy(right_square).unsqueeze(0)

    if use_gpu:
        X = X.cuda()
        # X_left = X_left.cuda()
        # X_right = X_right.cuda()

    with torch.no_grad():
        # output_left = net(X_left)
        # output_right = net(X_right)
        output = net(X)

        # left_probs = output_left.squeeze(0)
        # right_probs = output_right.squeeze(0)

        probs = output.squeeze(0)

        tf = transforms.Compose([
            transforms.ToPILImage(),
            # transforms.Resize(img_height),
            transforms.ToTensor()
        ])

        # left_probs = tf(left_probs.cpu())
        # right_probs = tf(right_probs.cpu())
        #
        # left_mask_np = left_probs.squeeze().cpu().numpy()
        # right_mask_np = right_probs.squeeze().cpu().numpy()

        probs = tf(probs.cpu())
        mask = probs.squeeze().cpu().numpy()

    # full_mask = merge_masks(left_mask_np, right_mask_np, img_width)

    if use_dense_crf:
        # full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask)
        mask = dense_crf(np.array(full_img).astype(np.uint8), mask)

    return mask > out_threshold