Ejemplo n.º 1
0
    def update_data(self,
                    var_names,
                    data,
                    phase='train',
                    index=-1,
                    bind_vars=True,
                    do_compile=True):
        """Update data in storegate with given options.

        Update (replace) data in the storegate. If ``var_names`` does not exist
        in given ``data_id`` and ``phase``, data are newly added. Otherwise,
        selected data are replaced with given data.

        Args:
            var_names (str or list(srt)): see ``add_data()`` method.
            data (list or ndarray): see ``add_data()`` method.
            phase (str or tuple): see ``add_data()`` method.
            index (int or tuple): If ``index`` is -1 (default), all data are
                updated for given options. If ``index`` is int, only the data
                with ``index`` is updated. If index is (x, y), data in the
                range (x, y) are updated.
            bind_vars (bool): see ``add_data()`` method.
            do_compile (bool): do compile if True after updating data.

        Examples:
            >>> # update data of train phase
            >>> storegate.update_data(var_names='var0', data=[1], phase='train', index=1)
        """
        self._check_valid_data_id()

        var_names, data = self._view_to_list(var_names, data, bind_vars)
        self._check_valid_phase(phase)

        for var_name, idata in zip(var_names, data):
            idata = self._convert_to_np(idata)
            indices = self._get_phase_indices(phase, len(idata))

            for iphase, phase_data in zip(const.PHASES,
                                          np.split(idata, indices)):
                metadata = self._db.get_metadata(self._data_id, iphase)

                if len(phase_data) == 0:
                    continue

                if var_name not in metadata.keys():
                    logger.debug(
                        f'Adding {phase} : {var_name} to {self._data_id}')
                    self.add_data(var_name, phase_data, iphase, False,
                                  bind_vars)

                else:
                    self._db.update_data(self._data_id, var_name, phase_data,
                                         iphase, index)

        if self._data_id in self._metadata:
            self._metadata[self._data_id]['compiled'] = False

        if do_compile:
            self.compile(reset=False)
Ejemplo n.º 2
0
 def close(self):
     """Close the shelve database.
     """
     if self._state == 'close':
         logger.debug('saver is already close')
     else:
         self._shelve.close()
         self._state = 'close'
Ejemplo n.º 3
0
    def to_storage(self, data_id, var_name, phase):
        metadata_zarr = self._db['zarr'].get_metadata(data_id, phase)
        metadata_numpy = self._db['numpy'].get_metadata(data_id, phase)

        if var_name in metadata_zarr:
            logger.debug(f'{var_name} is already on storage (zarr)')

        elif var_name in metadata_numpy:
            tmp_data = self.get_data(data_id, var_name, phase, -1)
            self.delete_data(data_id, var_name, phase)
            self.add_data(data_id, var_name, tmp_data, phase, 'zarr')

        else:
            raise ValueError(f'{var_name} does not exist in hybrid database')
Ejemplo n.º 4
0
 def on_epoch_end(self, epoch, logs=None):
     logger.debug('')
     for i, var in enumerate(self.model.alpha_vars):
         logger.debug(
             f"epoch = {epoch}: alpha {var.name} = {self.formatting(var)}")
         self._alpha_history[i].append(var.numpy().reshape(-1))
     logger.debug('')
     logger.debug('')
Ejemplo n.º 5
0
    def to_storage(self, key):
        """Move object from dict to storage.

        Args:
            key (str): the unique identifier to be moved.
        """
        if key in self.keys('shelve'):
            logger.debug(f'{key} already exist in shelve (storage)')

        elif key in self.keys('dict'):
            value = copy.copy(self[key])
            del self[key]
            self.add(key, value, 'shelve')

        else:
            raise ValueError(f'{key} does not exist in dict')
Ejemplo n.º 6
0
    def save_checkpoint(self, val_loss, model):
        """Saves model when validation loss decrease."""
        from copy import deepcopy

        from torch import save
        logger.debug(
            f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  '
            + 'updating model ...')
        if self.save:
            from os.path import isdir, join
            if isdir(self.path):
                save_path = join(self.path, 'checkpoint.pt')
            else:
                save_path = self.path
            save(model.state_dict(), save_path)
        self.val_loss_min = val_loss
        return deepcopy(model)
Ejemplo n.º 7
0
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.best_model = self.save_checkpoint(val_loss, model)
        elif score <= self.best_score:
            self.counter += 1
            logger.debug(
                f'EarlyStopping counter: {self.counter} out of {self.patience}'
            )
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.counter = 0
            self.best_model = self.save_checkpoint(val_loss, model)

        return self.early_stop
Ejemplo n.º 8
0
def plot_regression(storegate,
                    var_pred=[],
                    var_target=[],
                    data_id="",
                    phase="test",
                    save_dir="tmp"):
    storegate.set_data_id(data_id)
    y_test = storegate.get_data(phase=phase, var_names=var_target)
    y_pred = storegate.get_data(phase=phase, var_names=var_pred)

    from multiml_htautau.task.loss import Tau4vecCalibLoss_np
    mse = Tau4vecCalibLoss_np(pt_scale=1e-2, use_pxyz=True)(y_test, y_pred)
    if phase is None:
        phase = "total"
    from multiml import logger
    logger.debug(f"mse ({phase}) = {mse}")
    outputname = f"{save_dir}/mse.{phase}.pkl"
    with open(outputname, 'wb') as f:
        import pickle
        pickle.dump({"mse": mse}, f)

    plot_regression_pull(y_pred, y_test, var_target, save_dir)
Ejemplo n.º 9
0
    def open(self, mode=None):
        """Open shelve database with given mode.

        Args:
            mode (str): 'r': reading only, 'w': reading and writing,
                'c' (default): reading and writing, creating it if it does not
                exist, 'n': always create a new empty database, reading and
                writing.
  
        Examples:
            >>> saver.open('r')
            >>> print(saver['key0'])
            >>> print(saver['key1'])
            >>> saver.close()
        """
        if mode is None:
            mode = self._shelve_mode

        if self._state == 'open':
            logger.debug('saver is already open')
        else:
            self._shelve = shelve.open(f'{self._save_dir}/{self._shelve_name}',
                                       flag=mode)
            self._state = 'open'
Ejemplo n.º 10
0
    def _training_darts(self, x, y):
        result = {}

        n_train = len(x['train'][0])
        n_valid = len(x['valid'][0])
        logger.debug(f"num of training samples = {n_train}")
        logger.debug(f"num of validation samples = {n_valid}")

        ###################################
        # Check consistency of batch size #
        ###################################
        import math
        v_gcd = math.gcd(n_train, n_valid)
        frac_train = n_train // v_gcd
        frac_sum = (n_train + n_valid) // v_gcd

        if n_train < self._batch_size:
            self._batch_size = n_train

        if self._batch_size % frac_train > 0:
            raise ValueError(
                f"batch_size of darts training should be divisible by training/valid ratio. bsize_darts_train = {self._batch_size}, frac_train = {frac_train}"
            )

        batch_size_total = self._batch_size * frac_sum // frac_train
        logger.debug(
            f"total batch size (train + valid) in DARTS training = {batch_size_total}"
        )

        alpha_model_names = [v.name for v in self._model.alpha_vars]
        result['alpha_model_names'] = alpha_model_names

        # Validate
        for var in self._model.weight_vars:
            if 'batch_normalization' in var.name:
                logger.warn('DARTS should not have batch normalization layer.')

        #######################################
        # Merging training/validation samples #
        #######################################
        x_train_valid = []
        y_train_valid = []
        bsize_valid = batch_size_total - self._batch_size
        logger.debug(
            f"validation batch size in DARTS training = {bsize_valid}")
        for v1, v2 in zip(x['train'], x['valid']):
            v1 = v1.reshape((self._batch_size, -1) + v1.shape[1:])
            v2 = v2.reshape((bsize_valid, -1) + v2.shape[1:])
            v = np.concatenate([v1, v2], axis=0)
            v = v.reshape((-1, ) + v.shape[2:])
            x_train_valid.append(v)

        for v1, v2 in zip(y['train'], y['valid']):
            v1 = v1.reshape((self._batch_size, -1) + v1.shape[1:])
            v2 = v2.reshape((bsize_valid, -1) + v2.shape[1:])
            v = np.concatenate([v1, v2], axis=0)
            v = v.reshape((-1, ) + v.shape[2:])
            y_train_valid.append(v)

        ##################
        # DARTS training #
        ##################
        self.ml.model._batch_size_train.assign(self._batch_size)

        import tempfile
        chpt_path = f'{tempfile.mkdtemp()}/tf_chpt'

        cbs = []

        from tensorflow.keras.callbacks import EarlyStopping
        es_cb = EarlyStopping(monitor='valid_loss',
                              patience=self._max_patience,
                              verbose=0,
                              mode='min',
                              restore_best_weights=True)
        cbs.append(es_cb)

        from tensorflow.keras.callbacks import ModelCheckpoint
        cp_cb = ModelCheckpoint(filepath=chpt_path,
                                monitor='valid_loss',
                                verbose=0,
                                save_best_only=True,
                                save_weights_only=True,
                                mode='min')
        cbs.append(cp_cb)

        from tensorflow.keras.callbacks import TerminateOnNaN
        nan_cb = TerminateOnNaN()
        cbs.append(nan_cb)

        if self._save_tensorboard:
            from tensorflow.keras.callbacks import TensorBoard
            tb_cb = TensorBoard(log_dir=f'{self._saver.save_dir}/{self._name}',
                                histogram_freq=1,
                                profile_batch=5)
            cbs.append(tb_cb)

        from multiml.agent.keras.callback import (AlphaDumperCallback,
                                                  EachLossDumperCallback)
        alpha_cb = AlphaDumperCallback()
        loss_cb = EachLossDumperCallback()
        cbs.append(alpha_cb)
        cbs.append(loss_cb)

        training_verbose_mode = 0
        if logger.MIN_LEVEL <= logger.DEBUG:
            training_verbose_mode = 1

        history = self.ml.model.fit(x=x_train_valid,
                                    y=y_train_valid,
                                    batch_size=batch_size_total,
                                    epochs=self._num_epochs,
                                    callbacks=cbs,
                                    validation_data=(x['test'], y['test']),
                                    shuffle=False,
                                    verbose=training_verbose_mode)

        history0 = history.history
        result['darts_loss_train'] = history0['train_loss']
        result['darts_loss_valid'] = history0['valid_loss']
        result['darts_loss_test'] = history0['val_test_loss']
        result['darts_alpha_history'] = alpha_cb.get_alpha_history()
        result['darts_loss_history'] = loss_cb.get_loss_history()
        result['darts_lambda_history'] = history0['lambda']
        result['darts_alpha_gradients_sum'] = history0['alpha_gradients_sum']
        result['darts_alpha_gradients_sq_sum'] = history0[
            'alpha_gradients_sq_sum']
        result['darts_alpha_gradients_n'] = history0['alpha_gradients_n']

        # Check nan in alpha parameters
        # self._has_nan_in_alpha = nan_cb._isnan(self._model.alpha_vars)

        ##################
        # Save meta data #
        ##################
        self._index_of_best_submodels = self.ml.model.get_index_of_best_submodels(
        )

        return result
Ejemplo n.º 11
0
 def execute(self):
     for key, value in self.__dict__.items():
         logger.debug(f'{key} : {value}')
Ejemplo n.º 12
0
 def on_train_end(self, logs=None):
     for var in self.model.alpha_vars:
         logger.debug(
             f"DARTS final alpha {var.name} = {self.formatting(var)}")
Ejemplo n.º 13
0
    def on_train_begin(self, logs=None):
        for var in self.model.alpha_vars:
            logger.debug(f"Initial alpha {var.name} = {self.formatting(var)}")

        self._alpha_history = [[] for _ in range(len(self.model.alpha_vars))]