def test_all_assembly(self):
        for i in range(self.data_source_l.data_source_meta.partition_num):
            self.generate_raw_data(
                self.etcd_l, self.raw_data_controller_l, self.data_source_l, i,
                2048, 64, 'leader_key_partition_{}'.format(i) + ':{}',
                'leader_value_partition_{}'.format(i) + ':{}')
            self.generate_raw_data(
                self.etcd_f, self.raw_data_controller_f, self.data_source_f, i,
                4096, 128, 'follower_key_partition_{}'.format(i) + ':{}',
                'follower_value_partition_{}'.format(i) + ':{}')

        worker_addr_l = 'localhost:4161'
        worker_addr_f = 'localhost:4162'

        worker_l = data_join_worker.DataJoinWorkerService(
            int(worker_addr_l.split(':')[1]), worker_addr_f,
            self.master_addr_l, 0, self.etcd_name, self.etcd_base_dir_l,
            self.etcd_addrs, self.worker_options)

        worker_f = data_join_worker.DataJoinWorkerService(
            int(worker_addr_f.split(':')[1]), worker_addr_l,
            self.master_addr_f, 0, self.etcd_name, self.etcd_base_dir_f,
            self.etcd_addrs, self.worker_options)

        worker_l.start()
        worker_f.start()

        for i in range(self.data_source_f.data_source_meta.partition_num):
            rdmreq = dj_pb.RawDataRequest(
                data_source_meta=self.data_source_l.data_source_meta,
                partition_id=i,
                finish_raw_data=empty_pb2.Empty())
            rsp = self.master_client_l.FinishRawData(rdmreq)
            self.assertEqual(rsp.code, 0)
            rsp = self.master_client_f.FinishRawData(rdmreq)
            self.assertEqual(rsp.code, 0)

        while True:
            req_l = dj_pb.DataSourceRequest(
                data_source_meta=self.data_source_l.data_source_meta)
            req_f = dj_pb.DataSourceRequest(
                data_source_meta=self.data_source_f.data_source_meta)
            dss_l = self.master_client_l.GetDataSourceStatus(req_l)
            dss_f = self.master_client_f.GetDataSourceStatus(req_f)
            self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
            self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
            if dss_l.state == common_pb.DataSourceState.Finished and \
                    dss_f.state == common_pb.DataSourceState.Finished:
                break
            else:
                time.sleep(2)

        worker_l.stop()
        worker_f.stop()
        self.master_l.stop()
        self.master_f.stop()
Esempio n. 2
0
    def test_all_assembly(self):
        for i in range(self.data_source_l.data_source_meta.partition_num):
            self.generate_raw_data(
                self.data_source_l, i, 2048, 64,
                'leader_key_partition_{}'.format(i) + ':{}',
                'leader_value_partition_{}'.format(i) + ':{}')
            self.generate_raw_data(
                self.data_source_f, i, 4096, 128,
                'follower_key_partition_{}'.format(i) + ':{}',
                'follower_value_partition_{}'.format(i) + ':{}')

        worker_addr_l = 'localhost:4161'
        worker_addr_f = 'localhost:4162'

        options = customized_options.CustomizedOptions()
        options.set_raw_data_iter('TF_RECORD')
        options.set_example_joiner('STREAM_JOINER')
        worker_l = data_join_worker.DataJoinWorkerService(
            int(worker_addr_l.split(':')[1]), worker_addr_f,
            self.master_addr_l, 0, self.etcd_name, self.etcd_base_dir_l,
            self.etcd_addrs, options)

        worker_f = data_join_worker.DataJoinWorkerService(
            int(worker_addr_f.split(':')[1]), worker_addr_l,
            self.master_addr_f, 0, self.etcd_name, self.etcd_base_dir_f,
            self.etcd_addrs, options)

        worker_l.start()
        worker_f.start()

        while True:
            rsp_l = self.master_client_l.GetDataSourceState(
                self.data_source_l.data_source_meta)
            rsp_f = self.master_client_f.GetDataSourceState(
                self.data_source_f.data_source_meta)
            self.assertEqual(rsp_l.status.code, 0)
            self.assertEqual(rsp_l.role, common_pb.FLRole.Leader)
            self.assertEqual(rsp_l.data_source_type,
                             common_pb.DataSourceType.Sequential)
            self.assertEqual(rsp_f.status.code, 0)
            self.assertEqual(rsp_f.role, common_pb.FLRole.Follower)
            self.assertEqual(rsp_f.data_source_type,
                             common_pb.DataSourceType.Sequential)
            if (rsp_l.state == common_pb.DataSourceState.Finished
                    and rsp_f.state == common_pb.DataSourceState.Finished):
                break
            else:
                time.sleep(2)

        worker_l.stop()
        worker_f.stop()
        self.master_l.stop()
        self.master_f.stop()
Esempio n. 3
0
 def _launch_workers(self):
     worker_options = dj_pb.DataJoinWorkerOptions(
         use_mock_etcd=True,
         raw_data_options=dj_pb.RawDataOptions(raw_data_iter='CSV_DICT',
                                               compressed_type=''),
         example_id_dump_options=dj_pb.ExampleIdDumpOptions(
             example_id_dump_interval=1, example_id_dump_threshold=1024),
         example_joiner_options=dj_pb.ExampleJoinerOptions(
             example_joiner='SORT_RUN_JOINER',
             min_matching_window=64,
             max_matching_window=256,
             data_block_dump_interval=30,
             data_block_dump_threshold=1000),
         batch_processor_options=dj_pb.BatchProcessorOptions(
             batch_size=1024, max_flying_item=4096),
         data_block_builder_options=dj_pb.DataBlockBuilderOptions(
             data_block_builder='CSV_DICT_DATABLOCK_BUILDER'))
     self._worker_addrs_l = [
         'localhost:4161', 'localhost:4162', 'localhost:4163',
         'localhost:4164'
     ]
     self._worker_addrs_f = [
         'localhost:5161', 'localhost:5162', 'localhost:5163',
         'localhost:5164'
     ]
     self._workers_l = []
     self._workers_f = []
     for rank_id in range(4):
         worker_addr_l = self._worker_addrs_l[rank_id]
         worker_addr_f = self._worker_addrs_f[rank_id]
         self._workers_l.append(
             data_join_worker.DataJoinWorkerService(
                 int(worker_addr_l.split(':')[1]), worker_addr_f,
                 self._master_addr_l, rank_id, self._etcd_name,
                 self._etcd_base_dir_l, self._etcd_addrs, worker_options))
         self._workers_f.append(
             data_join_worker.DataJoinWorkerService(
                 int(worker_addr_f.split(':')[1]), worker_addr_l,
                 self._master_addr_f, rank_id, self._etcd_name,
                 self._etcd_base_dir_f, self._etcd_addrs, worker_options))
     for w in self._workers_l:
         w.start()
     for w in self._workers_f:
         w.start()
Esempio n. 4
0
 def _launch_workers(self):
     worker_options = dj_pb.DataJoinWorkerOptions(
         use_mock_etcd=True,
         raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                               compressed_type=''),
         example_id_dump_options=dj_pb.ExampleIdDumpOptions(
             example_id_dump_interval=1, example_id_dump_threshold=1024),
         example_joiner_options=dj_pb.ExampleJoinerOptions(
             example_joiner='STREAM_JOINER',
             min_matching_window=64,
             max_matching_window=256,
             data_block_dump_interval=30,
             data_block_dump_threshold=1000),
         example_id_batch_options=dj_pb.ExampleIdBatchOptions(
             example_id_batch_size=1024, max_flying_example_id=4096))
     self._worker_addrs_l = ['localhost:4161', 'localhost:4162']
     self._worker_addrs_f = ['localhost:5161', 'localhost:5162']
     self._workers_l = []
     self._workers_f = []
     for rank_id in range(2):
         worker_addr_l = self._worker_addrs_l[rank_id]
         worker_addr_f = self._worker_addrs_f[rank_id]
         self._workers_l.append(
             data_join_worker.DataJoinWorkerService(
                 int(worker_addr_l.split(':')[1]), worker_addr_f,
                 self._master_addr_l, rank_id, self._etcd_name,
                 self._etcd_base_dir_l, self._etcd_addrs, worker_options))
         self._workers_f.append(
             data_join_worker.DataJoinWorkerService(
                 int(worker_addr_f.split(':')[1]), worker_addr_l,
                 self._master_addr_f, rank_id, self._etcd_name,
                 self._etcd_base_dir_f, self._etcd_addrs, worker_options))
     for w in self._workers_l:
         w.start()
     for w in self._workers_f:
         w.start()
Esempio n. 5
0
    def _inner_test_round(self, start_index):
        for i in range(self.data_source_l.data_source_meta.partition_num):
            self.generate_raw_data(
                    start_index, self.etcd_l, self.raw_data_publisher_l,
                    self.data_source_l, self.raw_data_dir_l, i, 2048, 64,
                    'leader_key_partition_{}'.format(i) + ':{}',
                    'leader_value_partition_{}'.format(i) + ':{}'
                )
            self.generate_raw_data(
                    start_index, self.etcd_f, self.raw_data_publisher_f,
                    self.data_source_f, self.raw_data_dir_f, i, 4096, 128,
                    'follower_key_partition_{}'.format(i) + ':{}',
                    'follower_value_partition_{}'.format(i) + ':{}'
                )

        master_addr_l = 'localhost:4061'
        master_addr_f = 'localhost:4062'
        master_options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True,
                                                     batch_mode=True)
        master_l = data_join_master.DataJoinMasterService(
                int(master_addr_l.split(':')[1]), master_addr_f,
                self.data_source_name, self.etcd_name, self.etcd_base_dir_l,
                self.etcd_addrs, master_options,
            )
        master_l.start()
        master_f = data_join_master.DataJoinMasterService(
                int(master_addr_f.split(':')[1]), master_addr_l,
                self.data_source_name, self.etcd_name, self.etcd_base_dir_f,
                self.etcd_addrs, master_options
            )
        master_f.start()
        channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL)
        master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l)
        channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL)
        master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f)

        while True:
            try:
                req_l = dj_pb.DataSourceRequest(
                        data_source_meta=self.data_source_l.data_source_meta
                    )
                req_f = dj_pb.DataSourceRequest(
                        data_source_meta=self.data_source_f.data_source_meta
                    )
                dss_l = master_client_l.GetDataSourceStatus(req_l)
                dss_f = master_client_f.GetDataSourceStatus(req_f)
                self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
                self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
                if dss_l.state == common_pb.DataSourceState.Processing and \
                        dss_f.state == common_pb.DataSourceState.Processing:
                    break
            except Exception as e:
                pass
            time.sleep(2)

        worker_addr_l = 'localhost:4161'
        worker_addr_f = 'localhost:4162'

        worker_l = data_join_worker.DataJoinWorkerService(
                int(worker_addr_l.split(':')[1]),
                worker_addr_f, master_addr_l, 0,
                self.etcd_name, self.etcd_base_dir_l,
                self.etcd_addrs, self.worker_options
            )

        worker_f = data_join_worker.DataJoinWorkerService(
                int(worker_addr_f.split(':')[1]),
                worker_addr_l, master_addr_f, 0,
                self.etcd_name, self.etcd_base_dir_f,
                self.etcd_addrs, self.worker_options
            )

        th_l = threading.Thread(target=worker_l.run, name='worker_l')
        th_f = threading.Thread(target=worker_f.run, name='worker_f')

        th_l.start()
        th_f.start()

        while True:
            try:
                req_l = dj_pb.DataSourceRequest(
                        data_source_meta=self.data_source_l.data_source_meta
                    )
                req_f = dj_pb.DataSourceRequest(
                        data_source_meta=self.data_source_f.data_source_meta
                    )
                dss_l = master_client_l.GetDataSourceStatus(req_l)
                dss_f = master_client_f.GetDataSourceStatus(req_f)
                self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
                self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
                if dss_l.state == common_pb.DataSourceState.Ready and \
                        dss_f.state == common_pb.DataSourceState.Ready:
                    break
            except Exception as e: #xx
                pass
            time.sleep(2)

        th_l.join()
        th_f.join()
        master_l.stop()
        master_f.stop()
Esempio n. 6
0
    def _launch_workers(self):
        worker_options_l = dj_pb.DataJoinWorkerOptions(
            use_mock_etcd=True,
            raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                                  read_ahead_size=1 << 20,
                                                  read_batch_size=128),
            example_id_dump_options=dj_pb.ExampleIdDumpOptions(
                example_id_dump_interval=1, example_id_dump_threshold=1024),
            example_joiner_options=dj_pb.ExampleJoinerOptions(
                example_joiner='SORT_RUN_JOINER',
                min_matching_window=64,
                max_matching_window=256,
                data_block_dump_interval=30,
                data_block_dump_threshold=1000),
            batch_processor_options=dj_pb.BatchProcessorOptions(
                batch_size=1024, max_flying_item=4096),
            data_block_builder_options=dj_pb.WriterOptions(
                output_writer='CSV_DICT'))
        worker_options_f = dj_pb.DataJoinWorkerOptions(
            use_mock_etcd=True,
            raw_data_options=dj_pb.RawDataOptions(raw_data_iter='CSV_DICT',
                                                  read_ahead_size=1 << 20,
                                                  read_batch_size=128),
            example_id_dump_options=dj_pb.ExampleIdDumpOptions(
                example_id_dump_interval=1, example_id_dump_threshold=1024),
            example_joiner_options=dj_pb.ExampleJoinerOptions(
                example_joiner='SORT_RUN_JOINER',
                min_matching_window=64,
                max_matching_window=256,
                data_block_dump_interval=30,
                data_block_dump_threshold=1000),
            batch_processor_options=dj_pb.BatchProcessorOptions(
                batch_size=1024, max_flying_item=4096),
            data_block_builder_options=dj_pb.WriterOptions(
                output_writer='TF_RECORD'))

        self._worker_addrs_l = [
            'localhost:4161', 'localhost:4162', 'localhost:4163',
            'localhost:4164'
        ]
        self._worker_addrs_f = [
            'localhost:5161', 'localhost:5162', 'localhost:5163',
            'localhost:5164'
        ]
        self._workers_l = []
        self._workers_f = []
        for rank_id in range(4):
            worker_addr_l = self._worker_addrs_l[rank_id]
            worker_addr_f = self._worker_addrs_f[rank_id]
            os.environ['ETCD_BASE_DIR'] = self.leader_base_dir
            self._workers_l.append(
                data_join_worker.DataJoinWorkerService(
                    int(worker_addr_l.split(':')[1]), worker_addr_f,
                    self._master_addr_l, rank_id, self.kvstore_type,
                    worker_options_l))
            os.environ['ETCD_BASE_DIR'] = self.follower_base_dir
            self._workers_f.append(
                data_join_worker.DataJoinWorkerService(
                    int(worker_addr_f.split(':')[1]), worker_addr_l,
                    self._master_addr_f, rank_id, self.kvstore_type,
                    worker_options_f))
        for w in self._workers_l:
            w.start()
        for w in self._workers_f:
            w.start()