コード例 #1
0
class RemoteWorkerTest(unittest.TestCase):
    _multiprocess_shared_ = True

    @timeout(590)
    @unittest.skipIf(not on_gcp(), 'User indicated not on gcp')
    class UserIndicatedOnGCPTest(unittest.TestCase):
        def test_on_enviornment(self):
            self.assertTrue(
                'GOOGLE_APPLICATION_CREDENTIALS' in os.environ.keys())

    @unittest.skipIf(
        (not on_gcp())
        or 'GOOGLE_APPLICATION_CREDENTIALS' not in os.environ.keys(),
        'Skipping due to userinput or GCP Not detected' +
        'variable not set, won'
        ' be able to use google ' + 'PubSub')
    def test_remote_worker(self):
        experiment_name = 'test_remote_worker_' + str(uuid.uuid4())
        queue_name = experiment_name
        logger = logs.getLogger('test_remote_worker')
        logger.setLevel(10)

        pw = subprocess.Popen([
            'studio-start-remote-worker', '--queue=' + queue_name,
            '--single-run', '--no-cache', '--timeout=30',
            '--image=peterzhokhoff/studioml'
        ],
                              stdout=subprocess.PIPE,
                              stderr=subprocess.STDOUT)

        stubtest_worker(self,
                        experiment_name=experiment_name,
                        runner_args=['--queue=' + queue_name, '--force-git'],
                        config_name='test_config_http_client.yaml',
                        test_script='tf_hello_world.py',
                        script_args=['arg0'],
                        expected_output='[ 2.0 6.0 ]',
                        queue=PubsubQueue(queue_name))

        workerout, _ = pw.communicate()
        if workerout:
            logger.debug("studio-start-remote-worker output: \n" +
                         str(workerout))

    @timeout(590)
    @unittest.skipIf(
        (not on_gcp())
        or 'GOOGLE_APPLICATION_CREDENTIALS' not in os.environ.keys(),
        'Skipping due to userinput or GCP Not detected' +
        'variable not set, won'
        ' be able to use google ' + 'PubSub')
    def test_remote_worker_c(self):
        tmpfile = os.path.join(tempfile.gettempdir(), str(uuid.uuid4()))

        logger = logs.getLogger('test_remote_worker_c')
        logger.setLevel(10)
        experiment_name = "test_remote_worker_c_" + str(uuid.uuid4())

        random_str1 = str(uuid.uuid4())
        with open(tmpfile, 'w') as f:
            f.write(random_str1)

        random_str2 = str(uuid.uuid4())

        queue_name = experiment_name
        pw = subprocess.Popen([
            'studio-start-remote-worker', '--queue=' + queue_name,
            '--single-run', '--no-cache', '--image=peterzhokhoff/studioml'
        ],
                              stdout=subprocess.PIPE,
                              stderr=subprocess.STDOUT)

        db = stubtest_worker(self,
                             experiment_name=experiment_name,
                             runner_args=[
                                 '--capture=' + tmpfile + ':f',
                                 '--queue=' + queue_name, '--force-git'
                             ],
                             config_name='test_config_http_client.yaml',
                             test_script='art_hello_world.py',
                             script_args=[random_str2],
                             expected_output=random_str1,
                             queue=PubsubQueue(queue_name),
                             delete_when_done=False)

        workerout, _ = pw.communicate()
        if workerout:
            logger.debug("studio-start-remote-worker output: \n" +
                         str(workerout))
        os.remove(tmpfile)

        tmppath = os.path.join(tempfile.gettempdir(), str(uuid.uuid4()))
        if os.path.exists(tmppath):
            os.remove(tmppath)

        db.get_artifact(db.get_experiment(experiment_name).artifacts['f'],
                        tmppath,
                        only_newer=False)

        with open(tmppath, 'r') as f:
            self.assertEquals(f.read(), random_str2)
        os.remove(tmppath)
        db.delete_experiment(experiment_name)

    @timeout(590)
    @unittest.skipIf(
        (not on_gcp())
        or 'GOOGLE_APPLICATION_CREDENTIALS' not in os.environ.keys(),
        'Skipping due to userinput or GCP Not detected' +
        'variable not set, won'
        ' be able to use google ' + 'PubSub')
    def test_remote_worker_co(self):
        logger = logs.getLogger('test_remote_worker_co')
        logger.setLevel(10)

        tmpfile = os.path.join(tempfile.gettempdir(), str(uuid.uuid4()))

        random_str = str(uuid.uuid4())
        with open(tmpfile, 'w') as f:
            f.write(random_str)

        experiment_name = 'test_remote_worker_co_' + str(uuid.uuid4())
        queue_name = experiment_name
        pw = subprocess.Popen([
            'studio-start-remote-worker', '--queue=' + queue_name,
            '--single-run', '--no-cache', '--image=peterzhokhoff/studioml'
        ],
                              stdout=subprocess.PIPE,
                              stderr=subprocess.STDOUT)

        stubtest_worker(self,
                        experiment_name=experiment_name,
                        runner_args=[
                            '--capture-once=' + tmpfile + ':f',
                            '--queue=' + queue_name, '--force-git'
                        ],
                        config_name='test_config_http_client.yaml',
                        test_script='art_hello_world.py',
                        script_args=[],
                        expected_output=random_str,
                        queue=PubsubQueue(queue_name))

        workerout, _ = pw.communicate()
        logger.debug('studio-start-remote-worker output: \n' + str(workerout))

        os.remove(tmpfile)

    @timeout(590)
    @unittest.skipIf(
        (not on_gcp())
        or 'GOOGLE_APPLICATION_CREDENTIALS' not in os.environ.keys(),
        'Skipping due to userinput or GCP Not detected' +
        'variable not set, won'
        ' be able to use google ' + 'PubSub')
    def test_baked_image(self):

        # create a docker image with baked in credentials
        # and run a remote worker tests with it
        logger = logs.getLogger('test_baked_image')
        logger.setLevel(logs.DEBUG)

        # check if docker is installed
        dockertestp = subprocess.Popen(['docker'],
                                       stdout=subprocess.PIPE,
                                       stderr=subprocess.STDOUT)

        dockertestout, _ = dockertestp.communicate()
        if dockertestout:
            logger.info("docker test output: \n" + str(dockertestout))

        if dockertestp.returncode != 0:
            logger.error("docker is not installed (correctly)")
            return

        image = 'test_image' + str(uuid.uuid4())

        addcredsp = subprocess.Popen(
            [
                'studio-add-credentials', '--tag=' + image,
                '--base-image=peterzhokhoff/studioml'
            ],
            # stdout=subprocess.PIPE,
            # stderr=subprocess.STDOUT
        )

        addcredsout, _ = addcredsp.communicate()
        if addcredsout:
            logger.info('studio-add-credentials output: \n' + str(addcredsout))
        if addcredsp.returncode != 0:
            logger.error("studio-add-credentials failed.")
            self.assertTrue(False)

        experiment_name = 'test_remote_worker_baked' + str(uuid.uuid4())
        queue_name = experiment_name
        logger = logs.getLogger('test_baked_image')
        logger.setLevel(10)

        pw = subprocess.Popen(
            [
                'studio-start-remote-worker', '--queue=' + queue_name,
                '--no-cache', '--single-run', '--timeout=30',
                '--image=' + image
            ],
            # stdout=subprocess.PIPE,
            # stderr=subprocess.STDOUT
        )

        stubtest_worker(self,
                        experiment_name=experiment_name,
                        runner_args=['--queue=' + queue_name, '--force-git'],
                        config_name='test_config_http_client.yaml',
                        test_script='tf_hello_world.py',
                        script_args=['arg0'],
                        expected_output='[ 2.0 6.0 ]',
                        queue=PubsubQueue(queue_name))

        workerout, _ = pw.communicate()
        if workerout:
            logger.debug("studio-start-remote-worker output: \n" +
                         str(workerout))

        rmip = subprocess.Popen(['docker', 'rmi', image],
                                stdout=subprocess.PIPE,
                                stderr=subprocess.STDOUT)

        rmiout, _ = rmip.communicate()

        if rmiout:
            logger.info('docker rmi output: \n' + str(rmiout))
コード例 #2
0
        except IOError:
            exception_raised = True

        self.assertTrue(exception_raised)

    def test_get_qualified_location(self):
        fb = self.get_store()
        key = str(uuid.uuid4())
        qualified_location = fb.get_qualified_location(key)
        expected_qualified_location = self.get_qualified_location_prefix() + \
            key

        self.assertEqual(qualified_location, expected_qualified_location)


@unittest.skipIf(not on_gcp(), 'User indicated not on gcp')
class FirebaseArtifactStoreTest(ArtifactStoreTest, unittest.TestCase):
    # Tests of private methods

    def get_qualified_location_prefix(self):
        return "gs://studio-ed756.appspot.com/"

    def test_get_file_url(self):
        remove_all_keys()
        fb = self.get_store('test_config.yaml')

        tmp_filename = os.path.join(tempfile.gettempdir(),
                                    str(uuid.uuid4()) + '.txt')

        random_str = str(uuid.uuid4())
        with open(tmp_filename, 'wt') as f:
コード例 #3
0
ファイル: local_worker_test.py プロジェクト: zuma89/studio
class LocalWorkerTest(unittest.TestCase):

    @timeout(TEST_TIMEOUT, use_signals=True)
    def test_runner_local(self):
        with stubtest_worker(
                self,
                experiment_name='test_runner_local_' + str(uuid.uuid4()),
                config_name='test_config.yaml',
                test_script='tf_hello_world.py',
                runner_args=[],
                script_args=['arg0'],
                expected_output='[ 2.0 6.0 ]'
        ):
            pass

    @timeout(TEST_TIMEOUT, use_signals=True)
    def test_args_conflict(self):
        with stubtest_worker(
            self,
            experiment_name='test_runner_conflict_' + str(uuid.uuid4()),
            config_name='test_config.yaml',
            runner_args=[],
            test_script='conflicting_args.py',
            script_args=['--experiment', 'aaa'],
            expected_output='Experiment key = aaa'
        ):
            pass


    @timeout(TEST_TIMEOUT, use_signals=True)
    def test_local_hyperparam(self):
        with stubtest_worker(
            self,
            experiment_name='test_local_hyperparam' + str(uuid.uuid4()),
            runner_args=['--verbose='+EXPERIMENT_VERBOSE_LEVEL],
            config_name='test_config.yaml',
            test_script='hyperparam_hello_world.py',
            expected_output='0.3'
        ):
            pass

        with stubtest_worker(
            self,
            experiment_name='test_local_hyperparam' + str(uuid.uuid4()),
            runner_args=[
                '--verbose='+EXPERIMENT_VERBOSE_LEVEL,
                '--hyperparam=learning_rate=0.4'
            ],
            config_name='test_config.yaml',
            test_script='hyperparam_hello_world.py',
            expected_output='0.4'
        ):
            pass


    # @unittest.skip('peterz figure out the failure - happens intermittently ' +
    #                'when running in parallel')
    @timeout(TEST_TIMEOUT, use_signals=True)
    def test_local_worker_ce(self):
        tmpfile = os.path.join(tempfile.gettempdir(),
                               'tmpfile_ce_' +
                               str(uuid.uuid4()) + '.txt')

        random_str1 = str(uuid.uuid4())

        with open(tmpfile, 'w') as f:
            f.write(random_str1)

        random_str2 = str(uuid.uuid4())
        experiment_name = 'test_local_worker_c' + str(uuid.uuid4())
        print("random_str1 = " + random_str1)
        print("random_str2 = " + random_str2)
        print("experiment_name = " + experiment_name)
        print("tmpfile = " + tmpfile)

        with stubtest_worker(
            self,
            experiment_name=experiment_name,
            runner_args=['--capture=' + tmpfile + ':f',
                         '--verbose='+EXPERIMENT_VERBOSE_LEVEL],
            config_name='test_config.yaml',
            test_script='art_hello_world.py',
            script_args=[random_str2],
            expected_output=random_str1,
            delete_when_done=False
        ) as db:
            pass

        tmppath = db.get_artifact(
            db.get_experiment(experiment_name).artifacts['f']
        )

        with open(tmppath, 'r') as f:
            self.assertTrue(f.read() == random_str2)
        os.remove(tmppath)

        with stubtest_worker(
            self,
            experiment_name='test_local_worker_e' + str(uuid.uuid4()),
            runner_args=['--reuse={}/f:f'.format(experiment_name)],
            config_name='test_config.yaml',
            test_script='art_hello_world.py',
            script_args=[],
            expected_output=random_str2
        ) as db:

            db.delete_experiment(experiment_name)



    @timeout(TEST_TIMEOUT, use_signals=True)
    def test_local_worker_co(self):
        tmpfile = os.path.join(tempfile.gettempdir(),
                               'tmpfile' +
                               str(uuid.uuid4()) + '.txt')

        random_str = str(uuid.uuid4())
        with open(tmpfile, 'w') as f:
            f.write(random_str)

        # with get_local_queue_lock():
        with stubtest_worker(
            self,
            experiment_name='test_local_worker_co' + str(uuid.uuid4()),
            runner_args=['--capture-once=' + tmpfile + ':f'],
            config_name='test_config.yaml',
            test_script='art_hello_world.py',
            script_args=[],
            expected_output=random_str
        ):
            pass


    @unittest.skipIf(not on_gcp(), "NOT using GCP")
    @timeout(TEST_TIMEOUT, use_signals=True)
    def test_local_worker_co_url(self):
        expected_str = 'Zabil zaryad ya v pushku tugo'
        url = 'https://storage.googleapis.com/studio-ed756.appspot.com/' + \
              'tests/url_artifact.txt'

        # with get_local_queue_lock():
        with stubtest_worker(
            self,
            experiment_name='test_local_worker_co_url' + str(uuid.uuid4()),
            runner_args=['--capture-once=' + url + ':f'],
            config_name='test_config.yaml',
            test_script='art_hello_world.py',
            script_args=[],
            expected_output=expected_str
        ):
            pass

    @unittest.skipIf(
        not on_aws(),
        'User indicated not on aws')
    class UserIndicatedOnAWSTest(unittest.TestCase):
        def test_on_enviornment(self):
            self.assertTrue(has_aws_credentials())

    @unittest.skipIf(
        (not on_aws()) or not has_aws_credentials(),
        'Skipping due to userinput or AWS Not detected')
    @timeout(TEST_TIMEOUT, use_signals=True)
    def test_local_worker_co_s3(self):
        expected_str = 'No4 ulica fonar apteka, bessmyslennyj i tusklyj svet'
        s3loc = 's3://studioml-artifacts/tests/download_test/download_test.txt'

        # with get_local_queue_lock():
        with stubtest_worker(
            self,
            experiment_name='test_local_worker_co_s3' + str(uuid.uuid4()),
            runner_args=['--capture-once=' + s3loc + ':f'],
            config_name='test_config.yaml',
            test_script='art_hello_world.py',
            script_args=[],
            expected_output=expected_str
        ):
            pass


    @unittest.skipIf(keras is None,
                     'keras is required for this test')
    @timeout(TEST_TIMEOUT, use_signals=True)
    def test_save_get_model(self):
        experiment_name = 'test_save_get_model' + str(uuid.uuid4())
        # with get_local_queue_lock():
        with stubtest_worker(
            self,
            experiment_name=experiment_name,
            runner_args=[],
            config_name='test_config.yaml',
            test_script='save_model.py',
            script_args=[],
            expected_output='',
            delete_when_done=False,
            test_output=False
        ) as db:

            experiment = db.get_experiment(experiment_name)
            saved_model = self._get_model(db, experiment)

            v = np.random.rand(1, 2)
            prediction = saved_model.predict(v)
            expected = v * 2

            self.assertTrue(np.isclose(prediction, expected).all())

            db.delete_experiment(experiment)

    def _get_model(self, db, experiment):
        modeldir = db.get_artifact(experiment.artifacts['modeldir'])
        hdf5_files = [
            (p, os.path.getmtime(p))
            for p in
            glob.glob(modeldir + '/*.hdf*') +
            glob.glob(modeldir + '/*.h5')]
        if any(hdf5_files):
            # experiment type - keras
            import keras
            last_checkpoint = max(hdf5_files, key=lambda t: t[1])[0]
            return keras.models.load_model(last_checkpoint)

        if self.info.get('type') == 'tensorflow':
            raise NotImplementedError

        raise ValueError("Experiment type is unknown!")

    @timeout(TEST_TIMEOUT, use_signals=True)
    def test_stop_experiment(self):
        my_path = os.path.dirname(os.path.realpath(__file__))

        logger = logs.get_logger('test_stop_experiment')
        logger.setLevel(10)

        config_name = os.path.join(my_path, 'test_config.yaml')
        key = 'test_stop_experiment' + str(uuid.uuid4())

        with model.get_db_provider(model.get_config(config_name)) as db:
            try:
                db.delete_experiment(key)
            except Exception:
                pass

            p = subprocess.Popen(['studio', 'run',
                                  '--config=' + config_name,
                                  '--experiment=' + key,
                                  '--force-git',
                                  '--verbose='+EXPERIMENT_VERBOSE_LEVEL,
                                  'stop_experiment.py'],
                                 stdout=subprocess.PIPE,
                                 stderr=subprocess.STDOUT,
                                 cwd=my_path)

            # wait till experiment spins up
            experiment = None
            while experiment is None or experiment.status == 'waiting':
                time.sleep(1)
                try:
                    experiment = db.get_experiment(key)
                except BaseException:
                    pass

            logger.info('Stopping experiment')
            db.stop_experiment(key)
            pout, _ = p.communicate()

            if pout:
                logger.debug("studio run output: \n" + pout.decode())

            db.delete_experiment(key)


    @timeout(TEST_TIMEOUT, use_signals=True)
    def test_experiment_maxduration(self):
        my_path = os.path.dirname(os.path.realpath(__file__))

        logger = logs.get_logger('test_experiment_maxduration')
        logger.setLevel(10)

        config_name = os.path.join(my_path, 'test_config.yaml')
        key = 'test_experiment_maxduration' + str(uuid.uuid4())

        with model.get_db_provider(model.get_config(config_name)) as db:
            try:
                db.delete_experiment(key)
            except Exception:
                pass

            p = subprocess.Popen(['studio', 'run',
                                  '--config=' + config_name,
                                  '--experiment=' + key,
                                  '--force-git',
                                  '--verbose='+EXPERIMENT_VERBOSE_LEVEL,
                                  '--max-duration=10s',
                                  'stop_experiment.py'],
                                 stdout=subprocess.PIPE,
                                 stderr=subprocess.STDOUT,
                                 cwd=my_path)

            pout, _ = p.communicate()
            if pout:
                logger.debug("studio run output: \n" + pout.decode())

            db.delete_experiment(key)

    @timeout(TEST_TIMEOUT, use_signals=True)
    def test_experiment_lifetime(self):
        my_path = os.path.dirname(os.path.realpath(__file__))

        logger = logs.get_logger('test_experiment_lifetime')
        logger.setLevel(10)

        config_name = os.path.join(my_path, 'test_config.yaml')
        key = 'test_experiment_lifetime' + str(uuid.uuid4())

        with model.get_db_provider(model.get_config(config_name)) as db:
            try:
                db.delete_experiment(key)
            except Exception:
                pass

            p = subprocess.Popen(['studio', 'run',
                                  '--config=' + config_name,
                                  '--experiment=' + key,
                                  '--force-git',
                                  '--verbose='+EXPERIMENT_VERBOSE_LEVEL,
                                  '--lifetime=-10m',
                                  'stop_experiment.py'],
                                 stdout=subprocess.PIPE,
                                 stderr=subprocess.STDOUT,
                                 cwd=my_path)

            pout, _ = p.communicate()

            if pout:
                logger.debug("studio run output: \n" + pout.decode())

            db.delete_experiment(key)