def load_model(self, save_dir): """Load the model parameters from disk. Parameters ---------- save_dir : str Directory where the saved variables are placed. """ self.build() path = os.path.abspath(save_dir) saver = VariableSaver(self.get_param_variables(), path) saver.restore(ignore_non_exist=False)
def _enter(self): # open a temporary directory if the checkpoint dir is not specified if self._checkpoint_dir is None: self._temp_dir_ctx = TemporaryDirectory() self._checkpoint_dir = self._temp_dir_ctx.__enter__() else: makedirs(self._checkpoint_dir, exist_ok=True) # create the variable saver self._saver = VariableSaver(self._param_vars, self._checkpoint_dir) # return self as the context object return self
def save_model(self, save_dir, overwrite=False): """Save the model parameters onto disk. Parameters ---------- save_dir : str Directory where to place the saved variables. overwrite : bool Whether or not to overwrite the existing directory? """ self.build() path = os.path.abspath(save_dir) if os.path.exists(path): if overwrite: if os.path.isdir(path): shutil.rmtree(path) else: os.remove(path) elif not os.path.isdir(path) or len(os.listdir(path)) > 0: raise IOError('%r already exists.' % save_dir) saver = VariableSaver(self.get_param_variables(), path) saver.save()
def test_non_exist(self): with TemporaryDirectory() as tempdir: a = tf.get_variable('a', initializer=1, dtype=tf.int32) saver = VariableSaver([a], tempdir) with pytest.raises( IOError, match='Checkpoint file does not exist in directory'): saver.restore() saver.restore(ignore_non_exist=True)
kernel_regularizer=K.regularizers.l2(0.001), activation=tf.nn.relu), K.layers.Dense(100, kernel_regularizer=K.regularizers.l2(0.001), activation=tf.nn.relu), ]), x_dims=120, z_dims=5, ) # To train the Donut model, and use a trained model for prediction trainer = DonutTrainer(model=model, model_vs=model_vs) predictor = DonutPredictor(model) with tf.Session().as_default(): #trainer.fit(train_values, train_labels, train_missing, mean, std) #var_dict = get_variables_as_dict(model_vs) #saver = VariableSaver(var_dict, "donut_without_label_2.ckpt") #saver.save() # Restore variables from `save_dir`. saver = VariableSaver(get_variables_as_dict(model_vs), "donut_without_label_2.ckpt") saver.restore() test_score = predictor.get_score(test_values, test_missing) result = np.array([test_labels[119:], test_score]) np.savetxt('result_arti_sin2.csv', result.transpose(), delimiter=',', fmt='%.3f')
marker='^', color='green', label="Non Anomalies") plt.legend(['Non Anomalies']) plt.xlim(df3['timestamp'].min(), df3['timestamp'].max()) plt.ylim(-.0006, .0006) plt.title('Non Anomalies from Fetal Brain Scan') plt.ylabel('# Direct_1') plt.xlabel('Date-Time') plt.savefig('figs/out_brain_nanomaly.png') from tfsnippet.utils import get_variables_as_dict, VariableSaver session = K.backend.get_session() init = tf.global_variables_initializer() session.run(init) with session.as_default(): var_dict = get_variables_as_dict(model_vs) # save variables to `save_dir` saver = VariableSaver(var_dict, save_dir) saver.save() print("Saved the model successfully") with session.as_default(): # Restore the model. saver = VariableSaver(get_variables_as_dict(model_vs), save_dir) saver.restore() print("Restored the model successfully")
class EarlyStopping(DisposableContext): """ Early-stopping context object. This class provides a object for memorizing the parameters for best metric, in an early-stopping context. An example of using this context: .. code-block:: python with EarlyStopping(param_vars) as es: ... es.update(loss, global_step) ... Where ``es.update(loss, global_step)`` should cause the parameters to be saved on disk if `loss` is better than the current best metric. One may also get the current best metric via ``es.best_metric``. Notes: If no loss is given via ``es.update``, then the variables would keep their latest values when closing an early-stopping object. """ def __init__(self, param_vars, initial_metric=None, checkpoint_dir=None, smaller_is_better=True, restore_on_error=False, cleanup=True, name=None): """ Construct the :class:`EarlyStopping`. Args: param_vars (list[tf.Variable] or dict[str, tf.Variable]): List or dict of variables to be memorized. If a dict is specified, the keys of the dict would be used as the serializations keys via :class:`VariableSaver`. initial_metric (float or tf.Tensor or tf.Variable): The initial best metric (for recovering from previous session). checkpoint_dir (str): The directory where to save the checkpoint files. If not specified, will use a temporary directory. smaller_is_better (bool): Whether or not it is better to have smaller metric values? (default :obj:`True`) restore_on_error (bool): Whether or not to restore the memorized parameters even on error? (default :obj:`False`) cleanup (bool): Whether or not to cleanup the checkpoint directory on exit? This argument will be ignored if `checkpoint_dir` is :obj:`None`, where the temporary directory will always be deleted on exit. name (str): Name scope of all TensorFlow operations. (default "early_stopping"). """ # regularize the parameters if not param_vars: raise ValueError('`param_vars` must not be empty') if isinstance(initial_metric, (tf.Tensor, tf.Variable)): initial_metric = initial_metric.eval() if checkpoint_dir is not None: checkpoint_dir = os.path.abspath(checkpoint_dir) # memorize the parameters self._param_vars = copy.copy(param_vars) self._checkpoint_dir = checkpoint_dir self._smaller_is_better = smaller_is_better self._restore_on_error = restore_on_error self._cleanup = cleanup self._name = name # internal states of the object self._best_metric = initial_metric self._ever_updated = False self._temp_dir_ctx = None self._saver = None # type: VariableSaver def _enter(self): # open a temporary directory if the checkpoint dir is not specified if self._checkpoint_dir is None: self._temp_dir_ctx = TemporaryDirectory() self._checkpoint_dir = self._temp_dir_ctx.__enter__() else: makedirs(self._checkpoint_dir, exist_ok=True) # create the variable saver self._saver = VariableSaver(self._param_vars, self._checkpoint_dir) # return self as the context object return self def _exit(self, exc_type, exc_val, exc_tb): try: # restore the variables # exc_info = (exc_type, exc_val, exc_tb) if exc_type is None or exc_type is KeyboardInterrupt or \ self._restore_on_error: self._saver.restore(ignore_non_exist=True) finally: # cleanup the checkpoint directory try: if self._temp_dir_ctx is not None: self._temp_dir_ctx.__exit__(exc_type, exc_val, exc_tb) elif self._cleanup: if os.path.exists(self._checkpoint_dir): shutil.rmtree(self._checkpoint_dir) except Exception: # pragma: no cover getLogger(__name__).error( 'Failed to cleanup validation save dir %r.', self._checkpoint_dir, exc_info=True) # warning if metric never updated if not self._ever_updated: warnings.warn('Early-stopping metric has never been updated. ' 'The variables will keep their latest values. ' 'Did you forget to add corresponding metric?') def update(self, metric, global_step=None): """ Update the best metric. Args: metric (float): New metric value. global_step (int): Optional global step counter. Returns: bool: Whether or not the best loss has been updated? """ self._require_entered() self._ever_updated = True if self._best_metric is None or \ (self._smaller_is_better and metric < self._best_metric) or \ (not self._smaller_is_better and metric > self._best_metric): self._saver.save(global_step) self._best_metric = metric return True return False @property def best_metric(self): """Get the current best loss.""" return self._best_metric @property def ever_updated(self): """Check whether or not `update` method has ever been called.""" return self._ever_updated
def early_stopping(param_vars, initial_metric=None, save_dir=None, smaller_is_better=True, restore_on_error=False, cleanup=True, name=None): """Open a context to memorize the values of parameters at best metric. This method will open a context with an object to memorize the best metric for early-stopping. An example of using this early-stopping context is: with early_stopping(param_vars) as es: ... es.update(loss, global_step) ... Where ``es.update(loss, global_step)`` should cause the parameters to be saved on disk if `loss` is better than the current best metric. One may also get the best metric via ``es.best_metric``. Note that if no loss is given via ``es.update``, then the variables would keep their latest values when exiting the early-stopping context. Parameters ---------- param_vars : list[tf.Variable] | dict[str, tf.Variable] List or dict of variables to be memorized. If a dict is specified, the keys of the dict would be used as the serializations keys via `VariableSaver`. initial_metric : float | tf.Tensor | tf.Variable The initial best loss (usually for recovering from previous session). save_dir : str The directory where to save the variable values. If not specified, will use a temporary directory. smaller_is_better : bool Whether or not the less, the better loss? (default True) restore_on_error : bool Whether or not to restore the memorized parameters even on error? (default False) cleanup : bool Whether or not to cleanup the saving directory on exit? This argument will be ignored if `save_dir` is None, while the temporary directory will always be deleted on exit. name : str Optional name of this scope. Yields ------ _EarlyStopping The object to receive loss during early-stopping context. """ if not param_vars: raise ValueError('`param_vars` must not be empty.') if save_dir is None: with TemporaryDirectory() as tempdir: with early_stopping(param_vars, initial_metric=initial_metric, save_dir=tempdir, cleanup=False, smaller_is_better=smaller_is_better, restore_on_error=restore_on_error, name=name) as es: yield es else: if isinstance(initial_metric, (tf.Tensor, tf.Variable)): initial_metric = initial_metric.eval() with tf.name_scope(name): saver = VariableSaver(param_vars, save_dir) save_dir = os.path.abspath(save_dir) makedirs(save_dir, exist_ok=True) es = _EarlyStopping(saver, best_metric=initial_metric, smaller_is_better=smaller_is_better) try: yield es except Exception as ex: if isinstance(ex, KeyboardInterrupt) or restore_on_error: saver.restore() raise else: saver.restore() finally: if cleanup: try: if os.path.exists(save_dir): shutil.rmtree(save_dir) except Exception: getLogger(__name__).error( 'Failed to cleanup validation save dir %r.', save_dir, exc_info=True ) if not es.ever_updated: warnings.warn( 'Early-stopping metric has never been updated. ' 'The variables will keep their latest values. ' 'Did you forget to add corresponding metric?' )
def test_save_restore(self): a = tf.get_variable('a', initializer=1, dtype=tf.int32) b = tf.get_variable('b', initializer=2, dtype=tf.int32) c = tf.get_variable('c', initializer=3, dtype=tf.int32) a_ph = tf.placeholder(dtype=tf.int32, shape=(), name='a_ph') b_ph = tf.placeholder(dtype=tf.int32, shape=(), name='b_ph') c_ph = tf.placeholder(dtype=tf.int32, shape=(), name='c_ph') assign_op = tf.group(tf.assign(a, a_ph), tf.assign(b, b_ph), tf.assign(c, c_ph)) def get_values(sess): return sess.run([a, b, c]) def set_values(sess, a, b, c): sess.run(assign_op, feed_dict={a_ph: a, b_ph: b, c_ph: c}) with TemporaryDirectory() as tempdir1, \ TemporaryDirectory() as tempdir2: saver1 = VariableSaver([a, b, c], tempdir1) saver2 = VariableSaver({'aa': a, 'bb': b}, tempdir2) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) self.assertEqual(get_values(sess), [1, 2, 3]) set_values(sess, 10, 20, 30) self.assertEqual(get_values(sess), [10, 20, 30]) saver1.save() set_values(sess, 100, 200, 300) self.assertEqual(get_values(sess), [100, 200, 300]) saver1.restore() self.assertEqual(get_values(sess), [10, 20, 30]) saver2.save() set_values(sess, 100, 200, 300) self.assertEqual(get_values(sess), [100, 200, 300]) saver2.restore() self.assertEqual(get_values(sess), [10, 20, 300]) saver1.restore() self.assertEqual(get_values(sess), [10, 20, 30]) set_values(sess, 101, 201, 301) saver2.save() set_values(sess, 100, 200, 300) self.assertEqual(get_values(sess), [100, 200, 300]) saver2.restore() self.assertEqual(get_values(sess), [101, 201, 300])
def test_errors(self): with TemporaryDirectory() as tempdir: a = tf.get_variable('a', initializer=1, dtype=tf.int32) with pytest.raises(ValueError, match='At least 2 versions should be kept'): _ = VariableSaver([a], tempdir, max_versions=1)