def update_data(self, var_names, data, phase='train', index=-1, bind_vars=True, do_compile=True): """Update data in storegate with given options. Update (replace) data in the storegate. If ``var_names`` does not exist in given ``data_id`` and ``phase``, data are newly added. Otherwise, selected data are replaced with given data. Args: var_names (str or list(srt)): see ``add_data()`` method. data (list or ndarray): see ``add_data()`` method. phase (str or tuple): see ``add_data()`` method. index (int or tuple): If ``index`` is -1 (default), all data are updated for given options. If ``index`` is int, only the data with ``index`` is updated. If index is (x, y), data in the range (x, y) are updated. bind_vars (bool): see ``add_data()`` method. do_compile (bool): do compile if True after updating data. Examples: >>> # update data of train phase >>> storegate.update_data(var_names='var0', data=[1], phase='train', index=1) """ self._check_valid_data_id() var_names, data = self._view_to_list(var_names, data, bind_vars) self._check_valid_phase(phase) for var_name, idata in zip(var_names, data): idata = self._convert_to_np(idata) indices = self._get_phase_indices(phase, len(idata)) for iphase, phase_data in zip(const.PHASES, np.split(idata, indices)): metadata = self._db.get_metadata(self._data_id, iphase) if len(phase_data) == 0: continue if var_name not in metadata.keys(): logger.debug( f'Adding {phase} : {var_name} to {self._data_id}') self.add_data(var_name, phase_data, iphase, False, bind_vars) else: self._db.update_data(self._data_id, var_name, phase_data, iphase, index) if self._data_id in self._metadata: self._metadata[self._data_id]['compiled'] = False if do_compile: self.compile(reset=False)
def close(self): """Close the shelve database. """ if self._state == 'close': logger.debug('saver is already close') else: self._shelve.close() self._state = 'close'
def to_storage(self, data_id, var_name, phase): metadata_zarr = self._db['zarr'].get_metadata(data_id, phase) metadata_numpy = self._db['numpy'].get_metadata(data_id, phase) if var_name in metadata_zarr: logger.debug(f'{var_name} is already on storage (zarr)') elif var_name in metadata_numpy: tmp_data = self.get_data(data_id, var_name, phase, -1) self.delete_data(data_id, var_name, phase) self.add_data(data_id, var_name, tmp_data, phase, 'zarr') else: raise ValueError(f'{var_name} does not exist in hybrid database')
def on_epoch_end(self, epoch, logs=None): logger.debug('') for i, var in enumerate(self.model.alpha_vars): logger.debug( f"epoch = {epoch}: alpha {var.name} = {self.formatting(var)}") self._alpha_history[i].append(var.numpy().reshape(-1)) logger.debug('') logger.debug('')
def to_storage(self, key): """Move object from dict to storage. Args: key (str): the unique identifier to be moved. """ if key in self.keys('shelve'): logger.debug(f'{key} already exist in shelve (storage)') elif key in self.keys('dict'): value = copy.copy(self[key]) del self[key] self.add(key, value, 'shelve') else: raise ValueError(f'{key} does not exist in dict')
def save_checkpoint(self, val_loss, model): """Saves model when validation loss decrease.""" from copy import deepcopy from torch import save logger.debug( f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). ' + 'updating model ...') if self.save: from os.path import isdir, join if isdir(self.path): save_path = join(self.path, 'checkpoint.pt') else: save_path = self.path save(model.state_dict(), save_path) self.val_loss_min = val_loss return deepcopy(model)
def __call__(self, val_loss, model): score = -val_loss if self.best_score is None: self.best_score = score self.best_model = self.save_checkpoint(val_loss, model) elif score <= self.best_score: self.counter += 1 logger.debug( f'EarlyStopping counter: {self.counter} out of {self.patience}' ) if self.counter >= self.patience: self.early_stop = True else: self.best_score = score self.counter = 0 self.best_model = self.save_checkpoint(val_loss, model) return self.early_stop
def plot_regression(storegate, var_pred=[], var_target=[], data_id="", phase="test", save_dir="tmp"): storegate.set_data_id(data_id) y_test = storegate.get_data(phase=phase, var_names=var_target) y_pred = storegate.get_data(phase=phase, var_names=var_pred) from multiml_htautau.task.loss import Tau4vecCalibLoss_np mse = Tau4vecCalibLoss_np(pt_scale=1e-2, use_pxyz=True)(y_test, y_pred) if phase is None: phase = "total" from multiml import logger logger.debug(f"mse ({phase}) = {mse}") outputname = f"{save_dir}/mse.{phase}.pkl" with open(outputname, 'wb') as f: import pickle pickle.dump({"mse": mse}, f) plot_regression_pull(y_pred, y_test, var_target, save_dir)
def open(self, mode=None): """Open shelve database with given mode. Args: mode (str): 'r': reading only, 'w': reading and writing, 'c' (default): reading and writing, creating it if it does not exist, 'n': always create a new empty database, reading and writing. Examples: >>> saver.open('r') >>> print(saver['key0']) >>> print(saver['key1']) >>> saver.close() """ if mode is None: mode = self._shelve_mode if self._state == 'open': logger.debug('saver is already open') else: self._shelve = shelve.open(f'{self._save_dir}/{self._shelve_name}', flag=mode) self._state = 'open'
def _training_darts(self, x, y): result = {} n_train = len(x['train'][0]) n_valid = len(x['valid'][0]) logger.debug(f"num of training samples = {n_train}") logger.debug(f"num of validation samples = {n_valid}") ################################### # Check consistency of batch size # ################################### import math v_gcd = math.gcd(n_train, n_valid) frac_train = n_train // v_gcd frac_sum = (n_train + n_valid) // v_gcd if n_train < self._batch_size: self._batch_size = n_train if self._batch_size % frac_train > 0: raise ValueError( f"batch_size of darts training should be divisible by training/valid ratio. bsize_darts_train = {self._batch_size}, frac_train = {frac_train}" ) batch_size_total = self._batch_size * frac_sum // frac_train logger.debug( f"total batch size (train + valid) in DARTS training = {batch_size_total}" ) alpha_model_names = [v.name for v in self._model.alpha_vars] result['alpha_model_names'] = alpha_model_names # Validate for var in self._model.weight_vars: if 'batch_normalization' in var.name: logger.warn('DARTS should not have batch normalization layer.') ####################################### # Merging training/validation samples # ####################################### x_train_valid = [] y_train_valid = [] bsize_valid = batch_size_total - self._batch_size logger.debug( f"validation batch size in DARTS training = {bsize_valid}") for v1, v2 in zip(x['train'], x['valid']): v1 = v1.reshape((self._batch_size, -1) + v1.shape[1:]) v2 = v2.reshape((bsize_valid, -1) + v2.shape[1:]) v = np.concatenate([v1, v2], axis=0) v = v.reshape((-1, ) + v.shape[2:]) x_train_valid.append(v) for v1, v2 in zip(y['train'], y['valid']): v1 = v1.reshape((self._batch_size, -1) + v1.shape[1:]) v2 = v2.reshape((bsize_valid, -1) + v2.shape[1:]) v = np.concatenate([v1, v2], axis=0) v = v.reshape((-1, ) + v.shape[2:]) y_train_valid.append(v) ################## # DARTS training # ################## self.ml.model._batch_size_train.assign(self._batch_size) import tempfile chpt_path = f'{tempfile.mkdtemp()}/tf_chpt' cbs = [] from tensorflow.keras.callbacks import EarlyStopping es_cb = EarlyStopping(monitor='valid_loss', patience=self._max_patience, verbose=0, mode='min', restore_best_weights=True) cbs.append(es_cb) from tensorflow.keras.callbacks import ModelCheckpoint cp_cb = ModelCheckpoint(filepath=chpt_path, monitor='valid_loss', verbose=0, save_best_only=True, save_weights_only=True, mode='min') cbs.append(cp_cb) from tensorflow.keras.callbacks import TerminateOnNaN nan_cb = TerminateOnNaN() cbs.append(nan_cb) if self._save_tensorboard: from tensorflow.keras.callbacks import TensorBoard tb_cb = TensorBoard(log_dir=f'{self._saver.save_dir}/{self._name}', histogram_freq=1, profile_batch=5) cbs.append(tb_cb) from multiml.agent.keras.callback import (AlphaDumperCallback, EachLossDumperCallback) alpha_cb = AlphaDumperCallback() loss_cb = EachLossDumperCallback() cbs.append(alpha_cb) cbs.append(loss_cb) training_verbose_mode = 0 if logger.MIN_LEVEL <= logger.DEBUG: training_verbose_mode = 1 history = self.ml.model.fit(x=x_train_valid, y=y_train_valid, batch_size=batch_size_total, epochs=self._num_epochs, callbacks=cbs, validation_data=(x['test'], y['test']), shuffle=False, verbose=training_verbose_mode) history0 = history.history result['darts_loss_train'] = history0['train_loss'] result['darts_loss_valid'] = history0['valid_loss'] result['darts_loss_test'] = history0['val_test_loss'] result['darts_alpha_history'] = alpha_cb.get_alpha_history() result['darts_loss_history'] = loss_cb.get_loss_history() result['darts_lambda_history'] = history0['lambda'] result['darts_alpha_gradients_sum'] = history0['alpha_gradients_sum'] result['darts_alpha_gradients_sq_sum'] = history0[ 'alpha_gradients_sq_sum'] result['darts_alpha_gradients_n'] = history0['alpha_gradients_n'] # Check nan in alpha parameters # self._has_nan_in_alpha = nan_cb._isnan(self._model.alpha_vars) ################## # Save meta data # ################## self._index_of_best_submodels = self.ml.model.get_index_of_best_submodels( ) return result
def execute(self): for key, value in self.__dict__.items(): logger.debug(f'{key} : {value}')
def on_train_end(self, logs=None): for var in self.model.alpha_vars: logger.debug( f"DARTS final alpha {var.name} = {self.formatting(var)}")
def on_train_begin(self, logs=None): for var in self.model.alpha_vars: logger.debug(f"Initial alpha {var.name} = {self.formatting(var)}") self._alpha_history = [[] for _ in range(len(self.model.alpha_vars))]