def test_get_dtype(self):
        img_path = data_file_path('small-rgb-tile.tif')
        with RVConfig.get_tmp_dir() as tmp_dir:
            source = rv.data.RasterioSourceConfig(uris=[img_path]) \
                            .create_source(tmp_dir)

            self.assertEqual(source.get_dtype(), np.uint8)
    def _run_experiment(self, command_dag):
        """Runs all commands on this machine."""
        def run_commands(tmp_dir):
            for command_config in command_dag.get_sorted_commands():
                msg = command_config.to_proto()
                builder = rv._registry.get_command_config_builder(
                    msg.command_type)()
                command_config = builder.from_proto(msg).build()

                command_root_uri = command_config.root_uri
                command_basename = 'command-config-{}.json'.format(
                    command_config.split_id)
                command_uri = os.path.join(command_root_uri, command_basename)
                log.info('Saving command configuration to {}...'.format(
                    command_uri))
                save_json_config(command_config.to_proto(), command_uri)

                command = command_config.create_command()
                command.run(tmp_dir)

        if self.tmp_dir:
            run_commands(self.tmp_dir)
        else:
            with RVConfig.get_tmp_dir() as tmp_dir:
                run_commands(tmp_dir)
    def test_no_epsg(self):
        crs = rasterio.crs.CRS()
        with RVConfig.get_tmp_dir() as tmp_dir:
            image_path = os.path.join(tmp_dir, 'temp.tif')
            height = 100
            width = 100
            nb_channels = 3
            with rasterio.open(image_path,
                               'w',
                               driver='GTiff',
                               height=height,
                               width=width,
                               count=nb_channels,
                               dtype=np.uint8,
                               crs=crs) as image_dataset:
                im = np.zeros((height, width, nb_channels)).astype(np.uint8)
                for channel in range(nb_channels):
                    image_dataset.write(im[:, :, channel], channel + 1)

            try:
                rv.RasterSourceConfig.builder(rv.RASTERIO_SOURCE) \
                                     .with_uri(image_path) \
                                     .build() \
                                     .create_source(tmp_dir=tmp_dir)
            except Exception:
                self.fail(
                    'Creating RasterioSource with CRS with no EPSG attribute '
                    'raised an exception when it should not have.')
    def test_mask(self):
        with RVConfig.get_tmp_dir() as temp_dir:
            # make geotiff filled with ones and zeros and mask the whole image
            image_path = os.path.join(temp_dir, 'temp.tif')
            height = 100
            width = 100
            nb_channels = 3
            with rasterio.open(image_path,
                               'w',
                               driver='GTiff',
                               height=height,
                               width=width,
                               count=nb_channels,
                               dtype=np.uint8) as image_dataset:
                im = np.random.randint(
                    0, 2, (height, width, nb_channels)).astype(np.uint8)
                for channel in range(nb_channels):
                    image_dataset.write(im[:, :, channel], channel + 1)
                image_dataset.write_mask(
                    np.zeros(im.shape[0:2]).astype(np.bool))

            source = rv.RasterSourceConfig.builder(rv.RASTERIO_SOURCE) \
                                          .with_uri(image_path) \
                                          .build() \
                                          .create_source(tmp_dir=temp_dir)
            with source.activate():
                out_chip = source.get_image_array()
                expected_out_chip = np.zeros((height, width, nb_channels))
                np.testing.assert_equal(out_chip, expected_out_chip)
    def test_gets_raw_chip_from_uint16_transformed_proto(self):
        img_path = data_file_path('small-uint16-tile.tif')
        channel_order = [0, 1]

        with RVConfig.get_tmp_dir() as temp_dir:
            stats_uri = os.path.join(temp_dir, 'temp.tif')
            stats = RasterStats()
            stats.compute([
                rv.RasterSourceConfig.builder(rv.RASTERIO_SOURCE).with_uri(
                    img_path).build().create_source(temp_dir)
            ])
            stats.save(stats_uri)

            transformer = rv.RasterTransformerConfig.builder(rv.STATS_TRANSFORMER) \
                                                    .with_stats_uri(stats_uri) \
                                                    .build()

            msg = rv.RasterSourceConfig.builder(rv.RASTERIO_SOURCE) \
                                       .with_uri(img_path) \
                                       .with_channel_order(channel_order) \
                                       .with_transformer(transformer) \
                                       .build() \
                                       .to_proto()

            source = rv.RasterSourceConfig.from_proto(msg) \
                                          .create_source(tmp_dir=None)

            with source.activate():
                out_chip = source.get_raw_image_array()
                self.assertEqual(out_chip.shape[2], 3)
Пример #6
0
    def test_respects_utilizes_gpu(self):
        config = self.mock_config()
        config['AWS_BATCH_job_queue'] = 'GPU_JOB_QUEUE'
        config['AWS_BATCH_job_definition'] = 'GPU_JOB_DEF'
        config['AWS_BATCH_cpu_job_queue'] = 'CPU_JOB_QUEUE'
        config['AWS_BATCH_cpu_job_definition'] = 'CPU_JOB_DEF'

        rv._registry.initialize_config(config_overrides=config)

        with RVConfig.get_tmp_dir() as tmp_dir:
            e = mk.create_mock_experiment().to_builder() \
                                           .with_root_uri(tmp_dir) \
                                           .clear_command_uris() \
                                           .build()

            runner = MockAwsBatchExperimentRunner()

            runner.run(
                e, commands_to_run=[rv.CHIP, rv.TRAIN, rv.PREDICT, rv.EVAL])

            submit_args = runner.mock_client.submit_job.call_args_list

            self.assertEqual(len(submit_args), 4)

            for args in submit_args:
                jobName, jobQueue = args[1]['jobName'], args[1]['jobQueue']

                if 'EVAL' in jobName or 'CHIP' in jobName:
                    self.assertTrue('CPU' in jobQueue)
                else:
                    self.assertTrue('GPU' in jobQueue)
Пример #7
0
def main(tests, rv_root, verbose):
    """Runs RV end-to-end and checks that evaluation metrics are correct."""
    if len(tests) == 0:
        tests = all_tests

    if verbose:
        rv._registry.initialize_config(
            verbosity=rv.cli.verbosity.Verbosity.DEBUG)

    tests = list(map(lambda x: x.upper(), tests))

    with RVConfig.get_tmp_dir() as temp_dir:
        if rv_root:
            temp_dir = rv_root

        errors = []
        for test in tests:
            if test not in all_tests:
                print('{} is not a valid test.'.format(test))
                return

            errors.extend(run_test(test, temp_dir))

            for error in errors:
                print(error)

        for test in tests:
            nb_test_errors = len(
                list(filter(lambda error: error.test == test, errors)))
            if nb_test_errors == 0:
                print('{} test passed!'.format(test))

        if errors:
            exit(1)
    def test_command_create(self):
        task = rv.TaskConfig.builder(mk.MOCK_TASK).build()
        with RVConfig.get_tmp_dir() as tmp_dir:
            img_path = os.path.join(tmp_dir, 'img.tif')
            chip = np.ones((2, 2, 4)).astype(np.uint8)
            chip[:, :, :] *= np.array([0, 1, 2, 3]).astype(np.uint8)
            save_img(chip, img_path)

            source = rv.data.RasterioSourceConfig(img_path)

            scenes = [rv.data.SceneConfig('', source)]
            analyzers = [
                rv.analyzer.StatsAnalyzerConfig(stats_uri='dummy_path')
            ]

            cmd_conf = rv.CommandConfig.builder(rv.ANALYZE) \
                                       .with_task(task) \
                                       .with_root_uri(tmp_dir) \
                                       .with_scenes(scenes) \
                                       .with_analyzers(analyzers) \
                                       .build()

            cmd_conf = rv.command.CommandConfig.from_proto(cmd_conf.to_proto())
            cmd = cmd_conf.create_command()

            self.assertTrue(cmd, rv.command.AnalyzeCommand)
Пример #9
0
    def test_make_predict_windows_with_aoi(self):
        task_config = rv.TaskConfig.builder(rv.CHIP_CLASSIFICATION) \
                                   .with_chip_size(200) \
                                   .with_classes(['car', 'building', 'background']) \
                                   .build()

        backend_config = rv.BackendConfig.builder(rv.KERAS_CLASSIFICATION) \
                                         .with_task(task_config) \
                                         .with_model_defaults(rv.RESNET50_IMAGENET) \
                                         .with_pretrained_model(None) \
                                         .build()

        label_source_uri = data_file_path('evaluator/cc-label-full.json')
        label_source = rv.LabelSourceConfig.builder(rv.CHIP_CLASSIFICATION_GEOJSON) \
                                           .with_uri(label_source_uri) \
                                           .build()

        label_source_2_uri = data_file_path('evaluator/cc-label-filtered.json')
        label_source_2 = rv.LabelSourceConfig.builder(rv.CHIP_CLASSIFICATION_GEOJSON) \
                                           .with_uri(label_source_2_uri) \
                                           .build()

        source_uri = data_file_path('evaluator/cc-label-img-blank.tif')
        raster_source = rv.RasterSourceConfig.builder(rv.RASTERIO_SOURCE) \
                                             .with_uri(source_uri) \
                                             .build()

        aoi_uri = data_file_path('evaluator/cc-label-aoi.json')
        s = rv.SceneConfig.builder() \
                          .with_id('test') \
                          .with_raster_source(raster_source) \
                          .with_label_source(label_source) \
                          .with_aoi_uri(aoi_uri) \
                          .build()

        with RVConfig.get_tmp_dir() as tmp_dir:
            scene = s.create_scene(task_config, tmp_dir)
            backend = backend_config.create_backend(task_config)
            task = task_config.create_task(backend)

            with scene.activate():
                windows = task.get_train_windows(scene)

            from rastervision.data import (ChipClassificationLabels,
                                           ChipClassificationGeoJSONStore)
            labels = ChipClassificationLabels()
            for w in windows:
                labels.set_cell(w, 1)
            store = ChipClassificationGeoJSONStore(
                os.path.join(tmp_dir, 'test.json'),
                scene.raster_source.get_crs_transformer(),
                task_config.class_map)
            store.save(labels)

            ls = label_source_2.create_source(
                task_config, scene.raster_source.get_extent(),
                scene.raster_source.get_crs_transformer(), tmp_dir)
            actual = ls.get_labels().get_cells()

            self.assertEqual(len(windows), len(actual))
    def test_command_run_with_mocks(self):
        with RVConfig.get_tmp_dir() as tmp_dir:
            predict_package_uri = os.path.join(tmp_dir, 'predict_package.zip')

            task_config = rv.TaskConfig.builder(mk.MOCK_TASK).build()

            backend_config = rv.BackendConfig.builder(mk.MOCK_BACKEND).build()
            scene = mk.create_mock_scene()
            analyzer_config = rv.AnalyzerConfig.builder(
                mk.MOCK_ANALYZER).build()

            cmd_conf = rv.command.BundleCommandConfig.builder() \
                                                     .with_task(task_config) \
                                                     .with_backend(backend_config) \
                                                     .with_scene(scene) \
                                                     .with_analyzers([analyzer_config]) \
                                                     .with_root_uri('.') \
                                                     .build()

            cmd_conf = rv.command.CommandConfig.from_proto(cmd_conf.to_proto())

            cmd_conf.task.predict_package_uri = predict_package_uri
            analyzer_config = cmd_conf.analyzers[0]

            cmd = cmd_conf.create_command()

            cmd.run()

            self.assertTrue(os.path.exists(predict_package_uri))
            self.assertTrue(analyzer_config.mock.save_bundle_files.called)
Пример #11
0
    def setUp(self):
        self.crs_transformer = DoubleCRSTransformer()
        self.geojson = {
            'type':
            'FeatureCollection',
            'features': [{
                'type': 'Feature',
                'geometry': {
                    'type':
                    'MultiPolygon',
                    'coordinates': [[[[0., 0.], [0., 2.], [2., 2.], [2., 0.],
                                      [0., 0.]]]]
                },
                'properties': {
                    'class_name': 'car',
                    'class_id': 1,
                    'score': 0.0
                }
            }, {
                'type': 'Feature',
                'geometry': {
                    'type':
                    'Polygon',
                    'coordinates': [[[2., 2.], [2., 4.], [4., 4.], [4., 2.],
                                     [2., 2.]]]
                },
                'properties': {
                    'score': 0.0,
                    'class_name': 'house',
                    'class_id': 2
                }
            }]
        }

        self.class_map = ClassMap([ClassItem(1, 'car'), ClassItem(2, 'house')])

        class MockTaskConfig():
            def __init__(self, class_map):
                self.class_map = class_map

        self.task_config = MockTaskConfig(self.class_map)

        self.box1 = Box.make_square(0, 0, 4)
        self.box2 = Box.make_square(4, 4, 4)
        self.class_id1 = 1
        self.class_id2 = 2
        self.background_class_id = 3

        geoms = []
        for f in self.geojson['features']:
            g = shape(f['geometry'])
            g.class_id = f['properties']['class_id']
            geoms.append(g)
        self.str_tree = STRtree(geoms)

        self.file_name = 'labels.json'
        self.temp_dir = RVConfig.get_tmp_dir()
        self.uri = os.path.join(self.temp_dir.name, self.file_name)
        json_to_file(self.geojson, self.uri)
Пример #12
0
 def setUp(self):
     self.crs_transformer = IdentityCRSTransformer()
     self.extent = Box.make_square(0, 0, 10)
     self.tmp_dir = RVConfig.get_tmp_dir()
     self.class_id = 2
     self.background_class_id = 3
     self.line_buffer = 1
     self.uri = os.path.join(self.tmp_dir.name, 'temp.json')
Пример #13
0
    def test_required_fields(self):
        with RVConfig.get_tmp_dir() as tmp_dir:
            b = rv.CommandConfig.builder(mk.MOCK_AUX_COMMAND) \
                                .with_config() \
                                .with_root_uri(tmp_dir)
            with self.assertRaises(rv.ConfigError) as context:
                b.build()

            self.assertTrue('uris' in str(context.exception))
 def save_debug_predict_image(self, scene, debug_dir_uri):
     img = draw_debug_predict_image(scene, self.config.class_map)
     # Saving to a jpg leads to segfault for unknown reasons.
     debug_image_uri = join(debug_dir_uri, scene.id + '.png')
     with RVConfig.get_tmp_dir() as temp_dir:
         debug_image_path = get_local_path(debug_image_uri, temp_dir)
         make_dir(debug_image_path, use_dirname=True)
         img.save(debug_image_path)
         upload_or_copy(debug_image_path, debug_image_uri)
Пример #15
0
 def get_tmp_dir(self):
     if hasattr(self, '_tmp_dir') and self._tmp_dir:
         if isinstance(self._tmp_dir, str):
             return self._tmp_dir
         else:
             return self._tmp_dir.name
     else:
         tmp_dir = RVConfig.get_tmp_dir()
         self.set_tmp_dir(tmp_dir)
         return tmp_dir.name
    def create_command(self, tmp_dir=None):
        if not tmp_dir:
            _tmp_dir = RVConfig.get_tmp_dir()
            tmp_dir = _tmp_dir.name
        else:
            _tmp_dir = tmp_dir

        retval = BundleCommand(self)
        retval.set_tmp_dir(_tmp_dir)
        return retval
Пример #17
0
    def create_command(self, tmp_dir=None):
        if not tmp_dir:
            _tmp_dir = RVConfig.get_tmp_dir()
            tmp_dir = _tmp_dir.name
        else:
            _tmp_dir = tmp_dir

        retval = self.command_class(self.config)
        retval.set_tmp_dir(_tmp_dir)

        return retval
Пример #18
0
    def test_bundle_od_command(self):
        def get_task(tmp_dir):
            predict_package_uri = os.path.join(tmp_dir, 'predict_package.zip')
            t = rv.TaskConfig.builder(rv.OBJECT_DETECTION) \
                             .with_predict_package_uri(predict_package_uri) \
                             .with_classes(['class1']) \
                             .build()
            return t

        def get_backend(task, tmp_dir):
            model_uri = os.path.join(tmp_dir, 'model')
            template_uri = data_file_path(
                'tf_object_detection/embedded_ssd_mobilenet_v1_coco.config')
            with open(model_uri, 'w') as f:
                f.write('DUMMY')
            b = rv.BackendConfig.builder(rv.TF_OBJECT_DETECTION) \
                                .with_task(task) \
                                .with_template(template_uri) \
                                .with_model_uri(model_uri) \
                                .build()
            return b

        with RVConfig.get_tmp_dir() as tmp_dir:
            task = get_task(tmp_dir)
            backend = get_backend(task, tmp_dir)
            analyzer = self.get_analyzer(tmp_dir)
            scene = self.get_scene(tmp_dir)
            cmd = rv.CommandConfig.builder(rv.BUNDLE) \
                                  .with_task(task) \
                                  .with_root_uri(tmp_dir) \
                                  .with_backend(backend) \
                                  .with_analyzers([analyzer]) \
                                  .with_scene(scene) \
                                  .build() \
                                  .create_command()

            cmd.run(tmp_dir)

            package_dir = os.path.join(tmp_dir, 'package')
            make_dir(package_dir)
            with zipfile.ZipFile(task.predict_package_uri, 'r') as package_zip:
                package_zip.extractall(path=package_dir)

            bundle_config_path = os.path.join(package_dir,
                                              'bundle_config.json')
            bundle_config = load_json_config(bundle_config_path,
                                             CommandConfigMsg())

            self.assertEqual(bundle_config.command_type, rv.BUNDLE)

            actual = set(os.listdir(package_dir))
            expected = set(['stats.json', 'model', 'bundle_config.json'])

            self.assertEqual(actual, expected)
Пример #19
0
        def _merge_training_results(results, record_path, split):
            # merge each scene's tfrecord into "split" tf record
            merge_tf_records(record_path, results)

            # Save debug chips.
            if self.config.debug:
                debug_zip_path = training_package.get_local_path(
                    training_package.get_debug_chips_uri(split))
                with RVConfig.get_tmp_dir() as debug_dir:
                    make_debug_images(record_path, self.class_map, debug_dir)
                    shutil.make_archive(
                        os.path.splitext(debug_zip_path)[0], 'zip', debug_dir)
 def test_no_config_error(self):
     task = rv.task.ChipClassificationConfig({})
     try:
         with RVConfig.get_tmp_dir() as tmp_dir:
             rv.command.AnalyzeCommandConfig.builder() \
                                            .with_task(task) \
                                            .with_root_uri(tmp_dir) \
                                            .with_scenes(['']) \
                                            .with_analyzers(['']) \
                                            .build()
     except rv.ConfigError:
         self.fail('rv.ConfigError raised unexpectedly')
Пример #21
0
    def setUp(self):
        # Setup mock S3 bucket.
        self.mock_s3 = mock_s3()
        self.mock_s3.start()
        self.s3 = boto3.client('s3')
        self.bucket_name = 'mock_bucket'
        self.s3.create_bucket(Bucket=self.bucket_name)

        self.content_str = 'hello'
        self.file_name = 'hello.txt'
        self.temp_dir = RVConfig.get_tmp_dir()
        self.cache_dir = os.path.join(self.temp_dir.name, 'cache')
Пример #22
0
 def test_no_config_error(self):
     task = rv.task.ChipClassificationConfig({})
     backend = rv.backend.KerasClassificationConfig('')
     try:
         with RVConfig.get_tmp_dir() as tmp_dir:
             rv.command.TrainCommandConfig.builder() \
                                          .with_task(task) \
                                          .with_root_uri(tmp_dir) \
                                          .with_backend(backend) \
                                          .build()
     except rv.ConfigError:
         self.fail('rv.ConfigError raised unexpectedly')
Пример #23
0
    def setUp(self):
        self.lorem = LOREM

        # Mock S3 bucket
        self.mock_s3 = mock_s3()
        self.mock_s3.start()
        self.s3 = boto3.client('s3')
        self.bucket_name = 'mock_bucket'
        self.s3.create_bucket(Bucket=self.bucket_name)

        # Temporary directory
        self.temp_dir = RVConfig.get_tmp_dir()
Пример #24
0
    def test_bundle_cc_command(self):
        def get_task(tmp_dir):
            predict_package_uri = os.path.join(tmp_dir, 'predict_package.zip')
            t = rv.TaskConfig.builder(rv.CHIP_CLASSIFICATION) \
                             .with_predict_package_uri(predict_package_uri) \
                             .with_classes(['class1']) \
                             .build()
            return t

        def get_backend(task, tmp_dir):
            model_uri = os.path.join(tmp_dir, 'model')
            with open(model_uri, 'w') as f:
                f.write('DUMMY')
            b = rv.BackendConfig.builder(rv.KERAS_CLASSIFICATION) \
                                .with_task(task) \
                                .with_model_defaults(rv.RESNET50_IMAGENET) \
                                .with_model_uri(model_uri) \
                                .build()
            return b

        with RVConfig.get_tmp_dir() as tmp_dir:
            task = get_task(tmp_dir)
            backend = get_backend(task, tmp_dir)
            analyzer = self.get_analyzer(tmp_dir)
            scene = self.get_scene(tmp_dir)
            cmd = rv.CommandConfig.builder(rv.BUNDLE) \
                                  .with_task(task) \
                                  .with_root_uri(tmp_dir) \
                                  .with_backend(backend) \
                                  .with_analyzers([analyzer]) \
                                  .with_scene(scene) \
                                  .build() \
                                  .create_command(tmp_dir)

            cmd.run(tmp_dir)

            package_dir = os.path.join(tmp_dir, 'package')
            make_dir(package_dir)
            with zipfile.ZipFile(task.predict_package_uri, 'r') as package_zip:
                package_zip.extractall(path=package_dir)

            bundle_config_path = os.path.join(package_dir,
                                              'bundle_config.json')
            bundle_config = load_json_config(bundle_config_path,
                                             CommandConfigMsg())

            self.assertEqual(bundle_config.command_type, rv.BUNDLE)

            actual = set(os.listdir(package_dir))
            expected = set(['stats.json', 'model', 'bundle_config.json'])

            self.assertEqual(actual, expected)
Пример #25
0
    def create_command(self, tmp_dir=None):
        if len(self.train_scenes) == 0 and len(self.val_scenes) == 0:
            return NoOpCommand()

        if not tmp_dir:
            _tmp_dir = RVConfig.get_tmp_dir()
            tmp_dir = _tmp_dir.name
        else:
            _tmp_dir = tmp_dir

        retval = ChipCommand(self)
        retval.set_tmp_dir(_tmp_dir)
        return retval
Пример #26
0
    def create_command(self, tmp_dir=None):
        if len(self.scenes) == 0 or len(self.evaluators) == 0:
            return NoOpCommand()

        if not tmp_dir:
            _tmp_dir = RVConfig.get_tmp_dir()
            tmp_dir = _tmp_dir.name
        else:
            _tmp_dir = tmp_dir

        retval = EvalCommand(self)
        retval.set_tmp_dir(_tmp_dir)
        return retval
    def test_accounts_for_aoi(self):
        task = rv.TaskConfig.builder(rv.CHIP_CLASSIFICATION) \
                            .with_classes(['car', 'building', 'background']) \
                            .build()

        label_source_uri = data_file_path('evaluator/cc-label-filtered.json')
        label_source = rv.LabelSourceConfig.builder(rv.CHIP_CLASSIFICATION_GEOJSON) \
                                           .with_uri(label_source_uri) \
                                           .build()

        label_store_uri = data_file_path('evaluator/cc-label-full.json')
        label_store = rv.LabelStoreConfig.builder(rv.CHIP_CLASSIFICATION_GEOJSON) \
                                         .with_uri(label_store_uri) \
                                         .build()

        source_uri = data_file_path('evaluator/cc-label-img-blank.tif')
        raster_source = rv.RasterSourceConfig.builder(rv.GEOTIFF_SOURCE) \
                                             .with_uri(source_uri) \
                                             .build()

        aoi_uri = data_file_path('evaluator/cc-label-aoi.json')
        s = rv.SceneConfig.builder() \
                          .with_id('test') \
                          .with_raster_source(raster_source) \
                          .with_label_source(label_source) \
                          .with_label_store(label_store) \
                          .with_aoi_uri(aoi_uri) \
                          .build()

        with RVConfig.get_tmp_dir() as tmp_dir:
            scene = s.create_scene(task, tmp_dir)

            output_uri = os.path.join(tmp_dir, 'eval.json')

            e = rv.EvaluatorConfig.builder(rv.CHIP_CLASSIFICATION_EVALUATOR) \
                                  .with_task(task) \
                                  .with_output_uri(output_uri) \
                                  .build()

            e.update_for_command(rv.EVAL, create_mock_experiment())

            evaluator = e.create_evaluator()

            evaluator.process([scene], tmp_dir)

            results = None
            with open(output_uri) as f:
                results = json.loads(f.read())['overall']

            for result in results:
                self.assertEqual(result['f1'], 1.0)
    def test_channel_order_error(self):
        with RVConfig.get_tmp_dir() as tmp_dir:
            img_path = os.path.join(tmp_dir, 'img.tif')
            chip = np.ones((2, 2, 3)).astype(np.uint8)
            chip[:, :, :] *= np.array([0, 1, 2]).astype(np.uint8)
            save_img(chip, img_path)

            channel_order = [3, 1, 0]
            with self.assertRaises(ChannelOrderError):
                rv.RasterSourceConfig.builder(rv.RASTERIO_SOURCE) \
                                     .with_uri(img_path) \
                                     .with_channel_order(channel_order) \
                                     .build() \
                                     .create_source(tmp_dir=tmp_dir)
Пример #29
0
def predict(predict_package, image_uri, output_uri, update_stats,
            channel_order, export_config):
    """Make predictions on the image at IMAGE_URI
    using PREDICT_PACKAGE and store the
    prediciton output at OUTPUT_URI.
    """
    if channel_order is not None:
        channel_order = [
            int(channel_ind) for channel_ind in channel_order.split(' ')
        ]

    with RVConfig.get_tmp_dir() as tmp_dir:
        predictor = rv.Predictor(predict_package, tmp_dir, update_stats,
                                 channel_order)
        predictor.predict(image_uri, output_uri, export_config)
    def test_command_create(self):
        task = rv.TaskConfig.builder(mk.MOCK_TASK).build()
        backend = rv.BackendConfig.builder(mk.MOCK_BACKEND).build()
        with RVConfig.get_tmp_dir() as tmp_dir:
            cmd_conf = rv.CommandConfig.builder(rv.PREDICT) \
                                       .with_task(task) \
                                       .with_root_uri(tmp_dir) \
                                       .with_scenes([]) \
                                       .with_backend(backend) \
                                       .build()

            cmd_conf = rv.command.CommandConfig.from_proto(cmd_conf.to_proto())
            cmd = cmd_conf.create_command()

            self.assertTrue(cmd, rv.command.PredictCommand)