示例#1
0
def create_generators(args):
    """
    Create generators for training and validation.

    Args
        args: parseargs object containing configuration for generators.
        preprocess_image: Function that preprocesses an image for the network.
    """
    common_args = {
        'batch_size': args.batch_size,
        'phi': args.phi,
    }

    # create random transform generator for augmenting training data
    if args.random_transform:
        misc_effect = MiscEffect()
        visual_effect = VisualEffect()
    else:
        misc_effect = None
        visual_effect = None

    if args.dataset_type == 'pascal':
        from generators.pascal import PascalVocGenerator
        train_generator = PascalVocGenerator(args.pascal_path,
                                             'trainval',
                                             skip_difficult=True,
                                             misc_effect=misc_effect,
                                             visual_effect=visual_effect,
                                             **common_args)

        validation_generator = PascalVocGenerator(args.pascal_path,
                                                  'val',
                                                  skip_difficult=True,
                                                  shuffle_groups=False,
                                                  **common_args)
    elif args.dataset_type == 'csv':
        from generators.csv_ import CSVGenerator
        train_generator = CSVGenerator(args.annotations_path,
                                       args.classes_path,
                                       misc_effect=misc_effect,
                                       visual_effect=visual_effect,
                                       **common_args)

        if args.val_annotations_path:
            validation_generator = CSVGenerator(args.val_annotations_path,
                                                args.classes_path,
                                                shuffle_groups=False,
                                                **common_args)
        else:
            validation_generator = None
    else:
        raise ValueError('Invalid data type received: {}'.format(
            args.dataset_type))

    return train_generator, validation_generator
示例#2
0
def create_generators(batch_size, phi, is_text_detect, is_detect_quadrangle,
                      rand_transf_augm, train_ann_path, val_ann_path,
                      train_class_path, val_class_path):
    """
    Create generators for training and validation.

    Args
        args: parseargs object containing configuration for generators.
        preprocess_image: Function that preprocesses an image for the network.
    """
    common_args = {
        'batch_size': batch_size,
        'phi': phi,
        'detect_text': is_text_detect,
        'detect_quadrangle': is_detect_quadrangle
    }

    # create random transform generator for augmenting training data
    if rand_transf_augm:
        misc_effect = MiscEffect()
        visual_effect = VisualEffect()
    else:
        misc_effect = None
        visual_effect = None

    from generators.csv_ import CSVGenerator
    train_generator = CSVGenerator(train_ann_path,
                                   train_class_path,
                                   misc_effect=misc_effect,
                                   visual_effect=visual_effect,
                                   **common_args)

    if val_ann_path:
        validation_generator = CSVGenerator(val_ann_path,
                                            val_class_path,
                                            shuffle_groups=False,
                                            **common_args)
    else:
        validation_generator = None

    return train_generator, validation_generator
示例#3
0
def create_generators(args):
    """
    Create generators for training and validation.

    Args
        args: parseargs object containing configuration for generators.
        preprocess_image: Function that preprocesses an image for the network.
    """
    common_args = {
        'batch_size': args.batch_size,
        'phi': args.phi,
        'detect_text': args.detect_text,
        'detect_quadrangle': args.detect_quadrangle
    }

    # create random transform generator for augmenting training data
    if args.random_transform:
        misc_effect = MiscEffect()
        visual_effect = VisualEffect()
    else:
        misc_effect = None
        visual_effect = None

    if args.dataset_type == 'pascal':
        from generators.pascal import PascalVocGenerator
        train_generator = PascalVocGenerator(
            args.pascal_path,
            'trainval',
            skip_difficult=True,
            misc_effect=misc_effect,
            visual_effect=visual_effect,
            image_extension=".png",
            **common_args
        )

        validation_generator = PascalVocGenerator(
            args.pascal_path,
            'val',
            skip_difficult=True,
            shuffle_groups=False,
            image_extension=".png",
            **common_args
        )
    elif args.dataset_type == 'csv':
        from generators.csv_ import CSVGenerator
        train_generator = CSVGenerator(
            args.annotations_path,
            args.classes_path,
            misc_effect=misc_effect,
            visual_effect=visual_effect,
            **common_args
        )

        if args.val_annotations_path:
            validation_generator = CSVGenerator(
                args.val_annotations_path,
                args.classes_path,
                shuffle_groups=False,
                **common_args
            )
        else:
            validation_generator = None
    elif args.dataset_type == 'coco':
        # import here to prevent unnecessary dependency on cocoapi
        from generators.coco import CocoGenerator
        train_generator = CocoGenerator(
            args.coco_path,
            'train2017',
            misc_effect=misc_effect,
            visual_effect=visual_effect,
            group_method='random',
            **common_args
        )

        validation_generator = CocoGenerator(
            args.coco_path,
            'val2017',
            shuffle_groups=False,
            **common_args
        )
    else:
        raise ValueError('Invalid data type received: {}'.format(args.dataset_type))

    return train_generator, validation_generator
    ckp_model_dir = "2020-08-04 (B1 MND WBiFPN LLR)"
    ckp_model_file = "csv_298_0.0389_0.6384.h5"

    quad_angle_arg = False
    txt_detect_arg = False

    phi = 1
    weighted_bifpn = True
    common_args = {
        "batch_size": 1,
        "phi": phi,
    }
    test_generator = CSVGenerator(
        path_test_csv,
        path_classes_csv,
        base_dir=paths_base_dir,
        detect_quadrangle=quad_angle_arg,
        detect_text=txt_detect_arg,
    )
    model_path = ckp_base_path + ckp_model_dir + "/" + ckp_model_file
    input_shape = (test_generator.image_size, test_generator.image_size)
    print(input_shape)
    anchors = test_generator.anchors
    num_classes = test_generator.num_classes()
    model, prediction_model = efficientdet(phi=phi,
                                           num_classes=num_classes,
                                           weighted_bifpn=weighted_bifpn)
    prediction_model.load_weights(model_path, by_name=True)
    average_precisions = evaluate(test_generator,
                                  prediction_model,
                                  visualize=False)
示例#5
0
        # image, annotations = rotate(image, annotations, prob=self.rotate_prob, border_value=self.border_value)
        image, annotations = flipx(image, annotations, prob=self.flip_prob)
        image, annotations = crop(image, annotations, prob=self.crop_prob)
        image, annotations = translate(image,
                                       annotations,
                                       prob=self.translate_prob,
                                       border_value=self.border_value)
        return image, annotations


if __name__ == '__main__':
    from generators.csv_ import CSVGenerator

    train_generator = CSVGenerator('datasets/ic15/train.csv',
                                   'datasets/ic15/classes.csv',
                                   detect_text=True,
                                   batch_size=1,
                                   phi=5,
                                   shuffle_groups=False)
    misc_effect = MiscEffect()
    for i in range(train_generator.size()):
        image = train_generator.load_image(i)
        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        annotations = train_generator.load_annotations(i)
        boxes = annotations['bboxes'].astype(np.int32)
        quadrangles = annotations['quadrangles'].astype(np.int32)
        for box in boxes:
            cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]),
                          (0, 0, 255), 1)
        cv2.drawContours(image, quadrangles, -1, (0, 255, 255), 1)
        src_image = image.copy()
        # cv2.namedWindow('src_image', cv2.WINDOW_NORMAL)
def create_generators(args):
    """
    Create generators for training and validation.

    Args
        args: parseargs object containing configuration for generators.
        preprocess_image: Function that preprocesses an image for the network.
    """
    common_args = {
        "batch_size": args.batch_size,
        "phi": args.phi,
        "detect_text": args.detect_text,
        "detect_quadrangle": args.detect_quadrangle,
    }

    # create random transform generator for augmenting training data
    if args.random_transform:
        misc_effect = MiscEffect()
        visual_effect = VisualEffect()
    else:
        misc_effect = None
        visual_effect = None

    if args.dataset_type == "pascal":
        from generators.pascal import PascalVocGenerator

        train_generator = PascalVocGenerator(
            args.pascal_path,
            "trainval",
            skip_difficult=True,
            misc_effect=misc_effect,
            visual_effect=visual_effect,
            **common_args,
        )

        validation_generator = PascalVocGenerator(
            args.pascal_path,
            "val",
            skip_difficult=True,
            shuffle_groups=False,
            **common_args,
        )
    elif args.dataset_type == "csv":
        from generators.csv_ import CSVGenerator

        train_generator = CSVGenerator(
            args.annotations_path,
            args.classes_path,
            misc_effect=misc_effect,
            visual_effect=visual_effect,
            **common_args,
        )

        if args.val_annotations_path:
            validation_generator = CSVGenerator(
                args.val_annotations_path,
                args.classes_path,
                shuffle_groups=False,
                **common_args,
            )
        else:
            validation_generator = None
    elif args.dataset_type == "coco":
        # import here to prevent unnecessary dependency on cocoapi
        from generators.coco import CocoGenerator

        train_generator = CocoGenerator(
            args.coco_path,
            "train2017",
            misc_effect=misc_effect,
            visual_effect=visual_effect,
            group_method="random",
            **common_args,
        )

        validation_generator = CocoGenerator(args.coco_path,
                                             "val2017",
                                             shuffle_groups=False,
                                             **common_args)
    else:
        raise ValueError("Invalid data type received: {}".format(
            args.dataset_type))

    return train_generator, validation_generator
示例#7
0
from generators.utils import affine_transform, get_affine_transform
from utils.image import read_image_bgr, preprocess_image, resize_image
import os.path as osp

DATA_SUFFIX = '_datamap.png'
RESULT_PATH = "result/"
PROCESS_PATH = "process/"
model_path = 'checkpoints/csv.h5'
colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255)]
score_threshold = 0.5
flip_test = False

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
generator = CSVGenerator(
    'data/annotations.csv',
    'data/classes.csv',
    'data',
)

num_classes = generator.num_classes()
classes = list(generator.classes.keys())

model, prediction_model, debug_model = centernet(num_classes=num_classes,
                                                 nms=True,
                                                 flip_test=flip_test,
                                                 freeze_bn=False,
                                                 score_threshold=score_threshold)
prediction_model.load_weights(model_path, by_name=True, skip_mismatch=True)


for f in os.listdir(PROCESS_PATH):
示例#8
0
        image, annotations = flipx(image, annotations, prob=self.flip_prob)
        image, annotations = crop(image, annotations, prob=self.crop_prob)
        image, annotations = translate(image,
                                       annotations,
                                       prob=self.translate_prob,
                                       border_value=self.border_value)
        return image, annotations


if __name__ == '__main__':
    from generators.csv_ import CSVGenerator

    train_generator = CSVGenerator('../generators/val.csv',
                                   '../generators/csv_class_file.csv',
                                   '../generators/csv_property_file.csv',
                                   base_dir='/Users/yanyan/data',
                                   detect_text=False,
                                   batch_size=1,
                                   phi=3,
                                   shuffle_groups=False)
    for i in range(train_generator.size()):
        image = train_generator.load_image(i)
        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        annotations = train_generator.load_annotations(i)

        image, annotations = crop(image, annotations)

        boxes = annotations['bboxes'].astype(np.int32)
        labels = annotations['labels']
        # quadrangles = annotations['quadrangles'].astype(np.int32)
        for box, label in zip(boxes, labels):
            cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]),
示例#9
0
def load_efficient_det(config, LOCAL_ANNOTATIONS_PATH, LOCAL_ROOT_PATH,
                       LOCAL_CLASSES_PATH, LOCAL_VALIDATIONS_PATH,
                       LOCAL_LOGS_PATH, LOCAL_SNAPSHOTS_PATH):

    common_args = {
        'phi': config['phi'],
        'detect_text': config['detect_text'],
        'detect_quadrangle': config['detect_quadrangle']
    }

    # create random transform generator for augmenting training data
    if config['random_transform']:
        misc_effect = MiscEffect()
        visual_effect = VisualEffect()
    else:
        misc_effect = None
        visual_effect = None

    annotations_df = pd.read_csv(LOCAL_ANNOTATIONS_PATH, header=None)
    # stratified sampling
    N = int(len(annotations_df) * 0.15)
    evaluation_df = annotations_df.groupby(
        5, group_keys=False).apply(lambda x: x.sample(
            int(np.rint(N * len(x) / len(annotations_df))))).sample(frac=1)
    evaluation_path = f'{LOCAL_ROOT_PATH}/evaluation.csv'
    evaluation_df.to_csv(evaluation_path, index=False, header=None)

    config['steps_per_epoch'] = annotations_df.iloc[:, 0].nunique(
    ) / config['batch_size']

    train_generator = CSVGenerator(LOCAL_ANNOTATIONS_PATH,
                                   LOCAL_CLASSES_PATH,
                                   batch_size=config['batch_size'],
                                   misc_effect=misc_effect,
                                   visual_effect=visual_effect,
                                   **common_args)
    if config['train_evaluation']:
        evaluation_generator = CSVGenerator(evaluation_path,
                                            LOCAL_CLASSES_PATH,
                                            batch_size=config['batch_size'],
                                            misc_effect=misc_effect,
                                            visual_effect=visual_effect,
                                            **common_args)
    else:
        evaluation_generator = None
    if config['validation']:
        validation_generator = CSVGenerator(LOCAL_VALIDATIONS_PATH,
                                            LOCAL_CLASSES_PATH,
                                            batch_size=config['batch_size'],
                                            misc_effect=misc_effect,
                                            visual_effect=visual_effect,
                                            **common_args)
    else:
        validation_generator = None
    num_classes = train_generator.num_classes()
    num_anchors = train_generator.num_anchors

    model, prediction_model = efficientdet(
        config['phi'],
        num_classes=num_classes,
        num_anchors=num_anchors,
        weighted_bifpn=config['weighted_bifpn'],
        freeze_bn=config['freeze_bn'],
        detect_quadrangle=config['detect_quadrangle'])

    # freeze backbone layers
    if config['freeze_backbone']:
        # 227, 329, 329, 374, 464, 566, 656
        for i in range(1, [227, 329, 329, 374, 464, 566, 656][config['phi']]):
            model.layers[i].trainable = False
    # optionally choose specific GPU
    gpu = config['gpu']
    device = gpu.split(':')[0]
    if gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = device
    if gpu and len(gpu.split(':')) > 1:
        gpus = gpu.split(':')[1]
        model = tf.keras.utils.multi_gpu_model(model,
                                               gpus=list(
                                                   map(int, gpus.split(','))))

    if config['snapshot'] == 'imagenet':
        model_name = 'efficientnet-b{}'.format(config['phi'])
        file_name = '{}_weights_tf_dim_ordering_tf_kernels_autoaugment_notop.h5'.format(
            model_name)
        file_hash = WEIGHTS_HASHES[model_name][1]
        weights_path = tf.keras.utils.get_file(file_name,
                                               BASE_WEIGHTS_PATH + file_name,
                                               cache_subdir='models',
                                               file_hash=file_hash)
        model.load_weights(weights_path, by_name=True)
    elif config['snapshot']:
        print('Loading model, this may take a second...')
        model.load_weights(config['snapshot'], by_name=True)

    return (model, prediction_model, train_generator, evaluation_generator,
            validation_generator, config)