Exemplo n.º 1
0
    def load_rbm_params(self, param_file):
        """
        Loads the parameters for the RBM only from param_file - used if the RBM was pre-trained

        :param param_file: location of rbm parameters
        :type param_file: string

        :return: whether successful
        :rtype: boolean
        """
        param_file = os.path.realpath(param_file)

        # make sure it is a pickle file
        ftype = file_ops.get_file_type(param_file)
        if ftype == file_ops.PKL:
            log.debug("loading model %s parameters from %s",
                      str(type(self)), str(param_file))
            # try to grab the pickled params from the specified param_file path
            with open(param_file, 'r') as f:
                loaded_params = pickle.load(f)
            #############################################################################
            # set the W, bv, and bh values (make sure same order as saved in RBM class) #
            #############################################################################
            self.W.set_value(loaded_params[0])
            self.bv.set_value(loaded_params[1])
            self.bh.set_value(loaded_params[2])
            return True
        # if get_file_type didn't return pkl or none, it wasn't a pickle file
        elif ftype:
            log.error("Param file %s doesn't have a supported pickle extension!", str(param_file))
            return False
        # if get_file_type returned none, it couldn't find the file
        else:
            log.error("Param file %s couldn't be found!", str(param_file))
            return False
Exemplo n.º 2
0
    def load_params(self, param_file):
        """
        This loads the model's parameters from the param_file (pickle file)
        ------------------

        :param param_file: filename of pickled params file
        :type param_file: String

        :return: whether or not successful
        :rtype: Boolean
        """
        param_file = os.path.realpath(param_file)

        # make sure it is a pickle file
        ftype = file_ops.get_file_type(param_file)
        if ftype == file_ops.PKL:
            log.debug("loading model %s parameters from %s",
                      str(type(self)), str(param_file))
            # try to grab the pickled params from the specified param_file path
            with open(param_file, 'r') as f:
                loaded_params = pickle.load(f)
            self.set_param_values(loaded_params)
            return True
        # if get_file_type didn't return pkl or none, it wasn't a pickle file
        elif ftype:
            log.error("Param file %s doesn't have a supported pickle extension!", str(param_file))
            return False
        # if get_file_type returned none, it couldn't find the file
        else:
            log.error("Param file %s couldn't be found!", str(param_file))
            return False
Exemplo n.º 3
0
    def load_params(self, param_file):
        """
        This loads the model's parameters from the param_file (pickle file)
        ------------------

        :param param_file: filename of pickled params file
        :type param_file: String

        :return: whether or not successful
        :rtype: Boolean
        """
        param_file = os.path.realpath(param_file)

        # make sure it is a pickle file
        ftype = file_ops.get_file_type(param_file)
        if ftype == file_ops.PKL:
            log.debug("loading model %s parameters from %s...",
                      str(type(self)), str(param_file))
            # try to grab the pickled params from the specified param_file path
            with open(param_file, 'r') as f:
                loaded_params = cPickle.load(f)
            self.set_param_values(loaded_params)
            return True
        # if get_file_type didn't return pkl or none, it wasn't a pickle file
        elif ftype:
            log.error(
                "Param file %s doesn't have a supported pickle extension!",
                str(param_file))
            return False
        # if get_file_type returned none, it couldn't find the file
        else:
            log.error("Param file %s couldn't be found!", str(param_file))
            return False
Exemplo n.º 4
0
    def load_rbm_params(self, param_file):
        """
        Loads the parameters for the RBM only from param_file - used if the RBM was pre-trained

        :param param_file: location of rbm parameters
        :type param_file: string

        :return: whether successful
        :rtype: boolean
        """
        param_file = os.path.realpath(param_file)

        # make sure it is a pickle file
        ftype = file_ops.get_file_type(param_file)
        if ftype == file_ops.PKL:
            log.debug("loading model %s parameters from %s",
                      str(type(self)), str(param_file))
            # try to grab the pickled params from the specified param_file path
            with open(param_file, 'r') as f:
                loaded_params = pickle.load(f)
            #############################################################################
            # set the W, bv, and bh values (make sure same order as saved in RBM class) #
            #############################################################################
            self.W.set_value(loaded_params[0])
            self.bv.set_value(loaded_params[1])
            self.bh.set_value(loaded_params[2])
            return True
        # if get_file_type didn't return pkl or none, it wasn't a pickle file
        elif ftype:
            log.error("Param file %s doesn't have a supported pickle extension!", str(param_file))
            return False
        # if get_file_type returned none, it couldn't find the file
        else:
            log.error("Param file %s couldn't be found!", str(param_file))
            return False
Exemplo n.º 5
0
    def load_params(self, param_file):
        """
        This loads the model's parameters from the param_file (hdf5 or pickle file)

        Parameters
        ----------
        param_file : str
            Filename of hdf5 or pickled params file (the file holding the model parameters).

        Returns
        -------
        bool
            Whether or not successfully loaded parameters.
        """
        param_file = os.path.realpath(param_file)

        # make sure it is a pickle file
        ftype = file_ops.get_file_type(param_file)

        log.debug("loading %s model parameters from %s", self._classname,
                  str(param_file))

        if ftype == file_ops.PKL:
            # try to grab the pickled params from the specified param_file path
            with open(param_file, 'rb') as f:
                loaded_params = pickle.load(f)
            self.set_param_values(loaded_params, borrow=False)
            return True

        elif ftype == file_ops.HDF5:
            if HAS_H5PY:
                f = h5py.File(param_file)
                try:
                    params = f[hdf5_param_key]
                    self.set_param_values(params)
                except Exception as e:
                    log.exception(
                        "Some issue loading model %s parameters from %s! Exception: %s",
                        self._classname, str(param_file), str(e))
                    return False
                finally:
                    f.close()
            else:
                log.error(
                    "Please install the h5py package to read HDF5 files!")
                return False
        # if get_file_type didn't return pkl, hdf5, or none
        elif ftype:
            log.error(
                "Param file %s doesn't have a supported pickle or HDF5 extension!",
                str(param_file))
            return False
        # if get_file_type returned none, it couldn't find the file
        else:
            log.error("Param file %s couldn't be found!", str(param_file))
            return False
Exemplo n.º 6
0
    def load_params(self, param_file):
        """
        This loads the model's parameters from the param_file (hdf5 or pickle file)

        Parameters
        ----------
        param_file : str
            Filename of hdf5 or pickled params file (the file holding the model parameters).

        Returns
        -------
        bool
            Whether or not successfully loaded parameters.
        """
        param_file = os.path.realpath(param_file)

        # make sure it is a pickle file
        ftype = file_ops.get_file_type(param_file)

        log.debug("loading %s model parameters from %s",
                  self._classname, str(param_file))

        if ftype == file_ops.PKL:
            # try to grab the pickled params from the specified param_file path
            with open(param_file, 'rb') as f:
                loaded_params = pickle.load(f)
            self.set_param_values(loaded_params, borrow=False)
            return True

        elif ftype == file_ops.HDF5:
            if HAS_H5PY:
                f = h5py.File(param_file)
                try:
                    params = f[hdf5_param_key]
                    self.set_param_values(params)
                except Exception as e:
                    log.exception("Some issue loading model %s parameters from %s! Exception: %s",
                                  self._classname, str(param_file), str(e))
                    return False
                finally:
                    f.close()
            else:
                log.error("Please install the h5py package to read HDF5 files!")
                return False
        # if get_file_type didn't return pkl, hdf5, or none
        elif ftype:
            log.error("Param file %s doesn't have a supported pickle or HDF5 extension!", str(param_file))
            return False
        # if get_file_type returned none, it couldn't find the file
        else:
            log.error("Param file %s couldn't be found!", str(param_file))
            return False
Exemplo n.º 7
0
    def load(config_file, param_file=None):
        """
        Returns a new Model from a configuration file.

        Parameters
        ----------
        config_file : str
            Filename of pickled configuration file.
        param_file : str, optional
            Filename of hdf5 or pickle file holding the model parameters (in a separate file from `config_file`
            if you want to load some starting parameters).

        Returns
        -------
        :class:`Model`
            A `Model` instance from the configuration and optionally loaded parameters.
        """
        config_file = os.path.realpath(config_file)

        ftype = file_ops.get_file_type(config_file)

        # deal with pickle
        if ftype == file_ops.PKL:
            log.debug("loading model from %s", str(config_file))
            with open(config_file, 'rb') as f:
                loaded_config = pickle.load(f)
        # if get_file_type didn't return pkl, or none
        elif ftype:
            log.exception(
                "Config file %s doesn't have a supported pickle extension!",
                str(config_file))
            raise AssertionError(
                "Config file %s doesn't have a supported pickle extension!",
                str(config_file))
        # if get_file_type returned none, it couldn't find the file
        else:
            log.exception("Config file %s couldn't be found!",
                          str(config_file))
            raise AssertionError("Config file %s couldn't be found!",
                                 str(config_file))

        classname = loaded_config.pop(class_key)
        class_ = getattr(opendeep.models, classname)
        model = class_(**loaded_config)
        if param_file is not None:
            model.load_params(param_file=param_file)

        return model
Exemplo n.º 8
0
    def load(config_file, param_file=None):
        """
        Returns a new Model from a configuration file.

        Parameters
        ----------
        config_file : str
            Filename of pickled configuration file.
        param_file : str, optional
            Filename of hdf5 or pickle file holding the model parameters (in a separate file from `config_file`
            if you want to load some starting parameters).

        Returns
        -------
        :class:`Model`
            A `Model` instance from the configuration and optionally loaded parameters.
        """
        config_file = os.path.realpath(config_file)

        ftype = file_ops.get_file_type(config_file)

        # deal with pickle
        if ftype == file_ops.PKL:
            log.debug("loading model from %s",
                      str(config_file))
            with open(config_file, 'rb') as f:
                loaded_config = pickle.load(f)
        # if get_file_type didn't return pkl, or none
        elif ftype:
            log.exception("Config file %s doesn't have a supported pickle extension!", str(config_file))
            raise AssertionError("Config file %s doesn't have a supported pickle extension!", str(config_file))
        # if get_file_type returned none, it couldn't find the file
        else:
            log.exception("Config file %s couldn't be found!", str(config_file))
            raise AssertionError("Config file %s couldn't be found!", str(config_file))

        classname = loaded_config.pop(class_key)
        class_ = getattr(opendeep.models, classname)
        model = class_(**loaded_config)
        if param_file is not None:
            model.load_params(param_file=param_file)

        return model
Exemplo n.º 9
0
    def load_gsn_params(self, param_file):
        """
        Loads the parameters for the GSN only from param_file - used if the GSN was pre-trained

        Parameters
        ----------
        param_file : str
            Relative location of GSN parameters.

        Returns
        -------
        bool
            Whether or not successful.
        """
        param_file = os.path.realpath(param_file)

        # make sure it is a pickle file
        ftype = file_ops.get_file_type(param_file)
        if ftype == file_ops.PKL:
            log.debug("loading model %s parameters from %s",
                      str(type(self)), str(param_file))
            # try to grab the pickled params from the specified param_file path
            with open(param_file, 'r') as f:
                loaded_params = pickle.load(f)
            # set the GSN parameters
            for i, weight in enumerate(self.weights_list):
                weight.set_value(loaded_params[i])
            for i, bias in enumerate(self.bias_list):
                bias.set_value(loaded_params[i+len(self.weights_list)])
            return True
        # if get_file_type didn't return pkl or none, it wasn't a pickle file
        elif ftype:
            log.error("Param file %s doesn't have a supported pickle extension!", str(param_file))
            return False
        # if get_file_type returned none, it couldn't find the file
        else:
            log.error("Param file %s couldn't be found!", str(param_file))
            return False
Exemplo n.º 10
0
    def load_gsn_params(self, param_file):
        """
        Loads the parameters for the GSN only from param_file - used if the GSN was pre-trained

        Parameters
        ----------
        param_file : str
            Relative location of GSN parameters.

        Returns
        -------
        bool
            Whether or not successful.
        """
        param_file = os.path.realpath(param_file)

        # make sure it is a pickle file
        ftype = file_ops.get_file_type(param_file)
        if ftype == file_ops.PKL:
            log.debug("loading model %s parameters from %s",
                      str(type(self)), str(param_file))
            # try to grab the pickled params from the specified param_file path
            with open(param_file, 'r') as f:
                loaded_params = pickle.load(f)
            # set the GSN parameters
            for i, weight in enumerate(self.weights_list):
                weight.set_value(loaded_params[i])
            for i, bias in enumerate(self.bias_list):
                bias.set_value(loaded_params[i+len(self.weights_list)])
            return True
        # if get_file_type didn't return pkl or none, it wasn't a pickle file
        elif ftype:
            log.error("Param file %s doesn't have a supported pickle extension!", str(param_file))
            return False
        # if get_file_type returned none, it couldn't find the file
        else:
            log.error("Param file %s couldn't be found!", str(param_file))
            return False
Exemplo n.º 11
0
    def install(self):
        '''
        Method to both download and extract the dataset from the internet (if there) or verify connection settings
        '''
        file_type = None
        if self.filename is not None:
            log.info('Installing dataset %s', str(self.filename))
            # construct the actual path to the dataset
            prevdir = os.getcwd()
            os.chdir(os.path.split(os.path.realpath(__file__))[0])
            dataset_dir = os.path.realpath(self.dataset_dir)
            try:
                mkdir_p(dataset_dir)
                dataset_location = os.path.join(dataset_dir, self.filename)
            except Exception as e:
                log.error("Couldn't make the dataset path with directory %s and filename %s",
                          dataset_dir,
                          str(self.filename))
                log.exception("%s", str(e))
                dataset_location = None
            finally:
                os.chdir(prevdir)

            # check if the dataset is already in the source, otherwise download it.
            # first check if the base filename exists - without all the extensions.
            # then, add each extension on and keep checking until the upper level, when you download from http.
            if dataset_location is not None:
                (dirs, fname) = os.path.split(dataset_location)
                split_fname = fname.split('.')
                accumulated_name = split_fname[0]
                found = False
                # first check if the filename was a directory (like for the midi datasets)
                if os.path.exists(os.path.join(dirs, accumulated_name)):
                    found = True
                    file_type = get_file_type(os.path.join(dirs, accumulated_name))
                    dataset_location = os.path.join(dirs, accumulated_name)
                    log.debug('Found file %s', dataset_location)
                # now go through the file extensions starting with the lowest level and check if the file exists
                if not found and len(split_fname) > 1:
                    for chunk in split_fname[1:]:
                        accumulated_name = '.'.join((accumulated_name, chunk))
                        file_type = get_file_type(os.path.join(dirs, accumulated_name))
                        if file_type is not None:
                            dataset_location = os.path.join(dirs, accumulated_name)
                            log.debug('Found file %s', dataset_location)
                            break

            # if the file wasn't found, download it if a source was provided. Otherwise, raise error.
            download_success = True
            if self.source is not None:
                if file_type is None:
                    download_success = download_file(self.source, dataset_location)
                    file_type = get_file_type(dataset_location)
            else:
                log.error("Filename %s couldn't be found, and no URL source to download was provided.",
                          str(self.filename))
                raise RuntimeError("Filename %s couldn't be found, and no URL source to download was provided." %
                                   str(self.filename))

            # if the file type is a zip, unzip it.
            unzip_success = True
            if file_type is files.ZIP:
                (dirs, fname) = os.path.split(dataset_location)
                post_unzip = os.path.join(dirs, '.'.join(fname.split('.')[0:-1]))
                unzip_success = files.unzip(dataset_location, post_unzip)
                # if the unzip was successful
                if unzip_success:
                    # remove the zipfile and update the dataset location and file type
                    log.debug('Removing file %s', dataset_location)
                    os.remove(dataset_location)
                    dataset_location = post_unzip
                    file_type = get_file_type(dataset_location)
            if download_success and unzip_success:
                log.info('Installation complete. Yay!')
            else:
                log.warning('Something went wrong installing dataset. Boo :(')

            return dataset_location, file_type
Exemplo n.º 12
0
    def install(self):
        '''
        Method to both download and extract the dataset from the internet (if applicable) or verify that the file
        exists in the dataset_dir.

        Returns
        -------
        str
            The absolute path to the dataset location on disk.
        int
            The integer representing the file type for the dataset, as defined in the opendeep.utils.file_ops module.
        '''
        file_type = None
        if self.filename is not None:
            log.info('Installing dataset %s', str(self.filename))
            # construct the actual path to the dataset
            prevdir = os.getcwd()
            os.chdir(os.path.split(os.path.realpath(__file__))[0])
            dataset_dir = os.path.realpath(self.dataset_dir)
            try:
                mkdir_p(dataset_dir)
                dataset_location = os.path.join(dataset_dir, self.filename)
            except Exception as e:
                log.error("Couldn't make the dataset path with directory %s and filename %s",
                          dataset_dir,
                          str(self.filename))
                log.exception("%s", str(e))
                dataset_location = None
            finally:
                os.chdir(prevdir)

            # check if the dataset is already in the source, otherwise download it.
            # first check if the base filename exists - without all the extensions.
            # then, add each extension on and keep checking until the upper level, when you download from http.
            if dataset_location is not None:
                (dirs, fname) = os.path.split(dataset_location)
                split_fname = fname.split('.')
                accumulated_name = split_fname[0]
                found = False
                # first check if the filename was a directory (like for the midi datasets)
                if os.path.exists(os.path.join(dirs, accumulated_name)):
                    found = True
                    file_type = get_file_type(os.path.join(dirs, accumulated_name))
                    dataset_location = os.path.join(dirs, accumulated_name)
                    log.debug('Found file %s', dataset_location)
                # now go through the file extensions starting with the lowest level and check if the file exists
                if not found and len(split_fname) > 1:
                    for chunk in split_fname[1:]:
                        accumulated_name = '.'.join((accumulated_name, chunk))
                        file_type = get_file_type(os.path.join(dirs, accumulated_name))
                        if file_type is not None:
                            dataset_location = os.path.join(dirs, accumulated_name)
                            log.debug('Found file %s', dataset_location)
                            break

            # if the file wasn't found, download it if a source was provided. Otherwise, raise error.
            download_success = True
            if self.source is not None:
                if file_type is None:
                    download_success = download_file(self.source, dataset_location)
                    file_type = get_file_type(dataset_location)
            else:
                log.error("Filename %s couldn't be found, and no URL source to download was provided.",
                          str(self.filename))
                raise RuntimeError("Filename %s couldn't be found, and no URL source to download was provided." %
                                   str(self.filename))

            # if the file type is a zip, unzip it.
            unzip_success = True
            if file_type is files.ZIP:
                (dirs, fname) = os.path.split(dataset_location)
                post_unzip = os.path.join(dirs, '.'.join(fname.split('.')[0:-1]))
                unzip_success = files.unzip(dataset_location, post_unzip)
                # if the unzip was successful
                if unzip_success:
                    # remove the zipfile and update the dataset location and file type
                    log.debug('Removing file %s', dataset_location)
                    os.remove(dataset_location)
                    dataset_location = post_unzip
                    file_type = get_file_type(dataset_location)
            if download_success and unzip_success:
                log.info('Installation complete. Yay!')
            else:
                log.warning('Something went wrong installing dataset. Boo :(')

            return dataset_location, file_type