Ejemplo n.º 1
0
    def restore_from(
        self,
        calling_cls,
        restore_path: str,
        override_config_path: Optional[Union[OmegaConf, str]] = None,
        map_location: Optional[torch.device] = None,
        strict: bool = True,
        return_config: bool = False,
        trainer: Trainer = None,
    ):
        """
        Restores model instance (weights and configuration) into .mridc file

        Parameters
        ----------
        calling_cls: The class of the model to be restored.
        restore_path: path to .mridc file from which model should be instantiated
        override_config_path: path to a yaml config that will override the internal config file or an
        OmegaConf/DictConfig object representing the model config.
        map_location: Optional torch.device() to map the instantiated model to a device. By default (None), it will
        select a GPU if available, falling back to CPU otherwise.
        strict: Passed to load_state_dict. By default, True.
        return_config: If set to true, will return just the underlying config of the restored model as an
        OmegaConf/DictConfig object without instantiating the model.
        trainer: Optional trainer object to be used for restoring the model.

        Returns
        -------
        An instance of type cls or its underlying config (if return_config is set).
        """
        # Get path where the command is executed - the artifacts will be "retrieved" there (original .mridc behavior)
        loaded_params = self.load_config_and_state_dict(
            calling_cls,
            restore_path,
            override_config_path,
            map_location,
            strict,
            return_config,
            trainer,
        )

        if not isinstance(loaded_params, tuple):
            return loaded_params

        _, instance, state_dict = loaded_params
        self.load_instance_with_state_dict(instance, state_dict, strict)
        logging.info(
            f"Model {instance.__class__.__name__} was successfully restored from {restore_path}."
        )
        return instance
Ejemplo n.º 2
0
    def load_state_dict(self, state_dict):
        """Load the state of the optimizer."""
        # Optimizer.
        optimizer_key = "optimizer"
        if optimizer_key not in state_dict:
            optimizer_key = "optimizer_state_dict"
            logging.info("***WARNING*** loading optimizer from "
                         "an old checkpoint ...")
        self.optimizer.load_state_dict(state_dict[optimizer_key])

        # Copy data for the main params.
        fp32_from_float16_params_key = "fp32_from_fp16_params"
        if fp32_from_float16_params_key not in state_dict:
            fp32_from_float16_params_key = "fp32_from_fp16"
        for current_group, saved_group in zip(
                self.fp32_from_float16_groups,
                state_dict[fp32_from_float16_params_key]):
            for current_param, saved_param in zip(current_group, saved_group):
                current_param.data.copy_(saved_param.data)
Ejemplo n.º 3
0
def initialize_distributed(args, backend="nccl"):
    """
    Initialize distributed training.

    Parameters
    ----------
    args: The arguments object.
    backend: The backend to use.
        default: "nccl"

    Returns
    -------
    local_rank: The local rank of the process.
    rank: The rank of the process.
    world_size: The number of processes.
    """
    # Get local rank in case it is provided.
    local_rank = args.local_rank

    # Get rank and world size.
    rank = int(os.getenv("RANK", "0"))
    world_size = int(os.getenv("WORLD_SIZE", "1"))

    logging.info(
        f"Initializing torch.distributed with local_rank: {local_rank}, rank: {rank}, world_size: {world_size}"
    )

    # Set the device id.
    device = rank % torch.cuda.device_count()
    if local_rank is not None:
        device = local_rank
    torch.cuda.set_device(device)

    # Call the init process.
    init_method = "tcp://"
    master_ip = os.getenv("MASTER_ADDR", "localhost")
    master_port = os.getenv("MASTER_PORT", "6000")
    init_method += f"{master_ip}:{master_port}"
    torch.distributed.init_process_group(backend=backend,
                                         world_size=world_size,
                                         rank=rank,
                                         init_method=init_method)
    return local_rank, rank, world_size
Ejemplo n.º 4
0
    def _del_model_without_trainer(self, filepath: str) -> None:
        """
        Delete a model without a trainer.

        Parameters
        ----------
        filepath: The path to the model to delete.
        """
        app_state = AppState()
        if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1:
            # filepath needs to be updated to include mp_rank
            filepath = mridc.utils.model_utils.inject_model_parallel_rank(
                filepath)  # type: ignore

        # each model parallel rank needs to remove its model
        if is_global_rank_zero() or (app_state.model_parallel_size is not None
                                     and app_state.data_parallel_rank == 0):
            try:
                self._fs.rm(filepath)
                logging.info(f"Removed checkpoint: {filepath}")
            except FileNotFoundError:
                logging.info(
                    f"Tried to remove checkpoint: {filepath} but failed.")
Ejemplo n.º 5
0
def main(cfg: DictConfig) -> None:
    """
    Main function for training and running a model

    Parameters
    ----------
    cfg: Configuration (yaml) file.
        DictConfig
    """
    logging.info(f"Config: {OmegaConf.to_yaml(cfg)}")

    trainer = pl.Trainer(**cfg.trainer)
    exp_manager(trainer, cfg.get("exp_manager", None))

    model_name = (cfg.model["model_name"]).upper()

    if model_name == "CASCADENET":
        model = CascadeNet(cfg.model, trainer=trainer)
    elif model_name == "CIRIM":
        model = CIRIM(cfg.model, trainer=trainer)
    elif model_name == "CRNNET":
        model = CRNNet(cfg.model, trainer=trainer)
    elif model_name == "DUNET":
        model = DUNet(cfg.model, trainer=trainer)
    elif model_name in ("E2EVN", "VN"):
        model = VarNet(cfg.model, trainer=trainer)
    elif model_name == "JOINTICNET":
        model = JointICNet(cfg.model, trainer=trainer)
    elif model_name == "KIKINET":
        model = KIKINet(cfg.model, trainer=trainer)
    elif model_name == "LPDNET":
        model = LPDNet(cfg.model, trainer=trainer)
    elif model_name == "MULTIDOMAINNET":
        model = MultiDomainNet(cfg.model, trainer=trainer)
    elif model_name == "PICS":
        model = PICS(cfg.model, trainer=trainer)
    elif model_name == "RVN":
        model = RecurrentVarNet(cfg.model, trainer=trainer)
    elif model_name == "UNET":
        model = UNet(cfg.model, trainer=trainer)
    elif model_name == "VSNET":
        model = VSNet(cfg.model, trainer=trainer)
    elif model_name == "XPDNET":
        model = XPDNet(cfg.model, trainer=trainer)
    elif model_name == "ZF":
        model = ZF(cfg.model, trainer=trainer)
    else:
        raise NotImplementedError(
            f"{model_name} is not implemented in MRIDC. You can use one of the following methods: "
            "CASCADENET, CIRIM, CRNNET, DUNET, E2EVN, JOINTICNET, KIKINET, LPDNET, MULTIDOMAINNET, PICS, RVN, UNET, "
            "VSNET, XPDNET, or Zero-Filled. /n"
            "If you implemented a new model, please consider adding it through a PR on GitHub."
        )

    if cfg.get("pretrained", None):
        checkpoint = cfg.get("checkpoint", None)
        logging.info(f"Loading pretrained model from {checkpoint}")
        model.load_state_dict(torch.load(checkpoint)["state_dict"])

    if cfg.get("mode", None) == "train":
        logging.info("Validating")
        trainer.validate(model)
        logging.info("Training")
        trainer.fit(model)
    else:
        logging.info("Evaluating")
        trainer.test(model)
Ejemplo n.º 6
0
def assert_dataclass_signature_match(
    cls: "class_type",  # type: ignore
    datacls: "dataclass",  # type: ignore
    ignore_args: Optional[List[str]] = None,
    remap_args: Optional[Dict[str, str]] = None,
):
    """
    Analyses the signature of a provided class and its respective data class,
    asserting that the dataclass signature matches the class __init__ signature.
    Note:
        This is not a value based check. This function only checks if all argument
        names exist on both class and dataclass and logs mismatches.

    Parameters
    ----------
    cls: Any class type - but not an instance of a class. Pass type(x) where x is an instance
        if class type is not easily available.
    datacls: A corresponding dataclass for the above class.
    ignore_args: (Optional) A list of string argument names which are forcibly ignored,
        even if mismatched in the signature. Useful when a dataclass is a superset of the
        arguments of a class.
    remap_args: (Optional) A dictionary, mapping an argument name that exists (in either the
        class or its dataclass), to another name. Useful when argument names are mismatched between
        a class and its dataclass due to indirect instantiation via a helper method.

    Returns
    -------
    A tuple containing information about the analysis:
        1) A bool value which is True if the signatures matched exactly / after ignoring values.
            False otherwise.
        2) A set of arguments names that exist in the class, but *do not* exist in the dataclass.
            If exact signature match occurs, this will be None instead.
        3) A set of argument names that exist in the data class, but *do not* exist in the class itself.
            If exact signature match occurs, this will be None instead.
    """
    class_sig = inspect.signature(cls.__init__)

    class_params = dict(**class_sig.parameters)
    class_params.pop("self")

    dataclass_sig = inspect.signature(datacls)

    dataclass_params = dict(**dataclass_sig.parameters)
    dataclass_params.pop("_target_", None)

    class_params = set(class_params.keys())  # type: ignore
    dataclass_params = set(dataclass_params.keys())  # type: ignore

    if remap_args is not None:
        for original_arg, new_arg in remap_args.items():
            if original_arg in class_params:
                class_params.remove(original_arg)  # type: ignore
                class_params.add(new_arg)  # type: ignore
                logging.info(
                    f"Remapped {original_arg} -> {new_arg} in {cls.__name__}")

            if original_arg in dataclass_params:
                dataclass_params.remove(original_arg)  # type: ignore
                dataclass_params.add(new_arg)  # type: ignore
                logging.info(
                    f"Remapped {original_arg} -> {new_arg} in {datacls.__name__}"
                )

    if ignore_args is not None:
        ignore_args = set(ignore_args)  # type: ignore

        class_params = class_params - ignore_args  # type: ignore
        dataclass_params = dataclass_params - ignore_args  # type: ignore
        logging.info(f"Removing ignored arguments - {ignore_args}")

    intersection: Set[type] = set.intersection(
        class_params, dataclass_params)  # type: ignore
    subset_cls = class_params - intersection  # type: ignore
    subset_datacls = dataclass_params - intersection  # type: ignore

    if (len(class_params) != len(dataclass_params)
        ) or len(subset_cls) > 0 or len(subset_datacls) > 0:
        logging.error(f"Class {cls.__name__} arguments do not match "
                      f"Dataclass {datacls.__name__}!")

        if len(subset_cls) > 0:
            logging.error(f"Class {cls.__name__} has additional arguments :\n"
                          f"{subset_cls}")

        if len(subset_datacls):
            logging.error(
                f"Dataclass {datacls.__name__} has additional arguments :\n{subset_datacls}"
            )

        return False, subset_cls, subset_datacls
    return True, None, None
Ejemplo n.º 7
0
def verify_runtime(
    output,
    input_list,
    input_dict,
    input_names,
    output_names,
    output_example,
    check_tolerance=0.01,
):
    """
    Verify runtime output with onnxrt.

    Parameters
    ----------
    output: The output of the module.
    input_list: The input list of the module.
    input_dict: The input dict of the module.
    input_names: The input names of the module.
    output_names: The output names of the module.
    output_example: The output example of the module.
    check_tolerance: The tolerance for the check.

    Returns
    -------
    The runtime output.
    """
    # Verify the model can be read, and is valid
    onnx_model = onnx.load(output)
    input_names = [node.name for node in onnx_model.graph.input]
    # skipcq: PYL-W0622
    global ort_available
    if not ort_available:
        logging.warning(
            f"ONNX generated at {output}, not verified - please install onnxruntime_gpu package.\n"
        )
        onnx.checker.check_model(onnx_model, full_check=True)
        return

    onnx_session_opt = onnxruntime.SessionOptions()
    onnx_session_opt.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL

    sess = onnxruntime.InferenceSession(onnx_model.SerializeToString(),
                                        sess_options=onnx_session_opt,
                                        providers=["CUDAExecutionProvider"])
    ort_out = sess.run(output_names,
                       to_onnxrt_input(input_names, input_dict, input_list))
    all_good = True

    for i, out in enumerate(ort_out[0]):
        expected = output_example[i]
        if torch.is_tensor(expected):
            tout = torch.from_numpy(out)
            if not torch.allclose(tout,
                                  expected.cpu(),
                                  rtol=check_tolerance,
                                  atol=100 * check_tolerance):
                all_good = False
                logging.info(
                    f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{tout}"
                )
    status = "SUCCESS" if all_good else "FAIL"
    logging.info(
        f"ONNX generated at {output} verified with onnxruntime : {status}")
    return all_good
Ejemplo n.º 8
0
Archivo: cloud.py Proyecto: wdika/mridc
def maybe_download_from_cloud(url,
                              filename,
                              subfolder=None,
                              cache_dir=None,
                              refresh_cache=False) -> str:
    """
    Download a file from a URL if it does not exist in the cache.

    Parameters
    ----------
    url: URL to download the file from.
        str
    filename: What to download. The request will be issued to url/filename
        str
    subfolder: Subfolder within cache_dir. The file will be stored in cache_dir/subfolder. Subfolder can be empty.
        str
    cache_dir: A cache directory where to download. If not present, this function will attempt to create it.
        str, If None (default), then it will be $HOME/.cache/torch/mridc
    refresh_cache: If True and cached file is present, it will delete it and re-fetch
        bool

    Returns
    -------
    If successful - absolute local path to the downloaded file else empty string.
    """
    if cache_dir is None:
        cache_location = Path.joinpath(Path.home(), ".cache/torch/mridc")
    else:
        cache_location = cache_dir
    if subfolder is not None:
        destination = Path.joinpath(cache_location, subfolder)
    else:
        destination = cache_location

    if not os.path.exists(destination):
        os.makedirs(destination, exist_ok=True)

    destination_file = Path.joinpath(destination, filename)

    if os.path.exists(destination_file):
        logging.info(f"Found existing object {destination_file}.")
        if refresh_cache:
            logging.info("Asked to refresh the cache.")
            logging.info(f"Deleting file: {destination_file}")
            os.remove(destination_file)
        else:
            logging.info(f"Re-using file from: {destination_file}")
            return str(destination_file)
    # download file
    wget_uri = url + filename
    logging.info(f"Downloading from: {wget_uri} to {str(destination_file)}")
    # NGC links do not work everytime so we try and wait
    i = 0
    max_attempts = 3
    while i < max_attempts:
        i += 1
        try:
            wget.download(wget_uri, str(destination_file))
            if os.path.exists(destination_file):
                return str(destination_file)
            return ""
        except Exception as e:
            logging.info(
                f"Download from cloud failed. Attempt {i} of {max_attempts}")
            logging.info(f"Error: {e}")
            sleep(0.05)
            continue
    raise ValueError("Not able to download url right now, please try again.")
Ejemplo n.º 9
0
    def from_pretrained(
        cls,
        model_name: str,
        refresh_cache: bool = False,
        override_config_path: Optional[str] = None,
        map_location: Optional[torch.device] = None,
        strict: bool = True,
        return_config: bool = False,
        trainer: Optional[Trainer] = None,
        save_restore_connector: SaveRestoreConnector = None,
    ):
        """
        Instantiates an instance of mridc. Use restore_from() to instantiate from a local .mridc file.

        Parameters
        ----------
        model_name: String key which will be used to find the module.
        refresh_cache: If set to True, then when fetching from cloud, this will re-fetch the file from cloud even if it
         is already found in a cache locally.
        override_config_path: Path to a yaml config that will override the internal config file.
        map_location: Optional torch.device() to map the instantiated model to a device. By default (None), it will
        select a GPU if available, falling back to CPU otherwise.
        strict: Passed to torch.load_state_dict. By default, True.
        return_config: If set to true, will return just the underlying config of the restored model as an
        OmegaConf/DictConfig object without instantiating the model.
        trainer: Optional Trainer objects to use for restoring the model.
        save_restore_connector: Optional SaveRestoreConnector object to use for restoring the model.

        Returns
        -------
        A model instance of a particular model class or its underlying config (if return_config is set).
        """
        if save_restore_connector is None:
            save_restore_connector = SaveRestoreConnector()

        location_in_the_cloud = None
        description = None
        models = cls.list_available_models()
        if models is not None:
            for pretrained_model_info in cls.list_available_models():  # type: ignore
                found = False
                if pretrained_model_info.pretrained_model_name == model_name:
                    found = True
                elif pretrained_model_info.aliases is not None:
                    for alias in pretrained_model_info.aliases:
                        if alias == model_name:
                            found = True
                            break
                if found:
                    location_in_the_cloud = pretrained_model_info.location
                    description = pretrained_model_info.description
                    class_ = pretrained_model_info.class_
                    break

        if location_in_the_cloud is None:
            raise FileNotFoundError(
                f"Model {model_name} was not found. "
                "Check cls.list_available_models() for the list of all available models."
            )
        filename = location_in_the_cloud.split("/")[-1]
        url = location_in_the_cloud.replace(filename, "")
        cache_dir = Path.joinpath(mridc.utils.model_utils.resolve_cache_dir(), f"{filename[:-5]}")  # type: ignore
        # If either description and location in the cloud changes, this will force re-download
        # of the model.
        cache_subfolder = hashlib.sha512((location_in_the_cloud + description).encode("utf-8")).hexdigest()
        # if file exists on cache_folder/subfolder, it will be re-used, unless refresh_cache is True
        mridc_model_file_in_cache = maybe_download_from_cloud(
            url=url, filename=filename, cache_dir=cache_dir, subfolder=cache_subfolder, refresh_cache=refresh_cache
        )
        logging.info("Instantiating model from pre-trained checkpoint")
        if class_ is None:
            class_ = cls

        return class_.restore_from(
            restore_path=mridc_model_file_in_cache,
            override_config_path=override_config_path,
            map_location=map_location,
            strict=strict,
            return_config=return_config,
            trainer=trainer,
            save_restore_connector=save_restore_connector,
        )
Ejemplo n.º 10
0
def prepare_lr_scheduler(
    optimizer: optim.Optimizer,
    scheduler_config: Union[Dict[str, Any], DictConfig, None],
    train_dataloader: Optional[dataloader.DataLoader] = None,
) -> Optional[Dict[str, Any]]:
    """
    Constructs an LR Scheduler (optionally) for a given optimizer, based on a config with the following schema.

    Parameters
    ----------
    optimizer: The optimizer to use for the scheduler.
        name: <name of optimizer>

        lr: <maximal learning rate>

        # <additional optimizer arguments>

        args:

            name: auto  # special keyword, resolves to correct optimizer config for given optimizer name

            # cls: mridc.core.config.optimizers.NovogradParams  # explicit instantiation by class path

            params:  # optional override parameters for the optimizer config

                betas: [0.8, 0.5]

                weight_decay: 0.001

    scheduler_config: The scheduler config.

        name: <name of scheduler>

        iters_per_batch: null # computed at runtime; mandatory to have

        max_steps: null # computed at runtime or explicitly set here; mandatory to have

        # pytorch lightning args <mandatory>

        monitor: val_loss

        reduce_on_plateau: false

        # <scheduler config override>

        args:

            name: auto  # special keyword, resolves to correct optimizer config for given optimizer name

            # cls: mridc.core.config.schedulers.CosineAnnealingParams  # explicit instantiation by class path

            params:  # optional override parameters for the optimizer config

                warmup_steps: null

                warmup_ratio: null

                min_lr: 0.0

                last_epoch: -1

    train_dataloader: Optional requirement, must be passed if "iters_per_batch" is defined instead of "max_steps". \
    Used to compute effective "max_steps".

    Returns
    -------
    A dictionary containing the LR Scheduler implementation if the config was successfully parsed along with other \
    parameters required by Pytorch Lightning, otherwise None.
    """
    if scheduler_config is not None:
        scheduler_config = maybe_update_config_version(scheduler_config)

    # Build nested dictionary for convenience out of structured objects
    if isinstance(scheduler_config, DictConfig):
        scheduler_config = OmegaConf.to_container(scheduler_config,
                                                  resolve=True)

    elif dataclasses.is_dataclass(scheduler_config):
        # Recursively transform data classes to basic dictionaries
        scheduler_config = OmegaConf.create(scheduler_config)
        scheduler_config = OmegaConf.to_container(scheduler_config,
                                                  resolve=True)

    # Test to see if config follows above schema

    add_max_args_flag = True
    interval = "step"
    if scheduler_config is not None:
        if "args" in scheduler_config:
            scheduler_args = scheduler_config.pop("args")
        else:
            scheduler_args = copy.deepcopy(scheduler_config)

            # Remove extra parameters from scheduler_args nest
            # Assume all other parameters are to be passed into scheduler constructor

            if "name" in scheduler_args and scheduler_args[
                    "name"] == "ReduceLROnPlateau":
                add_max_args_flag = False
                interval = "epoch"

            scheduler_args.pop("name", None)
            scheduler_args.pop("t_max_epochs", None)
            scheduler_args.pop("t_accumulate_grad_batches", None)
            scheduler_args.pop("t_limit_train_batches", None)
            scheduler_args.pop("t_num_workers", None)
            scheduler_args.pop("monitor", None)
            scheduler_args.pop("reduce_on_plateau", None)

    else:
        # Return gracefully in case `sched` was not supplied; inform user
        logging.info(
            "Scheduler not initialized as no `sched` config supplied to setup_optimizer()"
        )
        return None

    # Try instantiation of scheduler params from config class path
    if "_target_" in scheduler_args:
        scheduler_args_cfg = OmegaConf.create(scheduler_args)
        scheduler_conf = hydra.utils.instantiate(scheduler_args_cfg)
        scheduler_args = vars(scheduler_conf)

        # Get name of the scheduler
        scheduler_name = scheduler_conf.__class__.__name__

        if "Params" in scheduler_name:
            scheduler_name = scheduler_name.replace("Params", "")

    else:
        # Class path instantiation failed; try resolving "name" component

        # Get name of the scheduler
        if "name" in scheduler_config:
            scheduler_name = scheduler_config["name"]
        else:
            logging.warning(
                "Could not resolve classpath for Scheduler Config, and `name` "
                "was not provided either. \n"
                "Scheduler cannot be instantiated !")
            return None

        # If class path was not provided, perhaps `name` is provided for resolution
        if "name" in scheduler_args:
            # If `auto` is passed as name for resolution of optimizer name,
            # then lookup optimizer name and resolve its parameter config
            if scheduler_args["name"] == "auto":
                scheduler_params_name = f"{scheduler_name}Params"
            else:
                scheduler_params_name = scheduler_args["name"]

            # Get override arguments provided in the config yaml file / Dict Config
            scheduler_params_override = scheduler_args.get("params", {})

            # If params is itself a dict config object provided explicitly in Dict Config
            # Resolve to dictionary for convenience
            if isinstance(scheduler_params_override, DictConfig):
                scheduler_params_override = OmegaConf.to_container(
                    scheduler_params_override, resolve=True)

            # Get and instantiate the Config dataclass for this scheduler
            scheduler_params_cls = get_scheduler_config(
                scheduler_params_name, **scheduler_params_override)
            scheduler_params = scheduler_params_cls  # instantiate the parameters object
            scheduler_args = vars(
                scheduler_params
            )  # extract just the dictionary from the Config object

    # Extract value to monitor in losses, if provided.
    if "monitor" in scheduler_config:
        monitor = scheduler_config.get("monitor")
    else:
        # Default to train loss
        monitor = "loss"

    # Store exact max_steps if it is provided
    if "max_steps" in scheduler_config and scheduler_config[
            "max_steps"] is not None:
        max_steps = scheduler_config["max_steps"]

    elif "t_max_epochs" in scheduler_config:
        # Compute effective max_steps if t_max_epochs is provided
        if train_dataloader is None:
            logging.warning(
                "As `t_max_epochs` is provided/computed, it is required to pass the train dataloader in order\n"
                "to compute effective maximum number of steps.\n"
                "Scheduler will not be instantiated !")
            return None

        # Raise exception if neither `max_steps` nor `t_max_epochs` is provided
        if scheduler_config.get("t_max_epochs", None) is None:
            logging.warning(
                "`t_max_epochs` cannot be None when `max_steps` is not not provided.\n"
                "This can occur when `train dataloader` is not available to correctly "
                "prepare the scheduler.\n"
                "Scheduler will not be instantiated !")
            return None

        # Get iters_per_batch
        max_epochs = scheduler_config.get("t_max_epochs")
        accumulate_grad_batches = scheduler_config.get(
            "t_accumulate_grad_batches")
        limit_train_batches = scheduler_config.get("t_limit_train_batches")
        num_workers = scheduler_config.get("t_num_workers")

        # Compute effective num max_steps
        num_samples = len(train_dataloader.dataset)  # type: ignore

        # we may need to override ModelPT setup_optimization
        if train_dataloader.batch_size is not None:
            batch_size = train_dataloader.batch_size
        elif hasattr(train_dataloader, "batch_sampler"
                     ) and train_dataloader.batch_sampler is not None:
            if train_dataloader.batch_sampler.micro_batch_size is not None:
                batch_size = train_dataloader.batch_sampler.micro_batch_size
            else:
                raise ValueError(
                    f"Could not find batch_size from batch_sampler: {train_dataloader.batch_sampler}"
                )
        else:
            raise ValueError(
                f"Could not find batch_size from train_dataloader: {train_dataloader}"
            )
        drop_last = train_dataloader.drop_last

        max_steps = compute_max_steps(
            max_epochs=max_epochs,
            accumulate_grad_batches=accumulate_grad_batches,
            limit_train_batches=limit_train_batches,
            num_workers=num_workers,
            num_samples=num_samples,
            batch_size=batch_size,
            drop_last=drop_last,
        )

    else:
        logging.warning(
            "Neither `max_steps` nor `iters_per_batch` were provided to `optim.sched`, "
            "cannot compute effective `max_steps` !\n"
            "Scheduler will not be instantiated !")
        return None

    # Inject max_steps (effective or provided) into the scheduler config
    if add_max_args_flag and scheduler_config.get("name",
                                                  "") != "ExponentialLR":
        scheduler_args["max_steps"] = max_steps

    # Get the scheduler class from the config
    scheduler_cls = get_scheduler(scheduler_name, **scheduler_args)

    # Instantiate the LR schedule
    schedule = scheduler_cls(optimizer, **scheduler_args)

    logging.info(
        'Scheduler "%s" \nwill be used during training (effective maximum steps = %d) - \nParameters : \n(%s)',
        str(schedule),
        max_steps,
        OmegaConf.to_yaml(OmegaConf.create(scheduler_args)),
    )

    # Wrap the schedule in PTL arguments to perform stepwise computation
    # Rather than epoch level computation
    reduce_lr_on_plateau = isinstance(schedule,
                                      optim.lr_scheduler.ReduceLROnPlateau)

    return {
        "scheduler": schedule,
        "interval": interval,
        "frequency": 1,
        "monitor": monitor,
        "reduce_on_plateau": reduce_lr_on_plateau,
    }
Ejemplo n.º 11
0
    def load_config_and_state_dict(
        self,
        calling_cls,
        restore_path: str,
        override_config_path: Optional[Union[OmegaConf, str]] = None,
        map_location: Optional[torch.device] = None,
        strict: bool = True,
        return_config: bool = False,
        trainer: Trainer = None,
    ):
        """
        Restores model instance (weights and configuration) into .mridc file

        Parameters
        ----------
        calling_cls: Class of the model to be restored.
        restore_path: path to .mridc file from which model should be instantiated
        override_config_path: path to a yaml config that will override the internal config file or an
        OmegaConf/DictConfig object representing the model config.
        map_location: Optional torch.device() to map the instantiated model to a device. By default (None), it will
        select a GPU if available, falling back to CPU otherwise.
        strict: Passed to load_state_dict. By default, True.
        return_config: If set to true, will return just the underlying config of the restored model as an OmegaConf
        DictConfig object without instantiating the model.
        trainer: Optional trainer object to be used for model parallelism.

        Example
        -------
            ```
            model = mridc.collections.asr.models.EncDecCTCModel.restore_from('asr.mridc')
            assert isinstance(model, mridc.collections.asr.models.EncDecCTCModel)
            ```

        Returns
        -------
        An instance of type cls or its underlying config (if return_config is set).
        """
        # Get path where the command is executed - the artifacts will be "retrieved" there
        # (original .mridc behavior)
        cwd = os.getcwd()

        if map_location is None:
            if torch.cuda.is_available():
                map_location = torch.device("cuda")
            else:
                map_location = torch.device("cpu")

        app_state = AppState()
        with tempfile.TemporaryDirectory() as tmpdir:
            try:
                # Check if self.model_extracted_dir is set, and is a valid path
                if self.model_extracted_dir is not None and os.path.isdir(
                        self.model_extracted_dir):
                    # Log that MRIDC will use the provided `model_extracted_dir`
                    logging.info(
                        "Restoration will occur within pre-extracted directory : "
                        f"`{self.model_extracted_dir}`.")
                    # Override `tmpdir` above with the pre-extracted `model_extracted_dir`
                    tmpdir = self.model_extracted_dir
                else:
                    # Extract the nemo file into the temporary directory
                    self._unpack_mridc_file(path2file=restore_path,
                                            out_folder=tmpdir)

                # Change current working directory to the temporary directory
                os.chdir(tmpdir)
                if override_config_path is None:
                    config_yaml = os.path.join(tmpdir, self.model_config_yaml)
                else:
                    # can be str path or OmegaConf / DictConfig object
                    config_yaml = override_config_path
                if not isinstance(config_yaml, (OmegaConf, DictConfig)):
                    conf = OmegaConf.load(config_yaml)
                else:
                    conf = config_yaml
                    if override_config_path is not None:
                        # Resolve the override config
                        conf = OmegaConf.to_container(conf, resolve=True)
                        conf = OmegaConf.create(conf)
                # If override is top level config, extract just `model` from it
                if "model" in conf:
                    conf = conf.model

                if return_config:
                    instance = conf
                    return instance
                if app_state.model_parallel_rank is not None and app_state.model_parallel_size > 1:
                    model_weights = self._inject_model_parallel_rank_for_ckpt(
                        tmpdir, self.model_weights_ckpt)
                else:
                    model_weights = os.path.join(tmpdir,
                                                 self.model_weights_ckpt)
                OmegaConf.set_struct(conf, True)
                os.chdir(cwd)
                # get the class
                calling_cls._set_model_restore_state(
                    is_being_restored=True, folder=tmpdir)  # type: ignore
                instance = calling_cls.from_config_dict(config=conf,
                                                        trainer=trainer)
                instance = instance.to(map_location)
                # add load_state_dict override
                if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1:
                    model_weights = self._inject_model_parallel_rank_for_ckpt(
                        tmpdir, self.model_weights_ckpt)
                instance.load_state_dict(self._load_state_dict_from_disk(
                    model_weights, map_location=map_location),
                                         strict=strict)
                logging.info(
                    f"Model {instance.__class__.__name__} was successfully restored from {restore_path}."
                )
                instance._set_model_restore_state(
                    is_being_restored=False)  # type: ignore
            finally:
                os.chdir(cwd)

        return instance
Ejemplo n.º 12
0
    def extract_state_dict_from(self,
                                restore_path: str,
                                save_dir: str,
                                split_by_module: bool = False):
        """
        Extract the state dict(s) from a provided .mridc tarfile and save it to a directory.

        Parameters
        ----------
        restore_path: path to .mridc file from which state dict(s) should be extracted
        save_dir: directory in which the saved state dict(s) should be stored
        split_by_module: bool flag, which determines whether the output checkpoint should be for the entire Model, or
        the individual module's that comprise the Model.

        Example
        -------
        To convert the .mridc tarfile into a single Model level PyTorch checkpoint
        ::
        state_dict = mridc.collections.asr.models.EncDecCTCModel.extract_state_dict_from('asr.mridc',
        './asr_ckpts')
        To restore a model from a Model level checkpoint
        ::
        model = mridc.collections.asr.models.EncDecCTCModel(cfg)  # or any other method of restoration
        model.load_state_dict(torch.load("./asr_ckpts/model_weights.ckpt"))
        To convert the .mridc tarfile into multiple Module level PyTorch checkpoints
        ::
        state_dict = mridc.collections.asr.models.EncDecCTCModel.extract_state_dict_from('asr.mridc',
        './asr_ckpts', split_by_module=True). To restore a module from a Module level checkpoint
        ::
        model = mridc.collections.asr.models.EncDecCTCModel(cfg)  # or any other method of restoration
        # load the individual components
        model.preprocessor.load_state_dict(torch.load("./asr_ckpts/preprocessor.ckpt"))
        model.encoder.load_state_dict(torch.load("./asr_ckpts/encoder.ckpt"))
        model.decoder.load_state_dict(torch.load("./asr_ckpts/decoder.ckpt"))

        Returns
        -------
        The state dict that was loaded from the original .mridc checkpoint.
        """
        cwd = os.getcwd()

        save_dir = os.path.abspath(save_dir)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir, exist_ok=True)

        with tempfile.TemporaryDirectory() as tmpdir:
            try:
                self._unpack_mridc_file(path2file=restore_path,
                                        out_folder=tmpdir)
                os.chdir(tmpdir)
                model_weights = os.path.join(tmpdir, self.model_weights_ckpt)
                state_dict = self._load_state_dict_from_disk(model_weights)

                if not split_by_module:
                    filepath = os.path.join(save_dir, self.model_weights_ckpt)
                    self._save_state_dict_to_disk(state_dict, filepath)

                else:
                    key_set = {key.split(".")[0] for key in state_dict.keys()}
                    for primary_key in key_set:
                        inner_keys = [
                            key for key in state_dict.keys()
                            if key.split(".")[0] == primary_key
                        ]
                        state_dict_subset = {
                            ".".join(inner_key.split(".")[1:]):
                            state_dict[inner_key]
                            for inner_key in inner_keys
                        }
                        filepath = os.path.join(save_dir,
                                                f"{primary_key}.ckpt")
                        self._save_state_dict_to_disk(state_dict_subset,
                                                      filepath)

                logging.info(
                    f"Checkpoints from {restore_path} were successfully extracted into {save_dir}."
                )
            finally:
                os.chdir(cwd)

        return state_dict
Ejemplo n.º 13
0
def configure_loggers(
    trainer: Trainer,
    exp_dir: List[Union[Path, str]],
    name: str,
    version: str,
    create_tensorboard_logger: bool,
    summary_writer_kwargs: dict,
    create_wandb_logger: bool,
    wandb_kwargs: dict,
):
    """
    Creates TensorboardLogger and/or WandBLogger and attach them to trainer. Raises ValueError if summary_writer_kwargs
    or wandb_kwargs are miss configured.

    Parameters
    ----------
    trainer: The trainer to attach the loggers to.
    exp_dir: The experiment directory.
    name: The name of the experiment.
    version: The version of the experiment.
    create_tensorboard_logger: Whether to create a TensorboardLogger.
    summary_writer_kwargs: The kwargs to pass to the TensorboardLogger.
    create_wandb_logger: Whether to create a Weights & Biases logger.
    wandb_kwargs: The kwargs to pass to the Weights & Biases logger.

    Returns
    -------
    LoggerList: A list of loggers.
    """
    # Potentially create tensorboard logger and/or WandBLogger
    logger_list = []
    if create_tensorboard_logger:
        if summary_writer_kwargs is None:
            summary_writer_kwargs = {}
        elif "log_dir" in summary_writer_kwargs:
            raise ValueError(
                "You cannot pass `log_dir` as part of `summary_writer_kwargs`. `log_dir` is handled by lightning's "
                "TensorBoardLogger logger.")
        tensorboard_logger = TensorBoardLogger(save_dir=exp_dir[0],
                                               name=name,
                                               version=version,
                                               **summary_writer_kwargs)
        logger_list.append(tensorboard_logger)
        logging.info("TensorboardLogger has been set up")

    if create_wandb_logger:
        if wandb_kwargs is None:
            wandb_kwargs = {}
        if "name" not in wandb_kwargs and "project" not in wandb_kwargs:
            raise ValueError("name and project are required for wandb_logger")
        wandb_logger = WandbLogger(save_dir=exp_dir[0],
                                   version=version,
                                   **wandb_kwargs)

        logger_list.append(wandb_logger)
        logging.info("WandBLogger has been set up")

    logger_list = (LoggerList(
        logger_list, mridc_name=name, mridc_version=version)
                   if len(logger_list) > 1 else logger_list[0])
    trainer._logger_connector.configure_logger(logger_list)
Ejemplo n.º 14
0
def check_resume(
    trainer: Trainer,
    log_dir: str,
    resume_past_end: bool = False,
    resume_ignore_no_checkpoint: bool = False,
):
    """
    Checks that resume=True was used correctly with the arguments pass to exp_manager. Sets
    trainer._checkpoint_connector.resume_from_checkpoint_fit_path as necessary.

    Parameters
    ----------
    trainer: The trainer that is being used.
    log_dir: The directory where the logs are being saved.
    resume_past_end: Whether to resume from the end of the experiment.
    resume_ignore_no_checkpoint: Whether to ignore if there is no checkpoint to resume from.

    Returns
    -------
    NotFoundError: If resume is True, resume_ignore_no_checkpoint is False, and checkpoints could not be found.
    ValueError: If resume is True, and there were more than 1 checkpoint could found.
    """
    if not log_dir:
        raise ValueError(
            f"Resuming requires the log_dir {log_dir} to be passed to exp_manager"
        )

    checkpoint_dir = Path(Path(log_dir) / "checkpoints")

    checkpoint = None
    end_checkpoints = list(checkpoint_dir.rglob("*end.ckpt"))
    last_checkpoints = list(checkpoint_dir.rglob("*last.ckpt"))
    if not checkpoint_dir.exists():
        if not resume_ignore_no_checkpoint:
            raise NotFoundError(
                f"There was no checkpoint folder at checkpoint_dir :{checkpoint_dir}. Cannot resume."
            )
        logging.warning(
            f"There was no checkpoint folder at checkpoint_dir :{checkpoint_dir}. Training from scratch."
        )
        return
    if end_checkpoints:
        if not resume_past_end:
            raise ValueError(
                f"Found {end_checkpoints[0]} indicating that the last training run has already completed."
            )
        if len(end_checkpoints) > 1:
            if "mp_rank" in str(end_checkpoints[0]):
                checkpoint = end_checkpoints[0]
            else:
                raise ValueError(
                    f"Multiple checkpoints {end_checkpoints} that matches *end.ckpt."
                )
        logging.info(f"Resuming from {end_checkpoints[0]}")
    elif not last_checkpoints:
        if not resume_ignore_no_checkpoint:
            raise NotFoundError(
                f"There were no checkpoints found in {checkpoint_dir}. Cannot resume."
            )
        logging.warning(
            f"There were no checkpoints found in {checkpoint_dir}. Training from scratch."
        )
        return
    elif len(last_checkpoints) > 1:
        if "mp_rank" not in str(last_checkpoints[0]) and "tp_rank" not in str(
                last_checkpoints[0]):
            raise ValueError(
                f"Multiple checkpoints {last_checkpoints} that matches *last.ckpt."
            )
        checkpoint = last_checkpoints[0]
        checkpoint = mridc.utils.model_utils.uninject_model_parallel_rank(
            checkpoint)  # type: ignore
    else:
        logging.info(f"Resuming from {last_checkpoints[0]}")
        checkpoint = last_checkpoints[0]

    trainer._checkpoint_connector.resume_from_checkpoint_fit_path = str(
        checkpoint)

    if is_global_rank_zero():
        if files_to_move := [
                child for child in Path(log_dir).iterdir() if child.is_file()
        ]:
            # Move old files to a new folder
            other_run_dirs = Path(log_dir).glob("run_*")
            run_count = sum(bool(fold.is_dir()) for fold in other_run_dirs)
            new_run_dir = Path(Path(log_dir) / f"run_{run_count}")
            new_run_dir.mkdir()
            for _file in files_to_move:
                move(str(_file), str(new_run_dir))
Ejemplo n.º 15
0
def exp_manager(
        trainer: Trainer,
        cfg: Optional[Union[DictConfig, Dict]] = None) -> Optional[Path]:
    """
    exp_manager is a helper function used to manage folders for experiments. It follows the pytorch lightning \
    paradigm of exp_dir/model_or_experiment_name/version. If the lightning trainer has a logger, exp_manager will \
    get exp_dir, name, and version from the logger. Otherwise, it will use the exp_dir and name arguments to create \
    the logging directory. exp_manager also allows for explicit folder creation via explicit_log_dir.

    The version can be a datetime string or an integer. Datetime version can be disabled if you use_datetime_version \
    is set to False. It optionally creates TensorBoardLogger, WandBLogger, ModelCheckpoint objects from pytorch \
    lightning. It copies sys.argv, and git information if available to the logging directory. It creates a log file \
    for each process to log their output into.

    exp_manager additionally has a resume feature (resume_if_exists) which can be used to continuing training from \
    the constructed log_dir. When you need to continue the training repeatedly (like on a cluster which you need \
    multiple consecutive jobs), you need to avoid creating the version folders. Therefore, from v1.0.0, when \
    resume_if_exists is set to True, creating the version folders is ignored.

    Parameters
    ----------
    trainer: The lightning trainer object.
    cfg: Can have the following keys:
        - explicit_log_dir: Can be used to override exp_dir/name/version folder creation. Defaults to None, which \
        will use exp_dir, name, and version to construct the logging directory.
        - exp_dir: The base directory to create the logging directory. Defaults to None, which logs to \
         ./mridc_experiments.
        - name: The name of the experiment. Defaults to None which turns into "default" via name = name or "default".
        - version: The version of the experiment. Defaults to None which uses either a datetime string or lightning's \
         TensorboardLogger system of using version_{int}.
        - use_datetime_version: Whether to use a datetime string for version. Defaults to True.
        - resume_if_exists: Whether this experiment is resuming from a previous run. If True, it sets \
        trainer._checkpoint_connector.resume_from_checkpoint_fit_path so that the trainer should auto-resume. \
        exp_manager will move files under log_dir to log_dir/run_{int}. Defaults to False. From v1.0.0, when \
        resume_if_exists is True, we would not create version folders to make it easier to find the log folder for \
        next runs.
        - resume_past_end: exp_manager errors out if resume_if_exists is True and a checkpoint matching \*end.ckpt \
        indicating a previous training run fully completed. This behaviour can be disabled, in which case the \
        \*end.ckpt will be loaded by setting resume_past_end to True. Defaults to False.
        - resume_ignore_no_checkpoint: exp_manager errors out if resume_if_exists is True and no checkpoint could be \
         found. This behaviour can be disabled, in which case exp_manager will print a message and continue without \
         restoring, by setting resume_ignore_no_checkpoint to True. Defaults to False.
        - create_tensorboard_logger: Whether to create a tensorboard logger and attach it to the pytorch lightning \
        trainer. Defaults to True.
        - summary_writer_kwargs: A dictionary of kwargs that can be passed to lightning's TensorboardLogger class. \
        Note that log_dir is passed by exp_manager and cannot exist in this dict. Defaults to None.
        - create_wandb_logger: Whether to create a Weights and Biases logger and attach it to the pytorch lightning \
        trainer. Defaults to False.
        - wandb_logger_kwargs: A dictionary of kwargs that can be passed to lightning's WandBLogger class. Note that \
         name and project are required parameters if create_wandb_logger is True. Defaults to None.
        - create_checkpoint_callback: Whether to create a ModelCheckpoint callback and attach it to the pytorch \
        lightning trainer. The ModelCheckpoint saves the top 3 models with the best "val_loss", the most recent \
        checkpoint under \*last.ckpt, and the final checkpoint after training completes under \*end.ckpt. \
        Defaults to True.
        - files_to_copy: A list of files to copy to the experiment logging directory. Defaults to None which copies \
        no files.
        - log_local_rank_0_only: Whether to only create log files for local rank 0. Defaults to False. Set this to \
        True if you are using DDP with many GPUs and do not want many log files in your exp dir.
        - log_global_rank_0_only: Whether to only create log files for global rank 0. Defaults to False. Set this to \
        True if you are using DDP with many GPUs and do not want many log files in your exp dir.

    Returns
    -------
    The final logging directory where logging files are saved. Usually the concatenation of exp_dir, name, and version.
    """
    # Add rank information to logger
    # Note: trainer.global_rank and trainer.is_global_zero are not set until trainer.fit, so have to hack around it
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    global_rank = trainer.node_rank * trainer.num_devices + local_rank
    logging.rank = global_rank

    if cfg is None:
        logging.error(
            "exp_manager did not receive a cfg argument. It will be disabled.")
        return None

    if trainer.fast_dev_run:
        logging.info(
            "Trainer was called with fast_dev_run. exp_manager will return without any functionality."
        )
        return None

    # Ensure passed cfg is compliant with ExpManagerConfig
    schema = OmegaConf.structured(ExpManagerConfig)
    if isinstance(cfg, dict):
        cfg = OmegaConf.create(cfg)
    elif not isinstance(cfg, DictConfig):
        raise ValueError(
            f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig"
        )
    cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True))
    cfg = OmegaConf.merge(schema, cfg)

    error_checks(
        trainer, cfg
    )  # Ensures that trainer options are compliant with MRIDC and exp_manager arguments

    log_dir, exp_dir, name, version = get_log_dir(
        trainer=trainer,
        exp_dir=cfg.exp_dir,
        name=cfg.name,
        version=cfg.version,
        explicit_log_dir=cfg.explicit_log_dir,
        use_datetime_version=cfg.use_datetime_version,
        resume_if_exists=cfg.resume_if_exists,
    )

    if cfg.resume_if_exists:
        check_resume(trainer, str(log_dir), cfg.resume_past_end,
                     cfg.resume_ignore_no_checkpoint)

    checkpoint_name = name
    # If name returned from get_log_dir is "", use cfg.name for checkpointing
    if checkpoint_name is None or checkpoint_name == "":
        checkpoint_name = cfg.name or "default"
    cfg.name = name  # Used for configure_loggers so that the log_dir is properly set even if name is ""
    cfg.version = version

    # update app_state with log_dir, exp_dir, etc
    app_state = AppState()
    app_state.log_dir = log_dir
    app_state.exp_dir = exp_dir
    app_state.name = name
    app_state.version = version
    app_state.checkpoint_name = checkpoint_name
    app_state.create_checkpoint_callback = cfg.create_checkpoint_callback
    app_state.checkpoint_callback_params = cfg.checkpoint_callback_params

    # Create the logging directory if it does not exist
    os.makedirs(
        log_dir, exist_ok=True
    )  # Cannot limit creation to global zero as all ranks write to own log file
    logging.info(f"Experiments will be logged at {log_dir}")
    trainer._default_root_dir = log_dir

    if cfg.log_local_rank_0_only is True and cfg.log_global_rank_0_only is True:
        raise ValueError(
            "Cannot set both log_local_rank_0_only and log_global_rank_0_only to True."
            "Please set either one or neither.")

    # This is set if the env var MRIDC_TESTING is set to True.
    mridc_testing = get_envbool(MRIDC_ENV_VARNAME_TESTING, False)

    log_file = log_dir / f"mridc_log_globalrank-{global_rank}_localrank-{local_rank}.txt"

    # Handle logging to file. Logs local rank 0 only
    if local_rank == 0 and cfg.log_local_rank_0_only and not mridc_testing:
        logging.add_file_handler(log_file)
    elif global_rank == 0 and cfg.log_global_rank_0_only and mridc_testing:
        logging.add_file_handler(log_file)
    else:
        logging.add_file_handler(log_file)

    # For some reason, LearningRateLogger requires trainer to have a logger. Safer to create logger on all ranks
    # not just global rank 0.
    if cfg.create_tensorboard_logger or cfg.create_wandb_logger:
        configure_loggers(
            trainer,
            [Path(exp_dir)],
            cfg.name,
            cfg.version,
            cfg.create_tensorboard_logger,
            cfg.summary_writer_kwargs,
            cfg.create_wandb_logger,
            cfg.wandb_logger_kwargs,
        )

    # add loggers timing callbacks
    if cfg.log_step_timing:
        timing_callback = TimingCallback(
            timer_kwargs=cfg.step_timing_kwargs or {})
        trainer.callbacks.insert(0, timing_callback)

    if cfg.create_checkpoint_callback:
        configure_checkpointing(trainer, log_dir, checkpoint_name,
                                cfg.resume_if_exists,
                                cfg.checkpoint_callback_params)

    if is_global_rank_zero():
        # Move files_to_copy to folder and add git information if present
        if cfg.files_to_copy:
            for _file in cfg.files_to_copy:
                copy(Path(_file), log_dir)

        # Create files for cmd args and git info
        with open(log_dir / "cmd-args.log", "w", encoding="utf-8") as _file:
            _file.write(" ".join(sys.argv))

        # Try to get git hash
        git_repo, git_hash = get_git_hash()
        if git_repo:
            with open(log_dir / "git-info.log", "w",
                      encoding="utf-8") as _file:
                _file.write(f"commit hash: {git_hash}")
                _file.write(get_git_diff())

        # Add err_file logging to global_rank zero
        logging.add_err_file_handler(log_dir / "mridc_error_log.txt")

        # Add lightning file logging to global_rank zero
        add_filehandlers_to_pl_logger(log_dir / "lightning_logs.txt",
                                      log_dir / "mridc_error_log.txt")

    return log_dir
Ejemplo n.º 16
0
    def export(
        self,
        output: str,
        input_example=None,
        verbose=False,
        export_params=True,
        do_constant_folding=True,
        onnx_opset_version=None,
        try_script: bool = False,
        training=TrainingMode.EVAL,
        check_trace: bool = False,
        use_dynamic_axes: bool = True,
        dynamic_axes=None,
        check_tolerance=0.01,
    ):
        """
        Export the module to a file.

        Parameters
        ----------
        output: The output file path.
        input_example: A dictionary of input names and values.
        verbose: If True, print out the export process.
        export_params: If True, export the parameters of the module.
        do_constant_folding: If True, do constant folding.
        onnx_opset_version: The ONNX opset version to use.
        try_script: If True, try to export as TorchScript.
        training: Training mode for the export.
        check_trace: If True, check the trace of the exported model.
        use_dynamic_axes: If True, use dynamic axes for the export.
        dynamic_axes: A dictionary of input names and dynamic axes.
        check_tolerance: The tolerance for the check_trace.
        """
        my_args = locals().copy()
        my_args.pop("self")

        exportables = []
        for m in self.modules():  # type: ignore
            if isinstance(m, Exportable):
                exportables.append(m)

        qual_name = self.__module__ + "." + self.__class__.__qualname__
        format = get_export_format(output)
        output_descr = f"{qual_name} exported to {format}"

        # Pytorch's default for None is too low, can't pass None through
        if onnx_opset_version is None:
            onnx_opset_version = 13

        try:
            # Disable typechecks
            typecheck.set_typecheck_enabled(enabled=False)

            # Allow user to completely override forward method to export
            forward_method, old_forward_method = wrap_forward_method(self)

            # Set module mode
            with torch.onnx.select_model_mode_for_export(
                    self, training), torch.inference_mode(
                    ), torch.jit.optimized_execution(True):

                if input_example is None:
                    input_example = self.input_module.input_example()

                # Remove i/o examples from args we propagate to enclosed Exportables
                my_args.pop("output")
                my_args.pop("input_example")

                # Run (possibly overridden) prepare methods before calling forward()
                for ex in exportables:
                    ex._prepare_for_export(**my_args, noreplace=True)
                self._prepare_for_export(output=output,
                                         input_example=input_example,
                                         **my_args)

                input_list, input_dict = parse_input_example(input_example)
                input_names = self.input_names
                output_names = self.output_names
                output_example = tuple(self.forward(
                    *input_list, **input_dict))  # type: ignore

                jitted_model = None
                if try_script:
                    try:
                        jitted_model = torch.jit.script(self)
                    except Exception as e:
                        logging.error(f"jit.script() failed!\n{e}")

                if format == ExportFormat.TORCHSCRIPT:
                    if jitted_model is None:
                        jitted_model = torch.jit.trace_module(
                            self,
                            {
                                "forward":
                                tuple(input_list) + tuple(input_dict.values())
                            },
                            strict=True,
                            check_trace=check_trace,
                            check_tolerance=check_tolerance,
                        )
                    if not self.training:  # type: ignore
                        jitted_model = torch.jit.optimize_for_inference(
                            jitted_model)
                    if verbose:
                        logging.info(f"JIT code:\n{jitted_model.code}")
                    jitted_model.save(output)
                elif format == ExportFormat.ONNX:
                    if jitted_model is None:
                        jitted_model = self

                    # dynamic axis is a mapping from input/output_name => list of "dynamic" indices
                    if dynamic_axes is None and use_dynamic_axes:
                        dynamic_axes = get_dynamic_axes(
                            self.input_module.input_types, input_names)
                        dynamic_axes.update(
                            get_dynamic_axes(self.output_module.output_types,
                                             output_names))

                    torch.onnx.export(
                        jitted_model,
                        input_example,
                        output,
                        input_names=input_names,
                        output_names=output_names,
                        verbose=verbose,
                        export_params=export_params,
                        do_constant_folding=do_constant_folding,
                        dynamic_axes=dynamic_axes,
                        opset_version=onnx_opset_version,
                    )

                    if check_trace:
                        verify_runtime(output, input_list, input_dict,
                                       input_names, output_names,
                                       output_example)

                else:
                    raise ValueError(
                        f"Encountered unknown export format {format}.")
        finally:
            typecheck.set_typecheck_enabled(enabled=True)
            if forward_method:
                type(self).forward = old_forward_method  # type: ignore
            self._export_teardown()
        return [output], [output_descr]