예제 #1
0
    def save_weights(self, filepath, format=None):
        """Input filepath, save model weights into a file of given format.
            Use self.load_weights() to restore.

        Parameters
        ----------
        filepath : str
            Filename to which the model weights will be saved.
        format : str or None
            Saved file format.
            Value should be None, 'hdf5', 'npz', 'npz_dict' or 'ckpt'. Other format is not supported now.
            1) If this is set to None, then the postfix of filepath will be used to decide saved format.
            If the postfix is not in ['h5', 'hdf5', 'npz', 'ckpt'], then file will be saved in hdf5 format by default.
            2) 'hdf5' will save model weights name in a list and each layer has its weights stored in a group of
            the hdf5 file.
            3) 'npz' will save model weights sequentially into a npz file.
            4) 'npz_dict' will save model weights along with its name as a dict into a npz file.
            5) 'ckpt' will save model weights into a tensorflow ckpt file.

            Default None.

        Examples
        --------
        1) Save model weights in hdf5 format by default.
        >>> net = tl.models.vgg16()
        >>> net.save_weights('./model.h5')
        ...
        >>> net.load_weights('./model.h5')

        2) Save model weights in npz/npz_dict format
        >>> net = tl.models.vgg16()
        >>> net.save_weights('./model.npz')
        >>> net.save_weights('./model.npz', format='npz_dict')

        """
        if self.all_weights is None or len(self.all_weights) == 0:
            logging.warning("Model contains no weights or layers haven't been built, nothing will be saved")
            return

        if format is None:
            postfix = filepath.split('.')[-1]
            if postfix in ['h5', 'hdf5', 'npz', 'ckpt']:
                format = postfix
            else:
                format = 'hdf5'

        if format == 'hdf5' or format == 'h5':
            utils.save_weights_to_hdf5(filepath, self)
        elif format == 'npz':
            utils.save_npz(self.all_weights, filepath)
        elif format == 'npz_dict':
            utils.save_npz_dict(self.all_weights, filepath)
        elif format == 'ckpt':
            # TODO: enable this when tf save ckpt is enabled
            raise NotImplementedError("ckpt load/save is not supported now.")
        else:
            raise ValueError(
                "Save format must be 'hdf5', 'npz', 'npz_dict' or 'ckpt'."
                "Other format is not supported now."
            )
예제 #2
0
    def save_weights(self, filepath, sess=None, format='hdf5'):
        # TODO: Documentation pending
        """Input filepath and the session(optional), save model weights into a file of given format.
            Use self.load_weights() to restore.

        Parameters
        ----------
        filepath : str
            Filename to which the model weights will be saved.
        sess : None or a tensorflow session
            In eager mode, this should be left as None. In graph mode, must specify it with a tensorflow session.
        format : Save file format
            Value should be 'hdf5', 'npz', 'npz_dict' or 'ckpt'. Other format is not supported now.
            'hdf5' will save model weights name in a list and each layer has its weights stored in a group of
            the hdf5 file.
            'npz' will save model weights sequentially into a npz file.
            'npz_dict' will save model weights along with its name as a dict into a npz file.
            'ckpt' will save model weights into a tensorflow ckpt file.

        Examples
        --------
        1) Save model to hdf5 in eager mode
        >>> net = tl.models.vgg.vgg16()
        >>> net.save_weights('./model.h5')

        2) Save model to npz in graph mode
        >>> sess = tf.Session()
        >>> sess.run(tf.global_variables_initializer())
        >>> net.save_weights('./model.npz', sess=sess, format='npz')

        Returns
        -------

        """
        if self.weights is None:
            logging.warning(
                "Model contains no weights or layers haven't been built, nothing will be saved"
            )
            return

        if format == 'hdf5':
            utils.save_weights_to_hdf5(filepath, self.weights, sess)
        elif format == 'npz':
            utils.save_npz(self.weights, filepath, sess)
        elif format == 'npz_dict':
            utils.save_npz_dict(self.weights, filepath, sess)
        elif format == 'ckpt':
            # TODO: enable this when tf save ckpt is enabled
            raise NotImplementedError("ckpt load/save is not supported now.")
        else:
            raise ValueError(
                "Save format must be 'hdf5', 'npz', 'npz_dict' or 'ckpt'."
                "Other format is not supported now.")
예제 #3
0
    def save_weights(self, filepath, format=None):
        """Input filepath and the session(optional), save model weights into a file of given format.
            Use self.load_weights() to restore.

        Parameters
        ----------
        filepath : str
            Filename to which the model weights will be saved.
        format : str or None
            Saved file format.
            Value should be None, 'hdf5', 'npz', 'npz_dict' or 'ckpt'. Other format is not supported now.
            1) If this is set to None, then the postfix of filepath will be used to decide saved format.
            If the postfix is not in ['h5', 'hdf5', 'npz', 'ckpt'], then file will be saved in hdf5 format by default.
            2) 'hdf5' will save model weights name in a list and each layer has its weights stored in a group of
            the hdf5 file.
            3) 'npz' will save model weights sequentially into a npz file.
            4) 'npz_dict' will save model weights along with its name as a dict into a npz file.
            5) 'ckpt' will save model weights into a tensorflow ckpt file.

            Default None.

        Examples
        --------
        1) Save model weights in hdf5 format by default.
        >>> net = tl.models.vgg16()
        >>> net.save_weights('./model.h5')
        ...
        >>> net.load_weights('./model.h5')

        2) Save model weights in npz/npz_dict format
        >>> net = tl.models.vgg16()
        >>> net.save_weights('./model.npz')
        >>> net.save_weights('./model.npz', format='npz_dict')

        Returns
        -------

        """
        if self.weights is None or len(self.weights) == 0:
            logging.warning("Model contains no weights or layers haven't been built, nothing will be saved")
            return

        if format is None:
            postfix = filepath.split('.')[-1]
            if postfix in ['h5', 'hdf5', 'npz', 'ckpt']:
                format = postfix
            else:
                format = 'hdf5'

        if format == 'hdf5' or format == 'h5':
            utils.save_weights_to_hdf5(filepath, self)
        elif format == 'npz':
            utils.save_npz(self.weights, filepath)
        elif format == 'npz_dict':
            utils.save_npz_dict(self.weights, filepath)
        elif format == 'ckpt':
            # TODO: enable this when tf save ckpt is enabled
            raise NotImplementedError("ckpt load/save is not supported now.")
        else:
            raise ValueError(
                "Save format must be 'hdf5', 'npz', 'npz_dict' or 'ckpt'."
                "Other format is not supported now."
            )