Exemple #1
0
 def test_process_manifest_db_record_race_no_provider(self, mock_get_manifest):
     """Test that the _process_manifest_db_record returns the correct manifest during a race for initial entry."""
     mock_get_manifest.side_effect = [None, None]
     with patch.object(ReportManifestDBAccessor, "add", side_effect=IntegrityError):
         downloader = ReportDownloaderBase(provider_uuid=self.unkown_test_provider_uuid, cache_key=self.cache_key)
         with self.assertRaises(IntegrityError):
             downloader._process_manifest_db_record(self.assembly_id, self.billing_start, 2, DateAccessor().today())
Exemple #2
0
 def setUp(self):
     """Set up each test case."""
     super().setUp()
     self.cache_key = self.fake.word()
     self.downloader = ReportDownloaderBase(
         provider_uuid=self.aws_provider_uuid, cache_key=self.cache_key)
     self.billing_start = self.date_accessor.today_with_timezone(
         "UTC").replace(day=1)
     self.manifest_dict = {
         "assembly_id": self.assembly_id,
         "billing_period_start_datetime": self.billing_start,
         "num_total_files": 2,
         "provider_uuid": self.aws_provider_uuid,
     }
     with ReportManifestDBAccessor() as manifest_accessor:
         self.manifest = manifest_accessor.add(**self.manifest_dict)
         self.manifest.save()
         self.manifest_id = self.manifest.id
     for i in [1, 2]:
         baker.make(
             CostUsageReportStatus,
             report_name=f"{self.assembly_id}_file_{i}.csv.gz",
             last_completed_datetime=None,
             last_started_datetime=None,
             manifest_id=self.manifest_id,
         )
 def test_report_downloader_base(self):
     """Test download path matches expected."""
     dl_path = "/{}/{}/{}".format(self.fake.word().lower(),
                                  self.fake.word().lower(),
                                  self.fake.word().lower())
     downloader = ReportDownloaderBase(download_path=dl_path)
     self.assertEqual(downloader.download_path, dl_path)
 def setUp(self):
     """Setup each test case."""
     super().setUp()
     self.downloader = ReportDownloaderBase(
         provider_id=self.aws_provider_id)
     billing_start = self.date_accessor.today_with_timezone('UTC').replace(
         day=1)
     self.manifest_dict = {
         'assembly_id': self.assembly_id,
         'billing_period_start_datetime': billing_start,
         'num_total_files': 2,
         'provider_id': self.aws_provider_id
     }
     with ReportManifestDBAccessor() as manifest_accessor:
         manifest = manifest_accessor.add(**self.manifest_dict)
         manifest.save()
         self.manifest_id = manifest.id
 def setUp(self):
     """Setup each test case."""
     super().setUp()
     self.mock_task = Mock(request=Mock(id=str(self.fake.uuid4()),
                                        return_value={}))
     self.downloader = ReportDownloaderBase(task=self.mock_task,
                                            provider_uuid=self.aws_provider_uuid)
     billing_start = self.date_accessor.today_with_timezone('UTC').replace(day=1)
     self.task_id = str(self.fake.uuid4())
     self.manifest_dict = {
         'assembly_id': self.assembly_id,
         'billing_period_start_datetime': billing_start,
         'num_total_files': 2,
         'provider_uuid': self.aws_provider_uuid,
         'task': self.task_id
     }
     with ReportManifestDBAccessor() as manifest_accessor:
         manifest = manifest_accessor.add(**self.manifest_dict)
         manifest.save()
         self.manifest_id = manifest.id
Exemple #6
0
    def test_process_manifest_db_record_race_no_provider(
            self, mock_get_manifest):
        """Test that the _process_manifest_db_record returns the correct manifest during a race for initial entry."""
        mock_get_manifest.side_effect = [None, None]
        side_effect_error = IntegrityError(
            """insert or update on table "reporting_awscostentrybill" violates foreign key constraint "reporting_awscostent_provider_id_a08725b3_fk_api_provi"
DETAIL:  Key (provider_id)=(fbe0593a-1b83-4182-b23e-08cd190ed939) is not present in table "api_provider".
"""

            # noqa
        )  # noqa
        with patch.object(ReportManifestDBAccessor,
                          "add",
                          side_effect=side_effect_error):
            downloader = ReportDownloaderBase(
                provider_uuid=self.unkown_test_provider_uuid,
                cache_key=self.cache_key)
            with self.assertRaises(ReportDownloaderError):
                downloader._process_manifest_db_record(self.assembly_id,
                                                       self.billing_start, 2,
                                                       DateAccessor().today())
Exemple #7
0
 def test_report_downloader_base_no_path(self):
     """Test report downloader download_path."""
     downloader = ReportDownloaderBase()
     self.assertIsInstance(downloader, ReportDownloaderBase)
     self.assertIsNotNone(downloader.download_path)
     self.assertTrue(os.path.exists(downloader.download_path))
Exemple #8
0
class ReportDownloaderBaseTest(MasuTestCase):
    """Test Cases for ReportDownloaderBase."""

    fake = Faker()
    patch_path = True

    @classmethod
    def setUpClass(cls):
        """Set up the test class."""
        super().setUpClass()
        cls.fake = Faker()
        cls.patch_path = True
        cls.date_accessor = DateAccessor()
        cls.assembly_id = cls.fake.pystr()
        cls.report_name = f"{cls.assembly_id}_file_1.csv.gz"

    def setUp(self):
        """Set up each test case."""
        super().setUp()
        self.cache_key = self.fake.word()
        self.downloader = ReportDownloaderBase(provider_uuid=self.aws_provider_uuid, cache_key=self.cache_key)
        self.billing_start = self.date_accessor.today_with_timezone("UTC").replace(day=1)
        self.manifest_dict = {
            "assembly_id": self.assembly_id,
            "billing_period_start_datetime": self.billing_start,
            "num_total_files": 2,
            "provider_uuid": self.aws_provider_uuid,
        }
        with ReportManifestDBAccessor() as manifest_accessor:
            self.manifest = manifest_accessor.add(**self.manifest_dict)
            self.manifest.save()
            self.manifest_id = self.manifest.id
        for i in [1, 2]:
            baker.make(
                CostUsageReportStatus,
                report_name=f"{self.assembly_id}_file_{i}.csv.gz",
                last_completed_datetime=None,
                last_started_datetime=None,
                manifest_id=self.manifest_id,
            )

    def tearDown(self):
        """Tear down each test case."""
        super().tearDown()
        with ReportStatsDBAccessor(self.report_name, self.manifest_id) as file_accessor:
            files = file_accessor._get_db_obj_query().all()
            for file in files:
                file_accessor.delete(file)

        with ReportManifestDBAccessor() as manifest_accessor:
            manifests = manifest_accessor._get_db_obj_query().all()
            for manifest in manifests:
                manifest_accessor.delete(manifest)

    def test_report_downloader_base_no_path(self):
        """Test report downloader download_path."""
        downloader = ReportDownloaderBase()
        self.assertIsInstance(downloader, ReportDownloaderBase)
        self.assertIsNotNone(downloader.download_path)
        self.assertTrue(os.path.exists(downloader.download_path))

    def test_report_downloader_base(self):
        """Test download path matches expected."""
        dl_path = "/{}/{}/{}".format(self.fake.word().lower(), self.fake.word().lower(), self.fake.word().lower())
        downloader = ReportDownloaderBase(download_path=dl_path)
        self.assertEqual(downloader.download_path, dl_path)

    def test_get_existing_manifest_db_id(self):
        """Test that a manifest ID is returned."""
        manifest_id = self.downloader._get_existing_manifest_db_id(self.assembly_id)
        self.assertEqual(manifest_id, self.manifest_id)

    @patch.object(ReportManifestDBAccessor, "get_manifest")
    def test_process_manifest_db_record_race(self, mock_get_manifest):
        """Test that the _process_manifest_db_record returns the correct manifest during a race for initial entry."""
        mock_get_manifest.side_effect = [None, self.manifest]
        with patch.object(ReportManifestDBAccessor, "add", side_effect=IntegrityError):
            manifest_id = self.downloader._process_manifest_db_record(
                self.assembly_id, self.billing_start, 2, DateAccessor().today()
            )
        self.assertEqual(manifest_id, self.manifest.id)

    @patch.object(ReportManifestDBAccessor, "get_manifest")
    def test_process_manifest_db_record_race_no_provider(self, mock_get_manifest):
        """Test that the _process_manifest_db_record returns the correct manifest during a race for initial entry."""
        mock_get_manifest.side_effect = [None, None]
        with patch.object(ReportManifestDBAccessor, "add", side_effect=IntegrityError):
            downloader = ReportDownloaderBase(provider_uuid=self.unkown_test_provider_uuid, cache_key=self.cache_key)
            with self.assertRaises(IntegrityError):
                downloader._process_manifest_db_record(self.assembly_id, self.billing_start, 2, DateAccessor().today())
Exemple #9
0
class ReportDownloaderBaseTest(MasuTestCase):
    """Test Cases for ReportDownloaderBase."""

    fake = Faker()
    patch_path = True

    @classmethod
    def setUpClass(cls):
        """Set up the test class."""
        super().setUpClass()
        cls.fake = Faker()
        cls.patch_path = True
        cls.date_accessor = DateAccessor()
        cls.assembly_id = cls.fake.pystr()
        cls.report_name = f"{cls.assembly_id}_file_1.csv.gz"

    def setUp(self):
        """Set up each test case."""
        super().setUp()
        self.cache_key = self.fake.word()
        self.downloader = ReportDownloaderBase(
            provider_uuid=self.aws_provider_uuid, cache_key=self.cache_key)
        self.billing_start = self.date_accessor.today_with_timezone(
            "UTC").replace(day=1)
        self.manifest_dict = {
            "assembly_id": self.assembly_id,
            "billing_period_start_datetime": self.billing_start,
            "num_total_files": 2,
            "provider_uuid": self.aws_provider_uuid,
        }
        with ReportManifestDBAccessor() as manifest_accessor:
            self.manifest = manifest_accessor.add(**self.manifest_dict)
            self.manifest.save()
            self.manifest_id = self.manifest.id
        for i in [1, 2]:
            baker.make(
                CostUsageReportStatus,
                report_name=f"{self.assembly_id}_file_{i}.csv.gz",
                last_completed_datetime=None,
                last_started_datetime=None,
                manifest_id=self.manifest_id,
            )

    def tearDown(self):
        """Tear down each test case."""
        super().tearDown()
        with ReportStatsDBAccessor(self.report_name,
                                   self.manifest_id) as file_accessor:
            files = file_accessor._get_db_obj_query().all()
            for file in files:
                file_accessor.delete(file)

        with ReportManifestDBAccessor() as manifest_accessor:
            manifests = manifest_accessor._get_db_obj_query().all()
            for manifest in manifests:
                manifest_accessor.delete(manifest)

    def test_report_downloader_base_no_path(self):
        """Test report downloader download_path."""
        downloader = ReportDownloaderBase()
        self.assertIsInstance(downloader, ReportDownloaderBase)
        self.assertIsNotNone(downloader.download_path)
        self.assertTrue(os.path.exists(downloader.download_path))

    def test_report_downloader_base(self):
        """Test download path matches expected."""
        dl_path = f"/{self.fake.word().lower()}/{self.fake.word().lower()}/{self.fake.word().lower()}"
        downloader = ReportDownloaderBase(download_path=dl_path)
        self.assertEqual(downloader.download_path, dl_path)

    def test_get_existing_manifest_db_id(self):
        """Test that a manifest ID is returned."""
        manifest_id = self.downloader._get_existing_manifest_db_id(
            self.assembly_id)
        self.assertEqual(manifest_id, self.manifest_id)

    @patch.object(ReportManifestDBAccessor, "get_manifest")
    def test_process_manifest_db_record_race(self, mock_get_manifest):
        """Test that the _process_manifest_db_record returns the correct manifest during a race for initial entry."""
        mock_get_manifest.side_effect = [None, self.manifest]
        with patch.object(ReportManifestDBAccessor,
                          "add",
                          side_effect=IntegrityError):
            manifest_id = self.downloader._process_manifest_db_record(
                self.assembly_id, self.billing_start, 2,
                DateAccessor().today())
        self.assertEqual(manifest_id, self.manifest.id)

    @patch.object(ReportManifestDBAccessor, "get_manifest")
    def test_process_manifest_db_record_race_no_provider(
            self, mock_get_manifest):
        """Test that the _process_manifest_db_record returns the correct manifest during a race for initial entry."""
        mock_get_manifest.side_effect = [None, None]
        side_effect_error = IntegrityError(
            """insert or update on table "reporting_awscostentrybill" violates foreign key constraint "reporting_awscostent_provider_id_a08725b3_fk_api_provi"
DETAIL:  Key (provider_id)=(fbe0593a-1b83-4182-b23e-08cd190ed939) is not present in table "api_provider".
"""

            # noqa
        )  # noqa
        with patch.object(ReportManifestDBAccessor,
                          "add",
                          side_effect=side_effect_error):
            downloader = ReportDownloaderBase(
                provider_uuid=self.unkown_test_provider_uuid,
                cache_key=self.cache_key)
            with self.assertRaises(ReportDownloaderError):
                downloader._process_manifest_db_record(self.assembly_id,
                                                       self.billing_start, 2,
                                                       DateAccessor().today())

    def test_process_manifest_db_record_file_num_changed(self):
        """Test that the _process_manifest_db_record returns the correct manifest during a race for initial entry."""
        CostUsageReportStatus.objects.create(
            report_name="fake_report.csv",
            last_completed_datetime=self.billing_start,
            last_started_datetime=self.billing_start,
            etag="etag",
            manifest=self.manifest,
        )
        manifest_id = self.downloader._process_manifest_db_record(
            self.assembly_id, self.billing_start, 3,
            DateAccessor().today())
        self.assertEqual(manifest_id, self.manifest.id)
        with ReportManifestDBAccessor() as manifest_accessor:
            result_manifest = manifest_accessor.get_manifest_by_id(manifest_id)
        expected_count = CostUsageReportStatus.objects.filter(
            manifest_id=self.manifest_id).count()
        self.assertEqual(result_manifest.num_total_files, expected_count)
 def test_report_downloader_base(self):
     dl_path = '/{}/{}/{}'.format(self.fake.word().lower(),
                                  self.fake.word().lower(),
                                  self.fake.word().lower())
     downloader = ReportDownloaderBase(download_path=dl_path)
     self.assertEqual(downloader.download_path, dl_path)
class ReportDownloaderBaseTest(MasuTestCase):
    """Test Cases for ReportDownloaderBase."""

    fake = Faker()
    patch_path = True

    @classmethod
    def setUpClass(cls):
        """Setup the test class."""
        super().setUpClass()
        cls.fake = Faker()
        cls.patch_path = True
        cls.date_accessor = DateAccessor()
        cls.assembly_id = cls.fake.pystr()
        cls.report_name = f'{cls.assembly_id}_file_1.csv.gz'

    def setUp(self):
        """Setup each test case."""
        super().setUp()
        self.downloader = ReportDownloaderBase(
            provider_id=self.aws_provider_id)
        billing_start = self.date_accessor.today_with_timezone('UTC').replace(
            day=1)
        self.manifest_dict = {
            'assembly_id': self.assembly_id,
            'billing_period_start_datetime': billing_start,
            'num_total_files': 2,
            'provider_id': self.aws_provider_id
        }
        with ReportManifestDBAccessor() as manifest_accessor:
            manifest = manifest_accessor.add(**self.manifest_dict)
            manifest.save()
            self.manifest_id = manifest.id

    def tearDown(self):
        """Tear down each test case."""
        super().tearDown()
        with ReportStatsDBAccessor(self.report_name,
                                   self.manifest_id) as file_accessor:
            files = file_accessor._get_db_obj_query().all()
            for file in files:
                file_accessor.delete(file)

        with ReportManifestDBAccessor() as manifest_accessor:
            manifests = manifest_accessor._get_db_obj_query().all()
            for manifest in manifests:
                manifest_accessor.delete(manifest)

    def test_report_downloader_base_no_path(self):
        downloader = ReportDownloaderBase()
        self.assertIsInstance(downloader, ReportDownloaderBase)
        self.assertIsNotNone(downloader.download_path)
        self.assertTrue(os.path.exists(downloader.download_path))

    def test_report_downloader_base(self):
        dl_path = '/{}/{}/{}'.format(self.fake.word().lower(),
                                     self.fake.word().lower(),
                                     self.fake.word().lower())
        downloader = ReportDownloaderBase(download_path=dl_path)
        self.assertEqual(downloader.download_path, dl_path)

    def test_get_existing_manifest_db_id(self):
        """Test that a manifest ID is returned."""

        manifest_id = self.downloader._get_existing_manifest_db_id(
            self.assembly_id)
        self.assertEqual(manifest_id, self.manifest_id)

    def test_check_if_manifest_should_be_downloaded_new_manifest(self):
        """Test that a new manifest should be processed."""
        result = self.downloader.check_if_manifest_should_be_downloaded('1234')
        self.assertTrue(result)

    def test_check_if_manifest_should_be_downloaded_currently_processing_manifest(
            self):
        """Test that a manifest being processed should not be reprocessed."""
        with ReportManifestDBAccessor() as manifest_accessor:
            manifest = manifest_accessor.get_manifest_by_id(self.manifest_id)
            manifest.num_processed_files = 1
            manifest.num_total_files = 2
            manifest.save()

        with ReportStatsDBAccessor(self.report_name,
                                   self.manifest_id) as file_accessor:
            file_accessor.log_last_started_datetime()
            file_accessor.log_last_completed_datetime()

        result = self.downloader.check_if_manifest_should_be_downloaded(
            self.assembly_id)
        self.assertFalse(result)

    def test_check_if_manifest_should_be_downloaded_error_processing_manifest(
            self):
        """Test that a manifest that did not succeessfully process should be reprocessed."""
        with ReportManifestDBAccessor() as manifest_accessor:
            manifest = manifest_accessor.get_manifest_by_id(self.manifest_id)
            manifest.num_processed_files = 1
            manifest.num_total_files = 2
            manifest.save()

        with ReportStatsDBAccessor(self.report_name,
                                   self.manifest_id) as file_accessor:
            file_accessor.log_last_started_datetime()
            file_accessor.log_last_completed_datetime()
            completed_datetime = self.date_accessor.today_with_timezone(
                'UTC') - datetime.timedelta(hours=1)
            file_accessor.update(last_completed_datetime=completed_datetime)
        result = self.downloader.check_if_manifest_should_be_downloaded(
            self.assembly_id)
        self.assertTrue(result)

    def test_check_if_manifest_should_be_downloaded_done_processing_manifest(
            self):
        """Test that a manifest that has finished processing is not reprocessed."""
        with ReportManifestDBAccessor() as manifest_accessor:
            manifest = manifest_accessor.get_manifest_by_id(self.manifest_id)
            manifest.num_processed_files = 2
            manifest.num_total_files = 2
            manifest.save()

        result = self.downloader.check_if_manifest_should_be_downloaded(
            self.assembly_id)
        self.assertFalse(result)
class ReportDownloaderBaseTest(MasuTestCase):
    """Test Cases for ReportDownloaderBase."""

    fake = Faker()
    patch_path = True

    @classmethod
    def setUpClass(cls):
        """Set up the test class."""
        super().setUpClass()
        cls.fake = Faker()
        cls.patch_path = True
        cls.date_accessor = DateAccessor()
        cls.assembly_id = cls.fake.pystr()
        cls.report_name = f"{cls.assembly_id}_file_1.csv.gz"

    def setUp(self):
        """Set up each test case."""
        super().setUp()
        self.cache_key = self.fake.word()
        self.mock_task = Mock(
            request=Mock(id=str(self.fake.uuid4()), return_value={}))
        self.downloader = ReportDownloaderBase(
            task=self.mock_task,
            provider_uuid=self.aws_provider_uuid,
            cache_key=self.cache_key)
        billing_start = self.date_accessor.today_with_timezone("UTC").replace(
            day=1)
        self.task_id = str(self.fake.uuid4())
        self.manifest_dict = {
            "assembly_id": self.assembly_id,
            "billing_period_start_datetime": billing_start,
            "num_total_files": 2,
            "provider_uuid": self.aws_provider_uuid,
            "task": self.task_id,
        }
        with ReportManifestDBAccessor() as manifest_accessor:
            manifest = manifest_accessor.add(**self.manifest_dict)
            manifest.save()
            self.manifest_id = manifest.id
        for i in [1, 2]:
            baker.make(
                CostUsageReportStatus,
                report_name=f"{self.assembly_id}_file_{i}.csv.gz",
                last_completed_datetime=None,
                last_started_datetime=None,
                manifest_id=self.manifest_id,
            )

    def tearDown(self):
        """Tear down each test case."""
        super().tearDown()
        with ReportStatsDBAccessor(self.report_name,
                                   self.manifest_id) as file_accessor:
            files = file_accessor._get_db_obj_query().all()
            for file in files:
                file_accessor.delete(file)

        with ReportManifestDBAccessor() as manifest_accessor:
            manifests = manifest_accessor._get_db_obj_query().all()
            for manifest in manifests:
                manifest_accessor.delete(manifest)

    def test_report_downloader_base_no_path(self):
        """Test report downloader download_path."""
        downloader = ReportDownloaderBase(self.mock_task)
        self.assertIsInstance(downloader, ReportDownloaderBase)
        self.assertIsNotNone(downloader.download_path)
        self.assertTrue(os.path.exists(downloader.download_path))

    def test_report_downloader_base(self):
        """Test download path matches expected."""
        dl_path = "/{}/{}/{}".format(self.fake.word().lower(),
                                     self.fake.word().lower(),
                                     self.fake.word().lower())
        downloader = ReportDownloaderBase(self.mock_task,
                                          download_path=dl_path)
        self.assertEqual(downloader.download_path, dl_path)

    def test_get_existing_manifest_db_id(self):
        """Test that a manifest ID is returned."""
        manifest_id = self.downloader._get_existing_manifest_db_id(
            self.assembly_id)
        self.assertEqual(manifest_id, self.manifest_id)

    def test_check_if_manifest_should_be_downloaded_new_manifest(self):
        """Test that a new manifest should be processed."""
        result = self.downloader.check_if_manifest_should_be_downloaded("1234")
        self.assertTrue(result)

    def test_check_if_manifest_should_be_downloaded_error_processing_manifest(
            self):
        """Test that a manifest that did not succeessfully process should be reprocessed."""
        reports = CostUsageReportStatus.objects.filter(
            manifest_id=self.manifest_id)
        with ReportStatsDBAccessor(reports[0].report_name,
                                   reports[0].manifest_id) as file_accessor:
            file_accessor.log_last_started_datetime()
            file_accessor.log_last_completed_datetime()
        with ReportStatsDBAccessor(reports[1].report_name,
                                   reports[1].manifest_id) as file_accessor:
            file_accessor.log_last_started_datetime()
            file_accessor.update(last_completed_datetime=None)
        result = self.downloader.check_if_manifest_should_be_downloaded(
            self.assembly_id)
        self.assertTrue(result)

    def test_check_if_manifest_should_be_downloaded_done_processing_manifest(
            self):
        """Test that a manifest that has finished processing is not reprocessed."""
        reports = CostUsageReportStatus.objects.filter(
            manifest_id=self.manifest_id)
        for report in reports:
            with ReportStatsDBAccessor(report.report_name,
                                       report.manifest_id) as file_accessor:
                file_accessor.log_last_started_datetime()
                file_accessor.log_last_completed_datetime()

        result = self.downloader.check_if_manifest_should_be_downloaded(
            self.assembly_id)
        self.assertFalse(result)

    def test_check_if_manifest_should_be_downloaded_error_no_complete_date(
            self):
        """Test that a manifest that did not succeessfully process should be reprocessed."""
        with ReportManifestDBAccessor() as manifest_accessor:
            manifest = manifest_accessor.get_manifest_by_id(self.manifest_id)
            manifest.num_processed_files = 1
            manifest.num_total_files = 2
            manifest.save()

        with ReportStatsDBAccessor(self.report_name,
                                   self.manifest_id) as file_accessor:
            file_accessor.log_last_started_datetime()
        result = self.downloader.check_if_manifest_should_be_downloaded(
            self.assembly_id)
        self.assertTrue(result)

    def test_check_if_manifest_should_be_downloaded_task_currently_running(
            self):
        """Test that a manifest being processed should not be reprocessed."""
        _cache = WorkerCache()
        _cache.add_task_to_cache(self.cache_key)

        result = self.downloader.check_if_manifest_should_be_downloaded(
            self.assembly_id)
        self.assertFalse(result)
Exemple #13
0
class ReportDownloaderBaseTest(MasuTestCase):
    """Test Cases for ReportDownloaderBase."""

    fake = Faker()
    patch_path = True

    @classmethod
    def setUpClass(cls):
        """Set up the test class."""
        super().setUpClass()
        cls.fake = Faker()
        cls.patch_path = True
        cls.date_accessor = DateAccessor()
        cls.assembly_id = cls.fake.pystr()
        cls.report_name = f'{cls.assembly_id}_file_1.csv.gz'

    def setUp(self):
        """Set up each test case."""
        super().setUp()
        self.mock_task = Mock(
            request=Mock(id=str(self.fake.uuid4()), return_value={}))
        self.downloader = ReportDownloaderBase(
            task=self.mock_task, provider_uuid=self.aws_provider_uuid)
        billing_start = self.date_accessor.today_with_timezone('UTC').replace(
            day=1)
        self.task_id = str(self.fake.uuid4())
        self.manifest_dict = {
            'assembly_id': self.assembly_id,
            'billing_period_start_datetime': billing_start,
            'num_total_files': 2,
            'provider_uuid': self.aws_provider_uuid,
            'task': self.task_id,
        }
        with ReportManifestDBAccessor() as manifest_accessor:
            manifest = manifest_accessor.add(**self.manifest_dict)
            manifest.save()
            self.manifest_id = manifest.id

    def tearDown(self):
        """Tear down each test case."""
        super().tearDown()
        with ReportStatsDBAccessor(self.report_name,
                                   self.manifest_id) as file_accessor:
            files = file_accessor._get_db_obj_query().all()
            for file in files:
                file_accessor.delete(file)

        with ReportManifestDBAccessor() as manifest_accessor:
            manifests = manifest_accessor._get_db_obj_query().all()
            for manifest in manifests:
                manifest_accessor.delete(manifest)

    @patch('masu.external.downloader.report_downloader_base.app')
    def test_report_downloader_base_no_path(self, _):
        """Test report downloader download_path."""
        downloader = ReportDownloaderBase(self.mock_task)
        self.assertIsInstance(downloader, ReportDownloaderBase)
        self.assertIsNotNone(downloader.download_path)
        self.assertTrue(os.path.exists(downloader.download_path))

    @patch('masu.external.downloader.report_downloader_base.app')
    def test_report_downloader_base(self, _):
        """Test download path matches expected."""
        dl_path = '/{}/{}/{}'.format(self.fake.word().lower(),
                                     self.fake.word().lower(),
                                     self.fake.word().lower())
        downloader = ReportDownloaderBase(self.mock_task,
                                          download_path=dl_path)
        self.assertEqual(downloader.download_path, dl_path)

    @patch('masu.external.downloader.report_downloader_base.app')
    def test_get_existing_manifest_db_id(self, _):
        """Test that a manifest ID is returned."""
        manifest_id = self.downloader._get_existing_manifest_db_id(
            self.assembly_id)
        self.assertEqual(manifest_id, self.manifest_id)

    @patch('masu.external.downloader.report_downloader_base.app')
    def test_check_if_manifest_should_be_downloaded_new_manifest(self, _):
        """Test that a new manifest should be processed."""
        result = self.downloader.check_if_manifest_should_be_downloaded('1234')
        self.assertTrue(result)

    @patch('masu.external.downloader.report_downloader_base.app')
    def test_check_if_manifest_should_be_downloaded_currently_processing_manifest(
            self, _):
        """Test that a manifest being processed should not be reprocessed."""
        with ReportManifestDBAccessor() as manifest_accessor:
            manifest = manifest_accessor.get_manifest_by_id(self.manifest_id)
            manifest.num_processed_files = 1
            manifest.num_total_files = 2
            manifest.save()

        with ReportStatsDBAccessor(self.report_name,
                                   self.manifest_id) as file_accessor:
            file_accessor.log_last_started_datetime()
            file_accessor.log_last_completed_datetime()

        result = self.downloader.check_if_manifest_should_be_downloaded(
            self.assembly_id)
        self.assertFalse(result)

    @patch('masu.external.downloader.report_downloader_base.app')
    def test_check_if_manifest_should_be_downloaded_error_processing_manifest(
            self, _):
        """Test that a manifest that did not succeessfully process should be reprocessed."""
        with ReportManifestDBAccessor() as manifest_accessor:
            manifest = manifest_accessor.get_manifest_by_id(self.manifest_id)
            manifest.num_processed_files = 1
            manifest.num_total_files = 2
            manifest.save()

        with ReportStatsDBAccessor(self.report_name,
                                   self.manifest_id) as file_accessor:
            file_accessor.log_last_started_datetime()
            file_accessor.log_last_completed_datetime()
            completed_datetime = self.date_accessor.today_with_timezone(
                'UTC') - datetime.timedelta(hours=1)
            file_accessor.update(last_completed_datetime=completed_datetime)
        result = self.downloader.check_if_manifest_should_be_downloaded(
            self.assembly_id)
        self.assertTrue(result)

    @patch('masu.external.downloader.report_downloader_base.app')
    def test_check_if_manifest_should_be_downloaded_done_processing_manifest(
            self, _):
        """Test that a manifest that has finished processing is not reprocessed."""
        with ReportManifestDBAccessor() as manifest_accessor:
            manifest = manifest_accessor.get_manifest_by_id(self.manifest_id)
            manifest.num_processed_files = 2
            manifest.num_total_files = 2
            manifest.save()

        result = self.downloader.check_if_manifest_should_be_downloaded(
            self.assembly_id)
        self.assertFalse(result)

    @patch('masu.external.downloader.report_downloader_base.app')
    def test_check_task_queues_false(self, mock_celery):
        """Test that check_task_queues() returns false when task_id is absent."""
        # app.control.inspect()
        mock_celery.control = Mock(inspect=Mock(return_value=Mock(
            active=Mock(return_value={}),
            reserved=Mock(return_value={}),
            scheduled=Mock(return_value={}),
        )))
        result = self.downloader.check_task_queues(self.manifest_id)
        self.assertFalse(result)

    @patch('masu.external.downloader.report_downloader_base.app')
    def test_check_task_queues_true(self, mock_celery):
        """Test that check_task_queues() returns true when task_id is found."""
        # app.control.inspect()
        active = Mock(return_value={self.fake.word(): [{'id': self.task_id}]})
        mock_celery.control = Mock(inspect=Mock(
            return_value=Mock(active=active,
                              reserved=Mock(return_value={}),
                              scheduled=Mock(return_value={}))))
        result = self.downloader.check_task_queues(self.task_id)
        self.assertTrue(result)

    @patch('masu.external.downloader.report_downloader_base.app')
    def test_check_if_manifest_should_be_downloaded_error_no_complete_date(
            self, _):
        """Test that a manifest that did not succeessfully process should be reprocessed."""
        with ReportManifestDBAccessor() as manifest_accessor:
            manifest = manifest_accessor.get_manifest_by_id(self.manifest_id)
            manifest.num_processed_files = 1
            manifest.num_total_files = 2
            manifest.save()

        with ReportStatsDBAccessor(self.report_name,
                                   self.manifest_id) as file_accessor:
            file_accessor.log_last_started_datetime()
        result = self.downloader.check_if_manifest_should_be_downloaded(
            self.assembly_id)
        self.assertTrue(result)
 def test_report_downloader_base_no_path(self, _):
     downloader = ReportDownloaderBase(self.mock_task)
     self.assertIsInstance(downloader, ReportDownloaderBase)
     self.assertIsNotNone(downloader.download_path)
     self.assertTrue(os.path.exists(downloader.download_path))