예제 #1
0
    def test_handle__fork_child_raise(self):
        e_message = 'e_message'
        context_variables = 'variables'
        converge_gateway_id = 'converge_gateway_id'

        cpg = MagicMock()
        cpg.targets_meet_condition = MagicMock(return_value=targets)
        cpg.converge_gateway_id = converge_gateway_id
        status = MockStatus(loop=0)
        process = MockPipelineProcess(top_pipeline_context=MockContext(
            variables=context_variables))

        with patch(PIPELINE_PROCESS_FORK_CHILD,
                   MagicMock(side_effect=PipelineException(e_message))):
            result = handlers.conditional_parallel_handler(
                process, cpg, status)
            self.assertIsNone(result.next_node)
            self.assertTrue(result.should_return)
            self.assertTrue(result.should_sleep)

            conditional_parallel.hydrate_data.assert_called_once_with(
                context_variables)

            cpg.targets_meet_condition.assert_called_once_with(hydrate_context)

            Status.objects.fail.assert_called_once_with(cpg, ex_data=e_message)

            process.join.assert_not_called()
예제 #2
0
    def make(self, key):
        # get movie
        scan, scale = (AstroCaMovie & key).fetch1("corrected_scan",
                                                  "scale_factor")
        scan = scan.astype(np.float32) * scale

        # get regressors
        X = (tune.OriDesign() & key).fetch1("regressors")
        pipe = (fuse.MotionCorrection() & key).module
        nfields_name = ("nfields/nrois"
                        if "nrois" in pipe.ScanInfo.heading else "nfields")
        nfields = int((pipe.ScanInfo & key).proj(n=nfields_name).fetch1("n"))
        X = X[key["field"] - 1::nfields, :]

        if abs(X.shape[0] - scan.shape[2]) > 1:
            raise PipelineException(
                "The sync frames do not match scan frames.")
        else:
            # truncate scan if X is shorter
            if X.shape[0] < scan.shape[2]:
                warn("Scan is longer than design matrix")
                scan = scan[:, :, :X.shape[0]]
            # truncate design matrix if scan is shorter
            if scan.shape[2] < X.shape[0]:
                warn("Scan is shorter than design matrix")
                X = X[:scan.shape[2], :]

        # limit the analysis to times when X is non-zero
        ix = (X**2).sum(axis=1) > 1e-4 * (X**2).sum(axis=1).max()
        X = X[ix, :]
        scan = scan[:, :, ix]

        # normalize regressors
        X = X - X.mean(axis=0)
        X /= np.sqrt((X**2).sum(axis=0, keepdims=True))

        # normalize movie
        scan -= scan.mean(axis=2, keepdims=True)
        key["activity_map"] = np.sqrt((scan**2).sum(axis=2))
        scan /= key["activity_map"][:, :, None] + 1e-6

        # compute response
        key["response_map"] = np.tensordot(scan,
                                           np.linalg.pinv(X),
                                           axes=(2, 1))
        self.insert1(key)
    def test_handle__fork_raise_exception(self):
        process = MockPipelineProcess()
        parallel_gateway = MockParallelGateway()
        e_msg = 'e_msg'

        with patch(PIPELINE_PROCESS_FORK_CHILD, MagicMock(side_effect=PipelineException(e_msg))):
            hdl_result = handlers.parallel_gateway_handler(process, parallel_gateway, MockStatus())

            PipelineProcess.objects.fork_child.assert_called()

            Status.objects.fail.assert_called_once_with(parallel_gateway, e_msg)

            process.join.assert_not_called()

            Status.objects.finish.assert_not_called()

            self.assertIsNone(hdl_result.next_node)
            self.assertTrue(hdl_result.should_return)
            self.assertTrue(hdl_result.should_sleep)
예제 #4
0
파일: h5.py 프로젝트: zhoupc/ease
def ts2sec(ts, packet_length=0, samplingrate=1e7):
    """
    Convert 10MHz timestamps from Saumil's patching program (ts) to seconds (s)

    :param ts: timestamps
    :param packet_length: length of timestamped packets
    :returns:
        timestamps converted to seconds
        system time (in seconds) of t=0
        bad camera indices from 2^31:2^32 in camera timestamps prior to 4/10/13
    """
    ts = ts.astype(float)

    # find bad indices in camera timestamps and replace with linear est
    bad_idx = ts == 2 ** 31 - 1
    if bad_idx.sum() > 10:
        raise PipelineException('Bad camera ts...')
        x = np.where(~bad_idx)[0]
        x_bad = np.where(bad_idx)[0]
        f = iu_spline(x, ts[~bad_idx], k=1)
        ts[bad_idx] = f(x_bad)

    # remove wraparound
    wrap_idx = np.where(np.diff(ts) < 0)[0]
    while not len(wrap_idx) == 0:
        ts[wrap_idx[0] + 1:] += 2 ** 32
        wrap_idx = np.where(np.diff(ts) < 0)[0]

    s = ts / samplingrate

    # Remove offset, and if not monotonically increasing (i.e. for packeted ts), interpolate
    if np.any(np.diff(s) <= 0):
        # Check to make sure it's packets
        diffs = np.where(np.diff(s) > 0)[0]
        assert packet_length == diffs[0] + 1

        # Interpolate
        not_zero = np.hstack((0, diffs + 1))
        f = iu_spline(not_zero, s[not_zero], k=1)
        s = f(np.arange(len(s)))

    return s, bad_idx
예제 #5
0
    def test_handle__next_raise_exception(self):
        hydrate_data_return = 'hydrate_data_return'
        e = PipelineException('ex_data')
        exclusive_gateway = MockExclusiveGateway(next_exception=e)
        context = MockContext()
        process = MockPipelineProcess(top_pipeline_context=context)

        with patch(EXG_HYDRATE_DATA, MagicMock(return_value=hydrate_data_return)):
            hdl_result = handlers.exclusive_gateway_handler(process, exclusive_gateway, MockStatus())

            exg_h.hydrate_data.assert_called_once_with(context.variables)

            exclusive_gateway.next.assert_called_once_with(hydrate_data_return)

            Status.objects.fail.assert_called_once_with(exclusive_gateway, ex_data=e.message)

            Status.objects.finish.assert_not_called()

            self.assertIsNone(hdl_result.next_node)
            self.assertTrue(hdl_result.should_return)
            self.assertTrue(hdl_result.should_sleep)
예제 #6
0
def find_idx_boundaries(indices, drop_single_idx=False):
    """
    Given a flatten list/array of indices, break list into a list of lists of indices incrementing by 1

        Example:
            >>>find_idx_boundaries([1,2,3,4,501,502,503,504])
               return value: [[1,2,3,4],[501,502,503,504]]

        Parameters:
            indices: Flattened list or numpy array of indices to break apart into sublists
            drop_single_idx: Boolean which sets if single indices not part of any sublist
                             should be dropped or raise an error upon detection.

        Returns:
            events: List of lists of indices that are incrementing by 1

    """

    events = []

    ## Basic idea: If you have a list [1,2,3,20,21,22], subtracting the index of that value from it
    ## will lead to assigning different numbers to different clusters of values incrementing by one.
    ## For instance [1-1, 2-2, 3-3, 20-4, 21-5, 22-6] = [0, 0, 0, 16, 16, 16]. Using groupby we
    ## split these values into group 1 (everything assigned 0) and group 2 (everything assigned 16).
    for k, g in groupby(enumerate(indices), lambda x: x[0] - x[1]):

        event = np.array([e[1] for e in g])

        if len(event) == 1:
            if not drop_single_idx:
                raise PipelineException(
                    f"Disconnected index found: {event[0]}")
        else:
            events.append(event)

    return events
예제 #7
0
 def data_for_node(self, node):
     node = self.spec.objects.get(node.id)
     if not node:
         raise PipelineException('Can not find node %s in this pipeline.' %
                                 node.id)
     return node.data
예제 #8
0
class APITest(TestCase):
    @classmethod
    def setUpClass(cls):
        cls.GET_TEMPLATE_LIST_URL = '/apigw/get_template_list/{biz_cc_id}/'
        cls.GET_TEMPLATE_INFO_URL = '/apigw/get_template_info/{template_id}/{bk_biz_id}/'
        cls.CREATE_TASK_URL = '/apigw/create_task/{template_id}/{bk_biz_id}/'
        cls.START_TASK_URL = '/apigw/start_task/{task_id}/{bk_biz_id}/'
        cls.OPERATE_TASK_URL = '/apigw/operate_task/{task_id}/{bk_biz_id}/'
        cls.GET_TASK_STATUS_URL = '/apigw/get_task_status/{task_id}/{bk_biz_id}/'
        cls.QUERY_TASK_COUNT_URL = '/apigw/query_task_count/{bk_biz_id}/'
        cls.GET_PERIODIC_TASK_LIST_URL = '/apigw/get_periodic_task_list/{bk_biz_id}/'
        cls.GET_PERIODIC_TASK_INFO_URL = '/apigw/get_periodic_task_info/{task_id}/{bk_biz_id}/'
        cls.CREATE_PERIODIC_TASK_URL = '/apigw/create_periodic_task/{template_id}/{bk_biz_id}/'
        cls.SET_PERIODIC_TASK_ENABLED_URL = '/apigw/set_periodic_task_enabled/{task_id}/{bk_biz_id}/'
        cls.MODIFY_PERIODIC_TASK_CRON_URL = '/apigw/modify_cron_for_periodic_task/{task_id}/{bk_biz_id}/'
        cls.MODIFY_PERIODIC_TASK_CONSTANTS_URL = '/apigw/modify_constants_for_periodic_task/{task_id}/{bk_biz_id}/'

        super(APITest, cls).setUpClass()

    def setUp(self):
        self.white_list_patcher = mock.patch(APIGW_DECORATOR_CHECK_WHITE_LIST,
                                             MagicMock(return_value=True))

        self.dummy_user = MagicMock()
        self.dummy_user.username = ''
        self.user_cls = MagicMock()
        self.user_cls.objects = MagicMock()
        self.user_cls.objects.get_or_create = MagicMock(
            return_value=(self.dummy_user, False))

        self.get_user_model_patcher = mock.patch(
            APIGW_DECORATOR_GET_USER_MODEL,
            MagicMock(return_value=self.user_cls))
        self.prepare_user_business_patcher = mock.patch(
            APIGW_DECORATOR_PREPARE_USER_BUSINESS, MagicMock())
        self.business_exist_patcher = mock.patch(
            APIGW_DECORATOR_BUSINESS_EXIST, MagicMock(return_value=True))

        self.white_list_patcher.start()
        self.get_user_model_patcher.start()
        self.prepare_user_business_patcher.start()
        self.business_exist_patcher.start()

        self.client = Client()

    def tearDown(self):
        self.white_list_patcher.stop()
        self.get_user_model_patcher.stop()
        self.prepare_user_business_patcher.stop()
        self.business_exist_patcher.stop()

    @mock.patch(BUSINESS_GET,
                MagicMock(return_value=MockBusiness(cc_id=TEST_BIZ_CC_ID,
                                                    cc_name=TEST_BIZ_CC_NAME)))
    def test_get_template_list__for_business_template(self):
        pt1 = MockPipelineTemplate(id=1, name='pt1')
        pt2 = MockPipelineTemplate(id=2, name='pt2')

        task_tmpl1 = MockTaskTemplate(id=1, pipeline_template=pt1)
        task_tmpl2 = MockTaskTemplate(id=2, pipeline_template=pt2)

        task_templates = [task_tmpl1, task_tmpl2]

        with mock.patch(
                TASKTEMPLATE_SELECT_RELATE,
                MagicMock(return_value=MockQuerySet(
                    filter_result=task_templates))):
            assert_data = [{
                'id':
                tmpl.id,
                'name':
                tmpl.pipeline_template.name,
                'creator':
                tmpl.pipeline_template.creator,
                'create_time':
                strftime_with_timezone(tmpl.pipeline_template.create_time),
                'editor':
                tmpl.pipeline_template.editor,
                'edit_time':
                strftime_with_timezone(tmpl.pipeline_template.edit_time),
                'category':
                tmpl.category,
                'bk_biz_id':
                TEST_BIZ_CC_ID,
                'bk_biz_name':
                TEST_BIZ_CC_NAME
            } for tmpl in task_templates]

            response = self.client.get(path=self.GET_TEMPLATE_LIST_URL.format(
                biz_cc_id=TEST_BIZ_CC_ID))

            self.assertEqual(response.status_code, 200)

            data = json.loads(response.content)

            self.assertTrue(data['result'])
            self.assertEqual(data['data'], assert_data)

        with mock.patch(
                TASKTEMPLATE_SELECT_RELATE,
                MagicMock(return_value=MockQuerySet(filter_result=[]))):
            assert_data = []

            response = self.client.get(path=self.GET_TEMPLATE_LIST_URL.format(
                biz_cc_id=TEST_BIZ_CC_ID))

            data = json.loads(response.content)

            self.assertTrue(data['result'])
            self.assertEqual(data['data'], assert_data)

    @mock.patch(BUSINESS_GET,
                MagicMock(return_value=MockBusiness(cc_id=TEST_BIZ_CC_ID,
                                                    cc_name=TEST_BIZ_CC_NAME)))
    def test_get_template_list__for_common_template(self):
        pt1 = MockPipelineTemplate(id=1, name='pt1')
        pt2 = MockPipelineTemplate(id=2, name='pt2')

        task_tmpl1 = MockCommonTemplate(id=1, pipeline_template=pt1)
        task_tmpl2 = MockCommonTemplate(id=2, pipeline_template=pt2)

        task_templates = [task_tmpl1, task_tmpl2]

        with mock.patch(
                COMMONTEMPLATE_SELECT_RELATE,
                MagicMock(return_value=MockQuerySet(
                    filter_result=task_templates))):
            assert_data = [{
                'id':
                tmpl.id,
                'name':
                tmpl.pipeline_template.name,
                'creator':
                tmpl.pipeline_template.creator,
                'create_time':
                strftime_with_timezone(tmpl.pipeline_template.create_time),
                'editor':
                tmpl.pipeline_template.editor,
                'edit_time':
                strftime_with_timezone(tmpl.pipeline_template.edit_time),
                'category':
                tmpl.category,
                'bk_biz_id':
                TEST_BIZ_CC_ID,
                'bk_biz_name':
                TEST_BIZ_CC_NAME
            } for tmpl in task_templates]

            response = self.client.get(path=self.GET_TEMPLATE_LIST_URL.format(
                biz_cc_id=TEST_BIZ_CC_ID),
                                       data={'template_source': 'common'})

            self.assertEqual(response.status_code, 200)

            data = json.loads(response.content)

            self.assertTrue(data['result'])
            self.assertEqual(data['data'], assert_data)

        with mock.patch(
                COMMONTEMPLATE_SELECT_RELATE,
                MagicMock(return_value=MockQuerySet(filter_result=[]))):
            assert_data = []

            response = self.client.get(path=self.GET_TEMPLATE_LIST_URL.format(
                biz_cc_id=TEST_BIZ_CC_ID),
                                       data={'template_source': 'common'})

            data = json.loads(response.content)

            self.assertTrue(data['result'])
            self.assertEqual(data['data'], assert_data)

    @mock.patch(BUSINESS_GET,
                MagicMock(return_value=MockBusiness(cc_id=TEST_BIZ_CC_ID,
                                                    cc_name=TEST_BIZ_CC_NAME)))
    def test_get_template_info__for_business_template(self):
        pt1 = MockPipelineTemplate(id=1, name='pt1')

        tmpl = MockTaskTemplate(id=1, pipeline_template=pt1)

        with mock.patch(TASKTEMPLATE_SELECT_RELATE,
                        MagicMock(return_value=MockQuerySet(get_result=tmpl))):
            pipeline_tree = copy.deepcopy(tmpl.pipeline_tree)
            pipeline_tree.pop('line')
            pipeline_tree.pop('location')
            assert_data = {
                'id':
                tmpl.id,
                'name':
                tmpl.pipeline_template.name,
                'creator':
                tmpl.pipeline_template.creator,
                'create_time':
                strftime_with_timezone(tmpl.pipeline_template.create_time),
                'editor':
                tmpl.pipeline_template.editor,
                'edit_time':
                strftime_with_timezone(tmpl.pipeline_template.edit_time),
                'category':
                tmpl.category,
                'bk_biz_id':
                TEST_BIZ_CC_ID,
                'bk_biz_name':
                TEST_BIZ_CC_NAME,
                'pipeline_tree':
                pipeline_tree
            }

            response = self.client.get(path=self.GET_TEMPLATE_INFO_URL.format(
                template_id=TEST_TEMPLATE_ID, bk_biz_id=TEST_BIZ_CC_ID))

            data = json.loads(response.content)

            self.assertTrue(data['result'])
            self.assertEqual(assert_data, data['data'])

    @mock.patch(TASKTEMPLATE_SELECT_RELATE,
                MagicMock(return_value=MockQuerySet(
                    get_raise=TaskTemplate.DoesNotExist())))
    @mock.patch(BUSINESS_GET,
                MagicMock(return_value=MockBusiness(cc_id=TEST_BIZ_CC_ID,
                                                    cc_name=TEST_BIZ_CC_NAME)))
    def test_get_template_info__for_business_template_does_not_exists(self):
        response = self.client.get(path=self.GET_TEMPLATE_INFO_URL.format(
            template_id=TEST_TEMPLATE_ID, bk_biz_id=TEST_BIZ_CC_ID), )

        data = json.loads(response.content)

        self.assertFalse(data['result'])
        self.assertTrue('message' in data)

    @mock.patch(BUSINESS_GET,
                MagicMock(return_value=MockBusiness(cc_id=TEST_BIZ_CC_ID,
                                                    cc_name=TEST_BIZ_CC_NAME)))
    def test_get_template_info__for_common_template(self):
        pt1 = MockPipelineTemplate(id=1, name='pt1')

        tmpl = MockCommonTemplate(id=1, pipeline_template=pt1)

        with mock.patch(COMMONTEMPLATE_SELECT_RELATE,
                        MagicMock(return_value=MockQuerySet(get_result=tmpl))):
            pipeline_tree = copy.deepcopy(tmpl.pipeline_tree)
            pipeline_tree.pop('line')
            pipeline_tree.pop('location')
            assert_data = {
                'id':
                tmpl.id,
                'name':
                tmpl.pipeline_template.name,
                'creator':
                tmpl.pipeline_template.creator,
                'create_time':
                strftime_with_timezone(tmpl.pipeline_template.create_time),
                'editor':
                tmpl.pipeline_template.editor,
                'edit_time':
                strftime_with_timezone(tmpl.pipeline_template.edit_time),
                'category':
                tmpl.category,
                'bk_biz_id':
                TEST_BIZ_CC_ID,
                'bk_biz_name':
                TEST_BIZ_CC_NAME,
                'pipeline_tree':
                pipeline_tree
            }

            response = self.client.get(path=self.GET_TEMPLATE_INFO_URL.format(
                template_id=TEST_TEMPLATE_ID, bk_biz_id=TEST_BIZ_CC_ID),
                                       data={'template_source': 'common'})

            data = json.loads(response.content)

            self.assertTrue(data['result'])
            self.assertEqual(assert_data, data['data'])

    @mock.patch(COMMONTEMPLATE_SELECT_RELATE,
                MagicMock(return_value=MockQuerySet(
                    get_raise=CommonTemplate.DoesNotExist())))
    @mock.patch(BUSINESS_GET,
                MagicMock(return_value=MockBusiness(cc_id=TEST_BIZ_CC_ID,
                                                    cc_name=TEST_BIZ_CC_NAME)))
    def test_get_template_info__for_common_template_does_not_exists(self):
        response = self.client.get(path=self.GET_TEMPLATE_INFO_URL.format(
            template_id=TEST_TEMPLATE_ID, bk_biz_id=TEST_BIZ_CC_ID),
                                   data={'template_source': 'common'})

        data = json.loads(response.content)

        self.assertFalse(data['result'])
        self.assertTrue('message' in data)

    @mock.patch(TASKINSTANCE_CREATE_PIPELINE,
                MagicMock(return_value=(True, TEST_DATA)))
    @mock.patch(
        TASKINSTANCE_CREATE,
        MagicMock(return_value=MockTaskFlowInstance(id=TEST_TASKFLOW_ID)))
    @mock.patch(APIGW_VIEW_JSON_SCHEMA_VALIDATE, MagicMock())
    def test_create_task__success(self):
        pt1 = MockPipelineTemplate(id=1, name='pt1')

        tmpl = MockTaskTemplate(id=1, pipeline_template=pt1)
        biz = MockBusiness(cc_id=TEST_BIZ_CC_ID, cc_name=TEST_BIZ_CC_NAME)

        with mock.patch(BUSINESS_GET, MagicMock(return_value=biz)):
            with mock.patch(
                    TASKTEMPLATE_SELECT_RELATE,
                    MagicMock(return_value=MockQuerySet(get_result=tmpl))):
                assert_data = {'task_id': TEST_TASKFLOW_ID}
                response = self.client.post(path=self.CREATE_TASK_URL.format(
                    template_id=TEST_TEMPLATE_ID, bk_biz_id=TEST_BIZ_CC_ID),
                                            data=json.dumps({
                                                'name':
                                                'name',
                                                'constants':
                                                'constants',
                                                'exclude_task_nodes_id':
                                                'exclude_task_nodes_id',
                                                'flow_type':
                                                'common'
                                            }),
                                            content_type="application/json",
                                            HTTP_BK_APP_CODE=TEST_APP_CODE)

                TaskFlowInstance.objects.create_pipeline_instance_exclude_task_nodes.assert_called_once_with(
                    tmpl, {
                        'name': 'name',
                        'creator': ''
                    }, 'constants', 'exclude_task_nodes_id')

                TaskFlowInstance.objects.create.assert_called_once_with(
                    business=biz,
                    category=tmpl.category,
                    pipeline_instance=TEST_DATA,
                    template_id=TEST_TEMPLATE_ID,
                    create_method='api',
                    create_info=TEST_APP_CODE,
                    flow_type='common',
                    current_flow='execute_task')

                data = json.loads(response.content)

                self.assertTrue(data['result'])
                self.assertEqual(data['data'], assert_data)

                TaskFlowInstance.objects.create_pipeline_instance_exclude_task_nodes.reset_mock(
                )
                TaskFlowInstance.objects.create.reset_mock()

            pt1 = MockPipelineTemplate(id=1, name='pt1')

            tmpl = MockCommonTemplate(id=1, pipeline_template=pt1)

            with mock.patch(
                    COMMONTEMPLATE_SELECT_RELATE,
                    MagicMock(return_value=MockQuerySet(get_result=tmpl))):
                assert_data = {'task_id': TEST_TASKFLOW_ID}
                response = self.client.post(path=self.CREATE_TASK_URL.format(
                    template_id=TEST_TEMPLATE_ID, bk_biz_id=TEST_BIZ_CC_ID),
                                            data=json.dumps({
                                                'name':
                                                'name',
                                                'constants':
                                                'constants',
                                                'exclude_task_nodes_id':
                                                'exclude_task_nodes_id',
                                                'template_source':
                                                'common',
                                                'flow_type':
                                                'common'
                                            }),
                                            content_type="application/json",
                                            HTTP_BK_APP_CODE=TEST_APP_CODE)

                TaskFlowInstance.objects.create_pipeline_instance_exclude_task_nodes.assert_called_once_with(
                    tmpl, {
                        'name': 'name',
                        'creator': ''
                    }, 'constants', 'exclude_task_nodes_id')

                TaskFlowInstance.objects.create.assert_called_once_with(
                    business=biz,
                    category=tmpl.category,
                    pipeline_instance=TEST_DATA,
                    template_id=TEST_TEMPLATE_ID,
                    create_method='api',
                    create_info=TEST_APP_CODE,
                    flow_type='common',
                    current_flow='execute_task')

                data = json.loads(response.content)

                self.assertTrue(data['result'])
                self.assertEqual(data['data'], assert_data)

    @mock.patch(BUSINESS_GET,
                MagicMock(return_value=MockBusiness(cc_id=TEST_BIZ_CC_ID,
                                                    cc_name=TEST_BIZ_CC_NAME)))
    @mock.patch(TASKTEMPLATE_SELECT_RELATE,
                MagicMock(return_value=MockQuerySet()))
    @mock.patch(COMMONTEMPLATE_SELECT_RELATE,
                MagicMock(return_value=MockQuerySet()))
    @mock.patch(APIGW_VIEW_JSON_SCHEMA_VALIDATE,
                MagicMock(side_effect=jsonschema.ValidationError('')))
    def test_create_task__validate_fail(self):
        response = self.client.post(path=self.CREATE_TASK_URL.format(
            template_id=TEST_TEMPLATE_ID, bk_biz_id=TEST_BIZ_CC_ID),
                                    data=json.dumps({
                                        'constants':
                                        'constants',
                                        'exclude_task_node_id':
                                        'exclude_task_node_id'
                                    }),
                                    content_type="application/json")

        data = json.loads(response.content)

        self.assertFalse(data['result'])
        self.assertTrue('message' in data)

        response = self.client.post(
            path=self.CREATE_TASK_URL.format(template_id=TEST_TEMPLATE_ID,
                                             bk_biz_id=TEST_BIZ_CC_ID),
            data=json.dumps({
                'constants': 'constants',
                'exclude_task_node_id': 'exclude_task_node_id',
                'template_source': 'common'
            }),
            content_type="application/json")

        data = json.loads(response.content)

        self.assertFalse(data['result'])
        self.assertTrue('message' in data)

    @mock.patch(BUSINESS_GET,
                MagicMock(return_value=MockBusiness(cc_id=TEST_BIZ_CC_ID,
                                                    cc_name=TEST_BIZ_CC_NAME)))
    @mock.patch(TASKTEMPLATE_SELECT_RELATE,
                MagicMock(return_value=MockQuerySet()))
    @mock.patch(COMMONTEMPLATE_SELECT_RELATE,
                MagicMock(return_value=MockQuerySet()))
    @mock.patch(APIGW_VIEW_JSON_SCHEMA_VALIDATE, MagicMock())
    def test_create_task__without_app_code(self):
        response = self.client.post(path=self.CREATE_TASK_URL.format(
            template_id=TEST_TEMPLATE_ID, bk_biz_id=TEST_BIZ_CC_ID),
                                    data=json.dumps({
                                        'constants':
                                        'constants',
                                        'exclude_task_node_id':
                                        'exclude_task_node_id'
                                    }),
                                    content_type="application/json")

        data = json.loads(response.content)

        self.assertFalse(data['result'])
        self.assertTrue('message' in data)

        response = self.client.post(
            path=self.CREATE_TASK_URL.format(template_id=TEST_TEMPLATE_ID,
                                             bk_biz_id=TEST_BIZ_CC_ID),
            data=json.dumps({
                'constants': 'constants',
                'exclude_task_node_id': 'exclude_task_node_id',
                'template_source': 'common'
            }),
            content_type="application/json")

        data = json.loads(response.content)

        self.assertFalse(data['result'])
        self.assertTrue('message' in data)

    @mock.patch(BUSINESS_GET,
                MagicMock(return_value=MockBusiness(cc_id=TEST_BIZ_CC_ID,
                                                    cc_name=TEST_BIZ_CC_NAME)))
    @mock.patch(TASKINSTANCE_CREATE_PIPELINE,
                MagicMock(side_effect=PipelineException()))
    @mock.patch(APIGW_VIEW_JSON_SCHEMA_VALIDATE, MagicMock())
    def test_create_task__create_pipeline_raise(self):
        pt1 = MockPipelineTemplate(id=1, name='pt1')

        tmpl = MockTaskTemplate(id=1, pipeline_template=pt1)

        with mock.patch(TASKTEMPLATE_SELECT_RELATE,
                        MagicMock(return_value=MockQuerySet(get_result=tmpl))):
            response = self.client.post(path=self.CREATE_TASK_URL.format(
                template_id=TEST_TEMPLATE_ID, bk_biz_id=TEST_BIZ_CC_ID),
                                        data=json.dumps({
                                            'name':
                                            'name',
                                            'constants':
                                            'constants',
                                            'exclude_task_node_id':
                                            'exclude_task_node_id'
                                        }),
                                        content_type="application/json",
                                        HTTP_BK_APP_CODE=TEST_APP_CODE)

            data = json.loads(response.content)

            self.assertFalse(data['result'])
            self.assertTrue('message' in data)

        pt1 = MockPipelineTemplate(id=1, name='pt1')

        tmpl = MockCommonTemplate(id=1, pipeline_template=pt1)

        with mock.patch(COMMONTEMPLATE_SELECT_RELATE,
                        MagicMock(return_value=MockQuerySet(get_result=tmpl))):
            response = self.client.post(path=self.CREATE_TASK_URL.format(
                template_id=TEST_TEMPLATE_ID, bk_biz_id=TEST_BIZ_CC_ID),
                                        data=json.dumps({
                                            'name':
                                            'name',
                                            'constants':
                                            'constants',
                                            'exclude_task_node_id':
                                            'exclude_task_node_id',
                                            'template_source':
                                            'common'
                                        }),
                                        content_type="application/json",
                                        HTTP_BK_APP_CODE=TEST_APP_CODE)

            data = json.loads(response.content)

            self.assertFalse(data['result'])
            self.assertTrue('message' in data)

    @mock.patch(BUSINESS_GET,
                MagicMock(return_value=MockBusiness(cc_id=TEST_BIZ_CC_ID,
                                                    cc_name=TEST_BIZ_CC_NAME)))
    @mock.patch(TASKINSTANCE_CREATE_PIPELINE,
                MagicMock(return_value=(False, '')))
    @mock.patch(APIGW_VIEW_JSON_SCHEMA_VALIDATE, MagicMock())
    def test_create_task__create_pipeline_fail(self):
        pt1 = MockPipelineTemplate(id=1, name='pt1')

        tmpl = MockTaskTemplate(id=1, pipeline_template=pt1)

        with mock.patch(TASKTEMPLATE_SELECT_RELATE,
                        MagicMock(return_value=MockQuerySet(get_result=tmpl))):
            response = self.client.post(path=self.CREATE_TASK_URL.format(
                template_id=TEST_TEMPLATE_ID, bk_biz_id=TEST_BIZ_CC_ID),
                                        data=json.dumps({
                                            'name':
                                            'name',
                                            'constants':
                                            'constants',
                                            'exclude_task_node_id':
                                            'exclude_task_node_id'
                                        }),
                                        content_type="application/json",
                                        HTTP_BK_APP_CODE=TEST_APP_CODE)

            data = json.loads(response.content)

            self.assertFalse(data['result'])
            self.assertTrue('message' in data)

        pt1 = MockPipelineTemplate(id=1, name='pt1')

        tmpl = MockCommonTemplate(id=1, pipeline_template=pt1)

        with mock.patch(COMMONTEMPLATE_SELECT_RELATE,
                        MagicMock(return_value=MockQuerySet(get_result=tmpl))):
            response = self.client.post(path=self.CREATE_TASK_URL.format(
                template_id=TEST_TEMPLATE_ID, bk_biz_id=TEST_BIZ_CC_ID),
                                        data=json.dumps({
                                            'name':
                                            'name',
                                            'constants':
                                            'constants',
                                            'exclude_task_node_id':
                                            'exclude_task_node_id',
                                            'template_source':
                                            'common'
                                        }),
                                        content_type="application/json",
                                        HTTP_BK_APP_CODE=TEST_APP_CODE)

            data = json.loads(response.content)

            self.assertFalse(data['result'])
            self.assertTrue('message' in data)

    def test_start_task(self):
        assert_return = {'result': True}
        task = MockTaskFlowInstance(task_action_return=assert_return)

        with mock.patch(TASKINSTANCE_GET, MagicMock(return_value=task)):
            response = self.client.post(path=self.START_TASK_URL.format(
                task_id=TEST_TASKFLOW_ID, bk_biz_id=TEST_BIZ_CC_ID))

            task.task_action.assert_called_once_with('start', '')

            data = json.loads(response.content)

            self.assertEqual(data, assert_return)

    def test_operate_task(self):
        assert_return = {'result': True}
        assert_action = 'any_action'
        task = MockTaskFlowInstance(task_action_return=assert_return)

        with mock.patch(TASKINSTANCE_GET, MagicMock(return_value=task)):
            response = self.client.post(
                path=self.OPERATE_TASK_URL.format(task_id=TEST_TASKFLOW_ID,
                                                  bk_biz_id=TEST_BIZ_CC_ID),
                data=json.dumps({'action': assert_action}),
                content_type='application/json')

            task.task_action.assert_called_once_with(assert_action, '')

            data = json.loads(response.content)

            self.assertEqual(data, assert_return)

    def test_get_task_status__success(self):
        task = MockTaskFlowInstance(get_status_return=TEST_DATA)

        with mock.patch(TASKINSTANCE_GET, MagicMock(return_value=task)):
            response = self.client.get(path=self.GET_TASK_STATUS_URL.format(
                task_id=TEST_TASKFLOW_ID, bk_biz_id=TEST_BIZ_CC_ID))

            data = json.loads(response.content)
            self.assertTrue(data['result'])
            self.assertEqual(data['data'], TEST_DATA)

    def test_get_task_status__raise(self):
        task = MockTaskFlowInstance(get_status_raise=Exception())

        with mock.patch(TASKINSTANCE_GET, MagicMock(return_value=task)):
            response = self.client.get(path=self.GET_TASK_STATUS_URL.format(
                task_id=TEST_TASKFLOW_ID, bk_biz_id=TEST_BIZ_CC_ID))

            data = json.loads(response.content)
            self.assertFalse(data['result'])
            self.assertTrue('message' in data)

    @mock.patch(TASKINSTANCE_FORMAT_STATUS, MagicMock())
    @mock.patch(APIGW_VIEW_PIPELINE_API_GET_STATUS_TREE,
                MagicMock(return_value=TEST_DATA))
    def test_get_task_status__is_subprocess(self):
        task = MockTaskFlowInstance(
            get_status_raise=TaskFlowInstance.DoesNotExist())

        with mock.patch(TASKINSTANCE_GET, MagicMock(return_value=task)):
            response = self.client.get(path=self.GET_TASK_STATUS_URL.format(
                task_id=TEST_TASKFLOW_ID, bk_biz_id=TEST_BIZ_CC_ID))

            TaskFlowInstance.format_pipeline_status.assert_called_once_with(
                TEST_DATA)

            data = json.loads(response.content)
            self.assertTrue(data['result'])
            self.assertEqual(data['data'], TEST_DATA)

    @mock.patch(APIGW_VIEW_PIPELINE_API_GET_STATUS_TREE,
                MagicMock(return_value=TEST_DATA))
    def test_get_task_status__is_subprocess_raise(self):
        task = MockTaskFlowInstance(
            get_status_raise=TaskFlowInstance.DoesNotExist())

        with mock.patch(TASKINSTANCE_GET, MagicMock(return_value=task)):
            with mock.patch(APIGW_VIEW_PIPELINE_API_GET_STATUS_TREE,
                            MagicMock(side_effect=Exception())):
                response = self.client.get(
                    path=self.GET_TASK_STATUS_URL.format(
                        task_id=TEST_TASKFLOW_ID, bk_biz_id=TEST_BIZ_CC_ID))

                data = json.loads(response.content)
                self.assertFalse(data['result'])
                self.assertTrue('message' in data)

            with mock.patch(TASKINSTANCE_FORMAT_STATUS,
                            MagicMock(side_effect=Exception())):
                response = self.client.get(
                    path=self.GET_TASK_STATUS_URL.format(
                        task_id=TEST_TASKFLOW_ID, bk_biz_id=TEST_BIZ_CC_ID))

                data = json.loads(response.content)
                self.assertFalse(data['result'])
                self.assertTrue('message' in data)

    @mock.patch(TASKINSTANCE_EXTEN_CLASSIFIED_COUNT,
                MagicMock(return_value=(True, TEST_DATA)))
    def test_query_task_count__success(self):
        response = self.client.post(
            path=self.QUERY_TASK_COUNT_URL.format(bk_biz_id=TEST_BIZ_CC_ID),
            data=json.dumps({'group_by': 'category'}),
            content_type='application/json')

        data = json.loads(response.content)
        self.assertTrue(data['result'])
        self.assertEqual(data['data'], TEST_DATA)

    def test_query_task_count__conditions_is_not_dict(self):
        response = self.client.post(
            path=self.QUERY_TASK_COUNT_URL.format(bk_biz_id=TEST_BIZ_CC_ID),
            data=json.dumps({'conditions': []}),
            content_type='application/json')

        data = json.loads(response.content)
        self.assertFalse(data['result'])
        self.assertTrue('message' in data)

    def test_query_task_count__group_by_is_not_valid(self):
        response = self.client.post(
            path=self.QUERY_TASK_COUNT_URL.format(bk_biz_id=TEST_BIZ_CC_ID),
            data=json.dumps({'group_by': 'invalid_value'}),
            content_type='application/json')

        data = json.loads(response.content)
        self.assertFalse(data['result'])
        self.assertTrue('message' in data)

    @mock.patch(TASKINSTANCE_EXTEN_CLASSIFIED_COUNT,
                MagicMock(return_value=(False, '')))
    def test_query_task_count__extend_classified_count_fail(self):
        response = self.client.post(
            path=self.QUERY_TASK_COUNT_URL.format(bk_biz_id=TEST_BIZ_CC_ID),
            data=json.dumps({'group_by': 'category'}),
            content_type='application/json')

        TaskFlowInstance.objects.extend_classified_count.assert_called_once_with(
            'category', {
                'business__cc_id': TEST_BIZ_CC_ID,
                'is_deleted': False
            })

        data = json.loads(response.content)
        self.assertFalse(data['result'])
        self.assertTrue('message' in data)

    def test_get_periodic_task_list(self):
        pt1 = MockPeriodicTask(id='1')
        pt2 = MockPeriodicTask(id='2')
        pt3 = MockPeriodicTask(id='3')

        periodic_tasks = [pt1, pt2, pt3]

        assert_data = [{
            'id': task.id,
            'name': task.name,
            'template_id': task.template_id,
            'creator': task.creator,
            'cron': task.cron,
            'enabled': task.enabled,
            'last_run_at': strftime_with_timezone(task.last_run_at),
            'total_run_count': task.total_run_count,
        } for task in periodic_tasks]

        with mock.patch(PERIODIC_TASK_FILTER,
                        MagicMock(return_value=periodic_tasks)):
            response = self.client.get(
                path=self.GET_PERIODIC_TASK_LIST_URL.format(
                    bk_biz_id=TEST_BIZ_CC_ID))

            data = json.loads(response.content)

            self.assertTrue(data['result'])
            self.assertEqual(data['data'], assert_data)

    def test_get_periodic_task_info__success(self):
        task = MockPeriodicTask()
        assert_data = {
            'id': task.id,
            'name': task.name,
            'template_id': task.template_id,
            'creator': task.creator,
            'cron': task.cron,
            'enabled': task.enabled,
            'last_run_at': strftime_with_timezone(task.last_run_at),
            'total_run_count': task.total_run_count,
            'form': task.form,
            'pipeline_tree': task.pipeline_tree
        }

        with mock.patch(PERIODIC_TASK_GET, MagicMock(return_value=task)):
            response = self.client.get(
                path=self.GET_PERIODIC_TASK_INFO_URL.format(
                    task_id=TEST_PERIODIC_TASK_ID, bk_biz_id=TEST_BIZ_CC_ID))

            data = json.loads(response.content)

            self.assertTrue(data['result'])
            self.assertEqual(data['data'], assert_data)

    @mock.patch(PERIODIC_TASK_GET,
                MagicMock(side_effect=PeriodicTask.DoesNotExist))
    def test_periodic_task_info__task_does_not_exist(self):
        response = self.client.get(path=self.GET_PERIODIC_TASK_INFO_URL.format(
            task_id=TEST_PERIODIC_TASK_ID, bk_biz_id=TEST_BIZ_CC_ID))

        data = json.loads(response.content)

        self.assertFalse(data['result'])
        self.assertTrue('message' in data)

    @mock.patch(TASKINSTANCE_PREVIEW_TREE, MagicMock())
    @mock.patch(APIGW_VIEW_JSON_SCHEMA_VALIDATE, MagicMock())
    def test_create_periodic_task__success(self):
        task = MockPeriodicTask()
        assert_data = {
            'id': task.id,
            'name': task.name,
            'template_id': task.template_id,
            'creator': task.creator,
            'cron': task.cron,
            'enabled': task.enabled,
            'last_run_at': strftime_with_timezone(task.last_run_at),
            'total_run_count': task.total_run_count,
            'form': task.form,
            'pipeline_tree': task.pipeline_tree
        }
        biz = MockBusiness(cc_id=TEST_BIZ_CC_ID, cc_name=TEST_BIZ_CC_NAME)
        template = MockTaskTemplate()

        with mock.patch(TASKTEMPLATE_GET, MagicMock(return_value=template)):
            with mock.patch(BUSINESS_GET, MagicMock(return_value=biz)):
                with mock.patch(PERIODIC_TASK_CREATE,
                                MagicMock(return_value=task)):
                    response = self.client.post(
                        path=self.CREATE_PERIODIC_TASK_URL.format(
                            template_id=TEST_TEMPLATE_ID,
                            bk_biz_id=TEST_BIZ_CC_ID),
                        data=json.dumps({
                            'name':
                            task.name,
                            'cron':
                            task.cron,
                            'exclude_task_nodes_id':
                            'exclude_task_nodes_id'
                        }),
                        content_type='application/json')

                    TaskFlowInstance.objects.preview_pipeline_tree_exclude_task_nodes.assert_called_with(
                        template.pipeline_tree, 'exclude_task_nodes_id')

                    PeriodicTask.objects.create.assert_called_once_with(
                        business=biz,
                        template=template,
                        name=task.name,
                        cron=task.cron,
                        pipeline_tree=template.pipeline_tree,
                        creator='')

                    data = json.loads(response.content)

                    self.assertTrue(data['result'])
                    self.assertEqual(data['data'], assert_data)

    @mock.patch(TASKTEMPLATE_GET,
                MagicMock(side_effect=TaskTemplate.DoesNotExist()))
    def test_create_periodic_task__template_does_not_exist(self):
        response = self.client.post(path=self.CREATE_PERIODIC_TASK_URL.format(
            template_id=TEST_TEMPLATE_ID, bk_biz_id=TEST_BIZ_CC_ID),
                                    content_type='application/json')

        data = json.loads(response.content)

        self.assertFalse(data['result'])
        self.assertTrue('message' in data)

    @mock.patch(TASKTEMPLATE_GET, MagicMock(return_value=MockTaskTemplate()))
    @mock.patch(APIGW_VIEW_JSON_SCHEMA_VALIDATE,
                MagicMock(side_effect=jsonschema.ValidationError('')))
    def test_create_periodic_task__params_validate_fail(self):
        response = self.client.post(path=self.CREATE_PERIODIC_TASK_URL.format(
            template_id=TEST_TEMPLATE_ID, bk_biz_id=TEST_BIZ_CC_ID),
                                    content_type='application/json')

        data = json.loads(response.content)

        self.assertFalse(data['result'])
        self.assertTrue('message' in data)

    @mock.patch(TASKTEMPLATE_GET, MagicMock(return_value=MockTaskTemplate()))
    @mock.patch(APIGW_VIEW_JSON_SCHEMA_VALIDATE, MagicMock())
    @mock.patch(TASKINSTANCE_PREVIEW_TREE, MagicMock(side_effect=Exception()))
    def test_create_periodic_task__preview_pipeline_fail(self):
        response = self.client.post(path=self.CREATE_PERIODIC_TASK_URL.format(
            template_id=TEST_TEMPLATE_ID, bk_biz_id=TEST_BIZ_CC_ID),
                                    content_type='application/json')

        data = json.loads(response.content)

        self.assertFalse(data['result'])
        self.assertTrue('message' in data)

    @mock.patch(BUSINESS_GET,
                MagicMock(return_value=MockBusiness(cc_id=TEST_BIZ_CC_ID,
                                                    cc_name=TEST_BIZ_CC_NAME)))
    @mock.patch(TASKTEMPLATE_GET, MagicMock(return_value=MockTaskTemplate()))
    @mock.patch(APIGW_VIEW_JSON_SCHEMA_VALIDATE, MagicMock())
    @mock.patch(TASKINSTANCE_PREVIEW_TREE, MagicMock())
    @mock.patch(PERIODIC_TASK_CREATE, MagicMock(side_effect=Exception()))
    def test_create_periodic_task__periodic_task_create_fail(self):
        response = self.client.post(path=self.CREATE_PERIODIC_TASK_URL.format(
            template_id=TEST_TEMPLATE_ID, bk_biz_id=TEST_BIZ_CC_ID),
                                    data=json.dumps({
                                        'name': 'name',
                                        'cron': 'cron'
                                    }),
                                    content_type='application/json')

        data = json.loads(response.content)

        self.assertFalse(data['result'])
        self.assertTrue('message' in data)

    @mock.patch(BUSINESS_GET, MagicMock(return_value=MockBusiness()))
    def test_set_periodic_task_enabled__success(self):
        task = MockPeriodicTask()
        with mock.patch(PERIODIC_TASK_GET, MagicMock(return_value=task)):
            response = self.client.post(
                path=self.SET_PERIODIC_TASK_ENABLED_URL.format(
                    task_id=TEST_PERIODIC_TASK_ID, bk_biz_id=TEST_BIZ_CC_ID),
                data=json.dumps({'enabled': True}),
                content_type='application/json')

            task.set_enabled.assert_called_once_with(True)

            data = json.loads(response.content)

            self.assertTrue(data['result'])
            self.assertEqual(data['data'], {'enabled': task.enabled})

    @mock.patch(PERIODIC_TASK_GET,
                MagicMock(side_effect=PeriodicTask.DoesNotExist))
    def test_set_periodic_task_enabled__task_does_not_exist(self):
        response = self.client.post(
            path=self.SET_PERIODIC_TASK_ENABLED_URL.format(
                task_id=TEST_PERIODIC_TASK_ID, bk_biz_id=TEST_BIZ_CC_ID),
            data=json.dumps({'enabled': True}),
            content_type='application/json')

        data = json.loads(response.content)

        self.assertFalse(data['result'])
        self.assertTrue('message' in data)

    def test_modify_cron_for_periodic_task__success(self):
        biz = MockBusiness()
        task = MockPeriodicTask()
        cron = {'minute': '*/1'}

        with mock.patch(BUSINESS_GET, MagicMock(return_value=biz)):
            with mock.patch(PERIODIC_TASK_GET, MagicMock(return_value=task)):
                response = self.client.post(
                    path=self.MODIFY_PERIODIC_TASK_CRON_URL.format(
                        task_id=TEST_PERIODIC_TASK_ID,
                        bk_biz_id=TEST_BIZ_CC_ID),
                    data=json.dumps({'cron': cron}),
                    content_type='application/json')

                task.modify_cron.assert_called_once_with(cron, biz.time_zone)

                data = json.loads(response.content)

                self.assertTrue(data['result'])
                self.assertEqual(data['data'], {'cron': task.cron})

    @mock.patch(BUSINESS_GET, MagicMock(return_value=MockBusiness()))
    @mock.patch(PERIODIC_TASK_GET,
                MagicMock(side_effect=PeriodicTask.DoesNotExist))
    def test_modify_cron_for_periodic_task__task_does_not_exist(self):
        response = self.client.post(
            path=self.MODIFY_PERIODIC_TASK_CRON_URL.format(
                task_id=TEST_PERIODIC_TASK_ID, bk_biz_id=TEST_BIZ_CC_ID),
            data=json.dumps({'enabled': True}),
            content_type='application/json')

        data = json.loads(response.content)

        self.assertFalse(data['result'])
        self.assertTrue('message' in data)

    def test_modify_cron_for_periodic_task__modify_raise(self):
        biz = MockBusiness()
        task = MockPeriodicTask()
        task.modify_cron = MagicMock(side_effect=Exception())
        cron = {'minute': '*/1'}

        with mock.patch(BUSINESS_GET, MagicMock(return_value=biz)):
            with mock.patch(PERIODIC_TASK_GET, MagicMock(return_value=task)):
                response = self.client.post(
                    path=self.MODIFY_PERIODIC_TASK_CRON_URL.format(
                        task_id=TEST_PERIODIC_TASK_ID,
                        bk_biz_id=TEST_BIZ_CC_ID),
                    data=json.dumps({'cron': cron}),
                    content_type='application/json')

                data = json.loads(response.content)

                self.assertFalse(data['result'])
                self.assertTrue('message' in data)

    def test_modify_constants_for_periodic_task__success(self):
        biz = MockBusiness()
        task = MockPeriodicTask()
        constants = {'k': 'v'}

        with mock.patch(BUSINESS_GET, MagicMock(return_value=biz)):
            with mock.patch(PERIODIC_TASK_GET, MagicMock(return_value=task)):
                response = self.client.post(
                    path=self.MODIFY_PERIODIC_TASK_CONSTANTS_URL.format(
                        task_id=TEST_PERIODIC_TASK_ID,
                        bk_biz_id=TEST_BIZ_CC_ID),
                    data=json.dumps({'constants': constants}),
                    content_type='application/json')

                task.modify_constants.assert_called_once_with(constants)

                data = json.loads(response.content)

                self.assertTrue(data['result'])
                self.assertEqual(data['data'],
                                 task.modify_constants.return_value)

    @mock.patch(PERIODIC_TASK_GET,
                MagicMock(side_effect=PeriodicTask.DoesNotExist))
    def test_modify_constants_for_periodic_task__task_does_not_exist(self):
        response = self.client.post(
            path=self.MODIFY_PERIODIC_TASK_CONSTANTS_URL.format(
                task_id=TEST_PERIODIC_TASK_ID, bk_biz_id=TEST_BIZ_CC_ID),
            content_type='application/json')

        data = json.loads(response.content)

        self.assertFalse(data['result'])
        self.assertTrue('message' in data)

    def test_modify_constants_for_periodic_task__modify_constants_raise(self):
        biz = MockBusiness()
        task = MockPeriodicTask()
        task.modify_constants = MagicMock(side_effect=Exception())

        with mock.patch(BUSINESS_GET, MagicMock(return_value=biz)):
            with mock.patch(PERIODIC_TASK_GET, MagicMock(return_value=task)):
                response = self.client.post(
                    path=self.MODIFY_PERIODIC_TASK_CONSTANTS_URL.format(
                        task_id=TEST_PERIODIC_TASK_ID,
                        bk_biz_id=TEST_BIZ_CC_ID),
                    content_type='application/json')

                data = json.loads(response.content)

                self.assertFalse(data['result'])
                self.assertTrue('message' in data)
예제 #9
0
    def make(self, key):
        """ Read ephys data and insert into table """
        import h5py

        # Read the scan
        print('Reading file...')
        vreso_path, filename_base = (PatchSession *
                                     (Recording() & key)).fetch1(
                                         'recording_path', 'file_name')
        local_path = lab.Paths().get_local_path(vreso_path)
        filename = os.path.join(local_path, filename_base + '_%d.h5')
        with h5py.File(filename, 'r', driver='family', memb_size=0) as f:

            # Load timing info
            ANALOG_PACKET_LEN = f.attrs['waveform Frame Size'][0]

            # Get counter timestamps and convert to seconds
            patch_times = h5.ts2sec(f['waveform'][10, :], is_packeted=True)

            # Detect rising edges in scanimage clock signal (start of each frame)
            binarized_signal = f['waveform'][
                9, :] > 2.7  # TTL voltage low/high threshold
            rising_edges = np.where(
                np.diff(binarized_signal.astype(int)) > 0)[0]
            frame_times = patch_times[rising_edges]

            # Correct NaN gaps in timestamps (mistimed or dropped packets during recording)
            if np.any(np.isnan(frame_times)):
                # Raise exception if first or last frame pulse was recorded in mistimed packet
                if np.isnan(frame_times[0]) or np.isnan(frame_times[-1]):
                    msg = (
                        'First or last frame happened during misstamped packets. Pulses '
                        'could have been missed: start/end of scanning is unknown.'
                    )
                    raise PipelineException(msg)

                # Fill each gap of nan values with correct number of timepoints
                frame_period = np.nanmedian(np.diff(frame_times))  # approx
                nan_limits = np.where(np.diff(np.isnan(frame_times)))[0]
                nan_limits[
                    1::
                    2] += 1  # limits are indices of the last valid point before the nan gap and first after it
                correct_fts = []
                for i, (start, stop) in enumerate(
                        zip(nan_limits[::2], nan_limits[1::2])):
                    correct_fts.extend(
                        frame_times[0 if i == 0 else nan_limits[2 * i -
                                                                1]:start + 1])
                    num_missing_points = int(
                        round((frame_times[stop] - frame_times[start]) /
                              frame_period - 1))
                    correct_fts.extend(
                        np.linspace(frame_times[start], frame_times[stop],
                                    num_missing_points + 2)[1:-1])
                correct_fts.extend(frame_times[nan_limits[-1]:])
                frame_times = np.array(correct_fts)

                # Record the NaN fix
                num_gaps = int(len(nan_limits) / 2)
                nan_length = sum(nan_limits[1::2] -
                                 nan_limits[::2]) * frame_period  # secs

            ####### WARNING: FRAME INTERVALS NOT ERROR CHECKED - TEMP CODE #######
            # Check that frame times occur at the same period
            frame_intervals = np.diff(frame_times)
            frame_period = np.median(frame_intervals)
            #if np.any(abs(frame_intervals - frame_period) > 0.15 * frame_period):
            #    raise PipelineException('Frame time period is irregular')

            # Drop last frame time if scan crashed or was stopped before completion
            valid_times = ~np.isnan(
                patch_times[rising_edges[0]:rising_edges[-1]]
            )  # restricted to scan period
            binarized_valid = binarized_signal[
                rising_edges[0]:rising_edges[-1]][valid_times]
            frame_duration = np.mean(binarized_valid) * frame_period
            falling_edges = np.where(
                np.diff(binarized_signal.astype(int)) < 0)[0]
            last_frame_duration = patch_times[
                falling_edges[-1]] - frame_times[-1]
            if (np.isnan(last_frame_duration) or last_frame_duration < 0
                    or abs(last_frame_duration - frame_duration) >
                    0.15 * frame_duration):
                frame_times = frame_times[:-1]

            ####### WARNING: NO CORRECTION APPLIED - TEMP CODE #######
            voltage = np.array(f['waveform'][1, :], dtype='float32')
            current = np.array(f['waveform'][0, :], dtype='float32')
            command = np.array(f['waveform'][5, :], dtype='float32')

            ####### WARNING: DUMMY VARIABLES - TEMP CODE #######
            vgain = 0
            igain = 0
            command_gain = 0

            self.insert1({
                **key, 'voltage': voltage,
                'current': current,
                'command': command,
                'patch_times': patch_times,
                'frame_times': frame_times,
                'vgain': vgain,
                'igain': igain,
                'command_gain': command_gain
            })
예제 #10
0
def fetch_timing_data(
    scan_key: dict,
    source_type: str,
    target_type: str,
    debug: bool = True,
):
    """
    Fetches timing data for source and target recordings. Adjusts both timings based on any calculable delays. Returns two
    arrays. Converts target recording times on target clock into target recording times on source clock if the two are different.

        Parameters:

                scan_key: A dictionary specifying a single scan and/or field. A single field must be defined if requesting
                          a source or target from ScanImage. If key specifies a single unit, unit delay will be added to
                          all timepoints recorded. Single units can be specified via unique mask_id + field or via unit_id.
                          If only field is specified, average field delay will be added.

                source_type: A string specifying what recording times to fetch for source_times. Both target and source times
                             will be returned on whatever clock is used for source_type. Fluorescence and deconvolution have
                             a dash followed by "behavior" or "stimulus" to refer to which clock you are using.
                             Supported options:
                                 'fluorescence-stimulus', 'deconvolution-stimulus', ,'fluorescence-behavior',
                                 'deconvolution-behavior', 'pupil', 'treadmill', 'respiration'

                target_type: A string specifying what recording times to fetch for target_times. Both target and source times
                             will be returned on whatever clock is used for source_type. Fluorescence and deconvolution have
                             a dash followed by "behavior" or "stimulus" to refer to which clock you are using.
                             Supported options:
                                 'fluorescence-stimulus', 'deconvolution-stimulus', ,'fluorescence-behavior',
                                 'deconvolution-behavior', 'pupil', 'treadmill', 'respiration'

                debug: Set function to print helpful debug text while running


        Returns:

                source_times: Numpy array of times for source recording on source clock

                target_times: Numpy array of times for target recording on source clock
    """

    ## Make settings strings lowercase
    source_type = source_type.lower()
    target_type = target_type.lower()

    ##
    ## Set pipe, error check scan_key, and fetch field offset
    ##

    ## Define the pipe (meso/reso) to use
    if len(fuse.MotionCorrection & scan_key) == 0:
        msg = f"scan_key {scan_key} not found in fuse.MotionCorrection."
        raise PipelineException(msg)
    pipe = (fuse.MotionCorrection & scan_key).module

    ## Make strings lowercase and process indices
    source_type = source_type.lower()
    target_type = target_type.lower()

    ## Set default values for later processing
    field_offset = 0
    slice_num = 1
    ms_delay = 0

    ## Determine if source or target type requires extra scan info
    scan_types = (
        "fluorescence-stimulus",
        "fluorescence-behavior",
        "deconvolution-stimulus",
        "deconvolution-behavior",
    )
    if source_type in scan_types or target_type in scan_types:

        ## Check scan_key defines a unique scan
        if len(pipe.ScanInfo & scan_key) != 1:
            msg = (
                f"scan_key {scan_key} does not define a unique scan. "
                f"Matching scans found: {len(fuse.MotionCorrection & scan_key)}"
            )
            raise PipelineException(msg)

        ## Check a single field is defined by scan_key
        if len(pipe.ScanInfo.Field & scan_key) != 1:
            msg = (
                f"scan_key {scan_key} must specify a single field when source or target type is set "
                f"to 'scan'. Matching fields found: {len(pipe.ScanInfo.Field & scan_key)}"
            )
            raise PipelineException(msg)

        ## Determine field offset to slice times later on and set ms_delay to field average
        scan_restriction = (pipe.ScanInfo & scan_key).fetch("KEY")
        all_z = np.unique((pipe.ScanInfo.Field & scan_restriction).fetch(
            "z", order_by="field ASC"))
        slice_num = len(all_z)
        field_z = (pipe.ScanInfo.Field & scan_key).fetch1("z")
        field_offset = np.where(all_z == field_z)[0][0]
        if debug:
            print(
                f"Field offset found as {field_offset} for depths 0-{len(all_z)}"
            )

        field_delay_im = (pipe.ScanInfo.Field & scan_key).fetch1("delay_image")
        average_field_delay = np.mean(field_delay_im)
        ms_delay = average_field_delay
        if debug:
            print(
                f"Average field delay found to be {round(ms_delay,4)}ms. This will be used unless a unit is specified in the key."
            )

        ## If included, add unit offset
        if "unit_id" in scan_key or "mask_id" in scan_key:
            if len(pipe.ScanSet.Unit & scan_key) > 0:
                unit_key = (pipe.ScanSet.Unit & scan_key).fetch1()
                ms_delay = (pipe.ScanSet.UnitInfo
                            & unit_key).fetch1("ms_delay")
                if debug:
                    print(
                        f"Unit found with delay of {round(ms_delay,4)}ms. Delay added to relevant times."
                    )
            else:
                if debug:
                    print(
                        f"Warning: ScanSet.Unit is not populated for the given key! Using field offset minimum instead."
                    )

    ##
    ## Fetch source and target sync data
    ##

    ## Define a lookup for data sources. Key values are in (data_table, column_name) tuples.
    data_source_lookup = {
        "fluorescence-stimulus": (stimulus.Sync, "frame_times"),
        "deconvolution-stimulus": (stimulus.Sync, "frame_times"),
        "fluorescence-behavior": (stimulus.BehaviorSync, "frame_times"),
        "deconvolution-behavior": (stimulus.BehaviorSync, "frame_times"),
        "treadmill": (treadmill.Treadmill, "treadmill_time"),
        "pupil": (pupil.Eye, "eye_time"),
        "respiration": (odor.Respiration * odor.MesoMatch, "times"),
    }

    ## Error check inputs
    if source_type not in data_source_lookup or target_type not in data_source_lookup:
        msg = (
            f"Source and target type combination '{source_type}' and '{target_type}' not supported. "
            f"Valid values are 'fluorescence-behavior', 'fluorescence-stimulus', 'deconvolution-behavior', "
            f"'deconvolution-stimulus', treadmill', 'respiration' or 'pupil'.")
        raise PipelineException(msg)

    ## Fetch source and target times using lookup dictionary
    source_table, source_column = data_source_lookup[source_type]
    source_times = (source_table & scan_key).fetch1(source_column).squeeze()

    target_table, target_column = data_source_lookup[target_type]
    target_times = (target_table & scan_key).fetch1(target_column).squeeze()

    ##
    ## Timing corrections
    ##

    ## Slice times if on ScanImage clock and add delay (scan_types defined near top)
    if source_type in scan_types:
        source_times = source_times[field_offset::slice_num] + ms_delay
    if target_type in scan_types:
        target_times = target_times[field_offset::slice_num] + ms_delay

    ##
    ## Interpolate into different clock if necessary
    ##

    clock_type_lookup = {
        "fluorescence-stimulus": "stimulus",
        "deconvolution-stimulus": "stimulus",
        "fluorescence-behavior": "behavior",
        "deconvolution-behavior": "behavior",
        "pupil": "behavior",
        "processed-pupil": "behavior",
        "treadmill": "behavior",
        "processed-treadmill": "behavior",
        "respiration": "odor",
    }

    sync_conversion_lookup = {
        "stimulus": stimulus.Sync,
        "behavior": stimulus.BehaviorSync,
        "odor": odor.OdorSync * odor.MesoMatch,
    }

    source_clock_type = clock_type_lookup[source_type]
    target_clock_type = clock_type_lookup[target_type]

    if source_clock_type != target_clock_type:

        interp_source_table = sync_conversion_lookup[source_clock_type]
        interp_target_table = sync_conversion_lookup[target_clock_type]

        interp_source = (interp_source_table
                         & scan_key).fetch1("frame_times").squeeze()
        interp_target = (interp_target_table
                         & scan_key).fetch1("frame_times").squeeze()

        target2source_interp = interpolate.interp1d(interp_target,
                                                    interp_source,
                                                    fill_value="extrapolate")
        target_times = target2source_interp(target_times)

    return source_times, target_times
예제 #11
0
def convert_clocks(
    scan_key: dict,
    input_list,
    source_format: str,
    source_type: str,
    target_format: str,
    target_type: str,
    drop_single_idx: bool = True,
    debug: bool = True,
):
    """
    Converts indices or times of interest on source clock to indices, times, or signals on target clock. Can convert
    a collection of event-triggered fragments of indices/times or a single flat list. Can also be used as an automated
    times/signal fetching function by setting input_list to None and source_type equal to target_type.

        Parameters:

                scan_key: A dictionary specifying a single scan and/or field. A single field must be defined if
                          requesting a source or target from ScanImage. If key specifies a single unit, unit delay
                          will be added to all timepoints recorded. Single units can be specified via unique
                          mask_id + field or via unit_id. If only field is specified, average field delay will be
                          added.


                input_list: List/array/None. Depending on the source_format, there are many possible structures:

                             - source_type='indices'
                                 1) List/array of indices to convert or a boolean array with True values at indices
                                    of interest. Indices can be discontinuous fragments (something like 20 indices
                                    around a several spikes) and the function will return of list of lists, each
                                    containings the corresponding target_idx fragments.
                                 2) None. Set input_list to None to use all indices where source time is not NaN.

                             - source_type='times'
                                 1) List of lists containing [start,stop] boundaries for times of note on source clock.
                                    Start/Stop times are included (>=,<=). These boundaries are converted to all indices
                                    equal to or between recording times on target recording.
                                 2) None. Set input_list to None to use all times where source time is not NaN.


                source_format: A string specifying what the input_list variable represents and what structure to expect.
                               See details for input_list variable to learn more.
                               Supported options:
                                   'indices', 'times'


                source_type: A string specifying what indices/times you want to convert from. Fluorescence and
                             deconvolution have a dash followed by "behavior" or "stimulus" to refer to which clock
                             you are using.
                             Supported options:
                                 'fluorescence-stimulus', 'deconvolution-stimulus', ,'fluorescence-behavior',
                                 'deconvolution-behavior', 'pupil', 'treadmill', 'respiration'


                target_format: A string specifying what values to return. "Times" has a dash followed by "source"
                               or "target" to specify if the returning times should be on the source clock or on
                               the target clock. If set to "signal", returns interpolated target signal on the
                               corresponding source_type recording times specified by input_list.
                               Supported options:
                                   'indices', 'times-source', 'times-target', 'signal'


                target_type: A string specifying what indices to convert into. Fluorescence and deconvolution have a
                             dash followed by "behavior" or "stimulus" to refer to which clock you are using.
                             Supported options:
                                 'fluorescence-stimulus', 'deconvolution-stimulus', ,'fluorescence-behavior',
                                 'deconvolution-behavior', 'pupil', 'treadmill', 'respiration'


                drop_single_idx: Boolean with the following behavior
                                     - True: Drop any signal fragments which would create a signal of length 1.
                                     - False: Raise an error and stop if any list of indices leads to a signal
                                              of length 1.
                                    ex. Source IDX [1,2,...,300] on a 500HZ recording will only correspond to
                                        target IDX [1] if target is recorded at 1Hz.


                debug: Set function to print helpful debug text while running


        Returns:

                requested_array: Numpy array of corresponding indices, times, or interpolated target signal. If
                multiple continuous fragments or time boundaries are in input_list, return value is a list of arrays.


        Warnings:

                * NaN refilling for signal interpolation will only refill values if NaNs stretch for multiple indices
                  on target clock

                * Recording points where the time value is NaN are dropped from analysis/processing


        Examples:

                Fetch fluorescence signal for one unit:

                    >>>key = dict(animal_id=17797, session=4, scan_idx=7, field=1, segmentation_method=6, mask_id=1, tracking_method=2)
                    >>>settings = dict(scan_key=key, input_list=None, source_format='indices', source_type='fluorescence-behavior',
                                       target_format='signal', target_type='fluorescence-behavior', drop_single_idx=True, debug=False)
                    >>>fluorescence_signal = convert_clocks(settings)


                Fetch recording times (on behavior clock) for one unit:

                    >>>key = dict(animal_id=17797, session=4, scan_idx=7, field=1, segmentation_method=6, mask_id=1, tracking_method=2)
                    >>>settings = dict(scan_key=key, input_list=None, source_format='indices', source_type='fluorescence-behavior',
                                       target_format='times-source', target_type='fluorescence-behavior', drop_single_idx=True, debug=False)
                    >>>fluorescence_times = convert_clocks(settings)


                Interpolate entire treadmill trace to fluorescence recording times:

                    >>>key = dict(animal_id=17797, session=4, scan_idx=7, field=1, segmentation_method=6, mask_id=1, tracking_method=2)
                    >>>settings = dict(scan_key=key, input_list=None, source_format='indices', source_type='fluorescence-behavior',
                                       target_format='signal', target_type='treadmill', drop_single_idx=True, debug=False)
                    >>>interpolated_treadmill = convert_clocks(settings)


                Convert discontinuous pupil IDX fragments to treadmill times (on behavior clock):

                    >>>key = dict(animal_id=17797, session=4, scan_idx=7, field=1, segmentation_method=6, mask_id=1, tracking_method=2)
                    >>>input_indices = np.concatenate(((np.arange(1000)), np.arange(1005, 2000)))
                    >>>settings = dict(scan_key=key, input_list=input_indices, source_format='indices', source_type='pupil',
                                       target_format='times-source', target_type='treadmill', drop_single_idx=True, debug=False)
                    >>>treadmill_time_fragments = convert_clocks(settings)


                Convert fluorescence time boundaries on behavior clock to fluorescence times on stimulus clock:

                    >>>key = dict(animal_id=17797, session=4, scan_idx=7, field=1, segmentation_method=6, mask_id=1, tracking_method=2)
                    >>>time_boundaries = [[400, 500], [501, 601]]
                    >>>settings = dict(scan_key=key, input_list=time_boundaries, source_format='times', source_type='fluorescence-behavior',
                                       target_format='times-target', target_type='fluorescence-stimulus', drop_single_idx=True, debug=False)
                    >>>fluorescence_stimulus_times_in_bounds = convert_clocks(settings)
    """

    ##
    ## Make settings strings lowercase
    ##

    source_format = source_format.lower()
    source_type = source_type.lower()
    target_format = target_format.lower()
    target_type = target_type.lower()

    ##
    ## Fetch source and target times, along with converting between Stimulus or Behavior clock if needed
    ##

    source_times_source_clock, target_times_source_clock = fetch_timing_data(
        scan_key, source_type, target_type, debug)
    target_times_target_clock, source_times_target_clock = fetch_timing_data(
        scan_key, target_type, source_type, debug)

    ##
    ## Convert indices to a list of numbers if argument equals None or a Boolean mask
    ##

    if source_format == "indices":
        if input_list is None:
            input_list = np.arange(len(source_times_source_clock))
        elif type(input_list[0]) == bool:
            input_list = np.where(input_list)[0]
        elif type(input_list[0]) == list or type(input_list[0]) == np.ndarray:
            input_list = [item for sublist in input_list for item in sublist
                          ]  ## Flatten array if list of lists
        else:
            ## Check for duplicates if manually entered
            if len(np.unique(input_list)) != len(input_list):
                msg = (
                    f"Duplicate entries found for provided indice array! "
                    f"Try to fix the error or use np.unique() on indices array."
                )
                raise PipelineException(msg)

    ## Convert behavior to indices to make None input work smoothly
    if "times" in source_format and input_list is None:
        input_list = np.arange(len(source_times_source_clock))
        source_format = "indices"

    ##
    ## Convert source indices to time boundaries, then convert time boundaries into target indices
    ##

    ## Convert indices into start/end times for each continuous fragment (incrementing by 1)
    if source_format == "indices":
        time_boundaries = find_time_boundaries(input_list,
                                               source_times_source_clock,
                                               drop_single_idx)
    elif "times" in source_format:
        time_boundaries = input_list
    else:
        msg = (f"Source format {source_format} not supported. "
               f"Valid options are 'indices' and 'times'.")
        raise PipelineException(msg)

    target_indices = []
    single_idx_count = 0

    ## Loop through start & end times and create list of indices corresponding to that block of time
    with np.errstate(invalid="ignore"):
        for [start, end] in time_boundaries:
            target_idx = np.where(
                np.logical_and(target_times_source_clock >= start,
                               target_times_source_clock <= end))[0]
            if len(target_idx) < 2:
                if drop_single_idx:
                    single_idx_count += 1
                else:
                    msg = (
                        f"Event of length {len(target_idx)} found. "
                        f"Set drop_single_idx to True to suppress these errors."
                    )
                    raise PipelineException(msg)
            else:
                target_indices.append(target_idx)

    if debug:
        print(
            f"Indices converted. {single_idx_count} events of length 0 or 1 dropped."
        )

    ##
    ## Interpolate related signal if requested, else just return the target_indices found.
    ##

    if target_format == "signal":

        ## Define source_indices if they're not already defined
        if source_format == "indices":
            source_indices = find_idx_boundaries(input_list, drop_single_idx)
        elif "times" in source_format:
            source_indices = convert_clocks(
                scan_key,
                input_list,
                source_format,
                source_type,
                "indices",
                target_type,
                drop_single_idx,
                False,
            )
        else:
            msg = (f"Source format {source_format} not supported. "
                   f"Valid options are 'indices' and 'times'.")
            raise PipelineException(msg)

        ## Create full interpolated signal
        interpolated_signal = interpolate_signal_data(
            scan_key,
            source_type,
            target_type,
            source_times_source_clock,
            target_times_source_clock,
            debug=debug,
        )

        ## Split indices given into fragments based on which ones are continuous (incrementing by 1)
        target_signal_fragments = []
        for idx_fragment in source_indices:
            idx_fragment_mask = ~np.isnan(
                source_times_source_clock[idx_fragment])
            masked_idx_fragment = idx_fragment[idx_fragment_mask]
            target_signal_fragments.append(
                interpolated_signal[masked_idx_fragment])

        ## If full signal is converted, remove wrapping list
        if len(target_signal_fragments) == 1:
            target_signal_fragments = target_signal_fragments[0]

        converted_values = target_signal_fragments

    elif "times" in target_format:

        ## Set type of times to use
        if target_format == "times-source":
            target_times = target_times_source_clock
        elif target_format == "times-target":
            target_times = target_times_target_clock
        else:
            msg = f"'Times' target format must be 'times-source' or 'times-target'. Value was {target_format}."
            raise PipelineException(msg)

        ## Convert indices to times and return
        source_idx_to_target_times = []

        for target_idx_list in target_indices:
            source_idx_to_target_times.append(target_times[target_idx_list])

        if len(source_idx_to_target_times) == 1:
            source_idx_to_target_times = source_idx_to_target_times[0]

        converted_values = source_idx_to_target_times

    elif target_format == "indices":

        if len(target_indices) == 1:
            target_indices = target_indices[0]

        converted_values = target_indices

    else:

        msg = (
            f"Target format {target_format} is not supported. "
            f"Valid options are 'indices', 'times-source', 'times-target', 'signal'."
        )
        raise PipelineException(msg)

    return converted_values
예제 #12
0
def interpolate_signal_data(
    scan_key: dict,
    source_type: str,
    target_type: str,
    source_times,
    target_times,
    debug: bool = True,
):
    """
    Interpolates target_type recording onto source_times. If target FPS is higher than source FPS, run lowpass hamming
    filter at source Hz over target_type recording before interpolating. Automatically slices ScanImage times and runs
    error checking for length mismatches.

        Parameters:

                scan_key: A dictionary specifying a single scan and/or field. A single field must be defined if requesting
                          a source or target from ScanImage. If key specifies a single unit, unit delay will be added to
                          all timepoints recorded. Single units can be specified via unique mask_id + field or via unit_id.
                          If only field is specified, average field delay will be added.

                source_type: A string specifying what indices you want to convert from. Fluorescence and deconvolution
                             have a dash followed by "behavior" or "stimulus" to refer to which clock you are using.
                             Supported options:
                                 'fluorescence-stimulus', 'deconvolution-stimulus', ,'fluorescence-behavior',
                                 'deconvolution-behavior', 'pupil', 'treadmill', 'respiration'

                target_type: A string specifying what indices to convert into. Fluorescence and deconvolution have a
                             dash followed by "behavior" or "stimulus" to refer to which clock you are using.
                             Supported options:
                                 'fluorescence-stimulus', 'deconvolution-stimulus', ,'fluorescence-behavior',
                                 'deconvolution-behavior', 'pupil', 'treadmill', 'respiration'

                source_times: Numpy array of times for source recording on source clock. Assumed to be corrected for
                              delays such as average field delay or specific unit delay.

                target_times: Numpy array of times for target recording on source clock. Assumed to be corrected for
                              delays such as average field delay or specific unit delay.

                debug: Set function to print helpful debug text while running


        Returns:

                interpolate_signal: Numpy array of target_type signal interpolated to recording times of source_type
    """

    ## Make settings strings lowercase
    source_type = source_type.lower()
    target_type = target_type.lower()

    ## Define the pipe (meso/reso) to use
    if len(fuse.MotionCorrection & scan_key) == 0:
        msg = f"scan_key {scan_key} not found in fuse.MotionCorrection."
        raise PipelineException(msg)
    pipe = (fuse.MotionCorrection & scan_key).module

    ## Run helpful error checking
    if source_type == "pupil":
        tracking_method_num = len(
            dj.U("tracking_method") & (pupil.FittedPupil & scan_key))
        if tracking_method_num > 1:
            msg = (
                "More than one pupil tracking method found for entered scan. "
                "Specify tracking_method in scan key (tracking_method=2 for DeepLabCut)."
            )
            raise PipelineException(msg)

    ## Fetch required signal
    ## Note: Pupil requires .fetch() while other signals require .fetch1().
    ##       It is easier to make an if-elif-else structure than a lookup dictionary in this case.
    if target_type in ("fluorescence-stimulus", "fluorescence-behavior"):
        target_signal = (pipe.Fluorescence.Trace & scan_key).fetch1("trace")
    elif target_type in ("deconvolution-stimulus", "deconvolution-behavior"):
        unit_key = (pipe.ScanSet.Unit & scan_key).fetch1()
        target_signal = (pipe.Activity.Trace & unit_key).fetch1("trace")
    elif target_type == "pupil":
        target_signal = (pupil.FittedPupil.Circle & scan_key).fetch("radius")
    elif target_type == "treadmill":
        target_signal = (treadmill.Treadmill
                         & scan_key).fetch1("treadmill_vel")
    elif target_type == "respiration":
        target_signal = ((odor.Respiration * odor.MesoMatch)
                         & scan_key).fetch1("trace")
    else:
        msg = f"Error, target type {target_type} is not supported. Cannot fetch signal data."
        raise PipelineException(msg)

    ## Calculate FPS to determine if lowpass filtering is needed
    source_fps = 1 / np.nanmedian(np.diff(source_times))
    target_fps = 1 / np.nanmedian(np.diff(target_times))

    ## Fill NaNs to prevent interpolation errors, but store NaNs for later to add back in after interpolating
    source_replace_nans = None  # Use this as a switch to refill things later
    if sum(np.isnan(target_signal)) > 0:
        target_nan_indices = np.isnan(target_signal)
        time_nan_indices = np.isnan(target_times)
        target_replace_nans = np.logical_and(target_nan_indices,
                                             ~time_nan_indices)
        if sum(target_replace_nans) > 0:
            source_replace_nans = convert_clocks(
                scan_key,
                np.where(target_replace_nans)[0],
                "indices",
                target_type,
                "indices",
                source_type,
                debug=False,
            )
        nan_filler_func = (shared.FilterMethod & {
            "filter_method": "NaN Filler"
        }).run_filter
        target_signal = nan_filler_func(target_signal)
        if debug:
            biggest_time_gap = np.nanmax(
                np.diff(target_times[np.where(~target_replace_nans)[0]]))
            msg = (
                f"Found NaNs in {sum(target_nan_indices)} locations, which corresponds to "
                f"{round(100*sum(target_nan_indices)/len(target_signal),2)}% of total signal. "
                f"Largest NaN gap found: {round(biggest_time_gap, 2)} seconds."
            )
            print(msg)

    ## Lowpass signal if needed
    if source_fps < target_fps:
        if debug:
            msg = (
                f"Target FPS of {round(target_fps,2)} is greater than source FPS {round(source_fps,2)}. "
                f"Hamming lowpass filtering target signal before interpolation"
            )
            print(msg)
        target_signal = shared.FilterMethod._lowpass_hamming(
            signal=target_signal,
            signal_freq=target_fps,
            lowpass_freq=source_fps)

    ## Timing and recording array lengths can differ slightly if recording was stopped mid-scan. Timings for
    ## the next X depths would be recorded, but fluorescence values would be dropped if all depths were not
    ## recorded. This would mean timings difference shouldn't be more than the number of depths of the scan.
    if len(target_times) < len(target_signal):
        msg = (
            f"More recording values than target time values exist! This should not be possible.\n"
            f"Target time length: {len(target_times)}\n"
            f"Target signal length: {len(target_signal)}")
        raise PipelineException(msg)

    elif len(target_times) > len(target_signal):

        scan_res = pipe.ScanInfo.proj(
        ) & scan_key  ## To make sure we select all fields
        z_plane_num = len(dj.U("z") & (pipe.ScanInfo.Field & scan_res))
        if (len(target_times) - len(target_signal)) > z_plane_num:
            msg = (
                f"Extra timing values exceeds reasonable error bounds. "
                f"Error length of {len(target_times) - len(target_signal)} with only {z_plane_num} z-planes."
            )
            raise PipelineException(msg)

        else:

            shorter_length = np.min((len(target_times), len(target_signal)))
            source_times = target_times[:shorter_length]
            source_signal = target_signal[:shorter_length]
            if debug:
                length_diff = np.abs(len(target_times) - len(target_signal))
                msg = (
                    f"Target times and target signal show length mismatch within acceptable error."
                    f"Difference of {length_diff} within acceptable bounds of {z_plane_num} z-planes."
                )
                print(msg)

    ## Interpolating target signal into source timings
    signal_interp = interpolate.interp1d(target_times,
                                         target_signal,
                                         bounds_error=False)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        interpolated_signal = signal_interp(source_times)
    if source_replace_nans is not None:
        for source_nan_idx in source_replace_nans:
            interpolated_signal[source_nan_idx] = np.nan

    return interpolated_signal
예제 #13
0
class CreateTaskAPITest(APITest):
    def url(self):
        return "/apigw/create_task/{template_id}/{project_id}/"

    @mock.patch(TASKINSTANCE_CREATE_PIPELINE,
                MagicMock(return_value=TEST_DATA))
    @mock.patch(
        TASKINSTANCE_CREATE,
        MagicMock(return_value=MockTaskFlowInstance(id=TEST_TASKFLOW_ID)),
    )
    @mock.patch(APIGW_CREATE_TASK_JSON_SCHEMA_VALIDATE, MagicMock())
    def test_create_task__success(self):
        pt1 = MockPipelineTemplate(id=1, name="pt1")

        tmpl = MockTaskTemplate(id=1, pipeline_template=pt1)
        proj = MockProject(
            project_id=TEST_PROJECT_ID,
            name=TEST_PROJECT_NAME,
            bk_biz_id=TEST_BIZ_CC_ID,
            from_cmdb=True,
        )

        with mock.patch(PROJECT_GET, MagicMock(return_value=proj)):
            with mock.patch(
                    TASKTEMPLATE_SELECT_RELATE,
                    MagicMock(return_value=MockQuerySet(get_result=tmpl)),
            ):
                assert_data = {
                    "task_id": TEST_TASKFLOW_ID,
                    "task_url": TEST_TASKFLOW_URL,
                    "pipeline_tree": TEST_TASKFLOW_PIPELINE_TREE,
                }
                response = self.client.post(
                    path=self.url().format(template_id=TEST_TEMPLATE_ID,
                                           project_id=TEST_PROJECT_ID),
                    data=json.dumps({
                        "name":
                        "name",
                        "constants": {},
                        "exclude_task_nodes_id":
                        ["ne584c1e69f53d109f0d99eacc3bd670"],
                        "flow_type":
                        "common",
                    }),
                    content_type="application/json",
                    HTTP_BK_APP_CODE=TEST_APP_CODE,
                    HTTP_BK_USERNAME=TEST_USERNAME,
                )

                TaskFlowInstance.objects.create_pipeline_instance_exclude_task_nodes.assert_called_once_with(
                    tmpl,
                    {
                        "name": "name",
                        "creator": "",
                        "description": ""
                    },
                    {},
                    ["ne584c1e69f53d109f0d99eacc3bd670"],
                    [],
                )

                TaskFlowInstance.objects.create.assert_called_once_with(
                    project=proj,
                    category=tmpl.category,
                    pipeline_instance=TEST_DATA,
                    template_id=TEST_TEMPLATE_ID,
                    template_source="project",
                    create_method="api",
                    create_info=TEST_APP_CODE,
                    flow_type="common",
                    current_flow="execute_task",
                    engine_ver=2,
                )

                data = json.loads(response.content)

                self.assertTrue(data["result"], msg=data)
                self.assertEqual(data["data"], assert_data)

                TaskFlowInstance.objects.create_pipeline_instance_exclude_task_nodes.reset_mock(
                )
                TaskFlowInstance.objects.create.reset_mock()

            pt1 = MockPipelineTemplate(id=1, name="pt1")

            tmpl = MockCommonTemplate(id=1, pipeline_template=pt1)

            with mock.patch(
                    COMMONTEMPLATE_SELECT_RELATE,
                    MagicMock(return_value=MockQuerySet(get_result=tmpl)),
            ):
                assert_data = {
                    "task_id": TEST_TASKFLOW_ID,
                    "task_url": TEST_TASKFLOW_URL,
                    "pipeline_tree": TEST_TASKFLOW_PIPELINE_TREE,
                }
                response = self.client.post(
                    path=self.url().format(template_id=TEST_TEMPLATE_ID,
                                           project_id=TEST_PROJECT_ID),
                    data=json.dumps({
                        "name":
                        "name",
                        "constants": {},
                        "exclude_task_nodes_id":
                        ["ne584c1e69f53d109f0d99eacc3bd670"],
                        "template_source":
                        "common",
                        "flow_type":
                        "common",
                    }),
                    content_type="application/json",
                    HTTP_BK_APP_CODE=TEST_APP_CODE,
                    HTTP_BK_USERNAME=TEST_USERNAME,
                )

                TaskFlowInstance.objects.create_pipeline_instance_exclude_task_nodes.assert_called_once_with(
                    tmpl,
                    {
                        "name": "name",
                        "creator": "",
                        "description": ""
                    },
                    {},
                    ["ne584c1e69f53d109f0d99eacc3bd670"],
                    [],
                )

                TaskFlowInstance.objects.create.assert_called_once_with(
                    project=proj,
                    category=tmpl.category,
                    pipeline_instance=TEST_DATA,
                    template_id=TEST_TEMPLATE_ID,
                    template_source="common",
                    create_method="api",
                    create_info=TEST_APP_CODE,
                    flow_type="common",
                    current_flow="execute_task",
                    engine_ver=2,
                )

                data = json.loads(response.content)

                self.assertTrue(data["result"], msg=data)
                self.assertEqual(data["data"], assert_data)

    @mock.patch(TASKINSTANCE_CREATE_PIPELINE,
                MagicMock(return_value=TEST_DATA))
    @mock.patch(
        TASKINSTANCE_CREATE,
        MagicMock(return_value=MockTaskFlowInstance(id=TEST_TASKFLOW_ID)),
    )
    @mock.patch(APIGW_CREATE_TASK_JSON_SCHEMA_VALIDATE, MagicMock())
    @mock.patch(APIGW_CREATE_TASK_NODE_NAME_HANDLE, MagicMock())
    @mock.patch(APIGW_CREATE_TASK_VALIDATE_WEB_PIPELINE_TREE, MagicMock())
    @mock.patch(TASKINSTANCE_CREATE_PIPELINE_INSTANCE,
                MagicMock(return_value=TEST_DATA))
    def test_create_task__success_with_tree(self):
        pt1 = MockPipelineTemplate(id=1, name="pt1")

        tmpl = MockTaskTemplate(id=1, pipeline_template=pt1)
        proj = MockProject(
            project_id=TEST_PROJECT_ID,
            name=TEST_PROJECT_NAME,
            bk_biz_id=TEST_BIZ_CC_ID,
            from_cmdb=True,
        )

        with mock.patch(PROJECT_GET, MagicMock(return_value=proj)):
            with mock.patch(
                    TASKTEMPLATE_SELECT_RELATE,
                    MagicMock(return_value=MockQuerySet(get_result=tmpl)),
            ):
                assert_data = {
                    "task_id": TEST_TASKFLOW_ID,
                    "task_url": TEST_TASKFLOW_URL,
                    "pipeline_tree":
                    copy.deepcopy(TEST_TASKFLOW_PIPELINE_TREE),
                }
                response = self.client.post(
                    path=self.url().format(template_id=TEST_TEMPLATE_ID,
                                           project_id=TEST_PROJECT_ID),
                    data=json.dumps({
                        "name": "name",
                        "constants": {
                            "key1": "value1",
                            "key2": "value2"
                        },
                        "pipeline_tree": TEST_PIPELINE_TREE,
                        "flow_type": "common",
                    }),
                    content_type="application/json",
                    HTTP_BK_APP_CODE=TEST_APP_CODE,
                    HTTP_BK_USERNAME=TEST_USERNAME,
                )

                TaskFlowInstance.objects.create_pipeline_instance.assert_called_once_with(
                    template=tmpl,
                    name="name",
                    creator="",
                    description="",
                    pipeline_tree={"constants": {
                        "key1": {
                            "value": "value1"
                        }
                    }},
                )

                TaskFlowInstance.objects.create.assert_called_once_with(
                    project=proj,
                    category=tmpl.category,
                    pipeline_instance=TEST_DATA,
                    template_id=TEST_TEMPLATE_ID,
                    template_source="project",
                    create_method="api",
                    create_info=TEST_APP_CODE,
                    flow_type="common",
                    current_flow="execute_task",
                    engine_ver=2,
                )

                data = json.loads(response.content)

                self.assertTrue(data["result"], msg=data)
                self.assertEqual(data["data"], assert_data)

    @mock.patch(
        PROJECT_GET,
        MagicMock(return_value=MockProject(
            project_id=TEST_PROJECT_ID,
            name=TEST_PROJECT_NAME,
            bk_biz_id=TEST_BIZ_CC_ID,
            from_cmdb=True,
        )),
    )
    @mock.patch(TASKTEMPLATE_SELECT_RELATE,
                MagicMock(return_value=MockQuerySet()))
    @mock.patch(COMMONTEMPLATE_SELECT_RELATE,
                MagicMock(return_value=MockQuerySet()))
    @mock.patch(
        APIGW_CREATE_TASK_JSON_SCHEMA_VALIDATE,
        MagicMock(side_effect=jsonschema.ValidationError("")),
    )
    def test_create_task__validate_fail(self):
        response = self.client.post(
            path=self.url().format(template_id=TEST_TEMPLATE_ID,
                                   project_id=TEST_PROJECT_ID),
            data=json.dumps({
                "name":
                "name",
                "constants": {},
                "exclude_task_node_id": ["ne584c1e69f53d109f0d99eacc3bd670"],
            }),
            content_type="application/json",
            HTTP_BK_APP_CODE=TEST_APP_CODE,
            HTTP_BK_USERNAME=TEST_USERNAME,
        )

        data = json.loads(response.content)

        self.assertFalse(data["result"])
        self.assertTrue("message" in data)

        response = self.client.post(
            path=self.url().format(template_id=TEST_TEMPLATE_ID,
                                   project_id=TEST_PROJECT_ID),
            data=json.dumps({
                "name":
                "name",
                "constants": {},
                "exclude_task_node_id": ["ne584c1e69f53d109f0d99eacc3bd670"],
                "template_source":
                "common",
            }),
            content_type="application/json",
            HTTP_BK_APP_CODE=TEST_APP_CODE,
            HTTP_BK_USERNAME=TEST_USERNAME,
        )

        data = json.loads(response.content)

        self.assertFalse(data["result"])
        self.assertTrue("message" in data)

    @mock.patch(
        PROJECT_GET,
        MagicMock(return_value=MockProject(
            project_id=TEST_PROJECT_ID,
            name=TEST_PROJECT_NAME,
            bk_biz_id=TEST_BIZ_CC_ID,
            from_cmdb=True,
        )),
    )
    @mock.patch(TASKTEMPLATE_SELECT_RELATE,
                MagicMock(return_value=MockQuerySet()))
    @mock.patch(COMMONTEMPLATE_SELECT_RELATE,
                MagicMock(return_value=MockQuerySet()))
    @mock.patch(TASKINSTANCE_CREATE_PIPELINE, MagicMock(return_value=""))
    @mock.patch(APIGW_CREATE_TASK_JSON_SCHEMA_VALIDATE, MagicMock())
    def test_create_task__without_app_code(self):
        response = self.client.post(
            path=self.url().format(template_id=TEST_TEMPLATE_ID,
                                   project_id=TEST_PROJECT_ID),
            data=json.dumps({
                "constants": {},
                "name":
                "test",
                "exclude_task_node_id": ["ne584c1e69f53d109f0d99eacc3bd670"],
            }),
            content_type="application/json",
            HTTP_BK_APP_CODE=TEST_APP_CODE,
            HTTP_BK_USERNAME=TEST_USERNAME,
        )

        data = json.loads(response.content)

        self.assertFalse(data["result"])
        self.assertTrue("message" in data)

        response = self.client.post(
            path=self.url().format(template_id=TEST_TEMPLATE_ID,
                                   project_id=TEST_PROJECT_ID),
            data=json.dumps({
                "constants": {},
                "name":
                "test",
                "exclude_task_node_id": ["ne584c1e69f53d109f0d99eacc3bd670"],
                "template_source":
                "common",
            }),
            content_type="application/json",
            HTTP_BK_APP_CODE=TEST_APP_CODE,
            HTTP_BK_USERNAME=TEST_USERNAME,
        )

        data = json.loads(response.content)

        self.assertFalse(data["result"])
        self.assertTrue("message" in data)

    @mock.patch(
        PROJECT_GET,
        MagicMock(return_value=MockProject(
            project_id=TEST_PROJECT_ID,
            name=TEST_PROJECT_NAME,
            bk_biz_id=TEST_BIZ_CC_ID,
            from_cmdb=True,
        )),
    )
    @mock.patch(TASKINSTANCE_CREATE_PIPELINE,
                MagicMock(side_effect=PipelineException()))
    @mock.patch(APIGW_CREATE_TASK_JSON_SCHEMA_VALIDATE, MagicMock())
    def test_create_task__create_pipeline_raise(self):
        pt1 = MockPipelineTemplate(id=1, name="pt1")

        tmpl = MockTaskTemplate(id=1, pipeline_template=pt1)

        with mock.patch(
                TASKTEMPLATE_SELECT_RELATE,
                MagicMock(return_value=MockQuerySet(get_result=tmpl)),
        ):
            response = self.client.post(
                path=self.url().format(template_id=TEST_TEMPLATE_ID,
                                       project_id=TEST_PROJECT_ID),
                data=json.dumps({
                    "name":
                    "name",
                    "constants": {},
                    "exclude_task_node_id":
                    ["ne584c1e69f53d109f0d99eacc3bd670"],
                }),
                content_type="application/json",
                HTTP_BK_APP_CODE=TEST_APP_CODE,
                HTTP_BK_USERNAME=TEST_USERNAME,
            )

            data = json.loads(response.content)

            self.assertFalse(data["result"])
            self.assertTrue("message" in data)

        pt1 = MockPipelineTemplate(id=1, name="pt1")

        tmpl = MockCommonTemplate(id=1, pipeline_template=pt1)

        with mock.patch(
                COMMONTEMPLATE_SELECT_RELATE,
                MagicMock(return_value=MockQuerySet(get_result=tmpl)),
        ):
            response = self.client.post(
                path=self.url().format(template_id=TEST_TEMPLATE_ID,
                                       project_id=TEST_PROJECT_ID),
                data=json.dumps({
                    "name":
                    "name",
                    "constants": {},
                    "exclude_task_node_id":
                    ["ne584c1e69f53d109f0d99eacc3bd670"],
                    "template_source":
                    "common",
                }),
                content_type="application/json",
                HTTP_BK_APP_CODE=TEST_APP_CODE,
                HTTP_BK_USERNAME=TEST_USERNAME,
            )

            data = json.loads(response.content)

            self.assertFalse(data["result"])
            self.assertTrue("message" in data)

    @mock.patch(
        PROJECT_GET,
        MagicMock(return_value=MockProject(
            project_id=TEST_PROJECT_ID,
            name=TEST_PROJECT_NAME,
            bk_biz_id=TEST_BIZ_CC_ID,
            from_cmdb=True,
        )),
    )
    @mock.patch(TASKINSTANCE_CREATE_PIPELINE, MagicMock(return_value=""))
    @mock.patch(APIGW_CREATE_TASK_JSON_SCHEMA_VALIDATE, MagicMock())
    def test_create_task__create_pipeline_fail(self):
        pt1 = MockPipelineTemplate(id=1, name="pt1")

        tmpl = MockTaskTemplate(id=1, pipeline_template=pt1)

        with mock.patch(
                TASKTEMPLATE_SELECT_RELATE,
                MagicMock(return_value=MockQuerySet(get_result=tmpl)),
        ):
            response = self.client.post(
                path=self.url().format(template_id=TEST_TEMPLATE_ID,
                                       project_id=TEST_PROJECT_ID),
                data=json.dumps({
                    "name": "name",
                    "constants": {},
                    "exclude_task_node_id": "exclude_task_node_id"
                }),
                content_type="application/json",
                HTTP_BK_APP_CODE=TEST_APP_CODE,
                HTTP_BK_USERNAME=TEST_USERNAME,
            )

            data = json.loads(response.content)

            self.assertFalse(data["result"])
            self.assertTrue("message" in data)

        pt1 = MockPipelineTemplate(id=1, name="pt1")

        tmpl = MockCommonTemplate(id=1, pipeline_template=pt1)

        with mock.patch(
                COMMONTEMPLATE_SELECT_RELATE,
                MagicMock(return_value=MockQuerySet(get_result=tmpl)),
        ):
            response = self.client.post(
                path=self.url().format(template_id=TEST_TEMPLATE_ID,
                                       project_id=TEST_PROJECT_ID),
                data=json.dumps({
                    "name":
                    "name",
                    "constants": {},
                    "exclude_task_node_id":
                    ["ne584c1e69f53d109f0d99eacc3bd670"],
                    "template_source":
                    "common",
                }),
                content_type="application/json",
                HTTP_BK_APP_CODE=TEST_APP_CODE,
                HTTP_BK_USERNAME=TEST_USERNAME,
            )

            data = json.loads(response.content)

            self.assertFalse(data["result"])
            self.assertTrue("message" in data)

    @mock.patch(
        PROJECT_GET,
        MagicMock(return_value=MockProject(
            project_id=TEST_PROJECT_ID,
            name=TEST_PROJECT_NAME,
            bk_biz_id=TEST_BIZ_CC_ID,
            from_cmdb=True,
        )),
    )
    @mock.patch(APIGW_CREATE_TASK_JSON_SCHEMA_VALIDATE, MagicMock())
    @mock.patch(APIGW_CREATE_TASK_NODE_NAME_HANDLE, MagicMock())
    @mock.patch(APIGW_CREATE_TASK_VALIDATE_WEB_PIPELINE_TREE,
                MagicMock(side_effect=Exception()))
    def test_create_task__validate_pipeline_tree_error(self):
        pt1 = MockPipelineTemplate(id=1, name="pt1")

        tmpl = MockTaskTemplate(id=1, pipeline_template=pt1)

        with mock.patch(
                TASKTEMPLATE_SELECT_RELATE,
                MagicMock(return_value=MockQuerySet(get_result=tmpl)),
        ):
            response = self.client.post(
                path=self.url().format(template_id=TEST_TEMPLATE_ID,
                                       project_id=TEST_PROJECT_ID),
                data=json.dumps({
                    "name":
                    "name",
                    "pipeline_tree":
                    TEST_PIPELINE_TREE,
                    "exclude_task_node_id":
                    ["ne584c1e69f53d109f0d99eacc3bd670"],
                }),
                content_type="application/json",
                HTTP_BK_APP_CODE=TEST_APP_CODE,
                HTTP_BK_USERNAME=TEST_USERNAME,
            )

            data = json.loads(response.content)

            self.assertFalse(data["result"])
            self.assertTrue("message" in data)
            self.assertEqual(data["code"], err_code.UNKNOWN_ERROR.code)

            create_task.standardize_pipeline_node_name.assert_called_once_with(
                TEST_PIPELINE_TREE)
            create_task.validate_web_pipeline_tree.assert_called_once_with(
                TEST_PIPELINE_TREE)
            create_task.standardize_pipeline_node_name.reset_mock()
            create_task.validate_web_pipeline_tree.reset_mock()

        pt1 = MockPipelineTemplate(id=1, name="pt1")

        tmpl = MockCommonTemplate(id=1, pipeline_template=pt1)

        with mock.patch(
                COMMONTEMPLATE_SELECT_RELATE,
                MagicMock(return_value=MockQuerySet(get_result=tmpl)),
        ):
            response = self.client.post(
                path=self.url().format(template_id=TEST_TEMPLATE_ID,
                                       project_id=TEST_PROJECT_ID),
                data=json.dumps({
                    "name":
                    "name",
                    "pipeline_tree":
                    TEST_PIPELINE_TREE,
                    "exclude_task_node_id":
                    ["ne584c1e69f53d109f0d99eacc3bd670"],
                    "template_source":
                    "common",
                }),
                content_type="application/json",
                HTTP_BK_APP_CODE=TEST_APP_CODE,
                HTTP_BK_USERNAME=TEST_USERNAME,
            )

            data = json.loads(response.content)

            self.assertFalse(data["result"])
            self.assertTrue("message" in data)
            self.assertEqual(data["code"], err_code.UNKNOWN_ERROR.code)

            create_task.standardize_pipeline_node_name.assert_called_once_with(
                TEST_PIPELINE_TREE)
            create_task.validate_web_pipeline_tree.assert_called_once_with(
                TEST_PIPELINE_TREE)

    @mock.patch(
        PROJECT_GET,
        MagicMock(return_value=MockProject(
            project_id=TEST_PROJECT_ID,
            name=TEST_PROJECT_NAME,
            bk_biz_id=TEST_BIZ_CC_ID,
            from_cmdb=True,
        )),
    )
    @mock.patch(APIGW_CREATE_TASK_JSON_SCHEMA_VALIDATE, MagicMock())
    @mock.patch(APIGW_CREATE_TASK_NODE_NAME_HANDLE, MagicMock())
    @mock.patch(APIGW_CREATE_TASK_VALIDATE_WEB_PIPELINE_TREE, MagicMock())
    @mock.patch(
        TASKINSTANCE_CREATE_PIPELINE_INSTANCE,
        MagicMock(side_effect=PipelineException()),
    )
    def test_create_task__create_pipeline_instance_error(self):
        pt1 = MockPipelineTemplate(id=1, name="pt1")

        tmpl = MockTaskTemplate(id=1, pipeline_template=pt1)

        with mock.patch(
                TASKTEMPLATE_SELECT_RELATE,
                MagicMock(return_value=MockQuerySet(get_result=tmpl)),
        ):
            response = self.client.post(
                path=self.url().format(template_id=TEST_TEMPLATE_ID,
                                       project_id=TEST_PROJECT_ID),
                data=json.dumps({
                    "name":
                    "name",
                    "pipeline_tree":
                    TEST_PIPELINE_TREE,
                    "exclude_task_node_id":
                    ["ne584c1e69f53d109f0d99eacc3bd670"],
                }),
                content_type="application/json",
                HTTP_BK_APP_CODE=TEST_APP_CODE,
                HTTP_BK_USERNAME=TEST_USERNAME,
            )

            data = json.loads(response.content)

            self.assertFalse(data["result"])
            self.assertTrue("message" in data)
            self.assertEqual(data["code"], err_code.UNKNOWN_ERROR.code)

            TaskFlowInstance.objects.create_pipeline_instance.assert_called_once_with(
                template=tmpl,
                name="name",
                creator="",
                description="",
                pipeline_tree=TEST_PIPELINE_TREE,
            )
            TaskFlowInstance.objects.create_pipeline_instance.reset_mock()

        pt1 = MockPipelineTemplate(id=1, name="pt1")

        tmpl = MockCommonTemplate(id=1, pipeline_template=pt1)

        with mock.patch(
                COMMONTEMPLATE_SELECT_RELATE,
                MagicMock(return_value=MockQuerySet(get_result=tmpl)),
        ):
            response = self.client.post(
                path=self.url().format(template_id=TEST_TEMPLATE_ID,
                                       project_id=TEST_PROJECT_ID),
                data=json.dumps({
                    "name":
                    "name",
                    "pipeline_tree":
                    TEST_PIPELINE_TREE,
                    "exclude_task_node_id":
                    ["ne584c1e69f53d109f0d99eacc3bd670"],
                    "template_source":
                    "common",
                }),
                content_type="application/json",
                HTTP_BK_APP_CODE=TEST_APP_CODE,
                HTTP_BK_USERNAME=TEST_USERNAME,
            )

            data = json.loads(response.content)

            self.assertFalse(data["result"])
            self.assertTrue("message" in data)
            self.assertEqual(data["code"], err_code.UNKNOWN_ERROR.code)

            TaskFlowInstance.objects.create_pipeline_instance.assert_called_once_with(
                template=tmpl,
                name="name",
                creator="",
                description="",
                pipeline_tree=TEST_PIPELINE_TREE,
            )

    @mock.patch(
        PROJECT_GET,
        MagicMock(return_value=MockProject(
            project_id=TEST_PROJECT_ID,
            name=TEST_PROJECT_NAME,
            bk_biz_id=TEST_BIZ_CC_ID,
            from_cmdb=True,
        )),
    )
    @mock.patch(APIGW_CREATE_TASK_JSON_SCHEMA_VALIDATE, MagicMock())
    @mock.patch(APIGW_CREATE_TASK_NODE_NAME_HANDLE, MagicMock())
    @mock.patch(APIGW_CREATE_TASK_VALIDATE_WEB_PIPELINE_TREE, MagicMock())
    @mock.patch(
        TASKINSTANCE_CREATE_PIPELINE_INSTANCE,
        MagicMock(side_effect=PipelineException()),
    )
    def test_create_task_success_with_execute_task_nodes(self):
        pt1 = MockPipelineTemplate(id=1, name="pt1")

        tmpl = MockTaskTemplate(id=1, pipeline_template=pt1)

        with mock.patch(
                TASKTEMPLATE_SELECT_RELATE,
                MagicMock(return_value=MockQuerySet(get_result=tmpl)),
        ):
            response = self.client.post(
                path=self.url().format(template_id=TEST_TEMPLATE_ID,
                                       project_id=TEST_PROJECT_ID),
                data=json.dumps({
                    "name":
                    "name",
                    "pipeline_tree":
                    TEST_PIPELINE_TREE,
                    "exclude_task_node_id":
                    ["ne584c1e69f53d109f0d99eacc3bd670"],
                    "execute_task_nodes_id":
                    ["node8c2af510c04b898e1f9ac8296296"],
                }),
                content_type="application/json",
                HTTP_BK_APP_CODE=TEST_APP_CODE,
                HTTP_BK_USERNAME=TEST_USERNAME,
            )

            data = json.loads(response.content)

            self.assertFalse(data["result"])
            self.assertTrue("message" in data)
            self.assertEqual(data["code"], err_code.UNKNOWN_ERROR.code)

            TaskFlowInstance.objects.create_pipeline_instance.assert_called_once_with(
                template=tmpl,
                name="name",
                creator="",
                description="",
                pipeline_tree=TEST_PIPELINE_TREE,
            )
            TaskFlowInstance.objects.create_pipeline_instance.reset_mock()

        pt1 = MockPipelineTemplate(id=1, name="pt1")

        tmpl = MockCommonTemplate(id=1, pipeline_template=pt1)

        with mock.patch(
                COMMONTEMPLATE_SELECT_RELATE,
                MagicMock(return_value=MockQuerySet(get_result=tmpl)),
        ):
            response = self.client.post(
                path=self.url().format(template_id=TEST_TEMPLATE_ID,
                                       project_id=TEST_PROJECT_ID),
                data=json.dumps({
                    "name":
                    "name",
                    "pipeline_tree":
                    TEST_PIPELINE_TREE,
                    "exclude_task_node_id":
                    ["ne584c1e69f53d109f0d99eacc3bd670"],
                    "template_source":
                    "common",
                    "execute_task_nodes_id":
                    ["node8c2af510c04b898e1f9ac8296296"],
                }),
                content_type="application/json",
                HTTP_BK_APP_CODE=TEST_APP_CODE,
                HTTP_BK_USERNAME=TEST_USERNAME,
            )

            data = json.loads(response.content)

            self.assertFalse(data["result"])
            self.assertTrue("message" in data)
            self.assertEqual(data["code"], err_code.UNKNOWN_ERROR.code)

            TaskFlowInstance.objects.create_pipeline_instance.assert_called_once_with(
                template=tmpl,
                name="name",
                creator="",
                description="",
                pipeline_tree=TEST_PIPELINE_TREE,
            )