예제 #1
0
def create_sample(num_classes, num_samples, exp_name):
    # create sample script
    logger = log.SearchLogger('logs',
                              exp_name,
                              delete_if_exists=True,
                              create_parent_folders=True)
    result_fp = 'logs/%s/sample_info.json' % exp_name
    out_fp = 'logs/%s/sample_slurm.out' % exp_name
    jobname = 'sampling_architecture_for_%s_experiment' % exp_name
    script = [
        '#!/bin/bash', '#SBATCH --nodes=1', '#SBATCH --partition=GPU-shared',
        '#SBATCH --gres=gpu:p100:1', '#SBATCH --ntasks-per-node=1',
        '#SBATCH --time=00:30:00',
        '#SBATCH --job-name=%s' % jobname,
        '#SBATCH --output=%s' % out_fp,
        'module load keras/2.2.0_tf1.7_py3_gpu ', 'source activate',
        'export PYTHONPATH=${PYTHONPATH}:/pylon5/ci4s8dp/maxle18/deep_architect/',
        'python examples/tutorials/multiworker/searcher.py --num_classes %d --num_samples %d --exp_name %s --result_fp %s'
        % (num_classes, num_samples, exp_name, result_fp)
    ]
    script_fp = 'logs/%s/sample.sh' % exp_name
    ut.write_textfile(script_fp, script)
    check_output(["chmod", "+x", script_fp])

    # submit batch job and obtain id
    job_id = check_output(['sbatch', script_fp]).split()[3]
    while True:  # perhaps need to have a time guard here
        if (done(job_id) and valid_filepath(result_fp)):
            return ut.read_jsonfile(result_fp)
    return dict()
예제 #2
0
def train_best(comm, state, config):
    search_results = ut.read_jsonfile(config['results_file'])
    best_k = list(
        reversed(
            sorted(
                zip(search_results['validation_accuracies'],
                    search_results['configs']))))[:config['num_architectures']]
    model_id = 0
    for rank, (_, vs) in enumerate(best_k):
        for trial in range(config['num_trials']):
            for lr in [.1, .05, .025, .01, .005, .001]:
                for wd in [.0001, .0003, .0005]:
                    if not comm.check_data_exists(
                            get_topic_name(ARCH_TOPIC, config), 'evaluation_id',
                            model_id):
                        logger.info(
                            'Publishing trial %d for rank %d architecture',
                            trial, rank)
                        arch = {
                            'vs': vs,
                            'evaluation_id': model_id,
                            'searcher_eval_token': {},
                            'eval_hparams': {
                                'init_lr': lr,
                                'weight_decay': wd,
                                'lr_decay_method': 'cosine',
                                'optimizer_type': 'sgd_mom'
                            }
                        }
                        comm.publish(get_topic_name(ARCH_TOPIC, config), arch)
                    model_id += 1
    state['models_sampled'] = model_id
    comm.publish(get_topic_name(ARCH_TOPIC, config), KILL_SIGNAL)
예제 #3
0
def main():

    num_classes = 10
    num_samples = 3  # number of architecture to sample
    exp_name = 'mnist_multiworker'
    best_val_acc, best_architecture = 0., -1

    # create and submit batch job to sample architectures and get information back
    architectures = create_sample(num_classes, num_samples, exp_name)

    # create and submit batch job to evaluate each architecture
    jobs = []
    for i in architectures:
        (config_fp, evaluation_fp) = architectures[i][
            "config_filepath"], architectures[i]["evaluation_filepath"]
        (job_id, result_fp) = create_and_submit_job(int(i), exp_name,
                                                    config_fp, evaluation_fp)
        jobs.append((int(i), job_id, result_fp))

    # extract result and remove job
    while (len(jobs) > 0):
        for (arch_id, result_fp, job_id) in jobs:
            if (done(job_id) and valid_filepath(result_fp)):
                results = ut.read_jsonfile(fp)
                val_acc = float(results['val_accuracy'])
                print(
                    "Finished evaluating architecture %d, validation accuracy is %f"
                    % (arch_i, val_acc))
                if val_acc > best_val_acc:
                    best_architecture, best_val_acc = arch_id, val_acc
            jobs.remove((arch_id, eval_logger, job_id))

    print("Best validation accuracy is %f with architecture %d" %
          (best_val_acc, best_architecture))
 def load_state(self, folderpath):
     filepath = ut.join_paths([folderpath, 'evolution_searcher.json'])
     state = ut.read_jsonfile(filepath)
     self.P = state["P"]
     self.S = state["S"]
     self.regularized = state['regularized']
     self.population = deque(state['population'])
     self.initializing = state['initializing']
예제 #5
0
 def load_state(self, folder):
     state = ut.read_jsonfile(os.path.join(folder, 'hash_model_state.json'))
     self.vals_lst = state['vals_lst']
     num_evals = state['num_evals']
     for i in range(num_evals):
         self.vecs_lst.append(
             sp.load_npz(os.path.join(folder,
                                      str(i) + '.npz')))
     if num_evals > 0:
         self._refit()
예제 #6
0
 def load_state(self, folderpath):
     state = ut.read_jsonfile(
         ut.join_paths([folderpath, 'hash_model_state.json']))
     self.vals_lst = state['vals_lst']
     num_evals = state['num_evals']
     for i in range(num_evals):
         self.vecs_lst.append(
             sp.load_npz(ut.join_paths([folderpath,
                                        str(i) + '.npz'])))
     if num_evals > 0:
         self._refit()
예제 #7
0
    def load_state(self, folder_name):
        filepath = join_paths([folder_name, 'evolution_searcher.json'])
        if not file_exists(filepath):
            raise RuntimeError("Load file does not exist")

        state = read_jsonfile(filepath)
        self.P = state["P"]
        self.S = state["S"]
        self.regularized = state['regularized']
        self.population = deque(state['population'])
        self.initializing = state['initializing']
예제 #8
0
def main(args):

    num_classes = 10

    # load and normalize data
    (x_train, y_train), (x_test, y_test) = load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0

    # defining evaluator, and logger
    evaluator = SimpleClassifierEvaluator((x_train, y_train),
                                          num_classes,
                                          max_num_training_epochs=5)
    inputs, outputs = get_search_space(num_classes)()
    h_values = ut.read_jsonfile(args.config)
    specify(outputs, hs,
            h_values["hyperp_value_lst"])  # hs is "extra" hyperparameters
    results = evaluator.evaluate(inputs, outputs)
    ut.write_jsonfile(results, args.result_fp)
예제 #9
0
def start_worker(comm,
                 evaluator,
                 search_space_factory,
                 folderpath,
                 search_name,
                 resume=True,
                 save_every=1):
    # set the available gpu for process
    print('WORKER %d' % comm.get_rank())
    step = 0

    sl.create_search_folderpath(folderpath, search_name)
    search_data_folder = sl.get_search_data_folderpath(folderpath, search_name)
    save_filepath = ut.join_paths(
        (search_data_folder, 'worker' + str(comm.get_rank()) + '.json'))

    if resume:
        evaluator.load_state(search_data_folder)
        state = ut.read_jsonfile(save_filepath)
        step = state['step']

    while (True):
        arch = comm.receive_architecture_in_worker()

        # if a kill signal is received
        if arch is None:
            break

        vs, evaluation_id, searcher_eval_token = arch

        inputs, outputs = search_space_factory.get_search_space()
        se.specify(outputs, vs)
        results = evaluator.eval(inputs, outputs)
        step += 1
        if step % save_every == 0:
            evaluator.save_state(search_data_folder)
            state = {'step': step}
            ut.write_jsonfile(state, save_filepath)
        comm.publish_results_to_master(results, evaluation_id,
                                       searcher_eval_token)
예제 #10
0
def read_evaluation_folder(evaluation_folderpath):
    """Reads all the standard JSON log files associated to a single evaluation.

    See also :func:`deep_architect.search_logging.read_search_folder` for the function
    that reads all the evaluations in a search folder.

    Args:
        evaluation_folderpath (str): Path to the folder containing the standard
            JSON logs.

    Returns:
        dict[str,dict[str,object]]:
            Nested dictionaries with the logged information. The first
            dictionary has keys corresponding to the names of the standard
            log files.
    """
    assert ut.folder_exists(evaluation_folderpath)

    name_to_log = {}
    for name in ['config', 'results']:
        log_filepath = ut.join_paths([evaluation_folderpath, name + '.json'])
        name_to_log[name] = ut.read_jsonfile(log_filepath)
    return name_to_log
예제 #11
0
def main():
    cmd = ut.CommandLineArgs()
    cmd.add('config_filepath', 'str')
    cmd.add('worker_id', 'int')
    cmd.add('num_workers', 'int')
    out = cmd.parse()
    cfg = ut.read_jsonfile(out['config_filepath'])

    # Loading the data.
    (Xtrain, ytrain, Xval, yval, Xtest, ytest) = load_mnist('data/mnist')
    train_dataset = InMemoryDataset(Xtrain, ytrain, True)
    val_dataset = InMemoryDataset(Xval, yval, False)
    test_dataset = InMemoryDataset(Xtest, ytest, False)

    # Creating up the evaluator.
    evaluator = SimpleClassifierEvaluator(
        train_dataset,
        val_dataset,
        ss.num_classes,
        './temp/worker%d' % out["worker_id"],
        max_eval_time_in_minutes=cfg['max_eval_time_in_minutes'],
        log_output_to_terminal=True,
        test_dataset=test_dataset)

    for evaluation_id in range(out["worker_id"], cfg["num_samples"],
                               out["num_workers"]):
        logger = sl.EvaluationLogger(cfg["folderpath"],
                                     cfg["search_name"],
                                     evaluation_id,
                                     abort_if_notexists=True)
        if not logger.results_exist():
            eval_cfg = logger.read_config()
            inputs, outputs = ss.search_space_fn()
            specify(outputs, eval_cfg["hyperp_value_lst"])
            results = evaluator.eval(inputs, outputs)
            logger.log_results(results)
예제 #12
0
 def read_config(self):
     assert ut.file_exists(self.config_filepath)
     return ut.read_jsonfile(self.config_filepath)
예제 #13
0
 def load_state(self, folder):
     state = ut.read_jsonfile(
         os.path.join(folder, 'mcts_searcher_state.json'))
     self.mcts_root_node = MCTSTreeNode.deserialize(state['mcts_root_node'])
예제 #14
0
 def read_results(self):
     assert ut.file_exists(self.results_filepath)
     return ut.read_jsonfile(self.results_filepath)
예제 #15
0
def process_config_and_args():

    parser = argparse.ArgumentParser("MPI Job for architecture search")
    parser.add_argument('--config',
                        '-c',
                        action='store',
                        dest='config_name',
                        default='normal')
    parser.add_argument(
        '--config-file',
        action='store',
        dest='config_file',
        default=
        '/deep_architect/kubernetes/mongo_communicator/train_best_config.json')
    parser.add_argument('--bucket',
                        '-b',
                        action='store',
                        dest='bucket',
                        default=BUCKET_NAME)

    # Other arguments
    parser.add_argument('--resume',
                        '-r',
                        action='store_true',
                        dest='resume',
                        default=False)
    parser.add_argument('--mongo-host',
                        '-m',
                        action='store',
                        dest='mongo_host',
                        default='127.0.0.1')
    parser.add_argument('--mongo-port',
                        '-p',
                        action='store',
                        dest='mongo_port',
                        default=27017)
    parser.add_argument('--repetition', default=0)
    options = parser.parse_args()
    configs = ut.read_jsonfile(options.config_file)
    config = configs[options.config_name]

    config['bucket'] = options.bucket

    comm = MongoCommunicator(host=options.mongo_host,
                             port=options.mongo_port,
                             data_refresher=True)

    # SET UP GOOGLE STORE FOLDER
    config['search_name'] = config['search_name'] + '_' + str(
        options.repetition)
    search_logger = sl.SearchLogger(config['search_folder'],
                                    config['search_name'])
    search_data_folder = search_logger.get_search_data_folderpath()
    config['save_filepath'] = ut.join_paths(
        (search_data_folder, config['searcher_file_name']))
    config['eval_path'] = sl.get_all_evaluations_folderpath(
        config['search_folder'], config['search_name'])
    config['full_search_folder'] = sl.get_search_folderpath(
        config['search_folder'], config['search_name'])
    config['results_file'] = os.path.join(
        config['results_prefix'] + '_' + str(options.repetition),
        config['results_file'])
    state = {'finished': 0, 'best_accuracy': 0.0}
    if options.resume:
        try:
            download_folder(search_data_folder, config['full_search_folder'],
                            config['bucket'])
            searcher.load_state(search_data_folder)
            if ut.file_exists(config['save_filepath']):
                old_state = ut.read_jsonfile(config['save_filepath'])
                state['finished'] = old_state['finished']
                state['best_accuracy'] = old_state['best_accuracy']
        except:
            pass

    return comm, search_logger, state, config
예제 #16
0
 def load_state(self, folderpath):
     state = ut.read_jsonfile(
         ut.join_paths([folderpath, 'mcts_searcher_state.json']))
     self.mcts_root_node = MCTSTreeNode.deserialize(state['mcts_root_node'])
예제 #17
0
def start_searcher(comm,
                   searcher,
                   resume_if_exists,
                   folderpath,
                   search_name,
                   searcher_load_path,
                   num_samples=-1,
                   num_epochs=-1,
                   save_every=1):
    assert num_samples != -1 or num_epochs != -1

    print('SEARCHER')

    sl.create_search_folderpath(folderpath, search_name)
    search_data_folder = sl.get_search_data_folderpath(folderpath, search_name)
    save_filepath = ut.join_paths((search_data_folder, searcher_load_path))

    models_sampled = 0
    epochs = 0
    finished = 0
    killed = 0
    best_accuracy = 0.

    # Load previous searcher
    if resume_if_exists:
        searcher.load(search_data_folder)
        state = ut.read_jsonfile(save_filepath)
        epochs = state['epochs']
        killed = state['killed']
        models_sampled = state['models_finished']
        finished = state['models_finished']

    while (finished < models_sampled or killed < comm.num_workers):
        # Search end conditions
        cont = num_samples == -1 or models_sampled < num_samples
        cont = cont and (num_epochs == -1 or epochs < num_epochs)
        if cont:
            # See whether workers are ready to consume architectures
            if comm.is_ready_to_publish_architecture():
                eval_logger = sl.EvaluationLogger(folderpath, search_name,
                                                  models_sampled)
                _, _, vs, searcher_eval_token = searcher.sample()

                eval_logger.log_config(vs, searcher_eval_token)
                comm.publish_architecture_to_worker(vs, models_sampled,
                                                    searcher_eval_token)

                models_sampled += 1
        else:
            if comm.is_ready_to_publish_architecture():
                comm.kill_worker()
                killed += 1

        # See which workers have finished evaluation
        for worker in range(comm.num_workers):
            msg = comm.receive_results_in_master(worker)
            if msg is not None:
                results, model_id, searcher_eval_token = msg
                eval_logger = sl.EvaluationLogger(folderpath, search_name,
                                                  model_id)
                eval_logger.log_results(results)

                if 'epoch' in results:
                    epochs = max(epochs, results['epoch'])

                searcher.update(results['validation_accuracy'],
                                searcher_eval_token)
                best_accuracy = max(best_accuracy,
                                    results['validation_accuracy'])
                finished += 1
                if finished % save_every == 0:
                    print('Models sampled: %d Best Accuracy: %f' %
                          (finished, best_accuracy))
                    best_accuracy = 0.

                    searcher.save_state(search_data_folder)
                    state = {
                        'models_finished': finished,
                        'epochs': epochs,
                        'killed': killed
                    }
                    ut.write_jsonfile(state, save_filepath)
예제 #18
0
def main():
    configs = ut.read_jsonfile(
        "./examples/tensorflow/full_benchmarks/experiment_config.json")

    parser = argparse.ArgumentParser("MPI Job for architecture search")
    parser.add_argument('--config',
                        '-c',
                        action='store',
                        dest='config_name',
                        default='normal')

    # Other arguments
    parser.add_argument('--display-output',
                        '-o',
                        action='store_true',
                        dest='display_output',
                        default=False)
    parser.add_argument('--resume',
                        '-r',
                        action='store_true',
                        dest='resume',
                        default=False)

    options = parser.parse_args()
    config = configs[options.config_name]

    num_procs = config['num_procs'] if 'num_procs' in config else 0
    comm = get_communicator(config['communicator'], num_procs)
    if len(gpu_utils.get_gpu_information()) != 0:
        #https://github.com/tensorflow/tensorflow/issues/1888
        gpu_utils.set_visible_gpus(
            [comm.get_rank() % gpu_utils.get_total_num_gpus()])

    if 'eager' in config and config['eager']:
        import tensorflow as tf
        tf.logging.set_verbosity(tf.logging.ERROR)
        tf.enable_eager_execution()
    datasets = {
        'cifar10': lambda: (load_cifar10('data/cifar10/', one_hot=False), 10),
        'mnist': lambda: (load_mnist('data/mnist/'), 10),
    }

    (Xtrain, ytrain, Xval, yval, Xtest,
     ytest), num_classes = datasets[config['dataset']]()
    search_space_factory = name_to_search_space_factory_fn[
        config['search_space']](num_classes)

    save_every = 1 if 'save_every' not in config else config['save_every']
    if comm.get_rank() == 0:
        searcher = name_to_searcher_fn[config['searcher']](
            search_space_factory.get_search_space)
        num_samples = -1 if 'samples' not in config else config['samples']
        num_epochs = -1 if 'epochs' not in config else config['epochs']
        start_searcher(comm,
                       searcher,
                       options.resume,
                       config['search_folder'],
                       config['search_name'],
                       config['searcher_file_name'],
                       num_samples=num_samples,
                       num_epochs=num_epochs,
                       save_every=save_every)
    else:
        train_d_advataset = InMemoryDataset(Xtrain, ytrain, True)
        val_dataset = InMemoryDataset(Xval, yval, False)
        test_dataset = InMemoryDataset(Xtest, ytest, False)

        search_path = sl.get_search_folderpath(config['search_folder'],
                                               config['search_name'])
        ut.create_folder(ut.join_paths([search_path, 'scratch_data']),
                         create_parent_folders=True)
        scratch_folder = ut.join_paths(
            [search_path, 'scratch_data', 'eval_' + str(comm.get_rank())])
        ut.create_folder(scratch_folder)

        evaluators = {
            'simple_classification':
            lambda: SimpleClassifierEvaluator(
                train_dataset,
                val_dataset,
                num_classes,
                './temp' + str(comm.get_rank()),
                max_num_training_epochs=config['eval_epochs'],
                log_output_to_terminal=options.display_output,
                test_dataset=test_dataset),
        }

        assert not config['evaluator'].startswith('enas') or hasattr(
            search_space_factory, 'weight_sharer')
        evaluator = evaluators[config['evaluator']]()

        start_worker(comm,
                     evaluator,
                     search_space_factory,
                     config['search_folder'],
                     config['search_name'],
                     resume=options.resume,
                     save_every=save_every)
예제 #19
0
from deep_architect.searchers import common as se
from deep_architect.contrib.misc import gpu_utils
from deep_architect import search_logging as sl
from deep_architect import utils as ut

from search_space_factory import name_to_search_space_factory_fn
from searcher import name_to_searcher_fn

from dev.enas.evaluator.enas_evaluator import ENASEvaluator
from deep_architect.contrib.misc.evaluators.tensorflow.classification import SimpleClassifierEvaluator

from deep_architect.contrib.communicators.communicator import get_communicator
logging.basicConfig()

configs = ut.read_jsonfile(
    "/deep_architect/dev/google_communicator/experiment_config.json")

parser = argparse.ArgumentParser("MPI Job for architecture search")
parser.add_argument('--config',
                    '-c',
                    action='store',
                    dest='config_name',
                    default='normal')
parser.add_argument('--project-id',
                    '-p',
                    action='store',
                    dest='project_id',
                    default='normal')
parser.add_argument('--bucket',
                    '-b',
                    action='store',
예제 #20
0
def main():
    global specified
    global evaluated
    global results_topic, arch_subscription
    configs = ut.read_jsonfile(
        "/deep_architect/dev/google_communicator/experiment_config.json")

    parser = argparse.ArgumentParser("MPI Job for architecture search")
    parser.add_argument('--config',
                        '-c',
                        action='store',
                        dest='config_name',
                        default='search_evol')

    # Other arguments
    parser.add_argument('--display-output',
                        '-o',
                        action='store_true',
                        dest='display_output',
                        default=False)
    parser.add_argument('--project-id',
                        action='store',
                        dest='project_id',
                        default='deep-architect')
    parser.add_argument('--bucket',
                        '-b',
                        action='store',
                        dest='bucket',
                        default='normal')
    parser.add_argument('--resume',
                        '-r',
                        action='store_true',
                        dest='resume',
                        default=False)

    options = parser.parse_args()
    config = configs[options.config_name]

    PROJECT_ID = options.project_id
    BUCKET_NAME = options.bucket
    results_topic = publisher.topic_path(PROJECT_ID, 'results')
    arch_subscription = subscriber.subscription_path(PROJECT_ID,
                                                     'architectures-sub')

    datasets = {
        'cifar10': ('/data/cifar10/', 10),
    }

    data_dir, num_classes = datasets[config['dataset']]
    search_space_factory = name_to_search_space_factory_fn[
        config['search_space']](num_classes)

    save_every = 1 if 'save_every' not in config else config['save_every']

    evaluators = {
        'tpu_classification':
        lambda: TPUEstimatorEvaluator(
            'gs://' + BUCKET_NAME + data_dir,
            max_num_training_epochs=config['eval_epochs'],
            log_output_to_terminal=options.display_output,
            base_dir='gs://' + BUCKET_NAME + '/scratch_dir'),
    }

    evaluator = evaluators[config['evaluator']]()

    search_data_folder = sl.get_search_data_folderpath(config['search_folder'],
                                                       config['search_name'])
    subscription = subscriber.subscribe(arch_subscription,
                                        callback=retrieve_message)
    thread = threading.Thread(target=nudge_master)
    thread.start()
    step = 0
    while True:
        while not specified:
            time.sleep(5)
        if arch_data:
            vs, evaluation_id, searcher_eval_token = arch_data
            inputs, outputs = search_space_factory.get_search_space()
            se.specify(outputs, vs)
            print('Evaluating architecture')
            results = evaluator.eval(inputs, outputs)
            print('Evaluated architecture')
            step += 1
            if step % save_every == 0:
                evaluator.save_state(search_data_folder)
            encoded_results = json.dumps((results, vs, evaluation_id,
                                          searcher_eval_token)).encode('utf-8')
            future = publisher.publish(results_topic, encoded_results)
            future.result()
            evaluated = True
            specified = False
        else:
            break
    thread.join()
    subscription.cancel()
예제 #21
0
def process_config_and_args():
    parser = argparse.ArgumentParser("Worker Job for architecture search")
    parser.add_argument('--config',
                        '-c',
                        action='store',
                        dest='config_name',
                        default='search_evol')
    parser.add_argument(
        '--config-file',
        action='store',
        dest='config_file',
        default=
        '/deep_architect/examples/contrib/kubernetes/experiment_config.json')
    # Other arguments
    parser.add_argument('--display-output',
                        '-o',
                        action='store_true',
                        dest='display_output',
                        default=False)
    parser.add_argument('--resume',
                        '-r',
                        action='store_true',
                        dest='resume',
                        default=False)
    parser.add_argument('--bucket',
                        '-b',
                        action='store',
                        dest='bucket',
                        default=BUCKET_NAME)
    parser.add_argument('--mongo-host',
                        '-m',
                        action='store',
                        dest='mongo_host',
                        default='127.0.0.1')
    parser.add_argument('--tpu-name',
                        '-t',
                        action='store',
                        dest='tpu_name',
                        default='')
    parser.add_argument('--mongo-port',
                        '-p',
                        action='store',
                        dest='mongo_port',
                        default=27017)
    parser.add_argument('--repetition', default=0)
    parser.add_argument('--log',
                        choices=['debug', 'info', 'warning', 'error'],
                        default='info')

    options = parser.parse_args()

    numeric_level = getattr(logging, options.log.upper(), None)
    if not isinstance(numeric_level, int):
        raise ValueError('Invalid log level: %s' % options.log)
    logging.getLogger().setLevel(numeric_level)

    configs = ut.read_jsonfile(options.config_file)
    config = configs[options.config_name]

    config['bucket'] = options.bucket

    datasets = {
        'cifar10': ('/data/cifar10/', 10),
    }

    data_dir, num_classes = datasets[config['dataset']]
    search_space_factory = name_to_search_space_factory_fn[
        config['search_space']](num_classes)

    config['save_every'] = 1 if 'save_every' not in config else config[
        'save_every']

    config['search_name'] = config['search_name'] + '_' + str(
        options.repetition)
    evaluators = {
        'tpu_classification':
        lambda **kwargs: TPUEstimatorEvaluator(
            'gs://' + config['bucket'] + data_dir,
            options.tpu_name,
            max_num_training_epochs=config['eval_epochs'],
            log_output_to_terminal=options.display_output,
            base_dir='gs://' + config['bucket'] + '/' + config['search_folder']
            + '/' + config['search_name'] + '/scratch_dir',
            **kwargs),
    }

    comm = MongoCommunicator(options.mongo_host,
                             options.mongo_port,
                             refresh_period=10)
    evaluator_fn = evaluators[config['evaluator']]

    return comm, search_space_factory, evaluator_fn, config
예제 #22
0
def process_config_and_args():

    parser = argparse.ArgumentParser("MPI Job for architecture search")
    parser.add_argument('--config',
                        '-c',
                        action='store',
                        dest='config_name',
                        default='normal')
    parser.add_argument(
        '--config-file',
        action='store',
        dest='config_file',
        default=
        '/deep_architect/examples/contrib/kubernetes/experiment_config.json')
    parser.add_argument('--bucket',
                        '-b',
                        action='store',
                        dest='bucket',
                        default=BUCKET_NAME)

    # Other arguments
    parser.add_argument('--resume',
                        '-r',
                        action='store_true',
                        dest='resume',
                        default=False)
    parser.add_argument('--mongo-host',
                        '-m',
                        action='store',
                        dest='mongo_host',
                        default='127.0.0.1')
    parser.add_argument('--mongo-port',
                        '-p',
                        action='store',
                        dest='mongo_port',
                        default=27017)
    parser.add_argument('--log',
                        choices=['debug', 'info', 'warning', 'error'],
                        default='info')
    parser.add_argument('--repetition', default=0)
    options = parser.parse_args()

    numeric_level = getattr(logging, options.log.upper(), None)
    if not isinstance(numeric_level, int):
        raise ValueError('Invalid log level: %s' % options.log)
    logging.getLogger().setLevel(numeric_level)

    configs = ut.read_jsonfile(options.config_file)
    config = configs[options.config_name]

    config['bucket'] = options.bucket

    comm = MongoCommunicator(host=options.mongo_host,
                             port=options.mongo_port,
                             data_refresher=True,
                             refresh_period=10)

    datasets = {
        'cifar10': ('data/cifar10/', 10),
    }

    _, num_classes = datasets[config['dataset']]
    search_space_factory = name_to_search_space_factory_fn[
        config['search_space']](num_classes)

    config['save_every'] = 1 if 'save_every' not in config else config[
        'save_every']
    searcher = name_to_searcher_fn[config['searcher']](
        search_space_factory.get_search_space)
    config['num_epochs'] = -1 if 'epochs' not in config else config['epochs']
    config[
        'num_samples'] = -1 if 'samples' not in config else config['samples']

    # SET UP GOOGLE STORE FOLDER
    config['search_name'] = config['search_name'] + '_' + str(
        options.repetition)
    search_logger = sl.SearchLogger(config['search_folder'],
                                    config['search_name'])
    search_data_folder = search_logger.get_search_data_folderpath()
    config['save_filepath'] = ut.join_paths(
        (search_data_folder, config['searcher_file_name']))
    config['eval_path'] = sl.get_all_evaluations_folderpath(
        config['search_folder'], config['search_name'])
    config['full_search_folder'] = sl.get_search_folderpath(
        config['search_folder'], config['search_name'])
    config['eval_hparams'] = {} if 'eval_hparams' not in config else config[
        'eval_hparams']

    state = {
        'epochs': 0,
        'models_sampled': 0,
        'finished': 0,
        'best_accuracy': 0.0
    }
    if options.resume:
        try:
            download_folder(search_data_folder, config['full_search_folder'],
                            config['bucket'])
            searcher.load_state(search_data_folder)
            if ut.file_exists(config['save_filepath']):
                old_state = ut.read_jsonfile(config['save_filepath'])
                state['epochs'] = old_state['epochs']
                state['models_sampled'] = old_state['models_sampled']
                state['finished'] = old_state['finished']
                state['best_accuracy'] = old_state['best_accuracy']
        except:
            pass

    return comm, search_logger, searcher, state, config