예제 #1
0
    def setup_data(self):
        """Set the the DataSet and DataLoaders for train, validation, and test sets."""
        cfg = self.cfg
        batch_sz = self.cfg.solver.batch_sz
        num_workers = self.cfg.data.num_workers

        train_ds, valid_ds, test_ds = self.get_datasets()
        if len(train_ds) < batch_sz:
            raise ConfigError(
                'Training dataset has fewer elements than batch size.')
        if len(valid_ds) < batch_sz:
            raise ConfigError(
                'Validation dataset has fewer elements than batch size.')
        if len(test_ds) < batch_sz:
            raise ConfigError(
                'Test dataset has fewer elements than batch size.')

        if cfg.overfit_mode:
            train_ds = Subset(train_ds, range(batch_sz))
            valid_ds = train_ds
            test_ds = train_ds
        elif cfg.test_mode:
            train_ds = Subset(train_ds, range(batch_sz))
            valid_ds = Subset(valid_ds, range(batch_sz))
            test_ds = Subset(test_ds, range(batch_sz))

        if cfg.data.train_sz is not None:
            train_inds = list(range(len(train_ds)))
            random.seed(1234)
            random.shuffle(train_inds)
            train_inds = train_inds[0:cfg.data.train_sz]
            train_ds = Subset(train_ds, train_inds)

        collate_fn = self.get_collate_fn()
        train_dl = DataLoader(train_ds,
                              shuffle=True,
                              batch_size=batch_sz,
                              drop_last=True,
                              num_workers=num_workers,
                              pin_memory=True,
                              collate_fn=collate_fn)
        valid_dl = DataLoader(valid_ds,
                              shuffle=True,
                              batch_size=batch_sz,
                              num_workers=num_workers,
                              pin_memory=True,
                              collate_fn=collate_fn)
        test_dl = DataLoader(test_ds,
                             shuffle=True,
                             batch_size=batch_sz,
                             num_workers=num_workers,
                             pin_memory=True,
                             collate_fn=collate_fn)

        self.train_ds, self.valid_ds, self.test_ds = (train_ds, valid_ds,
                                                      test_ds)
        self.train_dl, self.valid_dl, self.test_dl = (train_dl, valid_dl,
                                                      test_dl)
예제 #2
0
    def validate_config(self):
        ids = [s.id for s in self.train_scenes]
        if len(set(ids)) != len(ids):
            raise ConfigError('All training scene ids must be unique.')

        ids = [s.id for s in self.validation_scenes + self.test_scenes]
        if len(set(ids)) != len(ids):
            raise ConfigError(
                'All validation and test scene ids must be unique.')
 def validate_solver_config(cls, v):
     if v.class_loss_weights is not None:
         raise ConfigError(
             'class_loss_weights is currently not supported for '
             'Object Detection.')
     if v.external_loss_def is not None:
         raise ConfigError(
             'external_loss_def is currently not supported for '
             'Object Detection.')
     return v
예제 #4
0
    def validate_config(self):
        has_weights = self.class_loss_weights is not None
        has_external_loss_def = self.external_loss_def is not None

        if self.ignore_last_class and has_external_loss_def:
            raise ConfigError(
                'ignore_last_class is not supported with external_loss_def.')

        if has_weights and has_external_loss_def:
            raise ConfigError(
                'class_loss_weights is not supported with external_loss_def.')
예제 #5
0
 def validate_extent_crop(cls, v):
     if v is None:
         return v
     skip_top, skip_left, skip_bottom, skip_right = v
     if skip_top + skip_bottom >= 1:
         raise ConfigError(
             'Invalid crop. skip_top + skip_bottom must be less than 1.')
     if skip_left + skip_right >= 1:
         raise ConfigError(
             'Invalid crop. skip_left + skip_right must be less than 1.')
     return v
예제 #6
0
 def validate_config(self):
     if self.method == GeoDataWindowMethod.sliding:
         if self.stride is None:
             raise ConfigError('stride must be specified if using '
                               'GeoDataWindowMethod.sliding')
     elif self.method == GeoDataWindowMethod.random:
         has_size_lims = self.size_lims is not None
         has_h_lims = self.h_lims is not None
         has_w_lims = self.w_lims is not None
         if has_size_lims == (has_w_lims or has_h_lims):
             raise ConfigError('Specify either size_lims or h and w lims.')
         if has_h_lims != has_w_lims:
             raise ConfigError('h_lims and w_lims must both be specified')
예제 #7
0
    def validate_config(self):
        has_weights = self.class_loss_weights is not None
        has_external_loss_def = self.external_loss_def is not None

        if self.ignore_last_class is True and has_external_loss_def:
            raise ConfigError(
                'ignore_last_class=True is not supported with external_loss_def.  '
                'Please carefully considering using ignore_last_class=\'force\' '
                'and setting the external loss function to ignore the last index.'
            )

        if has_weights and has_external_loss_def:
            raise ConfigError(
                'class_loss_weights is not supported with external_loss_def.')
예제 #8
0
    def validate_channel_mappings(self, channel_mappings: Sequence[int],
                                  raw_channel_order: Sequence[int]):
        # validate completeness of mappings
        src_inds = set(range(len(channel_mappings)))
        tgt_inds = set(channel_mappings)
        if src_inds != tgt_inds:
            raise ConfigError('Missing mappings for some channels.')

        # check compatibility with channel_order, if given
        if self.channel_order:
            if len(self.channel_order) != len(raw_channel_order):
                raise ConfigError(
                    f'Channel mappings ({raw_channel_order}) and '
                    f'channel_order ({self.channel_order}) are incompatible.')
예제 #9
0
    def build_model(self) -> nn.Module:
        # TODO support FCN option
        pretrained = self.cfg.model.pretrained
        out_classes = len(self.cfg.data.class_names)
        if self.cfg.solver.ignore_last_class:
            out_classes -= 1
        model = models.segmentation.segmentation._segm_model(
            'deeplabv3',
            self.cfg.model.get_backbone_str(),
            out_classes,
            False,
            pretrained_backbone=pretrained)

        input_channels = self.cfg.data.img_channels
        old_conv = model.backbone.conv1

        if input_channels == old_conv.in_channels:
            return model

        # these parameters will be the same for the new conv layer
        old_conv_args = {
            'out_channels': old_conv.out_channels,
            'kernel_size': old_conv.kernel_size,
            'stride': old_conv.stride,
            'padding': old_conv.padding,
            'dilation': old_conv.dilation,
            'groups': old_conv.groups,
            'bias': old_conv.bias
        }

        if not pretrained:
            # simply replace the first conv layer with one with the
            # correct number of input channels
            new_conv = nn.Conv2d(in_channels=input_channels, **old_conv_args)
            model.backbone.conv1 = new_conv
            return model

        if input_channels > old_conv.in_channels:
            # insert a new conv layer parallel to the existing one
            # and sum their outputs
            new_conv_channels = input_channels - old_conv.in_channels
            new_conv = nn.Conv2d(
                in_channels=new_conv_channels, **old_conv_args)
            model.backbone.conv1 = nn.Sequential(
                # split input along channel dim
                SplitTensor((old_conv.in_channels, new_conv_channels), dim=1),
                # each split goes to its respective conv layer
                Parallel(old_conv, new_conv),
                # sum the parallel outputs
                AddTensors())
        elif input_channels < old_conv.in_channels:
            model.backbone.conv1 = nn.Conv2d(
                in_channels=input_channels, **old_conv_args)
            model.backbone.conv1.weight.data[:, :input_channels] = \
                old_conv.weight.data[:, :input_channels]
        else:
            raise ConfigError(f'Something went wrong')

        return model
예제 #10
0
 def validate_config(self, *args, **kwargs):
     super().update(*args, **kwargs)
     if isinstance(self.window_opts, dict):
         scenes = self.scene_dataset.get_all_scenes()
         for s in scenes:
             if s.id not in self.window_opts:
                 raise ConfigError(
                     f'Window config not found for scene {s.id}')
예제 #11
0
def validate_albumentation_transform(tf: dict):
    """ Validate a serialized albumentation transform. """
    if tf is not None:
        try:
            A.from_dict(tf)
        except Exception:
            raise ConfigError('The given serialization is invalid. Use '
                              'A.to_dict(transform) to serialize.')
    return tf
    def validate_class_loss_weights(self):
        if self.solver.class_loss_weights is None:
            return

        num_weights = len(self.solver.class_loss_weights)
        num_classes = len(self.data.class_names)
        if num_weights != num_classes:
            raise ConfigError(
                f'class_loss_weights ({num_weights}) must be same length as '
                f'the number of classes ({num_classes}), null class included')
예제 #13
0
    def validate_config(self):
        super().validate_config()

        if self.dataset.img_channels is None:
            return

        if self.img_format == 'png' and self.dataset.img_channels != 3:
            raise ConfigError('img_channels must be 3 if img_format is png.')

        self.validate_channel_display_groups()
예제 #14
0
 def validate_solver_config(cls, v):
     if v.class_loss_weights is not None:
         from rastervision.pytorch_backend import (
             PyTorchSemanticSegmentationConfig,
             PyTorchChipClassificationConfig)
         if cls not in (PyTorchSemanticSegmentationConfig,
                        PyTorchChipClassificationConfig):
             raise ConfigError(
                 'class_loss_weights is currently only supported for '
                 'Semantic Segmentation and Chip Classification.')
     return v
예제 #15
0
    def validate_group_uris(self):
        has_group_train_sz = self.group_train_sz is not None
        has_group_train_sz_rel = self.group_train_sz_rel is not None
        has_group_uris = self.group_uris is not None

        if has_group_train_sz and has_group_train_sz_rel:
            raise ConfigError('Only one of group_train_sz and '
                              'group_train_sz_rel should be specified.')
        if has_group_train_sz and not has_group_uris:
            raise ConfigError('group_train_sz specified without group_uris.')
        if has_group_train_sz_rel and not has_group_uris:
            raise ConfigError(
                'group_train_sz_rel specified without group_uris.')
        if has_group_train_sz and sequence_like(self.group_train_sz):
            if len(self.group_train_sz) != len(self.group_uris):
                raise ConfigError('len(group_train_sz) != len(group_uris).')
        if has_group_train_sz_rel and sequence_like(self.group_train_sz_rel):
            if len(self.group_train_sz_rel) != len(self.group_uris):
                raise ConfigError(
                    'len(group_train_sz_rel) != len(group_uris).')
예제 #16
0
 def update(self, pipeline: Optional[RVPipeline] = None):
     super().update(pipeline=pipeline)
     dcfg = self.data
     pcfg = pipeline
     if isinstance(dcfg, SemanticSegmentationImageDataConfig):
         if (dcfg.img_format
                 is not None) and (dcfg.img_format != pcfg.img_format):
             raise ConfigError(
                 'SemanticSegmentationImageDataConfig.img_format is '
                 'specified and not equal to '
                 'SemanticSegmentationConfig.img_format.')
         if (dcfg.label_format
                 is not None) and (dcfg.label_format != pcfg.label_format):
             raise ConfigError(
                 'SemanticSegmentationImageDataConfig.label_format is '
                 'specified and not equal to '
                 'SemanticSegmentationConfig.label_format.')
         if dcfg.img_format is None:
             dcfg.img_format = pcfg.img_format
         if dcfg.label_format is None:
             dcfg.label_format = pcfg.label_format
예제 #17
0
    def validate_channel_display_groups(self):
        def _are_ints(ints) -> bool:
            return all(isinstance(i, int) for i in ints)

        def _in_range(inds, lt: int) -> bool:
            return all(0 <= i < lt for i in inds)

        img_channels = self.dataset.img_channels
        groups = self.channel_display_groups

        # validate dict form
        if isinstance(groups, dict):
            for k, v in groups.items():
                if not isinstance(k, str):
                    raise ConfigError(
                        'channel_display_groups keys must be strings.')
                if not isinstance(v, (list, tuple)):
                    raise ConfigError(
                        'channel_display_groups values must be lists or tuples.'
                    )
                if not (0 < len(v) <= 3):
                    raise ConfigError(
                        f'channel_display_groups[{k}]: len(group) must be 1, 2, or 3'
                    )
                if not (_are_ints(v) and _in_range(v, lt=img_channels)):
                    raise ConfigError(
                        f'Invalid channel indices in channel_display_groups[{k}].'
                    )
        # validate list/tuple form
        elif isinstance(groups, (list, tuple)):
            for i, grp in enumerate(groups):
                if not (0 < len(grp) <= 3):
                    raise ConfigError(
                        f'channel_display_groups[{i}]: len(group) must be 1, 2, or 3'
                    )
                if not (_are_ints(grp) and _in_range(grp, lt=img_channels)):
                    raise ConfigError(
                        f'Invalid channel index in channel_display_groups[{i}].'
                    )
예제 #18
0
 def validate_raster_sources(cls, v):
     if len(v) == 0:
         raise ConfigError('raster_sources should be non-empty.')
     return v
예제 #19
0
 def validate_config(self):
     if self.run_tensorboard and not self.log_tensorboard:
         raise ConfigError(
             'Cannot run_tensorboard if log_tensorboard is False')
예제 #20
0
 def ensure_same_channel_order(self):
     all_scenes = self.train_scenes + self.validation_scenes + self.test_scenes
     channel_orders = [s.raster_source.channel_order for s in all_scenes]
     if not all_equal(channel_orders):
         raise ConfigError('channel_order must be same for all scenes.')
예제 #21
0
 def validate_config(self):
     if self.vector_source.has_null_class_bufs():
         raise ConfigError(
             'Setting buffer to None for a class in the vector_source is '
             'not allowed for ChipClassificationLabelSourceConfig.')
예제 #22
0
 def validate_config(self):
     if self.sample_prob > 1 or self.sample_prob <= 0:
         raise ConfigError('sample_prob must be <= 1 and > 0')
예제 #23
0
 def validate_config(self):
     has_uri = self.uri is not None
     has_repo = self.github_repo is not None
     if has_uri == has_repo:
         raise ConfigError('Must specify one of github_repo and uri.')
 def validate_ignore_last_class(self):
     if self.solver.ignore_last_class:
         raise ConfigError(
             'ignore_last_class is not supported for Chip Classification.')
예제 #25
0
 def validate_config(self):
     if self.null_class is not None and self.null_class not in self.names:
         raise ConfigError(
             'The null_class: {} must be in list of class names.'.format(
                 self.null_class))
 def validate_config(self):
     if self.train_chip_sz != self.predict_chip_sz:
         raise ConfigError(
             'train_chip_sz must be equal to predict_chip_sz for chip '
             'classification.')
 def validate_model_config(cls, v):
     if v.external_def is not None:
         raise ConfigError('external_def is currently not supported for '
                           'Object Detection.')
     return v
예제 #28
0
 def non_empty_target_channels(cls, v):
     if len(v) == 0:
         raise ConfigError('target_channels should be non-empty.')
     return list(v)
 def validate_config(self):
     if self.vector_source.has_null_class_bufs():
         raise ConfigError(
             'Setting buffer to None for a class in the vector_source is '
             'not allowed for RasterizedSourceConfig.')