示例#1
0
def create_generators(args, preprocess_image):
    """ Create generators for training and validation.

    Args
        args             : parseargs object containing configuration for generators.
        preprocess_image : Function that preprocesses an image for the network.
    """
    common_args = {
        'batch_size'       : args.batch_size,
        'image_min_side'   : args.image_min_side,
        'image_max_side'   : args.image_max_side,
        'preprocess_image' : preprocess_image,
    }

    # create random transform generator for augmenting training data
    if args.random_transform:
        transform_generator = random_transform_generator(
            min_rotation=-0.05,
            max_rotation=0.05,
            min_translation=(-0.1, -0.1),
            max_translation=(0.1, 0.1),
            #min_shear=-0.1,
            #max_shear=0.1,
            min_scaling=(0.8, 0.8),
            max_scaling=(1.2, 1.2),
            flip_x_chance=0.5,
            #flip_y_chance=0.5,
        )
    else:
        transform_generator = random_transform_generator(flip_x_chance=0.5)

    if args.dataset_type == 'coco':
        # import here to prevent unnecessary dependency on cocoapi
        from ..preprocessing.coco import CocoGenerator

        train_generator = CocoGenerator(
            args.coco_path,
            'train2017',
            transform_generator=transform_generator,
            **common_args
        )

        validation_generator = CocoGenerator(
            args.coco_path,
            'val2017',
            **common_args
        )
    elif args.dataset_type == 'pascal':
        train_generator = PascalVocGenerator(
            args.pascal_path,
            'trainval',
            transform_generator=transform_generator,
            **common_args
        )

        validation_generator = PascalVocGenerator(
            args.pascal_path,
            'test',
            **common_args
        )
    elif args.dataset_type == 'csv':
        train_generator = CSVGenerator(
            args.annotations,
            args.classes,
            transform_generator=transform_generator,
            **common_args
        )

        if args.val_annotations:
            validation_generator = CSVGenerator(
                args.val_annotations,
                args.classes,
                **common_args
            )
        else:
            validation_generator = None
    elif args.dataset_type == 'oid':
        train_generator = OpenImagesGenerator(
            args.main_dir,
            subset='train',
            version=args.version,
            labels_filter=args.labels_filter,
            annotation_cache_dir=args.annotation_cache_dir,
            parent_label=args.parent_label,
            transform_generator=transform_generator,
            **common_args
        )

        validation_generator = OpenImagesGenerator(
            args.main_dir,
            subset='validation',
            version=args.version,
            labels_filter=args.labels_filter,
            annotation_cache_dir=args.annotation_cache_dir,
            parent_label=args.parent_label,
            **common_args
        )
    elif args.dataset_type == 'kitti':
        train_generator = KittiGenerator(
            args.kitti_path,
            subset='train',
            transform_generator=transform_generator,
            **common_args
        )

        validation_generator = KittiGenerator(
            args.kitti_path,
            subset='val',
            **common_args
        )
    else:
        raise ValueError('Invalid data type received: {}'.format(args.dataset_type))

    return train_generator, validation_generator
示例#2
0
def create_generators(args):
    # create random transform generator for augmenting training data
    if args.random_transform:
        transform_generator = random_transform_generator(
            min_translation=(-0.3, -0.3),
            max_translation=(0.3, 0.3),
            min_scaling=(0.2, 0.2),
            max_scaling=(2, 2),
            flip_x_chance=0.5,
        )
    else:
        transform_generator = random_transform_generator(flip_x_chance=0.5)

    if args.dataset_type == 'coco':
        # import here to prevent unnecessary dependency on cocoapi
        from ..preprocessing.coco import CocoGenerator

        train_generator = CocoGenerator(
            args.coco_path,
            'train2017',
            transform_generator=transform_generator,
            batch_size=args.batch_size)

        validation_generator = CocoGenerator(args.coco_path,
                                             'val2017',
                                             batch_size=args.batch_size)
    elif args.dataset_type == 'pascal':
        train_generator = PascalVocGenerator(
            args.pascal_path,
            'trainval',
            transform_generator=transform_generator,
            batch_size=args.batch_size)

        validation_generator = PascalVocGenerator(args.pascal_path,
                                                  'test',
                                                  batch_size=args.batch_size)
    elif args.dataset_type == 'csv':
        train_generator = CSVGenerator(args.annotations,
                                       args.classes,
                                       transform_generator=transform_generator,
                                       batch_size=args.batch_size)

        if args.val_annotations:
            validation_generator = CSVGenerator(args.val_annotations,
                                                args.classes,
                                                batch_size=args.batch_size)
        else:
            validation_generator = None
    elif args.dataset_type == 'oid':
        train_generator = OpenImagesGenerator(
            args.main_dir,
            subset='train',
            version=args.version,
            labels_filter=args.labels_filter,
            annotation_cache_dir=args.annotation_cache_dir,
            fixed_labels=args.fixed_labels,
            transform_generator=transform_generator,
            batch_size=args.batch_size)

        validation_generator = OpenImagesGenerator(
            args.main_dir,
            subset='validation',
            version=args.version,
            labels_filter=args.labels_filter,
            annotation_cache_dir=args.annotation_cache_dir,
            fixed_labels=args.fixed_labels,
            batch_size=args.batch_size)
    elif args.dataset_type == 'kitti':
        train_generator = KittiGenerator(
            args.kitti_path,
            subset='train',
            transform_generator=transform_generator,
            batch_size=args.batch_size)

        validation_generator = KittiGenerator(args.kitti_path,
                                              subset='val',
                                              batch_size=args.batch_size)
    else:
        raise ValueError('Invalid data type received: {}'.format(
            args.dataset_type))

    return train_generator, validation_generator
示例#3
0
def create_generator(args, config):
    """ Create the data generators.

    Args:
        args: parseargs arguments object.
    """
    # create random transform generator for augmenting training data
    transform_generator = random_transform_generator(
        min_rotation=-0.1,
        max_rotation=0.1,
        min_translation=(-0.1, -0.1),
        max_translation=(0.1, 0.1),
        min_shear=-0.1,
        max_shear=0.1,
        min_scaling=(0.9, 0.9),
        max_scaling=(1.1, 1.1),
        flip_x_chance=0.5,
        flip_y_chance=0.5,
    )

    if args.dataset_type == 'coco':
        # import here to prevent unnecessary dependency on cocoapi
        from ..preprocessing.coco import CocoGenerator

        generator = CocoGenerator(args.coco_path,
                                  args.coco_set,
                                  transform_generator=transform_generator,
                                  image_min_side=args.image_min_side,
                                  image_max_side=args.image_max_side)
    elif args.dataset_type == 'pascal':
        generator = PascalVocGenerator(args.pascal_path,
                                       args.pascal_set,
                                       transform_generator=transform_generator,
                                       image_min_side=args.image_min_side,
                                       image_max_side=args.image_max_side)
    elif args.dataset_type == 'csv':
        generator = CSVGenerator(args.annotations,
                                 args.classes,
                                 transform_generator=transform_generator,
                                 image_min_side=args.image_min_side,
                                 image_max_side=args.image_max_side)
    elif args.dataset_type == 'oid':
        generator = OpenImagesGenerator(
            args.main_dir,
            subset=args.subset,
            version=args.version,
            labels_filter=args.labels_filter,
            fixed_labels=args.fixed_labels,
            annotation_cache_dir=args.annotation_cache_dir,
            transform_generator=transform_generator,
            image_min_side=args.image_min_side,
            image_max_side=args.image_max_side)
    elif args.dataset_type == 'kitti':
        generator = KittiGenerator(args.kitti_path,
                                   subset=args.subset,
                                   transform_generator=transform_generator,
                                   image_min_side=args.image_min_side,
                                   image_max_side=args.image_max_side)
    elif args.dataset_type == 'onthefly':
        validation_generator = onthefly.OnTheFlyGenerator(
            args.annotations, batch_size=args.batch_size, config=config)
    else:
        raise ValueError('Invalid data type received: {}'.format(
            args.dataset_type))

    return generator
示例#4
0
def create_generators(args, preprocess_image):
    """ Create generators for training and validation.

    Args
        args             : parseargs object containing configuration for generators.
        preprocess_image : Function that preprocesses an image for the network.
    """
    common_args = {
        "batch_size": args.batch_size,
        "config": args.config,
        "image_min_side": args.image_min_side,
        "image_max_side": args.image_max_side,
        "preprocess_image": preprocess_image,
    }

    # create random transform generator for augmenting training data
    if args.random_transform:
        transform_generator = random_transform_generator(
            min_rotation=-0.1,
            max_rotation=0.1,
            min_translation=(-0.1, -0.1),
            max_translation=(0.1, 0.1),
            min_shear=-0.1,
            max_shear=0.1,
            min_scaling=(0.9, 0.9),
            max_scaling=(1.1, 1.1),
            flip_x_chance=0.5,
            flip_y_chance=0.5,
        )
    else:
        transform_generator = random_transform_generator(flip_x_chance=0.5)

    if args.dataset_type == "coco":
        # import here to prevent unnecessary dependency on cocoapi
        from keras_retinanet.preprocessing.coco import CocoGenerator

        train_generator = CocoGenerator(
            args.coco_path,
            "train2017",
            transform_generator=transform_generator,
            **common_args)

        validation_generator = CocoGenerator(args.coco_path,
                                             "val2017",
                                             shuffle_groups=False,
                                             **common_args)
    elif args.dataset_type == "pascal":
        train_generator = PascalVocGenerator(
            args.pascal_path,
            "trainval",
            transform_generator=transform_generator,
            **common_args)

        validation_generator = PascalVocGenerator(args.pascal_path,
                                                  "test",
                                                  shuffle_groups=False,
                                                  **common_args)
    elif args.dataset_type == "csv":
        train_generator = CSVGenerator(args.annotations,
                                       args.classes,
                                       transform_generator=transform_generator,
                                       **common_args)

        if args.val_annotations:
            validation_generator = CSVGenerator(args.val_annotations,
                                                args.classes,
                                                shuffle_groups=False,
                                                **common_args)
        else:
            validation_generator = None
    elif args.dataset_type == "oid":
        train_generator = OpenImagesGenerator(
            args.main_dir,
            subset="train",
            version=args.version,
            labels_filter=args.labels_filter,
            annotation_cache_dir=args.annotation_cache_dir,
            parent_label=args.parent_label,
            transform_generator=transform_generator,
            **common_args)

        validation_generator = OpenImagesGenerator(
            args.main_dir,
            subset="validation",
            version=args.version,
            labels_filter=args.labels_filter,
            annotation_cache_dir=args.annotation_cache_dir,
            parent_label=args.parent_label,
            shuffle_groups=False,
            **common_args)
    elif args.dataset_type == "kitti":
        train_generator = KittiGenerator(
            args.kitti_path,
            subset="train",
            transform_generator=transform_generator,
            **common_args)

        validation_generator = KittiGenerator(args.kitti_path,
                                              subset="val",
                                              shuffle_groups=False,
                                              **common_args)
    else:
        raise ValueError("Invalid data type received: {}".format(
            args.dataset_type))

    return train_generator, validation_generator