def get_dataflow_vgg(annot_path, img_dir, strict, x_size, y_size, include_outputs_masks=False):
    """
    This function initializes the tensorpack dataflow and serves generator
    for training operation.

    :param annot_path: path to the annotation file
    :param img_dir: path to the images
    :return: dataflow object
    """
    coco_crop_size = 368

    # configure augmentors

    augmentors = [
        ScaleAug(scale_min=0.5,
                 scale_max=1.1,
                 target_dist=0.6,
                 interp=cv2.INTER_CUBIC),

        RotateAug(rotate_max_deg=40,
                  interp=cv2.INTER_CUBIC,
                  border=cv2.BORDER_CONSTANT,
                  border_value=(128, 128, 128), mask_border_val=1),

        CropAug(coco_crop_size, coco_crop_size, center_perterb_max=40, border_value=128,
                mask_border_val=1),

        FlipAug(num_parts=18, prob=0.5)
    ]

    if x_size != coco_crop_size:
        augmentors.append(ResizeAug(x_size, x_size))

    # prepare augment function

    augment_func = functools.partial(augment,
                                     augmentors=augmentors)

    # build the dataflow

    df = CocoDataFlow((coco_crop_size, coco_crop_size), annot_path, img_dir)
    df.prepare()
    size = df.size()
    df = MapData(df, read_img)

    if include_outputs_masks:
        df = MapData(df, gen_mask)
        build_sample_func = functools.partial(build_sample_with_masks,
                                              y_size=y_size)
    else:
        build_sample_func = functools.partial(build_sample,
                                              y_size=y_size)

    df = MapData(df, augment_func)

    df = MultiProcessMapDataZMQ(df, num_proc=4, map_func=build_sample_func, buffer_size=200, strict=strict)

    return df, size
コード例 #2
0
def get_dataflow(annot_path, img_dir, strict, x_size=224, y_size=28):
    """
    This function initializes the tensorpack dataflow and serves generator
    for training operation.

    :param annot_path: path to the annotation file
    :param img_dir: path to the images
    :return: dataflow object
    """
    coco_crop_size = 368
    coco_crop_size = 224  # TODO(JZ)crop_size

    # configure augmentors

    augmentors = [
        ScaleAug(scale_min=0.5,
                 scale_max=1.1,
                 target_dist=0.6,
                 interp=cv2.INTER_CUBIC),
        RotateAug(rotate_max_deg=40,
                  interp=cv2.INTER_CUBIC,
                  border=cv2.BORDER_CONSTANT,
                  border_value=(128, 128, 128),
                  mask_border_val=1),
        CropAug(coco_crop_size,
                coco_crop_size,
                center_perterb_max=40,
                border_value=128,
                mask_border_val=1),
        # CropAug(64, 48, center_perterb_max=40, border_value=128,
        #         mask_border_val=1),

        # FlipAug(num_parts=18, prob=0.5),
        FlipAug(num_parts=5, prob=0.5),
        #TODO(JZ)FlipAug
        ResizeAug(x_size, x_size)
    ]

    # prepare augment function

    augment_func = functools.partial(augment, augmentors=augmentors)

    # prepare building sample function

    build_sample_func = functools.partial(build_sample, y_size=y_size)

    # build the dataflow

    df = CocoDataFlow((coco_crop_size, coco_crop_size), annot_path, img_dir)
    df.prepare()
    size = df.size()
    print(df.size())

    df = MapData(df, read_img)
    print(df.size())

    df = MapData(df, augment_func)
    print(df.size())
    # df = MultiProcessMapDataZMQ(df, num_proc=4, map_func=build_sample_func,
    #                             buffer_size=200, strict=strict)

    df = MultiThreadMapData(df,
                            4,
                            build_sample_func,
                            buffer_size=200,
                            strict=strict)

    # TODO(JZ)ZMQ

    # df = MultiThreadMapData(df, 4, build_sample_func,
    #                         buffer_size=200, strict=strict)
    print(df.size())
    return df, size