Ejemplo n.º 1
0
    def save_model(self, save_dir, overwrite=False):
        """Save the model parameters onto disk.

        Parameters
        ----------
        save_dir : str
            Directory where to place the saved variables.

        overwrite : bool
            Whether or not to overwrite the existing directory?
        """
        self.build()
        path = os.path.abspath(save_dir)
        if os.path.exists(path):
            if overwrite:
                if os.path.isdir(path):
                    shutil.rmtree(path)
                else:
                    os.remove(path)
            elif not os.path.isdir(path) or len(os.listdir(path)) > 0:
                raise IOError('%r already exists.' % save_dir)
        saver = VariableSaver(self.get_param_variables(), path)
        saver.save()
Ejemplo n.º 2
0
    def test_save_restore(self):
        a = tf.get_variable('a', initializer=1, dtype=tf.int32)
        b = tf.get_variable('b', initializer=2, dtype=tf.int32)
        c = tf.get_variable('c', initializer=3, dtype=tf.int32)
        a_ph = tf.placeholder(dtype=tf.int32, shape=(), name='a_ph')
        b_ph = tf.placeholder(dtype=tf.int32, shape=(), name='b_ph')
        c_ph = tf.placeholder(dtype=tf.int32, shape=(), name='c_ph')
        assign_op = tf.group(tf.assign(a, a_ph), tf.assign(b, b_ph),
                             tf.assign(c, c_ph))

        def get_values(sess):
            return sess.run([a, b, c])

        def set_values(sess, a, b, c):
            sess.run(assign_op, feed_dict={a_ph: a, b_ph: b, c_ph: c})

        with TemporaryDirectory() as tempdir1, \
                TemporaryDirectory() as tempdir2:
            saver1 = VariableSaver([a, b, c], tempdir1)
            saver2 = VariableSaver({'aa': a, 'bb': b}, tempdir2)

            with self.test_session() as sess:
                sess.run(tf.global_variables_initializer())
                self.assertEqual(get_values(sess), [1, 2, 3])
                set_values(sess, 10, 20, 30)
                self.assertEqual(get_values(sess), [10, 20, 30])

                saver1.save()
                set_values(sess, 100, 200, 300)
                self.assertEqual(get_values(sess), [100, 200, 300])
                saver1.restore()
                self.assertEqual(get_values(sess), [10, 20, 30])

                saver2.save()
                set_values(sess, 100, 200, 300)
                self.assertEqual(get_values(sess), [100, 200, 300])
                saver2.restore()
                self.assertEqual(get_values(sess), [10, 20, 300])

                saver1.restore()
                self.assertEqual(get_values(sess), [10, 20, 30])

                set_values(sess, 101, 201, 301)
                saver2.save()
                set_values(sess, 100, 200, 300)
                self.assertEqual(get_values(sess), [100, 200, 300])
                saver2.restore()
                self.assertEqual(get_values(sess), [101, 201, 300])
Ejemplo n.º 3
0
                    marker='^',
                    color='green',
                    label="Non Anomalies")
    plt.legend(['Non Anomalies'])
    plt.xlim(df3['timestamp'].min(), df3['timestamp'].max())
    plt.ylim(-.0006, .0006)
    plt.title('Non Anomalies from Fetal Brain Scan')
    plt.ylabel('# Direct_1')
    plt.xlabel('Date-Time')
    plt.savefig('figs/out_brain_nanomaly.png')

from tfsnippet.utils import get_variables_as_dict, VariableSaver

session = K.backend.get_session()
init = tf.global_variables_initializer()
session.run(init)

with session.as_default():

    var_dict = get_variables_as_dict(model_vs)
    # save variables to `save_dir`
    saver = VariableSaver(var_dict, save_dir)
    saver.save()
    print("Saved the model successfully")

with session.as_default():
    # Restore the model.
    saver = VariableSaver(get_variables_as_dict(model_vs), save_dir)
    saver.restore()
    print("Restored the model successfully")
Ejemplo n.º 4
0
class EarlyStopping(DisposableContext):
    """
    Early-stopping context object.

    This class provides a object for memorizing the parameters for best
    metric, in an early-stopping context.  An example of using this context:

    .. code-block:: python

        with EarlyStopping(param_vars) as es:
            ...
            es.update(loss, global_step)
            ...

    Where ``es.update(loss, global_step)`` should cause the parameters to
    be saved on disk if `loss` is better than the current best metric.
    One may also get the current best metric via ``es.best_metric``.

    Notes:
        If no loss is given via ``es.update``, then the variables
        would keep their latest values when closing an early-stopping object.
    """
    def __init__(self,
                 param_vars,
                 initial_metric=None,
                 checkpoint_dir=None,
                 smaller_is_better=True,
                 restore_on_error=False,
                 cleanup=True,
                 name=None):
        """
        Construct the :class:`EarlyStopping`.

        Args:
            param_vars (list[tf.Variable] or dict[str, tf.Variable]): List or
                dict of variables to be memorized. If a dict is specified, the
                keys of the dict would be used as the serializations keys via
                :class:`VariableSaver`.
            initial_metric (float or tf.Tensor or tf.Variable): The initial best
                metric (for recovering from previous session).
            checkpoint_dir (str): The directory where to save the checkpoint
                files.  If not specified, will use a temporary directory.
            smaller_is_better (bool): Whether or not it is better to have
                smaller metric values? (default :obj:`True`)
            restore_on_error (bool): Whether or not to restore the memorized
                parameters even on error? (default :obj:`False`)
            cleanup (bool): Whether or not to cleanup the checkpoint directory
                on exit? This argument will be ignored if `checkpoint_dir` is
                :obj:`None`, where the temporary directory will always be
                deleted on exit.
            name (str): Name scope of all TensorFlow operations. (default
                "early_stopping").
        """
        # regularize the parameters
        if not param_vars:
            raise ValueError('`param_vars` must not be empty')

        if isinstance(initial_metric, (tf.Tensor, tf.Variable)):
            initial_metric = initial_metric.eval()

        if checkpoint_dir is not None:
            checkpoint_dir = os.path.abspath(checkpoint_dir)

        # memorize the parameters
        self._param_vars = copy.copy(param_vars)
        self._checkpoint_dir = checkpoint_dir
        self._smaller_is_better = smaller_is_better
        self._restore_on_error = restore_on_error
        self._cleanup = cleanup
        self._name = name

        # internal states of the object
        self._best_metric = initial_metric
        self._ever_updated = False
        self._temp_dir_ctx = None
        self._saver = None  # type: VariableSaver

    def _enter(self):
        # open a temporary directory if the checkpoint dir is not specified
        if self._checkpoint_dir is None:
            self._temp_dir_ctx = TemporaryDirectory()
            self._checkpoint_dir = self._temp_dir_ctx.__enter__()
        else:
            makedirs(self._checkpoint_dir, exist_ok=True)

        # create the variable saver
        self._saver = VariableSaver(self._param_vars, self._checkpoint_dir)

        # return self as the context object
        return self

    def _exit(self, exc_type, exc_val, exc_tb):
        try:
            # restore the variables
            # exc_info = (exc_type, exc_val, exc_tb)
            if exc_type is None or exc_type is KeyboardInterrupt or \
                    self._restore_on_error:
                self._saver.restore(ignore_non_exist=True)

        finally:
            # cleanup the checkpoint directory
            try:
                if self._temp_dir_ctx is not None:
                    self._temp_dir_ctx.__exit__(exc_type, exc_val, exc_tb)
                elif self._cleanup:
                    if os.path.exists(self._checkpoint_dir):
                        shutil.rmtree(self._checkpoint_dir)
            except Exception:  # pragma: no cover
                getLogger(__name__).error(
                    'Failed to cleanup validation save dir %r.',
                    self._checkpoint_dir,
                    exc_info=True)

            # warning if metric never updated
            if not self._ever_updated:
                warnings.warn('Early-stopping metric has never been updated. '
                              'The variables will keep their latest values. '
                              'Did you forget to add corresponding metric?')

    def update(self, metric, global_step=None):
        """
        Update the best metric.

        Args:
            metric (float): New metric value.
            global_step (int): Optional global step counter.

        Returns:
            bool: Whether or not the best loss has been updated?
        """
        self._require_entered()
        self._ever_updated = True
        if self._best_metric is None or \
                (self._smaller_is_better and metric < self._best_metric) or \
                (not self._smaller_is_better and metric > self._best_metric):
            self._saver.save(global_step)
            self._best_metric = metric
            return True
        return False

    @property
    def best_metric(self):
        """Get the current best loss."""
        return self._best_metric

    @property
    def ever_updated(self):
        """Check whether or not `update` method has ever been called."""
        return self._ever_updated