def create_cache_backend(**kwargs): name = kwargs.get('backend', 'mongo') if name not in ODIN_STORE_BACKENDS: LOGGER.info("Loading %s backend", name) import_user_module(f'odin.{name}') creator_fn = ODIN_CACHE_BACKENDS[name] return creator_fn(**kwargs)
def _configure_reporting(self, reporting, **kwargs): """Configure all `reporting_hooks` specified in the mead settings or overridden at the command line :param reporting: :param kwargs: :return: """ default_reporting = self.mead_settings_config.get('reporting_hooks', {}) # Add default reporting information to the reporting settings. for report_type in default_reporting: if report_type in reporting: for report_arg, report_val in default_reporting[report_type].items(): if report_arg not in reporting[report_type]: reporting[report_type][report_arg] = report_val reporting_hooks = list(reporting.keys()) for settings in reporting.values(): try: import_user_module(settings.get('module', '')) except (ImportError, ValueError): pass self.reporting = baseline.create_reporting(reporting_hooks, reporting, {'config_file': self.config_file, 'task': self.__class__.task_name(), 'base_dir': self.get_basedir()}) self.config_params['train']['reporting'] = [x.step for x in self.reporting] logging.basicConfig(level=logging.DEBUG)
def main(): parser = argparse.ArgumentParser( description='Convert finetuned transformer model trained with PyTorch classifier to an TLM NPZ' ) parser.add_argument('--model', help='The path to the .pyt file created by training', required=True, type=str) parser.add_argument('--device', help='device') parser.add_argument('--npz', help='Output file name, defaults to the original name with replaced suffix') parser.add_argument('--modules', help='modules to load: local files, remote URLs or mead-ml/hub refs', default=[], nargs='+', required=False) parser.add_argument('--no_output_layer', action='store_true', help='If set, we wont store the final layers') args = parser.parse_args() for module in args.modules: import_user_module(module) if args.npz is None: args.npz = args.model.replace('.pyt', '') + '.npz' bl_model = torch.load(args.model, map_location=args.device) tpt_embed_dict = bl_model.embeddings keys = list(tpt_embed_dict.keys()) if len(keys) > 1: logger.warning( "Unsupported model! Multiple embeddings are applied in this model, " "but this converter only supports a single embedding" ) tpt_embed = tpt_embed_dict[keys[0]] if args.no_output_layer: save_tlm_npz(tpt_embed, args.npz, verbose=True) else: # Monkey patch the embedding to contain an output_layer tpt_embed.output_layer = bl_model.output_layer save_tlm_output_npz(tpt_embed, args.npz, verbose=True)
def import_backend(backend_type): if backend_type == "docker": backend_type = "dock" try: import_user_module('hpctl.{}'.format(backend_type)) except ImportError: pass
def _configure_reporting(self, reporting, **kwargs): """Configure all `reporting_hooks` specified in the mead settings or overridden at the command line :param reporting: :param kwargs: :return: """ default_reporting = self.mead_settings_config.get( 'reporting_hooks', {}) # Add default reporting information to the reporting settings. for report_type in default_reporting: if report_type in reporting: for report_arg, report_val in default_reporting[ report_type].items(): if report_arg not in reporting[report_type]: reporting[report_type][report_arg] = report_val reporting_hooks = list(reporting.keys()) for settings in reporting.values(): try: import_user_module(settings.get('module', '')) except (ImportError, ValueError): pass self.reporting = baseline.create_reporting( reporting_hooks, reporting, { 'config_file': self.config_file, 'task': self.__class__.task_name(), 'base_dir': self.get_basedir() }) self.config_params['train']['reporting'] = [ x.step for x in self.reporting ] logging.basicConfig(level=logging.DEBUG)
def _load_user_modules(self): # User modules can be downloaded from hub or HTTP automatically if they are defined in form # http://path/to/module_name.py # hub:v1:addons:module_name if 'modules' in self.config_params: for addon in self.config_params['modules']: import_user_module(addon, self.data_download_cache)
def load(cls, bundle, **kwargs): """Load a model from a bundle. This can be either a local model or a remote, exported model. :returns a Service implementation """ # can delegate if os.path.isdir(bundle): directory = bundle else: directory = unzip_files(bundle) model_basename = find_model_basename(directory) vocabs = load_vocabs(directory) vectorizers = load_vectorizers(directory) be = normalize_backend(kwargs.get('backend', 'tf')) remote = kwargs.get("remote", None) name = kwargs.get("name", None) if remote: beam = kwargs.get('beam', 10) model = Service._create_remote_model(directory, be, remote, name, cls.signature_name(), beam, preproc=kwargs.get('preproc', False)) return cls(vocabs, vectorizers, model) # Currently nothing to do here # labels = read_json(os.path.join(directory, model_basename) + '.labels') import_user_module('baseline.{}.embeddings'.format(be)) import_user_module('baseline.{}.{}'.format(be, cls.task_name())) model = load_model_for(cls.task_name(), model_basename, **kwargs) return cls(vocabs, vectorizers, model)
def main(): """Driver program to execute chores.""" parser = argparse.ArgumentParser(description='Run chores') parser.add_argument('file', help='A chore YAML file.') parser.add_argument('--cred', help='cred file', default="/etc/odind/odin-cred.yml") parser.add_argument('--label', required=True) parser.add_argument('--modules', nargs='+', default=[], help='Addon modules to load') args = parser.parse_args() for addon in args.modules: import_user_module(addon) cred_params = read_config_stream(args.cred) store = create_store_backend(**cred_params['jobs_db']) config = read_config_stream(args.file) previous = store.get_previous(args.label) parent_details = store.get_parent(args.label) results = {prev_job_details['name']: prev_job_details for prev_job_details in previous} results['parent'] = parent_details results = run_chores(config, results) results = {'chore_context': results} LOGGER.info(results) job_details = store.get(args.label) job_details.update(results) store.set(job_details)
def __init__(self, remote, name, signature, labels=None, beam=None, lengths_key=None, inputs=None, version=None, return_labels=False): """A remote model with gRPC transport When using this type of model, there is an external dependency on the `grpc` package, as well as the TF serving protobuf stub files. There is also currently a dependency on `tensorflow` :param remote: The remote endpoint :param name: The name of the model :param signature: The model signature :param labels: The labels (defaults to None) :param beam: The beam width (defaults to None) :param lengths_key: Which key is used for the length of the input vector (defaults to None) :param inputs: The inputs (defaults to empty list) :param version: The model version (defaults to None) :param return_labels: Whether the remote model returns class indices or the class labels directly. This depends on the `return_labels` parameter in exporters """ super(RemoteModelGRPC, self).__init__( remote, name, signature, labels, beam, lengths_key, inputs, version, return_labels ) self.predictpb = import_user_module('tensorflow_serving.apis.predict_pb2') self.servicepb = import_user_module('tensorflow_serving.apis.prediction_service_pb2_grpc') self.metadatapb = import_user_module('tensorflow_serving.apis.get_model_metadata_pb2') self.grpc = import_user_module('grpc') self.channel = self.grpc.insecure_channel(remote)
def create_vectorizer(**kwargs): vec_type = kwargs.get('vectorizer_type', kwargs.get('type', 'token1d')) # Dynamically load a module if its needed for module in listify(kwargs.get('module', kwargs.get('modules', []))): import_user_module(module, kwargs.get('data_download_cache')) Constructor = MEAD_VECTORIZERS.get(vec_type) return Constructor(**kwargs)
def load(cls, bundle, **kwargs): """Load a model from a bundle. This can be either a local model or a remote, exported model. :returns a Service implementation """ # can delegate basehead = None if os.path.isdir(bundle): directory = bundle elif os.path.isfile(bundle): directory = unzip_files(bundle) else: directory = os.path.dirname(bundle) basehead = os.path.basename(bundle) model_basename = find_model_basename(directory, basehead) suffix = model_basename.split('-')[-1] + ".json" vocabs = load_vocabs(directory, suffix) be = normalize_backend(kwargs.get('backend', 'tf')) remote = kwargs.get("remote", None) name = kwargs.get("name", None) if remote: logging.debug("loading remote model") beam = int(kwargs.get('beam', 30)) model, preproc = Service._create_remote_model( directory, be, remote, name, cls.task_name(), cls.signature_name(), beam, preproc=kwargs.get('preproc', 'client'), version=kwargs.get('version'), remote_type=kwargs.get('remote_type'), ) vectorizers = load_vectorizers(directory) return cls(vocabs, vectorizers, model, preproc) # Currently nothing to do here # labels = read_json(os.path.join(directory, model_basename) + '.labels') import_user_module('baseline.{}.embeddings'.format(be)) try: import_user_module('baseline.{}.{}'.format(be, cls.task_name())) except: pass model = load_model_for(cls.task_name(), model_basename, **kwargs) vectorizers = load_vectorizers(directory) return cls(vocabs, vectorizers, model, 'client')
def _remote_monkey_patch(backend_config, hp_logs, results_config, xpctl_config): if backend_config.get('type', 'local') == 'remote': import_user_module('hpctl.remote') hp_logs['type'] = 'remote' results_config['type'] = 'remote' results_config['host'] = backend_config['host'] results_config['port'] = backend_config['port'] if xpctl_config is not None: xpctl_config['type'] = 'remote' xpctl_config['host'] = backend_config['host'] xpctl_config['port'] = backend_config['port']
def merge_reporting_with_settings(reporting, settings): default_reporting = settings.get('reporting_hooks', {}) # Add default reporting information to the reporting settings. for report_type in default_reporting: if report_type in reporting: for report_arg, report_val in default_reporting[report_type].items(): if report_arg not in reporting[report_type]: reporting[report_type][report_arg] = report_val reporting_hooks = list(reporting.keys()) for settings in reporting.values(): for module in listify(settings.get('module', settings.get('modules', []))): import_user_module(module) return reporting_hooks, reporting
def load_model_for(activity, filename, **kwargs): # Sniff state to see if we need to import things state = read_json('{}.state'.format(filename)) # There won't be a module for pytorch (there is no state file to load). if 'module' in state: import_user_module(state['module']) # Allow user to override model type (for back compat with old api), backoff # to the model type in the state file or to default. # TODO: Currently in pytorch all models are always reloaded with the load # classmethod with a default model class. This is fine given how simple pyt # loading is but it could cause problems if a model has a custom load model_type = kwargs.get('model_type', state.get('model_type', 'default')) creator_fn = BASELINE_LOADERS[activity][model_type] logger.info('Calling model %s', creator_fn) return creator_fn(filename, **kwargs)
def load_model_for(activity, filename, **kwargs): # Sniff state to see if we need to import things state = read_json('{}.state'.format(filename)) # There won't be a module for pytorch (there is no state file to load). if 'module' in state: import_user_module(state['module']) # Allow user to override model type (for back compat with old api), backoff # to the model type in the state file or to default. # TODO: Currently in pytorch all models are always reloaded with the load # classmethod with a default model class. This is fine given how simple pyt # loading is but it could cause problems if a model has a custom load model_type = kwargs.get('type', kwargs.get('model_type', state.get('type', state.get('model_type', 'default')))) creator_fn = BASELINE_LOADERS[activity][model_type] logger.info('Calling model %s', creator_fn) return creator_fn(filename, **kwargs)
def _create_remote_model(directory, backend, remote, name, signature_name, beam, preproc='client'): """Reads the necessary information from the remote bundle to instatiate a client for a remote model. :directory the location of the exported model bundle :remote a url endpoint to hit :name the model name, as defined in tf-serving's model.config :signature_name the signature to use. :beam used for s2s and found in the kwargs. We default this and pass it in. :returns a RemoteModel """ assets = read_json(os.path.join(directory, 'model.assets')) model_name = assets['metadata']['exported_model'] labels = read_json(os.path.join(directory, model_name) + '.labels') lengths_key = assets.get('lengths_key', None) inputs = assets.get('inputs', []) if backend == 'tf': remote_models = import_user_module('baseline.remote') if remote.startswith('http'): RemoteModel = remote_models.RemoteModelTensorFlowREST elif preproc == 'server': RemoteModel = remote_models.RemoteModelTensorFlowGRPCPreproc else: RemoteModel = remote_models.RemoteModelTensorFlowGRPC model = RemoteModel(remote, name, signature_name, labels=labels, lengths_key=lengths_key, inputs=inputs, beam=beam) else: raise ValueError("only Tensorflow is currently supported for remote Services") return model
def _create_remote_model(directory, backend, remote, name, task_name, signature_name, beam, **kwargs): """Reads the necessary information from the remote bundle to instatiate a client for a remote model. :directory the location of the exported model bundle :remote a url endpoint to hit :name the model name, as defined in tf-serving's model.config :signature_name the signature to use. :beam used for s2s and found in the kwargs. We default this and pass it in. :returns a RemoteModel """ from baseline.remote import create_remote assets = read_json(os.path.join(directory, 'model.assets')) model_name = assets['metadata']['exported_model'] preproc = assets['metadata'].get('preproc', kwargs.get('preproc', 'client')) labels = read_json(os.path.join(directory, model_name) + '.labels') lengths_key = assets.get('lengths_key', None) inputs = assets.get('inputs', []) return_labels = bool(assets['metadata']['return_labels']) version = kwargs.get('version') if backend not in {'tf', 'onnx'}: raise ValueError( f"Unsupported backend {backend} for remote Services") import_user_module('baseline.{}.remote'.format(backend)) exp_type = kwargs.get('remote_type') if exp_type is None: exp_type = 'http' if remote.startswith('http') else 'grpc' exp_type = '{}-preproc'.format( exp_type) if preproc == 'server' else exp_type exp_type = f'{exp_type}-{task_name}' model = create_remote( exp_type, remote=remote, name=name, signature=signature_name, labels=labels, lengths_key=lengths_key, inputs=inputs, beam=beam, return_labels=return_labels, version=version, ) return model, preproc
def create_lm_reader(max_word_length, nbptt, word_trans_fn, **kwargs): reader_type = kwargs.get('reader_type', 'default') if reader_type == 'default': reader = PTBSeqReader(max_word_length, nbptt, word_trans_fn) else: mod = import_user_module("reader", reader_type) reader = mod.create_lm_reader(max_word_length, nbptt, word_trans_fn, **kwargs) return reader
def create_pred_reader(mxlen, zeropadding, clean_fn, vec_alloc, src_vec_trans, **kwargs): reader_type = kwargs.get('reader_type', 'default') if reader_type == 'default': reader = TSVSeqLabelReader(mxlen, zeropadding, clean_fn, vec_alloc, src_vec_trans) else: mod = import_user_module("reader", reader_type) reader = mod.create_pred_reader(mxlen, zeropadding, clean_fn, vec_alloc, src_vec_trans, **kwargs) return reader
def load(self, task_name=None): base_pkg_name = 'baseline.{}'.format(self.name) mod = import_user_module(base_pkg_name) import_user_module('{}.embeddings'.format(base_pkg_name)) import_user_module('mead.{}.exporters'.format(self.name)) if task_name is not None: import_user_module('{}.{}'.format(base_pkg_name, task_name)) self.transition_mask = mod.transition_mask
def __init__(self, remote, name, signature, labels=None, beam=None, lengths_key=None, inputs=[], return_labels=False): """A remote model that lives on TF serving with gRPC transport When using this type of model, there is an external dependency on the `grpc` package, as well as the TF serving protobuf stub files. There is also currently a dependency on `tensorflow` :param remote: The remote endpoint :param name: The name of the model :param signature: The model signature :param labels: The labels (defaults to None) :param beam: The beam width (defaults to None) :param lengths_key: Which key is used for the length of the input vector (defaults to None) :param inputs: The inputs (defaults to empty list) :param return_labels: Whether the remote model returns class indices or the class labels directly. This depends on the `return_labels` parameter in exporters """ self.predictpb = import_user_module( 'tensorflow_serving.apis.predict_pb2') self.servicepb = import_user_module( 'tensorflow_serving.apis.prediction_service_pb2_grpc') self.metadatapb = import_user_module( 'tensorflow_serving.apis.get_model_metadata_pb2') self.grpc = import_user_module('grpc') self.remote = remote self.name = name self.signature = signature self.channel = self.grpc.insecure_channel(remote) self.lengths_key = lengths_key self.input_keys = set(inputs) self.beam = beam self.labels = labels self.return_labels = return_labels
def reload_embeddings(embeddings_dict, basename): embeddings = {} for key, cls in embeddings_dict.items(): embed_args = read_json('{}-{}-md.json'.format(basename, key)) module = embed_args.pop('module') name = embed_args.pop('name', None) assert name is None or name == key mod = import_user_module(module) Constructor = getattr(mod, cls) embeddings[key] = Constructor(key, **embed_args) return embeddings
def load(cls, bundle, **kwargs): """Load a model from a bundle. This can be either a local model or a remote, exported model. :returns a Service implementation """ # can delegate if os.path.isdir(bundle): directory = bundle else: directory = unzip_files(bundle) model_basename = find_model_basename(directory) vocabs = load_vocabs(directory) vectorizers = load_vectorizers(directory) be = normalize_backend(kwargs.get('backend', 'tf')) remote = kwargs.get("remote", None) name = kwargs.get("name", None) if remote: logging.debug("loading remote model") beam = kwargs.get('beam', 30) model, preproc = Service._create_remote_model( directory, be, remote, name, cls.signature_name(), beam, preproc=kwargs.get('preproc', 'client'), version=kwargs.get('version') ) return cls(vocabs, vectorizers, model, preproc) # Currently nothing to do here # labels = read_json(os.path.join(directory, model_basename) + '.labels') import_user_module('baseline.{}.embeddings'.format(be)) try: import_user_module('baseline.{}.{}'.format(be, cls.task_name())) except: pass model = load_model_for(cls.task_name(), model_basename, **kwargs) return cls(vocabs, vectorizers, model, 'client')
def load(self, task_name=None): if self.name == 'tf': from eight_mile.tf.layers import set_tf_log_level, set_tf_eager_debug set_tf_log_level(os.getenv("MEAD_TF_LOG_LEVEL", "ERROR")) set_tf_eager_debug(str2bool(os.getenv("MEAD_TF_EAGER_DEBUG", "FALSE"))) base_pkg_name = 'baseline.{}'.format(self.name) # Backends may not be downloaded to the cache, they must exist locally mod = import_user_module(base_pkg_name) import_user_module('baseline.{}.optz'.format(self.name)) import_user_module('baseline.{}.embeddings'.format(self.name)) import_user_module('mead.{}.exporters'.format(self.name)) if task_name is not None: try: import_user_module(f'{base_pkg_name}.{task_name}') except: logger.warning(f"No module found [{base_pkg_name}.{task_name}]") self.transition_mask = mod.transition_mask
def create_featurizer(model, zero_alloc=np.zeros, **kwargs): mxlen = kwargs.pop('mxlen', model.mxlen if hasattr(model, 'mxlen') else -1) maxw = kwargs.pop('maxw', model.maxw if hasattr(model, 'maxw') else model.mxwlen if hasattr(model, 'mxwlen') else -1) kwargs.pop('zero_alloc', None) featurizer_type = kwargs.get('featurizer_type', 'default') if featurizer_type == 'default': return WordCharLength(model, mxlen, maxw, zero_alloc, **kwargs) elif featurizer_type == 'multifeature': return MultiFeatureFeaturizer(model, mxlen, maxw, zero_alloc, **kwargs) else: mod = import_user_module("featurizer", featurizer_type) return mod.create_featurizer(model, mxlen, maxw, zero_alloc, **kwargs)
def _create_remote_model(directory, backend, remote, name, signature_name, beam, **kwargs): """Reads the necessary information from the remote bundle to instatiate a client for a remote model. :directory the location of the exported model bundle :remote a url endpoint to hit :name the model name, as defined in tf-serving's model.config :signature_name the signature to use. :beam used for s2s and found in the kwargs. We default this and pass it in. :returns a RemoteModel """ from baseline.remote import create_remote assets = read_json(os.path.join(directory, 'model.assets')) model_name = assets['metadata']['exported_model'] preproc = assets['metadata'].get('preproc', kwargs.get('preproc', 'client')) labels = read_json(os.path.join(directory, model_name) + '.labels') lengths_key = assets.get('lengths_key', None) inputs = assets.get('inputs', []) return_labels = bool(assets['metadata']['return_labels']) version = kwargs.get('version') if backend not in {'tf', 'pytorch'}: raise ValueError("only Tensorflow and Pytorch are currently supported for remote Services") import_user_module('baseline.{}.remote'.format(backend)) exp_type = 'http' if remote.startswith('http') else 'grpc' exp_type = '{}-preproc'.format(exp_type) if preproc == 'server' else exp_type model = create_remote( exp_type, remote=remote, name=name, signature=signature_name, labels=labels, lengths_key=lengths_key, inputs=inputs, beam=beam, return_labels=return_labels, version=version, ) return model, preproc
def create_seq_pred_reader(mxlen, mxwlen, word_trans_fn, vec_alloc, vec_shape, trim, **kwargs): reader_type = kwargs.get('reader_type', 'default') if reader_type == 'default': print('Reading CONLL sequence file corpus') reader = CONLLSeqReader(mxlen, mxwlen, word_trans_fn, vec_alloc, vec_shape, trim, extended_features=kwargs.get('extended_features', {})) else: mod = import_user_module("reader", reader_type) reader = mod.create_seq_pred_reader(mxlen, mxwlen, word_trans_fn, vec_alloc, vec_shape, trim, **kwargs) return reader
def create_pred_reader(mxlen, zeropadding, clean_fn, vec_alloc, src_vec_trans, **kwargs): reader_type = kwargs.get('reader_type', 'default') if reader_type == 'default': do_chars = kwargs.get('do_chars', False) data_format = kwargs.get('data_format', 'objs') trim = kwargs.get('trim', False) #splitter = kwargs.get('splitter', '[\t\s]+') reader = TSVSeqLabelReader(mxlen, kwargs.get('mxwlen', -1), zeropadding, clean_fn, vec_alloc, src_vec_trans, do_chars=do_chars, data_format=data_format, trim=trim) else: mod = import_user_module("reader", reader_type) reader = mod.create_pred_reader(mxlen, zeropadding, clean_fn, vec_alloc, src_vec_trans, **kwargs) return reader
def __init__(self, store: Store, namespace: str = 'default', modules: List[str] = DEFAULT_MODULES): """Create a Pod scheduler :param store: A job store :param namespace: A k8s namespace, defaults to `default` :param modules: A list of ResourceHandler modules to load """ super().__init__() self.namespace = namespace try: config.load_incluster_config() except config.config_exception.ConfigException: config.load_kube_config() self.store = store for module in modules: import_user_module(module) self.handlers = { k: create_resource_handler(k, namespace) for k in RESOURCE_HANDLERS.keys() }
def create_parallel_corpus_reader(mxlen, alloc_fn, trim, src_vec_trans, **kwargs): reader_type = kwargs.get('reader_type', 'default') if reader_type == 'default': print('Reading parallel file corpus') pair_suffix = kwargs.get('pair_suffix') reader = MultiFileParallelCorpusReader(pair_suffix[0], pair_suffix[1], mxlen, alloc_fn, src_vec_trans, trim) elif reader_type == 'tsv': print('Reading tab-separated corpus') reader = TSVParallelCorpusReader(mxlen, alloc_fn, src_vec_trans, trim) else: mod = import_user_module("reader", reader_type) return mod.create_parallel_corpus_reader(mxlen, alloc_fn, src_vec_trans, trim, **kwargs) return reader
def create_reporting_hook(reporting_hooks, hook_settings, **kwargs): reporting = [LoggingReporting()] if 'console' in reporting_hooks: reporting.append(ConsoleReporting()) reporting_hooks.remove('console') if 'visdom' in reporting_hooks: visdom_settings = hook_settings.get('visdom', {}) reporting.append(VisdomReporting(visdom_settings=visdom_settings)) reporting_hooks.remove('visdom') if 'tensorboard' in reporting_hooks: tensorboard_settings = hook_settings.get('tensorboard', {}) reporting.append( TensorBoardReporting(tensorboard_settings=tensorboard_settings)) reporting_hooks.remove('tensorboard') for reporting_hook in reporting_hooks: mod = import_user_module("reporting", reporting_hook) hook_setting = hook_settings.get(reporting_hook, {}) reporting.append( mod.create_reporting_hook(hook_setting=hook_setting, **kwargs)) return reporting
def load_user_modules(config, settings): for module in config.pop('hpctl_modules', settings.get('hpctl_modules', [])): import_user_module(module)
def _load_user_modules(self): if 'modules' in self.config_params: for addon in self.config_params['modules']: import_user_module(addon)
def load(cls, bundle, **kwargs): import_user_module('create_servable_embeddings') return super(EmbeddingsService, cls).load(bundle, **kwargs)
def main(): parser = argparse.ArgumentParser(description='Evaluate on a dataset') parser.add_argument('--model', required=True) parser.add_argument('--dataset', required=True) parser.add_argument('--settings', default=DEFAULT_SETTINGS_LOC, type=convert_path) parser.add_argument('--modules', nargs="+", default=[]) parser.add_argument('--reporting', nargs="+") parser.add_argument('--logging', default=DEFAULT_LOGGING_LOC, type=convert_path) parser.add_argument('--task', default='classify', choices={'classify', 'tagger', 'seq2seq', 'lm'}) parser.add_argument('--backend', default='tf') parser.add_argument('--reader', default='default') parser.add_argument('--trim', default=True, type=str2bool) parser.add_argument('--batchsz', default=50) parser.add_argument('--trainer', default='default') parser.add_argument('--output', default=None) parser.add_argument('--remote') parser.add_argument( '--features', help= '(optional) features in the format feature_name:index (column # in conll) or ' 'just feature names (assumed sequential)', default=[], nargs='+', ) parser.add_argument('--device', default='cpu') # our parse_extra_args doesn't handle lists :/ parser.add_argument('--pair_suffix', nargs='+', default=[]) args, extra_args = parser.parse_known_args() args.batchsz = args.batchsz if args.task != 'lm' else 1 named_fields = { str(v): k for k, v in feature_index_mapping(args.features).items() } reader_options = parse_extra_args(['reader'], extra_args)['reader'] reader_options = process_reader_options(reader_options) verbose_options = parse_extra_args(['verbose'], extra_args)['verbose'] trainer_options = parse_extra_args(['trainer'], extra_args)['trainer'] if 'span_type' not in trainer_options: trainer_options['span_type'] = 'iobes' model_options = parse_extra_args(['model'], extra_args)['model'] args.logging = read_config_stream(args.logging) configure_logger(args.logging) try: args.settings = read_config_stream(args.settings) except: logger.warning( 'Warning: no mead-settings file was found at [{}]'.format( args.settings)) args.settings = {} backend = Backend(args.backend) backend.load(args.task) for module in args.modules: import_user_module(module) reporting = parse_extra_args( args.reporting if args.reporting is not None else [], extra_args) reporting_hooks, reporting = merge_reporting_with_settings( reporting, args.settings) reporting_fns = [ x.step for x in create_reporting(reporting_hooks, reporting, {'task': args.task}) ] service = get_service(args.task) model = service.load(args.model, backend=args.backend, remote=args.remote, device=args.device, **model_options) vectorizers = get_vectorizers(args.task, model) reader = create_reader(args.task, vectorizers, args.trim, type=args.reader, named_fields=named_fields, pair_suffix=args.pair_suffix, **reader_options) reader = patch_reader(args.task, model, reader) data, txts = load_data(args.task, reader, model, args.dataset, args.batchsz) if args.task == 'seq2seq': trainer_options['tgt_rlut'] = { v: k for k, v in model.tgt_vocab.items() } trainer = get_trainer(model, args.trainer, verbose_options, backend.name, gpu=args.device != 'cpu', nogpu=args.device == 'cpu', **trainer_options) if args.task == 'classify': _ = trainer.test(data, reporting_fns=reporting_fns, phase='Test', verbose=verbose_options, output=args.output, txts=txts, **model_options) elif args.task == 'tagger': _ = trainer.test(data, reporting_fns=reporting_fns, phase='Test', verbose=verbose_options, conll_output=args.output, txts=txts, **model_options) else: _ = trainer.test(data, reporting_fns=reporting_fns, phase='Test', verbose=verbose_options, **model_options)
def main(): parser = argparse.ArgumentParser(description='Train a text classifier') parser.add_argument( '--config', help= 'JSON/YML Configuration for an experiment: local file or remote URL', type=convert_path, default="$MEAD_CONFIG") parser.add_argument('--settings', help='JSON/YML Configuration for mead', default=DEFAULT_SETTINGS_LOC, type=convert_path) parser.add_argument('--task_modules', help='tasks to load, must be local', default=[], nargs='+', required=False) parser.add_argument( '--datasets', help= 'index of dataset labels: local file, remote URL or mead-ml/hub ref', type=convert_path) parser.add_argument( '--modules', help='modules to load: local files, remote URLs or mead-ml/hub refs', default=[], nargs='+', required=False) parser.add_argument('--mod_train_file', help='override the training set') parser.add_argument('--mod_valid_file', help='override the validation set') parser.add_argument('--mod_test_file', help='override the test set') parser.add_argument('--fit_func', help='override the fit function') parser.add_argument( '--embeddings', help='index of embeddings: local file, remote URL or mead-ml/hub ref', type=convert_path) parser.add_argument( '--vecs', help='index of vectorizers: local file, remote URL or hub mead-ml/ref', type=convert_path) parser.add_argument('--logging', help='json file for logging', default=DEFAULT_LOGGING_LOC, type=convert_path) parser.add_argument('--task', help='task to run', choices=['classify', 'tagger', 'seq2seq', 'lm']) parser.add_argument('--gpus', help='Number of GPUs (defaults to number available)', type=int, default=-1) parser.add_argument( '--basedir', help='Override the base directory where models are stored', type=str) parser.add_argument('--reporting', help='reporting hooks', nargs='+') parser.add_argument('--backend', help='The deep learning backend to use') parser.add_argument('--checkpoint', help='Restart training from this checkpoint') parser.add_argument( '--prefer_eager', help="If running in TensorFlow, should we prefer eager model", type=str2bool) args, overrides = parser.parse_known_args() config_params = read_config_stream(args.config) config_params = parse_and_merge_overrides(config_params, overrides, pre='x') if args.basedir is not None: config_params['basedir'] = args.basedir # task_module overrides are not allowed via hub or HTTP, must be defined locally for task in args.task_modules: import_user_module(task) task_name = config_params.get( 'task', 'classify') if args.task is None else args.task args.logging = read_config_stream(args.logging) configure_logger(args.logging, config_params.get('basedir', './{}'.format(task_name))) try: args.settings = read_config_stream(args.settings) except: logger.warning( 'Warning: no mead-settings file was found at [{}]'.format( args.settings)) args.settings = {} args.datasets = args.settings.get( 'datasets', convert_path( DEFAULT_DATASETS_LOC)) if args.datasets is None else args.datasets args.datasets = read_config_stream(args.datasets) if args.mod_train_file or args.mod_valid_file or args.mod_test_file: logging.warning( 'Warning: overriding the training/valid/test data with user-specified files' ' different from what was specified in the dataset index. Creating a new key for this entry' ) update_datasets(args.datasets, config_params, args.mod_train_file, args.mod_valid_file, args.mod_test_file) args.embeddings = args.settings.get( 'embeddings', convert_path(DEFAULT_EMBEDDINGS_LOC) ) if args.embeddings is None else args.embeddings args.embeddings = read_config_stream(args.embeddings) args.vecs = args.settings.get('vecs', convert_path( DEFAULT_VECTORIZERS_LOC)) if args.vecs is None else args.vecs args.vecs = read_config_stream(args.vecs) if args.gpus: # why does it go to model and not to train? config_params['train']['gpus'] = args.gpus if args.fit_func: config_params['train']['fit_func'] = args.fit_func if args.backend: config_params['backend'] = normalize_backend(args.backend) config_params['modules'] = list( set(chain(config_params.get('modules', []), args.modules))) cmd_hooks = args.reporting if args.reporting is not None else [] config_hooks = config_params.get('reporting') if config_params.get( 'reporting') is not None else [] reporting = parse_extra_args(set(chain(cmd_hooks, config_hooks)), overrides) config_params['reporting'] = reporting logger.info('Task: [{}]'.format(task_name)) task = mead.Task.get_task_specific(task_name, args.settings) task.read_config(config_params, args.datasets, args.vecs, reporting_args=overrides, prefer_eager=args.prefer_eager) task.initialize(args.embeddings) task.train(args.checkpoint)
def load(cls, bundle, **kwargs): import_user_module('hub:v1:addons:create_servable_embeddings') return super().load(bundle, **kwargs)