示例#1
0
    def to_memory(self, var_names, phase='train'):
        """Move data from storage to memory.

        This method is valid for only hybrid backend. This should be effective
        to reduce data I/O impacts.

        Args:
            var_names (str or list): see ``add_data()`` method.
            phase (str): *all*, *train*, *valid*, *test*.
        """
        if self._backend != 'hybrid':
            logger.warn(
                f'to_memory is valid for only hybrid database ({self._backend})'
            )

        self._check_valid_data_id()

        if isinstance(var_names, str):
            var_names = [var_names]

        if phase == 'all':
            phases = const.PHASES
        else:
            phases = [phase]

        for iphase in phases:
            for var_name in var_names:
                self._db.to_memory(self._data_id, var_name, iphase)
示例#2
0
    def to_storage(self, var_names, phase='train'):
        """Move data from storage to memory.

        This method is valid for only hybrid backend. This is useful if data
        are large, then need to be escaped to storage.

        Args:
            var_names (str or list): see ``add_data()`` method.
            phase (str): *all*, *train*, *valid*, *test*.
        """
        if self._backend != 'hybrid':
            logger.warn('to_storage is valid for only hybrid database')

        self._check_valid_data_id()

        if isinstance(var_names, str):
            var_names = [var_names]

        if phase == 'all':
            phases = const.PHASES
        else:
            phases = [phase]

        for iphase in phases:
            for var_name in var_names:
                self._db.to_storage(self._data_id, var_name, iphase)
示例#3
0
    def __init__(self,
                 subtasks,
                 use_multi_loss=False,
                 variable_mapping=None,
                 **kwargs):
        """ Constructor of ModelConnectionTask.

        Args:
            subtasks (list): list must contains ordered instrance objects
                inherited from ``MLBaseTask``.
            use_multi_loss (bool): If False, intermediate losses are not
                considered in training steps.
            variable_mapping (list(str, str)): Input variables are replaced
                following this list. Used for the case that the input varialbes
                change from pre-training to main-training (with model connecting).
            **kwargs: Arbitrary keyword arguments passed to ``MLBaseTask``.
        """
        super().__init__(**kwargs)

        self._subtasks = subtasks
        self._use_multi_loss = use_multi_loss
        self._variable_mapping = variable_mapping

        self._input_var_index = None
        self._output_var_index = None

        if self._input_var_names is not None:
            logger.warn(
                'input_var_names is geiven but it will be set automatically ')
            self._input_var_names = None

        if self._output_var_names is not None:
            logger.warn(
                'output_var_names is geiven but it will be set automatically ')
            self._output_var_names = None
示例#4
0
def compile(obj, obj_args, modules):
    # str object
    if isinstance(obj, str):
        return getattr(modules, obj)(**obj_args)

    # class object
    elif inspect.isclass(obj):
        return obj(**obj_args)

    # instance object
    else:
        if obj_args != {}:
            logger.warn('instance object is given but args is also provided')
        return copy.copy(obj)
示例#5
0
    def set_mode(self, mode):
        """Set backend mode of hybrid architecture.

        This method is valid for only *hybrid* database. If ``mode`` is
        *numpy*, basically data will be written in memory, and ``mode`` is
        *zarr*, dataw ill be written to storage.

        Args:
            mode (str): *numpy* or *zarr*.
        """
        if self._backend != 'hybrid':
            logger.warn(
                f'set_mode is valid for only hybrid database ({self._backend})'
            )

        else:
            self._db.mode = mode
示例#6
0
    def calculate(self):
        """ Calculate AUC.
        """
        y_true, y_pred = self.get_true_pred_data()

        if any(np.isnan(y_pred)):
            logger.warn(
                "There is nan in prediction values for auc. Replace nan with zero"
            )
            np.nan_to_num(y_pred, copy=False, nan=0.0)

        from sklearn.metrics import roc_curve
        fpr, tpr, _ = roc_curve(y_true, y_pred)
        from sklearn.metrics import auc
        roc_auc = auc(fpr, tpr)

        return roc_auc
示例#7
0
    def _training_darts(self, x, y):
        result = {}

        n_train = len(x['train'][0])
        n_valid = len(x['valid'][0])
        logger.debug(f"num of training samples = {n_train}")
        logger.debug(f"num of validation samples = {n_valid}")

        ###################################
        # Check consistency of batch size #
        ###################################
        import math
        v_gcd = math.gcd(n_train, n_valid)
        frac_train = n_train // v_gcd
        frac_sum = (n_train + n_valid) // v_gcd

        if n_train < self._batch_size:
            self._batch_size = n_train

        if self._batch_size % frac_train > 0:
            raise ValueError(
                f"batch_size of darts training should be divisible by training/valid ratio. bsize_darts_train = {self._batch_size}, frac_train = {frac_train}"
            )

        batch_size_total = self._batch_size * frac_sum // frac_train
        logger.debug(
            f"total batch size (train + valid) in DARTS training = {batch_size_total}"
        )

        alpha_model_names = [v.name for v in self._model.alpha_vars]
        result['alpha_model_names'] = alpha_model_names

        # Validate
        for var in self._model.weight_vars:
            if 'batch_normalization' in var.name:
                logger.warn('DARTS should not have batch normalization layer.')

        #######################################
        # Merging training/validation samples #
        #######################################
        x_train_valid = []
        y_train_valid = []
        bsize_valid = batch_size_total - self._batch_size
        logger.debug(
            f"validation batch size in DARTS training = {bsize_valid}")
        for v1, v2 in zip(x['train'], x['valid']):
            v1 = v1.reshape((self._batch_size, -1) + v1.shape[1:])
            v2 = v2.reshape((bsize_valid, -1) + v2.shape[1:])
            v = np.concatenate([v1, v2], axis=0)
            v = v.reshape((-1, ) + v.shape[2:])
            x_train_valid.append(v)

        for v1, v2 in zip(y['train'], y['valid']):
            v1 = v1.reshape((self._batch_size, -1) + v1.shape[1:])
            v2 = v2.reshape((bsize_valid, -1) + v2.shape[1:])
            v = np.concatenate([v1, v2], axis=0)
            v = v.reshape((-1, ) + v.shape[2:])
            y_train_valid.append(v)

        ##################
        # DARTS training #
        ##################
        self.ml.model._batch_size_train.assign(self._batch_size)

        import tempfile
        chpt_path = f'{tempfile.mkdtemp()}/tf_chpt'

        cbs = []

        from tensorflow.keras.callbacks import EarlyStopping
        es_cb = EarlyStopping(monitor='valid_loss',
                              patience=self._max_patience,
                              verbose=0,
                              mode='min',
                              restore_best_weights=True)
        cbs.append(es_cb)

        from tensorflow.keras.callbacks import ModelCheckpoint
        cp_cb = ModelCheckpoint(filepath=chpt_path,
                                monitor='valid_loss',
                                verbose=0,
                                save_best_only=True,
                                save_weights_only=True,
                                mode='min')
        cbs.append(cp_cb)

        from tensorflow.keras.callbacks import TerminateOnNaN
        nan_cb = TerminateOnNaN()
        cbs.append(nan_cb)

        if self._save_tensorboard:
            from tensorflow.keras.callbacks import TensorBoard
            tb_cb = TensorBoard(log_dir=f'{self._saver.save_dir}/{self._name}',
                                histogram_freq=1,
                                profile_batch=5)
            cbs.append(tb_cb)

        from multiml.agent.keras.callback import (AlphaDumperCallback,
                                                  EachLossDumperCallback)
        alpha_cb = AlphaDumperCallback()
        loss_cb = EachLossDumperCallback()
        cbs.append(alpha_cb)
        cbs.append(loss_cb)

        training_verbose_mode = 0
        if logger.MIN_LEVEL <= logger.DEBUG:
            training_verbose_mode = 1

        history = self.ml.model.fit(x=x_train_valid,
                                    y=y_train_valid,
                                    batch_size=batch_size_total,
                                    epochs=self._num_epochs,
                                    callbacks=cbs,
                                    validation_data=(x['test'], y['test']),
                                    shuffle=False,
                                    verbose=training_verbose_mode)

        history0 = history.history
        result['darts_loss_train'] = history0['train_loss']
        result['darts_loss_valid'] = history0['valid_loss']
        result['darts_loss_test'] = history0['val_test_loss']
        result['darts_alpha_history'] = alpha_cb.get_alpha_history()
        result['darts_loss_history'] = loss_cb.get_loss_history()
        result['darts_lambda_history'] = history0['lambda']
        result['darts_alpha_gradients_sum'] = history0['alpha_gradients_sum']
        result['darts_alpha_gradients_sq_sum'] = history0[
            'alpha_gradients_sq_sum']
        result['darts_alpha_gradients_n'] = history0['alpha_gradients_n']

        # Check nan in alpha parameters
        # self._has_nan_in_alpha = nan_cb._isnan(self._model.alpha_vars)

        ##################
        # Save meta data #
        ##################
        self._index_of_best_submodels = self.ml.model.get_index_of_best_submodels(
        )

        return result