def __init__(self,
                 mode,
                 batch_size=256,
                 shuffle=False,
                 num_workers=25,
                 cache=50000,
                 collate_fn=default_collate,
                 remainder=False,
                 cuda=False,
                 transform=None):
        # enumerate standard imagenet augmentors
        #imagenet_augmentors = fbresnet_augmentor(mode == 'train')
        imagenet_augmentors = [ImgAugTVCompose(transform)]

        # load the lmdb if we can find it
        lmdb_loc = os.path.join(os.environ['IMAGENET'],
                                'ILSVRC-%s.lmdb' % mode)
        ds = td.LMDBData(lmdb_loc, shuffle=False)
        if mode == 'train':
            ds = td.LocallyShuffleData(ds, cache)
        ds = td.PrefetchData(ds, 5000, 1)
        ds = td.LMDBDataPoint(ds)
        #ds = td.MapDataComponent(ds, lambda x: cv2.imdecode(x, cv2.IMREAD_COLOR), 0)
        ds = td.MapDataComponent(
            ds, lambda x: np.asarray(Image.open(io.BytesIO(x)).convert('RGB')),
            0)
        ds = td.AugmentImageComponent(ds, imagenet_augmentors)
        ds = td.PrefetchDataZMQ(ds, num_workers)
        self.ds = td.BatchData(ds, batch_size, remainder=remainder)
        self.ds.reset_state()

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.cuda = cuda
Exemple #2
0
def lmdb_dataflow(lmdb_path, batch_size, input_size, output_size, is_training, test_speed=False):
    #df = dataflow.LMDBSerializer.load("/home/cuda/Alex/PC-NBV/data/train.lmdb", shuffle=False)

    df = dataflow.LMDBSerializer.load(lmdb_path, shuffle=False)
    #df = dataflow.LMDBSerializer.load("/home/cuda/Alex/PC-NBV/data/", shuffle=False)

    size = df.size()
    if is_training:
        df = dataflow.LocallyShuffleData(df, buffer_size=2000)
        
        df = dataflow.PrefetchData(df, num_prefetch=500, num_proc=1)
        # df = dataflow.PrefetchData(df,nr_prefetch=500, nr_proc=1)

        
    df = BatchData(df, batch_size, input_size, output_size)
    if is_training:
        df = dataflow.PrefetchDataZMQ(df, num_proc=8)
        #df = dataflow.PrefetchData(df,num_prefetch=500, num_proc=1)
        #df = dataflow.PrefetchDataZMQ(df, num_proc=1)
    df = dataflow.RepeatedData(df, -1)
    if test_speed:
        dataflow.TestDataSpeed(df, size=1000).start()
    
    df.reset_state()

    return df, size
Exemple #3
0
def lmdb_dataflow(lmdb_path,
                  batch_size,
                  sample_size,
                  is_training,
                  test_speed=False,
                  train_perturb_list=None,
                  valid_perturb_list=None,
                  so3_perturb=False,
                  use_partial=False):
    df = dataflow.LMDBSerializer.load(lmdb_path, shuffle=False)
    size = df.size()
    if is_training:
        df = dataflow.LocallyShuffleData(df, buffer_size=2000)
    df = dataflow.PrefetchData(df, nr_prefetch=500, nr_proc=1)
    df = PreprocessData(df,
                        sample_size,
                        is_training,
                        train_perturb_list=train_perturb_list,
                        valid_perturb_list=valid_perturb_list,
                        so3_perturb=so3_perturb,
                        use_partial=use_partial)
    if is_training:
        df = dataflow.PrefetchDataZMQ(df, nr_proc=8)
    df = dataflow.BatchData(df, batch_size, use_list=True)
    df = dataflow.RepeatedData(df, -1)
    if test_speed:
        dataflow.TestDataSpeed(df, size=1000).start()
    df.reset_state()
    return df, size
Exemple #4
0
    def __init__(self,
                 mode,
                 batch_size=256,
                 shuffle=False,
                 num_workers=25,
                 cache=50000,
                 collate_fn=default_collate,
                 drop_last=False,
                 cuda=False):
        # enumerate standard imagenet augmentors
        imagenet_augmentors = fbresnet_augmentor(mode == 'train')

        # load the lmdb if we can find it
        lmdb_loc = os.path.join(os.environ['IMAGENET'],
                                'ILSVRC-%s.lmdb' % mode)
        ds = td.LMDBData(lmdb_loc, shuffle=False)
        ds = td.LocallyShuffleData(ds, cache)
        ds = td.PrefetchData(ds, 5000, 1)
        ds = td.LMDBDataPoint(ds)
        ds = td.MapDataComponent(ds,
                                 lambda x: cv2.imdecode(x, cv2.IMREAD_COLOR),
                                 0)
        ds = td.AugmentImageComponent(ds, imagenet_augmentors)
        ds = td.PrefetchDataZMQ(ds, num_workers)
        self.ds = td.BatchData(ds, batch_size)
        self.ds.reset_state()

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.cuda = cuda
def data_pipe(fmri_files,
              confound_files,
              label_matrix,
              target_name=None,
              batch_size=32,
              data_type='train',
              train_percent=0.8,
              nr_thread=nr_thread,
              buffer_size=buffer_size):
    assert data_type in ['train', 'val', 'test']
    assert fmri_files is not None

    print('\n\nGenerating dataflow for %s datasets \n' % data_type)

    buffer_size = min(len(fmri_files), buffer_size)
    nr_thread = min(len(fmri_files), nr_thread)

    ds0 = gen_fmri_file(fmri_files,
                        confound_files,
                        label_matrix,
                        data_type=data_type,
                        train_percent=train_percent)
    print('dataflowSize is ' + str(ds0.size()))
    print('Loading data using %d threads with %d buffer_size ... \n' %
          (nr_thread, buffer_size))

    if target_name is None:
        target_name = np.unique(label_matrix)

    ####running the model
    start_time = time.clock()
    ds1 = dataflow.MultiThreadMapData(
        ds0,
        nr_thread=nr_thread,
        map_func=lambda dp: map_load_fmri_image(dp, target_name),
        buffer_size=buffer_size,
        strict=True)

    ds1 = dataflow.PrefetchData(ds1, buffer_size, 1)

    ds1 = split_samples(ds1)
    print('prefetch dataflowSize is ' + str(ds1.size()))

    ds1 = dataflow.LocallyShuffleData(ds1,
                                      buffer_size=ds1.size() * buffer_size)

    ds1 = dataflow.BatchData(ds1, batch_size=batch_size)
    print('Time Usage of loading data in seconds: {} \n'.format(time.clock() -
                                                                start_time))

    ds1 = dataflow.PrefetchDataZMQ(ds1, nr_proc=1)
    ds1._reset_once()
    ##ds1.reset_state()

    #return ds1.get_data()
    for df in ds1.get_data():
        ##print(np.expand_dims(df[0].astype('float32'),axis=3).shape)
        yield (np.expand_dims(df[0].astype('float32'), axis=3),
               to_categorical(df[1].astype('int32'), len(target_name)))
Exemple #6
0
    def __init__(
        self,
        corpus_path,
        tokenizer,
        seq_len,
        encoding="utf-8",
        predict_feature=False,
        hard_negative=False,
        batch_size=512,
        shuffle=False,
        num_workers=25,
        cache=50000,
        drop_last=False,
        cuda=False,
        distributed=False,
        visualization=False,
    ):

        if dist.is_available() and distributed:
            num_replicas = dist.get_world_size()
            # assert num_replicas == 8
            rank = dist.get_rank()
            lmdb_file = "/coc/dataset/conceptual_caption/training_feat_part_" + str(rank) + ".lmdb"
            # if not os.path.exists(lmdb_file):
            # lmdb_file = "/srv/share/datasets/conceptual_caption/training_feat_part_" + str(rank) + ".lmdb"
        else:
            # lmdb_file = "/coc/dataset/conceptual_caption/training_feat_all.lmdb"
            # if not os.path.exists(lmdb_file):
            lmdb_file = "/coc/pskynet2/jlu347/multi-modal-bert/data/conceptual_caption/training_feat_all.lmdb"
            
        caption_path = "/coc/pskynet2/jlu347/multi-modal-bert/data/conceptual_caption/caption_train.json"
        print("Loading from %s" % lmdb_file)

        ds = td.LMDBSerializer.load(lmdb_file, shuffle=False)
        self.num_dataset = len(ds)

        preprocess_function = BertPreprocessBatch(
            caption_path,
            tokenizer,
            seq_len,
            36,
            self.num_dataset,
            encoding="utf-8",
            predict_feature=predict_feature,
        )

        ds = td.LocallyShuffleData(ds, cache)
        ds = td.PrefetchData(ds, 5000, 1)
        ds = td.MapData(ds, preprocess_function)
        # self.ds = td.PrefetchData(ds, 1)
        ds = td.PrefetchDataZMQ(ds, num_workers)
        self.ds = td.BatchData(ds, batch_size)
        # self.ds = ds
        self.ds.reset_state()

        self.batch_size = batch_size
        self.num_workers = num_workers
Exemple #7
0
    def __init__(self,
                 datafile,
                 batch_size,
                 num_workers=1,
                 nviews=12,
                 reset=True,
                 augment=False,
                 filter_classes=None,
                 filter_views=None,
                 polarmode='cartesian',
                 shuffle=True,
                 filter_ids=None,
                 label_to0idx=False,
                 rgb=False,
                 force_res=0,
                 autocrop=False,
                 keep_aspect_ratio=False):
        self.filter_classes = filter_classes
        self.filter_views = filter_views
        self.filter_ids = filter_ids
        self.polarmode = polarmode
        self.label_to0idx = label_to0idx
        self.rgb = rgb
        self.force_res = force_res
        self.autocrop = autocrop
        self.keep_aspect_ratio = keep_aspect_ratio

        if not isinstance(datafile, list):
            datafile = [datafile]

        ds = []
        for d in datafile:

            ds.append(df.LMDBSerializer.load(d, shuffle=shuffle))

            if shuffle:
                ds[-1] = df.LocallyShuffleData(ds[-1], 100)
            ds[-1] = df.PrefetchData(ds[-1], 20, 1)

            ds[-1] = df.MapData(ds[-1], self.load)
            if augment:
                ds[-1] = df.MapDataComponent(ds[-1], LMDBMultiView._augment, 0)

            if (not filter_classes and not filter_ids and num_workers > 1):
                # warning: skipping this is slower when filtering datasets
                #          but epoch counting will be wrong otherwise
                ds[-1] = df.PrefetchDataZMQ(ds[-1], num_workers)
            ds[-1] = df.BatchData(ds[-1], batch_size)

            if reset:
                ds[-1].reset_state()

        self.ds = ds
    def __init__(
        self,
        annotations_path,
        features_path,
        tokenizer,
        bert_model,
        seq_len,
        batch_size=512,
        num_workers=25,
        cache=10000,
        local_rank=-1,
        objective=0,
        num_locs=5,
        add_global_imgfeat=None,
    ):

        if dist.is_available() and local_rank != -1:
            rank = dist.get_rank()
            lmdb_file = os.path.join(
                features_path, "training_feat_part_" + str(rank) + ".lmdb")
        else:
            lmdb_file = os.path.join(features_path, "training_feat_all.lmdb")

            print("Loading from %s" % lmdb_file)

        ds = td.LMDBSerializer.load(lmdb_file, shuffle=False)
        self.num_dataset = len(ds)
        ds = td.LocallyShuffleData(ds, cache)
        caption_path = os.path.join(annotations_path, "caption_train.json")

        preprocess_function = BertPreprocessBatch(
            caption_path,
            tokenizer,
            bert_model,
            seq_len,
            36,
            self.num_dataset,
            objective=objective,
            num_locs=num_locs,
        )

        ds = td.PrefetchData(ds, 5000, 1)
        ds = td.MapData(ds, preprocess_function)
        ds = td.PrefetchDataZMQ(ds, num_workers)
        self.ds = td.BatchData(ds, batch_size)
        self.ds.reset_state()

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.add_global_imgfeat = add_global_imgfeat
        self.num_locs = num_locs
Exemple #9
0
def lmdb_dataflow(lmdb_path, batch_size, input_size, output_size, is_training, test_speed=False):
    """load LMDB files, then generate batches??"""
    df = dataflow.LMDBSerializer.load(lmdb_path, shuffle=False)
    size = df.size()
    if is_training:
        df = dataflow.LocallyShuffleData(df, buffer_size=2000)  # buffer_size
        df = dataflow.PrefetchData(df, nr_prefetch=500, nr_proc=1)  # multiprocess the data
    df = BatchData(df, batch_size, input_size, output_size)
    if is_training:
        df = dataflow.PrefetchDataZMQ(df, nr_proc=8)
    df = dataflow.RepeatedData(df, -1)
    if test_speed:
        dataflow.TestDataSpeed(df, size=1000).start()
    df.reset_state()
    return df, size
Exemple #10
0
def lmdb_dataflow(lmdb_path, batch_size, input_size, output_size, is_training, test_speed=False):
    df = dataflow.LMDBData(lmdb_path, shuffle=False)
    size = df.size()
    if is_training:
        df = dataflow.LocallyShuffleData(df, buffer_size=2000)
    df = dataflow.PrefetchData(df, nr_prefetch=500, nr_proc=1)
    df = dataflow.LMDBDataPoint(df)
    df = PreprocessData(df, input_size, output_size)
    if is_training:
        df = dataflow.PrefetchDataZMQ(df, nr_proc=8)
    df = dataflow.BatchData(df, batch_size, use_list=True)
    df = dataflow.RepeatedData(df, -1)
    if test_speed:
        dataflow.TestDataSpeed(df, size=1000).start()
    df.reset_state()
    return df, size
Exemple #11
0
    def __init__(self, config, dataset_mode):
        """Set the path for Data."""
        self.data_folder = config.data_folder
        self.num_input_points = config.num_input_points
        self.num_gt_points = config.num_gt_points
        self.dataset_mode = dataset_mode

        print(self.data_folder + self.dataset_mode + '.lmdb')

        self.df = dataflow.LMDBSerializer.load(self.data_folder +
                                               self.dataset_mode + '.lmdb',
                                               shuffle=False)
        if config.mode == "train":
            self.df = dataflow.LocallyShuffleData(self.df, buffer_size=2000)
        self.df = dataflow.PrefetchData(self.df, nr_prefetch=500, nr_proc=1)
        self.df.reset_state()
Exemple #12
0
 def __init__(self, split, batch_size, set_size):
     if split == 'train':
         lmdb_path = f'{data_path}/ModelNet40_train_1024_middle.lmdb'
     else:
         lmdb_path = f'{data_path}/ModelNet40_test_1024_middle.lmdb'
     df = dataflow.LMDBSerializer.load(lmdb_path, shuffle=False)
     self.size = df.size()
     self.num_batches = self.size // batch_size
     if split == 'train':
         df = dataflow.LocallyShuffleData(df,
                                          buffer_size=2000)  # buffer_size
         df = dataflow.PrefetchData(df, num_prefetch=500, num_proc=1)
     df = BatchData(df, batch_size, set_size // 8, set_size - set_size // 8)
     if split == 'train':
         df = dataflow.PrefetchDataZMQ(df, num_proc=8)
     df = dataflow.RepeatedData(df, -1)
     df.reset_state()
     self.generator = df.get_data()
Exemple #13
0
def read_data(files=None,
              batch_size=1,
              window=2,
              random_rotation=False,
              repeat=False,
              shuffle_buffer=None,
              num_workers=1,
              cache_data=False):
    print(files[0:20], '...' if len(files) > 20 else '')

    # caching makes only sense if the data is finite
    if cache_data:
        if repeat == True:
            raise Exception("repeat must be False if cache_data==True")
        if random_rotation == True:
            raise Exception(
                "random_rotation must be False if cache_data==True")
        if num_workers != 1:
            raise Exception("num_workers must be 1 if cache_data==True")

    df = PhysicsSimDataFlow(
        files=files,
        random_rotation=random_rotation,
        shuffle=True if shuffle_buffer else False,
        window=window,
    )

    if repeat:
        df = dataflow.RepeatedData(df, -1)

    if shuffle_buffer:
        df = dataflow.LocallyShuffleData(df, shuffle_buffer)

    if num_workers > 1:
        df = dataflow.MultiProcessRunnerZMQ(df, num_proc=num_workers)

    df = dataflow.BatchData(df, batch_size=batch_size, use_list=True)

    if cache_data:
        df = dataflow.CacheData(df)

    df.reset_state()
    return df
Exemple #14
0
def lmdb_dataflow(lmdb_path,
                  batch_size,
                  num_points,
                  shuffle,
                  task,
                  render=False):
    df = dataflow.LMDBSerializer.load(lmdb_path, shuffle=False)
    size = df.size()
    if render:
        df = VirtualRenderData(df)
    if num_points is not None:
        df = ResampleData(df, num_points, task)
    if shuffle:
        df = dataflow.LocallyShuffleData(df, 1000)
        df = dataflow.PrefetchDataZMQ(df, 8)
    df = dataflow.BatchData(df, batch_size, use_list=True)
    df = dataflow.RepeatedData(df, -1)
    df.reset_state()
    return df, size
Exemple #15
0
def lmdb_dataflow(lmdb_path,
                  batch_size,
                  input_size,
                  output_size,
                  is_training,
                  test_speed=False,
                  filter_rate=0):
    df = dataflow.LMDBSerializer.load(lmdb_path, shuffle=False)
    df = dataflow.MapData(df,
                          lambda dp: [item for item in dp] + [random.random()])

    size = df.size()
    print(size)
    if is_training:
        df = dataflow.LocallyShuffleData(df, buffer_size=2000)
        df = dataflow.PrefetchData(df, nr_prefetch=500, nr_proc=1)
    df = BatchData(df, batch_size, input_size, output_size)
    if is_training:
        df = dataflow.PrefetchDataZMQ(df, nr_proc=8)
    df = dataflow.RepeatedData(df, -1)
    if test_speed:
        dataflow.TestDataSpeed(df, size=1000).start()
    df.reset_state()
    return df, size
def data_pipe_3dcnn_block(fmri_files,
                          confound_files,
                          label_matrix,
                          target_name=None,
                          flag_cnn='3d',
                          block_dura=1,
                          hrf_delay=0,
                          batch_size=32,
                          data_type='train',
                          nr_thread=4,
                          buffer_size=10,
                          dataselect_percent=1.0,
                          seed=814,
                          verbose=0):
    assert data_type in ['train', 'val', 'test']
    assert flag_cnn in ['3d', '2d']
    assert fmri_files is not None
    isTrain = data_type == 'train'
    isVal = data_type == 'val'
    isTest = data_type == 'test'

    buffer_size = int(min(len(fmri_files), buffer_size))
    nr_thread = int(min(len(fmri_files), nr_thread))

    ds0 = gen_fmri_file(fmri_files,
                        confound_files,
                        label_matrix,
                        data_type=data_type,
                        seed=seed)

    if target_name is None:
        target_name = np.unique(label_matrix)
    ##Subject_Num, Trial_Num = np.array(label_matrix).shape

    ####running the model
    start_time = time.clock()
    if flag_cnn == '2d':
        ds1 = dataflow.MultiThreadMapData(
            ds0,
            nr_thread=nr_thread,
            map_func=lambda dp: map_load_fmri_image_block(
                dp, target_name, block_dura=block_dura, hrf_delay=hrf_delay),
            buffer_size=buffer_size,
            strict=True)
    elif flag_cnn == '3d':
        ds1 = dataflow.MultiThreadMapData(
            ds0,
            nr_thread=nr_thread,
            map_func=lambda dp: map_load_fmri_image_3d_block(
                dp, target_name, block_dura=block_dura, hrf_delay=hrf_delay),
            buffer_size=buffer_size,
            strict=True)

    ds1 = dataflow.PrefetchData(ds1, buffer_size, 1)  ##1

    ds1 = split_samples(ds1,
                        subject_num=len(fmri_files),
                        batch_size=batch_size,
                        dataselect_percent=dataselect_percent)
    dataflowSize = ds1.size()

    if isTrain:
        if verbose:
            print('%d #Trials/Samples per subject with %d channels in tc' %
                  (ds1.Trial_Num, ds1.Block_dura))
        Trial_Num = ds1.Trial_Num
        ds1 = dataflow.LocallyShuffleData(ds1,
                                          buffer_size=Trial_Num * buffer_size,
                                          shuffle_interval=Trial_Num *
                                          buffer_size // 2)  #//2

    ds1 = dataflow.BatchData(ds1, batch_size=batch_size)

    if verbose:
        print('\n\nGenerating dataflow for %s datasets \n' % data_type)
        print('dataflowSize is ' + str(ds0.size()))
        print('Loading data using %d threads with %d buffer_size ... \n' %
              (nr_thread, buffer_size))
        print('prefetch dataflowSize is ' + str(dataflowSize))

        print('Time Usage of loading data in seconds: {} \n'.format(
            time.clock() - start_time))

    if isTrain:
        ds1 = dataflow.PrefetchDataZMQ(ds1, nr_proc=nr_thread)  ##1
    else:
        ds1 = dataflow.PrefetchDataZMQ(ds1, nr_proc=1)  ##1
    ##ds1._reset_once()
    ds1.reset_state()

    for df in ds1.get_data():
        yield (df[0].astype('float32'),
               one_hot(df[1],
                       len(target_name) + 1).astype('uint8'))


###end of tensorpack: multithread
##############################################################
    def __init__(
        self,
        corpus_path,
        tokenizer,
        bert_model,
        seq_len,
        encoding="utf-8",
        visual_target=0,
        hard_negative=False,
        batch_size=512,
        shuffle=False,
        num_workers=25,
        cache=10000,
        drop_last=False,
        cuda=False,
        local_rank=-1,
        objective=0,
        visualization=False,
    ):
        TRAIN_DATASET_SIZE = 3119449

        if dist.is_available() and local_rank != -1:

            num_replicas = dist.get_world_size()
            rank = dist.get_rank()

            lmdb_file = os.path.join(
                corpus_path, "training_feat_part_" + str(rank) + ".lmdb")
        else:
            lmdb_file = os.path.join(corpus_path,
                                     "gqa_resnext152_faster_rcnn_genome.lmdb")
            # lmdb_file = os.path.join(corpus_path, "validation_feat_all.lmdb")

            print("Loading from %s" % lmdb_file)

        ds = td.LMDBSerializer.load(lmdb_file, shuffle=False)
        self.num_dataset = len(ds)
        ds = td.LocallyShuffleData(ds, cache)
        caption_path = os.path.join(corpus_path, "caption_train.json")
        # caption_path = os.path.join(corpus_path, "caption_val.json")

        preprocess_function = BertPreprocessBatch(
            caption_path,
            tokenizer,
            bert_model,
            seq_len,
            36,
            self.num_dataset,
            encoding="utf-8",
            visual_target=visual_target,
            objective=objective,
        )

        ds = td.PrefetchData(ds, 5000, 1)
        ds = td.MapData(ds, preprocess_function)
        # self.ds = td.PrefetchData(ds, 1)
        ds = td.PrefetchDataZMQ(ds, num_workers)
        self.ds = td.BatchData(ds, batch_size)
        # self.ds = ds
        self.ds.reset_state()

        self.batch_size = batch_size
        self.num_workers = num_workers
Exemple #18
0
def get_dataflow(files, params, is_training):
    """
    Build a tensorflow Dataset from appropriate tfrecords files.
    :param files: list a file paths corresponding to appropriate tfrecords data
    :param params: parsed arguments
    :param is_training: bool, true for training.
    :return: (nextdata, num_samples).
    nextdata: list of tensorflow ops that produce the next input with the following elements:
    true_states, global_map, init_particles, observations, odometries, is_first_step.
    See House3DTrajData.get_data for definitions.
    num_samples: number of samples that make an epoch
    """

    mapmode = params.mapmode
    obsmode = params.obsmode
    batchsize = params.batchsize
    num_particles = params.num_particles
    trajlen = params.trajlen
    bptt_steps = params.bptt_steps

    # build initial covariance matrix of particles, in pixels and radians
    particle_std = params.init_particles_std.copy()
    particle_std[0] = particle_std[
        0] / params.map_pixel_in_meters  # convert meters to pixels
    particle_std2 = np.square(particle_std)  # element-wise variance
    init_particles_cov = np.diag(particle_std2[(0, 0,
                                                1), ])  # index is (0,0,1)

    df = House3DTrajData(
        files,
        mapmode,
        obsmode,
        trajlen,
        num_particles,
        params.init_particles_distr,
        init_particles_cov,
        seed=(params.seed if params.seed is not None and params.seed > 0 else
              (params.validseed if not is_training else None)))
    # data: true_states, global_map, init_particles, observation, odometry

    # make it a multiple of batchsize
    df = dataflow.FixedSizeData(df,
                                size=(df.size() // batchsize) * batchsize,
                                keep_state=False)

    # shuffle
    if is_training:
        df = dataflow.LocallyShuffleData(
            df, 100 * batchsize)  # buffer_size = 100 * batchsize

    # repeat data for the number of epochs
    df = dataflow.RepeatedData(df, params.epochs)

    # batch
    df = BatchDataWithPad(df, batchsize, padded_indices=(1, ))

    # break trajectory into multiple segments for BPTT training. Augment df with is_first_step indicator
    df = BreakForBPTT(df,
                      timed_indices=(0, 3, 4),
                      trajlen=trajlen,
                      bptt_steps=bptt_steps)
    # data: true_states, global_map, init_particles, observation, odometry, is_first_step

    num_samples = df.size() // params.epochs

    df.reset_state()

    # # test dataflow
    # df = dataflow.TestDataSpeed(dataflow.PrintData(df), 100)
    # df.start()

    obs_ch = {'rgb': 3, 'depth': 1, 'rgb-depth': 4}
    map_ch = {
        'wall': 1,
        'wall-door': 2,
        'wall-roomtype': 10,
        'wall-door-roomtype': 11
    }  # every semantic is a channel
    types = [
        tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.bool
    ]
    sizes = [
        (batchsize, bptt_steps, 3),
        (batchsize, None, None, map_ch[mapmode]),
        (batchsize, num_particles, 3),
        (batchsize, bptt_steps, 56, 56, obs_ch[obsmode]),
        (batchsize, bptt_steps, 3),
        (),
    ]

    # turn it into a tf dataset
    def tuplegen():
        for dp in df.get_data():
            yield tuple(dp)

    dataset = tf.data.Dataset.from_generator(tuplegen, tuple(types),
                                             tuple(sizes))
    iterator = dataset.make_one_shot_iterator()  # only read once
    nextdata = iterator.get_next()

    return nextdata, num_samples
def data_pipe_3dcnn_block(fmri_files,
                          confound_files,
                          label_matrix,
                          target_name=None,
                          flag_cnn='3d',
                          block_dura=1,
                          batch_size=32,
                          data_type='train',
                          nr_thread=nr_thread,
                          buffer_size=buffer_size):
    assert data_type in ['train', 'val', 'test']
    assert flag_cnn in ['3d', '2d']
    assert fmri_files is not None
    isTrain = data_type == 'train'
    isVal = data_type == 'val'

    print('\n\nGenerating dataflow for %s datasets \n' % data_type)

    buffer_size = int(min(len(fmri_files), buffer_size))
    nr_thread = int(min(len(fmri_files), nr_thread))

    ds0 = gen_fmri_file(fmri_files,
                        confound_files,
                        label_matrix,
                        data_type=data_type)
    print('dataflowSize is ' + str(ds0.size()))
    print('Loading data using %d threads with %d buffer_size ... \n' %
          (nr_thread, buffer_size))

    if target_name is None:
        target_name = np.unique(label_matrix)
    ##Subject_Num, Trial_Num = np.array(label_matrix).shape

    ####running the model
    start_time = time.clock()
    if flag_cnn == '2d':
        ds1 = dataflow.MultiThreadMapData(
            ds0,
            nr_thread=nr_thread,
            map_func=lambda dp: map_load_fmri_image_block(
                dp, target_name, block_dura=block_dura),
            buffer_size=buffer_size,
            strict=True)
    elif flag_cnn == '3d':
        ds1 = dataflow.MultiThreadMapData(
            ds0,
            nr_thread=nr_thread,
            map_func=lambda dp: map_load_fmri_image_3d_block(
                dp, target_name, block_dura=block_dura),
            buffer_size=buffer_size,
            strict=True)

    ds1 = dataflow.PrefetchData(ds1, buffer_size, 1)

    ds1 = split_samples(ds1)
    print('prefetch dataflowSize is ' + str(ds1.size()))

    if isTrain:
        print('%d #Trials/Samples per subject with %d channels in tc' %
              (ds1.Trial_Num, ds1.Block_dura))
        Trial_Num = ds1.Trial_Num
        #ds1 = dataflow.LocallyShuffleData(ds1, buffer_size=ds1.size() * buffer_size)
        ds1 = dataflow.LocallyShuffleData(ds1,
                                          buffer_size=Trial_Num * buffer_size,
                                          shuffle_interval=Trial_Num *
                                          buffer_size)  #//2

    ds1 = dataflow.BatchData(ds1, batch_size=batch_size, remainder=True)
    print('Time Usage of loading data in seconds: {} \n'.format(time.clock() -
                                                                start_time))

    ds1 = dataflow.PrefetchDataZMQ(ds1, nr_proc=1)
    #ds1._reset_once()
    ##ds1.reset_state()
    '''
    for df in ds1.get_data():
        if flag_cnn == '2d':
            yield (df[0].astype('float32'),to_categorical(df[1].astype('int32'), len(target_name)))
        elif flag_cnn == '3d':
            yield (df[0].astype('float32'),to_categorical(df[1].astype('int32'), len(target_name)))
    '''
    return ds1