def __init__(self, model, config, data_loaders, tb_writer, run_info,
                 logger, checkpoint_dir):
        """
        Creates a new evaluator object for evaluating a model.
        :param model: model to train. Needs to inherit from the BaseModel class.
        :param config: dictionary containing the whole configuration of the experiment
        :param data_loaders: (dictionary) the keys represent the name and each value contains
         a pytorch data loader providing the validation data
        :param tb_writer: tensorboardX summary writer
        :param run_info: sacred run info for loging training progress
        :param logger: python logger object
        :param checkpoint_dir: directory path for storing checkpoints
        """
        self.run_info = run_info
        self.logger = logger
        self.data_loaders = data_loaders
        self.config = config
        self.engine = Engine(self._step)
        self.model = model
        self.tb_writer = tb_writer
        self.trainer = None

        # Using custom metric wrapper which retrieves metrics from dictionary instead of separately calculating them.
        self.metrics = {k: LossFromDict(k) for k in self.model.metric_names}
        self.non_scalar_metrics = {
            k: LossFromDict(k, reduce=False)
            for k in self.model.non_scalar_metrics_names
        }

        if 'external_metrics' in config['val_data']:
            for idx, name in enumerate(config['val_data']['external_metrics']):
                if 'external_metrics_kw_args' in config['val_data']:
                    self.metrics[name] = get_subclass(name, Metric)(
                        config['devices'][0],
                        **config['val_data']['external_metrics_kw_args'][idx])
                else:
                    self.metrics[name] = get_subclass(name, Metric)()

        self._handle_save_best_checkpoint_handler = \
            ModelCheckpoint(checkpoint_dir, 'best',
                            score_function=lambda engine: -self.model.main_metric(engine.state.metrics),
                            score_name=self.model.name_main_metric,
                            n_saved=1,
                            require_empty=False)

        self.add_handler()
        self.best_loss = None
        self.current_data_loader = None
        self.main_data_loader = config['val_data']['main_dataset']
    def __init__(self,
                 ds_names: tuple,
                 sample_keys: tuple,
                 aug_names=None,
                 n_sub_samples=None,
                 ds_kw_args=None,
                 **kwargs):
        """
        In the initialization of the Collection all the datasets in the collection are constructed

        :param ds_names: (tuple of strings) names of the datasets
        :param sample_keys: (tuple of strings) names of the keys used to return the batch,
        order is in corresponding to the names tuple
        :param oversampling: (tuple of ints) number of samples to return, for taking a subset during debugging,
        order is in corresponding to the names tuple
        :param n_samples: (tuple of integers) number of samples per item get fetched for the respective dataset,
        order is in corresponding to the names tuple
        :param aug_names: (tuple of tuples of strings) sequence of augmentation class names for the respective dataset,
        order is in corresponding to the names tuple
        :param n_sub_samples: (tuple of ints or int)
        # TODO should we add kwargs here and how should we pass them to the child classes?
        :param kwargs:
        """
        super(Collection, self).__init__()

        self.ds_names = ds_names
        self.sample_keys = sample_keys

        self.ds = []

        n_ds = len(ds_names)  # number of datasets part of the collection

        if aug_names is None:
            aug_names = n_ds * (None, )

        n_sub_samples = self._check_and_extend(n_sub_samples, n_ds)
        self.n_sub_samples = n_sub_samples

        ds_kw_args = self._check_and_extend(ds_kw_args, n_ds)

        assert n_ds == len(sample_keys) == len(aug_names) == len(n_sub_samples)

        for ds_name, n_subsample, aug_sequence, keyword_args in \
                zip(ds_names, n_sub_samples, aug_names, ds_kw_args):
            # TODO discuss the following options, if the argument is not given and is therefor by default None,
            #  we can either overwrite these arguments for the instantiation each class inside the collection
            #  or we can use the default argument from the class and thus not pass it along. The last option would
            #  a little bit dirtier code wise but could make sense
            ds_class = get_subclass(ds_name, Dataset)
            if isinstance(keyword_args, dict):
                ds_inst = ds_class(subsample=n_subsample,
                                   aug_names=aug_sequence,
                                   **keyword_args)
            else:
                ds_inst = ds_class(subsample=n_subsample,
                                   aug_names=aug_sequence)
            self.ds.append(ds_inst)
예제 #3
0
    def get_dataset_params(cls, name) -> dict:
        """
        Returns configurable parameters of a subclass
        :param name: sub class name
        :return: parameters in a dictionary
        """
        subclass = get_subclass(name, cls)
        params = get_parameter_of_cls(subclass)

        return params
예제 #4
0
def update_factory_class(model_name: str, factory_name: str,
                         new_class: Union[str, type]) -> None:
    """
    Updates one factory of a given model depending on a chosen class which can be a string which matches the
    class name or a type itself.
    :param model_name:
    :param factory_name:
    :param new_class: str or class
    :return:
    """
    if factory_name not in factories[model_name].keys():
        return

    _factory = factories[model_name][factory_name]
    _factory['cls'] = get_subclass(new_class, _factory['base_class'])

    params = get_parameter_of_cls(_factory['cls'])
    _factory['params'] = params
    _factory['has_args'] = len(get_args(_factory['cls'])) > 0
예제 #5
0
def factory(name, base_class, default_class, **kwargs):
    """
    Decorator factory for a factory decorator which allows to inject configurable objects into a models ini function.
    The injected objects will be accessible through the kwargs argument through the defined name. In cases the factory
    class requires arguments for it's instantiation the decorator will return a factory method instead of an object.
    The factory method will have all the parameters inluded inside and only requires the argument to build the final
    object. This is for example needed for optimizers since they require model parameters.

    Example:
        @factory('network', Network, ResFCN256)
        @factory('loss', Loss, 'WeightMaskMSE')
        @factory('optim', Optimizer, 'adam', lr=0.0001, weight_decay=0.0)
        def __init__(self, gpu_ids, is_train, **kwargs):
            network = kwargs['network']
            loss = kwargs['loss']
            optim = kwargs['optim']

            # implement some magic

    :param name: name of the factory. will determent the key inside the kwargs dict and also the name in the sacred conf
    :param base_class: Base class of the factory object. For example Network for all the network definitions.
    :param default_class: The default class to use if not overwritten by the config. Can be string or type
    :param kwargs: Additional default parameters for the factory to use if not overwritten by the config
    :return:
    """
    model = inspect.stack()[1].function

    base_class = base_class
    name = name

    cls = get_subclass(default_class, base_class)
    params = get_parameter_of_cls(cls)
    params.update(**kwargs)
    has_args = len(get_args(cls)) > 0

    factories[model.lower()][name] = {
        'params': params,
        'cls': cls,
        'has_args': has_args,
        'base_class': base_class
    }

    def _func(caller, *args, **kwargs):
        """
        The actual decorator function which gets returned by the decorator factory.
        :param caller: the function which gets wrapped by the decorator
        :param args:
        :param kwargs:
        :return:
        """
        _factory = factories[model.lower()][name]
        try:
            if _factory['has_args']:

                def _cls(*args):
                    return _factory['cls'](*args, **_factory['params'])

                obj = _cls
            else:
                obj = _factory['cls'](**_factory['params'])
        except TypeError as e:
            msg = str(e)
            msg += '\nValid options for class %s are:' % _factory[
                'cls'].__name__
            msg += ''.join([
                '\n\t* %s (default=%s)' % (k, v)
                for k, v in get_parameter_of_cls(_factory['cls']).items()
            ])
            raise type(e)(msg).with_traceback(sys.exc_info()[2])

        kwargs[name] = obj
        return caller(*args, **kwargs)

    return decorator.decorator(_func)
 def __call__(self, sample):
     for aug_name in self.aug_name_lst:
         aug = get_subclass(aug_name, Augment)()
         sample = aug(sample)
     return sample
    def get_default_params(cls, name):
        subclass = get_subclass(name, cls)
        params = get_parameter_of_cls(subclass, ancestors=False)

        return params
def train(model_name: str, run, logger):
    """
    Main function for starting the training. Sets up the dataloader, trainer and evaluater objects
    and starts the training.
    :param model_name: name of the model definition that is defined in the hibashi/models directory
    :param run: sacred run object containing for example the configuration of the experiment
    :param logger: python logger object
    :return: final validation loss
    """
    config = run.config
    print(f'Printing out configuration:')
    pprint.pprint(config)

    run_meta_dir = get_meta_dir(run)
    checkpoint_dir = os.path.join(run_meta_dir, 'checkpoints')
    tb_logger_dir = run_meta_dir

    run.info.update({"tensorflow": {"logdirs": [tb_logger_dir]}})
    writer = SummaryWriter(tb_logger_dir, filename_suffix='')

    importlib.import_module(f'hibashi.models.{model_name}.data.datasets',
                            __package__)
    importlib.import_module(f'hibashi.models.{model_name}.data.augmentations',
                            __package__)
    importlib.import_module(f'hibashi.models.{model_name}.losses', __package__)

    train_dataset = get_subclass(config['train_data']['name'],
                                 Dataset)(**config['train_data'])

    if 'sampler_n_per_ds' in config['train_data']:
        train_sampler = get_subclass(config['train_data']['sampler'], Sampler)(
            data_source=train_dataset,
            n_per_ds=config['train_data']['sampler_n_per_ds'])
    else:
        train_sampler = get_subclass(config['train_data']['sampler'],
                                     Sampler)(data_source=train_dataset)

    if 'collate_fn' in config['train_data']:
        collate_fn = getattr(collate, config['train_data']['collate_fn'])
    else:
        collate_fn = default_collate

    train_loader = DataLoader(train_dataset,
                              batch_size=config['train_data']['batch_size'],
                              num_workers=config['train_data']['n_workers'],
                              drop_last=config['train_data']['drop_last'],
                              sampler=train_sampler,
                              collate_fn=collate_fn,
                              pin_memory=True)

    val_loaders = {}
    for name in config['val_data']['names']:
        # TODO I would like to make the following a bit nicer
        if name == 'Collection':
            keyword_args = config['val_data'][name]['ds_kwargs']
            val_dataset = get_subclass(name,
                                       Dataset)(**config['val_data'][name],
                                                ds_kw_args=keyword_args)
        else:
            val_dataset = get_subclass(name,
                                       Dataset)(**config['val_data'][name])

        logger.info('Validation dataset {}: {}'.format(name, len(val_dataset)))

        if 'n_val_samples' in config['val_data'][name]:
            val_sampler = get_subclass(
                config['val_data'][name]['sampler'],
                Sampler)(data_source=val_dataset,
                         num_samples=config['val_data'][name]['n_val_samples'],
                         replacement=True)
        else:
            val_sampler = get_subclass(config['val_data'][name]['sampler'],
                                       Sampler)(data_source=val_dataset)

        if 'collate_fn' in config['val_data'][name]:
            collate_fn = getattr(collate,
                                 config['val_data'][name]['collate_fn'])
        else:
            collate_fn = default_collate

        val_loaders[name] = DataLoader(
            val_dataset,
            batch_size=config['val_data'][name]['batch_size'],
            sampler=val_sampler,
            num_workers=config['val_data'][name]['n_workers'],
            drop_last=config['val_data'][name]['drop_last'],
            collate_fn=collate_fn,
            pin_memory=True)

    logger.info('Train data: {}'.format(len(train_dataset)))

    # Build the model
    model = get_subclass(model_name, Model)(config['devices'],
                                            is_train=True,
                                            **config['model'])
    model.print_networks(verbose=True)

    # Trainer / Evaluator
    evaluator = Evaluator(model, config, val_loaders, writer, run, logger,
                          checkpoint_dir)

    trainer = Trainer(model, config, evaluator, train_loader, writer, run,
                      logger, checkpoint_dir)

    evaluator.set_trainer(trainer)

    trainer.run()
    writer.close()

    return evaluator.best_loss
def create_optimizer(name, params, model):
    opt_cls = get_subclass(name, Optimizer)

    return opt_cls(model.parameters(), **params)
def get_optimizer_params(name):
    opt_cls = get_subclass(name, Optimizer)
    params = get_parameter_of_cls(opt_cls)

    return params