Exemple #1
0
    def __init__(
        self,
        device: torch.device,
        val_data_loader,
        network,
        prepare_batch: Callable = default_prepare_batch,
        iteration_update: Optional[Callable] = None,
        inferer=SimpleInferer(),
        post_transform=None,
        key_val_metric=None,
        additional_metrics=None,
        val_handlers=None,
    ):
        super().__init__(
            device=device,
            val_data_loader=val_data_loader,
            prepare_batch=prepare_batch,
            iteration_update=iteration_update,
            post_transform=post_transform,
            key_val_metric=key_val_metric,
            additional_metrics=additional_metrics,
            val_handlers=val_handlers,
        )

        self.network = network
        self.inferer = inferer
Exemple #2
0
    def __init__(
        self,
        device: torch.device,
        max_epochs: int,
        train_data_loader,
        network,
        optimizer,
        loss_function,
        prepare_batch: Callable = default_prepare_batch,
        iteration_update: Optional[Callable] = None,
        inferer=SimpleInferer(),
        amp: bool = True,
        post_transform=None,
        key_train_metric: Optional[Metric] = None,
        additional_metrics=None,
        train_handlers=None,
    ):
        # set up Ignite engine and environments
        super().__init__(
            device=device,
            max_epochs=max_epochs,
            amp=amp,
            data_loader=train_data_loader,
            prepare_batch=prepare_batch,
            iteration_update=iteration_update,
            key_metric=key_train_metric,
            additional_metrics=additional_metrics,
            handlers=train_handlers,
            post_transform=post_transform,
        )

        self.network = network
        self.optimizer = optimizer
        self.loss_function = loss_function
        self.inferer = inferer
Exemple #3
0
    def __init__(
        self,
        device,
        max_epochs,
        train_data_loader,
        network,
        optimizer,
        loss_function,
        prepare_batch=default_prepare_batch,
        iteration_update=None,
        lr_scheduler=None,
        inferer=SimpleInferer(),
        amp=True,
        key_train_metric=None,
        additional_metrics=None,
        train_handlers=None,
    ):
        # set up Ignite engine and environments
        super().__init__(
            device=device,
            max_epochs=max_epochs,
            amp=amp,
            data_loader=train_data_loader,
            prepare_batch=prepare_batch,
            iteration_update=iteration_update,
            key_metric=key_train_metric,
            additional_metrics=additional_metrics,
            handlers=train_handlers,
        )

        self.network = network
        self.optimizer = optimizer
        self.loss_function = loss_function
        self.inferer = inferer
Exemple #4
0
    def initialize(self, args):
        """
        `initialize` is called only once when the model is being loaded.
        Implementing `initialize` function is optional. This function allows
        the model to intialize any state associated with this model.
        """

        # Pull model from google drive
        extract_dir = "/models/mednist_class/1"
        tar_save_path = os.path.join(extract_dir, model_filename)
        download_and_extract(gdrive_url,
                             tar_save_path,
                             output_dir=extract_dir,
                             hash_val=md5_check,
                             hash_type="md5")
        # load model configuration
        self.model_config = json.loads(args['model_config'])

        # create inferer engine and load PyTorch model
        inference_device_kind = args.get('model_instance_kind', None)
        logger.info(f"Inference device: {inference_device_kind}")

        self.inference_device = torch.device('cpu')
        if inference_device_kind is None or inference_device_kind == 'CPU':
            self.inference_device = torch.device('cpu')
        elif inference_device_kind == 'GPU':
            inference_device_id = args.get('model_instance_device_id', '0')
            logger.info(f"Inference device id: {inference_device_id}")

            if torch.cuda.is_available():
                self.inference_device = torch.device(
                    f'cuda:{inference_device_id}')
                cudnn.enabled = True
            else:
                logger.error(
                    f"No CUDA device detected. Using device: {inference_device_kind}"
                )

        # create pre-transforms for MedNIST
        self.pre_transforms = Compose([
            LoadImage(reader="PILReader", image_only=True, dtype=np.float32),
            ScaleIntensity(),
            AddChannel(),
            AddChannel(),
            ToTensor(),
            Lambda(func=lambda x: x.to(device=self.inference_device)),
        ])

        # create post-transforms
        self.post_transforms = Compose([
            Lambda(func=lambda x: x.to(device="cpu")),
        ])

        self.inferer = SimpleInferer()

        self.model = torch.jit.load(
            f'{pathlib.Path(os.path.realpath(__file__)).parent}{os.path.sep}model.pt',
            map_location=self.inference_device)
Exemple #5
0
    def __init__(
        self,
        device,
        val_data_loader,
        network,
        prepare_batch=default_prepare_batch,
        iteration_update=None,
        inferer=SimpleInferer(),
        key_val_metric=None,
        additional_metrics=None,
        val_handlers=None,
    ):
        super().__init__(
            device, val_data_loader, prepare_batch, iteration_update, key_val_metric, additional_metrics, val_handlers
        )

        self.network = network
        self.inferer = inferer
Exemple #6
0
    def initialize(self, args):
        """
        `initialize` is called only once when the model is being loaded.
        Implementing `initialize` function is optional. This function allows
        the model to intialize any state associated with this model.
        """

        # Pull model from google drive
        extract_dir = "/models/monai_covid/1"
        tar_save_path = os.path.join(extract_dir, model_filename)
        download_and_extract(gdrive_url,
                             tar_save_path,
                             output_dir=extract_dir,
                             hash_val=md5_check,
                             hash_type="md5")
        # load model configuration
        self.model_config = json.loads(args['model_config'])

        # create inferer engine and load PyTorch model
        inference_device_kind = args.get('model_instance_kind', None)
        logger.info(f"Inference device: {inference_device_kind}")

        self.inference_device = torch.device('cpu')
        if inference_device_kind is None or inference_device_kind == 'CPU':
            self.inference_device = torch.device('cpu')
        elif inference_device_kind == 'GPU':
            inference_device_id = args.get('model_instance_device_id', '0')
            logger.info(f"Inference device id: {inference_device_id}")

            if torch.cuda.is_available():
                self.inference_device = torch.device(
                    f'cuda:{inference_device_id}')
                cudnn.enabled = True
            else:
                logger.error(
                    f"No CUDA device detected. Using device: {inference_device_kind}"
                )

        # create pre-transforms
        self.pre_transforms = Compose([
            LoadImage(reader="NibabelReader",
                      image_only=True,
                      dtype=np.float32),
            AddChannel(),
            ScaleIntensityRange(a_min=-1000,
                                a_max=500,
                                b_min=0.0,
                                b_max=1.0,
                                clip=True),
            CropForeground(margin=5),
            Resize([192, 192, 64], mode="area"),
            AddChannel(),
            ToTensor(),
            Lambda(func=lambda x: x.to(device=self.inference_device)),
        ])

        # create post-transforms
        self.post_transforms = Compose([
            Lambda(func=lambda x: x.to(device="cpu")),
            Activations(sigmoid=True),
            ToNumpy(),
            AsDiscrete(threshold_values=True, logit_thresh=0.5),
        ])

        self.inferer = SimpleInferer()

        self.model = torch.jit.load(
            f'{pathlib.Path(os.path.realpath(__file__)).parent}{os.path.sep}covid19_model.ts',
            map_location=self.inference_device)