class TestSageMakerEndpointConfigOperator(unittest.TestCase): def setUp(self): self.sagemaker = SageMakerEndpointConfigOperator( task_id='test_sagemaker_operator', aws_conn_id='sagemaker_test_id', config=create_endpoint_config_params) def test_parse_config_integers(self): self.sagemaker.parse_config_integers() for variant in self.sagemaker.config['ProductionVariants']: self.assertEqual(variant['InitialInstanceCount'], int(variant['InitialInstanceCount'])) @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_endpoint_config') def test_execute(self, mock_model, mock_client): mock_model.return_value = { 'EndpointConfigArn': 'testarn', 'ResponseMetadata': { 'HTTPStatusCode': 200 } } self.sagemaker.execute(None) mock_model.assert_called_once_with(create_endpoint_config_params) @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_model') def test_execute_with_failure(self, mock_model, mock_client): mock_model.return_value = { 'EndpointConfigArn': 'testarn', 'ResponseMetadata': { 'HTTPStatusCode': 200 } } self.assertRaises(AirflowException, self.sagemaker.execute, None)
def setUp(self): self.sagemaker = SageMakerEndpointConfigOperator( task_id='test_sagemaker_operator', aws_conn_id='sagemaker_test_id', config=create_endpoint_config_params, )