def test_all_pipeline(self):
        self._wait_timestamp(self._missing_datetime_l - timedelta(hours=1),
                             self._missing_datetime_f - timedelta(hours=1))
        self._generate_portal_input_data(self._missing_datetime_l,
                                         self._event_time_filter_l,
                                         self._missing_start_index_l, 1 << 13,
                                         self._portal_manifest_l)
        self._generate_portal_input_data(self._missing_datetime_f,
                                         self._event_time_filter_f,
                                         self._missing_start_index_f, 1 << 13,
                                         self._portal_manifest_f)
        self._wait_timestamp(self._dt_l - timedelta(hours=1),
                             self._dt_f - timedelta(hours=1))
        self._generate_portal_input_data(self._dt_l, self._event_time_filter_l,
                                         self._item_num_l, 1 << 13,
                                         self._portal_manifest_l)
        self._dt_l += timedelta(hours=1)
        self.assertEqual(self._dt_f, self._dt_l)
        self._wait_timestamp(self._dt_l - timedelta(hours=1),
                             self._dt_f - timedelta(hours=1))
        data_source_l = self._master_client_l.GetDataSource(empty_pb2.Empty())
        data_source_f = self._master_client_f.GetDataSource(empty_pb2.Empty())
        rd_puber_l = raw_data_publisher.RawDataPublisher(
            self._etcd_l, self._raw_data_pub_dir_l)
        rd_puber_f = raw_data_publisher.RawDataPublisher(
            self._etcd_f, self._raw_data_pub_dir_f)
        for partition_id in range(
                data_source_l.data_source_meta.partition_num):
            rd_puber_f.finish_raw_data(partition_id)
            rd_puber_l.finish_raw_data(partition_id)

        while True:
            req_l = dj_pb.DataSourceRequest(
                data_source_meta=self._data_source_l.data_source_meta)
            req_f = dj_pb.DataSourceRequest(
                data_source_meta=self._data_source_f.data_source_meta)
            dss_l = self._master_client_l.GetDataSourceStatus(req_l)
            dss_f = self._master_client_f.GetDataSourceStatus(req_f)
            self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
            self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
            if dss_l.state == common_pb.DataSourceState.Finished and \
                    dss_f.state == common_pb.DataSourceState.Finished:
                break
            else:
                time.sleep(2)
        logging.info("masters turn into Finished state")
Exemplo n.º 2
0
 def _preprocess_rsa_psi_follower(self):
     processors = []
     rsa_key_pem = None
     with gfile.GFile(self._rsa_public_key_path, 'rb') as f:
         rsa_key_pem = f.read()
     self._follower_rsa_psi_sub_dir = 'follower_rsa_psi_sub_dir'
     rd_publisher = raw_data_publisher.RawDataPublisher(
         self._kvstore_f, self._follower_rsa_psi_sub_dir)
     for partition_id in range(
             self._data_source_f.data_source_meta.partition_num):
         rd_publisher.publish_raw_data(
             partition_id, [self._psi_raw_data_fpaths_f[partition_id]])
         rd_publisher.finish_raw_data(partition_id)
         options = dj_pb.RsaPsiPreProcessorOptions(
             preprocessor_name='follower-rsa-psi-processor',
             role=common_pb.FLRole.Follower,
             rsa_key_pem=rsa_key_pem,
             input_file_subscribe_dir=self._follower_rsa_psi_sub_dir,
             output_file_dir=self._pre_processor_ouput_dir_f,
             raw_data_publish_dir=self._raw_data_pub_dir_f,
             partition_id=partition_id,
             leader_rsa_psi_signer_addr=self._rsa_psi_signer_addr,
             offload_processor_number=1,
             max_flying_sign_batch=128,
             max_flying_sign_rpc=64,
             sign_rpc_timeout_ms=100000,
             stub_fanout=2,
             slow_sign_threshold=8,
             sort_run_merger_read_ahead_buffer=1 << 20,
             sort_run_merger_read_batch_size=128,
             batch_processor_options=dj_pb.BatchProcessorOptions(
                 batch_size=1024, max_flying_item=1 << 14),
             input_raw_data=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                                 read_ahead_size=1 << 20),
             writer_options=dj_pb.WriterOptions(output_writer='CSV_DICT'))
         os.environ['ETCD_BASE_DIR'] = self.follower_base_dir
         processor = rsa_psi_preprocessor.RsaPsiPreProcessor(
             options, self.kvstore_type, True)
         processor.start_process()
         processors.append(processor)
     for processor in processors:
         processor.wait_for_finished()
Exemplo n.º 3
0
    def setUp(self):
        etcd_name = 'test_etcd'
        etcd_addrs = 'localhost:2379'
        etcd_base_dir_l = 'byefl_l'
        etcd_base_dir_f = 'byefl_f'
        data_source_name = 'test_data_source'
        etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True)
        etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True)
        etcd_l.delete_prefix(
            common.data_source_etcd_base_dir(data_source_name))
        etcd_f.delete_prefix(
            common.data_source_etcd_base_dir(data_source_name))
        data_source_l = common_pb.DataSource()
        self.raw_data_pub_dir_l = './raw_data_pub_dir_l'
        data_source_l.raw_data_sub_dir = self.raw_data_pub_dir_l
        data_source_l.role = common_pb.FLRole.Leader
        data_source_l.state = common_pb.DataSourceState.Init
        data_source_l.data_block_dir = "./data_block_l"
        data_source_l.raw_data_dir = "./raw_data_l"
        data_source_l.example_dumped_dir = "./example_dumped_l"
        data_source_f = common_pb.DataSource()
        self.raw_data_pub_dir_f = './raw_data_pub_dir_f'
        data_source_f.role = common_pb.FLRole.Follower
        data_source_f.raw_data_sub_dir = self.raw_data_pub_dir_f
        data_source_f.state = common_pb.DataSourceState.Init
        data_source_f.data_block_dir = "./data_block_f"
        data_source_f.raw_data_dir = "./raw_data_f"
        data_source_f.example_dumped_dir = "./example_dumped_f"
        data_source_meta = common_pb.DataSourceMeta()
        data_source_meta.name = data_source_name
        data_source_meta.partition_num = 2
        data_source_meta.start_time = 0
        data_source_meta.end_time = 100000000
        data_source_l.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_l, data_source_l)
        data_source_f.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_f, data_source_f)
        master_options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True)

        master_addr_l = 'localhost:4061'
        master_addr_f = 'localhost:4062'
        master_l = data_join_master.DataJoinMasterService(
            int(master_addr_l.split(':')[1]),
            master_addr_f,
            data_source_name,
            etcd_name,
            etcd_base_dir_l,
            etcd_addrs,
            master_options,
        )
        master_l.start()
        master_f = data_join_master.DataJoinMasterService(
            int(master_addr_f.split(':')[1]), master_addr_l, data_source_name,
            etcd_name, etcd_base_dir_f, etcd_addrs, master_options)
        master_f.start()
        channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL)
        master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l)
        channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL)
        master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f)

        while True:
            req_l = dj_pb.DataSourceRequest(
                data_source_meta=data_source_l.data_source_meta)
            req_f = dj_pb.DataSourceRequest(
                data_source_meta=data_source_f.data_source_meta)
            dss_l = master_client_l.GetDataSourceStatus(req_l)
            dss_f = master_client_f.GetDataSourceStatus(req_f)
            self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
            self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
            if dss_l.state == common_pb.DataSourceState.Processing and \
                    dss_f.state == common_pb.DataSourceState.Processing:
                break
            else:
                time.sleep(2)

        self.master_client_l = master_client_l
        self.master_client_f = master_client_f
        self.master_addr_l = master_addr_l
        self.master_addr_f = master_addr_f
        self.etcd_l = etcd_l
        self.etcd_f = etcd_f
        self.data_source_l = data_source_l
        self.data_source_f = data_source_f
        self.master_l = master_l
        self.master_f = master_f
        self.data_source_name = data_source_name,
        self.etcd_name = etcd_name
        self.etcd_addrs = etcd_addrs
        self.etcd_base_dir_l = etcd_base_dir_l
        self.etcd_base_dir_f = etcd_base_dir_f
        self.raw_data_publisher_l = raw_data_publisher.RawDataPublisher(
            self.etcd_l, self.raw_data_pub_dir_l)
        self.raw_data_publisher_f = raw_data_publisher.RawDataPublisher(
            self.etcd_f, self.raw_data_pub_dir_f)
        if gfile.Exists(data_source_l.data_block_dir):
            gfile.DeleteRecursively(data_source_l.data_block_dir)
        if gfile.Exists(data_source_l.example_dumped_dir):
            gfile.DeleteRecursively(data_source_l.example_dumped_dir)
        if gfile.Exists(data_source_l.raw_data_dir):
            gfile.DeleteRecursively(data_source_l.raw_data_dir)
        if gfile.Exists(data_source_f.data_block_dir):
            gfile.DeleteRecursively(data_source_f.data_block_dir)
        if gfile.Exists(data_source_f.example_dumped_dir):
            gfile.DeleteRecursively(data_source_f.example_dumped_dir)
        if gfile.Exists(data_source_f.raw_data_dir):
            gfile.DeleteRecursively(data_source_f.raw_data_dir)

        self.worker_options = dj_pb.DataJoinWorkerOptions(
            use_mock_etcd=True,
            raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                                  compressed_type=''),
            example_id_dump_options=dj_pb.ExampleIdDumpOptions(
                example_id_dump_interval=1, example_id_dump_threshold=1024),
            example_joiner_options=dj_pb.ExampleJoinerOptions(
                example_joiner='STREAM_JOINER',
                min_matching_window=64,
                max_matching_window=256,
                data_block_dump_interval=30,
                data_block_dump_threshold=1000),
            batch_processor_options=dj_pb.BatchProcessorOptions(
                batch_size=512, max_flying_item=2048),
            data_block_builder_options=dj_pb.WriterOptions(
                output_writer='TF_RECORD'))

        self.total_index = 1 << 13
Exemplo n.º 4
0
    def setUp(self):
        etcd_name = 'test_etcd'
        etcd_addrs = 'localhost:2379'
        etcd_base_dir_l = 'byefl_l'
        etcd_base_dir_f= 'byefl_f'
        data_source_name = 'test_data_source'
        etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True)
        etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True)
        etcd_l.delete_prefix(common.data_source_etcd_base_dir(data_source_name))
        etcd_f.delete_prefix(common.data_source_etcd_base_dir(data_source_name))
        data_source_l = common_pb.DataSource()
        self.raw_data_pub_dir_l = './raw_data_pub_dir_l'
        data_source_l.raw_data_sub_dir = self.raw_data_pub_dir_l
        data_source_l.role = common_pb.FLRole.Leader
        data_source_l.state = common_pb.DataSourceState.Init
        data_source_l.output_base_dir = "./ds_output_l"
        self.raw_data_dir_l = "./raw_data_l"
        data_source_f = common_pb.DataSource()
        self.raw_data_pub_dir_f = './raw_data_pub_dir_f'
        data_source_f.role = common_pb.FLRole.Follower
        data_source_f.raw_data_sub_dir = self.raw_data_pub_dir_f
        data_source_f.state = common_pb.DataSourceState.Init
        data_source_f.output_base_dir = "./ds_output_f"
        self.raw_data_dir_f = "./raw_data_f"
        data_source_meta = common_pb.DataSourceMeta()
        data_source_meta.name = data_source_name
        data_source_meta.partition_num = 2
        data_source_meta.start_time = 0
        data_source_meta.end_time = 100000000
        data_source_l.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_l, data_source_l)
        data_source_f.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_f, data_source_f)

        self.etcd_l = etcd_l
        self.etcd_f = etcd_f
        self.data_source_l = data_source_l
        self.data_source_f = data_source_f
        self.data_source_name = data_source_name
        self.etcd_name = etcd_name
        self.etcd_addrs = etcd_addrs
        self.etcd_base_dir_l = etcd_base_dir_l
        self.etcd_base_dir_f = etcd_base_dir_f
        self.raw_data_publisher_l = raw_data_publisher.RawDataPublisher(
                self.etcd_l, self.raw_data_pub_dir_l
            )
        self.raw_data_publisher_f = raw_data_publisher.RawDataPublisher(
                self.etcd_f, self.raw_data_pub_dir_f
            )
        if gfile.Exists(data_source_l.output_base_dir):
            gfile.DeleteRecursively(data_source_l.output_base_dir)
        if gfile.Exists(self.raw_data_dir_l):
            gfile.DeleteRecursively(self.raw_data_dir_l)
        if gfile.Exists(data_source_f.output_base_dir):
            gfile.DeleteRecursively(data_source_f.output_base_dir)
        if gfile.Exists(self.raw_data_dir_f):
            gfile.DeleteRecursively(self.raw_data_dir_f)

        self.worker_options = dj_pb.DataJoinWorkerOptions(
                use_mock_etcd=True,
                raw_data_options=dj_pb.RawDataOptions(
                    raw_data_iter='TF_RECORD',
                    read_ahead_size=1<<20,
                    read_batch_size=128
                ),
                example_id_dump_options=dj_pb.ExampleIdDumpOptions(
                    example_id_dump_interval=1,
                    example_id_dump_threshold=1024
                ),
                example_joiner_options=dj_pb.ExampleJoinerOptions(
                    example_joiner='STREAM_JOINER',
                    min_matching_window=64,
                    max_matching_window=256,
                    data_block_dump_interval=30,
                    data_block_dump_threshold=1000
                ),
                batch_processor_options=dj_pb.BatchProcessorOptions(
                    batch_size=512,
                    max_flying_item=2048
                ),
                data_block_builder_options=dj_pb.WriterOptions(
                    output_writer='TF_RECORD'
                )
            )

        self.total_index = 1 << 12