Exemplo n.º 1
0
    def __call__(self, options, history, model, losses, outputs, datasets, datasets_infos, callbacks_per_batch, **kwargs):
        logger.info('CallbackExportHistory.__call__ started')
        export_root = options['workflow_options']['current_logging_directory']
        sample_root_dir = os.path.join(export_root, self.export_dirname)
        utilities.create_or_recreate_folder(sample_root_dir)

        for dataset_name, dataset in outputs.items():
            split_name, outputs = next(iter(dataset.items()))
            for output_name, output in outputs.items():
                metrics_names = extract_metrics_name(
                    history,
                    dataset_name,
                    split_name,
                    output_name,
                    self.dicarded_metrics)

                for metric_name in metrics_names:
                    r = extract_from_history(history, dataset_name, output_name, metric_name)
                    if r is None:
                        continue
                    r = merge_history_values([r])
                    analysis_plots.plot_group_histories(
                        sample_root_dir,
                        r,
                        '{}/{}/{}'.format(dataset_name, output_name, metric_name),
                        xlabel='Epochs',
                        ylabel=metric_name)

        logger.info('CallbackExportHistory.__call__ done!')
Exemplo n.º 2
0
    def first_time(self, options, datasets, model):
        # here we only want to collect the kernels a single time per epoch, so fix the dataset/split names
        if self.dataset_name is None or self.split_name is None:
            self.dataset_name, self.split_name = utilities.find_default_dataset_and_split_names(
                datasets,
                default_dataset_name=self.dataset_name,
                default_split_name=self.split_name)

        if self.dataset_name is None or self.split_name is None:
            logger.error('can\'t find a dataset name or split name!')
            return

        self.kernel_root_path = os.path.join(
            options['workflow_options']['current_logging_directory'],
            self.dirname)
        utilities.create_or_recreate_folder(self.kernel_root_path)

        # find the requested kernels
        kernels = []
        batch = next(iter(datasets[self.dataset_name][self.split_name]))
        for fn in self.find_convolution_fns:
            result = fn(model, batch)
            if result is not None:
                assert 'matched_module' in result, 'must be a dict with key `matched_module`'
                kernel = result['matched_module'].weight
                kernels.append(kernel)
            else:
                logger.error('can\'t find a convolution kernel!')

        logger.info(f'number of convolution kernel found={len(kernels)}')
        self.kernels = kernels
Exemplo n.º 3
0
    def __del__(self):
        if self.output_path is not None:
            list_of_lr_optimizers = {
                'optimizer ' + name: [l]
                for name, l in self.lr_optimizers.items()
            }

            utilities.create_or_recreate_folder(self.output_path)
            analysis_plots.plot_group_histories(
                self.output_path,
                history_values=list_of_lr_optimizers,
                title='Learning rate by epoch',
                xlabel='Epochs',
                ylabel='Learning rate')
Exemplo n.º 4
0
    def __call__(self, options, history, model, losses, outputs, datasets,
                 datasets_infos, callbacks_per_batch, **kwargs):
        if self.dataset_name is None and outputs is not None:
            self.first_time(datasets, outputs)
        if self.output_name is None or self.split_names is None or self.dataset_name is None:
            return

        self.root = os.path.join(
            options['workflow_options']['current_logging_directory'],
            self.dirname)
        if not os.path.exists(self.root):
            utilities.create_or_recreate_folder(self.root)

        dataset_output = outputs.get(self.dataset_name)
        if dataset_output is None:
            return

        self.current_epoch = len(history) - 1
        for split_name in self.split_names:
            split_output = dataset_output.get(split_name)
            if split_output is not None:
                output = split_output.get(self.output_name)
                if output is not None and 'uid' in output:
                    uids = output['uid']
                    output_losses = trw.utils.to_value(output['losses'])
                    assert len(uids) == len(output_losses)
                    for loss, uid in zip(output_losses, uids):
                        # record the epoch: for example if we have resampled dataset,
                        # we may not have all samples selected every epoch so we can
                        # display properly these epochs
                        self.errors_by_split[split_name][uid].append(
                            (loss, self.current_epoch))

        last_epoch = kwargs.get('last_epoch')
        if last_epoch:
            self.export_stats(model=model,
                              losses=losses,
                              datasets=datasets,
                              datasets_infos=datasets_infos,
                              options=options,
                              callbacks_per_batch=callbacks_per_batch)
Exemplo n.º 5
0
    def __call__(self, options, history, model, losses, outputs, datasets, datasets_infos, callbacks_per_batch, **kwargs):
        root = options['workflow_options']['current_logging_directory']
        logger.info('root={}'.format(root))
        logger_tb = callback_tensorboard.CallbackTensorboardBased.create_logger(root)
        if logger_tb is None:
            return

        if self.dataset_name is None:
            if self.dataset_name is None:
                self.dataset_name = next(iter(datasets))

        if self.split_name is None:
            self.split_name = next(iter(datasets[self.dataset_name]))

        device = options['workflow_options']['device']

        # ONNX export MUST be in eval mode!
        model.eval()

        batch = next(iter(datasets[self.dataset_name][self.split_name]))
        batch = utilities.transfer_batch_to_device(batch, device=device)
        trw.train.utilities.postprocess_batch(self.dataset_name, self.split_name, batch, callbacks_per_batch)

        class NoDictModel(torch.nn.Module):
            # cahce the model input as onnx doesn't handle dict like input
            # or output
            def __init__(self, model, batch):
                super().__init__()
                self.model = model
                self.batch = batch

            def __call__(self, *input, **kwargs):
                with torch.no_grad():
                    r = self.model(self.batch)
                outputs = [o.output for name, o in r.items()]
                return outputs

        # at the moment, few issues with the onnx export with the support
        # of return dictionary is spotty. There are 2 ways options:
        # 1) export to onnx format and use `logger_tb.add_onnx_graph` to export the graph -> this requires the `onnx` additional dependency. This only works for the
        #    latest PyTorch. Additionally, onnx should be installed via conda
        # 2) export directly using `logger_tb.add_graph(NoDictModel(model, batch), input_to_model=torch.Tensor())`, but in the dev build, this fails
        # option 2 is preferable but doesn't work right now

        try:
            # option 1)
            root_onnx = os.path.join(root, self.onnx_folder)
            utilities.create_or_recreate_folder(root_onnx)
            onnx_filepath = os.path.join(root_onnx, 'model.onnx')
            logger.info('exporting ONNX model to `{}`'.format(onnx_filepath))
            with utilities.CleanAddedHooks(model):  # make sure no extra hooks are kept
                with open(onnx_filepath, 'wb') as f:
                    torch.onnx.export(NoDictModel(model, batch), torch.Tensor(), f)  # fake input. The input is already binded in the wrapper!
                logger_tb.add_onnx_graph(onnx_filepath)
                # else there is an assert here. Not sure what is happening

                # option 2)
                #logger_tb.add_graph(NoDictModel(model, batch), input_to_model=torch.Tensor())
            logger.info('successfully exported!')
        except Exception as e:
            logger.error('ONNX export failed! Exception=', str(e))
Exemplo n.º 6
0
def analyse_hyperparameters(hprams_path_pattern,
                            output_path,
                            hparams_to_visualize=None,
                            params_forest_n_estimators=5000,
                            params_forest_max_features_ratio=0.6,
                            top_k_covariance=5,
                            create_graphs=True,
                            verbose=True,
                            dpi=300):
    """
    Importance hyper-pramaeter estimation using random forest regressors

    From simulation, the ordering of hyper-parameters importance is correct, but the importance value itself may be
    over-estimated (for the best param) and underestimated (for the others).

    The scatter plot for each hparam is useful to understand in what direction the hyper-parameter should be modified

    The covariance plot can be used to understand the relation between most important hyper-parameter

    WARNING:
    [1] With correlated features, strong features can end up with low scores and the method can be biased towards
    variables with many categories. See for more details:
    see http://blog.datadive.net/selecting-good-features-part-iii-random-forests/
    and https://link.springer.com/article/10.1186%2F1471-2105-8-25

    :param params_forest_n_estimators: number of trees used to estimate the loss from the hyperparameters
    :param params_forest_max_features_ratio: the maximum number of features to be used. Note we don't want to
           select all the features to limit the correlation importance decrease effect [1]
    :param hprams_path_pattern: a pattern (globing) to be used to select the hyper parameter files
    :param hparams_to_visualize: a list of hparam names to visualize or `None`. If `None`, display from the most
              important (i.e., causing the most loss variation) to the least
    :param create_graphs: if True, export matplotlib visualizations
    :param top_k_covariance: export the parameter covariance for the most important k hyper-parameters
    :param output_path: where to export the graph
    :param dpi: the resolution of the exported graph
    :param verbose: if True, display additional information
    :return:
    """
    files = glob.glob(hprams_path_pattern)
    files = [file for file in files if os.path.isfile(file)]
    if verbose:
        print('Hyper parameter files matching patterns=\n',
              str(files).replace(',', '\n'))

    if len(files) == 0:
        return

    data = []
    for file in files:
        loss, _, params = params_optimizer_random_search.load_loss_params(file)
        with open(file, 'rb') as f:
            loss = pickle.load(f)
            _ = pickle.load(f)
            params = pickle.load(f)

        params_current = {}
        if hparams_to_visualize is None:
            for key, value in params.hparams.items():
                params_current[key] = value.current_value
        else:
            for key in hparams_to_visualize:
                value = params.hparams.get(key)
                if value is not None:
                    params_current[key] = value

        params_current['loss'] = loss
        data.append(params_current)

    f = pd.DataFrame(data)
    param_names = [name for name in f.columns if name != 'loss']

    # hyper parameter importance: use an extra tree to predict the loss value from the hyper parameter values
    # in order to calculate the importance of each feature
    #Estimator = RandomForestRegressor
    Estimator = ExtraTreesRegressor
    params_forest = {
        'n_estimators':
        params_forest_n_estimators,
        'max_features':
        max(1, int(params_forest_max_features_ratio * len(param_names))),
        #'bootstrap': True,
    }
    forest = Estimator(**params_forest)

    values = f[param_names].values
    values, mapping = discretize(
        values)  # here special case for strings as values
    forest.fit(X=values, y=f['loss'].values)
    importances = forest.feature_importances_
    std = np.std([tree.feature_importances_ for tree in forest.estimators_],
                 axis=0)
    indices = np.argsort(importances)[::-1]

    sorted_param_names = np.asarray(param_names)[indices]
    sorted_importances = importances[indices]

    if create_graphs:
        if verbose:
            print('output_path=%s' % output_path)

        utilities.create_or_recreate_folder(output_path)

        fig = _plot_importance(plot_name='hyper-parameter importance',
                               x_names=sorted_param_names,
                               y_values=sorted_importances,
                               y_name='importance',
                               y_errors=std)
        fig.savefig(os.path.join(output_path,
                                 'hyper-parameter importance.png'),
                    dpi=dpi)

        # now we know what parameter is important, but we need to understand in what direction this is important
        # (e.g., beneficial or detrimental values?)
        for h in indices:
            param_name = param_names[h]
            fig = _plot_scatter(plot_name='[%s] variations' % param_name,
                                y_values=f['loss'].values,
                                y_name='loss',
                                x_name=param_name,
                                x_values=values[:, h],
                                x_ticks=mapping.get(h))
            fig.savefig(os.path.join(output_path, param_name + '.png'),
                        dpi=dpi)

        # finally, we want to look at the hyper parameter covariances
        best_param_names = sorted_param_names[:top_k_covariance]
        for y in range(0, len(param_names)):
            for x in range(y + 1, len(param_names)):
                feature_1 = param_names[y]
                feature_2 = param_names[x]
                if feature_1 in best_param_names and feature_2 in best_param_names:
                    fig = _plot_param_covariance(
                        plot_name='[%s, %s] variations' %
                        (feature_1, feature_2),
                        y_values=values[:, y],
                        y_name=feature_1,
                        x_values=values[:, x],
                        x_name=feature_2,
                        xy_values=f['loss'])
                    fig.savefig(os.path.join(
                        output_path,
                        'covariance_' + feature_1 + '_' + feature_2 + '.png'),
                                dpi=dpi)

    return {
        'sorted_param_names': sorted_param_names,
        'sorted_importances': sorted_importances
    }
Exemplo n.º 7
0
    def __call__(self, options, history, model, losses, outputs, datasets,
                 datasets_infos, callbacks_per_batch, **kwargs):
        """
        .. note:: The model will be deep copied so that we don't influence the training

        Args:
            **kwargs: required `optimizers_fn`
        """
        logger.info('started CallbackLearningRateFinder.__call__')
        device = options['workflow_options']['device']

        output_path = os.path.join(
            options['workflow_options']['current_logging_directory'],
            self.dirname)
        utilities.create_or_recreate_folder(output_path)

        if self.dataset_name is None:
            self.dataset_name = next(iter(datasets))

        if self.split_name is None:
            self.split_name = options['workflow_options']['train_split']

        logger.info(
            'dataset={}, split={}, nb_samples={}, learning_rate_start={}, learning_rate_stop={}'
            .format(self.dataset_name, self.split_name,
                    self.nb_samples_per_learning_rate,
                    self.learning_rate_start, self.learning_rate_stop))

        callback_stop_epoch = CallbackStopEpoch(
            nb_samples=self.nb_samples_per_learning_rate)
        if callbacks_per_batch is not None:
            callbacks_per_batch = copy.copy(
                callbacks_per_batch)  # make sure these are only local changes!
            callbacks_per_batch.append(callback_stop_epoch)
        else:
            callbacks_per_batch = [callback_stop_epoch]

        optimizers_fn = kwargs.get('optimizers_fn')
        assert optimizers_fn is not None, '`optimizers_fn` can\'t be None!'
        split = datasets[self.dataset_name][self.split_name]

        lr_loss_list = []
        learning_rate = self.learning_rate_start
        model_copy = copy.deepcopy(model)
        while learning_rate < self.learning_rate_stop:
            # we do NOT want to modify our model or optimizer so make a copy
            # we restart from the original model to better isolate the learning rate effect
            #model_copy = copy.deepcopy(model)
            optimizers, _ = optimizers_fn(datasets, model_copy)
            optimizer = optimizers.get(self.dataset_name)
            assert optimizer is not None, 'optimizer can\'t be found for dataset={}'.format(
                self.dataset_name)
            utilities.set_optimizer_learning_rate(optimizer, learning_rate)

            callback_stop_epoch.reset()
            all_loss_terms = trainer.train_loop(
                device,
                self.dataset_name,
                self.split_name,
                split,
                optimizer,
                model_copy,
                losses[self.dataset_name],
                history=None,
                callbacks_per_batch=callbacks_per_batch,
                callbacks_per_batch_loss_terms=None)

            loss = 0.0
            for loss_batch in all_loss_terms:
                current_loss = trw.utils.to_value(
                    loss_batch['overall_loss']['loss'])
                loss += current_loss
            loss /= len(all_loss_terms)

            lr_loss_list.append((learning_rate, loss))
            learning_rate *= self.learning_rate_mul

        lines_x = np.asarray(lr_loss_list)[:, 0]
        lines_y = np.asarray(lr_loss_list)[:, 1]

        # find the relevant section of LR
        logger.debug('Raw (LR, loss) by epoch:\n{}'.format(
            list(zip(lines_x, lines_y))))
        lines_x, lines_y = self.identify_learning_rate_section(
            lines_x,
            lines_y,
            loss_ratio_to_discard=self.param_maximum_loss_ratio)
        logger.debug('Interesting LR section (LR, loss) by epoch:\n{}'.format(
            list(zip(lines_x, lines_y))))

        # select the LR with the smallest loss
        min_index = np.argmin(lines_y)
        best_learning_rate = lines_x[min_index]
        best_loss = lines_y[min_index]

        # finally, export the figure
        y_range = (
            np.min(lines_y) * 0.9, lines_y[0] * 1.1
        )  # there is not point to display losses more than no training at all!
        plot_trend(export_path=output_path,
                   lines_x=lines_x,
                   lines_y=lines_y,
                   title='Learning rate finder ({}, {})'.format(
                       self.dataset_name, self.split_name),
                   xlabel='learning rate',
                   ylabel='overall_loss',
                   y_range=y_range,
                   x_scale='log',
                   y_scale='log',
                   name_xy_markers={
                       'LR={}'.format(best_learning_rate):
                       (best_learning_rate, best_loss)
                   })

        logger.info('best_learning_rate={}'.format(best_learning_rate))
        logger.info('best_loss={}'.format(best_loss))
        print('best_learning_rate=', best_learning_rate)
        print('best_loss=', best_loss)

        if self.set_new_learning_rate:
            best_learning_rate *= self.learning_rate_final_multiplier
            optimizers = kwargs.get('optimizers')
            if optimizers is not None:
                for optimizer_name, optimizer in optimizers.items():
                    logger.info(
                        'optimizer={}, changed learning rate={}'.format(
                            optimizer_name, best_learning_rate))
                    utilities.set_optimizer_learning_rate(
                        optimizer, best_learning_rate)
            else:
                logger.warning('No optimizers available in `kwargs`')

        logger.info(
            'successfully finished CallbackLearningRateFinder.__call__')
Exemplo n.º 8
0
    def __call__(self, options, history, model, losses, outputs, datasets,
                 datasets_infos, callbacks_per_batch, **kwargs):

        logger.info('started CallbackExportSamples.__call__')
        device = options['workflow_options']['device']

        if not self.reporting_config_exported:
            # export how the samples should be displayed by the reporting
            config_path = options['workflow_options']['sql_database_view_path']
            update_json_config(
                config_path, {
                    self.table_name: {
                        'data': {
                            'keep_last_n_rows':
                            self.reporting_config_keep_last_n_rows,
                            'subsampling_factor':
                            self.reporting_config_subsampling_factor,
                        },
                        'default': {
                            'Scatter X Axis': self.reporting_scatter_x,
                            'Scatter Y Axis': self.reporting_scatter_y,
                            'Color by': self.reporting_color_by,
                            'Display with': self.reporting_display_with,
                            'Binning X Axis': self.reporting_binning_x_axis,
                            'Binning selection':
                            self.reporting_binning_selection,
                        }
                    }
                })
            self.reporting_config_exported = True

        sql_database = options['workflow_options']['sql_database']
        if self.clear_previously_exported_samples:
            cursor = sql_database.cursor()
            table_truncate(cursor, self.table_name)
            sql_database.commit()

            # also remove the binary/image store
            root = os.path.dirname(
                options['workflow_options']['sql_database_path'])
            create_or_recreate_folder(
                os.path.join(root, 'static', self.table_name))

        sql_table = reporting.TableStream(cursor=sql_database.cursor(),
                                          table_name=self.table_name,
                                          table_role='data_samples')

        logger.info(f'export started..., N={self.max_samples}')
        for dataset_name, dataset in datasets.items():
            root = os.path.join(
                options['workflow_options']['current_logging_directory'],
                'static', self.table_name)
            if not os.path.exists(root):
                utilities.create_or_recreate_folder(root)

            for split_name, split in dataset.items():
                exported_cases = []
                trainer.eval_loop(
                    device,
                    dataset_name,
                    split_name,
                    split,
                    model,
                    losses[dataset_name],
                    history=None,
                    callbacks_per_batch=callbacks_per_batch,
                    callbacks_per_batch_loss_terms=[
                        functools.partial(
                            callbacks_per_loss_term,
                            root=options['workflow_options']
                            ['current_logging_directory'],
                            datasets_infos=datasets_infos,
                            loss_terms_inclusion=self.loss_terms_inclusion,
                            feature_exclusions=self.feature_exclusions,
                            dataset_exclusions=self.dataset_exclusions,
                            split_exclusions=self.split_exclusions,
                            exported_cases=exported_cases,
                            max_samples=self.max_samples,
                            epoch=len(history),
                            sql_table=sql_table,
                            format=self.format,
                            select_fn=self.select_sample_to_export)
                    ])

        sql_database.commit()
        logger.info('successfully completed CallbackExportSamples.__call__!')
Exemplo n.º 9
0
    def fit(self,
            options,
            inputs_fn,
            model_fn,
            optimizers_fn,
            losses_fn=default_sum_all_losses,
            loss_creator=create_losses_fn,
            run_prefix='default',
            with_final_evaluation=True,
            eval_every_X_epoch=1):
        """
        Fit the model

        Requirements:

        * enough main memory to store the outputs of all the datasets of a single epoch.
            If this cannot be satisfied, sub-sample the epoch so that it can fit in main memory.
        
        Notes:

        * if a feature value is Callable, its value will be replaced by the result of the call
            (e.g., this can be useful to generate `z` embedding in GANs)

        :param options:
        :param inputs_fn: a functor returning a dictionary of datasets. Alternatively, datasets infos can be specified.
                        `inputs_fn` must return one of:

                        * datasets: dictionary of dataset
                        * (datasets, datasets_infos): dictionary of dataset and additional infos
                        
                        We define:

                        * datasets: a dictionary of dataset. a dataset is a dictionary of splits. a split is a dictionary of batched features.
                        * Datasets infos are additional infos useful for the debugging of the dataset (e.g., class mappings, sample UIDs).
                        Datasets infos are typically much smaller than datasets should be loaded in loadable in memory

        :param model_fn: a functor with parameter `options` and returning a `Module` or a `ModuleDict`
        
        Depending of the type of the model, this is how it will be used:

        * `Module`: optimizer will optimize `model.parameters()`
        * `ModuleDict`: for each dataset name, the optimizer will optimize
            `model[dataset_name].parameters()`. Note that a `forward` method will need to be implemented

        :param losses_fn:
        :param optimizers_fn:
        :param loss_creator:
        :param eval_every_X_epoch: evaluate the model every `X` epochs
        :param run_prefix: the prefix of the output folder
        :param with_final_evaluation: if True, once the model is fitted, evaluate all the data again in eval mode
        :return: a tuple `model, result`
        """
        # set up our log path. This is where all the analysis of the model will be exported
        log_path = os.path.join(
            options['workflow_options']['logging_directory'], run_prefix +
            '_r{}'.format(options['workflow_options']['trainer_run']))
        options['workflow_options']['current_logging_directory'] = log_path

        # now clear our log path to remove previous files if needed
        utilities.create_or_recreate_folder(log_path)

        if len(logging.root.handlers) == 0:
            # there is no logger configured, so add a basic one
            logging.basicConfig(
                filename=os.path.join(
                    options['workflow_options']['logging_directory'],
                    'logging.txt'),
                format='%(asctime)s %(levelname)s %(name)s %(message)s',
                level=logging.DEBUG,
                filemode='w')

        # create the reporting SQL database
        sql_path = os.path.join(
            options['workflow_options']['current_logging_directory'],
            'reporting_sqlite.db')
        sql = sqlite3.connect(sql_path)
        options['workflow_options']['sql_database'] = sql
        options['workflow_options']['sql_database_path'] = sql_path
        options['workflow_options'][
            'sql_database_view_path'] = sql_path.replace('.db', '.json')

        # here we want to have our logging per training run, so add a handler
        handler = logging.FileHandler(os.path.join(log_path, 'trainer.txt'))
        formatter = utilities.RuntimeFormatter(
            '%(asctime)s %(levelname)s %(name)s %(message)s')
        handler.setFormatter(formatter)
        logging.root.addHandler(handler)

        # instantiate the datasets, model, optimizers and losses
        logger.info('started Trainer.fit(). Options={}'.format(options))

        datasets_infos = None
        logger.info('creating datasets...')
        datasets = inputs_fn()
        logger.info('datasets created successfully!')
        assert datasets is not None, '`datasets` is None!'
        if isinstance(datasets, tuple):
            if len(datasets) == 2:
                logger.info('inputs_fn specified `datasets, datasets_infos`')
                datasets, datasets_infos = datasets
            else:
                assert 0, 'expected tuple `datasets` or `datasets, datasets_infos`'

        logger.info('creating model...')
        model = model_fn(options)
        logger.info('model created successfully!')

        if isinstance(model, torch.nn.ModuleDict):
            # if we have sub-models, we MUST define a `forward` method
            # to orchestrate the calls of sub-models
            assert 'forward' in dir(model)

        # migrate the model to the specified device
        device = options['workflow_options']['device']

        logger.info('model moved to device={}'.format(device))
        model.to(device)

        # instantiate the optimizer and scheduler
        logger.info('creating optimizers...')
        if optimizers_fn is not None:
            optimizers, schedulers = optimizers_fn(datasets, model)
            logger.info('optimizers created successfully!')
        else:
            logger.info('optimizer fn is None! No optimizer created.')
            optimizers, schedulers = None, None

        logger.info('creating losses...')
        losses = loss_creator(datasets, losses_fn)
        logger.info('losses created successfully!')

        num_epochs = options['training_parameters']['num_epochs']

        if isinstance(optimizers, tuple):
            assert len(optimizers) == 2, 'expected tuple(optimizer, scheduler)'
            optimizers, schedulers = optimizers

        history = []

        logger.info('creating callbacks...')
        if self.callbacks_per_epoch_fn is not None:
            callbacks_per_epoch = self.callbacks_per_epoch_fn()
        else:
            callbacks_per_epoch = []

        callbacks_per_batch = []
        if self.trainer_callbacks_per_batch is not None:
            callbacks_per_batch.append(self.trainer_callbacks_per_batch)
        if self.callbacks_per_batch_fn is not None:
            callbacks_per_batch += self.callbacks_per_batch_fn()

        callbacks_per_batch_loss_terms = []
        if self.callbacks_per_batch_loss_terms_fn is not None:
            callbacks_per_batch_loss_terms += self.callbacks_per_batch_loss_terms_fn(
            )
        logger.info('callbacks created successfully!')

        # run the callbacks  before training
        if self.callbacks_pre_training_fn is not None:
            logger.info('running pre-training callbacks...')
            callbacks = self.callbacks_pre_training_fn()
            for callback in callbacks:
                callback(options,
                         history,
                         model,
                         losses=losses,
                         outputs=None,
                         datasets=datasets,
                         datasets_infos=datasets_infos,
                         callbacks_per_batch=callbacks_per_batch,
                         optimizers_fn=optimizers_fn,
                         optimizers=optimizers)
                #try:
                #    callback(options, history, model, losses=losses, outputs=None,
                #             datasets=datasets, datasets_infos=datasets_infos, callbacks_per_batch=callbacks_per_batch, optimizers_fn=optimizers_fn, optimizers=optimizers)
                #except Exception as e:
                #    print('callback={} failed with exception={}'.format(callback, e))
                #    logger.error('callback={} failed with exception={}'.format(callback, e))
            logger.info('pre-training callbacks completed!')

        for epoch in range(num_epochs):
            logger.info('started training epoch {}'.format(epoch))
            run_eval = epoch == 0 or (epoch + 1) % eval_every_X_epoch == 0

            outputs_epoch, history_epoch = self.run_epoch_fn(
                options,
                datasets,
                optimizers,
                model,
                losses,
                schedulers,
                history,
                callbacks_per_batch,
                callbacks_per_batch_loss_terms,
                run_eval=run_eval,
                force_eval_mode=False)
            history.append(history_epoch)

            logger.info('finished training epoch {}'.format(epoch))

            last_epoch = epoch + 1 == num_epochs

            logger.info('callbacks started')
            for callback in callbacks_per_epoch:
                callback(options,
                         history,
                         model,
                         losses=losses,
                         outputs=outputs_epoch,
                         datasets=datasets,
                         datasets_infos=datasets_infos,
                         callbacks_per_batch=callbacks_per_batch,
                         optimizers_fn=optimizers_fn,
                         optimizers=optimizers,
                         last_epoch=last_epoch)
                #try:
                #    callback(options, history, model, losses=losses, outputs=outputs_epoch,
                #             datasets=datasets, datasets_infos=datasets_infos, callbacks_per_batch=callbacks_per_batch,
                #             optimizers_fn=optimizers_fn, optimizers=optimizers, last_epoch=last_epoch)
                #except Exception as e:
                #    logger.error('callback={} failed with exception={}'.format(callback, e))

            logger.info('callbacks epoch {} finished'.format(epoch))

        # finally run the post-training callbacks
        outputs_epoch = None
        if with_final_evaluation:
            logger.info('started final evaluation...')
            outputs_epoch, history_epoch = self.run_epoch_fn(
                options,
                datasets,
                None,
                model,
                losses,
                None,
                history,
                callbacks_per_batch,
                callbacks_per_batch_loss_terms,
                run_eval=True,
                force_eval_mode=True)
            logger.info('finished final evaluation...')
            history.append(history_epoch)

        if self.callbacks_post_training_fn is not None:
            logger.info('started post training callbacks...')
            callbacks_post_training = self.callbacks_post_training_fn()
            for callback in callbacks_post_training:
                callback(options,
                         history,
                         model,
                         losses=losses,
                         outputs=outputs_epoch,
                         datasets=datasets,
                         datasets_infos=datasets_infos,
                         callbacks_per_batch=callbacks_per_batch,
                         optimizers_fn=optimizers_fn)
                #try:
                #    callback(options, history, model, losses=losses, outputs=outputs_epoch,
                #             datasets=datasets, datasets_infos=datasets_infos, callbacks_per_batch=callbacks_per_batch, optimizers_fn=optimizers_fn)
                #except Exception as e:
                #    print('callback={} failed with exception={}'.format(callback, e))
                #    logger.error('callback={} failed with exception={}'.format(callback, e))

            logger.info('finished post training callbacks...')

            del callbacks_post_training
            logger.info('deleted post training callbacks!')

        # increment the number of runs
        options['workflow_options']['trainer_run'] += 1

        logger.info('removing logging handlers...')
        logging.root.removeHandler(handler)

        logger.info('training completed!')

        sql.commit()
        sql.close()

        return model, {
            'history': history,
            'options': options,
            'outputs': outputs_epoch,
            'datasets_infos': datasets_infos
        }