class RpcClientTest(unittest.TestCase): _TEST_PROJECT_NAME = 'test-project' _TEST_RECEIVER_NAME = 'test-receiver' _TEST_URL = 'localhost:123' _TEST_AUTHORITY = 'test-authority' _X_HOST_HEADER_KEY = 'x-host' _TEST_X_HOST = 'default.fedlearner.webconsole' _TEST_SELF_DOMAIN_NAME = 'fl-test-self.com' _DB = create_test_db() @classmethod def setUpClass(cls): grpc_spec = GrpcSpec( authority=cls._TEST_AUTHORITY, extra_headers={cls._X_HOST_HEADER_KEY: cls._TEST_X_HOST}) participant = Participant(name=cls._TEST_RECEIVER_NAME, domain_name='fl-test.com', grpc_spec=grpc_spec) project_config = Project(name=cls._TEST_PROJECT_NAME, domain_name=cls._TEST_SELF_DOMAIN_NAME, token='test-auth-token', participants=[participant], variables=[{ 'name': 'EGRESS_URL', 'value': cls._TEST_URL }]) cls._participant = participant cls._project_config = project_config cls._project = ProjectModel(name=cls._TEST_PROJECT_NAME) cls._project.set_config(project_config) # Inserts the project entity cls._DB.create_all() cls._DB.session.add(cls._project) cls._DB.session.commit() @classmethod def tearDownClass(cls): cls._DB.session.remove() cls._DB.drop_all() def setUp(self): self._client_execution_thread_pool = logging_pool.pool(1) # Builds a testing channel self._fake_channel = grpc_testing.channel( DESCRIPTOR.services_by_name.values(), grpc_testing.strict_real_time()) self._build_channel_patcher = patch( 'fedlearner_webconsole.rpc.client._build_channel') self._mock_build_channel = self._build_channel_patcher.start() self._mock_build_channel.return_value = self._fake_channel self._client = RpcClient(self._project_config, self._participant) self._mock_build_channel.assert_called_once_with( self._TEST_URL, self._TEST_AUTHORITY) def tearDown(self): self._build_channel_patcher.stop() self._client_execution_thread_pool.shutdown(wait=False) def test_check_connection(self): call = self._client_execution_thread_pool.submit( self._client.check_connection) invocation_metadata, request, rpc = self._fake_channel.take_unary_unary( TARGET_SERVICE.methods_by_name['CheckConnection']) self.assertIn((self._X_HOST_HEADER_KEY, self._TEST_X_HOST), invocation_metadata) self.assertEqual( request, CheckConnectionRequest(auth_info=ProjAuthInfo( project_name=self._project_config.name, source_domain=self._TEST_SELF_DOMAIN_NAME, target_domain=self._participant.domain_name, auth_token=self._project_config.token))) expected_status = Status(code=FedLearnerStatusCode.STATUS_SUCCESS, msg='test') rpc.terminate(response=CheckConnectionResponse(status=expected_status), code=StatusCode.OK, trailing_metadata=(), details=None) self.assertEqual(call.result(), expected_status)
def setUp(self): self._db = create_test_db() self._db.create_all()