def from_model(cls, emd_path, data=None): emd_path = Path(emd_path) with open(emd_path) as f: emd = json.load(f) model_file = Path(emd['ModelFile']) if not model_file.is_absolute(): model_file = emd_path.parent / model_file model_params = emd['ModelParameters'] chip_size = emd["ImageWidth"] try: class_mapping = {i['Value']: i['Name'] for i in emd['Classes']} color_mapping = {i['Value']: i['Color'] for i in emd['Classes']} except KeyError: class_mapping = { i['ClassValue']: i['ClassName'] for i in emd['Classes'] } color_mapping = { i['ClassValue']: i['Color'] for i in emd['Classes'] } if data is None: ranges = (0, 1) train_tfms = [ rotate(degrees=30, p=0.5), crop(size=chip_size, p=1., row_pct=ranges, col_pct=ranges), dihedral_affine(), brightness(change=(0.4, 0.6)), contrast(scale=(0.75, 1.5)), # rand_zoom(scale=(0.75, 1.5)) ] val_tfms = [crop(size=chip_size, p=1.0, row_pct=0.5, col_pct=0.5)] transforms = (train_tfms, val_tfms) with warnings.catch_warnings(): warnings.simplefilter("ignore", UserWarning) tempdata = ImageDataBunch.single_from_classes( tempfile.TemporaryDirectory().name, sorted(list(class_mapping.values())), tfms=transforms, size=chip_size).normalize(imagenet_stats) tempdata.chip_size = chip_size return cls(tempdata, **model_params, pretrained_path=str(model_file)) else: return cls(data, **model_params, pretrained_path=str(model_file))
def prepare_data(path, class_mapping=None, chip_size=224, val_split_pct=0.1, batch_size=64, transforms=None, collate_fn=_bb_pad_collate, seed=42, dataset_type=None, resize_to=None, **kwargs): """ Prepares a data object from training sample exported by the Export Training Data tool in ArcGIS Pro or Image Server, or training samples in the supported dataset formats. This data object consists of training and validation data sets with the specified transformations, chip size, batch size, split percentage, etc. -For object detection, use Pascal_VOC_rectangles format. -For feature categorization use Labelled Tiles or ImageNet format. -For pixel classification, use Classified Tiles format. -For entity extraction from text, use IOB, BILUO or ner_json formats. ===================== =========================================== **Argument** **Description** --------------------- ------------------------------------------- path Required string. Path to data directory. --------------------- ------------------------------------------- class_mapping Optional dictionary. Mapping from id to its string label. For dataset_type=IOB, BILUO or ner_json: Provide address field as class mapping in below format: class_mapping={'address_tag':'address_field'} --------------------- ------------------------------------------- chip_size Optional integer. Size of the image to train the model. --------------------- ------------------------------------------- val_split_pct Optional float. Percentage of training data to keep as validation. --------------------- ------------------------------------------- batch_size Optional integer. Batch size for mini batch gradient descent (Reduce it if getting CUDA Out of Memory Errors). --------------------- ------------------------------------------- transforms Optional tuple. Fast.ai transforms for data augmentation of training and validation datasets respectively (We have set good defaults which work for satellite imagery well). If transforms is set to `False` no transformation will take place and `chip_size` parameter will also not take effect. --------------------- ------------------------------------------- collate_fn Optional function. Passed to PyTorch to collate data into batches(usually default works). --------------------- ------------------------------------------- seed Optional integer. Random seed for reproducible train-validation split. --------------------- ------------------------------------------- dataset_type Optional string. `prepare_data` function will infer the `dataset_type` on its own if it contains a map.txt file. If the path does not contain the map.txt file pass either of 'PASCAL_VOC_rectangles', 'RCNN_Masks' and 'Classified_Tiles' --------------------- ------------------------------------------- resize_to Optional integer. Resize the image to given size. ===================== =========================================== :returns: data object """ """kwargs documentation imagery_type='RGB' # Change to known imagery_type or anything else to trigger multispectral bands=None # sepcify bands type for unknow imagery ['r', 'g', 'b', 'nir'] rgb_bands=[0, 1, 2] # specify rgb bands indices for unknown imagery norm_pct=0.3 # sample of images to calculate normalization stats on do_normalize=True # Normalize data """ height_width = [] if not HAS_FASTAI: _raise_fastai_import_error() if isinstance(path, str) and not os.path.exists(path): raise Exception("Invalid input path.") if type(path) is str: path = Path(path) databunch_kwargs = {'num_workers': 0} if sys.platform == 'win32' else {} databunch_kwargs['bs'] = batch_size kwargs_transforms = {} if resize_to: kwargs_transforms['size'] = resize_to has_esri_files = _check_esri_files(path) alter_class_mapping = False color_mapping = None # Multispectral Kwargs init _bands = None _imagery_type = None _is_multispectral = False _show_batch_multispectral = None if dataset_type is None and not has_esri_files: raise Exception("Could not infer dataset type.") if dataset_type != "Imagenet" and has_esri_files: stats_file = path / 'esri_accumulated_stats.json' with open(stats_file) as f: stats = json.load(f) dataset_type = stats['MetaDataMode'] with open(path / 'map.txt') as f: line = f.readline() right = line.split()[1].split('.')[-1].lower() json_file = path / 'esri_model_definition.emd' with open(json_file) as f: emd = json.load(f) # Create Class Mapping from EMD if not specified by user ## Validate user defined class_mapping keys with emd (issue #3064) # Get classmapping from emd file. try: emd_class_mapping = {i['Value']: i['Name'] for i in emd['Classes']} except KeyError: emd_class_mapping = { i['ClassValue']: i['ClassName'] for i in emd['Classes'] } ## Change all keys to int. if class_mapping is not None: class_mapping = { int(key): value for key, value in class_mapping.items() } else: class_mapping = {} ## Map values from user defined classmapping to emd classmapping. for key, _ in emd_class_mapping.items(): if class_mapping.get(key) is not None: emd_class_mapping[key] = class_mapping[key] class_mapping = emd_class_mapping color_mapping = {(i.get('Value', 0) or i.get('ClassValue', 0)): i['Color'] for i in emd.get('Classes', [])} if color_mapping.get(None): del color_mapping[None] if class_mapping.get(None): del class_mapping[None] # Multispectral support from EMD # Not Implemented Yet if emd.get('bands', None) is not None: _bands = emd.get['bands'] # Not Implemented if emd.get('imagery_type', None) is not None: _imagery_type = emd.get['imagery_type'] # Not Implemented elif dataset_type == 'PASCAL_VOC_rectangles' and not has_esri_files: if class_mapping is None: class_mapping = _get_class_mapping(path / 'labels') alter_class_mapping = True # Multispectral check imagery_type = 'RGB' if kwargs.get('imagery_type', None) is not None: imagery_type = kwargs.get('imagery_type') elif _imagery_type is not None: imagery_type = _imagery_type bands = None if kwargs.get('bands', None) is not None: bands = kwargs.get('bands') for i, b in enumerate(bands): if type(b) == str: bands[i] = b.lower() elif imagery_type_lib.get(imagery_type, None) is not None: bands = imagery_type_lib.get(imagery_type)['bands'] elif _bands is not None: bands = _bands rgb_bands = None if kwargs.get('rgb_bands', None) is not None: rgb_bands = kwargs.get('rgb_bands') elif bands is not None: rgb_bands = [bands.index(b) for b in ['r', 'g', 'b'] if b in bands] if (bands is not None) or (rgb_bands is not None) or (not imagery_type == 'RGB'): if imagery_type == 'RGB': imagery_type = 'multispectral' _is_multispectral = True if kwargs.get('norm_pct', None) is not None: norm_pct = kwargs.get('norm_pct') norm_pct = min(max(0, norm_pct), 1) else: norm_pct = .3 lighting_transforms = kwargs.get('lighting_transforms', True) if dataset_type == 'RCNN_Masks': def get_labels(x, label_dirs, ext=right): label_path = [] for lbl in label_dirs: if os.path.exists(Path(lbl) / (x.stem + '.{}'.format(ext))): label_path.append(Path(lbl) / (x.stem + '.{}'.format(ext))) return label_path if class_mapping.get(0): del class_mapping[0] if color_mapping.get(0): del color_mapping[0] # Handle Multispectral if _is_multispectral: src = (ArcGISInstanceSegmentationMSItemList.from_folder( path / 'images').split_by_rand_pct(val_split_pct, seed=seed)) _show_batch_multispectral = show_batch_rcnn_masks else: src = (ArcGISInstanceSegmentationItemList.from_folder( path / 'images').split_by_rand_pct(val_split_pct, seed=seed)) label_dirs = [] index_dir = {} #for handling calss value with any number for i, k in enumerate(sorted(class_mapping.keys())): label_dirs.append(class_mapping[k]) index_dir[k] = i + 1 label_dir = [ os.path.join(path / 'labels', lbl) for lbl in label_dirs if os.path.isdir(os.path.join(path / 'labels', lbl)) ] get_y_func = partial(get_labels, label_dirs=label_dir) src = src.label_from_func(get_y_func, chip_size=chip_size, classes=['NoData'] + list(class_mapping.values()), class_mapping=class_mapping, color_mapping=color_mapping, index_dir=index_dir) elif dataset_type == 'Classified_Tiles': def get_y_func(x, ext=right): return x.parents[1] / 'labels' / (x.stem + '.{}'.format(ext)) if class_mapping.get(0): del class_mapping[0] if color_mapping.get(0): del color_mapping[0] if is_no_color(color_mapping): color_mapping = { j: [random.choice(range(256)) for i in range(3)] for j in class_mapping.keys() } # TODO : Handle NoData case # Handle Multispectral if _is_multispectral: data = ArcGISSegmentationMSItemList.from_folder(path/'images')\ .split_by_rand_pct(val_split_pct, seed=seed)\ .label_from_func( get_y_func, classes=(['NoData'] + list(class_mapping.values())), class_mapping=class_mapping, color_mapping=color_mapping ) _show_batch_multispectral = _show_batch_unet_multispectral def classified_tiles_collate_fn( samples ): # The default fastai collate_fn was causing memory leak on tensors r = (torch.stack([x[0].data for x in samples]), torch.stack([x[1].data for x in samples])) return r databunch_kwargs['collate_fn'] = classified_tiles_collate_fn else: data = ArcGISSegmentationItemList.from_folder(path/'images')\ .split_by_rand_pct(val_split_pct, seed=seed)\ .label_from_func( get_y_func, classes=(['NoData'] + list(class_mapping.values())), class_mapping=class_mapping, color_mapping=color_mapping ) if transforms is None: transforms = get_transforms(flip_vert=True, max_rotate=90., max_zoom=3.0, max_lighting=0.5) kwargs_transforms['tfm_y'] = True kwargs_transforms['size'] = chip_size elif dataset_type == 'PASCAL_VOC_rectangles': not_label_count = [0] get_y_func = partial(_get_bbox_lbls, class_mapping=class_mapping, not_label_count=not_label_count, height_width=height_width) if _is_multispectral: data = SSDObjectMSItemList.from_folder(path/'images')\ .split_by_rand_pct(val_split_pct, seed=seed)\ .label_from_func(get_y_func) _show_batch_multispectral = show_batch_pascal_voc_rectangles else: data = SSDObjectItemList.from_folder(path/'images')\ .split_by_rand_pct(val_split_pct, seed=seed)\ .label_from_func(get_y_func) if not_label_count[0]: logger = logging.getLogger() logger.warning("Please check your dataset. " + str(not_label_count[0]) + " images dont have the corresponding label files.") if transforms is None: ranges = (0, 1) train_tfms = [ crop(size=chip_size, p=1., row_pct=ranges, col_pct=ranges), dihedral_affine() if has_esri_files else flip_lr(), brightness(change=(0.4, 0.6)), contrast(scale=(0.75, 1.5)), rand_zoom(scale=(1.0, 1.5)) ] val_tfms = [crop(size=chip_size, p=1., row_pct=0.5, col_pct=0.5)] transforms = (train_tfms, val_tfms) kwargs_transforms['tfm_y'] = True databunch_kwargs['collate_fn'] = collate_fn elif dataset_type in ['Labeled_Tiles', 'Imagenet']: if dataset_type == 'Labeled_Tiles': get_y_func = partial(_get_lbls, class_mapping=class_mapping) else: def get_y_func(x): return x.parent.stem if _is_multispectral: data = ArcGISMSImageList.from_folder(path/'images')\ .split_by_rand_pct(val_split_pct, seed=42)\ .label_from_func(get_y_func) _show_batch_multispectral = show_batch_labeled_tiles else: data = ImageList.from_folder(path/'images')\ .split_by_rand_pct(val_split_pct, seed=42)\ .label_from_func(get_y_func) if dataset_type == 'Imagenet': class_mapping = {} index = 1 for class_name in data.classes: class_mapping[index] = class_name index = index + 1 if transforms is None: ranges = (0, 1) train_tfms = [ rotate(degrees=30, p=0.5), crop(size=chip_size, p=1., row_pct=ranges, col_pct=ranges), dihedral_affine(), brightness(change=(0.4, 0.6)), contrast(scale=(0.75, 1.5)) ] val_tfms = [crop(size=chip_size, p=1.0, row_pct=0.5, col_pct=0.5)] transforms = (train_tfms, val_tfms) elif dataset_type in ['ner_json', 'BIO', 'IOB', 'LBIOU', 'BILUO']: return ner_prepare_data(dataset_type=dataset_type, path=path, class_mapping=class_mapping, val_split_pct=val_split_pct) else: raise NotImplementedError( 'Unknown dataset_type="{}".'.format(dataset_type)) if _is_multispectral: if dataset_type == 'RCNN_Masks': kwargs['do_normalize'] = False if transforms == None: data = (src.transform( size=chip_size, tfm_y=True).databunch(**databunch_kwargs)) else: data = (src.transform( transforms, size=chip_size, tfm_y=True).databunch(**databunch_kwargs)) else: data = (data.transform( transforms, **kwargs_transforms).databunch(**databunch_kwargs)) if len(data.x) < 300: norm_pct = 1 # Statistics dummy_stats = { "batch_stats_for_norm_pct_0": { "band_min_values": None, "band_max_values": None, "band_mean_values": None, "band_std_values": None, "scaled_min_values": None, "scaled_max_values": None, "scaled_mean_values": None, "scaled_std_values": None } } normstats_json_path = os.path.abspath(data.path / '..' / 'esri_normalization_stats.json') if not os.path.exists(normstats_json_path): normstats = dummy_stats with open(normstats_json_path, 'w', encoding='utf-8') as f: json.dump(normstats, f, ensure_ascii=False, indent=4) else: with open(normstats_json_path) as f: normstats = json.load(f) norm_pct_search = f"batch_stats_for_norm_pct_{round(norm_pct*100)}" if norm_pct_search in normstats: batch_stats = normstats[norm_pct_search] for s in batch_stats: if batch_stats[s] is not None: batch_stats[s] = torch.tensor(batch_stats[s]) else: batch_stats = _get_batch_stats(data.x, norm_pct) normstats[norm_pct_search] = dict(batch_stats) for s in normstats[norm_pct_search]: if normstats[norm_pct_search][s] is not None: normstats[norm_pct_search][s] = normstats[norm_pct_search][ s].tolist() with open(normstats_json_path, 'w', encoding='utf-8') as f: json.dump(normstats, f, ensure_ascii=False, indent=4) # batch_stats -> [band_min_values, band_max_values, band_mean_values, band_std_values, scaled_min_values, scaled_max_values, scaled_mean_values, scaled_std_values] data._band_min_values = batch_stats['band_min_values'] data._band_max_values = batch_stats['band_max_values'] data._band_mean_values = batch_stats['band_mean_values'] data._band_std_values = batch_stats['band_std_values'] data._scaled_min_values = batch_stats['scaled_min_values'] data._scaled_max_values = batch_stats['scaled_max_values'] data._scaled_mean_values = batch_stats['scaled_mean_values'] data._scaled_std_values = batch_stats['scaled_std_values'] # Prevent Divide by zeros data._band_max_values[data._band_min_values == data._band_max_values] += 1 data._scaled_std_values[data._scaled_std_values == 0] += 1e-02 # Scaling data._min_max_scaler = partial(_tensor_scaler, min_values=data._band_min_values, max_values=data._band_max_values, mode='minmax') data._min_max_scaler_tfm = partial(_tensor_scaler_tfm, min_values=data._band_min_values, max_values=data._band_max_values, mode='minmax') #data.add_tfm(data._min_max_scaler_tfm) # Transforms def _scaling_tfm(x): ## Scales Fastai Image Scaling | MS Image Values -> 0 - 1 range return x.__class__(data._min_max_scaler_tfm((x.data, None))[0][0]) ## Fastai need tfm, order and resolve. class dummy(): pass _scaling_tfm.tfm = dummy() _scaling_tfm.tfm.order = 0 _scaling_tfm.resolve = dummy ## Scaling the images before applying any other transform if getattr(data.train_ds, 'tfms') is not None: data.train_ds.tfms = [_scaling_tfm] + data.train_ds.tfms else: data.train_ds.tfms = [_scaling_tfm] if getattr(data.valid_ds, 'tfms') is not None: data.valid_ds.tfms = [_scaling_tfm] + data.valid_ds.tfms else: data.valid_ds.tfms = [_scaling_tfm] # Normalize data._do_normalize = True if kwargs.get('do_normalize', None) is not None: data._do_normalize = kwargs.get('do_normalize', True) if data._do_normalize: data = data.normalize(stats=(data._scaled_mean_values, data._scaled_std_values), do_x=True, do_y=False) elif dataset_type == 'RCNN_Masks': if transforms == None: data = (src.transform(size=chip_size, tfm_y=True).databunch(**databunch_kwargs)) else: data = (src.transform(transforms, size=chip_size, tfm_y=True).databunch(**databunch_kwargs)) data.show_batch = types.MethodType(show_batch_rcnn_masks, data) else: # data = (data.transform(transforms, **kwargs_transforms).databunch( **databunch_kwargs).normalize(imagenet_stats)) data.chip_size = data.x[0].shape[-1] if transforms is False else chip_size if alter_class_mapping: new_mapping = {} for i, class_name in enumerate(class_mapping.keys()): new_mapping[i + 1] = class_name class_mapping = new_mapping data.class_mapping = class_mapping data.color_mapping = color_mapping data.show_batch = partial(data.show_batch, rows=min(int(math.sqrt(batch_size)), 5)) data.orig_path = path data.resize_to = kwargs_transforms.get('size', None) data.height_width = height_width data._is_multispectral = _is_multispectral if data._is_multispectral: data._imagery_type = imagery_type data._bands = bands data._norm_pct = norm_pct data._rgb_bands = rgb_bands data._symbology_rgb_bands = rgb_bands # Handle invalid color mapping data._multispectral_color_mapping = color_mapping if any(-1 in x for x in data._multispectral_color_mapping.values()): random_color_list = np.random.randint( low=0, high=255, size=(len(data._multispectral_color_mapping), 3)).tolist() for i, c in enumerate(data._multispectral_color_mapping): if -1 in data._multispectral_color_mapping[c]: data._multispectral_color_mapping[c] = random_color_list[i] # prepare color array alpha = kwargs.get('alpha', 0.7) color_array = torch.tensor(list( data.color_mapping.values())).float() / 255 alpha_tensor = torch.tensor([alpha] * len(color_array)).view( -1, 1).float() color_array = torch.cat([color_array, alpha_tensor], dim=-1) background_color = torch.tensor([[0, 0, 0, 0]]).float() data._multispectral_color_array = torch.cat( [background_color, color_array]) # Prepare unknown bands list if bands data is missing if data._bands is None: n_bands = data.x[0].data.shape[0] if n_bands == 1: # Handle Pancromatic case data._bands = ['p'] data._symbology_rgb_bands = [0] else: data._bands = ['u' for i in range(n_bands)] if n_bands == 2: # Handle Data with two channels data._symbology_rgb_bands = [0] # if data._rgb_bands is None: data._rgb_bands = [] # if data._symbology_rgb_bands is None: data._symbology_rgb_bands = [0, 1, 2][:min(n_bands, 3)] # Complete symbology rgb bands if len(data._bands) > 2 and len(data._symbology_rgb_bands) < 3: data._symbology_rgb_bands += [ min(max(data._symbology_rgb_bands) + 1, len(data._bands) - 1) for i in range(3 - len(data._symbology_rgb_bands)) ] # Overwrite band values at r g b indexes with 'r' 'g' 'b' for i, band_idx in enumerate(data._rgb_bands): if band_idx is not None: if data._bands[band_idx] == 'u': data._bands[band_idx] = ['r', 'g', 'b'][i] # Attach custom show batch if _show_batch_multispectral is not None: data.show_batch = types.MethodType(_show_batch_multispectral, data) # Apply filter band transformation if user has specified extract_bands otherwise add a generic extract_bands """ extract_bands : List containing band indices of the bands from imagery on which the model would be trained. Useful for benchmarking and applied training, for reference see examples below. 4 band naip ['r, 'g', 'b', 'nir'] + extract_bands=[0, 1, 2] -> 3 band naip with bands ['r', 'g', 'b'] """ data._extract_bands = kwargs.get('extract_bands', None) if data._extract_bands is None: data._extract_bands = list(range(len(data._bands))) else: data._extract_bands_tfm = partial(_extract_bands_tfm, band_indices=data._extract_bands) data.add_tfm(data._extract_bands_tfm) # Tail Training Override _train_tail = True if [data._bands[i] for i in data._extract_bands] == ['r', 'g', 'b']: _train_tail = False data._train_tail = kwargs.get('train_tail', _train_tail) if has_esri_files: data._image_space_used = emd.get('ImageSpaceUsed', 'MAP_SPACE') else: data._image_space_used = 'PIXEL_SPACE' return data
def prepare_data(path, class_mapping=None, chip_size=224, val_split_pct=0.1, batch_size=64, transforms=None, collate_fn=_bb_pad_collate, seed=42, dataset_type = None): """ Prepares a Fast.ai DataBunch from the exported Pascal VOC image chips exported by Export Training Data tool in ArcGIS Pro or Image Server. This DataBunch consists of training and validation DataLoaders with the specified transformations, chip size, batch size, split percentage. ===================== =========================================== **Argument** **Description** --------------------- ------------------------------------------- path Required string. Path to data directory. --------------------- ------------------------------------------- class_mapping Optional dictionary. Mapping from id to its string label. --------------------- ------------------------------------------- chip_size Optional integer. Size of the image to train the model. --------------------- ------------------------------------------- val_split_pct Optional float. Percentage of training data to keep as validation. --------------------- ------------------------------------------- batch_size Optional integer. Batch size for mini batch gradient descent (Reduce it if getting CUDA Out of Memory Errors). --------------------- ------------------------------------------- transforms Optional tuple. Fast.ai transforms for data augmentation of training and validation datasets respectively (We have set good defaults which work for satellite imagery well). --------------------- ------------------------------------------- collate_fn Optional function. Passed to PyTorch to collate data into batches(usually default works). --------------------- ------------------------------------------- seed Optional integer. Random seed for reproducible train-validation split. --------------------- ------------------------------------------- dataset_type Optional string. `prepare_data` function will infer the `dataset_type` on its own if it contains a map.txt file. If the path does not contain the map.txt file pass either of 'PASCAL_VOC_rectangles', 'RCNN_Masks' and 'Classified_Tiles' ===================== =========================================== :returns: fastai DataBunch object """ if not HAS_FASTAI: _raise_fastai_import_error() if type(path) is str: path = Path(path) databunch_kwargs = {'num_workers':0} if sys.platform == 'win32' else {} json_file = path / 'esri_model_definition.emd' with open(json_file) as f: emd = json.load(f) if class_mapping is None: try: class_mapping = {i['Value'] : i['Name'] for i in emd['Classes']} except KeyError: class_mapping = {i['ClassValue'] : i['ClassName'] for i in emd['Classes']} color_mapping = None if color_mapping is None: try: color_mapping = {i['Value'] : i['Color'] for i in emd['Classes']} except KeyError: color_mapping = {i['ClassValue'] : i['Color'] for i in emd['Classes']} # if [-1, -1, -1] in color_mapping.values(): # for c_idx, c_color in color_mapping.items(): # if c_color[0] == -1: # color_mapping[c_idx] = [random.choice(range(256)) for i in range(3)] #color_mapping[0] = [0, 0, 0] if dataset_type is None: stats_file = path / 'esri_accumulated_stats.json' with open(stats_file) as f: stats = json.load(f) dataset_type = stats['MetaDataMode'] # imagefile_types = ['png', 'jpg', 'tif', 'jpeg', 'tiff'] # bboxfile_types = ['xml', 'json'] with open(path / 'map.txt') as f: line = f.readline() # left = line.split()[0].split('.')[-1].lower() right = line.split()[1].split('.')[-1].lower() # if (left in imagefile_types) and (right in imagefile_types): # dataset_type = 'RCNN_Masks' # elif (left in imagefile_types) and (right in bboxfile_types): # dataset_type = 'PASCAL_VOC_rectangles' # else: # raise NotImplementedError('Cannot infer dataset type. The dataset type is not implemented') if dataset_type in ['RCNN_Masks', 'Classified_Tiles']: def get_y_func(x, ext=right): return x.parents[1] / 'labels' / (x.stem + '.{}'.format(ext)) src = (ArcGISSegmentationItemList.from_folder(path/'images') .random_split_by_pct(val_split_pct, seed=seed) .label_from_func(get_y_func, classes=['NoData'] + list(class_mapping.values()), class_mapping=class_mapping, color_mapping=color_mapping)) #TODO : Handel NoData case if transforms is None: transforms = get_transforms(flip_vert=True, max_rotate=90., max_zoom=3.0, max_lighting=0.5) #, # xtra_tfms=[skew(direction=(1,8), # magnitude=(0.2,0.8))]) data = (src .transform(transforms, size=chip_size, tfm_y=True) .databunch(bs=batch_size, **databunch_kwargs) .normalize(imagenet_stats)) elif dataset_type == 'PASCAL_VOC_rectangles': get_y_func = partial(_get_bbox_lbls, class_mapping=class_mapping) src = (SSDObjectItemList.from_folder(path/'images') .random_split_by_pct(val_split_pct, seed=seed) .label_from_func(get_y_func)) if transforms is None: ranges = (0,1) train_tfms = [crop(size=chip_size, p=1., row_pct=ranges, col_pct=ranges), dihedral_affine(), brightness(change=(0.4, 0.6)), contrast(scale=(0.75, 1.5)), rand_zoom(scale=(0.75, 1.5))] val_tfms = [crop(size=chip_size, p=1., row_pct=0.5, col_pct=0.5)] transforms = (train_tfms, val_tfms) data = (src .transform(transforms, tfm_y=True) .databunch(bs=batch_size, collate_fn=collate_fn, **databunch_kwargs) .normalize(imagenet_stats)) elif dataset_type == 'Labeled_Tiles': get_y_func = partial(_get_lbls, class_mapping=class_mapping) src = (ImageItemList.from_folder(path/'images') .random_split_by_pct(val_split_pct, seed=42) .label_from_func(get_y_func)) if transforms is None: # transforms = get_transforms(flip_vert=True, # max_warp=0, # max_rotate=90., # max_zoom=1.5, # max_lighting=0.5) ranges = (0, 1) train_tfms = [rotate(degrees=30, p=0.5), crop(size=chip_size, p=1., row_pct=ranges, col_pct=ranges), dihedral_affine(), brightness(change=(0.4, 0.6)), contrast(scale=(0.75, 1.5)), # rand_zoom(scale=(0.75, 1.5)) ] val_tfms = [crop(size=chip_size, p=1.0, row_pct=0.5, col_pct=0.5)] transforms = (train_tfms, val_tfms) data = (src .transform(transforms, size=chip_size) .databunch(bs=batch_size, **databunch_kwargs) .normalize(imagenet_stats)) else: raise NotImplementedError('Unknown dataset_type="{}".'.format(dataset_type)) data.chip_size = chip_size data.class_mapping = class_mapping data.color_mapping = color_mapping show_batch_func = data.show_batch show_batch_func = partial(show_batch_func, rows=min(int(math.sqrt(batch_size)), 5)) data.show_batch = show_batch_func data.orig_path = path return data