Exemple #1
0
    def test_get_model_id(self, http_hook_class_mock):
        http_hook_mock = Mock(base_url='http://rekcurd-dashboard.com')
        run_result = [
            {"model_id": 5, "register_date": 2.0, "description": "dummy1"},
            {"model_id": 6, "register_date": 5.0, "description": "dummy2"},
            {"model_id": 3, "register_date": 1.0, "description": "dummy1"},
            {"model_id": 4, "register_date": 1.5, "description": "dummy2"},
        ]
        http_hook_mock.run.return_value = Mock(text=json.dumps(run_result))
        http_hook_class_mock.return_value = http_hook_mock

        task = ModelUploadOperator(task_id='rekcurd_api',
                                   dag=self.dag,
                                   project_id=1, app_id='sample_app',
                                   model_file_path='test.zip')
        res = task.get_model_id('dummy2')
        self.assertEqual(res, 6)

        expected_headers = {'Authorization': 'Bearer my_token'}

        http_hook_class_mock.assert_called_with('GET', http_conn_id='rekcurd_dashboard')
        http_hook_mock.run.assert_called_with(
            '/api/projects/1/applications/sample_app/models',
            headers=expected_headers,
            extra_options={'timeout': 300})
Exemple #2
0
    def test_get_model_data(self):
        with NamedTemporaryFile() as tf:
            tf.write(b'dummy model content')
            tf.seek(0)
            task = ModelUploadOperator(task_id='rekcurd_api',
                                       dag=self.dag,
                                       project_id=1, app_id='sample_app',
                                       model_file_path=tf.name,
                                       model_description='dummy model')
            model, desc = task.get_model_data(None)

        self.assertEqual(model, b'dummy model content')
        self.assertEqual(desc, 'dummy model')
Exemple #3
0
    def test_get_model_data_from_xcom(self):
        task = ModelUploadOperator(task_id='rekcurd_api',
                                   dag=self.dag,
                                   project_id=1, app_id='sample_app',
                                   model_provide_task_id='task_1',
                                   model_description='dummy model')
        ti_mock = Mock()
        ti_mock.xcom_pull.return_value = 'model binary'
        context = {'ti': ti_mock}

        model, desc = task.get_model_data(context)
        self.assertEqual(model, 'model binary')
        self.assertEqual(desc, 'model binary')

        ti_mock.xcom_pull.assert_has_calls([
            call(key='rekcurd_model_key', task_ids='task_1'),
            call(key='rekcurd_model_desc_key', task_ids='task_1')
        ])
Exemple #4
0
    def test_upload(self, http_hook_class_mock, request_class_mock):
        prepare_request_mock = Mock()

        session_mock = Mock()
        session_mock.prepare_request.return_value = prepare_request_mock

        http_hook_mock = Mock(base_url='http://rekcurd-dashboard.com')
        http_hook_mock.get_conn.return_value = session_mock
        http_hook_mock.run_and_check.return_value = Mock(
            text='{"status": true, "message": "success"}')

        http_hook_class_mock.return_value = http_hook_mock

        request_mock = Mock()
        request_class_mock.return_value = request_mock
        task = ModelUploadOperator(task_id='rekcurd_api',
                                   dag=self.dag,
                                   project_id=1, app_id='sample_app',
                                   model_file_path='test.zip')
        model = 'dummy model'
        desc = 'dummy desc'
        task.upload(model, desc)

        expected_headers = {'Authorization': 'Bearer my_token'}

        http_hook_class_mock.assert_called_with('POST', http_conn_id='rekcurd_dashboard')
        http_hook_mock.get_conn.assert_called_with(expected_headers)
        session_mock.prepare_request.assert_called_with(request_mock)
        request_class_mock.assert_called_with(
            'POST',
            'http://rekcurd-dashboard.com/api/projects/1/applications/sample_app/models',
            data={'description': desc},
            files={'file': model},
            headers=expected_headers)
        http_hook_mock.run_and_check.assert_called_with(
            session_mock,
            prepare_request_mock,
            {'timeout': 300})
Exemple #5
0
    return print_model_id


with DAG('example_model_upload',
         default_args=default_args,
         schedule_interval="@once") as dag:
    train = PythonOperator(task_id='train',
                           python_callable=train_func,
                           provide_context=True)
    save = PythonOperator(task_id='save',
                          python_callable=save_model_func,
                          provide_context=True)

    # upload saved model file.
    upload_file = ModelUploadOperator(task_id='upload_file',
                                      app_id=5,
                                      model_file_path=MODEL_PATH,
                                      model_description=MODEL_DESCRIPTION)
    # upload trained model data.
    upload_binary = ModelUploadOperator(task_id='upload_binary',
                                        app_id=5,
                                        model_provide_task_id='train',
                                        model_description=MODEL_DESCRIPTION)
    delete = BashOperator(task_id='delete',
                          bash_command='rm {}'.format(MODEL_PATH),
                          trigger_rule='all_done')

    print_id_file = PythonOperator(
        task_id='print_id_file',
        python_callable=get_print_model_id('upload_file'),
        provide_context=True)
    print_id_binary = PythonOperator(
Exemple #6
0
        else:
            print(m + ':', ' '.join('{:.5f}'.format(r) for r in result[m]))


with DAG('example_all', default_args=default_args,
         schedule_interval="@once") as dag:
    train = PythonOperator(task_id='train',
                           python_callable=train_func,
                           provide_context=True)

    application_id = 5
    sandbox_service_id = 10
    dev_service_id = 11

    upload_model = ModelUploadOperator(task_id='upload_model',
                                       app_id=application_id,
                                       model_provide_task_id='train')
    switch_sandbox_model = ModelSwitchOperator(
        task_id='switch_sandbox_model',
        app_id=application_id,
        service_id=sandbox_service_id,
        model_provide_task_id='upload_model')
    # wait until kubernetes cluster finishes rolling update.
    wait = BashOperator(task_id='wait_updating', bash_command='sleep 800')

    save_eval_file = PythonOperator(task_id='write_eval_file',
                                    python_callable=write_eval_file,
                                    provide_context=True)
    upload_evaluation_file = EvaluationUploadOperator(
        task_id='upload_eval_file',
        app_id=application_id,