def test_csv_raw_data_visitor(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.raw_data_dir = path.join(
             path.dirname(path.abspath(__file__)), "../csv_raw_data"
         )
     self.kvstore = DBClient("etcd", True)
     self.kvstore.delete_prefix(common.data_source_kvstore_base_dir(self.data_source.data_source_meta.name))
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = path.join(self.raw_data_dir, common.partition_repr(0))
     self.assertTrue(gfile.Exists(partition_dir))
     manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.kvstore, self.data_source)
     manifest_manager.add_raw_data(
             0, [dj_pb.RawDataMeta(file_path=path.join(partition_dir, "test_raw_data.csv"),
                                   timestamp=timestamp_pb2.Timestamp(seconds=3))],
             True)
     raw_data_options = dj_pb.RawDataOptions(
             raw_data_iter='CSV_DICT',
             read_ahead_size=1<<20
         )
     rdm = raw_data_visitor.RawDataManager(self.kvstore, self.data_source,0)
     self.assertTrue(rdm.check_index_meta_by_process_index(0))
     rdv = raw_data_visitor.RawDataVisitor(self.kvstore, self.data_source, 0,
                                           raw_data_options)
     expected_index = 0
     for (index, item) in rdv:
         if index > 0 and index % 1024 == 0:
             print("{} {}".format(index, item.raw_id))
         self.assertEqual(index, expected_index)
         expected_index += 1
     self.assertEqual(expected_index, 4999)
Esempio n. 2
0
    def setUp(self) -> None:
        logging.getLogger().setLevel(logging.DEBUG)
        self._data_portal_name = 'test_data_portal_job_manager'

        self._kvstore = DBClient('etcd', True)
        self._portal_input_base_dir = './portal_input_dir'
        self._portal_output_base_dir = './portal_output_dir'
        self._raw_data_publish_dir = 'raw_data_publish_dir'
        if gfile.Exists(self._portal_input_base_dir):
            gfile.DeleteRecursively(self._portal_input_base_dir)
        gfile.MakeDirs(self._portal_input_base_dir)

        self._data_fnames = ['1001/{}.data'.format(i) for i in range(100)]
        self._data_fnames_without_success = \
            ['1002/{}.data'.format(i) for i in range(100)]
        self._csv_fnames = ['1003/{}.csv'.format(i) for i in range(100)]
        self._unused_fnames = ['{}.xx'.format(100)]
        self._all_fnames = self._data_fnames + \
                           self._data_fnames_without_success + \
                           self._csv_fnames + self._unused_fnames

        all_fnames_with_success = ['1001/_SUCCESS'] + ['1003/_SUCCESS'] +\
                                  self._all_fnames
        for fname in all_fnames_with_success:
            fpath = os.path.join(self._portal_input_base_dir, fname)
            gfile.MakeDirs(os.path.dirname(fpath))
            with gfile.Open(fpath, "w") as f:
                f.write('xxx')
Esempio n. 3
0
 def _setUpMySQL(self):
     self.kvstore_type = 'etcd'
     self.leader_base_dir = 'bytefl_l'
     self.follower_base_dir = 'bytefl_f'
     os.environ['ETCD_BASE_DIR'] = self.leader_base_dir
     self._kvstore_l = DBClient(self.kvstore_type, True)
     os.environ['ETCD_BASE_DIR'] = self.follower_base_dir
     self._kvstore_f = DBClient(self.kvstore_type, True)
Esempio n. 4
0
 def __init__(self, options, kvstore_type, use_mock_etcd=False):
     self._lock = threading.Condition()
     self._options = options
     kvstore = DBClient(kvstore_type, use_mock_etcd)
     pub_dir = self._options.raw_data_publish_dir
     self._publisher = RawDataPublisher(kvstore, pub_dir)
     self._process_pool_executor = \
             concur_futures.ProcessPoolExecutor(
                     options.offload_processor_number
                 )
     self._callback_submitter = None
     # pre fock sub processor before launch grpc client
     self._process_pool_executor.submit(min, 1, 2).result()
     self._id_batch_fetcher = IdBatchFetcher(kvstore, self._options)
     if self._options.role == common_pb.FLRole.Leader:
         private_key = rsa.PrivateKey.load_pkcs1(options.rsa_key_pem)
         self._psi_rsa_signer = LeaderPsiRsaSigner(
             self._id_batch_fetcher,
             options.batch_processor_options.max_flying_item,
             self._options.max_flying_sign_batch,
             self._options.slow_sign_threshold,
             self._process_pool_executor,
             private_key,
         )
         self._repr = 'leader-' + 'rsa_psi_preprocessor'
     else:
         public_key = rsa.PublicKey.load_pkcs1(options.rsa_key_pem)
         self._callback_submitter = concur_futures.ThreadPoolExecutor(1)
         self._psi_rsa_signer = FollowerPsiRsaSigner(
             self._id_batch_fetcher,
             options.batch_processor_options.max_flying_item,
             self._options.max_flying_sign_batch,
             self._options.max_flying_sign_rpc,
             self._options.sign_rpc_timeout_ms,
             self._options.slow_sign_threshold, self._options.stub_fanout,
             self._process_pool_executor, self._callback_submitter,
             public_key, self._options.leader_rsa_psi_signer_addr)
         self._repr = 'follower-' + 'rsa_psi_preprocessor'
     self._sort_run_dumper = SortRunDumper(options)
     self._sort_run_merger = SortRunMerger(
             dj_pb.SortRunMergerOptions(
                 merger_name='sort_run_merger_'+\
                             partition_repr(options.partition_id),
                 reader_options=dj_pb.RawDataOptions(
                     raw_data_iter=options.writer_options.output_writer,
                     compressed_type=options.writer_options.compressed_type,
                     read_ahead_size=\
                         options.sort_run_merger_read_ahead_buffer,
                     read_batch_size=\
                         options.sort_run_merger_read_batch_size
                 ),
                 writer_options=options.writer_options,
                 output_file_dir=options.output_file_dir,
                 partition_id=options.partition_id
             ),
             self._merger_comparator
         )
     self._produce_item_cnt = 0
     self._comsume_item_cnt = 0
     self._started = False
Esempio n. 5
0
 def __init__(self, listen_port, peer_addr, master_addr, rank_id,
              kvstore_type, options):
     master_channel = make_insecure_channel(
         master_addr,
         ChannelType.INTERNAL,
         options=[('grpc.max_send_message_length', 2**31 - 1),
                  ('grpc.max_receive_message_length', 2**31 - 1)])
     self._master_client = dj_grpc.DataJoinMasterServiceStub(master_channel)
     self._rank_id = rank_id
     kvstore = DBClient(kvstore_type, options.use_mock_etcd)
     data_source = self._sync_data_source()
     self._data_source_name = data_source.data_source_meta.name
     self._listen_port = listen_port
     self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
     peer_channel = make_insecure_channel(
         peer_addr,
         ChannelType.REMOTE,
         options=[('grpc.max_send_message_length', 2**31 - 1),
                  ('grpc.max_receive_message_length', 2**31 - 1)])
     peer_client = dj_grpc.DataJoinWorkerServiceStub(peer_channel)
     self._data_join_worker = DataJoinWorker(peer_client,
                                             self._master_client, rank_id,
                                             kvstore, data_source, options)
     dj_grpc.add_DataJoinWorkerServiceServicer_to_server(
         self._data_join_worker, self._server)
     self._role_repr = "leader" if data_source.role == \
             common_pb.FLRole.Leader else "follower"
     self._server.add_insecure_port('[::]:%d' % listen_port)
     self._server_started = False
Esempio n. 6
0
 def __init__(self, peer_client, data_source_name, kvstore_type, options):
     super(DataJoinMaster, self).__init__()
     self._data_source_name = data_source_name
     kvstore = DBClient(kvstore_type, options.use_mock_etcd)
     self._options = options
     self._fsm = MasterFSM(peer_client, data_source_name, kvstore,
                           self._options.batch_mode)
     self._data_source_meta = \
             self._fsm.get_data_source().data_source_meta
Esempio n. 7
0
 def __init__(self, listen_port, portal_name, kvstore_type, portal_options):
     self._portal_name = portal_name
     self._listen_port = listen_port
     self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
     kvstore = DBClient(kvstore_type, portal_options.use_mock_etcd)
     self._data_portal_master = DataPortalMaster(portal_name, kvstore,
                                                 portal_options)
     dp_grpc.add_DataPortalMasterServiceServicer_to_server(
         self._data_portal_master, self._server)
     self._server.add_insecure_port('[::]:%d' % listen_port)
     self._server_started = False
 def __init__(self, options, part_field, kvstore_type, use_mock_etcd=False):
     self._options = options
     self._part_field = part_field
     kvstore = DBClient(kvstore_type, use_mock_etcd)
     self._raw_data_batch_fetcher = RawDataBatchFetcher(kvstore, options)
     self._next_part_index = None
     self._dumped_process_index = None
     self._flying_writers = []
     self._dumped_file_metas = {}
     self._worker_map = {}
     self._started = False
     self._part_finished = False
     self._cond = threading.Condition()
Esempio n. 9
0
 def __init__(self, options):
     self._fiter = None
     self._index_meta = None
     self._item = None
     self._index = None
     self._iter_failed = False
     self._options = options
     #_options will be None for example id visitor
     if self._options and self._options.raw_data_cache_type == "disk":
         #use leveldb to manager the disk storage by default
         self._cache_type = DBClient("leveldb", False)
     else:
         self._cache_type = None
    def __init__(self, options):
        self._fiter = None
        self._index_meta = None
        self._item = None
        self._index = None
        self._iter_failed = False
        self._options = options

        try:
            self._validator = Validator(options.validation_ratio)
        except AttributeError:
            self._validator = Validator()
        #_options will be None for example id visitor
        if self._options and self._options.raw_data_cache_type == "disk":
            #use leveldb to manager the disk storage by default
            self._cache_type = DBClient("leveldb", False)
        else:
            self._cache_type = None
Esempio n. 11
0
class DataBlockVisitor(object):
    def __init__(self, data_source_name, kvstore_type, use_mock_etcd=False):
        self._kvstore = DBClient(kvstore_type, use_mock_etcd)
        self._data_source = retrieve_data_source(self._kvstore,
                                                 data_source_name)

    def LoadDataBlockRepByTimeFrame(self, start_time=None, end_time=None):
        partition_num = self._data_source.data_source_meta.partition_num
        data_block_fnames = {}
        for partition_id in range(0, partition_num):
            data_block_fnames[partition_id] = \
                self._list_data_block(partition_id)
        data_block_reps = {}
        for partition_id, fnames in data_block_fnames.items():
            manifest = self._sync_raw_data_manifest(partition_id)
            for idx, fname in enumerate(fnames):
                check_existed = (idx == len(fnames) - 1)
                rep = self._make_data_block_rep(partition_id, fname,
                                                check_existed)
                filtered = True
                reason = ''
                if rep is None:
                    reason = 'failed to create data block rep'
                elif end_time is not None and rep.end_time > end_time:
                    reason = 'excess time frame'
                elif start_time is not None and rep.end_time <= start_time:
                    reason = 'less time frame'
                elif self._filter_by_visible(rep.data_block_index, manifest):
                    reason = 'data block visible'
                else:
                    data_block_reps[rep.block_id] = rep
                    filtered = False
                if filtered:
                    logging.debug('skip %s since %s', fname, reason)
        return data_block_reps

    def LoadDataBlockReqByIndex(self, partition_id, data_block_index):
        partition_num = self._data_source.data_source_meta.partition_num
        if partition_id < 0 or partition_id >= partition_num:
            raise IndexError("partition {} out range".format(partition_id))
        dirpath = self._partition_data_block_dir(partition_id)
        meta_fname = encode_data_block_meta_fname(self._data_source_name(),
                                                  partition_id,
                                                  data_block_index)
        meta_fpath = os.path.join(dirpath, meta_fname)
        meta = load_data_block_meta(meta_fpath)
        manifest = self._sync_raw_data_manifest(partition_id)
        if meta is not None and \
                not self._filter_by_visible(meta.data_block_index, manifest):
            fname = encode_data_block_fname(self._data_source_name(), meta)
            return DataBlockRep(self._data_source_name(), fname, partition_id,
                                dirpath)
        return None

    def LoadDataBlockRepByBlockId(self, block_id):
        block_info = decode_block_id(block_id)
        dbr = self.LoadDataBlockReqByIndex(block_info['partition_id'],
                                           block_info['data_block_index'])
        if dbr:
            assert dbr.block_id == block_id, \
                    "Invalid datablock, expected %s, but got %s), please "\
                    "check datasource!"%(block_id, dbr.block_id)
        return dbr

    def _list_data_block(self, partition_id):
        dirpath = self._partition_data_block_dir(partition_id)
        if gfile.Exists(dirpath) and gfile.IsDirectory(dirpath):
            return [
                f for f in gfile.ListDirectory(dirpath)
                if f.endswith(DataBlockSuffix)
            ]
        return []

    def _partition_data_block_dir(self, partition_id):
        return os.path.join(data_source_data_block_dir(self._data_source),
                            partition_repr(partition_id))

    def _make_data_block_rep(self, partition_id, data_block_fname,
                             check_existed):
        try:
            rep = DataBlockRep(self._data_source.data_source_meta.name,
                               data_block_fname, partition_id,
                               self._partition_data_block_dir(partition_id),
                               check_existed)
        except Exception as e:  # pylint: disable=broad-except
            logging.warning("Failed to create data block rep for %s in"\
                            "partition %d reason %s", data_block_fname,
                            partition_id, e)
            return None
        return rep

    def _data_source_name(self):
        return self._data_source.data_source_meta.name

    def _sync_raw_data_manifest(self, partition_id):
        kvstore_key = partition_manifest_kvstore_key(self._data_source_name(),
                                                     partition_id)
        data = self._kvstore.get_data(kvstore_key)
        assert data is not None, "raw data manifest of partition "\
                                 "{} must be existed".format(partition_id)
        return text_format.Parse(data, dj_pb.RawDataManifest())

    def _filter_by_visible(self, index, manifest):
        join_state = manifest.join_example_rep.state
        if self._data_source.role == common_pb.FLRole.Follower and \
                join_state != dj_pb.JoinExampleState.Joined:
            return index > manifest.peer_dumped_index
        return False
Esempio n. 12
0
                        type=int,
                        default=None,
                        help='Max number of files in a job')
    parser.add_argument('--start_date',
                        type=str,
                        default=None,
                        help='Start date of input data, format %Y%m%d')
    parser.add_argument('--end_date',
                        type=str,
                        default=None,
                        help='End date of input data, format %Y%m%d')
    args = parser.parse_args()
    set_logger()

    use_mock_etcd = (args.kvstore_type == 'mock')
    kvstore = DBClient(args.kvstore_type, use_mock_etcd)
    kvstore_key = common.portal_kvstore_base_dir(args.data_portal_name)
    portal_manifest = kvstore.get_data(kvstore_key)
    data_portal_type = dp_pb.DataPortalType.PSI if \
        args.data_portal_type == 'PSI' else dp_pb.DataPortalType.Streaming
    if portal_manifest is None:
        portal_manifest = dp_pb.DataPortalManifest(
            name=args.data_portal_name,
            data_portal_type=data_portal_type,
            output_partition_num=args.output_partition_num,
            input_file_wildcard=args.input_file_wildcard,
            input_base_dir=args.input_base_dir,
            output_base_dir=args.output_base_dir,
            raw_data_publish_dir=args.raw_data_publish_dir,
            processing_job_id=-1)
        kvstore.set_data(kvstore_key, text_format.\
Esempio n. 13
0
class TestDataPortalJobManager(unittest.TestCase):
    def setUp(self) -> None:
        logging.getLogger().setLevel(logging.DEBUG)
        self._data_portal_name = 'test_data_portal_job_manager'

        self._kvstore = DBClient('etcd', True)
        self._portal_input_base_dir = './portal_input_dir'
        self._portal_output_base_dir = './portal_output_dir'
        self._raw_data_publish_dir = 'raw_data_publish_dir'
        if gfile.Exists(self._portal_input_base_dir):
            gfile.DeleteRecursively(self._portal_input_base_dir)
        gfile.MakeDirs(self._portal_input_base_dir)

        self._data_fnames = ['1001/{}.data'.format(i) for i in range(100)]
        self._data_fnames_without_success = \
            ['1002/{}.data'.format(i) for i in range(100)]
        self._csv_fnames = ['1003/{}.csv'.format(i) for i in range(100)]
        self._unused_fnames = ['{}.xx'.format(100)]
        self._all_fnames = self._data_fnames + \
                           self._data_fnames_without_success + \
                           self._csv_fnames + self._unused_fnames

        all_fnames_with_success = ['1001/_SUCCESS'] + ['1003/_SUCCESS'] +\
                                  self._all_fnames
        for fname in all_fnames_with_success:
            fpath = os.path.join(self._portal_input_base_dir, fname)
            gfile.MakeDirs(os.path.dirname(fpath))
            with gfile.Open(fpath, "w") as f:
                f.write('xxx')

    def tearDown(self) -> None:
        gfile.DeleteRecursively(self._portal_input_base_dir)

    def _list_input_dir(self, portal_options, file_wildcard,
                        target_fnames, max_files_per_job=8000):
        portal_manifest = dp_pb.DataPortalManifest(
            name=self._data_portal_name,
            data_portal_type=dp_pb.DataPortalType.Streaming,
            output_partition_num=4,
            input_file_wildcard=file_wildcard,
            input_base_dir=self._portal_input_base_dir,
            output_base_dir=self._portal_output_base_dir,
            raw_data_publish_dir=self._raw_data_publish_dir,
            processing_job_id=-1,
            next_job_id=0
        )
        self._kvstore.set_data(
            common.portal_kvstore_base_dir(self._data_portal_name),
            text_format.MessageToString(portal_manifest))

        with Timer("DataPortalJobManager initialization"):
            data_portal_job_manager = DataPortalJobManager(
                self._kvstore, self._data_portal_name,
                portal_options.long_running,
                portal_options.check_success_tag,
                portal_options.single_subfolder,
                portal_options.files_per_job_limit,
                max_files_per_job
            )
        portal_job = data_portal_job_manager._sync_processing_job()
        target_fnames.sort()
        fpaths = [os.path.join(self._portal_input_base_dir, f)
                  for f in target_fnames]
        self.assertEqual(len(fpaths), len(portal_job.fpaths))
        for index, fpath in enumerate(fpaths):
            self.assertEqual(fpath, portal_job.fpaths[index])

    def test_list_input_dir(self):
        portal_options = dp_pb.DataPotraMasterlOptions(
            use_mock_etcd=True,
            long_running=False,
            check_success_tag=True,
            single_subfolder=False,
            files_per_job_limit=None
        )
        self._list_input_dir(portal_options, "*.data", self._data_fnames)

    def test_list_input_dir_single_folder(self):
        portal_options = dp_pb.DataPotraMasterlOptions(
            use_mock_etcd=True,
            long_running=False,
            check_success_tag=False,
            single_subfolder=True,
            files_per_job_limit=None,
        )
        self._list_input_dir(
            portal_options, "*.data", self._data_fnames)

    def test_list_input_dir_files_limit(self):
        portal_options = dp_pb.DataPotraMasterlOptions(
            use_mock_etcd=True,
            long_running=False,
            check_success_tag=False,
            single_subfolder=False,
            files_per_job_limit=1,
        )
        self._list_input_dir(
            portal_options, "*.data", self._data_fnames)

        portal_options = dp_pb.DataPotraMasterlOptions(
            use_mock_etcd=True,
            long_running=False,
            check_success_tag=False,
            single_subfolder=False,
            files_per_job_limit=150,
        )
        self._list_input_dir(
            portal_options, "*.data", self._data_fnames)

        portal_options = dp_pb.DataPotraMasterlOptions(
            use_mock_etcd=True,
            long_running=False,
            check_success_tag=False,
            single_subfolder=False,
            files_per_job_limit=200,
        )
        self._list_input_dir(
            portal_options, "*.data",
            self._data_fnames + self._data_fnames_without_success)

    def test_list_input_dir_over_limit(self):
        portal_options = dp_pb.DataPotraMasterlOptions(
            use_mock_etcd=True,
            long_running=False,
            check_success_tag=False,
            single_subfolder=False,
        )
        self._list_input_dir(
            portal_options, "*.data", self._data_fnames, max_files_per_job=100)

        self._list_input_dir(
            portal_options, "*.data",
            self._data_fnames + self._data_fnames_without_success,
            max_files_per_job=200)

    def test_list_input_dir_without_success_check(self):
        portal_options = dp_pb.DataPotraMasterlOptions(
            use_mock_etcd=True,
            long_running=False,
            check_success_tag=False,
            single_subfolder=False,
            files_per_job_limit=None
        )
        self._list_input_dir(
            portal_options, "*.data",
            self._data_fnames + self._data_fnames_without_success)

    def test_list_input_dir_without_wildcard(self):
        portal_options = dp_pb.DataPotraMasterlOptions(
            use_mock_etcd=True,
            long_running=False,
            check_success_tag=True,
            single_subfolder=False,
            files_per_job_limit=None
        )
        self._list_input_dir(
            portal_options, None,
            self._data_fnames + self._csv_fnames)

    def test_list_input_dir_without_wildcard_and_success_check(self):
        portal_options = dp_pb.DataPotraMasterlOptions(
            use_mock_etcd=True,
            long_running=False,
            check_success_tag=False,
            single_subfolder=False,
            files_per_job_limit=None
        )
        self._list_input_dir(portal_options, None, self._all_fnames)
    parser.add_argument('--raw_data_sub_dir',
                        type=str,
                        required=True,
                        help='the mysql base dir to subscribe new raw data')
    args = parser.parse_args()
    data_source = common_pb.DataSource()
    data_source.data_source_meta.name = args.data_source_name
    data_source.data_source_meta.partition_num = args.partition_num
    data_source.data_source_meta.start_time = args.start_time
    data_source.data_source_meta.end_time = args.end_time
    data_source.data_source_meta.negative_sampling_rate = \
            args.negative_sampling_rate
    if args.role.upper() == 'LEADER':
        data_source.role = common_pb.FLRole.Leader
    else:
        assert args.role.upper() == 'FOLLOWER'
        data_source.role = common_pb.FLRole.Follower
    data_source.output_base_dir = args.output_base_dir
    data_source.raw_data_sub_dir = args.raw_data_sub_dir
    data_source.state = common_pb.DataSourceState.Init
    kvstore = DBClient(args.kvstore_type)
    master_kvstore_key = common.data_source_kvstore_base_dir(
        data_source.data_source_meta.name)
    raw_data = kvstore.get_data(master_kvstore_key)
    if raw_data is None:
        logging.info("data source %s is not existed", args.data_source_name)
        common.commit_data_source(kvstore, data_source)
        logging.info("apply new data source %s", args.data_source_name)
    else:
        logging.info("data source %s has been existed", args.data_source_name)
    def test_api(self):
        logging.getLogger().setLevel(logging.DEBUG)
        os.environ['ETCD_BASE_DIR'] = 'bytefl_l'
        data_source_name = 'test_data_source'
        kvstore_l = DBClient('etcd', True)
        os.environ['ETCD_BASE_DIR'] = 'bytefl_f'
        kvstore_f = DBClient('etcd', True)
        kvstore_l.delete_prefix(
            common.data_source_kvstore_base_dir(data_source_name))
        kvstore_f.delete_prefix(
            common.data_source_kvstore_base_dir(data_source_name))
        data_source_l = common_pb.DataSource()
        data_source_l.role = common_pb.FLRole.Leader
        data_source_l.state = common_pb.DataSourceState.Init
        data_source_l.output_base_dir = "./ds_output_l"
        data_source_f = common_pb.DataSource()
        data_source_f.role = common_pb.FLRole.Follower
        data_source_f.state = common_pb.DataSourceState.Init
        data_source_f.output_base_dir = "./ds_output_f"
        data_source_meta = common_pb.DataSourceMeta()
        data_source_meta.name = data_source_name
        data_source_meta.partition_num = 1
        data_source_meta.start_time = 0
        data_source_meta.end_time = 100000000
        data_source_l.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(kvstore_l, data_source_l)
        data_source_f.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(kvstore_f, data_source_f)

        master_addr_l = 'localhost:4061'
        master_addr_f = 'localhost:4062'
        options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True)
        os.environ['ETCD_BASE_DIR'] = 'bytefl_l'
        master_l = data_join_master.DataJoinMasterService(
            int(master_addr_l.split(':')[1]), master_addr_f, data_source_name,
            'etcd', options)
        master_l.start()
        os.environ['ETCD_BASE_DIR'] = 'bytefl_f'
        master_f = data_join_master.DataJoinMasterService(
            int(master_addr_f.split(':')[1]), master_addr_l, data_source_name,
            'etcd', options)
        master_f.start()
        channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL)
        client_l = dj_grpc.DataJoinMasterServiceStub(channel_l)
        channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL)
        client_f = dj_grpc.DataJoinMasterServiceStub(channel_f)

        while True:
            req_l = dj_pb.DataSourceRequest(
                data_source_meta=data_source_l.data_source_meta)
            req_f = dj_pb.DataSourceRequest(
                data_source_meta=data_source_f.data_source_meta)
            dss_l = client_l.GetDataSourceStatus(req_l)
            dss_f = client_f.GetDataSourceStatus(req_f)
            self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
            self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
            if dss_l.state == common_pb.DataSourceState.Processing and \
                    dss_f.state == common_pb.DataSourceState.Processing:
                break
            else:
                time.sleep(2)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=-1,
            join_example=empty_pb2.Empty())

        rdrsp = client_f.RequestJoinPartition(rdreq)
        self.assertTrue(rdrsp.status.code == 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0)
        self.assertEqual(rdrsp.manifest.join_example_rep.state,
                         dj_pb.JoinExampleState.Joining)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        #check idempotent
        rdrsp = client_f.RequestJoinPartition(rdreq)
        self.assertTrue(rdrsp.status.code == 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0)
        self.assertEqual(rdrsp.manifest.join_example_rep.state,
                         dj_pb.JoinExampleState.Joining)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            join_example=empty_pb2.Empty())
        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertTrue(rdrsp.status.code == 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0)
        self.assertEqual(rdrsp.manifest.join_example_rep.state,
                         dj_pb.JoinExampleState.Joining)
        self.assertEqual(rdrsp.manifest.partition_id, 0)
        #check idempotent
        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertTrue(rdrsp.status.code == 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0)
        self.assertEqual(rdrsp.manifest.join_example_rep.state,
                         dj_pb.JoinExampleState.Joining)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=1,
            partition_id=-1,
            sync_example_id=empty_pb2.Empty())
        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1)
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.state,
                         dj_pb.SyncExampleIdState.Syncing)
        self.assertEqual(rdrsp.manifest.partition_id, 0)
        #check idempotent
        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1)
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.state,
                         dj_pb.SyncExampleIdState.Syncing)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=1,
            partition_id=0,
            sync_example_id=empty_pb2.Empty())
        rdrsp = client_f.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1)
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.state,
                         dj_pb.SyncExampleIdState.Syncing)
        self.assertEqual(rdrsp.manifest.partition_id, 0)
        #check idempotent
        rdrsp = client_f.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1)
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.state,
                         dj_pb.SyncExampleIdState.Syncing)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdreq1 = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=1,
            partition_id=0,
            sync_example_id=empty_pb2.Empty())

        try:
            rsp = client_l.FinishJoinPartition(rdreq1)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        rdreq2 = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            join_example=empty_pb2.Empty())
        try:
            rsp = client_l.FinishJoinPartition(rdreq2)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        manifest_l = client_l.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_l is not None)
        self.assertFalse(manifest_l.finished)
        self.assertEqual(manifest_l.next_process_index, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_l.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 0)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=False,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='a',
                        timestamp=timestamp_pb2.Timestamp(seconds=3))
                ]))
        rsp = client_l.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_l = client_l.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_l is not None)
        self.assertFalse(manifest_l.finished)
        self.assertEqual(manifest_l.next_process_index, 1)
        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_l.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 3)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=False,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='b',
                        timestamp=timestamp_pb2.Timestamp(seconds=5))
                ]))
        rsp = client_l.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_l = client_l.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_l is not None)
        self.assertFalse(manifest_l.finished)
        self.assertEqual(manifest_l.next_process_index, 2)
        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_l.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 5)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=True,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='b',
                        timestamp=timestamp_pb2.Timestamp(seconds=5)),
                    dj_pb.RawDataMeta(
                        file_path='a',
                        timestamp=timestamp_pb2.Timestamp(seconds=3))
                ]))
        rsp = client_l.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_l = client_l.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_l is not None)
        self.assertFalse(manifest_l.finished)
        self.assertEqual(manifest_l.next_process_index, 2)
        rsp = client_l.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 5)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        manifest_f = client_f.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_f is not None)
        self.assertFalse(manifest_f.finished)
        self.assertEqual(manifest_f.next_process_index, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_f.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 0)
        self.assertEqual(rsp.timestamp.nanos, 0)
        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=False,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='a',
                        timestamp=timestamp_pb2.Timestamp(seconds=1))
                ]))
        rsp = client_f.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_f = client_f.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_f is not None)
        self.assertFalse(manifest_f.finished)
        self.assertEqual(manifest_f.next_process_index, 1)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_f.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 1)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=False,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='b',
                        timestamp=timestamp_pb2.Timestamp(seconds=2))
                ]))
        rsp = client_f.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_f = client_f.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_f is not None)
        self.assertFalse(manifest_f.finished)
        self.assertEqual(manifest_f.next_process_index, 2)
        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_f.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 2)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=True,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='a',
                        timestamp=timestamp_pb2.Timestamp(seconds=1)),
                    dj_pb.RawDataMeta(
                        file_path='b',
                        timestamp=timestamp_pb2.Timestamp(seconds=2))
                ]))
        rsp = client_f.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_f = client_f.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_f is not None)
        self.assertFalse(manifest_f.finished)
        self.assertEqual(manifest_f.next_process_index, 2)
        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_f.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 2)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            finish_raw_data=empty_pb2.Empty())
        rsp = client_l.FinishRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_l.FinishRawData(rdreq)
        self.assertEqual(rsp.code, 0)

        manifest_l = client_l.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_l is not None)
        self.assertTrue(manifest_l.finished)
        self.assertEqual(manifest_l.next_process_index, 2)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=False,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='x',
                        timestamp=timestamp_pb2.Timestamp(seconds=4))
                ]))
        try:
            rsp = client_l.AddRawData(rdreq)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        try:
            rsp = client_f.FinishJoinPartition(rdreq2)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        rsp = client_l.FinishJoinPartition(rdreq1)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_l.FinishJoinPartition(rdreq1)
        self.assertEqual(rsp.code, 0)

        rsp = client_f.FinishJoinPartition(rdreq1)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_f.FinishJoinPartition(rdreq1)
        self.assertEqual(rsp.code, 0)

        try:
            rsp = client_f.FinishJoinPartition(rdreq2)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
            finish_raw_data=empty_pb2.Empty())
        rsp = client_f.FinishRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_f.FinishRawData(rdreq)
        self.assertEqual(rsp.code, 0)

        manifest_f = client_f.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_f is not None)
        self.assertTrue(manifest_f.finished)
        self.assertEqual(manifest_f.next_process_index, 2)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=True,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='x',
                        timestamp=timestamp_pb2.Timestamp(seconds=3))
                ]))
        try:
            rsp = client_f.AddRawData(rdreq)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        rsp = client_f.FinishJoinPartition(rdreq2)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_f.FinishJoinPartition(rdreq2)
        self.assertEqual(rsp.code, 0)

        rsp = client_l.FinishJoinPartition(rdreq2)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_l.FinishJoinPartition(rdreq2)
        self.assertEqual(rsp.code, 0)

        while True:
            req_l = dj_pb.DataSourceRequest(
                data_source_meta=data_source_l.data_source_meta)
            req_f = dj_pb.DataSourceRequest(
                data_source_meta=data_source_f.data_source_meta)
            dss_l = client_l.GetDataSourceStatus(req_l)
            dss_f = client_f.GetDataSourceStatus(req_f)
            self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
            self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
            if dss_l.state == common_pb.DataSourceState.Finished and \
                    dss_f.state == common_pb.DataSourceState.Finished:
                break
            else:
                time.sleep(2)

        master_l.stop()
        master_f.stop()
Esempio n. 16
0
 def __init__(self, data_source_name, kvstore_type, use_mock_etcd=False):
     self._kvstore = DBClient(kvstore_type, use_mock_etcd)
     self._data_source = retrieve_data_source(self._kvstore,
                                              data_source_name)
Esempio n. 17
0
class RsaPsi(unittest.TestCase):
    def _setUpMySQL(self):
        self.kvstore_type = 'etcd'
        self.leader_base_dir = 'bytefl_l'
        self.follower_base_dir = 'bytefl_f'
        os.environ['ETCD_BASE_DIR'] = self.leader_base_dir
        self._kvstore_l = DBClient(self.kvstore_type, True)
        os.environ['ETCD_BASE_DIR'] = self.follower_base_dir
        self._kvstore_f = DBClient(self.kvstore_type, True)

    def _setUpDataSource(self):
        self._data_source_name = 'test_data_source'
        self._kvstore_l.delete_prefix(
            common.data_source_kvstore_base_dir(self._data_source_name))
        self._kvstore_f.delete_prefix(
            common.data_source_kvstore_base_dir(self._data_source_name))
        self._data_source_l = common_pb.DataSource()
        self._data_source_l.role = common_pb.FLRole.Leader
        self._data_source_l.state = common_pb.DataSourceState.Init
        self._data_source_l.output_base_dir = "./ds_output_l"
        self._raw_data_dir_l = "./raw_data_l"
        self._data_source_l.raw_data_sub_dir = "./raw_data_sub_dir_l"
        self._data_source_f = common_pb.DataSource()
        self._data_source_f.role = common_pb.FLRole.Follower
        self._data_source_f.state = common_pb.DataSourceState.Init
        self._data_source_f.output_base_dir = "./ds_output_f"
        self._raw_data_dir_f = "./raw_data_f"
        self._data_source_f.raw_data_sub_dir = "./raw_data_sub_dir_f"
        data_source_meta = common_pb.DataSourceMeta()
        data_source_meta.name = self._data_source_name
        data_source_meta.partition_num = 4
        data_source_meta.start_time = 0
        data_source_meta.end_time = 100000000
        self._data_source_l.data_source_meta.MergeFrom(data_source_meta)
        self._data_source_f.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(self._kvstore_l, self._data_source_l)
        common.commit_data_source(self._kvstore_f, self._data_source_f)

    def _generate_input_csv(self, cands, base_dir):
        if not gfile.Exists(base_dir):
            gfile.MakeDirs(base_dir)
        fpaths = []
        random.shuffle(cands)
        csv_writers = []
        partition_num = self._data_source_l.data_source_meta.partition_num
        for partition_id in range(partition_num):
            fpath = os.path.join(base_dir,
                                 str(partition_id) + common.RawDataFileSuffix)
            fpaths.append(fpath)
            csv_writers.append(csv_dict_writer.CsvDictWriter(fpath))
        for item in cands:
            partition_id = CityHash32(item) % partition_num
            raw = OrderedDict()
            raw['raw_id'] = item
            raw['feat_0'] = 'leader-' + str((partition_id << 30) + 0) + item
            raw['feat_1'] = 'leader-' + str((partition_id << 30) + 1) + item
            raw['feat_2'] = 'leader-' + str((partition_id << 30) + 2) + item
            csv_writers[partition_id].write(raw)
        for csv_writer in csv_writers:
            csv_writer.close()
        return fpaths

    def _generate_input_tf_record(self, cands, base_dir):
        if not gfile.Exists(base_dir):
            gfile.MakeDirs(base_dir)
        fpaths = []
        random.shuffle(cands)
        tfr_writers = []
        partition_num = self._data_source_l.data_source_meta.partition_num
        for partition_id in range(partition_num):
            fpath = os.path.join(base_dir,
                                 str(partition_id) + common.RawDataFileSuffix)
            fpaths.append(fpath)
            tfr_writers.append(tf.io.TFRecordWriter(fpath))
        for item in cands:
            partition_id = CityHash32(item) % partition_num
            feat = {}
            feat['raw_id'] = tf.train.Feature(bytes_list=tf.train.BytesList(
                value=[item.encode()]))
            f0 = 'follower' + str((partition_id << 30) + 0) + item
            f1 = 'follower' + str((partition_id << 30) + 1) + item
            f2 = 'follower' + str((partition_id << 30) + 2) + item
            feat['feat_0'] = tf.train.Feature(bytes_list=tf.train.BytesList(
                value=[f0.encode()]))
            feat['feat_1'] = tf.train.Feature(bytes_list=tf.train.BytesList(
                value=[f1.encode()]))
            feat['feat_2'] = tf.train.Feature(bytes_list=tf.train.BytesList(
                value=[f2.encode()]))
            example = tf.train.Example(features=tf.train.Features(
                feature=feat))
            tfr_writers[partition_id].write(example.SerializeToString())
        for tfr_writer in tfr_writers:
            tfr_writer.close()
        return fpaths

    def _setUpRsaPsiConf(self):
        self._input_dir_l = './rsa_psi_raw_input_l'
        self._input_dir_f = './rsa_psi_raw_input_f'
        self._pre_processor_ouput_dir_l = './pre_processor_output_dir_l'
        self._pre_processor_ouput_dir_f = './pre_processor_output_dir_f'
        key_dir = path.join(path.dirname(path.abspath(__file__)), '../rsa_key')
        self._rsa_public_key_path = path.join(key_dir, 'rsa_psi.pub')
        self._rsa_private_key_path = path.join(key_dir, 'rsa_psi')
        self._raw_data_pub_dir_l = self._data_source_l.raw_data_sub_dir
        self._raw_data_pub_dir_f = self._data_source_f.raw_data_sub_dir

    def _gen_psi_input_raw_data(self):
        self._intersection_ids = set(
            ['{:09}'.format(i) for i in range(0, 1 << 16) if i % 3 == 0])
        self._rsa_raw_id_l = set([
            '{:09}'.format(i) for i in range(0, 1 << 16) if i % 2 == 0
        ]) | self._intersection_ids
        self._rsa_raw_id_f = set([
            '{:09}'.format(i) for i in range(0, 1 << 16) if i % 2 == 1
        ]) | self._intersection_ids
        self._input_dir_l = './rsa_psi_raw_input_l'
        self._input_dir_f = './rsa_psi_raw_input_f'
        self._psi_raw_data_fpaths_l = self._generate_input_csv(
            list(self._rsa_raw_id_l), self._input_dir_l)
        self._psi_raw_data_fpaths_f = self._generate_input_tf_record(
            list(self._rsa_raw_id_f), self._input_dir_f)

    def _remove_existed_dir(self):
        if gfile.Exists(self._input_dir_l):
            gfile.DeleteRecursively(self._input_dir_l)
        if gfile.Exists(self._input_dir_f):
            gfile.DeleteRecursively(self._input_dir_f)
        if gfile.Exists(self._pre_processor_ouput_dir_l):
            gfile.DeleteRecursively(self._pre_processor_ouput_dir_l)
        if gfile.Exists(self._pre_processor_ouput_dir_f):
            gfile.DeleteRecursively(self._pre_processor_ouput_dir_f)
        if gfile.Exists(self._data_source_l.output_base_dir):
            gfile.DeleteRecursively(self._data_source_l.output_base_dir)
        if gfile.Exists(self._raw_data_dir_l):
            gfile.DeleteRecursively(self._raw_data_dir_l)
        if gfile.Exists(self._data_source_f.output_base_dir):
            gfile.DeleteRecursively(self._data_source_f.output_base_dir)
        if gfile.Exists(self._raw_data_dir_f):
            gfile.DeleteRecursively(self._raw_data_dir_f)

    def _launch_masters(self):
        self._master_addr_l = 'localhost:4061'
        self._master_addr_f = 'localhost:4062'
        master_options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True)
        os.environ['ETCD_BASE_DIR'] = self.leader_base_dir
        self._master_l = data_join_master.DataJoinMasterService(
            int(self._master_addr_l.split(':')[1]), self._master_addr_f,
            self._data_source_name, self.kvstore_type, master_options)
        os.environ['ETCD_BASE_DIR'] = self.follower_base_dir
        self._master_f = data_join_master.DataJoinMasterService(
            int(self._master_addr_f.split(':')[1]), self._master_addr_l,
            self._data_source_name, self.kvstore_type, master_options)
        self._master_f.start()
        self._master_l.start()
        channel_l = make_insecure_channel(self._master_addr_l,
                                          ChannelType.INTERNAL)
        self._master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l)
        channel_f = make_insecure_channel(self._master_addr_f,
                                          ChannelType.INTERNAL)
        self._master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f)

        while True:
            req_l = dj_pb.DataSourceRequest(
                data_source_meta=self._data_source_l.data_source_meta)
            req_f = dj_pb.DataSourceRequest(
                data_source_meta=self._data_source_f.data_source_meta)
            dss_l = self._master_client_l.GetDataSourceStatus(req_l)
            dss_f = self._master_client_f.GetDataSourceStatus(req_f)
            self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
            self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
            if dss_l.state == common_pb.DataSourceState.Processing and \
                    dss_f.state == common_pb.DataSourceState.Processing:
                break
            else:
                time.sleep(2)
        logging.info("masters turn into Processing state")

    def _launch_workers(self):
        worker_options_l = dj_pb.DataJoinWorkerOptions(
            use_mock_etcd=True,
            raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                                  read_ahead_size=1 << 20,
                                                  read_batch_size=128),
            example_id_dump_options=dj_pb.ExampleIdDumpOptions(
                example_id_dump_interval=1, example_id_dump_threshold=1024),
            example_joiner_options=dj_pb.ExampleJoinerOptions(
                example_joiner='SORT_RUN_JOINER',
                min_matching_window=64,
                max_matching_window=256,
                data_block_dump_interval=30,
                data_block_dump_threshold=1000),
            batch_processor_options=dj_pb.BatchProcessorOptions(
                batch_size=1024, max_flying_item=4096),
            data_block_builder_options=dj_pb.WriterOptions(
                output_writer='CSV_DICT'))
        worker_options_f = dj_pb.DataJoinWorkerOptions(
            use_mock_etcd=True,
            raw_data_options=dj_pb.RawDataOptions(raw_data_iter='CSV_DICT',
                                                  read_ahead_size=1 << 20,
                                                  read_batch_size=128),
            example_id_dump_options=dj_pb.ExampleIdDumpOptions(
                example_id_dump_interval=1, example_id_dump_threshold=1024),
            example_joiner_options=dj_pb.ExampleJoinerOptions(
                example_joiner='SORT_RUN_JOINER',
                min_matching_window=64,
                max_matching_window=256,
                data_block_dump_interval=30,
                data_block_dump_threshold=1000),
            batch_processor_options=dj_pb.BatchProcessorOptions(
                batch_size=1024, max_flying_item=4096),
            data_block_builder_options=dj_pb.WriterOptions(
                output_writer='TF_RECORD'))

        self._worker_addrs_l = [
            'localhost:4161', 'localhost:4162', 'localhost:4163',
            'localhost:4164'
        ]
        self._worker_addrs_f = [
            'localhost:5161', 'localhost:5162', 'localhost:5163',
            'localhost:5164'
        ]
        self._workers_l = []
        self._workers_f = []
        for rank_id in range(4):
            worker_addr_l = self._worker_addrs_l[rank_id]
            worker_addr_f = self._worker_addrs_f[rank_id]
            os.environ['ETCD_BASE_DIR'] = self.leader_base_dir
            self._workers_l.append(
                data_join_worker.DataJoinWorkerService(
                    int(worker_addr_l.split(':')[1]), worker_addr_f,
                    self._master_addr_l, rank_id, self.kvstore_type,
                    worker_options_l))
            os.environ['ETCD_BASE_DIR'] = self.follower_base_dir
            self._workers_f.append(
                data_join_worker.DataJoinWorkerService(
                    int(worker_addr_f.split(':')[1]), worker_addr_l,
                    self._master_addr_f, rank_id, self.kvstore_type,
                    worker_options_f))
        for w in self._workers_l:
            w.start()
        for w in self._workers_f:
            w.start()

    def _launch_rsa_psi_signer(self):
        self._rsa_psi_signer_addr = 'localhost:6171'
        rsa_private_key_pem = None
        with gfile.GFile(self._rsa_private_key_path, 'rb') as f:
            rsa_private_key_pem = f.read()
        rsa_private_key = rsa.PrivateKey.load_pkcs1(rsa_private_key_pem)
        self._rsa_psi_signer = rsa_psi_signer.RsaPsiSigner(
            rsa_private_key, 1, 500)
        self._rsa_psi_signer.start(
            int(self._rsa_psi_signer_addr.split(':')[1]), 512)

    def _stop_workers(self):
        for w in self._workers_f:
            w.stop()
        for w in self._workers_l:
            w.stop()

    def _stop_masters(self):
        self._master_f.stop()
        self._master_l.stop()

    def _stop_rsa_psi_signer(self):
        self._rsa_psi_signer.stop()

    def setUp(self):
        self._setUpMySQL()
        self._setUpDataSource()
        self._setUpRsaPsiConf()
        self._remove_existed_dir()
        self._gen_psi_input_raw_data()
        self._launch_masters()
        self._launch_workers()
        self._launch_rsa_psi_signer()

    def _preprocess_rsa_psi_leader(self):
        processors = []
        rsa_key_pem = None
        with gfile.GFile(self._rsa_private_key_path, 'rb') as f:
            rsa_key_pem = f.read()
        for partition_id in range(
                self._data_source_l.data_source_meta.partition_num):
            options = dj_pb.RsaPsiPreProcessorOptions(
                preprocessor_name='leader-rsa-psi-processor',
                role=common_pb.FLRole.Leader,
                rsa_key_pem=rsa_key_pem,
                input_file_paths=[self._psi_raw_data_fpaths_l[partition_id]],
                output_file_dir=self._pre_processor_ouput_dir_l,
                raw_data_publish_dir=self._raw_data_pub_dir_l,
                partition_id=partition_id,
                offload_processor_number=1,
                max_flying_sign_batch=128,
                stub_fanout=2,
                slow_sign_threshold=8,
                sort_run_merger_read_ahead_buffer=1 << 20,
                sort_run_merger_read_batch_size=128,
                batch_processor_options=dj_pb.BatchProcessorOptions(
                    batch_size=1024, max_flying_item=1 << 14),
                input_raw_data=dj_pb.RawDataOptions(raw_data_iter='CSV_DICT',
                                                    read_ahead_size=1 << 20),
                writer_options=dj_pb.WriterOptions(output_writer='TF_RECORD'))
            os.environ['ETCD_BASE_DIR'] = self.leader_base_dir
            processor = rsa_psi_preprocessor.RsaPsiPreProcessor(
                options, self.kvstore_type, True)
            processor.start_process()
            processors.append(processor)
        for processor in processors:
            processor.wait_for_finished()

    def _preprocess_rsa_psi_follower(self):
        processors = []
        rsa_key_pem = None
        with gfile.GFile(self._rsa_public_key_path, 'rb') as f:
            rsa_key_pem = f.read()
        self._follower_rsa_psi_sub_dir = 'follower_rsa_psi_sub_dir'
        rd_publisher = raw_data_publisher.RawDataPublisher(
            self._kvstore_f, self._follower_rsa_psi_sub_dir)
        for partition_id in range(
                self._data_source_f.data_source_meta.partition_num):
            rd_publisher.publish_raw_data(
                partition_id, [self._psi_raw_data_fpaths_f[partition_id]])
            rd_publisher.finish_raw_data(partition_id)
            options = dj_pb.RsaPsiPreProcessorOptions(
                preprocessor_name='follower-rsa-psi-processor',
                role=common_pb.FLRole.Follower,
                rsa_key_pem=rsa_key_pem,
                input_file_subscribe_dir=self._follower_rsa_psi_sub_dir,
                output_file_dir=self._pre_processor_ouput_dir_f,
                raw_data_publish_dir=self._raw_data_pub_dir_f,
                partition_id=partition_id,
                leader_rsa_psi_signer_addr=self._rsa_psi_signer_addr,
                offload_processor_number=1,
                max_flying_sign_batch=128,
                max_flying_sign_rpc=64,
                sign_rpc_timeout_ms=100000,
                stub_fanout=2,
                slow_sign_threshold=8,
                sort_run_merger_read_ahead_buffer=1 << 20,
                sort_run_merger_read_batch_size=128,
                batch_processor_options=dj_pb.BatchProcessorOptions(
                    batch_size=1024, max_flying_item=1 << 14),
                input_raw_data=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                                    read_ahead_size=1 << 20),
                writer_options=dj_pb.WriterOptions(output_writer='CSV_DICT'))
            os.environ['ETCD_BASE_DIR'] = self.follower_base_dir
            processor = rsa_psi_preprocessor.RsaPsiPreProcessor(
                options, self.kvstore_type, True)
            processor.start_process()
            processors.append(processor)
        for processor in processors:
            processor.wait_for_finished()

    def test_all_pipeline(self):
        start_tm = time.time()
        self._preprocess_rsa_psi_follower()
        logging.warning("Follower Preprocess cost %d seconds",
                        time.time() - start_tm)
        start_tm = time.time()
        self._preprocess_rsa_psi_leader()
        logging.warning("Leader Preprocess cost %f seconds",
                        time.time() - start_tm)
        while True:
            req_l = dj_pb.DataSourceRequest(
                data_source_meta=self._data_source_l.data_source_meta)
            req_f = dj_pb.DataSourceRequest(
                data_source_meta=self._data_source_f.data_source_meta)
            dss_l = self._master_client_l.GetDataSourceStatus(req_l)
            dss_f = self._master_client_f.GetDataSourceStatus(req_f)
            self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
            self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
            if dss_l.state == common_pb.DataSourceState.Finished and \
                    dss_f.state == common_pb.DataSourceState.Finished:
                break
            else:
                time.sleep(2)
        logging.info("masters turn into Finished state")

    def tearDown(self):
        self._stop_workers()
        self._stop_masters()
        self._stop_rsa_psi_signer()
        self._remove_existed_dir()
Esempio n. 18
0
import etcd3
from fedlearner.common.db_client import DBClient
from fedlearner.common.db_client import get_kvstore_config

MySQL_client = DBClient('mysql')
database, addr, username, password, base_dir = \
    get_kvstore_config('etcd')
(host, port) = addr.split(':')
options = [('grpc.max_send_message_length', 2**31 - 1),
           ('grpc.max_receive_message_length', 2**31 - 1)]
clnt = etcd3.client(host=host, port=port, grpc_options=options)
for (data, key) in clnt.get_prefix('/', sort_order='ascend'):
    if not isinstance(key.key, str):
        key = key.key.decoder()
    if not isinstance(data, str):
        data = data.decoder()
    MySQL_client.set_data(key, data)
Esempio n. 19
0
    def test_api(self):
        logging.getLogger().setLevel(logging.DEBUG)
        kvstore_type = 'etcd'
        db_base_dir = 'dp_test'
        os.environ['ETCD_BASE_DIR'] = db_base_dir
        data_portal_name = 'test_data_source'
        kvstore = DBClient(kvstore_type, True)
        kvstore.delete_prefix(db_base_dir)
        portal_input_base_dir = './portal_upload_dir'
        portal_output_base_dir = './portal_output_dir'
        raw_data_publish_dir = 'raw_data_publish_dir'
        portal_manifest = dp_pb.DataPortalManifest(
            name=data_portal_name,
            data_portal_type=dp_pb.DataPortalType.Streaming,
            output_partition_num=4,
            input_file_wildcard="*.done",
            input_base_dir=portal_input_base_dir,
            output_base_dir=portal_output_base_dir,
            raw_data_publish_dir=raw_data_publish_dir,
            processing_job_id=-1,
            next_job_id=0)
        kvstore.set_data(common.portal_kvstore_base_dir(data_portal_name),
                         text_format.MessageToString(portal_manifest))
        if gfile.Exists(portal_input_base_dir):
            gfile.DeleteRecursively(portal_input_base_dir)
        gfile.MakeDirs(portal_input_base_dir)
        all_fnames = ['1001/{}.done'.format(i) for i in range(100)]
        all_fnames.append('{}.xx'.format(100))
        all_fnames.append('1001/_SUCCESS')
        for fname in all_fnames:
            fpath = os.path.join(portal_input_base_dir, fname)
            gfile.MakeDirs(os.path.dirname(fpath))
            with gfile.Open(fpath, "w") as f:
                f.write('xxx')
        portal_master_addr = 'localhost:4061'
        portal_options = dp_pb.DataPotraMasterlOptions(
            use_mock_etcd=True,
            long_running=False,
            check_success_tag=True,
        )
        data_portal_master = DataPortalMasterService(
            int(portal_master_addr.split(':')[1]), data_portal_name,
            kvstore_type, portal_options)
        data_portal_master.start()

        channel = make_insecure_channel(portal_master_addr,
                                        ChannelType.INTERNAL)
        portal_master_cli = dp_grpc.DataPortalMasterServiceStub(channel)
        recv_manifest = portal_master_cli.GetDataPortalManifest(
            empty_pb2.Empty())
        self.assertEqual(recv_manifest.name, portal_manifest.name)
        self.assertEqual(recv_manifest.data_portal_type,
                         portal_manifest.data_portal_type)
        self.assertEqual(recv_manifest.output_partition_num,
                         portal_manifest.output_partition_num)
        self.assertEqual(recv_manifest.input_file_wildcard,
                         portal_manifest.input_file_wildcard)
        self.assertEqual(recv_manifest.input_base_dir,
                         portal_manifest.input_base_dir)
        self.assertEqual(recv_manifest.output_base_dir,
                         portal_manifest.output_base_dir)
        self.assertEqual(recv_manifest.raw_data_publish_dir,
                         portal_manifest.raw_data_publish_dir)
        self.assertEqual(recv_manifest.next_job_id, 1)
        self.assertEqual(recv_manifest.processing_job_id, 0)
        self._check_portal_job(kvstore, all_fnames, portal_manifest, 0)
        mapped_partition = set()
        task_0 = portal_master_cli.RequestNewTask(
            dp_pb.NewTaskRequest(rank_id=0))
        task_0_1 = portal_master_cli.RequestNewTask(
            dp_pb.NewTaskRequest(rank_id=0))
        self.assertEqual(task_0, task_0_1)
        self.assertTrue(task_0.HasField('map_task'))
        mapped_partition.add(task_0.map_task.partition_id)
        self._check_map_task(task_0.map_task, all_fnames,
                             task_0.map_task.partition_id, portal_manifest)
        portal_master_cli.FinishTask(
            dp_pb.FinishTaskRequest(rank_id=0,
                                    partition_id=task_0.map_task.partition_id,
                                    part_state=dp_pb.PartState.kIdMap))
        task_1 = portal_master_cli.RequestNewTask(
            dp_pb.NewTaskRequest(rank_id=0))
        self.assertTrue(task_1.HasField('map_task'))
        mapped_partition.add(task_1.map_task.partition_id)
        self._check_map_task(task_1.map_task, all_fnames,
                             task_1.map_task.partition_id, portal_manifest)

        task_2 = portal_master_cli.RequestNewTask(
            dp_pb.NewTaskRequest(rank_id=1))
        self.assertTrue(task_2.HasField('map_task'))
        mapped_partition.add(task_2.map_task.partition_id)
        self._check_map_task(task_2.map_task, all_fnames,
                             task_2.map_task.partition_id, portal_manifest)

        task_3 = portal_master_cli.RequestNewTask(
            dp_pb.NewTaskRequest(rank_id=2))
        self.assertTrue(task_3.HasField('map_task'))
        mapped_partition.add(task_3.map_task.partition_id)
        self._check_map_task(task_3.map_task, all_fnames,
                             task_3.map_task.partition_id, portal_manifest)

        self.assertEqual(len(mapped_partition),
                         portal_manifest.output_partition_num)

        portal_master_cli.FinishTask(
            dp_pb.FinishTaskRequest(rank_id=0,
                                    partition_id=task_1.map_task.partition_id,
                                    part_state=dp_pb.PartState.kIdMap))

        pending_1 = portal_master_cli.RequestNewTask(
            dp_pb.NewTaskRequest(rank_id=4))
        self.assertTrue(pending_1.HasField('pending'))
        pending_2 = portal_master_cli.RequestNewTask(
            dp_pb.NewTaskRequest(rank_id=3))
        self.assertTrue(pending_2.HasField('pending'))

        portal_master_cli.FinishTask(
            dp_pb.FinishTaskRequest(rank_id=1,
                                    partition_id=task_2.map_task.partition_id,
                                    part_state=dp_pb.PartState.kIdMap))

        portal_master_cli.FinishTask(
            dp_pb.FinishTaskRequest(rank_id=2,
                                    partition_id=task_3.map_task.partition_id,
                                    part_state=dp_pb.PartState.kIdMap))

        reduce_partition = set()
        task_4 = portal_master_cli.RequestNewTask(
            dp_pb.NewTaskRequest(rank_id=0))
        task_4_1 = portal_master_cli.RequestNewTask(
            dp_pb.NewTaskRequest(rank_id=0))
        self.assertEqual(task_4, task_4_1)
        self.assertTrue(task_4.HasField('reduce_task'))
        reduce_partition.add(task_4.reduce_task.partition_id)
        self._check_reduce_task(task_4.reduce_task,
                                task_4.reduce_task.partition_id,
                                portal_manifest)
        task_5 = portal_master_cli.RequestNewTask(
            dp_pb.NewTaskRequest(rank_id=1))
        self.assertTrue(task_5.HasField('reduce_task'))
        reduce_partition.add(task_5.reduce_task.partition_id)
        self._check_reduce_task(task_5.reduce_task,
                                task_5.reduce_task.partition_id,
                                portal_manifest)
        task_6 = portal_master_cli.RequestNewTask(
            dp_pb.NewTaskRequest(rank_id=2))
        self.assertTrue(task_6.HasField('reduce_task'))
        reduce_partition.add(task_6.reduce_task.partition_id)
        self._check_reduce_task(task_6.reduce_task,
                                task_6.reduce_task.partition_id,
                                portal_manifest)
        task_7 = portal_master_cli.RequestNewTask(
            dp_pb.NewTaskRequest(rank_id=3))
        self.assertTrue(task_7.HasField('reduce_task'))
        reduce_partition.add(task_7.reduce_task.partition_id)
        self.assertEqual(len(reduce_partition), 4)
        self._check_reduce_task(task_7.reduce_task,
                                task_7.reduce_task.partition_id,
                                portal_manifest)

        task_8 = portal_master_cli.RequestNewTask(
            dp_pb.NewTaskRequest(rank_id=5))
        self.assertTrue(task_8.HasField('pending'))

        portal_master_cli.FinishTask(
            dp_pb.FinishTaskRequest(
                rank_id=0,
                partition_id=task_4.reduce_task.partition_id,
                part_state=dp_pb.PartState.kEventTimeReduce))
        portal_master_cli.FinishTask(
            dp_pb.FinishTaskRequest(
                rank_id=1,
                partition_id=task_5.reduce_task.partition_id,
                part_state=dp_pb.PartState.kEventTimeReduce))
        portal_master_cli.FinishTask(
            dp_pb.FinishTaskRequest(
                rank_id=2,
                partition_id=task_6.reduce_task.partition_id,
                part_state=dp_pb.PartState.kEventTimeReduce))
        portal_master_cli.FinishTask(
            dp_pb.FinishTaskRequest(
                rank_id=3,
                partition_id=task_7.reduce_task.partition_id,
                part_state=dp_pb.PartState.kEventTimeReduce))

        task_9 = portal_master_cli.RequestNewTask(
            dp_pb.NewTaskRequest(rank_id=5))
        self.assertTrue(task_9.HasField('finished'))

        data_portal_master.stop()
        gfile.DeleteRecursively(portal_input_base_dir)
    def setUp(self):
        self.kvstore_type = 'etcd'
        self.leader_base_dir = 'bytefl_l'
        self.follower_base_dir = 'bytefl_f'
        data_source_name = 'test_data_source'
        os.environ['ETCD_BASE_DIR'] = self.leader_base_dir
        kvstore_l = DBClient(self.kvstore_type, True)
        os.environ['ETCD_BASE_DIR'] = self.follower_base_dir
        kvstore_f = DBClient(self.kvstore_type, True)
        kvstore_l.delete_prefix(
            common.data_source_kvstore_base_dir(data_source_name))
        kvstore_f.delete_prefix(
            common.data_source_kvstore_base_dir(data_source_name))
        data_source_l = common_pb.DataSource()
        self.raw_data_pub_dir_l = './raw_data_pub_dir_l'
        data_source_l.raw_data_sub_dir = self.raw_data_pub_dir_l
        data_source_l.role = common_pb.FLRole.Leader
        data_source_l.state = common_pb.DataSourceState.Init
        data_source_l.output_base_dir = "./ds_output_l"
        self.raw_data_dir_l = "./raw_data_l"
        data_source_f = common_pb.DataSource()
        self.raw_data_pub_dir_f = './raw_data_pub_dir_f'
        data_source_f.role = common_pb.FLRole.Follower
        data_source_f.raw_data_sub_dir = self.raw_data_pub_dir_f
        data_source_f.state = common_pb.DataSourceState.Init
        data_source_f.output_base_dir = "./ds_output_f"
        self.raw_data_dir_f = "./raw_data_f"
        data_source_meta = common_pb.DataSourceMeta()
        data_source_meta.name = data_source_name
        data_source_meta.partition_num = 2
        data_source_meta.start_time = 0
        data_source_meta.end_time = 100000000
        data_source_l.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(kvstore_l, data_source_l)
        data_source_f.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(kvstore_f, data_source_f)

        self.kvstore_l = kvstore_l
        self.kvstore_f = kvstore_f
        self.data_source_l = data_source_l
        self.data_source_f = data_source_f
        self.data_source_name = data_source_name
        self.raw_data_publisher_l = raw_data_publisher.RawDataPublisher(
            self.kvstore_l, self.raw_data_pub_dir_l)
        self.raw_data_publisher_f = raw_data_publisher.RawDataPublisher(
            self.kvstore_f, self.raw_data_pub_dir_f)
        if gfile.Exists(data_source_l.output_base_dir):
            gfile.DeleteRecursively(data_source_l.output_base_dir)
        if gfile.Exists(self.raw_data_dir_l):
            gfile.DeleteRecursively(self.raw_data_dir_l)
        if gfile.Exists(data_source_f.output_base_dir):
            gfile.DeleteRecursively(data_source_f.output_base_dir)
        if gfile.Exists(self.raw_data_dir_f):
            gfile.DeleteRecursively(self.raw_data_dir_f)

        self.worker_options = dj_pb.DataJoinWorkerOptions(
            use_mock_etcd=True,
            raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                                  read_ahead_size=1 << 20,
                                                  read_batch_size=128,
                                                  optional_fields=['label']),
            example_id_dump_options=dj_pb.ExampleIdDumpOptions(
                example_id_dump_interval=1, example_id_dump_threshold=1024),
            example_joiner_options=dj_pb.ExampleJoinerOptions(
                example_joiner='STREAM_JOINER',
                min_matching_window=64,
                max_matching_window=256,
                data_block_dump_interval=30,
                data_block_dump_threshold=1000),
            batch_processor_options=dj_pb.BatchProcessorOptions(
                batch_size=512, max_flying_item=2048),
            data_block_builder_options=dj_pb.WriterOptions(
                output_writer='TF_RECORD'))
        self.total_index = 1 << 12