Exemplo n.º 1
0
    def from_model(cls, emd_path, data=None):
        emd_path = Path(emd_path)
        with open(emd_path) as f:
            emd = json.load(f)

        model_file = Path(emd['ModelFile'])

        if not model_file.is_absolute():
            model_file = emd_path.parent / model_file

        model_params = emd['ModelParameters']
        chip_size = emd["ImageWidth"]

        try:
            class_mapping = {i['Value']: i['Name'] for i in emd['Classes']}
            color_mapping = {i['Value']: i['Color'] for i in emd['Classes']}
        except KeyError:
            class_mapping = {
                i['ClassValue']: i['ClassName']
                for i in emd['Classes']
            }
            color_mapping = {
                i['ClassValue']: i['Color']
                for i in emd['Classes']
            }

        if data is None:
            ranges = (0, 1)
            train_tfms = [
                rotate(degrees=30, p=0.5),
                crop(size=chip_size, p=1., row_pct=ranges, col_pct=ranges),
                dihedral_affine(),
                brightness(change=(0.4, 0.6)),
                contrast(scale=(0.75, 1.5)),
                # rand_zoom(scale=(0.75, 1.5))
            ]
            val_tfms = [crop(size=chip_size, p=1.0, row_pct=0.5, col_pct=0.5)]
            transforms = (train_tfms, val_tfms)

            with warnings.catch_warnings():
                warnings.simplefilter("ignore", UserWarning)

                tempdata = ImageDataBunch.single_from_classes(
                    tempfile.TemporaryDirectory().name,
                    sorted(list(class_mapping.values())),
                    tfms=transforms,
                    size=chip_size).normalize(imagenet_stats)
                tempdata.chip_size = chip_size
                return cls(tempdata,
                           **model_params,
                           pretrained_path=str(model_file))
        else:
            return cls(data, **model_params, pretrained_path=str(model_file))
Exemplo n.º 2
0
def prepare_data(path,
                 class_mapping=None,
                 chip_size=224,
                 val_split_pct=0.1,
                 batch_size=64,
                 transforms=None,
                 collate_fn=_bb_pad_collate,
                 seed=42,
                 dataset_type=None,
                 resize_to=None,
                 **kwargs):
    """
    Prepares a data object from training sample exported by the 
    Export Training Data tool in ArcGIS Pro or Image Server, or training 
    samples in the supported dataset formats. This data object consists of 
    training and validation data sets with the specified transformations, 
    chip size, batch size, split percentage, etc. 
    -For object detection, use Pascal_VOC_rectangles format.
    -For feature categorization use Labelled Tiles or ImageNet format.
    -For pixel classification, use Classified Tiles format.
    -For entity extraction from text, use IOB, BILUO or ner_json formats. 

    =====================   ===========================================
    **Argument**            **Description**
    ---------------------   -------------------------------------------
    path                    Required string. Path to data directory.
    ---------------------   -------------------------------------------
    class_mapping           Optional dictionary. Mapping from id to
                            its string label.
                            For dataset_type=IOB, BILUO or ner_json:
                                Provide address field as class mapping
                                in below format:
                                class_mapping={'address_tag':'address_field'}
    ---------------------   -------------------------------------------
    chip_size               Optional integer. Size of the image to train the
                            model.
    ---------------------   -------------------------------------------
    val_split_pct           Optional float. Percentage of training data to keep
                            as validation.
    ---------------------   -------------------------------------------
    batch_size              Optional integer. Batch size for mini batch gradient
                            descent (Reduce it if getting CUDA Out of Memory
                            Errors).
    ---------------------   -------------------------------------------
    transforms              Optional tuple. Fast.ai transforms for data
                            augmentation of training and validation datasets
                            respectively (We have set good defaults which work
                            for satellite imagery well). If transforms is set
                            to `False` no transformation will take place and 
                            `chip_size` parameter will also not take effect.
    ---------------------   -------------------------------------------
    collate_fn              Optional function. Passed to PyTorch to collate data
                            into batches(usually default works).
    ---------------------   -------------------------------------------
    seed                    Optional integer. Random seed for reproducible
                            train-validation split.
    ---------------------   -------------------------------------------
    dataset_type            Optional string. `prepare_data` function will infer 
                            the `dataset_type` on its own if it contains a 
                            map.txt file. If the path does not contain the 
                            map.txt file pass either of 'PASCAL_VOC_rectangles', 
                            'RCNN_Masks' and 'Classified_Tiles'                    
    ---------------------   -------------------------------------------
    resize_to               Optional integer. Resize the image to given size.
    =====================   ===========================================

    :returns: data object
    """
    """kwargs documentation
    imagery_type='RGB' # Change to known imagery_type or anything else to trigger multispectral
    bands=None # sepcify bands type for unknow imagery ['r', 'g', 'b', 'nir']
    rgb_bands=[0, 1, 2] # specify rgb bands indices for unknown imagery
    norm_pct=0.3 # sample of images to calculate normalization stats on 
    do_normalize=True # Normalize data 
    """

    height_width = []

    if not HAS_FASTAI:
        _raise_fastai_import_error()

    if isinstance(path, str) and not os.path.exists(path):
        raise Exception("Invalid input path.")

    if type(path) is str:
        path = Path(path)

    databunch_kwargs = {'num_workers': 0} if sys.platform == 'win32' else {}
    databunch_kwargs['bs'] = batch_size

    kwargs_transforms = {}
    if resize_to:
        kwargs_transforms['size'] = resize_to

    has_esri_files = _check_esri_files(path)
    alter_class_mapping = False
    color_mapping = None

    # Multispectral Kwargs init
    _bands = None
    _imagery_type = None
    _is_multispectral = False
    _show_batch_multispectral = None

    if dataset_type is None and not has_esri_files:
        raise Exception("Could not infer dataset type.")

    if dataset_type != "Imagenet" and has_esri_files:
        stats_file = path / 'esri_accumulated_stats.json'
        with open(stats_file) as f:
            stats = json.load(f)
            dataset_type = stats['MetaDataMode']

        with open(path / 'map.txt') as f:
            line = f.readline()

        right = line.split()[1].split('.')[-1].lower()

        json_file = path / 'esri_model_definition.emd'
        with open(json_file) as f:
            emd = json.load(f)

        # Create Class Mapping from EMD if not specified by user
        ## Validate user defined class_mapping keys with emd (issue #3064)
        # Get classmapping from emd file.
        try:
            emd_class_mapping = {i['Value']: i['Name'] for i in emd['Classes']}
        except KeyError:
            emd_class_mapping = {
                i['ClassValue']: i['ClassName']
                for i in emd['Classes']
            }

        ## Change all keys to int.
        if class_mapping is not None:
            class_mapping = {
                int(key): value
                for key, value in class_mapping.items()
            }
        else:
            class_mapping = {}

        ## Map values from user defined classmapping to emd classmapping.
        for key, _ in emd_class_mapping.items():
            if class_mapping.get(key) is not None:
                emd_class_mapping[key] = class_mapping[key]

        class_mapping = emd_class_mapping

        color_mapping = {(i.get('Value', 0) or i.get('ClassValue', 0)):
                         i['Color']
                         for i in emd.get('Classes', [])}

        if color_mapping.get(None):
            del color_mapping[None]

        if class_mapping.get(None):
            del class_mapping[None]

        # Multispectral support from EMD
        # Not Implemented Yet
        if emd.get('bands', None) is not None:
            _bands = emd.get['bands']  # Not Implemented

        if emd.get('imagery_type', None) is not None:
            _imagery_type = emd.get['imagery_type']  # Not Implemented

    elif dataset_type == 'PASCAL_VOC_rectangles' and not has_esri_files:
        if class_mapping is None:
            class_mapping = _get_class_mapping(path / 'labels')
            alter_class_mapping = True

    # Multispectral check
    imagery_type = 'RGB'
    if kwargs.get('imagery_type', None) is not None:
        imagery_type = kwargs.get('imagery_type')
    elif _imagery_type is not None:
        imagery_type = _imagery_type

    bands = None
    if kwargs.get('bands', None) is not None:
        bands = kwargs.get('bands')
        for i, b in enumerate(bands):
            if type(b) == str:
                bands[i] = b.lower()
    elif imagery_type_lib.get(imagery_type, None) is not None:
        bands = imagery_type_lib.get(imagery_type)['bands']
    elif _bands is not None:
        bands = _bands

    rgb_bands = None
    if kwargs.get('rgb_bands', None) is not None:
        rgb_bands = kwargs.get('rgb_bands')
    elif bands is not None:
        rgb_bands = [bands.index(b) for b in ['r', 'g', 'b'] if b in bands]

    if (bands is not None) or (rgb_bands
                               is not None) or (not imagery_type == 'RGB'):
        if imagery_type == 'RGB':
            imagery_type = 'multispectral'
        _is_multispectral = True

    if kwargs.get('norm_pct', None) is not None:
        norm_pct = kwargs.get('norm_pct')
        norm_pct = min(max(0, norm_pct), 1)
    else:
        norm_pct = .3

    lighting_transforms = kwargs.get('lighting_transforms', True)

    if dataset_type == 'RCNN_Masks':

        def get_labels(x, label_dirs, ext=right):
            label_path = []
            for lbl in label_dirs:
                if os.path.exists(Path(lbl) / (x.stem + '.{}'.format(ext))):
                    label_path.append(Path(lbl) / (x.stem + '.{}'.format(ext)))
            return label_path

        if class_mapping.get(0):
            del class_mapping[0]

        if color_mapping.get(0):
            del color_mapping[0]

        # Handle Multispectral
        if _is_multispectral:
            src = (ArcGISInstanceSegmentationMSItemList.from_folder(
                path / 'images').split_by_rand_pct(val_split_pct, seed=seed))
            _show_batch_multispectral = show_batch_rcnn_masks
        else:
            src = (ArcGISInstanceSegmentationItemList.from_folder(
                path / 'images').split_by_rand_pct(val_split_pct, seed=seed))

        label_dirs = []
        index_dir = {}  #for handling calss value with any number
        for i, k in enumerate(sorted(class_mapping.keys())):
            label_dirs.append(class_mapping[k])
            index_dir[k] = i + 1
        label_dir = [
            os.path.join(path / 'labels', lbl) for lbl in label_dirs
            if os.path.isdir(os.path.join(path / 'labels', lbl))
        ]
        get_y_func = partial(get_labels, label_dirs=label_dir)
        src = src.label_from_func(get_y_func,
                                  chip_size=chip_size,
                                  classes=['NoData'] +
                                  list(class_mapping.values()),
                                  class_mapping=class_mapping,
                                  color_mapping=color_mapping,
                                  index_dir=index_dir)

    elif dataset_type == 'Classified_Tiles':

        def get_y_func(x, ext=right):
            return x.parents[1] / 'labels' / (x.stem + '.{}'.format(ext))

        if class_mapping.get(0):
            del class_mapping[0]

        if color_mapping.get(0):
            del color_mapping[0]

        if is_no_color(color_mapping):
            color_mapping = {
                j: [random.choice(range(256)) for i in range(3)]
                for j in class_mapping.keys()
            }

        # TODO : Handle NoData case

        # Handle Multispectral
        if _is_multispectral:
            data = ArcGISSegmentationMSItemList.from_folder(path/'images')\
                .split_by_rand_pct(val_split_pct, seed=seed)\
                .label_from_func(
                    get_y_func, classes=(['NoData'] + list(class_mapping.values())),
                    class_mapping=class_mapping,
                    color_mapping=color_mapping
                )
            _show_batch_multispectral = _show_batch_unet_multispectral

            def classified_tiles_collate_fn(
                samples
            ):  # The default fastai collate_fn was causing memory leak on tensors
                r = (torch.stack([x[0].data for x in samples]),
                     torch.stack([x[1].data for x in samples]))
                return r

            databunch_kwargs['collate_fn'] = classified_tiles_collate_fn

        else:
            data = ArcGISSegmentationItemList.from_folder(path/'images')\
                .split_by_rand_pct(val_split_pct, seed=seed)\
                .label_from_func(
                    get_y_func, classes=(['NoData'] + list(class_mapping.values())),
                    class_mapping=class_mapping,
                    color_mapping=color_mapping
                )

        if transforms is None:
            transforms = get_transforms(flip_vert=True,
                                        max_rotate=90.,
                                        max_zoom=3.0,
                                        max_lighting=0.5)

        kwargs_transforms['tfm_y'] = True
        kwargs_transforms['size'] = chip_size
    elif dataset_type == 'PASCAL_VOC_rectangles':
        not_label_count = [0]
        get_y_func = partial(_get_bbox_lbls,
                             class_mapping=class_mapping,
                             not_label_count=not_label_count,
                             height_width=height_width)

        if _is_multispectral:
            data = SSDObjectMSItemList.from_folder(path/'images')\
            .split_by_rand_pct(val_split_pct, seed=seed)\
            .label_from_func(get_y_func)
            _show_batch_multispectral = show_batch_pascal_voc_rectangles
        else:
            data = SSDObjectItemList.from_folder(path/'images')\
                .split_by_rand_pct(val_split_pct, seed=seed)\
                .label_from_func(get_y_func)

        if not_label_count[0]:
            logger = logging.getLogger()
            logger.warning("Please check your dataset. " +
                           str(not_label_count[0]) +
                           " images dont have the corresponding label files.")

        if transforms is None:
            ranges = (0, 1)
            train_tfms = [
                crop(size=chip_size, p=1., row_pct=ranges, col_pct=ranges),
                dihedral_affine() if has_esri_files else flip_lr(),
                brightness(change=(0.4, 0.6)),
                contrast(scale=(0.75, 1.5)),
                rand_zoom(scale=(1.0, 1.5))
            ]
            val_tfms = [crop(size=chip_size, p=1., row_pct=0.5, col_pct=0.5)]
            transforms = (train_tfms, val_tfms)

        kwargs_transforms['tfm_y'] = True
        databunch_kwargs['collate_fn'] = collate_fn
    elif dataset_type in ['Labeled_Tiles', 'Imagenet']:
        if dataset_type == 'Labeled_Tiles':
            get_y_func = partial(_get_lbls, class_mapping=class_mapping)
        else:

            def get_y_func(x):
                return x.parent.stem

        if _is_multispectral:
            data = ArcGISMSImageList.from_folder(path/'images')\
                .split_by_rand_pct(val_split_pct, seed=42)\
                .label_from_func(get_y_func)
            _show_batch_multispectral = show_batch_labeled_tiles
        else:
            data = ImageList.from_folder(path/'images')\
                .split_by_rand_pct(val_split_pct, seed=42)\
                .label_from_func(get_y_func)

        if dataset_type == 'Imagenet':
            class_mapping = {}
            index = 1
            for class_name in data.classes:
                class_mapping[index] = class_name
                index = index + 1

        if transforms is None:
            ranges = (0, 1)
            train_tfms = [
                rotate(degrees=30, p=0.5),
                crop(size=chip_size, p=1., row_pct=ranges, col_pct=ranges),
                dihedral_affine(),
                brightness(change=(0.4, 0.6)),
                contrast(scale=(0.75, 1.5))
            ]
            val_tfms = [crop(size=chip_size, p=1.0, row_pct=0.5, col_pct=0.5)]
            transforms = (train_tfms, val_tfms)
    elif dataset_type in ['ner_json', 'BIO', 'IOB', 'LBIOU', 'BILUO']:
        return ner_prepare_data(dataset_type=dataset_type,
                                path=path,
                                class_mapping=class_mapping,
                                val_split_pct=val_split_pct)
    else:
        raise NotImplementedError(
            'Unknown dataset_type="{}".'.format(dataset_type))

    if _is_multispectral:
        if dataset_type == 'RCNN_Masks':
            kwargs['do_normalize'] = False
            if transforms == None:
                data = (src.transform(
                    size=chip_size, tfm_y=True).databunch(**databunch_kwargs))
            else:
                data = (src.transform(
                    transforms, size=chip_size,
                    tfm_y=True).databunch(**databunch_kwargs))
        else:
            data = (data.transform(
                transforms, **kwargs_transforms).databunch(**databunch_kwargs))

        if len(data.x) < 300:
            norm_pct = 1

        # Statistics
        dummy_stats = {
            "batch_stats_for_norm_pct_0": {
                "band_min_values": None,
                "band_max_values": None,
                "band_mean_values": None,
                "band_std_values": None,
                "scaled_min_values": None,
                "scaled_max_values": None,
                "scaled_mean_values": None,
                "scaled_std_values": None
            }
        }
        normstats_json_path = os.path.abspath(data.path / '..' /
                                              'esri_normalization_stats.json')
        if not os.path.exists(normstats_json_path):
            normstats = dummy_stats
            with open(normstats_json_path, 'w', encoding='utf-8') as f:
                json.dump(normstats, f, ensure_ascii=False, indent=4)
        else:
            with open(normstats_json_path) as f:
                normstats = json.load(f)

        norm_pct_search = f"batch_stats_for_norm_pct_{round(norm_pct*100)}"
        if norm_pct_search in normstats:
            batch_stats = normstats[norm_pct_search]
            for s in batch_stats:
                if batch_stats[s] is not None:
                    batch_stats[s] = torch.tensor(batch_stats[s])
        else:
            batch_stats = _get_batch_stats(data.x, norm_pct)
            normstats[norm_pct_search] = dict(batch_stats)
            for s in normstats[norm_pct_search]:
                if normstats[norm_pct_search][s] is not None:
                    normstats[norm_pct_search][s] = normstats[norm_pct_search][
                        s].tolist()
            with open(normstats_json_path, 'w', encoding='utf-8') as f:
                json.dump(normstats, f, ensure_ascii=False, indent=4)

        # batch_stats -> [band_min_values, band_max_values, band_mean_values, band_std_values, scaled_min_values, scaled_max_values, scaled_mean_values, scaled_std_values]
        data._band_min_values = batch_stats['band_min_values']
        data._band_max_values = batch_stats['band_max_values']
        data._band_mean_values = batch_stats['band_mean_values']
        data._band_std_values = batch_stats['band_std_values']
        data._scaled_min_values = batch_stats['scaled_min_values']
        data._scaled_max_values = batch_stats['scaled_max_values']
        data._scaled_mean_values = batch_stats['scaled_mean_values']
        data._scaled_std_values = batch_stats['scaled_std_values']

        # Prevent Divide by zeros
        data._band_max_values[data._band_min_values ==
                              data._band_max_values] += 1
        data._scaled_std_values[data._scaled_std_values == 0] += 1e-02

        # Scaling
        data._min_max_scaler = partial(_tensor_scaler,
                                       min_values=data._band_min_values,
                                       max_values=data._band_max_values,
                                       mode='minmax')
        data._min_max_scaler_tfm = partial(_tensor_scaler_tfm,
                                           min_values=data._band_min_values,
                                           max_values=data._band_max_values,
                                           mode='minmax')

        #data.add_tfm(data._min_max_scaler_tfm)

        # Transforms
        def _scaling_tfm(x):
            ## Scales Fastai Image Scaling | MS Image Values -> 0 - 1 range
            return x.__class__(data._min_max_scaler_tfm((x.data, None))[0][0])

        ## Fastai need tfm, order and resolve.
        class dummy():
            pass

        _scaling_tfm.tfm = dummy()
        _scaling_tfm.tfm.order = 0
        _scaling_tfm.resolve = dummy

        ## Scaling the images before applying any  other transform
        if getattr(data.train_ds, 'tfms') is not None:
            data.train_ds.tfms = [_scaling_tfm] + data.train_ds.tfms
        else:
            data.train_ds.tfms = [_scaling_tfm]
        if getattr(data.valid_ds, 'tfms') is not None:
            data.valid_ds.tfms = [_scaling_tfm] + data.valid_ds.tfms
        else:
            data.valid_ds.tfms = [_scaling_tfm]

        # Normalize
        data._do_normalize = True
        if kwargs.get('do_normalize', None) is not None:
            data._do_normalize = kwargs.get('do_normalize', True)
        if data._do_normalize:
            data = data.normalize(stats=(data._scaled_mean_values,
                                         data._scaled_std_values),
                                  do_x=True,
                                  do_y=False)

    elif dataset_type == 'RCNN_Masks':
        if transforms == None:
            data = (src.transform(size=chip_size,
                                  tfm_y=True).databunch(**databunch_kwargs))
        else:
            data = (src.transform(transforms, size=chip_size,
                                  tfm_y=True).databunch(**databunch_kwargs))
        data.show_batch = types.MethodType(show_batch_rcnn_masks, data)
    else:
        #
        data = (data.transform(transforms, **kwargs_transforms).databunch(
            **databunch_kwargs).normalize(imagenet_stats))

    data.chip_size = data.x[0].shape[-1] if transforms is False else chip_size

    if alter_class_mapping:
        new_mapping = {}
        for i, class_name in enumerate(class_mapping.keys()):
            new_mapping[i + 1] = class_name
        class_mapping = new_mapping

    data.class_mapping = class_mapping
    data.color_mapping = color_mapping
    data.show_batch = partial(data.show_batch,
                              rows=min(int(math.sqrt(batch_size)), 5))
    data.orig_path = path
    data.resize_to = kwargs_transforms.get('size', None)
    data.height_width = height_width

    data._is_multispectral = _is_multispectral
    if data._is_multispectral:
        data._imagery_type = imagery_type
        data._bands = bands
        data._norm_pct = norm_pct
        data._rgb_bands = rgb_bands
        data._symbology_rgb_bands = rgb_bands

        # Handle invalid color mapping
        data._multispectral_color_mapping = color_mapping
        if any(-1 in x for x in data._multispectral_color_mapping.values()):
            random_color_list = np.random.randint(
                low=0,
                high=255,
                size=(len(data._multispectral_color_mapping), 3)).tolist()
            for i, c in enumerate(data._multispectral_color_mapping):
                if -1 in data._multispectral_color_mapping[c]:
                    data._multispectral_color_mapping[c] = random_color_list[i]

        # prepare color array
        alpha = kwargs.get('alpha', 0.7)
        color_array = torch.tensor(list(
            data.color_mapping.values())).float() / 255
        alpha_tensor = torch.tensor([alpha] * len(color_array)).view(
            -1, 1).float()
        color_array = torch.cat([color_array, alpha_tensor], dim=-1)
        background_color = torch.tensor([[0, 0, 0, 0]]).float()
        data._multispectral_color_array = torch.cat(
            [background_color, color_array])

        # Prepare unknown bands list if bands data is missing
        if data._bands is None:
            n_bands = data.x[0].data.shape[0]
            if n_bands == 1:  # Handle Pancromatic case
                data._bands = ['p']
                data._symbology_rgb_bands = [0]
            else:
                data._bands = ['u' for i in range(n_bands)]
                if n_bands == 2:  # Handle Data with two channels
                    data._symbology_rgb_bands = [0]

        #
        if data._rgb_bands is None:
            data._rgb_bands = []

        #
        if data._symbology_rgb_bands is None:
            data._symbology_rgb_bands = [0, 1, 2][:min(n_bands, 3)]

        # Complete symbology rgb bands
        if len(data._bands) > 2 and len(data._symbology_rgb_bands) < 3:
            data._symbology_rgb_bands += [
                min(max(data._symbology_rgb_bands) + 1,
                    len(data._bands) - 1)
                for i in range(3 - len(data._symbology_rgb_bands))
            ]

        # Overwrite band values at r g b indexes with 'r' 'g' 'b'
        for i, band_idx in enumerate(data._rgb_bands):
            if band_idx is not None:
                if data._bands[band_idx] == 'u':
                    data._bands[band_idx] = ['r', 'g', 'b'][i]

        # Attach custom show batch
        if _show_batch_multispectral is not None:
            data.show_batch = types.MethodType(_show_batch_multispectral, data)

        # Apply filter band transformation if user has specified extract_bands otherwise add a generic extract_bands
        """
        extract_bands : List containing band indices of the bands from imagery on which the model would be trained. 
                        Useful for benchmarking and applied training, for reference see examples below.
                        
                        4 band naip ['r, 'g', 'b', 'nir'] + extract_bands=[0, 1, 2] -> 3 band naip with bands ['r', 'g', 'b'] 

        """
        data._extract_bands = kwargs.get('extract_bands', None)
        if data._extract_bands is None:
            data._extract_bands = list(range(len(data._bands)))
        else:
            data._extract_bands_tfm = partial(_extract_bands_tfm,
                                              band_indices=data._extract_bands)
            data.add_tfm(data._extract_bands_tfm)

        # Tail Training Override
        _train_tail = True
        if [data._bands[i] for i in data._extract_bands] == ['r', 'g', 'b']:
            _train_tail = False
        data._train_tail = kwargs.get('train_tail', _train_tail)

    if has_esri_files:
        data._image_space_used = emd.get('ImageSpaceUsed', 'MAP_SPACE')
    else:
        data._image_space_used = 'PIXEL_SPACE'

    return data
Exemplo n.º 3
0
def prepare_data(path, class_mapping=None, chip_size=224, val_split_pct=0.1, batch_size=64, transforms=None, collate_fn=_bb_pad_collate, seed=42, dataset_type = None):
    """
    Prepares a Fast.ai DataBunch from the exported Pascal VOC image chips
    exported by Export Training Data tool in ArcGIS Pro or Image Server.
    This DataBunch consists of training and validation DataLoaders with the
    specified transformations, chip size, batch size, split percentage.

    =====================   ===========================================
    **Argument**            **Description**
    ---------------------   -------------------------------------------
    path                    Required string. Path to data directory.
    ---------------------   -------------------------------------------
    class_mapping           Optional dictionary. Mapping from id to
                            its string label.
    ---------------------   -------------------------------------------
    chip_size               Optional integer. Size of the image to train the
                            model.
    ---------------------   -------------------------------------------
    val_split_pct           Optional float. Percentage of training data to keep
                            as validation.
    ---------------------   -------------------------------------------
    batch_size              Optional integer. Batch size for mini batch gradient
                            descent (Reduce it if getting CUDA Out of Memory
                            Errors).
    ---------------------   -------------------------------------------
    transforms              Optional tuple. Fast.ai transforms for data
                            augmentation of training and validation datasets
                            respectively (We have set good defaults which work
                            for satellite imagery well).
    ---------------------   -------------------------------------------
    collate_fn              Optional function. Passed to PyTorch to collate data
                            into batches(usually default works).
    ---------------------   -------------------------------------------
    seed                    Optional integer. Random seed for reproducible
                            train-validation split.
    ---------------------   -------------------------------------------
    dataset_type            Optional string. `prepare_data` function will infer 
                            the `dataset_type` on its own if it contains a 
                            map.txt file. If the path does not contain the 
                            map.txt file pass either of 'PASCAL_VOC_rectangles', 
                            'RCNN_Masks' and 'Classified_Tiles'                    
                            
    =====================   ===========================================

    :returns: fastai DataBunch object
    """

    if not HAS_FASTAI:
        _raise_fastai_import_error()

    if type(path) is str:
        path = Path(path)

    databunch_kwargs = {'num_workers':0} if sys.platform == 'win32' else {}

    json_file = path / 'esri_model_definition.emd'
    with open(json_file) as f:
        emd = json.load(f)

    if class_mapping is None:
        try:
            class_mapping = {i['Value'] : i['Name'] for i in emd['Classes']}
        except KeyError:
            class_mapping = {i['ClassValue'] : i['ClassName'] for i in emd['Classes']}
    
    color_mapping = None
    if color_mapping is None:
        try:
            color_mapping = {i['Value'] : i['Color'] for i in emd['Classes']}
        except KeyError:          
            color_mapping = {i['ClassValue'] : i['Color'] for i in emd['Classes']}                

        # if [-1, -1, -1] in color_mapping.values():
        #     for c_idx, c_color in color_mapping.items():
        #         if c_color[0] == -1:
        #             color_mapping[c_idx] = [random.choice(range(256)) for i in range(3)]

        #color_mapping[0] = [0, 0, 0] 

    if dataset_type is None:

        stats_file = path / 'esri_accumulated_stats.json'
        with open(stats_file) as f:
            stats = json.load(f)
            dataset_type = stats['MetaDataMode']

        # imagefile_types = ['png', 'jpg', 'tif', 'jpeg', 'tiff']
        # bboxfile_types = ['xml', 'json']
        with open(path / 'map.txt') as f:
            line = f.readline()
        # left = line.split()[0].split('.')[-1].lower()
        right = line.split()[1].split('.')[-1].lower()
        
        # if (left in imagefile_types) and (right in imagefile_types):
        #     dataset_type = 'RCNN_Masks'
        # elif (left in imagefile_types) and (right in bboxfile_types):
        #     dataset_type = 'PASCAL_VOC_rectangles'
        # else:
        #     raise NotImplementedError('Cannot infer dataset type. The dataset type is not implemented')            
        
    
    if dataset_type in ['RCNN_Masks', 'Classified_Tiles']:
        
        def get_y_func(x, ext=right):
            return x.parents[1] / 'labels' / (x.stem + '.{}'.format(ext))
        
        src = (ArcGISSegmentationItemList.from_folder(path/'images')
           .random_split_by_pct(val_split_pct, seed=seed)
           .label_from_func(get_y_func, classes=['NoData'] + list(class_mapping.values()), class_mapping=class_mapping, color_mapping=color_mapping)) #TODO : Handel NoData case

        if transforms is None:
            transforms = get_transforms(flip_vert=True,
                                        max_rotate=90.,
                                        max_zoom=3.0,
                                        max_lighting=0.5) #,
    #                                     xtra_tfms=[skew(direction=(1,8),
    #                                     magnitude=(0.2,0.8))]) 


        data = (src
            .transform(transforms, size=chip_size, tfm_y=True)
            .databunch(bs=batch_size, **databunch_kwargs)
            .normalize(imagenet_stats))
        
    elif dataset_type == 'PASCAL_VOC_rectangles': 


        get_y_func = partial(_get_bbox_lbls, class_mapping=class_mapping)

        src = (SSDObjectItemList.from_folder(path/'images')
           .random_split_by_pct(val_split_pct, seed=seed)
           .label_from_func(get_y_func))

        if transforms is None:
            ranges = (0,1)
            train_tfms = [crop(size=chip_size, p=1., row_pct=ranges, col_pct=ranges), dihedral_affine(), brightness(change=(0.4, 0.6)), contrast(scale=(0.75, 1.5)), rand_zoom(scale=(0.75, 1.5))]
            val_tfms = [crop(size=chip_size, p=1., row_pct=0.5, col_pct=0.5)]
            transforms = (train_tfms, val_tfms)

        data = (src
            .transform(transforms, tfm_y=True)
            .databunch(bs=batch_size, collate_fn=collate_fn, **databunch_kwargs)
            .normalize(imagenet_stats))
        
    elif dataset_type == 'Labeled_Tiles':


        get_y_func = partial(_get_lbls, class_mapping=class_mapping)

        src = (ImageItemList.from_folder(path/'images')
           .random_split_by_pct(val_split_pct, seed=42)
           .label_from_func(get_y_func))

        if transforms is None:
            # transforms = get_transforms(flip_vert=True,
            #                             max_warp=0,
            #                             max_rotate=90.,
            #                             max_zoom=1.5,
            #                             max_lighting=0.5)
            ranges = (0, 1)
            train_tfms = [rotate(degrees=30, p=0.5),
                crop(size=chip_size, p=1., row_pct=ranges, col_pct=ranges), 
                dihedral_affine(), brightness(change=(0.4, 0.6)), contrast(scale=(0.75, 1.5)),
                # rand_zoom(scale=(0.75, 1.5))
                ]
            val_tfms = [crop(size=chip_size, p=1.0, row_pct=0.5, col_pct=0.5)]
            transforms = (train_tfms, val_tfms)

        data = (src
            .transform(transforms, size=chip_size)
            .databunch(bs=batch_size, **databunch_kwargs)
            .normalize(imagenet_stats))
        
    else:
        raise NotImplementedError('Unknown dataset_type="{}".'.format(dataset_type))    

    data.chip_size = chip_size
    data.class_mapping = class_mapping
    data.color_mapping = color_mapping
    show_batch_func = data.show_batch
    show_batch_func = partial(show_batch_func, rows=min(int(math.sqrt(batch_size)), 5))
    data.show_batch = show_batch_func
    data.orig_path = path

    return data