示例#1
0
 def test_raw_data_visitor(self):
     manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source)
     manifest, finished = manifest_manager.alloc_unallocated_partition(2)
     self.assertFalse(finished)
     self.assertEqual(manifest.partition_id, 0)
     self.assertEqual(manifest.state, dj_pb.RawDataState.Syncing)
     self.assertEqual(manifest.allocated_rank_id, 2)
     options = customized_options.CustomizedOptions()
     options.set_raw_data_iter('TF_RECORD')
     rdv = raw_data_visitor.RawDataVisitor(self.etcd, self.data_source, 0,
                                           options)
     expected_index = 0
     for (index, item) in rdv:
         self.assertEqual(index, expected_index)
         expected_index += 1
         self.assertEqual(item.example_id, '{}'.format(index).encode())
     try:
         rdv.seek(200)
     except StopIteration:
         self.assertTrue(True)
         self.assertEqual(rdv.get_current_index(), 199)
     else:
         self.assertFalse(False)
     index, item = rdv.seek(50)
     self.assertEqual(index, 50)
     self.assertEqual(item.example_id, '{}'.format(index).encode())
     expected_index = index + 1
     for (index, item) in rdv:
         self.assertEqual(index, expected_index)
         expected_index += 1
         self.assertEqual(item.example_id, '{}'.format(index).encode())
示例#2
0
 def test_raw_data_visitor(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.data_source.raw_data_dir = "./test/compressed_raw_data"
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner')
     self.etcd.delete_prefix(self.data_source.data_source_meta.name)
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = os.path.join(self.data_source.raw_data_dir,
                                  'partition_0')
     self.assertTrue(gfile.Exists(partition_dir))
     manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source)
     options = customized_options.CustomizedOptions()
     options.set_raw_data_iter('TF_DATASET')
     options.set_compressed_type('GZIP')
     rdv = raw_data_visitor.RawDataVisitor(self.etcd, self.data_source, 0,
                                           options)
     expected_index = 0
     for (index, item) in rdv:
         if index > 0 and index % 1024 == 0:
             print("{} {} {}".format(index, item.example_id,
                                     item.event_time))
         self.assertEqual(index, expected_index)
         expected_index += 1
     self.assertGreater(expected_index, 0)
示例#3
0
 def test_data_block_dumper(self):
     self.generate_follower_data_block()
     self.generate_leader_raw_data()
     options = customized_options.CustomizedOptions()
     options.set_raw_data_iter('TF_RECORD')
     dbd = data_block_dumper.DataBlockDumperManager(
             self.etcd, self.data_source_l, 0, options
         )
     self.assertEqual(dbd.get_partition_id(), 0)
     self.assertEqual(dbd.get_next_data_block_index(), 0)
     for (idx, meta) in enumerate(self.dumped_metas):
         success, next_index = dbd.append_synced_data_block_meta(meta)
         self.assertTrue(success)
         self.assertEqual(next_index, idx + 1)
     self.assertTrue(dbd.need_dump())
     self.assertEqual(dbd.get_next_data_block_index(), len(self.dumped_metas))
     dbd.dump_data_blocks()
     dbm_f = data_block_manager.DataBlockManager(self.data_source_f, 0)
     dbm_l = data_block_manager.DataBlockManager(self.data_source_l, 0)
     self.assertEqual(dbm_f.get_dumped_data_block_num(), len(self.dumped_metas))
     self.assertEqual(dbm_f.get_dumped_data_block_num(),
                         dbm_l.get_dumped_data_block_num())
     for (idx, meta) in enumerate(self.dumped_metas):
         self.assertEqual(meta.data_block_index, idx)
         self.assertEqual(dbm_l.get_data_block_meta_by_index(idx)[0], meta)
         self.assertEqual(dbm_f.get_data_block_meta_by_index(idx)[0], meta)
         block_id = meta.block_id
         meta_fpth_l = os.path.join(self.data_source_l.data_block_dir, 'partition_0',
                                    block_id + common.DataBlockMetaSuffix)
         mitr = tf.io.tf_record_iterator(meta_fpth_l)
         meta_l = dj_pb.DataBlockMeta()
         meta_l.ParseFromString(next(mitr))
         self.assertEqual(meta_l, meta)
         meta_fpth_f = os.path.join(self.data_source_f.data_block_dir, 'partition_0',
                                    block_id + common.DataBlockMetaSuffix)
         mitr = tf.io.tf_record_iterator(meta_fpth_f)
         meta_f = dj_pb.DataBlockMeta()
         meta_f.ParseFromString(next(mitr)) 
         self.assertEqual(meta_f, meta)
         data_fpth_l = os.path.join(self.data_source_l.data_block_dir, 'partition_0',
                                    block_id + common.DataBlockSuffix)
         for (iidx, record) in enumerate(tf.io.tf_record_iterator(data_fpth_l)):
             example = tf.train.Example()
             example.ParseFromString(record)
             feat = example.features.feature
             self.assertEqual(feat['example_id'].bytes_list.value[0],
                              meta.example_ids[iidx])
         self.assertEqual(len(meta.example_ids), iidx + 1)
         data_fpth_f = os.path.join(self.data_source_f.data_block_dir, 'partition_0',
                                    block_id + common.DataBlockSuffix)
         for (iidx, record) in enumerate(tf.io.tf_record_iterator(data_fpth_f)):
             example = tf.train.Example()
             example.ParseFromString(record)
             feat = example.features.feature
             self.assertEqual(feat['example_id'].bytes_list.value[0],
                              meta.example_ids[iidx])
         self.assertEqual(len(meta.example_ids), iidx + 1)
示例#4
0
    def test_all_assembly(self):
        for i in range(self.data_source_l.data_source_meta.partition_num):
            self.generate_raw_data(
                self.data_source_l, i, 2048, 64,
                'leader_key_partition_{}'.format(i) + ':{}',
                'leader_value_partition_{}'.format(i) + ':{}')
            self.generate_raw_data(
                self.data_source_f, i, 4096, 128,
                'follower_key_partition_{}'.format(i) + ':{}',
                'follower_value_partition_{}'.format(i) + ':{}')

        worker_addr_l = 'localhost:4161'
        worker_addr_f = 'localhost:4162'

        options = customized_options.CustomizedOptions()
        options.set_raw_data_iter('TF_RECORD')
        options.set_example_joiner('STREAM_JOINER')
        worker_l = data_join_worker.DataJoinWorkerService(
            int(worker_addr_l.split(':')[1]), worker_addr_f,
            self.master_addr_l, 0, self.etcd_name, self.etcd_base_dir_l,
            self.etcd_addrs, options)

        worker_f = data_join_worker.DataJoinWorkerService(
            int(worker_addr_f.split(':')[1]), worker_addr_l,
            self.master_addr_f, 0, self.etcd_name, self.etcd_base_dir_f,
            self.etcd_addrs, options)

        worker_l.start()
        worker_f.start()

        while True:
            rsp_l = self.master_client_l.GetDataSourceState(
                self.data_source_l.data_source_meta)
            rsp_f = self.master_client_f.GetDataSourceState(
                self.data_source_f.data_source_meta)
            self.assertEqual(rsp_l.status.code, 0)
            self.assertEqual(rsp_l.role, common_pb.FLRole.Leader)
            self.assertEqual(rsp_l.data_source_type,
                             common_pb.DataSourceType.Sequential)
            self.assertEqual(rsp_f.status.code, 0)
            self.assertEqual(rsp_f.role, common_pb.FLRole.Follower)
            self.assertEqual(rsp_f.data_source_type,
                             common_pb.DataSourceType.Sequential)
            if (rsp_l.state == common_pb.DataSourceState.Finished
                    and rsp_f.state == common_pb.DataSourceState.Finished):
                break
            else:
                time.sleep(2)

        worker_l.stop()
        worker_f.stop()
        self.master_l.stop()
        self.master_f.stop()
    def test_example_join(self):
        self.generate_raw_data()
        self.generate_example_id()
        options = customized_options.CustomizedOptions()
        options.set_example_joiner('STREAM_JOINER')
        sei = joiner_impl.create_example_joiner(options, self.etcd,
                                                self.data_source, 0, options)
        sei.join_example()
        self.assertTrue(sei.join_finished())
        dbm = data_block_manager.DataBlockManager(self.data_source, 0)
        data_block_num = dbm.get_dumped_data_block_num()
        join_count = 0
        for data_block_index in range(data_block_num):
            meta = dbm.get_data_block_meta_by_index(data_block_index)[0]
            self.assertTrue(meta is not None)
            join_count += len(meta.example_ids)

        print("join rate {}/{}({}), min_matching_window {}, "\
              "max_matching_window {}".format(
              join_count, self.total_index,
              (join_count+.0)/self.total_index,
              self.data_source.data_source_meta.min_matching_window,
              self.data_source.data_source_meta.max_matching_window))
                        type=str,
                        default='TF_RECORD',
                        help='the type for raw data file')
    parser.add_argument('--example_joiner',
                        type=str,
                        default='STREAM_JOINER',
                        help='the method for example joiner')
    parser.add_argument('--compressed_type',
                        type=str,
                        choices=['', 'ZLIB', 'GZIP'],
                        help='the compressed type for raw data')
    parser.add_argument('--tf_eager_mode',
                        action='store_true',
                        help='use the eager_mode for tf')
    args = parser.parse_args()
    cst_options = customized_options.CustomizedOptions()
    if args.raw_data_iter is not None:
        cst_options.set_raw_data_iter(args.raw_data_iter)
    if args.example_joiner is not None:
        cst_options.set_example_joiner(args.example_joiner)
    if args.compressed_type is not None:
        cst_options.set_compressed_type(args.compressed_type)
    if args.tf_eager_mode:
        import tensorflow
        tensorflow.compat.v1.enable_eager_execution()
    worker_srv = DataJoinWorkerService(args.listen_port, args.peer_addr,
                                       args.master_addr, args.rank_id,
                                       args.etcd_name, args.etcd_base_dir,
                                       args.etcd_addrs, cst_options)
    worker_srv.run()