Exemplo n.º 1
0
    def train_and_infer(self, idx=0):
        results = []
        set_determinism(0)
        losses, best_metric, best_metric_epoch = run_training_test(self.data_dir, device=self.device, cachedataset=idx)
        infer_metric = run_inference_test(self.data_dir, device=self.device)

        # check training properties
        print("losses", losses)
        print("best metric", best_metric)
        print("infer metric", infer_metric)
        self.assertTrue(test_integration_value(TASK, key="losses", data=losses, rtol=1e-3))
        self.assertTrue(test_integration_value(TASK, key="best_metric", data=best_metric, rtol=1e-2))
        self.assertTrue(len(glob(os.path.join(self.data_dir, "runs"))) > 0)
        model_file = os.path.join(self.data_dir, "best_metric_model.pth")
        self.assertTrue(os.path.exists(model_file))

        # check inference properties
        self.assertTrue(test_integration_value(TASK, key="infer_metric", data=infer_metric, rtol=1e-2))
        output_files = sorted(glob(os.path.join(self.data_dir, "output", "img*", "*.nii.gz")))
        print([np.mean(nib.load(output).get_fdata()) for output in output_files])
        results.extend(losses)
        results.append(best_metric)
        results.append(infer_metric)
        for output in output_files:
            ave = np.mean(nib.load(output).get_fdata())
            results.append(ave)
        self.assertTrue(test_integration_value(TASK, key="output_sums", data=results[8:], rtol=1e-2))
        return results
    def train_and_infer(self, idx=0):
        results = []
        if not os.path.exists(os.path.join(self.data_dir, "MedNIST")):
            # skip test if no MedNIST dataset
            return results

        set_determinism(seed=0)
        losses, best_metric, best_metric_epoch = run_training_test(
            self.data_dir, self.train_x, self.train_y, self.val_x, self.val_y, device=self.device
        )
        infer_metric = run_inference_test(self.data_dir, self.test_x, self.test_y, device=self.device)

        print(f"integration_classification_2d {losses}")
        print("best metric", best_metric)
        print("infer metric", infer_metric)
        # check training properties
        self.assertTrue(test_integration_value(TASK, key="losses", data=losses, rtol=1e-2))
        self.assertTrue(test_integration_value(TASK, key="best_metric", data=best_metric, rtol=1e-4))
        np.testing.assert_allclose(best_metric_epoch, 4)
        model_file = os.path.join(self.data_dir, "best_metric_model.pth")
        self.assertTrue(os.path.exists(model_file))
        # check inference properties
        self.assertTrue(test_integration_value(TASK, key="infer_prop", data=np.asarray(infer_metric), rtol=1))
        results.extend(losses)
        results.append(best_metric)
        results.extend(infer_metric)
        return results
Exemplo n.º 3
0
    def test_training(self):
        repeated = []
        for i in range(3):
            set_determinism(0)

            repeated.append([])
            losses, best_metric, best_metric_epoch = run_training_test(
                self.data_dir, device=self.device, cachedataset=(i == 2)
            )

            # check training properties
            print("losses", losses)
            self.assertTrue(test_integration_value(TASK, key="losses", data=losses, rtol=1e-3))
            repeated[i].extend(losses)
            print("best metric", best_metric)
            self.assertTrue(test_integration_value(TASK, key="best_metric", data=best_metric, rtol=1e-2))
            repeated[i].append(best_metric)
            self.assertTrue(len(glob(os.path.join(self.data_dir, "runs"))) > 0)
            model_file = os.path.join(self.data_dir, "best_metric_model.pth")
            self.assertTrue(os.path.exists(model_file))

            infer_metric = run_inference_test(self.data_dir, device=self.device)

            # check inference properties
            print("infer metric", infer_metric)
            self.assertTrue(test_integration_value(TASK, key="infer_metric", data=infer_metric, rtol=1e-2))
            repeated[i].append(infer_metric)
            output_files = sorted(glob(os.path.join(self.data_dir, "output", "img*", "*.nii.gz")))
            print([np.mean(nib.load(output).get_fdata()) for output in output_files])
            for output in output_files:
                ave = np.mean(nib.load(output).get_fdata())
                repeated[i].append(ave)
            self.assertTrue(test_integration_value(TASK, key="output_sums", data=repeated[i][8:], rtol=1e-2))
        np.testing.assert_allclose(repeated[0], repeated[1])
        np.testing.assert_allclose(repeated[0], repeated[2])
 def _test_saved_files(postfix):
     output_files = sorted(glob(os.path.join(self.data_dir, "img*", f"*{postfix}.nii.gz")))
     values = []
     for output in output_files:
         ave = np.mean(nib.load(output).get_fdata())
         values.append(ave)
     if idx == 2:
         self.assertTrue(test_integration_value(TASK, key="output_sums_2", data=values, rtol=1e-2))
     else:
         self.assertTrue(test_integration_value(TASK, key="output_sums", data=values, rtol=1e-2))
    def train_and_infer(self, idx=0):
        results = []
        set_determinism(seed=0)
        best_metric = run_training_test(self.data_dir, device=self.device, amp=(idx == 2))
        model_file = sorted(glob(os.path.join(self.data_dir, "net_key_metric*.pt")))[-1]
        infer_metric = run_inference_test(self.data_dir, model_file, device=self.device, amp=(idx == 2))

        print("best metric", best_metric)
        print("infer metric", infer_metric)
        if idx == 2:
            self.assertTrue(test_integration_value(TASK, key="best_metric_2", data=best_metric, rtol=1e-2))
        else:
            self.assertTrue(test_integration_value(TASK, key="best_metric", data=best_metric, rtol=1e-2))
        # check inference properties
        if idx == 2:
            self.assertTrue(test_integration_value(TASK, key="infer_metric_2", data=infer_metric, rtol=1e-2))
        else:
            self.assertTrue(test_integration_value(TASK, key="infer_metric", data=infer_metric, rtol=1e-2))
        results.append(best_metric)
        results.append(infer_metric)

        def _test_saved_files(postfix):
            output_files = sorted(glob(os.path.join(self.data_dir, "img*", f"*{postfix}.nii.gz")))
            values = []
            for output in output_files:
                ave = np.mean(nib.load(output).get_fdata())
                values.append(ave)
            if idx == 2:
                self.assertTrue(test_integration_value(TASK, key="output_sums_2", data=values, rtol=1e-2))
            else:
                self.assertTrue(test_integration_value(TASK, key="output_sums", data=values, rtol=1e-2))

        _test_saved_files(postfix="seg_handler")
        _test_saved_files(postfix="seg_transform")
        try:
            os.remove(model_file)
        except Exception as e:
            warnings.warn(f"Fail to remove {model_file}: {e}.")
        if torch.cuda.is_available():
            try:
                torch.cuda.empty_cache()
            except Exception:
                pass

        return results
Exemplo n.º 6
0
    def test_training(self):
        repeated = []
        test_rounds = 3 if monai.config.get_torch_version_tuple() >= (1,
                                                                      6) else 2
        for i in range(test_rounds):
            set_determinism(seed=0)

            repeated.append([])
            best_metric = run_training_test(self.data_dir,
                                            device=self.device,
                                            amp=(i == 2))
            print("best metric", best_metric)
            if i == 2:
                self.assertTrue(
                    test_integration_value(TASK,
                                           key="best_metric_2",
                                           data=best_metric,
                                           rtol=1e-2))
            else:
                self.assertTrue(
                    test_integration_value(TASK,
                                           key="best_metric",
                                           data=best_metric,
                                           rtol=1e-2))
            repeated[i].append(best_metric)

            model_file = sorted(
                glob(os.path.join(self.data_dir, "net_key_metric*.pt")))[-1]
            infer_metric = run_inference_test(self.data_dir,
                                              model_file,
                                              device=self.device,
                                              amp=(i == 2))
            print("infer metric", infer_metric)
            # check inference properties
            if i == 2:
                self.assertTrue(
                    test_integration_value(TASK,
                                           key="infer_metric_2",
                                           data=infer_metric,
                                           rtol=1e-2))
            else:
                self.assertTrue(
                    test_integration_value(TASK,
                                           key="infer_metric",
                                           data=infer_metric,
                                           rtol=1e-2))
            repeated[i].append(infer_metric)

            output_files = sorted(
                glob(os.path.join(self.data_dir, "img*", "*.nii.gz")))
            for output in output_files:
                ave = np.mean(nib.load(output).get_fdata())
                repeated[i].append(ave)
            if i == 2:
                self.assertTrue(
                    test_integration_value(TASK,
                                           key="output_sums_2",
                                           data=repeated[i][2:],
                                           rtol=1e-2))
            else:
                self.assertTrue(
                    test_integration_value(TASK,
                                           key="output_sums",
                                           data=repeated[i][2:],
                                           rtol=1e-2))
        np.testing.assert_allclose(repeated[0], repeated[1])