def launch(config, settings, logging, hpctl_logging, datasets, embeddings, reporting, unknown, task, num_iters, **kwargs): mead_config = get_config(config, reporting, unknown) hp_settings, mead_settings = get_settings(settings) load_user_modules(mead_config, hp_settings) exp_hash = hash_config(mead_config) hp_logs, mead_logs = get_logs(hp_settings, logging, hpctl_logging) datasets = read_config_file_or_json(datasets) embeddings = read_config_file_or_json(embeddings) if task is None: task = mead_config.get('task', 'classify') _, backend_config = get_ends(hp_settings, unknown) force_remote_backend(backend_config) config_sampler = get_config_sampler(mead_config, None) be = get_backend(backend_config) for _ in range(num_iters): label, config = config_sampler.sample() print(label) send = { 'label': label, 'config': config, 'mead_logs': mead_logs, 'hpctl_logs': hp_logs, 'task_name': task, 'settings': mead_settings, 'experiment_config': mead_config, } be.launch(**send)
def serve(settings, hpctl_logging, embeddings, datasets, unknown, **kwargs): hp_settings, mead_settings = get_settings(settings) load_user_modules({}, hp_settings) frontend_config, backend_config = get_ends(hp_settings, unknown) hp_logs, _ = get_logs(hp_settings, {}, hpctl_logging) xpctl_config = get_xpctl_settings(mead_settings) set_root(hp_settings) datasets = read_config_file_or_json(datasets) embeddings = read_config_file_or_json(embeddings) results = get_results({}) backend = get_backend(backend_config) logs = get_log_server(hp_logs) xpctl = get_xpctl(xpctl_config) frontend_config['type'] = 'flask' frontend_config['datasets'] = index_by_label(datasets) frontend_config['embeddings'] = index_by_label(embeddings) frontend = get_frontend(frontend_config, results, xpctl) scheduler = RoundRobinScheduler() cache = mead_settings.get('datacache', '~/.bl-data') try: run_forever(results, backend, scheduler, frontend, logs, cache, xpctl_config, datasets, embeddings) except KeyboardInterrupt: pass
def __init__(self, **kwargs): super(XPCtlReporting, self).__init__(**kwargs) # throw exception if the next three can't be read from kwargs if 'host' in kwargs: self.api_url = kwargs['host'] elif 'cred' in kwargs: self.api_url = read_config_file_or_json(kwargs['cred'])['host'] else: raise ValueError( 'must provide a url where xpctl server is running') self.exp_config = read_config_file_or_json(kwargs['config_file']) self.task = kwargs['task'] self.label = kwargs.get('label', None) self.username = kwargs.get('user', getpass.getuser()) self.hostname = kwargs.get('host', socket.gethostname()) self.checkpoint_base = None self.checkpoint_store = kwargs.get('checkpoint_store', '/data/model-checkpoints') self.save_model = kwargs.get('save_model', False) # optionally save the model config = Configuration(host=self.api_url) api_client = ApiClient(config) self.api = XpctlApi(api_client) self.log = []
def serve(settings, hpctl_logging, embeddings, datasets, unknown, **kwargs): hp_settings, mead_settings = get_settings(settings) load_user_modules({}, hp_settings) frontend_config, backend_config = get_ends(hp_settings, unknown) hp_logs, _ = get_logs(hp_settings, {}, hpctl_logging) xpctl_config = get_xpctl_settings(mead_settings) set_root(hp_settings) datasets = read_config_file_or_json(datasets) embeddings = read_config_file_or_json(embeddings) results = get_results({}) backend = get_backend(backend_config) logs = get_log_server(hp_logs) xpctl = get_xpctl(xpctl_config) frontend_config['type'] = 'flask' frontend_config['datasets'] = index_by_label(datasets) frontend_config['embeddings'] = index_by_label(embeddings) frontend = get_frontend(frontend_config, results, xpctl) scheduler = RoundRobinScheduler() cache = mead_settings.get('datacache', '~/.bl-data') try: run_forever(results, backend, scheduler, frontend, logs, cache, xpctl_config, datasets, embeddings) except KeyboardInterrupt: pass
def get_logs(hpctl_settings, logging, hpctl_logging): mead_logs = read_config_file_or_json(logging) hpctl_logs = read_config_file_or_json(hpctl_logging) hpctl_logs['host'] = hpctl_settings.get('logging', {}).get('host', 'localhost') hpctl_logs['port'] = int( hpctl_settings.get('logging', {}).get('post', 6006)) return hpctl_logs, mead_logs
def __init__(self, **kwargs): super(XPCtlReporting, self).__init__(**kwargs) # throw exception if the next three can't be read from kwargs self.cred = read_config_file_or_json(kwargs['cred']) self.label = kwargs.get('label', None) self.exp_config = read_config_file_or_json(kwargs['config_file']) self.task = kwargs['task'] self.print_fn = print self.username = kwargs.get('user', getpass.getuser()) self.hostname = kwargs.get('host', socket.gethostname()) self.checkpoint_base = None self.checkpoint_store = kwargs.get('checkpoint_store', '/data/model-checkpoints') self.save_model = kwargs.get('save_model', False) # optionally save the model self.repo = ExperimentRepo().create_repo(**self.cred) self.log = []
def read_config(self, config_params, datasets_index, **kwargs): """ Read the config file and the datasets index Between the config file and the dataset index, we have enough information to configure the backend and the models. We can also initialize the data readers :param config_file: The config file :param datasets_index: The index of datasets :return: """ datasets_index = read_config_file_or_json(datasets_index, 'datasets') datasets_set = index_by_label(datasets_index) self.config_params = config_params config_file = deepcopy(config_params) basedir = self.get_basedir() if basedir is not None and not os.path.exists(basedir): logger.info('Creating: %s', basedir) os.makedirs(basedir) self.config_params['train']['basedir'] = basedir # Read GPUS from env variables now so that the reader has access if self.config_params['model'].get('gpus', -1) == -1: self.config_params['model']['gpus'] = len(get_env_gpus()) self._setup_task(**kwargs) self._load_user_modules() self.dataset = get_dataset_from_key(self.config_params['dataset'], datasets_set) # replace dataset in config file by the latest dataset label, this will be used by some reporting hooks config_file['dataset'] = self.dataset['label'] self._configure_reporting(config_params.get('reporting', {}), config_file=config_file, **kwargs) self.reader = self._create_task_specific_reader()
def read_config(self, config_params, datasets_index, **kwargs): """ Read the config file and the datasets index Between the config file and the dataset index, we have enough information to configure the backend and the models. We can also initialize the data readers :param config_file: The config file :param datasets_index: The index of datasets :return: """ datasets_index = read_config_file_or_json(datasets_index, 'datasets') datasets_set = index_by_label(datasets_index) self.config_params = config_params basedir = self.get_basedir() if basedir is not None and not os.path.exists(basedir): logger.info('Creating: {}'.format(basedir)) os.mkdir(basedir) self.config_params['train']['basedir'] = basedir # Read GPUS from env variables now so that the reader has access if self.config_params['model'].get('gpus', 1) == -1: self.config_params['model']['gpus'] = len(get_env_gpus()) self.config_file = kwargs.get('config_file') self._setup_task(**kwargs) self._load_user_modules() self._configure_reporting(config_params.get('reporting', {}), **kwargs) self.dataset = datasets_set[self.config_params['dataset']] self.reader = self._create_task_specific_reader()
def read_config(self, config_params, datasets_index, **kwargs): """ Read the config file and the datasets index Between the config file and the dataset index, we have enough information to configure the backend and the models. We can also initialize the data readers :param config_file: The config file :param datasets_index: The index of datasets :return: """ datasets_index = read_config_file_or_json(datasets_index, 'datasets') datasets_set = index_by_label(datasets_index) self.config_params = config_params basedir = self.get_basedir() if basedir is not None and not os.path.exists(basedir): logger.info('Creating: %s', basedir) os.makedirs(basedir) self.config_params['train']['basedir'] = basedir # Read GPUS from env variables now so that the reader has access if self.config_params['model'].get('gpus', -1) == -1: self.config_params['model']['gpus'] = len(get_env_gpus()) self.config_file = kwargs.get('config_file') self._setup_task(**kwargs) self._load_user_modules() self._configure_reporting(config_params.get('reporting', {}), **kwargs) self.dataset = get_dataset_from_key(self.config_params['dataset'], datasets_set) self.reader = self._create_task_specific_reader()
def initialize(self, embeddings): embeddings = read_config_file_or_json(embeddings, 'embeddings') embeddings_set = index_by_label(embeddings) self.dataset = DataDownloader(self.dataset, self.data_download_cache).download() print_dataset_info(self.dataset) vocab1, vocab2 = self.reader.build_vocabs( [self.dataset['train_file'], self.dataset['valid_file'], self.dataset['test_file']], min_f=Task._get_min_f(self.config_params), vocab_file=self.dataset.get('vocab_file') ) # To keep the config file simple, share a list between source and destination (tgt) features_src = [] features_tgt = None for feature in self.config_params['features']: if feature['name'] == 'tgt': features_tgt = feature else: features_src += [feature] self.src_embeddings, self.feat2src = self._create_embeddings(embeddings_set, vocab1, features_src) # For now, dont allow multiple vocabs of output baseline.save_vocabs(self.get_basedir(), self.feat2src) self.tgt_embeddings, self.feat2tgt = self._create_embeddings(embeddings_set, {'tgt': vocab2}, [features_tgt]) baseline.save_vocabs(self.get_basedir(), self.feat2tgt) self.tgt_embeddings = self.tgt_embeddings['tgt'] self.feat2tgt = self.feat2tgt['tgt']
def initialize(self, embeddings): embeddings = read_config_file_or_json(embeddings, 'embeddings') embeddings_set = index_by_label(embeddings) self.dataset = DataDownloader(self.dataset, self.data_download_cache).download() print_dataset_info(self.dataset) vocab_sources = [self.dataset['train_file'], self.dataset['valid_file']] # TODO: make this optional if 'test_file' in self.dataset: vocab_sources.append(self.dataset['test_file']) vocab1, vocab2 = self.reader.build_vocabs(vocab_sources, min_f=Task._get_min_f(self.config_params), vocab_file=self.dataset.get('vocab_file')) # To keep the config file simple, share a list between source and destination (tgt) features_src = [] features_tgt = None for feature in self.config_params['features']: if feature['name'] == 'tgt': features_tgt = feature else: features_src += [feature] self.src_embeddings, self.feat2src = self._create_embeddings(embeddings_set, vocab1, features_src) # For now, dont allow multiple vocabs of output baseline.save_vocabs(self.get_basedir(), self.feat2src) self.tgt_embeddings, self.feat2tgt = self._create_embeddings(embeddings_set, {'tgt': vocab2}, [features_tgt]) baseline.save_vocabs(self.get_basedir(), self.feat2tgt) self.tgt_embeddings = self.tgt_embeddings['tgt'] self.feat2tgt = self.feat2tgt['tgt']
def initialize(self, embeddings): embeddings = read_config_file_or_json(embeddings, 'embeddings') embeddings_set = index_by_label(embeddings) self.config_params['keep_unused'] = True features = self.config_params['features'] self.embeddings, self.feat2index = self._create_embeddings( embeddings_set, defaultdict(dict), self.config_params['features']) save_vocabs(self.get_basedir(), self.feat2index)
def __init__(self, **kwargs): super(XPCtlReporting, self).__init__(**kwargs) # throw exception if the next three can't be read from kwargs self.cred = read_config_file_or_json(kwargs['cred']) self.label = kwargs.get('label', None) self.exp_config = read_config_file_or_json(kwargs['config_file']) self.task = kwargs['task'] self.print_fn = print self.username = kwargs.get('user', getpass.getuser()) self.hostname = kwargs.get('host', socket.gethostname()) self.checkpoint_base = None self.checkpoint_store = kwargs.get('checkpoint_store', '/data/model-checkpoints') self.save_model = kwargs.get('save_model', False) # optionally save the model self.repo = ExperimentRepo().create_repo(**self.cred) self.log = []
def initialize(self, embeddings): self.dataset = DataDownloader(self.dataset, self.data_download_cache).download() print_dataset_info(self.dataset) embeddings = read_config_file_or_json(embeddings, 'embeddings') embeddings_set = index_by_label(embeddings) vocabs = self.reader.build_vocab( [self.dataset['train_file'], self.dataset['valid_file'], self.dataset['test_file']], min_f=Task._get_min_f(self.config_params), ) self.embeddings, self.feat2index = self._create_embeddings(embeddings_set, vocabs, self.config_params['features']) baseline.save_vocabs(self.get_basedir(), self.feat2index)
def initialize(self, embeddings): self.dataset = DataDownloader(self.dataset, self.data_download_cache).download() print_dataset_info(self.dataset) embeddings = read_config_file_or_json(embeddings, 'embeddings') embeddings_set = index_by_label(embeddings) vocabs = self.reader.build_vocab( [self.dataset['train_file'], self.dataset['valid_file'], self.dataset['test_file']], min_f=Task._get_min_f(self.config_params), ) self.embeddings, self.feat2index = self._create_embeddings(embeddings_set, vocabs, self.config_params['features']) baseline.save_vocabs(self.get_basedir(), self.feat2index)
def _configure_logger(self, logger_config): """Use the logger file (logging.json) to configure the log, but overwrite the filename to include the PID :param logger_config: The logging configuration JSON or file containing JSON :return: A dictionary config derived from the logger_file, with the reporting handler suffixed with PID """ config = read_config_file_or_json(logger_config, 'logger') config['handlers']['reporting_file_handler'][ 'filename'] = 'reporting-{}.log'.format(os.getpid()) config['handlers']['timing_file_handler'][ 'filename'] = 'timing-{}.log'.format(os.getpid()) logging.config.dictConfig(config)
def initialize(self, embeddings): embeddings = read_config_file_or_json(embeddings, 'embeddings') embeddings_set = index_by_label(embeddings) self.dataset = DataDownloader(self.dataset, self.data_download_cache).download() print_dataset_info(self.dataset) vocab_sources = [self.dataset['train_file'], self.dataset['valid_file']] # TODO: make this optional if 'test_file' in self.dataset: vocab_sources.append(self.dataset['test_file']) vocabs = self.reader.build_vocab(vocab_sources, min_f=Task._get_min_f(self.config_params), vocab_file=self.dataset.get('vocab_file')) self.embeddings, self.feat2index = self._create_embeddings(embeddings_set, vocabs, self.config_params['features']) baseline.save_vocabs(self.get_basedir(), self.feat2index)
def initialize(self, embeddings): embeddings = read_config_file_or_json(embeddings, 'embeddings') embeddings_set = index_by_label(embeddings) self.dataset = DataDownloader(self.dataset, self.data_download_cache).download() print("[train file]: {}\n[valid file]: {}\n[test file]: {}".format( self.dataset['train_file'], self.dataset['valid_file'], self.dataset['test_file'])) vocab, self.num_elems = self.reader.build_vocab([ self.dataset['train_file'], self.dataset['valid_file'], self.dataset['test_file'] ]) self.embeddings, self.feat2index = self._create_embeddings( embeddings_set, vocab)
def launch( config, settings, logging, hpctl_logging, datasets, embeddings, reporting, unknown, task, num_iters, **kwargs ): mead_config = get_config(config, reporting, unknown) hp_settings, mead_settings = get_settings(settings) load_user_modules(mead_config, hp_settings) exp_hash = hash_config(mead_config) hp_logs, mead_logs = get_logs(hp_settings, logging, hpctl_logging) datasets = read_config_file_or_json(datasets) embeddings = read_config_file_or_json(embeddings) if task is None: task = mead_config.get('task', 'classify') _, backend_config = get_ends(hp_settings, unknown) force_remote_backend(backend_config) config_sampler = get_config_sampler(mead_config, None) be = get_backend(backend_config) for _ in range(num_iters): label, config = config_sampler.sample() print(label) send = { 'label': label, 'config': config, 'mead_logs': mead_logs, 'hpctl_logs': hp_logs, 'task_name': task, 'settings': mead_settings, 'experiment_config': mead_config, } be.launch(**send)
def read_config(self, config_params, datasets_index, **kwargs): """ Read the config file and the datasets index Between the config file and the dataset index, we have enough information to configure the backend and the models. We can also initialize the data readers :param config_file: The config file :param datasets_index: The index of datasets :return: """ datasets_index = read_config_file_or_json(datasets_index, 'datasets') datasets_set = index_by_label(datasets_index) self.config_params = config_params self.config_file = kwargs.get('config_file') self._setup_task() self._configure_reporting(config_params.get('reporting', {}), self.task_name, **kwargs) self.dataset = datasets_set[self.config_params['dataset']] self.reader = self._create_task_specific_reader()
def initialize(self, embeddings): embeddings = read_config_file_or_json(embeddings, 'embeddings') embeddings_set = index_by_label(embeddings) self.dataset = DataDownloader(self.dataset, self.data_download_cache, True).download() print( "[train file]: {}\n[valid file]: {}\n[test file]: {}\n[vocab file]: {}" .format(self.dataset['train_file'], self.dataset['valid_file'], self.dataset['test_file'], self.dataset.get('vocab_file', "None"))) vocab_file = self.dataset.get('vocab_file', None) if vocab_file is not None: vocab1, vocab2 = self.reader.build_vocabs([vocab_file]) else: vocab1, vocab2 = self.reader.build_vocabs([ self.dataset['train_file'], self.dataset['valid_file'], self.dataset['test_file'] ]) self.embeddings1, self.feat2index1 = self._create_embeddings( embeddings_set, {'word': vocab1}) self.embeddings2, self.feat2index2 = self._create_embeddings( embeddings_set, {'word': vocab2})
def test_read_config_file_or_json_list(gold_data): input_ = [gold_data, '12'] data = read_config_file_or_json(input_) assert data == input_
def test_read_config_file_or_json_json(gold_data): data = read_config_file_or_json(gold_data) assert data == gold_data
def test_read_config_file_or_json_missing_file(): with pytest.raises(Exception): data = read_config_file_or_json(os.path.join(data_loc, 'not_there.json'))
def test_read_config_file_or_json_file(): file_name = os.path.join(data_loc, 'test_json.json') with mock.patch('mead.utils.read_config_file') as read_patch: read_config_file_or_json(file_name) read_patch.assert_called_once_with(file_name)
def get_logs(hpctl_settings, logging, hpctl_logging): mead_logs = read_config_file_or_json(logging) hpctl_logs = read_config_file_or_json(hpctl_logging) hpctl_logs['host'] = hpctl_settings.get('logging', {}).get('host', 'localhost') hpctl_logs['port'] = int(hpctl_settings.get('logging', {}).get('post', 6006)) return hpctl_logs, mead_logs
def search( config, settings, logging, hpctl_logging, datasets, embeddings, reporting, unknown, task, num_iters, **kwargs ): """Search for optimal hyperparameters.""" mead_config = get_config(config, reporting, unknown) hp_settings, mead_settings = get_settings(settings) load_user_modules(mead_config, hp_settings) exp_hash = hash_config(mead_config) hp_logs, mead_logs = get_logs(hp_settings, logging, hpctl_logging) datasets = read_config_file_or_json(datasets) embeddings = read_config_file_or_json(embeddings) if task is None: task = mead_config.get('task', 'classify') frontend_config, backend_config = get_ends(hp_settings, unknown) # Figure out xpctl xpctl_config = None auto_xpctl = 'xpctl' in mead_config.get('reporting', []) if not auto_xpctl: # If the jobs aren't setup to use xpctl automatically create your own xpctl_config = get_xpctl_settings(mead_settings) if xpctl_config is not None: xpctl_extra = parse_extra_args(['xpctl'], unknown) xpctl_config['label'] = xpctl_extra.get('xpctl', {}).get('label') results_config = {} # Set frontend defaults frontend_config['experiment_hash'] = exp_hash default = mead_config['train'].get('early_stopping_metric', 'avg_loss') frontend_config.setdefault('train', 'avg_loss') frontend_config.setdefault('dev', default) frontend_config.setdefault('test', default) # Negotiate remote status if backend_config['type'] != 'remote': set_root(hp_settings) _remote_monkey_patch(backend_config, hp_logs, results_config, xpctl_config) xpctl = get_xpctl(xpctl_config) results = get_results(results_config) results.add_experiment(mead_config) backend = get_backend(backend_config) config_sampler = get_config_sampler(mead_config, results) logs = get_log_server(hp_logs) frontend = get_frontend(frontend_config, results, xpctl) labels = run(num_iters, results, backend, frontend, config_sampler, logs, mead_logs, hp_logs, mead_settings, datasets, embeddings, task) logs.stop() frontend.finalize() results.save() if auto_xpctl: for label in labels: results.set_xpctl(label, True) return labels, results
def test_read_config_file_or_json_json(gold_data): data = read_config_file_or_json(gold_data) assert data == gold_data
def test_read_config_file_or_json_list(gold_data): input_ = [gold_data, '12'] data = read_config_file_or_json(input_) assert data == input_
def test_read_config_file_or_json_missing_file(): with pytest.raises(Exception): data = read_config_file_or_json( os.path.join(data_loc, 'not_there.json'))
def download_dataset(dataset: str, datasets_index: str, cache: str) -> Dict[str, str]: dataset = index_by_label(read_config_file_or_json(datasets_index))[dataset] return DataDownloader(dataset, cache).download()
def main(): parser = argparse.ArgumentParser(description='Classify text with a model') parser.add_argument( '--model', help= 'The path to either the .zip file created by training or to the client bundle ' 'created by exporting', required=True, type=str) parser.add_argument('--config', type=str, required=True) parser.add_argument('--text1', type=str) parser.add_argument('--text2', type=str) parser.add_argument('--file', type=str) parser.add_argument('--backend', help='backend', choices={'tf', 'pytorch'}, default='pytorch') parser.add_argument('--device', help='device') parser.add_argument('--batchsz', help='batch size when --text is a file', default=100, type=int) parser.add_argument('--modules', default=[]) args = parser.parse_args() if args.backend == 'tf': from eight_mile.tf.layers import set_tf_eager_mode set_tf_eager_mode(args.prefer_eager) for mod_name in args.modules: bl.import_user_module(mod_name) if os.path.exists(args.file) and os.path.isfile(args.file): df = pd.read_csv(args.file) text_1 = [x.strip().split() for x in df['hypothesis']] text_2 = [x.strip().split() for x in df['premise']] else: text_1 = [args.text1.split()] text_2 = [args.text2.split()] text_1_batched = [ text_1[i:i + args.batchsz] for i in range(0, len(text_1), args.batchsz) ] text_2_batched = [ text_2[i:i + args.batchsz] for i in range(0, len(text_2), args.batchsz) ] config = read_config_file_or_json(args.config) loader_config = config['loader'] model_type = config['model']['model_type'] model = NLIService.load(args.model, model_type=model_type, backend=args.backend, device=args.device, **loader_config) for text_1_batch, text_2_batch in zip(text_1_batched, text_2_batched): output_batch = model.predict(text_1_batch, text_2_batch) for text_1, text_2, output in zip(text_1_batch, text_2_batch, output_batch): print("text1: {}, text2: {}, prediction: {}".format( " ".join(text_1), " ".join(text_2), output[0][0])) print('=' * 30)
def get_xpctl_settings(mead_settings): xpctl = mead_settings.get('reporting_hooks', {}).get('xpctl', {}) if 'cred' not in xpctl: return None return read_config_file_or_json(xpctl['cred'])
def search(config, settings, logging, hpctl_logging, datasets, embeddings, reporting, unknown, task, num_iters, **kwargs): """Search for optimal hyperparameters.""" mead_config = get_config(config, reporting, unknown) hp_settings, mead_settings = get_settings(settings) load_user_modules(mead_config, hp_settings) exp_hash = hash_config(mead_config) hp_logs, mead_logs = get_logs(hp_settings, logging, hpctl_logging) datasets = read_config_file_or_json(datasets) embeddings = read_config_file_or_json(embeddings) if task is None: task = mead_config.get('task', 'classify') frontend_config, backend_config = get_ends(hp_settings, unknown) # Figure out xpctl xpctl_config = None auto_xpctl = 'xpctl' in mead_config.get('reporting', []) if not auto_xpctl: # If the jobs aren't setup to use xpctl automatically create your own xpctl_config = get_xpctl_settings(mead_settings) if xpctl_config is not None: xpctl_extra = parse_extra_args(['xpctl'], unknown) xpctl_config['label'] = xpctl_extra.get('xpctl', {}).get('label') results_config = {} # Set frontend defaults frontend_config['experiment_hash'] = exp_hash default = mead_config['train'].get('early_stopping_metric', 'avg_loss') frontend_config.setdefault('train', 'avg_loss') frontend_config.setdefault('dev', default) frontend_config.setdefault('test', default) # Negotiate remote status if backend_config['type'] != 'remote': set_root(hp_settings) _remote_monkey_patch(backend_config, hp_logs, results_config, xpctl_config) xpctl = get_xpctl(xpctl_config) results = get_results(results_config) results.add_experiment(mead_config) backend = get_backend(backend_config) config_sampler = get_config_sampler(mead_config, results) logs = get_log_server(hp_logs) frontend = get_frontend(frontend_config, results, xpctl) labels = run(num_iters, results, backend, frontend, config_sampler, logs, mead_logs, hp_logs, mead_settings, datasets, embeddings, task) logs.stop() frontend.finalize() results.save() if auto_xpctl: for label in labels: results.set_xpctl(label, True) return labels, results
def test_read_config_file_or_json_file(): file_name = os.path.join(data_loc, 'test_json.json') with mock.patch('mead.utils.read_config_file') as read_patch: read_config_file_or_json(file_name) read_patch.assert_called_once_with(file_name)
def get_config(config, reporting, extra_args): mead_config = read_config_file_or_json(config) if reporting is not None: mead_config['reporting'] = parse_extra_args(reporting, extra_args) return mead_config
def get_xpctl_settings(mead_settings): xpctl = mead_settings.get('reporting_hooks', {}).get('xpctl', {}) if 'cred' not in xpctl: return None return read_config_file_or_json(xpctl['cred'])
def get_config(config, reporting, extra_args): mead_config = read_config_file_or_json(config) if reporting is not None: mead_config['reporting'] = parse_extra_args(reporting, extra_args) return mead_config