Exemple #1
0
 def prepare(self, model_as_str, checkpoint="none", pretrained_key="none"):
     """Prepare the model , e.g. instantiate a pre-trained classifier from torchvision"""
     # e.g. model=config["model"]
     if pretrained_key is not "none":
         model = get_obj_from_str(model_as_str).from_pretrained(pretrained_key)
     else:
         model = get_obj_from_str(model_as_str)(self.model_config)
         if checkpoint is not "none":
             assert type(checkpoint) == str, 'please provide a path to a checkpoint'
             state = torch.load(checkpoint)["model"]
             model.load_state_dict(state)
             self.logger.info("Restored model from {}".format(checkpoint))
     return model
Exemple #2
0
def setup_loaders(labels, meta_dict):
    """Creates a map of key -> function pairs, which can be used to postprocess
    label values at each ``__getitem__`` call.

    Loaders defined in :attr:`meta_dict` supersede those definde in the label
    keys.

    Parameters
    ----------
    labels : dict(str, numpy.memmap)
        Labels contain all load-easy dataset relevant data. If the key follows
        the pattern ``name:loader``, this function will try to finde the
        corresponding loader in :attr:`DEFAULT_LOADERS`.
    meta_dict : dict
        A dictionary containing all dataset relevent information, which is the
        same for all examples. This function will try to find the entry
        ``loaders`` in the dictionary, which must contain another ``dict`` with
        ``name:loader`` pairs. Here ``loader`` must be either an entry in
        :attr:`DEFAULT_LOADERS` or a loadable import path.
        You can additionally define an entry ``loader_kwargs``, which must
        contain ``name:dict`` pairs. The dictionary is passed as keyword
        arguments to the loader corresponding to ``name``.

    Returns
    -------
    loaders : dict
        Name, function pairs, to apply loading logic based on the labels with
        the specified names.
    loader_kwargs : dict
        Name, dict pairs. The dicts are passed to the loader functions as
        keyword arguments.
    """

    loaders = {}
    loader_kwargs = {}

    for k in labels.keys():
        k, l = loader_from_key(k)
        if l is not None:
            loaders[k] = l

    meta_loaders = retrieve(meta_dict, "loaders", default={})
    meta_loader_kwargs = retrieve(meta_dict, "loader_kwargs", default={})

    loaders.update(meta_loaders)

    for k, l in loaders.items():
        if l in DEFAULT_LOADERS:
            loaders[k] = DEFAULT_LOADERS[l]
        else:
            loaders[k] = get_obj_from_str(l)

        if k in meta_loader_kwargs:
            loader_kwargs[k] = meta_loader_kwargs[k]
        else:
            loader_kwargs[k] = {}

    return loaders, loader_kwargs
Exemple #3
0
 def init_first_stage(self, config):
     subconfig = config["subconfig"]
     self.first_stage = get_obj_from_str(config["model"])(subconfig)
     if "checkpoint" in config:
         checkpoint = config["checkpoint"]
         state = torch.load(checkpoint)["model"]
         self.first_stage.load_state_dict(state)
         self.logger.info("Restored first stage from {}".format(checkpoint))
     self.first_stage.to(self.device)
     self.first_stage.eval()
Exemple #4
0
def getSeqDataset(config):
    """This allows to not define a dataset class, but use a baseclass and a
    `length` and `step` parameter in the supplied `config` to load and
    sequentialize a dataset.

    A config passed to edflow would the look like this:

    .. code-block:: yaml

        dataset: edflow.data.dataset.getSeqDataSet
        model: Some Model
        iterator: Some Iterator

        seqdataset:
                dataset: import.path.to.your.basedataset
                length: 3
                step: 1
                fid_key: fid
                base_step: 1

    ``getSeqDataSet`` will import the base ``dataset`` and pass it to
    :class:`SequenceDataset` together with ``length`` and ``step`` to
    make the actually used dataset.

    Parameters
    ----------
    config : dict
	An edflow config, with at least the keys
            ``seqdataset`` and nested inside it ``dataset``, ``seq_length`` and
            ``seq_step``.

    Returns
    -------
    :class:`SequenceDataset`
        A Sequence Dataset based on the basedataset.
    """

    ks = "seqdataset"
    base_dset = get_obj_from_str(config[ks]["dataset"])
    base_dset = base_dset(config=config)

    S = SequenceDataset(
        base_dset,
        config[ks]["length"],
        config[ks]["step"],
        fid_key=config[ks]["fid_key"],
        base_step=config[ks].get("base_step", 1),
    )

    return S
Exemple #5
0
 def init_ae(self, config):
     """Initializes autoencoder"""
     if "pretrained_ae_key" in config:
         # load from the 'autoencoders' repo
         ae_key = config["pretrained_ae_key"]
         self.autoencoder = autoencoders.get_model(ae_key)
         self.logger.info("Loaded autoencoder {} from 'autoencoders'".format(ae_key))
     else:
         # in case you want to use a checkpoint different from the one provided.
         subconfig = config["subconfig"]
         self.autoencoder = get_obj_from_str(config["model"])(subconfig)
         if "checkpoint" in config:
             checkpoint = config["checkpoint"]
             state = torch.load(checkpoint)["model"]
             self.autoencoder.load_state_dict(state)
             self.logger.info("Restored autoencoder from {}".format(checkpoint))
     self.autoencoder.to(self.device)
     self.autoencoder.eval()
Exemple #6
0
    def __init__(self, root):
        super().__init__(root)

        base_import = retrieve(self.meta, "base_dset")
        base_kwargs = retrieve(self.meta, "base_kwargs")
        self.base = get_obj_from_str(base_import)(**base_kwargs)
        self.base.append_labels = False

        views = retrieve(self.meta, "views", default="view")

        def get_label(key):
            return retrieve(self.labels, key)

        self.views = walk(views, get_label)

        if not os.path.exists(os.path.join(root, ".constructed.txt")):

            def constructor(name, view):
                folder_name = name
                savefolder = os.path.join(root, "labels", folder_name)

                os.makedirs(savefolder, exist_ok=True)

                for key, label in tqdm(self.base.labels.items(),
                                       desc=f"Exporting View {name}"):

                    savepath = os.path.join(root, "labels", name)
                    label_view = np.take(label, view, axis=0)
                    store_label_mmap(label_view, savepath, key)

            walk(self.views, constructor, pass_key=True)

            with open(os.path.join(root, ".constructed.txt"), "w+") as cf:
                cf.write("Do not delete, this reduces loading times.\n"
                         "If you need to re-render the view, you can safely "
                         "delete this file.")

            # Re-initialize as we need to load the labels again.
            super().__init__(root)
Exemple #7
0
 def init_greybox(self, config):
     """Initializes a provided 'Greybox', i.e. a model one wants to interpret/analyze."""
     self.greybox = get_obj_from_str(config["model"])(config)
     self.greybox.to(self.device)
     self.greybox.eval()
Exemple #8
0
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # wrap save and restore into a LambdaCheckpointHook
        self.ckpthook = LambdaCheckpointHook(
            root_path=ProjectManager.checkpoints,
            global_step_getter=self.get_global_step,
            global_step_setter=self.set_global_step,
            save=self.save,
            restore=self.restore,
            interval=set_default(self.config, "ckpt_freq", None),
            ckpt_zero=set_default(self.config, "ckpt_zero", False),
        )
        # write checkpoints after epoch or when interrupted during training
        if not self.config.get("test_mode", False):
            self.hooks.append(self.ckpthook)

        ## hooks - disabled unless -t is specified

        # execute train ops
        self._train_ops = set_default(self.config, "train_ops",
                                      ["train/train_op"])
        train_hook = ExpandHook(paths=self._train_ops, interval=1)
        self.hooks.append(train_hook)

        # log train/step_ops/log_ops in increasing intervals
        self._log_ops = set_default(self.config, "log_ops",
                                    ["train/log_op", "validation/log_op"])
        self.loghook = LoggingHook(paths=self._log_ops,
                                   root_path=ProjectManager.train,
                                   interval=1)
        self.ihook = IntervalHook(
            [self.loghook],
            interval=set_default(self.config, "start_log_freq", 1),
            modify_each=1,
            max_interval=set_default(self.config, "log_freq", 1000),
            get_step=self.get_global_step,
        )
        self.hooks.append(self.ihook)

        # setup logging integrations
        if not self.config.get("test_mode", False):
            default_wandb_logging = {
                "active": False,
                "handlers": ["scalars", "images"]
            }
            wandb_logging = set_default(self.config, "integrations/wandb",
                                        default_wandb_logging)
            if wandb_logging["active"]:
                import wandb
                from edflow.hooks.logging_hooks.wandb_handler import (
                    log_wandb,
                    log_wandb_images,
                )

                os.environ["WANDB_RESUME"] = "allow"
                os.environ["WANDB_RUN_ID"] = ProjectManager.root.strip(
                    "/").replace("/", "-")
                wandb_project = set_default(self.config,
                                            "integrations/wandb/project", None)
                wandb_entity = set_default(self.config,
                                           "integrations/wandb/entity", None)
                wandb.init(
                    name=ProjectManager.root,
                    config=self.config,
                    project=wandb_project,
                    entity=wandb_entity,
                )

                handlers = set_default(
                    self.config,
                    "integrations/wandb/handlers",
                    default_wandb_logging["handlers"],
                )
                if "scalars" in handlers:
                    self.loghook.handlers["scalars"].append(log_wandb)
                if "images" in handlers:
                    self.loghook.handlers["images"].append(log_wandb_images)

            default_tensorboard_logging = {
                "active": False,
                "handlers": ["scalars", "images", "figures"],
            }
            tensorboard_logging = set_default(self.config,
                                              "integrations/tensorboard",
                                              default_tensorboard_logging)
            if tensorboard_logging["active"]:
                try:
                    from torch.utils.tensorboard import SummaryWriter
                except:
                    from tensorboardX import SummaryWriter

                from edflow.hooks.logging_hooks.tensorboard_handler import (
                    log_tensorboard_config,
                    log_tensorboard_scalars,
                    log_tensorboard_images,
                    log_tensorboard_figures,
                )

                self.tensorboard_writer = SummaryWriter(ProjectManager.root)
                log_tensorboard_config(self.tensorboard_writer, self.config,
                                       self.get_global_step())
                handlers = set_default(
                    self.config,
                    "integrations/tensorboard/handlers",
                    default_tensorboard_logging["handlers"],
                )
                if "scalars" in handlers:
                    self.loghook.handlers["scalars"].append(
                        lambda *args, **kwargs: log_tensorboard_scalars(
                            self.tensorboard_writer, *args, **kwargs))
                if "images" in handlers:
                    self.loghook.handlers["images"].append(
                        lambda *args, **kwargs: log_tensorboard_images(
                            self.tensorboard_writer, *args, **kwargs))
                if "figures" in handlers:
                    self.loghook.handlers["figures"].append(
                        lambda *args, **kwargs: log_tensorboard_figures(
                            self.tensorboard_writer, *args, **kwargs))
        ## epoch hooks

        # evaluate validation/step_ops/eval_op after each epoch
        self._eval_op = set_default(self.config, "eval_hook/eval_op",
                                    "validation/eval_op")
        _eval_callbacks = set_default(self.config, "eval_hook/eval_callbacks",
                                      dict())
        if not isinstance(_eval_callbacks, dict):
            _eval_callbacks = {"cb": _eval_callbacks}
        eval_callbacks = dict()
        for k in _eval_callbacks:
            eval_callbacks[k] = _eval_callbacks[k]
        if hasattr(self, "callbacks"):
            iterator_callbacks = retrieve(self.callbacks,
                                          "eval_op",
                                          default=dict())
            for k in iterator_callbacks:
                import_path = get_str_from_obj(iterator_callbacks[k])
                set_value(self.config, "eval_hook/eval_callbacks/{}".format(k),
                          import_path)
                eval_callbacks[k] = import_path
        if hasattr(self.model, "callbacks"):
            model_callbacks = retrieve(self.model.callbacks,
                                       "eval_op",
                                       default=dict())
            for k in model_callbacks:
                import_path = get_str_from_obj(model_callbacks[k])
                set_value(self.config, "eval_hook/eval_callbacks/{}".format(k),
                          import_path)
                eval_callbacks[k] = import_path
        callback_handler = None
        if not self.config.get("test_mode", False):
            callback_handler = lambda results, paths: self.loghook(
                results=results,
                step=self.get_global_step(),
                paths=paths,
            )

        # offer option to run eval functor:
        # overwrite step op to only include the evaluation of the functor and
        # overwrite callbacks to only include the callbacks of the functor
        if self.config.get("test_mode",
                           False) and "eval_functor" in self.config:
            # offer option to use eval functor for evaluation
            eval_functor = get_obj_from_str(
                self.config["eval_functor"])(config=self.config)
            self.step_ops = lambda: {"eval_op": eval_functor}
            eval_callbacks = dict()
            if hasattr(eval_functor, "callbacks"):
                for k in eval_functor.callbacks:
                    eval_callbacks[k] = get_str_from_obj(
                        eval_functor.callbacks[k])
            set_value(self.config, "eval_hook/eval_callbacks", eval_callbacks)
        self.evalhook = TemplateEvalHook(
            datasets=self.datasets,
            step_getter=self.get_global_step,
            keypath=self._eval_op,
            config=self.config,
            callbacks=eval_callbacks,
            callback_handler=callback_handler,
        )
        self.epoch_hooks.append(self.evalhook)
Exemple #9
0
def train(config, root, checkpoint=None, retrain=False, debug=False):
    """Run training. Loads model, iterator and dataset according to config."""
    from edflow.iterators.batches import make_batches

    # disable integrations in debug mode
    if debug:
        if retrieve(config, "debug/disable_integrations", default=True):
            integrations = retrieve(config, "integrations", default=dict())
            for k in integrations:
                config["integrations"][k]["active"] = False
        max_steps = retrieve(config, "debug/max_steps", default=5 * 2)
        if max_steps > 0:
            config["num_steps"] = max_steps

    # backwards compatibility
    if not "datasets" in config:
        config["datasets"] = {"train": config["dataset"]}
        if "validation_dataset" in config:
            config["datasets"]["validation"] = config["validation_dataset"]

    log.set_log_target("train")
    logger = log.get_logger("train")
    logger.info("Starting Training.")

    model = get_obj_from_str(config["model"])
    iterator = get_obj_from_str(config["iterator"])
    datasets = dict(
        (split, get_obj_from_str(config["datasets"][split]))
        for split in config["datasets"]
    )

    logger.info("Instantiating datasets.")
    for split in datasets:
        datasets[split] = datasets[split](config=config)
        datasets[split].expand = True
        logger.info("{} dataset size: {}".format(split, len(datasets[split])))
        if debug:
            max_examples = retrieve(
                config, "debug/max_examples", default=5 * config["batch_size"]
            )
            if max_examples > 0:
                logger.info(
                    "Monkey patching {} dataset __len__ to {} examples".format(
                        split, max_examples
                    )
                )
                type(datasets[split]).__len__ = lambda self: max_examples

    n_processes = config.get("n_data_processes", min(16, config["batch_size"]))
    n_prefetch = config.get("n_prefetch", 1)
    logger.info("Building batches.")
    batches = dict()
    for split in datasets:
        batches[split] = make_batches(
            datasets[split],
            batch_size=config["batch_size"],
            shuffle=True,
            n_processes=n_processes,
            n_prefetch=n_prefetch,
            error_on_timeout=config.get("error_on_timeout", False),
        )
    main_split = "train"
    try:
        if "num_steps" in config:
            # set number of epochs to perform at least num_steps steps
            steps_per_epoch = len(datasets[main_split]) / config["batch_size"]
            num_epochs = config["num_steps"] / steps_per_epoch
            config["num_epochs"] = math.ceil(num_epochs)
        else:
            steps_per_epoch = len(datasets[main_split]) / config["batch_size"]
            num_steps = config["num_epochs"] * steps_per_epoch
            config["num_steps"] = math.ceil(num_steps)

        logger.info("Instantiating model.")
        model = model(config)
        if not "hook_freq" in config:
            config["hook_freq"] = 1
        compat_kwargs = dict(
            hook_freq=config["hook_freq"], num_epochs=config["num_epochs"]
        )
        logger.info("Instantiating iterator.")
        iterator = iterator(config, root, model, datasets=datasets, **compat_kwargs)

        logger.info("Initializing model.")
        if checkpoint is not None:
            iterator.initialize(checkpoint_path=checkpoint)
        else:
            iterator.initialize()

        if retrain:
            iterator.reset_global_step()

        # save current config
        logger.info("Starting Training with config:\n{}".format(yaml.dump(config)))
        cpath = _save_config(config, prefix="train")
        logger.info("Saved config at {}".format(cpath))

        logger.info("Iterating.")
        iterator.iterate(batches)
    finally:
        for split in batches:
            batches[split].finalize()
Exemple #10
0
def test(config, root, checkpoint=None, nogpu=False, bar_position=0, debug=False):
    """Run tests. Loads model, iterator and dataset from config."""
    from edflow.iterators.batches import make_batches

    # backwards compatibility
    if not "datasets" in config:
        config["datasets"] = {"train": config["dataset"]}
        if "validation_dataset" in config:
            config["datasets"]["validation"] = config["validation_dataset"]

    log.set_log_target("latest_eval")
    logger = log.get_logger("test")
    logger.info("Starting Evaluation.")

    if "test_batch_size" in config:
        config["batch_size"] = config["test_batch_size"]
    if "test_mode" not in config:
        config["test_mode"] = True

    model = get_obj_from_str(config["model"])
    iterator = get_obj_from_str(config["iterator"])
    datasets = dict(
        (split, get_obj_from_str(config["datasets"][split]))
        for split in config["datasets"]
    )

    logger.info("Instantiating datasets.")
    for split in datasets:
        datasets[split] = datasets[split](config=config)
        datasets[split].expand = True
        logger.info("{} dataset size: {}".format(split, len(datasets[split])))
        if debug:
            max_examples = retrieve(
                config, "debug/max_examples", default=5 * config["batch_size"]
            )
            if max_examples > 0:
                logger.info(
                    "Monkey patching {} dataset __len__ to {} examples".format(
                        split, max_examples
                    )
                )
                type(datasets[split]).__len__ = lambda self: max_examples

    n_processes = config.get("n_data_processes", min(16, config["batch_size"]))
    n_prefetch = config.get("n_prefetch", 1)
    logger.info("Building batches.")
    batches = dict()
    for split in datasets:
        batches[split] = make_batches(
            datasets[split],
            batch_size=config["batch_size"],
            shuffle=False,
            n_processes=n_processes,
            n_prefetch=n_prefetch,
            error_on_timeout=config.get("error_on_timeout", False),
        )
    try:
        logger.info("Initializing model.")
        model = model(config)

        config["hook_freq"] = 1
        config["num_epochs"] = 1
        config["nogpu"] = nogpu
        compat_kwargs = dict(
            hook_freq=config["hook_freq"],
            bar_position=bar_position,
            nogpu=config["nogpu"],
            num_epochs=config["num_epochs"],
        )
        iterator = iterator(config, root, model, datasets=datasets, **compat_kwargs)

        logger.info("Initializing model.")
        if checkpoint is not None:
            iterator.initialize(checkpoint_path=checkpoint)
        else:
            iterator.initialize()

        # save current config
        logger.info("Starting Evaluation with config:\n{}".format(yaml.dump(config)))
        prefix = "eval"
        if bar_position > 0:
            prefix = prefix + str(bar_position)
        cpath = _save_config(config, prefix=prefix)
        logger.info("Saved config at {}".format(cpath))

        logger.info("Iterating")
        while True:
            iterator.iterate(batches)
            if not config.get("eval_forever", False):
                break
    finally:
        for split in batches:
            batches[split].finalize()
Exemple #11
0
def standalone_eval_meta_dset(path_to_meta_dir,
                              callbacks,
                              additional_kwargs={},
                              other_config=None):
    """Runs all given callbacks on the data in the :class:`EvalDataFolder`
    constructed from the given csv.abs

    Parameters
    ----------
    path_to_csv : str
        Path to the csv file.
    callbacks : dict(name: str or Callable)
        Import commands used to construct the functions applied to the Data
        extracted from :attr:`path_to_csv`.
    additional_kwargs : dict
        Keypath-value pairs added to the config, which is extracted from
        the ``model_outputs.csv``. These will overwrite parameters in the
        original config extracted from the csv.
    other_config : str
        Path to additional config used to update the existing one as taken from
        the ``model_outputs.csv`` . Cannot overwrite the dataset. Only used for
        callbacks. Parameters in this other config will overwrite the
        parameters in the original config and those of the commandline
        arguments.

    Returns
    -------
    outputs: dict
        The collected outputs of the callbacks.
    """

    from edflow.util import get_obj_from_str
    from edflow.config import update_config
    import yaml

    if other_config is not None:
        with open(other_config, "r") as f:
            other_config = yaml.full_load(f)
    else:
        other_config = {}

    out_data = MetaDataset(path_to_meta_dir)
    out_data.expand = True
    out_data.append_labels = True

    config = out_data.meta

    # backwards compatibility
    if not "datasets" in config:
        config["datasets"] = {"train": config["dataset"]}
        if "validation_dataset" in config:
            config["datasets"]["validation"] = config["validation_dataset"]
    datasets = dict((split, get_obj_from_str(config["datasets"][split]))
                    for split in config["datasets"])
    # TODO fix hardcoded dataset
    in_data = datasets["validation"](config=config)
    in_data.expand = True

    update_config(config, additional_kwargs)
    config.update(other_config)

    config_callbacks, callback_kwargs = config2cbdict(config)
    callbacks.update(config_callbacks)

    callbacks = load_callbacks(callbacks)

    root = os.path.dirname(path_to_meta_dir)

    # TODO handle logging of return values
    outputs = apply_callbacks(callbacks, root, in_data, out_data, config,
                              callback_kwargs)

    return outputs