コード例 #1
0
def create_generators(configs):
    """create a input data generator"""
    # create random transform generator for augmenting training data
    if not configs["Data_Augmentation"]['only_x_flip']:
        transform_generator = random_transform_generator(
            min_rotation=configs['Data_Augmentation']['rotation'][0],
            max_rotation=configs['Data_Augmentation']['rotation'][1],
            min_translation=configs['Data_Augmentation']['min_translation'],
            max_translation=configs['Data_Augmentation']['max_translation'],
            min_shear=configs['Data_Augmentation']['shear'][0],
            max_shear=configs['Data_Augmentation']['shear'][1],
            min_scaling=configs['Data_Augmentation']['min_scaling'],
            max_scaling=configs['Data_Augmentation']['max_scaling'],
            flip_x_chance=0.5,
            gray=configs['Data_Augmentation']['gray'],
            inverse_color=configs['Data_Augmentation']['inverse_color'],
        )
    else:
        transform_generator = random_transform_generator(flip_x_chance=0.5)

    if configs['Dataset']['dataset_type'] == 'pascal':
        train_generator = PascalVocGenerator(
            configs['Dataset']['dataset_path'],
            'trainval',
            classes=configs['Dataset']['classes'],
            transform_generator=transform_generator,
            batch_size=configs['Train']['batch_size'],
            image_min_side=configs['Train']['image_min_side'],
            image_max_side=configs['Train']['image_max_side'],
            sizes=configs['Anchors']['sizes'],
            strides=configs['Anchors']['strides'],
            ratios=configs['Anchors']['ratios'],
            scales=configs['Anchors']['scales'],
        )

        validation_generator = PascalVocGenerator(
            configs['Dataset']['dataset_path'],
            'test',
            classes=configs['Dataset']['classes'],
            batch_size=configs['Train']['batch_size'],
            image_min_side=configs['Train']['image_min_side'],
            image_max_side=configs['Train']['image_max_side'],
            sizes=configs['Anchors']['sizes'],
            strides=configs['Anchors']['strides'],
            ratios=configs['Anchors']['ratios'],
            scales=configs['Anchors']['scales'],
        )
    else:
        raise ValueError('Invalid data type received: {}'.format(
            configs['Dataset']['dataset_type']))

    return train_generator, validation_generator
コード例 #2
0
ファイル: debug.py プロジェクト: maroonray/RetinaNet
def create_generator(args):
    # create random transform generator for augmenting training data
    transform_generator = random_transform_generator(
        # min_rotation=-0.1,
        # max_rotation=0.1,
        # min_translation=(-0.1, -0.1),
        # max_translation=(0.1, 0.1),
        # min_shear=-0.1,
        # max_shear=0.1,
        # min_scaling=(0.9, 0.9),
        # max_scaling=(1.1, 1.1),
        flip_x_chance=0.5,
        # flip_y_chance=0.5,
    )

    if args.dataset_type == 'coco':
        # import here to prevent unnecessary dependency on cocoapi
        from preprocessing.coco import CocoGenerator

        generator = CocoGenerator(args.coco_path,
                                  args.coco_set,
                                  transform_generator=transform_generator)
    elif args.dataset_type == 'pascal':
        generator = PascalVocGenerator(
            args.pascal_path,
            args.pascal_set,
            # classes={"bluehat": 0, 'whitehat': 1, 'yellohat': 2, 'redhat': 3, 'blackhat': 4, 'no_hat': 5},
            classes={
                'hathead': 0,
                'nohathead': 1
            },
            transform_generator=transform_generator,
            image_min_side=1080,
            image_max_side=1920,
            anchor_ratios=[0.5, 1, 1.5],
            anchor_scales=[0, 0.33333333, 0.66666667],
            anchor_sizes=[16, 32, 64, 128, 256],
            anchor_strides=[8, 16, 32, 64, 128])
    elif args.dataset_type == 'csv':
        generator = CSVGenerator(args.annotations,
                                 args.classes,
                                 transform_generator=transform_generator)
    elif args.dataset_type == 'oid':
        generator = OpenImagesGenerator(
            args.main_dir,
            subset=args.subset,
            version=args.version,
            labels_filter=args.labels_filter,
            fixed_labels=args.fixed_labels,
            annotation_cache_dir=args.annotation_cache_dir,
            transform_generator=transform_generator)
    elif args.dataset_type == 'kitti':
        generator = KittiGenerator(args.kitti_path,
                                   subset=args.subset,
                                   transform_generator=transform_generator)
    else:
        raise ValueError('Invalid data type received: {}'.format(
            args.dataset_type))

    return generator
コード例 #3
0
ファイル: debug.py プロジェクト: wang-tf/retinanet
def create_generator(args):
    # create random transform generator for augmenting training data
    transform_generator = random_transform_generator(
        min_rotation=-0.1,
        max_rotation=0.1,
        min_translation=(-0.1, -0.1),
        max_translation=(0.1, 0.1),
        min_shear=-0.1,
        max_shear=0.1,
        min_scaling=(0.9, 0.9),
        max_scaling=(1.1, 1.1),
        flip_x_chance=0.5,
        flip_y_chance=0.5,
    )

    if args.dataset_type == 'coco':
        # import here to prevent unnecessary dependency on cocoapi
        from preprocessing.coco import CocoGenerator

        generator = CocoGenerator(
            args.coco_path,
            args.coco_set,
            transform_generator=transform_generator
        )
    elif args.dataset_type == 'pascal':
        generator = PascalVocGenerator(
            args.pascal_path,
            args.pascal_set,
            classes={"rebar": 0},
            transform_generator=transform_generator
        )
    elif args.dataset_type == 'csv':
        generator = CSVGenerator(
            args.annotations,
            args.classes,
            transform_generator=transform_generator
        )
    elif args.dataset_type == 'oid':
        generator = OpenImagesGenerator(
            args.main_dir,
            subset=args.subset,
            version=args.version,
            labels_filter=args.labels_filter,
            fixed_labels=args.fixed_labels,
            annotation_cache_dir=args.annotation_cache_dir,
            transform_generator=transform_generator
        )
    elif args.dataset_type == 'kitti':
        generator = KittiGenerator(
            args.kitti_path,
            subset=args.subset,
            transform_generator=transform_generator
        )
    else:
        raise ValueError('Invalid data type received: {}'.format(args.dataset_type))

    return generator
コード例 #4
0
def verify_no_negative_regr():
    transform_generator = random_transform_generator(
        min_rotation=-0.1,
        max_rotation=0.1,
        min_translation=(-0.1, -0.1),
        max_translation=(0.1, 0.1),
        min_shear=-0.1,
        max_shear=0.1,
        min_scaling=(0.9, 0.9),
        max_scaling=(1.1, 1.1),
        flip_x_chance=0.5,
        flip_y_chance=0.5,
    )
    visual_effect_generator = random_visual_effect_generator(
        contrast_range=(0.9, 1.1),
        brightness_range=(-.1, .1),
        hue_range=(-0.05, 0.05),
        saturation_range=(0.95, 1.05)
    )
    common_args = {
        'batch_size': 1,
        'image_min_side': 800,
        'image_max_side': 1333,
        'preprocess_image': preprocess_image,
    }
    generator = PascalVocGenerator(
        'datasets/voc_trainval/VOC0712',
        'trainval',
        transform_generator=transform_generator,
        visual_effect_generator=visual_effect_generator,
        skip_difficult=True,
        **common_args
    )
    i = 0
    for image_group, targets in generator:
        i += 1
        if i > 20000:
            break
コード例 #5
0
def create_generators(args, preprocess_image):
    """
    Create generators for training and validation.

    Args
        args: parseargs object containing configuration for generators.
        preprocess_image: Function that preprocesses an image for the network.
    """
    common_args = {
        'batch_size': args.batch_size,
        'config': args.config,
        'image_min_side': args.image_min_side,
        'image_max_side': args.image_max_side,
        'preprocess_image': preprocess_image,
    }

    # create random transform generator for augmenting training data
    if args.random_transform:
        transform_generator = random_transform_generator(
            min_rotation=-0.1,
            max_rotation=0.1,
            min_translation=(-0.1, -0.1),
            max_translation=(0.1, 0.1),
            min_shear=-0.1,
            max_shear=0.1,
            min_scaling=(0.9, 0.9),
            max_scaling=(1.1, 1.1),
            flip_x_chance=0.5,
            flip_y_chance=0.5,
        )
        visual_effect_generator = random_visual_effect_generator(
            contrast_range=(0.9, 1.1),
            brightness_range=(-.1, .1),
            hue_range=(-0.05, 0.05),
            saturation_range=(0.95, 1.05)
        )
    else:
        transform_generator = random_transform_generator(flip_x_chance=0.5)
        visual_effect_generator = None

    if args.dataset_type == 'pascal':
        train_generator = PascalVocGenerator(
            args.pascal_path,
            'trainval',
            transform_generator=transform_generator,
            visual_effect_generator=visual_effect_generator,
            skip_difficult=True,
            **common_args
        )

        validation_generator = PascalVocGenerator(
            args.pascal_path,
            'val',
            shuffle_groups=False,
            skip_difficult=True,
            **common_args
        )
    elif args.dataset_type == 'csv':
        train_generator = CSVGenerator(
            args.annotations_path,
            args.classes_path,
            transform_generator=transform_generator,
            visual_effect_generator=visual_effect_generator,
            **common_args
        )

        if args.val_annotations_path:
            validation_generator = CSVGenerator(
                args.val_annotations_path,
                args.classes_path,
                shuffle_groups=False,
                **common_args
            )
        else:
            validation_generator = None
    elif args.dataset_type == 'coco':
        # import here to prevent unnecessary dependency on cocoapi
        from generators.coco_generator import CocoGenerator

        train_generator = CocoGenerator(
            args.coco_path,
            'train2017',
            transform_generator=transform_generator,
            visual_effect_generator=visual_effect_generator,
            **common_args
        )

        validation_generator = CocoGenerator(
            args.coco_path,
            'val2017',
            shuffle_groups=False,
            **common_args
        )
    else:
        raise ValueError('Invalid data type received: {}'.format(args.dataset_type))

    return train_generator, validation_generator
コード例 #6
0
ファイル: train.py プロジェクト: maximek3/mxk_retinanetsss
def create_generators(args, preprocess_image):
    """ Create generators for training and validation.
    Args
        args             : parseargs object containing configuration for generators.
        preprocess_image : Function that preprocesses an image for the network.
    """
    if args.fpn_layers == 5:
        fpn_layers = [3, 4, 5, 6, 7]
    elif args.fpn_layers == 4:
        fpn_layers = [4, 5, 6, 7]
    elif args.fpn_layers == 3:
        fpn_layers = [5, 6, 7]

    common_args = {
        'batch_size': args.batch_size,
        'config': args.config,
        'image_min_side': args.image_min_side,
        'image_max_side': args.image_max_side,
        'preprocess_image': preprocess_image,
        'negative_overlap': args.neg_overlap,
        'positive_overlap': args.pos_overlap,
        'train_type': args.train_type,
        'fpn_layers': fpn_layers,
        'augm': args.augm,
    }

    # create random transform generator for augmenting training data
    if args.random_transform:
        transform_generator = random_transform_generator(
            min_rotation=-0.1,
            max_rotation=0.1,
            min_translation=(-0.1, -0.1),
            max_translation=(0.1, 0.1),
            min_shear=-0.1,
            max_shear=0.1,
            min_scaling=(0.9, 0.9),
            max_scaling=(1.1, 1.1),
            flip_x_chance=0.5,
            flip_y_chance=0.5,
        )
    #elif args.random_transform == "augm_a":
    #  transform_generator = random_transform_generator(
    #      min_rotation=-0.5,
    #      max_rotation=0.5,
    #      min_translation=(-0.3, -0.3),
    #      max_translation=(0.3, 0.3),
    #      min_shear=-0.3,
    #      max_shear=0.3,
    #      min_scaling=(0.6, 0.6),
    #      max_scaling=(1.4, 1.4),
    #      flip_x_chance=0.5,
    #      flip_y_chance=0.5,
    #  )
    else:
        transform_generator = random_transform_generator(flip_x_chance=0.5)

    if args.dataset_type == 'ead' or args.dataset_type == 'polyp':
        train_generator = CSVGenerator(
            os.path.join(args.data_dir, args.annotations),
            os.path.join(args.data_dir, args.classes),
            transform_generator=transform_generator,
            base_dir=os.path.join(args.data_dir, args.train_dir),
            **common_args)

        if args.val_annotations:
            validation_generator = CSVGenerator(
                os.path.join(args.data_dir, args.val_annotations),
                os.path.join(args.data_dir, args.classes),
                base_dir=os.path.join(args.data_dir, args.val_dir),
                **common_args)
        else:
            validation_generator = None

    else:
        raise ValueError('Invalid data type received: {}'.format(dataset_type))

    return train_generator, validation_generator
コード例 #7
0
)

img = cv2.imread('../../dataset/raccoon/raccoon-1.jpg')
img = img / 255.0
if img is None:
    print("can't open image")
    exit()
print(img.shape)
img2 = np.concatenate(
    [img, np.expand_dims(np.mean(img, axis=-1), axis=-1)], axis=-1)
transform_generator = random_transform_generator(
    min_rotation=-0.1,
    max_rotation=0.1,
    min_translation=(-0.1, -0.1),
    max_translation=(0.1, 0.1),
    min_shear=-0.1,
    max_shear=0.1,
    min_scaling=(0.9, 0.9),
    max_scaling=(1.1, 1.1),
    flip_x_chance=0.5,
    flip_y_chance=0.5,
)
image = img2
transform_parameters = TransformParameters()
transform = adjust_transform_for_image(
    next(transform_generator), image,
    transform_parameters.relative_translation)
img_ = apply_transform(transform, img, transform_parameters)
image_ = apply_transform(transform, image, transform_parameters)

cv2.imshow('img original', img[:, :, :3])
cv2.waitKey(0)
コード例 #8
0
def create_generators(configs):
    """create a input data generator"""
    # create random transform generator for augmenting training data
    if not configs["Data_Augmentation"]['only_x_flip']:
        transform_generator = random_transform_generator(
            min_rotation=configs['Data_Augmentation']['rotation'][0],
            max_rotation=configs['Data_Augmentation']['rotation'][1],
            min_translation=configs['Data_Augmentation']['min_translation'],
            max_translation=configs['Data_Augmentation']['max_translation'],
            min_shear=configs['Data_Augmentation']['shear'][0],
            max_shear=configs['Data_Augmentation']['shear'][1],
            min_scaling=configs['Data_Augmentation']['min_scaling'],
            max_scaling=configs['Data_Augmentation']['max_scaling'],
            flip_x_chance=0.5,
            gray=configs['Data_Augmentation']['gray'],
            inverse_color=configs['Data_Augmentation']['inverse_color'],
        )
    else:
        transform_generator = random_transform_generator(flip_x_chance=0.5)

    if configs['Dataset']['dataset_type'] == 'coco':
        # import here to prevent unnecessary dependency on cocoapi
        from preprocessing.coco import CocoGenerator

        train_generator = CocoGenerator(
            configs['Dataset']['dataset_path'],
            'train2017',
            transform_generator=transform_generator,
            batch_size=configs['Train']['batch_size'],
            image_min_side=configs['Train']['image_min_side'],
            image_max_side=configs['Train']['image_max_side'])

        validation_generator = CocoGenerator(
            configs['Dataset']['dataset_path'],
            'val2017',
            batch_size=configs['Train']['batch_size'],
            image_min_side=configs['Train']['image_min_side'],
            image_max_side=configs['Train']['image_max_side'])
    elif configs['Dataset']['dataset_type'] == 'pascal':
        train_generator = PascalVocGenerator(
            configs['Dataset']['dataset_path'],
            'trainval',
            classes=configs['Dataset']['classes'],
            transform_generator=transform_generator,
            batch_size=configs['Train']['batch_size'],
            image_min_side=configs['Train']['image_min_side'],
            image_max_side=configs['Train']['image_max_side'],
            anchor_ratios=configs['Anchors']['ratios'],
            anchor_scales=configs['Anchors']['scales'],
            anchor_sizes=configs['Anchors']['sizes'],
            anchor_strides=configs['Anchors']['strides'])

        validation_generator = PascalVocGenerator(
            configs['Dataset']['dataset_path'],
            'test',
            classes=configs['Dataset']['classes'],
            batch_size=configs['Train']['batch_size'],
            image_min_side=configs['Train']['image_min_side'],
            image_max_side=configs['Train']['image_max_side'],
            anchor_ratios=configs['Anchors']['ratios'],
            anchor_scales=configs['Anchors']['scales'],
            anchor_sizes=configs['Anchors']['sizes'],
            anchor_strides=configs['Anchors']['strides'])
    elif configs['Dataset']['dataset_type'] == 'csv':
        train_generator = CSVGenerator(
            configs['Dataset']['csv_data_file'],
            configs['Dataset']['csv_classes_file'],
            transform_generator=transform_generator,
            batch_size=configs['Train']['batch_size'],
            image_min_side=configs['Train']['image_min_side'],
            image_max_side=configs['Train']['image_max_side'])

        if configs['Dataset']['csv_val_annotations']:
            validation_generator = CSVGenerator(
                configs['Dataset']['csv_val_annotations'],
                configs['Dataset']['csv_classes_file'],
                batch_size=configs['Train']['batch_size'],
                image_min_side=configs['Train']['image_min_side'],
                image_max_side=configs['Train']['image_max_side'])
        else:
            validation_generator = None
    elif configs['Dataset']['dataset_type'] == 'oid':
        train_generator = OpenImagesGenerator(
            configs['Dataset']['dataset_path'],
            subset='train',
            version=configs['Dataset']['version'],
            labels_filter=configs['Dataset']['oid_labels_filter'],
            annotation_cache_dir=configs['Dataset']
            ['oid_annotation_cache_dir'],
            fixed_labels=configs['Dataset']['fixed_labels'],
            transform_generator=transform_generator,
            batch_size=configs['Train']['batch_size'],
            image_min_side=configs['Train']['image_min_side'],
            image_max_side=configs['Train']['image_max_side'])

        validation_generator = OpenImagesGenerator(
            configs['Dataset']['dataset_path'],
            subset='validation',
            version=configs['Dataset']['version'],
            labels_filter=configs['Dataset']['oid_labels_filter'],
            annotation_cache_dir=configs['Dataset']
            ['oid_annotation_cache_dir'],
            fixed_labels=configs['Dataset']['fixed_labels'],
            batch_size=configs['Train']['batch_size'],
            image_min_side=configs['Train']['image_min_side'],
            image_max_side=configs['Train']['image_max_side'])
    elif configs['Dataset']['dataset_type'] == 'kitti':
        train_generator = KittiGenerator(
            configs['Dataset']['dataset_path'],
            subset='train',
            transform_generator=transform_generator,
            batch_size=configs['Train']['batch_size'],
            image_min_side=configs['Train']['image_min_side'],
            image_max_side=configs['Train']['image_max_side'])

        validation_generator = KittiGenerator(
            configs['Dataset']['dataset_path'],
            subset='val',
            batch_size=configs['Train']['batch_size'],
            image_min_side=configs['Train']['image_min_side'],
            image_max_side=configs['Train']['image_max_side'])
    else:
        raise ValueError('Invalid data type received: {}'.format(
            configs['Dataset']['dataset_type']))

    return train_generator, validation_generator
コード例 #9
0
ファイル: main.py プロジェクト: deepaksharma84/Retina_net
def main():
    backbone = models.backbone('resnet50')
    # create the generators
    #train_generator, validation_generator = create_generators(args, backbone.preprocess_image)
    random_transform = True
    val_annotations = './data/processed/val.csv'
    annotations = './data/processed/train.csv'
    classes = './data/processed/classes.csv'
    common_args = {
        'batch_size': 8,
        'image_min_side': 224,
        'image_max_side': 1333,
        'preprocess_image': backbone.preprocess_image,
    }
    # create random transform generator for augmenting training data
    if random_transform:
        transform_generator = random_transform_generator(
            min_rotation=-0.05,
            max_rotation=0.05,
            min_translation=(-0.1, -0.1),
            max_translation=(0.1, 0.1),
            #min_shear=-0.1,
            #max_shear=0.1,
            min_scaling=(0.8, 0.8),
            max_scaling=(1.2, 1.2),
            flip_x_chance=0.5,
            #flip_y_chance=0.5,
        )
    else:
        transform_generator = random_transform_generator(flip_x_chance=0.5)
    train_generator = CSVGenerator(annotations,
                                   classes,
                                   transform_generator=transform_generator,
                                   **common_args)

    if val_annotations:
        validation_generator = CSVGenerator(val_annotations, classes,
                                            **common_args)
    else:
        validation_generator = None

    #train_generator, validation_generator = create_generators(args, backbone.preprocess_image)
    num_classes = 1  # change
    model = backbone.retinanet(num_classes, backbone='resnet50')
    training_model = model

    # prediction_model = retinanet_bbox(model=model)
    nms = True
    class_specific_filter = True
    name = 'retinanet-bbox'
    anchor_params = AnchorParameters.default
    # compute the anchors
    features = [
        model.get_layer(p_name).output
        for p_name in ['P3', 'P4', 'P5', 'P6', 'P7']
    ]

    anchor = [
        layers.Anchors(size=anchor_params.sizes[i],
                       stride=anchor_params.strides[i],
                       ratios=anchor_params.ratios,
                       scales=anchor_params.scales,
                       name='anchors_{}'.format(i))(f)
        for i, f in enumerate(features)
    ]
    anchors = keras.layers.Concatenate(axis=1, name='anchors')(anchor)
    # we expect the anchors, regression and classification values as first output
    regression = model.outputs[0]  # check
    classification = model.outputs[1]

    # "other" can be any additional output from custom submodels, by default this will be []
    other = model.outputs[2:]

    # apply predicted regression to anchors
    boxes = layers.RegressBoxes(name='boxes')([anchors, regression])
    boxes = layers.ClipBoxes(name='clipped_boxes')([model.inputs[0], boxes])

    # filter detections (apply NMS / score threshold / select top-k)
    detections = layers.FilterDetections(
        nms=nms,
        class_specific_filter=class_specific_filter,
        name='filtered_detections')([boxes, classification] + other)

    outputs = detections

    # construct the model
    prediction_model = keras.models.Model(inputs=model.inputs,
                                          outputs=outputs,
                                          name=name)

    # end of prediction_model = retinanet_bbox(model=model)

    # compile model
    training_model.compile(loss={
        'regression': losses.smooth_l1(),
        'classification': losses.focal()
    },
                           optimizer=keras.optimizers.SGD(lr=1e-2,
                                                          momentum=0.9,
                                                          decay=.0001,
                                                          nesterov=True,
                                                          clipnorm=1)
                           # , clipnorm=0.001)
                           )
    print(model.summary())
    # start of create_callbacks
    #callbacks = create_callbacks(model,training_model,prediction_model,validation_generator,args,)
    callbacks = []
    tensorboard_callback = None
    tensorboard_callback = keras.callbacks.TensorBoard(
        log_dir='',
        histogram_freq=0,
        batch_size=8,
        write_graph=True,
        write_grads=False,
        write_images=False,
        embeddings_freq=0,
        embeddings_layer_names=None,
        embeddings_metadata=None)
    callbacks.append(tensorboard_callback)
    evaluation = Evaluate(validation_generator,
                          tensorboard=tensorboard_callback,
                          weighted_average=False)
    evaluation = RedirectModel(evaluation, prediction_model)
    callbacks.append(evaluation)
    makedirs('./snapshots/')
    checkpoint = keras.callbacks.ModelCheckpoint(os.path.join(
        './snapshots/', '{backbone}_{dataset_type}_{{epoch:02d}}.h5'.format(
            backbone='resnet50', dataset_type='csv')),
                                                 verbose=1,
                                                 save_best_only=False,
                                                 monitor="mAP",
                                                 mode='max')

    checkpoint = RedirectModel(checkpoint, model)
    callbacks.append(checkpoint)
    callbacks.append(
        keras.callbacks.ReduceLROnPlateau(monitor='loss',
                                          factor=0.9,
                                          patience=4,
                                          verbose=1,
                                          mode='auto',
                                          min_delta=0.0001,
                                          cooldown=0,
                                          min_lr=0))
    steps = 2500
    epochs = 25

    # start training
    history = training_model.fit(
        generator=train_generator,
        steps_per_epoch=steps,
        epochs=epochs,
        verbose=1,
        callbacks=callbacks,
    )

    timestr = time.strftime("%Y-%m-%d-%H%M")

    history_path = os.path.join(
        './snapshots/', '{timestr}_{backbone}.csv'.format(timestr=timestr,
                                                          backbone='resnet50',
                                                          dataset_type='csv'))
    pd.DataFrame(history.history).to_csv(history_path)
コード例 #10
0
def create_generators(args, preprocess_image):
    """!@brief
    Create generators for training and validation.

    @param args             : Parseargs object containing configuration for generators.
    @param preprocess_image : Function that preprocesses an image for the network.

    @return
        The generators created.
    """
    common_args = {
        'batch_size': args.batch_size,
        'config': args.config,
        'image_min_side': args.image_min_side,
        'image_max_side': args.image_max_side,
        'preprocess_image': preprocess_image,
        'photometric': args.random_photometric,
        'motion': args.random_motion,
        'deformable': args.random_deformable,
        'alpha': (1, 200),
        'sigma': (4, 7),
    }

    # create random transform generator for augmenting training data
    # returns a matrix of transformation ramdonly generated (yield)
    transform_generator = None
    if args.random_transform:
        transform_generator = random_transform_generator(
            # min_rotation=-0.1,
            # max_rotation=0.1,
            # min_translation=(-0.1, -0.1),
            # max_translation=(0.1, 0.1),
            # min_shear=-0.1,
            # max_shear=0.1,
            min_scaling=(0.9, 0.9),
            max_scaling=(1.1, 1.1),
            flip_x_chance=0.5,
            flip_y_chance=0.5,
        )

    if args.dataset_type == 'coco':
        # import here to prevent unnecessary dependency on cocoapi
        from ..preprocessing.coco import CocoGenerator

        train_generator = CocoGenerator(
            args.coco_path,
            'train2017',
            transform_generator=transform_generator,
            **common_args)

        validation_generator = CocoGenerator(args.coco_path, 'val2017',
                                             **common_args)
    elif args.dataset_type == 'pascal':
        train_generator = PascalVocGenerator(
            args.pascal_path,
            'trainval',
            transform_generator=transform_generator,
            **common_args)

        validation_generator = PascalVocGenerator(args.pascal_path, 'test',
                                                  **common_args)
    elif args.dataset_type == 'csv':
        train_generator = CSVGenerator(args.annotations,
                                       args.classes,
                                       transform_generator=transform_generator,
                                       **common_args)

        if args.val_annotations:
            validation_generator = CSVGenerator(args.val_annotations,
                                                args.classes, **common_args)
        else:
            validation_generator = None
    elif args.dataset_type == 'oid':
        train_generator = OpenImagesGenerator(
            args.main_dir,
            subset='train',
            version=args.version,
            labels_filter=args.labels_filter,
            annotation_cache_dir=args.annotation_cache_dir,
            parent_label=args.parent_label,
            transform_generator=transform_generator,
            **common_args)

        validation_generator = OpenImagesGenerator(
            args.main_dir,
            subset='validation',
            version=args.version,
            labels_filter=args.labels_filter,
            annotation_cache_dir=args.annotation_cache_dir,
            parent_label=args.parent_label,
            **common_args)
    elif args.dataset_type == 'kitti':
        train_generator = KittiGenerator(
            args.kitti_path,
            subset='train',
            transform_generator=transform_generator,
            **common_args)

        validation_generator = KittiGenerator(args.kitti_path,
                                              subset='val',
                                              **common_args)
    elif args.dataset_type == 'ivm':
        train_generator = IVMGenerator(args.ivm_path,
                                       'train',
                                       transform_generator=transform_generator,
                                       **common_args)

        validation_generator = IVMGenerator(args.ivm_path, 'val',
                                            **common_args)
    else:
        raise ValueError('Invalid data type received: {}'.format(
            args.dataset_type))

    return train_generator, validation_generator
コード例 #11
0
def create_generator(args):
    """!@brief
    Create the data generators.

    @param args: parseargs arguments object.
    """
    common_args = {
        'config'           : args.config,
        'image_min_side'   : args.image_min_side,
        'image_max_side'   : args.image_max_side,
        'photometric'      : args.random_photometric,
        'motion'           : args.random_motion,
        'deformable'       : args.random_deformable,
        'alpha'            : (1,200),
        'sigma'            : (4,7),
    }

    transform_generator = None

    # create random transform generator for augmenting training data
    # returns a matrix of transformation ramdonly generated (yield)
    if args.random_transform:
        transform_generator = random_transform_generator(
            min_rotation=-0.1,
            max_rotation=0.1,
            min_translation=(-0.1, -0.1),
            max_translation=(0.1, 0.1),
            min_shear=-0.1,
            max_shear=0.1,
            min_scaling=(0.9, 0.9),
            max_scaling=(1.1, 1.1),
            flip_x_chance=0.5,
            flip_y_chance=0.5,
        )

    if args.dataset_type == 'coco':
        # import here to prevent unnecessary dependency on cocoapi
        from ..preprocessing.coco import CocoGenerator

        generator = CocoGenerator(
            args.coco_path,
            args.coco_set,
            transform_generator=transform_generator,
            **common_args
        )
    elif args.dataset_type == 'pascal':
        generator = PascalVocGenerator(
            args.pascal_path,
            args.pascal_set,
            transform_generator=transform_generator,
            **common_args
        )
    elif args.dataset_type == 'csv':
        generator = CSVGenerator(
            args.annotations,
            args.classes,
            transform_generator=transform_generator,
            **common_args
        )
    elif args.dataset_type == 'oid':
        generator = OpenImagesGenerator(
            args.main_dir,
            subset=args.subset,
            version=args.version,
            labels_filter=args.labels_filter,
            parent_label=args.parent_label,
            annotation_cache_dir=args.annotation_cache_dir,
            transform_generator=transform_generator,
            **common_args
        )
    elif args.dataset_type == 'kitti':
        generator = KittiGenerator(
            args.kitti_path,
            subset=args.subset,
            transform_generator=transform_generator,
            **common_args
        )
    elif args.dataset_type == 'ivm':
        generator = IVMGenerator(
            args.ivm_path,
            args.ivm_set,
            transform_generator=transform_generator,
            **common_args
        )
    else:
        raise ValueError('Invalid data type received: {}'.format(args.dataset_type))

    return generator