Example #1
0
def test_fliplr_image2mask():
    x = torch.rand((4, 3, 224, 224))
    model = NoOp()

    output = tta.fliplr_image2mask(model, x)
    np.testing.assert_allclose(to_numpy(output),
                               to_numpy(x),
                               atol=1e-6,
                               rtol=1e-6)
Example #2
0
 def forward(self, x):
     return fliplr_image2mask(self.model, x)
Example #3
0
def main():
    args = get_args()

    with open(args.config_path) as f:
        hparams = yaml.load(f, Loader=yaml.SafeLoader)

    model = SegmentDocs(hparams)

    test_file_names = sorted(Path(args.image_path).glob("*.jpg"))

    test_mask_path = args.output_path / "masks"
    test_vis_path = args.output_path / "vis"

    test_mask_path.mkdir(exist_ok=True, parents=True)
    test_vis_path.mkdir(exist_ok=True, parents=True)

    test_aug = from_dict(hparams["test_aug"])

    dataloader = DataLoader(
        SegmentationDatasetTest(test_file_names, test_aug),
        batch_size=hparams["test_parameters"]["batch_size"],
        num_workers=hparams["num_workers"],
        shuffle=False,
        pin_memory=True,
        drop_last=False,
    )

    corrections: Dict[str, str] = {}

    checkpoint = load_checkpoint(file_path=args.checkpoint_path, rename_in_layers=corrections)  # type: ignore

    model.load_state_dict(checkpoint["state_dict"])
    model = nn.Sequential(model, nn.Sigmoid())

    model = tta.MultiscaleTTAWrapper(model, [0.5, 2])

    model.eval()
    model = model.half()
    model.cuda()

    with torch.no_grad():
        for batch in tqdm(dataloader):
            features = batch["features"]
            image_ids = batch["image_id"]

            preds = tta.fliplr_image2mask(model, features.half().cuda())

            for batch_id in range(features.shape[0]):
                image_id = image_ids[batch_id]
                mask = (preds[batch_id][0] > 0.5).cpu().numpy().astype(np.uint8)

                height = batch["height"][batch_id].item()
                width = batch["width"][batch_id].item()
                pads = batch["pads"][batch_id].cpu().numpy()

                mask = unpad(mask, pads)

                mask = remove_small_connected_binary(mask, min_area=100)
                mask = fill_small_holes(mask, min_area=100)

                mask = cv2.resize(
                    mask, (width, height), interpolation=cv2.INTER_NEAREST
                )

                cv2.imwrite(str(test_mask_path / f"{image_id}.png"), mask * 255)

                image = cv2.imread(str(args.image_path / f"{image_id}.jpg"))

                mask_image = mask_overlay(image, mask)

                cv2.imwrite(
                    str(test_vis_path / f"{image_id}.jpg"),
                    np.hstack(
                        [mask_image, cv2.cvtColor((mask * 255), cv2.COLOR_GRAY2BGR)]
                    ),
                )