def load(cls, basename, **kwargs): K.clear_session() model = cls() model.impl = keras.models.load_model(basename) state = read_json(basename + '.state') for prop in ls_props(model): if prop in state: setattr(model, prop, state[prop]) inputs = dict({(v.name[:v.name.find(':')], v) for v in model.impl.inputs}) model.embeddings = dict() for key, class_name in state['embeddings'].items(): md = read_json('{}-{}-md.json'.format(basename, key)) embed_args = dict({'vsz': md['vsz'], 'dsz': md['dsz']}) embed_args[key] = inputs[key] Constructor = eval(class_name) model.embeddings[key] = Constructor(key, **embed_args) ##model.lengths_key = state.get('lengths_key') with open(basename + '.labels', 'r') as f: model.labels = json.load(f) return model
def _create_remote_model(directory, backend, remote, name, signature_name, beam, preproc='client'): """Reads the necessary information from the remote bundle to instatiate a client for a remote model. :directory the location of the exported model bundle :remote a url endpoint to hit :name the model name, as defined in tf-serving's model.config :signature_name the signature to use. :beam used for s2s and found in the kwargs. We default this and pass it in. :returns a RemoteModel """ assets = read_json(os.path.join(directory, 'model.assets')) model_name = assets['metadata']['exported_model'] labels = read_json(os.path.join(directory, model_name) + '.labels') lengths_key = assets.get('lengths_key', None) inputs = assets.get('inputs', []) if backend == 'tf': remote_models = import_user_module('baseline.remote') if remote.startswith('http'): RemoteModel = remote_models.RemoteModelTensorFlowREST elif preproc == 'server': RemoteModel = remote_models.RemoteModelTensorFlowGRPCPreproc else: RemoteModel = remote_models.RemoteModelTensorFlowGRPC model = RemoteModel(remote, name, signature_name, labels=labels, lengths_key=lengths_key, inputs=inputs, beam=beam) else: raise ValueError("only Tensorflow is currently supported for remote Services") return model
def _create_model(self, sess, basename): labels = read_json(basename + '.labels') model_params = self.task.config_params["model"] model_params["sess"] = sess state = read_json(basename + '.state') if state.get('constrain_decode', False): constraint = transition_mask( labels, self.task.config_params['train']['span_type'], Offsets.GO, Offsets.EOS, Offsets.PAD) model_params['constraint'] = constraint # Re-create the embeddings sub-graph embeddings = dict() for key, class_name in state['embeddings'].items(): md = read_json('{}-{}-md.json'.format(basename, key)) embed_args = dict({'vsz': md['vsz'], 'dsz': md['dsz']}) Constructor = eval(class_name) embeddings[key] = Constructor(key, **embed_args) model = baseline.model.create_model_for(self.task.task_name(), embeddings, labels, **model_params) for prop in ls_props(model): if prop in state: setattr(model, prop, state[prop]) model.create_loss() softmax_output = tf.nn.softmax(model.probs) values, indices = tf.nn.top_k(softmax_output, 1) start_np = np.full((1, 1, len(labels)), -1e4, dtype=np.float32) start_np[:, 0, Offsets.GO] = 0 start = tf.constant(start_np) start = tf.tile(start, [tf.shape(model.probs)[0], 1, 1]) model.probs = tf.concat([start, model.probs], 1) ones = tf.fill(tf.shape(model.lengths), 1) lengths = tf.add(model.lengths, ones) if model.crf is True: indices, _ = tf.contrib.crf.crf_decode(model.probs, model.A, lengths) indices = indices[:, 1:] list_of_labels = [''] * len(labels) for label, idval in labels.items(): list_of_labels[idval] = label class_tensor = tf.constant(list_of_labels) table = tf.contrib.lookup.index_to_string_table_from_tensor( class_tensor) classes = table.lookup(tf.to_int64(indices)) self._restore_checkpoint(sess, basename) return model, indices, values
def load(cls, basename, **kwargs): """Reload the model from a graph file and a checkpoint The model that is loaded is independent of the pooling and stacking layers, making this class reusable by sub-classes. :param basename: The base directory to load from :param kwargs: See below :Keyword Arguments: * *sess* -- An optional tensorflow session. If not passed, a new session is created :return: A restored model """ _state = read_json("{}.state".format(basename)) # FIXME: Somehow not writing this anymore #if __version__ != _state['version']: # logger.warning("Loaded model is from baseline version %s, running version is %s", _state['version'], __version__) if not tf.executing_eagerly(): _state['sess'] = kwargs.pop('sess', create_session()) embeddings_info = _state.pop("embeddings") with _state['sess'].graph.as_default(): embeddings = reload_embeddings(embeddings_info, basename) for k in embeddings_info: if k in kwargs: _state[k] = kwargs[k] labels = read_json("{}.labels".format(basename)) # FIXME: referring to the `constraint_mask` in the base where its not mentioned isnt really clean if _state.get('constraint_mask') is not None: # Dummy constraint values that will be filled in by the check pointing _state['constraint_mask'] = [np.zeros((len(labels), len(labels))) for _ in range(2)] model = cls.create(embeddings, labels, **_state) model._state = _state model.create_loss() if kwargs.get('init', True): model.sess.run(tf.compat.v1.global_variables_initializer()) model.saver = tf.compat.v1.train.Saver() model.saver.restore(model.sess, basename) else: embeddings_info = _state.pop('embeddings') embeddings = reload_embeddings(embeddings_info, basename) for k in embeddings_info: if k in kwargs: _state[k] = kwargs[k] # TODO: convert labels into just another vocab and pass number of labels to models. labels = read_json("{}.labels".format(basename)) # FIXME: referring to the `constraint_mask` in the base where its not mentioned isnt really clean if _state.get('constraint_mask') is not None: # Dummy constraint values that will be filled in by the check pointing _state['constraint_mask'] = [np.zeros((len(labels), len(labels))) for _ in range(2)] model = cls.create(embeddings, labels, **_state) model._state = _state model.load_weights(f"{basename}.wgt") return model
def load(cls, basename: str, **kwargs) -> 'ClassifierModelBase': """Reload the model from a graph file and a checkpoint The model that is loaded is independent of the pooling and stacking layers, making this class reusable by sub-classes. :param basename: The base directory to load from :param kwargs: See below :Keyword Arguments: * *sess* -- An optional tensorflow session. If not passed, a new session is created :return: A restored model """ _state = read_json("{}.state".format(basename)) if __version__ != _state['version']: logger.warning( "Loaded model is from baseline version %s, running version is %s", _state['version'], __version__) if not tf.executing_eagerly(): _state['sess'] = kwargs.pop('sess', create_session()) with _state['sess'].graph.as_default(): embeddings_info = _state.pop('embeddings') embeddings = reload_embeddings(embeddings_info, basename) # If there is a kwarg that is the same name as an embedding object that # is taken to be the input of that layer. This allows for passing in # subgraphs like from a tf.split (for data parallel) or preprocessing # graphs that convert text to indices for k in embeddings_info: if k in kwargs: _state[k] = kwargs[k] labels = read_json("{}.labels".format(basename)) model = cls.create(embeddings, labels, **_state) model._state = _state if kwargs.get('init', True): model.sess.run(tf.compat.v1.global_variables_initializer()) model.saver = tf.compat.v1.train.Saver() model.saver.restore(model.sess, basename) else: embeddings_info = _state.pop('embeddings') embeddings = reload_embeddings(embeddings_info, basename) # If there is a kwarg that is the same name as an embedding object that # is taken to be the input of that layer. This allows for passing in # subgraphs like from a tf.split (for data parallel) or preprocessing # graphs that convert text to indices for k in embeddings_info: if k in kwargs: _state[k] = kwargs[k] # TODO: convert labels into just another vocab and pass number of labels to models. labels = read_json("{}.labels".format(basename)) model = cls.create(embeddings, labels, **_state) model._state = _state model.load_weights(f"{basename}.wgt") return model
def load(cls, basename, **kwargs): """Reload the model from a graph file and a checkpoint The model that is loaded is independent of the pooling and stacking layers, making this class reusable by sub-classes. :param basename: The base directory to load from :param kwargs: See below :Keyword Arguments: * *sess* -- An optional tensorflow session. If not passed, a new session is created :return: A restored model """ _state = read_json(basename + '.state') if not tf.executing_eagerly(): _state['sess'] = kwargs.pop('sess', create_session()) embeddings_info = _state.pop("embeddings") with _state['sess'].graph.as_default(): embeddings = reload_embeddings(embeddings_info, basename) for k in embeddings_info: if k in kwargs: _state[k] = kwargs[k] _state['model_type'] = kwargs.get('model_type', 'default') model = cls.create(embeddings, **_state) model._state = _state do_init = kwargs.get('init', True) if do_init: init = tf.compat.v1.global_variables_initializer() model.sess.run(init) model.saver = tf.compat.v1.train.Saver() model.saver.restore(model.sess, basename) else: _state = read_json(basename + '.state') _state['model_type'] = kwargs.get('model_type', 'default') embeddings = {} embeddings_dict = _state.pop("embeddings") for key, class_name in embeddings_dict.items(): md = read_json('{}-{}-md.json'.format(basename, key)) embed_args = dict({'vsz': md['vsz'], 'dsz': md['dsz']}) Constructor = eval(class_name) embeddings[key] = Constructor(key, **embed_args) model = cls.create(embeddings, **_state) model._state = _state model.load_weights(f"{basename}.wgt") return model
def download(self): if is_file_correct(self.embedding_file): logger.info("embedding file location: {}".format(self.embedding_file)) return self.embedding_file dcache_path = os.path.join(self.data_download_cache, DATA_CACHE_CONF) dcache = read_json(dcache_path) if self.embedding_file in dcache and not self.cache_ignore: download_loc = dcache[self.embedding_file] logger.info("files for {} found in cache".format(self.embedding_file)) return self._get_embedding_file(download_loc, self.embedding_key) else: # try to download the bundle and unzip url = self.embedding_file if not validate_url(url): raise RuntimeError("can not download from the given url") else: cache_dir = self.data_download_cache temp_file = web_downloader(url) download_loc = extractor(filepath=temp_file, cache_dir=cache_dir, extractor_func=Downloader.ZIPD.get(mime_type(temp_file), None)) if self.sha1 is not None: if os.path.split(download_loc)[-1] != self.sha1: raise RuntimeError("The sha1 of the downloaded file does not match with the provided one") dcache.update({url: download_loc}) write_json(dcache, os.path.join(self.data_download_cache, DATA_CACHE_CONF)) return self._get_embedding_file(download_loc, self.embedding_key)
def download(self): dload_bundle = self.dataset_desc.get("download", None) if dload_bundle is not None: # download a zip/tar/tar.gz directory, look for train, dev test files inside that. dcache_path = os.path.join(self.data_download_cache, DATA_CACHE_CONF) dcache = read_json(dcache_path) if dload_bundle in dcache and \ is_dir_correct(dcache[dload_bundle], self.dataset_desc, self.data_download_cache, dload_bundle, self.enc_dec) and not self.cache_ignore: download_dir = dcache[dload_bundle] logger.info("files for {} found in cache, not downloading".format(dload_bundle)) return {k: os.path.join(download_dir, self.dataset_desc[k]) for k in self.dataset_desc if k.endswith("_file")} else: # try to download the bundle and unzip if not validate_url(dload_bundle): raise RuntimeError("can not download from the given url") else: cache_dir = self.data_download_cache temp_file = web_downloader(dload_bundle) download_dir = extractor(filepath=temp_file, cache_dir=cache_dir, extractor_func=Downloader.ZIPD.get(mime_type(temp_file), None)) if "sha1" in self.dataset_desc: if os.path.split(download_dir)[-1] != self.dataset_desc["sha1"]: raise RuntimeError("The sha1 of the downloaded file does not match with the provided one") dcache.update({dload_bundle: download_dir}) write_json(dcache, os.path.join(self.data_download_cache, DATA_CACHE_CONF)) return {k: os.path.join(download_dir, self.dataset_desc[k]) for k in self.dataset_desc if k.endswith("_file")} else: # we have download links to every file or they exist if not self.enc_dec: return {k: SingleFileDownloader(self.dataset_desc[k], self.data_download_cache).download() for k in self.dataset_desc if k.endswith("_file")} else: return {k: self.dataset_desc[k] for k in self.dataset_desc if k.endswith("_file")}
def load(cls, basename, **kwargs): _state = read_json('{}.state'.format(basename)) if __version__ != _state['version']: logger.warning( "Loaded model is from baseline version %s, running version is %s", _state['version'], __version__) if 'predict' in kwargs: _state['predict'] = kwargs['predict'] if 'beam' in kwargs: _state['beam'] = kwargs['beam'] _state['sess'] = kwargs.get('sess', create_session()) with _state['sess'].graph.as_default(): src_embeddings_info = _state.pop('src_embeddings') src_embeddings = reload_embeddings(src_embeddings_info, basename) for k in src_embeddings_info: if k in kwargs: _state[k] = kwargs[k] tgt_embedding_info = _state.pop('tgt_embedding') tgt_embedding = reload_embeddings(tgt_embedding_info, basename)['tgt'] model = cls.create(src_embeddings, tgt_embedding, **_state) model._state = _state if kwargs.get('init', True): model.sess.run(tf.compat.v1.global_variables_initializer()) model.saver = tf.compat.v1.train.Saver() #reload_checkpoint(model.sess, basename, ['OptimizeLoss/']) model.saver.restore(model.sess, basename) return model
def load(cls, basename, **kwargs): _state = read_json('{}.state'.format(basename)) if __version__ != _state['version']: logger.warning("Loaded model is from baseline version %s, running version is %s", _state['version'], __version__) if 'predict' in kwargs: _state['predict'] = kwargs['predict'] if 'sampling' in kwargs: _state['sampling'] = kwargs['sampling'] if 'sampling_temp' in kwargs: _state['sampling_temp'] = kwargs['sampling_temp'] if 'beam' in kwargs: _state['beam'] = kwargs['beam'] _state['sess'] = kwargs.get('sess', tf.Session()) with _state['sess'].graph.as_default(): embeddings_info = _state.pop('embeddings') embeddings = reload_embeddings(embeddings_info, basename) for k in embeddings_info: if k in kwargs: _state[k] = kwargs[k] model = cls.create(embeddings, **_state) if kwargs.get('init', True): model.sess.run(tf.global_variables_initializer()) model.saver = tf.train.Saver() model.saver.restore(model.sess, basename) return model
def download(self): file_loc = self.dataset_file if is_file_correct(file_loc): return file_loc elif validate_url( file_loc): # is it a web URL? check if exists in cache url = file_loc dcache_path = os.path.join(self.data_download_cache, DATA_CACHE_CONF) dcache = read_json(dcache_path) if url in dcache and is_file_correct( dcache[url], self.data_download_cache, url) and not self.cache_ignore: print( "file for {} found in cache, not downloading".format(url)) return dcache[url] else: # download the file in the cache, update the json cache_dir = self.data_download_cache print("using {} as data/embeddings cache".format(cache_dir)) temp_file = web_downloader(url) dload_file = extractor(filepath=temp_file, cache_dir=cache_dir, extractor_func=Downloader.ZIPD.get( mime_type(temp_file), None)) dcache.update({url: dload_file}) write_json( dcache, os.path.join(self.data_download_cache, DATA_CACHE_CONF)) return dload_file raise RuntimeError( "the file [{}] is not in cache and can not be downloaded".format( file_loc))
def load(cls, basename, **kwargs): _state = read_json('{}.state'.format(basename)) if __version__ != _state['version']: logger.warning( "Loaded model is from baseline version %s, running version is %s", _state['version'], __version__) if 'predict' in kwargs: _state['predict'] = kwargs['predict'] if 'sampling' in kwargs: _state['sampling'] = kwargs['sampling'] if 'sampling_temp' in kwargs: _state['sampling_temp'] = kwargs['sampling_temp'] if 'beam' in kwargs: _state['beam'] = kwargs['beam'] _state['sess'] = kwargs.get('sess', tf.Session()) embeddings_info = _state.pop('embeddings') embeddings = reload_embeddings(embeddings_info, basename) for k in embeddings_info: if k in kwargs: _state[k] = kwargs[k] model = cls.create(embeddings, **_state) if kwargs.get('init', True): model.sess.run(tf.global_variables_initializer()) model.saver = tf.train.Saver() model.saver.restore(model.sess, basename) return model
def download(self): dload_bundle = self.dataset_desc.get("download", None) if dload_bundle is not None: # download a zip/tar/tar.gz directory, look for train, dev test files inside that. dcache_path = os.path.join(self.data_download_cache, DATA_CACHE_CONF) dcache = read_json(dcache_path) if dload_bundle in dcache and \ is_dir_correct(dcache[dload_bundle], self.dataset_desc, self.data_download_cache, dload_bundle, self.enc_dec) and not self.cache_ignore: download_dir = dcache[dload_bundle] logger.info("files for {} found in cache, not downloading".format(dload_bundle)) return {k: os.path.join(download_dir, self.dataset_desc[k]) for k in self.dataset_desc if k.endswith("_file")} else: # try to download the bundle and unzip if not validate_url(dload_bundle): raise RuntimeError("can not download from the given url") else: cache_dir = self.data_download_cache temp_file = web_downloader(dload_bundle) download_dir = extractor(filepath=temp_file, cache_dir=cache_dir, extractor_func=Downloader.ZIPD.get(mime_type(temp_file), None)) if "sha1" in self.dataset_desc: if os.path.split(download_dir)[-1] != self.dataset_desc["sha1"]: raise RuntimeError("The sha1 of the downloaded file does not match with the provided one") dcache.update({dload_bundle: download_dir}) write_json(dcache, os.path.join(self.data_download_cache, DATA_CACHE_CONF)) return {k: os.path.join(download_dir, self.dataset_desc[k]) for k in self.dataset_desc if k.endswith("_file")} else: # we have download links to every file or they exist if not self.enc_dec: return {k: SingleFileDownloader(self.dataset_desc[k], self.data_download_cache).download() for k in self.dataset_desc if k.endswith("_file") and self.dataset_desc[k]} else: return {k: self.dataset_desc[k] for k in self.dataset_desc if k.endswith("_file")}
def load(cls, bundle, **kwargs): """Load a model from a bundle. This can be either a local model or a remote, exported model. :returns a Service implementation """ import onnxruntime as ort # can delegate if os.path.isdir(bundle): directory = bundle # Try and unzip if its a zip file else: directory = unzip_files(bundle) model_basename = find_model_basename(directory) # model_basename = model_basename.replace(".pyt", "") model_name = f"{model_basename}.onnx" vocabs = load_vocabs(directory) vectorizers = load_vectorizers(directory) # Currently nothing to do here labels = read_json(model_basename + '.labels') model = ort.InferenceSession(model_name) return cls(vocabs, vectorizers, model, labels)
def _create_remote_model(directory, backend, remote, name, task_name, signature_name, beam, **kwargs): """Reads the necessary information from the remote bundle to instatiate a client for a remote model. :directory the location of the exported model bundle :remote a url endpoint to hit :name the model name, as defined in tf-serving's model.config :signature_name the signature to use. :beam used for s2s and found in the kwargs. We default this and pass it in. :returns a RemoteModel """ from baseline.remote import create_remote assets = read_json(os.path.join(directory, 'model.assets')) model_name = assets['metadata']['exported_model'] preproc = assets['metadata'].get('preproc', kwargs.get('preproc', 'client')) labels = read_json(os.path.join(directory, model_name) + '.labels') lengths_key = assets.get('lengths_key', None) inputs = assets.get('inputs', []) return_labels = bool(assets['metadata']['return_labels']) version = kwargs.get('version') if backend not in {'tf', 'onnx'}: raise ValueError( f"Unsupported backend {backend} for remote Services") import_user_module('baseline.{}.remote'.format(backend)) exp_type = kwargs.get('remote_type') if exp_type is None: exp_type = 'http' if remote.startswith('http') else 'grpc' exp_type = '{}-preproc'.format( exp_type) if preproc == 'server' else exp_type exp_type = f'{exp_type}-{task_name}' model = create_remote( exp_type, remote=remote, name=name, signature=signature_name, labels=labels, lengths_key=lengths_key, inputs=inputs, beam=beam, return_labels=return_labels, version=version, ) return model, preproc
def init_embeddings(self, embeddings_map, basename): embeddings = dict() for key, class_name in embeddings_map: md = read_json('{}-{}-md.json'.format(basename, key)) embed_args = dict({'vsz': md['vsz'], 'dsz': md['dsz']}) Constructor = eval(class_name) embeddings[key] = Constructor(key, **embed_args) return embeddings
def load(cls, basename, **kwargs): """Reload the model from a graph file and a checkpoint The model that is loaded is independent of the pooling and stacking layers, making this class reusable by sub-classes. :param basename: The base directory to load from :param kwargs: See below :Keyword Arguments: * *sess* -- An optional tensorflow session. If not passed, a new session is created :return: A restored model """ _state = read_json("{}.state".format(basename)) if __version__ != _state['version']: logger.warning( "Loaded model is from baseline version %s, running version is %s", _state['version'], __version__) _state['sess'] = kwargs.pop('sess', create_session()) embeddings_info = _state.pop("embeddings") with _state['sess'].graph.as_default(): embeddings = reload_embeddings(embeddings_info, basename) for k in embeddings_info: if k in kwargs: _state[k] = kwargs[k] labels = read_json("{}.labels".format(basename)) if _state.get('constraint') is not None: # Dummy constraint values that will be filled in by the check pointing _state['constraint'] = [ tf.zeros((len(labels), len(labels))) for _ in range(2) ] if 'lengths' in kwargs: _state['lengths'] = kwargs['lengths'] model = cls.create(embeddings, labels, **_state) model._state = _state model.create_loss() if kwargs.get('init', True): model.sess.run(tf.global_variables_initializer()) model.saver = tf.train.Saver() model.saver.restore(model.sess, basename) return model
def __init__(self, name, embed_file=None, known_vocab=None, **kwargs): super(ELMoEmbeddings, self).__init__(name=name) # options file self.weight_file = embed_file self.dsz = kwargs['dsz'] elmo_config = embed_file.replace('weights.hdf5', 'options.json') self.model = BidirectionalLanguageModel(elmo_config, self.weight_file) self.vocab = UnicodeCharsVocabulary(known_vocab) elmo_config = read_json(elmo_config) assert self.dsz == 2 * int(elmo_config['lstm']['projection_dim'])
def reload_embeddings(embeddings_dict, basename): embeddings = {} for key, cls in embeddings_dict.items(): embed_args = read_json('{}-{}-md.json'.format(basename, key)) module = embed_args.pop('module') name = embed_args.pop('name', None) assert name is None or name == key mod = import_user_module(module) Constructor = getattr(mod, cls) embeddings[key] = Constructor(key, **embed_args) return embeddings
def __init__(self, name, embed_file=None, known_vocab=None, **kwargs): super(ELMoEmbeddings, self).__init__(name=name, **kwargs) # options file self.weight_file = embed_file self.dsz = kwargs['dsz'] elmo_config = embed_file.replace('weights.hdf5', 'options.json') self.model = BidirectionalLanguageModel(elmo_config, self.weight_file) self.known_vocab = known_vocab self.vocab = UnicodeCharsVocabulary(known_vocab) elmo_config = read_json(elmo_config) assert self.dsz == 2*int(elmo_config['lstm']['projection_dim'])
def load(cls, basename, **kwargs): state = read_json(basename + '.state') if 'predict' in kwargs: state['predict'] = kwargs['predict'] if 'beam' in kwargs: state['beam'] = kwargs['beam'] state['sess'] = kwargs.get('sess', tf.Session()) state['model_type'] = kwargs.get('model_type', 'default') with open(basename + '.saver') as fsv: saver_def = tf.train.SaverDef() text_format.Merge(fsv.read(), saver_def) src_embeddings = dict() src_embeddings_dict = state.pop('src_embeddings') for key, class_name in src_embeddings_dict.items(): md = read_json('{}-{}-md.json'.format(basename, key)) embed_args = dict({'vsz': md['vsz'], 'dsz': md['dsz']}) Constructor = eval(class_name) src_embeddings[key] = Constructor(key, **embed_args) tgt_class_name = state.pop('tgt_embedding') md = read_json('{}-tgt-md.json'.format(basename)) embed_args = dict({'vsz': md['vsz'], 'dsz': md['dsz']}) Constructor = eval(tgt_class_name) tgt_embedding = Constructor('tgt', **embed_args) model = cls.create(src_embeddings, tgt_embedding, **state) for prop in ls_props(model): if prop in state: setattr(model, prop, state[prop]) do_init = kwargs.get('init', True) if do_init: init = tf.global_variables_initializer() model.sess.run(init) model.saver = tf.train.Saver() model.saver.restore(model.sess, basename) return model
def __init__(self, name, **kwargs): super(TransformerLMEmbeddings, self).__init__(name) self.vocab = read_json(kwargs.get('vocab_file')) self.cls_index = self.vocab['[CLS]'] self.vsz = len(self.vocab) layers = int(kwargs.get('layers', 8)) num_heads = int(kwargs.get('num_heads', 10)) pdrop = kwargs.get('dropout', 0.1) self.d_model = int(kwargs.get('dsz', kwargs.get('d_model', 410))) d_ff = int(kwargs.get('d_ff', 2100)) x_embedding = PositionalLookupTableEmbeddings('pos', vsz=self.vsz, dsz=self.d_model) self.init_embed({'x': x_embedding}) self.transformer = TransformerEncoderStack(num_heads, d_model=self.d_model, pdrop=pdrop, scale=True, layers=layers, d_ff=d_ff)
def _create_model(self, sess, basename): # Read the labels labels = read_json(basename + '.labels') # Get the parameters from MEAD model_params = self.task.config_params["model"] model_params["sess"] = sess # Read the state file state = read_json(basename + '.state') # Re-create the embeddings sub-graph embeddings = dict() for key, class_name in state['embeddings'].items(): md = read_json('{}-{}-md.json'.format(basename, key)) embed_args = dict({'vsz': md['vsz'], 'dsz': md['dsz']}) Constructor = eval(class_name) embeddings[key] = Constructor(key, **embed_args) # Instantiate a graph model = baseline.model.create_model_for(self.task.task_name(), embeddings, labels, **model_params) # Set the properties of the model from the state file for prop in ls_props(model): if prop in state: setattr(model, prop, state[prop]) # Append to the graph for class output values, indices = tf.nn.top_k(model.probs, len(labels)) class_tensor = tf.constant(model.labels) table = tf.contrib.lookup.index_to_string_table_from_tensor( class_tensor) classes = table.lookup(tf.to_int64(indices)) # Restore the checkpoint self._restore_checkpoint(sess, basename) return model, classes, values
def _create_remote_model(directory, backend, remote, name, signature_name, beam, **kwargs): """Reads the necessary information from the remote bundle to instatiate a client for a remote model. :directory the location of the exported model bundle :remote a url endpoint to hit :name the model name, as defined in tf-serving's model.config :signature_name the signature to use. :beam used for s2s and found in the kwargs. We default this and pass it in. :returns a RemoteModel """ from baseline.remote import create_remote assets = read_json(os.path.join(directory, 'model.assets')) model_name = assets['metadata']['exported_model'] preproc = assets['metadata'].get('preproc', kwargs.get('preproc', 'client')) labels = read_json(os.path.join(directory, model_name) + '.labels') lengths_key = assets.get('lengths_key', None) inputs = assets.get('inputs', []) return_labels = bool(assets['metadata']['return_labels']) version = kwargs.get('version') if backend not in {'tf', 'pytorch'}: raise ValueError("only Tensorflow and Pytorch are currently supported for remote Services") import_user_module('baseline.{}.remote'.format(backend)) exp_type = 'http' if remote.startswith('http') else 'grpc' exp_type = '{}-preproc'.format(exp_type) if preproc == 'server' else exp_type model = create_remote( exp_type, remote=remote, name=name, signature=signature_name, labels=labels, lengths_key=lengths_key, inputs=inputs, beam=beam, return_labels=return_labels, version=version, ) return model, preproc
def __init__(self, name, **kwargs): from baseline.pytorch.embeddings import PositionalCharConvEmbeddings X_CHAR_EMBEDDINGS = { "dsz": 16, "wsz": 128, "embed_type": "positional-char-conv", "keep_unused": True, "cfiltsz": [[1, 32], [2, 32], [3, 64], [4, 128], [5, 256], [6, 512], [7, 1024]], "gating": "highway", "num_gates": 2, "projsz": 512 } super(TransformerLMEmbeddings, self).__init__(name) self.vocab = read_json(kwargs.get('vocab_file'), strict=True) self.cls_index = self.vocab['[CLS]'] self.vsz = len(self.vocab) layers = int(kwargs.get('layers', 18)) num_heads = int(kwargs.get('num_heads', 10)) pdrop = kwargs.get('dropout', 0.1) self.d_model = int(kwargs.get('dsz', kwargs.get('d_model', 410))) d_ff = int(kwargs.get('d_ff', 2100)) x_embedding = PositionalCharConvEmbeddings('pcc', vsz=self.vsz, **X_CHAR_EMBEDDINGS) self.dsz = self.init_embed({'x': x_embedding}) self.proj_to_dsz = pytorch_linear( self.dsz, self.d_model) if self.dsz != self.d_model else _identity self.init_embed({'x': x_embedding}) self.transformer = TransformerEncoderStack(num_heads, d_model=self.d_model, pdrop=pdrop, scale=True, layers=layers, d_ff=d_ff) self.mlm = kwargs.get('mlm', False) pooling = kwargs.get('pooling', 'cls') if pooling == 'max': self.pooling_op = _max_pool elif pooling == 'mean': self.pooling_op = _mean_pool else: self.pooling_op = self._cls_pool
def load(cls, basename, **kwargs): """Reload the model from a graph file and a checkpoint The model that is loaded is independent of the pooling and stacking layers, making this class reusable by sub-classes. :param basename: The base directory to load from :param kwargs: See below :Keyword Arguments: * *sess* -- An optional tensorflow session. If not passed, a new session is created :return: A restored model """ _state = read_json("{}.state".format(basename)) if __version__ != _state['version']: logger.warning("Loaded model is from baseline version %s, running version is %s", _state['version'], __version__) _state['sess'] = kwargs.pop('sess', tf.Session()) with _state['sess'].graph.as_default(): embeddings_info = _state.pop('embeddings') embeddings = reload_embeddings(embeddings_info, basename) # If there is a kwarg that is the same name as an embedding object that # is taken to be the input of that layer. This allows for passing in # subgraphs like from a tf.split (for data parallel) or preprocessing # graphs that convert text to indices for k in embeddings_info: if k in kwargs: _state[k] = kwargs[k] # TODO: convert labels into just another vocab and pass number of labels to models. labels = read_json("{}.labels".format(basename)) model = cls.create(embeddings, labels, **_state) model._state = _state if kwargs.get('init', True): model.sess.run(tf.global_variables_initializer()) model.saver = tf.train.Saver() model.saver.restore(model.sess, basename) return model
def load(cls, basename, **kwargs): """Reload the model from a graph file and a checkpoint The model that is loaded is independent of the pooling and stacking layers, making this class reusable by sub-classes. :param basename: The base directory to load from :param kwargs: See below :Keyword Arguments: * *sess* -- An optional tensorflow session. If not passed, a new session is created :return: A restored model """ _state = read_json("{}.state".format(basename)) if __version__ != _state['version']: logger.warning("Loaded model is from baseline version %s, running version is %s", _state['version'], __version__) _state['sess'] = kwargs.pop('sess', tf.Session()) embeddings_info = _state.pop("embeddings") with _state['sess'].graph.as_default(): embeddings = reload_embeddings(embeddings_info, basename) for k in embeddings_info: if k in kwargs: _state[k] = kwargs[k] labels = read_json("{}.labels".format(basename)) if _state.get('constraint') is not None: # Dummy constraint values that will be filled in by the check pointing _state['constraint'] = [tf.zeros((len(labels), len(labels))) for _ in range(2)] model = cls.create(embeddings, labels, **_state) model._state = _state model.create_loss() if kwargs.get('init', True): model.sess.run(tf.global_variables_initializer()) model.saver = tf.train.Saver() model.saver.restore(model.sess, basename) return model
def load_model_for(activity, filename, **kwargs): # Sniff state to see if we need to import things state = read_json('{}.state'.format(filename)) # There won't be a module for pytorch (there is no state file to load). if 'module' in state: import_user_module(state['module']) # Allow user to override model type (for back compat with old api), backoff # to the model type in the state file or to default. # TODO: Currently in pytorch all models are always reloaded with the load # classmethod with a default model class. This is fine given how simple pyt # loading is but it could cause problems if a model has a custom load model_type = kwargs.get('model_type', state.get('model_type', 'default')) creator_fn = BASELINE_LOADERS[activity][model_type] logger.info('Calling model %s', creator_fn) return creator_fn(filename, **kwargs)
def load_model_for(activity, filename, **kwargs): # Sniff state to see if we need to import things state = read_json('{}.state'.format(filename)) # There won't be a module for pytorch (there is no state file to load). if 'module' in state: import_user_module(state['module']) # Allow user to override model type (for back compat with old api), backoff # to the model type in the state file or to default. # TODO: Currently in pytorch all models are always reloaded with the load # classmethod with a default model class. This is fine given how simple pyt # loading is but it could cause problems if a model has a custom load model_type = kwargs.get('type', kwargs.get('model_type', state.get('type', state.get('model_type', 'default')))) creator_fn = BASELINE_LOADERS[activity][model_type] logger.info('Calling model %s', creator_fn) return creator_fn(filename, **kwargs)
def load(cls, basename: str, **kwargs) -> 'DependencyParserModelBase': """Reload the model from a graph file and a checkpoint The model that is loaded is independent of the pooling and stacking layers, making this class reusable by sub-classes. :param basename: The base directory to load from :param kwargs: See below :Keyword Arguments: * *sess* -- An optional tensorflow session. If not passed, a new session is created :return: A restored model """ _state = read_json("{}.state".format(basename)) if __version__ != _state['version']: logger.warning( "Loaded model is from baseline version %s, running version is %s", _state['version'], __version__) embeddings_info = _state.pop('embeddings') embeddings = reload_embeddings(embeddings_info, basename) # If there is a kwarg that is the same name as an embedding object that # is taken to be the input of that layer. This allows for passing in # subgraphs like from a tf.split (for data parallel) or preprocessing # graphs that convert text to indices for k in embeddings_info: if k in kwargs: _state[k] = kwargs[k] # TODO: convert labels into just another vocab and pass number of labels to models. labels = read_json("{}.labels".format(basename)) model = cls.create(embeddings, labels, **_state) model._state = _state model.load_weights(f"{basename}.wgt") return model
def load(cls, basename, **kwargs): _state = read_json("{}.state".format(basename)) if __version__ != _state['version']: bl_logger.warning("Loaded model is from baseline version %s, running version is %s", _state['version'], __version__) _state['sess'] = kwargs.pop('sess', create_session()) with _state['sess'].graph.as_default(): embeddings_info = _state.pop('embeddings') embeddings = reload_embeddings(embeddings_info, basename) for k in embeddings_info: if k in kwargs: _state[k] = kwargs[k] model = cls.create(embeddings, init=kwargs.get('init', True), **_state) model._state = _state model.saver = tf.train.Saver() model.saver.restore(model.sess, basename) return model
def _create_model(self, sess, basename, **kwargs): model = load_tagger_model(basename, sess=sess, **kwargs) softmax_output = tf.nn.softmax(model.probs) values, _ = tf.nn.top_k(softmax_output, 1) indices = model.best if self.return_labels: labels = read_json(basename + '.labels') list_of_labels = [''] * len(labels) for label, idval in labels.items(): list_of_labels[idval] = label class_tensor = tf.constant(list_of_labels) table = tf.contrib.lookup.index_to_string_table_from_tensor(class_tensor) classes = table.lookup(tf.to_int64(indices)) return model, classes, values else: return model, indices, values
def __init__(self, name, embed_file=None, known_vocab=None, **kwargs): super().__init__(trainable=True, name=name, dtype=tf.float32, **kwargs) # options file self.weight_file = embed_file if 'options' not in kwargs: print('Warning: old style configuration') elmo_config = embed_file.replace('weights.hdf5', 'options.json') elmo_config = read_json(elmo_config) else: elmo_config = kwargs['options'] print(elmo_config) self.dsz = kwargs.get('dsz', 2*int(elmo_config['lstm']['projection_dim'])) self.model = BidirectionalLanguageModel(elmo_config, self.weight_file) self.known_vocab = known_vocab self.vocab = UnicodeCharsVocabulary(known_vocab) assert self.dsz == 2*int(elmo_config['lstm']['projection_dim'])
def __init__(self, logger_file, mead_config): super(Task, self).__init__() self.config_params = None self.ExporterType = None self.mead_config = mead_config if os.path.exists(mead_config): mead_settings = read_json(mead_config) else: mead_settings = {} if 'datacache' not in mead_settings: self.data_download_cache = os.path.expanduser("~/.bl-data") mead_settings['datacache'] = self.data_download_cache write_json(mead_settings, mead_config) else: self.data_download_cache = os.path.expanduser(mead_settings['datacache']) print("using {} as data/embeddings cache".format(self.data_download_cache)) self._configure_logger(logger_file)
def read_cred(config_file): dbtype = None dbhost = None dbport = None user = None passwd = None try: j = read_json(config_file, strict=True) dbtype = j.get('dbtype') dbhost = j.get('dbhost') dbport = j.get('dbport') user = j.get('user') passwd = j.get('passwd') except IOError: pass return dbtype, dbhost, dbport, user, passwd
def download(self): file_loc = self.dataset_file if is_file_correct(file_loc): return file_loc elif validate_url(file_loc): # is it a web URL? check if exists in cache url = file_loc dcache_path = os.path.join(self.data_download_cache, DATA_CACHE_CONF) dcache = read_json(dcache_path) if url in dcache and is_file_correct(dcache[url], self.data_download_cache, url) and not self.cache_ignore: logger.info("file for {} found in cache, not downloading".format(url)) return dcache[url] else: # download the file in the cache, update the json cache_dir = self.data_download_cache logger.info("using {} as data/embeddings cache".format(cache_dir)) temp_file = web_downloader(url) dload_file = extractor(filepath=temp_file, cache_dir=cache_dir, extractor_func=Downloader.ZIPD.get(mime_type(temp_file), None)) dcache.update({url: dload_file}) write_json(dcache, os.path.join(self.data_download_cache, DATA_CACHE_CONF)) return dload_file raise RuntimeError("the file [{}] is not in cache and can not be downloaded".format(file_loc))
def init_decode(self, **kwargs): label_vocab = read_json(kwargs["label_vocab"]) label_trans = np.load(kwargs["label_trans"]) trans = np.zeros((len(self.labels), len(self.labels))) for src, src_idx in self.labels.items(): if src not in label_vocab: continue for tgt, tgt_idx in self.labels.items(): if tgt not in label_vocab: continue trans[src_idx, tgt_idx] = label_trans[label_vocab[src], label_vocab[tgt]] name = kwargs.get("decode_name") self.constraint_mask = kwargs.get("constraint_mask").unsqueeze(0) return PremadeTransitionsTagger(len(self.labels), self.constraint_mask, transitions=trans)
def load(cls, basename, **kwargs): _state = read_json('{}.state'.format(basename)) if __version__ != _state['version']: logger.warning("Loaded model is from baseline version %s, running version is %s", _state['version'], __version__) if 'predict' in kwargs: _state['predict'] = kwargs['predict'] if 'beam' in kwargs: _state['beam'] = kwargs['beam'] src_embeddings_info = _state.pop('src_embeddings') src_embeddings = reload_embeddings(src_embeddings_info, basename) for k in src_embeddings_info: if k in kwargs: _state[k] = kwargs[k] tgt_embedding_info = _state.pop('tgt_embedding') tgt_embedding = reload_embeddings(tgt_embedding_info, basename)['tgt'] model = cls.create(src_embeddings, tgt_embedding, **_state) model._state = _state model.load_weights(f"{basename}.wgt") return model
def peek_lengths_key(model_file, field_map): """Check if there is lengths key to use when finding the length. (defaults to tokens)""" peek_state = read_json(model_file + ".state") lengths = peek_state.get('lengths_key', 'tokens') lengths = lengths.replace('_lengths', '') return field_map.get(lengths, lengths)
def update_cache(key, data_download_cache): dcache = read_json(os.path.join(data_download_cache, DATA_CACHE_CONF)) if key not in dcache: return del dcache[key] write_json(dcache, os.path.join(data_download_cache, DATA_CACHE_CONF))
def test_read_json_strict(): with pytest.raises(IOError): read_json(os.path.join('not_there.json'), strict=True)
def test_read_json_given_default(): gold_default = 'default' data = read_json(os.path.join(data_loc, 'not_there.json'), gold_default) assert data == gold_default
def test_read_json_default_value(): gold_default = {} data = read_json(os.path.join(data_loc, 'not_there.json')) assert data == gold_default
def test_read_json(gold_data): data = read_json(os.path.join(data_loc, 'test_json.json')) assert data == gold_data
import os import argparse from baseline.utils import read_json from mead.utils import index_by_label, convert_path from mead.downloader import EmbeddingDownloader, DataDownloader parser = argparse.ArgumentParser(description="Download all data and embeddings.") parser.add_argument("--cache", default="~/.bl-data", type=os.path.expanduser, help="Location of the data cache") parser.add_argument('--datasets', help='json library of dataset labels', default='config/datasets.json', type=convert_path) parser.add_argument('--embeddings', help='json library of embeddings', default='config/embeddings.json', type=convert_path) args = parser.parse_args() datasets = read_json(args.datasets) datasets = index_by_label(datasets) for name, d in datasets.items(): print(name) try: DataDownloader(d, args.cache).download() except Exception as e: print(e) emb = read_json(args.embeddings) emb = index_by_label(emb) for name, e in emb.items(): print(name) try: