def prepare(self, model_as_str, checkpoint="none", pretrained_key="none"): """Prepare the model , e.g. instantiate a pre-trained classifier from torchvision""" # e.g. model=config["model"] if pretrained_key is not "none": model = get_obj_from_str(model_as_str).from_pretrained(pretrained_key) else: model = get_obj_from_str(model_as_str)(self.model_config) if checkpoint is not "none": assert type(checkpoint) == str, 'please provide a path to a checkpoint' state = torch.load(checkpoint)["model"] model.load_state_dict(state) self.logger.info("Restored model from {}".format(checkpoint)) return model
def setup_loaders(labels, meta_dict): """Creates a map of key -> function pairs, which can be used to postprocess label values at each ``__getitem__`` call. Loaders defined in :attr:`meta_dict` supersede those definde in the label keys. Parameters ---------- labels : dict(str, numpy.memmap) Labels contain all load-easy dataset relevant data. If the key follows the pattern ``name:loader``, this function will try to finde the corresponding loader in :attr:`DEFAULT_LOADERS`. meta_dict : dict A dictionary containing all dataset relevent information, which is the same for all examples. This function will try to find the entry ``loaders`` in the dictionary, which must contain another ``dict`` with ``name:loader`` pairs. Here ``loader`` must be either an entry in :attr:`DEFAULT_LOADERS` or a loadable import path. You can additionally define an entry ``loader_kwargs``, which must contain ``name:dict`` pairs. The dictionary is passed as keyword arguments to the loader corresponding to ``name``. Returns ------- loaders : dict Name, function pairs, to apply loading logic based on the labels with the specified names. loader_kwargs : dict Name, dict pairs. The dicts are passed to the loader functions as keyword arguments. """ loaders = {} loader_kwargs = {} for k in labels.keys(): k, l = loader_from_key(k) if l is not None: loaders[k] = l meta_loaders = retrieve(meta_dict, "loaders", default={}) meta_loader_kwargs = retrieve(meta_dict, "loader_kwargs", default={}) loaders.update(meta_loaders) for k, l in loaders.items(): if l in DEFAULT_LOADERS: loaders[k] = DEFAULT_LOADERS[l] else: loaders[k] = get_obj_from_str(l) if k in meta_loader_kwargs: loader_kwargs[k] = meta_loader_kwargs[k] else: loader_kwargs[k] = {} return loaders, loader_kwargs
def init_first_stage(self, config): subconfig = config["subconfig"] self.first_stage = get_obj_from_str(config["model"])(subconfig) if "checkpoint" in config: checkpoint = config["checkpoint"] state = torch.load(checkpoint)["model"] self.first_stage.load_state_dict(state) self.logger.info("Restored first stage from {}".format(checkpoint)) self.first_stage.to(self.device) self.first_stage.eval()
def getSeqDataset(config): """This allows to not define a dataset class, but use a baseclass and a `length` and `step` parameter in the supplied `config` to load and sequentialize a dataset. A config passed to edflow would the look like this: .. code-block:: yaml dataset: edflow.data.dataset.getSeqDataSet model: Some Model iterator: Some Iterator seqdataset: dataset: import.path.to.your.basedataset length: 3 step: 1 fid_key: fid base_step: 1 ``getSeqDataSet`` will import the base ``dataset`` and pass it to :class:`SequenceDataset` together with ``length`` and ``step`` to make the actually used dataset. Parameters ---------- config : dict An edflow config, with at least the keys ``seqdataset`` and nested inside it ``dataset``, ``seq_length`` and ``seq_step``. Returns ------- :class:`SequenceDataset` A Sequence Dataset based on the basedataset. """ ks = "seqdataset" base_dset = get_obj_from_str(config[ks]["dataset"]) base_dset = base_dset(config=config) S = SequenceDataset( base_dset, config[ks]["length"], config[ks]["step"], fid_key=config[ks]["fid_key"], base_step=config[ks].get("base_step", 1), ) return S
def init_ae(self, config): """Initializes autoencoder""" if "pretrained_ae_key" in config: # load from the 'autoencoders' repo ae_key = config["pretrained_ae_key"] self.autoencoder = autoencoders.get_model(ae_key) self.logger.info("Loaded autoencoder {} from 'autoencoders'".format(ae_key)) else: # in case you want to use a checkpoint different from the one provided. subconfig = config["subconfig"] self.autoencoder = get_obj_from_str(config["model"])(subconfig) if "checkpoint" in config: checkpoint = config["checkpoint"] state = torch.load(checkpoint)["model"] self.autoencoder.load_state_dict(state) self.logger.info("Restored autoencoder from {}".format(checkpoint)) self.autoencoder.to(self.device) self.autoencoder.eval()
def __init__(self, root): super().__init__(root) base_import = retrieve(self.meta, "base_dset") base_kwargs = retrieve(self.meta, "base_kwargs") self.base = get_obj_from_str(base_import)(**base_kwargs) self.base.append_labels = False views = retrieve(self.meta, "views", default="view") def get_label(key): return retrieve(self.labels, key) self.views = walk(views, get_label) if not os.path.exists(os.path.join(root, ".constructed.txt")): def constructor(name, view): folder_name = name savefolder = os.path.join(root, "labels", folder_name) os.makedirs(savefolder, exist_ok=True) for key, label in tqdm(self.base.labels.items(), desc=f"Exporting View {name}"): savepath = os.path.join(root, "labels", name) label_view = np.take(label, view, axis=0) store_label_mmap(label_view, savepath, key) walk(self.views, constructor, pass_key=True) with open(os.path.join(root, ".constructed.txt"), "w+") as cf: cf.write("Do not delete, this reduces loading times.\n" "If you need to re-render the view, you can safely " "delete this file.") # Re-initialize as we need to load the labels again. super().__init__(root)
def init_greybox(self, config): """Initializes a provided 'Greybox', i.e. a model one wants to interpret/analyze.""" self.greybox = get_obj_from_str(config["model"])(config) self.greybox.to(self.device) self.greybox.eval()
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # wrap save and restore into a LambdaCheckpointHook self.ckpthook = LambdaCheckpointHook( root_path=ProjectManager.checkpoints, global_step_getter=self.get_global_step, global_step_setter=self.set_global_step, save=self.save, restore=self.restore, interval=set_default(self.config, "ckpt_freq", None), ckpt_zero=set_default(self.config, "ckpt_zero", False), ) # write checkpoints after epoch or when interrupted during training if not self.config.get("test_mode", False): self.hooks.append(self.ckpthook) ## hooks - disabled unless -t is specified # execute train ops self._train_ops = set_default(self.config, "train_ops", ["train/train_op"]) train_hook = ExpandHook(paths=self._train_ops, interval=1) self.hooks.append(train_hook) # log train/step_ops/log_ops in increasing intervals self._log_ops = set_default(self.config, "log_ops", ["train/log_op", "validation/log_op"]) self.loghook = LoggingHook(paths=self._log_ops, root_path=ProjectManager.train, interval=1) self.ihook = IntervalHook( [self.loghook], interval=set_default(self.config, "start_log_freq", 1), modify_each=1, max_interval=set_default(self.config, "log_freq", 1000), get_step=self.get_global_step, ) self.hooks.append(self.ihook) # setup logging integrations if not self.config.get("test_mode", False): default_wandb_logging = { "active": False, "handlers": ["scalars", "images"] } wandb_logging = set_default(self.config, "integrations/wandb", default_wandb_logging) if wandb_logging["active"]: import wandb from edflow.hooks.logging_hooks.wandb_handler import ( log_wandb, log_wandb_images, ) os.environ["WANDB_RESUME"] = "allow" os.environ["WANDB_RUN_ID"] = ProjectManager.root.strip( "/").replace("/", "-") wandb_project = set_default(self.config, "integrations/wandb/project", None) wandb_entity = set_default(self.config, "integrations/wandb/entity", None) wandb.init( name=ProjectManager.root, config=self.config, project=wandb_project, entity=wandb_entity, ) handlers = set_default( self.config, "integrations/wandb/handlers", default_wandb_logging["handlers"], ) if "scalars" in handlers: self.loghook.handlers["scalars"].append(log_wandb) if "images" in handlers: self.loghook.handlers["images"].append(log_wandb_images) default_tensorboard_logging = { "active": False, "handlers": ["scalars", "images", "figures"], } tensorboard_logging = set_default(self.config, "integrations/tensorboard", default_tensorboard_logging) if tensorboard_logging["active"]: try: from torch.utils.tensorboard import SummaryWriter except: from tensorboardX import SummaryWriter from edflow.hooks.logging_hooks.tensorboard_handler import ( log_tensorboard_config, log_tensorboard_scalars, log_tensorboard_images, log_tensorboard_figures, ) self.tensorboard_writer = SummaryWriter(ProjectManager.root) log_tensorboard_config(self.tensorboard_writer, self.config, self.get_global_step()) handlers = set_default( self.config, "integrations/tensorboard/handlers", default_tensorboard_logging["handlers"], ) if "scalars" in handlers: self.loghook.handlers["scalars"].append( lambda *args, **kwargs: log_tensorboard_scalars( self.tensorboard_writer, *args, **kwargs)) if "images" in handlers: self.loghook.handlers["images"].append( lambda *args, **kwargs: log_tensorboard_images( self.tensorboard_writer, *args, **kwargs)) if "figures" in handlers: self.loghook.handlers["figures"].append( lambda *args, **kwargs: log_tensorboard_figures( self.tensorboard_writer, *args, **kwargs)) ## epoch hooks # evaluate validation/step_ops/eval_op after each epoch self._eval_op = set_default(self.config, "eval_hook/eval_op", "validation/eval_op") _eval_callbacks = set_default(self.config, "eval_hook/eval_callbacks", dict()) if not isinstance(_eval_callbacks, dict): _eval_callbacks = {"cb": _eval_callbacks} eval_callbacks = dict() for k in _eval_callbacks: eval_callbacks[k] = _eval_callbacks[k] if hasattr(self, "callbacks"): iterator_callbacks = retrieve(self.callbacks, "eval_op", default=dict()) for k in iterator_callbacks: import_path = get_str_from_obj(iterator_callbacks[k]) set_value(self.config, "eval_hook/eval_callbacks/{}".format(k), import_path) eval_callbacks[k] = import_path if hasattr(self.model, "callbacks"): model_callbacks = retrieve(self.model.callbacks, "eval_op", default=dict()) for k in model_callbacks: import_path = get_str_from_obj(model_callbacks[k]) set_value(self.config, "eval_hook/eval_callbacks/{}".format(k), import_path) eval_callbacks[k] = import_path callback_handler = None if not self.config.get("test_mode", False): callback_handler = lambda results, paths: self.loghook( results=results, step=self.get_global_step(), paths=paths, ) # offer option to run eval functor: # overwrite step op to only include the evaluation of the functor and # overwrite callbacks to only include the callbacks of the functor if self.config.get("test_mode", False) and "eval_functor" in self.config: # offer option to use eval functor for evaluation eval_functor = get_obj_from_str( self.config["eval_functor"])(config=self.config) self.step_ops = lambda: {"eval_op": eval_functor} eval_callbacks = dict() if hasattr(eval_functor, "callbacks"): for k in eval_functor.callbacks: eval_callbacks[k] = get_str_from_obj( eval_functor.callbacks[k]) set_value(self.config, "eval_hook/eval_callbacks", eval_callbacks) self.evalhook = TemplateEvalHook( datasets=self.datasets, step_getter=self.get_global_step, keypath=self._eval_op, config=self.config, callbacks=eval_callbacks, callback_handler=callback_handler, ) self.epoch_hooks.append(self.evalhook)
def train(config, root, checkpoint=None, retrain=False, debug=False): """Run training. Loads model, iterator and dataset according to config.""" from edflow.iterators.batches import make_batches # disable integrations in debug mode if debug: if retrieve(config, "debug/disable_integrations", default=True): integrations = retrieve(config, "integrations", default=dict()) for k in integrations: config["integrations"][k]["active"] = False max_steps = retrieve(config, "debug/max_steps", default=5 * 2) if max_steps > 0: config["num_steps"] = max_steps # backwards compatibility if not "datasets" in config: config["datasets"] = {"train": config["dataset"]} if "validation_dataset" in config: config["datasets"]["validation"] = config["validation_dataset"] log.set_log_target("train") logger = log.get_logger("train") logger.info("Starting Training.") model = get_obj_from_str(config["model"]) iterator = get_obj_from_str(config["iterator"]) datasets = dict( (split, get_obj_from_str(config["datasets"][split])) for split in config["datasets"] ) logger.info("Instantiating datasets.") for split in datasets: datasets[split] = datasets[split](config=config) datasets[split].expand = True logger.info("{} dataset size: {}".format(split, len(datasets[split]))) if debug: max_examples = retrieve( config, "debug/max_examples", default=5 * config["batch_size"] ) if max_examples > 0: logger.info( "Monkey patching {} dataset __len__ to {} examples".format( split, max_examples ) ) type(datasets[split]).__len__ = lambda self: max_examples n_processes = config.get("n_data_processes", min(16, config["batch_size"])) n_prefetch = config.get("n_prefetch", 1) logger.info("Building batches.") batches = dict() for split in datasets: batches[split] = make_batches( datasets[split], batch_size=config["batch_size"], shuffle=True, n_processes=n_processes, n_prefetch=n_prefetch, error_on_timeout=config.get("error_on_timeout", False), ) main_split = "train" try: if "num_steps" in config: # set number of epochs to perform at least num_steps steps steps_per_epoch = len(datasets[main_split]) / config["batch_size"] num_epochs = config["num_steps"] / steps_per_epoch config["num_epochs"] = math.ceil(num_epochs) else: steps_per_epoch = len(datasets[main_split]) / config["batch_size"] num_steps = config["num_epochs"] * steps_per_epoch config["num_steps"] = math.ceil(num_steps) logger.info("Instantiating model.") model = model(config) if not "hook_freq" in config: config["hook_freq"] = 1 compat_kwargs = dict( hook_freq=config["hook_freq"], num_epochs=config["num_epochs"] ) logger.info("Instantiating iterator.") iterator = iterator(config, root, model, datasets=datasets, **compat_kwargs) logger.info("Initializing model.") if checkpoint is not None: iterator.initialize(checkpoint_path=checkpoint) else: iterator.initialize() if retrain: iterator.reset_global_step() # save current config logger.info("Starting Training with config:\n{}".format(yaml.dump(config))) cpath = _save_config(config, prefix="train") logger.info("Saved config at {}".format(cpath)) logger.info("Iterating.") iterator.iterate(batches) finally: for split in batches: batches[split].finalize()
def test(config, root, checkpoint=None, nogpu=False, bar_position=0, debug=False): """Run tests. Loads model, iterator and dataset from config.""" from edflow.iterators.batches import make_batches # backwards compatibility if not "datasets" in config: config["datasets"] = {"train": config["dataset"]} if "validation_dataset" in config: config["datasets"]["validation"] = config["validation_dataset"] log.set_log_target("latest_eval") logger = log.get_logger("test") logger.info("Starting Evaluation.") if "test_batch_size" in config: config["batch_size"] = config["test_batch_size"] if "test_mode" not in config: config["test_mode"] = True model = get_obj_from_str(config["model"]) iterator = get_obj_from_str(config["iterator"]) datasets = dict( (split, get_obj_from_str(config["datasets"][split])) for split in config["datasets"] ) logger.info("Instantiating datasets.") for split in datasets: datasets[split] = datasets[split](config=config) datasets[split].expand = True logger.info("{} dataset size: {}".format(split, len(datasets[split]))) if debug: max_examples = retrieve( config, "debug/max_examples", default=5 * config["batch_size"] ) if max_examples > 0: logger.info( "Monkey patching {} dataset __len__ to {} examples".format( split, max_examples ) ) type(datasets[split]).__len__ = lambda self: max_examples n_processes = config.get("n_data_processes", min(16, config["batch_size"])) n_prefetch = config.get("n_prefetch", 1) logger.info("Building batches.") batches = dict() for split in datasets: batches[split] = make_batches( datasets[split], batch_size=config["batch_size"], shuffle=False, n_processes=n_processes, n_prefetch=n_prefetch, error_on_timeout=config.get("error_on_timeout", False), ) try: logger.info("Initializing model.") model = model(config) config["hook_freq"] = 1 config["num_epochs"] = 1 config["nogpu"] = nogpu compat_kwargs = dict( hook_freq=config["hook_freq"], bar_position=bar_position, nogpu=config["nogpu"], num_epochs=config["num_epochs"], ) iterator = iterator(config, root, model, datasets=datasets, **compat_kwargs) logger.info("Initializing model.") if checkpoint is not None: iterator.initialize(checkpoint_path=checkpoint) else: iterator.initialize() # save current config logger.info("Starting Evaluation with config:\n{}".format(yaml.dump(config))) prefix = "eval" if bar_position > 0: prefix = prefix + str(bar_position) cpath = _save_config(config, prefix=prefix) logger.info("Saved config at {}".format(cpath)) logger.info("Iterating") while True: iterator.iterate(batches) if not config.get("eval_forever", False): break finally: for split in batches: batches[split].finalize()
def standalone_eval_meta_dset(path_to_meta_dir, callbacks, additional_kwargs={}, other_config=None): """Runs all given callbacks on the data in the :class:`EvalDataFolder` constructed from the given csv.abs Parameters ---------- path_to_csv : str Path to the csv file. callbacks : dict(name: str or Callable) Import commands used to construct the functions applied to the Data extracted from :attr:`path_to_csv`. additional_kwargs : dict Keypath-value pairs added to the config, which is extracted from the ``model_outputs.csv``. These will overwrite parameters in the original config extracted from the csv. other_config : str Path to additional config used to update the existing one as taken from the ``model_outputs.csv`` . Cannot overwrite the dataset. Only used for callbacks. Parameters in this other config will overwrite the parameters in the original config and those of the commandline arguments. Returns ------- outputs: dict The collected outputs of the callbacks. """ from edflow.util import get_obj_from_str from edflow.config import update_config import yaml if other_config is not None: with open(other_config, "r") as f: other_config = yaml.full_load(f) else: other_config = {} out_data = MetaDataset(path_to_meta_dir) out_data.expand = True out_data.append_labels = True config = out_data.meta # backwards compatibility if not "datasets" in config: config["datasets"] = {"train": config["dataset"]} if "validation_dataset" in config: config["datasets"]["validation"] = config["validation_dataset"] datasets = dict((split, get_obj_from_str(config["datasets"][split])) for split in config["datasets"]) # TODO fix hardcoded dataset in_data = datasets["validation"](config=config) in_data.expand = True update_config(config, additional_kwargs) config.update(other_config) config_callbacks, callback_kwargs = config2cbdict(config) callbacks.update(config_callbacks) callbacks = load_callbacks(callbacks) root = os.path.dirname(path_to_meta_dir) # TODO handle logging of return values outputs = apply_callbacks(callbacks, root, in_data, out_data, config, callback_kwargs) return outputs