Exemple #1
0
def get_queue(
        queue_name=None,
        cloud=None,
        config=None,
        logger=None,
        close_after=None,
        verbose=10):
    _ = verbose
    if queue_name is None:
        if cloud in ['gcloud', 'gcspot']:
            queue_name = 'pubsub_' + str(uuid.uuid4())
        elif cloud in ['ec2', 'ec2spot']:
            queue_name = 'sqs_' + str(uuid.uuid4())
        else:
            queue_name = 'local_' + str(uuid.uuid4())

    if queue_name.startswith('ec2') or \
       queue_name.startswith('sqs'):
        return SQSQueue(queue_name, config=config, logger=logger)
    if queue_name.startswith('rmq_'):
        return get_cached_queue(
            name=queue_name,
            route='StudioML.' + queue_name,
            config=config,
            close_after=close_after,
            logger=logger)
    if queue_name.startswith('local'):
        return LocalQueue(queue_name, logger=logger)
    return None
Exemple #2
0
def main(args=sys.argv):
    parser = argparse.ArgumentParser(description='Studio worker. \
                     Usage: studio-local-worker \
                     ')

    parser.add_argument('--config', help='configuration file', default=None)
    parser.add_argument('--guest',
                        help='Guest mode (does not require db credentials)',
                        action='store_true')
    parser.add_argument('--timeout', default=0, type=int)
    parser.add_argument('--verbose', default='error')

    # Register signal handler for signal.SIGUSR1
    # which will invoke built-in Python debugger:
    signal.signal(signal.SIGUSR1, lambda sig, stack: pdb.set_trace())

    parsed_args, script_args = parser.parse_known_args(args)
    verbose = parse_verbosity(parsed_args.verbose)

    # worker_config = None
    # if parsed_args.config is not None:
    #     print("Local Runner configuration file = {0}".format(parsed_args.config))
    #     with open(parsed_args.config) as f:
    #         worker_config = json.load(f)

    queue = LocalQueue('local')
    # queue = glob.glob(fs_tracker.get_queue_directory() + "/*")
    # wait_for_messages(queue, parsed_args.timeout)
    returncode = worker_loop(queue, parsed_args, timeout=parsed_args.timeout)
    sys.exit(returncode)
Exemple #3
0
 def get_queue(self):
     return LocalQueue('test')
 def get_queue(self):
     return LocalQueue()
Exemple #5
0
def stubtest_worker(
        testclass,
        experiment_name,
        runner_args,
        config_name,
        test_script,
        expected_output,
        script_args=[],
        queue=LocalQueue('test'),
        wait_for_experiment=True,
        delete_when_done=True,
        test_output=True,
        test_workspace=True):

    my_path = os.path.dirname(os.path.realpath(__file__))
    config_name = os.path.join(my_path, config_name)
    logger = logs.get_logger('stubtest_worker')
    logger.setLevel(10)

    queue.clean()

    with model.get_db_provider(model.get_config(config_name)) as db:
        try:
            db.delete_experiment(experiment_name)
        except Exception:
            pass

    os.environ['PYTHONUNBUFFERED'] = 'True'
    p = subprocess.Popen(['studio', 'run'] + runner_args +
                         ['--config=' + config_name,
                          '--verbose='+EXPERIMENT_VERBOSE_LEVEL,
                          '--force-git',
                          '--experiment=' + experiment_name,
                          test_script] + script_args,
                         stdout=subprocess.PIPE,
                         stderr=subprocess.STDOUT,
                         close_fds=True,
                         cwd=my_path)

    pout, _ = p.communicate()

    if pout:
        logger.debug("studio run output: \n" + sixdecode(pout))
        splitpout = sixdecode(pout).split('\n')
        experiments = [line.split(' ')[-1] for line in splitpout
                       if 'studio run: submitted experiment' in line]
        logger.debug("added experiments: {}".format(experiments))

    db = model.get_db_provider(model.get_config(config_name))
    experiment_name = experiments[0]

    try:
        experiment = db.get_experiment(experiment_name)
        if wait_for_experiment:
            while not experiment or not experiment.status == 'finished':
                experiment = db.get_experiment(experiment_name)

        if test_output:
            with open(db.get_artifact(experiment.artifacts['output']),
                      'r') as f:
                data = f.read()
                split_data = data.strip().split('\n')
                print(data)
                testclass.assertEqual(split_data[-1], expected_output)

        if test_workspace:
            check_workspace(testclass, db, experiment_name)

        if delete_when_done:
            retry(lambda: db.delete_experiment(experiment_name), sleep_time=10)

        return db

    except Exception as e:
        print("Exception {} raised during test".format(e))
        print("worker output: \n {}".format(pout))
        print("Exception trace:")
        print(traceback.format_exc())
        raise e