예제 #1
0
    def generator(self):
        """batch dict generator"""
        def worker(filter_id, perm):
            """ multiprocess worker"""
            def func_run():
                """ func_run """
                pid = os.getpid()
                np.random.seed(pid + int(time.time()))
                for batch_examples in self.batch_iter(filter_id, perm):
                    batch_dict = self.batch_fn(batch_examples)
                    yield batch_dict

            return func_run

        # consume a seed
        np.random.rand()
        if self.shuffle:
            perm = np.arange(0, len(self))
            np.random.shuffle(perm)
        else:
            perm = None
        if self.num_workers == 1:
            r = paddle.reader.buffered(worker(0, perm), self.buf_size)
        else:
            worker_pool = [
                worker(wid, perm) for wid in range(self.num_workers)
            ]
            worker = mp_reader.multiprocess_reader(worker_pool,
                                                   use_pipe=True,
                                                   queue_size=1000)
            r = paddle.reader.buffered(worker, self.buf_size)

        for batch in r():
            yield batch
예제 #2
0
파일: reader.py 프로젝트: zgsxwsdxg/PGL
 def reader():
     """reader"""
     batch_info = list(
         node_batch_iter(node_index, node_label, batch_size=batch_size))
     block_size = int(len(batch_info) / num_workers + 1)
     reader_pool = []
     for i in range(num_workers):
         reader_pool.append(
             worker(batch_info[block_size * i:block_size * (i + 1)],
                    graph_wrapper, samples))
     multi_process_sample = mp_reader.multiprocess_reader(reader_pool,
                                                          use_pipe=True,
                                                          queue_size=1000)
     r = parse_to_subgraph(multi_process_sample)
     return paddle.reader.buffered(r, 1000)
예제 #3
0
    def generator(self):
        """batch dict generator"""
        def worker(filter_id, perm):
            """ multiprocess worker"""
            def func_run():
                """ func_run """
                pid = os.getpid()
                #                 np.random.seed(pid + int(time.time()))
                for batch_examples in self.batch_iter(filter_id, perm):
                    try:
                        batch_dict = self.batch_fn(batch_examples)
                    except Exception as e:
                        traceback.print_exc()
                        log.info(traceback.format_exc())
                        log.info(str(e))
                        continue

                    if batch_dict is None:
                        continue
                    yield batch_dict

            return func_run

        perm = None

        if self.num_workers == 1:

            def post_fn():
                for batch in worker(worker(0, perm)):
                    yield self.post_fn(batch)

            r = paddle.reader.buffered(post_fn(), self.buf_size)
        else:
            worker_pool = [
                worker(wid, perm) for wid in range(self.num_workers)
            ]
            worker = mp_reader.multiprocess_reader(worker_pool,
                                                   use_pipe=True,
                                                   queue_size=1000)

            def post_fn():
                for batch in worker():
                    yield self.post_fn(batch)

            r = paddle.reader.buffered(post_fn, self.buf_size)

        for batch in r():
            yield batch
예제 #4
0
    def __iter__(self):
        # random seed will be fixed when using multiprocess,
        # so set seed explicitly every time
        np.random.seed()
        if self.num_workers == 1:
            r = paddle.reader.buffered(_DataLoaderIter(self, 0), self.buf_size)
        else:
            worker_pool = [
                _DataLoaderIter(self, wid) for wid in range(self.num_workers)
            ]
            workers = mp_reader.multiprocess_reader(worker_pool,
                                                    use_pipe=True,
                                                    queue_size=1000)
            r = paddle.reader.buffered(workers, self.buf_size)

        for batch in r():
            yield batch
예제 #5
0
파일: reader.py 프로젝트: Yelrose/PGL
 def reader():
     """ reader
     """
     batch_info = list(batch_iter(data, batch_size=batch_size))
     log.info("The size of batch:%d" % (len(batch_info)))
     block_size = int(len(batch_info) / num_workers + 1)
     reader_pool = []
     for i in range(num_workers):
         reader_pool.append(
             worker(num_layers, batch_info[block_size * i:block_size * (
                 i + 1)], graph_wrappers, samples, feed_name_list,
                    use_pyreader, graph, predict))
     use_pipe = True
     multi_process_sample = mp_reader.multiprocess_reader(
         reader_pool, use_pipe=use_pipe)
     r = parse_to_subgraph(multi_process_sample)
     if use_pipe:
         return paddle.reader.buffered(r, 5 * num_workers)
     else:
         return r
예제 #6
0
def multiprocess_data_generator(config, dataset):
    """Using multiprocess to generate training data.
    """
    num_sample_workers = config['trainer']['args']['num_sample_workers']

    walkpath_files = [[] for i in range(num_sample_workers)]
    for idx, f in enumerate(glob.glob(dataset.walk_files)):
        walkpath_files[idx % num_sample_workers].append(f)

    gen_data_pool = [
        dataset.pairs_generator(files) for files in walkpath_files
    ]
    if num_sample_workers == 1:
        gen_data_func = gen_data_pool[0]
    else:
        gen_data_func = mp_reader.multiprocess_reader(gen_data_pool,
                                                      use_pipe=True,
                                                      queue_size=100)

    return gen_data_func
예제 #7
0
파일: reader.py 프로젝트: Yelrose/PGL
    def reader():
        """reader"""
        batch_info = list(
            node_batch_iter(node_index, node_label, batch_size=batch_size))
        block_size = int(len(batch_info) / num_workers + 1)
        reader_pool = []
        for i in range(num_workers):
            reader_pool.append(
                worker(batch_info[block_size * i:block_size * (i + 1)], graph,
                       graph_wrapper, samples))

        if len(reader_pool) == 1:
            r = parse_to_subgraph(reader_pool[0], repr(graph_wrapper),
                                  graph.node_feat, with_parent_node_index)
        else:
            multi_process_sample = mp_reader.multiprocess_reader(
                reader_pool, use_pipe=True, queue_size=1000)
            r = parse_to_subgraph(multi_process_sample, repr(graph_wrapper),
                                  graph.node_feat, with_parent_node_index)
        return paddle.reader.buffered(r, num_workers)
예제 #8
0
파일: dataloader.py 프로젝트: zzs95/PGL
    def __iter__(self):
        """__iter__"""
        def worker(filter_id):
            def func_run():
                for batch_examples in self.batch_iter(filter_id):
                    batch_dict = self.batch_fn(batch_examples)
                    yield batch_dict

            return func_run

        if self.num_workers == 1:
            r = paddle.reader.buffered(worker(0), self.buf_size)
        else:
            worker_pool = [worker(wid) for wid in range(self.num_workers)]
            worker = mp_reader.multiprocess_reader(worker_pool,
                                                   use_pipe=True,
                                                   queue_size=1000)
            r = paddle.reader.buffered(worker, self.buf_size)

        for batch in r():
            yield batch