def setup_data(self): """Set the the DataSet and DataLoaders for train, validation, and test sets.""" cfg = self.cfg batch_sz = self.cfg.solver.batch_sz num_workers = self.cfg.data.num_workers train_ds, valid_ds, test_ds = self.get_datasets() if len(train_ds) < batch_sz: raise ConfigError( 'Training dataset has fewer elements than batch size.') if len(valid_ds) < batch_sz: raise ConfigError( 'Validation dataset has fewer elements than batch size.') if len(test_ds) < batch_sz: raise ConfigError( 'Test dataset has fewer elements than batch size.') if cfg.overfit_mode: train_ds = Subset(train_ds, range(batch_sz)) valid_ds = train_ds test_ds = train_ds elif cfg.test_mode: train_ds = Subset(train_ds, range(batch_sz)) valid_ds = Subset(valid_ds, range(batch_sz)) test_ds = Subset(test_ds, range(batch_sz)) if cfg.data.train_sz is not None: train_inds = list(range(len(train_ds))) random.seed(1234) random.shuffle(train_inds) train_inds = train_inds[0:cfg.data.train_sz] train_ds = Subset(train_ds, train_inds) collate_fn = self.get_collate_fn() train_dl = DataLoader(train_ds, shuffle=True, batch_size=batch_sz, drop_last=True, num_workers=num_workers, pin_memory=True, collate_fn=collate_fn) valid_dl = DataLoader(valid_ds, shuffle=True, batch_size=batch_sz, num_workers=num_workers, pin_memory=True, collate_fn=collate_fn) test_dl = DataLoader(test_ds, shuffle=True, batch_size=batch_sz, num_workers=num_workers, pin_memory=True, collate_fn=collate_fn) self.train_ds, self.valid_ds, self.test_ds = (train_ds, valid_ds, test_ds) self.train_dl, self.valid_dl, self.test_dl = (train_dl, valid_dl, test_dl)
def validate_config(self): ids = [s.id for s in self.train_scenes] if len(set(ids)) != len(ids): raise ConfigError('All training scene ids must be unique.') ids = [s.id for s in self.validation_scenes + self.test_scenes] if len(set(ids)) != len(ids): raise ConfigError( 'All validation and test scene ids must be unique.')
def validate_solver_config(cls, v): if v.class_loss_weights is not None: raise ConfigError( 'class_loss_weights is currently not supported for ' 'Object Detection.') if v.external_loss_def is not None: raise ConfigError( 'external_loss_def is currently not supported for ' 'Object Detection.') return v
def validate_config(self): has_weights = self.class_loss_weights is not None has_external_loss_def = self.external_loss_def is not None if self.ignore_last_class and has_external_loss_def: raise ConfigError( 'ignore_last_class is not supported with external_loss_def.') if has_weights and has_external_loss_def: raise ConfigError( 'class_loss_weights is not supported with external_loss_def.')
def validate_extent_crop(cls, v): if v is None: return v skip_top, skip_left, skip_bottom, skip_right = v if skip_top + skip_bottom >= 1: raise ConfigError( 'Invalid crop. skip_top + skip_bottom must be less than 1.') if skip_left + skip_right >= 1: raise ConfigError( 'Invalid crop. skip_left + skip_right must be less than 1.') return v
def validate_config(self): if self.method == GeoDataWindowMethod.sliding: if self.stride is None: raise ConfigError('stride must be specified if using ' 'GeoDataWindowMethod.sliding') elif self.method == GeoDataWindowMethod.random: has_size_lims = self.size_lims is not None has_h_lims = self.h_lims is not None has_w_lims = self.w_lims is not None if has_size_lims == (has_w_lims or has_h_lims): raise ConfigError('Specify either size_lims or h and w lims.') if has_h_lims != has_w_lims: raise ConfigError('h_lims and w_lims must both be specified')
def validate_config(self): has_weights = self.class_loss_weights is not None has_external_loss_def = self.external_loss_def is not None if self.ignore_last_class is True and has_external_loss_def: raise ConfigError( 'ignore_last_class=True is not supported with external_loss_def. ' 'Please carefully considering using ignore_last_class=\'force\' ' 'and setting the external loss function to ignore the last index.' ) if has_weights and has_external_loss_def: raise ConfigError( 'class_loss_weights is not supported with external_loss_def.')
def validate_channel_mappings(self, channel_mappings: Sequence[int], raw_channel_order: Sequence[int]): # validate completeness of mappings src_inds = set(range(len(channel_mappings))) tgt_inds = set(channel_mappings) if src_inds != tgt_inds: raise ConfigError('Missing mappings for some channels.') # check compatibility with channel_order, if given if self.channel_order: if len(self.channel_order) != len(raw_channel_order): raise ConfigError( f'Channel mappings ({raw_channel_order}) and ' f'channel_order ({self.channel_order}) are incompatible.')
def build_model(self) -> nn.Module: # TODO support FCN option pretrained = self.cfg.model.pretrained out_classes = len(self.cfg.data.class_names) if self.cfg.solver.ignore_last_class: out_classes -= 1 model = models.segmentation.segmentation._segm_model( 'deeplabv3', self.cfg.model.get_backbone_str(), out_classes, False, pretrained_backbone=pretrained) input_channels = self.cfg.data.img_channels old_conv = model.backbone.conv1 if input_channels == old_conv.in_channels: return model # these parameters will be the same for the new conv layer old_conv_args = { 'out_channels': old_conv.out_channels, 'kernel_size': old_conv.kernel_size, 'stride': old_conv.stride, 'padding': old_conv.padding, 'dilation': old_conv.dilation, 'groups': old_conv.groups, 'bias': old_conv.bias } if not pretrained: # simply replace the first conv layer with one with the # correct number of input channels new_conv = nn.Conv2d(in_channels=input_channels, **old_conv_args) model.backbone.conv1 = new_conv return model if input_channels > old_conv.in_channels: # insert a new conv layer parallel to the existing one # and sum their outputs new_conv_channels = input_channels - old_conv.in_channels new_conv = nn.Conv2d( in_channels=new_conv_channels, **old_conv_args) model.backbone.conv1 = nn.Sequential( # split input along channel dim SplitTensor((old_conv.in_channels, new_conv_channels), dim=1), # each split goes to its respective conv layer Parallel(old_conv, new_conv), # sum the parallel outputs AddTensors()) elif input_channels < old_conv.in_channels: model.backbone.conv1 = nn.Conv2d( in_channels=input_channels, **old_conv_args) model.backbone.conv1.weight.data[:, :input_channels] = \ old_conv.weight.data[:, :input_channels] else: raise ConfigError(f'Something went wrong') return model
def validate_config(self, *args, **kwargs): super().update(*args, **kwargs) if isinstance(self.window_opts, dict): scenes = self.scene_dataset.get_all_scenes() for s in scenes: if s.id not in self.window_opts: raise ConfigError( f'Window config not found for scene {s.id}')
def validate_albumentation_transform(tf: dict): """ Validate a serialized albumentation transform. """ if tf is not None: try: A.from_dict(tf) except Exception: raise ConfigError('The given serialization is invalid. Use ' 'A.to_dict(transform) to serialize.') return tf
def validate_class_loss_weights(self): if self.solver.class_loss_weights is None: return num_weights = len(self.solver.class_loss_weights) num_classes = len(self.data.class_names) if num_weights != num_classes: raise ConfigError( f'class_loss_weights ({num_weights}) must be same length as ' f'the number of classes ({num_classes}), null class included')
def validate_config(self): super().validate_config() if self.dataset.img_channels is None: return if self.img_format == 'png' and self.dataset.img_channels != 3: raise ConfigError('img_channels must be 3 if img_format is png.') self.validate_channel_display_groups()
def validate_solver_config(cls, v): if v.class_loss_weights is not None: from rastervision.pytorch_backend import ( PyTorchSemanticSegmentationConfig, PyTorchChipClassificationConfig) if cls not in (PyTorchSemanticSegmentationConfig, PyTorchChipClassificationConfig): raise ConfigError( 'class_loss_weights is currently only supported for ' 'Semantic Segmentation and Chip Classification.') return v
def validate_group_uris(self): has_group_train_sz = self.group_train_sz is not None has_group_train_sz_rel = self.group_train_sz_rel is not None has_group_uris = self.group_uris is not None if has_group_train_sz and has_group_train_sz_rel: raise ConfigError('Only one of group_train_sz and ' 'group_train_sz_rel should be specified.') if has_group_train_sz and not has_group_uris: raise ConfigError('group_train_sz specified without group_uris.') if has_group_train_sz_rel and not has_group_uris: raise ConfigError( 'group_train_sz_rel specified without group_uris.') if has_group_train_sz and sequence_like(self.group_train_sz): if len(self.group_train_sz) != len(self.group_uris): raise ConfigError('len(group_train_sz) != len(group_uris).') if has_group_train_sz_rel and sequence_like(self.group_train_sz_rel): if len(self.group_train_sz_rel) != len(self.group_uris): raise ConfigError( 'len(group_train_sz_rel) != len(group_uris).')
def update(self, pipeline: Optional[RVPipeline] = None): super().update(pipeline=pipeline) dcfg = self.data pcfg = pipeline if isinstance(dcfg, SemanticSegmentationImageDataConfig): if (dcfg.img_format is not None) and (dcfg.img_format != pcfg.img_format): raise ConfigError( 'SemanticSegmentationImageDataConfig.img_format is ' 'specified and not equal to ' 'SemanticSegmentationConfig.img_format.') if (dcfg.label_format is not None) and (dcfg.label_format != pcfg.label_format): raise ConfigError( 'SemanticSegmentationImageDataConfig.label_format is ' 'specified and not equal to ' 'SemanticSegmentationConfig.label_format.') if dcfg.img_format is None: dcfg.img_format = pcfg.img_format if dcfg.label_format is None: dcfg.label_format = pcfg.label_format
def validate_channel_display_groups(self): def _are_ints(ints) -> bool: return all(isinstance(i, int) for i in ints) def _in_range(inds, lt: int) -> bool: return all(0 <= i < lt for i in inds) img_channels = self.dataset.img_channels groups = self.channel_display_groups # validate dict form if isinstance(groups, dict): for k, v in groups.items(): if not isinstance(k, str): raise ConfigError( 'channel_display_groups keys must be strings.') if not isinstance(v, (list, tuple)): raise ConfigError( 'channel_display_groups values must be lists or tuples.' ) if not (0 < len(v) <= 3): raise ConfigError( f'channel_display_groups[{k}]: len(group) must be 1, 2, or 3' ) if not (_are_ints(v) and _in_range(v, lt=img_channels)): raise ConfigError( f'Invalid channel indices in channel_display_groups[{k}].' ) # validate list/tuple form elif isinstance(groups, (list, tuple)): for i, grp in enumerate(groups): if not (0 < len(grp) <= 3): raise ConfigError( f'channel_display_groups[{i}]: len(group) must be 1, 2, or 3' ) if not (_are_ints(grp) and _in_range(grp, lt=img_channels)): raise ConfigError( f'Invalid channel index in channel_display_groups[{i}].' )
def validate_raster_sources(cls, v): if len(v) == 0: raise ConfigError('raster_sources should be non-empty.') return v
def validate_config(self): if self.run_tensorboard and not self.log_tensorboard: raise ConfigError( 'Cannot run_tensorboard if log_tensorboard is False')
def ensure_same_channel_order(self): all_scenes = self.train_scenes + self.validation_scenes + self.test_scenes channel_orders = [s.raster_source.channel_order for s in all_scenes] if not all_equal(channel_orders): raise ConfigError('channel_order must be same for all scenes.')
def validate_config(self): if self.vector_source.has_null_class_bufs(): raise ConfigError( 'Setting buffer to None for a class in the vector_source is ' 'not allowed for ChipClassificationLabelSourceConfig.')
def validate_config(self): if self.sample_prob > 1 or self.sample_prob <= 0: raise ConfigError('sample_prob must be <= 1 and > 0')
def validate_config(self): has_uri = self.uri is not None has_repo = self.github_repo is not None if has_uri == has_repo: raise ConfigError('Must specify one of github_repo and uri.')
def validate_ignore_last_class(self): if self.solver.ignore_last_class: raise ConfigError( 'ignore_last_class is not supported for Chip Classification.')
def validate_config(self): if self.null_class is not None and self.null_class not in self.names: raise ConfigError( 'The null_class: {} must be in list of class names.'.format( self.null_class))
def validate_config(self): if self.train_chip_sz != self.predict_chip_sz: raise ConfigError( 'train_chip_sz must be equal to predict_chip_sz for chip ' 'classification.')
def validate_model_config(cls, v): if v.external_def is not None: raise ConfigError('external_def is currently not supported for ' 'Object Detection.') return v
def non_empty_target_channels(cls, v): if len(v) == 0: raise ConfigError('target_channels should be non-empty.') return list(v)
def validate_config(self): if self.vector_source.has_null_class_bufs(): raise ConfigError( 'Setting buffer to None for a class in the vector_source is ' 'not allowed for RasterizedSourceConfig.')