def test_experiment_lifetime(self): my_path = os.path.dirname(os.path.realpath(__file__)) logger = logs.getLogger('test_experiment_lifetime') logger.setLevel(10) config_name = os.path.join(my_path, 'test_config_http_client.yaml') key = 'test_experiment_lifetime' + str(uuid.uuid4()) with model.get_db_provider(model.get_config(config_name)) as db: try: db.delete_experiment(key) except Exception: pass with get_local_queue_lock(): p = subprocess.Popen([ 'studio', 'run', '--config=' + config_name, '--experiment=' + key, '--force-git', '--verbose=debug', '--lifetime=-10m', 'stop_experiment.py' ], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=my_path) pout, _ = p.communicate() if pout: logger.debug("studio run output: \n" + pout.decode()) db.delete_experiment(key)
def _test_serving(self, data_in, expected_data_out, wrapper=None): self.port = randint(5000, 9000) server_experimentid = 'test_serving_' + str(uuid.uuid4()) with get_local_queue_lock(): args = [ 'studio', 'run', '--force-git', '--verbose=debug', '--experiment=' + server_experimentid, '--config=' + self.get_config_path(), 'studio::serve_main', '--port=' + str(self.port), '--host=localhost' ] if wrapper: args.append('--wrapper=' + wrapper) subprocess.Popen(args, cwd=os.path.dirname(__file__)) time.sleep(60) try: retval = requests.post(url='http://localhost:' + str(self.port), json=data_in) data_out = retval.json() assert data_out == expected_data_out finally: with model.get_db_provider(model.get_config( self.get_config_path())) as db: db.stop_experiment(server_experimentid) time.sleep(20) db.delete_experiment(server_experimentid)
def test_args_conflict(self): with get_local_queue_lock(): with stubtest_worker(self, experiment_name='test_runner_conflict_' + str(uuid.uuid4()), runner_args=['--verbose=debug'], config_name='test_config.yaml', test_script='conflicting_args.py', script_args=['--experiment', 'aaa'], expected_output='Experiment key = aaa'): pass
def test_runner_local(self): with get_local_queue_lock(): with stubtest_worker(self, experiment_name='test_runner_local_' + str(uuid.uuid4()), runner_args=['--verbose=debug'], config_name='test_config_http_client.yaml', test_script='tf_hello_world.py', script_args=['arg0'], expected_output='[ 2. 6.]'): pass
def test_local_worker_co_s3(self): expected_str = 'No4 ulica fonar apteka, bessmyslennyj i tusklyj svet' s3loc = 's3://studioml-artifacts/tests/download_test/download_test.txt' with get_local_queue_lock(): with stubtest_worker( self, experiment_name='test_local_worker_co_s3' + str(uuid.uuid4()), runner_args=['--capture-once=' + s3loc + ':f'], config_name='test_config_http_client.yaml', test_script='art_hello_world.py', script_args=[], expected_output=expected_str): pass
def test_local_worker_co_url(self): expected_str = 'Zabil zaryad ya v pushku tugo' url = 'https://storage.googleapis.com/studio-ed756.appspot.com/' + \ 'tests/url_artifact.txt' with get_local_queue_lock(): with stubtest_worker(self, experiment_name='test_local_worker_co_url' + str(uuid.uuid4()), runner_args=['--capture-once=' + url + ':f'], config_name='test_config_http_client.yaml', test_script='art_hello_world.py', script_args=[], expected_output=expected_str): pass
def test_local_worker_co(self): tmpfile = os.path.join(tempfile.gettempdir(), 'tmpfile' + str(uuid.uuid4()) + '.txt') random_str = str(uuid.uuid4()) with open(tmpfile, 'w') as f: f.write(random_str) with get_local_queue_lock(): with stubtest_worker( self, experiment_name='test_local_worker_co' + str(uuid.uuid4()), runner_args=['--capture-once=' + tmpfile + ':f'], config_name='test_config_http_client.yaml', test_script='art_hello_world.py', script_args=[], expected_output=random_str): pass
def test_local_worker_ce(self): tmpfile = os.path.join(tempfile.gettempdir(), 'tmpfile' + str(uuid.uuid4()) + '.txt') random_str1 = str(uuid.uuid4()) with open(tmpfile, 'w') as f: f.write(random_str1) random_str2 = str(uuid.uuid4()) experiment_name = 'test_local_worker_c' + str(uuid.uuid4()) with get_local_queue_lock(): with stubtest_worker(self, experiment_name=experiment_name, runner_args=[ '--capture=' + tmpfile + ':f', '--verbose=debug' ], config_name='test_config_http_client.yaml', test_script='art_hello_world.py', script_args=[random_str2], expected_output=random_str1, delete_when_done=False) as db: tmppath = os.path.join(tempfile.gettempdir(), str(uuid.uuid4())) db.get_artifact( db.get_experiment(experiment_name).artifacts['f'], tmppath) with open(tmppath, 'r') as f: self.assertTrue(f.read() == random_str2) os.remove(tmppath) with stubtest_worker( self, experiment_name='test_local_worker_e' + str(uuid.uuid4()), runner_args=['--reuse={}/f:f'.format(experiment_name)], config_name='test_config_http_client.yaml', test_script='art_hello_world.py', script_args=[], expected_output=random_str2) as db: db.delete_experiment(experiment_name)
def test_stop_experiment(self): my_path = os.path.dirname(os.path.realpath(__file__)) logger = logs.getLogger('test_stop_experiment') logger.setLevel(10) config_name = os.path.join(my_path, 'test_config_http_client.yaml') key = 'test_stop_experiment' + str(uuid.uuid4()) with model.get_db_provider(model.get_config(config_name)) as db: try: db.delete_experiment(key) except Exception: pass with get_local_queue_lock(): p = subprocess.Popen([ 'studio', 'run', '--config=' + config_name, '--experiment=' + key, '--force-git', '--verbose=debug', 'stop_experiment.py' ], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=my_path) # wait till experiment spins up experiment = None while experiment is None or experiment.status == 'waiting': time.sleep(1) try: experiment = db.get_experiment(key) except BaseException: pass logger.info('Stopping experiment') db.stop_experiment(key) pout, _ = p.communicate() if pout: logger.debug("studio run output: \n" + pout.decode()) db.delete_experiment(key)
def test_two_experiments_apiserver(self): mypath = os.path.dirname(os.path.realpath(__file__)) config_path = os.path.join(mypath, '..', 'tests', 'test_config_datacenter.yaml') files_in_workspace = os.listdir(mypath) files = { f: os.path.join(mypath, f) for f in files_in_workspace if os.path.isfile(os.path.join(mypath, f)) } files['url'] = _file_url if has_aws_credentials(): files['s3'] = _file_s3 with get_local_queue_lock(): self._run_test_files(n_experiments=2, files=files, config=config_path)
def test_local_hyperparam(self): with get_local_queue_lock(): with stubtest_worker(self, experiment_name='test_local_hyperparam' + str(uuid.uuid4()), runner_args=['--verbose=debug'], config_name='test_config_http_client.yaml', test_script='hyperparam_hello_world.py', expected_output='0.3'): pass with stubtest_worker(self, experiment_name='test_local_hyperparam' + str(uuid.uuid4()), runner_args=[ '--verbose=debug', '--hyperparam=learning_rate=0.4' ], config_name='test_config_http_client.yaml', test_script='hyperparam_hello_world.py', expected_output='0.4'): pass
def test_save_get_model(self): experiment_name = 'test_save_get_model' + str(uuid.uuid4()) with get_local_queue_lock(): with stubtest_worker(self, experiment_name=experiment_name, runner_args=[], config_name='test_config_http_client.yaml', test_script='save_model.py', script_args=[], expected_output='', delete_when_done=False, test_output=False) as db: experiment = db.get_experiment(experiment_name) saved_model = experiment.get_model(db) v = np.random.rand(1, 2) prediction = saved_model.predict(v) expected = v * 2 self.assertTrue(np.isclose(prediction, expected).all()) db.delete_experiment(experiment)
def get_lock(self): return get_local_queue_lock()