Ejemplo n.º 1
0
def get_colorize_data(
    sz: int,
    bs: int,
    crappy_path: Path,
    good_path: Path,
    random_seed: int = None,
    keep_pct: float = 1.0,
    num_workers: int = 8,
    stats: tuple = imagenet_stats,
    xtra_tfms=[],
) -> ImageDataBunch:

    src = (ImageImageList.from_folder(crappy_path,
                                      convert_mode="RGB").use_partial_data(
                                          sample_pct=keep_pct,
                                          seed=random_seed).split_by_rand_pct(
                                              0.1, seed=random_seed))

    data = (src.label_from_func(
        lambda x: good_path / x.relative_to(crappy_path)).transform(
            get_transforms(max_zoom=1.2,
                           max_lighting=0.5,
                           max_warp=0.25,
                           xtra_tfms=xtra_tfms),
            size=sz,
            tfm_y=True,
        ).databunch(bs=bs, num_workers=num_workers,
                    no_check=True).normalize(stats, do_y=True))

    data.c = 3
    return data
Ejemplo n.º 2
0
def load_data_mtl(path, bs=8, train_size=256, xtra_tfms=None, **db_kwargs):
    """
    Create databunch for multi-task learning (classification+segmentation)

    path: path to the csv linking image paths to run-length encoded masks
    bs: batch size
    train_size: size to which image are to be resized
    xtra_tfms: additional transforms to basic fastai ones

    return: databunch with train and validation datasets
    """
    train_list = (
        MultiTaskList.
        from_csv(path.parent, path.name).
        split_by_rand_pct(valid_pct=0.2).label_from_df(
            cols=[0, 1],
            classes=['bg', 'pneum'],
            label_cls=MultiTaskLabelList, train_path=path.parent).
        transform(
            get_transforms(do_flip=False, xtra_tfms=xtra_tfms),
            size=train_size, tfm_y=True).
        databunch(
            bs=bs, num_workers=0, **db_kwargs).
        normalize(imagenet_stats))
    return train_list
Ejemplo n.º 3
0
def get_DIV2k_data(pLow, pFull, bs: int, sz: int):
    """Given the path of low resolution images with a proper suffix
       returns a databunch
    """
    suffixes = {
        "dataset/DIV2K_train_LR_x8": "x8",
        "dataset/DIV2K_train_LR_difficult": "x4d",
        "dataset/DIV2K_train_LR_mild": "x4m"
    }
    lowResSuffix = suffixes[str(pLow)]
    src = ImageImageList.from_folder(pLow, presort=True).split_by_idxs(
        train_idx=list(range(0, 800)), valid_idx=list(range(800, 900)))

    data = (src.label_from_func(
        lambda x: pFull / (x.name).replace(lowResSuffix, '')).transform(
            get_transforms(max_rotate=30,
                           max_zoom=3.,
                           max_lighting=.4,
                           max_warp=.4,
                           p_affine=.85),
            size=sz,
            tfm_y=True,
        ).databunch(bs=bs, num_workers=8,
                    no_check=True).normalize(imagenet_stats, do_y=True))
    data.c = 3
    return data
Ejemplo n.º 4
0
def load_data_kfold_mtl(
        path, nfolds=5, bs=8, train_size=256, xtra_tfms=None, **db_kwargs):
    """
    Create databunches for multi-task learning (classification+segmentation)
    using k-fold cross-validation

    path: path to the csv linking image paths to run-length encoded masks
    nfolds: number of folds for cross-validation
    bs: batch size
    train_size: size to which image are to be resized
    xtra_tfms: additional transforms to basic fastai ones

    yield: nfolds databunches with train and validation datasets
    """
    kf = KFold(n_splits=nfolds, shuffle=True)
    train_list = (MultiTaskList.
                  from_csv(path.parent, path.name))
    for _, valid_idx in kf.split(np.arange(len(train_list))):
        db = (
            train_list.split_by_idx(valid_idx).
            label_from_df(
                cols=[0, 1],
                classes=['bg', 'pneum'],
                label_cls=MultiTaskLabelList, train_path=path.parent).
            transform(
                get_transforms(do_flip=False, xtra_tfms=xtra_tfms),
                size=train_size, tfm_y=True).
            databunch(
                bs=bs, num_workers=0, **db_kwargs).
            normalize(imagenet_stats))
        yield db
Ejemplo n.º 5
0
 def get_data_generator(self):
     data = (SSDObjectItemList.from_folder(self.images_path).split_by_files(
         list(self.img2bbox_v.keys())).label_from_func(
             self.get_y_func).transform(
                 get_transforms(), tfm_y=True, size=224).databunch(
                     bs=64,
                     collate_fn=bb_pad_collate).normalize(imagenet_stats))
     return data
def run_mnist(input_path,
              output_path,
              batch_size,
              epochs,
              learning_rate,
              model=Mnist_NN()):

    path = Path(input_path)

    ## Defining transformation
    ds_tfms = get_transforms(
        do_flip=False,
        flip_vert=False,
        max_rotate=15,
        max_zoom=1.1,
        max_lighting=0.2,
        max_warp=0.2,
    )

    ## Creating Databunch
    data = (ImageItemList.from_folder(path, convert_mode="L").split_by_folder(
        train="training", valid="testing").label_from_folder().transform(
            tfms=ds_tfms, size=28).databunch(bs=batch_size))

    ## Defining the learner
    mlp_learner = Learner(data=data,
                          model=model,
                          loss_func=nn.CrossEntropyLoss(),
                          metrics=accuracy)

    # Training the model
    mlp_learner.fit_one_cycle(epochs, learning_rate)

    val_acc = int(
        np.round(mlp_learner.recorder.metrics[-1][0].numpy().tolist(), 3) *
        1000)

    ## Saving the model
    mlp_learner.save("mlp_mnist_stg_1_" + str(val_acc))

    ## Evaluation
    print("Evaluating Network..")
    interp = ClassificationInterpretation.from_learner(mlp_learner)
    print(classification_report(interp.y_true, interp.pred_class))

    ## Plotting train and validation loss
    mlp_learner.recorder.plot_losses()
    plt.savefig(output_path + "/loss.png")

    mlp_learner.recorder.plot_metrics()
    plt.savefig(output_path + "/metric.png")
Ejemplo n.º 7
0
def get_DIV2k_data_QF(pLow, pFull, bs: int, sz: int):
    """Given the path of low resolution images
       returns a databunch
    """
    src = ImageImageList.from_folder(pLow, presort=True).split_by_idxs(
        train_idx=list(range(0, 800)), valid_idx=list(range(800, 900)))

    data = (src.label_from_func(
        lambda x: pFull / (x.name.replace(".jpg", ".png"))).transform(
            get_transforms(max_zoom=2.), size=sz,
            tfm_y=True).databunch(bs=bs, num_workers=8,
                                  no_check=True).normalize(imagenet_stats,
                                                           do_y=True))
    data.c = 3
    return data
Ejemplo n.º 8
0
def load_image_databunch(input_path, classes):
    """
    Code to define a databunch compatible with model
    """
    tfms = get_transforms(
        do_flip=False,
        flip_vert=False,
        max_rotate=0,
        max_lighting=0,
        max_zoom=1,
        max_warp=0,
    )

    data_bunch = ImageDataBunch.single_from_classes(
        Path(input_path), classes, ds_tfms=tfms, size=224
    )

    return data_bunch
Ejemplo n.º 9
0
def load_catsvsdog(input_path, batch_size):
    """
    Function to load data from cats vs dog Kaggle competition
    """
    path = Path(input_path)
    fnames = get_image_files(path)

    # Creating Databunch
    data = ImageDataBunch.from_name_re(
        path,
        fnames,
        pat=r"([^/]+)\.\d+.jpg$",
        ds_tfms=get_transforms(),
        valid_pct=0.2,
        size=227,
        bs=batch_size,
    ).normalize()

    return data
Ejemplo n.º 10
0
def load_data_classif(
        path, bs=8, train_size=256, xtra_tfms=None, **db_kwargs):
    """
    Create databunch for classification task

    path: path to the csv linking image paths to labels
    bs: batch size
    train_size: size to which image are to be resized
    xtra_tfms: additional transforms to basic fastai ones

    return: databunch with train and validation datasets
    """
    train_list = (PneumoClassifList.
                  from_csv(path.parent, path.name).
                  split_by_rand_pct(valid_pct=0.2).
                  label_from_df().
                  transform(get_transforms(do_flip=False, xtra_tfms=xtra_tfms),
                            size=train_size).
                  databunch(bs=bs, num_workers=0, **db_kwargs).
                  normalize(imagenet_stats))
    return train_list
Ejemplo n.º 11
0
def predict(images):
    #loadig model
    train_dir = "./pipelines/lymph_node/data/data_bunch"
    base_dir = "./pipelines/lymph_node/data"  #base directory
    l = os.listdir(train_dir)
    #random.shuffle(l)
    tfms = get_transforms(do_flip=True)
    #do_flip: if True, a random flip is applied with probability 0.5 to images
    bs = 64  # also the default batch size
    print("loaddd")
    #ImageDataBunch splits out the imnages (in the train sub-folder) into a training set and validation set (defaulting to an 80/20 percent split)
    data = ImageDataBunch.from_csv(base_dir,
                                   ds_tfms=tfms,
                                   size=224,
                                   suffix=".tiff",
                                   folder="data_bunch",
                                   csv_labels="dummy_labels.csv",
                                   bs=bs)
    print("valid ", data.valid_ds)
    print("train ", data.train_ds)
    print("test ", data.test_ds)
    # transform the image values according to the nueral network we are using
    data.normalize(imagenet_stats)

    #cnn_learner loads the model into learn variable`
    learn = cnn_learner(data,
                        models.densenet161,
                        metrics=error_rate,
                        callback_fns=ShowGraph)

    learn = learn.load("./densenet10epochs")
    #predicting labels
    print("prediction ", type(images))
    print(images)
    print("size ", len(images), images[0].shape)
    preds = learn.predict(images[0])
    #preds=learn.pred_batch(np.array(images))  #TO-DO!!!!!!!!
    print(type(preds))
    print("prediction ", preds)
    return preds
Ejemplo n.º 12
0
def get_colorize_data_with_samplers(sz: int,
                                    bs: int,
                                    crappy_path: Path,
                                    good_path: Path,
                                    random_seed: int = None,
                                    keep_pct: float = 1.0,
                                    num_workers: int = 4,
                                    samplers=None,
                                    valid_pct=0.2,
                                    xtra_tfms=[]) -> ImageDataBunch:
    src = (ImageImageList.from_folder(crappy_path,
                                      convert_mode='RGB').use_partial_data(
                                          sample_pct=keep_pct,
                                          seed=random_seed).split_by_rand_pct(
                                              valid_pct, seed=random_seed))
    if xtra_tfms is not None:
        data = (src.label_from_func(
            lambda x: good_path / x.relative_to(crappy_path)).transform(
                (xtra_tfms, None), size=sz,
                tfm_y=True).databunch(bs=bs,
                                      num_workers=num_workers,
                                      sampler=samplers,
                                      no_check=True).normalize(imagenet_stats,
                                                               do_y=True))
    else:
        data = (src.label_from_func(
            lambda x: good_path / x.relative_to(crappy_path)).transform(
                get_transforms(max_zoom=1.5,
                               max_lighting=0.4,
                               max_warp=0.25,
                               xtra_tfms=[]),
                size=sz,
                tfm_y=True).databunch(bs=bs,
                                      num_workers=num_workers,
                                      sampler=samplers,
                                      no_check=True).normalize(imagenet_stats,
                                                               do_y=True))
    data.c = 3
    return data
Ejemplo n.º 13
0
def train():
    # load the VGG16 network and initialize the label encoder
    model = VGG16(weights="imagenet", include_top=False)
    tfms = transform.get_transforms(do_flip=True,
                                    flip_vert=True,
                                    max_rotate=30.,
                                    max_zoom=1.05)
    train_x = np.array([])
    train_y = np.array([])
    # Collect all train images
    train_imgs = np.array([])
    for y, label in enumerate(LABELS):
        img_paths = list(paths.list_images(TRAIN_DIR + label))
        for path in img_paths:
            # Augment original image 5x times + original image, all in size = (244, 244, 3)
            augmented_imgs = np.array(
                [img_to_array(img) / 255 for img in augment_image(path, x=4)])
            augmented_y = np.ones(augmented_imgs.shape[0]) * y
            if train_imgs.shape == (0, ):
                train_imgs = augmented_imgs
                train_y = augmented_y
                continue
            train_imgs = np.r_[train_imgs, augmented_imgs]
            train_y = np.r_[train_y, augmented_y]
    dataset_size = train_imgs.shape[0]
    for start in range(0, dataset_size, BATCH_SIZE):
        batchImages = train_imgs[start:np.min((start + BATCH_SIZE,
                                               dataset_size))]
        features = model.predict(batchImages,
                                 batch_size=np.min(
                                     (BATCH_SIZE, len(batchImages))))
        features = features.reshape(
            (features.shape[0], np.prod(features.shape[1:])))
        if train_x.shape == (0, ):
            train_x = features
            continue
        train_x = np.r_[train_x, features]
    return train_x, train_y
Ejemplo n.º 14
0
def inference_dict(path_to_model, progress_output=True):
    """
    Makes inference on `test` images from `path_to_model/test` subfolder.
    """
    data = (ImageList.from_folder(path_to_model)
        .split_by_folder()
        .label_from_folder()
        .add_test_folder('test')
        .transform(get_transforms(), size=224)
        .databunch()
        .normalize(imagenet_stats))
    # load model from `export.pkl` file
    learn = load_learner(path_to_model)
    # inference on all test images
    res_dict = dict()
    for idx in range(len(data.test_ds)):
        img = data.test_ds[idx][0]
        start_time = time.time()
        label, _, probs = learn.predict(img)
        elapsed_time = time.time() - start_time
        label = str(label)
        fname = data.test_dl.dataset.items[idx].stem
        # create dictionary value (future dataframe row)
        row = [label]
        row.extend([float(p) for p in probs])
        row.extend([elapsed_time])
        res_dict[fname] = row
        if progress_output:
            print("'{}' --> '{:>17}' class with probabilities [{:04.2f}, {:04.2f}, {:04.2f}] inference time: {:04.3} seconds".
                  format(fname, label, probs[0], probs[1], probs[2], elapsed_time))
    # creating columns names for pretty outputs
    prob_names = data.classes
    prob_names = ["p_" + el for el in prob_names]
    columns = ['label']
    columns.extend(prob_names)
    columns.extend(['time'])
    df = pd.DataFrame.from_dict(res_dict, orient='index', columns=columns)
    return df
Ejemplo n.º 15
0
def load_model(inference=False):
    if inference:
        data = ImageDataBunch.load_empty(TRAIN_PATH)
    else:
        np.random.seed(1337)  # give consistency to the validation set
        data = ImageDataBunch.from_folder(TRAIN_PATH,
                                          train=".",
                                          valid_pct=0.1,
                                          ds_tfms=transform.get_transforms(),
                                          size=224,
                                          num_workers=4,
                                          bs=32).normalize(imagenet_stats)

        data.export()  # Save the classes used in training for inference

    learn = learner.cnn_learner(data,
                                models.resnet34,
                                metrics=metrics.error_rate)

    if inference:
        learn.load(MODEL_NAME)

    return learn, data
Ejemplo n.º 16
0
def prepare_data(path, class_mapping=None, chip_size=224, val_split_pct=0.1, batch_size=64, transforms=None, collate_fn=_bb_pad_collate, seed=42, dataset_type = None):
    """
    Prepares a Fast.ai DataBunch from the exported Pascal VOC image chips
    exported by Export Training Data tool in ArcGIS Pro or Image Server.
    This DataBunch consists of training and validation DataLoaders with the
    specified transformations, chip size, batch size, split percentage.

    =====================   ===========================================
    **Argument**            **Description**
    ---------------------   -------------------------------------------
    path                    Required string. Path to data directory.
    ---------------------   -------------------------------------------
    class_mapping           Optional dictionary. Mapping from id to
                            its string label.
    ---------------------   -------------------------------------------
    chip_size               Optional integer. Size of the image to train the
                            model.
    ---------------------   -------------------------------------------
    val_split_pct           Optional float. Percentage of training data to keep
                            as validation.
    ---------------------   -------------------------------------------
    batch_size              Optional integer. Batch size for mini batch gradient
                            descent (Reduce it if getting CUDA Out of Memory
                            Errors).
    ---------------------   -------------------------------------------
    transforms              Optional tuple. Fast.ai transforms for data
                            augmentation of training and validation datasets
                            respectively (We have set good defaults which work
                            for satellite imagery well).
    ---------------------   -------------------------------------------
    collate_fn              Optional function. Passed to PyTorch to collate data
                            into batches(usually default works).
    ---------------------   -------------------------------------------
    seed                    Optional integer. Random seed for reproducible
                            train-validation split.
    ---------------------   -------------------------------------------
    dataset_type            Optional string. `prepare_data` function will infer 
                            the `dataset_type` on its own if it contains a 
                            map.txt file. If the path does not contain the 
                            map.txt file pass either of 'PASCAL_VOC_rectangles', 
                            'RCNN_Masks' and 'Classified_Tiles'                    
                            
    =====================   ===========================================

    :returns: fastai DataBunch object
    """

    if not HAS_FASTAI:
        _raise_fastai_import_error()

    if type(path) is str:
        path = Path(path)

    databunch_kwargs = {'num_workers':0} if sys.platform == 'win32' else {}

    json_file = path / 'esri_model_definition.emd'
    with open(json_file) as f:
        emd = json.load(f)

    if class_mapping is None:
        try:
            class_mapping = {i['Value'] : i['Name'] for i in emd['Classes']}
        except KeyError:
            class_mapping = {i['ClassValue'] : i['ClassName'] for i in emd['Classes']}
    
    color_mapping = None
    if color_mapping is None:
        try:
            color_mapping = {i['Value'] : i['Color'] for i in emd['Classes']}
        except KeyError:          
            color_mapping = {i['ClassValue'] : i['Color'] for i in emd['Classes']}                

        # if [-1, -1, -1] in color_mapping.values():
        #     for c_idx, c_color in color_mapping.items():
        #         if c_color[0] == -1:
        #             color_mapping[c_idx] = [random.choice(range(256)) for i in range(3)]

        #color_mapping[0] = [0, 0, 0] 

    if dataset_type is None:

        stats_file = path / 'esri_accumulated_stats.json'
        with open(stats_file) as f:
            stats = json.load(f)
            dataset_type = stats['MetaDataMode']

        # imagefile_types = ['png', 'jpg', 'tif', 'jpeg', 'tiff']
        # bboxfile_types = ['xml', 'json']
        with open(path / 'map.txt') as f:
            line = f.readline()
        # left = line.split()[0].split('.')[-1].lower()
        right = line.split()[1].split('.')[-1].lower()
        
        # if (left in imagefile_types) and (right in imagefile_types):
        #     dataset_type = 'RCNN_Masks'
        # elif (left in imagefile_types) and (right in bboxfile_types):
        #     dataset_type = 'PASCAL_VOC_rectangles'
        # else:
        #     raise NotImplementedError('Cannot infer dataset type. The dataset type is not implemented')            
        
    
    if dataset_type in ['RCNN_Masks', 'Classified_Tiles']:
        
        def get_y_func(x, ext=right):
            return x.parents[1] / 'labels' / (x.stem + '.{}'.format(ext))
        
        src = (ArcGISSegmentationItemList.from_folder(path/'images')
           .random_split_by_pct(val_split_pct, seed=seed)
           .label_from_func(get_y_func, classes=['NoData'] + list(class_mapping.values()), class_mapping=class_mapping, color_mapping=color_mapping)) #TODO : Handel NoData case

        if transforms is None:
            transforms = get_transforms(flip_vert=True,
                                        max_rotate=90.,
                                        max_zoom=3.0,
                                        max_lighting=0.5) #,
    #                                     xtra_tfms=[skew(direction=(1,8),
    #                                     magnitude=(0.2,0.8))]) 


        data = (src
            .transform(transforms, size=chip_size, tfm_y=True)
            .databunch(bs=batch_size, **databunch_kwargs)
            .normalize(imagenet_stats))
        
    elif dataset_type == 'PASCAL_VOC_rectangles': 


        get_y_func = partial(_get_bbox_lbls, class_mapping=class_mapping)

        src = (SSDObjectItemList.from_folder(path/'images')
           .random_split_by_pct(val_split_pct, seed=seed)
           .label_from_func(get_y_func))

        if transforms is None:
            ranges = (0,1)
            train_tfms = [crop(size=chip_size, p=1., row_pct=ranges, col_pct=ranges), dihedral_affine(), brightness(change=(0.4, 0.6)), contrast(scale=(0.75, 1.5)), rand_zoom(scale=(0.75, 1.5))]
            val_tfms = [crop(size=chip_size, p=1., row_pct=0.5, col_pct=0.5)]
            transforms = (train_tfms, val_tfms)

        data = (src
            .transform(transforms, tfm_y=True)
            .databunch(bs=batch_size, collate_fn=collate_fn, **databunch_kwargs)
            .normalize(imagenet_stats))
        
    elif dataset_type == 'Labeled_Tiles':


        get_y_func = partial(_get_lbls, class_mapping=class_mapping)

        src = (ImageItemList.from_folder(path/'images')
           .random_split_by_pct(val_split_pct, seed=42)
           .label_from_func(get_y_func))

        if transforms is None:
            # transforms = get_transforms(flip_vert=True,
            #                             max_warp=0,
            #                             max_rotate=90.,
            #                             max_zoom=1.5,
            #                             max_lighting=0.5)
            ranges = (0, 1)
            train_tfms = [rotate(degrees=30, p=0.5),
                crop(size=chip_size, p=1., row_pct=ranges, col_pct=ranges), 
                dihedral_affine(), brightness(change=(0.4, 0.6)), contrast(scale=(0.75, 1.5)),
                # rand_zoom(scale=(0.75, 1.5))
                ]
            val_tfms = [crop(size=chip_size, p=1.0, row_pct=0.5, col_pct=0.5)]
            transforms = (train_tfms, val_tfms)

        data = (src
            .transform(transforms, size=chip_size)
            .databunch(bs=batch_size, **databunch_kwargs)
            .normalize(imagenet_stats))
        
    else:
        raise NotImplementedError('Unknown dataset_type="{}".'.format(dataset_type))    

    data.chip_size = chip_size
    data.class_mapping = class_mapping
    data.color_mapping = color_mapping
    show_batch_func = data.show_batch
    show_batch_func = partial(show_batch_func, rows=min(int(math.sqrt(batch_size)), 5))
    data.show_batch = show_batch_func
    data.orig_path = path

    return data
Ejemplo n.º 17
0
def get_learner(classes):
    # TODO: Can we make this faster/lighter?
    data = ImageDataBunch.single_from_classes(".", classes, ds_tfms=get_transforms(), size=224).normalize(imagenet_stats)
    learn = create_cnn(data, resnet34, pretrained=False)
    learn.load('makemodel-392')
    return learn
Ejemplo n.º 18
0
'''
File: /Users/origami/Desktop/dl-projects/dl-playground/src/imagesClassify/img_classify.py
Project: /Users/origami/Desktop/dl-projects/dl-playground/src/imagesClassify
Created Date: Monday May 20th 2019
Author: Rick yang tongxue(🍔🍔) ([email protected])
-----
Last Modified: Wednesday May 22nd 2019 9:18:55 am
Modified By: Rick yang tongxue(🍔🍔) ([email protected])
-----
'''
from fastai.vision.transform import get_transforms
from numpy import random
from fastai.vision.data import ImageList
import pandas as pd
import os, sys
path = os.path.abspath('../../data/imageClassify')
df = pd.read_csv(path + '/list_attr_celeba_fixed.csv')
tfms = get_transforms(flip_vert=True,
                      max_lighting=0.1,
                      max_zoom=1.05,
                      max_warp=0.)
src = (ImageList.from_csv(
    path, 'list_attr_celeba.csv',
    folder='img_align_celeba').split_by_rand_pct(0.2).label_from_df(
        label_delim=' '))
Ejemplo n.º 19
0
Training the pytorch model
"""
import torch
from fastai.vision.transform import get_transforms
from fastai.vision.learner import cnn_learner
from fastai.vision.data import ImageDataBunch, imagenet_stats
from fastai.metrics import accuracy
from fastai.vision import models
from PIL import ImageFile
import dill

#defaults.device = torch.device('cuda')
DATA_PATH = '/valohai/inputs/dataset/dataset/'
MODEL_PATH = '/valohai/outputs/'
# Data augmentation: create a list of flip, rotate, zoom, warp, lighting transforms...
tfms = get_transforms()
# Create databunch from imagenet style dataset in path with
# images resized 224x224 and batch size equal to 64
# and validation set about 30% of the dataset
data = ImageDataBunch.from_folder(DATA_PATH,
                                  ds_tfms=tfms,
                                  size=224,
                                  bs=64,
                                  valid_pct=0.3).normalize(imagenet_stats)
# Get a pretrained model (resnet34) with a custom head that is suitable for our data.
learn = cnn_learner(data, models.resnet34, metrics=[accuracy])
learn.model_dir = MODEL_PATH

ImageFile.LOAD_TRUNCATED_IMAGES = True
# Fit a model following the 1cycle policy with 50 epochs
learn.fit_one_cycle(50)
Ejemplo n.º 20
0
def prepare_data(path,
                 class_mapping=None,
                 chip_size=224,
                 val_split_pct=0.1,
                 batch_size=64,
                 transforms=None,
                 collate_fn=_bb_pad_collate,
                 seed=42,
                 dataset_type=None,
                 resize_to=None,
                 **kwargs):
    """
    Prepares a data object from training sample exported by the 
    Export Training Data tool in ArcGIS Pro or Image Server, or training 
    samples in the supported dataset formats. This data object consists of 
    training and validation data sets with the specified transformations, 
    chip size, batch size, split percentage, etc. 
    -For object detection, use Pascal_VOC_rectangles format.
    -For feature categorization use Labelled Tiles or ImageNet format.
    -For pixel classification, use Classified Tiles format.
    -For entity extraction from text, use IOB, BILUO or ner_json formats. 

    =====================   ===========================================
    **Argument**            **Description**
    ---------------------   -------------------------------------------
    path                    Required string. Path to data directory.
    ---------------------   -------------------------------------------
    class_mapping           Optional dictionary. Mapping from id to
                            its string label.
                            For dataset_type=IOB, BILUO or ner_json:
                                Provide address field as class mapping
                                in below format:
                                class_mapping={'address_tag':'address_field'}
    ---------------------   -------------------------------------------
    chip_size               Optional integer. Size of the image to train the
                            model.
    ---------------------   -------------------------------------------
    val_split_pct           Optional float. Percentage of training data to keep
                            as validation.
    ---------------------   -------------------------------------------
    batch_size              Optional integer. Batch size for mini batch gradient
                            descent (Reduce it if getting CUDA Out of Memory
                            Errors).
    ---------------------   -------------------------------------------
    transforms              Optional tuple. Fast.ai transforms for data
                            augmentation of training and validation datasets
                            respectively (We have set good defaults which work
                            for satellite imagery well). If transforms is set
                            to `False` no transformation will take place and 
                            `chip_size` parameter will also not take effect.
    ---------------------   -------------------------------------------
    collate_fn              Optional function. Passed to PyTorch to collate data
                            into batches(usually default works).
    ---------------------   -------------------------------------------
    seed                    Optional integer. Random seed for reproducible
                            train-validation split.
    ---------------------   -------------------------------------------
    dataset_type            Optional string. `prepare_data` function will infer 
                            the `dataset_type` on its own if it contains a 
                            map.txt file. If the path does not contain the 
                            map.txt file pass either of 'PASCAL_VOC_rectangles', 
                            'RCNN_Masks' and 'Classified_Tiles'                    
    ---------------------   -------------------------------------------
    resize_to               Optional integer. Resize the image to given size.
    =====================   ===========================================

    :returns: data object
    """
    """kwargs documentation
    imagery_type='RGB' # Change to known imagery_type or anything else to trigger multispectral
    bands=None # sepcify bands type for unknow imagery ['r', 'g', 'b', 'nir']
    rgb_bands=[0, 1, 2] # specify rgb bands indices for unknown imagery
    norm_pct=0.3 # sample of images to calculate normalization stats on 
    do_normalize=True # Normalize data 
    """

    height_width = []

    if not HAS_FASTAI:
        _raise_fastai_import_error()

    if isinstance(path, str) and not os.path.exists(path):
        raise Exception("Invalid input path.")

    if type(path) is str:
        path = Path(path)

    databunch_kwargs = {'num_workers': 0} if sys.platform == 'win32' else {}
    databunch_kwargs['bs'] = batch_size

    kwargs_transforms = {}
    if resize_to:
        kwargs_transforms['size'] = resize_to

    has_esri_files = _check_esri_files(path)
    alter_class_mapping = False
    color_mapping = None

    # Multispectral Kwargs init
    _bands = None
    _imagery_type = None
    _is_multispectral = False
    _show_batch_multispectral = None

    if dataset_type is None and not has_esri_files:
        raise Exception("Could not infer dataset type.")

    if dataset_type != "Imagenet" and has_esri_files:
        stats_file = path / 'esri_accumulated_stats.json'
        with open(stats_file) as f:
            stats = json.load(f)
            dataset_type = stats['MetaDataMode']

        with open(path / 'map.txt') as f:
            line = f.readline()

        right = line.split()[1].split('.')[-1].lower()

        json_file = path / 'esri_model_definition.emd'
        with open(json_file) as f:
            emd = json.load(f)

        # Create Class Mapping from EMD if not specified by user
        ## Validate user defined class_mapping keys with emd (issue #3064)
        # Get classmapping from emd file.
        try:
            emd_class_mapping = {i['Value']: i['Name'] for i in emd['Classes']}
        except KeyError:
            emd_class_mapping = {
                i['ClassValue']: i['ClassName']
                for i in emd['Classes']
            }

        ## Change all keys to int.
        if class_mapping is not None:
            class_mapping = {
                int(key): value
                for key, value in class_mapping.items()
            }
        else:
            class_mapping = {}

        ## Map values from user defined classmapping to emd classmapping.
        for key, _ in emd_class_mapping.items():
            if class_mapping.get(key) is not None:
                emd_class_mapping[key] = class_mapping[key]

        class_mapping = emd_class_mapping

        color_mapping = {(i.get('Value', 0) or i.get('ClassValue', 0)):
                         i['Color']
                         for i in emd.get('Classes', [])}

        if color_mapping.get(None):
            del color_mapping[None]

        if class_mapping.get(None):
            del class_mapping[None]

        # Multispectral support from EMD
        # Not Implemented Yet
        if emd.get('bands', None) is not None:
            _bands = emd.get['bands']  # Not Implemented

        if emd.get('imagery_type', None) is not None:
            _imagery_type = emd.get['imagery_type']  # Not Implemented

    elif dataset_type == 'PASCAL_VOC_rectangles' and not has_esri_files:
        if class_mapping is None:
            class_mapping = _get_class_mapping(path / 'labels')
            alter_class_mapping = True

    # Multispectral check
    imagery_type = 'RGB'
    if kwargs.get('imagery_type', None) is not None:
        imagery_type = kwargs.get('imagery_type')
    elif _imagery_type is not None:
        imagery_type = _imagery_type

    bands = None
    if kwargs.get('bands', None) is not None:
        bands = kwargs.get('bands')
        for i, b in enumerate(bands):
            if type(b) == str:
                bands[i] = b.lower()
    elif imagery_type_lib.get(imagery_type, None) is not None:
        bands = imagery_type_lib.get(imagery_type)['bands']
    elif _bands is not None:
        bands = _bands

    rgb_bands = None
    if kwargs.get('rgb_bands', None) is not None:
        rgb_bands = kwargs.get('rgb_bands')
    elif bands is not None:
        rgb_bands = [bands.index(b) for b in ['r', 'g', 'b'] if b in bands]

    if (bands is not None) or (rgb_bands
                               is not None) or (not imagery_type == 'RGB'):
        if imagery_type == 'RGB':
            imagery_type = 'multispectral'
        _is_multispectral = True

    if kwargs.get('norm_pct', None) is not None:
        norm_pct = kwargs.get('norm_pct')
        norm_pct = min(max(0, norm_pct), 1)
    else:
        norm_pct = .3

    lighting_transforms = kwargs.get('lighting_transforms', True)

    if dataset_type == 'RCNN_Masks':

        def get_labels(x, label_dirs, ext=right):
            label_path = []
            for lbl in label_dirs:
                if os.path.exists(Path(lbl) / (x.stem + '.{}'.format(ext))):
                    label_path.append(Path(lbl) / (x.stem + '.{}'.format(ext)))
            return label_path

        if class_mapping.get(0):
            del class_mapping[0]

        if color_mapping.get(0):
            del color_mapping[0]

        # Handle Multispectral
        if _is_multispectral:
            src = (ArcGISInstanceSegmentationMSItemList.from_folder(
                path / 'images').split_by_rand_pct(val_split_pct, seed=seed))
            _show_batch_multispectral = show_batch_rcnn_masks
        else:
            src = (ArcGISInstanceSegmentationItemList.from_folder(
                path / 'images').split_by_rand_pct(val_split_pct, seed=seed))

        label_dirs = []
        index_dir = {}  #for handling calss value with any number
        for i, k in enumerate(sorted(class_mapping.keys())):
            label_dirs.append(class_mapping[k])
            index_dir[k] = i + 1
        label_dir = [
            os.path.join(path / 'labels', lbl) for lbl in label_dirs
            if os.path.isdir(os.path.join(path / 'labels', lbl))
        ]
        get_y_func = partial(get_labels, label_dirs=label_dir)
        src = src.label_from_func(get_y_func,
                                  chip_size=chip_size,
                                  classes=['NoData'] +
                                  list(class_mapping.values()),
                                  class_mapping=class_mapping,
                                  color_mapping=color_mapping,
                                  index_dir=index_dir)

    elif dataset_type == 'Classified_Tiles':

        def get_y_func(x, ext=right):
            return x.parents[1] / 'labels' / (x.stem + '.{}'.format(ext))

        if class_mapping.get(0):
            del class_mapping[0]

        if color_mapping.get(0):
            del color_mapping[0]

        if is_no_color(color_mapping):
            color_mapping = {
                j: [random.choice(range(256)) for i in range(3)]
                for j in class_mapping.keys()
            }

        # TODO : Handle NoData case

        # Handle Multispectral
        if _is_multispectral:
            data = ArcGISSegmentationMSItemList.from_folder(path/'images')\
                .split_by_rand_pct(val_split_pct, seed=seed)\
                .label_from_func(
                    get_y_func, classes=(['NoData'] + list(class_mapping.values())),
                    class_mapping=class_mapping,
                    color_mapping=color_mapping
                )
            _show_batch_multispectral = _show_batch_unet_multispectral

            def classified_tiles_collate_fn(
                samples
            ):  # The default fastai collate_fn was causing memory leak on tensors
                r = (torch.stack([x[0].data for x in samples]),
                     torch.stack([x[1].data for x in samples]))
                return r

            databunch_kwargs['collate_fn'] = classified_tiles_collate_fn

        else:
            data = ArcGISSegmentationItemList.from_folder(path/'images')\
                .split_by_rand_pct(val_split_pct, seed=seed)\
                .label_from_func(
                    get_y_func, classes=(['NoData'] + list(class_mapping.values())),
                    class_mapping=class_mapping,
                    color_mapping=color_mapping
                )

        if transforms is None:
            transforms = get_transforms(flip_vert=True,
                                        max_rotate=90.,
                                        max_zoom=3.0,
                                        max_lighting=0.5)

        kwargs_transforms['tfm_y'] = True
        kwargs_transforms['size'] = chip_size
    elif dataset_type == 'PASCAL_VOC_rectangles':
        not_label_count = [0]
        get_y_func = partial(_get_bbox_lbls,
                             class_mapping=class_mapping,
                             not_label_count=not_label_count,
                             height_width=height_width)

        if _is_multispectral:
            data = SSDObjectMSItemList.from_folder(path/'images')\
            .split_by_rand_pct(val_split_pct, seed=seed)\
            .label_from_func(get_y_func)
            _show_batch_multispectral = show_batch_pascal_voc_rectangles
        else:
            data = SSDObjectItemList.from_folder(path/'images')\
                .split_by_rand_pct(val_split_pct, seed=seed)\
                .label_from_func(get_y_func)

        if not_label_count[0]:
            logger = logging.getLogger()
            logger.warning("Please check your dataset. " +
                           str(not_label_count[0]) +
                           " images dont have the corresponding label files.")

        if transforms is None:
            ranges = (0, 1)
            train_tfms = [
                crop(size=chip_size, p=1., row_pct=ranges, col_pct=ranges),
                dihedral_affine() if has_esri_files else flip_lr(),
                brightness(change=(0.4, 0.6)),
                contrast(scale=(0.75, 1.5)),
                rand_zoom(scale=(1.0, 1.5))
            ]
            val_tfms = [crop(size=chip_size, p=1., row_pct=0.5, col_pct=0.5)]
            transforms = (train_tfms, val_tfms)

        kwargs_transforms['tfm_y'] = True
        databunch_kwargs['collate_fn'] = collate_fn
    elif dataset_type in ['Labeled_Tiles', 'Imagenet']:
        if dataset_type == 'Labeled_Tiles':
            get_y_func = partial(_get_lbls, class_mapping=class_mapping)
        else:

            def get_y_func(x):
                return x.parent.stem

        if _is_multispectral:
            data = ArcGISMSImageList.from_folder(path/'images')\
                .split_by_rand_pct(val_split_pct, seed=42)\
                .label_from_func(get_y_func)
            _show_batch_multispectral = show_batch_labeled_tiles
        else:
            data = ImageList.from_folder(path/'images')\
                .split_by_rand_pct(val_split_pct, seed=42)\
                .label_from_func(get_y_func)

        if dataset_type == 'Imagenet':
            class_mapping = {}
            index = 1
            for class_name in data.classes:
                class_mapping[index] = class_name
                index = index + 1

        if transforms is None:
            ranges = (0, 1)
            train_tfms = [
                rotate(degrees=30, p=0.5),
                crop(size=chip_size, p=1., row_pct=ranges, col_pct=ranges),
                dihedral_affine(),
                brightness(change=(0.4, 0.6)),
                contrast(scale=(0.75, 1.5))
            ]
            val_tfms = [crop(size=chip_size, p=1.0, row_pct=0.5, col_pct=0.5)]
            transforms = (train_tfms, val_tfms)
    elif dataset_type in ['ner_json', 'BIO', 'IOB', 'LBIOU', 'BILUO']:
        return ner_prepare_data(dataset_type=dataset_type,
                                path=path,
                                class_mapping=class_mapping,
                                val_split_pct=val_split_pct)
    else:
        raise NotImplementedError(
            'Unknown dataset_type="{}".'.format(dataset_type))

    if _is_multispectral:
        if dataset_type == 'RCNN_Masks':
            kwargs['do_normalize'] = False
            if transforms == None:
                data = (src.transform(
                    size=chip_size, tfm_y=True).databunch(**databunch_kwargs))
            else:
                data = (src.transform(
                    transforms, size=chip_size,
                    tfm_y=True).databunch(**databunch_kwargs))
        else:
            data = (data.transform(
                transforms, **kwargs_transforms).databunch(**databunch_kwargs))

        if len(data.x) < 300:
            norm_pct = 1

        # Statistics
        dummy_stats = {
            "batch_stats_for_norm_pct_0": {
                "band_min_values": None,
                "band_max_values": None,
                "band_mean_values": None,
                "band_std_values": None,
                "scaled_min_values": None,
                "scaled_max_values": None,
                "scaled_mean_values": None,
                "scaled_std_values": None
            }
        }
        normstats_json_path = os.path.abspath(data.path / '..' /
                                              'esri_normalization_stats.json')
        if not os.path.exists(normstats_json_path):
            normstats = dummy_stats
            with open(normstats_json_path, 'w', encoding='utf-8') as f:
                json.dump(normstats, f, ensure_ascii=False, indent=4)
        else:
            with open(normstats_json_path) as f:
                normstats = json.load(f)

        norm_pct_search = f"batch_stats_for_norm_pct_{round(norm_pct*100)}"
        if norm_pct_search in normstats:
            batch_stats = normstats[norm_pct_search]
            for s in batch_stats:
                if batch_stats[s] is not None:
                    batch_stats[s] = torch.tensor(batch_stats[s])
        else:
            batch_stats = _get_batch_stats(data.x, norm_pct)
            normstats[norm_pct_search] = dict(batch_stats)
            for s in normstats[norm_pct_search]:
                if normstats[norm_pct_search][s] is not None:
                    normstats[norm_pct_search][s] = normstats[norm_pct_search][
                        s].tolist()
            with open(normstats_json_path, 'w', encoding='utf-8') as f:
                json.dump(normstats, f, ensure_ascii=False, indent=4)

        # batch_stats -> [band_min_values, band_max_values, band_mean_values, band_std_values, scaled_min_values, scaled_max_values, scaled_mean_values, scaled_std_values]
        data._band_min_values = batch_stats['band_min_values']
        data._band_max_values = batch_stats['band_max_values']
        data._band_mean_values = batch_stats['band_mean_values']
        data._band_std_values = batch_stats['band_std_values']
        data._scaled_min_values = batch_stats['scaled_min_values']
        data._scaled_max_values = batch_stats['scaled_max_values']
        data._scaled_mean_values = batch_stats['scaled_mean_values']
        data._scaled_std_values = batch_stats['scaled_std_values']

        # Prevent Divide by zeros
        data._band_max_values[data._band_min_values ==
                              data._band_max_values] += 1
        data._scaled_std_values[data._scaled_std_values == 0] += 1e-02

        # Scaling
        data._min_max_scaler = partial(_tensor_scaler,
                                       min_values=data._band_min_values,
                                       max_values=data._band_max_values,
                                       mode='minmax')
        data._min_max_scaler_tfm = partial(_tensor_scaler_tfm,
                                           min_values=data._band_min_values,
                                           max_values=data._band_max_values,
                                           mode='minmax')

        #data.add_tfm(data._min_max_scaler_tfm)

        # Transforms
        def _scaling_tfm(x):
            ## Scales Fastai Image Scaling | MS Image Values -> 0 - 1 range
            return x.__class__(data._min_max_scaler_tfm((x.data, None))[0][0])

        ## Fastai need tfm, order and resolve.
        class dummy():
            pass

        _scaling_tfm.tfm = dummy()
        _scaling_tfm.tfm.order = 0
        _scaling_tfm.resolve = dummy

        ## Scaling the images before applying any  other transform
        if getattr(data.train_ds, 'tfms') is not None:
            data.train_ds.tfms = [_scaling_tfm] + data.train_ds.tfms
        else:
            data.train_ds.tfms = [_scaling_tfm]
        if getattr(data.valid_ds, 'tfms') is not None:
            data.valid_ds.tfms = [_scaling_tfm] + data.valid_ds.tfms
        else:
            data.valid_ds.tfms = [_scaling_tfm]

        # Normalize
        data._do_normalize = True
        if kwargs.get('do_normalize', None) is not None:
            data._do_normalize = kwargs.get('do_normalize', True)
        if data._do_normalize:
            data = data.normalize(stats=(data._scaled_mean_values,
                                         data._scaled_std_values),
                                  do_x=True,
                                  do_y=False)

    elif dataset_type == 'RCNN_Masks':
        if transforms == None:
            data = (src.transform(size=chip_size,
                                  tfm_y=True).databunch(**databunch_kwargs))
        else:
            data = (src.transform(transforms, size=chip_size,
                                  tfm_y=True).databunch(**databunch_kwargs))
        data.show_batch = types.MethodType(show_batch_rcnn_masks, data)
    else:
        #
        data = (data.transform(transforms, **kwargs_transforms).databunch(
            **databunch_kwargs).normalize(imagenet_stats))

    data.chip_size = data.x[0].shape[-1] if transforms is False else chip_size

    if alter_class_mapping:
        new_mapping = {}
        for i, class_name in enumerate(class_mapping.keys()):
            new_mapping[i + 1] = class_name
        class_mapping = new_mapping

    data.class_mapping = class_mapping
    data.color_mapping = color_mapping
    data.show_batch = partial(data.show_batch,
                              rows=min(int(math.sqrt(batch_size)), 5))
    data.orig_path = path
    data.resize_to = kwargs_transforms.get('size', None)
    data.height_width = height_width

    data._is_multispectral = _is_multispectral
    if data._is_multispectral:
        data._imagery_type = imagery_type
        data._bands = bands
        data._norm_pct = norm_pct
        data._rgb_bands = rgb_bands
        data._symbology_rgb_bands = rgb_bands

        # Handle invalid color mapping
        data._multispectral_color_mapping = color_mapping
        if any(-1 in x for x in data._multispectral_color_mapping.values()):
            random_color_list = np.random.randint(
                low=0,
                high=255,
                size=(len(data._multispectral_color_mapping), 3)).tolist()
            for i, c in enumerate(data._multispectral_color_mapping):
                if -1 in data._multispectral_color_mapping[c]:
                    data._multispectral_color_mapping[c] = random_color_list[i]

        # prepare color array
        alpha = kwargs.get('alpha', 0.7)
        color_array = torch.tensor(list(
            data.color_mapping.values())).float() / 255
        alpha_tensor = torch.tensor([alpha] * len(color_array)).view(
            -1, 1).float()
        color_array = torch.cat([color_array, alpha_tensor], dim=-1)
        background_color = torch.tensor([[0, 0, 0, 0]]).float()
        data._multispectral_color_array = torch.cat(
            [background_color, color_array])

        # Prepare unknown bands list if bands data is missing
        if data._bands is None:
            n_bands = data.x[0].data.shape[0]
            if n_bands == 1:  # Handle Pancromatic case
                data._bands = ['p']
                data._symbology_rgb_bands = [0]
            else:
                data._bands = ['u' for i in range(n_bands)]
                if n_bands == 2:  # Handle Data with two channels
                    data._symbology_rgb_bands = [0]

        #
        if data._rgb_bands is None:
            data._rgb_bands = []

        #
        if data._symbology_rgb_bands is None:
            data._symbology_rgb_bands = [0, 1, 2][:min(n_bands, 3)]

        # Complete symbology rgb bands
        if len(data._bands) > 2 and len(data._symbology_rgb_bands) < 3:
            data._symbology_rgb_bands += [
                min(max(data._symbology_rgb_bands) + 1,
                    len(data._bands) - 1)
                for i in range(3 - len(data._symbology_rgb_bands))
            ]

        # Overwrite band values at r g b indexes with 'r' 'g' 'b'
        for i, band_idx in enumerate(data._rgb_bands):
            if band_idx is not None:
                if data._bands[band_idx] == 'u':
                    data._bands[band_idx] = ['r', 'g', 'b'][i]

        # Attach custom show batch
        if _show_batch_multispectral is not None:
            data.show_batch = types.MethodType(_show_batch_multispectral, data)

        # Apply filter band transformation if user has specified extract_bands otherwise add a generic extract_bands
        """
        extract_bands : List containing band indices of the bands from imagery on which the model would be trained. 
                        Useful for benchmarking and applied training, for reference see examples below.
                        
                        4 band naip ['r, 'g', 'b', 'nir'] + extract_bands=[0, 1, 2] -> 3 band naip with bands ['r', 'g', 'b'] 

        """
        data._extract_bands = kwargs.get('extract_bands', None)
        if data._extract_bands is None:
            data._extract_bands = list(range(len(data._bands)))
        else:
            data._extract_bands_tfm = partial(_extract_bands_tfm,
                                              band_indices=data._extract_bands)
            data.add_tfm(data._extract_bands_tfm)

        # Tail Training Override
        _train_tail = True
        if [data._bands[i] for i in data._extract_bands] == ['r', 'g', 'b']:
            _train_tail = False
        data._train_tail = kwargs.get('train_tail', _train_tail)

    if has_esri_files:
        data._image_space_used = emd.get('ImageSpaceUsed', 'MAP_SPACE')
    else:
        data._image_space_used = 'PIXEL_SPACE'

    return data
range(fold * len(df) // nfolds, (fold + 1) * len(df) // nfolds)

# +
stats = ([0.0692], [0.2051])
data = (ImageList.from_df(
    df,
    path='.',
    folder=TRAIN,
    suffix='.png',
    cols='image_id',
    convert_mode='L').split_by_idx(
        range(fold * len(df) // nfolds,
              (fold + 1) * len(df) // nfolds)).label_from_df(cols=[
                  'grapheme_root', 'vowel_diacritic', 'consonant_diacritic'
              ]).transform(
                  transform.get_transforms(do_flip=False, max_warp=0.1),
                  size=sz,
                  padding_mode='zeros').databunch(bs=bs)).normalize(stats)

data.show_batch()


# +
class Head(nn.Module):
    def __init__(self, nc, n, ps=0.5):
        super().__init__()
        layers = [AdaptiveConcatPool2d(), Mish(), Flatten()] + \
            bn_drop_lin(nc*2, 512, True, ps, Mish()) + \
            bn_drop_lin(512, n, True, ps)
        self.fc = nn.Sequential(*layers)
        self._init_weight()