示例#1
0
def get_dataflow(path, is_train, img_path=None):
    ds = CocoPose(path, img_path, is_train)  # read data from lmdb
    if is_train:
        ds = MapData(ds, read_image_url)
        ds = MapDataComponent(ds, pose_random_scale)
        ds = MapDataComponent(ds, pose_rotation)
        ds = MapDataComponent(ds, pose_flip)
        ds = MapDataComponent(ds, pose_resize_shortestedge_random)
        ds = MapDataComponent(ds, pose_crop_random)
        ds = MapData(ds, pose_to_img)
        # augs = [
        #     imgaug.RandomApplyAug(imgaug.RandomChooseAug([
        #         imgaug.GaussianBlur(max_size=3)
        #     ]), 0.7)
        # ]
        # ds = AugmentImageComponent(ds, augs)
        ds = PrefetchData(ds, 1000, multiprocessing.cpu_count() * 1)
    else:
        ds = MultiThreadMapData(ds,
                                nr_thread=16,
                                map_func=read_image_url,
                                buffer_size=1000)
        ds = MapDataComponent(ds, pose_resize_shortestedge_fixed)
        ds = MapDataComponent(ds, pose_crop_center)
        ds = MapData(ds, pose_to_img)
        ds = PrefetchData(ds, 100, multiprocessing.cpu_count() // 4)

    return ds
示例#2
0
def get_dataflow_batch(path, is_train, batchsize, img_path=None):
    logger.info('dataflow img_path=%s' % img_path)
    ds = get_dataflow(path, is_train, img_path=img_path)
    ds = BatchData(ds, batchsize)
    if is_train:
        ds = PrefetchData(ds, 10, 2)
    else:
        ds = PrefetchData(ds, 50, 2)

    return ds
示例#3
0
    def get_input_flow(self):
        ds_train = CellImageDataManagerTrain()
        # ds_train = MapDataComponent(ds_train, random_affine)  # TODO : no improvement?
        ds_train = MapDataComponent(ds_train, random_color)
        # ds_train = MapDataComponent(ds_train, random_scaling)
        ds_train = MapDataComponent(
            ds_train,
            mask_size_normalize)  # Resize by instance size - normalization
        ds_train = MapDataComponent(
            ds_train, lambda x: resize_shortedge_if_small(x, self.img_size))
        ds_train = MapDataComponent(
            ds_train, lambda x: random_crop(x, self.img_size, self.img_size))
        ds_train = MapDataComponent(ds_train, random_flip_lr)
        ds_train = MapDataComponent(ds_train, random_flip_ud)
        # ds_train = MapDataComponent(ds_train, data_to_elastic_transform_wrapper)
        ds_train = MapDataComponent(ds_train, erosion_mask)
        ds_train = MapData(
            ds_train, lambda x: data_to_segment_input(
                x, is_gray=False, unet_weight=True))
        ds_train = PrefetchData(ds_train, 256, 24)
        ds_train = BatchData(ds_train, self.batchsize)
        ds_train = MapDataComponent(ds_train, data_to_normalize1)

        ds_valid = CellImageDataManagerValid()
        ds_valid = MapDataComponent(
            ds_valid, lambda x: resize_shortedge_if_small(x, self.img_size))
        ds_valid = MapDataComponent(
            ds_valid, lambda x: random_crop(x, self.img_size, self.img_size))
        ds_valid = MapDataComponent(ds_valid, erosion_mask)
        ds_valid = MapData(
            ds_valid, lambda x: data_to_segment_input(
                x, is_gray=False, unet_weight=True))
        ds_valid = PrefetchData(ds_valid, 20, 12)
        ds_valid = BatchData(ds_valid, self.batchsize, remainder=True)
        ds_valid = MapDataComponent(ds_valid, data_to_normalize1)

        ds_valid2 = CellImageDataManagerValid()
        ds_valid2 = MapDataComponent(
            ds_valid2, lambda x: resize_shortedge_if_small(x, self.img_size))
        ds_valid2 = MapDataComponent(
            ds_valid2,
            lambda x: center_crop_if_tcga(x, self.img_size, self.img_size))
        # ds_valid2 = MapDataComponent(ds_valid2, lambda x: resize_shortedge(x, self.img_size))
        ds_valid2 = MapData(ds_valid2,
                            lambda x: data_to_segment_input(x, is_gray=False))
        ds_valid2 = MapDataComponent(ds_valid2, data_to_normalize1)

        ds_test = CellImageDataManagerTest()
        ds_test = MapDataComponent(
            ds_test, lambda x: resize_shortedge_if_small(x, self.img_size))
        # ds_test = MapDataComponent(ds_test, lambda x: resize_shortedge(x, self.img_size))
        ds_test = MapData(ds_test, lambda x: data_to_image(x, is_gray=False))
        ds_test = MapDataComponent(ds_test, data_to_normalize1)

        return ds_train, ds_valid, ds_valid2, ds_test
def get_dataflow(path, is_train):
    ds = SynthHands(path, is_train)       # read data from lmdb
    if is_train:
        ds = MapData(ds, read_image_url)
        ds = MapData(ds, pose_to_img)
        ds = PrefetchData(ds, 1000, multiprocessing.cpu_count() * 1)
    else:
        ds = MultiThreadMapData(ds, num_thread=16, map_func=read_image_url, buffer_size=1000)
        ds = MapData(ds, pose_to_img)
        ds = PrefetchData(ds, 100, multiprocessing.cpu_count() // 4)
    return ds
示例#5
0
def get_dataflow_batch(path, is_train, batchsize, img_path=None):
    logger.info('dataflow img_path=%s' % img_path)
    ds = get_dataflow(path, is_train, img_path=img_path)
    print("ds from get_dataflow", ds)
    ds = BatchData(ds, batchsize)
    print("ds from batchdata", ds)
    if is_train:
        ds = PrefetchData(ds, 10, 3)
        print("ds from preferchdata", ds)
    else:
        ds = PrefetchData(ds, 50, 2)

    return ds
    def get_input_flow(self):
        ds_train = CellImageDataManagerTrain()
        # Augmentation :
        ds_train = MapDataComponent(ds_train, random_affine)
        ds_train = MapDataComponent(ds_train, random_color)
        # ds_train = MapDataComponent(ds_train, random_color2)  # not good
        ds_train = MapDataComponent(ds_train, random_scaling)
        ds_train = MapDataComponent(
            ds_train, lambda x: resize_shortedge_if_small(x, 224))
        ds_train = MapDataComponent(ds_train,
                                    lambda x: random_crop(x, 224, 224))
        ds_train = MapDataComponent(ds_train, random_flip_lr)
        # ds_train = MapDataComponent(ds_train, data_to_elastic_transform_wrapper)
        ds_train = MapDataComponent(ds_train, random_flip_ud)
        if self.unet_weight:
            ds_train = MapDataComponent(ds_train, erosion_mask)
        ds_train = PrefetchData(ds_train, 1000, 24)
        ds_train = MapData(
            ds_train, lambda x: data_to_segment_input(x, not self.is_color,
                                                      self.unet_weight))
        ds_train = BatchData(ds_train, self.batchsize)
        ds_train = MapDataComponent(ds_train, data_to_normalize1)
        ds_train = PrefetchData(ds_train, 10, 2)

        ds_valid = CellImageDataManagerValid()
        ds_valid = MapDataComponent(ds_valid,
                                    lambda x: center_crop(x, 224, 224))
        if self.unet_weight:
            ds_valid = MapDataComponent(ds_valid, erosion_mask)
        ds_valid = MapData(
            ds_valid, lambda x: data_to_segment_input(x, not self.is_color,
                                                      self.unet_weight))
        ds_valid = BatchData(ds_valid, self.batchsize, remainder=True)
        ds_valid = MapDataComponent(ds_valid, data_to_normalize1)
        ds_valid = PrefetchData(ds_valid, 20, 24)

        ds_valid2 = CellImageDataManagerValid()
        ds_valid2 = MapDataComponent(
            ds_valid2, lambda x: resize_shortedge_if_small(x, 224))
        ds_valid2 = MapData(
            ds_valid2, lambda x: data_to_segment_input(x, not self.is_color))
        ds_valid2 = MapDataComponent(ds_valid2, data_to_normalize1)

        ds_test = CellImageDataManagerTest()
        ds_test = MapDataComponent(ds_test,
                                   lambda x: resize_shortedge_if_small(x, 224))
        ds_test = MapData(ds_test,
                          lambda x: data_to_image(x, not self.is_color))
        ds_test = MapDataComponent(ds_test, data_to_normalize1)

        return ds_train, ds_valid, ds_valid2, ds_test
示例#7
0
文件: train.py 项目: qq456cvb/UKPGAN
def main(cfg):
    print(cfg)
    
    tf.reset_default_graph()
    
    logger.set_logger_dir('tflogs', action='d')

    copyfile(hydra.utils.to_absolute_path('model.py'), 'model.py')
    copyfile(hydra.utils.to_absolute_path('dataflow.py'), 'dataflow.py')
    
    if cfg.cat_name == 'smpl':
        train_df = SMPLDataFlow(cfg, True, 1000)
        val_df = VisSMPLDataFlow(cfg, True, 1000, port=1080)
    else:
        train_df = ShapeNetDataFlow(cfg, cfg.data.train_txt, True)
        val_df = VisDataFlow(cfg, cfg.data.val_txt, False, port=1080)
    
    config = TrainConfig(
        model=Model(cfg),
        dataflow=BatchData(PrefetchData(train_df, cpu_count() // 2, cpu_count() // 2), cfg.batch_size),
        callbacks=[
            ModelSaver(),
            SimpleMovingAverage(['recon_loss', 'GAN/loss_d', 'GAN/loss_g', 'GAN/gp_loss', 'symmetry_loss'], 100),
            PeriodicTrigger(val_df, every_k_steps=30)
        ],
        monitors=tensorpack.train.DEFAULT_MONITORS() + [ScalarPrinter(enable_step=True, enable_epoch=False)],
        max_epoch=10
    )
    launch_train_with_config(config, SimpleTrainer())
示例#8
0
def get_dataflow(path, is_train):
    ds = CocoPoseLMDB(path, is_train)       # read data from lmdb
    if is_train:
        ds = MapDataComponent(ds, pose_random_scale)
        ds = MapDataComponent(ds, pose_rotation)
        ds = MapDataComponent(ds, pose_flip)
        ds = MapDataComponent(ds, pose_resize_shortestedge_random)
        ds = MapDataComponent(ds, pose_crop_random)
        ds = MapData(ds, pose_to_img)
        augs = [
            imgaug.RandomApplyAug(imgaug.RandomChooseAug([
                imgaug.BrightnessScale((0.6, 1.4), clip=False),
                imgaug.Contrast((0.7, 1.4), clip=False),
                imgaug.GaussianBlur(max_size=3)
            ]), 0.7),
        ]
        ds = AugmentImageComponent(ds, augs)
    else:
        ds = MapDataComponent(ds, pose_resize_shortestedge_fixed)
        ds = MapDataComponent(ds, pose_crop_center)
        ds = MapData(ds, pose_to_img)

    ds = PrefetchData(ds, 1000, multiprocessing.cpu_count())

    return ds
示例#9
0
def _get_dataflow_onlyread(path, is_train, img_path=None):
    ds = OpenOoseHand(path, is_train)  # read data from lmdb
    ds = MapData(ds, read_image_url)
    ds = MapDataComponent(ds, crop_hand_roi_big)
    ds = MapDataComponent(ds, hand_random_scale)
    ds = MapDataComponent(ds, pose_rotation)
    ds = MapDataComponent(ds, pose_flip)
    ds = MapDataComponent(ds, crop_hand_roi)
    # ds = MapDataComponent(ds, pose_resize_shortestedge_fixed)
    # ds = MapDataComponent(ds, pose_crop_random)
    ds = MapData(ds, pose_to_img)
    ds = PrefetchData(ds, 10, 2)
    return ds
示例#10
0
def get_dataflow(path, is_train=True, img_path=None,sigma=8.0,output_shape=(1440,2560),
                    numparts=5,translation=False,scale=False,rotation=True,
                            mins=0.25,maxs=1.2,mina=-np.pi,maxa=np.pi ,ilumination=0.0,image_type='RGB'):
    print('Creating images from',path)
    numparts, skeleton = get_skeleton_from_json(path)
    #numparts + 1 because need to the background
    ds = CocoPose(path, img_path, is_train, numparts=numparts + 1, sigma=sigma,skeleton=skeleton,
                    output_shape=output_shape, translation=translation,scale=scale,rotation=rotation,
                        mins=mins,maxs=maxs,mina=mina,maxa=maxa, ilumination=ilumination,image_type=image_type
                        )       # read data from lmdb
    if is_train:
        #ds = MapData(ds, read_image_url)
        ds = MultiThreadMapData(ds, nr_thread=8, map_func=read_image_url, buffer_size=10)
        
        ds = MapDataComponent(ds, get_augmented_image)
        #ds = MapDataComponent(ds, pose_rotation)
        #ds = MapDataComponent(ds, pose_flip)
        #ds = MapDataComponent(ds, pose_resize_shortestedge_random)
        #ds = MapDataComponent(ds, pose_crop_random)
        #logger.info('Out of new augmenter')
        ds = MapData(ds, pose_to_img)
        #logger.info('Out pose to img')
        # augs = [
        #     imgaug.RandomApplyAug(imgaug.RandomChooseAug([
        #         imgaug.GaussianBlur(max_size=3)
        #     ]), 0.7)
        # ]
        # ds = AugmentImageComponent(ds, augs)
        ds = PrefetchData(ds, 10, multiprocessing.cpu_count() * 1)
    else:
        #ds = MultiThreadMapData(ds, nr_thread=4, map_func=read_image_url, buffer_size=10)
        ds = MapData(ds, read_image_url)
        #ds = MapDataComponent(ds, pose_resize_shortestedge_fixed)
        #ds = MapDataComponent(ds, pose_crop_center)
        ds = MapData(ds, pose_to_img)
        ds = PrefetchData(ds, 10, multiprocessing.cpu_count() // 4)

    return ds
示例#11
0
def get_dataflow_batch(path, is_train=True, batch_size=10, img_path=None,sigma=8.0,output_shape=(1440,2560),
                                   numparts=5,translation=False,scale=False,rotation=True,
                                                mins=0.25,maxs=1.2,mina=-np.pi,maxa=np.pi, ilumination=0.0,image_type='RGB'):
    logger.info('dataflow img_path=%s' % img_path)
    
    ds = get_dataflow(path, is_train, img_path=img_path,sigma=sigma,output_shape=output_shape, 
                            translation=translation,scale=scale,rotation=rotation,
                                    mins=mins,maxs=maxs, mina=mina,maxa=maxa,ilumination=ilumination,image_type=image_type)
    ds = BatchData(ds, batch_size)
    # if is_train:
    ds = PrefetchData(ds, 10, 2)
    # else:
    #     ds = PrefetchData(ds, 50, 2)

    return ds
示例#12
0
def get_dataflow(coco_data_paths):
    """
    This function initializes the tensorpack dataflow and serves generator
    for training operation.

    :param coco_data_paths: paths to the coco files: annotation file and folder with images
    :return: dataflow object
    """
    df = CocoDataFlow((368, 368), coco_data_paths)
    df.prepare()
    df = MapData(df, read_img)
    df = MapData(df, gen_mask)
    df = MapData(df, augment)
    df = MapData(df, apply_mask)
    df = MapData(df, build_sample)
    # df = PrefetchDataZMQ(df, nr_proc=4)
    df = PrefetchData(df, 2, 1)

    return df
def get_dataflow(annot_path, img_dir):
    """
    This function initializes the tensorpack dataflow and serves generator
    for training operation.

    :param annot_path: path to the annotation file
    :param img_dir: path to the images
    :return: dataflow object
    """
    df = CocoDataFlow((368, 368), annot_path, img_dir)
    df.prepare()
    df = MapData(df, read_img)
    df = MapData(df, gen_mask)
    df = MapData(df, augment)
    df = MapData(df, apply_mask)
    df = MapData(df, build_sample)
    df = PrefetchData(df, 2, 1)  #df = PrefetchData(df, 2, 1)

    return df
示例#14
0
    def test_speed2(self):
        ds_train = CellImageDataManagerTrain()
        # ds_train = MapDataComponent(ds_train, random_affine)
        # ds_train = MapDataComponent(ds_train, random_color)
        # ds_train = MapDataComponent(ds_train, random_scaling)
        # ds_train = MapDataComponent(ds_train, lambda x: resize_shortedge_if_small(x, 228))
        # ds_train = MapDataComponent(ds_train, lambda x: random_crop(x, 228, 228))
        # ds_train = MapDataComponent(ds_train, random_flip_lr)
        # ds_train = MapDataComponent(ds_train, random_flip_ud)
        # ds_train = MapDataComponent(ds_train, data_to_elastic_transform_wrapper)
        ds_train = MapDataComponent(ds_train, erosion_mask)
        ds_train = PrefetchData(ds_train, 1000, 24)

        for idx, dp in enumerate(ds_train.get_data()):
            if idx > 100:
                break

        t = time.time()
        TestDataSpeed(ds_train, size=100).start()
        dt = time.time() - t
        self.assertLessEqual(dt, 5.0)
示例#15
0
                        meta.aug_joints,
                        1,
                        stride=8)

    return [meta, mask_paf, mask_heatmap, pafmap, heatmap]


if __name__ == '__main__':
    batch_size = 10
    curr_dir = os.path.dirname(__file__)

    annot_path = os.path.join(
        curr_dir, '../dataset/annotations/pen_keypoints_validation.json')
    img_dir = os.path.abspath(os.path.join(curr_dir, '../dataset/validation/'))
    df = CocoDataFlow(
        (368, 368), COCODataPaths(annot_path, img_dir))  #, select_ids=[1000])
    df.prepare()
    df = MapData(df, read_img)
    df = MapData(df, gen_mask)
    df = MapData(df, augment)
    df = MapData(df, apply_mask)
    df = MapData(df, build_debug_sample)
    df = PrefetchData(df, nr_prefetch=2, nr_proc=1)

    df.reset_state()
    gen = df.get_data()

    for g in gen:
        show_image_mask_center_of_main_person(g)
        #show_image_heatmap_paf(g)
示例#16
0
def get_dataflow_batch(path, is_train, batchsize):
    ds = get_dataflow(path, is_train)
    ds = BatchData(ds, batchsize)
    ds = PrefetchData(ds, 10, 2)

    return ds
示例#17
0
    def get_input_flow(self):
        ds_train = CellImageDataManagerTrain()
        # ds_train = MapDataComponent(ds_train, random_affine)  # TODO : no improvement?
        ds_train = MapDataComponent(ds_train, random_color)
        # ds_train = MapDataComponent(ds_train, random_scaling)
        ds_train = MapDataComponent(
            ds_train,
            mask_size_normalize)  # Resize by instance size - normalization
        ds_train = MapDataComponent(
            ds_train, lambda x: resize_shortedge_if_small(x, self.img_size))
        # ds_train = MapDataComponent(ds_train, lambda x: pad_if_small(x, self.img_size)) # preseve cell's size
        ds_train = MapDataComponent(
            ds_train, lambda x: random_crop(
                x, self.img_size, self.img_size, padding=self.pad_size))
        # ds_train = MapDataComponent(ds_train, random_add_thick_area)      # TODO : worth?
        ds_train = MapDataComponent(ds_train, random_flip_lr)
        ds_train = MapDataComponent(ds_train, random_flip_ud)
        # ds_train = MapDataComponent(ds_train, data_to_elastic_transform_wrapper)
        if self.unet_weight:
            ds_train = MapDataComponent(ds_train, erosion_mask)
        ds_train = MapData(
            ds_train, lambda x: data_to_segment_input(x, not self.is_color,
                                                      self.unet_weight))
        ds_train = PrefetchData(ds_train, 256, 24)
        ds_train = BatchData(ds_train, self.batchsize)
        ds_train = MapDataComponent(ds_train, data_to_normalize1)

        ds_valid = CellImageDataManagerValid()
        ds_valid = MapDataComponent(
            ds_valid, lambda x: resize_shortedge_if_small(x, self.img_size))
        ds_valid = MapDataComponent(
            ds_valid, lambda x: random_crop(
                x, self.img_size, self.img_size, padding=self.pad_size))
        if self.unet_weight:
            ds_valid = MapDataComponent(ds_valid, erosion_mask)
        ds_valid = MapData(
            ds_valid, lambda x: data_to_segment_input(x, not self.is_color,
                                                      self.unet_weight))
        ds_valid = PrefetchData(ds_valid, 32, 8)
        ds_valid = BatchData(ds_valid, self.batchsize, remainder=True)
        ds_valid = MapDataComponent(ds_valid, data_to_normalize1)

        ds_valid2 = CellImageDataManagerValid()
        ds_valid2 = MapDataComponent(
            ds_valid2, lambda x: resize_shortedge_if_small(x, self.img_size))
        ds_valid2 = MapDataComponent(
            ds_valid2,
            lambda x: center_crop_if_tcga(x, self.img_size, self.img_size))
        # ds_valid2 = MapDataComponent(ds_valid2, lambda x: resize_shortedge(x, self.img_size))
        ds_valid2 = MapData(
            ds_valid2, lambda x: data_to_segment_input(x, not self.is_color))
        ds_valid2 = MapDataComponent(ds_valid2, data_to_normalize1)

        ds_test = CellImageDataManagerTest()
        ds_test = MapDataComponent(
            ds_test, lambda x: resize_shortedge_if_small(x, self.img_size))
        # ds_test = MapDataComponent(ds_test, lambda x: resize_shortedge(x, self.img_size))
        ds_test = MapData(ds_test,
                          lambda x: data_to_image(x, not self.is_color))
        ds_test = MapDataComponent(ds_test, data_to_normalize1)

        return ds_train, ds_valid, ds_valid2, ds_test
示例#18
0
def get_dataflow(path, is_train, img_path=None):
    ds = OpenOoseHand(path, is_train)       # read data from lmdb
    # if is_train:
    #     ''' 
    #         ds is a DataFlow Object which must implement get_data() function
            
    #         MapData(ds, func) will apply func to data returned by ds.get_data()
    #         1. create an obj
    #         2. obj.ds = ds
    #         2. obj.get_data():
    #             data = self.ds.get_data()
    #             yield self.func(data)

    #         MapDataComponent(ds, func) will do similar thing
    #             the main different is the target of MapDataComponent(...)
    #             is the returned data of get_data()
    #     '''

    #     ds = MapData(ds, read_image_url)
    #     ds = MapDataComponent(ds, pose_random_scale)
    #     ds = MapDataComponent(ds, pose_rotation)
    #     ds = MapDataComponent(ds, pose_flip)
    #     ds = MapDataComponent(ds, pose_resize_shortestedge_random)
    #     ds = MapDataComponent(ds, pose_crop_random)
    #     # use joint_list to draw two point and vector heatmap
    #     ds = MapData(ds, pose_to_img)
    #     # augs = [
    #     #     imgaug.RandomApplyAug(imgaug.RandomChooseAug([
    #     #         imgaug.GaussianBlur(max_size=3)
    #     #     ]), 0.7)
    #     # ]
    #     # ds = AugmentImageComponent(ds, augs)
    #     # ds = PrefetchData(ds, 10, multiprocessing.cpu_count() * 4)
    #     ds = PrefetchData(ds, 2, 1)
    # else:
    #     ds = MultiThreadMapData(ds, nr_thread=16, map_func=read_image_url, buffer_size=1000)
    #     ds = MapDataComponent(ds, pose_resize_shortestedge_fixed)
    #     ds = MapDataComponent(ds, pose_crop_center)
    #     ds = MapData(ds, pose_to_img)
    #     ds = PrefetchData(ds, 10, multiprocessing.cpu_count() // 4)
    if is_train:
        ''' 
            ds is a DataFlow Object which must implement get_data() function
            
            MapData(ds, func) will apply func to data returned by ds.get_data()
            1. create an obj
            2. obj.ds = ds
            2. obj.get_data():
                data = self.ds.get_data()
                yield self.func(data)

            MapDataComponent(ds, func) will do similar thing
                the main different is the target of MapDataComponent(...)
                is the returned data of get_data()
        '''

        ds = MapData(ds, read_image_url)
        ds = MapDataComponent(ds, crop_hand_roi_big)
        ds = MapDataComponent(ds, hand_random_scale)
        ds = MapDataComponent(ds, pose_rotation)
        ds = MapDataComponent(ds, pose_flip)
        ds = MapDataComponent(ds, crop_hand_roi)
        # ds = MapDataComponent(ds, pose_resize_shortestedge_fixed)
        # ds = MapDataComponent(ds, pose_crop_random)
        ds = MapData(ds, pose_to_img)
        ds = PrefetchData(ds, 20, 1)
    else:
        ds = MultiThreadMapData(ds, nr_thread=1, map_func=read_image_url, buffer_size=5)
        ds = MapDataComponent(ds, crop_hand_roi_big)
        ds = MapDataComponent(ds, pose_resize_shortestedge_fixed)
        ds = MapDataComponent(ds, pose_crop_center)
        ds = MapData(ds, pose_to_img)
        ds = PrefetchData(ds, 2, 2)

    return ds
    df = MapData(
        df, lambda x: ([x[0], x[1], x[2], x[5]],
                       [x[3], x[4], x[6], x[3], x[4], x[6], x[3], x[4], x[6]]))
    df.reset_state()
    return df


if __name__ == '__main__':
    """
    Run this script to check speed of generating samples. Tweak the nr_proc
    parameter of PrefetchDataZMQ. Ideally it should reflect the number of cores 
    in your hardware
    """
    batch_size = 10
    curr_dir = os.path.dirname(__file__)
    annot_path = os.path.join(
        curr_dir, '../dataset/annotations/person_keypoints_val2017.json')
    img_dir = os.path.abspath(os.path.join(curr_dir, '../dataset/val2017/'))
    df = CocoDataFlow((368, 368), annot_path, img_dir)  #, select_ids=[1000])
    df.prepare()
    df = MapData(df, read_img)
    df = MapData(df, gen_mask)
    df = MapData(df, augment)
    df = MapData(df, apply_mask)
    df = MapData(df, build_sample)
    df = PrefetchData(df, nr_proc=4)
    df = BatchData(df, batch_size, use_list=False)
    df = MapData(df, lambda x: ([x[0], x[1], x[2]], [x[3], x[4], x[3], x[4]]))

    TestDataSpeed(df, size=100).start()
示例#20
0
    df.reset_state()
    return df


if __name__ == '__main__':
    """
    Run this script to check speed of generating samples. Tweak the nr_proc
    parameter of PrefetchDataZMQ. Ideally it should reflect the number of cores 
    in your hardware
    """
    batch_size = 10
    curr_dir = os.path.dirname(__file__)
    annot_path = os.path.join(curr_dir, '../dataset/annotations/person_keypoints_val2017.json')
    img_dir = os.path.abspath(os.path.join(curr_dir, '../dataset/val2017/'))
    df = CocoDataFlow((368, 368), COCODataPaths(annot_path, img_dir))#, select_ids=[1000])
    df.prepare()
    df = MapData(df, read_img)
    df = MapData(df, gen_mask)
    df = MapData(df, augment)
    df = MapData(df, apply_mask)
    df = MapData(df, build_sample)
    #df = PrefetchDataZMQ(df, nr_proc=4)
    df = PrefetchData(df, 2, 1)
    df = BatchData(df, batch_size, use_list=False)
    df = MapData(df, lambda x: (
        [x[0], x[1], x[2]],
        [x[3], x[4], x[3], x[4], x[3], x[4], x[3], x[4], x[3], x[4], x[3], x[4]])
    )

    TestDataSpeed(df, size=100).start()