Пример #1
0
def launch(config, settings, logging, hpctl_logging, datasets, embeddings,
           reporting, unknown, task, num_iters, **kwargs):
    mead_config = get_config(config, reporting, unknown)
    hp_settings, mead_settings = get_settings(settings)
    load_user_modules(mead_config, hp_settings)
    exp_hash = hash_config(mead_config)
    hp_logs, mead_logs = get_logs(hp_settings, logging, hpctl_logging)
    datasets = read_config_file_or_json(datasets)
    embeddings = read_config_file_or_json(embeddings)
    if task is None:
        task = mead_config.get('task', 'classify')
    _, backend_config = get_ends(hp_settings, unknown)

    force_remote_backend(backend_config)

    config_sampler = get_config_sampler(mead_config, None)
    be = get_backend(backend_config)

    for _ in range(num_iters):
        label, config = config_sampler.sample()
        print(label)
        send = {
            'label': label,
            'config': config,
            'mead_logs': mead_logs,
            'hpctl_logs': hp_logs,
            'task_name': task,
            'settings': mead_settings,
            'experiment_config': mead_config,
        }
        be.launch(**send)
Пример #2
0
def serve(settings, hpctl_logging, embeddings, datasets, unknown, **kwargs):
    hp_settings, mead_settings = get_settings(settings)
    load_user_modules({}, hp_settings)
    frontend_config, backend_config = get_ends(hp_settings, unknown)
    hp_logs, _ = get_logs(hp_settings, {}, hpctl_logging)
    xpctl_config = get_xpctl_settings(mead_settings)
    set_root(hp_settings)

    datasets = read_config_file_or_json(datasets)
    embeddings = read_config_file_or_json(embeddings)

    results = get_results({})
    backend = get_backend(backend_config)
    logs = get_log_server(hp_logs)

    xpctl = get_xpctl(xpctl_config)

    frontend_config['type'] = 'flask'
    frontend_config['datasets'] = index_by_label(datasets)
    frontend_config['embeddings'] = index_by_label(embeddings)
    frontend = get_frontend(frontend_config, results, xpctl)
    scheduler = RoundRobinScheduler()

    cache = mead_settings.get('datacache', '~/.bl-data')

    try:
        run_forever(results, backend, scheduler, frontend, logs, cache,
                    xpctl_config, datasets, embeddings)
    except KeyboardInterrupt:
        pass
Пример #3
0
    def __init__(self, **kwargs):
        super(XPCtlReporting, self).__init__(**kwargs)
        # throw exception if the next three can't be read from kwargs
        if 'host' in kwargs:
            self.api_url = kwargs['host']
        elif 'cred' in kwargs:
            self.api_url = read_config_file_or_json(kwargs['cred'])['host']
        else:
            raise ValueError(
                'must provide a url where xpctl server is running')
        self.exp_config = read_config_file_or_json(kwargs['config_file'])
        self.task = kwargs['task']

        self.label = kwargs.get('label', None)
        self.username = kwargs.get('user', getpass.getuser())
        self.hostname = kwargs.get('host', socket.gethostname())
        self.checkpoint_base = None
        self.checkpoint_store = kwargs.get('checkpoint_store',
                                           '/data/model-checkpoints')
        self.save_model = kwargs.get('save_model',
                                     False)  # optionally save the model
        config = Configuration(host=self.api_url)
        api_client = ApiClient(config)
        self.api = XpctlApi(api_client)

        self.log = []
Пример #4
0
def serve(settings, hpctl_logging, embeddings, datasets, unknown, **kwargs):
    hp_settings, mead_settings = get_settings(settings)
    load_user_modules({}, hp_settings)
    frontend_config, backend_config = get_ends(hp_settings, unknown)
    hp_logs, _ = get_logs(hp_settings, {}, hpctl_logging)
    xpctl_config = get_xpctl_settings(mead_settings)
    set_root(hp_settings)

    datasets = read_config_file_or_json(datasets)
    embeddings = read_config_file_or_json(embeddings)

    results = get_results({})
    backend = get_backend(backend_config)
    logs = get_log_server(hp_logs)

    xpctl = get_xpctl(xpctl_config)

    frontend_config['type'] = 'flask'
    frontend_config['datasets'] = index_by_label(datasets)
    frontend_config['embeddings'] = index_by_label(embeddings)
    frontend = get_frontend(frontend_config, results, xpctl)
    scheduler = RoundRobinScheduler()

    cache = mead_settings.get('datacache', '~/.bl-data')

    try:
        run_forever(results, backend, scheduler, frontend, logs, cache, xpctl_config, datasets, embeddings)
    except KeyboardInterrupt:
        pass
Пример #5
0
def get_logs(hpctl_settings, logging, hpctl_logging):
    mead_logs = read_config_file_or_json(logging)
    hpctl_logs = read_config_file_or_json(hpctl_logging)
    hpctl_logs['host'] = hpctl_settings.get('logging',
                                            {}).get('host', 'localhost')
    hpctl_logs['port'] = int(
        hpctl_settings.get('logging', {}).get('post', 6006))
    return hpctl_logs, mead_logs
Пример #6
0
    def __init__(self, **kwargs):
        super(XPCtlReporting, self).__init__(**kwargs)
        # throw exception if the next three can't be read from kwargs
        self.cred = read_config_file_or_json(kwargs['cred'])
        self.label = kwargs.get('label', None)
        self.exp_config = read_config_file_or_json(kwargs['config_file'])
        self.task = kwargs['task']
        self.print_fn = print
        self.username = kwargs.get('user', getpass.getuser())
        self.hostname = kwargs.get('host', socket.gethostname())
        self.checkpoint_base = None
        self.checkpoint_store = kwargs.get('checkpoint_store', '/data/model-checkpoints')
        self.save_model = kwargs.get('save_model', False) # optionally save the model

        self.repo = ExperimentRepo().create_repo(**self.cred)
        self.log = []
Пример #7
0
    def read_config(self, config_params, datasets_index, **kwargs):
        """
        Read the config file and the datasets index

        Between the config file and the dataset index, we have enough information
        to configure the backend and the models.  We can also initialize the data readers

        :param config_file: The config file
        :param datasets_index: The index of datasets
        :return:
        """
        datasets_index = read_config_file_or_json(datasets_index, 'datasets')
        datasets_set = index_by_label(datasets_index)
        self.config_params = config_params
        config_file = deepcopy(config_params)
        basedir = self.get_basedir()
        if basedir is not None and not os.path.exists(basedir):
            logger.info('Creating: %s', basedir)
            os.makedirs(basedir)
        self.config_params['train']['basedir'] = basedir
        # Read GPUS from env variables now so that the reader has access
        if self.config_params['model'].get('gpus', -1) == -1:
            self.config_params['model']['gpus'] = len(get_env_gpus())
        self._setup_task(**kwargs)
        self._load_user_modules()
        self.dataset = get_dataset_from_key(self.config_params['dataset'], datasets_set)
        # replace dataset in config file by the latest dataset label, this will be used by some reporting hooks
        config_file['dataset'] = self.dataset['label']
        self._configure_reporting(config_params.get('reporting', {}), config_file=config_file, **kwargs)
        self.reader = self._create_task_specific_reader()
Пример #8
0
    def read_config(self, config_params, datasets_index, **kwargs):
        """
        Read the config file and the datasets index

        Between the config file and the dataset index, we have enough information
        to configure the backend and the models.  We can also initialize the data readers

        :param config_file: The config file
        :param datasets_index: The index of datasets
        :return:
        """
        datasets_index = read_config_file_or_json(datasets_index, 'datasets')
        datasets_set = index_by_label(datasets_index)
        self.config_params = config_params
        basedir = self.get_basedir()
        if basedir is not None and not os.path.exists(basedir):
            logger.info('Creating: {}'.format(basedir))
            os.mkdir(basedir)
        self.config_params['train']['basedir'] = basedir
        # Read GPUS from env variables now so that the reader has access
        if self.config_params['model'].get('gpus', 1) == -1:
            self.config_params['model']['gpus'] = len(get_env_gpus())
        self.config_file = kwargs.get('config_file')
        self._setup_task(**kwargs)
        self._load_user_modules()
        self._configure_reporting(config_params.get('reporting', {}), **kwargs)
        self.dataset = datasets_set[self.config_params['dataset']]
        self.reader = self._create_task_specific_reader()
Пример #9
0
    def read_config(self, config_params, datasets_index, **kwargs):
        """
        Read the config file and the datasets index

        Between the config file and the dataset index, we have enough information
        to configure the backend and the models.  We can also initialize the data readers

        :param config_file: The config file
        :param datasets_index: The index of datasets
        :return:
        """
        datasets_index = read_config_file_or_json(datasets_index, 'datasets')
        datasets_set = index_by_label(datasets_index)
        self.config_params = config_params
        basedir = self.get_basedir()
        if basedir is not None and not os.path.exists(basedir):
            logger.info('Creating: %s', basedir)
            os.makedirs(basedir)
        self.config_params['train']['basedir'] = basedir
        # Read GPUS from env variables now so that the reader has access
        if self.config_params['model'].get('gpus', -1) == -1:
            self.config_params['model']['gpus'] = len(get_env_gpus())
        self.config_file = kwargs.get('config_file')
        self._setup_task(**kwargs)
        self._load_user_modules()
        self._configure_reporting(config_params.get('reporting', {}), **kwargs)
        self.dataset = get_dataset_from_key(self.config_params['dataset'], datasets_set)
        self.reader = self._create_task_specific_reader()
Пример #10
0
    def initialize(self, embeddings):
        embeddings = read_config_file_or_json(embeddings, 'embeddings')
        embeddings_set = index_by_label(embeddings)
        self.dataset = DataDownloader(self.dataset, self.data_download_cache).download()
        print_dataset_info(self.dataset)
        vocab1, vocab2 = self.reader.build_vocabs(
            [self.dataset['train_file'], self.dataset['valid_file'], self.dataset['test_file']],
            min_f=Task._get_min_f(self.config_params),
            vocab_file=self.dataset.get('vocab_file')
        )

        # To keep the config file simple, share a list between source and destination (tgt)
        features_src = []
        features_tgt = None
        for feature in self.config_params['features']:
            if feature['name'] == 'tgt':
                features_tgt = feature
            else:
                features_src += [feature]

        self.src_embeddings, self.feat2src = self._create_embeddings(embeddings_set, vocab1, features_src)
        # For now, dont allow multiple vocabs of output
        baseline.save_vocabs(self.get_basedir(), self.feat2src)
        self.tgt_embeddings, self.feat2tgt = self._create_embeddings(embeddings_set, {'tgt': vocab2}, [features_tgt])
        baseline.save_vocabs(self.get_basedir(), self.feat2tgt)
        self.tgt_embeddings = self.tgt_embeddings['tgt']
        self.feat2tgt = self.feat2tgt['tgt']
Пример #11
0
    def initialize(self, embeddings):
        embeddings = read_config_file_or_json(embeddings, 'embeddings')
        embeddings_set = index_by_label(embeddings)
        self.dataset = DataDownloader(self.dataset, self.data_download_cache).download()
        print_dataset_info(self.dataset)
        vocab_sources = [self.dataset['train_file'], self.dataset['valid_file']]
        # TODO: make this optional
        if 'test_file' in self.dataset:
            vocab_sources.append(self.dataset['test_file'])
        vocab1, vocab2 = self.reader.build_vocabs(vocab_sources,
                                                  min_f=Task._get_min_f(self.config_params),
                                                  vocab_file=self.dataset.get('vocab_file'))

        # To keep the config file simple, share a list between source and destination (tgt)
        features_src = []
        features_tgt = None
        for feature in self.config_params['features']:
            if feature['name'] == 'tgt':
                features_tgt = feature
            else:
                features_src += [feature]

        self.src_embeddings, self.feat2src = self._create_embeddings(embeddings_set, vocab1, features_src)
        # For now, dont allow multiple vocabs of output
        baseline.save_vocabs(self.get_basedir(), self.feat2src)
        self.tgt_embeddings, self.feat2tgt = self._create_embeddings(embeddings_set, {'tgt': vocab2}, [features_tgt])
        baseline.save_vocabs(self.get_basedir(), self.feat2tgt)
        self.tgt_embeddings = self.tgt_embeddings['tgt']
        self.feat2tgt = self.feat2tgt['tgt']
 def initialize(self, embeddings):
     embeddings = read_config_file_or_json(embeddings, 'embeddings')
     embeddings_set = index_by_label(embeddings)
     self.config_params['keep_unused'] = True
     features = self.config_params['features']
     self.embeddings, self.feat2index = self._create_embeddings(
         embeddings_set, defaultdict(dict), self.config_params['features'])
     save_vocabs(self.get_basedir(), self.feat2index)
Пример #13
0
    def __init__(self, **kwargs):
        super(XPCtlReporting, self).__init__(**kwargs)
        # throw exception if the next three can't be read from kwargs
        self.cred = read_config_file_or_json(kwargs['cred'])
        self.label = kwargs.get('label', None)
        self.exp_config = read_config_file_or_json(kwargs['config_file'])
        self.task = kwargs['task']
        self.print_fn = print
        self.username = kwargs.get('user', getpass.getuser())
        self.hostname = kwargs.get('host', socket.gethostname())
        self.checkpoint_base = None
        self.checkpoint_store = kwargs.get('checkpoint_store',
                                           '/data/model-checkpoints')
        self.save_model = kwargs.get('save_model',
                                     False)  # optionally save the model

        self.repo = ExperimentRepo().create_repo(**self.cred)
        self.log = []
Пример #14
0
 def initialize(self, embeddings):
     self.dataset = DataDownloader(self.dataset, self.data_download_cache).download()
     print_dataset_info(self.dataset)
     embeddings = read_config_file_or_json(embeddings, 'embeddings')
     embeddings_set = index_by_label(embeddings)
     vocabs = self.reader.build_vocab(
         [self.dataset['train_file'], self.dataset['valid_file'], self.dataset['test_file']],
         min_f=Task._get_min_f(self.config_params),
     )
     self.embeddings, self.feat2index = self._create_embeddings(embeddings_set, vocabs, self.config_params['features'])
     baseline.save_vocabs(self.get_basedir(), self.feat2index)
Пример #15
0
 def initialize(self, embeddings):
     self.dataset = DataDownloader(self.dataset, self.data_download_cache).download()
     print_dataset_info(self.dataset)
     embeddings = read_config_file_or_json(embeddings, 'embeddings')
     embeddings_set = index_by_label(embeddings)
     vocabs = self.reader.build_vocab(
         [self.dataset['train_file'], self.dataset['valid_file'], self.dataset['test_file']],
         min_f=Task._get_min_f(self.config_params),
     )
     self.embeddings, self.feat2index = self._create_embeddings(embeddings_set, vocabs, self.config_params['features'])
     baseline.save_vocabs(self.get_basedir(), self.feat2index)
Пример #16
0
    def _configure_logger(self, logger_config):
        """Use the logger file (logging.json) to configure the log, but overwrite the filename to include the PID

        :param logger_config: The logging configuration JSON or file containing JSON
        :return: A dictionary config derived from the logger_file, with the reporting handler suffixed with PID
        """
        config = read_config_file_or_json(logger_config, 'logger')

        config['handlers']['reporting_file_handler'][
            'filename'] = 'reporting-{}.log'.format(os.getpid())
        config['handlers']['timing_file_handler'][
            'filename'] = 'timing-{}.log'.format(os.getpid())
        logging.config.dictConfig(config)
Пример #17
0
 def initialize(self, embeddings):
     embeddings = read_config_file_or_json(embeddings, 'embeddings')
     embeddings_set = index_by_label(embeddings)
     self.dataset = DataDownloader(self.dataset, self.data_download_cache).download()
     print_dataset_info(self.dataset)
     vocab_sources = [self.dataset['train_file'], self.dataset['valid_file']]
     # TODO: make this optional
     if 'test_file' in self.dataset:
         vocab_sources.append(self.dataset['test_file'])
     vocabs = self.reader.build_vocab(vocab_sources,
                                      min_f=Task._get_min_f(self.config_params),
                                      vocab_file=self.dataset.get('vocab_file'))
     self.embeddings, self.feat2index = self._create_embeddings(embeddings_set, vocabs, self.config_params['features'])
     baseline.save_vocabs(self.get_basedir(), self.feat2index)
Пример #18
0
 def initialize(self, embeddings):
     embeddings = read_config_file_or_json(embeddings, 'embeddings')
     embeddings_set = index_by_label(embeddings)
     self.dataset = DataDownloader(self.dataset,
                                   self.data_download_cache).download()
     print("[train file]: {}\n[valid file]: {}\n[test file]: {}".format(
         self.dataset['train_file'], self.dataset['valid_file'],
         self.dataset['test_file']))
     vocab, self.num_elems = self.reader.build_vocab([
         self.dataset['train_file'], self.dataset['valid_file'],
         self.dataset['test_file']
     ])
     self.embeddings, self.feat2index = self._create_embeddings(
         embeddings_set, vocab)
Пример #19
0
def launch(
        config, settings,
        logging, hpctl_logging,
        datasets, embeddings,
        reporting, unknown,
        task, num_iters, **kwargs
):
    mead_config = get_config(config, reporting, unknown)
    hp_settings, mead_settings = get_settings(settings)
    load_user_modules(mead_config, hp_settings)
    exp_hash = hash_config(mead_config)
    hp_logs, mead_logs = get_logs(hp_settings, logging, hpctl_logging)
    datasets = read_config_file_or_json(datasets)
    embeddings = read_config_file_or_json(embeddings)
    if task is None:
        task = mead_config.get('task', 'classify')
    _, backend_config = get_ends(hp_settings, unknown)

    force_remote_backend(backend_config)

    config_sampler = get_config_sampler(mead_config, None)
    be = get_backend(backend_config)

    for _ in range(num_iters):
        label, config = config_sampler.sample()
        print(label)
        send = {
            'label': label,
            'config': config,
            'mead_logs': mead_logs,
            'hpctl_logs': hp_logs,
            'task_name': task,
            'settings': mead_settings,
            'experiment_config': mead_config,
        }
        be.launch(**send)
Пример #20
0
    def read_config(self, config_params, datasets_index, **kwargs):
        """
        Read the config file and the datasets index

        Between the config file and the dataset index, we have enough information
        to configure the backend and the models.  We can also initialize the data readers

        :param config_file: The config file
        :param datasets_index: The index of datasets
        :return:
        """
        datasets_index = read_config_file_or_json(datasets_index, 'datasets')
        datasets_set = index_by_label(datasets_index)
        self.config_params = config_params
        self.config_file = kwargs.get('config_file')
        self._setup_task()
        self._configure_reporting(config_params.get('reporting', {}),
                                  self.task_name, **kwargs)
        self.dataset = datasets_set[self.config_params['dataset']]
        self.reader = self._create_task_specific_reader()
Пример #21
0
 def initialize(self, embeddings):
     embeddings = read_config_file_or_json(embeddings, 'embeddings')
     embeddings_set = index_by_label(embeddings)
     self.dataset = DataDownloader(self.dataset, self.data_download_cache,
                                   True).download()
     print(
         "[train file]: {}\n[valid file]: {}\n[test file]: {}\n[vocab file]: {}"
         .format(self.dataset['train_file'], self.dataset['valid_file'],
                 self.dataset['test_file'],
                 self.dataset.get('vocab_file', "None")))
     vocab_file = self.dataset.get('vocab_file', None)
     if vocab_file is not None:
         vocab1, vocab2 = self.reader.build_vocabs([vocab_file])
     else:
         vocab1, vocab2 = self.reader.build_vocabs([
             self.dataset['train_file'], self.dataset['valid_file'],
             self.dataset['test_file']
         ])
     self.embeddings1, self.feat2index1 = self._create_embeddings(
         embeddings_set, {'word': vocab1})
     self.embeddings2, self.feat2index2 = self._create_embeddings(
         embeddings_set, {'word': vocab2})
Пример #22
0
def test_read_config_file_or_json_list(gold_data):
    input_ = [gold_data, '12']
    data = read_config_file_or_json(input_)
    assert data == input_
Пример #23
0
def test_read_config_file_or_json_json(gold_data):
    data = read_config_file_or_json(gold_data)
    assert data == gold_data
Пример #24
0
def test_read_config_file_or_json_missing_file():
    with pytest.raises(Exception):
        data = read_config_file_or_json(os.path.join(data_loc, 'not_there.json'))
Пример #25
0
def test_read_config_file_or_json_file():
    file_name = os.path.join(data_loc, 'test_json.json')
    with mock.patch('mead.utils.read_config_file') as read_patch:
        read_config_file_or_json(file_name)
    read_patch.assert_called_once_with(file_name)
Пример #26
0
def get_logs(hpctl_settings, logging, hpctl_logging):
    mead_logs = read_config_file_or_json(logging)
    hpctl_logs = read_config_file_or_json(hpctl_logging)
    hpctl_logs['host'] = hpctl_settings.get('logging', {}).get('host', 'localhost')
    hpctl_logs['port'] = int(hpctl_settings.get('logging', {}).get('post', 6006))
    return hpctl_logs, mead_logs
Пример #27
0
def search(
        config, settings,
        logging, hpctl_logging,
        datasets, embeddings,
        reporting, unknown,
        task, num_iters,
        **kwargs
):
    """Search for optimal hyperparameters."""
    mead_config = get_config(config, reporting, unknown)
    hp_settings, mead_settings = get_settings(settings)
    load_user_modules(mead_config, hp_settings)
    exp_hash = hash_config(mead_config)


    hp_logs, mead_logs = get_logs(hp_settings, logging, hpctl_logging)

    datasets = read_config_file_or_json(datasets)
    embeddings = read_config_file_or_json(embeddings)

    if task is None:
        task = mead_config.get('task', 'classify')

    frontend_config, backend_config = get_ends(hp_settings, unknown)

    # Figure out xpctl
    xpctl_config = None
    auto_xpctl = 'xpctl' in mead_config.get('reporting', [])
    if not auto_xpctl:
        # If the jobs aren't setup to use xpctl automatically create your own
        xpctl_config = get_xpctl_settings(mead_settings)
        if xpctl_config is not None:
            xpctl_extra = parse_extra_args(['xpctl'], unknown)
            xpctl_config['label'] = xpctl_extra.get('xpctl', {}).get('label')
    results_config = {}

    # Set frontend defaults
    frontend_config['experiment_hash'] = exp_hash
    default = mead_config['train'].get('early_stopping_metric', 'avg_loss')
    frontend_config.setdefault('train', 'avg_loss')
    frontend_config.setdefault('dev', default)
    frontend_config.setdefault('test', default)

    # Negotiate remote status
    if backend_config['type'] != 'remote':
        set_root(hp_settings)
    _remote_monkey_patch(backend_config, hp_logs, results_config, xpctl_config)

    xpctl = get_xpctl(xpctl_config)

    results = get_results(results_config)
    results.add_experiment(mead_config)

    backend = get_backend(backend_config)

    config_sampler = get_config_sampler(mead_config, results)

    logs = get_log_server(hp_logs)

    frontend = get_frontend(frontend_config, results, xpctl)

    labels = run(num_iters, results, backend, frontend, config_sampler, logs, mead_logs, hp_logs, mead_settings, datasets, embeddings, task)
    logs.stop()
    frontend.finalize()
    results.save()
    if auto_xpctl:
        for label in labels:
            results.set_xpctl(label, True)
    return labels, results
Пример #28
0
def test_read_config_file_or_json_json(gold_data):
    data = read_config_file_or_json(gold_data)
    assert data == gold_data
Пример #29
0
def test_read_config_file_or_json_list(gold_data):
    input_ = [gold_data, '12']
    data = read_config_file_or_json(input_)
    assert data == input_
Пример #30
0
def test_read_config_file_or_json_missing_file():
    with pytest.raises(Exception):
        data = read_config_file_or_json(
            os.path.join(data_loc, 'not_there.json'))
Пример #31
0
def download_dataset(dataset: str, datasets_index: str,
                     cache: str) -> Dict[str, str]:
    dataset = index_by_label(read_config_file_or_json(datasets_index))[dataset]
    return DataDownloader(dataset, cache).download()
Пример #32
0
def main():
    parser = argparse.ArgumentParser(description='Classify text with a model')
    parser.add_argument(
        '--model',
        help=
        'The path to either the .zip file created by training or to the client bundle '
        'created by exporting',
        required=True,
        type=str)
    parser.add_argument('--config', type=str, required=True)
    parser.add_argument('--text1', type=str)
    parser.add_argument('--text2', type=str)
    parser.add_argument('--file', type=str)
    parser.add_argument('--backend',
                        help='backend',
                        choices={'tf', 'pytorch'},
                        default='pytorch')
    parser.add_argument('--device', help='device')
    parser.add_argument('--batchsz',
                        help='batch size when --text is a file',
                        default=100,
                        type=int)
    parser.add_argument('--modules', default=[])
    args = parser.parse_args()

    if args.backend == 'tf':
        from eight_mile.tf.layers import set_tf_eager_mode
        set_tf_eager_mode(args.prefer_eager)

    for mod_name in args.modules:
        bl.import_user_module(mod_name)

    if os.path.exists(args.file) and os.path.isfile(args.file):
        df = pd.read_csv(args.file)
        text_1 = [x.strip().split() for x in df['hypothesis']]
        text_2 = [x.strip().split() for x in df['premise']]
    else:
        text_1 = [args.text1.split()]
        text_2 = [args.text2.split()]

    text_1_batched = [
        text_1[i:i + args.batchsz] for i in range(0, len(text_1), args.batchsz)
    ]
    text_2_batched = [
        text_2[i:i + args.batchsz] for i in range(0, len(text_2), args.batchsz)
    ]

    config = read_config_file_or_json(args.config)
    loader_config = config['loader']
    model_type = config['model']['model_type']
    model = NLIService.load(args.model,
                            model_type=model_type,
                            backend=args.backend,
                            device=args.device,
                            **loader_config)
    for text_1_batch, text_2_batch in zip(text_1_batched, text_2_batched):
        output_batch = model.predict(text_1_batch, text_2_batch)
        for text_1, text_2, output in zip(text_1_batch, text_2_batch,
                                          output_batch):
            print("text1: {}, text2: {}, prediction: {}".format(
                " ".join(text_1), " ".join(text_2), output[0][0]))
        print('=' * 30)
Пример #33
0
def get_xpctl_settings(mead_settings):
    xpctl = mead_settings.get('reporting_hooks', {}).get('xpctl', {})
    if 'cred' not in xpctl:
        return None
    return read_config_file_or_json(xpctl['cred'])
Пример #34
0
def search(config, settings, logging, hpctl_logging, datasets, embeddings,
           reporting, unknown, task, num_iters, **kwargs):
    """Search for optimal hyperparameters."""
    mead_config = get_config(config, reporting, unknown)
    hp_settings, mead_settings = get_settings(settings)
    load_user_modules(mead_config, hp_settings)
    exp_hash = hash_config(mead_config)

    hp_logs, mead_logs = get_logs(hp_settings, logging, hpctl_logging)

    datasets = read_config_file_or_json(datasets)
    embeddings = read_config_file_or_json(embeddings)

    if task is None:
        task = mead_config.get('task', 'classify')

    frontend_config, backend_config = get_ends(hp_settings, unknown)

    # Figure out xpctl
    xpctl_config = None
    auto_xpctl = 'xpctl' in mead_config.get('reporting', [])
    if not auto_xpctl:
        # If the jobs aren't setup to use xpctl automatically create your own
        xpctl_config = get_xpctl_settings(mead_settings)
        if xpctl_config is not None:
            xpctl_extra = parse_extra_args(['xpctl'], unknown)
            xpctl_config['label'] = xpctl_extra.get('xpctl', {}).get('label')
    results_config = {}

    # Set frontend defaults
    frontend_config['experiment_hash'] = exp_hash
    default = mead_config['train'].get('early_stopping_metric', 'avg_loss')
    frontend_config.setdefault('train', 'avg_loss')
    frontend_config.setdefault('dev', default)
    frontend_config.setdefault('test', default)

    # Negotiate remote status
    if backend_config['type'] != 'remote':
        set_root(hp_settings)
    _remote_monkey_patch(backend_config, hp_logs, results_config, xpctl_config)

    xpctl = get_xpctl(xpctl_config)

    results = get_results(results_config)
    results.add_experiment(mead_config)

    backend = get_backend(backend_config)

    config_sampler = get_config_sampler(mead_config, results)

    logs = get_log_server(hp_logs)

    frontend = get_frontend(frontend_config, results, xpctl)

    labels = run(num_iters, results, backend, frontend, config_sampler, logs,
                 mead_logs, hp_logs, mead_settings, datasets, embeddings, task)
    logs.stop()
    frontend.finalize()
    results.save()
    if auto_xpctl:
        for label in labels:
            results.set_xpctl(label, True)
    return labels, results
Пример #35
0
def test_read_config_file_or_json_file():
    file_name = os.path.join(data_loc, 'test_json.json')
    with mock.patch('mead.utils.read_config_file') as read_patch:
        read_config_file_or_json(file_name)
    read_patch.assert_called_once_with(file_name)
Пример #36
0
def get_config(config, reporting, extra_args):
    mead_config = read_config_file_or_json(config)
    if reporting is not None:
        mead_config['reporting'] = parse_extra_args(reporting, extra_args)
    return mead_config
Пример #37
0
def get_xpctl_settings(mead_settings):
    xpctl = mead_settings.get('reporting_hooks', {}).get('xpctl', {})
    if 'cred' not in xpctl:
        return None
    return read_config_file_or_json(xpctl['cred'])
Пример #38
0
def get_config(config, reporting, extra_args):
    mead_config = read_config_file_or_json(config)
    if reporting is not None:
        mead_config['reporting'] = parse_extra_args(reporting, extra_args)
    return mead_config