Exemplo n.º 1
0
    def _launch_masters(self):
        self._master_addr_l = 'localhost:4061'
        self._master_addr_f = 'localhost:4062'
        master_options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True)
        self._master_l = data_join_master.DataJoinMasterService(
            int(self._master_addr_l.split(':')[1]), self._master_addr_f,
            self._data_source_name, self._db_database, self._db_base_dir_l,
            self._db_addr, self._db_username_l, self._db_password_l,
            master_options)
        self._master_f = data_join_master.DataJoinMasterService(
            int(self._master_addr_f.split(':')[1]), self._master_addr_l,
            self._data_source_name, self._db_database, self._db_base_dir_f,
            self._db_addr, self._db_username_f, self._db_password_f,
            master_options)
        self._master_f.start()
        self._master_l.start()
        channel_l = make_insecure_channel(self._master_addr_l,
                                          ChannelType.INTERNAL)
        self._master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l)
        channel_f = make_insecure_channel(self._master_addr_f,
                                          ChannelType.INTERNAL)
        self._master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f)

        while True:
            req_l = dj_pb.DataSourceRequest(
                data_source_meta=self._data_source_l.data_source_meta)
            req_f = dj_pb.DataSourceRequest(
                data_source_meta=self._data_source_f.data_source_meta)
            dss_l = self._master_client_l.GetDataSourceStatus(req_l)
            dss_f = self._master_client_f.GetDataSourceStatus(req_f)
            self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
            self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
            if dss_l.state == common_pb.DataSourceState.Processing and \
                    dss_f.state == common_pb.DataSourceState.Processing:
                break
            else:
                time.sleep(2)
        logging.info("masters turn into Processing state")
Exemplo n.º 2
0
    def setUp(self):
        etcd_name = 'test_etcd'
        etcd_addrs = 'localhost:2379'
        etcd_base_dir_l = 'byefl_l'
        etcd_base_dir_f = 'byefl_f'
        data_source_name = 'test_data_source'
        etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True)
        etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True)
        etcd_l.delete_prefix(
            common.data_source_etcd_base_dir(data_source_name))
        etcd_f.delete_prefix(
            common.data_source_etcd_base_dir(data_source_name))
        data_source_l = common_pb.DataSource()
        self.raw_data_pub_dir_l = './raw_data_pub_dir_l'
        data_source_l.raw_data_sub_dir = self.raw_data_pub_dir_l
        data_source_l.role = common_pb.FLRole.Leader
        data_source_l.state = common_pb.DataSourceState.Init
        data_source_l.data_block_dir = "./data_block_l"
        data_source_l.raw_data_dir = "./raw_data_l"
        data_source_l.example_dumped_dir = "./example_dumped_l"
        data_source_f = common_pb.DataSource()
        self.raw_data_pub_dir_f = './raw_data_pub_dir_f'
        data_source_f.role = common_pb.FLRole.Follower
        data_source_f.raw_data_sub_dir = self.raw_data_pub_dir_f
        data_source_f.state = common_pb.DataSourceState.Init
        data_source_f.data_block_dir = "./data_block_f"
        data_source_f.raw_data_dir = "./raw_data_f"
        data_source_f.example_dumped_dir = "./example_dumped_f"
        data_source_meta = common_pb.DataSourceMeta()
        data_source_meta.name = data_source_name
        data_source_meta.partition_num = 2
        data_source_meta.start_time = 0
        data_source_meta.end_time = 100000000
        data_source_l.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_l, data_source_l)
        data_source_f.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_f, data_source_f)
        master_options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True)

        master_addr_l = 'localhost:4061'
        master_addr_f = 'localhost:4062'
        master_l = data_join_master.DataJoinMasterService(
            int(master_addr_l.split(':')[1]),
            master_addr_f,
            data_source_name,
            etcd_name,
            etcd_base_dir_l,
            etcd_addrs,
            master_options,
        )
        master_l.start()
        master_f = data_join_master.DataJoinMasterService(
            int(master_addr_f.split(':')[1]), master_addr_l, data_source_name,
            etcd_name, etcd_base_dir_f, etcd_addrs, master_options)
        master_f.start()
        channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL)
        master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l)
        channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL)
        master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f)

        while True:
            req_l = dj_pb.DataSourceRequest(
                data_source_meta=data_source_l.data_source_meta)
            req_f = dj_pb.DataSourceRequest(
                data_source_meta=data_source_f.data_source_meta)
            dss_l = master_client_l.GetDataSourceStatus(req_l)
            dss_f = master_client_f.GetDataSourceStatus(req_f)
            self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
            self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
            if dss_l.state == common_pb.DataSourceState.Processing and \
                    dss_f.state == common_pb.DataSourceState.Processing:
                break
            else:
                time.sleep(2)

        self.master_client_l = master_client_l
        self.master_client_f = master_client_f
        self.master_addr_l = master_addr_l
        self.master_addr_f = master_addr_f
        self.etcd_l = etcd_l
        self.etcd_f = etcd_f
        self.data_source_l = data_source_l
        self.data_source_f = data_source_f
        self.master_l = master_l
        self.master_f = master_f
        self.data_source_name = data_source_name,
        self.etcd_name = etcd_name
        self.etcd_addrs = etcd_addrs
        self.etcd_base_dir_l = etcd_base_dir_l
        self.etcd_base_dir_f = etcd_base_dir_f
        self.raw_data_publisher_l = raw_data_publisher.RawDataPublisher(
            self.etcd_l, self.raw_data_pub_dir_l)
        self.raw_data_publisher_f = raw_data_publisher.RawDataPublisher(
            self.etcd_f, self.raw_data_pub_dir_f)
        if gfile.Exists(data_source_l.data_block_dir):
            gfile.DeleteRecursively(data_source_l.data_block_dir)
        if gfile.Exists(data_source_l.example_dumped_dir):
            gfile.DeleteRecursively(data_source_l.example_dumped_dir)
        if gfile.Exists(data_source_l.raw_data_dir):
            gfile.DeleteRecursively(data_source_l.raw_data_dir)
        if gfile.Exists(data_source_f.data_block_dir):
            gfile.DeleteRecursively(data_source_f.data_block_dir)
        if gfile.Exists(data_source_f.example_dumped_dir):
            gfile.DeleteRecursively(data_source_f.example_dumped_dir)
        if gfile.Exists(data_source_f.raw_data_dir):
            gfile.DeleteRecursively(data_source_f.raw_data_dir)

        self.worker_options = dj_pb.DataJoinWorkerOptions(
            use_mock_etcd=True,
            raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                                  compressed_type=''),
            example_id_dump_options=dj_pb.ExampleIdDumpOptions(
                example_id_dump_interval=1, example_id_dump_threshold=1024),
            example_joiner_options=dj_pb.ExampleJoinerOptions(
                example_joiner='STREAM_JOINER',
                min_matching_window=64,
                max_matching_window=256,
                data_block_dump_interval=30,
                data_block_dump_threshold=1000),
            batch_processor_options=dj_pb.BatchProcessorOptions(
                batch_size=512, max_flying_item=2048),
            data_block_builder_options=dj_pb.WriterOptions(
                output_writer='TF_RECORD'))

        self.total_index = 1 << 13
Exemplo n.º 3
0
    def _inner_test_round(self, start_index):
        for i in range(self.data_source_l.data_source_meta.partition_num):
            self.generate_raw_data(
                    start_index, self.etcd_l, self.raw_data_publisher_l,
                    self.data_source_l, self.raw_data_dir_l, i, 2048, 64,
                    'leader_key_partition_{}'.format(i) + ':{}',
                    'leader_value_partition_{}'.format(i) + ':{}'
                )
            self.generate_raw_data(
                    start_index, self.etcd_f, self.raw_data_publisher_f,
                    self.data_source_f, self.raw_data_dir_f, i, 4096, 128,
                    'follower_key_partition_{}'.format(i) + ':{}',
                    'follower_value_partition_{}'.format(i) + ':{}'
                )

        master_addr_l = 'localhost:4061'
        master_addr_f = 'localhost:4062'
        master_options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True,
                                                     batch_mode=True)
        master_l = data_join_master.DataJoinMasterService(
                int(master_addr_l.split(':')[1]), master_addr_f,
                self.data_source_name, self.etcd_name, self.etcd_base_dir_l,
                self.etcd_addrs, master_options,
            )
        master_l.start()
        master_f = data_join_master.DataJoinMasterService(
                int(master_addr_f.split(':')[1]), master_addr_l,
                self.data_source_name, self.etcd_name, self.etcd_base_dir_f,
                self.etcd_addrs, master_options
            )
        master_f.start()
        channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL)
        master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l)
        channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL)
        master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f)

        while True:
            try:
                req_l = dj_pb.DataSourceRequest(
                        data_source_meta=self.data_source_l.data_source_meta
                    )
                req_f = dj_pb.DataSourceRequest(
                        data_source_meta=self.data_source_f.data_source_meta
                    )
                dss_l = master_client_l.GetDataSourceStatus(req_l)
                dss_f = master_client_f.GetDataSourceStatus(req_f)
                self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
                self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
                if dss_l.state == common_pb.DataSourceState.Processing and \
                        dss_f.state == common_pb.DataSourceState.Processing:
                    break
            except Exception as e:
                pass
            time.sleep(2)

        worker_addr_l = 'localhost:4161'
        worker_addr_f = 'localhost:4162'

        worker_l = data_join_worker.DataJoinWorkerService(
                int(worker_addr_l.split(':')[1]),
                worker_addr_f, master_addr_l, 0,
                self.etcd_name, self.etcd_base_dir_l,
                self.etcd_addrs, self.worker_options
            )

        worker_f = data_join_worker.DataJoinWorkerService(
                int(worker_addr_f.split(':')[1]),
                worker_addr_l, master_addr_f, 0,
                self.etcd_name, self.etcd_base_dir_f,
                self.etcd_addrs, self.worker_options
            )

        th_l = threading.Thread(target=worker_l.run, name='worker_l')
        th_f = threading.Thread(target=worker_f.run, name='worker_f')

        th_l.start()
        th_f.start()

        while True:
            try:
                req_l = dj_pb.DataSourceRequest(
                        data_source_meta=self.data_source_l.data_source_meta
                    )
                req_f = dj_pb.DataSourceRequest(
                        data_source_meta=self.data_source_f.data_source_meta
                    )
                dss_l = master_client_l.GetDataSourceStatus(req_l)
                dss_f = master_client_f.GetDataSourceStatus(req_f)
                self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
                self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
                if dss_l.state == common_pb.DataSourceState.Ready and \
                        dss_f.state == common_pb.DataSourceState.Ready:
                    break
            except Exception as e: #xx
                pass
            time.sleep(2)

        th_l.join()
        th_f.join()
        master_l.stop()
        master_f.stop()
Exemplo n.º 4
0
    def test_api(self):
        logging.getLogger().setLevel(logging.DEBUG)
        etcd_name = 'test_etcd'
        etcd_addrs = 'localhost:2379'
        etcd_base_dir_l = 'byefl_l'
        etcd_base_dir_f = 'byefl_f'
        data_source_name = 'test_data_source'
        etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True)
        etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True)
        etcd_l.delete_prefix(
            common.data_source_etcd_base_dir(data_source_name))
        etcd_f.delete_prefix(
            common.data_source_etcd_base_dir(data_source_name))
        data_source_l = common_pb.DataSource()
        data_source_l.role = common_pb.FLRole.Leader
        data_source_l.state = common_pb.DataSourceState.Init
        data_source_l.output_base_dir = "./ds_output_l"
        data_source_f = common_pb.DataSource()
        data_source_f.role = common_pb.FLRole.Follower
        data_source_f.state = common_pb.DataSourceState.Init
        data_source_f.output_base_dir = "./ds_output_f"
        data_source_meta = common_pb.DataSourceMeta()
        data_source_meta.name = data_source_name
        data_source_meta.partition_num = 1
        data_source_meta.start_time = 0
        data_source_meta.end_time = 100000000
        data_source_l.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_l, data_source_l)
        data_source_f.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_f, data_source_f)

        master_addr_l = 'localhost:4061'
        master_addr_f = 'localhost:4062'
        options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True)
        master_l = data_join_master.DataJoinMasterService(
            int(master_addr_l.split(':')[1]), master_addr_f, data_source_name,
            etcd_name, etcd_base_dir_l, etcd_addrs, options)
        master_l.start()
        master_f = data_join_master.DataJoinMasterService(
            int(master_addr_f.split(':')[1]), master_addr_l, data_source_name,
            etcd_name, etcd_base_dir_f, etcd_addrs, options)
        master_f.start()
        channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL)
        client_l = dj_grpc.DataJoinMasterServiceStub(channel_l)
        channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL)
        client_f = dj_grpc.DataJoinMasterServiceStub(channel_f)

        while True:
            req_l = dj_pb.DataSourceRequest(
                data_source_meta=data_source_l.data_source_meta)
            req_f = dj_pb.DataSourceRequest(
                data_source_meta=data_source_f.data_source_meta)
            dss_l = client_l.GetDataSourceStatus(req_l)
            dss_f = client_f.GetDataSourceStatus(req_f)
            self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
            self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
            if dss_l.state == common_pb.DataSourceState.Processing and \
                    dss_f.state == common_pb.DataSourceState.Processing:
                break
            else:
                time.sleep(2)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=-1,
            join_example=empty_pb2.Empty())

        rdrsp = client_f.RequestJoinPartition(rdreq)
        self.assertTrue(rdrsp.status.code == 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0)
        self.assertEqual(rdrsp.manifest.join_example_rep.state,
                         dj_pb.JoinExampleState.Joining)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        #check idempotent
        rdrsp = client_f.RequestJoinPartition(rdreq)
        self.assertTrue(rdrsp.status.code == 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0)
        self.assertEqual(rdrsp.manifest.join_example_rep.state,
                         dj_pb.JoinExampleState.Joining)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            join_example=empty_pb2.Empty())
        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertTrue(rdrsp.status.code == 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0)
        self.assertEqual(rdrsp.manifest.join_example_rep.state,
                         dj_pb.JoinExampleState.Joining)
        self.assertEqual(rdrsp.manifest.partition_id, 0)
        #check idempotent
        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertTrue(rdrsp.status.code == 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0)
        self.assertEqual(rdrsp.manifest.join_example_rep.state,
                         dj_pb.JoinExampleState.Joining)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=1,
            partition_id=-1,
            sync_example_id=empty_pb2.Empty())
        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1)
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.state,
                         dj_pb.SyncExampleIdState.Syncing)
        self.assertEqual(rdrsp.manifest.partition_id, 0)
        #check idempotent
        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1)
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.state,
                         dj_pb.SyncExampleIdState.Syncing)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=1,
            partition_id=0,
            sync_example_id=empty_pb2.Empty())
        rdrsp = client_f.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1)
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.state,
                         dj_pb.SyncExampleIdState.Syncing)
        self.assertEqual(rdrsp.manifest.partition_id, 0)
        #check idempotent
        rdrsp = client_f.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1)
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.state,
                         dj_pb.SyncExampleIdState.Syncing)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdreq1 = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=1,
            partition_id=0,
            sync_example_id=empty_pb2.Empty())

        try:
            rsp = client_l.FinishJoinPartition(rdreq1)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        rdreq2 = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            join_example=empty_pb2.Empty())
        try:
            rsp = client_l.FinishJoinPartition(rdreq2)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        manifest_l = client_l.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_l is not None)
        self.assertFalse(manifest_l.finished)
        self.assertEqual(manifest_l.next_process_index, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_l.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 0)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=False,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='a',
                        timestamp=timestamp_pb2.Timestamp(seconds=3))
                ]))
        rsp = client_l.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_l = client_l.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_l is not None)
        self.assertFalse(manifest_l.finished)
        self.assertEqual(manifest_l.next_process_index, 1)
        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_l.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 3)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=False,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='b',
                        timestamp=timestamp_pb2.Timestamp(seconds=5))
                ]))
        rsp = client_l.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_l = client_l.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_l is not None)
        self.assertFalse(manifest_l.finished)
        self.assertEqual(manifest_l.next_process_index, 2)
        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_l.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 5)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=True,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='b',
                        timestamp=timestamp_pb2.Timestamp(seconds=5)),
                    dj_pb.RawDataMeta(
                        file_path='a',
                        timestamp=timestamp_pb2.Timestamp(seconds=3))
                ]))
        rsp = client_l.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_l = client_l.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_l is not None)
        self.assertFalse(manifest_l.finished)
        self.assertEqual(manifest_l.next_process_index, 2)
        rsp = client_l.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 5)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        manifest_f = client_f.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_f is not None)
        self.assertFalse(manifest_f.finished)
        self.assertEqual(manifest_f.next_process_index, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_f.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 0)
        self.assertEqual(rsp.timestamp.nanos, 0)
        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=False,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='a',
                        timestamp=timestamp_pb2.Timestamp(seconds=1))
                ]))
        rsp = client_f.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_f = client_f.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_f is not None)
        self.assertFalse(manifest_f.finished)
        self.assertEqual(manifest_f.next_process_index, 1)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_f.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 1)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=False,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='b',
                        timestamp=timestamp_pb2.Timestamp(seconds=2))
                ]))
        rsp = client_f.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_f = client_f.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_f is not None)
        self.assertFalse(manifest_f.finished)
        self.assertEqual(manifest_f.next_process_index, 2)
        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_f.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 2)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=True,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='a',
                        timestamp=timestamp_pb2.Timestamp(seconds=1)),
                    dj_pb.RawDataMeta(
                        file_path='b',
                        timestamp=timestamp_pb2.Timestamp(seconds=2))
                ]))
        rsp = client_f.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_f = client_f.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_f is not None)
        self.assertFalse(manifest_f.finished)
        self.assertEqual(manifest_f.next_process_index, 2)
        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_f.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 2)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            finish_raw_data=empty_pb2.Empty())
        rsp = client_l.FinishRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_l.FinishRawData(rdreq)
        self.assertEqual(rsp.code, 0)

        manifest_l = client_l.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_l is not None)
        self.assertTrue(manifest_l.finished)
        self.assertEqual(manifest_l.next_process_index, 2)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=False,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='x',
                        timestamp=timestamp_pb2.Timestamp(seconds=4))
                ]))
        try:
            rsp = client_l.AddRawData(rdreq)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        try:
            rsp = client_f.FinishJoinPartition(rdreq2)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        rsp = client_l.FinishJoinPartition(rdreq1)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_l.FinishJoinPartition(rdreq1)
        self.assertEqual(rsp.code, 0)

        rsp = client_f.FinishJoinPartition(rdreq1)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_f.FinishJoinPartition(rdreq1)
        self.assertEqual(rsp.code, 0)

        try:
            rsp = client_f.FinishJoinPartition(rdreq2)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
            finish_raw_data=empty_pb2.Empty())
        rsp = client_f.FinishRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_f.FinishRawData(rdreq)
        self.assertEqual(rsp.code, 0)

        manifest_f = client_f.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_f is not None)
        self.assertTrue(manifest_f.finished)
        self.assertEqual(manifest_f.next_process_index, 2)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=True,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='x',
                        timestamp=timestamp_pb2.Timestamp(seconds=3))
                ]))
        try:
            rsp = client_f.AddRawData(rdreq)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        rsp = client_f.FinishJoinPartition(rdreq2)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_f.FinishJoinPartition(rdreq2)
        self.assertEqual(rsp.code, 0)

        rsp = client_l.FinishJoinPartition(rdreq2)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_l.FinishJoinPartition(rdreq2)
        self.assertEqual(rsp.code, 0)

        while True:
            req_l = dj_pb.DataSourceRequest(
                data_source_meta=data_source_l.data_source_meta)
            req_f = dj_pb.DataSourceRequest(
                data_source_meta=data_source_f.data_source_meta)
            dss_l = client_l.GetDataSourceStatus(req_l)
            dss_f = client_f.GetDataSourceStatus(req_f)
            self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
            self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
            if dss_l.state == common_pb.DataSourceState.Finished and \
                    dss_f.state == common_pb.DataSourceState.Finished:
                break
            else:
                time.sleep(2)

        master_l.stop()
        master_f.stop()
Exemplo n.º 5
0
                        default='test_etcd',
                        help='the name of etcd')
    parser.add_argument('--etcd_addrs',
                        type=str,
                        default='localhost:2379',
                        help='the addrs of etcd')
    parser.add_argument('--etcd_base_dir',
                        type=str,
                        default='fedlearner_test',
                        help='the namespace of etcd key')
    parser.add_argument('--listen_port',
                        '-p',
                        type=int,
                        default=4032,
                        help='Listen port of data join master')
    parser.add_argument('--data_source_name',
                        type=str,
                        default='test_data_source',
                        help='the name of data source')
    parser.add_argument('--use_mock_etcd',
                        action='store_true',
                        help='use to mock etcd for test')
    args = parser.parse_args()
    master_options = dj_pb.DataJoinMasterOptions(
        use_mock_etcd=args.use_mock_etcd)
    master_srv = DataJoinMasterService(args.listen_port, args.peer_addr,
                                       args.data_source_name, args.etcd_name,
                                       args.etcd_base_dir, args.etcd_addrs,
                                       master_options)
    master_srv.run()
Exemplo n.º 6
0
    logging.basicConfig(format="%(asctime)s %(filename)s "\
                               "%(lineno)s %(levelname)s - %(message)s")
    parser = argparse.ArgumentParser(description='DataJointMasterService cmd.')
    parser.add_argument('peer_addr',
                        type=str,
                        help='the addr(uuid) of peer data join master')
    parser.add_argument('--kvstore_type',
                        type=str,
                        default='etcd',
                        help='the name of mysql')
    parser.add_argument('--listen_port',
                        '-p',
                        type=int,
                        default=4032,
                        help='Listen port of data join master')
    parser.add_argument('--data_source_name',
                        type=str,
                        default='test_data_source',
                        help='the name of data source')
    parser.add_argument('--batch_mode',
                        action='store_true',
                        help='make the data join run in batch mode')
    args = parser.parse_args()
    master_options = dj_pb.DataJoinMasterOptions(
        use_mock_etcd=(args.kvstore_type == 'mock'),
        batch_mode=args.batch_mode)
    master_srv = DataJoinMasterService(args.listen_port, args.peer_addr,
                                       args.data_source_name,
                                       args.kvstore_type, master_options)
    master_srv.run()
                        type=str,
                        default='localhost:2379',
                        help='the addrs of etcd')
    parser.add_argument('--etcd_base_dir',
                        type=str,
                        default='fedlearner_test',
                        help='the namespace of etcd key')
    parser.add_argument('--listen_port',
                        '-p',
                        type=int,
                        default=4032,
                        help='Listen port of data join master')
    parser.add_argument('--data_source_name',
                        type=str,
                        default='test_data_source',
                        help='the name of data source')
    parser.add_argument('--use_mock_etcd',
                        action='store_true',
                        help='use to mock etcd for test')
    parser.add_argument('--batch_mode',
                        action='store_true',
                        help='make the data join run in batch mode')
    args = parser.parse_args()
    master_options = dj_pb.DataJoinMasterOptions(
        use_mock_etcd=args.use_mock_etcd, batch_mode=args.batch_mode)
    master_srv = DataJoinMasterService(args.listen_port, args.peer_addr,
                                       args.data_source_name, args.etcd_name,
                                       args.etcd_base_dir, args.etcd_addrs,
                                       master_options)
    master_srv.run()