def setUp(self):
        etcd_name = 'test_etcd'
        etcd_addrs = 'localhost:2379'
        etcd_base_dir_l = 'byefl_l'
        etcd_base_dir_f = 'byefl_f'
        data_source_name = 'test_data_source'
        etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True)
        etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True)
        etcd_l.delete_prefix(
            common.data_source_etcd_base_dir(data_source_name))
        etcd_f.delete_prefix(
            common.data_source_etcd_base_dir(data_source_name))
        data_source_l = common_pb.DataSource()
        self.raw_data_pub_dir_l = './raw_data_pub_dir_l'
        data_source_l.raw_data_sub_dir = self.raw_data_pub_dir_l
        data_source_l.role = common_pb.FLRole.Leader
        data_source_l.state = common_pb.DataSourceState.Init
        data_source_l.data_block_dir = "./data_block_l"
        data_source_l.raw_data_dir = "./raw_data_l"
        data_source_l.example_dumped_dir = "./example_dumped_l"
        data_source_f = common_pb.DataSource()
        self.raw_data_pub_dir_f = './raw_data_pub_dir_f'
        data_source_f.role = common_pb.FLRole.Follower
        data_source_f.raw_data_sub_dir = self.raw_data_pub_dir_f
        data_source_f.state = common_pb.DataSourceState.Init
        data_source_f.data_block_dir = "./data_block_f"
        data_source_f.raw_data_dir = "./raw_data_f"
        data_source_f.example_dumped_dir = "./example_dumped_f"
        data_source_meta = common_pb.DataSourceMeta()
        data_source_meta.name = data_source_name
        data_source_meta.partition_num = 2
        data_source_meta.start_time = 0
        data_source_meta.end_time = 100000000
        data_source_l.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_l, data_source_l)
        data_source_f.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_f, data_source_f)
        master_options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True)

        master_addr_l = 'localhost:4061'
        master_addr_f = 'localhost:4062'
        master_l = data_join_master.DataJoinMasterService(
            int(master_addr_l.split(':')[1]),
            master_addr_f,
            data_source_name,
            etcd_name,
            etcd_base_dir_l,
            etcd_addrs,
            master_options,
        )
        master_l.start()
        master_f = data_join_master.DataJoinMasterService(
            int(master_addr_f.split(':')[1]), master_addr_l, data_source_name,
            etcd_name, etcd_base_dir_f, etcd_addrs, master_options)
        master_f.start()
        channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL)
        master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l)
        channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL)
        master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f)

        while True:
            req_l = dj_pb.DataSourceRequest(
                data_source_meta=data_source_l.data_source_meta)
            req_f = dj_pb.DataSourceRequest(
                data_source_meta=data_source_f.data_source_meta)
            dss_l = master_client_l.GetDataSourceStatus(req_l)
            dss_f = master_client_f.GetDataSourceStatus(req_f)
            self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
            self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
            if dss_l.state == common_pb.DataSourceState.Processing and \
                    dss_f.state == common_pb.DataSourceState.Processing:
                break
            else:
                time.sleep(2)

        self.master_client_l = master_client_l
        self.master_client_f = master_client_f
        self.master_addr_l = master_addr_l
        self.master_addr_f = master_addr_f
        self.etcd_l = etcd_l
        self.etcd_f = etcd_f
        self.data_source_l = data_source_l
        self.data_source_f = data_source_f
        self.master_l = master_l
        self.master_f = master_f
        self.data_source_name = data_source_name,
        self.etcd_name = etcd_name
        self.etcd_addrs = etcd_addrs
        self.etcd_base_dir_l = etcd_base_dir_l
        self.etcd_base_dir_f = etcd_base_dir_f
        self.raw_data_publisher_l = raw_data_publisher.RawDataPublisher(
            self.etcd_l, self.raw_data_pub_dir_l)
        self.raw_data_publisher_f = raw_data_publisher.RawDataPublisher(
            self.etcd_f, self.raw_data_pub_dir_f)
        if gfile.Exists(data_source_l.data_block_dir):
            gfile.DeleteRecursively(data_source_l.data_block_dir)
        if gfile.Exists(data_source_l.example_dumped_dir):
            gfile.DeleteRecursively(data_source_l.example_dumped_dir)
        if gfile.Exists(data_source_l.raw_data_dir):
            gfile.DeleteRecursively(data_source_l.raw_data_dir)
        if gfile.Exists(data_source_f.data_block_dir):
            gfile.DeleteRecursively(data_source_f.data_block_dir)
        if gfile.Exists(data_source_f.example_dumped_dir):
            gfile.DeleteRecursively(data_source_f.example_dumped_dir)
        if gfile.Exists(data_source_f.raw_data_dir):
            gfile.DeleteRecursively(data_source_f.raw_data_dir)

        self.worker_options = dj_pb.DataJoinWorkerOptions(
            use_mock_etcd=True,
            raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                                  compressed_type=''),
            example_id_dump_options=dj_pb.ExampleIdDumpOptions(
                example_id_dump_interval=1, example_id_dump_threshold=1024),
            example_joiner_options=dj_pb.ExampleJoinerOptions(
                example_joiner='STREAM_JOINER',
                min_matching_window=64,
                max_matching_window=256,
                data_block_dump_interval=30,
                data_block_dump_threshold=1000),
            batch_processor_options=dj_pb.BatchProcessorOptions(
                batch_size=512, max_flying_item=2048),
            data_block_builder_options=dj_pb.WriterOptions(
                output_writer='TF_RECORD'))

        self.total_index = 1 << 13
Exemple #2
0
    def setUp(self):
        etcd_name = 'test_etcd'
        etcd_addrs = 'localhost:2379'
        etcd_base_dir_l = 'byefl_l'
        etcd_base_dir_f= 'byefl_f'
        data_source_name = 'test_data_source'
        etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True)
        etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True)
        etcd_l.delete_prefix(common.data_source_etcd_base_dir(data_source_name))
        etcd_f.delete_prefix(common.data_source_etcd_base_dir(data_source_name))
        data_source_l = common_pb.DataSource()
        self.raw_data_pub_dir_l = './raw_data_pub_dir_l'
        data_source_l.raw_data_sub_dir = self.raw_data_pub_dir_l
        data_source_l.role = common_pb.FLRole.Leader
        data_source_l.state = common_pb.DataSourceState.Init
        data_source_l.output_base_dir = "./ds_output_l"
        self.raw_data_dir_l = "./raw_data_l"
        data_source_f = common_pb.DataSource()
        self.raw_data_pub_dir_f = './raw_data_pub_dir_f'
        data_source_f.role = common_pb.FLRole.Follower
        data_source_f.raw_data_sub_dir = self.raw_data_pub_dir_f
        data_source_f.state = common_pb.DataSourceState.Init
        data_source_f.output_base_dir = "./ds_output_f"
        self.raw_data_dir_f = "./raw_data_f"
        data_source_meta = common_pb.DataSourceMeta()
        data_source_meta.name = data_source_name
        data_source_meta.partition_num = 2
        data_source_meta.start_time = 0
        data_source_meta.end_time = 100000000
        data_source_l.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_l, data_source_l)
        data_source_f.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_f, data_source_f)

        self.etcd_l = etcd_l
        self.etcd_f = etcd_f
        self.data_source_l = data_source_l
        self.data_source_f = data_source_f
        self.data_source_name = data_source_name
        self.etcd_name = etcd_name
        self.etcd_addrs = etcd_addrs
        self.etcd_base_dir_l = etcd_base_dir_l
        self.etcd_base_dir_f = etcd_base_dir_f
        self.raw_data_publisher_l = raw_data_publisher.RawDataPublisher(
                self.etcd_l, self.raw_data_pub_dir_l
            )
        self.raw_data_publisher_f = raw_data_publisher.RawDataPublisher(
                self.etcd_f, self.raw_data_pub_dir_f
            )
        if gfile.Exists(data_source_l.output_base_dir):
            gfile.DeleteRecursively(data_source_l.output_base_dir)
        if gfile.Exists(self.raw_data_dir_l):
            gfile.DeleteRecursively(self.raw_data_dir_l)
        if gfile.Exists(data_source_f.output_base_dir):
            gfile.DeleteRecursively(data_source_f.output_base_dir)
        if gfile.Exists(self.raw_data_dir_f):
            gfile.DeleteRecursively(self.raw_data_dir_f)

        self.worker_options = dj_pb.DataJoinWorkerOptions(
                use_mock_etcd=True,
                raw_data_options=dj_pb.RawDataOptions(
                    raw_data_iter='TF_RECORD',
                    read_ahead_size=1<<20,
                    read_batch_size=128
                ),
                example_id_dump_options=dj_pb.ExampleIdDumpOptions(
                    example_id_dump_interval=1,
                    example_id_dump_threshold=1024
                ),
                example_joiner_options=dj_pb.ExampleJoinerOptions(
                    example_joiner='STREAM_JOINER',
                    min_matching_window=64,
                    max_matching_window=256,
                    data_block_dump_interval=30,
                    data_block_dump_threshold=1000
                ),
                batch_processor_options=dj_pb.BatchProcessorOptions(
                    batch_size=512,
                    max_flying_item=2048
                ),
                data_block_builder_options=dj_pb.WriterOptions(
                    output_writer='TF_RECORD'
                )
            )

        self.total_index = 1 << 12
Exemple #3
0
    def test_api(self):
        etcd_name = 'test_etcd'
        etcd_addrs = 'localhost:2379'
        etcd_base_dir_l = 'byefl_l'
        etcd_base_dir_f = 'byefl_f'
        data_source_name = 'test_data_source'
        etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l)
        etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f)
        etcd_l.delete_prefix(data_source_name)
        etcd_f.delete_prefix(data_source_name)
        data_source_l = common_pb.DataSource()
        data_source_l.role = common_pb.FLRole.Leader
        data_source_l.state = common_pb.DataSourceState.Init
        data_source_l.data_block_dir = "./data_block_l"
        data_source_l.raw_data_dir = "./raw_data_l"
        data_source_l.example_dumped_dir = "./example_dumped_l"
        data_source_f = common_pb.DataSource()
        data_source_f.role = common_pb.FLRole.Follower
        data_source_f.state = common_pb.DataSourceState.Init
        data_source_f.data_block_dir = "./data_block_f"
        data_source_f.raw_data_dir = "./raw_data_f"
        data_source_f.example_dumped_dir = "./example_dumped_f"
        data_source_meta = common_pb.DataSourceMeta()
        data_source_meta.name = data_source_name
        data_source_meta.partition_num = 1
        data_source_meta.start_time = 0
        data_source_meta.end_time = 100000000
        data_source_meta.min_matching_window = 32
        data_source_meta.max_matching_window = 1024
        data_source_meta.data_source_type = common_pb.DataSourceType.Sequential
        data_source_meta.max_example_in_data_block = 1000
        data_source_l.data_source_meta.MergeFrom(data_source_meta)
        etcd_l.set_data(os.path.join(data_source_name, 'master'),
                        text_format.MessageToString(data_source_l))
        data_source_f.data_source_meta.MergeFrom(data_source_meta)
        etcd_f.set_data(os.path.join(data_source_name, 'master'),
                        text_format.MessageToString(data_source_f))

        master_addr_l = 'localhost:4061'
        master_addr_f = 'localhost:4062'
        master_l = data_join_master.DataJoinMasterService(
            int(master_addr_l.split(':')[1]), master_addr_f, data_source_name,
            etcd_name, etcd_base_dir_l, etcd_addrs)
        master_l.start()
        master_f = data_join_master.DataJoinMasterService(
            int(master_addr_f.split(':')[1]), master_addr_l, data_source_name,
            etcd_name, etcd_base_dir_f, etcd_addrs)
        master_f.start()
        channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL)
        client_l = dj_grpc.DataJoinMasterServiceStub(channel_l)
        channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL)
        client_f = dj_grpc.DataJoinMasterServiceStub(channel_f)

        while True:
            rsp_l = client_l.GetDataSourceState(data_source_l.data_source_meta)
            rsp_f = client_f.GetDataSourceState(data_source_f.data_source_meta)
            self.assertEqual(rsp_l.status.code, 0)
            self.assertEqual(rsp_l.role, common_pb.FLRole.Leader)
            self.assertEqual(rsp_l.data_source_type,
                             common_pb.DataSourceType.Sequential)
            self.assertEqual(rsp_f.status.code, 0)
            self.assertEqual(rsp_f.role, common_pb.FLRole.Follower)
            self.assertEqual(rsp_f.data_source_type,
                             common_pb.DataSourceType.Sequential)
            if (rsp_l.state == common_pb.DataSourceState.Processing
                    and rsp_f.state == common_pb.DataSourceState.Processing):
                break
            else:
                time.sleep(2)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            join_example=dj_pb.JoinExampleRequest(partition_id=-1))

        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertTrue(rdrsp.status.code == 0)
        self.assertFalse(rdrsp.HasField('manifest'))
        self.assertFalse(rdrsp.HasField('finished'))

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=-1))
        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.state, dj_pb.Syncing)
        self.assertEqual(rdrsp.manifest.allocated_rank_id, 0)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.state, dj_pb.Syncing)
        self.assertEqual(rdrsp.manifest.allocated_rank_id, 0)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=0))
        rdrsp = client_f.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.state, dj_pb.Syncing)
        self.assertEqual(rdrsp.manifest.allocated_rank_id, 0)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        frreq = dj_pb.FinishRawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=0))
        frrsp = client_l.FinishJoinPartition(frreq)
        self.assertEqual(frrsp.code, 0)
        rdrsp = client_l.FinishJoinPartition(rdreq)
        self.assertEqual(frrsp.code, 0)

        rdreq = dj_pb.FinishRawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=0))
        frrsp = client_f.FinishJoinPartition(rdreq)
        self.assertEqual(frrsp.code, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            join_example=dj_pb.JoinExampleRequest(partition_id=-1))
        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.state, dj_pb.Joining)
        self.assertEqual(rdrsp.manifest.allocated_rank_id, 0)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            join_example=dj_pb.JoinExampleRequest(partition_id=0))
        rdrsp = client_f.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.state, dj_pb.Joining)
        self.assertEqual(rdrsp.manifest.allocated_rank_id, 0)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        frreq = dj_pb.FinishRawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            join_example=dj_pb.JoinExampleRequest(partition_id=0))
        frrsp = client_l.FinishJoinPartition(rdreq)
        self.assertEqual(frrsp.code, 0)

        frrsp = client_l.FinishJoinPartition(rdreq)
        self.assertEqual(frrsp.code, 0)

        frreq = dj_pb.FinishRawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            join_example=dj_pb.JoinExampleRequest(partition_id=0))
        frrsp = client_f.FinishJoinPartition(rdreq)
        self.assertEqual(frrsp.code, 0)

        while True:
            rsp_l = client_l.GetDataSourceState(data_source_l.data_source_meta)
            rsp_f = client_f.GetDataSourceState(data_source_l.data_source_meta)
            self.assertEqual(rsp_l.status.code, 0)
            self.assertEqual(rsp_l.role, common_pb.FLRole.Leader)
            self.assertEqual(rsp_l.data_source_type,
                             common_pb.DataSourceType.Sequential)
            self.assertEqual(rsp_f.status.code, 0)
            self.assertEqual(rsp_f.role, common_pb.FLRole.Follower)
            self.assertEqual(rsp_f.data_source_type,
                             common_pb.DataSourceType.Sequential)
            if (rsp_l.state == common_pb.DataSourceState.Finished
                    and rsp_f.state == common_pb.DataSourceState.Finished):
                break
            else:
                time.sleep(2)

        master_l.stop()
        master_f.stop()
Exemple #4
0
class RsaPsi(unittest.TestCase):
    def _setUpEtcd(self):
        self._etcd_name = 'test_etcd'
        self._etcd_addrs = 'localhost:2379'
        self._etcd_base_dir_l = 'byefl_l'
        self._etcd_base_dir_f = 'byefl_f'
        self._etcd_l = EtcdClient(self._etcd_name, self._etcd_addrs,
                                  self._etcd_base_dir_l, True)
        self._etcd_f = EtcdClient(self._etcd_name, self._etcd_addrs,
                                  self._etcd_base_dir_f, True)

    def _setUpDataSource(self):
        self._data_source_name = 'test_data_source'
        self._etcd_l.delete_prefix(self._data_source_name)
        self._etcd_f.delete_prefix(self._data_source_name)
        self._data_source_l = common_pb.DataSource()
        self._data_source_l.role = common_pb.FLRole.Leader
        self._data_source_l.state = common_pb.DataSourceState.Init
        self._data_source_l.data_block_dir = "./data_block_l"
        self._data_source_l.raw_data_dir = "./raw_data_l"
        self._data_source_l.example_dumped_dir = "./example_dumped_l"
        self._data_source_l.raw_data_sub_dir = "./raw_data_sub_dir_l"
        self._data_source_f = common_pb.DataSource()
        self._data_source_f.role = common_pb.FLRole.Follower
        self._data_source_f.state = common_pb.DataSourceState.Init
        self._data_source_f.data_block_dir = "./data_block_f"
        self._data_source_f.raw_data_dir = "./raw_data_f"
        self._data_source_f.example_dumped_dir = "./example_dumped_f"
        self._data_source_f.raw_data_sub_dir = "./raw_data_sub_dir_f"
        data_source_meta = common_pb.DataSourceMeta()
        data_source_meta.name = self._data_source_name
        data_source_meta.partition_num = 4
        data_source_meta.start_time = 0
        data_source_meta.end_time = 100000000
        self._data_source_l.data_source_meta.MergeFrom(data_source_meta)
        self._data_source_f.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(self._etcd_l, self._data_source_l)
        common.commit_data_source(self._etcd_f, self._data_source_f)

    def _generate_input_csv(self, cands, base_dir):
        if not gfile.Exists(base_dir):
            gfile.MakeDirs(base_dir)
        fpaths = []
        random.shuffle(cands)
        csv_writers = []
        partition_num = self._data_source_l.data_source_meta.partition_num
        for partition_id in range(partition_num):
            fpath = os.path.join(base_dir, str(partition_id) + '.rd')
            fpaths.append(fpath)
            csv_writers.append(csv_dict_writer.CsvDictWriter(fpath))
        for item in cands:
            partition_id = CityHash32(item) % partition_num
            raw = OrderedDict()
            raw['raw_id'] = item
            raw['feat_0'] = str((partition_id << 30) + 0) + item
            raw['feat_1'] = str((partition_id << 30) + 1) + item
            raw['feat_2'] = str((partition_id << 30) + 2) + item
            csv_writers[partition_id].write(raw)
        for csv_writer in csv_writers:
            csv_writer.close()
        return fpaths

    def _setUpRsaPsiConf(self):
        self._input_dir_l = './rsa_psi_raw_input_l'
        self._input_dir_f = './rsa_psi_raw_input_f'
        self._pre_processor_ouput_dir_l = './pre_processor_output_dir_l'
        self._pre_processor_ouput_dir_f = './pre_processor_output_dir_f'
        key_dir = path.join(path.dirname(path.abspath(__file__)), '../rsa_key')
        self._rsa_public_key_path = path.join(key_dir, 'rsa_psi.pub')
        self._rsa_private_key_path = path.join(key_dir, 'rsa_psi')
        self._raw_data_pub_dir_l = self._data_source_l.raw_data_sub_dir
        self._raw_data_pub_dir_f = self._data_source_f.raw_data_sub_dir

    def _gen_psi_input_raw_data(self):
        self._intersection_ids = set(
            ['{:09}'.format(i) for i in range(0, 1 << 16) if i % 3 == 0])
        self._rsa_raw_id_l = set([
            '{:09}'.format(i) for i in range(0, 1 << 16) if i % 2 == 0
        ]) | self._intersection_ids
        self._rsa_raw_id_f = set([
            '{:09}'.format(i) for i in range(0, 1 << 16) if i % 2 == 1
        ]) | self._intersection_ids
        self._input_dir_l = './rsa_psi_raw_input_l'
        self._input_dir_f = './rsa_psi_raw_input_f'
        self._psi_raw_data_fpaths_l = self._generate_input_csv(
            list(self._rsa_raw_id_l), self._input_dir_l)
        self._psi_raw_data_fpaths_f = self._generate_input_csv(
            list(self._rsa_raw_id_f), self._input_dir_f)

    def _remove_existed_dir(self):
        if gfile.Exists(self._input_dir_l):
            gfile.DeleteRecursively(self._input_dir_l)
        if gfile.Exists(self._input_dir_f):
            gfile.DeleteRecursively(self._input_dir_f)
        if gfile.Exists(self._pre_processor_ouput_dir_l):
            gfile.DeleteRecursively(self._pre_processor_ouput_dir_l)
        if gfile.Exists(self._pre_processor_ouput_dir_f):
            gfile.DeleteRecursively(self._pre_processor_ouput_dir_f)
        if gfile.Exists(self._data_source_l.data_block_dir):
            gfile.DeleteRecursively(self._data_source_l.data_block_dir)
        if gfile.Exists(self._data_source_l.raw_data_dir):
            gfile.DeleteRecursively(self._data_source_l.raw_data_dir)
        if gfile.Exists(self._data_source_l.example_dumped_dir):
            gfile.DeleteRecursively(self._data_source_l.example_dumped_dir)
        if gfile.Exists(self._data_source_f.data_block_dir):
            gfile.DeleteRecursively(self._data_source_f.data_block_dir)
        if gfile.Exists(self._data_source_f.raw_data_dir):
            gfile.DeleteRecursively(self._data_source_f.raw_data_dir)
        if gfile.Exists(self._data_source_f.example_dumped_dir):
            gfile.DeleteRecursively(self._data_source_f.example_dumped_dir)

    def _launch_masters(self):
        self._master_addr_l = 'localhost:4061'
        self._master_addr_f = 'localhost:4062'
        master_options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True)
        self._master_l = data_join_master.DataJoinMasterService(
            int(self._master_addr_l.split(':')[1]), self._master_addr_f,
            self._data_source_name, self._etcd_name, self._etcd_base_dir_l,
            self._etcd_addrs, master_options)
        self._master_f = data_join_master.DataJoinMasterService(
            int(self._master_addr_f.split(':')[1]), self._master_addr_l,
            self._data_source_name, self._etcd_name, self._etcd_base_dir_f,
            self._etcd_addrs, master_options)
        self._master_f.start()
        self._master_l.start()
        channel_l = make_insecure_channel(self._master_addr_l,
                                          ChannelType.INTERNAL)
        self._master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l)
        channel_f = make_insecure_channel(self._master_addr_f,
                                          ChannelType.INTERNAL)
        self._master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f)

        while True:
            req_l = dj_pb.DataSourceRequest(
                data_source_meta=self._data_source_l.data_source_meta)
            req_f = dj_pb.DataSourceRequest(
                data_source_meta=self._data_source_f.data_source_meta)
            dss_l = self._master_client_l.GetDataSourceStatus(req_l)
            dss_f = self._master_client_f.GetDataSourceStatus(req_f)
            self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
            self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
            if dss_l.state == common_pb.DataSourceState.Processing and \
                    dss_f.state == common_pb.DataSourceState.Processing:
                break
            else:
                time.sleep(2)
        logging.info("masters turn into Processing state")

    def _launch_workers(self):
        worker_options = dj_pb.DataJoinWorkerOptions(
            use_mock_etcd=True,
            raw_data_options=dj_pb.RawDataOptions(raw_data_iter='CSV_DICT',
                                                  compressed_type=''),
            example_id_dump_options=dj_pb.ExampleIdDumpOptions(
                example_id_dump_interval=1, example_id_dump_threshold=1024),
            example_joiner_options=dj_pb.ExampleJoinerOptions(
                example_joiner='SORT_RUN_JOINER',
                min_matching_window=64,
                max_matching_window=256,
                data_block_dump_interval=30,
                data_block_dump_threshold=1000),
            batch_processor_options=dj_pb.BatchProcessorOptions(
                batch_size=1024, max_flying_item=4096),
            data_block_builder_options=dj_pb.DataBlockBuilderOptions(
                data_block_builder='CSV_DICT_DATABLOCK_BUILDER'))
        self._worker_addrs_l = [
            'localhost:4161', 'localhost:4162', 'localhost:4163',
            'localhost:4164'
        ]
        self._worker_addrs_f = [
            'localhost:5161', 'localhost:5162', 'localhost:5163',
            'localhost:5164'
        ]
        self._workers_l = []
        self._workers_f = []
        for rank_id in range(4):
            worker_addr_l = self._worker_addrs_l[rank_id]
            worker_addr_f = self._worker_addrs_f[rank_id]
            self._workers_l.append(
                data_join_worker.DataJoinWorkerService(
                    int(worker_addr_l.split(':')[1]), worker_addr_f,
                    self._master_addr_l, rank_id, self._etcd_name,
                    self._etcd_base_dir_l, self._etcd_addrs, worker_options))
            self._workers_f.append(
                data_join_worker.DataJoinWorkerService(
                    int(worker_addr_f.split(':')[1]), worker_addr_l,
                    self._master_addr_f, rank_id, self._etcd_name,
                    self._etcd_base_dir_f, self._etcd_addrs, worker_options))
        for w in self._workers_l:
            w.start()
        for w in self._workers_f:
            w.start()

    def _launch_rsa_psi_signer(self):
        self._rsa_psi_signer_addr = 'localhost:6171'
        rsa_private_key_pem = None
        with gfile.GFile(self._rsa_private_key_path, 'rb') as f:
            rsa_private_key_pem = f.read()
        rsa_private_key = rsa.PrivateKey.load_pkcs1(rsa_private_key_pem)
        self._rsa_psi_signer = rsa_psi_signer.RsaPsiSigner(
            rsa_private_key, 1, 500)
        self._rsa_psi_signer.start(
            int(self._rsa_psi_signer_addr.split(':')[1]), 512)

    def _stop_workers(self):
        for w in self._workers_f:
            w.stop()
        for w in self._workers_l:
            w.stop()

    def _stop_masters(self):
        self._master_f.stop()
        self._master_l.stop()

    def _stop_rsa_psi_signer(self):
        self._rsa_psi_signer.stop()

    def setUp(self):
        self._setUpEtcd()
        self._setUpDataSource()
        self._setUpRsaPsiConf()
        self._remove_existed_dir()
        self._gen_psi_input_raw_data()
        self._launch_masters()
        self._launch_workers()
        self._launch_rsa_psi_signer()

    def _preprocess_rsa_psi_leader(self):
        processors = []
        rsa_key_pem = None
        with gfile.GFile(self._rsa_private_key_path, 'rb') as f:
            rsa_key_pem = f.read()
        for partition_id in range(
                self._data_source_l.data_source_meta.partition_num):
            options = dj_pb.RsaPsiPreProcessorOptions(
                preprocessor_name='leader-rsa-psi-processor',
                role=common_pb.FLRole.Leader,
                rsa_key_pem=rsa_key_pem,
                input_file_paths=[self._psi_raw_data_fpaths_l[partition_id]],
                output_file_dir=self._pre_processor_ouput_dir_l,
                raw_data_publish_dir=self._raw_data_pub_dir_l,
                partition_id=partition_id,
                offload_processor_number=1,
                max_flying_sign_batch=128,
                stub_fanout=2,
                slow_sign_threshold=8,
                sort_run_merger_read_ahead_buffer=1 << 20,
                batch_processor_options=dj_pb.BatchProcessorOptions(
                    batch_size=1024, max_flying_item=1 << 14))
            processor = rsa_psi_preprocessor.RsaPsiPreProcessor(
                options, self._etcd_name, self._etcd_addrs,
                self._etcd_base_dir_l, True)
            processor.start_process()
            processors.append(processor)
        for processor in processors:
            processor.wait_for_finished()

    def _preprocess_rsa_psi_follower(self):
        processors = []
        rsa_key_pem = None
        with gfile.GFile(self._rsa_public_key_path, 'rb') as f:
            rsa_key_pem = f.read()
        for partition_id in range(
                self._data_source_f.data_source_meta.partition_num):
            options = dj_pb.RsaPsiPreProcessorOptions(
                preprocessor_name='follower-rsa-psi-processor',
                role=common_pb.FLRole.Follower,
                rsa_key_pem=rsa_key_pem,
                input_file_paths=[self._psi_raw_data_fpaths_f[partition_id]],
                output_file_dir=self._pre_processor_ouput_dir_f,
                raw_data_publish_dir=self._raw_data_pub_dir_f,
                partition_id=partition_id,
                leader_rsa_psi_signer_addr=self._rsa_psi_signer_addr,
                offload_processor_number=1,
                max_flying_sign_batch=128,
                max_flying_sign_rpc=64,
                sign_rpc_timeout_ms=100000,
                stub_fanout=2,
                slow_sign_threshold=8,
                sort_run_merger_read_ahead_buffer=1 << 20,
                rpc_sync_mode=True if partition_id % 2 == 0 else False,
                rpc_thread_pool_size=16,
                batch_processor_options=dj_pb.BatchProcessorOptions(
                    batch_size=1024, max_flying_item=1 << 14))
            processor = rsa_psi_preprocessor.RsaPsiPreProcessor(
                options, self._etcd_name, self._etcd_addrs,
                self._etcd_base_dir_f, True)
            processor.start_process()
            processors.append(processor)
        for processor in processors:
            processor.wait_for_finished()

    def test_all_pipeline(self):
        start_tm = time.time()
        self._preprocess_rsa_psi_follower()
        logging.warning("Follower Preprocess cost %d seconds",
                        time.time() - start_tm)
        start_tm = time.time()
        self._preprocess_rsa_psi_leader()
        logging.warning("Leader Preprocess cost %f seconds",
                        time.time() - start_tm)
        while True:
            req_l = dj_pb.DataSourceRequest(
                data_source_meta=self._data_source_l.data_source_meta)
            req_f = dj_pb.DataSourceRequest(
                data_source_meta=self._data_source_f.data_source_meta)
            dss_l = self._master_client_l.GetDataSourceStatus(req_l)
            dss_f = self._master_client_f.GetDataSourceStatus(req_f)
            self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
            self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
            if dss_l.state == common_pb.DataSourceState.Finished and \
                    dss_f.state == common_pb.DataSourceState.Finished:
                break
            else:
                time.sleep(2)
        logging.info("masters turn into Finished state")

    def tearDown(self):
        self._stop_workers()
        self._stop_masters()
        self._stop_rsa_psi_signer()
        self._remove_existed_dir()
Exemple #5
0
    def test_api(self):
        logging.getLogger().setLevel(logging.DEBUG)
        etcd_name = 'test_etcd'
        etcd_addrs = 'localhost:2379'
        etcd_base_dir_l = 'byefl_l'
        etcd_base_dir_f = 'byefl_f'
        data_source_name = 'test_data_source'
        etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True)
        etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True)
        etcd_l.delete_prefix(
            common.data_source_etcd_base_dir(data_source_name))
        etcd_f.delete_prefix(
            common.data_source_etcd_base_dir(data_source_name))
        data_source_l = common_pb.DataSource()
        data_source_l.role = common_pb.FLRole.Leader
        data_source_l.state = common_pb.DataSourceState.Init
        data_source_l.output_base_dir = "./ds_output_l"
        data_source_f = common_pb.DataSource()
        data_source_f.role = common_pb.FLRole.Follower
        data_source_f.state = common_pb.DataSourceState.Init
        data_source_f.output_base_dir = "./ds_output_f"
        data_source_meta = common_pb.DataSourceMeta()
        data_source_meta.name = data_source_name
        data_source_meta.partition_num = 1
        data_source_meta.start_time = 0
        data_source_meta.end_time = 100000000
        data_source_l.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_l, data_source_l)
        data_source_f.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_f, data_source_f)

        master_addr_l = 'localhost:4061'
        master_addr_f = 'localhost:4062'
        options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True)
        master_l = data_join_master.DataJoinMasterService(
            int(master_addr_l.split(':')[1]), master_addr_f, data_source_name,
            etcd_name, etcd_base_dir_l, etcd_addrs, options)
        master_l.start()
        master_f = data_join_master.DataJoinMasterService(
            int(master_addr_f.split(':')[1]), master_addr_l, data_source_name,
            etcd_name, etcd_base_dir_f, etcd_addrs, options)
        master_f.start()
        channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL)
        client_l = dj_grpc.DataJoinMasterServiceStub(channel_l)
        channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL)
        client_f = dj_grpc.DataJoinMasterServiceStub(channel_f)

        while True:
            req_l = dj_pb.DataSourceRequest(
                data_source_meta=data_source_l.data_source_meta)
            req_f = dj_pb.DataSourceRequest(
                data_source_meta=data_source_f.data_source_meta)
            dss_l = client_l.GetDataSourceStatus(req_l)
            dss_f = client_f.GetDataSourceStatus(req_f)
            self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
            self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
            if dss_l.state == common_pb.DataSourceState.Processing and \
                    dss_f.state == common_pb.DataSourceState.Processing:
                break
            else:
                time.sleep(2)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=-1,
            join_example=empty_pb2.Empty())

        rdrsp = client_f.RequestJoinPartition(rdreq)
        self.assertTrue(rdrsp.status.code == 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0)
        self.assertEqual(rdrsp.manifest.join_example_rep.state,
                         dj_pb.JoinExampleState.Joining)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        #check idempotent
        rdrsp = client_f.RequestJoinPartition(rdreq)
        self.assertTrue(rdrsp.status.code == 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0)
        self.assertEqual(rdrsp.manifest.join_example_rep.state,
                         dj_pb.JoinExampleState.Joining)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            join_example=empty_pb2.Empty())
        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertTrue(rdrsp.status.code == 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0)
        self.assertEqual(rdrsp.manifest.join_example_rep.state,
                         dj_pb.JoinExampleState.Joining)
        self.assertEqual(rdrsp.manifest.partition_id, 0)
        #check idempotent
        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertTrue(rdrsp.status.code == 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0)
        self.assertEqual(rdrsp.manifest.join_example_rep.state,
                         dj_pb.JoinExampleState.Joining)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=1,
            partition_id=-1,
            sync_example_id=empty_pb2.Empty())
        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1)
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.state,
                         dj_pb.SyncExampleIdState.Syncing)
        self.assertEqual(rdrsp.manifest.partition_id, 0)
        #check idempotent
        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1)
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.state,
                         dj_pb.SyncExampleIdState.Syncing)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=1,
            partition_id=0,
            sync_example_id=empty_pb2.Empty())
        rdrsp = client_f.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1)
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.state,
                         dj_pb.SyncExampleIdState.Syncing)
        self.assertEqual(rdrsp.manifest.partition_id, 0)
        #check idempotent
        rdrsp = client_f.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1)
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.state,
                         dj_pb.SyncExampleIdState.Syncing)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdreq1 = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=1,
            partition_id=0,
            sync_example_id=empty_pb2.Empty())

        try:
            rsp = client_l.FinishJoinPartition(rdreq1)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        rdreq2 = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            join_example=empty_pb2.Empty())
        try:
            rsp = client_l.FinishJoinPartition(rdreq2)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        manifest_l = client_l.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_l is not None)
        self.assertFalse(manifest_l.finished)
        self.assertEqual(manifest_l.next_process_index, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_l.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 0)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=False,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='a',
                        timestamp=timestamp_pb2.Timestamp(seconds=3))
                ]))
        rsp = client_l.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_l = client_l.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_l is not None)
        self.assertFalse(manifest_l.finished)
        self.assertEqual(manifest_l.next_process_index, 1)
        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_l.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 3)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=False,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='b',
                        timestamp=timestamp_pb2.Timestamp(seconds=5))
                ]))
        rsp = client_l.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_l = client_l.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_l is not None)
        self.assertFalse(manifest_l.finished)
        self.assertEqual(manifest_l.next_process_index, 2)
        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_l.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 5)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=True,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='b',
                        timestamp=timestamp_pb2.Timestamp(seconds=5)),
                    dj_pb.RawDataMeta(
                        file_path='a',
                        timestamp=timestamp_pb2.Timestamp(seconds=3))
                ]))
        rsp = client_l.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_l = client_l.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_l is not None)
        self.assertFalse(manifest_l.finished)
        self.assertEqual(manifest_l.next_process_index, 2)
        rsp = client_l.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 5)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        manifest_f = client_f.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_f is not None)
        self.assertFalse(manifest_f.finished)
        self.assertEqual(manifest_f.next_process_index, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_f.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 0)
        self.assertEqual(rsp.timestamp.nanos, 0)
        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=False,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='a',
                        timestamp=timestamp_pb2.Timestamp(seconds=1))
                ]))
        rsp = client_f.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_f = client_f.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_f is not None)
        self.assertFalse(manifest_f.finished)
        self.assertEqual(manifest_f.next_process_index, 1)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_f.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 1)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=False,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='b',
                        timestamp=timestamp_pb2.Timestamp(seconds=2))
                ]))
        rsp = client_f.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_f = client_f.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_f is not None)
        self.assertFalse(manifest_f.finished)
        self.assertEqual(manifest_f.next_process_index, 2)
        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_f.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 2)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=True,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='a',
                        timestamp=timestamp_pb2.Timestamp(seconds=1)),
                    dj_pb.RawDataMeta(
                        file_path='b',
                        timestamp=timestamp_pb2.Timestamp(seconds=2))
                ]))
        rsp = client_f.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_f = client_f.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_f is not None)
        self.assertFalse(manifest_f.finished)
        self.assertEqual(manifest_f.next_process_index, 2)
        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_f.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 2)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            finish_raw_data=empty_pb2.Empty())
        rsp = client_l.FinishRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_l.FinishRawData(rdreq)
        self.assertEqual(rsp.code, 0)

        manifest_l = client_l.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_l is not None)
        self.assertTrue(manifest_l.finished)
        self.assertEqual(manifest_l.next_process_index, 2)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=False,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='x',
                        timestamp=timestamp_pb2.Timestamp(seconds=4))
                ]))
        try:
            rsp = client_l.AddRawData(rdreq)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        try:
            rsp = client_f.FinishJoinPartition(rdreq2)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        rsp = client_l.FinishJoinPartition(rdreq1)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_l.FinishJoinPartition(rdreq1)
        self.assertEqual(rsp.code, 0)

        rsp = client_f.FinishJoinPartition(rdreq1)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_f.FinishJoinPartition(rdreq1)
        self.assertEqual(rsp.code, 0)

        try:
            rsp = client_f.FinishJoinPartition(rdreq2)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
            finish_raw_data=empty_pb2.Empty())
        rsp = client_f.FinishRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_f.FinishRawData(rdreq)
        self.assertEqual(rsp.code, 0)

        manifest_f = client_f.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_f is not None)
        self.assertTrue(manifest_f.finished)
        self.assertEqual(manifest_f.next_process_index, 2)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=True,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='x',
                        timestamp=timestamp_pb2.Timestamp(seconds=3))
                ]))
        try:
            rsp = client_f.AddRawData(rdreq)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        rsp = client_f.FinishJoinPartition(rdreq2)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_f.FinishJoinPartition(rdreq2)
        self.assertEqual(rsp.code, 0)

        rsp = client_l.FinishJoinPartition(rdreq2)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_l.FinishJoinPartition(rdreq2)
        self.assertEqual(rsp.code, 0)

        while True:
            req_l = dj_pb.DataSourceRequest(
                data_source_meta=data_source_l.data_source_meta)
            req_f = dj_pb.DataSourceRequest(
                data_source_meta=data_source_f.data_source_meta)
            dss_l = client_l.GetDataSourceStatus(req_l)
            dss_f = client_f.GetDataSourceStatus(req_f)
            self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
            self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
            if dss_l.state == common_pb.DataSourceState.Finished and \
                    dss_f.state == common_pb.DataSourceState.Finished:
                break
            else:
                time.sleep(2)

        master_l.stop()
        master_f.stop()
    def test_api(self):
        logging.getLogger().setLevel(logging.DEBUG)
        etcd_name = 'test_etcd'
        etcd_addrs = 'localhost:2379'
        etcd_base_dir = 'dp_test'
        data_portal_name = 'test_data_source'
        etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir, True)
        etcd.delete_prefix(etcd_base_dir)
        portal_input_base_dir='./portal_upload_dir'
        portal_output_base_dir='./portal_output_dir'
        raw_data_publish_dir = 'raw_data_publish_dir'
        portal_manifest = dp_pb.DataPortalManifest(
                name=data_portal_name,
                data_portal_type=dp_pb.DataPortalType.Streaming,
                output_partition_num=4,
                input_file_wildcard="*.done",
                input_base_dir=portal_input_base_dir,
                output_base_dir=portal_output_base_dir,
                raw_data_publish_dir=raw_data_publish_dir,
                processing_job_id=-1,
                next_job_id=0
            )
        etcd.set_data(common.portal_etcd_base_dir(data_portal_name),
                      text_format.MessageToString(portal_manifest))
        if gfile.Exists(portal_input_base_dir):
            gfile.DeleteRecursively(portal_input_base_dir)
        gfile.MakeDirs(portal_input_base_dir)
        all_fnames = ['{}.done'.format(i) for i in range(100)]
        all_fnames.append('{}.xx'.format(100))
        for fname in all_fnames:
            fpath = os.path.join(portal_input_base_dir, fname)
            with gfile.Open(fpath, "w") as f:
                f.write('xxx')
        portal_master_addr = 'localhost:4061'
        portal_options = dp_pb.DataPotraMasterlOptions(
                use_mock_etcd=True,
                long_running=False
            )
        data_portal_master = DataPortalMasterService(
                int(portal_master_addr.split(':')[1]),
                data_portal_name, etcd_name, etcd_base_dir,
                etcd_addrs, portal_options
            )
        data_portal_master.start()

        channel = make_insecure_channel(portal_master_addr, ChannelType.INTERNAL)
        portal_master_cli = dp_grpc.DataPortalMasterServiceStub(channel)
        recv_manifest = portal_master_cli.GetDataPortalManifest(empty_pb2.Empty())
        self.assertEqual(recv_manifest.name, portal_manifest.name)
        self.assertEqual(recv_manifest.data_portal_type, portal_manifest.data_portal_type)
        self.assertEqual(recv_manifest.output_partition_num, portal_manifest.output_partition_num)
        self.assertEqual(recv_manifest.input_file_wildcard, portal_manifest.input_file_wildcard)
        self.assertEqual(recv_manifest.input_base_dir, portal_manifest.input_base_dir)
        self.assertEqual(recv_manifest.output_base_dir, portal_manifest.output_base_dir)
        self.assertEqual(recv_manifest.raw_data_publish_dir, portal_manifest.raw_data_publish_dir)
        self.assertEqual(recv_manifest.next_job_id, 1)
        self.assertEqual(recv_manifest.processing_job_id, 0)
        self._check_portal_job(etcd, all_fnames, portal_manifest, 0)
        mapped_partition = set()
        task_0 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=0))
        task_0_1 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=0))
        self.assertEqual(task_0, task_0_1)
        self.assertTrue(task_0.HasField('map_task'))
        mapped_partition.add(task_0.map_task.partition_id)
        self._check_map_task(task_0.map_task, all_fnames,
                             task_0.map_task.partition_id,
                             portal_manifest)
        portal_master_cli.FinishTask(dp_pb.FinishTaskRequest(
            rank_id=0, partition_id=task_0.map_task.partition_id,
            part_state=dp_pb.PartState.kIdMap)
        )
        task_1 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=0))
        self.assertTrue(task_1.HasField('map_task'))
        mapped_partition.add(task_1.map_task.partition_id)
        self._check_map_task(task_1.map_task, all_fnames,
                             task_1.map_task.partition_id,
                             portal_manifest)

        task_2 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=1))
        self.assertTrue(task_2.HasField('map_task'))
        mapped_partition.add(task_2.map_task.partition_id)
        self._check_map_task(task_2.map_task, all_fnames,
                             task_2.map_task.partition_id,
                             portal_manifest)

        task_3 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=2))
        self.assertTrue(task_3.HasField('map_task'))
        mapped_partition.add(task_3.map_task.partition_id)
        self._check_map_task(task_3.map_task, all_fnames,
                             task_3.map_task.partition_id,
                             portal_manifest)

        self.assertEqual(len(mapped_partition), portal_manifest.output_partition_num)

        portal_master_cli.FinishTask(dp_pb.FinishTaskRequest(
            rank_id=0, partition_id=task_1.map_task.partition_id,
            part_state=dp_pb.PartState.kIdMap)
        )

        pending_1 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=4))
        self.assertTrue(pending_1.HasField('pending'))
        pending_2 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=3))
        self.assertTrue(pending_2.HasField('pending'))

        portal_master_cli.FinishTask(dp_pb.FinishTaskRequest(
            rank_id=1, partition_id=task_2.map_task.partition_id,
            part_state=dp_pb.PartState.kIdMap)
        )

        portal_master_cli.FinishTask(dp_pb.FinishTaskRequest(
            rank_id=2, partition_id=task_3.map_task.partition_id,
            part_state=dp_pb.PartState.kIdMap)
        )

        reduce_partition = set()
        task_4 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=0))
        task_4_1 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=0))
        self.assertEqual(task_4, task_4_1)
        self.assertTrue(task_4.HasField('reduce_task'))
        reduce_partition.add(task_4.reduce_task.partition_id)
        self._check_reduce_task(task_4.reduce_task,
                                task_4.reduce_task.partition_id,
                                portal_manifest)
        task_5 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=1))
        self.assertTrue(task_5.HasField('reduce_task'))
        reduce_partition.add(task_5.reduce_task.partition_id)
        self._check_reduce_task(task_5.reduce_task,
                                task_5.reduce_task.partition_id,
                                portal_manifest)
        task_6 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=2))
        self.assertTrue(task_6.HasField('reduce_task'))
        reduce_partition.add(task_6.reduce_task.partition_id)
        self._check_reduce_task(task_6.reduce_task,
                                task_6.reduce_task.partition_id,
                                portal_manifest)
        task_7= portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=3))
        self.assertTrue(task_7.HasField('reduce_task'))
        reduce_partition.add(task_7.reduce_task.partition_id)
        self.assertEqual(len(reduce_partition), 4)
        self._check_reduce_task(task_7.reduce_task,
                                task_7.reduce_task.partition_id,
                                portal_manifest)

        task_8= portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=5))
        self.assertTrue(task_8.HasField('pending'))

        portal_master_cli.FinishTask(dp_pb.FinishTaskRequest(
            rank_id=0, partition_id=task_4.reduce_task.partition_id,
            part_state=dp_pb.PartState.kEventTimeReduce)
        )
        portal_master_cli.FinishTask(dp_pb.FinishTaskRequest(
            rank_id=1, partition_id=task_5.reduce_task.partition_id,
            part_state=dp_pb.PartState.kEventTimeReduce)
        )
        portal_master_cli.FinishTask(dp_pb.FinishTaskRequest(
            rank_id=2, partition_id=task_6.reduce_task.partition_id,
            part_state=dp_pb.PartState.kEventTimeReduce)
        )
        portal_master_cli.FinishTask(dp_pb.FinishTaskRequest(
            rank_id=3, partition_id=task_7.reduce_task.partition_id,
            part_state=dp_pb.PartState.kEventTimeReduce)
        )

        task_9= portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=5))
        self.assertTrue(task_9.HasField('finished'))

        data_portal_master.stop()
        gfile.DeleteRecursively(portal_input_base_dir)
Exemple #7
0
    def setUp(self):
        etcd_name = 'test_etcd'
        etcd_addrs = 'localhost:2379'
        etcd_base_dir_l = 'byefl_l'
        etcd_base_dir_f = 'byefl_f'
        data_source_name = 'test_data_source'
        etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True)
        etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True)
        etcd_l.delete_prefix(data_source_name)
        etcd_f.delete_prefix(data_source_name)
        data_source_l = common_pb.DataSource()
        data_source_l.role = common_pb.FLRole.Leader
        data_source_l.state = common_pb.DataSourceState.Init
        data_source_l.data_block_dir = "./data_block_l"
        data_source_l.raw_data_dir = "./raw_data_l"
        data_source_l.example_dumped_dir = "./example_dumped_l"
        data_source_f = common_pb.DataSource()
        data_source_f.role = common_pb.FLRole.Follower
        data_source_f.state = common_pb.DataSourceState.Init
        data_source_f.data_block_dir = "./data_block_f"
        data_source_f.raw_data_dir = "./raw_data_f"
        data_source_f.example_dumped_dir = "./example_dumped_f"
        data_source_meta = common_pb.DataSourceMeta()
        data_source_meta.name = data_source_name
        data_source_meta.partition_num = 2
        data_source_meta.start_time = 0
        data_source_meta.end_time = 100000000
        data_source_meta.min_matching_window = 64
        data_source_meta.max_matching_window = 128
        data_source_meta.data_source_type = common_pb.DataSourceType.Sequential
        data_source_meta.max_example_in_data_block = 1000
        data_source_l.data_source_meta.MergeFrom(data_source_meta)
        etcd_l.set_data(os.path.join(data_source_name, 'master'),
                        text_format.MessageToString(data_source_l))
        data_source_f.data_source_meta.MergeFrom(data_source_meta)
        etcd_f.set_data(os.path.join(data_source_name, 'master'),
                        text_format.MessageToString(data_source_f))
        customized_options.set_use_mock_etcd()

        master_addr_l = 'localhost:4061'
        master_addr_f = 'localhost:4062'
        master_l = data_join_master.DataJoinMasterService(
            int(master_addr_l.split(':')[1]), master_addr_f, data_source_name,
            etcd_name, etcd_base_dir_l, etcd_addrs)
        master_l.start()
        master_f = data_join_master.DataJoinMasterService(
            int(master_addr_f.split(':')[1]), master_addr_l, data_source_name,
            etcd_name, etcd_base_dir_f, etcd_addrs)
        master_f.start()
        channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL)
        master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l)
        channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL)
        master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f)

        while True:
            rsp_l = master_client_l.GetDataSourceState(
                data_source_l.data_source_meta)
            rsp_f = master_client_f.GetDataSourceState(
                data_source_f.data_source_meta)
            self.assertEqual(rsp_l.status.code, 0)
            self.assertEqual(rsp_l.role, common_pb.FLRole.Leader)
            self.assertEqual(rsp_l.data_source_type,
                             common_pb.DataSourceType.Sequential)
            self.assertEqual(rsp_f.status.code, 0)
            self.assertEqual(rsp_f.role, common_pb.FLRole.Follower)
            self.assertEqual(rsp_f.data_source_type,
                             common_pb.DataSourceType.Sequential)
            if (rsp_l.state == common_pb.DataSourceState.Processing
                    and rsp_f.state == common_pb.DataSourceState.Processing):
                break
            else:
                time.sleep(2)

        self.master_client_l = master_client_l
        self.master_client_f = master_client_f
        self.master_addr_l = master_addr_l
        self.master_addr_f = master_addr_f
        self.etcd_l = etcd_l
        self.etcd_f = etcd_f
        self.data_source_l = data_source_l
        self.data_source_f = data_source_f
        self.master_l = master_l
        self.master_f = master_f
        self.data_source_name = data_source_name,
        self.etcd_name = etcd_name
        self.etcd_addrs = etcd_addrs
        self.etcd_base_dir_l = etcd_base_dir_l
        self.etcd_base_dir_f = etcd_base_dir_f
        if gfile.Exists(data_source_l.data_block_dir):
            gfile.DeleteRecursively(data_source_l.data_block_dir)
        if gfile.Exists(data_source_l.example_dumped_dir):
            gfile.DeleteRecursively(data_source_l.example_dumped_dir)
        if gfile.Exists(data_source_l.raw_data_dir):
            gfile.DeleteRecursively(data_source_l.raw_data_dir)
        if gfile.Exists(data_source_f.data_block_dir):
            gfile.DeleteRecursively(data_source_f.data_block_dir)
        if gfile.Exists(data_source_f.example_dumped_dir):
            gfile.DeleteRecursively(data_source_f.example_dumped_dir)
        if gfile.Exists(data_source_f.raw_data_dir):
            gfile.DeleteRecursively(data_source_f.raw_data_dir)

        self.total_index = 1 << 13
class DataJoinPortal(unittest.TestCase):
    def _setUpEtcd(self):
        self._etcd_name = 'test_etcd'
        self._etcd_addrs = 'localhost:2379'
        self._etcd_base_dir_l = 'byefl_l'
        self._etcd_base_dir_f = 'byefl_f'
        self._etcd_l = EtcdClient(self._etcd_name, self._etcd_addrs,
                                  self._etcd_base_dir_l, True)
        self._etcd_f = EtcdClient(self._etcd_name, self._etcd_addrs,
                                  self._etcd_base_dir_f, True)

    def _setUpDataSource(self):
        self._data_source_name = 'test_data_source'
        self._etcd_l.delete_prefix(self._data_source_name)
        self._etcd_f.delete_prefix(self._data_source_name)
        self._raw_data_pub_dir_l = './raw_data_pub_dir_l'
        self._raw_data_pub_dir_f = './raw_data_pub_dir_f'
        self._data_source_l = common_pb.DataSource()
        self._data_source_l.role = common_pb.FLRole.Leader
        self._data_source_l.state = common_pb.DataSourceState.Init
        self._data_source_l.data_block_dir = "./data_block_l"
        self._data_source_l.raw_data_dir = "./raw_data_l"
        self._data_source_l.example_dumped_dir = "./example_dumped_l"
        self._data_source_l.raw_data_sub_dir = self._raw_data_pub_dir_l
        self._data_source_f = common_pb.DataSource()
        self._data_source_f.role = common_pb.FLRole.Follower
        self._data_source_f.state = common_pb.DataSourceState.Init
        self._data_source_f.data_block_dir = "./data_block_f"
        self._data_source_f.raw_data_dir = "./raw_data_f"
        self._data_source_f.example_dumped_dir = "./example_dumped_f"
        self._data_source_f.raw_data_sub_dir = self._raw_data_pub_dir_f
        data_source_meta = common_pb.DataSourceMeta()
        data_source_meta.name = self._data_source_name
        data_source_meta.partition_num = 2
        data_source_meta.start_time = 0
        data_source_meta.end_time = 100000000
        self._data_source_l.data_source_meta.MergeFrom(data_source_meta)
        self._data_source_f.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(self._etcd_l, self._data_source_l)
        common.commit_data_source(self._etcd_f, self._data_source_f)

    def _setUpPortalManifest(self):
        self._portal_name = 'test_portal'
        self._etcd_l.delete_prefix(self._portal_name)
        self._etcd_f.delete_prefix(self._portal_name)
        self._portal_manifest_l = common_pb.DataJoinPortalManifest(
            name=self._portal_name,
            input_partition_num=4,
            output_partition_num=2,
            input_data_base_dir='./portal_input_l',
            output_data_base_dir='./portal_output_l',
            raw_data_publish_dir=self._raw_data_pub_dir_l,
            begin_timestamp=common.trim_timestamp_by_hourly(
                common.convert_datetime_to_timestamp(datetime.now())))
        self._portal_manifest_f = common_pb.DataJoinPortalManifest(
            name=self._portal_name,
            input_partition_num=2,
            output_partition_num=2,
            input_data_base_dir='./portal_input_f',
            output_data_base_dir='./portal_output_f',
            raw_data_publish_dir=self._raw_data_pub_dir_f,
            begin_timestamp=common.trim_timestamp_by_hourly(
                common.convert_datetime_to_timestamp(datetime.now())))
        common.commit_portal_manifest(self._etcd_l, self._portal_manifest_l)
        common.commit_portal_manifest(self._etcd_f, self._portal_manifest_f)

    def _launch_masters(self):
        self._master_addr_l = 'localhost:4061'
        self._master_addr_f = 'localhost:4062'
        master_options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True)
        self._master_l = data_join_master.DataJoinMasterService(
            int(self._master_addr_l.split(':')[1]), self._master_addr_f,
            self._data_source_name, self._etcd_name, self._etcd_base_dir_l,
            self._etcd_addrs, master_options)
        self._master_f = data_join_master.DataJoinMasterService(
            int(self._master_addr_f.split(':')[1]), self._master_addr_l,
            self._data_source_name, self._etcd_name, self._etcd_base_dir_f,
            self._etcd_addrs, master_options)
        self._master_f.start()
        self._master_l.start()
        channel_l = make_insecure_channel(self._master_addr_l,
                                          ChannelType.INTERNAL)
        self._master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l)
        channel_f = make_insecure_channel(self._master_addr_f,
                                          ChannelType.INTERNAL)
        self._master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f)

        while True:
            req_l = dj_pb.DataSourceRequest(
                data_source_meta=self._data_source_l.data_source_meta)
            req_f = dj_pb.DataSourceRequest(
                data_source_meta=self._data_source_f.data_source_meta)
            dss_l = self._master_client_l.GetDataSourceStatus(req_l)
            dss_f = self._master_client_f.GetDataSourceStatus(req_f)
            self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
            self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
            if dss_l.state == common_pb.DataSourceState.Processing and \
                    dss_f.state == common_pb.DataSourceState.Processing:
                break
            else:
                time.sleep(2)
        logging.info("masters turn into Processing state")

    def _launch_workers(self):
        worker_options = dj_pb.DataJoinWorkerOptions(
            use_mock_etcd=True,
            raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                                  compressed_type=''),
            example_id_dump_options=dj_pb.ExampleIdDumpOptions(
                example_id_dump_interval=1, example_id_dump_threshold=1024),
            example_joiner_options=dj_pb.ExampleJoinerOptions(
                example_joiner='STREAM_JOINER',
                min_matching_window=64,
                max_matching_window=256,
                data_block_dump_interval=30,
                data_block_dump_threshold=1000),
            batch_processor_options=dj_pb.BatchProcessorOptions(
                batch_size=1024, max_flying_item=4096),
            data_block_builder_options=dj_pb.DataBlockBuilderOptions(
                data_block_builder='TF_RECORD_DATABLOCK_BUILDER'))
        self._worker_addrs_l = ['localhost:4161', 'localhost:4162']
        self._worker_addrs_f = ['localhost:5161', 'localhost:5162']
        self._workers_l = []
        self._workers_f = []
        for rank_id in range(2):
            worker_addr_l = self._worker_addrs_l[rank_id]
            worker_addr_f = self._worker_addrs_f[rank_id]
            self._workers_l.append(
                data_join_worker.DataJoinWorkerService(
                    int(worker_addr_l.split(':')[1]), worker_addr_f,
                    self._master_addr_l, rank_id, self._etcd_name,
                    self._etcd_base_dir_l, self._etcd_addrs, worker_options))
            self._workers_f.append(
                data_join_worker.DataJoinWorkerService(
                    int(worker_addr_f.split(':')[1]), worker_addr_l,
                    self._master_addr_f, rank_id, self._etcd_name,
                    self._etcd_base_dir_f, self._etcd_addrs, worker_options))
        for w in self._workers_l:
            w.start()
        for w in self._workers_f:
            w.start()

    def _launch_portals(self):
        portal_options = dj_pb.DataJoinPotralOptions(
            example_validator=dj_pb.ExampleValidatorOptions(
                example_validator='EXAMPLE_VALIDATOR',
                validate_event_time=True),
            reducer_buffer_size=1024,
            raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                                  compressed_type=''),
            use_mock_etcd=True)
        self._portal_l = data_join_portal.DataJoinPortal(
            self._portal_name, self._etcd_name, self._etcd_addrs,
            self._etcd_base_dir_l, portal_options)

        self._portal_f = data_join_portal.DataJoinPortal(
            self._portal_name, self._etcd_name, self._etcd_addrs,
            self._etcd_base_dir_f, portal_options)
        self._portal_l.start()
        self._portal_f.start()

    def _remove_existed_dir(self):
        if gfile.Exists(self._portal_manifest_l.input_data_base_dir):
            gfile.DeleteRecursively(
                self._portal_manifest_l.input_data_base_dir)
        if gfile.Exists(self._portal_manifest_l.output_data_base_dir):
            gfile.DeleteRecursively(
                self._portal_manifest_l.output_data_base_dir)
        if gfile.Exists(self._portal_manifest_f.input_data_base_dir):
            gfile.DeleteRecursively(
                self._portal_manifest_f.input_data_base_dir)
        if gfile.Exists(self._portal_manifest_f.output_data_base_dir):
            gfile.DeleteRecursively(
                self._portal_manifest_f.output_data_base_dir)
        if gfile.Exists(self._data_source_l.data_block_dir):
            gfile.DeleteRecursively(self._data_source_l.data_block_dir)
        if gfile.Exists(self._data_source_l.raw_data_dir):
            gfile.DeleteRecursively(self._data_source_l.raw_data_dir)
        if gfile.Exists(self._data_source_l.example_dumped_dir):
            gfile.DeleteRecursively(self._data_source_l.example_dumped_dir)
        if gfile.Exists(self._data_source_f.data_block_dir):
            gfile.DeleteRecursively(self._data_source_f.data_block_dir)
        if gfile.Exists(self._data_source_f.raw_data_dir):
            gfile.DeleteRecursively(self._data_source_f.raw_data_dir)
        if gfile.Exists(self._data_source_f.example_dumped_dir):
            gfile.DeleteRecursively(self._data_source_f.example_dumped_dir)

    def _generate_portal_input_data(self, date_time, event_time_filter,
                                    start_index, total_item_num,
                                    portal_manifest):
        self.assertEqual(total_item_num % portal_manifest.input_partition_num,
                         0)
        item_step = portal_manifest.input_partition_num
        for partition_id in range(portal_manifest.input_partition_num):
            cands = list(range(partition_id, total_item_num, item_step))
            for i in range(len(cands)):
                if random.randint(1, 4) > 1:
                    continue
                a = random.randint(i - 16, i + 16)
                b = random.randint(i - 16, i + 16)
                if a < 0:
                    a = 0
                if a >= len(cands):
                    a = len(cands) - 1
                if b < 0:
                    b = 0
                if b >= len(cands):
                    b = len(cands) - 1
                if abs(cands[a] // item_step -
                       b) <= 16 and abs(cands[b] // item_step - a) <= 16:
                    cands[a], cands[b] = cands[b], cands[a]
            fpath = common.encode_portal_hourly_fpath(
                portal_manifest.input_data_base_dir, date_time, partition_id)
            if not gfile.Exists(os.path.dirname(fpath)):
                gfile.MakeDirs(os.path.dirname(fpath))
            with tf.io.TFRecordWriter(fpath) as writer:
                for lid in cands:
                    real_id = lid + start_index
                    feat = {}
                    example_id = '{}'.format(real_id).encode()
                    feat['example_id'] = tf.train.Feature(
                        bytes_list=tf.train.BytesList(value=[example_id]))
                    # if test the basic example_validator for invalid event time
                    if real_id == 0 or not event_time_filter(real_id):
                        event_time = 150000000 + real_id
                        feat['event_time'] = tf.train.Feature(
                            int64_list=tf.train.Int64List(value=[event_time]))
                    example = tf.train.Example(features=tf.train.Features(
                        feature=feat))
                    writer.write(example.SerializeToString())
        succ_tag_fpath = common.encode_portal_hourly_finish_tag(
            portal_manifest.input_data_base_dir, date_time)
        with gfile.GFile(succ_tag_fpath, 'w') as fh:
            fh.write('')

    def setUp(self):
        self._setUpEtcd()
        self._setUpDataSource()
        self._setUpPortalManifest()
        self._remove_existed_dir()
        self._item_num_l = 0
        self._event_time_filter_l = lambda x: x % 877 == 0
        self._dt_l = common.convert_timestamp_to_datetime(
            self._portal_manifest_l.begin_timestamp)
        for i in range(4):
            if i == 1:
                self._missing_datetime_l = self._dt_l
                self._missing_start_index_l = self._item_num_l
                self._missing_item_cnt_l = 1 << 13
                self._item_num_l += self._missing_item_cnt_l
            else:
                self._generate_portal_input_data(self._dt_l,
                                                 self._event_time_filter_l,
                                                 self._item_num_l, 1 << 13,
                                                 self._portal_manifest_l)
                self._item_num_l += 1 << 13
            self._dt_l += timedelta(hours=1)
        self._item_num_f = 0
        self._event_time_filter_f = lambda x: x % 907 == 0
        self._dt_f = common.convert_timestamp_to_datetime(
            self._portal_manifest_f.begin_timestamp)
        for i in range(5):
            if i == 2:
                self._missing_datetime_f = self._dt_f
                self._missing_start_index_f = self._item_num_f
                self._missing_item_cnt_f = 1 << 13
            else:
                self._generate_portal_input_data(self._dt_f,
                                                 self._event_time_filter_f,
                                                 self._item_num_f, 1 << 13,
                                                 self._portal_manifest_f)
            self._item_num_f += 1 << 13
            self._dt_f += timedelta(hours=1)

        self._launch_masters()
        self._launch_workers()
        self._launch_portals()

    def _stop_workers(self):
        for w in self._workers_f:
            w.stop()
        for w in self._workers_l:
            w.stop()

    def _stop_masters(self):
        self._master_f.stop()
        self._master_l.stop()

    def _stop_portals(self):
        self._portal_f.stop()
        self._portal_l.stop()

    def _wait_timestamp(self, target_l, target_f):
        while True:
            min_datetime_l = None
            min_datetime_f = None
            for pid in range(
                    self._data_source_f.data_source_meta.partition_num):
                req_l = dj_pb.RawDataRequest(
                    partition_id=pid,
                    data_source_meta=self._data_source_l.data_source_meta)
                req_f = dj_pb.RawDataRequest(
                    partition_id=pid,
                    data_source_meta=self._data_source_f.data_source_meta)
                rsp_l = self._master_client_l.GetRawDataLatestTimeStamp(req_l)
                rsp_f = self._master_client_f.GetRawDataLatestTimeStamp(req_f)
                datetime_l = common.convert_timestamp_to_datetime(
                    rsp_l.timestamp)
                datetime_f = common.convert_timestamp_to_datetime(
                    rsp_f.timestamp)
                if min_datetime_l is None or min_datetime_l > datetime_l:
                    min_datetime_l = datetime_l
                if min_datetime_f is None or min_datetime_f > datetime_f:
                    min_datetime_f = datetime_f
            if min_datetime_l >= target_l and min_datetime_f >= target_f:
                break
            else:
                time.sleep(2)

    def test_all_pipeline(self):
        self._wait_timestamp(self._missing_datetime_l - timedelta(hours=1),
                             self._missing_datetime_f - timedelta(hours=1))
        self._generate_portal_input_data(self._missing_datetime_l,
                                         self._event_time_filter_l,
                                         self._missing_start_index_l, 1 << 13,
                                         self._portal_manifest_l)
        self._generate_portal_input_data(self._missing_datetime_f,
                                         self._event_time_filter_f,
                                         self._missing_start_index_f, 1 << 13,
                                         self._portal_manifest_f)
        self._wait_timestamp(self._dt_l - timedelta(hours=1),
                             self._dt_f - timedelta(hours=1))
        self._generate_portal_input_data(self._dt_l, self._event_time_filter_l,
                                         self._item_num_l, 1 << 13,
                                         self._portal_manifest_l)
        self._dt_l += timedelta(hours=1)
        self.assertEqual(self._dt_f, self._dt_l)
        self._wait_timestamp(self._dt_l - timedelta(hours=1),
                             self._dt_f - timedelta(hours=1))
        data_source_l = self._master_client_l.GetDataSource(empty_pb2.Empty())
        data_source_f = self._master_client_f.GetDataSource(empty_pb2.Empty())
        rd_puber_l = raw_data_publisher.RawDataPublisher(
            self._etcd_l, self._raw_data_pub_dir_l)
        rd_puber_f = raw_data_publisher.RawDataPublisher(
            self._etcd_f, self._raw_data_pub_dir_f)
        for partition_id in range(
                data_source_l.data_source_meta.partition_num):
            rd_puber_f.finish_raw_data(partition_id)
            rd_puber_l.finish_raw_data(partition_id)

        while True:
            req_l = dj_pb.DataSourceRequest(
                data_source_meta=self._data_source_l.data_source_meta)
            req_f = dj_pb.DataSourceRequest(
                data_source_meta=self._data_source_f.data_source_meta)
            dss_l = self._master_client_l.GetDataSourceStatus(req_l)
            dss_f = self._master_client_f.GetDataSourceStatus(req_f)
            self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
            self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
            if dss_l.state == common_pb.DataSourceState.Finished and \
                    dss_f.state == common_pb.DataSourceState.Finished:
                break
            else:
                time.sleep(2)
        logging.info("masters turn into Finished state")

    def tearDown(self):
        self._stop_portals()
        self._stop_masters()
        self._stop_workers()
        self._remove_existed_dir()