Пример #1
0
    def load(cls, basename: str, **kwargs) -> 'ClassifierModelBase':
        """Reload the model from a graph file and a checkpoint

        The model that is loaded is independent of the pooling and stacking layers, making this class reusable
        by sub-classes.

        :param basename: The base directory to load from
        :param kwargs: See below

        :Keyword Arguments:
        * *sess* -- An optional tensorflow session.  If not passed, a new session is
            created

        :return: A restored model
        """
        _state = read_json("{}.state".format(basename))
        if __version__ != _state['version']:
            logger.warning(
                "Loaded model is from baseline version %s, running version is %s",
                _state['version'], __version__)

        if not tf.executing_eagerly():
            _state['sess'] = kwargs.pop('sess', create_session())
            with _state['sess'].graph.as_default():
                embeddings_info = _state.pop('embeddings')
                embeddings = reload_embeddings(embeddings_info, basename)
                # If there is a kwarg that is the same name as an embedding object that
                # is taken to be the input of that layer. This allows for passing in
                # subgraphs like from a tf.split (for data parallel) or preprocessing
                # graphs that convert text to indices
                for k in embeddings_info:
                    if k in kwargs:
                        _state[k] = kwargs[k]
                labels = read_json("{}.labels".format(basename))
                model = cls.create(embeddings, labels, **_state)
                model._state = _state
                if kwargs.get('init', True):
                    model.sess.run(tf.compat.v1.global_variables_initializer())
                model.saver = tf.compat.v1.train.Saver()
                model.saver.restore(model.sess, basename)
        else:
            embeddings_info = _state.pop('embeddings')
            embeddings = reload_embeddings(embeddings_info, basename)
            # If there is a kwarg that is the same name as an embedding object that
            # is taken to be the input of that layer. This allows for passing in
            # subgraphs like from a tf.split (for data parallel) or preprocessing
            # graphs that convert text to indices
            for k in embeddings_info:
                if k in kwargs:
                    _state[k] = kwargs[k]
                # TODO: convert labels into just another vocab and pass number of labels to models.
            labels = read_json("{}.labels".format(basename))
            model = cls.create(embeddings, labels, **_state)
            model._state = _state
            model.load_weights(f"{basename}.wgt")
        return model
Пример #2
0
    def load(cls, basename: str, **kwargs) -> 'DependencyParserModelBase':
        """Reload the model from a graph file and a checkpoint

        The model that is loaded is independent of the pooling and stacking layers, making this class reusable
        by sub-classes.

        :param basename: The base directory to load from
        :param kwargs: See below

        :Keyword Arguments:
        * *sess* -- An optional tensorflow session.  If not passed, a new session is
            created

        :return: A restored model
        """
        _state = read_json("{}.state".format(basename))
        if __version__ != _state['version']:
            logger.warning(
                "Loaded model is from baseline version %s, running version is %s",
                _state['version'], __version__)

        embeddings_info = _state.pop('embeddings')
        embeddings = reload_embeddings(embeddings_info, basename)
        # If there is a kwarg that is the same name as an embedding object that
        # is taken to be the input of that layer. This allows for passing in
        # subgraphs like from a tf.split (for data parallel) or preprocessing
        # graphs that convert text to indices
        for k in embeddings_info:
            if k in kwargs:
                _state[k] = kwargs[k]
        labels = {"labels": read_json("{}.labels".format(basename))}
        model = cls.create(embeddings, labels, **_state)
        model._state = _state
        model.load_weights(f"{basename}.wgt")
        return model
Пример #3
0
 def load(cls, basename, **kwargs):
     _state = read_json("{}.state".format(basename))
     if __version__ != _state['version']:
         bl_logger.warning("Loaded model is from baseline version %s, running version is %s", _state['version'], __version__)
     _state['sess'] = kwargs.pop('sess', create_session())
     with _state['sess'].graph.as_default():
         embeddings_info = _state.pop('embeddings')
         embeddings = reload_embeddings(embeddings_info, basename)
         for k in embeddings_info:
             if k in kwargs:
                 _state[k] = kwargs[k]
         model = cls.create(embeddings, init=kwargs.get('init', True), **_state)
         model._state = _state
         model.saver = tf.train.Saver()
         model.saver.restore(model.sess, basename)
     return model
Пример #4
0
    def load(cls, basename, **kwargs):
        """Reload the model from a graph file and a checkpoint

        The model that is loaded is independent of the pooling and stacking layers, making this class reusable
        by sub-classes.

        :param basename: The base directory to load from
        :param kwargs: See below

        :Keyword Arguments:
        * *sess* -- An optional tensorflow session.  If not passed, a new session is
            created

        :return: A restored model
        """
        _state = read_json("{}.state".format(basename))
        if __version__ != _state['version']:
            logger.warning("Loaded model is from baseline version %s, running version is %s", _state['version'], __version__)
        _state['sess'] = kwargs.pop('sess', tf.Session())
        with _state['sess'].graph.as_default():
            embeddings_info = _state.pop('embeddings')
            embeddings = reload_embeddings(embeddings_info, basename)
            # If there is a kwarg that is the same name as an embedding object that
            # is taken to be the input of that layer. This allows for passing in
            # subgraphs like from a tf.split (for data parallel) or preprocessing
            # graphs that convert text to indices
            for k in embeddings_info:
                if k in kwargs:
                    _state[k] = kwargs[k]
            # TODO: convert labels into just another vocab and pass number of labels to models.
            labels = read_json("{}.labels".format(basename))
            model = cls.create(embeddings, labels, **_state)
            model._state = _state
            if kwargs.get('init', True):
                model.sess.run(tf.global_variables_initializer())
            model.saver = tf.train.Saver()
            model.saver.restore(model.sess, basename)
            return model