Ejemplo n.º 1
0
    def test_export(self, key_in_ckpt):
        meta_file = os.path.join(os.path.dirname(__file__), "testing_data",
                                 "metadata.json")
        config_file = os.path.join(os.path.dirname(__file__), "testing_data",
                                   "inference.json")
        with tempfile.TemporaryDirectory() as tempdir:
            def_args = {"meta_file": "will be replaced by `meta_file` arg"}
            def_args_file = os.path.join(tempdir, "def_args.json")
            ckpt_file = os.path.join(tempdir, "model.pt")
            ts_file = os.path.join(tempdir, "model.ts")

            parser = ConfigParser()
            parser.export_config_file(config=def_args, filepath=def_args_file)
            parser.read_config(config_file)
            net = parser.get_parsed_content("network_def")
            save_state(src=net if key_in_ckpt == "" else {key_in_ckpt: net},
                       path=ckpt_file)

            cmd = [
                "coverage", "run", "-m", "monai.bundle", "ckpt_export",
                "network_def", "--filepath", ts_file
            ]
            cmd += [
                "--meta_file", meta_file, "--config_file", config_file,
                "--ckpt_file", ckpt_file
            ]
            cmd += ["--key_in_ckpt", key_in_ckpt, "--args_file", def_args_file]
            subprocess.check_call(cmd)
            self.assertTrue(os.path.exists(ts_file))
Ejemplo n.º 2
0
    def test_transforms(self, case_id):
        set_determinism(2022)
        config = ConfigParser()
        config.read_config(TEST_CASES)
        config["input_keys"] = keys
        test_case = config.get_parsed_content(id=case_id, instantiate=True)  # transform instance

        dataset = CacheDataset(self.files, transform=test_case)
        loader = DataLoader(dataset, batch_size=3, shuffle=True)
        for x in loader:
            self.assertIsInstance(x[keys[0]], MetaTensor)
            self.assertIsInstance(x[keys[1]], MetaTensor)
            out = decollate_batch(x)  # decollate every batch should work

        # test forward patches
        loaded = out[0]
        self.assertEqual(len(loaded), len(keys))
        img, seg = loaded[keys[0]], loaded[keys[1]]
        expected = config.get_parsed_content(id=f"{case_id}_answer", instantiate=True)  # expected results
        self.assertEqual(expected["load_shape"], list(x[keys[0]].shape))
        assert_allclose(expected["affine"], img.affine, type_test=False, atol=TINY_DIFF, rtol=TINY_DIFF)
        assert_allclose(expected["affine"], seg.affine, type_test=False, atol=TINY_DIFF, rtol=TINY_DIFF)
        test_cls = [type(x).__name__ for x in test_case.transforms]
        tracked_cls = [x[TraceKeys.CLASS_NAME] for x in img.applied_operations]
        self.assertTrue(len(tracked_cls) <= len(test_cls))  # tracked items should  be no more than the compose items.
        with tempfile.TemporaryDirectory() as tempdir:  # test writer
            SaveImageD(keys, resample=False, output_dir=tempdir, output_postfix=case_id)(loaded)

        # test inverse
        inv = InvertD(keys, orig_keys=keys, transform=test_case, nearest_interp=True)
        out = inv(loaded)
        img, seg = out[keys[0]], out[keys[1]]
        assert_allclose(expected["inv_affine"], img.affine, type_test=False, atol=TINY_DIFF, rtol=TINY_DIFF)
        assert_allclose(expected["inv_affine"], seg.affine, type_test=False, atol=TINY_DIFF, rtol=TINY_DIFF)
        self.assertFalse(img.applied_operations)
        self.assertFalse(seg.applied_operations)
        assert_allclose(expected["inv_shape"], img.shape, type_test=False, atol=TINY_DIFF, rtol=TINY_DIFF)
        assert_allclose(expected["inv_shape"], seg.shape, type_test=False, atol=TINY_DIFF, rtol=TINY_DIFF)
        with tempfile.TemporaryDirectory() as tempdir:  # test writer
            SaveImageD(keys, resample=False, output_dir=tempdir, output_postfix=case_id)(out)
            seg_file = os.path.join(tempdir, key_1, f"{key_1}_{case_id}.nii.gz")
            segout = nib.load(seg_file).get_fdata()
            segin = nib.load(FILE_PATH_1).get_fdata()
            ndiff = np.sum(np.abs(segout - segin) > 0)
            total = np.prod(segout.shape)
        self.assertTrue(ndiff / total < 0.4, f"{ndiff / total}")