class RasterioSourceConfig(RasterSourceConfig): uris: List[str] = Field( ..., description= ('List of image URIs that comprise imagery for a scene. The format of each file ' 'can be any that can be read by Rasterio/GDAL. If > 1 URI is provided, a VRT ' 'will be created to mosaic together the individual images.')) x_shift: float = Field( 0.0, descriptions= ('A number of meters to shift along the x-axis. A positive shift moves the ' '"camera" to the right.')) y_shift: float = Field( 0.0, descriptions= ('A number of meters to shift along the y-axis. A positive shift moves the ' '"camera" down.')) def build(self, tmp_dir, use_transformers=True): raster_transformers = ([rt.build() for rt in self.transformers] if use_transformers else []) return RasterioSource(self.uris, raster_transformers, tmp_dir, channel_order=self.channel_order, x_shift=self.x_shift, y_shift=self.y_shift)
class SemanticSegmentationLabelSourceConfig(LabelSourceConfig): """Config for a read-only label source for semantic segmentation.""" raster_source: Union[RasterSourceConfig, RasterizedSourceConfig] = Field( ..., description='The labels in the form of rasters.') rgb_class_config: Optional[ClassConfig] = Field( None, description= ('If set, will infer the class_ids for the labels using the colors field. This ' 'assumes the labels are stored as RGB rasters.')) def update(self, pipeline=None, scene=None): super().update() if self.rgb_class_config is not None: self.rgb_class_config.ensure_null_class() def build(self, class_config, crs_transformer, extent, tmp_dir): if isinstance(self.raster_source, RasterizedSourceConfig): rs = self.raster_source.build(class_config, crs_transformer, extent) else: rs = self.raster_source.build(tmp_dir) return SemanticSegmentationLabelSource( rs, class_config.get_null_class_id(), rgb_class_config=self.rgb_class_config)
class SemanticSegmentationDataConfig(DataConfig): data_format: SemanticSegmentationDataFormat = SemanticSegmentationDataFormat.default img_channels: PositiveInt = Field( 3, description='The number of channels of the training images.') img_format: Optional[str] = Field( None, description='The filetype of the training images.') label_format: str = Field( 'png', description='The filetype of the training labels.') channel_display_groups: Optional[Union[dict, list, tuple]] = Field( None, description= ('Groups of image channels to display together as a subplot ' 'when plotting the data and predictions. ' 'Can be a list or tuple of groups (e.g. [(0, 1, 2), (3,)]) or a dict ' 'containing title-to-group mappings ' '(e.g. {"RGB": [0, 1, 2], "IR": [3]}), ' 'where each group is a list or tuple of channel indices and title ' 'is a string that will be used as the title of the subplot ' 'for that group.')) def update(self, **kwargs): super().update() if self.img_format is None: self.img_format = 'png' if self.img_channels == 3 else 'npy' if self.channel_display_groups is None: self.channel_display_groups = { 'Input': tuple(range(self.img_channels)) }
class SolverConfig(Config): """Config related to solver aka optimizer.""" lr: PositiveFloat = Field(1e-4, description='Learning rate.') num_epochs: PositiveInt = Field( 10, description= 'Number of epochs (ie. sweeps through the whole training set).') test_num_epochs: PositiveInt = Field( 2, description='Number of epochs to use in test mode.') test_batch_sz: PositiveInt = Field( 4, description='Batch size to use in test mode.') overfit_num_steps: PositiveInt = Field( 1, description='Number of optimizer steps to use in overfit mode.') sync_interval: PositiveInt = Field( 1, description='The interval in epochs for each sync to the cloud.') batch_sz: PositiveInt = Field(32, description='Batch size.') one_cycle: bool = Field( True, description= ('If True, use triangular LR scheduler with a single cycle across all ' 'epochs with start and end LR being lr/10 and the peak being lr.')) multi_stage: List = Field( [], description=('List of epoch indices at which to divide LR by 10.')) class_loss_weights: Optional[Union[list, tuple]] = Field( None, description=('Class weights for weighted loss.')) def update(self, learner: Optional['LearnerConfig'] = None): pass
class StatsAnalyzerConfig(AnalyzerConfig): """Config for an Analyzer that computes imagery statistics of scenes.""" output_uri: Optional[str] = Field( None, description= ('URI for output. If None and this is part of an RVPipeline, this is ' 'auto-generated.')) sample_prob: Optional[float] = Field( 0.1, description= ('The probability of using a random window for computing statistics. ' 'If None, will use a sliding window.')) def update(self, pipeline=None): if pipeline is not None and self.output_uri is None: self.output_uri = join(pipeline.analyze_uri, 'stats.json') def validate_config(self): if self.sample_prob > 1 or self.sample_prob <= 0: raise ConfigError('sample_prob must be <= 1 and > 0') def build(self): from rastervision.core.analyzer import StatsAnalyzer return StatsAnalyzer(self.output_uri, self.sample_prob) def get_bundle_filenames(self): return ['stats.json']
class PyTorchLearnerBackendConfig(BackendConfig): model: ModelConfig solver: SolverConfig log_tensorboard: bool = Field( True, description='If True, log events to Tensorboard log files.') run_tensorboard: bool = Field( False, description='If True, run Tensorboard server pointing at log files.') augmentors: List[str] = Field( default_augmentors, description=( 'Names of albumentations augmentors to use for training batches. ' 'Choices include: ' + str(augmentor_list))) test_mode: bool = Field( False, description= ('This field is passed along to the LearnerConfig which is returned by ' 'get_learner_config(). For more info, see the docs for' 'pytorch_learner.learner_config.LearnerConfig.test_mode.')) def get_bundle_filenames(self): return ['model-bundle.zip'] def get_learner_config(self, pipeline): raise NotImplementedError() def build(self, pipeline, tmp_dir): raise NotImplementedError()
class VectorOutputConfig(Config): """Config for vectorized semantic segmentation predictions.""" uri: Optional[str] = Field( None, description= ('URI of vector output. If None, and this Config is part of a SceneConfig and ' 'RVPipeline, this field will be auto-generated.')) class_id: int = Field( ..., description='The prediction class that is to turned into vectors.') denoise: int = Field( 0, description= ('Radius of the structural element used to remove high-frequency signals from ' 'the image.')) def update(self, pipeline=None, scene=None): if self.uri is None: if pipeline and scene: mode = self.get_mode() class_id = self.class_id filename = f'{mode}-{class_id}.json' self.uri = join(pipeline.predict_uri, scene.id, 'vector_output', filename) def get_mode(self): raise NotImplementedError()
class PyTorchLearnerBackendConfig(BackendConfig): model: ModelConfig solver: SolverConfig data: DataConfig log_tensorboard: bool = Field( True, description='If True, log events to Tensorboard log files.') run_tensorboard: bool = Field( False, description='If True, run Tensorboard server pointing at log files.') test_mode: bool = Field( False, description= ('This field is passed along to the LearnerConfig which is returned by ' 'get_learner_config(). For more info, see the docs for' 'pytorch_learner.learner_config.LearnerConfig.test_mode.')) def get_bundle_filenames(self): return ['model-bundle.zip'] def update(self, pipeline: Optional[RVPipeline] = None): super().update(pipeline=pipeline) if isinstance(self.data, ImageDataConfig): if self.data.uri is None and self.data.group_uris is None: self.data.uri = pipeline.chip_uri if not self.data.class_names: self.data.class_names = pipeline.dataset.class_config.names if not self.data.class_colors: self.data.class_colors = pipeline.dataset.class_config.colors def get_learner_config(self, pipeline: Optional[RVPipeline]): raise NotImplementedError() def build(self, pipeline: Optional[RVPipeline], tmp_dir: str): raise NotImplementedError()
class BuildingVectorOutputConfig(VectorOutputConfig): """Config for vectorized semantic segmentation predictions. Intended to break up clusters of buildings. """ min_aspect_ratio: float = Field( 1.618, description= ('Ratio between length and height (or height and length) of anything that can ' 'be considered to be a cluster of buildings. The goal is to distinguish between ' 'rows of buildings and (say) a single building.')) min_area: float = Field( 0.0, description= ('Minimum area of anything that can be considered to be a cluster of buildings. ' 'The goal is to distinguish between buildings and artifacts.')) element_width_factor: float = Field( 0.5, description= ('Width of the structural element used to break building clusters as a fraction ' 'of the width of the cluster.')) element_thickness: float = Field( 0.001, description= ('Thickness of the structural element that is used to break building clusters.' )) def get_mode(self): return 'buildings'
class ModelConfig(Config): """Config related to models.""" backbone: Backbone = Field( Backbone.resnet18, description='The torchvision.models backbone to use.') pretrained: bool = Field( True, description= ('If True, use ImageNet weights. If False, use random initialization.' )) init_weights: Optional[str] = Field( None, description=('URI of PyTorch model weights used to initialize model. ' 'If set, this supercedes the pretrained option.')) load_strict: bool = Field( True, description=( 'If True, the keys in the state dict referenced by init_weights ' 'must match exactly. Setting this to False can be useful if you ' 'just want to load the backbone of a model.')) external_def: Optional[ExternalModuleConfig] = Field( None, description='If specified, the model will be built from the ' 'definition from this external source, using Torch Hub.') def update(self, learner: Optional['LearnerConfig'] = None): pass def get_backbone_str(self): return self.backbone.name
class RegressionPlotOptions(PlotOptions): max_scatter_points: int = Field( 5000, description=('Maximum number of datapoints to use in scatter plot. ' 'Useful to avoid running out of memory and cluttering.')) hist_bins: int = Field(30, description='Number of bins to use for histogram.')
class PipelineConfig(Config): """Base class for configuring Pipelines. This should be subclassed to configure new Pipelines. """ root_uri: str = Field( None, description='The root URI for output generated by the pipeline') rv_config: dict = Field( None, description='Used to store serialized RVConfig so pipeline can ' 'run in remote environment with the local RVConfig. This should ' 'not be set explicitly by users -- it is only used by the runner ' 'when running a remote pipeline.') plugin_versions: Optional[Dict[str, int]] = Field( None, description= ('Used to store a mapping of plugin module paths to the latest ' 'version number. This should not be set explicitly by users -- it is set ' 'automatically when serializing and saving the config to disk.')) def get_config_uri(self) -> str: """Get URI of serialized version of this PipelineConfig.""" return join(self.root_uri, 'pipeline-config.json') def build(self, tmp_dir: str) -> 'Pipeline': """Return a pipeline based on this configuration. Subclasses should override this to return an instance of the corresponding subclass of Pipeline. Args: tmp_dir: root of any temporary directory to pass to pipeline """ from rastervision.pipeline.pipeline import Pipeline # noqa return Pipeline(self, tmp_dir)
class SemanticSegmentationLabelStoreConfig(LabelStoreConfig): """Config for storage for semantic segmentation predictions. Stores class raster as GeoTIFF, and can optionally vectorizes predictions and stores them in GeoJSON files. """ uri: Optional[str] = Field( None, description=( 'URI of file with predictions. If None, and this Config is part of ' 'a SceneConfig inside an RVPipelineConfig, this fiend will be ' 'auto-generated.')) vector_output: List[VectorOutputConfig] = [] rgb: bool = Field( False, description= ('If True, save prediction class_ids in RGB format using the colors in ' 'class_config.')) smooth_output: bool = Field( False, description='If True, expects labels to be continuous values ' 'representing class scores and stores both scores and discrete ' 'labels.') smooth_as_uint8: bool = Field( False, description='If True, stores smooth scores as uint8, resulting in ' 'loss of precision, but reduced file size. Only used if ' 'smooth_output=True.') rasterio_block_size: int = Field( 256, description='blockxsize and blockysize params in rasterio.open() will ' 'be set to this.') def build(self, class_config, crs_transformer, extent, tmp_dir): class_config.ensure_null_class() label_store = SemanticSegmentationLabelStore( self.uri, extent, crs_transformer, tmp_dir, vector_outputs=self.vector_output, class_config=class_config, save_as_rgb=self.rgb, smooth_output=self.smooth_output, smooth_as_uint8=self.smooth_as_uint8, rasterio_block_size=self.rasterio_block_size) return label_store def update(self, pipeline=None, scene=None): if pipeline is not None and scene is not None: if self.uri is None: self.uri = join(pipeline.predict_uri, f'{scene.id}') for vo in self.vector_output: vo.update(pipeline, scene)
class DataConfig(Config): """Config related to dataset for training and testing.""" uri: Union[None, str, List[str]] = Field( None, description= ('URI of the dataset. This can be a zip file, a list of zip files, or a ' 'directory which contains a set of zip files.')) train_sz: Optional[int] = Field( None, description= ('If set, the number of training images to use. If fewer images exist, ' 'then an exception will be raised.')) group_uris: Union[None, List[Union[str, List[str]]]] = Field( None, description= ('This can be set instead of uri in order to specify groups of chips. Each ' 'element in the list is expected to be an object of the same form accepted by ' 'the uri field. The purpose of separating chips into groups is to be able to ' 'use the group_train_sz field.')) group_train_sz: Optional[int] = Field( None, description= ('If group_uris is set, this can be used to specify the number of chips to use ' 'per group.')) data_format: Optional[str] = Field(None, description='Name of dataset format.') class_names: List[str] = Field([], description='Names of classes.') class_colors: Union[None, List[str], List[List]] = Field( None, description=('Colors used to display classes. ' 'Can be color 3-tuples in list form.')) img_sz: PositiveInt = Field( 256, description= ('Length of a side of each image in pixels. This is the size to transform ' 'it to during training, not the size in the raw dataset.')) num_workers: int = Field( 4, description='Number of workers to use when DataLoader makes batches.') # TODO support setting parameters of augmentors? augmentors: List[str] = Field( default_augmentors, description=( 'Names of albumentations augmentors to use for training batches. ' 'Choices include: ' + str(augmentors))) def update(self, learner: Optional['LearnerConfig'] = None): if not self.class_colors: self.class_colors = [color_to_triple() for _ in self.class_names] def validate_augmentors(self): self.validate_list('augmentors', augmentors) def validate_config(self): self.validate_augmentors()
class ClassConfig(Config): """Configures the class names that are being predicted.""" names: List[str] = Field(..., description='Names of classes.') colors: Optional[List[Union[List, str]]] = Field( None, description= ('Colors used to visualize classes. Can be color strings accepted by ' 'matplotlib or RGB tuples. If None, a random color will be auto-generated ' 'for each class.')) null_class: Optional[str] = Field( None, description= ('Optional name of class in `names` to use as the null class. This is used in ' 'semantic segmentation to represent the label for imagery pixels that are ' 'NODATA or that are missing a label. If None, and this Config is part of a ' 'SemanticSegmentationConfig, a null class will be added automatically.' )) def get_class_id(self, name): return self.names.index(name) def get_name(self, id): return self.names[id] def get_null_class_id(self): if self.null_class is None: raise ValueError('null_class is not set') return self.get_class_id(self.null_class) def get_color_to_class_id(self): return dict([(self.colors[i], i) for i in range(len(self.colors))]) def ensure_null_class(self): """Add a null class if one isn't set.""" if self.null_class is None: self.null_class = 'null' if self.null_class not in self.names: self.names.append('null') self.colors.append('black') def update(self, pipeline=None): if not self.colors: self.colors = [color_to_triple() for _ in self.names] def validate_config(self): if self.null_class is not None and self.null_class not in self.names: raise ConfigError( 'The null_class: {} must be in list of class names.'.format( self.null_class)) def __len__(self): return len(self.names)
class ObjectDetectionPredictOptions(Config): merge_thresh: float = Field( 0.5, description= ('If predicted boxes have an IOA (intersection over area) greater than ' 'merge_thresh, then they are merged into a single box during postprocessing. ' 'This is needed since the sliding window approach results in some false ' 'duplicates.')) score_thresh: float = Field( 0.5, description= ('Predicted boxes are only output if their score is above score_thresh.' ))
class RasterizerConfig(Config): background_class_id: int = Field( ..., description= ('The class_id to use for any background pixels, ie. pixels not covered by a ' 'polygon.')) all_touched: bool = Field( False, description=( 'If True, all pixels touched by geometries will be burned in. ' 'If false, only pixels whose center is within the polygon or ' 'that are selected by Bresenham’s line algorithm will be ' 'burned in. (See rasterio.features.rasterize).'))
class PyTorchLearnerBackendConfig(BackendConfig): model: ModelConfig solver: SolverConfig log_tensorboard: bool = Field( True, description='If True, log events to Tensorboard log files.') run_tensorboard: bool = Field( False, description='If True, run Tensorboard server pointing at log files.') augmentors: List[str] = Field( default_augmentors, description='Names of albumentations augmentors to use for training ' f'batches. Choices include: {augmentor_list}. Alternatively, a custom ' 'transform can be provided via the aug_transform option.') base_transform: Optional[dict] = Field( None, description='An Albumentations transform serialized as a dict that ' 'will be applied to all datasets: training, validation, and test. ' 'This transformation is in addition to the resizing due to img_sz. ' 'This is useful for, for example, applying the same normalization to ' 'all datasets.') aug_transform: Optional[dict] = Field( None, description='An Albumentations transform serialized as a dict that ' 'will be applied as data augmentation to the training dataset. This ' 'transform is applied before base_transform. If provided, the ' 'augmentors option is ignored.') test_mode: bool = Field( False, description= ('This field is passed along to the LearnerConfig which is returned by ' 'get_learner_config(). For more info, see the docs for' 'pytorch_learner.learner_config.LearnerConfig.test_mode.')) plot_options: Optional[PlotOptions] = Field( PlotOptions(), description='Options to control plotting.') img_sz: Optional[PositiveInt] = Field( None, description='Length of a side of each image in pixels. This is the ' 'size to transform it to during training, not the size in the raw ' 'dataset. Defaults to train_chip_sz in the pipeline config.') num_workers: int = Field( 4, description='The number of workers to use in PyTorch to read data.') # validators _base_tf = validator( 'base_transform', allow_reuse=True)(validate_albumentation_transform) _aug_tf = validator( 'aug_transform', allow_reuse=True)(validate_albumentation_transform) def get_bundle_filenames(self): return ['model-bundle.zip'] def get_learner_config(self, pipeline): raise NotImplementedError() def build(self, pipeline, tmp_dir): raise NotImplementedError()
class VectorSourceConfig(Config): default_class_id: Optional[int] = Field( ..., description= ('The default class_id to use if class cannot be inferred using other ' 'mechanisms. If a feature has an inferred class_id of None, then it ' 'will be deleted.')) class_id_to_filter: Optional[Dict] = Field( None, description= ('Map from class_id to JSON filter used to infer missing class_ids. ' 'Each key should be a class id, and its value should be a boolean ' 'expression which is run against the property field for each feature.' 'This allows matching different features to different class ids based on ' 'its properties. The expression schema is that described by ' 'https://docs.mapbox.com/mapbox-gl-js/style-spec/other/#other-filter' )) line_bufs: Optional[Dict[int, Union[int, float, None]]] = Field( None, description= ('This is useful, for example, for buffering lines representing roads so that ' 'their width roughly matches the width of roads in the imagery. If None, uses ' 'default buffer value of 1. Otherwise, a map from class_id to ' 'number of pixels to buffer by. If the buffer value is None, then no buffering ' 'will be performed and the LineString or Point won\'t get converted to a ' 'Polygon. Not converting to Polygon is incompatible with the currently ' 'available LabelSources, but may be useful in the future.')) point_bufs: Optional[Dict[int, Union[int, float, None]]] = Field( None, description= 'Same as above, but used for buffering Points into Polygons.') def has_null_class_bufs(self): if self.point_bufs is not None: for c, v in self.point_bufs.items(): if v is None: return True if self.line_bufs is not None: for c, v in self.line_bufs.items(): if v is None: return True return False def build(self, class_config, crs_transformer): raise NotImplementedError() def update(self, pipeline=None, scene=None): pass
class SemanticSegmentationPredictOptions(PredictOptions): stride: Optional[int] = Field( None, description= 'Stride of windows across image. Allows aggregating multiple ' 'predictions for each pixel if less than the chip size and outputting ' 'smooth labels. Defaults to predict_chip_sz.')
class RegressionGeoDataConfig(RegressionDataConfig, GeoDataConfig): plot_options: Optional[RegressionPlotOptions] = Field( RegressionPlotOptions(), description='Options to control plotting.') def scene_to_dataset(self, scene: Scene, transform: Optional[A.BasicTransform] = None ) -> Dataset: if isinstance(self.window_opts, dict): opts = self.window_opts[scene.id] else: opts = self.window_opts if opts.method == GeoDataWindowMethod.sliding: ds = RegressionSlidingWindowGeoDataset( scene, size=opts.size, stride=opts.stride, padding=opts.padding, transform=transform) elif opts.method == GeoDataWindowMethod.random: ds = RegressionRandomWindowGeoDataset( scene, size_lims=opts.size_lims, h_lims=opts.h_lims, w_lims=opts.w_lims, out_size=opts.size, padding=opts.padding, max_windows=opts.max_windows, max_sample_attempts=opts.max_sample_attempts, transform=transform) else: raise NotImplementedError() return ds
class CastTransformerConfig(RasterTransformerConfig): to_dtype: str = Field( ..., description='dtype to cast raster to. Must be a valid Numpy dtype ' 'e.g. "uint8", "float32", etc.') def build(self): return CastTransformer(to_dtype=self.to_dtype)
class ExternalModuleConfig(Config): """Config describing an object to be loaded via Torch Hub.""" uri: Optional[NonEmptyStr] = Field( None, description=('Local uri of a zip file, or local uri of a directory,' 'or remote uri of zip file.')) github_repo: Optional[constr( strip_whitespace=True, regex=r'.+/.+')] = Field(None, description='<repo-owner>/<repo-name>[:tag]') name: Optional[NonEmptyStr] = Field( None, description= 'Name of the folder in which to extract/copy the definition files.') entrypoint: NonEmptyStr = Field( ..., description=('Name of a callable present in hubconf.py. ' 'See docs for torch.hub for details.')) entrypoint_args: list = Field( [], description='Args to pass to the entrypoint. Must be serializable.') entrypoint_kwargs: dict = Field( {}, description= 'Keyword args to pass to the entrypoint. Must be serializable.') force_reload: bool = Field( False, description='Force reload of module definition.') def validate_config(self): has_uri = self.uri is not None has_repo = self.github_repo is not None if has_uri == has_repo: raise ConfigError('Must specify one of github_repo and uri.')
class ObjectDetectionChipOptions(Config): neg_ratio: float = Field( 1.0, description= ('The ratio of negative chips (those containing no bounding ' 'boxes) to positive chips. This can be useful if the statistics ' 'of the background is different in positive chips. For example, ' 'in car detection, the positive chips will always contain roads, ' 'but no examples of rooftops since cars tend to not be near rooftops.' )) ioa_thresh: float = Field( 0.8, description= ('When a box is partially outside of a training chip, it is not clear if (a ' 'clipped version) of the box should be included in the chip. If the IOA ' '(intersection over area) of the box with the chip is greater than ioa_thresh, ' 'it is included in the chip.')) window_method: ObjectDetectionWindowMethod = ObjectDetectionWindowMethod.chip
class CastTransformerConfig(RasterTransformerConfig): to_dtype: Optional[str] = Field( 'np.uint8', description=('dtype to cast raster to.')) def update(self, pipeline=None, scene=None): if pipeline is not None and self.to_dtype is None: self.to_dtype = pipeline.to_dtype def build(self): return CastTransformer(to_dtype=self.to_dtype)
class EvaluatorConfig(Config): output_uri: Optional[str] = Field( None, description= ('URI of JSON output by evaluator. If None, and this Config is part of an ' 'RVPipeline, then this field will be auto-generated.')) def update(self, pipeline=None): if pipeline is not None and self.output_uri is None: self.output_uri = join(pipeline.eval_uri, 'eval.json')
class SubRasterSourceConfig(Config): raster_source: RasterSourceConfig = Field( ..., description= 'A RasterSourceConfig that will provide a subset of the channels.') target_channels: Sequence[conint(ge=0)] = Field( ..., description='Channel indices to send each of the channels in this ' 'raster source to.') @validator('target_channels') def non_empty_target_channels(cls, v): if len(v) == 0: raise ConfigError('target_channels should be non-empty.') return list(v) def build(self, tmp_dir, use_transformers=True): rs = self.raster_source.build(tmp_dir, use_transformers) return rs
class NanTransformerConfig(RasterTransformerConfig): to_value: Optional[float] = Field( 0.0, description=('Turn all NaN values into this value.')) def update(self, pipeline=None, scene=None): if pipeline is not None and self.to_value is None: self.to_value = pipeline.to_value def build(self): return NanTransformer(to_value=self.to_value)
class GeoDataWindowConfig(Config): method: GeoDataWindowMethod = Field( GeoDataWindowMethod.sliding, description='') size: Union[PosInt, Tuple[PosInt, PosInt]] = Field( ..., description='If method = sliding, this is the size of sliding window. ' 'If method = random, this is the size that all the windows are ' 'resized to before they are returned.') stride: Optional[Union[PosInt, Tuple[PosInt, PosInt]]] = Field( None, description='Stride of sliding window. Only used if method = sliding.') padding: Optional[Union[NonNegInt, Tuple[NonNegInt, NonNegInt]]] = Field( None, description='How many pixels are windows allowed to overflow ' 'the edges of the raster source.') size_lims: Optional[Tuple[PosInt, PosInt]] = Field( None, description='[min, max) interval from which window sizes will be ' 'uniformly randomly sampled. The upper limit is exclusive. To fix the ' 'size to a constant value, use size_lims = (sz, sz + 1). ' 'Only used if method = random. Must specify either size_lims or ' 'h and w lims, but not both.') h_lims: Optional[Tuple[PosInt, PosInt]] = Field( None, description='[min, max] interval from which window heights will be ' 'uniformly randomly sampled. Only used if method = random.') w_lims: Optional[Tuple[PosInt, PosInt]] = Field( None, description='[min, max] interval from which window widths will be ' 'uniformly randomly sampled. Only used if method = random.') max_windows: NonNegInt = Field( 10_000, description='Max allowed reads from a GeoDataset. Only used if ' 'method = random.') max_sample_attempts: PosInt = Field( 100, description='Max attempts when trying to find a window within the AOI ' 'of a scene. Only used if method = random and the scene has ' 'aoi_polygons specified.') def validate_config(self): if self.method == GeoDataWindowMethod.sliding: if self.stride is None: raise ConfigError('stride must be specified if using ' 'GeoDataWindowMethod.sliding') elif self.method == GeoDataWindowMethod.random: has_size_lims = self.size_lims is not None has_h_lims = self.h_lims is not None has_w_lims = self.w_lims is not None if has_size_lims == (has_w_lims or has_h_lims): raise ConfigError('Specify either size_lims or h and w lims.') if has_h_lims != has_w_lims: raise ConfigError('h_lims and w_lims must both be specified')
class RegressionImageDataConfig(RegressionDataConfig, ImageDataConfig): data_format: RegressionDataFormat = RegressionDataFormat.csv plot_options: Optional[RegressionPlotOptions] = Field( RegressionPlotOptions(), description='Options to control plotting.') def dir_to_dataset(self, data_dir: str, transform: A.BasicTransform) -> Dataset: ds = RegressionImageDataset( data_dir, self.class_names, transform=transform) return ds