Exemple #1
0
    def load_model(self, save_dir):
        """Load the model parameters from disk.

        Parameters
        ----------
        save_dir : str
            Directory where the saved variables are placed.
        """
        self.build()
        path = os.path.abspath(save_dir)
        saver = VariableSaver(self.get_param_variables(), path)
        saver.restore(ignore_non_exist=False)
    def _enter(self):
        # open a temporary directory if the checkpoint dir is not specified
        if self._checkpoint_dir is None:
            self._temp_dir_ctx = TemporaryDirectory()
            self._checkpoint_dir = self._temp_dir_ctx.__enter__()
        else:
            makedirs(self._checkpoint_dir, exist_ok=True)

        # create the variable saver
        self._saver = VariableSaver(self._param_vars, self._checkpoint_dir)

        # return self as the context object
        return self
Exemple #3
0
    def save_model(self, save_dir, overwrite=False):
        """Save the model parameters onto disk.

        Parameters
        ----------
        save_dir : str
            Directory where to place the saved variables.

        overwrite : bool
            Whether or not to overwrite the existing directory?
        """
        self.build()
        path = os.path.abspath(save_dir)
        if os.path.exists(path):
            if overwrite:
                if os.path.isdir(path):
                    shutil.rmtree(path)
                else:
                    os.remove(path)
            elif not os.path.isdir(path) or len(os.listdir(path)) > 0:
                raise IOError('%r already exists.' % save_dir)
        saver = VariableSaver(self.get_param_variables(), path)
        saver.save()
Exemple #4
0
 def test_non_exist(self):
     with TemporaryDirectory() as tempdir:
         a = tf.get_variable('a', initializer=1, dtype=tf.int32)
         saver = VariableSaver([a], tempdir)
         with pytest.raises(
                 IOError,
                 match='Checkpoint file does not exist in directory'):
             saver.restore()
         saver.restore(ignore_non_exist=True)
Exemple #5
0
                               kernel_regularizer=K.regularizers.l2(0.001),
                               activation=tf.nn.relu),
                K.layers.Dense(100,
                               kernel_regularizer=K.regularizers.l2(0.001),
                               activation=tf.nn.relu),
            ]),
            x_dims=120,
            z_dims=5,
        )

    # To train the Donut model, and use a trained model for prediction
    trainer = DonutTrainer(model=model, model_vs=model_vs)
    predictor = DonutPredictor(model)

    with tf.Session().as_default():
        #trainer.fit(train_values, train_labels, train_missing, mean, std)
        #var_dict = get_variables_as_dict(model_vs)
        #saver = VariableSaver(var_dict, "donut_without_label_2.ckpt")
        #saver.save()

        # Restore variables from `save_dir`.
        saver = VariableSaver(get_variables_as_dict(model_vs),
                              "donut_without_label_2.ckpt")
        saver.restore()
        test_score = predictor.get_score(test_values, test_missing)
        result = np.array([test_labels[119:], test_score])
        np.savetxt('result_arti_sin2.csv',
                   result.transpose(),
                   delimiter=',',
                   fmt='%.3f')
Exemple #6
0
                    marker='^',
                    color='green',
                    label="Non Anomalies")
    plt.legend(['Non Anomalies'])
    plt.xlim(df3['timestamp'].min(), df3['timestamp'].max())
    plt.ylim(-.0006, .0006)
    plt.title('Non Anomalies from Fetal Brain Scan')
    plt.ylabel('# Direct_1')
    plt.xlabel('Date-Time')
    plt.savefig('figs/out_brain_nanomaly.png')

from tfsnippet.utils import get_variables_as_dict, VariableSaver

session = K.backend.get_session()
init = tf.global_variables_initializer()
session.run(init)

with session.as_default():

    var_dict = get_variables_as_dict(model_vs)
    # save variables to `save_dir`
    saver = VariableSaver(var_dict, save_dir)
    saver.save()
    print("Saved the model successfully")

with session.as_default():
    # Restore the model.
    saver = VariableSaver(get_variables_as_dict(model_vs), save_dir)
    saver.restore()
    print("Restored the model successfully")
class EarlyStopping(DisposableContext):
    """
    Early-stopping context object.

    This class provides a object for memorizing the parameters for best
    metric, in an early-stopping context.  An example of using this context:

    .. code-block:: python

        with EarlyStopping(param_vars) as es:
            ...
            es.update(loss, global_step)
            ...

    Where ``es.update(loss, global_step)`` should cause the parameters to
    be saved on disk if `loss` is better than the current best metric.
    One may also get the current best metric via ``es.best_metric``.

    Notes:
        If no loss is given via ``es.update``, then the variables
        would keep their latest values when closing an early-stopping object.
    """
    def __init__(self,
                 param_vars,
                 initial_metric=None,
                 checkpoint_dir=None,
                 smaller_is_better=True,
                 restore_on_error=False,
                 cleanup=True,
                 name=None):
        """
        Construct the :class:`EarlyStopping`.

        Args:
            param_vars (list[tf.Variable] or dict[str, tf.Variable]): List or
                dict of variables to be memorized. If a dict is specified, the
                keys of the dict would be used as the serializations keys via
                :class:`VariableSaver`.
            initial_metric (float or tf.Tensor or tf.Variable): The initial best
                metric (for recovering from previous session).
            checkpoint_dir (str): The directory where to save the checkpoint
                files.  If not specified, will use a temporary directory.
            smaller_is_better (bool): Whether or not it is better to have
                smaller metric values? (default :obj:`True`)
            restore_on_error (bool): Whether or not to restore the memorized
                parameters even on error? (default :obj:`False`)
            cleanup (bool): Whether or not to cleanup the checkpoint directory
                on exit? This argument will be ignored if `checkpoint_dir` is
                :obj:`None`, where the temporary directory will always be
                deleted on exit.
            name (str): Name scope of all TensorFlow operations. (default
                "early_stopping").
        """
        # regularize the parameters
        if not param_vars:
            raise ValueError('`param_vars` must not be empty')

        if isinstance(initial_metric, (tf.Tensor, tf.Variable)):
            initial_metric = initial_metric.eval()

        if checkpoint_dir is not None:
            checkpoint_dir = os.path.abspath(checkpoint_dir)

        # memorize the parameters
        self._param_vars = copy.copy(param_vars)
        self._checkpoint_dir = checkpoint_dir
        self._smaller_is_better = smaller_is_better
        self._restore_on_error = restore_on_error
        self._cleanup = cleanup
        self._name = name

        # internal states of the object
        self._best_metric = initial_metric
        self._ever_updated = False
        self._temp_dir_ctx = None
        self._saver = None  # type: VariableSaver

    def _enter(self):
        # open a temporary directory if the checkpoint dir is not specified
        if self._checkpoint_dir is None:
            self._temp_dir_ctx = TemporaryDirectory()
            self._checkpoint_dir = self._temp_dir_ctx.__enter__()
        else:
            makedirs(self._checkpoint_dir, exist_ok=True)

        # create the variable saver
        self._saver = VariableSaver(self._param_vars, self._checkpoint_dir)

        # return self as the context object
        return self

    def _exit(self, exc_type, exc_val, exc_tb):
        try:
            # restore the variables
            # exc_info = (exc_type, exc_val, exc_tb)
            if exc_type is None or exc_type is KeyboardInterrupt or \
                    self._restore_on_error:
                self._saver.restore(ignore_non_exist=True)

        finally:
            # cleanup the checkpoint directory
            try:
                if self._temp_dir_ctx is not None:
                    self._temp_dir_ctx.__exit__(exc_type, exc_val, exc_tb)
                elif self._cleanup:
                    if os.path.exists(self._checkpoint_dir):
                        shutil.rmtree(self._checkpoint_dir)
            except Exception:  # pragma: no cover
                getLogger(__name__).error(
                    'Failed to cleanup validation save dir %r.',
                    self._checkpoint_dir,
                    exc_info=True)

            # warning if metric never updated
            if not self._ever_updated:
                warnings.warn('Early-stopping metric has never been updated. '
                              'The variables will keep their latest values. '
                              'Did you forget to add corresponding metric?')

    def update(self, metric, global_step=None):
        """
        Update the best metric.

        Args:
            metric (float): New metric value.
            global_step (int): Optional global step counter.

        Returns:
            bool: Whether or not the best loss has been updated?
        """
        self._require_entered()
        self._ever_updated = True
        if self._best_metric is None or \
                (self._smaller_is_better and metric < self._best_metric) or \
                (not self._smaller_is_better and metric > self._best_metric):
            self._saver.save(global_step)
            self._best_metric = metric
            return True
        return False

    @property
    def best_metric(self):
        """Get the current best loss."""
        return self._best_metric

    @property
    def ever_updated(self):
        """Check whether or not `update` method has ever been called."""
        return self._ever_updated
Exemple #8
0
def early_stopping(param_vars, initial_metric=None, save_dir=None,
                   smaller_is_better=True, restore_on_error=False,
                   cleanup=True, name=None):
    """Open a context to memorize the values of parameters at best metric.

    This method will open a context with an object to memorize the best
    metric for early-stopping.  An example of using this early-stopping
    context is:

        with early_stopping(param_vars) as es:
            ...
            es.update(loss, global_step)
            ...

    Where ``es.update(loss, global_step)`` should cause the parameters to
    be saved on disk if `loss` is better than the current best metric.
    One may also get the best metric via ``es.best_metric``.

    Note that if no loss is given via ``es.update``, then the variables
    would keep their latest values when exiting the early-stopping context.

    Parameters
    ----------
    param_vars : list[tf.Variable] | dict[str, tf.Variable]
        List or dict of variables to be memorized.

        If a dict is specified, the keys of the dict would be used as the
        serializations keys via `VariableSaver`.

    initial_metric : float | tf.Tensor | tf.Variable
        The initial best loss (usually for recovering from previous session).

    save_dir : str
        The directory where to save the variable values.
        If not specified, will use a temporary directory.

    smaller_is_better : bool
        Whether or not the less, the better loss? (default True)

    restore_on_error : bool
        Whether or not to restore the memorized parameters even on error?
        (default False)

    cleanup : bool
        Whether or not to cleanup the saving directory on exit?

        This argument will be ignored if `save_dir` is None, while
        the temporary directory will always be deleted on exit.

    name : str
        Optional name of this scope.

    Yields
    ------
    _EarlyStopping
        The object to receive loss during early-stopping context.
    """
    if not param_vars:
        raise ValueError('`param_vars` must not be empty.')

    if save_dir is None:
        with TemporaryDirectory() as tempdir:
            with early_stopping(param_vars, initial_metric=initial_metric,
                                save_dir=tempdir, cleanup=False,
                                smaller_is_better=smaller_is_better,
                                restore_on_error=restore_on_error,
                                name=name) as es:
                yield es

    else:
        if isinstance(initial_metric, (tf.Tensor, tf.Variable)):
            initial_metric = initial_metric.eval()

        with tf.name_scope(name):
            saver = VariableSaver(param_vars, save_dir)
            save_dir = os.path.abspath(save_dir)
            makedirs(save_dir, exist_ok=True)

            es = _EarlyStopping(saver,
                                best_metric=initial_metric,
                                smaller_is_better=smaller_is_better)

            try:
                yield es
            except Exception as ex:
                if isinstance(ex, KeyboardInterrupt) or restore_on_error:
                    saver.restore()
                raise
            else:
                saver.restore()
            finally:
                if cleanup:
                    try:
                        if os.path.exists(save_dir):
                            shutil.rmtree(save_dir)
                    except Exception:
                        getLogger(__name__).error(
                            'Failed to cleanup validation save dir %r.',
                            save_dir, exc_info=True
                        )
                if not es.ever_updated:
                    warnings.warn(
                        'Early-stopping metric has never been updated. '
                        'The variables will keep their latest values. '
                        'Did you forget to add corresponding metric?'
                    )
Exemple #9
0
    def test_save_restore(self):
        a = tf.get_variable('a', initializer=1, dtype=tf.int32)
        b = tf.get_variable('b', initializer=2, dtype=tf.int32)
        c = tf.get_variable('c', initializer=3, dtype=tf.int32)
        a_ph = tf.placeholder(dtype=tf.int32, shape=(), name='a_ph')
        b_ph = tf.placeholder(dtype=tf.int32, shape=(), name='b_ph')
        c_ph = tf.placeholder(dtype=tf.int32, shape=(), name='c_ph')
        assign_op = tf.group(tf.assign(a, a_ph), tf.assign(b, b_ph),
                             tf.assign(c, c_ph))

        def get_values(sess):
            return sess.run([a, b, c])

        def set_values(sess, a, b, c):
            sess.run(assign_op, feed_dict={a_ph: a, b_ph: b, c_ph: c})

        with TemporaryDirectory() as tempdir1, \
                TemporaryDirectory() as tempdir2:
            saver1 = VariableSaver([a, b, c], tempdir1)
            saver2 = VariableSaver({'aa': a, 'bb': b}, tempdir2)

            with self.test_session() as sess:
                sess.run(tf.global_variables_initializer())
                self.assertEqual(get_values(sess), [1, 2, 3])
                set_values(sess, 10, 20, 30)
                self.assertEqual(get_values(sess), [10, 20, 30])

                saver1.save()
                set_values(sess, 100, 200, 300)
                self.assertEqual(get_values(sess), [100, 200, 300])
                saver1.restore()
                self.assertEqual(get_values(sess), [10, 20, 30])

                saver2.save()
                set_values(sess, 100, 200, 300)
                self.assertEqual(get_values(sess), [100, 200, 300])
                saver2.restore()
                self.assertEqual(get_values(sess), [10, 20, 300])

                saver1.restore()
                self.assertEqual(get_values(sess), [10, 20, 30])

                set_values(sess, 101, 201, 301)
                saver2.save()
                set_values(sess, 100, 200, 300)
                self.assertEqual(get_values(sess), [100, 200, 300])
                saver2.restore()
                self.assertEqual(get_values(sess), [101, 201, 300])
Exemple #10
0
 def test_errors(self):
     with TemporaryDirectory() as tempdir:
         a = tf.get_variable('a', initializer=1, dtype=tf.int32)
         with pytest.raises(ValueError,
                            match='At least 2 versions should be kept'):
             _ = VariableSaver([a], tempdir, max_versions=1)