示例#1
0
def create_generator(args):
    if args.dataset_type == 'coco':
        # import here to prevent unnecessary dependency on cocoapi
        from preprocessing.coco import CocoGenerator

        validation_generator = CocoGenerator(args.coco_path, 'val2017')
    elif args.dataset_type == 'pascal':
        validation_generator = PascalVocGenerator(args.pascal_path,
                                                  'test',
                                                  image_min_side=800,
                                                  image_max_side=1430,
                                                  classes={
                                                      "hathead": 0,
                                                      "nohathead": 1
                                                  })
    elif args.dataset_type == 'csv':
        validation_generator = CSVGenerator(
            args.annotations,
            args.classes,
        )
    else:
        raise ValueError('Invalid data type received: {}'.format(
            args.dataset_type))

    return validation_generator
示例#2
0
def create_generator(args):
    """ Create generators for evaluation.
    """
    if args.dataset_type == 'coco':
        # import here to prevent unnecessary dependency on cocoapi
        # from ..preprocessing.coco import CocoGenerator
        from preprocessing.coco import CocoGenerator

        validation_generator = CocoGenerator(
            args.coco_path,
            'val2017',
            image_min_side=args.image_min_side,
            image_max_side=args.image_max_side,
            config=args.config)
    elif args.dataset_type == 'pascal':
        validation_generator = PascalVocGenerator(
            args.pascal_path,
            'train',
            image_min_side=args.image_min_side,
            image_max_side=args.image_max_side,
            config=args.config)
    elif args.dataset_type == 'csv':
        validation_generator = CSVGenerator(args.annotations,
                                            args.classes,
                                            image_min_side=args.image_min_side,
                                            image_max_side=args.image_max_side,
                                            config=args.config)
    else:
        raise ValueError('Invalid data type received: {}'.format(
            args.dataset_type))

    return validation_generator
示例#3
0
def create_generator(args):
    # create random transform generator for augmenting training data
    transform_generator = random_transform_generator(
        # min_rotation=-0.1,
        # max_rotation=0.1,
        # min_translation=(-0.1, -0.1),
        # max_translation=(0.1, 0.1),
        # min_shear=-0.1,
        # max_shear=0.1,
        # min_scaling=(0.9, 0.9),
        # max_scaling=(1.1, 1.1),
        flip_x_chance=0.5,
        # flip_y_chance=0.5,
    )

    if args.dataset_type == 'coco':
        # import here to prevent unnecessary dependency on cocoapi
        from preprocessing.coco import CocoGenerator

        generator = CocoGenerator(args.coco_path,
                                  args.coco_set,
                                  transform_generator=transform_generator)
    elif args.dataset_type == 'pascal':
        generator = PascalVocGenerator(
            args.pascal_path,
            args.pascal_set,
            # classes={"bluehat": 0, 'whitehat': 1, 'yellohat': 2, 'redhat': 3, 'blackhat': 4, 'no_hat': 5},
            classes={
                'hathead': 0,
                'nohathead': 1
            },
            transform_generator=transform_generator,
            image_min_side=1080,
            image_max_side=1920,
            anchor_ratios=[0.5, 1, 1.5],
            anchor_scales=[0, 0.33333333, 0.66666667],
            anchor_sizes=[16, 32, 64, 128, 256],
            anchor_strides=[8, 16, 32, 64, 128])
    elif args.dataset_type == 'csv':
        generator = CSVGenerator(args.annotations,
                                 args.classes,
                                 transform_generator=transform_generator)
    elif args.dataset_type == 'oid':
        generator = OpenImagesGenerator(
            args.main_dir,
            subset=args.subset,
            version=args.version,
            labels_filter=args.labels_filter,
            fixed_labels=args.fixed_labels,
            annotation_cache_dir=args.annotation_cache_dir,
            transform_generator=transform_generator)
    elif args.dataset_type == 'kitti':
        generator = KittiGenerator(args.kitti_path,
                                   subset=args.subset,
                                   transform_generator=transform_generator)
    else:
        raise ValueError('Invalid data type received: {}'.format(
            args.dataset_type))

    return generator
示例#4
0
def create_generator(args):
    # create random transform generator for augmenting training data
    transform_generator = random_transform_generator(
        min_rotation=-0.1,
        max_rotation=0.1,
        min_translation=(-0.1, -0.1),
        max_translation=(0.1, 0.1),
        min_shear=-0.1,
        max_shear=0.1,
        min_scaling=(0.9, 0.9),
        max_scaling=(1.1, 1.1),
        flip_x_chance=0.5,
        flip_y_chance=0.5,
    )

    if args.dataset_type == 'coco':
        # import here to prevent unnecessary dependency on cocoapi
        from preprocessing.coco import CocoGenerator

        generator = CocoGenerator(
            args.coco_path,
            args.coco_set,
            transform_generator=transform_generator
        )
    elif args.dataset_type == 'pascal':
        generator = PascalVocGenerator(
            args.pascal_path,
            args.pascal_set,
            classes={"rebar": 0},
            transform_generator=transform_generator
        )
    elif args.dataset_type == 'csv':
        generator = CSVGenerator(
            args.annotations,
            args.classes,
            transform_generator=transform_generator
        )
    elif args.dataset_type == 'oid':
        generator = OpenImagesGenerator(
            args.main_dir,
            subset=args.subset,
            version=args.version,
            labels_filter=args.labels_filter,
            fixed_labels=args.fixed_labels,
            annotation_cache_dir=args.annotation_cache_dir,
            transform_generator=transform_generator
        )
    elif args.dataset_type == 'kitti':
        generator = KittiGenerator(
            args.kitti_path,
            subset=args.subset,
            transform_generator=transform_generator
        )
    else:
        raise ValueError('Invalid data type received: {}'.format(args.dataset_type))

    return generator
示例#5
0
def main(args=None):
    # parse arguments
    if args is None:
        args = sys.argv[1:]
    args = parse_args(args)

    # make sure keras is the minimum required version
    check_keras_version()

    # optionally choose specific GPU
    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    keras.backend.tensorflow_backend.set_session(get_session())

    # create the model
    print('Loading model, this may take a second...')
    model = keras.models.load_model(args.model, custom_objects=custom_objects)

    # create a generator for testing data
    test_generator = CocoGenerator(args.coco_path, args.set)

    evaluate_coco(test_generator, model, args.score_threshold)
示例#6
0
def create_generators(configs):
    """create a input data generator"""
    # create random transform generator for augmenting training data
    if not configs["Data_Augmentation"]['only_x_flip']:
        transform_generator = random_transform_generator(
            min_rotation=configs['Data_Augmentation']['rotation'][0],
            max_rotation=configs['Data_Augmentation']['rotation'][1],
            min_translation=configs['Data_Augmentation']['min_translation'],
            max_translation=configs['Data_Augmentation']['max_translation'],
            min_shear=configs['Data_Augmentation']['shear'][0],
            max_shear=configs['Data_Augmentation']['shear'][1],
            min_scaling=configs['Data_Augmentation']['min_scaling'],
            max_scaling=configs['Data_Augmentation']['max_scaling'],
            flip_x_chance=0.5,
            gray=configs['Data_Augmentation']['gray'],
            inverse_color=configs['Data_Augmentation']['inverse_color'],
        )
    else:
        transform_generator = random_transform_generator(flip_x_chance=0.5)

    if configs['Dataset']['dataset_type'] == 'coco':
        # import here to prevent unnecessary dependency on cocoapi
        from preprocessing.coco import CocoGenerator

        train_generator = CocoGenerator(
            configs['Dataset']['dataset_path'],
            'train2017',
            transform_generator=transform_generator,
            batch_size=configs['Train']['batch_size'],
            image_min_side=configs['Train']['image_min_side'],
            image_max_side=configs['Train']['image_max_side'])

        validation_generator = CocoGenerator(
            configs['Dataset']['dataset_path'],
            'val2017',
            batch_size=configs['Train']['batch_size'],
            image_min_side=configs['Train']['image_min_side'],
            image_max_side=configs['Train']['image_max_side'])
    elif configs['Dataset']['dataset_type'] == 'pascal':
        train_generator = PascalVocGenerator(
            configs['Dataset']['dataset_path'],
            'trainval',
            classes=configs['Dataset']['classes'],
            transform_generator=transform_generator,
            batch_size=configs['Train']['batch_size'],
            image_min_side=configs['Train']['image_min_side'],
            image_max_side=configs['Train']['image_max_side'],
            anchor_ratios=configs['Anchors']['ratios'],
            anchor_scales=configs['Anchors']['scales'],
            anchor_sizes=configs['Anchors']['sizes'],
            anchor_strides=configs['Anchors']['strides'])

        validation_generator = PascalVocGenerator(
            configs['Dataset']['dataset_path'],
            'test',
            classes=configs['Dataset']['classes'],
            batch_size=configs['Train']['batch_size'],
            image_min_side=configs['Train']['image_min_side'],
            image_max_side=configs['Train']['image_max_side'],
            anchor_ratios=configs['Anchors']['ratios'],
            anchor_scales=configs['Anchors']['scales'],
            anchor_sizes=configs['Anchors']['sizes'],
            anchor_strides=configs['Anchors']['strides'])
    elif configs['Dataset']['dataset_type'] == 'csv':
        train_generator = CSVGenerator(
            configs['Dataset']['csv_data_file'],
            configs['Dataset']['csv_classes_file'],
            transform_generator=transform_generator,
            batch_size=configs['Train']['batch_size'],
            image_min_side=configs['Train']['image_min_side'],
            image_max_side=configs['Train']['image_max_side'])

        if configs['Dataset']['csv_val_annotations']:
            validation_generator = CSVGenerator(
                configs['Dataset']['csv_val_annotations'],
                configs['Dataset']['csv_classes_file'],
                batch_size=configs['Train']['batch_size'],
                image_min_side=configs['Train']['image_min_side'],
                image_max_side=configs['Train']['image_max_side'])
        else:
            validation_generator = None
    elif configs['Dataset']['dataset_type'] == 'oid':
        train_generator = OpenImagesGenerator(
            configs['Dataset']['dataset_path'],
            subset='train',
            version=configs['Dataset']['version'],
            labels_filter=configs['Dataset']['oid_labels_filter'],
            annotation_cache_dir=configs['Dataset']
            ['oid_annotation_cache_dir'],
            fixed_labels=configs['Dataset']['fixed_labels'],
            transform_generator=transform_generator,
            batch_size=configs['Train']['batch_size'],
            image_min_side=configs['Train']['image_min_side'],
            image_max_side=configs['Train']['image_max_side'])

        validation_generator = OpenImagesGenerator(
            configs['Dataset']['dataset_path'],
            subset='validation',
            version=configs['Dataset']['version'],
            labels_filter=configs['Dataset']['oid_labels_filter'],
            annotation_cache_dir=configs['Dataset']
            ['oid_annotation_cache_dir'],
            fixed_labels=configs['Dataset']['fixed_labels'],
            batch_size=configs['Train']['batch_size'],
            image_min_side=configs['Train']['image_min_side'],
            image_max_side=configs['Train']['image_max_side'])
    elif configs['Dataset']['dataset_type'] == 'kitti':
        train_generator = KittiGenerator(
            configs['Dataset']['dataset_path'],
            subset='train',
            transform_generator=transform_generator,
            batch_size=configs['Train']['batch_size'],
            image_min_side=configs['Train']['image_min_side'],
            image_max_side=configs['Train']['image_max_side'])

        validation_generator = KittiGenerator(
            configs['Dataset']['dataset_path'],
            subset='val',
            batch_size=configs['Train']['batch_size'],
            image_min_side=configs['Train']['image_min_side'],
            image_max_side=configs['Train']['image_max_side'])
    else:
        raise ValueError('Invalid data type received: {}'.format(
            configs['Dataset']['dataset_type']))

    return train_generator, validation_generator