コード例 #1
0
ファイル: model.py プロジェクト: bcmi220/multilingual_srl
    def load(cls, basename, **kwargs):
        K.clear_session()
        model = cls()
        model.impl = keras.models.load_model(basename)
        state = read_json(basename + '.state')
        for prop in ls_props(model):
            if prop in state:
                setattr(model, prop, state[prop])
        inputs = dict({(v.name[:v.name.find(':')], v)
                       for v in model.impl.inputs})

        model.embeddings = dict()
        for key, class_name in state['embeddings'].items():
            md = read_json('{}-{}-md.json'.format(basename, key))
            embed_args = dict({'vsz': md['vsz'], 'dsz': md['dsz']})
            embed_args[key] = inputs[key]
            Constructor = eval(class_name)
            model.embeddings[key] = Constructor(key, **embed_args)

        ##model.lengths_key = state.get('lengths_key')

        with open(basename + '.labels', 'r') as f:
            model.labels = json.load(f)

        return model
コード例 #2
0
ファイル: services.py プロジェクト: DevSinghSachan/baseline
    def _create_remote_model(directory, backend, remote, name, signature_name, beam, preproc='client'):
        """Reads the necessary information from the remote bundle to instatiate
        a client for a remote model.

        :directory the location of the exported model bundle
        :remote a url endpoint to hit
        :name the model name, as defined in tf-serving's model.config
        :signature_name  the signature to use.
        :beam used for s2s and found in the kwargs. We default this and pass it in.

        :returns a RemoteModel
        """
        assets = read_json(os.path.join(directory, 'model.assets'))
        model_name = assets['metadata']['exported_model']
        labels = read_json(os.path.join(directory, model_name) + '.labels')
        lengths_key = assets.get('lengths_key', None)
        inputs = assets.get('inputs', [])

        if backend == 'tf':
            remote_models = import_user_module('baseline.remote')
            if remote.startswith('http'):
                RemoteModel = remote_models.RemoteModelTensorFlowREST
            elif preproc == 'server':
                RemoteModel = remote_models.RemoteModelTensorFlowGRPCPreproc
            else:
                RemoteModel = remote_models.RemoteModelTensorFlowGRPC
            model = RemoteModel(remote, name, signature_name, labels=labels, lengths_key=lengths_key, inputs=inputs, beam=beam)
        else:
            raise ValueError("only Tensorflow is currently supported for remote Services")

        return model
コード例 #3
0
    def _create_model(self, sess, basename):
        labels = read_json(basename + '.labels')
        model_params = self.task.config_params["model"]
        model_params["sess"] = sess

        state = read_json(basename + '.state')
        if state.get('constrain_decode', False):
            constraint = transition_mask(
                labels, self.task.config_params['train']['span_type'],
                Offsets.GO, Offsets.EOS, Offsets.PAD)
            model_params['constraint'] = constraint

        # Re-create the embeddings sub-graph
        embeddings = dict()
        for key, class_name in state['embeddings'].items():
            md = read_json('{}-{}-md.json'.format(basename, key))
            embed_args = dict({'vsz': md['vsz'], 'dsz': md['dsz']})
            Constructor = eval(class_name)
            embeddings[key] = Constructor(key, **embed_args)

        model = baseline.model.create_model_for(self.task.task_name(),
                                                embeddings, labels,
                                                **model_params)

        for prop in ls_props(model):
            if prop in state:
                setattr(model, prop, state[prop])

        model.create_loss()

        softmax_output = tf.nn.softmax(model.probs)
        values, indices = tf.nn.top_k(softmax_output, 1)

        start_np = np.full((1, 1, len(labels)), -1e4, dtype=np.float32)
        start_np[:, 0, Offsets.GO] = 0
        start = tf.constant(start_np)
        start = tf.tile(start, [tf.shape(model.probs)[0], 1, 1])
        model.probs = tf.concat([start, model.probs], 1)

        ones = tf.fill(tf.shape(model.lengths), 1)
        lengths = tf.add(model.lengths, ones)

        if model.crf is True:
            indices, _ = tf.contrib.crf.crf_decode(model.probs, model.A,
                                                   lengths)
            indices = indices[:, 1:]

        list_of_labels = [''] * len(labels)
        for label, idval in labels.items():
            list_of_labels[idval] = label

        class_tensor = tf.constant(list_of_labels)
        table = tf.contrib.lookup.index_to_string_table_from_tensor(
            class_tensor)
        classes = table.lookup(tf.to_int64(indices))
        self._restore_checkpoint(sess, basename)

        return model, indices, values
コード例 #4
0
    def load(cls, basename, **kwargs):
        """Reload the model from a graph file and a checkpoint
        The model that is loaded is independent of the pooling and stacking layers, making this class reusable
        by sub-classes.

        :param basename: The base directory to load from
        :param kwargs: See below

        :Keyword Arguments:
        * *sess* -- An optional tensorflow session.  If not passed, a new session is
            created

        :return: A restored model
        """
        _state = read_json("{}.state".format(basename))
        # FIXME: Somehow not writing this anymore
        #if __version__ != _state['version']:
        #    logger.warning("Loaded model is from baseline version %s, running version is %s", _state['version'], __version__)
        if not tf.executing_eagerly():

            _state['sess'] = kwargs.pop('sess', create_session())
            embeddings_info = _state.pop("embeddings")

            with _state['sess'].graph.as_default():
                embeddings = reload_embeddings(embeddings_info, basename)
                for k in embeddings_info:
                    if k in kwargs:
                        _state[k] = kwargs[k]
                labels = read_json("{}.labels".format(basename))
                # FIXME: referring to the `constraint_mask` in the base where its not mentioned isnt really clean
                if _state.get('constraint_mask') is not None:
                    # Dummy constraint values that will be filled in by the check pointing
                    _state['constraint_mask'] = [np.zeros((len(labels), len(labels))) for _ in range(2)]
                model = cls.create(embeddings, labels, **_state)
                model._state = _state
                model.create_loss()
                if kwargs.get('init', True):
                    model.sess.run(tf.compat.v1.global_variables_initializer())
                model.saver = tf.compat.v1.train.Saver()
                model.saver.restore(model.sess, basename)
        else:
            embeddings_info = _state.pop('embeddings')
            embeddings = reload_embeddings(embeddings_info, basename)

            for k in embeddings_info:
                if k in kwargs:
                    _state[k] = kwargs[k]
            # TODO: convert labels into just another vocab and pass number of labels to models.
            labels = read_json("{}.labels".format(basename))
            # FIXME: referring to the `constraint_mask` in the base where its not mentioned isnt really clean
            if _state.get('constraint_mask') is not None:
                # Dummy constraint values that will be filled in by the check pointing
                _state['constraint_mask'] = [np.zeros((len(labels), len(labels))) for _ in range(2)]
            model = cls.create(embeddings, labels, **_state)
            model._state = _state
            model.load_weights(f"{basename}.wgt")
        return model
コード例 #5
0
ファイル: model.py プロジェクト: blester125/baseline
    def load(cls, basename: str, **kwargs) -> 'ClassifierModelBase':
        """Reload the model from a graph file and a checkpoint

        The model that is loaded is independent of the pooling and stacking layers, making this class reusable
        by sub-classes.

        :param basename: The base directory to load from
        :param kwargs: See below

        :Keyword Arguments:
        * *sess* -- An optional tensorflow session.  If not passed, a new session is
            created

        :return: A restored model
        """
        _state = read_json("{}.state".format(basename))
        if __version__ != _state['version']:
            logger.warning(
                "Loaded model is from baseline version %s, running version is %s",
                _state['version'], __version__)

        if not tf.executing_eagerly():
            _state['sess'] = kwargs.pop('sess', create_session())
            with _state['sess'].graph.as_default():
                embeddings_info = _state.pop('embeddings')
                embeddings = reload_embeddings(embeddings_info, basename)
                # If there is a kwarg that is the same name as an embedding object that
                # is taken to be the input of that layer. This allows for passing in
                # subgraphs like from a tf.split (for data parallel) or preprocessing
                # graphs that convert text to indices
                for k in embeddings_info:
                    if k in kwargs:
                        _state[k] = kwargs[k]
                labels = read_json("{}.labels".format(basename))
                model = cls.create(embeddings, labels, **_state)
                model._state = _state
                if kwargs.get('init', True):
                    model.sess.run(tf.compat.v1.global_variables_initializer())
                model.saver = tf.compat.v1.train.Saver()
                model.saver.restore(model.sess, basename)
        else:
            embeddings_info = _state.pop('embeddings')
            embeddings = reload_embeddings(embeddings_info, basename)
            # If there is a kwarg that is the same name as an embedding object that
            # is taken to be the input of that layer. This allows for passing in
            # subgraphs like from a tf.split (for data parallel) or preprocessing
            # graphs that convert text to indices
            for k in embeddings_info:
                if k in kwargs:
                    _state[k] = kwargs[k]
                # TODO: convert labels into just another vocab and pass number of labels to models.
            labels = read_json("{}.labels".format(basename))
            model = cls.create(embeddings, labels, **_state)
            model._state = _state
            model.load_weights(f"{basename}.wgt")
        return model
コード例 #6
0
ファイル: model.py プロジェクト: wenshuoliu/baseline
    def load(cls, basename, **kwargs):
        """Reload the model from a graph file and a checkpoint

        The model that is loaded is independent of the pooling and stacking layers, making this class reusable
        by sub-classes.

        :param basename: The base directory to load from
        :param kwargs: See below

        :Keyword Arguments:
        * *sess* -- An optional tensorflow session.  If not passed, a new session is
            created

        :return: A restored model
        """
        _state = read_json(basename + '.state')

        if not tf.executing_eagerly():
            _state['sess'] = kwargs.pop('sess', create_session())
            embeddings_info = _state.pop("embeddings")

            with _state['sess'].graph.as_default():
                embeddings = reload_embeddings(embeddings_info, basename)
                for k in embeddings_info:
                    if k in kwargs:
                        _state[k] = kwargs[k]

                _state['model_type'] = kwargs.get('model_type', 'default')
                model = cls.create(embeddings, **_state)
                model._state = _state

                do_init = kwargs.get('init', True)
                if do_init:
                    init = tf.compat.v1.global_variables_initializer()
                    model.sess.run(init)

                model.saver = tf.compat.v1.train.Saver()
                model.saver.restore(model.sess, basename)
        else:
            _state = read_json(basename + '.state')
            _state['model_type'] = kwargs.get('model_type', 'default')
            embeddings = {}
            embeddings_dict = _state.pop("embeddings")

            for key, class_name in embeddings_dict.items():
                md = read_json('{}-{}-md.json'.format(basename, key))
                embed_args = dict({'vsz': md['vsz'], 'dsz': md['dsz']})
                Constructor = eval(class_name)
                embeddings[key] = Constructor(key, **embed_args)

            model = cls.create(embeddings, **_state)
            model._state = _state
            model.load_weights(f"{basename}.wgt")

        return model
コード例 #7
0
ファイル: downloader.py プロジェクト: dpressel/baseline
 def download(self):
     if is_file_correct(self.embedding_file):
         logger.info("embedding file location: {}".format(self.embedding_file))
         return self.embedding_file
     dcache_path = os.path.join(self.data_download_cache, DATA_CACHE_CONF)
     dcache = read_json(dcache_path)
     if self.embedding_file in dcache and not self.cache_ignore:
         download_loc = dcache[self.embedding_file]
         logger.info("files for {} found in cache".format(self.embedding_file))
         return self._get_embedding_file(download_loc, self.embedding_key)
     else:  # try to download the bundle and unzip
         url = self.embedding_file
         if not validate_url(url):
             raise RuntimeError("can not download from the given url")
         else:
             cache_dir = self.data_download_cache
             temp_file = web_downloader(url)
             download_loc = extractor(filepath=temp_file, cache_dir=cache_dir,
                                      extractor_func=Downloader.ZIPD.get(mime_type(temp_file), None))
             if self.sha1 is not None:
                 if os.path.split(download_loc)[-1] != self.sha1:
                     raise RuntimeError("The sha1 of the downloaded file does not match with the provided one")
             dcache.update({url: download_loc})
             write_json(dcache, os.path.join(self.data_download_cache, DATA_CACHE_CONF))
             return self._get_embedding_file(download_loc, self.embedding_key)
コード例 #8
0
ファイル: downloader.py プロジェクト: dpressel/baseline
    def download(self):
        dload_bundle = self.dataset_desc.get("download", None)
        if dload_bundle is not None:  # download a zip/tar/tar.gz directory, look for train, dev test files inside that.
            dcache_path = os.path.join(self.data_download_cache, DATA_CACHE_CONF)
            dcache = read_json(dcache_path)
            if dload_bundle in dcache and \
                    is_dir_correct(dcache[dload_bundle], self.dataset_desc, self.data_download_cache, dload_bundle,
                                   self.enc_dec) and not self.cache_ignore:
                download_dir = dcache[dload_bundle]
                logger.info("files for {} found in cache, not downloading".format(dload_bundle))
                return {k: os.path.join(download_dir, self.dataset_desc[k]) for k in self.dataset_desc
                        if k.endswith("_file")}
            else:  # try to download the bundle and unzip
                if not validate_url(dload_bundle):
                    raise RuntimeError("can not download from the given url")
                else:
                    cache_dir = self.data_download_cache
                    temp_file = web_downloader(dload_bundle)

                    download_dir = extractor(filepath=temp_file, cache_dir=cache_dir,
                                             extractor_func=Downloader.ZIPD.get(mime_type(temp_file), None))
                    if "sha1" in self.dataset_desc:
                        if os.path.split(download_dir)[-1] != self.dataset_desc["sha1"]:
                            raise RuntimeError("The sha1 of the downloaded file does not match with the provided one")
                    dcache.update({dload_bundle: download_dir})
                    write_json(dcache, os.path.join(self.data_download_cache, DATA_CACHE_CONF))
                    return {k: os.path.join(download_dir, self.dataset_desc[k]) for k in self.dataset_desc
                            if k.endswith("_file")}
        else:  # we have download links to every file or they exist
            if not self.enc_dec:
                    return {k: SingleFileDownloader(self.dataset_desc[k], self.data_download_cache).download()
                        for k in self.dataset_desc if k.endswith("_file")}
            else:
                return {k: self.dataset_desc[k] for k in self.dataset_desc if k.endswith("_file")}
コード例 #9
0
        def load(cls, basename, **kwargs):
            _state = read_json('{}.state'.format(basename))
            if __version__ != _state['version']:
                logger.warning(
                    "Loaded model is from baseline version %s, running version is %s",
                    _state['version'], __version__)
            if 'predict' in kwargs:
                _state['predict'] = kwargs['predict']
            if 'beam' in kwargs:
                _state['beam'] = kwargs['beam']
            _state['sess'] = kwargs.get('sess', create_session())

            with _state['sess'].graph.as_default():

                src_embeddings_info = _state.pop('src_embeddings')
                src_embeddings = reload_embeddings(src_embeddings_info,
                                                   basename)
                for k in src_embeddings_info:
                    if k in kwargs:
                        _state[k] = kwargs[k]
                tgt_embedding_info = _state.pop('tgt_embedding')
                tgt_embedding = reload_embeddings(tgt_embedding_info,
                                                  basename)['tgt']

                model = cls.create(src_embeddings, tgt_embedding, **_state)
                model._state = _state
                if kwargs.get('init', True):
                    model.sess.run(tf.compat.v1.global_variables_initializer())
                model.saver = tf.compat.v1.train.Saver()
                #reload_checkpoint(model.sess, basename, ['OptimizeLoss/'])
                model.saver.restore(model.sess, basename)
                return model
コード例 #10
0
ファイル: model.py プロジェクト: dpressel/baseline
    def load(cls, basename, **kwargs):
        _state = read_json('{}.state'.format(basename))
        if __version__ != _state['version']:
            logger.warning("Loaded model is from baseline version %s, running version is %s", _state['version'], __version__)
        if 'predict' in kwargs:
            _state['predict'] = kwargs['predict']
        if 'sampling' in kwargs:
            _state['sampling'] = kwargs['sampling']
        if 'sampling_temp' in kwargs:
            _state['sampling_temp'] = kwargs['sampling_temp']
        if 'beam' in kwargs:
            _state['beam'] = kwargs['beam']
        _state['sess'] = kwargs.get('sess', tf.Session())

        with _state['sess'].graph.as_default():

            embeddings_info = _state.pop('embeddings')
            embeddings = reload_embeddings(embeddings_info, basename)
            for k in embeddings_info:
                if k in kwargs:
                    _state[k] = kwargs[k]
            model = cls.create(embeddings, **_state)
            if kwargs.get('init', True):
                model.sess.run(tf.global_variables_initializer())
            model.saver = tf.train.Saver()
            model.saver.restore(model.sess, basename)
            return model
コード例 #11
0
 def download(self):
     file_loc = self.dataset_file
     if is_file_correct(file_loc):
         return file_loc
     elif validate_url(
             file_loc):  # is it a web URL? check if exists in cache
         url = file_loc
         dcache_path = os.path.join(self.data_download_cache,
                                    DATA_CACHE_CONF)
         dcache = read_json(dcache_path)
         if url in dcache and is_file_correct(
                 dcache[url], self.data_download_cache,
                 url) and not self.cache_ignore:
             print(
                 "file for {} found in cache, not downloading".format(url))
             return dcache[url]
         else:  # download the file in the cache, update the json
             cache_dir = self.data_download_cache
             print("using {} as data/embeddings cache".format(cache_dir))
             temp_file = web_downloader(url)
             dload_file = extractor(filepath=temp_file,
                                    cache_dir=cache_dir,
                                    extractor_func=Downloader.ZIPD.get(
                                        mime_type(temp_file), None))
             dcache.update({url: dload_file})
             write_json(
                 dcache,
                 os.path.join(self.data_download_cache, DATA_CACHE_CONF))
             return dload_file
     raise RuntimeError(
         "the file [{}] is not in cache and can not be downloaded".format(
             file_loc))
コード例 #12
0
ファイル: model.py プロジェクト: kiennguyen94/baseline
    def load(cls, basename, **kwargs):
        _state = read_json('{}.state'.format(basename))
        if __version__ != _state['version']:
            logger.warning(
                "Loaded model is from baseline version %s, running version is %s",
                _state['version'], __version__)
        if 'predict' in kwargs:
            _state['predict'] = kwargs['predict']
        if 'sampling' in kwargs:
            _state['sampling'] = kwargs['sampling']
        if 'sampling_temp' in kwargs:
            _state['sampling_temp'] = kwargs['sampling_temp']
        if 'beam' in kwargs:
            _state['beam'] = kwargs['beam']
        _state['sess'] = kwargs.get('sess', tf.Session())

        embeddings_info = _state.pop('embeddings')
        embeddings = reload_embeddings(embeddings_info, basename)
        for k in embeddings_info:
            if k in kwargs:
                _state[k] = kwargs[k]
        model = cls.create(embeddings, **_state)
        if kwargs.get('init', True):
            model.sess.run(tf.global_variables_initializer())
        model.saver = tf.train.Saver()
        model.saver.restore(model.sess, basename)
        return model
コード例 #13
0
    def download(self):
        dload_bundle = self.dataset_desc.get("download", None)
        if dload_bundle is not None:  # download a zip/tar/tar.gz directory, look for train, dev test files inside that.
            dcache_path = os.path.join(self.data_download_cache, DATA_CACHE_CONF)
            dcache = read_json(dcache_path)
            if dload_bundle in dcache and \
                    is_dir_correct(dcache[dload_bundle], self.dataset_desc, self.data_download_cache, dload_bundle,
                                   self.enc_dec) and not self.cache_ignore:
                download_dir = dcache[dload_bundle]
                logger.info("files for {} found in cache, not downloading".format(dload_bundle))
                return {k: os.path.join(download_dir, self.dataset_desc[k]) for k in self.dataset_desc
                        if k.endswith("_file")}
            else:  # try to download the bundle and unzip
                if not validate_url(dload_bundle):
                    raise RuntimeError("can not download from the given url")
                else:
                    cache_dir = self.data_download_cache
                    temp_file = web_downloader(dload_bundle)

                    download_dir = extractor(filepath=temp_file, cache_dir=cache_dir,
                                             extractor_func=Downloader.ZIPD.get(mime_type(temp_file), None))
                    if "sha1" in self.dataset_desc:
                        if os.path.split(download_dir)[-1] != self.dataset_desc["sha1"]:
                            raise RuntimeError("The sha1 of the downloaded file does not match with the provided one")
                    dcache.update({dload_bundle: download_dir})
                    write_json(dcache, os.path.join(self.data_download_cache, DATA_CACHE_CONF))
                    return {k: os.path.join(download_dir, self.dataset_desc[k]) for k in self.dataset_desc
                            if k.endswith("_file")}
        else:  # we have download links to every file or they exist
            if not self.enc_dec:
                    return {k: SingleFileDownloader(self.dataset_desc[k], self.data_download_cache).download()
                        for k in self.dataset_desc if k.endswith("_file") and self.dataset_desc[k]}
            else:
                return {k: self.dataset_desc[k] for k in self.dataset_desc if k.endswith("_file")}
コード例 #14
0
ファイル: services.py プロジェクト: wenshuoliu/baseline
    def load(cls, bundle, **kwargs):
        """Load a model from a bundle.

        This can be either a local model or a remote, exported model.

        :returns a Service implementation
        """
        import onnxruntime as ort

        # can delegate
        if os.path.isdir(bundle):
            directory = bundle
        # Try and unzip if its a zip file
        else:
            directory = unzip_files(bundle)

        model_basename = find_model_basename(directory)
        # model_basename = model_basename.replace(".pyt", "")
        model_name = f"{model_basename}.onnx"

        vocabs = load_vocabs(directory)
        vectorizers = load_vectorizers(directory)

        # Currently nothing to do here
        labels = read_json(model_basename + '.labels')

        model = ort.InferenceSession(model_name)
        return cls(vocabs, vectorizers, model, labels)
コード例 #15
0
 def download(self):
     if is_file_correct(self.embedding_file):
         logger.info("embedding file location: {}".format(self.embedding_file))
         return self.embedding_file
     dcache_path = os.path.join(self.data_download_cache, DATA_CACHE_CONF)
     dcache = read_json(dcache_path)
     if self.embedding_file in dcache and not self.cache_ignore:
         download_loc = dcache[self.embedding_file]
         logger.info("files for {} found in cache".format(self.embedding_file))
         return self._get_embedding_file(download_loc, self.embedding_key)
     else:  # try to download the bundle and unzip
         url = self.embedding_file
         if not validate_url(url):
             raise RuntimeError("can not download from the given url")
         else:
             cache_dir = self.data_download_cache
             temp_file = web_downloader(url)
             download_loc = extractor(filepath=temp_file, cache_dir=cache_dir,
                                      extractor_func=Downloader.ZIPD.get(mime_type(temp_file), None))
             if self.sha1 is not None:
                 if os.path.split(download_loc)[-1] != self.sha1:
                     raise RuntimeError("The sha1 of the downloaded file does not match with the provided one")
             dcache.update({url: download_loc})
             write_json(dcache, os.path.join(self.data_download_cache, DATA_CACHE_CONF))
             return self._get_embedding_file(download_loc, self.embedding_key)
コード例 #16
0
ファイル: services.py プロジェクト: wenshuoliu/baseline
    def _create_remote_model(directory, backend, remote, name, task_name,
                             signature_name, beam, **kwargs):
        """Reads the necessary information from the remote bundle to instatiate
        a client for a remote model.

        :directory the location of the exported model bundle
        :remote a url endpoint to hit
        :name the model name, as defined in tf-serving's model.config
        :signature_name  the signature to use.
        :beam used for s2s and found in the kwargs. We default this and pass it in.

        :returns a RemoteModel
        """
        from baseline.remote import create_remote
        assets = read_json(os.path.join(directory, 'model.assets'))
        model_name = assets['metadata']['exported_model']
        preproc = assets['metadata'].get('preproc',
                                         kwargs.get('preproc', 'client'))
        labels = read_json(os.path.join(directory, model_name) + '.labels')
        lengths_key = assets.get('lengths_key', None)
        inputs = assets.get('inputs', [])
        return_labels = bool(assets['metadata']['return_labels'])
        version = kwargs.get('version')

        if backend not in {'tf', 'onnx'}:
            raise ValueError(
                f"Unsupported backend {backend} for remote Services")
        import_user_module('baseline.{}.remote'.format(backend))
        exp_type = kwargs.get('remote_type')
        if exp_type is None:
            exp_type = 'http' if remote.startswith('http') else 'grpc'
            exp_type = '{}-preproc'.format(
                exp_type) if preproc == 'server' else exp_type
            exp_type = f'{exp_type}-{task_name}'
        model = create_remote(
            exp_type,
            remote=remote,
            name=name,
            signature=signature_name,
            labels=labels,
            lengths_key=lengths_key,
            inputs=inputs,
            beam=beam,
            return_labels=return_labels,
            version=version,
        )
        return model, preproc
コード例 #17
0
    def init_embeddings(self, embeddings_map, basename):
        embeddings = dict()
        for key, class_name in embeddings_map:
            md = read_json('{}-{}-md.json'.format(basename, key))
            embed_args = dict({'vsz': md['vsz'], 'dsz': md['dsz']})
            Constructor = eval(class_name)
            embeddings[key] = Constructor(key, **embed_args)

        return embeddings
コード例 #18
0
ファイル: model.py プロジェクト: amyhemmeter/baseline
    def load(cls, basename, **kwargs):
        """Reload the model from a graph file and a checkpoint
        The model that is loaded is independent of the pooling and stacking layers, making this class reusable
        by sub-classes.

        :param basename: The base directory to load from
        :param kwargs: See below

        :Keyword Arguments:
        * *sess* -- An optional tensorflow session.  If not passed, a new session is
            created

        :return: A restored model
        """
        _state = read_json("{}.state".format(basename))
        if __version__ != _state['version']:
            logger.warning(
                "Loaded model is from baseline version %s, running version is %s",
                _state['version'], __version__)
        _state['sess'] = kwargs.pop('sess', create_session())
        embeddings_info = _state.pop("embeddings")

        with _state['sess'].graph.as_default():
            embeddings = reload_embeddings(embeddings_info, basename)
            for k in embeddings_info:
                if k in kwargs:
                    _state[k] = kwargs[k]
            labels = read_json("{}.labels".format(basename))
            if _state.get('constraint') is not None:
                # Dummy constraint values that will be filled in by the check pointing
                _state['constraint'] = [
                    tf.zeros((len(labels), len(labels))) for _ in range(2)
                ]
            if 'lengths' in kwargs:
                _state['lengths'] = kwargs['lengths']
            model = cls.create(embeddings, labels, **_state)
            model._state = _state
            model.create_loss()
            if kwargs.get('init', True):
                model.sess.run(tf.global_variables_initializer())
            model.saver = tf.train.Saver()
            model.saver.restore(model.sess, basename)
            return model
コード例 #19
0
ファイル: embed_elmo.py プロジェクト: DevSinghSachan/baseline
    def __init__(self, name, embed_file=None, known_vocab=None, **kwargs):
        super(ELMoEmbeddings, self).__init__(name=name)

        # options file
        self.weight_file = embed_file
        self.dsz = kwargs['dsz']
        elmo_config = embed_file.replace('weights.hdf5', 'options.json')
        self.model = BidirectionalLanguageModel(elmo_config, self.weight_file)
        self.vocab = UnicodeCharsVocabulary(known_vocab)
        elmo_config = read_json(elmo_config)
        assert self.dsz == 2 * int(elmo_config['lstm']['projection_dim'])
コード例 #20
0
def reload_embeddings(embeddings_dict, basename):
    embeddings = {}
    for key, cls in embeddings_dict.items():
        embed_args = read_json('{}-{}-md.json'.format(basename, key))
        module = embed_args.pop('module')
        name = embed_args.pop('name', None)
        assert name is None or name == key
        mod = import_user_module(module)
        Constructor = getattr(mod, cls)
        embeddings[key] = Constructor(key, **embed_args)
    return embeddings
コード例 #21
0
ファイル: tfy.py プロジェクト: dpressel/baseline
def reload_embeddings(embeddings_dict, basename):
    embeddings = {}
    for key, cls in embeddings_dict.items():
        embed_args = read_json('{}-{}-md.json'.format(basename, key))
        module = embed_args.pop('module')
        name = embed_args.pop('name', None)
        assert name is None or name == key
        mod = import_user_module(module)
        Constructor = getattr(mod, cls)
        embeddings[key] = Constructor(key, **embed_args)
    return embeddings
コード例 #22
0
ファイル: embed_elmo.py プロジェクト: dpressel/baseline
    def __init__(self, name, embed_file=None, known_vocab=None, **kwargs):
        super(ELMoEmbeddings, self).__init__(name=name, **kwargs)

        # options file
        self.weight_file = embed_file
        self.dsz = kwargs['dsz']
        elmo_config = embed_file.replace('weights.hdf5', 'options.json')
        self.model = BidirectionalLanguageModel(elmo_config, self.weight_file)
        self.known_vocab = known_vocab
        self.vocab = UnicodeCharsVocabulary(known_vocab)
        elmo_config = read_json(elmo_config)
        assert self.dsz == 2*int(elmo_config['lstm']['projection_dim'])
コード例 #23
0
    def load(cls, basename, **kwargs):
        state = read_json(basename + '.state')
        if 'predict' in kwargs:
            state['predict'] = kwargs['predict']

        if 'beam' in kwargs:
            state['beam'] = kwargs['beam']

        state['sess'] = kwargs.get('sess', tf.Session())
        state['model_type'] = kwargs.get('model_type', 'default')

        with open(basename + '.saver') as fsv:
            saver_def = tf.train.SaverDef()
            text_format.Merge(fsv.read(), saver_def)

        src_embeddings = dict()
        src_embeddings_dict = state.pop('src_embeddings')
        for key, class_name in src_embeddings_dict.items():
            md = read_json('{}-{}-md.json'.format(basename, key))
            embed_args = dict({'vsz': md['vsz'], 'dsz': md['dsz']})
            Constructor = eval(class_name)
            src_embeddings[key] = Constructor(key, **embed_args)

        tgt_class_name = state.pop('tgt_embedding')
        md = read_json('{}-tgt-md.json'.format(basename))
        embed_args = dict({'vsz': md['vsz'], 'dsz': md['dsz']})
        Constructor = eval(tgt_class_name)
        tgt_embedding = Constructor('tgt', **embed_args)
        model = cls.create(src_embeddings, tgt_embedding, **state)
        for prop in ls_props(model):
            if prop in state:
                setattr(model, prop, state[prop])
        do_init = kwargs.get('init', True)
        if do_init:
            init = tf.global_variables_initializer()
            model.sess.run(init)

        model.saver = tf.train.Saver()
        model.saver.restore(model.sess, basename)
        return model
コード例 #24
0
 def __init__(self, name, **kwargs):
     super(TransformerLMEmbeddings, self).__init__(name)
     self.vocab = read_json(kwargs.get('vocab_file'))
     self.cls_index = self.vocab['[CLS]']
     self.vsz = len(self.vocab)
     layers = int(kwargs.get('layers', 8))
     num_heads = int(kwargs.get('num_heads', 10))
     pdrop = kwargs.get('dropout', 0.1)
     self.d_model = int(kwargs.get('dsz', kwargs.get('d_model', 410)))
     d_ff = int(kwargs.get('d_ff', 2100))
     x_embedding = PositionalLookupTableEmbeddings('pos', vsz=self.vsz, dsz=self.d_model)
     self.init_embed({'x': x_embedding})
     self.transformer = TransformerEncoderStack(num_heads, d_model=self.d_model, pdrop=pdrop, scale=True, layers=layers, d_ff=d_ff)
コード例 #25
0
    def _create_model(self, sess, basename):
        # Read the labels
        labels = read_json(basename + '.labels')

        # Get the parameters from MEAD
        model_params = self.task.config_params["model"]
        model_params["sess"] = sess

        # Read the state file
        state = read_json(basename + '.state')

        # Re-create the embeddings sub-graph
        embeddings = dict()
        for key, class_name in state['embeddings'].items():
            md = read_json('{}-{}-md.json'.format(basename, key))
            embed_args = dict({'vsz': md['vsz'], 'dsz': md['dsz']})
            Constructor = eval(class_name)
            embeddings[key] = Constructor(key, **embed_args)

        # Instantiate a graph
        model = baseline.model.create_model_for(self.task.task_name(),
                                                embeddings, labels,
                                                **model_params)

        # Set the properties of the model from the state file
        for prop in ls_props(model):
            if prop in state:
                setattr(model, prop, state[prop])

        # Append to the graph for class output
        values, indices = tf.nn.top_k(model.probs, len(labels))
        class_tensor = tf.constant(model.labels)
        table = tf.contrib.lookup.index_to_string_table_from_tensor(
            class_tensor)
        classes = table.lookup(tf.to_int64(indices))

        # Restore the checkpoint
        self._restore_checkpoint(sess, basename)
        return model, classes, values
コード例 #26
0
ファイル: services.py プロジェクト: dpressel/baseline
    def _create_remote_model(directory, backend, remote, name, signature_name, beam, **kwargs):
        """Reads the necessary information from the remote bundle to instatiate
        a client for a remote model.

        :directory the location of the exported model bundle
        :remote a url endpoint to hit
        :name the model name, as defined in tf-serving's model.config
        :signature_name  the signature to use.
        :beam used for s2s and found in the kwargs. We default this and pass it in.

        :returns a RemoteModel
        """
        from baseline.remote import create_remote
        assets = read_json(os.path.join(directory, 'model.assets'))
        model_name = assets['metadata']['exported_model']
        preproc = assets['metadata'].get('preproc', kwargs.get('preproc', 'client'))
        labels = read_json(os.path.join(directory, model_name) + '.labels')
        lengths_key = assets.get('lengths_key', None)
        inputs = assets.get('inputs', [])
        return_labels = bool(assets['metadata']['return_labels'])
        version = kwargs.get('version')

        if backend not in {'tf', 'pytorch'}:
            raise ValueError("only Tensorflow and Pytorch are currently supported for remote Services")
        import_user_module('baseline.{}.remote'.format(backend))
        exp_type = 'http' if remote.startswith('http') else 'grpc'
        exp_type = '{}-preproc'.format(exp_type) if preproc == 'server' else exp_type
        model = create_remote(
            exp_type,
            remote=remote, name=name,
            signature=signature_name,
            labels=labels,
            lengths_key=lengths_key,
            inputs=inputs,
            beam=beam,
            return_labels=return_labels,
            version=version,
        )
        return model, preproc
コード例 #27
0
    def __init__(self, name, **kwargs):
        from baseline.pytorch.embeddings import PositionalCharConvEmbeddings
        X_CHAR_EMBEDDINGS = {
            "dsz":
            16,
            "wsz":
            128,
            "embed_type":
            "positional-char-conv",
            "keep_unused":
            True,
            "cfiltsz": [[1, 32], [2, 32], [3, 64], [4, 128], [5, 256],
                        [6, 512], [7, 1024]],
            "gating":
            "highway",
            "num_gates":
            2,
            "projsz":
            512
        }
        super(TransformerLMEmbeddings, self).__init__(name)
        self.vocab = read_json(kwargs.get('vocab_file'), strict=True)
        self.cls_index = self.vocab['[CLS]']
        self.vsz = len(self.vocab)
        layers = int(kwargs.get('layers', 18))
        num_heads = int(kwargs.get('num_heads', 10))
        pdrop = kwargs.get('dropout', 0.1)
        self.d_model = int(kwargs.get('dsz', kwargs.get('d_model', 410)))
        d_ff = int(kwargs.get('d_ff', 2100))
        x_embedding = PositionalCharConvEmbeddings('pcc',
                                                   vsz=self.vsz,
                                                   **X_CHAR_EMBEDDINGS)
        self.dsz = self.init_embed({'x': x_embedding})
        self.proj_to_dsz = pytorch_linear(
            self.dsz, self.d_model) if self.dsz != self.d_model else _identity
        self.init_embed({'x': x_embedding})
        self.transformer = TransformerEncoderStack(num_heads,
                                                   d_model=self.d_model,
                                                   pdrop=pdrop,
                                                   scale=True,
                                                   layers=layers,
                                                   d_ff=d_ff)
        self.mlm = kwargs.get('mlm', False)

        pooling = kwargs.get('pooling', 'cls')
        if pooling == 'max':
            self.pooling_op = _max_pool
        elif pooling == 'mean':
            self.pooling_op = _mean_pool
        else:
            self.pooling_op = self._cls_pool
コード例 #28
0
ファイル: model.py プロジェクト: dpressel/baseline
    def load(cls, basename, **kwargs):
        """Reload the model from a graph file and a checkpoint

        The model that is loaded is independent of the pooling and stacking layers, making this class reusable
        by sub-classes.

        :param basename: The base directory to load from
        :param kwargs: See below

        :Keyword Arguments:
        * *sess* -- An optional tensorflow session.  If not passed, a new session is
            created

        :return: A restored model
        """
        _state = read_json("{}.state".format(basename))
        if __version__ != _state['version']:
            logger.warning("Loaded model is from baseline version %s, running version is %s", _state['version'], __version__)
        _state['sess'] = kwargs.pop('sess', tf.Session())
        with _state['sess'].graph.as_default():
            embeddings_info = _state.pop('embeddings')
            embeddings = reload_embeddings(embeddings_info, basename)
            # If there is a kwarg that is the same name as an embedding object that
            # is taken to be the input of that layer. This allows for passing in
            # subgraphs like from a tf.split (for data parallel) or preprocessing
            # graphs that convert text to indices
            for k in embeddings_info:
                if k in kwargs:
                    _state[k] = kwargs[k]
            # TODO: convert labels into just another vocab and pass number of labels to models.
            labels = read_json("{}.labels".format(basename))
            model = cls.create(embeddings, labels, **_state)
            model._state = _state
            if kwargs.get('init', True):
                model.sess.run(tf.global_variables_initializer())
            model.saver = tf.train.Saver()
            model.saver.restore(model.sess, basename)
            return model
コード例 #29
0
ファイル: model.py プロジェクト: dpressel/baseline
    def load(cls, basename, **kwargs):
        """Reload the model from a graph file and a checkpoint
        The model that is loaded is independent of the pooling and stacking layers, making this class reusable
        by sub-classes.

        :param basename: The base directory to load from
        :param kwargs: See below

        :Keyword Arguments:
        * *sess* -- An optional tensorflow session.  If not passed, a new session is
            created

        :return: A restored model
        """
        _state = read_json("{}.state".format(basename))
        if __version__ != _state['version']:
            logger.warning("Loaded model is from baseline version %s, running version is %s", _state['version'], __version__)
        _state['sess'] = kwargs.pop('sess', tf.Session())
        embeddings_info = _state.pop("embeddings")

        with _state['sess'].graph.as_default():
            embeddings = reload_embeddings(embeddings_info, basename)
            for k in embeddings_info:
                if k in kwargs:
                    _state[k] = kwargs[k]
            labels = read_json("{}.labels".format(basename))
            if _state.get('constraint') is not None:
                # Dummy constraint values that will be filled in by the check pointing
                _state['constraint'] = [tf.zeros((len(labels), len(labels))) for _ in range(2)]
            model = cls.create(embeddings, labels, **_state)
            model._state = _state
            model.create_loss()
            if kwargs.get('init', True):
                model.sess.run(tf.global_variables_initializer())
            model.saver = tf.train.Saver()
            model.saver.restore(model.sess, basename)
            return model
コード例 #30
0
ファイル: model.py プロジェクト: dpressel/baseline
    def load(cls, basename, **kwargs):
        K.clear_session()
        model = cls()
        model.impl = keras.models.load_model(basename)
        state = read_json(basename + '.state')
        for prop in ls_props(model):
            if prop in state:
                setattr(model, prop, state[prop])
        inputs = dict({(v.name[:v.name.find(':')], v) for v in model.impl.inputs})

        model.embeddings = dict()
        for key, class_name in state['embeddings'].items():
            md = read_json('{}-{}-md.json'.format(basename, key))
            embed_args = dict({'vsz': md['vsz'], 'dsz': md['dsz']})
            embed_args[key] = inputs[key]
            Constructor = eval(class_name)
            model.embeddings[key] = Constructor(key, **embed_args)

        ##model.lengths_key = state.get('lengths_key')

        with open(basename + '.labels', 'r') as f:
            model.labels = json.load(f)

        return model
コード例 #31
0
ファイル: model.py プロジェクト: kiennguyen94/baseline
def load_model_for(activity, filename, **kwargs):
    # Sniff state to see if we need to import things
    state = read_json('{}.state'.format(filename))
    # There won't be a module for pytorch (there is no state file to load).
    if 'module' in state:
        import_user_module(state['module'])
    # Allow user to override model type (for back compat with old api), backoff
    # to the model type in the state file or to default.
    # TODO: Currently in pytorch all models are always reloaded with the load
    # classmethod with a default model class. This is fine given how simple pyt
    # loading is but it could cause problems if a model has a custom load
    model_type = kwargs.get('model_type', state.get('model_type', 'default'))
    creator_fn = BASELINE_LOADERS[activity][model_type]
    logger.info('Calling model %s', creator_fn)
    return creator_fn(filename, **kwargs)
コード例 #32
0
ファイル: model.py プロジェクト: dpressel/baseline
def load_model_for(activity, filename, **kwargs):
    # Sniff state to see if we need to import things
    state = read_json('{}.state'.format(filename))
    # There won't be a module for pytorch (there is no state file to load).
    if 'module' in state:
        import_user_module(state['module'])
    # Allow user to override model type (for back compat with old api), backoff
    # to the model type in the state file or to default.
    # TODO: Currently in pytorch all models are always reloaded with the load
    # classmethod with a default model class. This is fine given how simple pyt
    # loading is but it could cause problems if a model has a custom load
    model_type = kwargs.get('type', kwargs.get('model_type', state.get('type', state.get('model_type', 'default'))))
    creator_fn = BASELINE_LOADERS[activity][model_type]
    logger.info('Calling model %s', creator_fn)
    return creator_fn(filename, **kwargs)
コード例 #33
0
    def load(cls, basename: str, **kwargs) -> 'DependencyParserModelBase':
        """Reload the model from a graph file and a checkpoint

        The model that is loaded is independent of the pooling and stacking layers, making this class reusable
        by sub-classes.

        :param basename: The base directory to load from
        :param kwargs: See below

        :Keyword Arguments:
        * *sess* -- An optional tensorflow session.  If not passed, a new session is
            created

        :return: A restored model
        """
        _state = read_json("{}.state".format(basename))
        if __version__ != _state['version']:
            logger.warning(
                "Loaded model is from baseline version %s, running version is %s",
                _state['version'], __version__)

        embeddings_info = _state.pop('embeddings')
        embeddings = reload_embeddings(embeddings_info, basename)
        # If there is a kwarg that is the same name as an embedding object that
        # is taken to be the input of that layer. This allows for passing in
        # subgraphs like from a tf.split (for data parallel) or preprocessing
        # graphs that convert text to indices
        for k in embeddings_info:
            if k in kwargs:
                _state[k] = kwargs[k]
            # TODO: convert labels into just another vocab and pass number of labels to models.
        labels = read_json("{}.labels".format(basename))
        model = cls.create(embeddings, labels, **_state)
        model._state = _state
        model.load_weights(f"{basename}.wgt")
        return model
コード例 #34
0
 def load(cls, basename, **kwargs):
     _state = read_json("{}.state".format(basename))
     if __version__ != _state['version']:
         bl_logger.warning("Loaded model is from baseline version %s, running version is %s", _state['version'], __version__)
     _state['sess'] = kwargs.pop('sess', create_session())
     with _state['sess'].graph.as_default():
         embeddings_info = _state.pop('embeddings')
         embeddings = reload_embeddings(embeddings_info, basename)
         for k in embeddings_info:
             if k in kwargs:
                 _state[k] = kwargs[k]
         model = cls.create(embeddings, init=kwargs.get('init', True), **_state)
         model._state = _state
         model.saver = tf.train.Saver()
         model.saver.restore(model.sess, basename)
     return model
コード例 #35
0
ファイル: exporters.py プロジェクト: dpressel/baseline
 def _create_model(self, sess, basename, **kwargs):
     model = load_tagger_model(basename, sess=sess, **kwargs)
     softmax_output = tf.nn.softmax(model.probs)
     values, _ = tf.nn.top_k(softmax_output, 1)
     indices = model.best
     if self.return_labels:
         labels = read_json(basename + '.labels')
         list_of_labels = [''] * len(labels)
         for label, idval in labels.items():
             list_of_labels[idval] = label
         class_tensor = tf.constant(list_of_labels)
         table = tf.contrib.lookup.index_to_string_table_from_tensor(class_tensor)
         classes = table.lookup(tf.to_int64(indices))
         return model, classes, values
     else:
         return model, indices, values
コード例 #36
0
ファイル: exporters.py プロジェクト: kiennguyen94/baseline
 def _create_model(self, sess, basename, **kwargs):
     model = load_tagger_model(basename, sess=sess, **kwargs)
     softmax_output = tf.nn.softmax(model.probs)
     values, _ = tf.nn.top_k(softmax_output, 1)
     indices = model.best
     if self.return_labels:
         labels = read_json(basename + '.labels')
         list_of_labels = [''] * len(labels)
         for label, idval in labels.items():
             list_of_labels[idval] = label
         class_tensor = tf.constant(list_of_labels)
         table = tf.contrib.lookup.index_to_string_table_from_tensor(class_tensor)
         classes = table.lookup(tf.to_int64(indices))
         return model, classes, values
     else:
         return model, indices, values
コード例 #37
0
ファイル: embed_elmo_tf.py プロジェクト: mead-ml/hub
    def __init__(self, name, embed_file=None, known_vocab=None, **kwargs):
        super().__init__(trainable=True, name=name, dtype=tf.float32, **kwargs)
        # options file
        self.weight_file = embed_file
        if 'options' not in kwargs:
            print('Warning: old style configuration')
            elmo_config = embed_file.replace('weights.hdf5', 'options.json')
            elmo_config = read_json(elmo_config)
        else:
            elmo_config = kwargs['options']
            print(elmo_config)
        self.dsz = kwargs.get('dsz', 2*int(elmo_config['lstm']['projection_dim']))
        self.model = BidirectionalLanguageModel(elmo_config, self.weight_file)
        self.known_vocab = known_vocab
        self.vocab = UnicodeCharsVocabulary(known_vocab)

        assert self.dsz == 2*int(elmo_config['lstm']['projection_dim'])
コード例 #38
0
 def __init__(self, logger_file, mead_config):
     super(Task, self).__init__()
     self.config_params = None
     self.ExporterType = None
     self.mead_config = mead_config
     if os.path.exists(mead_config):
         mead_settings = read_json(mead_config)
     else:
         mead_settings = {}
     if 'datacache' not in mead_settings:
         self.data_download_cache = os.path.expanduser("~/.bl-data")
         mead_settings['datacache'] = self.data_download_cache
         write_json(mead_settings, mead_config)
     else:
         self.data_download_cache = os.path.expanduser(mead_settings['datacache'])
     print("using {} as data/embeddings cache".format(self.data_download_cache))
     self._configure_logger(logger_file)
コード例 #39
0
ファイル: cli.py プロジェクト: dpressel/baseline
def read_cred(config_file):
    dbtype = None
    dbhost = None
    dbport = None
    user = None
    passwd = None

    try:
        j = read_json(config_file, strict=True)
        dbtype = j.get('dbtype')
        dbhost = j.get('dbhost')
        dbport = j.get('dbport')
        user = j.get('user')
        passwd = j.get('passwd')

    except IOError:
        pass

    return dbtype, dbhost, dbport, user, passwd
コード例 #40
0
def read_cred(config_file):
    dbtype = None
    dbhost = None
    dbport = None
    user = None
    passwd = None

    try:
        j = read_json(config_file, strict=True)
        dbtype = j.get('dbtype')
        dbhost = j.get('dbhost')
        dbport = j.get('dbport')
        user = j.get('user')
        passwd = j.get('passwd')

    except IOError:
        pass

    return dbtype, dbhost, dbport, user, passwd
コード例 #41
0
ファイル: downloader.py プロジェクト: dpressel/baseline
 def download(self):
     file_loc = self.dataset_file
     if is_file_correct(file_loc):
         return file_loc
     elif validate_url(file_loc):  # is it a web URL? check if exists in cache
         url = file_loc
         dcache_path = os.path.join(self.data_download_cache, DATA_CACHE_CONF)
         dcache = read_json(dcache_path)
         if url in dcache and is_file_correct(dcache[url], self.data_download_cache, url) and not self.cache_ignore:
             logger.info("file for {} found in cache, not downloading".format(url))
             return dcache[url]
         else:  # download the file in the cache, update the json
             cache_dir = self.data_download_cache
             logger.info("using {} as data/embeddings cache".format(cache_dir))
             temp_file = web_downloader(url)
             dload_file = extractor(filepath=temp_file, cache_dir=cache_dir,
                                    extractor_func=Downloader.ZIPD.get(mime_type(temp_file), None))
             dcache.update({url: dload_file})
             write_json(dcache, os.path.join(self.data_download_cache, DATA_CACHE_CONF))
             return dload_file
     raise RuntimeError("the file [{}] is not in cache and can not be downloaded".format(file_loc))
コード例 #42
0
    def init_decode(self, **kwargs):

        label_vocab = read_json(kwargs["label_vocab"])
        label_trans = np.load(kwargs["label_trans"])

        trans = np.zeros((len(self.labels), len(self.labels)))

        for src, src_idx in self.labels.items():
            if src not in label_vocab:
                continue
            for tgt, tgt_idx in self.labels.items():
                if tgt not in label_vocab:
                    continue
                trans[src_idx, tgt_idx] = label_trans[label_vocab[src],
                                                      label_vocab[tgt]]

        name = kwargs.get("decode_name")
        self.constraint_mask = kwargs.get("constraint_mask").unsqueeze(0)
        return PremadeTransitionsTagger(len(self.labels),
                                        self.constraint_mask,
                                        transitions=trans)
コード例 #43
0
    def load(cls, basename, **kwargs):
        _state = read_json('{}.state'.format(basename))
        if __version__ != _state['version']:
            logger.warning("Loaded model is from baseline version %s, running version is %s", _state['version'], __version__)
        if 'predict' in kwargs:
            _state['predict'] = kwargs['predict']
        if 'beam' in kwargs:
            _state['beam'] = kwargs['beam']

        src_embeddings_info = _state.pop('src_embeddings')
        src_embeddings = reload_embeddings(src_embeddings_info, basename)
        for k in src_embeddings_info:
            if k in kwargs:
                _state[k] = kwargs[k]
        tgt_embedding_info = _state.pop('tgt_embedding')
        tgt_embedding = reload_embeddings(tgt_embedding_info, basename)['tgt']

        model = cls.create(src_embeddings, tgt_embedding, **_state)
        model._state = _state
        model.load_weights(f"{basename}.wgt")
        return model
コード例 #44
0
def peek_lengths_key(model_file, field_map):
    """Check if there is lengths key to use when finding the length. (defaults to tokens)"""
    peek_state = read_json(model_file + ".state")
    lengths = peek_state.get('lengths_key', 'tokens')
    lengths = lengths.replace('_lengths', '')
    return field_map.get(lengths, lengths)
コード例 #45
0
ファイル: downloader.py プロジェクト: dpressel/baseline
def update_cache(key, data_download_cache):
    dcache = read_json(os.path.join(data_download_cache, DATA_CACHE_CONF))
    if key not in dcache:
        return
    del dcache[key]
    write_json(dcache, os.path.join(data_download_cache, DATA_CACHE_CONF))
コード例 #46
0
ファイル: test_read_files.py プロジェクト: dpressel/baseline
def test_read_json_strict():
    with pytest.raises(IOError):
        read_json(os.path.join('not_there.json'), strict=True)
コード例 #47
0
ファイル: test_read_files.py プロジェクト: dpressel/baseline
def test_read_json_given_default():
    gold_default = 'default'
    data = read_json(os.path.join(data_loc, 'not_there.json'), gold_default)
    assert data == gold_default
コード例 #48
0
ファイル: test_read_files.py プロジェクト: dpressel/baseline
def test_read_json_default_value():
    gold_default = {}
    data = read_json(os.path.join(data_loc, 'not_there.json'))
    assert data == gold_default
コード例 #49
0
ファイル: test_read_files.py プロジェクト: dpressel/baseline
def test_read_json(gold_data):
    data = read_json(os.path.join(data_loc, 'test_json.json'))
    assert data == gold_data
コード例 #50
0
ファイル: download_all.py プロジェクト: dpressel/baseline
import os
import argparse
from baseline.utils import read_json
from mead.utils import index_by_label, convert_path
from mead.downloader import EmbeddingDownloader, DataDownloader


parser = argparse.ArgumentParser(description="Download all data and embeddings.")
parser.add_argument("--cache", default="~/.bl-data", type=os.path.expanduser, help="Location of the data cache")
parser.add_argument('--datasets', help='json library of dataset labels', default='config/datasets.json', type=convert_path)
parser.add_argument('--embeddings', help='json library of embeddings', default='config/embeddings.json', type=convert_path)
args = parser.parse_args()


datasets = read_json(args.datasets)
datasets = index_by_label(datasets)

for name, d in datasets.items():
    print(name)
    try:
        DataDownloader(d, args.cache).download()
    except Exception as e:
        print(e)


emb = read_json(args.embeddings)
emb = index_by_label(emb)

for name, e in emb.items():
    print(name)
    try: