Esempio n. 1
0
    def test_import_gcs_csv_to_cloud_sql_with_region_code(self) -> None:
        """Assert that rows are copied to the temp table for every region code before being swapped to
        the destination table."""
        user_1 = generate_fake_user_restrictions(
            "US_PA",
            "*****@*****.**",
            allowed_supervision_location_ids="1,2",
        )
        user_2 = generate_fake_user_restrictions(
            "US_PA",
            "*****@*****.**",
            allowed_supervision_location_ids="AB",
        )
        add_users_to_database_session(self.database_key, [user_1, user_2])

        def _mock_side_effect(**_kwargs: Any) -> str:
            return self._mock_load_data_from_csv(values=[
                "('US_MO', '*****@*****.**', '', '{2}', 'level_1_supervision_location', 'level_1_access_role', true, false)",
                "('US_MO', '*****@*****.**', '', '{3}', 'level_1_supervision_location', 'level_1_access_role', true, false)",
            ])

        self.mock_cloud_sql_client.import_gcs_csv.side_effect = _mock_side_effect
        self.mock_cloud_sql_client.wait_until_operation_completed.return_value = True

        import_gcs_csv_to_cloud_sql(
            schema_type=SchemaType.CASE_TRIAGE,
            destination_table=self.table_name,
            gcs_uri=self.gcs_uri,
            columns=self.columns,
            region_code="US_MO",
        )
        with SessionFactory.using_database(self.database_key,
                                           autocommit=False) as session:
            destination_table_rows = session.query(
                DashboardUserRestrictions).all()
            state_codes = [row.state_code for row in destination_table_rows]
            self.assertEqual(len(destination_table_rows), 4)
            self.assertEqual(set(state_codes), {"US_MO", "US_PA"})
Esempio n. 2
0
    def test_dashboard_user_restrictions_by_email_user_has_restrictions(
            self) -> None:
        user_1 = generate_fake_user_restrictions(
            self.region_code,
            "*****@*****.**",
            allowed_supervision_location_ids="1,2",
        )
        user_2 = generate_fake_user_restrictions(
            self.region_code,
            "*****@*****.**",
            allowed_supervision_location_ids="AB",
        )
        add_users_to_database_session(self.database_key, [user_1, user_2])

        with self.app.test_request_context():
            expected_restrictions = {
                "allowed_supervision_location_ids": ["1", "2"],
                "allowed_supervision_location_level":
                "level_1_supervision_location",
                "can_access_case_triage": False,
                "can_access_leadership_dashboard": True,
            }
            response = self.client.get(
                self.dashboard_user_restrictions_by_email_url,
                headers=self.headers,
                query_string={
                    "region_code": "US_MO",
                    "email_address": "*****@*****.**",
                },
            )

            self.assertEqual(HTTPStatus.OK, response.status_code)
            self.assertEqual(
                expected_restrictions,
                json.loads(response.data),
            )
Esempio n. 3
0
 def test_import_gcs_csv_to_cloud_sql_session_error(self) -> None:
     """Assert that session errors raise an error and roll back the session."""
     with SessionFactory.using_database(self.database_key,
                                        autocommit=False) as session:
         user_1 = generate_fake_user_restrictions(
             "US_PA",
             "*****@*****.**",
             allowed_supervision_location_ids="1,2",
         )
         add_users_to_database_session(self.database_key, [user_1])
         with self.assertRaises(Exception) as e:
             import_gcs_csv_to_cloud_sql(
                 schema_type=SchemaType.CASE_TRIAGE,
                 destination_table="table_does_not_exist",
                 gcs_uri=self.gcs_uri,
                 columns=self.columns,
             )
         destination_table_rows = session.query(
             DashboardUserRestrictions).all()
         self.assertEqual(len(destination_table_rows), 1)
         assert 'relation "table_does_not_exist" does not exist' in str(
             e.exception)
Esempio n. 4
0
 def test_import_gcs_csv_to_cloud_sql_client_error(self) -> None:
     """Assert that CloudSQLClient errors raise an error and roll back the session."""
     with SessionFactory.using_database(self.database_key,
                                        autocommit=False) as session:
         user_1 = generate_fake_user_restrictions(
             "US_PA",
             "*****@*****.**",
             allowed_supervision_location_ids="1,2",
         )
         add_users_to_database_session(self.database_key, [user_1])
         self.mock_cloud_sql_client.import_gcs_csv.side_effect = Exception(
             "Error while importing CSV to temp table")
         with self.assertRaises(Exception) as e:
             import_gcs_csv_to_cloud_sql(
                 schema_type=SchemaType.CASE_TRIAGE,
                 destination_table=self.table_name,
                 gcs_uri=self.gcs_uri,
                 columns=self.columns,
             )
         destination_table_rows = session.query(
             DashboardUserRestrictions).all()
         self.assertEqual(len(destination_table_rows), 1)
         self.assertEqual(str(e.exception),
                          "Error while importing CSV to temp table")
Esempio n. 5
0
    def test_update_auth0_user_metadata_with_users_returned(self) -> None:
        with self.app.test_request_context():
            user_1 = generate_fake_user_restrictions(
                self.region_code,
                "*****@*****.**",
                allowed_supervision_location_ids="23",
            )
            user_2 = generate_fake_user_restrictions(
                self.region_code,
                "*****@*****.**",
                allowed_supervision_location_ids="11, EP, 4E",
            )
            add_users_to_database_session(self.database_key, [user_1, user_2])

            self.mock_auth0_client.get_all_users_by_email_addresses.return_value = [
                {
                    "email": "*****@*****.**",
                    "user_id": "0"
                },
                {
                    "email": "*****@*****.**",
                    "user_id": "1"
                },
            ]

            response = self.client.get(
                self.update_auth0_user_metadata_url,
                headers=self.headers,
                query_string={"region_code": self.region_code},
            )

            self.mock_auth0_client.get_all_users_by_email_addresses.assert_called_with(
                [
                    "*****@*****.**",
                    "*****@*****.**",
                ])

            self.mock_auth0_client.update_user_app_metadata.assert_has_calls([
                call(
                    user_id="0",
                    app_metadata={
                        "allowed_supervision_location_ids": ["23"],
                        "allowed_supervision_location_level":
                        "level_1_supervision_location",
                        "can_access_leadership_dashboard": True,
                        "can_access_case_triage": False,
                    },
                ),
                call(
                    user_id="1",
                    app_metadata={
                        "allowed_supervision_location_ids": ["11", "EP", "4E"],
                        "allowed_supervision_location_level":
                        "level_1_supervision_location",
                        "can_access_leadership_dashboard": True,
                        "can_access_case_triage": False,
                    },
                ),
            ])
            self.assertEqual(HTTPStatus.OK, response.status_code)
            self.assertEqual(
                b"Finished updating 2 auth0 users with restrictions for region US_MO",
                response.data,
            )
Esempio n. 6
0
    def setUp(self) -> None:
        self.get_local_patcher = mock.patch(
            "recidiviz.case_triage.authorization.get_local_file",
            new=_test_get_local_file,
        )
        self.get_local_patcher.start()

        self.auth_store = AuthorizationStore()
        self.auth_store.refresh()

        self.database_key = SQLAlchemyDatabaseKey.for_schema(SchemaType.CASE_TRIAGE)
        local_postgres_helpers.use_on_disk_postgresql_database(self.database_key)

        self.case_triage_user = generate_fake_user_restrictions(
            "US_XX",
            "*****@*****.**",
            can_access_leadership_dashboard=False,
            can_access_case_triage=True,
        )
        self.dashboard_user = generate_fake_user_restrictions(
            "US_XX",
            "*****@*****.**",
            can_access_leadership_dashboard=True,
            can_access_case_triage=False,
        )
        self.both_user = generate_fake_user_restrictions(
            "US_XX",
            "*****@*****.**",
            can_access_leadership_dashboard=True,
            can_access_case_triage=True,
        )

        self.overridden_user = generate_fake_user_restrictions(
            "US_XX",
            "*****@*****.**",
            can_access_leadership_dashboard=True,
            can_access_case_triage=False,
        )

        self.both_user_different_state = generate_fake_user_restrictions(
            "US_YY",
            "*****@*****.**",
            can_access_leadership_dashboard=True,
            can_access_case_triage=True,
        )

        self.officer = generate_fake_officer(
            "test", "*****@*****.**", state_code="US_XX"
        )

        with SessionFactory.using_database(self.database_key) as session:
            session.expire_on_commit = False
            session.add_all(
                [
                    self.case_triage_user,
                    self.dashboard_user,
                    self.both_user,
                    self.overridden_user,
                    self.both_user_different_state,
                    self.officer,
                ]
            )