def predict_masks(args, hps, store, to_predict: List[str], threshold: float,
                  validation: str=None, no_edges: bool=False):
    logger.info('Predicting {} masks: {}'
                .format(len(to_predict), ', '.join(sorted(to_predict))))
    model = Model(hps=hps)
    if args.model_path:
        model.restore_snapshot(args.model_path)
    else:
        model.restore_last_snapshot(args.logdir)

    def load_im(im_id):
        data = model.preprocess_image(utils.load_image(im_id))
        if hps.n_channels != data.shape[0]:
            data = data[:hps.n_channels]
        if validation == 'square':
            data = square(data, hps)
        return Image(id=im_id, data=data)

    def predict_mask(im):
        logger.info(im.id)
        return im, model.predict_image_mask(im.data, no_edges=no_edges)

    im_masks = map(predict_mask, utils.imap_fixed_output_buffer(
        load_im, sorted(to_predict), threads=2))

    for im, mask in utils.imap_fixed_output_buffer(
            lambda _: next(im_masks), to_predict, threads=1):
        assert mask.shape[1:] == im.data.shape[1:]
        with gzip.open(str(mask_path(store, im.id)), 'wb') as f:
            # TODO - maybe do (mask * 20).astype(np.uint8)
            np.save(f, mask >= threshold)
Beispiel #2
0
    def predict(arg):
        img_meta, img = arg
        h, w = img.shape[:2]
        s = patch_size
        # step = s // 2 - 32
        step = s - 64
        xs = list(range(0, w - s, step)) + [w - s]
        ys = list(range(0, h - s, step)) + [h - s]
        all_xy = [(x, y) for x in xs for y in ys]
        pred_img = np.zeros((utils.N_CLASSES + 1, h, w), dtype=np.float32)
        pred_count = np.zeros((h, w), dtype=np.int32)

        def make_batch(xy_batch_):
            return (xy_batch_, torch.stack([
                utils.img_transform(img[y: y + s, x: x + s]) for x, y in xy_batch_]))

        for xy_batch, inputs in utils.imap_fixed_output_buffer(
                make_batch, tqdm.tqdm(list(utils.batches(all_xy, batch_size))),
                threads=1):
            outputs = model(utils.variable(inputs, volatile=True))
            outputs_data = np.exp(outputs.data.cpu().numpy())
            for (x, y), pred in zip(xy_batch, outputs_data):
                pred_img[:, y: y + s, x: x + s] += pred
                pred_count[y: y + s, x: x + s] += 1
        pred_img /= np.maximum(pred_count, 1)
        return img_meta, pred_img
def predict(
    model,
    img_paths: List[Path],
    out_path: Path,
    patch_size: int,
    batch_size: int,
    is_test=False,
    test_scale=1.0,
    min_scale=1.0,
    max_scale=1.0,
):
    with out_path.joinpath('blobs.pkl').open('rb') as f:
        blobs_data = pickle.load(f)
    blobs_by_img_id = defaultdict(list)
    blob_scale_by_img_id = {}
    assert len(blobs_data['blobs']) == utils.N_CLASSES
    for cls_blobs in blobs_data['blobs'][:-1]:  # skip pups
        for img_id, scale, blobs in cls_blobs:
            blob_scale_by_img_id[img_id] = scale  # same for all classes
            assert len(blobs) == len(BLOB_THRESHOLDS)
            blobs_by_img_id[img_id].append(blobs[0])  # take lowest threshold

    model.eval()
    probs_by_img_id_cls_blob_id = {}
    for arg in utils.imap_fixed_output_buffer(partial(_load_image,
                                                      is_test=is_test,
                                                      test_scale=test_scale,
                                                      min_scale=min_scale,
                                                      max_scale=max_scale),
                                              tqdm.tqdm(img_paths),
                                              threads=4):
        img_id, (indices, outputs) = _predict(arg, model, patch_size,
                                              batch_size, blobs_by_img_id,
                                              blob_scale_by_img_id)
        for (cls, blob_id), probs in zip(indices, outputs):
            probs_by_img_id_cls_blob_id[img_id, cls, blob_id] = probs

    features_img_ids = np.load(str(
        out_path.joinpath('features.npz')))['ids'][0]
    clf_features = [[] for _ in range(utils.N_CLASSES)]
    for cls, cls_blob_ids in enumerate(blobs_data['blob_ids']):
        for img_id, blob_ids in zip(features_img_ids, cls_blob_ids):
            prob_sum = prob_sum_05 = 0
            if cls != utils.N_CLASSES - 1:
                for blob_id in blob_ids:
                    for blob_cls in range(utils.N_CLASSES - 1):
                        probs = (probs_by_img_id_cls_blob_id.get(
                            (img_id, blob_cls, blob_id)))
                        if probs is not None:
                            prob_sum += probs[cls]
                            if probs[cls] > 0.5:
                                prob_sum_05 += probs[cls]
            clf_features[cls].append([prob_sum, prob_sum_05])

    clf_features = np.array(clf_features)
    with out_path.joinpath('clf_features.npz').open('wb') as f:
        np.savez(f, xs=clf_features)
Beispiel #4
0
    def predict_image_mask(self,
                           im_data: np.ndarray,
                           rotate: bool = False,
                           no_edges: bool = False,
                           average_shifts: bool = True) -> np.ndarray:
        self.net.eval()
        c, w, h = im_data.shape
        b = self.hps.patch_border
        s = self.hps.patch_inner
        padded = np.zeros([c, w + 2 * b, h + 2 * b], dtype=im_data.dtype)
        padded[:, b:-b, b:-b] = im_data
        # mirror on the edges
        padded[:, :b, b:-b] = np.flip(im_data[:, :b, :], 1)
        padded[:, -b:, b:-b] = np.flip(im_data[:, -b:, :], 1)
        padded[:, :, :b] = np.flip(padded[:, :, b:2 * b], 2)
        padded[:, :, -b:] = np.flip(padded[:, :, -2 * b:-b], 2)
        step = s // 3 if average_shifts else s
        margin = b if no_edges else 0
        xs = list(range(margin, w - s - margin, step)) + [w - s - margin]
        ys = list(range(margin, h - s - margin, step)) + [h - s - margin]
        all_xy = [(x, y) for x in xs for y in ys]
        out_shape = [self.hps.n_classes, w, h]
        pred_mask = np.zeros(out_shape, dtype=np.float32)
        pred_per_pixel = np.zeros(out_shape, dtype=np.int16)
        n_rot = 4 if rotate else 1

        def gen_batch(xy_batch_):
            inputs_ = []
            for x, y in xy_batch_:
                # shifted by -b to account for padding
                patch = padded[:, x:x + s + 2 * b, y:y + s + 2 * b]
                inputs_.append(patch)
                for i in range(1, n_rot):
                    inputs_.append(utils.rotated(patch, i * 90))
            return xy_batch_, np.array(inputs_, dtype=np.float32)

        for xy_batch, inputs in utils.imap_fixed_output_buffer(
                gen_batch,
                tqdm.tqdm(
                    list(
                        utils.chunks(all_xy,
                                     self.hps.batch_size // (4 * n_rot)))),
                threads=2):
            y_pred = self.net(self._var(torch.from_numpy(inputs)))
            for idx, mask in enumerate(y_pred.data.cpu().numpy()):
                x, y = xy_batch[idx // n_rot]
                i = idx % n_rot
                if i:
                    mask = utils.rotated(mask, -i * 90)
                # mask = (mask >= 0.5) + 0.001
                pred_mask[:, x:x + s, y:y + s] += mask / n_rot
                pred_per_pixel[:, x:x + s, y:y + s] += 1
        if not no_edges:
            assert pred_per_pixel.min() >= 1
        pred_mask /= np.maximum(pred_per_pixel, 1)
        return pred_mask
def _predict(arg, model, patch_size, batch_size, blobs_by_img_id,
             blob_scale_by_img_id):
    (path, scale), img = arg
    img_id = int(path.stem)
    h, w = img.shape[:2]
    s = patch_size // 2
    cls_blobs = blobs_by_img_id.get(img_id)
    if not cls_blobs or not any(cls_blobs):
        return (path, scale), None
    blob_scale = blob_scale_by_img_id[img_id]
    all_xy = [(cls, i, int(round(x * blob_scale * scale)),
               int(round(y * blob_scale * scale)))
              for cls, blobs in enumerate(cls_blobs)
              for i, (x, y, _) in enumerate(blobs)]

    def make_batch(xy_batch_):
        indices, patches = [], []
        for cls, i, x, y in xy_batch_:
            patch = img[max(0, y - s):y + s, max(0, x - s):x + s]
            if patch.shape[:2] == (patch_size, patch_size):
                patches.append(utils.img_transform(patch))
                indices.append((cls, i))
        patches = torch.stack(patches) if patches else None
        return indices, patches

    all_indices, all_outputs = [], []
    for indices, inputs in utils.imap_fixed_output_buffer(
            make_batch,
            tqdm.tqdm(list(utils.batches(all_xy, batch_size))),
            threads=1):
        if inputs is not None:
            outputs = model(utils.variable(inputs, volatile=True))
            outputs = F.softmax(outputs).data.cpu().numpy()
            all_indices.extend(indices)
            all_outputs.extend(outputs)
    return img_id, (all_indices, all_outputs)
def predict(
    model,
    img_paths: List[Path],
    out_path: Path,
    patch_size: int,
    batch_size: int,
    is_test=False,
    downsampled=False,
    test_scale=1.0,
    min_scale=1.0,
    max_scale=1.0,
):
    model.eval()

    def load_image(path):
        #         if is_test:
        #             scale = test_scale
        #         elif min_scale != max_scale:
        #             random.seed(path.stem)
        #             scale = round(random.uniform(min_scale, max_scale), 5)
        #         else:
        #             scale = min_scale
        scale = 1800 / 256
        img = utils.load_image(path, cache=False)
        img = img[:, :1800]
        #         h, w = img.shape[:2]
        #         if scale != 1:
        #             h = int(h * scale)
        #             w = int(w * scale)
        img = cv2.resize(img, (256, 256))
        return (path, scale), img

    def predict(arg):
        img_meta, img = arg
        h, w = img.shape[:2]
        s = patch_size

        def make_batch(xy_batch_):
            return (xy_batch_, torch.stack([utils.img_transform(img)]))

        pred_img = np.zeros((2, s, s), dtype=np.float32)
        #          np.zeros((utils.N_CLASSES + 1, h, w), dtype=np.float32)
        #         pred_count = np.zeros((s, s), dtype=np.int32)
        inputs = torch.stack([utils.img_transform(img)])
        outputs = model(utils.variable(inputs, volatile=True))
        #         print("outputs", outputs.shape)
        outputs_data = np.exp(outputs.data.cpu().numpy())
        #         print("o_data", outputs_data.shape)
        for pred in outputs_data:
            pred_img += pred
#             print("pred", pred)
#         print("pred_img", pred_img)
#         for idx, i in enumerate(pred_img):
#             utils.save_image('_runs/pred-{}-{}.png'.format(img_meta[0].stem, idx), (i > 0.25+idx*0.5).astype(np.float32))
        utils.save_image('_runs/pred-{}.png'.format(img_meta[0].stem),
                         colored_prediction(outputs_data[0]))
        utils.save_image(
            '_runs/{}-input.png'.format(prefix),
            skimage.exposure.rescale_intensity(img, out_range=(0, 1)))
        #         utils.save_image(
        #             '_runs/{}-target.png'.format(prefix),
        #             colored_prediction(target_one_hot.astype(np.float32)))

        return img_meta, pred_img


#         for xy_batch, inputs in utils.imap_fixed_output_buffer(
#                 make_batch, tqdm.tqdm(list(utils.batches(all_xy, batch_size))),
#                 threads=1):
#             outputs = model(utils.variable(inputs, volatile=True))
#             outputs_data = np.exp(outputs.data.cpu().numpy())
#             for (x, y), pred in zip(xy_batch, outputs_data):
#                 pred_img[:, y: y + s, x: x + s] += pred
#                 pred_count[y: y + s, x: x + s] += 1
#         pred_img /= np.maximum(pred_count, 1)
#         return img_meta, pred_img

    loaded = utils.imap_fixed_output_buffer(load_image,
                                            tqdm.tqdm(img_paths),
                                            threads=4)

    prediction_results = utils.imap_fixed_output_buffer(predict,
                                                        loaded,
                                                        threads=1)

    def save_prediction(arg):
        (img_path, img_scale), pred_img = arg
        #         if not downsampled:
        #             pred_img = np.stack([utils.downsample(p, PRED_SCALE) for p in pred_img])
        binarized = (pred_img).astype(np.uint16)
        with gzip.open(
                str(out_path /
                    '{}-{:.5f}-pred.npy'.format(img_path.stem, img_scale)),
                'wb',
                compresslevel=4) as f:
            np.save(f, binarized)
        return img_path.stem

    for _ in utils.imap_fixed_output_buffer(save_prediction,
                                            prediction_results,
                                            threads=4):
        print(_)
        pass
Beispiel #7
0
def predict(model, img_paths: List[Path], out_path: Path,
            patch_size: int, batch_size: int,
            is_test=False, downsampled=False,
            test_scale=1.0, min_scale=1.0, max_scale=1.0,
            ):
    model.eval()

    def load_image(path):
        if is_test:
            scale = test_scale
        elif min_scale != max_scale:
            random.seed(path.stem)
            scale = round(random.uniform(min_scale, max_scale), 5)
        else:
            scale = min_scale
        img = utils.load_image(path, cache=False)
        h, w = img.shape[:2]
        if scale != 1:
            h = int(h * scale)
            w = int(w * scale)
            img = cv2.resize(img, (w, h))
        return (path, scale), img

    def predict(arg):
        img_meta, img = arg
        h, w = img.shape[:2]
        s = patch_size
        # step = s // 2 - 32
        step = s - 64
        xs = list(range(0, w - s, step)) + [w - s]
        ys = list(range(0, h - s, step)) + [h - s]
        all_xy = [(x, y) for x in xs for y in ys]
        pred_img = np.zeros((utils.N_CLASSES + 1, h, w), dtype=np.float32)
        pred_count = np.zeros((h, w), dtype=np.int32)

        def make_batch(xy_batch_):
            return (xy_batch_, torch.stack([
                utils.img_transform(img[y: y + s, x: x + s]) for x, y in xy_batch_]))

        for xy_batch, inputs in utils.imap_fixed_output_buffer(
                make_batch, tqdm.tqdm(list(utils.batches(all_xy, batch_size))),
                threads=1):
            outputs = model(utils.variable(inputs, volatile=True))
            outputs_data = np.exp(outputs.data.cpu().numpy())
            for (x, y), pred in zip(xy_batch, outputs_data):
                pred_img[:, y: y + s, x: x + s] += pred
                pred_count[y: y + s, x: x + s] += 1
        pred_img /= np.maximum(pred_count, 1)
        return img_meta, pred_img

    loaded = utils.imap_fixed_output_buffer(
        load_image, tqdm.tqdm(img_paths), threads=4)

    prediction_results = utils.imap_fixed_output_buffer(
        predict, loaded, threads=1)

    def save_prediction(arg):
        (img_path, img_scale), pred_img = arg
        if not downsampled:
            pred_img = np.stack([utils.downsample(p, PRED_SCALE) for p in pred_img])
        binarized = (pred_img * 1000).astype(np.uint16)
        with gzip.open(
                str(out_path / '{}-{:.5f}-pred.npy'.format(
                    img_path.stem, img_scale)),
                'wb', compresslevel=4) as f:
            np.save(f, binarized)
        return img_path.stem

    for _ in utils.imap_fixed_output_buffer(
            save_prediction, prediction_results, threads=4):
        pass