Esempio n. 1
0
def load_application_cases():
    account = Account(**AccountFixtureFactory.make_publisher_source())
    account.set_id(account.makeid())

    application = Suggestion(**ApplicationFixtureFactory.make_application_source())
    application.makeid()

    wrong_id = uuid.uuid4()

    return [
        param("a_id_acc_lock", application, application.id, account, True, raises=lock.Locked),
        param("a_id_acc_nolock", application, application.id, account, False),
        param("a_id_noacc_nolock", application, application.id, None, False),
        param("a_noid_noacc_nolock", application, None, None, False, raises=exceptions.ArgumentException),
        param("a_wid_noacc_nolock", application, wrong_id, None, False),
        param("noa_id_noacc_nolock", None, application.id, None, False),
        param("noa_noid_noacc_nolock", None, None, None, False, raises=exceptions.ArgumentException)
    ]
Esempio n. 2
0
def load_journal_cases():
    account = Account(**AccountFixtureFactory.make_publisher_source())
    account.set_id(account.makeid())

    journal = Journal(**JournalFixtureFactory.make_journal_source(in_doaj=True))
    journal.set_id(journal.makeid())

    wrong_id = uuid.uuid4()

    return [
        param("j_id_acc_lock", journal, journal.id, account, True, raises=lock.Locked),
        param("j_id_acc_nolock", journal, journal.id, account, False),
        param("j_id_noacc_nolock", journal, journal.id, None, False),
        param("j_noid_noacc_nolock", journal, None, None, False, raises=exceptions.ArgumentException),
        param("j_wid_noacc_nolock", journal, wrong_id, None, False),
        param("noj_id_noacc_nolock", None, journal.id, None, False),
        param("noj_noid_noacc_nolock", None, None, None, False, raises=exceptions.ArgumentException)
    ]
Esempio n. 3
0
def load_j2a_cases():
    journal = Journal(**JournalFixtureFactory.make_journal_source(in_doaj=True))
    account_source = AccountFixtureFactory.make_publisher_source()

    owner_account = Account(**deepcopy(account_source))
    owner_account.set_id(journal.owner)

    non_owner_publisher = Account(**deepcopy(account_source))

    non_publisher = Account(**deepcopy(account_source))
    non_publisher.remove_role("publisher")

    admin = Account(**deepcopy(account_source))
    admin.add_role("admin")

    return [
        param("no_journal_no_account", None, None, raises=exceptions.ArgumentException),
        param("no_journal_with_account", None, owner_account, raises=exceptions.ArgumentException),
        param("journal_no_account", journal, None, comparator=application_matches),
        param("journal_matching_account", journal, owner_account, comparator=application_matches),
        param("journal_unmatched_account", journal, non_owner_publisher, raises=exceptions.AuthoriseException),
        param("journal_non_publisher_account", journal, non_publisher, raises=exceptions.AuthoriseException),
        param("journal_admin_account", journal, admin, comparator=application_matches)
    ]
class ValidateTest(unittest.TestCase):
    @parameterized.expand([
        param("NoFeaturesExpectedNonCompound",
              entry={
                  "tag": "TaG-1",
                  "root": "valid-root",
                  "morphophonemics": "~",
                  "features": "~",
                  "is_compound": "FaLsE"
              }),
        param("NoFeaturesExpectedNonCompoundWithMorphophonemics",
              entry={
                  "tag": "TaG-1",
                  "root": "valid-root",
                  "morphophonemics": "valid-morphophonemics",
                  "features": "~",
                  "is_compound": "FaLsE"
              }),
        param("NoFeaturesExpectedCompoundWithMorphophonemics",
              entry={
                  "tag": "TaG-1",
                  "root": "valid-root",
                  "morphophonemics": "valid-morphophonemics",
                  "features": "~",
                  "is_compound": "TrUe"
              }),
        param("RequiredFeaturesExpectedNonCompoundWithFeatures",
              entry={
                  "tag": "TaG-2",
                  "root": "valid-root",
                  "morphophonemics": "~",
                  "features": "+[Cat1=Val12]+[Cat2=Val21]",
                  "is_compound": "FaLsE"
              }),
        param(
            "RequiredFeaturesExpectedNonCompoundWithFeaturesAndMorphophonemics",
            entry={
                "tag": "TaG-2",
                "root": "valid-root",
                "morphophonemics": "valid-morphophonemics",
                "features": "+[Cat1=Val12]+[Cat2=Val21]",
                "is_compound": "FaLsE"
            }),
        param("RequiredFeaturesExpectedCompoundWithFeaturesAndMorphophonemics",
              entry={
                  "tag": "TaG-2",
                  "root": "valid-root",
                  "morphophonemics": "valid-morphophonemics",
                  "features": "+[Cat1=Val12]+[Cat2=Val21]",
                  "is_compound": "TrUe"
              }),
        param("OptionalFeaturesExpectedNonCompound",
              entry={
                  "tag": "TaG-3",
                  "root": "valid-root",
                  "morphophonemics": "~",
                  "features": "~",
                  "is_compound": "FaLsE"
              }),
        param("OptionalFeaturesExpectedNonCompoundWithMorphophonemics",
              entry={
                  "tag": "TaG-3",
                  "root": "valid-root",
                  "morphophonemics": "valid-morphophonemics",
                  "features": "~",
                  "is_compound": "FaLsE"
              }),
        param("OptionalFeaturesExpectedNonCompoundWithFeatures",
              entry={
                  "tag": "TaG-3",
                  "root": "valid-root",
                  "morphophonemics": "~",
                  "features": "+[Cat1=Val12]",
                  "is_compound": "FaLsE"
              }),
        param(
            "OptionalFeaturesExpectedNonCompoundWithFeaturesAndMorphophonemics",
            entry={
                "tag": "TaG-3",
                "root": "valid-root",
                "morphophonemics": "valid-morphophonemics",
                "features": "+[Cat1=Val12]",
                "is_compound": "FaLsE"
            }),
        param("OptionalFeaturesExpectedCompoundWithMorphophonemics",
              entry={
                  "tag": "TaG-3",
                  "root": "valid-root",
                  "morphophonemics": "valid-morphophonemics",
                  "features": "~",
                  "is_compound": "TrUe"
              }),
        param("OptionalFeaturesExpectedCompoundWithFeaturesAndMorphophonemics",
              entry={
                  "tag": "TaG-3",
                  "root": "valid-root",
                  "morphophonemics": "valid-morphophonemics",
                  "features": "+[Cat1=Val12]",
                  "is_compound": "TrUe"
              }),
    ])
    def test_success(self, _, entry):
        self.assertIsNone(validator.validate(entry))

    @parameterized.expand([
        param("MissingTag",
              error="Entry is missing fields: 'tag'",
              entry={
                  "root": "valid-root",
                  "morphophonemics": "valid-morphophonemics",
                  "features": "~",
                  "is_compound": "TrUe"
              }),
        param("MissingRoot",
              error="Entry is missing fields: 'root'",
              entry={
                  "tag": "TaG-1",
                  "morphophonemics": "valid-morphophonemics",
                  "features": "~",
                  "is_compound": "TrUe"
              }),
        param("MissingMorphophonemics",
              error="Entry is missing fields: 'morphophonemics'",
              entry={
                  "tag": "TaG-1",
                  "root": "valid-root",
                  "features": "~",
                  "is_compound": "TrUe"
              }),
        param("MissingFeatures",
              error="Entry is missing fields: 'features'",
              entry={
                  "tag": "TaG-1",
                  "root": "valid-root",
                  "morphophonemics": "valid-morphophonemics",
                  "is_compound": "TrUe"
              }),
        param("MissingIsCompound",
              error="Entry is missing fields: 'is_compound'",
              entry={
                  "tag": "TaG-1",
                  "root": "valid-root",
                  "morphophonemics": "valid-morphophonemics",
                  "features": "~"
              }),
        param("MissingMultipleRequiredField",
              error=("Entry is missing fields: 'is_compound,"
                     " morphophonemics, root"),
              entry={
                  "tag": "TaG-1",
                  "features": "~"
              }),
        param("EmptyTag",
              error="Entry fields have empty values: 'tag'",
              entry={
                  "tag": "",
                  "root": "valid-root",
                  "morphophonemics": "valid-morphophonemics",
                  "features": "~",
                  "is_compound": "TrUe"
              }),
        param("EmptyRoot",
              error="Entry fields have empty values: 'root'",
              entry={
                  "tag": "TaG-1",
                  "root": "",
                  "morphophonemics": "valid-morphophonemics",
                  "features": "~",
                  "is_compound": "TrUe"
              }),
        param("EmptyMorphophonemics",
              error="Entry fields have empty values: 'morphophonemics'",
              entry={
                  "tag": "TaG-1",
                  "root": "valid-root",
                  "morphophonemics": "",
                  "features": "~",
                  "is_compound": "TrUe"
              }),
        param("EmptyFeatures",
              error="Entry fields have empty values: 'features'",
              entry={
                  "tag": "TaG-1",
                  "root": "valid-root",
                  "morphophonemics": "valid-morphophonemics",
                  "features": "",
                  "is_compound": "TrUe"
              }),
        param("EmptyIsCompound",
              error="Entry fields have empty values: 'is_compound'",
              entry={
                  "tag": "TaG-1",
                  "root": "valid-root",
                  "morphophonemics": "valid-morphophonemics",
                  "features": "~",
                  "is_compound": ""
              }),
        param("MultipleEmptyRequiredField",
              error=("Entry fields have empty values: 'is_compound,"
                     " morphophonemics, root'"),
              entry={
                  "tag": "TaG-1",
                  "root": "",
                  "morphophonemics": "",
                  "features": "~",
                  "is_compound": ""
              }),
        param("TagContainsInfixWhitespace",
              error="Entry field values contain whitespace: 'tag'",
              entry={
                  "tag": "TaG 1",
                  "root": "valid-root",
                  "morphophonemics": "valid-morphophonemics",
                  "features": "~",
                  "is_compound": "TrUe"
              }),
        param("MorphophonemicsContainsInfixWhitespace",
              error="Entry field values contain whitespace: 'morphophonemics'",
              entry={
                  "tag": "TaG-1",
                  "root": "valid-root",
                  "morphophonemics": "valid morphophonemics",
                  "features": "~",
                  "is_compound": "TrUe"
              }),
        param("FeaturesContainsInfixWhitespace",
              error="Entry field values contain whitespace: 'features'",
              entry={
                  "tag": "TaG-3",
                  "root": "valid-root",
                  "morphophonemics": "valid-morphophonemics",
                  "features": "+[Cat1 = Val12]",
                  "is_compound": "TrUe"
              }),
        param("MultipleFieldsContainsInfixWhitespace",
              error="Entry field values contain whitespace: 'features, tag'",
              entry={
                  "tag": "TaG 3",
                  "root": "valid-root",
                  "morphophonemics": "valid-morphophonemics",
                  "features": "+[Cat1 = Val12]",
                  "is_compound": "TrUe"
              }),
        param(
            "InvalidTag",
            error=("Entry 'tag' field has invalid value. It can only be one of"
                   " the valid tags that are defined in"
                   " 'morphotactics_compiler/tags.py'."),
            entry={
                "tag": "Invalid-Tag",
                "root": "valid-root",
                "morphophonemics": "valid-morphophonemics",
                "features": "~",
                "is_compound": "TrUe"
            }),
        param("InvalidIsCompound",
              error=("Entry 'is_compound' field has invalid value. It can only"
                     " have the values 'true' or 'false'."),
              entry={
                  "tag": "TaG-1",
                  "root": "valid-root",
                  "morphophonemics": "valid-morphophonemics",
                  "features": "~",
                  "is_compound": "invalid-is-compound"
              }),
        param("InvalidMorphophonemics",
              error=(
                  "Entry is marked as ending with compounding marker but it is"
                  " missing morphophonemics annotation."),
              entry={
                  "tag": "TaG-1",
                  "root": "valid-root",
                  "morphophonemics": "~",
                  "features": "~",
                  "is_compound": "TrUe"
              }),
        param("InvalidFeaturesInvalidPrefixCharacters",
              error="Entry features annotation is invalid.",
              entry={
                  "tag": "TaG-3",
                  "root": "valid-root",
                  "morphophonemics": "valid-morphophonemics",
                  "features": "foo+[Cat1=Val12]+[Cat3=Val31]",
                  "is_compound": "TrUe"
              }),
        param("InvalidFeaturesInvalidInfixCharacters",
              error="Entry features annotation is invalid.",
              entry={
                  "tag": "TaG-3",
                  "root": "valid-root",
                  "morphophonemics": "valid-morphophonemics",
                  "features": "+[Cat1=Val12]foo+[Cat3=Val31]",
                  "is_compound": "TrUe"
              }),
        param("InvalidFeaturesInvalidSuffixCharacters",
              error="Entry features annotation is invalid.",
              entry={
                  "tag": "TaG-3",
                  "root": "valid-root",
                  "morphophonemics": "valid-morphophonemics",
                  "features": "+[Cat1=Val12]+[Cat3=Val31]foo",
                  "is_compound": "TrUe"
              }),
        param("NoRequiredFeatures",
              error="Entry is missing required features.",
              entry={
                  "tag": "TaG-2",
                  "root": "valid-root",
                  "morphophonemics": "valid-morphophonemics",
                  "features": "~",
                  "is_compound": "TrUe"
              }),
        param("MissingRequiredFeatures",
              error="Entry has invalid required feature category.",
              entry={
                  "tag": "TaG-2",
                  "root": "valid-root",
                  "morphophonemics": "valid-morphophonemics",
                  "features": "+[Cat1=Val12]",
                  "is_compound": "TrUe"
              }),
        param("InvalidRequiredFeatureCategory",
              error="Entry has invalid required feature category.",
              entry={
                  "tag": "TaG-2",
                  "root": "valid-root",
                  "morphophonemics": "valid-morphophonemics",
                  "features": "+[Cat1=Val12]+[Cat3=Val21]",
                  "is_compound": "TrUe"
              }),
        param("InvalidRequiredFeatureValue",
              error="Entry has invalid required feature value.",
              entry={
                  "tag": "TaG-2",
                  "root": "valid-root",
                  "morphophonemics": "valid-morphophonemics",
                  "features": "+[Cat1=Val12]+[Cat2=Val23]",
                  "is_compound": "TrUe"
              }),
        param("InvalidOptionalFeatureCategory",
              error="Entry has invalid optional features.",
              entry={
                  "tag": "TaG-3",
                  "root": "valid-root",
                  "morphophonemics": "valid-morphophonemics",
                  "features": "+[Cat2=Val12]",
                  "is_compound": "TrUe"
              }),
        param("InvalidOptionalFeatureValue",
              error="Entry has invalid optional features.",
              entry={
                  "tag": "TaG-3",
                  "root": "valid-root",
                  "morphophonemics": "valid-morphophonemics",
                  "features": "+[Cat1=Val11]",
                  "is_compound": "TrUe"
              }),
        param("RedundantFeatures",
              error="Entry has features while it is not expected to have any.",
              entry={
                  "tag": "TaG-1",
                  "root": "valid-root",
                  "morphophonemics": "valid-morphophonemics",
                  "features": "+[Cat1=Val12]",
                  "is_compound": "TrUe"
              }),
    ])
    def test_raises_exception(self, _, error, entry):
        with self.assertRaisesRegexp(validator.InvalidLexiconEntryError,
                                     error):
            validator.validate(entry)
Esempio n. 5
0
class TestDateParser(BaseTestCase):
    def setUp(self):
        super(TestDateParser, self).setUp()
        self.parser = NotImplemented
        self.result = NotImplemented
        self.date_parser = NotImplemented
        self.date_result = NotImplemented

    @parameterized.expand([
        # English dates
        param('[Sept] 04, 2014.', datetime(2014, 9, 4)),
        param('Tuesday Jul 22, 2014', datetime(2014, 7, 22)),
        param('Tues 9th Aug, 2015', datetime(2015, 8, 9)),
        param('10:04am', datetime(2012, 11, 13, 10, 4)),
        param('Friday', datetime(2012, 11, 9)),
        param('November 19, 2014 at noon', datetime(2014, 11, 19, 12, 0)),
        param('December 13, 2014 at midnight', datetime(2014, 12, 13, 0, 0)),
        param('Nov 25 2014 10:17 pm', datetime(2014, 11, 25, 22, 17)),
        param('Wed Aug 05 12:00:00 2015', datetime(2015, 8, 5, 12, 0)),
        param('April 9, 2013 at 6:11 a.m.', datetime(2013, 4, 9, 6, 11)),
        param('Aug. 9, 2012 at 2:57 p.m.', datetime(2012, 8, 9, 14, 57)),
        param('December 10, 2014, 11:02:21 pm', datetime(2014, 12, 10, 23, 2, 21)),
        param('8:25 a.m. Dec. 12, 2014', datetime(2014, 12, 12, 8, 25)),
        param('2:21 p.m., December 11, 2014', datetime(2014, 12, 11, 14, 21)),
        param('Fri, 12 Dec 2014 10:55:50', datetime(2014, 12, 12, 10, 55, 50)),
        param('20 Mar 2013 10h11', datetime(2013, 3, 20, 10, 11)),
        param('10:06am Dec 11, 2014', datetime(2014, 12, 11, 10, 6)),
        param('19 February 2013 year 09:10', datetime(2013, 2, 19, 9, 10)),
        param('21 January 2012 13:11:23.678', datetime(2012, 1, 21, 13, 11, 23, 678000)),
        param('1/1/16 9:02:43.1', datetime(2016, 1, 1, 9, 2, 43, 100000)),
        # French dates
        param('11 Mai 2014', datetime(2014, 5, 11)),
        param('dimanche, 11 Mai 2014', datetime(2014, 5, 11)),
        param('22 janvier 2015 à 14h40', datetime(2015, 1, 22, 14, 40)),
        param('Dimanche 1er Février à 21:24', datetime(2012, 2, 1, 21, 24)),
        param('vendredi, décembre 5 2014.', datetime(2014, 12, 5, 0, 0)),
        param('le 08 Déc 2014 15:11', datetime(2014, 12, 8, 15, 11)),
        param('Le 11 Décembre 2014 à 09:00', datetime(2014, 12, 11, 9, 0)),
        param('fév 15, 2013', datetime(2013, 2, 15, 0, 0)),
        param('Jeu 15:12', datetime(2012, 11, 8, 15, 12)),
        # Spanish dates
        param('Martes 21 de Octubre de 2014', datetime(2014, 10, 21)),
        param('Miércoles 20 de Noviembre de 2013', datetime(2013, 11, 20)),
        param('12 de junio del 2012', datetime(2012, 6, 12)),
        param('13 Ago, 2014', datetime(2014, 8, 13)),
        param('13 Septiembre, 2014', datetime(2014, 9, 13)),
        param('11 Marzo, 2014', datetime(2014, 3, 11)),
        param('julio 5, 2015 en 1:04 pm', datetime(2015, 7, 5, 13, 4)),
        param('Vi 17:15', datetime(2012, 11, 9, 17, 15)),
        # Dutch dates
        param('11 augustus 2014', datetime(2014, 8, 11)),
        param('14 januari 2014', datetime(2014, 1, 14)),
        param('vr jan 24, 2014 12:49', datetime(2014, 1, 24, 12, 49)),
        # Italian dates
        param('16 giu 2014', datetime(2014, 6, 16)),
        param('26 gennaio 2014', datetime(2014, 1, 26)),
        param('Ven 18:23', datetime(2012, 11, 9, 18, 23)),
        # Portuguese dates
        param('sexta-feira, 10 de junho de 2014 14:52', datetime(2014, 6, 10, 14, 52)),
        param('13 Setembro, 2014', datetime(2014, 9, 13)),
        param('Sab 3:03', datetime(2012, 11, 10, 3, 3)),
        # Russian dates
        param('10 мая', datetime(2012, 5, 10)),  # forum.codenet.ru
        param('26 апреля', datetime(2012, 4, 26)),
        param('20 ноября 2013', datetime(2013, 11, 20)),
        param('28 октября 2014 в 07:54', datetime(2014, 10, 28, 7, 54)),
        param('13 января 2015 г. в 13:34', datetime(2015, 1, 13, 13, 34)),
        param('09 августа 2012', datetime(2012, 8, 9, 0, 0)),
        param('Авг 26, 2015 15:12', datetime(2015, 8, 26, 15, 12)),
        param('2 Декабрь 95 11:15', datetime(1995, 12, 2, 11, 15)),
        param('13 янв. 2005 19:13', datetime(2005, 1, 13, 19, 13)),
        param('13 авг. 2005 19:13', datetime(2005, 8, 13, 19, 13)),
        param('13 авг. 2005г. 19:13', datetime(2005, 8, 13, 19, 13)),
        param('13 авг. 2005 г. 19:13', datetime(2005, 8, 13, 19, 13)),
        # Turkish dates
        param('11 Ağustos, 2014', datetime(2014, 8, 11)),
        param('08.Haziran.2014, 11:07', datetime(2014, 6, 8, 11, 7)),  # forum.andronova.net
        param('17.Şubat.2014, 17:51', datetime(2014, 2, 17, 17, 51)),
        param('14-Aralık-2012, 20:56', datetime(2012, 12, 14, 20, 56)),  # forum.ceviz.net
        # Romanian dates
        param('13 iunie 2013', datetime(2013, 6, 13)),
        param('14 aprilie 2014', datetime(2014, 4, 14)),
        param('18 martie 2012', datetime(2012, 3, 18)),
        param('12-Iun-2013', datetime(2013, 6, 12)),
        # German dates
        param('21. Dezember 2013', datetime(2013, 12, 21)),
        param('19. Februar 2012', datetime(2012, 2, 19)),
        param('26. Juli 2014', datetime(2014, 7, 26)),
        param('18.10.14 um 22:56 Uhr', datetime(2014, 10, 18, 22, 56)),
        param('12-Mär-2014', datetime(2014, 3, 12)),
        param('Mit 13:14', datetime(2012, 11, 7, 13, 14)),
        # Czech dates
        param('pon 16. čer 2014 10:07:43', datetime(2014, 6, 16, 10, 7, 43)),
        param('13 Srpen, 2014', datetime(2014, 8, 13)),
        param('čtv 14. lis 2013 12:38:43', datetime(2013, 11, 14, 12, 38, 43)),
        # Thai dates
        param('ธันวาคม 11, 2014, 08:55:08 PM', datetime(2014, 12, 11, 20, 55, 8)),
        param('22 พฤษภาคม 2012, 22:12', datetime(2012, 5, 22, 22, 12)),
        param('11 กุมภา 2020, 8:13 AM', datetime(2020, 2, 11, 8, 13)),
        param('1 เดือนตุลาคม 2005, 1:00 AM', datetime(2005, 10, 1, 1, 0)),
        param('11 ก.พ. 2020, 1:13 pm', datetime(2020, 2, 11, 13, 13)),
        # Vietnamese dates
        param('Thứ năm', datetime(2012, 11, 8)),  # Thursday
        param('Thứ sáu', datetime(2012, 11, 9)),  # Friday
        param('Tháng Mười Hai 29, 2013, 14:14', datetime(2013, 12, 29, 14, 14)),  # bpsosrcs.wordpress.com  # NOQA
        param('05 Tháng một 2015 - 03:54 AM', datetime(2015, 1, 5, 3, 54)),
        # Belarusian dates
        param('11 траўня', datetime(2012, 5, 11)),
        param('4 мая', datetime(2012, 5, 4)),
        param('Чацвер 06 жніўня 2015', datetime(2015, 8, 6)),
        param('Нд 14 сакавіка 2015 у 7 гадзін 10 хвілін', datetime(2015, 3, 14, 7, 10)),
        param('5 жніўня 2015 года у 13:34', datetime(2015, 8, 5, 13, 34)),
        # Ukrainian dates
        param('2015-кві-12', datetime(2015, 4, 12)),
        param('21 чер 2013 3:13', datetime(2013, 6, 21, 3, 13)),
        param('12 лютого 2012, 13:12:23', datetime(2012, 2, 12, 13, 12, 23)),
        param('вів о 14:04', datetime(2012, 11, 13, 14, 4)),
        # Tagalog dates
        param('12 Hulyo 2003 13:01', datetime(2003, 7, 12, 13, 1)),
        param('1978, 1 Peb, 7:05 PM', datetime(1978, 2, 1, 19, 5)),
        param('2 hun', datetime(2012, 6, 2)),
        param('Lin 16:16', datetime(2012, 11, 11, 16, 16)),
        # Japanese dates
        param('2016年3月20日(日) 21時40分', datetime(2016, 3, 20, 21, 40)),
        param("2016年3月20日 21時40分", datetime(2016, 3, 20, 21, 40)),
        # Numeric dates
        param('06-17-2014', datetime(2014, 6, 17)),
        param('13/03/2014', datetime(2014, 3, 13)),
        param('11. 12. 2014, 08:45:39', datetime(2014, 11, 12, 8, 45, 39)),
        # Miscellaneous dates
        param('1 Ni 2015', datetime(2015, 4, 1, 0, 0)),
        param('1 Mar 2015', datetime(2015, 3, 1, 0, 0)),
        param('1 сер 2015', datetime(2015, 8, 1, 0, 0)),
        param('2016020417:10', datetime(2016, 2, 4, 17, 10)),
        # Chinese dates
        param('2015年04月08日10:05', datetime(2015, 4, 8, 10, 5)),
        param('2012年12月20日10:35', datetime(2012, 12, 20, 10, 35)),
        param('2016年06月30日09时30分', datetime(2016, 6, 30, 9, 30)),
        param('2016年6月2911:30', datetime(2016, 6, 29, 11, 30)),
        param('2016年6月29', datetime(2016, 6, 29, 0, 0)),
        param('2016年 2月 5日', datetime(2016, 2, 5, 0, 0)),
        param('2016年9月14日晚8:00', datetime(2016, 9, 14, 20, 0)),
        # Bulgarian
        param('25 ян 2016', datetime(2016, 1, 25, 0, 0)),
        param('23 декември 2013 15:10:01', datetime(2013, 12, 23, 15, 10, 1)),
        # Bangla dates
        param('[সেপ্টেম্বর] 04, 2014.', datetime(2014, 9, 4)),
        param('মঙ্গলবার জুলাই 22, 2014', datetime(2014, 7, 22)),
        param('শুক্রবার', datetime(2012, 11, 9)),
        param('শুক্র, 12 ডিসেম্বর 2014 10:55:50', datetime(2014, 12, 12, 10, 55, 50)),
        param('1লা জানুয়ারী 2015', datetime(2015, 1, 1)),
        param('25শে মার্চ 1971', datetime(1971, 3, 25)),
        param('8ই মে 2002', datetime(2002, 5, 8)),
        param('10:06am ডিসেম্বর 11, 2014', datetime(2014, 12, 11, 10, 6)),
        param('19 ফেব্রুয়ারী 2013 সাল 09:10', datetime(2013, 2, 19, 9, 10)),
        # Hindi dates
        param('11 जुलाई 1994, 11:12', datetime(1994, 7, 11, 11, 12)),
        param('१७ अक्टूबर २०१८', datetime(2018, 10, 17, 0, 0)),
        param('12 जनवरी  1997 11:08 अपराह्न', datetime(1997, 1, 12, 23, 8)),
        # Georgian dates
        param('2011 წლის 17 მარტი, ოთხშაბათი', datetime(2011, 3, 17, 0, 0)),
        param('2015 წ. 12 ივნ, 15:34', datetime(2015, 6, 12, 15, 34))
    ])
    def test_dates_parsing(self, date_string, expected):
        self.given_parser(settings={'NORMALIZE': False,
                                    'RELATIVE_BASE': datetime(2012, 11, 13)})
        self.when_date_is_parsed(date_string)
        self.then_date_was_parsed_by_date_parser()
        self.then_period_is('day')
        self.then_date_obj_exactly_is(expected)

    def test_stringified_datetime_should_parse_fine(self):
        expected_date = datetime(2012, 11, 13, 10, 15, 5, 330256)
        self.given_parser(settings={'RELATIVE_BASE': expected_date})
        date_string = str(self.parser.get_date_data('today')['date_obj'])
        self.when_date_is_parsed(date_string)
        self.then_date_was_parsed_by_date_parser()
        self.then_period_is('day')
        self.then_date_obj_exactly_is(expected_date)

    @parameterized.expand([
        # English dates
        param('[Sept] 04, 2014.', datetime(2014, 9, 4)),
        param('Tuesday Jul 22, 2014', datetime(2014, 7, 22)),
        param('10:04am', datetime(2012, 11, 13, 10, 4)),
        param('Friday', datetime(2012, 11, 9)),
        param('November 19, 2014 at noon', datetime(2014, 11, 19, 12, 0)),
        param('December 13, 2014 at midnight', datetime(2014, 12, 13, 0, 0)),
        param('Nov 25 2014 10:17 pm', datetime(2014, 11, 25, 22, 17)),
        param('Wed Aug 05 12:00:00 2015', datetime(2015, 8, 5, 12, 0)),
        param('April 9, 2013 at 6:11 a.m.', datetime(2013, 4, 9, 6, 11)),
        param('Aug. 9, 2012 at 2:57 p.m.', datetime(2012, 8, 9, 14, 57)),
        param('December 10, 2014, 11:02:21 pm', datetime(2014, 12, 10, 23, 2, 21)),
        param('8:25 a.m. Dec. 12, 2014', datetime(2014, 12, 12, 8, 25)),
        param('2:21 p.m., December 11, 2014', datetime(2014, 12, 11, 14, 21)),
        param('Fri, 12 Dec 2014 10:55:50', datetime(2014, 12, 12, 10, 55, 50)),
        param('20 Mar 2013 10h11', datetime(2013, 3, 20, 10, 11)),
        param('10:06am Dec 11, 2014', datetime(2014, 12, 11, 10, 6)),
        param('19 February 2013 year 09:10', datetime(2013, 2, 19, 9, 10)),
        # French dates
        param('11 Mai 2014', datetime(2014, 5, 11)),
        param('dimanche, 11 Mai 2014', datetime(2014, 5, 11)),
        param('22 janvier 2015 à 14h40', datetime(2015, 1, 22, 14, 40)),  # wrong
        param('Dimanche 1er Février à 21:24', datetime(2012, 2, 1, 21, 24)),
        param('vendredi, décembre 5 2014.', datetime(2014, 12, 5, 0, 0)),
        param('le 08 Déc 2014 15:11', datetime(2014, 12, 8, 15, 11)),
        param('Le 11 Décembre 2014 à 09:00', datetime(2014, 12, 11, 9, 0)),
        param('fév 15, 2013', datetime(2013, 2, 15, 0, 0)),
        param('Jeu 15:12', datetime(2012, 11, 8, 15, 12)),
        # Spanish dates
        param('Martes 21 de Octubre de 2014', datetime(2014, 10, 21)),
        param('Miércoles 20 de Noviembre de 2013', datetime(2013, 11, 20)),
        param('12 de junio del 2012', datetime(2012, 6, 12)),
        param('13 Ago, 2014', datetime(2014, 8, 13)),
        param('13 Septiembre, 2014', datetime(2014, 9, 13)),
        param('11 Marzo, 2014', datetime(2014, 3, 11)),
        param('julio 5, 2015 en 1:04 pm', datetime(2015, 7, 5, 13, 4)),
        param('Vi 17:15', datetime(2012, 11, 9, 17, 15)),
        # Dutch dates
        param('11 augustus 2014', datetime(2014, 8, 11)),
        param('14 januari 2014', datetime(2014, 1, 14)),
        param('vr jan 24, 2014 12:49', datetime(2014, 1, 24, 12, 49)),
        # Italian dates
        param('16 giu 2014', datetime(2014, 6, 16)),
        param('26 gennaio 2014', datetime(2014, 1, 26)),
        param('Ven 18:23', datetime(2012, 11, 9, 18, 23)),
        # Portuguese dates
        param('sexta-feira, 10 de junho de 2014 14:52', datetime(2014, 6, 10, 14, 52)),
        param('13 Setembro, 2014', datetime(2014, 9, 13)),
        param('Sab 3:03', datetime(2012, 11, 10, 3, 3)),
        # Russian dates
        param('10 мая', datetime(2012, 5, 10)),  # forum.codenet.ru
        param('26 апреля', datetime(2012, 4, 26)),
        param('20 ноября 2013', datetime(2013, 11, 20)),
        param('28 октября 2014 в 07:54', datetime(2014, 10, 28, 7, 54)),
        param('13 января 2015 г. в 13:34', datetime(2015, 1, 13, 13, 34)),
        param('09 августа 2012', datetime(2012, 8, 9, 0, 0)),
        param('Авг 26, 2015 15:12', datetime(2015, 8, 26, 15, 12)),
        param('2 Декабрь 95 11:15', datetime(1995, 12, 2, 11, 15)),
        param('13 янв. 2005 19:13', datetime(2005, 1, 13, 19, 13)),
        param('13 авг. 2005 19:13', datetime(2005, 8, 13, 19, 13)),
        param('13 авг. 2005г. 19:13', datetime(2005, 8, 13, 19, 13)),
        param('13 авг. 2005 г. 19:13', datetime(2005, 8, 13, 19, 13)),
        # Turkish dates
        param('11 Ağustos, 2014', datetime(2014, 8, 11)),
        param('08.Haziran.2014, 11:07', datetime(2014, 6, 8, 11, 7)),  # forum.andronova.net
        param('17.Şubat.2014, 17:51', datetime(2014, 2, 17, 17, 51)),
        param('14-Aralık-2012, 20:56', datetime(2012, 12, 14, 20, 56)),  # forum.ceviz.net
        # Romanian dates
        param('13 iunie 2013', datetime(2013, 6, 13)),
        param('14 aprilie 2014', datetime(2014, 4, 14)),
        param('18 martie 2012', datetime(2012, 3, 18)),
        param('S 14:14', datetime(2012, 11, 10, 14, 14)),
        param('12-Iun-2013', datetime(2013, 6, 12)),
        # German dates
        param('21. Dezember 2013', datetime(2013, 12, 21)),
        param('19. Februar 2012', datetime(2012, 2, 19)),
        param('26. Juli 2014', datetime(2014, 7, 26)),
        param('18.10.14 um 22:56 Uhr', datetime(2014, 10, 18, 22, 56)),
        param('12-Mär-2014', datetime(2014, 3, 12)),
        param('Mit 13:14', datetime(2012, 11, 7, 13, 14)),
        # Czech dates
        param('pon 16. čer 2014 10:07:43', datetime(2014, 6, 16, 10, 7, 43)),
        param('13 Srpen, 2014', datetime(2014, 8, 13)),
        param('čtv 14. lis 2013 12:38:43', datetime(2013, 11, 14, 12, 38, 43)),
        # Thai dates
        param('ธันวาคม 11, 2014, 08:55:08 PM', datetime(2014, 12, 11, 20, 55, 8)),
        param('22 พฤษภาคม 2012, 22:12', datetime(2012, 5, 22, 22, 12)),
        param('11 กุมภา 2020, 8:13 AM', datetime(2020, 2, 11, 8, 13)),
        param('1 เดือนตุลาคม 2005, 1:00 AM', datetime(2005, 10, 1, 1, 0)),
        param('11 ก.พ. 2020, 1:13 pm', datetime(2020, 2, 11, 13, 13)),
        # Vietnamese dates
        param('Thứ năm', datetime(2012, 11, 8)),  # Thursday
        param('Thứ sáu', datetime(2012, 11, 9)),  # Friday
        param('Tháng Mười Hai 29, 2013, 14:14', datetime(2013, 12, 29, 14, 14)),  # bpsosrcs.wordpress.com  # NOQA
        param('05 Tháng một 2015 - 03:54 AM', datetime(2015, 1, 5, 3, 54)),
        # Belarusian dates
        param('11 траўня', datetime(2012, 5, 11)),
        param('4 мая', datetime(2012, 5, 4)),
        param('Чацвер 06 жніўня 2015', datetime(2015, 8, 6)),
        param('Нд 14 сакавіка 2015 у 7 гадзін 10 хвілін', datetime(2015, 3, 14, 7, 10)),
        param('5 жніўня 2015 года у 13:34', datetime(2015, 8, 5, 13, 34)),
        # Ukrainian dates
        param('2015-кві-12', datetime(2015, 4, 12)),
        param('21 чер 2013 3:13', datetime(2013, 6, 21, 3, 13)),
        param('12 лютого 2012, 13:12:23', datetime(2012, 2, 12, 13, 12, 23)),
        param('вів о 14:04', datetime(2012, 11, 13, 14, 4)),
        # Filipino dates
        param('12 Hulyo 2003 13:01', datetime(2003, 7, 12, 13, 1)),
        param('1978, 1 Peb, 7:05 PM', datetime(1978, 2, 1, 19, 5)),
        param('2 hun', datetime(2012, 6, 2)),
        param('Lin 16:16', datetime(2012, 11, 11, 16, 16)),
        # Japanese dates
        param('2016年3月20日(日) 21時40分', datetime(2016, 3, 20, 21, 40)),
        param("2016年3月20日 21時40分", datetime(2016, 3, 20, 21, 40)),
        # Bangla dates
        param('[সেপ্টেম্বর] 04, 2014.', datetime(2014, 9, 4)),
        param('মঙ্গলবার জুলাই 22, 2014', datetime(2014, 7, 22)),
        param('শুক্রবার', datetime(2012, 11, 9)),
        param('শুক্র, 12 ডিসেম্বর 2014 10:55:50', datetime(2014, 12, 12, 10, 55, 50)),
        param('1লা জানুয়ারী 2015', datetime(2015, 1, 1)),
        param('25শে মার্চ 1971', datetime(1971, 3, 25)),
        param('8ই মে 2002', datetime(2002, 5, 8)),
        param('10:06am ডিসেম্বর 11, 2014', datetime(2014, 12, 11, 10, 6)),
        param('19 ফেব্রুয়ারী 2013 সাল 09:10', datetime(2013, 2, 19, 9, 10)),
        # Numeric dates
        param('06-17-2014', datetime(2014, 6, 17)),
        param('13/03/2014', datetime(2014, 3, 13)),
        param('11. 12. 2014, 08:45:39', datetime(2014, 11, 12, 8, 45, 39)),
        # Miscellaneous dates
        param('1 Ni 2015', datetime(2015, 4, 1, 0, 0)),
        param('1 Mar 2015', datetime(2015, 3, 1, 0, 0)),
        param('1 сер 2015', datetime(2015, 8, 1, 0, 0)),
        # Bulgarian
        param('24 ян 2015г.', datetime(2015, 1, 24, 0, 0)),
        # Hindi dates
        param('बुधवार 24 मई 1997 12:09', datetime(1997, 5, 24, 12, 9)),
        param('28 दिसम्बर 2000 , 01:09:08', datetime(2000, 12, 28, 1, 9, 8)),
        param('१६ दिसम्बर १९७१', datetime(1971, 12, 16, 0, 0)),
        param('सन् 1989 11 फ़रवरी 09:43', datetime(1989, 2, 11, 9, 43)),
    ])
    def test_dates_parsing_with_normalization(self, date_string, expected):
        self.given_local_tz_offset(0)
        self.given_parser(settings={'NORMALIZE': True,
                                    'RELATIVE_BASE': datetime(2012, 11, 13)})
        self.when_date_is_parsed(normalize_unicode(date_string))
        self.then_date_was_parsed_by_date_parser()
        self.then_period_is('day')
        self.then_date_obj_exactly_is(expected)

    @parameterized.expand([
        param('Sep 03 2014 | 4:32 pm EDT', datetime(2014, 9, 3, 20, 32)),
        param('17th October, 2034 @ 01:08 am PDT', datetime(2034, 10, 17, 8, 8)),
        param('15 May 2004 23:24 EDT', datetime(2004, 5, 16, 3, 24)),
        param('08/17/14 17:00 (PDT)', datetime(2014, 8, 18, 0, 0)),
    ])
    def test_parsing_with_time_zones_and_converting_to_UTC(self, date_string, expected):
        self.given_parser(settings={'TO_TIMEZONE': 'UTC'})
        self.when_date_is_parsed(date_string)
        self.then_date_was_parsed_by_date_parser()
        self.then_period_is('day')
        self.then_timezone_parsed_is('UTC')
        self.then_date_obj_exactly_is(expected)

    @parameterized.expand([
        param('Sep 03 2014 | 4:32 pm EDT', 'EDT', datetime(2014, 9, 3, 16, 32)),
        param('17th October, 2034 @ 01:08 am PDT', 'PDT', datetime(2034, 10, 17, 1, 8)),
        param('15 May 2004 23:24 EDT', 'EDT', datetime(2004, 5, 15, 23, 24)),
        param('08/17/14 17:00 (PDT)', 'PDT', datetime(2014, 8, 17, 17, 0)),
        param('15 May 2004 16:10 -0400', '-04:00', datetime(2004, 5, 15, 16, 10)),
        param('1999-12-31 19:00:00 -0500', '-05:00', datetime(1999, 12, 31, 19, 0)),
        param('1999-12-31 19:00:00 +0500', '+05:00', datetime(1999, 12, 31, 19, 0)),
        param('Fri, 09 Sep 2005 13:51:39 -0700', '-07:00', datetime(2005, 9, 9, 13, 51, 39)),
        param('Fri, 09 Sep 2005 13:51:39 +0000', '+00:00', datetime(2005, 9, 9, 13, 51, 39)),
    ])
    def test_dateparser_should_return_tzaware_date_when_tz_info_present_in_date_string(
            self, date_string, timezone_str, expected):
        self.given_parser()
        self.when_date_is_parsed(date_string)
        self.then_date_was_parsed_by_date_parser()
        self.then_period_is('day')
        self.then_timezone_parsed_is(timezone_str)
        self.then_date_obj_exactly_is(expected)

    @parameterized.expand([
        param('15 May 2004 16:10 -0400', 'UTC', datetime(2004, 5, 15, 20, 10)),
        param('1999-12-31 19:00:00 -0500', 'UTC', datetime(2000, 1, 1, 0, 0)),
        param('1999-12-31 19:00:00 +0500', 'UTC', datetime(1999, 12, 31, 14, 0)),
        param('Fri, 09 Sep 2005 13:51:39 -0700', 'GMT', datetime(2005, 9, 9, 20, 51, 39)),
        param('Fri, 09 Sep 2005 13:51:39 +0000', 'GMT', datetime(2005, 9, 9, 13, 51, 39)),
    ])
    def test_dateparser_should_return_date_in_setting_timezone_if_timezone_info_present_both_in_datestring_and_given_in_settings(self, date_string, setting_timezone, expected):
        self.given_parser(settings={'TIMEZONE': setting_timezone})
        self.when_date_is_parsed(date_string)
        self.then_date_was_parsed_by_date_parser()
        self.then_period_is('day')
        self.then_timezone_parsed_is(setting_timezone)
        self.then_date_obj_exactly_is(expected)

    @parameterized.expand([
        param('15 May 2004 16:10 -0400', datetime(2004, 5, 15, 20, 10)),
        param('1999-12-31 19:00:00 -0500', datetime(2000, 1, 1, 0, 0)),
        param('1999-12-31 19:00:00 +0500', datetime(1999, 12, 31, 14, 0)),
        param('Fri, 09 Sep 2005 13:51:39 -0700', datetime(2005, 9, 9, 20, 51, 39)),
        param('Fri, 09 Sep 2005 13:51:39 +0000', datetime(2005, 9, 9, 13, 51, 39)),
        param('Fri Sep 23 2016 10:34:51 GMT+0800 (CST)', datetime(2016, 9, 23, 2, 34, 51)),
    ])
    def test_parsing_with_utc_offsets(self, date_string, expected):
        self.given_parser(settings={'TO_TIMEZONE': 'utc'})
        self.when_date_is_parsed(date_string)
        self.then_date_was_parsed_by_date_parser()
        self.then_period_is('day')
        self.then_timezone_parsed_is('UTC')
        self.then_date_obj_exactly_is(expected)

    def test_empty_dates_string_is_not_parsed(self):
        self.when_date_is_parsed_by_date_parser('')
        self.then_error_was_raised(ValueError, ["Empty string"])

    @parameterized.expand([
        param('invalid date string', 'Unable to parse: invalid'),
        param('Aug 7, 2014Aug 7, 2014', 'Unable to parse: Aug'),
        param('24h ago', 'Unable to parse: h'),
        param('2015-03-17t16:37:51+00:002015-03-17t15:24:37+00:00', 'Unable to parse: 00:002015'),
        param('8 enero 2013 martes 7:03 AM EST 8 enero 2013 martes 7:03 AM EST', 'Unable to parse: 8'),
        param('12/09/18567', 'Unable to parse: 18567'),
    ])
    def test_dates_not_parsed(self, date_string, message):
        self.when_date_is_parsed_by_date_parser(date_string)
        self.then_error_was_raised(ValueError, message)

    @parameterized.expand([
        param('10 December', datetime(2014, 12, 10)),
        param('March', datetime(2014, 3, 15)),
        param('Friday', datetime(2015, 2, 13)),
        param('Monday', datetime(2015, 2, 9)),
        param('Sunday', datetime(2015, 2, 8)),  # current day
        param('10:00PM', datetime(2015, 2, 14, 22, 0)),
        param('16:10', datetime(2015, 2, 14, 16, 10)),
        param('14:05', datetime(2015, 2, 15, 14, 5)),
        param('15 february 15:00', datetime(2015, 2, 15, 15, 0)),
        param('3/3/50', datetime(1950, 3, 3)),
        param('3/3/94', datetime(1994, 3, 3)),
    ])
    def test_preferably_past_dates(self, date_string, expected):
        self.given_parser(settings={'PREFER_DATES_FROM': 'past',
                          'RELATIVE_BASE': datetime(2015, 2, 15, 15, 30)})
        self.when_date_is_parsed(date_string)
        self.then_date_was_parsed_by_date_parser()
        self.then_date_obj_exactly_is(expected)

    @parameterized.expand([
        param('10 December', datetime(2015, 12, 10)),
        param('March', datetime(2015, 3, 15)),
        param('Friday', datetime(2015, 2, 20)),
        param('Sunday', datetime(2015, 2, 22)),  # current day
        param('Monday', datetime(2015, 2, 16)),
        param('10:00PM', datetime(2015, 2, 15, 22, 0)),
        param('16:10', datetime(2015, 2, 15, 16, 10)),
        param('14:05', datetime(2015, 2, 16, 14, 5)),
        param('3/3/50', datetime(2050, 3, 3)),
        param('3/3/94', datetime(2094, 3, 3)),
    ])
    def test_preferably_future_dates(self, date_string, expected):
        self.given_local_tz_offset(0)
        self.given_parser(settings={'PREFER_DATES_FROM': 'future',
                          'RELATIVE_BASE': datetime(2015, 2, 15, 15, 30)})
        self.when_date_is_parsed(date_string)
        self.then_date_was_parsed_by_date_parser()
        self.then_date_obj_exactly_is(expected)

    @parameterized.expand([
        param('10 December', datetime(2015, 12, 10)),
        param('March', datetime(2015, 3, 15)),
        param('Friday', datetime(2015, 2, 13)),
        param('Sunday', datetime(2015, 2, 15)),  # current weekday
        param('10:00PM', datetime(2015, 2, 15, 22, 00)),
        param('16:10', datetime(2015, 2, 15, 16, 10)),
        param('14:05', datetime(2015, 2, 15, 14, 5)),
    ])
    def test_dates_without_preference(self, date_string, expected):
        self.given_local_tz_offset(0)
        self.given_parser(settings={'PREFER_DATES_FROM': 'current_period',
                          'RELATIVE_BASE': datetime(2015, 2, 15, 15, 30)})
        self.when_date_is_parsed(date_string)
        self.then_date_was_parsed_by_date_parser()
        self.then_date_obj_exactly_is(expected)

    @parameterized.expand([
        param('February 2015', today=datetime(2015, 1, 31), expected=datetime(2015, 2, 28)),
        param('February 2012', today=datetime(2015, 1, 31), expected=datetime(2012, 2, 29)),
        param('March 2015', today=datetime(2015, 1, 25), expected=datetime(2015, 3, 25)),
        param('April 2015', today=datetime(2015, 1, 31), expected=datetime(2015, 4, 30)),
        param('April 2015', today=datetime(2015, 2, 28), expected=datetime(2015, 4, 28)),
        param('December 2014', today=datetime(2015, 2, 15), expected=datetime(2014, 12, 15)),
    ])
    def test_dates_with_day_missing_prefering_current_day_of_month(
            self, date_string, today=None, expected=None):
        self.given_parser(settings={'PREFER_DAY_OF_MONTH': 'current', 'RELATIVE_BASE': today})
        self.when_date_is_parsed(date_string)
        self.then_date_was_parsed_by_date_parser()
        self.then_date_obj_exactly_is(expected)

    @parameterized.expand([
        param('February 2015', today=datetime(2015, 1, 1), expected=datetime(2015, 2, 28)),
        param('February 2012', today=datetime(2015, 1, 1), expected=datetime(2012, 2, 29)),
        param('March 2015', today=datetime(2015, 1, 25), expected=datetime(2015, 3, 31)),
        param('April 2015', today=datetime(2015, 1, 15), expected=datetime(2015, 4, 30)),
        param('April 2015', today=datetime(2015, 2, 28), expected=datetime(2015, 4, 30)),
        param('December 2014', today=datetime(2015, 2, 15), expected=datetime(2014, 12, 31)),
    ])
    def test_dates_with_day_missing_prefering_last_day_of_month(
            self, date_string, today=None, expected=None):
        self.given_parser(settings={'PREFER_DAY_OF_MONTH': 'last', 'RELATIVE_BASE': today})
        self.when_date_is_parsed(date_string)
        self.then_date_was_parsed_by_date_parser()
        self.then_date_obj_exactly_is(expected)

    @parameterized.expand([
        param('February 2015', today=datetime(2015, 1, 8), expected=datetime(2015, 2, 1)),
        param('February 2012', today=datetime(2015, 1, 7), expected=datetime(2012, 2, 1)),
        param('March 2015', today=datetime(2015, 1, 25), expected=datetime(2015, 3, 1)),
        param('April 2015', today=datetime(2015, 1, 15), expected=datetime(2015, 4, 1)),
        param('April 2015', today=datetime(2015, 2, 28), expected=datetime(2015, 4, 1)),
        param('December 2014', today=datetime(2015, 2, 15), expected=datetime(2014, 12, 1)),
    ])
    def test_dates_with_day_missing_prefering_first_day_of_month(
            self, date_string, today=None, expected=None):
        self.given_parser(settings={'PREFER_DAY_OF_MONTH': 'first', 'RELATIVE_BASE': today})
        self.when_date_is_parsed(date_string)
        self.then_date_was_parsed_by_date_parser()
        self.then_date_obj_exactly_is(expected)

    @parameterized.expand([
        param(prefer_day_of_month='current'),
        param(prefer_day_of_month='last'),
        param(prefer_day_of_month='first'),
    ])
    def test_that_day_preference_does_not_affect_dates_with_explicit_day(
            self, prefer_day_of_month=None):
        self.given_parser(settings={'PREFER_DAY_OF_MONTH': prefer_day_of_month,
                          'RELATIVE_BASE': datetime(2015, 2, 12)})
        self.when_date_is_parsed('24 April 2012')
        self.then_date_was_parsed_by_date_parser()
        self.then_date_obj_exactly_is(datetime(2012, 4, 24))

    def test_date_is_parsed_when_skip_tokens_are_supplied(self):
        self.given_parser(settings={'SKIP_TOKENS': ['de'], 'RELATIVE_BASE': datetime(2015, 2, 12)})
        self.when_date_is_parsed('24 April 2012 de')
        self.then_date_was_parsed_by_date_parser()
        self.then_date_obj_exactly_is(datetime(2012, 4, 24))

    @parameterized.expand([
        param('29 February 2015', 'day must be in 1..28'),
        param('32 January 2015', 'day must be in 1..31'),
        param('31 April 2015', 'day must be in 1..30'),
        param('31 June 2015', 'day must be in 1..30'),
        param('31 September 2015', 'day must be in 1..30'),
    ])
    def test_error_should_be_raised_for_invalid_dates_with_too_large_day_number(self, date_string, message):
        self.when_date_is_parsed_by_date_parser(date_string)
        self.then_error_was_raised(ValueError, ['day is out of range for month', message])

    @parameterized.expand([
        param('2015-05-02T10:20:19+0000', languages=['fr'],
              expected=datetime(2015, 5, 2, 10, 20, 19)),
        param('2015-05-02T10:20:19+0000', languages=['en'],
              expected=datetime(2015, 5, 2, 10, 20, 19)),
    ])
    def test_iso_datestamp_format_should_always_parse(self, date_string, languages, expected):
        self.given_local_tz_offset(0)
        self.given_parser(languages=languages, settings={'PREFER_LOCALE_DATE_ORDER': False})
        self.when_date_is_parsed(date_string)
        self.then_date_was_parsed_by_date_parser()
        self.result['date_obj'] = self.result['date_obj'].replace(tzinfo=None)
        self.then_date_obj_exactly_is(expected)

    @parameterized.expand([
        # Epoch timestamps.
        param('1484823450', expected=datetime(2017, 1, 19, 10, 57, 30)),
        param('1436745600000', expected=datetime(2015, 7, 13, 0, 0)),
        param('1015673450', expected=datetime(2002, 3, 9, 11, 30, 50)),
        param('2016-09-23T02:54:32.845Z', expected=datetime(2016, 9, 23, 2, 54, 32, 845000,
              tzinfo=StaticTzInfo('Z', timedelta(0))))
    ])
    def test_parse_timestamp(self, date_string, expected):
        self.given_local_tz_offset(0)
        self.given_parser(settings={'TO_TIMEZONE': 'UTC'})
        self.when_date_is_parsed(date_string)
        self.then_date_obj_exactly_is(expected)

    @parameterized.expand([
        param('10 December', expected=datetime(2015, 12, 10), period='day'),
        param('March', expected=datetime(2015, 3, 15), period='month'),
        param('April', expected=datetime(2015, 4, 15), period='month'),
        param('December', expected=datetime(2015, 12, 15), period='month'),
        param('Friday', expected=datetime(2015, 2, 13), period='day'),
        param('Monday', expected=datetime(2015, 2, 9), period='day'),
        param('10:00PM', expected=datetime(2015, 2, 15, 22, 00), period='day'),
        param('16:10', expected=datetime(2015, 2, 15, 16, 10), period='day'),
        param('2014', expected=datetime(2014, 2, 15), period='year'),
        param('2008', expected=datetime(2008, 2, 15), period='year'),
    ])
    def test_extracted_period(self, date_string, expected=None, period=None):
        self.given_local_tz_offset(0)
        self.given_parser(settings={'RELATIVE_BASE': datetime(2015, 2, 15, 15, 30)})
        self.when_date_is_parsed(date_string)
        self.then_date_was_parsed_by_date_parser()
        self.then_date_obj_exactly_is(expected)
        self.then_period_is(period)

    @parameterized.expand([
        param('12th December 2019 19:00', expected=datetime(2019, 12, 12, 19, 0), period='time'),
        param('9 Jan 11 0:00', expected=datetime(2011, 1, 9, 0, 0), period='time'),
    ])
    def test_period_is_time_if_return_time_as_period_setting_applied_and_time_component_present(
        self, date_string, expected=None, period=None
    ):
        self.given_parser(settings={'RETURN_TIME_AS_PERIOD': True})
        self.when_date_is_parsed(date_string)
        self.then_date_was_parsed_by_date_parser()
        self.then_date_obj_exactly_is(expected)
        self.then_period_is(period)

    @parameterized.expand([
        param('16:00', expected=datetime(2018, 12, 13, 16, 0), period='time'),
        param('Monday 7:15 AM', expected=datetime(2018, 12, 10, 7, 15), period='time'),
    ])
    def test_period_is_time_if_return_time_as_period_and_relative_base_settings_applied_and_time_component_present(
        self, date_string, expected=None, period=None
    ):
        self.given_parser(settings={'RETURN_TIME_AS_PERIOD': True,
                                    'RELATIVE_BASE': datetime(2018, 12, 13, 15, 15)})
        self.when_date_is_parsed(date_string)
        self.then_date_was_parsed_by_date_parser()
        self.then_date_obj_exactly_is(expected)
        self.then_period_is(period)

    @parameterized.expand([
        param('12th March 2010', expected=datetime(2010, 3, 12, 0, 0), period='day'),
        param('21-12-19', expected=datetime(2019, 12, 21, 0, 0), period='day'),
    ])
    def test_period_is_day_if_return_time_as_period_setting_applied_and_time_component_is_not_present(
        self, date_string, expected=None, period=None
    ):
        self.given_parser(settings={'RETURN_TIME_AS_PERIOD': True})
        self.when_date_is_parsed(date_string)
        self.then_date_was_parsed_by_date_parser()
        self.then_date_obj_exactly_is(expected)
        self.then_period_is(period)

    @parameterized.expand([
        param('16:00', expected=datetime(2017, 1, 10, 16, 0), period='day'),
        param('Monday 7:15 AM', expected=datetime(2017, 1, 9, 7, 15), period='day'),
    ])
    def test_period_is_day_if_return_time_as_period_setting_not_applied(
        self, date_string, expected=None, period=None
    ):
        self.given_parser(settings={'RETURN_TIME_AS_PERIOD': False,
                                    'RELATIVE_BASE': datetime(2017, 1, 10, 15, 15)})
        self.when_date_is_parsed(date_string)
        self.then_date_was_parsed_by_date_parser()
        self.then_date_obj_exactly_is(expected)
        self.then_period_is(period)

    @parameterized.expand([
        param('15-12-18 06:00', expected=datetime(2015, 12, 18, 6, 0), order='YMD'),
        param('15-18-12 06:00', expected=datetime(2015, 12, 18, 6, 0), order='YDM'),
        param('10-11-12 06:00', expected=datetime(2012, 10, 11, 6, 0), order='MDY'),
        param('10-11-12 06:00', expected=datetime(2011, 10, 12, 6, 0), order='MYD'),
        param('10-11-12 06:00', expected=datetime(2011, 12, 10, 6, 0), order='DYM'),
        param('15-12-18 06:00', expected=datetime(2018, 12, 15, 6, 0), order='DMY'),
        param('12/09/08 04:23:15.567', expected=datetime(2008, 9, 12, 4, 23, 15, 567000),
              order='DMY'),
        param('10/9/1914 03:07:09.788888 pm', expected=datetime(1914, 10, 9, 15, 7, 9, 788888),
              order='MDY'),
        param('1-8-09 07:12:49 AM', expected=datetime(2009, 1, 8, 7, 12, 49), order='MDY'),
        param('201508', expected=datetime(2015, 8, 20, 0, 0), order='DYM'),
        param('201508', expected=datetime(2020, 8, 15, 0, 0), order='YDM'),
        param('201108', expected=datetime(2008, 11, 20, 0, 0), order='DMY'),
        param('2016 july 13.', expected=datetime(2016, 7, 13, 0, 0), order='YMD'),
        param('16 july 13.', expected=datetime(2016, 7, 13, 0, 0), order='YMD'),
        param('Sunday 23 May 1856 12:09:08 AM', expected=datetime(1856, 5, 23, 0, 9, 8),
              order='DMY'),
    ])
    def test_order(self, date_string, expected=None, order=None):
        self.given_parser(settings={'DATE_ORDER': order})
        self.when_date_is_parsed(date_string)
        self.then_date_was_parsed_by_date_parser()
        self.then_date_obj_exactly_is(expected)

    @parameterized.expand([
        param('10.1.2019', expected=datetime(2019, 1, 10, 0, 0), languages=['de'],
              settings={'PREFER_DAY_OF_MONTH': 'first'}),
        param('10.1.2019', expected=datetime(2019, 1, 10, 0, 0), languages=['de']),
        param('10.1.2019', expected=datetime(2019, 10, 1, 0, 0),
              settings={'DATE_ORDER': 'MDY'}),
    ])
    def test_if_settings_provided_date_order_is_retained(
        self, date_string, expected=None, languages=None, settings=None
    ):
        self.given_parser(languages=languages, settings=settings)
        self.when_date_is_parsed(date_string)
        self.then_date_was_parsed_by_date_parser()
        self.then_date_obj_exactly_is(expected)

    @parameterized.expand([
        param('::', None),
        param('..', None),
        param('  ', None),
        param('--', None),
        param('//', None),
        param('++', None),
    ])
    def test_parsing_strings_containing_only_separator_tokens(self, date_string, expected):
        self.given_parser()
        self.when_date_is_parsed(date_string)
        self.then_period_is('day')
        self.then_date_obj_exactly_is(expected)

    def given_local_tz_offset(self, offset):
        self.add_patch(
            patch.object(dateparser.timezone_parser,
                         'local_tz_offset',
                         new=timedelta(seconds=3600 * offset))
        )

    def given_parser(self, *args, **kwds):
        def collecting_get_date_data(parse):
            @wraps(parse)
            def wrapped(*args, **kwargs):
                self.date_result = parse(*args, **kwargs)
                return self.date_result
            return wrapped

        self.add_patch(patch.object(date_parser,
                                    'parse',
                                    collecting_get_date_data(date_parser.parse)))

        self.date_parser = Mock(wraps=date_parser)
        self.add_patch(patch('dateparser.date.date_parser', new=self.date_parser))
        self.parser = DateDataParser(*args, **kwds)

    def when_date_is_parsed(self, date_string):
        self.result = self.parser.get_date_data(date_string)

    def when_date_is_parsed_by_date_parser(self, date_string):
        try:
            self.result = DateParser().parse(date_string)
        except Exception as error:
            self.error = error

    def then_period_is(self, period):
        self.assertEqual(period, self.result['period'])

    def then_date_obj_exactly_is(self, expected):
        self.assertEqual(expected, self.result['date_obj'])

    def then_date_was_parsed_by_date_parser(self):
        self.assertNotEqual(NotImplemented, self.date_result, "Date was not parsed")
        self.assertEqual(self.result['date_obj'], self.date_result[0])

    def then_timezone_parsed_is(self, tzstr):
        self.assertTrue(tzstr in repr(self.result['date_obj'].tzinfo))
        self.result['date_obj'] = self.result['date_obj'].replace(tzinfo=None)
class DataflowRunnerTest(unittest.TestCase, ExtraAssertionsMixin):
    def setUp(self):
        self.default_properties = [
            '--dataflow_endpoint=ignored', '--job_name=test-job',
            '--project=test-project', '--staging_location=ignored',
            '--temp_location=/dev/null', '--no_auth', '--dry_run=True',
            '--sdk_location=container'
        ]

    @mock.patch('time.sleep', return_value=None)
    def test_wait_until_finish(self, patched_time_sleep):
        values_enum = dataflow_api.Job.CurrentStateValueValuesEnum

        class MockDataflowRunner(object):
            def __init__(self, states):
                self.dataflow_client = mock.MagicMock()
                self.job = mock.MagicMock()
                self.job.currentState = values_enum.JOB_STATE_UNKNOWN
                self._states = states
                self._next_state_index = 0

                def get_job_side_effect(*args, **kwargs):
                    self.job.currentState = self._states[
                        self._next_state_index]
                    if self._next_state_index < (len(self._states) - 1):
                        self._next_state_index += 1
                    return mock.DEFAULT

                self.dataflow_client.get_job = mock.MagicMock(
                    return_value=self.job, side_effect=get_job_side_effect)
                self.dataflow_client.list_messages = mock.MagicMock(
                    return_value=([], None))

        with self.assertRaisesRegex(DataflowRuntimeException,
                                    'Dataflow pipeline failed. State: FAILED'):
            failed_runner = MockDataflowRunner([values_enum.JOB_STATE_FAILED])
            failed_result = DataflowPipelineResult(failed_runner.job,
                                                   failed_runner)
            failed_result.wait_until_finish()

        succeeded_runner = MockDataflowRunner([values_enum.JOB_STATE_DONE])
        succeeded_result = DataflowPipelineResult(succeeded_runner.job,
                                                  succeeded_runner)
        result = succeeded_result.wait_until_finish()
        self.assertEqual(result, PipelineState.DONE)

        # Time array has duplicate items, because some logging implementations also
        # call time.
        with mock.patch('time.time',
                        mock.MagicMock(side_effect=[1, 1, 2, 2, 3])):
            duration_succeeded_runner = MockDataflowRunner(
                [values_enum.JOB_STATE_RUNNING, values_enum.JOB_STATE_DONE])
            duration_succeeded_result = DataflowPipelineResult(
                duration_succeeded_runner.job, duration_succeeded_runner)
            result = duration_succeeded_result.wait_until_finish(5000)
            self.assertEqual(result, PipelineState.DONE)

        with mock.patch('time.time',
                        mock.MagicMock(side_effect=[1, 9, 9, 20, 20])):
            duration_timedout_runner = MockDataflowRunner(
                [values_enum.JOB_STATE_RUNNING])
            duration_timedout_result = DataflowPipelineResult(
                duration_timedout_runner.job, duration_timedout_runner)
            result = duration_timedout_result.wait_until_finish(5000)
            self.assertEqual(result, PipelineState.RUNNING)

        with mock.patch('time.time',
                        mock.MagicMock(side_effect=[1, 1, 2, 2, 3])):
            with self.assertRaisesRegex(
                    DataflowRuntimeException,
                    'Dataflow pipeline failed. State: CANCELLED'):
                duration_failed_runner = MockDataflowRunner(
                    [values_enum.JOB_STATE_CANCELLED])
                duration_failed_result = DataflowPipelineResult(
                    duration_failed_runner.job, duration_failed_runner)
                duration_failed_result.wait_until_finish(5000)

    @mock.patch('time.sleep', return_value=None)
    def test_cancel(self, patched_time_sleep):
        values_enum = dataflow_api.Job.CurrentStateValueValuesEnum

        class MockDataflowRunner(object):
            def __init__(self, state, cancel_result):
                self.dataflow_client = mock.MagicMock()
                self.job = mock.MagicMock()
                self.job.currentState = state

                self.dataflow_client.get_job = mock.MagicMock(
                    return_value=self.job)
                self.dataflow_client.modify_job_state = mock.MagicMock(
                    return_value=cancel_result)
                self.dataflow_client.list_messages = mock.MagicMock(
                    return_value=([], None))

        with self.assertRaisesRegex(DataflowRuntimeException,
                                    'Failed to cancel job'):
            failed_runner = MockDataflowRunner(values_enum.JOB_STATE_RUNNING,
                                               False)
            failed_result = DataflowPipelineResult(failed_runner.job,
                                                   failed_runner)
            failed_result.cancel()

        succeeded_runner = MockDataflowRunner(values_enum.JOB_STATE_RUNNING,
                                              True)
        succeeded_result = DataflowPipelineResult(succeeded_runner.job,
                                                  succeeded_runner)
        succeeded_result.cancel()

        terminal_runner = MockDataflowRunner(values_enum.JOB_STATE_DONE, False)
        terminal_result = DataflowPipelineResult(terminal_runner.job,
                                                 terminal_runner)
        terminal_result.cancel()

    def test_create_runner(self):
        self.assertTrue(
            isinstance(create_runner('DataflowRunner'), DataflowRunner))
        self.assertTrue(
            isinstance(create_runner('TestDataflowRunner'),
                       TestDataflowRunner))

    def test_environment_override_translation_legacy_worker_harness_image(
            self):
        self.default_properties.append('--experiments=beam_fn_api')
        self.default_properties.append(
            '--worker_harness_container_image=LEGACY')
        remote_runner = DataflowRunner()
        with Pipeline(remote_runner,
                      options=PipelineOptions(self.default_properties)) as p:
            (  # pylint: disable=expression-not-assigned
                p | ptransform.Create([1, 2, 3])
                | 'Do' >> ptransform.FlatMap(lambda x: [(x, x)])
                | ptransform.GroupByKey())
        self.assertEqual(
            list(
                remote_runner.proto_pipeline.components.environments.values()),
            [
                beam_runner_api_pb2.Environment(
                    urn=common_urns.environments.DOCKER.urn,
                    payload=beam_runner_api_pb2.DockerPayload(
                        container_image='LEGACY').SerializeToString(),
                    capabilities=environments.python_sdk_docker_capabilities())
            ])

    def test_environment_override_translation_sdk_container_image(self):
        self.default_properties.append('--experiments=beam_fn_api')
        self.default_properties.append('--sdk_container_image=FOO')
        remote_runner = DataflowRunner()
        with Pipeline(remote_runner,
                      options=PipelineOptions(self.default_properties)) as p:
            (  # pylint: disable=expression-not-assigned
                p | ptransform.Create([1, 2, 3])
                | 'Do' >> ptransform.FlatMap(lambda x: [(x, x)])
                | ptransform.GroupByKey())
        self.assertEqual(
            list(
                remote_runner.proto_pipeline.components.environments.values()),
            [
                beam_runner_api_pb2.Environment(
                    urn=common_urns.environments.DOCKER.urn,
                    payload=beam_runner_api_pb2.DockerPayload(
                        container_image='FOO').SerializeToString(),
                    capabilities=environments.python_sdk_docker_capabilities())
            ])

    def test_remote_runner_translation(self):
        remote_runner = DataflowRunner()
        with Pipeline(remote_runner,
                      options=PipelineOptions(self.default_properties)) as p:

            (  # pylint: disable=expression-not-assigned
                p | ptransform.Create([1, 2, 3])
                | 'Do' >> ptransform.FlatMap(lambda x: [(x, x)])
                | ptransform.GroupByKey())

    def test_streaming_create_translation(self):
        remote_runner = DataflowRunner()
        self.default_properties.append("--streaming")
        self.default_properties.append("--experiments=disable_runner_v2")
        with Pipeline(remote_runner,
                      PipelineOptions(self.default_properties)) as p:
            p | ptransform.Create([1])  # pylint: disable=expression-not-assigned
        job_dict = json.loads(str(remote_runner.job))
        self.assertEqual(len(job_dict[u'steps']), 3)

        self.assertEqual(job_dict[u'steps'][0][u'kind'], u'ParallelRead')
        self.assertEqual(
            job_dict[u'steps'][0][u'properties'][u'pubsub_subscription'],
            '_starting_signal/')
        self.assertEqual(job_dict[u'steps'][1][u'kind'], u'ParallelDo')
        self.assertEqual(job_dict[u'steps'][2][u'kind'], u'ParallelDo')

    def test_biqquery_read_fn_api_fail(self):
        remote_runner = DataflowRunner()
        for flag in ['beam_fn_api', 'use_unified_worker', 'use_runner_v2']:
            self.default_properties.append("--experiments=%s" % flag)
            with self.assertRaisesRegex(
                    ValueError, 'The Read.BigQuerySource.*is not supported.*'
                    'apache_beam.io.gcp.bigquery.ReadFromBigQuery.*'):
                with Pipeline(remote_runner,
                              PipelineOptions(self.default_properties)) as p:
                    _ = p | beam.io.Read(
                        beam.io.BigQuerySource(
                            'some.table', use_dataflow_native_source=True))

    def test_remote_runner_display_data(self):
        remote_runner = DataflowRunner()
        p = Pipeline(remote_runner,
                     options=PipelineOptions(self.default_properties))

        now = datetime.now()
        # pylint: disable=expression-not-assigned
        (p | ptransform.Create([1, 2, 3, 4, 5])
         | 'Do' >> SpecialParDo(SpecialDoFn(), now))

        # TODO(BEAM-366) Enable runner API on this test.
        p.run(test_runner_api=False)
        job_dict = json.loads(str(remote_runner.job))
        steps = [
            step for step in job_dict['steps']
            if len(step['properties'].get('display_data', [])) > 0
        ]
        step = steps[1]
        disp_data = step['properties']['display_data']
        nspace = SpecialParDo.__module__ + '.'
        expected_data = [{
            'type':
            'TIMESTAMP',
            'namespace':
            nspace + 'SpecialParDo',
            'value':
            DisplayDataItem._format_value(now, 'TIMESTAMP'),
            'key':
            'a_time'
        }, {
            'type': 'STRING',
            'namespace': nspace + 'SpecialParDo',
            'value': nspace + 'SpecialParDo',
            'key': 'a_class',
            'shortValue': 'SpecialParDo'
        }, {
            'type': 'INTEGER',
            'namespace': nspace + 'SpecialDoFn',
            'value': 42,
            'key': 'dofn_value'
        }]
        self.assertUnhashableCountEqual(disp_data, expected_data)

    def test_no_group_by_key_directly_after_bigquery(self):
        remote_runner = DataflowRunner()
        with self.assertRaises(ValueError,
                               msg=('Coder for the GroupByKey operation'
                                    '"GroupByKey" is not a key-value coder: '
                                    'RowAsDictJsonCoder')):
            with beam.Pipeline(runner=remote_runner,
                               options=PipelineOptions(
                                   self.default_properties)) as p:
                # pylint: disable=expression-not-assigned
                p | beam.io.Read(
                    beam.io.BigQuerySource(
                        'dataset.faketable',
                        use_dataflow_native_source=True)) | beam.GroupByKey()

    def test_group_by_key_input_visitor_with_valid_inputs(self):
        p = TestPipeline()
        pcoll1 = PCollection(p)
        pcoll2 = PCollection(p)
        pcoll3 = PCollection(p)

        pcoll1.element_type = None
        pcoll2.element_type = typehints.Any
        pcoll3.element_type = typehints.KV[typehints.Any, typehints.Any]
        for pcoll in [pcoll1, pcoll2, pcoll3]:
            applied = AppliedPTransform(None, beam.GroupByKey(), "label",
                                        {'pcoll': pcoll})
            applied.outputs[None] = PCollection(None)
            common.group_by_key_input_visitor().visit_transform(applied)
            self.assertEqual(pcoll.element_type, typehints.KV[typehints.Any,
                                                              typehints.Any])

    def test_group_by_key_input_visitor_with_invalid_inputs(self):
        p = TestPipeline()
        pcoll1 = PCollection(p)
        pcoll2 = PCollection(p)

        pcoll1.element_type = str
        pcoll2.element_type = typehints.Set
        err_msg = (r"Input to 'label' must be compatible with KV\[Any, Any\]. "
                   "Found .*")
        for pcoll in [pcoll1, pcoll2]:
            with self.assertRaisesRegex(ValueError, err_msg):
                common.group_by_key_input_visitor().visit_transform(
                    AppliedPTransform(None, beam.GroupByKey(), "label",
                                      {'in': pcoll}))

    def test_group_by_key_input_visitor_for_non_gbk_transforms(self):
        p = TestPipeline()
        pcoll = PCollection(p)
        for transform in [beam.Flatten(), beam.Map(lambda x: x)]:
            pcoll.element_type = typehints.Any
            common.group_by_key_input_visitor().visit_transform(
                AppliedPTransform(None, transform, "label", {'in': pcoll}))
            self.assertEqual(pcoll.element_type, typehints.Any)

    def test_flatten_input_with_visitor_with_single_input(self):
        self._test_flatten_input_visitor(typehints.KV[int, int], typehints.Any,
                                         1)

    def test_flatten_input_with_visitor_with_multiple_inputs(self):
        self._test_flatten_input_visitor(typehints.KV[int, typehints.Any],
                                         typehints.Any, 5)

    def _test_flatten_input_visitor(self, input_type, output_type, num_inputs):
        p = TestPipeline()
        inputs = {}
        for ix in range(num_inputs):
            input_pcoll = PCollection(p)
            input_pcoll.element_type = input_type
            inputs[str(ix)] = input_pcoll
        output_pcoll = PCollection(p)
        output_pcoll.element_type = output_type

        flatten = AppliedPTransform(None, beam.Flatten(), "label", inputs)
        flatten.add_output(output_pcoll, None)
        DataflowRunner.flatten_input_visitor().visit_transform(flatten)
        for _ in range(num_inputs):
            self.assertEqual(inputs['0'].element_type, output_type)

    def test_gbk_then_flatten_input_visitor(self):
        p = TestPipeline(runner=DataflowRunner(),
                         options=PipelineOptions(self.default_properties))
        none_str_pc = p | 'c1' >> beam.Create({None: 'a'})
        none_int_pc = p | 'c2' >> beam.Create({None: 3})
        flat = (none_str_pc, none_int_pc) | beam.Flatten()
        _ = flat | beam.GroupByKey()

        # This may change if type inference changes, but we assert it here
        # to make sure the check below is not vacuous.
        self.assertNotIsInstance(flat.element_type, typehints.TupleConstraint)

        p.visit(common.group_by_key_input_visitor())
        p.visit(DataflowRunner.flatten_input_visitor())

        # The dataflow runner requires gbk input to be tuples *and* flatten
        # inputs to be equal to their outputs. Assert both hold.
        self.assertIsInstance(flat.element_type, typehints.TupleConstraint)
        self.assertEqual(flat.element_type, none_str_pc.element_type)
        self.assertEqual(flat.element_type, none_int_pc.element_type)

    def test_serialize_windowing_strategy(self):
        # This just tests the basic path; more complete tests
        # are in window_test.py.
        strategy = Windowing(window.FixedWindows(10))
        self.assertEqual(
            strategy,
            DataflowRunner.deserialize_windowing_strategy(
                DataflowRunner.serialize_windowing_strategy(strategy, None)))

    def test_side_input_visitor(self):
        p = TestPipeline()
        pc = p | beam.Create([])

        transform = beam.Map(lambda x, y, z: (x, y, z),
                             beam.pvalue.AsSingleton(pc),
                             beam.pvalue.AsMultiMap(pc))
        applied_transform = AppliedPTransform(None, transform, "label",
                                              {'pc': pc})
        DataflowRunner.side_input_visitor(
            use_fn_api=True).visit_transform(applied_transform)
        self.assertEqual(2, len(applied_transform.side_inputs))
        for side_input in applied_transform.side_inputs:
            self.assertEqual(common_urns.side_inputs.MULTIMAP.urn,
                             side_input._side_input_data().access_pattern)

    def test_min_cpu_platform_flag_is_propagated_to_experiments(self):
        remote_runner = DataflowRunner()
        self.default_properties.append('--min_cpu_platform=Intel Haswell')

        with Pipeline(remote_runner,
                      PipelineOptions(self.default_properties)) as p:
            p | ptransform.Create([1])  # pylint: disable=expression-not-assigned
        self.assertIn(
            'min_cpu_platform=Intel Haswell',
            remote_runner.job.options.view_as(DebugOptions).experiments)

    def test_streaming_engine_flag_adds_windmill_experiments(self):
        remote_runner = DataflowRunner()
        self.default_properties.append('--streaming')
        self.default_properties.append('--enable_streaming_engine')
        self.default_properties.append('--experiment=some_other_experiment')

        with Pipeline(remote_runner,
                      PipelineOptions(self.default_properties)) as p:
            p | ptransform.Create([1])  # pylint: disable=expression-not-assigned

        experiments_for_job = (
            remote_runner.job.options.view_as(DebugOptions).experiments)
        self.assertIn('enable_streaming_engine', experiments_for_job)
        self.assertIn('enable_windmill_service', experiments_for_job)
        self.assertIn('some_other_experiment', experiments_for_job)

    def test_upload_graph_experiment(self):
        remote_runner = DataflowRunner()
        self.default_properties.append('--experiment=upload_graph')

        with Pipeline(remote_runner,
                      PipelineOptions(self.default_properties)) as p:
            p | ptransform.Create([1])  # pylint: disable=expression-not-assigned

        experiments_for_job = (
            remote_runner.job.options.view_as(DebugOptions).experiments)
        self.assertIn('upload_graph', experiments_for_job)

    def test_dataflow_worker_jar_flag_non_fnapi_noop(self):
        remote_runner = DataflowRunner()
        self.default_properties.append('--experiment=some_other_experiment')
        self.default_properties.append('--dataflow_worker_jar=test.jar')

        with Pipeline(remote_runner,
                      PipelineOptions(self.default_properties)) as p:
            p | ptransform.Create([1])  # pylint: disable=expression-not-assigned

        experiments_for_job = (
            remote_runner.job.options.view_as(DebugOptions).experiments)
        self.assertIn('some_other_experiment', experiments_for_job)
        self.assertNotIn('use_staged_dataflow_worker_jar', experiments_for_job)

    def test_dataflow_worker_jar_flag_adds_use_staged_worker_jar_experiment(
            self):
        remote_runner = DataflowRunner()
        self.default_properties.append('--experiment=beam_fn_api')
        self.default_properties.append('--dataflow_worker_jar=test.jar')

        with Pipeline(remote_runner,
                      PipelineOptions(self.default_properties)) as p:
            p | ptransform.Create([1])  # pylint: disable=expression-not-assigned

        experiments_for_job = (
            remote_runner.job.options.view_as(DebugOptions).experiments)
        self.assertIn('beam_fn_api', experiments_for_job)
        self.assertIn('use_staged_dataflow_worker_jar', experiments_for_job)

    def test_use_fastavro_experiment_is_not_added_when_use_avro_is_present(
            self):
        remote_runner = DataflowRunner()
        self.default_properties.append('--experiment=use_avro')

        with Pipeline(remote_runner,
                      PipelineOptions(self.default_properties)) as p:
            p | ptransform.Create([1])  # pylint: disable=expression-not-assigned

        debug_options = remote_runner.job.options.view_as(DebugOptions)

        self.assertFalse(debug_options.lookup_experiment(
            'use_fastavro', False))

    @mock.patch('os.environ.get', return_value=None)
    @mock.patch('apache_beam.utils.processes.check_output', return_value=b'')
    def test_get_default_gcp_region_no_default_returns_none(
            self, patched_environ, patched_processes):
        runner = DataflowRunner()
        result = runner.get_default_gcp_region()
        self.assertIsNone(result)

    @mock.patch('os.environ.get', return_value='some-region1')
    @mock.patch('apache_beam.utils.processes.check_output', return_value=b'')
    def test_get_default_gcp_region_from_environ(self, patched_environ,
                                                 patched_processes):
        runner = DataflowRunner()
        result = runner.get_default_gcp_region()
        self.assertEqual(result, 'some-region1')

    @mock.patch('os.environ.get', return_value=None)
    @mock.patch('apache_beam.utils.processes.check_output',
                return_value=b'some-region2\n')
    def test_get_default_gcp_region_from_gcloud(self, patched_environ,
                                                patched_processes):
        runner = DataflowRunner()
        result = runner.get_default_gcp_region()
        self.assertEqual(result, 'some-region2')

    @mock.patch('os.environ.get', return_value=None)
    @mock.patch('apache_beam.utils.processes.check_output',
                side_effect=RuntimeError('Executable gcloud not found'))
    def test_get_default_gcp_region_ignores_error(self, patched_environ,
                                                  patched_processes):
        runner = DataflowRunner()
        result = runner.get_default_gcp_region()
        self.assertIsNone(result)

    def test_combine_values_translation(self):
        runner = DataflowRunner()

        with beam.Pipeline(runner=runner,
                           options=PipelineOptions(
                               self.default_properties)) as p:
            (  # pylint: disable=expression-not-assigned
                p
                | beam.Create([('a', [1, 2]), ('b', [3, 4])])
                | beam.CombineValues(lambda v, _: sum(v)))

        job_dict = json.loads(str(runner.job))
        self.assertIn(u'CombineValues',
                      set(step[u'kind'] for step in job_dict[u'steps']))

    def _find_step(self, job, step_name):
        job_dict = json.loads(str(job))
        maybe_step = [
            s for s in job_dict[u'steps']
            if s[u'properties'][u'user_name'] == step_name
        ]
        self.assertTrue(maybe_step, 'Could not find step {}'.format(step_name))
        return maybe_step[0]

    def expect_correct_override(self, job, step_name, step_kind):
        """Expects that a transform was correctly overriden."""

        # If the typing information isn't being forwarded correctly, the component
        # encodings here will be incorrect.
        expected_output_info = [{
            "encoding": {
                "@type":
                "kind:windowed_value",
                "component_encodings": [{
                    "@type": "kind:bytes"
                }, {
                    "@type": "kind:global_window"
                }],
                "is_wrapper":
                True
            },
            "output_name": "out",
            "user_name": step_name + ".out"
        }]

        step = self._find_step(job, step_name)
        self.assertEqual(step[u'kind'], step_kind)

        # The display data here is forwarded because the replace transform is
        # subclassed from iobase.Read.
        self.assertGreater(len(step[u'properties']['display_data']), 0)
        self.assertEqual(step[u'properties']['output_info'],
                         expected_output_info)

    def test_read_create_translation(self):
        runner = DataflowRunner()

        with beam.Pipeline(runner=runner,
                           options=PipelineOptions(
                               self.default_properties)) as p:
            # pylint: disable=expression-not-assigned
            p | beam.Create([b'a', b'b', b'c'])

        self.expect_correct_override(runner.job, u'Create/Read',
                                     u'ParallelRead')

    def test_read_bigquery_translation(self):
        runner = DataflowRunner()

        with beam.Pipeline(runner=runner,
                           options=PipelineOptions(
                               self.default_properties)) as p:
            # pylint: disable=expression-not-assigned
            p | beam.io.Read(
                beam.io.BigQuerySource('some.table',
                                       coder=BytesCoder(),
                                       use_dataflow_native_source=True))

        self.expect_correct_override(runner.job, u'Read', u'ParallelRead')

    def test_read_pubsub_translation(self):
        runner = DataflowRunner()

        self.default_properties.append("--streaming")

        with beam.Pipeline(runner=runner,
                           options=PipelineOptions(
                               self.default_properties)) as p:
            # pylint: disable=expression-not-assigned
            p | beam.io.ReadFromPubSub(topic='projects/project/topics/topic')

        self.expect_correct_override(runner.job, u'ReadFromPubSub/Read',
                                     u'ParallelRead')

    def test_gbk_translation(self):
        runner = DataflowRunner()
        with beam.Pipeline(runner=runner,
                           options=PipelineOptions(
                               self.default_properties)) as p:
            # pylint: disable=expression-not-assigned
            p | beam.Create([(1, 2)]) | beam.GroupByKey()

        expected_output_info = [{
            "encoding": {
                "@type": "kind:windowed_value",
                "component_encodings": [{
                    "@type": "kind:pair",
                    "component_encodings": [{
                        "@type": "kind:varint"
                    },
                    {
                        "@type": "kind:stream",
                        "component_encodings": [{
                            "@type": "kind:varint"
                        }],
                        "is_stream_like": True
                    }],
                    "is_pair_like": True
                }, {
                    "@type": "kind:global_window"
                }],
                "is_wrapper": True
            },
            "output_name": "out",
            "user_name": "GroupByKey.out"
        }]  # yapf: disable

        gbk_step = self._find_step(runner.job, u'GroupByKey')
        self.assertEqual(gbk_step[u'kind'], u'GroupByKey')
        self.assertEqual(gbk_step[u'properties']['output_info'],
                         expected_output_info)

    def test_write_bigquery_translation(self):
        runner = DataflowRunner()

        self.default_properties.append('--experiments=use_legacy_bq_sink')
        with beam.Pipeline(runner=runner,
                           options=PipelineOptions(
                               self.default_properties)) as p:
            # pylint: disable=expression-not-assigned
            p | beam.Create([1]) | beam.io.WriteToBigQuery('some.table')

        job_dict = json.loads(str(runner.job))

        expected_step = {
            "kind": "ParallelWrite",
            "name": "s2",
            "properties": {
                "create_disposition": "CREATE_IF_NEEDED",
                "dataset": "some",
                "display_data": [],
                "encoding": {
                    "@type":
                    "kind:windowed_value",
                    "component_encodings": [{
                        "component_encodings": [],
                        "pipeline_proto_coder_id":
                        "ref_Coder_RowAsDictJsonCoder_4"
                    }, {
                        "@type": "kind:global_window"
                    }],
                    "is_wrapper":
                    True
                },
                "format": "bigquery",
                "parallel_input": {
                    "@type": "OutputReference",
                    "output_name": "out",
                    "step_name": "s1"
                },
                "table": "table",
                "user_name": "WriteToBigQuery/Write/NativeWrite",
                "write_disposition": "WRITE_APPEND"
            }
        }
        job_dict = json.loads(str(runner.job))
        write_step = [
            s for s in job_dict[u'steps']
            if s[u'properties'][u'user_name'].startswith('WriteToBigQuery')
        ][0]

        # Delete the @type field because in this case it is a hash which may change
        # depending on the pickling version.
        step_encoding = write_step[u'properties'][u'encoding']
        del step_encoding[u'component_encodings'][0][u'@type']
        self.assertEqual(expected_step, write_step)

    def test_write_bigquery_failed_translation(self):
        """Tests that WriteToBigQuery cannot have any consumers if replaced."""
        runner = DataflowRunner()

        self.default_properties.append('--experiments=use_legacy_bq_sink')
        with self.assertRaises(Exception):
            with beam.Pipeline(runner=runner,
                               options=PipelineOptions(
                                   self.default_properties)) as p:
                # pylint: disable=expression-not-assigned
                out = p | beam.Create(
                    [1]) | beam.io.WriteToBigQuery('some.table')
                out['destination_file_pairs'] | 'MyTransform' >> beam.Map(
                    lambda _: _)

    @unittest.skip('BEAM-3736: enable once CombineFnVisitor is fixed')
    def test_unsupported_combinefn_detection(self):
        class CombinerWithNonDefaultSetupTeardown(combiners.CountCombineFn):
            def setup(self, *args, **kwargs):
                pass

            def teardown(self, *args, **kwargs):
                pass

        runner = DataflowRunner()
        with self.assertRaisesRegex(
                ValueError, 'CombineFn.setup and CombineFn.'
                'teardown are not supported'):
            with beam.Pipeline(runner=runner,
                               options=PipelineOptions(
                                   self.default_properties)) as p:
                _ = (p | beam.Create([1])
                     | beam.CombineGlobally(
                         CombinerWithNonDefaultSetupTeardown()))

        try:
            with beam.Pipeline(runner=runner,
                               options=PipelineOptions(
                                   self.default_properties)) as p:
                _ = (p | beam.Create([1])
                     | beam.CombineGlobally(
                         combiners.SingleInputTupleCombineFn(
                             combiners.CountCombineFn(),
                             combiners.CountCombineFn())))
        except ValueError:
            self.fail('ValueError raised unexpectedly')

    def _run_group_into_batches_and_get_step_properties(
            self, with_sharded_key, additional_properties):
        self.default_properties.append('--streaming')
        for property in additional_properties:
            self.default_properties.append(property)

        runner = DataflowRunner()
        with beam.Pipeline(runner=runner,
                           options=PipelineOptions(
                               self.default_properties)) as p:
            # pylint: disable=expression-not-assigned
            input = p | beam.Create([('a', 1), ('a', 1), ('b', 3), ('b', 4)])
            if with_sharded_key:
                (input | beam.GroupIntoBatches.WithShardedKey(2)
                 | beam.Map(lambda key_values:
                            (key_values[0].key, key_values[1])))
                step_name = (
                    u'WithShardedKey/GroupIntoBatches/ParDo(_GroupIntoBatchesDoFn)'
                )
            else:
                input | beam.GroupIntoBatches(2)
                step_name = u'GroupIntoBatches/ParDo(_GroupIntoBatchesDoFn)'

        return self._find_step(runner.job, step_name)['properties']

    def test_group_into_batches_translation(self):
        properties = self._run_group_into_batches_and_get_step_properties(
            True, ['--enable_streaming_engine', '--experiments=use_runner_v2'])
        self.assertEqual(properties[PropertyNames.USES_KEYED_STATE], u'true')
        self.assertEqual(properties[PropertyNames.ALLOWS_SHARDABLE_STATE],
                         u'true')
        self.assertEqual(properties[PropertyNames.PRESERVES_KEYS], u'true')

    def test_group_into_batches_translation_non_sharded(self):
        properties = self._run_group_into_batches_and_get_step_properties(
            False,
            ['--enable_streaming_engine', '--experiments=use_runner_v2'])
        self.assertEqual(properties[PropertyNames.USES_KEYED_STATE], u'true')
        self.assertNotIn(PropertyNames.ALLOWS_SHARDABLE_STATE, properties)
        self.assertNotIn(PropertyNames.PRESERVES_KEYS, properties)

    def test_group_into_batches_translation_non_se(self):
        with self.assertRaisesRegex(
                ValueError,
                'Runner determined sharding not available in Dataflow for '
                'GroupIntoBatches for non-Streaming-Engine jobs'):
            _ = self._run_group_into_batches_and_get_step_properties(
                True, ['--experiments=use_runner_v2'])

    def test_group_into_batches_translation_non_unified_worker(self):
        # non-portable
        with self.assertRaisesRegex(
                ValueError,
                'Runner determined sharding not available in Dataflow for '
                'GroupIntoBatches for jobs not using Runner V2'):
            _ = self._run_group_into_batches_and_get_step_properties(
                True, [
                    '--enable_streaming_engine',
                    '--experiments=disable_runner_v2'
                ])

        # JRH
        with self.assertRaisesRegex(
                ValueError,
                'Runner determined sharding not available in Dataflow for '
                'GroupIntoBatches for jobs not using Runner V2'):
            _ = self._run_group_into_batches_and_get_step_properties(
                True, [
                    '--enable_streaming_engine', '--experiments=beam_fn_api',
                    '--experiments=disable_runner_v2'
                ])

    def test_pack_combiners(self):
        class PackableCombines(beam.PTransform):
            def annotations(self):
                return {python_urns.APPLY_COMBINER_PACKING: b''}

            def expand(self, pcoll):
                _ = pcoll | 'PackableMin' >> beam.CombineGlobally(min)
                _ = pcoll | 'PackableMax' >> beam.CombineGlobally(max)

        runner = DataflowRunner()
        with beam.Pipeline(runner=runner,
                           options=PipelineOptions(
                               self.default_properties)) as p:
            _ = p | beam.Create([10, 20, 30]) | PackableCombines()

        unpacked_minimum_step_name = (
            'PackableCombines/PackableMin/CombinePerKey/Combine')
        unpacked_maximum_step_name = (
            'PackableCombines/PackableMax/CombinePerKey/Combine')
        packed_step_name = (
            'PackableCombines/Packed[PackableMin_CombinePerKey, '
            'PackableMax_CombinePerKey]/Pack')
        transform_names = set(
            transform.unique_name for transform in
            runner.proto_pipeline.components.transforms.values())
        self.assertNotIn(unpacked_minimum_step_name, transform_names)
        self.assertNotIn(unpacked_maximum_step_name, transform_names)
        self.assertIn(packed_step_name, transform_names)

    @parameterized.expand([
        param(memory_hint='min_ram'),
        param(memory_hint='minRam'),
    ])
    def test_resource_hints_translation(self, memory_hint):
        runner = DataflowRunner()
        self.default_properties.append('--resource_hint=accelerator=some_gpu')
        self.default_properties.append(f'--resource_hint={memory_hint}=20GB')
        with beam.Pipeline(runner=runner,
                           options=PipelineOptions(
                               self.default_properties)) as p:
            # pylint: disable=expression-not-assigned
            (p
             | beam.Create([1])
             | 'MapWithHints' >> beam.Map(lambda x: x + 1).with_resource_hints(
                 min_ram='10GB',
                 accelerator=
                 'type:nvidia-tesla-k80;count:1;install-nvidia-drivers'))

        step = self._find_step(runner.job, 'MapWithHints')
        self.assertEqual(
            step['properties']['resource_hints'],
            {
                'beam:resources:min_ram_bytes:v1': '20000000000',
                'beam:resources:accelerator:v1': \
                    'type%3Anvidia-tesla-k80%3Bcount%3A1%3Binstall-nvidia-drivers'
            })
Esempio n. 7
0
class TimeZoneSettingsTest(BaseTestCase):
    def setUp(self):
        super().setUp()
        self.given_ds = NotImplemented
        self.result = NotImplemented
        self.timezone = NotImplemented
        self.confs = NotImplemented

    @parameterized.expand([
        param('12 Feb 2015 10:30 PM +0100', datetime(2015, 2, 12, 22, 30),
              r'UTC\+01:00'),
        param('12 Feb 2015 4:30 PM EST', datetime(2015, 2, 12, 16, 30), 'EST'),
        param('12 Feb 2015 8:30 PM PKT', datetime(2015, 2, 12, 20, 30), 'PKT'),
        param('12 Feb 2015 8:30 PM ACT', datetime(2015, 2, 12, 20, 30), 'ACT'),
    ])
    def test_should_return_and_assert_tz(self, ds, dt, tz):
        self.given(ds)
        self.given_configurations({})
        self.when_date_is_parsed()
        self.then_date_is_tz_aware()
        self.then_date_is(dt)
        self.then_timezone_is(tz)

    @parameterized.expand([
        param('12 Feb 2015 4:30 PM EST', datetime(2015, 2, 12, 16, 30), 'EST'),
        param('12 Feb 2015 8:30 PM PKT', datetime(2015, 2, 12, 20, 30), 'PKT'),
        param('12 Feb 2015 8:30 PM ACT', datetime(2015, 2, 12, 20, 30), 'ACT'),
        param('12 Feb 2015 8:30 PM', datetime(2015, 2, 12, 20, 30), ''),
    ])
    def test_only_return_explicit_timezone(self, ds, dt, tz):
        self.given(ds)
        self.given_configurations({})
        self.when_date_is_parsed()
        self.then_date_is(dt)
        if tz:
            self.then_date_is_tz_aware()
            self.then_timezone_is(tz)
        else:
            self.then_date_is_not_tz_aware()

    @parameterized.expand([
        param(
            '12 Feb 2015 4:30 PM EST',
            datetime(2015, 2, 12, 16, 30),
        ),
        param(
            '12 Feb 2015 8:30 PM PKT',
            datetime(2015, 2, 12, 20, 30),
        ),
        param(
            '12 Feb 2015 8:30 PM ACT',
            datetime(2015, 2, 12, 20, 30),
        ),
        param(
            '12 Feb 2015 8:30 PM +0100',
            datetime(2015, 2, 12, 20, 30),
        ),
    ])
    def test_should_return_naive_if_RETURN_AS_TIMEZONE_AWARE_IS_FALSE(
            self, ds, dt):
        self.given(ds)
        self.given_configurations({'RETURN_AS_TIMEZONE_AWARE': False})
        self.when_date_is_parsed()
        self.then_date_is(dt)
        self.then_date_is_not_tz_aware()

    def then_timezone_is(self, tzname):
        self.assertEqual(self.result.tzinfo.tzname(''), tzname)

    def given(self, ds):
        self.given_ds = ds

    def given_configurations(self, confs):
        if 'TIMEZONE' not in confs:
            confs.update({'TIMEZONE': 'local'})

        self.confs = settings.replace(settings=confs)

    def when_date_is_parsed(self):
        self.result = parse(self.given_ds, settings=(self.confs or {}))

    def then_date_is_tz_aware(self):
        self.assertIsInstance(self.result.tzinfo, tzinfo)

    def then_date_is_not_tz_aware(self):
        self.assertIsNone(self.result.tzinfo)

    def then_date_is(self, date):
        dtc = self.result.replace(tzinfo=None)
        self.assertEqual(dtc, date)
Esempio n. 8
0
class TestDateDataParser(BaseTestCase):
    def setUp(self):
        super(TestDateDataParser, self).setUp()
        self.parser = NotImplemented
        self.result = NotImplemented
        self.multiple_results = NotImplemented

    @parameterized.expand([
        param('10:04am EDT'),
    ])
    def test_time_without_date_should_use_today(self, date_string):
        self.given_parser(settings={'RELATIVE_BASE': datetime(2020, 7, 19)})
        self.when_date_string_is_parsed(date_string)
        self.then_date_was_parsed()
        self.then_parsed_date_is(datetime(2020, 7, 19).date())

    @parameterized.expand([
        # Today
        param('today', days_ago=0),
        param('Today', days_ago=0),
        param('TODAY', days_ago=0),
        param('Сегодня', days_ago=0),
        param('Hoje', days_ago=0),
        param('Oggi', days_ago=0),
        # Yesterday
        param('yesterday', days_ago=1),
        param(' Yesterday \n', days_ago=1),
        param('Ontem', days_ago=1),
        param('Ieri', days_ago=1),
        # Day before yesterday
        param('the day before yesterday', days_ago=2),
        param('The DAY before Yesterday', days_ago=2),
        param('Anteontem', days_ago=2),
        param('Avant-hier', days_ago=2),
        param('вчера', days_ago=1),
        param('снощи', days_ago=1)
    ])
    def test_temporal_nouns_are_parsed(self, date_string, days_ago):
        self.given_parser()
        self.when_date_string_is_parsed(date_string)
        self.then_date_was_parsed()
        self.then_date_is_n_days_ago(days=days_ago)

    def test_should_not_assume_language_too_early(self):
        dates_to_parse = OrderedDict([
            ('07/07/2014', datetime(2014, 7, 7).date()),  # any language
            ('07.jul.2014 | 12:52', datetime(2014, 7,
                                             7).date()),  # en, es, pt, nl
            ('07.ago.2014 | 12:52', datetime(2014, 8, 7).date()),  # es, it, pt
            ('07.feb.2014 | 12:52',
             datetime(2014, 2, 7).date()),  # en, de, es, it, nl, ro
            ('07.ene.2014 | 12:52', datetime(2014, 1, 7).date())
        ])  # es

        self.given_parser(restrict_to_languages=[
            'en', 'de', 'fr', 'it', 'pt', 'nl', 'ro', 'es', 'ru'
        ])
        self.when_multiple_dates_are_parsed(dates_to_parse.keys())
        self.then_all_results_were_parsed()
        self.then_parsed_dates_are(list(dates_to_parse.values()))

    @parameterized.expand([
        param(date_string='11 Marzo, 2014', locale='es'),
        param(date_string='13 Septiembre, 2014', locale='es'),
        param(date_string='Сегодня', locale='ru'),
        param(date_string='Avant-hier', locale='fr'),
        param(date_string='Anteontem', locale='pt'),
        param(date_string='ธันวาคม 11, 2014, 08:55:08 PM', locale='th'),
        param(date_string='Anteontem', locale='pt'),
        param(date_string='14 aprilie 2014', locale='ro'),
        param(date_string='11 Ağustos, 2014', locale='tr'),
        param(date_string='pon 16. čer 2014 10:07:43', locale='cs'),
        param(date_string='24 януари 2015г.', locale='bg')
    ])
    def test_returned_detected_locale_should_be(self, date_string, locale):
        self.given_parser()
        self.when_date_string_is_parsed(date_string)
        self.then_detected_locale(locale)

    @parameterized.expand([
        param("2014-10-09T17:57:39+00:00"),
    ])
    def test_get_date_data_should_not_strip_timezone_info(self, date_string):
        self.given_parser()
        self.when_date_string_is_parsed(date_string)
        self.then_date_was_parsed()
        self.then_parsed_date_has_timezone()

    @parameterized.expand([
        param(date_string="14 giu 13",
              date_formats=["%y %B %d"],
              expected_result=datetime(2014, 6, 13)),
        param(date_string="14_luglio_15",
              date_formats=["%y_%B_%d"],
              expected_result=datetime(2014, 7, 15)),
        param(date_string="14_LUGLIO_15",
              date_formats=["%y_%B_%d"],
              expected_result=datetime(2014, 7, 15)),
        param(date_string="10.01.2016, 20:35",
              date_formats=["%d.%m.%Y, %H:%M"],
              expected_result=datetime(2016, 1, 10, 20, 35)),
    ])
    def test_parse_date_using_format(self, date_string, date_formats,
                                     expected_result):
        self.given_local_tz_offset(0)
        self.given_parser()
        self.when_date_string_is_parsed(date_string, date_formats)
        self.then_date_was_parsed()
        self.then_period_is('day')
        self.then_parsed_datetime_is(expected_result)

    @parameterized.expand([
        param(date_string="11/09/2007",
              date_formats={"date_formats": ["%d/%m/%Y"]}),
        param(date_string="16.09.03 11:55", date_formats=111),
        param(date_string="08-01-1998", date_formats=12.56),
    ])
    def test_parsing_date_using_invalid_type_date_format_must_raise_error(
            self, date_string, date_formats):
        self.given_local_tz_offset(0)
        self.given_parser()
        self.when_date_string_is_parsed(date_string, date_formats)
        self.then_error_was_raised(TypeError, [
            "Date formats should be list, tuple or set of strings",
            "'{}' object is not iterable".format(type(date_formats).__name__)
        ])

    def test_parsing_date_using_unknown_parsers_must_raise_error(self):
        self.given_parser(settings={'PARSERS': ['foo']})
        self.when_date_string_is_parsed('2020-02-19')
        self.then_error_was_raised(
            ValueError, ["Unknown parsers found in the PARSERS setting: foo"])

    @parameterized.expand([
        param(date_string={"date": "12/11/1998"}),
        param(date_string=[2017, 12, 1]),
        param(date_string=2018),
        param(date_string=12.2000),
        param(date_string=datetime(year=2009, month=12, day=7)),
    ])
    def test_parsing_date_using_invalid_type_date_string_must_raise_error(
            self, date_string):
        self.given_parser()
        self.when_date_string_is_parsed(date_string)
        self.then_error_was_raised(TypeError,
                                   ["Input type must be str or unicode"])

    @parameterized.expand([
        param(date_string="2014/11/17 14:56 EDT",
              expected_result=datetime(2014, 11, 17, 18, 56)),
    ])
    def test_parse_date_with_timezones_not_using_formats(
            self, date_string, expected_result):
        self.given_parser(settings={'TO_TIMEZONE': 'UTC'})
        self.when_date_string_is_parsed(date_string)
        self.then_date_was_parsed()
        self.then_period_is('day')
        self.result['date_obj'] = self.result['date_obj'].replace(tzinfo=None)
        self.then_parsed_datetime_is(expected_result)

    @parameterized.expand([
        param(date_string="2014/11/17 14:56 EDT",
              date_formats=["%Y/%m/%d %H:%M EDT"],
              expected_result=datetime(2014, 11, 17, 14, 56)),
    ])
    def test_parse_date_with_timezones_using_formats_ignore_timezone(
            self, date_string, date_formats, expected_result):
        self.given_local_tz_offset(0)
        self.given_parser()
        self.when_date_string_is_parsed(date_string, date_formats)
        self.then_date_was_parsed()
        self.then_period_is('day')
        self.then_parsed_datetime_is(expected_result)

    @parameterized.expand([
        param(date_string="08-08-2014\xa018:29",
              expected_result=datetime(2014, 8, 8, 18, 29)),
    ])
    def test_should_parse_with_no_break_space_in_dates(self, date_string,
                                                       expected_result):
        self.given_parser()
        self.when_date_string_is_parsed(date_string)
        self.then_date_was_parsed()
        self.then_period_is('day')
        self.then_parsed_datetime_is(expected_result)

    @parameterized.expand([
        param(date_string="12 jan 1876",
              expected_result=(datetime(1876, 1, 12, 0, 0), 'day', 'en')),
        param(date_string="02/09/16",
              expected_result=(datetime(2016, 2, 9, 0, 0), 'day', 'en')),
        param(date_string="10 giu 2018",
              expected_result=(datetime(2018, 6, 10, 0, 0), 'day', 'it')),
    ])
    def test_get_date_tuple(self, date_string, expected_result):
        self.given_parser()
        self.when_get_date_tuple_is_called(date_string)
        self.then_returned_tuple_is(expected_result)

    def given_now(self, year, month, day, **time):
        datetime_mock = Mock(wraps=datetime)
        datetime_mock.utcnow = Mock(
            return_value=datetime(year, month, day, **time))
        self.add_patch(
            patch('dateparser.date_parser.datetime', new=datetime_mock))

    def given_parser(self, restrict_to_languages=None, **params):
        self.parser = date.DateDataParser(languages=restrict_to_languages,
                                          **params)

    def given_local_tz_offset(self, offset):
        self.add_patch(
            patch.object(dateparser.timezone_parser,
                         'local_tz_offset',
                         new=timedelta(seconds=3600 * offset)))

    def when_date_string_is_parsed(self, date_string, date_formats=None):
        try:
            self.result = self.parser.get_date_data(date_string, date_formats)
        except Exception as error:
            self.error = error

    def when_multiple_dates_are_parsed(self, date_strings):
        self.multiple_results = []
        for date_string in date_strings:
            try:
                result = self.parser.get_date_data(date_string)
            except Exception as error:
                result = error
            finally:
                self.multiple_results.append(result)

    def when_get_date_tuple_is_called(self, date_string):
        self.result = self.parser.get_date_tuple(date_string)

    def then_date_was_parsed(self):
        self.assertIsNotNone(self.result['date_obj'])

    def then_date_locale(self):
        self.assertIsNotNone(self.result['locale'])

    def then_date_is_n_days_ago(self, days):
        today = datetime.now().date()
        expected_date = today - timedelta(days=days)
        self.assertEqual(expected_date, self.result['date_obj'].date())

    def then_all_results_were_parsed(self):
        self.assertNotIn(None, self.multiple_results)

    def then_parsed_dates_are(self, expected_dates):
        self.assertEqual(
            expected_dates,
            [result['date_obj'].date() for result in self.multiple_results])

    def then_detected_locale(self, locale):
        self.assertEqual(locale, self.result['locale'])

    def then_period_is(self, day):
        self.assertEqual(day, self.result['period'])

    def then_parsed_datetime_is(self, expected_datetime):
        self.assertEqual(expected_datetime, self.result['date_obj'])

    def then_parsed_date_is(self, expected_date):
        self.assertEqual(expected_date, self.result['date_obj'].date())

    def then_parsed_date_has_timezone(self):
        self.assertTrue(hasattr(self.result['date_obj'], 'tzinfo'))

    def then_returned_tuple_is(self, expected_tuple):
        self.assertEqual(expected_tuple, self.result)
Esempio n. 9
0
class ParseTest(unittest.TestCase):

  @parameterized.expand([
      param(
          "EmptyLines",
          lines=[],
          expected_pbtxt=""
      ),
      param(
          "SingleLine",
          lines=[["STATE-1", "STATE-2", "+Morpheme[Cat=Val]", "+Morpheme"]],
          expected_pbtxt="""
          rule {
            from_state: 'STATE-1'
            to_state: 'STATE-2'
            input: '+Morpheme[Cat=Val]'
            output: '+Morpheme'
          }
          """
      ),
      param(
          "MultipleLines",
          lines=[
              ["STATE-1", "STATE-2", "+Morpheme1[Cat1=Val1]", "+Morpheme1"],
              ["STATE-3", "STATE-4", "+Morpheme2[Cat2=Val2]", "+Morpheme2"],
          ],
          expected_pbtxt="""
          rule {
            from_state: 'STATE-1'
            to_state: 'STATE-2'
            input: '+Morpheme1[Cat1=Val1]'
            output: '+Morpheme1'
          }
          rule {
            from_state: 'STATE-3'
            to_state: 'STATE-4'
            input: '+Morpheme2[Cat2=Val2]'
            output: '+Morpheme2'
          }
          """
      ),
      param(
          "NormalizesFromStateName",
          lines=[["sTaTe-1", "STATE-2", "+Morpheme[Cat=Val]", "+Morpheme"]],
          expected_pbtxt="""
          rule {
            from_state: 'STATE-1'
            to_state: 'STATE-2'
            input: '+Morpheme[Cat=Val]'
            output: '+Morpheme'
          }
          """
      ),
      param(
          "NormalizesToStateName",
          lines=[["STATE-1", "StAtE-2", "+Morpheme[Cat=Val]", "+Morpheme"]],
          expected_pbtxt="""
          rule {
            from_state: 'STATE-1'
            to_state: 'STATE-2'
            input: '+Morpheme[Cat=Val]'
            output: '+Morpheme'
          }
          """
      ),
      param(
          "NormalizesBracketedOutputToken",
          lines=[["STATE-1", "StAtE-2", "<BrAcKeTeD>", "+Morpheme"]],
          expected_pbtxt="""
          rule {
            from_state: 'STATE-1'
            to_state: 'STATE-2'
            input: '<bracketed>'
            output: '+Morpheme'
          }
          """
      ),
      param(
          "NormalizesBracketedInputToken",
          lines=[["STATE-1", "StAtE-2", "+Morpheme[Cat=Val]", "<BrAcKeTeD>"]],
          expected_pbtxt="""
          rule {
            from_state: 'STATE-1'
            to_state: 'STATE-2'
            input: '+Morpheme[Cat=Val]'
            output: '<bracketed>'
          }
          """
      ),
  ])
  def test_success(self, _, lines, expected_pbtxt):
    actual = parser.parse(lines)
    expected = rule_pb2.RewriteRuleSet()
    text_format.Parse(expected_pbtxt, expected)
    self.assertEqual(expected, actual)
# which is needed for package testing
TEST_DIR = os.environ.get('TESTDIR', None)
if TEST_DIR:
    rsmtool_test_dir = TEST_DIR
else:
    from rsmtool.test_utils import rsmtool_test_dir


# set this to False to disable auto-updating of all experiment
# tests contained in this file via `update_files.py`.
# TODO: re-enable this once we start saving rsmcompare outputs
_AUTO_UPDATE = False


@parameterized([
    param('lr-self-compare', 'lr_subgroups_vs_lr_subgroups_report'),
    param('lr-different-compare', 'lr_baseline_vs_lr_with_FEATURE8_and_zero_scores_report'),
    param('lr-self-compare-with-h2', 'lr_with_h2_vs_lr_with_h2_report'),
    param('lr-self-compare-with-custom-order', 'lr_subgroups_vs_lr_subgroups_report'),
    param('lr-self-compare-with-chosen-sections', 'lr_subgroups_vs_lr_subgroups_report'),
    param('lr-self-compare-with-custom-sections-and-custom-order', 'lr_subgroups_vs_lr_subgroups_report'),
    param('lr-self-compare-with-thumbnails', 'lr_subgroups_vs_lr_subgroups_report'),
    param('linearsvr-self-compare', 'LinearSVR_vs_LinearSVR_report'),
    param('lr-eval-self-compare', 'lr_eval_with_h2_vs_lr_eval_with_h2_report'),
    param('lr-eval-tool-compare', 'lr_with_h2_vs_lr_eval_with_h2_report'),
    param('lr-self-compare-different-format', 'lr_subgroups_vs_lr_subgroups_report'),
    param('lr-self-compare-with-subgroups-and-h2', 'lr-subgroups-with-h2_vs_lr-subgroups-with-h2_report'),
    param('lr-self-compare-with-subgroups-and-edge-cases', 'lr-subgroups-with-edge-cases_vs_lr-subgroups-with-edge-cases_report')
])
def test_run_experiment_parameterized(*args, **kwargs):
    if TEST_DIR:
from rsmtool.reporter import Reporter
from rsmtool.test_utils import (check_report,
                                check_run_experiment,
                                do_run_experiment)

# allow test directory to be set via an environment variable
# which is needed for package testing
TEST_DIR = os.environ.get('TESTDIR', None)
if TEST_DIR:
    rsmtool_test_dir = TEST_DIR
else:
    from rsmtool.test_utils import rsmtool_test_dir


@parameterized([
    param('lr', 'lr'),
    param('lr-subset-features', 'lr_subset'),
    param('lr-with-feature-subset-file', 'lr_with_feature_subset_file'),
    param('lr-subgroups', 'lr_subgroups', subgroups=['L1']),
    param('lr-with-numeric-subgroup', 'lr_with_numeric_subgroup', subgroups=['ITEM', 'QUESTION']),
    param('lr-with-id-with-leading-zeros', 'lr_with_id_with_leading_zeros', subgroups=['ITEM', 'QUESTION']),
    param('lr-subgroups-with-edge-cases', 'lr_subgroups_with_edge_cases', subgroups=['group_edge_cases']),
    param('lr-missing-values', 'lr_missing_values'),
    param('lr-include-zeros', 'lr_include_zeros'),
    param('lr-with-length', 'lr_with_length'),
    param('lr-subgroups-with-length', 'lr_subgroups_with_length', subgroups=['L1', 'QUESTION']),
    param('lr-with-large-integer-value', 'lr_with_large_integer_value'),
    param('lr-with-missing-length-values', 'lr_with_missing_length_values'),
    param('lr-with-length-zero-sd', 'lr_with_length_zero_sd'),
    param('lr-with-h2', 'lr_with_h2', consistency=True),
    param('lr-subgroups-with-h2', 'lr_subgroups_with_h2', subgroups=['L1', 'QUESTION'], consistency=True),
                                check_generated_output,
                                check_run_prediction,
                                do_run_experiment,
                                do_run_prediction)

# allow test directory to be set via an environment variable
# which is needed for package testing
TEST_DIR = os.environ.get('TESTDIR', None)
if TEST_DIR:
    rsmtool_test_dir = TEST_DIR
else:
    from rsmtool.test_utils import rsmtool_test_dir


@parameterized([
    param('lr-predict'),
    param('lr-predict-with-score'),
    param('lr-predict-missing-values', excluded=True),
    param('lr-predict-with-subgroups'),
    param('lr-predict-with-candidate'),
    param('lr-predict-illegal-transformations', excluded=True),
    param('lr-predict-tsv-input-files'),
    param('lr-predict-xlsx-input-files'),
    param('lr-predict-no-standardization'),
    param('lr-predict-with-tsv-output', file_format='tsv'),
    param('lr-predict-with-xlsx-output', file_format='xlsx'),
    param('logistic-regression-predict'),
    param('logistic-regression-predict-expected-scores'),
    param('svc-predict-expected-scores')
])
def test_run_experiment_parameterized(*args, **kwargs):
Esempio n. 13
0
    """Tests that instantiation of PydvdidException raises an exception.
    """

    try:
        PydvdidException("This should not work.")
    except TypeError as expected:
        eq_("PydvdidException may not be directly instantiated.", str(expected))
    except Exception as unexpected: # pylint: disable=locally-disabled, broad-except
        ok_(False, "An unexpected {0} exception was raised.".format(type(unexpected).__name__))
    else:
        ok_(False, "An exception was expected but was not raised.")


@istest
@parameterized([
    param(7000, None, "No bytes are available."),
    param(20, 12, "20 bytes were expected, 12 were read.")
])
@patch("pydvdid.exceptions.PydvdidException.__init__") # pylint: disable=locally-disabled, invalid-name
def filecontentreadexception___init__calls_base___init___with_correct_message(expected_size,
                                                                              actual_size,
                                                                              expected_message,
                                                                              mock_init):
    """Tests that instantiation of FileContentReadException instantiates the base class with a
       formatted message.
    """

    mock_init.return_value = None

    FileContentReadException(expected_size, actual_size)
Esempio n. 14
0
def create_edit_cases():
    application_source = ApplicationFixtureFactory.make_application_source()
    account_source = AccountFixtureFactory.make_publisher_source()

    editable_application = Suggestion(**application_source)
    editable_application.set_application_status(constants.APPLICATION_STATUS_UPDATE_REQUEST)

    non_editable_application = Suggestion(**application_source)
    non_editable_application.set_application_status(constants.APPLICATION_STATUS_READY)

    owner_account = Account(**deepcopy(account_source))
    owner_account.set_id(editable_application.owner)

    non_owner_publisher = Account(**deepcopy(account_source))

    non_publisher = Account(**deepcopy(account_source))
    non_publisher.remove_role("publisher")

    admin = Account(**deepcopy(account_source))
    admin.add_role("admin")

    return [
        param("no_app_no_account", None, None, raises=exceptions.ArgumentException),
        param("no_app_with_account", None, owner_account, raises=exceptions.ArgumentException),
        param("app_no_account", editable_application, None, raises=exceptions.ArgumentException),
        param("editable_app_owning_account", editable_application, owner_account, expected=True),
        param("editable_app_nonowning_account", editable_application, non_owner_publisher, raises=exceptions.AuthoriseException),
        param("editable_app_non_publisher_account", editable_application, non_publisher, raises=exceptions.AuthoriseException),
        param("editable_app_admin_account", editable_application, admin, expected=True),
        param("non_editable_app_owning_account", non_editable_application, owner_account, raises=exceptions.AuthoriseException),
        param("non_editable_app_nonowning_account", non_editable_application, non_owner_publisher, raises=exceptions.AuthoriseException),
        param("non_editable_app_non_publisher_account", non_editable_application, non_publisher, raises=exceptions.AuthoriseException),
        param("non_editable_app_admin_account", non_editable_application, admin, expected=True)
    ]
from parameterized import param, parameterized

from rsmtool.test_utils import (check_run_experiment,
                                do_run_experiment)

# allow test directory to be set via an environment variable
# which is needed for package testing
TEST_DIR = os.environ.get('TESTDIR', None)
if TEST_DIR:
    rsmtool_test_dir = TEST_DIR
else:
    from rsmtool.test_utils import rsmtool_test_dir


@parameterized([
    param('lr-with-defaults-as-extra-columns', 'lr_with_defaults_as_extra_columns',
          consistency=True),
    param('lr-with-truncations', 'lr_with_truncations'),
    param('lr-exclude-listwise', 'lr_exclude_listwise'),
    param('lr-with-custom-order', 'lr_with_custom_order'),
    param('lr-with-custom-sections', 'lr_with_custom_sections'),
    param('lr-with-custom-sections-and-order', 'lr_with_custom_sections_and_order'),
    param('lr-exclude-flags', 'lr_exclude_flags'),
    param('lr-exclude-flags-and-zeros', 'lr_exclude_flags_and_zeros'),
    param('lr-use-all-features', 'lr_use_all_features'),
    param('lr-candidate-same-as-id', 'lr_candidate_same_as_id'),
    param('lr-candidate-same-as-id-candidate', 'lr_candidate_same_as_id_candidate'),
    param('lr-tsv-input-files', 'lr_tsv_input_files'),
    param('lr-tsv-input-and-subset-files', 'lr_tsv_input_and_subset_files'),
    param('lr-xlsx-input-files', 'lr_xlsx_input_files'),
    param('lr-xlsx-input-and-subset-files', 'lr_xlsx_input_and_subset_files')
])
Esempio n. 16
0
class TestPackageZip(PackageIntegBase):
    def setUp(self):
        super().setUp()

    def tearDown(self):
        super().tearDown()

    @parameterized.expand(["aws-serverless-function.yaml"])
    def test_package_template_flag(self, template_file):
        template_path = self.test_data_path.joinpath(template_file)
        command_list = self.get_command_list(s3_bucket=self.s3_bucket.name,
                                             template=template_path)

        process = Popen(command_list, stdout=PIPE)
        try:
            stdout, _ = process.communicate(timeout=TIMEOUT)
        except TimeoutExpired:
            process.kill()
            raise
        process_stdout = stdout.strip()

        self.assertIn("{bucket_name}".format(bucket_name=self.s3_bucket.name),
                      process_stdout.decode("utf-8"))

    @parameterized.expand([
        "aws-serverless-function.yaml",
        "aws-serverless-api.yaml",
        "aws-serverless-httpapi.yaml",
        "aws-appsync-graphqlschema.yaml",
        "aws-appsync-resolver.yaml",
        "aws-appsync-functionconfiguration.yaml",
        "aws-lambda-function.yaml",
        "aws-apigateway-restapi.yaml",
        "aws-elasticbeanstalk-applicationversion.yaml",
        "aws-cloudformation-stack.yaml",
        "aws-serverless-application.yaml",
        "aws-lambda-layerversion.yaml",
        "aws-serverless-layerversion.yaml",
        "aws-glue-job.yaml",
        "aws-serverlessrepo-application.yaml",
        "aws-serverless-statemachine.yaml",
        "aws-stepfunctions-statemachine.yaml",
    ])
    def test_package_barebones(self, template_file):
        template_path = self.test_data_path.joinpath(template_file)
        command_list = self.get_command_list(s3_bucket=self.s3_bucket.name,
                                             template_file=template_path)

        process = Popen(command_list, stdout=PIPE)
        try:
            stdout, _ = process.communicate(timeout=TIMEOUT)
        except TimeoutExpired:
            process.kill()
            raise
        process_stdout = stdout.strip()

        self.assertIn("{bucket_name}".format(bucket_name=self.s3_bucket.name),
                      process_stdout.decode("utf-8"))

    def test_package_without_required_args(self):
        command_list = self.get_command_list()

        process = Popen(command_list, stdout=PIPE)
        try:
            process.communicate(timeout=TIMEOUT)
        except TimeoutExpired:
            process.kill()
            raise
        self.assertNotEqual(process.returncode, 0)

    @parameterized.expand([
        "aws-serverless-function.yaml",
        "aws-serverless-api.yaml",
        "aws-serverless-httpapi.yaml",
        "aws-appsync-graphqlschema.yaml",
        "aws-appsync-resolver.yaml",
        "aws-appsync-functionconfiguration.yaml",
        "aws-lambda-function.yaml",
        "aws-apigateway-restapi.yaml",
        "aws-elasticbeanstalk-applicationversion.yaml",
        "aws-cloudformation-stack.yaml",
        "aws-serverless-application.yaml",
        "aws-lambda-layerversion.yaml",
        "aws-serverless-layerversion.yaml",
        "aws-glue-job.yaml",
        "aws-serverlessrepo-application.yaml",
        "aws-serverless-statemachine.yaml",
        "aws-stepfunctions-statemachine.yaml",
    ])
    def test_package_with_prefix(self, template_file):
        template_path = self.test_data_path.joinpath(template_file)
        s3_prefix = "integ_test_prefix"
        command_list = self.get_command_list(s3_bucket=self.s3_bucket.name,
                                             template_file=template_path,
                                             s3_prefix=s3_prefix)

        process = Popen(command_list, stdout=PIPE)
        try:
            stdout, _ = process.communicate(timeout=TIMEOUT)
        except TimeoutExpired:
            process.kill()
            raise
        process_stdout = stdout.strip()

        self.assertIn("{bucket_name}".format(bucket_name=self.s3_bucket.name),
                      process_stdout.decode("utf-8"))

        self.assertIn("{s3_prefix}".format(s3_prefix=s3_prefix),
                      process_stdout.decode("utf-8"))

    @parameterized.expand([
        "aws-serverless-function.yaml",
        "aws-serverless-api.yaml",
        "aws-serverless-httpapi.yaml",
        "aws-appsync-graphqlschema.yaml",
        "aws-appsync-resolver.yaml",
        "aws-appsync-functionconfiguration.yaml",
        "aws-lambda-function.yaml",
        "aws-apigateway-restapi.yaml",
        "aws-elasticbeanstalk-applicationversion.yaml",
        "aws-cloudformation-stack.yaml",
        "aws-serverless-application.yaml",
        "aws-lambda-layerversion.yaml",
        "aws-serverless-layerversion.yaml",
        "aws-glue-job.yaml",
        "aws-serverlessrepo-application.yaml",
        "aws-serverless-statemachine.yaml",
        "aws-stepfunctions-statemachine.yaml",
    ])
    def test_package_with_output_template_file(self, template_file):
        template_path = self.test_data_path.joinpath(template_file)
        s3_prefix = "integ_test_prefix"

        with tempfile.NamedTemporaryFile(delete=False) as output_template:

            command_list = self.get_command_list(
                s3_bucket=self.s3_bucket.name,
                template_file=template_path,
                s3_prefix=s3_prefix,
                output_template_file=output_template.name,
            )

            process = Popen(command_list, stdout=PIPE)
            try:
                stdout, _ = process.communicate(timeout=TIMEOUT)
            except TimeoutExpired:
                process.kill()
                raise
            process_stdout = stdout.strip()

            self.assertIn(
                bytes(
                    "Successfully packaged artifacts and wrote output template to file {output_template_file}"
                    .format(output_template_file=str(output_template.name)),
                    encoding="utf-8",
                ),
                process_stdout,
            )

    @parameterized.expand([
        "aws-serverless-function.yaml",
        "aws-serverless-api.yaml",
        "aws-serverless-httpapi.yaml",
        "aws-appsync-graphqlschema.yaml",
        "aws-appsync-resolver.yaml",
        "aws-appsync-functionconfiguration.yaml",
        "aws-lambda-function.yaml",
        "aws-apigateway-restapi.yaml",
        "aws-elasticbeanstalk-applicationversion.yaml",
        "aws-cloudformation-stack.yaml",
        "aws-serverless-application.yaml",
        "aws-lambda-layerversion.yaml",
        "aws-serverless-layerversion.yaml",
        "aws-glue-job.yaml",
        "aws-serverlessrepo-application.yaml",
        "aws-serverless-statemachine.yaml",
        "aws-stepfunctions-statemachine.yaml",
    ])
    def test_package_with_json(self, template_file):
        template_path = self.test_data_path.joinpath(template_file)
        s3_prefix = "integ_test_prefix"

        with tempfile.NamedTemporaryFile(delete=False) as output_template:

            command_list = self.get_command_list(
                s3_bucket=self.s3_bucket.name,
                template_file=template_path,
                s3_prefix=s3_prefix,
                output_template_file=output_template.name,
                use_json=True,
            )

            process = Popen(command_list, stdout=PIPE)
            try:
                stdout, _ = process.communicate(timeout=TIMEOUT)
            except TimeoutExpired:
                process.kill()
                raise
            process_stdout = stdout.strip()

            self.assertIn(
                bytes(
                    "Successfully packaged artifacts and wrote output template to file {output_template_file}"
                    .format(output_template_file=str(output_template.name)),
                    encoding="utf-8",
                ),
                process_stdout,
            )

    @parameterized.expand([
        "aws-serverless-function.yaml",
        "aws-serverless-api.yaml",
        "aws-serverless-httpapi.yaml",
        "aws-appsync-graphqlschema.yaml",
        "aws-appsync-resolver.yaml",
        "aws-appsync-functionconfiguration.yaml",
        "aws-lambda-function.yaml",
        "aws-apigateway-restapi.yaml",
        "aws-elasticbeanstalk-applicationversion.yaml",
        "aws-cloudformation-stack.yaml",
        "aws-serverless-application.yaml",
        "aws-lambda-layerversion.yaml",
        "aws-serverless-layerversion.yaml",
        "aws-glue-job.yaml",
        "aws-serverlessrepo-application.yaml",
        "aws-serverless-statemachine.yaml",
        "aws-stepfunctions-statemachine.yaml",
    ])
    def test_package_with_force_upload(self, template_file):
        template_path = self.test_data_path.joinpath(template_file)
        s3_prefix = "integ_test_prefix"

        with tempfile.NamedTemporaryFile(delete=False) as output_template:
            # Upload twice and see the string to have packaged artifacts both times.
            for _ in range(2):

                command_list = self.get_command_list(
                    s3_bucket=self.s3_bucket.name,
                    template_file=template_path,
                    s3_prefix=s3_prefix,
                    output_template_file=output_template.name,
                    force_upload=True,
                )

                process = Popen(command_list, stdout=PIPE)
                try:
                    stdout, _ = process.communicate(timeout=TIMEOUT)
                except TimeoutExpired:
                    process.kill()
                    raise
                process_stdout = stdout.strip()

                self.assertIn(
                    bytes(
                        "Successfully packaged artifacts and wrote output template to file {output_template_file}"
                        .format(
                            output_template_file=str(output_template.name)),
                        encoding="utf-8",
                    ),
                    process_stdout,
                )

    @parameterized.expand([
        "aws-serverless-function.yaml",
        "aws-serverless-api.yaml",
        "aws-serverless-httpapi.yaml",
        "aws-appsync-graphqlschema.yaml",
        "aws-appsync-resolver.yaml",
        "aws-appsync-functionconfiguration.yaml",
        "aws-lambda-function.yaml",
        "aws-apigateway-restapi.yaml",
        "aws-elasticbeanstalk-applicationversion.yaml",
        "aws-cloudformation-stack.yaml",
        "aws-serverless-application.yaml",
        "aws-lambda-layerversion.yaml",
        "aws-serverless-layerversion.yaml",
        "aws-glue-job.yaml",
        "aws-serverlessrepo-application.yaml",
        "aws-serverless-statemachine.yaml",
        "aws-stepfunctions-statemachine.yaml",
    ])
    def test_package_with_kms_key(self, template_file):
        template_path = self.test_data_path.joinpath(template_file)
        s3_prefix = "integ_test_prefix"

        with tempfile.NamedTemporaryFile(delete=False) as output_template:
            command_list = self.get_command_list(
                s3_bucket=self.s3_bucket.name,
                template_file=template_path,
                s3_prefix=s3_prefix,
                output_template_file=output_template.name,
                force_upload=True,
                kms_key_id=self.kms_key,
            )

            process = Popen(command_list, stdout=PIPE)
            try:
                stdout, _ = process.communicate(timeout=TIMEOUT)
            except TimeoutExpired:
                process.kill()
                raise
            process_stdout = stdout.strip()

            self.assertIn(
                bytes(
                    "Successfully packaged artifacts and wrote output template to file {output_template_file}"
                    .format(output_template_file=str(output_template.name)),
                    encoding="utf-8",
                ),
                process_stdout,
            )

    @parameterized.expand([
        "aws-serverless-function.yaml",
        "aws-serverless-api.yaml",
        "aws-serverless-httpapi.yaml",
        "aws-appsync-graphqlschema.yaml",
        "aws-appsync-resolver.yaml",
        "aws-appsync-functionconfiguration.yaml",
        "aws-lambda-function.yaml",
        "aws-apigateway-restapi.yaml",
        "aws-elasticbeanstalk-applicationversion.yaml",
        "aws-cloudformation-stack.yaml",
        "aws-serverless-application.yaml",
        "aws-lambda-layerversion.yaml",
        "aws-serverless-layerversion.yaml",
        "aws-glue-job.yaml",
        "aws-serverlessrepo-application.yaml",
        "aws-serverless-statemachine.yaml",
        "aws-stepfunctions-statemachine.yaml",
    ])
    def test_package_with_metadata(self, template_file):
        template_path = self.test_data_path.joinpath(template_file)
        s3_prefix = "integ_test_prefix"

        with tempfile.NamedTemporaryFile(delete=False) as output_template:
            command_list = self.get_command_list(
                s3_bucket=self.s3_bucket.name,
                template_file=template_path,
                s3_prefix=s3_prefix,
                output_template_file=output_template.name,
                force_upload=True,
                metadata={"integ": "yes"},
            )

            process = Popen(command_list, stdout=PIPE)
            try:
                stdout, _ = process.communicate(timeout=TIMEOUT)
            except TimeoutExpired:
                process.kill()
                raise
            process_stdout = stdout.strip()

            self.assertIn(
                bytes(
                    "Successfully packaged artifacts and wrote output template to file {output_template_file}"
                    .format(output_template_file=str(output_template.name)),
                    encoding="utf-8",
                ),
                process_stdout,
            )

    @parameterized.expand([
        "aws-serverless-function.yaml",
        "aws-serverless-api.yaml",
        "aws-appsync-graphqlschema.yaml",
        "aws-appsync-resolver.yaml",
        "aws-appsync-functionconfiguration.yaml",
        "aws-lambda-function.yaml",
        "aws-apigateway-restapi.yaml",
        "aws-elasticbeanstalk-applicationversion.yaml",
        "aws-cloudformation-stack.yaml",
        "aws-serverless-application.yaml",
        "aws-lambda-layerversion.yaml",
        "aws-serverless-layerversion.yaml",
        "aws-glue-job.yaml",
        "aws-serverlessrepo-application.yaml",
        "aws-serverless-statemachine.yaml",
        "aws-stepfunctions-statemachine.yaml",
    ])
    def test_package_with_resolve_s3(self, template_file):
        template_path = self.test_data_path.joinpath(template_file)
        s3_prefix = "integ_test_prefix"

        with tempfile.NamedTemporaryFile(delete=False) as output_template:
            command_list = self.get_command_list(
                template_file=template_path,
                s3_prefix=s3_prefix,
                output_template_file=output_template.name,
                force_upload=True,
                resolve_s3=True,
            )

            process = Popen(command_list, stdout=PIPE)
            try:
                stdout, _ = process.communicate(timeout=TIMEOUT)
            except TimeoutExpired:
                process.kill()
                raise
            process_stdout = stdout.strip()

            self.assertIn(
                bytes(
                    "Successfully packaged artifacts and wrote output template to file {output_template_file}"
                    .format(output_template_file=str(output_template.name)),
                    encoding="utf-8",
                ),
                process_stdout,
            )

    @parameterized.expand([(True, ), (False, )])
    def test_package_with_no_progressbar(self, no_progressbar):
        template_path = self.test_data_path.joinpath(
            "aws-serverless-function.yaml")
        s3_prefix = "integ_test_prefix"

        with tempfile.NamedTemporaryFile(delete=False) as output_template:
            command_list = self.get_command_list(
                template_file=template_path,
                s3_prefix=s3_prefix,
                output_template_file=output_template.name,
                force_upload=True,
                no_progressbar=no_progressbar,
                resolve_s3=True,
            )

            process = Popen(command_list, stdout=PIPE, stderr=PIPE)
            try:
                _, stderr = process.communicate(timeout=TIMEOUT)
            except TimeoutExpired:
                process.kill()
                raise
            process_stderr = stderr.strip()

            upload_message = bytes("Uploading to", encoding="utf-8")
            if no_progressbar:
                self.assertNotIn(
                    upload_message,
                    process_stderr,
                )
            else:
                self.assertIn(
                    upload_message,
                    process_stderr,
                )

    @parameterized.expand([
        param("aws-serverless-function-codedeploy-warning.yaml", "CodeDeploy"),
        param("aws-serverless-function-codedeploy-condition-warning.yaml",
              "CodeDeploy DeploymentGroups"),
    ])
    def test_package_with_warning_template(self, template_file,
                                           warning_keyword):
        template_path = self.test_data_path.joinpath(template_file)
        command_list = self.get_command_list(s3_bucket=self.s3_bucket.name,
                                             template=template_path)

        process = Popen(command_list, stdout=PIPE)
        try:
            stdout, _ = process.communicate(timeout=TIMEOUT)
        except TimeoutExpired:
            process.kill()
            raise
        process_stdout = stdout.strip().decode("utf-8")

        # Not comparing with full warning message because of line ending mismatch on
        # windows and non-windows
        self.assertIn(warning_keyword, process_stdout)
Esempio n. 17
0
import sys
import time
import random
import subprocess
import requests
import _thread
from functools import wraps
from panda import Panda
from nose.tools import assert_equal
from parameterized import parameterized, param

SPEED_NORMAL = 500
SPEED_GMLAN = 33.3

test_all_types = parameterized([
    param(panda_type=Panda.HW_TYPE_WHITE_PANDA),
    param(panda_type=Panda.HW_TYPE_GREY_PANDA),
    param(panda_type=Panda.HW_TYPE_BLACK_PANDA)
])
test_all_pandas = parameterized(Panda.list())
test_white_and_grey = parameterized([
    param(panda_type=Panda.HW_TYPE_WHITE_PANDA),
    param(panda_type=Panda.HW_TYPE_GREY_PANDA)
])
test_white = parameterized([param(panda_type=Panda.HW_TYPE_WHITE_PANDA)])
test_grey = parameterized([param(panda_type=Panda.HW_TYPE_GREY_PANDA)])
test_two_panda = parameterized([
    param(panda_type=[Panda.HW_TYPE_GREY_PANDA, Panda.HW_TYPE_WHITE_PANDA]),
    param(panda_type=[Panda.HW_TYPE_WHITE_PANDA, Panda.HW_TYPE_GREY_PANDA]),
    param(panda_type=[Panda.HW_TYPE_BLACK_PANDA, Panda.HW_TYPE_BLACK_PANDA])
])
                                check_scaled_coefficients,
                                check_generated_output,
                                check_run_experiment,
                                do_run_experiment)

# allow test directory to be set via an environment variable
# which is needed for package testing
TEST_DIR = os.environ.get('TESTDIR', None)
if TEST_DIR:
    rsmtool_test_dir = TEST_DIR
else:
    from rsmtool.test_utils import rsmtool_test_dir


@parameterized([
    param('lr-no-standardization', 'lr_no_standardization'),
    param('lr-exclude-test-flags', 'lr_exclude_test_flags'),
    param('lr-exclude-train-and-test-flags', 'lr_exclude_train_and_test_flags'),
    param('lr-with-sas', 'lr_with_sas'),
    param('lr-with-xlsx-output', 'lr_with_xlsx_output', file_format='xlsx'),
    param('lr-with-tsv-output', 'lr_with_tsv_output', file_format='tsv'),
    param('lr-with-thumbnails', 'lr_with_thumbnails'),
    param('lr-with-thumbnails-subgroups', 'lr_with_thumbnails_subgroups', subgroups=['L1']),
    param('lr-with-feature-list', 'lr_with_feature_list'),
    param('lr-with-length-non-numeric', 'lr_with_length_non_numeric'),
    param('lr-with-feature-list-and-transformation', 'lr_with_feature_list_and_transformation')
])
def test_run_experiment_parameterized(*args, **kwargs):
    if TEST_DIR:
        kwargs['given_test_dir'] = TEST_DIR
    check_run_experiment(*args, **kwargs)
Esempio n. 19
0
class TestDateRangeFunction(BaseTestCase):
    def setUp(self):
        super(TestDateRangeFunction, self).setUp()
        self.result = NotImplemented

    @parameterized.expand([
        param(begin=datetime(2014, 6, 15),
              end=datetime(2014, 6, 25),
              expected_length=10)
    ])
    def test_date_range(self, begin, end, expected_length):
        self.when_date_range_generated(begin, end)
        self.then_range_length_is(expected_length)
        self.then_all_dates_in_range_are_present(begin, end)
        self.then_range_is_in_ascending_order()

    @parameterized.expand([
        param(begin=datetime(2014, 4, 15),
              end=datetime(2014, 6, 25),
              expected_months=[(2014, 4), (2014, 5), (2014, 6)]),
        param(begin=datetime(2014, 4, 25),
              end=datetime(2014, 5, 5),
              expected_months=[(2014, 4), (2014, 5)]),
        param(begin=datetime(2014, 4, 5),
              end=datetime(2014, 4, 25),
              expected_months=[(2014, 4)]),
        param(begin=datetime(2014, 4, 25),
              end=datetime(2014, 6, 5),
              expected_months=[(2014, 4), (2014, 5), (2014, 6)]),
    ])
    def test_one_date_for_each_month(self, begin, end, expected_months):
        self.when_date_range_generated(begin, end, months=1)
        self.then_expected_months_are(expected_months)

    @parameterized.expand([
        'year',
        'month',
        'week',
        'day',
        'hour',
        'minute',
        'second',
    ])
    def test_should_reject_easily_mistaken_dateutil_arguments(
            self, invalid_period):
        self.when_date_range_generated(begin=datetime(2014, 6, 15),
                                       end=datetime(2014, 6, 25),
                                       **{invalid_period: 1})
        self.then_period_was_rejected(invalid_period)

    def when_date_range_generated(self, begin, end, **size):
        try:
            self.result = list(date.date_range(begin, end, **size))
        except Exception as error:
            self.error = error

    def then_expected_months_are(self, expected):
        self.assertEqual(expected, [(d.year, d.month) for d in self.result])

    def then_range_length_is(self, expected_length):
        self.assertEqual(expected_length, len(self.result))

    def then_all_dates_in_range_are_present(self, begin, end):
        date_under_test = begin
        while date_under_test < end:
            self.assertIn(date_under_test, self.result)
            date_under_test += timedelta(days=1)

    def then_range_is_in_ascending_order(self):
        for i in range(len(self.result) - 1):
            self.assertLess(self.result[i], self.result[i + 1])

    def then_period_was_rejected(self, period):
        self.then_error_was_raised(ValueError,
                                   ['Invalid argument: {}'.format(period)])
from rsmtool.test_utils import (check_file_output,
                                check_report,
                                check_run_evaluation,
                                do_run_evaluation)

# allow test directory to be set via an environment variable
# which is needed for package testing
TEST_DIR = os.environ.get('TESTDIR', None)
if TEST_DIR:
    rsmtool_test_dir = TEST_DIR
else:
    from rsmtool.test_utils import rsmtool_test_dir


@parameterized([
    param('lr-eval', 'lr_evaluation'),
    param('lr-eval-with-scaling', 'lr_evaluation_with_scaling'),
    param('lr-eval-exclude-listwise', 'lr_eval_exclude_listwise', subgroups=['QUESTION', 'L1']),
    param('lr-eval-exclude-flags', 'lr_eval_exclude_flags'),
    param('lr-eval-with-missing-scores', 'lr_eval_with_missing_scores', subgroups=['QUESTION', 'L1']),
    param('lr-eval-with-missing-data', 'lr_eval_with_missing_data', subgroups=['QUESTION', 'L1']),
    param('lr-eval-with-custom-order', 'lr_eval_with_custom_order', consistency=True),
    param('lr-eval-with-custom-sections', 'lr_eval_with_custom_sections'),
    param('lr-eval-with-custom-sections-and-order', 'lr_eval_with_custom_sections_and_order', subgroups=['QUESTION', 'L1']),
    param('lr-eval-tsv-input-files', 'lr_eval_tsv_input_files'),
    param('lr-eval-xlsx-input-files', 'lr_eval_xlsx_input_files'),
    param('lr-eval-with-tsv-output', 'lr_eval_with_tsv_output', file_format='tsv'),
    param('lr-eval-with-xlsx-output', 'lr_eval_with_xlsx_output', file_format='xlsx'),
    param('lr-eval-with-h2', 'lr_eval_with_h2', subgroups=['QUESTION', 'L1'], consistency=True),
    param('lr-eval-with-h2-named-sc1', 'lr_eval_with_h2_named_sc1', consistency=True),
    param('lr-eval-with-scaling-and-h2-keep-zeros', 'lr_eval_with_scaling_and_h2_keep_zeros', consistency=True),
Esempio n. 21
0
class TestGetIntersectingPeriodsFunction(BaseTestCase):
    def setUp(self):
        super(TestGetIntersectingPeriodsFunction, self).setUp()
        self.result = NotImplemented

    @parameterized.expand([
        param(low=datetime(2014, 6, 15), high=datetime(2014, 6, 16), length=1)
    ])
    def test_date_arguments_and_date_range_with_default_post_days(
            self, low, high, length):
        self.when_intersecting_period_calculated(low, high, period_size='day')
        self.then_all_dates_in_range_are_present(begin=low, end=high)
        self.then_date_range_length_is(length)

    @parameterized.expand([
        param(low=datetime(2014, 4, 15),
              high=datetime(2014, 6, 25),
              expected_results=[
                  datetime(2014, 4, 1),
                  datetime(2014, 5, 1),
                  datetime(2014, 6, 1)
              ]),
        param(low=datetime(2014, 4, 25),
              high=datetime(2014, 5, 5),
              expected_results=[datetime(2014, 4, 1),
                                datetime(2014, 5, 1)]),
        param(low=datetime(2014, 4, 5),
              high=datetime(2014, 4, 25),
              expected_results=[datetime(2014, 4, 1)]),
        param(low=datetime(2014, 4, 25),
              high=datetime(2014, 6, 5),
              expected_results=[
                  datetime(2014, 4, 1),
                  datetime(2014, 5, 1),
                  datetime(2014, 6, 1)
              ]),
        param(low=datetime(2014, 4, 25),
              high=datetime(2014, 4, 25),
              expected_results=[]),
        param(low=datetime(2014, 12, 31),
              high=datetime(2015, 1, 1),
              expected_results=[datetime(2014, 12, 1)]),
    ])
    def test_dates_in_intersecting_period_should_use_first_day_when_period_is_month(
            self, low, high, expected_results):
        self.when_intersecting_period_calculated(low,
                                                 high,
                                                 period_size='month')
        self.then_results_are(expected_results)

    @parameterized.expand([
        param(low=datetime(2012, 4, 18),
              high=datetime(2014, 9, 22),
              expected_results=[
                  datetime(2012, 1, 1, 0, 0),
                  datetime(2013, 1, 1, 0, 0),
                  datetime(2014, 1, 1, 0, 0)
              ]),
        param(low=datetime(2013, 8, 5),
              high=datetime(2014, 5, 15),
              expected_results=[
                  datetime(2013, 1, 1, 0, 0),
                  datetime(2014, 1, 1, 0, 0)
              ]),
        param(low=datetime(2008, 4, 5),
              high=datetime(2010, 1, 1),
              expected_results=[
                  datetime(2008, 1, 1, 0, 0),
                  datetime(2009, 1, 1, 0, 0)
              ]),
        param(low=datetime(2014, 1, 1),
              high=datetime(2016, 8, 22),
              expected_results=[
                  datetime(2014, 1, 1, 0, 0),
                  datetime(2015, 1, 1, 0, 0),
                  datetime(2016, 1, 1, 0, 0)
              ]),
        param(low=datetime(2001, 7, 11),
              high=datetime(2001, 10, 16),
              expected_results=[datetime(2001, 1, 1, 0, 0)]),
        param(low=datetime(2017, 1, 1),
              high=datetime(2017, 1, 1),
              expected_results=[]),
    ])
    def test_dates_in_intersecting_period_should_use_first_month_and_first_day_when_period_is_year(
            self, low, high, expected_results):
        self.when_intersecting_period_calculated(low, high, period_size='year')
        self.then_results_are(expected_results)

    @parameterized.expand([
        param(low=datetime(2014, 4, 15),
              high=datetime(2014, 5, 15),
              period_size='month',
              expected_results=[datetime(2014, 4, 1),
                                datetime(2014, 5, 1)]),
        param(low=datetime(2014, 10, 30, 4, 30),
              high=datetime(2014, 11, 7, 5, 20),
              period_size='week',
              expected_results=[datetime(2014, 10, 27),
                                datetime(2014, 11, 3)]),
        param(low=datetime(2014, 8, 13, 13, 21),
              high=datetime(2014, 8, 14, 14, 7),
              period_size='day',
              expected_results=[datetime(2014, 8, 13),
                                datetime(2014, 8, 14)]),
        param(low=datetime(2014, 5, 11, 22, 4),
              high=datetime(2014, 5, 12, 0, 5),
              period_size='hour',
              expected_results=[
                  datetime(2014, 5, 11, 22, 0),
                  datetime(2014, 5, 11, 23, 0),
                  datetime(2014, 5, 12, 0, 0)
              ]),
        param(low=datetime(2014, 4, 25, 11, 11, 11),
              high=datetime(2014, 4, 25, 11, 12, 11),
              period_size='minute',
              expected_results=[
                  datetime(2014, 4, 25, 11, 11, 0),
                  datetime(2014, 4, 25, 11, 12, 0)
              ]),
        param(low=datetime(2014, 12, 31, 23, 59, 58, 500),
              high=datetime(2014, 12, 31, 23, 59, 59, 600),
              period_size='second',
              expected_results=[
                  datetime(2014, 12, 31, 23, 59, 58, 0),
                  datetime(2014, 12, 31, 23, 59, 59, 0)
              ]),
    ])
    def test_periods(self, low, high, period_size, expected_results):
        self.when_intersecting_period_calculated(low,
                                                 high,
                                                 period_size=period_size)
        self.then_results_are(expected_results)

    @parameterized.expand([
        param('years'),
        param('months'),
        param('days'),
        param('hours'),
        param('minutes'),
        param('seconds'),
        param('microseconds'),
        param('some_period'),
    ])
    def test_should_reject_easily_mistaken_dateutil_arguments(
            self, period_size):
        self.when_intersecting_period_calculated(low=datetime(2014, 6, 15),
                                                 high=datetime(2014, 6, 25),
                                                 period_size=period_size)
        self.then_error_was_raised(ValueError,
                                   ['Invalid period: ' + str(period_size)])

    @parameterized.expand([
        param(low=datetime(2014, 4, 15),
              high=datetime(2014, 4, 14),
              period_size='month'),
        param(low=datetime(2014, 4, 25),
              high=datetime(2014, 4, 25),
              period_size='month'),
    ])
    def test_empty_period(self, low, high, period_size):
        self.when_intersecting_period_calculated(low, high, period_size)
        self.then_period_is_empty()

    def when_intersecting_period_calculated(self, low, high, period_size):
        try:
            self.result = list(
                date.get_intersecting_periods(low, high, period=period_size))
        except Exception as error:
            self.error = error

    def then_results_are(self, expected_results):
        self.assertEqual(expected_results, self.result)

    def then_date_range_length_is(self, size):
        self.assertEqual(size, len(self.result))

    def then_all_dates_in_range_are_present(self, begin, end):
        date_under_test = begin
        while date_under_test < end:
            self.assertIn(date_under_test, self.result)
            date_under_test += timedelta(days=1)

    def then_period_is_empty(self):
        self.assertEqual([], self.result)
from rsmtool.test_utils import (check_file_output,
                                check_report,
                                check_run_summary,
                                do_run_summary)

# allow test directory to be set via an environment variable
# which is needed for package testing
TEST_DIR = os.environ.get('TESTDIR', None)
if TEST_DIR:
    rsmtool_test_dir = TEST_DIR
else:
    from rsmtool.test_utils import rsmtool_test_dir


@parameterized([
    param('lr-self-summary'),
    param('linearsvr-self-summary'),
    param('lr-self-eval-summary'),
    param('lr-self-summary-with-custom-sections'),
    param('lr-self-summary-with-tsv-inputs'),
    param('lr-self-summary-with-tsv-output', file_format='tsv'),
    param('lr-self-summary-with-xlsx-output', file_format='xlsx'),
    param('lr-self-summary-no-scaling')
])
def test_run_experiment_parameterized(*args, **kwargs):
    if TEST_DIR:
        kwargs['given_test_dir'] = TEST_DIR
    check_run_summary(*args, **kwargs)


def test_run_experiment_lr_summary_with_object():
Esempio n. 23
0
class InvalidSettingsTest(BaseTestCase):
    def setUp(self):
        super().setUp()

    def test_error_is_raised_when_none_is_passed_in_settings(self):
        test_func = apply_settings(test_function)
        with self.assertRaisesRegex(TypeError, r'Invalid.*None\}'):
            test_func(settings={'PREFER_DATES_FROM': None})

        with self.assertRaisesRegex(TypeError, r'Invalid.*None\}'):
            test_func(settings={'TIMEZONE': None})

        with self.assertRaisesRegex(TypeError, r'Invalid.*None\}'):
            test_func(settings={'TO_TIMEZONE': None})

    def test_error_is_raised_for_invalid_type_settings(self):
        test_func = apply_settings(test_function)
        try:
            test_func(settings=['current_period', False, 'current'])
        except Exception as error:
            self.error = error
            self.then_error_was_raised(TypeError, [
                "settings can only be either dict or instance of Settings class"
            ])

    def test_check_settings_wrong_setting_name(self):
        with self.assertRaisesRegex(SettingValidationError,
                                    r'.* is not a valid setting'):
            DateDataParser(settings={'AAAAA': 'foo'})

    @parameterized.expand([
        param('DATE_ORDER', 2, 'YYY', 'MDY'),
        param('TIMEZONE', False, '',
              'Europe/Madrid'),  # should we check valid timezones?
        param('TO_TIMEZONE', True, '',
              'Europe/Madrid'),  # should we check valid timezones?
        param('RETURN_AS_TIMEZONE_AWARE', 'false', '', True),
        param('PREFER_DAY_OF_MONTH', False, 'current_period', 'current'),
        param('PREFER_DATES_FROM', True, 'current', 'current_period'),
        param('RELATIVE_BASE', 'yesterday', '', datetime.now()),
        param('SKIP_TOKENS', 'foo', '', ['foo']),
        param('REQUIRE_PARTS', 'day', '', ['month', 'day']),
        param('PARSERS', 'absolute-time', '',
              ['absolute-time', 'no-spaces-time']),
        param('STRICT_PARSING', 'true', '', True),
        param('RETURN_TIME_AS_PERIOD', 'false', '', True),
        param('PREFER_LOCALE_DATE_ORDER', 'true', '', False),
        param('NORMALIZE', 'true', '', True),
        param('FUZZY', 'true', '', False),
        param('PREFER_LOCALE_DATE_ORDER', 'false', '', True),
    ])
    def test_check_settings(self, setting, wrong_type, wrong_value,
                            valid_value):
        with self.assertRaisesRegex(
                SettingValidationError, r'"{}" must be .*, not "{}".'.format(
                    setting,
                    type(wrong_type).__name__)):
            DateDataParser(settings={setting: wrong_type})

        if wrong_value:
            with self.assertRaisesRegex(
                    SettingValidationError,
                    r'"{}" is not a valid value for "{}", it should be: .*'.
                    format(
                        str(wrong_value).replace('[',
                                                 '\\[').replace(']', '\\]'),
                        setting)):
                DateDataParser(settings={setting: wrong_value})

        # check that a valid value doesn't raise an error
        assert DateDataParser(settings={setting: valid_value})

    def test_check_settings_extra_check_require_parts(self):
        with self.assertRaisesRegex(
                SettingValidationError,
                r'"REQUIRE_PARTS" setting contains invalid values: time'):
            DateDataParser(settings={'REQUIRE_PARTS': ['time', 'day']})
        with self.assertRaisesRegex(
                SettingValidationError,
                r'There are repeated values in the "REQUIRE_PARTS" setting'):
            DateDataParser(
                settings={'REQUIRE_PARTS': ['month', 'day', 'month']})

    def test_check_settings_extra_check_parsers(self):
        with self.assertRaisesRegex(
                SettingValidationError,
                r'Found unknown parsers in the "PARSERS" setting: no-spaces'):
            DateDataParser(
                settings={'PARSERS': ['absolute-time', 'no-spaces']})

        with self.assertRaisesRegex(
                SettingValidationError,
                r'There are repeated values in the "PARSERS" setting'):
            DateDataParser(settings={
                'PARSERS': ['absolute-time', 'timestamp', 'absolute-time']
            })
from parameterized import param, parameterized

from rsmtool.test_utils import (check_run_experiment,
                                do_run_experiment)

# allow test directory to be set via an environment variable
# which is needed for package testing
TEST_DIR = os.environ.get('TESTDIR', None)
if TEST_DIR:
    rsmtool_test_dir = TEST_DIR
else:
    from rsmtool.test_utils import rsmtool_test_dir


@parameterized([
    param('lr-with-h2-include-zeros', 'lr_with_h2_include_zeros', consistency=True),
    param('lr-with-h2-and-length', 'lr_with_h2_and_length', consistency=True),
    param('lr-with-h2-named-sc1', 'lr_with_h2_named_sc1', consistency=True),
    param('lars', 'Lars', skll=True),
    param('lars-custom-objective', 'Lars_custom_objective', skll=True),
    param('logistic-regression', 'LogisticRegression', skll=True),
    param('logistic-regression-custom-objective', 'LogisticRegression_custom_objective', skll=True),
    param('logistic-regression-expected-scores', 'LogisticRegression_expected_scores', skll=True),
    param('svc', 'SVC', skll=True),
    param('svc-custom-objective', 'SVC_custom_objective', skll=True),
    param('svc-expected-scores', 'SVC_expected_scores', skll=True),
    param('dummyregressor', 'DummyRegressor', skll=True),
    param('dummyregressor-custom-objective', 'DummyRegressor_custom_objective', skll=True),
    param('ridge', 'Ridge', skll=True),
    param('ridge-custom-objective', 'Ridge_custom_objective', skll=True),
    param('linearsvr', 'LinearSVR', skll=True),
Esempio n. 25
0
class TestCli(TestCase):
    def setUp(self):
        self.function_id = "id"
        self.template = "template"
        self.eventfile = "eventfile"
        self.env_vars = "env-vars"
        self.debug_port = 123
        self.debug_args = "args"
        self.debugger_path = "/test/path"
        self.docker_volume_basedir = "basedir"
        self.docker_network = "network"
        self.log_file = "logfile"
        self.skip_pull_image = True
        self.no_event = False
        self.parameter_overrides = {}
        self.layer_cache_basedir = "/some/layers/path"
        self.force_image_build = True
        self.region_name = "region"

    @patch("samcli.commands.local.invoke.cli.InvokeContext")
    @patch("samcli.commands.local.invoke.cli._get_event")
    def test_cli_must_setup_context_and_invoke(self, get_event_mock,
                                               InvokeContextMock):
        event_data = "data"
        get_event_mock.return_value = event_data

        ctx_mock = Mock()
        ctx_mock.region = self.region_name

        # Mock the __enter__ method to return a object inside a context manager
        context_mock = Mock()
        InvokeContextMock.return_value.__enter__.return_value = context_mock

        invoke_cli(ctx=ctx_mock,
                   function_identifier=self.function_id,
                   template=self.template,
                   event=self.eventfile,
                   no_event=self.no_event,
                   env_vars=self.env_vars,
                   debug_port=self.debug_port,
                   debug_args=self.debug_args,
                   debugger_path=self.debugger_path,
                   docker_volume_basedir=self.docker_volume_basedir,
                   docker_network=self.docker_network,
                   log_file=self.log_file,
                   skip_pull_image=self.skip_pull_image,
                   parameter_overrides=self.parameter_overrides,
                   layer_cache_basedir=self.layer_cache_basedir,
                   force_image_build=self.force_image_build)

        InvokeContextMock.assert_called_with(
            template_file=self.template,
            function_identifier=self.function_id,
            env_vars_file=self.env_vars,
            docker_volume_basedir=self.docker_volume_basedir,
            docker_network=self.docker_network,
            log_file=self.log_file,
            skip_pull_image=self.skip_pull_image,
            debug_port=self.debug_port,
            debug_args=self.debug_args,
            debugger_path=self.debugger_path,
            parameter_overrides=self.parameter_overrides,
            layer_cache_basedir=self.layer_cache_basedir,
            force_image_build=self.force_image_build,
            aws_region=self.region_name)

        context_mock.local_lambda_runner.invoke.assert_called_with(
            context_mock.function_name,
            event=event_data,
            stdout=context_mock.stdout,
            stderr=context_mock.stderr)
        get_event_mock.assert_called_with(self.eventfile)

    @patch("samcli.commands.local.invoke.cli.InvokeContext")
    @patch("samcli.commands.local.invoke.cli._get_event")
    def test_cli_must_invoke_with_no_event(self, get_event_mock,
                                           InvokeContextMock):
        self.no_event = True

        ctx_mock = Mock()
        ctx_mock.region = self.region_name

        # Mock the __enter__ method to return a object inside a context manager
        context_mock = Mock()
        InvokeContextMock.return_value.__enter__.return_value = context_mock
        invoke_cli(ctx=ctx_mock,
                   function_identifier=self.function_id,
                   template=self.template,
                   event=STDIN_FILE_NAME,
                   no_event=self.no_event,
                   env_vars=self.env_vars,
                   debug_port=self.debug_port,
                   debug_args=self.debug_args,
                   debugger_path=self.debugger_path,
                   docker_volume_basedir=self.docker_volume_basedir,
                   docker_network=self.docker_network,
                   log_file=self.log_file,
                   skip_pull_image=self.skip_pull_image,
                   parameter_overrides=self.parameter_overrides,
                   layer_cache_basedir=self.layer_cache_basedir,
                   force_image_build=self.force_image_build)

        InvokeContextMock.assert_called_with(
            template_file=self.template,
            function_identifier=self.function_id,
            env_vars_file=self.env_vars,
            docker_volume_basedir=self.docker_volume_basedir,
            docker_network=self.docker_network,
            log_file=self.log_file,
            skip_pull_image=self.skip_pull_image,
            debug_port=self.debug_port,
            debug_args=self.debug_args,
            debugger_path=self.debugger_path,
            parameter_overrides=self.parameter_overrides,
            layer_cache_basedir=self.layer_cache_basedir,
            force_image_build=self.force_image_build,
            aws_region=self.region_name)

        context_mock.local_lambda_runner.invoke.assert_called_with(
            context_mock.function_name,
            event="{}",
            stdout=context_mock.stdout,
            stderr=context_mock.stderr)
        get_event_mock.assert_not_called()

    @patch("samcli.commands.local.invoke.cli.InvokeContext")
    @patch("samcli.commands.local.invoke.cli._get_event")
    def test_must_raise_user_exception_on_no_event_and_event(
            self, get_event_mock, InvokeContextMock):
        self.no_event = True

        ctx_mock = Mock()
        ctx_mock.region = self.region_name

        with self.assertRaises(UserException) as ex_ctx:

            invoke_cli(ctx=ctx_mock,
                       function_identifier=self.function_id,
                       template=self.template,
                       event=self.eventfile,
                       no_event=self.no_event,
                       env_vars=self.env_vars,
                       debug_port=self.debug_port,
                       debug_args=self.debug_args,
                       debugger_path=self.debugger_path,
                       docker_volume_basedir=self.docker_volume_basedir,
                       docker_network=self.docker_network,
                       log_file=self.log_file,
                       skip_pull_image=self.skip_pull_image,
                       parameter_overrides=self.parameter_overrides,
                       layer_cache_basedir=self.layer_cache_basedir,
                       force_image_build=self.force_image_build)

        msg = str(ex_ctx.exception)
        self.assertEquals(
            msg,
            "no_event and event cannot be used together. Please provide only one."
        )

    @parameterized.expand([
        param(FunctionNotFound("not found"),
              "Function id not found in template"),
        param(DockerImagePullFailedException("Failed to pull image"),
              "Failed to pull image")
    ])
    @patch("samcli.commands.local.invoke.cli.InvokeContext")
    @patch("samcli.commands.local.invoke.cli._get_event")
    def test_must_raise_user_exception_on_function_not_found(
            self, side_effect_exception, expected_exectpion_message,
            get_event_mock, InvokeContextMock):
        event_data = "data"
        get_event_mock.return_value = event_data

        ctx_mock = Mock()
        ctx_mock.region = self.region_name

        # Mock the __enter__ method to return a object inside a context manager
        context_mock = Mock()
        InvokeContextMock.return_value.__enter__.return_value = context_mock

        context_mock.local_lambda_runner.invoke.side_effect = side_effect_exception

        with self.assertRaises(UserException) as ex_ctx:

            invoke_cli(ctx=ctx_mock,
                       function_identifier=self.function_id,
                       template=self.template,
                       event=self.eventfile,
                       no_event=self.no_event,
                       env_vars=self.env_vars,
                       debug_port=self.debug_port,
                       debug_args=self.debug_args,
                       debugger_path=self.debugger_path,
                       docker_volume_basedir=self.docker_volume_basedir,
                       docker_network=self.docker_network,
                       log_file=self.log_file,
                       skip_pull_image=self.skip_pull_image,
                       parameter_overrides=self.parameter_overrides,
                       layer_cache_basedir=self.layer_cache_basedir,
                       force_image_build=self.force_image_build)

        msg = str(ex_ctx.exception)
        self.assertEquals(msg, expected_exectpion_message)

    @parameterized.expand([
        (InvalidSamDocumentException("bad template"), "bad template"),
        (InvalidLayerReference(), "Layer References need to be of type "
         "'AWS::Serverless::LayerVersion' or 'AWS::Lambda::LayerVersion'"),
        (DebuggingNotSupported("Debugging not supported"),
         "Debugging not supported")
    ])
    @patch("samcli.commands.local.invoke.cli.InvokeContext")
    @patch("samcli.commands.local.invoke.cli._get_event")
    def test_must_raise_user_exception_on_invalid_sam_template(
            self, exeception_to_raise, execption_message, get_event_mock,
            InvokeContextMock):
        event_data = "data"
        get_event_mock.return_value = event_data

        ctx_mock = Mock()
        ctx_mock.region = self.region_name

        InvokeContextMock.side_effect = exeception_to_raise

        with self.assertRaises(UserException) as ex_ctx:

            invoke_cli(ctx=ctx_mock,
                       function_identifier=self.function_id,
                       template=self.template,
                       event=self.eventfile,
                       no_event=self.no_event,
                       env_vars=self.env_vars,
                       debug_port=self.debug_port,
                       debug_args=self.debug_args,
                       debugger_path=self.debugger_path,
                       docker_volume_basedir=self.docker_volume_basedir,
                       docker_network=self.docker_network,
                       log_file=self.log_file,
                       skip_pull_image=self.skip_pull_image,
                       parameter_overrides=self.parameter_overrides,
                       layer_cache_basedir=self.layer_cache_basedir,
                       force_image_build=self.force_image_build)

        msg = str(ex_ctx.exception)
        self.assertEquals(msg, execption_message)

    @patch("samcli.commands.local.invoke.cli.InvokeContext")
    @patch("samcli.commands.local.invoke.cli._get_event")
    def test_must_raise_user_exception_on_invalid_env_vars(
            self, get_event_mock, InvokeContextMock):
        event_data = "data"
        get_event_mock.return_value = event_data

        ctx_mock = Mock()
        ctx_mock.region = self.region_name

        InvokeContextMock.side_effect = OverridesNotWellDefinedError(
            "bad env vars")

        with self.assertRaises(UserException) as ex_ctx:

            invoke_cli(ctx=ctx_mock,
                       function_identifier=self.function_id,
                       template=self.template,
                       event=self.eventfile,
                       no_event=self.no_event,
                       env_vars=self.env_vars,
                       debug_port=self.debug_port,
                       debug_args=self.debug_args,
                       debugger_path=self.debugger_path,
                       docker_volume_basedir=self.docker_volume_basedir,
                       docker_network=self.docker_network,
                       log_file=self.log_file,
                       skip_pull_image=self.skip_pull_image,
                       parameter_overrides=self.parameter_overrides,
                       layer_cache_basedir=self.layer_cache_basedir,
                       force_image_build=self.force_image_build)

        msg = str(ex_ctx.exception)
        self.assertEquals(msg, "bad env vars")
Esempio n. 26
0
class TestBigQueryFileLoads(_TestCaseWithTempDirCleanUp):
    def test_records_traverse_transform_with_mocks(self):
        destination = 'project1:dataset1.table1'

        job_reference = bigquery_api.JobReference()
        job_reference.projectId = 'project1'
        job_reference.jobId = 'job_name1'
        result_job = bigquery_api.Job()
        result_job.jobReference = job_reference

        mock_job = mock.Mock()
        mock_job.status.state = 'DONE'
        mock_job.status.errorResult = None
        mock_job.jobReference = job_reference

        bq_client = mock.Mock()
        bq_client.jobs.Get.return_value = mock_job

        bq_client.jobs.Insert.return_value = result_job

        transform = bqfl.BigQueryBatchFileLoads(
            destination,
            custom_gcs_temp_location=self._new_tempdir(),
            test_client=bq_client,
            validate=False,
            temp_file_format=bigquery_tools.FileFormat.JSON)

        # Need to test this with the DirectRunner to avoid serializing mocks
        with TestPipeline('DirectRunner') as p:
            outputs = p | beam.Create(_ELEMENTS) | transform

            dest_files = outputs[
                bqfl.BigQueryBatchFileLoads.DESTINATION_FILE_PAIRS]
            dest_job = outputs[
                bqfl.BigQueryBatchFileLoads.DESTINATION_JOBID_PAIRS]

            jobs = dest_job | "GetJobs" >> beam.Map(lambda x: x[1])

            files = dest_files | "GetFiles" >> beam.Map(lambda x: x[1][0])
            destinations = (
                dest_files
                | "GetDests" >>
                beam.Map(lambda x:
                         (bigquery_tools.get_hashable_destination(x[0]), x[1]))
                | "GetUniques" >> combiners.Count.PerKey()
                | "GetFinalDests" >> beam.Keys())

            # All files exist
            _ = (files
                 | beam.Map(
                     lambda x: hamcrest_assert(os.path.exists(x), is_(True))))

            # One file per destination
            assert_that(files | combiners.Count.Globally(),
                        equal_to([1]),
                        label='CountFiles')

            assert_that(destinations,
                        equal_to([destination]),
                        label='CheckDestinations')

            assert_that(jobs, equal_to([job_reference]), label='CheckJobs')

    @mock.patch('time.sleep')
    def test_wait_for_job_completion(self, sleep_mock):
        job_references = [
            bigquery_api.JobReference(),
            bigquery_api.JobReference()
        ]
        job_references[0].projectId = 'project1'
        job_references[0].jobId = 'jobId1'
        job_references[1].projectId = 'project1'
        job_references[1].jobId = 'jobId2'

        job_1_waiting = mock.Mock()
        job_1_waiting.status.state = 'RUNNING'
        job_2_done = mock.Mock()
        job_2_done.status.state = 'DONE'
        job_2_done.status.errorResult = None

        job_1_done = mock.Mock()
        job_1_done.status.state = 'DONE'
        job_1_done.status.errorResult = None

        bq_client = mock.Mock()
        bq_client.jobs.Get.side_effect = [
            job_1_waiting, job_2_done, job_1_done, job_2_done
        ]

        waiting_dofn = bqfl.WaitForBQJobs(bq_client)

        dest_list = [(i, job) for i, job in enumerate(job_references)]

        with TestPipeline('DirectRunner') as p:
            references = beam.pvalue.AsList(
                p | 'job_ref' >> beam.Create(dest_list))
            outputs = (p | beam.Create([''])
                       | beam.ParDo(waiting_dofn, references))

            assert_that(outputs, equal_to(dest_list))

        sleep_mock.assert_called_once()

    @mock.patch('time.sleep')
    def test_one_job_failed_after_waiting(self, sleep_mock):
        job_references = [
            bigquery_api.JobReference(),
            bigquery_api.JobReference()
        ]
        job_references[0].projectId = 'project1'
        job_references[0].jobId = 'jobId1'
        job_references[1].projectId = 'project1'
        job_references[1].jobId = 'jobId2'

        job_1_waiting = mock.Mock()
        job_1_waiting.status.state = 'RUNNING'
        job_2_done = mock.Mock()
        job_2_done.status.state = 'DONE'
        job_2_done.status.errorResult = None

        job_1_error = mock.Mock()
        job_1_error.status.state = 'DONE'
        job_1_error.status.errorResult = 'Some problems happened'

        bq_client = mock.Mock()
        bq_client.jobs.Get.side_effect = [
            job_1_waiting, job_2_done, job_1_error, job_2_done
        ]

        waiting_dofn = bqfl.WaitForBQJobs(bq_client)

        dest_list = [(i, job) for i, job in enumerate(job_references)]

        with self.assertRaises(Exception):
            with TestPipeline('DirectRunner') as p:
                references = beam.pvalue.AsList(
                    p | 'job_ref' >> beam.Create(dest_list))
                _ = (p | beam.Create([''])
                     | beam.ParDo(waiting_dofn, references))

        sleep_mock.assert_called_once()

    def test_multiple_partition_files(self):
        destination = 'project1:dataset1.table1'

        job_reference = bigquery_api.JobReference()
        job_reference.projectId = 'project1'
        job_reference.jobId = 'job_name1'
        result_job = mock.Mock()
        result_job.jobReference = job_reference

        mock_job = mock.Mock()
        mock_job.status.state = 'DONE'
        mock_job.status.errorResult = None
        mock_job.jobReference = job_reference

        bq_client = mock.Mock()
        bq_client.jobs.Get.return_value = mock_job

        bq_client.jobs.Insert.return_value = result_job
        bq_client.tables.Delete.return_value = None

        with TestPipeline('DirectRunner') as p:
            outputs = (p
                       | beam.Create(_ELEMENTS, reshuffle=False)
                       | bqfl.BigQueryBatchFileLoads(
                           destination,
                           custom_gcs_temp_location=self._new_tempdir(),
                           test_client=bq_client,
                           validate=False,
                           temp_file_format=bigquery_tools.FileFormat.JSON,
                           max_file_size=45,
                           max_partition_size=80,
                           max_files_per_partition=2))

            dest_files = outputs[
                bqfl.BigQueryBatchFileLoads.DESTINATION_FILE_PAIRS]
            dest_load_jobs = outputs[
                bqfl.BigQueryBatchFileLoads.DESTINATION_JOBID_PAIRS]
            dest_copy_jobs = outputs[
                bqfl.BigQueryBatchFileLoads.DESTINATION_COPY_JOBID_PAIRS]

            load_jobs = dest_load_jobs | "GetLoadJobs" >> beam.Map(
                lambda x: x[1])
            copy_jobs = dest_copy_jobs | "GetCopyJobs" >> beam.Map(
                lambda x: x[1])

            files = dest_files | "GetFiles" >> beam.Map(lambda x: x[1][0])
            destinations = (
                dest_files
                | "GetDests" >>
                beam.Map(lambda x:
                         (bigquery_tools.get_hashable_destination(x[0]), x[1]))
                | "GetUniques" >> combiners.Count.PerKey()
                | "GetFinalDests" >> beam.Keys())

            # All files exist
            _ = (files
                 | beam.Map(
                     lambda x: hamcrest_assert(os.path.exists(x), is_(True))))

            # One file per destination
            assert_that(files | "CountFiles" >> combiners.Count.Globally(),
                        equal_to([6]),
                        label='CheckFileCount')

            assert_that(destinations,
                        equal_to([destination]),
                        label='CheckDestinations')

            assert_that(load_jobs
                        | "CountLoadJobs" >> combiners.Count.Globally(),
                        equal_to([6]),
                        label='CheckLoadJobCount')
            assert_that(copy_jobs
                        | "CountCopyJobs" >> combiners.Count.Globally(),
                        equal_to([6]),
                        label='CheckCopyJobCount')

    @parameterized.expand([
        param(is_streaming=False, with_auto_sharding=False),
        param(is_streaming=True, with_auto_sharding=False),
        param(is_streaming=True, with_auto_sharding=True),
    ])
    def test_triggering_frequency(self, is_streaming, with_auto_sharding):
        destination = 'project1:dataset1.table1'

        job_reference = bigquery_api.JobReference()
        job_reference.projectId = 'project1'
        job_reference.jobId = 'job_name1'
        result_job = bigquery_api.Job()
        result_job.jobReference = job_reference

        mock_job = mock.Mock()
        mock_job.status.state = 'DONE'
        mock_job.status.errorResult = None
        mock_job.jobReference = job_reference

        bq_client = mock.Mock()
        bq_client.jobs.Get.return_value = mock_job
        bq_client.jobs.Insert.return_value = result_job

        # Insert a fake clock to work with auto-sharding which needs a processing
        # time timer.
        class _FakeClock(object):
            def __init__(self, now=time.time()):
                self._now = now

            def __call__(self):
                return self._now

        start_time = timestamp.Timestamp(0)
        bq_client.test_clock = _FakeClock(now=start_time)

        triggering_frequency = 20 if is_streaming else None
        transform = bqfl.BigQueryBatchFileLoads(
            destination,
            custom_gcs_temp_location=self._new_tempdir(),
            test_client=bq_client,
            validate=False,
            temp_file_format=bigquery_tools.FileFormat.JSON,
            is_streaming_pipeline=is_streaming,
            triggering_frequency=triggering_frequency,
            with_auto_sharding=with_auto_sharding)

        # Need to test this with the DirectRunner to avoid serializing mocks
        with TestPipeline(
                runner='BundleBasedDirectRunner',
                options=StandardOptions(streaming=is_streaming)) as p:
            if is_streaming:
                _SIZE = len(_ELEMENTS)
                fisrt_batch = [
                    TimestampedValue(value, start_time + i + 1)
                    for i, value in enumerate(_ELEMENTS[:_SIZE // 2])
                ]
                second_batch = [
                    TimestampedValue(value, start_time + _SIZE // 2 + i + 1)
                    for i, value in enumerate(_ELEMENTS[_SIZE // 2:])
                ]
                # Advance processing time between batches of input elements to fire the
                # user triggers. Intentionally advance the processing time twice for the
                # auto-sharding case since we need to first fire the timer and then
                # fire the trigger.
                test_stream = (
                    TestStream().advance_watermark_to(start_time).add_elements(
                        fisrt_batch).advance_processing_time(30).
                    advance_processing_time(30).add_elements(second_batch).
                    advance_processing_time(30).advance_processing_time(
                        30).advance_watermark_to_infinity())
                input = p | test_stream
            else:
                input = p | beam.Create(_ELEMENTS)
            outputs = input | transform

            dest_files = outputs[
                bqfl.BigQueryBatchFileLoads.DESTINATION_FILE_PAIRS]
            dest_job = outputs[
                bqfl.BigQueryBatchFileLoads.DESTINATION_JOBID_PAIRS]

            files = dest_files | "GetFiles" >> beam.Map(lambda x: x[1][0])
            destinations = (
                dest_files
                | "GetDests" >>
                beam.Map(lambda x:
                         (bigquery_tools.get_hashable_destination(x[0]), x[1]))
                | "GetUniques" >> combiners.Count.PerKey()
                | "GetFinalDests" >> beam.Keys())
            jobs = dest_job | "GetJobs" >> beam.Map(lambda x: x[1])

            # Check that all files exist.
            _ = (files
                 | beam.Map(
                     lambda x: hamcrest_assert(os.path.exists(x), is_(True))))

            # Expect two load jobs are generated in the streaming case due to the
            # triggering frequency. Grouping is per trigger so we expect two entries
            # in the output as opposed to one.
            file_count = files | combiners.Count.Globally().without_defaults()
            expected_file_count = [1, 1] if is_streaming else [1]
            expected_destinations = [destination, destination
                                     ] if is_streaming else [destination]
            expected_jobs = [job_reference, job_reference
                             ] if is_streaming else [job_reference]
            assert_that(file_count,
                        equal_to(expected_file_count),
                        label='CountFiles')
            assert_that(destinations,
                        equal_to(expected_destinations),
                        label='CheckDestinations')
            assert_that(jobs, equal_to(expected_jobs), label='CheckJobs')
Esempio n. 27
0
class TestJI(unittest.TestCase):
    """Test the JobInfo Module"""
    def setUp(self):
        self.jbi = JobInfo(jobID=123,
                           status="Failed",
                           tID=1234,
                           tType="MCReconstruction")
        self.diracAPI = Mock(name="dilcMock",
                             spec=DIRAC.Interfaces.API.Dirac.Dirac)
        self.jobMon = Mock(name="jobMonMock",
                           spec=DIRAC.WorkloadManagementSystem.Client.
                           JobMonitoringClient.JobMonitoringClient)
        self.jobMon.getInputData = Mock(return_value=S_OK([]))
        self.jobMon.getJobAttribute = Mock(return_value=S_OK("0"))
        self.jobMon.getJobParameter = Mock(return_value=S_OK({}))
        self.diracAPI.getJobJDL = Mock()

        self.jdl2 = {
            "LogTargetPath":
            "/ilc/prod/clic/500gev/yyveyx_o/ILD/REC/00006326/LOG/00006326_015.tar",
            "Executable":
            "dirac-jobexec",
            "TaskID":
            15,
            "SoftwareDistModule":
            "ILCDIRAC.Core.Utilities.CombinedSoftwareInstallation",
            "JobName":
            "00006326_00000015",
            "Priority":
            1,
            "Platform":
            "x86_64-slc5-gcc43-opt",
            "JobRequirements": {
                "OwnerDN":
                "/DC=ch/DC=cern/OU=Organic Units/OU=Users/CN=sailer/CN=683529/CN=Andre Sailer",
                "VirtualOrganization":
                "ilc",
                "Setup":
                "ILC-Production",
                "CPUTime":
                300000,
                "OwnerGroup":
                "ilc_prod",
                "Platforms": [
                    "x86_64-slc6-gcc44-opt",
                    "x86_64-slc5-gcc43-opt",
                    "slc5_amd64_gcc43",
                    "Linux_x86_64_glibc-2.12",
                    "Linux_x86_64_glibc-2.5",
                ],
                "UserPriority":
                1,
                "Sites": [
                    "LCG.LAPP.fr",
                    "LCG.UKI-SOUTHGRID-RALPP.uk",
                ],
                "BannedSites":
                "LCG.KEK.jp",
                "JobTypes":
                "MCReconstruction_Overlay",
            },
            "Arguments":
            "jobDescription.xml -o LogLevel=verbose",
            "SoftwarePackages": [
                "overlayinput.1",
                "marlin.v0111Prod",
            ],
            "DebugLFNs":
            "",
            "Status":
            "Created",
            "InputDataModule":
            "DIRAC.WorkloadManagementSystem.Client.InputDataResolution",
            "BannedSites":
            "LCG.KEK.jp",
            "LogLevel":
            "verbose",
            "InputSandbox": [
                "jobDescription.xml",
                "SB:ProductionSandboxSE2|/SandBox/i/ilc_prod/5d3/92f/5d392f5266a796018ab6774ef84cbd31.tar.bz2",
            ],
            "OwnerName":
            "sailer",
            "StdOutput":
            "std.out",
            "JobType":
            "MCReconstruction_Overlay",
            "GridEnv":
            "/cvmfs/grid.cern.ch/emi-ui-3.7.3-1_sl6v2/etc/profile.d/setup-emi3-ui-example",
            "TransformationID":
            6326,
            "DIRACSetup":
            "ILC-Production",
            "StdError":
            "std.err",
            "IS_PROD":
            "True",
            "OwnerDN":
            "/DC=ch/DC=cern/OU=Organic Units/OU=Users/CN=sailer/CN=683529/CN=Andre Sailer",
            "JobGroup":
            0o0006326,
            "OutputSandbox": [
                "std.err",
                "std.out",
            ],
            "JobID":
            15756436,
            "VirtualOrganization":
            "ilc",
            "ProductionOutputData": [
                "/ilc/prod/clic/500gev/yyveyx_o/ILD/REC/00006326/000/yyveyx_o_rec_6326_15.slcio",
                "/ilc/prod/clic/500gev/yyveyx_o/ILD/DST/00006326/000/yyveyx_o_dst_6326_15.slcio",
            ],
            "Site":
            "ANY",
            "OwnerGroup":
            "ilc_prod",
            "Owner":
            "sailer",
            "LogFilePath":
            "/ilc/prod/clic/500gev/yyveyx_o/ILD/REC/00006326/LOG/000",
            "InputData":
            "/ilc/prod/clic/500gev/yyveyx_o/ILD/SIM/00006325/000/yyveyx_o_sim_6325_17.slcio",
        }

        self.jdlBrokenContent = {
            "LogTargetPath":
            "/ilc/prod/clic/500gev/yyveyx_o/ILD/REC/00006326/LOG/00006326_015.tar",
            "Executable":
            "dirac-jobexec",
            "TaskID":
            "muahahaha",
            "SoftwareDistModule":
            "ILCDIRAC.Core.Utilities.CombinedSoftwareInstallation",
            "JobName":
            "00006326_00000015",
            "Priority":
            1,
            "Platform":
            "x86_64-slc5-gcc43-opt",
            "JobRequirements": {
                "OwnerDN":
                "/DC=ch/DC=cern/OU=Organic Units/OU=Users/CN=sailer/CN=683529/CN=Andre Sailer",
                "VirtualOrganization":
                "ilc",
                "Setup":
                "ILC-Production",
                "CPUTime":
                300000,
                "OwnerGroup":
                "ilc_prod",
                "Platforms": [
                    "x86_64-slc6-gcc44-opt",
                    "x86_64-slc5-gcc43-opt",
                    "slc5_amd64_gcc43",
                    "Linux_x86_64_glibc-2.12",
                    "Linux_x86_64_glibc-2.5",
                ],
                "UserPriority":
                1,
                "Sites": [
                    "LCG.LAPP.fr",
                    "LCG.UKI-SOUTHGRID-RALPP.uk",
                ],
                "BannedSites":
                "LCG.KEK.jp",
                "JobTypes":
                "MCReconstruction_Overlay",
            },
            "Arguments":
            "jobDescription.xml -o LogLevel=verbose",
            "SoftwarePackages": [
                "overlayinput.1",
                "marlin.v0111Prod",
            ],
            "DebugLFNs":
            "",
            "Status":
            "Created",
            "InputDataModule":
            "DIRAC.WorkloadManagementSystem.Client.InputDataResolution",
            "BannedSites":
            "LCG.KEK.jp",
            "LogLevel":
            "verbose",
            "InputSandbox": [
                "jobDescription.xml",
                "SB:ProductionSandboxSE2|/SandBox/i/ilc_prod/5d3/92f/5d392f5266a796018ab6774ef84cbd31.tar.bz2",
            ],
            "OwnerName":
            "sailer",
            "StdOutput":
            "std.out",
            "JobType":
            "MCReconstruction_Overlay",
            "GridEnv":
            "/cvmfs/grid.cern.ch/emi-ui-3.7.3-1_sl6v2/etc/profile.d/setup-emi3-ui-example",
            "TransformationID":
            6326,
            "DIRACSetup":
            "ILC-Production",
            "StdError":
            "std.err",
            "IS_PROD":
            "True",
            "OwnerDN":
            "/DC=ch/DC=cern/OU=Organic Units/OU=Users/CN=sailer/CN=683529/CN=Andre Sailer",
            "JobGroup":
            0o0006326,
            "OutputSandbox": [
                "std.err",
                "std.out",
            ],
            "JobID":
            15756436,
            "VirtualOrganization":
            "ilc",
            "ProductionOutputData": [
                "/ilc/prod/clic/500gev/yyveyx_o/ILD/REC/00006326/000/yyveyx_o_rec_6326_15.slcio",
                "/ilc/prod/clic/500gev/yyveyx_o/ILD/DST/00006326/000/yyveyx_o_dst_6326_15.slcio",
            ],
            "Site":
            "ANY",
            "OwnerGroup":
            "ilc_prod",
            "Owner":
            "sailer",
            "LogFilePath":
            "/ilc/prod/clic/500gev/yyveyx_o/ILD/REC/00006326/LOG/000",
            "InputData":
            "/ilc/prod/clic/500gev/yyveyx_o/ILD/SIM/00006325/000/yyveyx_o_sim_6325_17.slcio",
        }

        # jdl with single outputdata,
        self.jdl1 = {
            "LogTargetPath":
            "/ilc/prod/clic/3tev/e1e1_o/SID/SIM/00006301/LOG/00006301_10256.tar",
            "Executable":
            "dirac-jobexec",
            "TaskID":
            10256,
            "SoftwareDistModule":
            "ILCDIRAC.Core.Utilities.CombinedSoftwareInstallation",
            "JobName":
            "00006301_00010256",
            "Priority":
            1,
            "Platform":
            "x86_64-slc5-gcc43-opt",
            "JobRequirements": {
                "OwnerDN":
                "/DC=ch/DC=cern/OU=Organic Units/OU=Users/CN=sailer/CN=683529/CN=Andre Sailer",
                "VirtualOrganization":
                "ilc",
                "Setup":
                "ILC-Production",
                "CPUTime":
                300000,
                "OwnerGroup":
                "ilc_prod",
                "Platforms": [
                    "x86_64-slc6-gcc44-opt",
                    "x86_64-slc5-gcc43-opt",
                    "slc5_amd64_gcc43",
                    "Linux_x86_64_glibc-2.12",
                    "Linux_x86_64_glibc-2.5",
                ],
                "UserPriority":
                1,
                "Sites": [
                    "LCG.LAPP.fr",
                    "LCG.UKI-SOUTHGRID-RALPP.uk",
                ],
                "BannedSites": [
                    "OSG.MIT.us",
                    "OSG.SPRACE.br",
                ],
                "JobTypes":
                "MCSimulation",
            },
            "Arguments":
            "jobDescription.xml -o LogLevel=verbose",
            "SoftwarePackages":
            "slic.v2r9p8",
            "DebugLFNs":
            "",
            "Status":
            "Created",
            "InputDataModule":
            "DIRAC.WorkloadManagementSystem.Client.InputDataResolution",
            "BannedSites": [
                "OSG.MIT.us",
                "OSG.SPRACE.br",
            ],
            "LogLevel":
            "verbose",
            "InputSandbox": [
                "jobDescription.xml",
                "SB:ProductionSandboxSE2|/SandBox/i/ilc_prod/042/d64/042d64cb0fe73720cbd114a73506c582.tar.bz2",
            ],
            "OwnerName":
            "sailer",
            "StdOutput":
            "std.out",
            "JobType":
            "MCSimulation",
            "GridEnv":
            "/cvmfs/grid.cern.ch/emi-ui-3.7.3-1_sl6v2/etc/profile.d/setup-emi3-ui-example",
            "TransformationID":
            6301,
            "DIRACSetup":
            "ILC-Production",
            "StdError":
            "std.err",
            "IS_PROD":
            "True",
            "OwnerDN":
            "/DC=ch/DC=cern/OU=Organic Units/OU=Users/CN=sailer/CN=683529/CN=Andre Sailer",
            "JobGroup":
            "00006301",
            "OutputSandbox": [
                "std.err",
                "std.out",
            ],
            "JobID":
            15756456,
            "VirtualOrganization":
            "ilc",
            "ProductionOutputData":
            "/ilc/prod/clic/3tev/e1e1_o/SID/SIM/00006301/010/e1e1_o_sim_6301_10256.slcio",
            "Site":
            "ANY",
            "OwnerGroup":
            "ilc_prod",
            "Owner":
            "sailer",
            "LogFilePath":
            "/ilc/prod/clic/3tev/e1e1_o/SID/SIM/00006301/LOG/010",
            "InputData":
            "/ilc/prod/clic/3tev/e1e1_o/gen/00006300/004/e1e1_o_gen_6300_4077.stdhep",
        }

        self.jdlNoInput = {
            "LogTargetPath":
            "/ilc/prod/clic/1.4tev/ea_qqqqnu/gen/00006498/LOG/00006498_1307.tar",
            "Executable":
            "dirac-jobexec",
            "TaskID":
            1307,
            "SoftwareDistModule":
            "ILCDIRAC.Core.Utilities.CombinedSoftwareInstallation",
            "JobName":
            "00006498_00001307",
            "Priority":
            1,
            "Platform":
            "x86_64-slc5-gcc43-opt",
            "JobRequirements": {
                "OwnerDN":
                "/DC=ch/DC=cern/OU=Organic Units/OU=Users/CN=sailer/CN=683529/CN=Andre Sailer",
                "VirtualOrganization":
                "ilc",
                "Setup":
                "ILC-Production",
                "CPUTime":
                300000,
                "OwnerGroup":
                "ilc_prod",
                "Platforms": [
                    "x86_64-slc6-gcc44-opt",
                    "x86_64-slc5-gcc43-opt",
                    "slc5_amd64_gcc43",
                    "Linux_x86_64_glibc-2.12",
                    "Linux_x86_64_glibc-2.5",
                ],
                "UserPriority":
                1,
                "BannedSites":
                "LCG.KEK.jp",
                "JobTypes":
                "MCGeneration",
            },
            "Arguments":
            "jobDescription.xml -o LogLevel=verbose",
            "SoftwarePackages":
            "whizard.SM_V57",
            "DebugLFNs":
            "",
            "Status":
            "Created",
            "InputDataModule":
            "DIRAC.WorkloadManagementSystem.Client.InputDataResolution",
            "BannedSites":
            "LCG.KEK.jp",
            "LogLevel":
            "verbose",
            "InputSandbox": [
                "jobDescription.xml",
                "SB:ProductionSandboxSE2|/SandBox/i/ilc_prod/b2a/d98/b2ad98c3e240361a4253c4bb277be478.tar.bz2",
            ],
            "OwnerName":
            "sailer",
            "StdOutput":
            "std.out",
            "JobType":
            "MCGeneration",
            "GridEnv":
            "/cvmfs/grid.cern.ch/emi-ui-3.7.3-1_sl6v2/etc/profile.d/setup-emi3-ui-example",
            "TransformationID":
            6498,
            "DIRACSetup":
            "ILC-Production",
            "StdError":
            "std.err",
            "IS_PROD":
            "True",
            "OwnerDN":
            "/DC=ch/DC=cern/OU=Organic Units/OU=Users/CN=sailer/CN=683529/CN=Andre Sailer",
            "JobGroup":
            "00006498",
            "OutputSandbox": [
                "std.err",
                "std.out",
            ],
            "JobID":
            15762268,
            "VirtualOrganization":
            "ilc",
            "ProductionOutputData":
            "/ilc/prod/clic/1.4tev/ea_qqqqnu/gen/00006498/001/ea_qqqqnu_gen_6498_1307.stdhep",
            "Site":
            "ANY",
            "OwnerGroup":
            "ilc_prod",
            "Owner":
            "sailer",
            "LogFilePath":
            "/ilc/prod/clic/1.4tev/ea_qqqqnu/gen/00006498/LOG/001",
            "InputData":
            "",
        }

    def tearDown(self):
        pass

    def test_Init(self):
        """Transformation.Utilities.JobInfo init ...................................................."""
        assert self.jbi.outputFiles == []
        self.assertFalse(self.jbi.pendingRequest)

    def test_allFilesExist(self):
        """Transformation.Utilities.JobInfo.allFilesExist............................................"""
        self.jbi.outputFileStatus = ["Exists", "Exists"]
        self.assertTrue(self.jbi.allFilesExist())
        self.jbi.outputFileStatus = ["Exists", "Missing"]
        self.assertFalse(self.jbi.allFilesExist())
        self.jbi.outputFileStatus = ["Missing", "Exists"]
        self.assertFalse(self.jbi.allFilesExist())
        self.jbi.outputFileStatus = ["Missing", "Missing"]
        self.assertFalse(self.jbi.allFilesExist())
        self.jbi.outputFileStatus = []
        self.assertFalse(self.jbi.allFilesExist())

    def test_allFilesMissing(self):
        """Transformation.Utilities.JobInfo.allFilesMissing.........................................."""
        self.jbi.outputFileStatus = ["Exists", "Exists"]
        self.assertFalse(self.jbi.allFilesMissing())
        self.jbi.outputFileStatus = ["Exists", "Missing"]
        self.assertFalse(self.jbi.allFilesMissing())
        self.jbi.outputFileStatus = ["Missing", "Exists"]
        self.assertFalse(self.jbi.allFilesMissing())
        self.jbi.outputFileStatus = ["Missing", "Missing"]
        self.assertTrue(self.jbi.allFilesMissing())
        self.jbi.outputFileStatus = []
        self.assertFalse(self.jbi.allFilesMissing())

    @parameterized.expand([
        ("someFilesMissing", "outputFileStatus", ["Exists", "Exists"], False),
        ("someFilesMissing", "outputFileStatus", ["Exists", "Missing"], True),
        ("someFilesMissing", "outputFileStatus", ["Missing", "Exists"], True),
        ("someFilesMissing", "outputFileStatus", ["Missing",
                                                  "Missing"], False),
        ("someFilesMissing", "outputFileStatus", [], False),
        ("allInputFilesExist", "inputFileStatus", ["Exists", "Exists"], True),
        ("allInputFilesExist", "inputFileStatus", ["Exists",
                                                   "Missing"], False),
        ("allInputFilesExist", "inputFileStatus", ["Missing",
                                                   "Missing"], False),
        ("allInputFilesExist", "inputFileStatus", [], False),
        ("allInputFilesMissing", "inputFileStatus", ["Exists",
                                                     "Exists"], False),
        ("allInputFilesMissing", "inputFileStatus", ["Exists",
                                                     "Missing"], False),
        ("allInputFilesMissing", "inputFileStatus", ["Missing",
                                                     "Missing"], True),
        ("allInputFilesMissing", "inputFileStatus", [], False),
        ("someInputFilesMissing", "inputFileStatus", ["Exists",
                                                      "Exists"], False),
        ("someInputFilesMissing", "inputFileStatus", ["Exists",
                                                      "Missing"], True),
        ("someInputFilesMissing", "inputFileStatus", ["Missing",
                                                      "Exists"], True),
        ("someInputFilesMissing", "inputFileStatus", ["Missing",
                                                      "Missing"], False),
        ("someInputFilesMissing", "inputFileStatus", [], False),
        ("allFilesProcessed", "transFileStatus", ["Processed",
                                                  "Processed"], True),
        ("allFilesProcessed", "transFileStatus", ["Processed",
                                                  "Assigned"], False),
        ("allFilesProcessed", "transFileStatus", ["Assigned",
                                                  "Assigned"], False),
        ("allFilesProcessed", "transFileStatus", ["Deleted",
                                                  "Deleted"], False),
        ("allFilesProcessed", "transFileStatus", ["Unused", "Unused"], False),
        ("allFilesProcessed", "transFileStatus", [], False),
        ("allFilesAssigned", "transFileStatus", ["Processed",
                                                 "Processed"], True),
        ("allFilesAssigned", "transFileStatus", ["Processed",
                                                 "Assigned"], True),
        ("allFilesAssigned", "transFileStatus", ["Assigned",
                                                 "Assigned"], True),
        ("allFilesAssigned", "transFileStatus", ["Assigned", "Unused"], False),
        ("allFilesAssigned", "transFileStatus", ["Deleted", "Deleted"], False),
        ("allFilesAssigned", "transFileStatus", ["Unused", "Unused"], False),
        ("allFilesAssigned", "transFileStatus", [], False),
        ("checkErrorCount", "errorCounts", [0, 9], False),
        ("checkErrorCount", "errorCounts", [0, 10], False),
        ("checkErrorCount", "errorCounts", [0, 11], True),
        ("checkErrorCount", "errorCounts", [0, 12], True),
        ("allTransFilesDeleted", "transFileStatus", ["Deleted",
                                                     "Deleted"], True),
        ("allTransFilesDeleted", "transFileStatus", ["Deleted",
                                                     "Assigned"], False),
        ("allTransFilesDeleted", "transFileStatus", ["Assigned",
                                                     "Deleted"], False),
        ("allTransFilesDeleted", "transFileStatus", ["Assigned",
                                                     "Assigned"], False),
    ])
    def test_fileChecker(self, func, attr, value, expected):
        setattr(self.jbi, attr, value)
        gLogger.notice(
            "%s, %s, %s, %s, %s" %
            (getattr(self.jbi, func)(), func, attr, value, expected))
        assert expected == getattr(self.jbi, func)()

    def test_getJDL(self):
        """Transformation.Utilities.JobInfo.getJDL..................................................."""

        self.diracAPI.getJobJDL.return_value = S_OK(self.jdl1)
        jdlList = self.jbi._JobInfo__getJDL(self.diracAPI)
        self.assertIsInstance(jdlList, dict)

        self.diracAPI.getJobJDL.return_value = S_ERROR("no mon")
        with self.assertRaises(RuntimeError) as contextManagedException:
            jdlList = self.jbi._JobInfo__getJDL(self.diracAPI)
        self.assertIn("Failed to get jobJDL",
                      str(contextManagedException.exception))

    def test_getTaskInfo_1(self):
        # task is only one
        wit = ["MCReconstruction"]
        self.jbi.taskID = 1234
        self.jbi.inputFiles = ["lfn"]
        tasksDict = {
            1234:
            [dict(FileID=123456, LFN="lfn", Status="Assigned", ErrorCount=7)]
        }
        lfnTaskDict = {}
        self.jbi.getTaskInfo(tasksDict, lfnTaskDict, wit)
        self.assertEqual(self.jbi.transFileStatus, ["Assigned"])
        self.assertEqual(self.jbi.otherTasks, [])

    def test_getTaskInfo_2(self):
        # there are other tasks
        wit = ["MCReconstruction"]
        self.jbi.taskID = 1234
        self.jbi.inputFiles = ["lfn"]
        tasksDict = {
            12:
            [dict(FileID=123456, LFN="lfn", Status="Processed", ErrorCount=7)]
        }
        lfnTaskDict = {"lfn": 12}
        self.jbi.getTaskInfo(tasksDict, lfnTaskDict, wit)
        self.assertEqual(self.jbi.transFileStatus, ["Processed"])
        self.assertEqual(self.jbi.otherTasks, [12])

    def test_getTaskInfo_3(self):
        # raise
        wit = ["MCReconstruction"]
        self.jbi.taskID = 1234
        self.jbi.inputFiles = ["otherLFN"]
        tasksDict = {
            1234: [
                dict(FileID=123456,
                     LFN="lfn",
                     Status="Processed",
                     ErrorCount=23)
            ]
        }
        lfnTaskDict = {}
        with self.assertRaisesRegex(TaskInfoException,
                                    "InputFiles do not agree"):
            self.jbi.getTaskInfo(tasksDict, lfnTaskDict, wit)

    # def test_getTaskInfo_4(self):
    #   # raise keyError
    #   wit = ['MCReconstruction']
    #   self.jbi.taskID = 1235
    #   self.jbi.inputFiles = []
    #   tasksDict = {1234: dict(FileID=123456, LFN="lfn", Status="Processed")}
    #   lfnTaskDict = {}
    #   with self.assertRaisesRegex(KeyError, ""):
    #     self.jbi.getTaskInfo(tasksDict, lfnTaskDict, wit)

    def test_getTaskInfo_5(self):
        # raise inputFile
        wit = ["MCReconstruction"]
        self.jbi.taskID = 1235
        self.jbi.inputFiles = []
        tasksDict = {1234: dict(FileID=123456, LFN="lfn", Status="Processed")}
        lfnTaskDict = {}
        with self.assertRaisesRegex(TaskInfoException, "InputFiles is empty"):
            self.jbi.getTaskInfo(tasksDict, lfnTaskDict, wit)

    def test_getJobInformation(self):
        """Transformation.Utilities.JobInfo.getJobInformation........................................"""
        self.diracAPI.getJobJDL.return_value = S_OK(self.jdl1)
        self.jbi.getJobInformation(self.diracAPI, self.jobMon)
        self.assertEqual(self.jbi.outputFiles, [
            "/ilc/prod/clic/3tev/e1e1_o/SID/SIM/00006301/010/e1e1_o_sim_6301_10256.slcio"
        ])
        self.assertEqual(10256, self.jbi.taskID)
        self.assertEqual(self.jbi.inputFiles, [
            "/ilc/prod/clic/3tev/e1e1_o/gen/00006300/004/e1e1_o_gen_6300_4077.stdhep"
        ])

        # empty jdl
        self.setUp()
        self.diracAPI.getJobJDL.return_value = S_OK({})
        self.jbi.getJobInformation(self.diracAPI, self.jobMon)
        self.assertEqual(self.jbi.outputFiles, [])
        self.assertIsNone(self.jbi.taskID)
        self.assertEqual(self.jbi.inputFiles, [])

    def test_getOutputFiles(self):
        """Transformation.Utilities.JobInfo.getOutputFiles..........................................."""
        # singleLFN
        self.diracAPI.getJobJDL.return_value = S_OK(self.jdl1)
        jdlList = self.jbi._JobInfo__getJDL(self.diracAPI)
        self.jbi._JobInfo__getOutputFiles(jdlList)
        self.assertEqual(self.jbi.outputFiles, [
            "/ilc/prod/clic/3tev/e1e1_o/SID/SIM/00006301/010/e1e1_o_sim_6301_10256.slcio"
        ])

        # two LFNs
        self.diracAPI.getJobJDL.return_value = S_OK(self.jdl2)
        jdlList = self.jbi._JobInfo__getJDL(self.diracAPI)
        self.jbi._JobInfo__getOutputFiles(jdlList)
        self.assertEqual(
            self.jbi.outputFiles,
            [
                "/ilc/prod/clic/500gev/yyveyx_o/ILD/REC/00006326/000/yyveyx_o_rec_6326_15.slcio",
                "/ilc/prod/clic/500gev/yyveyx_o/ILD/DST/00006326/000/yyveyx_o_dst_6326_15.slcio",
            ],
        )

    def test_getTaskID(self):
        """Transformation.Utilities.JobInfo.getTaskID................................................"""
        # singleLFN
        self.diracAPI.getJobJDL.return_value = S_OK(self.jdl1)
        jdlList = self.jbi._JobInfo__getJDL(self.diracAPI)
        self.jbi._JobInfo__getTaskID(jdlList)
        self.assertEqual(10256, self.jbi.taskID)

        # broken jdl
        out = StringIO()
        sys.stdout = out
        self.diracAPI.getJobJDL.return_value = S_OK(self.jdlBrokenContent)
        jdlList = self.jbi._JobInfo__getJDL(self.diracAPI)
        with self.assertRaises(ValueError):
            self.jbi._JobInfo__getTaskID(jdlList)

    def test_getInputFile(self):
        """Test the extraction of the inputFile from the JDL parameters."""
        # singleLFN
        self.jbi._JobInfo__getInputFile({"InputData": "/single/lfn2"})
        self.assertEqual(self.jbi.inputFiles, ["/single/lfn2"])

        # list with singleLFN
        self.jbi._JobInfo__getInputFile({"InputData": ["/single/lfn1"]})
        self.assertEqual(self.jbi.inputFiles, ["/single/lfn1"])

        # list with two LFN
        self.jbi._JobInfo__getInputFile({"InputData": ["/lfn1", "/lfn2"]})
        self.assertEqual(self.jbi.inputFiles, ["/lfn1", "/lfn2"])

    def test_checkFileExistence(self):
        """Transformation.Utilities.JobInfo.checkFileExistance......................................."""
        # input and output files
        repStatus = {
            "inputFile1": True,
            "inputFile2": False,
            "outputFile1": False,
            "outputFile2": True
        }
        self.jbi.inputFiles = ["inputFile1", "inputFile2", "inputFile3"]
        self.jbi.outputFiles = ["outputFile1", "outputFile2", "unknownFile"]
        self.jbi.checkFileExistence(repStatus)
        self.assertTrue(self.jbi.inputFilesExist[0])
        self.assertFalse(self.jbi.inputFilesExist[1])
        self.assertFalse(self.jbi.inputFilesExist[2])
        self.assertEqual(self.jbi.inputFileStatus,
                         ["Exists", "Missing", "Unknown"])
        self.assertEqual(self.jbi.outputFileStatus,
                         ["Missing", "Exists", "Unknown"])

        # just output files
        self.setUp()
        repStatus = {
            "inputFile": True,
            "outputFile1": False,
            "outputFile2": True
        }
        self.jbi.inputFiles = []
        self.jbi.outputFiles = ["outputFile1", "outputFile2", "unknownFile"]
        self.jbi.checkFileExistence(repStatus)
        self.assertEqual(self.jbi.outputFileStatus,
                         ["Missing", "Exists", "Unknown"])

    @parameterized.expand([
        param(
            [
                "123: Failed MCReconstruction Transformation: 1234 -- 5678 ",
                "inputFile (True, Assigned, Errors 0"
            ],
            [],
        ),
        param([
            "123: Failed MCReconstruction Transformation: 1234 -- 5678  (Last task [7777])"
        ], [],
              otherTasks=[7777]),
        param([], ["MCReconstruction Transformation"], trID=0, taID=0),
        param([], ["(Last task"], otherTasks=[]),
        param(
            ["PENDING REQUEST IGNORE THIS JOB"],
            [],
            pendingRequest=True,
        ),
        param(
            ["No Pending Requests"],
            [],
            pendingRequest=False,
        ),
    ])
    def test__str__(self,
                    asserts,
                    assertNots,
                    trID=1234,
                    taID=5678,
                    otherTasks=False,
                    pendingRequest=False):
        jbi = JobInfo(jobID=123,
                      status="Failed",
                      tID=trID,
                      tType="MCReconstruction")
        jbi.pendingRequest = pendingRequest
        jbi.otherTasks = otherTasks
        gLogger.notice("otherTasks: ", jbi.otherTasks)
        jbi.taskID = taID
        jbi.inputFiles = ["inputFile"]
        jbi.inputFilesExist = [True]
        jbi.transFileStatus = ["Assigned"]
        jbi.outputFiles = ["outputFile"]
        jbi.errorCounts = [0]
        info = str(jbi)
        for assertStr in asserts:
            self.assertIn(assertStr, info)
        for assertStr in assertNots:
            self.assertNotIn(assertStr, info)

    def test_TaskInfoException(self):
        """Transformation.Utilities.JobInfo.TaskInfoException........................................"""
        tie = TaskInfoException("notTasked")
        self.assertIsInstance(tie, Exception)
        self.assertIn("notTasked", str(tie))
Esempio n. 28
0
def _load_params(path):
    with open(path, 'r') as file:
        return [param(json.loads(line)) for line in file]
from rsmtool.configuration_parser import Configuration
from rsmtool.test_utils import (check_file_output, check_report,
                                check_run_evaluation, copy_data_files,
                                do_run_evaluation)

# allow test directory to be set via an environment variable
# which is needed for package testing
TEST_DIR = os.environ.get('TESTDIR', None)
if TEST_DIR:
    rsmtool_test_dir = TEST_DIR
else:
    from rsmtool.test_utils import rsmtool_test_dir


@parameterized([
    param('lr-eval', 'lr_evaluation'),
    param('lr-eval-with-scaling', 'lr_evaluation_with_scaling'),
    param('lr-eval-exclude-listwise',
          'lr_eval_exclude_listwise',
          subgroups=['QUESTION', 'L1']),
    param('lr-eval-exclude-flags', 'lr_eval_exclude_flags'),
    param('lr-eval-with-missing-scores',
          'lr_eval_with_missing_scores',
          subgroups=['QUESTION', 'L1']),
    param('lr-eval-with-missing-data',
          'lr_eval_with_missing_data',
          subgroups=['QUESTION', 'L1']),
    param('lr-eval-with-custom-order',
          'lr_eval_with_custom_order',
          consistency=True),
    param('lr-eval-with-custom-sections', 'lr_eval_with_custom_sections'),
Esempio n. 30
0
class TestAES(TestCase):
    """
    测试aes加解密
    """
    token = "thisisatoken"
    aes_key = "jWmYm7qr5nMoAUwZRjGtBxmz3KA1tkAj3ykkR6q2B2C"
    input = "hello world 哈喽 "
    receive_id = "P00000000023"

    def crypt(self):
        return WXBizMsgCrypt(self.token, self.aes_key, self.receive_id)

    def test_encode(self):
        """
        加密
        :return:
        """
        nonce = "EDH75AiKExevY8L3"
        timestamp = "1592916319"
        ret, msg_data = self.crypt().EncryptMsg(self.input, nonce, timestamp)
        assert ret == WXBizMsgCrypt_OK, "msg_data:" + msg_data
        print(msg_data)

    def test_encode_and_decode(self):
        """
        加解密
        :return:
        """
        nonce = "EDH75AiKExevY8L3"
        timestamp = "1592916319"
        ret, msg_data = self.crypt().EncryptMsg(self.input, nonce, timestamp)
        assert ret == WXBizMsgCrypt_OK, "ret:%d" % ret
        assert msg_data != None, "msg_data 不能是none"
        ret, msg_data = self.crypt().DecryptMsgData(msg_data)
        assert ret == WXBizMsgCrypt_OK, f"ret!={ret}"
        if isinstance(msg_data, (bytes, )):
            msg_data = msg_data.decode()  # 显示中文啊
        print(msg_data)

    @parameterized.expand([
        param(
            "CZWs4CWRpI4VolQlvn4dlPb2f0uQxokbSIZGwiT1u44MCk4o6Iw6R/zVkcFkIdMCxD99C7nyOfERckg+HnjgMw==",
            "d8a328f4957174415ac09446f690be0f491c0895", "1593325356",
            "EDH75AiKExevY8L3"),
        param(
            "RSZ3zTLyHduRfE+eJ07Mi5HFE/5TXGaXAumHoli36eFMpdblxY2dqR+Vhv7yfKCkTsZLSkIeecEGCVyMQ76u7A==",
            "ff310278354c75a4235c81d2f95862416f67dbee", "1593333518",
            "lvsomfeujrn6bhdg")
    ])
    def test_decode(self, encrypt, signature, timestamp, nonce):
        """
        解密指定数据
        {
            "Encrypt": "CZWs4CWRpI4VolQlvn4dlPb2f0uQxokbSIZGwiT1u44MCk4o6Iw6R/zVkcFkIdMCxD99C7nyOfERckg+HnjgMw==",
            "Signature": "d8a328f4957174415ac09446f690be0f491c0895",
             "TimeStamp": "1593325356",
             "Nonce": "EDH75AiKExevY8L3"
         }
        :return:
        """
        # encrypt = "CZWs4CWRpI4VolQlvn4dlPb2f0uQxokbSIZGwiT1u44MCk4o6Iw6R/zVkcFkIdMCxD99C7nyOfERckg+HnjgMw=="
        # signature = "d8a328f4957174415ac09446f690be0f491c0895"
        # timestamp = "1593325356"
        # nonce = "EDH75AiKExevY8L3"
        msg_data = MsgData(encrypt=encrypt,
                           signature=signature,
                           timestamp=timestamp,
                           nonce=nonce)
        ret, msg = self.crypt().DecryptMsgData(msg_data)
        assert ret == WXBizMsgCrypt_OK, f"ret!={ret}"
        print(f"{encrypt}解析:\n", msg.decode())
Esempio n. 31
0
from .densenet import DenseNet, MyDenseNet
from .test_resnet import name_func

KWARGS = {
    'growth': (4, 12, 40),
    'bottleneck': (True, False),
    'neck_size': (None, 16),
    'compression': (1, 0.5),
    'dropout': (0.0, 0.2),
}

FRWRD = [
    param(DenseNet,
          n=4,
          growth=4,
          bottleneck=False,
          compression=1,
          dropout=0.5),
    param(DenseNet, n=4, growth=4, bottleneck=True, compression=1,
          dropout=0.5),
    param(DenseNet, n=4, growth=4, bottleneck=True, compression=1,
          neck_size=4),
    param(DenseNet, n=4, growth=12, bottleneck=True, compression=1),
    param(DenseNet,
          n=16,
          growth=12,
          bottleneck=True,
          compression=1,
          neck_size=16),
    param(DenseNet, n=4, growth=12, bottleneck=True, compression=0.5),
    param(MyDenseNet, n=4, growth=4, bottleneck=True, compression=1),
Esempio n. 32
0
class TestAws(TestCase):
    @parameterized.expand([
        param('success', ),
    ])
    def test_rotate_root_iam_credentials(self,
                                         test_label,
                                         mount_point=DEFAULT_MOUNT_POINT):
        expected_status_code = 200
        mock_response = {"data": {"access_key": "AKIA..."}}
        aws = Aws(adapter=Request())
        mock_url = 'http://localhost:8200/v1/{mount_point}/config/rotate-root'.format(
            mount_point=mount_point, )
        logging.debug('Mocking URL: %s' % mock_url)
        with requests_mock.mock() as requests_mocker:
            requests_mocker.register_uri(
                method='POST',
                url=mock_url,
                status_code=expected_status_code,
                json=mock_response,
            )
            rotate_root_response = aws.rotate_root_iam_credentials(
                mount_point=mount_point, )
        logging.debug('rotate_root_response: %s' % rotate_root_response)
        self.assertEqual(
            first=mock_response,
            second=rotate_root_response,
        )

    @parameterized.expand([
        param('success', ),
        param('invalid endpoint',
              endpoint='cats',
              raises=ParamValidationError,
              exception_msg='cats'),
    ])
    def test_generate_credentials(self,
                                  test_label,
                                  role_name='hvac-test-role',
                                  mount_point=DEFAULT_MOUNT_POINT,
                                  endpoint='creds',
                                  raises=None,
                                  exception_msg=''):
        expected_status_code = 200
        mock_response = {
            "data": {
                "access_key": "AKIA...",
                "secret_key": "xlCs...",
                "security_token": None
            }
        }
        mock_url = 'http://localhost:8200/v1/{mount_point}/creds/{role_name}'.format(
            mount_point=mount_point,
            role_name=role_name,
        )
        logging.debug('Mocking URL: %s' % mock_url)
        aws = Aws(adapter=Request())
        with requests_mock.mock() as requests_mocker:
            requests_mocker.register_uri(
                method='POST',
                url=mock_url,
                status_code=expected_status_code,
                json=mock_response,
            )

            if raises:
                with self.assertRaises(raises) as cm:
                    aws.generate_credentials(
                        name=role_name,
                        endpoint=endpoint,
                        mount_point=mount_point,
                    )
                self.assertIn(
                    member=exception_msg,
                    container=str(cm.exception),
                )
            else:
                gen_creds_response = aws.generate_credentials(
                    name=role_name,
                    endpoint=endpoint,
                    mount_point=mount_point,
                )
                logging.debug('gen_creds_response: %s' % gen_creds_response)
                self.assertEqual(
                    first=mock_response,
                    second=gen_creds_response,
                )
Esempio n. 33
0
class TestGcp(utils.HvacIntegrationTestCase, TestCase):
    TEST_MOUNT_POINT = 'gcp-test'

    def setUp(self):
        super(TestGcp, self).setUp()
        if '%s/' % self.TEST_MOUNT_POINT not in self.client.list_auth_backends(
        ):
            self.client.enable_auth_backend(
                backend_type='gcp',
                mount_point=self.TEST_MOUNT_POINT,
            )

    def tearDown(self):
        super(TestGcp, self).tearDown()
        self.client.disable_auth_backend(mount_point=self.TEST_MOUNT_POINT, )

    @parameterized.expand([
        param('success', ),
        param(
            'set valid credentials',
            credentials=utils.load_test_data('example.jwt.json'),
        ),
        param(
            'set invalid credentials',
            credentials='some invalid JWT',
            raises=exceptions.InvalidRequest,
            exception_message='error reading google credentials from given JSON'
        ),
    ])
    def test_configure(self,
                       label,
                       credentials='',
                       raises=None,
                       exception_message=''):
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.gcp.auth.configure(
                    credentials=credentials,
                    mount_point=self.TEST_MOUNT_POINT,
                )
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            configure_response = self.client.gcp.auth.configure(
                credentials=credentials,
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('configure_response: %s' % configure_response)
            self.assertEqual(
                first=configure_response.status_code,
                second=204,
            )

    @parameterized.expand([
        param('success', ),
        param('no config written yet',
              write_config_first=False,
              raises=exceptions.InvalidPath)
    ])
    def test_read_config(self, label, write_config_first=True, raises=None):

        credentials = utils.load_test_data('example.jwt.json')
        if write_config_first:
            self.client.gcp.auth.configure(
                credentials=credentials,
                mount_point=self.TEST_MOUNT_POINT,
            )
        if raises is not None:
            with self.assertRaises(raises):
                self.client.gcp.auth.read_config(
                    mount_point=self.TEST_MOUNT_POINT, )
        else:
            read_config_response = self.client.gcp.auth.read_config(
                mount_point=self.TEST_MOUNT_POINT, )
            logging.debug('read_config_response: %s' % read_config_response)

            creds_dict = json.loads(credentials)
            expected_config = {
                'project_id': creds_dict['project_id'],
                'client_email': creds_dict['client_email'],
                'private_key_id': creds_dict['private_key_id'],
            }
            for k, v in expected_config.items():
                self.assertEqual(
                    first=v,
                    second=read_config_response[k],
                )

    @parameterized.expand([
        # param(
        #     'success',  # TODO: figure out why this is returning a 405
        # ),
        param(
            'no existing config',
            write_config_first=False,
            raises=exceptions.UnexpectedError,
        ),
    ])
    def test_delete_config(self, label, write_config_first=True, raises=None):

        if write_config_first:
            self.client.gcp.auth.configure(mount_point=self.TEST_MOUNT_POINT, )
        if raises is not None:
            with self.assertRaises(raises):
                self.client.gcp.auth.delete_config(
                    mount_point=self.TEST_MOUNT_POINT, )
        else:
            delete_config_response = self.client.gcp.auth.delete_config(
                mount_point=self.TEST_MOUNT_POINT, )
            logging.debug('delete_config_response: %s' %
                          delete_config_response)
            self.assertEqual(
                first=delete_config_response.status_code,
                second=204,
            )

    @parameterized.expand([
        param('success iam',
              role_type='iam',
              extra_params=dict(bound_service_accounts=['*'], )),
        param(
            'iam no bound service account',
            role_type='iam',
            raises=exceptions.InvalidRequest,
            exception_message=
            'IAM role type must have at least one service account',
        ),
        param(
            'success gce',
            role_type='gce',
        ),
        param(
            'invalid role type',
            role_type='hvac',
            raises=exceptions.ParamValidationError,
            exception_message='unsupported role_type argument provided',
        ),
        param(
            'wrong policy arg type',
            role_type='iam',
            policies='cats',
            raises=exceptions.ParamValidationError,
            exception_message='unsupported policies argument provided',
        )
    ])
    def test_create_role(self,
                         label,
                         role_type,
                         policies=None,
                         extra_params=None,
                         raises=None,
                         exception_message=''):
        role_name = 'hvac'
        project_id = 'test-hvac-project-not-a-real-project'
        if extra_params is None:
            extra_params = {}
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.gcp.auth.create_role(
                    name=role_name,
                    role_type=role_type,
                    project_id=project_id,
                    policies=policies,
                    mount_point=self.TEST_MOUNT_POINT,
                    **extra_params)
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            create_role_response = self.client.gcp.auth.create_role(
                name=role_name,
                role_type=role_type,
                project_id=project_id,
                policies=policies,
                mount_point=self.TEST_MOUNT_POINT,
                **extra_params)
            logging.debug('create_role_response: %s' % create_role_response)
            if utils.skip_if_vault_version_lt('0.10.0'):
                expected_status_code = 204
            else:
                expected_status_code = 200  # TODO => figure out why this isn't a 204?
            self.assertEqual(
                first=create_role_response.status_code,
                second=expected_status_code,
            )

    @parameterized.expand([
        param(
            'success add',
            add=['test'],
        ),
        param(
            'success remove',
            remove=['test'],
        ),
        param(
            'fail upon no changes',
            raises=exceptions.InvalidRequest,
            exception_message=
            'must provide at least one value to add or remove',
        ),
        # TODO: wrong role type (gce)
    ])
    def test_edit_service_accounts_on_iam_role(self,
                                               label,
                                               add=None,
                                               remove=None,
                                               create_role_first=True,
                                               raises=None,
                                               exception_message=''):
        role_name = 'hvac'
        project_id = 'test-hvac-project-not-a-real-project'
        if create_role_first:
            self.client.gcp.auth.create_role(
                name=role_name,
                role_type='iam',
                project_id=project_id,
                bound_service_accounts=[
                    '*****@*****.**'
                ],
                mount_point=self.TEST_MOUNT_POINT,
            )
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.gcp.auth.edit_service_accounts_on_iam_role(
                    name=role_name,
                    add=add,
                    remove=remove,
                    mount_point=self.TEST_MOUNT_POINT,
                )
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            edit_sa_on_iam_response = self.client.gcp.auth.edit_service_accounts_on_iam_role(
                name=role_name,
                add=add,
                remove=remove,
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('create_role_response: %s' % edit_sa_on_iam_response)
            if utils.skip_if_vault_version_lt('0.10.0'):
                expected_status_code = 204
            else:
                expected_status_code = 200  # TODO => figure out why this isn't a 204?
            self.assertEqual(
                first=edit_sa_on_iam_response.status_code,
                second=expected_status_code,
            )

    @parameterized.expand([
        param(
            'success add',
            add=['test-key:test-value'],
        ),
        param(
            'success remove',
            remove=['test-key:test-value'],
        ),
        param(
            'fail upon no changes',
            raises=exceptions.InvalidRequest,
            exception_message=
            'must provide at least one value to add or remove',
        ),
        # TODO: wrong role type (iam)
    ])
    def test_edit_labels_on_gce_role(self,
                                     label,
                                     add=None,
                                     remove=None,
                                     create_role_first=True,
                                     raises=None,
                                     exception_message=''):
        role_name = 'hvac'
        project_id = 'test-hvac-project-not-a-real-project'
        if create_role_first:
            self.client.gcp.auth.create_role(
                name=role_name,
                role_type='gce',
                project_id=project_id,
                bound_service_accounts=[
                    '*****@*****.**'
                ],
                mount_point=self.TEST_MOUNT_POINT,
            )
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.gcp.auth.edit_labels_on_gce_role(
                    name=role_name,
                    add=add,
                    remove=remove,
                    mount_point=self.TEST_MOUNT_POINT,
                )
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            edit_labled_response = self.client.gcp.auth.edit_labels_on_gce_role(
                name=role_name,
                add=add,
                remove=remove,
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('create_role_response: %s' % edit_labled_response)
            if utils.skip_if_vault_version_lt('0.10.0'):
                expected_status_code = 204
            else:
                expected_status_code = 200  # TODO => figure out why this isn't a 204?
            self.assertEqual(
                first=edit_labled_response.status_code,
                second=expected_status_code,
            )

    @parameterized.expand([
        param('success', ),
        param(
            'nonexistent role',
            create_role_first=False,
            raises=exceptions.InvalidPath,
        ),
    ])
    def test_read_role(self,
                       label,
                       create_role_first=True,
                       raises=None,
                       exception_message=''):
        role_name = 'hvac'
        project_id = 'test-hvac-project-not-a-real-project'
        if create_role_first:
            self.client.gcp.auth.create_role(
                name=role_name,
                role_type='gce',
                project_id=project_id,
                bound_service_accounts=[
                    '*****@*****.**'
                ],
                mount_point=self.TEST_MOUNT_POINT,
            )
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.gcp.auth.read_role(
                    name=role_name,
                    mount_point=self.TEST_MOUNT_POINT,
                )
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            read_role_response = self.client.gcp.auth.read_role(
                name=role_name,
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('create_role_response: %s' % read_role_response)
            self.assertEqual(
                first=project_id,
                second=read_role_response['project_id'],
            )

    @parameterized.expand([
        param('success one role', ),
        param(
            'success multiple roles',
            num_roles_to_create=7,
        ),
        param(
            'no roles',
            num_roles_to_create=0,
            raises=exceptions.InvalidPath,
        ),
    ])
    def test_list_roles(self, label, num_roles_to_create=1, raises=None):
        project_id = 'test-hvac-project-not-a-real-project'
        roles_to_create = ['hvac%s' % n for n in range(0, num_roles_to_create)]
        logging.debug('roles_to_create: %s' % roles_to_create)
        for role_to_create in roles_to_create:
            create_role_response = self.client.gcp.auth.create_role(
                name=role_to_create,
                role_type='gce',
                project_id=project_id,
                bound_service_accounts=[
                    '*****@*****.**'
                ],
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('create_role_response: %s' % create_role_response)

        if raises is not None:
            with self.assertRaises(raises):
                self.client.gcp.list_roles(mount_point=self.TEST_MOUNT_POINT, )
        else:
            list_roles_response = self.client.gcp.auth.list_roles(
                mount_point=self.TEST_MOUNT_POINT, )
            logging.debug('list_roles_response: %s' % list_roles_response)
            self.assertEqual(
                first=list_roles_response['keys'],
                second=roles_to_create,
            )

    @parameterized.expand([
        param('success', ),
        param(
            'nonexistent role name',
            configure_role_first=False,
        ),
    ])
    def test_delete_role(self, label, configure_role_first=True, raises=None):
        role_name = 'hvac'
        project_id = 'test-hvac-project-not-a-real-project'
        if configure_role_first:
            create_role_response = self.client.gcp.auth.create_role(
                name=role_name,
                role_type='gce',
                project_id=project_id,
                bound_service_accounts=[
                    '*****@*****.**'
                ],
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('create_role_response: %s' % create_role_response)

        if raises is not None:
            with self.assertRaises(raises):
                self.client.gcp.delete_role(
                    role=role_name,
                    mount_point=self.TEST_MOUNT_POINT,
                )
        else:
            delete_role_response = self.client.gcp.delete_role(
                role=role_name,
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('delete_role_response: %s' % delete_role_response)
            self.assertEqual(
                first=delete_role_response.status_code,
                second=204,
            )
Esempio n. 34
0
       lowercase hex string.
    """

    mock_init.return_value = None

    result = Crc64Result(0x88889999)
    result._crc64 = 2246800662182009355 # pylint: disable=locally-disabled, protected-access

    eq_("56789a0b", result.low_bytes)

    mock_init.assert_called_once_with(0x88889999)


@istest
@parameterized([
    param("a == b is True", 1, 1001, 2, 1001, "==", True),
    param("a == b is False", 4, 2001, 4, 4001, "==", False),
    param("a != b is False", 8, 8001, 16, 8001, "!=", False),
    param("a != b is True", 32, 16001, 32, 32001, "!=", True)
])
@patch("pydvdid.crc64result.Crc64Result.__init__") # pylint: disable=locally-disabled, invalid-name, too-many-arguments
def crc64result_equality_and_inequality_comparisons_return_correctly(description, polynomial_one,
                                                                     crc64_one, polynomial_two,
                                                                     crc64_two,
                                                                     comparison_function_name,
                                                                     expected, mock_init):
    """Tests that invocation of == and != equality comparisons return correctly.

       (This is a Nose generator test which receives a set of data provided by parameterized).
    """
Esempio n. 35
0
class TestFileSystem(unittest.TestCase):
    def setUp(self):
        self.fs = TestingFileSystem(pipeline_options=None)

    def _flatten_match(self, match_results):
        return [
            file_metadata for match_result in match_results
            for file_metadata in match_result.metadata_list
        ]

    @parameterized.expand([
        ('gs://gcsio-test/**', all),
        # Does not match root-level files
        ('gs://gcsio-test/**/*', lambda n, i: n not in ['cat.png']),
        # Only matches root-level files
        ('gs://gcsio-test/*', [('cat.png', 19)]),
        ('gs://gcsio-test/cow/**', [
            ('cow/cat/fish', 2),
            ('cow/cat/blubber', 3),
            ('cow/dog/blubber', 4),
        ]),
        ('gs://gcsio-test/cow/ca**', [
            ('cow/cat/fish', 2),
            ('cow/cat/blubber', 3),
        ]),
        ('gs://gcsio-test/apple/[df]ish/ca*', [
            ('apple/fish/cat', 10),
            ('apple/fish/cart', 11),
            ('apple/fish/carl', 12),
            ('apple/dish/cat', 14),
            ('apple/dish/carl', 15),
        ]),
        ('gs://gcsio-test/apple/?ish/?a?', [
            ('apple/fish/cat', 10),
            ('apple/dish/bat', 13),
            ('apple/dish/cat', 14),
        ]),
        ('gs://gcsio-test/apple/fish/car?', [
            ('apple/fish/cart', 11),
            ('apple/fish/carl', 12),
        ]),
        ('gs://gcsio-test/apple/fish/b*', [
            ('apple/fish/blubber', 6),
            ('apple/fish/blowfish', 7),
            ('apple/fish/bambi', 8),
            ('apple/fish/balloon', 9),
        ]),
        ('gs://gcsio-test/apple/f*/b*', [
            ('apple/fish/blubber', 6),
            ('apple/fish/blowfish', 7),
            ('apple/fish/bambi', 8),
            ('apple/fish/balloon', 9),
        ]),
        ('gs://gcsio-test/apple/dish/[cb]at', [
            ('apple/dish/bat', 13),
            ('apple/dish/cat', 14),
        ]),
        ('gs://gcsio-test/banana/cyrano.m?', [
            ('banana/cyrano.md', 17),
            ('banana/cyrano.mb', 18),
        ]),
    ])
    def test_match_glob(self, file_pattern, expected_object_names):
        objects = [('cow/cat/fish', 2), ('cow/cat/blubber', 3),
                   ('cow/dog/blubber', 4), ('apple/dog/blubber', 5),
                   ('apple/fish/blubber', 6), ('apple/fish/blowfish', 7),
                   ('apple/fish/bambi', 8), ('apple/fish/balloon', 9),
                   ('apple/fish/cat', 10), ('apple/fish/cart', 11),
                   ('apple/fish/carl', 12), ('apple/dish/bat', 13),
                   ('apple/dish/cat', 14), ('apple/dish/carl', 15),
                   ('banana/cat', 16), ('banana/cyrano.md', 17),
                   ('banana/cyrano.mb', 18), ('cat.png', 19)]
        bucket_name = 'gcsio-test'

        if callable(expected_object_names):
            # A hack around the fact that the parameters do not have access to
            # the "objects" list.

            if expected_object_names is all:
                # It's a placeholder for "all" objects
                expected_object_names = objects
            else:
                # It's a filter function of type (str, int) -> bool
                # that returns true for expected objects
                filter_func = expected_object_names
                expected_object_names = [(short_path, size)
                                         for short_path, size in objects
                                         if filter_func(short_path, size)]

        for object_name, size in objects:
            file_name = 'gs://%s/%s' % (bucket_name, object_name)
            self.fs._insert_random_file(file_name, size)

        expected_file_names = [('gs://%s/%s' % (bucket_name, object_name),
                                size)
                               for object_name, size in expected_object_names]
        actual_file_names = [(file_metadata.path, file_metadata.size_in_bytes)
                             for file_metadata in self._flatten_match(
                                 self.fs.match([file_pattern]))]

        self.assertEqual(set(actual_file_names), set(expected_file_names))

        # Check if limits are followed correctly
        limit = 3
        expected_num_items = min(len(expected_object_names), limit)
        self.assertEqual(
            len(self._flatten_match(self.fs.match([file_pattern], [limit]))),
            expected_num_items)

    @parameterized.expand([
        param(
            os_path=posixpath,
            # re.escape does not escape forward slashes since Python 3.7
            # https://docs.python.org/3/whatsnew/3.7.html ("bpo-29995")
            sep_re='\\/' if sys.version_info < (3, 7, 0) else '/'),
        param(os_path=ntpath, sep_re='\\\\'),
    ])
    def test_translate_pattern(self, os_path, sep_re):
        star = r'[^/\\]*'
        double_star = r'.*'
        join = os_path.join

        sep = os_path.sep
        pattern__expected = [
            (join('a', '*'), sep_re.join(['a', star])),
            (join('b', '*') + sep, sep_re.join(['b', star]) + sep_re),
            (r'*[abc\]', star + r'[abc\\]'),
            (join('d', '**', '*'), sep_re.join(['d', double_star, star])),
        ]
        for pattern, expected in pattern__expected:
            expected = r'(?ms)' + expected + r'\Z'
            result = self.fs.translate_pattern(pattern)
            self.assertEqual(expected, result)
Esempio n. 36
0
class TestSwaggerParser_get_apis(TestCase):
    def test_with_one_path_method(self):
        function_name = "myfunction"
        swagger = {
            "paths": {
                "/path1": {
                    "get": {
                        "x-amazon-apigateway-integration": {
                            "type": "aws_proxy",
                            "uri": "someuri"
                        }
                    }
                }
            }
        }

        parser = SwaggerParser(swagger)
        parser._get_integration_function_name = Mock()
        parser._get_integration_function_name.return_value = function_name

        expected = [
            Route(path="/path1", methods=["get"], function_name=function_name)
        ]
        result = parser.get_routes()

        self.assertEquals(expected, result)
        parser._get_integration_function_name.assert_called_with({
            "x-amazon-apigateway-integration": {
                "type": "aws_proxy",
                "uri": "someuri"
            }
        })

    def test_with_combination_of_paths_methods(self):
        function_name = "myfunction"
        swagger = {
            "paths": {
                "/path1": {
                    "get": {
                        "x-amazon-apigateway-integration": {
                            "type": "aws_proxy",
                            "uri": "someuri"
                        }
                    },
                    "delete": {
                        "x-amazon-apigateway-integration": {
                            "type": "aws_proxy",
                            "uri": "someuri"
                        }
                    }
                },
                "/path2": {
                    "post": {
                        "x-amazon-apigateway-integration": {
                            "type": "aws_proxy",
                            "uri": "someuri"
                        }
                    }
                }
            }
        }

        parser = SwaggerParser(swagger)
        parser._get_integration_function_name = Mock()
        parser._get_integration_function_name.return_value = function_name

        expected = {
            Route(path="/path1", methods=["get"], function_name=function_name),
            Route(path="/path1",
                  methods=["delete"],
                  function_name=function_name),
            Route(path="/path2", methods=["post"],
                  function_name=function_name),
        }
        result = parser.get_routes()

        self.assertEquals(expected, set(result))

    def test_with_any_method(self):
        function_name = "myfunction"
        swagger = {
            "paths": {
                "/path1": {
                    "x-amazon-apigateway-any-method": {
                        "x-amazon-apigateway-integration": {
                            "type": "aws_proxy",
                            "uri": "someuri"
                        }
                    }
                }
            }
        }

        parser = SwaggerParser(swagger)
        parser._get_integration_function_name = Mock()
        parser._get_integration_function_name.return_value = function_name

        expected = [
            Route(methods=["ANY"], path="/path1", function_name=function_name)
        ]
        result = parser.get_routes()

        self.assertEquals(expected, result)

    def test_does_not_have_function_name(self):
        swagger = {
            "paths": {
                "/path1": {
                    "post": {
                        "x-amazon-apigateway-integration": {
                            "type": "aws_proxy",
                            "uri": "someuri"
                        }
                    }
                }
            }
        }

        parser = SwaggerParser(swagger)
        parser._get_integration_function_name = Mock()
        parser._get_integration_function_name.return_value = None  # Function Name could not be resolved

        expected = []
        result = parser.get_routes()

        self.assertEquals(expected, result)

    @parameterized.expand([
        param("empty swagger", {}),
        param("'paths' property is absent", {"foo": "bar"}),
        param("no paths", {"paths": {}}),
        param("no methods", {"paths": {
            "/path1": {}
        }}),
        param("no integration", {"paths": {
            "/path1": {
                "get": {}
            }
        }})
    ])
    def test_invalid_swagger(self, test_case_name, swagger):
        parser = SwaggerParser(swagger)
        result = parser.get_routes()

        expected = []
        self.assertEquals(expected, result)
Esempio n. 37
0
class TestTZPopping(BaseTestCase):
    def setUp(self):
        super(TestTZPopping, self).setUp()
        self.initial_string = self.datetime_string = self.timezone_offset = NotImplemented

    @parameterized.expand([
        param('Sep 03 2014 | 4:32 pm EDT', -4),
        param('17th October, 2034 @ 01:08 am PDT', -7),
        param('17th October, 2034 @ 01:08 am (PDT)', -7),
        param('October 17, 2014 at 7:30 am PST', -8),
        param('20 Oct 2014 13:08 CET', +1),
        param('20 Oct 2014 13:08cet', +1),
        param('Nov 25 2014 | 10:17 pm EST', -5),
        param('Nov 25 2014 | 10:17 pm +0600', +6),
        param('Nov 25 2014 | 10:17 pm -0930', -9.5),
        param('20 Oct 2014 | 05:17 am -1200', -12),
        param('20 Oct 2014 | 05:17 am +0000', 0),
        param('15 May 2004', None),
        param('Wed Aug 05 12:00:00 EDTERR 2015', None),
        param('Wed Aug 05 12:00:00 EDT 2015', -4),
        param('April 10, 2016 at 12:00:00 UTC', 0),
        param('April 10, 2016 at 12:00:00 MEZ', 1),
        param('April 10, 2016 at 12:00:00 MESZ', 2),
        param('April 10, 2016 at 12:00:00 GMT+2', 2),
        param('April 10, 2016 at 12:00:00 UTC+2:00', 2),
        param('April 10, 2016 at 12:00:00 GMT+02:00', 2),
        param('April 10, 2016 at 12:00:00 UTC+5:30', 5.5),
        param('April 10, 2016 at 12:00:00 GMT+05:30', 5.5),
        param('April 10, 2016 at 12:00:00 UTC-2', -2),
        param('April 10, 2016 at 12:00:00 GMT-2:00', -2),
        param('April 10, 2016 at 12:00:00 UTC-02:00', -2),
        param('April 10, 2016 at 12:00:00 GMT-9:30', -9.5),
        param('April 10, 2016 at 12:00:00 UTC-09:30', -9.5),
        param('Thu, 24 Nov 2016 16:03:00 UT', 0),
        param('Fri Sep 23 2016 10:34:51 GMT+0800 (CST)', 8),
    ])
    def test_extracting_valid_offset(self, initial_string, expected_offset):
        self.given_string(initial_string)
        self.when_offset_popped_from_string()
        self.then_offset_is(expected_offset)

    @parameterized.expand([
        param('Sep 03 2014 | 4:32 pm EDT', 'Sep 03 2014 | 4:32 pm '),
        param('17th October, 2034 @ 01:08 am PDT', '17th October, 2034 @ 01:08 am '),
        param('October 17, 2014 at 7:30 am PST', 'October 17, 2014 at 7:30 am '),
        param('20 Oct 2014 13:08 CET', '20 Oct 2014 13:08 '),
        param('20 Oct 2014 13:08cet', '20 Oct 2014 13:08'),
        param('Nov 25 2014 | 10:17 pm EST', 'Nov 25 2014 | 10:17 pm '),
        param('17th October, 2034 @ 01:08 am +0700', '17th October, 2034 @ 01:08 am '),
        param('Sep 03 2014 4:32 pm +0630', 'Sep 03 2014 4:32 pm '),
    ])
    def test_timezone_deleted_from_string(self, initial_string, result_string):
        self.given_string(initial_string)
        self.when_offset_popped_from_string()
        self.then_string_modified_to(result_string)

    def test_string_not_changed_if_no_timezone(self):
        self.given_string('15 May 2004')
        self.when_offset_popped_from_string()
        self.then_string_modified_to('15 May 2004')

    def given_string(self, string_):
        self.initial_string = string_

    def when_offset_popped_from_string(self):
        self.datetime_string, timezone_offset = pop_tz_offset_from_string(self.initial_string)
        if timezone_offset:
            self.timezone_offset = timezone_offset.utcoffset('')
        else:
            self.timezone_offset = timezone_offset

    def then_string_modified_to(self, expected_string):
        self.assertEqual(expected_string, self.datetime_string)

    def then_offset_is(self, expected_offset):
        delta = timedelta(hours=expected_offset) if expected_offset is not None else None
        self.assertEqual(delta, self.timezone_offset)
Esempio n. 38
0
class TestParameterValuesHandling(TestCase):
    """
    Test how user-supplied parameters & default template parameter values from template get merged
    """
    def test_add_default_parameter_values_must_merge(self):
        parameter_values = {"Param1": "value1"}

        sam_template = {
            "Parameters": {
                "Param2": {
                    "Type": "String",
                    "Default": "template default"
                }
            }
        }

        expected = {"Param1": "value1", "Param2": "template default"}

        sam_parser = Parser()
        translator = Translator({}, sam_parser)
        result = translator._add_default_parameter_values(
            sam_template, parameter_values)
        self.assertEquals(expected, result)

    def test_add_default_parameter_values_must_override_user_specified_values(
            self):
        parameter_values = {"Param1": "value1"}

        sam_template = {
            "Parameters": {
                "Param1": {
                    "Type": "String",
                    "Default": "template default"
                }
            }
        }

        expected = {"Param1": "value1"}

        sam_parser = Parser()
        translator = Translator({}, sam_parser)
        result = translator._add_default_parameter_values(
            sam_template, parameter_values)
        self.assertEquals(expected, result)

    def test_add_default_parameter_values_must_skip_params_without_defaults(
            self):
        parameter_values = {"Param1": "value1"}

        sam_template = {
            "Parameters": {
                "Param1": {
                    "Type": "String"
                },
                "Param2": {
                    "Type": "String"
                }
            }
        }

        expected = {"Param1": "value1"}

        sam_parser = Parser()
        translator = Translator({}, sam_parser)
        result = translator._add_default_parameter_values(
            sam_template, parameter_values)
        self.assertEquals(expected, result)

    @parameterized.expand([
        # Array
        param(["1", "2"]),

        # String
        param("something"),

        # Some other non-parameter looking dictionary
        param({"Param1": {
            "Foo": "Bar"
        }}),
        param(None)
    ])
    def test_add_default_parameter_values_must_ignore_invalid_template_parameters(
            self, template_parameters):
        parameter_values = {"Param1": "value1"}

        expected = {"Param1": "value1"}

        sam_template = {"Parameters": template_parameters}

        sam_parser = Parser()
        translator = Translator({}, sam_parser)
        result = translator._add_default_parameter_values(
            sam_template, parameter_values)
        self.assertEquals(expected, result)
Esempio n. 39
0
class TestParseWithFormatsFunction(BaseTestCase):
    def setUp(self):
        super(TestParseWithFormatsFunction, self).setUp()
        self.result = NotImplemented

    @parameterized.expand([
        param(date_string='yesterday', date_formats=['%Y-%m-%d']),
    ])
    def test_date_with_not_matching_format_is_not_parsed(
            self, date_string, date_formats):
        self.when_date_is_parsed_with_formats(date_string, date_formats)
        self.then_date_was_not_parsed()

    @parameterized.expand([
        param(date_string='25-03-14',
              date_formats=['%d-%m-%y'],
              expected_result=datetime(2014, 3, 25)),
    ])
    def test_should_parse_date(self, date_string, date_formats,
                               expected_result):
        self.when_date_is_parsed_with_formats(date_string, date_formats)
        self.then_date_was_parsed()
        self.then_parsed_period_is('day')
        self.then_parsed_date_is(expected_result)

    @parameterized.expand([
        param(date_string='09.16',
              date_formats=['%m.%d'],
              expected_month=9,
              expected_day=16),
    ])
    def test_should_use_current_year_for_dates_without_year(
            self, date_string, date_formats, expected_month, expected_day):
        self.given_now(2015, 2, 4)
        self.when_date_is_parsed_with_formats(date_string, date_formats)
        self.then_date_was_parsed()
        self.then_parsed_period_is('day')
        self.then_parsed_date_is(datetime(2015, expected_month, expected_day))

    @parameterized.expand([
        param(date_string='August 2014',
              date_formats=['%B %Y'],
              expected_year=2014,
              expected_month=8,
              today_day=12,
              prefer_day_of_month='first',
              expected_day=1),
        param(date_string='August 2014',
              date_formats=['%B %Y'],
              expected_year=2014,
              expected_month=8,
              today_day=12,
              prefer_day_of_month='last',
              expected_day=31),
        param(date_string='August 2014',
              date_formats=['%B %Y'],
              expected_year=2014,
              expected_month=8,
              today_day=12,
              prefer_day_of_month='current',
              expected_day=12),
    ])
    def test_should_use_correct_day_from_settings_for_dates_without_day(
            self, date_string, date_formats, expected_year, expected_month,
            today_day, prefer_day_of_month, expected_day):
        self.given_now(2014, 8, today_day)
        settings_mod = copy(settings)
        settings_mod.PREFER_DAY_OF_MONTH = prefer_day_of_month
        self.when_date_is_parsed_with_formats(date_string, date_formats,
                                              settings_mod)
        self.then_date_was_parsed()
        self.then_parsed_period_is('month')
        self.then_parsed_date_is(
            datetime(year=expected_year,
                     month=expected_month,
                     day=expected_day))

    def given_now(self, year, month, day, **time):
        now = datetime(year, month, day, **time)
        datetime_mock = Mock(wraps=datetime)
        datetime_mock.utcnow = Mock(return_value=now)
        datetime_mock.now = Mock(return_value=now)
        datetime_mock.today = Mock(return_value=now)
        self.add_patch(patch('dateparser.date.datetime', new=datetime_mock))
        self.add_patch(patch('dateparser.utils.datetime', new=datetime_mock))

    def when_date_is_parsed_with_formats(self,
                                         date_string,
                                         date_formats,
                                         custom_settings=None):
        self.result = date.parse_with_formats(date_string, date_formats,
                                              custom_settings or settings)

    def then_date_was_not_parsed(self):
        self.assertIsNotNone(self.result)
        self.assertIsNone(self.result['date_obj'])

    def then_date_was_parsed(self):
        self.assertIsNotNone(self.result)
        self.assertIsNotNone(self.result['date_obj'])

    def then_parsed_date_is(self, date_obj):
        self.assertEqual(date_obj.date(), self.result['date_obj'].date())

    def then_parsed_period_is(self, period):
        self.assertEqual(period, self.result['period'])
Esempio n. 40
0
class TestParseWithFormatsFunction(BaseTestCase):
    def setUp(self):
        super(TestParseWithFormatsFunction, self).setUp()
        self.result = NotImplemented

    @parameterized.expand([
        param(date_string='yesterday', date_formats=['%Y-%m-%d']),
    ])
    def test_date_with_not_matching_format_is_not_parsed(self, date_string, date_formats):
        self.when_date_is_parsed_with_formats(date_string, date_formats)
        self.then_date_was_not_parsed()

    @parameterized.expand([
        param(date_string='25-03-14', date_formats=['%d-%m-%y'], expected_result=datetime(2014, 3, 25)),
    ])
    def test_should_parse_date(self, date_string, date_formats, expected_result):
        self.when_date_is_parsed_with_formats(date_string, date_formats)
        self.then_date_was_parsed()
        self.then_parsed_period_is('day')
        self.then_parsed_date_is(expected_result)

    @parameterized.expand([
        param(date_string='09.16', date_formats=['%m.%d'], expected_month=9, expected_day=16),
    ])
    def test_should_use_current_year_for_dates_without_year(
        self, date_string, date_formats, expected_month, expected_day
    ):
        self.given_now(2015, 2, 4)
        self.when_date_is_parsed_with_formats(date_string, date_formats)
        self.then_date_was_parsed()
        self.then_parsed_period_is('day')
        self.then_parsed_date_is(datetime(2015, expected_month, expected_day))

    @parameterized.expand([
        param(date_string='August 2014', date_formats=['%B %Y'],
              expected_year=2014, expected_month=8),
    ])
    def test_should_use_last_day_of_month_for_dates_without_day(
        self, date_string, date_formats, expected_year, expected_month
    ):
        self.given_now(2014, 8, 12)
        self.when_date_is_parsed_with_formats(date_string, date_formats)
        self.then_date_was_parsed()
        self.then_parsed_period_is('month')
        self.then_parsed_date_is(datetime(year=expected_year,
                                          month=expected_month,
                                          day=get_last_day_of_month(expected_year, expected_month)))

    @parameterized.expand([
        param(date_string='25-03-14', date_formats='%d-%m-%y', expected_result=datetime(2014, 3, 25)),
    ])
    def test_should_support_a_string_as_date_formats(self, date_string, date_formats, expected_result):
        self.when_date_is_parsed_with_formats(date_string, date_formats)
        self.then_date_was_parsed()
        self.then_parsed_period_is('day')
        self.then_parsed_date_is(expected_result)

    def given_now(self, year, month, day, **time):
        now = datetime(year, month, day, **time)
        datetime_mock = Mock(wraps=datetime)
        datetime_mock.utcnow = Mock(return_value=now)
        datetime_mock.now = Mock(return_value=now)
        datetime_mock.today = Mock(return_value=now)
        self.add_patch(patch('dateparser.date.datetime', new=datetime_mock))

    def when_date_is_parsed_with_formats(self, date_string, date_formats):
        self.result = date.parse_with_formats(date_string, date_formats, settings)

    def then_date_was_not_parsed(self):
        self.assertIsNotNone(self.result)
        self.assertIsNone(self.result['date_obj'])

    def then_date_was_parsed(self):
        self.assertIsNotNone(self.result)
        self.assertIsNotNone(self.result['date_obj'])

    def then_parsed_date_is(self, date_obj):
        self.assertEquals(date_obj.date(), self.result['date_obj'].date())

    def then_parsed_period_is(self, period):
        self.assertEquals(period, self.result['period'])
Esempio n. 41
0
class TestParserInitialization(BaseTestCase):
    def setUp(self):
        super(TestParserInitialization, self).setUp()
        self.parser = NotImplemented

    @parameterized.expand([
        param(languages='en'),
        param(languages={'languages': ['en', 'he', 'it']}),
    ])
    def test_error_raised_for_invalid_languages_argument(self, languages):
        self.when_parser_is_initialized(languages=languages)
        self.then_error_was_raised(
            TypeError,
            ["languages argument must be a list (%r given)" % type(languages)])

    @parameterized.expand([
        param(locales='en-001'),
        param(locales={'locales': ['zh-Hant-HK', 'ha-NE', 'se-SE']}),
    ])
    def test_error_raised_for_invalid_locales_argument(self, locales):
        self.when_parser_is_initialized(locales=locales)
        self.then_error_was_raised(
            TypeError,
            ["locales argument must be a list (%r given)" % type(locales)])

    @parameterized.expand([
        param(region=['AW', 'BE']),
        param(region=150),
    ])
    def test_error_raised_for_invalid_region_argument(self, region):
        self.when_parser_is_initialized(region=region)
        self.then_error_was_raised(TypeError, [
            "region argument must be str or unicode (%r given)" % type(region)
        ])

    @parameterized.expand([
        param(try_previous_locales=['ar-OM', 'pt-PT', 'fr-CG', 'uk']),
        param(try_previous_locales='uk'),
        param(try_previous_locales={'try_previous_locales': True}),
        param(try_previous_locales=0),
    ])
    def test_error_raised_for_invalid_try_previous_locales_argument(
            self, try_previous_locales):
        self.when_parser_is_initialized(
            try_previous_locales=try_previous_locales)
        self.then_error_was_raised(TypeError, [
            "try_previous_locales argument must be a boolean (%r given)" %
            type(try_previous_locales)
        ])

    @parameterized.expand([
        param(use_given_order=['da', 'pt', 'ja', 'sv']),
        param(use_given_order='uk'),
        param(use_given_order={'use_given_order': True}),
        param(use_given_order=1),
    ])
    def test_error_raised_for_invalid_use_given_order_argument(
            self, use_given_order):
        self.when_parser_is_initialized(locales=['en', 'es'],
                                        use_given_order=use_given_order)
        self.then_error_was_raised(TypeError, [
            "use_given_order argument must be a boolean (%r given)" %
            type(use_given_order)
        ])

    def test_error_is_raised_when_use_given_order_is_True_and_locales_is_None(
            self):
        self.when_parser_is_initialized(use_given_order=True)
        self.then_error_was_raised(
            ValueError, ["locales must be given if use_given_order is True"])

    def when_parser_is_initialized(self,
                                   languages=None,
                                   locales=None,
                                   region=None,
                                   try_previous_locales=True,
                                   use_given_order=False):
        try:
            self.parser = date.DateDataParser(
                languages=languages,
                locales=locales,
                region=region,
                try_previous_locales=try_previous_locales,
                use_given_order=use_given_order)
        except Exception as error:
            self.error = error
Esempio n. 42
0
class TestAws(HvacIntegrationTestCase, TestCase):
    TEST_MOUNT_POINT = 'aws-test'
    TEST_ROLE_NAME = 'hvac-test-role'
    TEST_POLICY_DOCUMENT = {
        'Statement': [
            {
                'Action': 'ec2:Describe*',
                'Effect': 'Allow',
                'Resource': '*'
            },
        ],
        'Version':
        '2012-10-17'
    }

    def setUp(self):
        super(TestAws, self).setUp()
        if '%s/' % self.TEST_MOUNT_POINT not in self.client.list_auth_backends(
        ):
            self.client.sys.enable_secrets_engine(
                backend_type='aws',
                path=self.TEST_MOUNT_POINT,
            )

    def tearDown(self):
        self.client.sys.disable_secrets_engine(path=self.TEST_MOUNT_POINT, )
        super(TestAws, self).tearDown()

    @parameterized.expand([
        param('success', ),
    ])
    def test_configure_root_iam_credentials(self,
                                            label,
                                            credentials='',
                                            raises=None,
                                            exception_message=''):
        access_key = 'butts'
        secret_key = 'secret-butts'
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.secrets.aws.configure_root_iam_credentials(
                    access_key=access_key,
                    secret_key=secret_key,
                    mount_point=self.TEST_MOUNT_POINT,
                )
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            configure_response = self.client.secrets.aws.configure_root_iam_credentials(
                access_key=access_key,
                secret_key=secret_key,
                iam_endpoint='localhost',
                sts_endpoint='localhost',
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('configure_response: %s' % configure_response)
            self.assertEqual(
                first=bool(configure_response),
                second=True,
            )

    @parameterized.expand([
        param('success', ),
    ])
    def test_configure_lease(self,
                             label,
                             lease='60s',
                             lease_max='120s',
                             raises=None,
                             exception_message=''):
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.secrets.aws.configure_lease(
                    lease=lease,
                    lease_max=lease_max,
                    mount_point=self.TEST_MOUNT_POINT,
                )
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            configure_response = self.client.secrets.aws.configure_lease(
                lease=lease,
                lease_max=lease_max,
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('configure_response: %s' % configure_response)
            self.assertEqual(
                first=bool(configure_response),
                second=True,
            )

    @parameterized.expand([
        param('success', ),
    ])
    def test_read_lease(self,
                        label,
                        lease='60s',
                        lease_max='120s',
                        configure_first=True,
                        raises=None,
                        exception_message=''):
        if configure_first:
            configure_response = self.client.secrets.aws.configure_lease(
                lease=lease,
                lease_max=lease_max,
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('configure_response: %s' % configure_response)

        if raises:
            with self.assertRaises(raises) as cm:
                self.client.secrets.aws.read_lease_config(
                    mount_point=self.TEST_MOUNT_POINT, )
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            read_response = self.client.secrets.aws.read_lease_config(
                mount_point=self.TEST_MOUNT_POINT, )
            logging.debug('read_response: %s' % read_response)
            self.assertEqual(
                first=int(lease_max.replace('s', '')),
                second=self.
                convert_python_ttl_value_to_expected_vault_response(
                    ttl_value=read_response['data']['lease_max'], ),
            )

    @parameterized.expand([
        param('success',
              policy_document={
                  'Statement': [
                      {
                          'Action': 'ec2:Describe*',
                          'Effect': 'Allow',
                          'Resource': '*'
                      },
                  ],
                  'Version':
                  '2012-10-17'
              }),
        param(
            'with policy_arns',
            policy_arns=['arn:aws:iam::aws:policy/AmazonVPCReadOnlyAccess'],
        ),
        param(
            'assumed_role with policy document',
            policy_document={
                'Statement': [
                    {
                        'Action': 'ec2:Describe*',
                        'Effect': 'Allow',
                        'Resource': '*'
                    },
                ],
                'Version':
                '2012-10-17'
            },
            credential_type='assumed_role',
        ),
        param(
            'invalid credential type',
            credential_type='cat',
            raises=ParamValidationError,
            exception_message='invalid credential_type argument provided',
        ),
    ])
    def test_create_or_update_role(self,
                                   label,
                                   credential_type='iam_user',
                                   policy_document=None,
                                   default_sts_ttl=None,
                                   max_sts_ttl=None,
                                   role_arns=None,
                                   policy_arns=None,
                                   raises=None,
                                   exception_message=''):
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.secrets.aws.create_or_update_role(
                    name=self.TEST_ROLE_NAME,
                    credential_type=credential_type,
                    policy_document=policy_document,
                    default_sts_ttl=default_sts_ttl,
                    max_sts_ttl=max_sts_ttl,
                    role_arns=role_arns,
                    policy_arns=policy_arns,
                    legacy_params=vault_version_lt('0.11.0'),
                    mount_point=self.TEST_MOUNT_POINT,
                )
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            role_response = self.client.secrets.aws.create_or_update_role(
                name=self.TEST_ROLE_NAME,
                credential_type=credential_type,
                policy_document=policy_document,
                default_sts_ttl=default_sts_ttl,
                max_sts_ttl=max_sts_ttl,
                role_arns=role_arns,
                policy_arns=policy_arns,
                legacy_params=vault_version_lt('0.11.0'),
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('role_response: %s' % role_response)

            self.assertEqual(
                first=bool(role_response),
                second=True,
            )

    @parameterized.expand([
        param('success', ),
    ])
    def test_read_role(self,
                       label,
                       configure_first=True,
                       raises=None,
                       exception_message=''):
        if configure_first:
            self.client.secrets.aws.create_or_update_role(
                name=self.TEST_ROLE_NAME,
                credential_type='iam_user',
                policy_document=self.TEST_POLICY_DOCUMENT,
                legacy_params=vault_version_lt('0.11.0'),
                mount_point=self.TEST_MOUNT_POINT,
            )
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.secrets.aws.read_role(
                    name=self.TEST_ROLE_NAME,
                    mount_point=self.TEST_MOUNT_POINT,
                )
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            read_role_response = self.client.secrets.aws.read_role(
                name=self.TEST_ROLE_NAME,
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('read_role_response: %s' % read_role_response)
            if vault_version_lt('0.11.0'):
                self.assertDictEqual(
                    d1=json.loads(read_role_response['data']['policy']),
                    d2=self.TEST_POLICY_DOCUMENT,
                )
            # https://github.com/hashicorp/vault/commit/2dcd0aed2a242f53dae03318b4d68693f7d92b81
            elif vault_version_lt('1.0.2'):
                self.assertEqual(
                    first=read_role_response['data']['credential_types'],
                    second=['iam_user'],
                )
            else:
                self.assertEqual(
                    first=read_role_response['data']['credential_type'],
                    second='iam_user',
                )

    @parameterized.expand([
        param('success', ),
    ])
    def test_list_roles(self,
                        label,
                        configure_first=True,
                        raises=None,
                        exception_message=''):
        if configure_first:
            self.client.secrets.aws.create_or_update_role(
                name=self.TEST_ROLE_NAME,
                credential_type='iam_user',
                policy_document=self.TEST_POLICY_DOCUMENT,
                legacy_params=vault_version_lt('0.11.0'),
                mount_point=self.TEST_MOUNT_POINT,
            )
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.secrets.aws.list_roles(
                    mount_point=self.TEST_MOUNT_POINT, )
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            list_roles_response = self.client.secrets.aws.list_roles(
                mount_point=self.TEST_MOUNT_POINT, )
            logging.debug('list_roles_response: %s' % list_roles_response)
            self.assertEqual(
                first=list_roles_response['data']['keys'],
                second=[self.TEST_ROLE_NAME],
            )

    @parameterized.expand([
        param('success', ),
    ])
    def test_delete_role(self,
                         label,
                         configure_first=True,
                         raises=None,
                         exception_message=''):
        if configure_first:
            self.client.secrets.aws.create_or_update_role(
                name=self.TEST_ROLE_NAME,
                credential_type='iam_user',
                policy_document=self.TEST_POLICY_DOCUMENT,
                legacy_params=vault_version_lt('0.11.0'),
                mount_point=self.TEST_MOUNT_POINT,
            )
        if raises:
            with self.assertRaises(raises) as cm:
                self.client.secrets.aws.delete_role(
                    name=self.TEST_ROLE_NAME,
                    mount_point=self.TEST_MOUNT_POINT,
                )
            self.assertIn(
                member=exception_message,
                container=str(cm.exception),
            )
        else:
            delete_role_response = self.client.secrets.aws.delete_role(
                name=self.TEST_ROLE_NAME,
                mount_point=self.TEST_MOUNT_POINT,
            )
            logging.debug('delete_role_response: %s' % delete_role_response)
            self.assertEqual(
                first=bool(delete_role_response),
                second=True,
            )
Esempio n. 43
0
class IoExpectationTest(unittest.TestCase):
    def setUp(self):
        self._io = expect.ExpectedInputOutput()
        sys.stdin = self._io
        sys.stdout = self._io

    def tearDown(self):
        sys.stdin = self._io._original_stdin
        sys.stdout = self._io._original_stdout

    @parameterized.expand([
        # ==== expect.Equals ====
        param('expect_equals',
              expected_io=expect.Equals('Expected output\n'),
              ios=lambda: print('Expected output'),
              error_message=None),
        param('expect_equals_missing_newline',
              expected_io=expect.Equals('\nExpected output\n'),
              ios=lambda: sys.stdout.write('Expected output'),
              error_message=None),
        param('expect_equals_missing_white_space',
              expected_io=expect.Equals(' Expected output '),
              ios=lambda: print('Expected output'),
              error_message=None),
        param('expect_equals_extra_white_space_and_newlines',
              expected_io=expect.Equals('Expected output'),
              ios=lambda: print(' Expected output '),
              error_message=None),
        param('expect_equals_no_output',
              expected_io=expect.Equals('Expected output'),
              ios=lambda: None,
              error_message=("Pending IO expectation never fulfulled:\n"
                             "Equals('Expected output')")),
        param('expect_equals_mismatch',
              expected_io=expect.Equals('Expected output'),
              ios=lambda: (print('An Expected output and some more'),
                           print('Some more other output')),
              error_message=("Unexpected output:\n"
                             "- Equals('Expected output')\n"
                             "+ 'An Expected output and some more\\n'")),
        param('expect_equals_extra_output',
              expected_io=expect.Equals('Expected output'),
              ios=lambda:
              (print('Expected output'), print('Unexpected output')),
              error_message=
              "No more output expected, but got: 'Unexpected output\n'"),

        # ==== expect.Contains ====
        param('expect_contains',
              expected_io=expect.Contains('out'),
              ios=lambda: print('Some output'),
              error_message=None),
        param('expect_contains_no_output',
              expected_io=expect.Contains('out'),
              ios=lambda: None,
              error_message=("Pending IO expectation never fulfulled:\n"
                             "Contains('out')")),
        param('expect_contains_mismatch',
              expected_io=expect.Contains('out'),
              ios=lambda: print('Something else'),
              error_message=("Unexpected output:\n"
                             "- Contains('out')\n"
                             "+ 'Something else\\n'")),
        param('expect_contains_extra_output',
              expected_io=expect.Contains('out'),
              ios=lambda: (print('Some output'), print('Unexpected output')),
              error_message=
              "No more output expected, but got: 'Unexpected output\n'"),

        # ==== expect.Prefix ====
        param('expect_prefix',
              expected_io=expect.Prefix('Expected'),
              ios=lambda: print('Expected output'),
              error_message=None),
        param('expect_prefix_extra_whitespace',
              expected_io=expect.Prefix('Expected'),
              ios=lambda: print('  Expected output'),
              error_message=None),
        param('expect_prefix_no_output',
              expected_io=expect.Prefix('Expected'),
              ios=lambda: None,
              error_message=("Pending IO expectation never fulfulled:\n"
                             "Prefix('Expected')")),
        param('expect_prefix_mismatch',
              expected_io=expect.Prefix('Expected'),
              ios=lambda: print('Something else'),
              error_message=("Unexpected output:\n"
                             "- Prefix('Expected')\n"
                             "+ 'Something else\\n'")),
        param('expect_prefix_extra_output',
              expected_io=expect.Prefix('Expected'),
              ios=lambda:
              (print('Expected output'), print('Unexpected output')),
              error_message=
              "No more output expected, but got: 'Unexpected output\n'"),

        # ==== expect.Regex ====
        param('expect_regex',
              expected_io=expect.Regex('.xpec.*d.*'),
              ios=lambda: print('Expected output'),
              error_message=None),
        param('expect_regex_no_output',
              expected_io=expect.Regex('.xpec.*d.*'),
              ios=lambda: None,
              error_message=("Pending IO expectation never fulfulled:\n"
                             "Regex('.xpec.*d.*')")),
        param('expect_regex_mismatch',
              expected_io=expect.Regex('Expec.*d'),
              ios=lambda: print('Something else'),
              error_message=("Unexpected output:\n"
                             "- Regex('Expec.*d')\n"
                             "+ 'Something else\\n'")),
        param('expect_regex_extra_output',
              expected_io=expect.Regex('.*xpec.*d.*'),
              ios=lambda:
              (print('Expected output'), print('Unexpected output')),
              error_message=
              "No more output expected, but got: 'Unexpected output\n'"),

        # ==== expect.Anything ====
        param('expect_anyting_success',
              expected_io=expect.Anything(),
              ios=lambda: print('Some output'),
              error_message=None),
        param('expect_anyting_no_output',
              expected_io=expect.Anything(),
              ios=lambda: None,
              error_message=("Pending IO expectation never fulfulled:\n"
                             "Anything()")),
        param('expect_anyting_extra_output',
              expected_io=expect.Anything(),
              ios=lambda: (print('Some output'), print('Some more output')),
              error_message=
              "No more output expected, but got: 'Some more output\n'"),

        # ==== Repeatedly ====
        param('expect_repeatedly_equals_at_min',
              expected_io=expect.Repeatedly(expect.Equals('Expected output'),
                                            2, 4),
              ios=lambda: (print('Expected output'), print('Expected output')),
              error_message=None),
        param('expect_repeatedly_equals_in_range',
              expected_io=expect.Repeatedly(expect.Equals('Expected output'),
                                            2, 4),
              ios=lambda: (print('Expected output'), print('Expected output'),
                           print('Expected output')),
              error_message=None),
        param('expect_repeatedly_equals_at_max',
              expected_io=expect.Repeatedly(expect.Equals('Expected output'),
                                            2, 4),
              ios=lambda: (print('Expected output'), print('Expected output'),
                           print('Expected output'), print('Expected output')),
              error_message=None),
        param('expect_repeatedly_equals_no_input',
              expected_io=expect.Repeatedly(expect.Equals('Expected output'),
                                            2, 4),
              ios=lambda: None,
              error_message=("Pending IO expectation never fulfulled:\n"
                             "Repeatedly(Equals('Expected output'), 2, 4)")),
        param('expect_repeatedly_equals_below_min',
              expected_io=expect.Repeatedly(expect.Equals('Expected output'),
                                            2, 4),
              ios=lambda: print('Expected output'),
              error_message=("Pending IO expectation never fulfulled:\n"
                             "Repeatedly(Equals('Expected output'), 1, 3)")),
        param('expect_repeatedly_equals_above_max',
              expected_io=expect.Repeatedly(expect.Equals('Expected output'),
                                            2, 4),
              ios=lambda: (print('Expected output'), print('Expected output'),
                           print('Expected output'), print('Expected output'),
                           print('Expected output')),
              error_message=
              "No more output expected, but got: 'Expected output\n'"),
        param('expect_repeatedly_equals_mismatch',
              expected_io=expect.Repeatedly(expect.Equals('Expected output'),
                                            2, 4),
              ios=lambda: (print('Expected output'), print('Expected output'),
                           print('Some other output')),
              error_message=("Unexpected output:\n"
                             "- Repeatedly(Equals('Expected output'), 0, 2)\n"
                             "+ 'Some other output\\n'")),
        param('short_syntax_expect_indefinitely_repeating_error',
              expected_io=expect.Repeatedly(['a', 'b']),
              ios=lambda: (print('a'), print('b'), print('b'), print('a')),
              error_message=(
                  "Unexpected output:\n"
                  "- Repeatedly(InOrder(Contains('a'), Contains('b')))\n"
                  "+ 'b\\n'")),

        # ==== InOrder ====
        param('expect_in_order_equals',
              expected_io=expect.InOrder(
                  expect.Equals('First expected output'),
                  expect.Equals('Second expected output'),
                  expect.Equals('Third expected output')),
              ios=lambda:
              (print('First expected output'), print('Second expected output'),
               print('Third expected output')),
              error_message=None),
        param('expect_in_order_equals_mismatch',
              expected_io=expect.InOrder(
                  expect.Equals('First expected output'),
                  expect.Equals('Second expected output'),
                  expect.Equals('Third expected output')),
              ios=lambda:
              (print('First expected output'), print('Third expected output')),
              error_message=("Unexpected output:\n"
                             "- InOrder(Equals('Second expected output'), "
                             "Equals('Third expected output'))\n"
                             "+ 'Third expected output\\n'")),
        param('expect_in_order_equals_extra_output',
              expected_io=expect.InOrder(
                  expect.Equals('First expected output'),
                  expect.Equals('Second expected output')),
              ios=lambda:
              (print('First expected output'), print('Second expected output'),
               print('Unexpected output')),
              error_message=
              "No more output expected, but got: 'Unexpected output\n'"),
        param('expect_in_order_repeatedly',
              expected_io=expect.InOrder(
                  expect.Equals('Repeated output').times(2, 4),
                  expect.Equals('Remaining output')),
              ios=lambda:
              (print('Repeated output'), print('Repeated output'),
               print('Repeated output'), print('Remaining output')),
              error_message=None),
        param('expect_any_order_of_in_orders_1',
              expected_io=(expect.AnyOrder(
                  expect.InOrder('In order 1-a', 'In order 1-b'),
                  expect.InOrder('In order 2-a', 'In order 2-b'))),
              ios=lambda: (print('In order 1-a'), print('In order 1-b'),
                           print('In order 2-a'), print('In order 2-b')),
              error_message=None),
        param('expect_any_order_of_in_orders_2',
              expected_io=(expect.AnyOrder(
                  expect.InOrder('In order 1-a', 'In order 1-b'),
                  expect.InOrder('In order 2-a', 'In order 2-b'))),
              ios=lambda: (print('In order 1-a'), print('In order 2-a'),
                           print('In order 2-b'), print('In order 1-b')),
              error_message=None),
        param('expect_any_order_of_in_orders_error',
              expected_io=(expect.AnyOrder(
                  expect.InOrder('In order 1-a', 'In order 1-b'),
                  expect.InOrder('In order 2-a', 'In order 2-b'))),
              ios=lambda: (print('In order 1-a'), print('In order 2-b'),
                           print('In order 2-a'), print('In order 1-b')),
              error_message=("Unexpected output:\n"
                             "- AnyOrder(Contains('In order 1-b'), "
                             "InOrder(Contains('In order 2-a'), "
                             "Contains('In order 2-b')))\n"
                             "+ 'In order 2-b\\n'")),
        param('expect_in_order_of_contains_and_anything',
              expected_io=(expect.AnyOrder(expect.Contains('foo'),
                                           expect.Contains('bar'),
                                           expect.Anything().repeatedly())),
              ios=lambda:
              (print('Second match is "foo".'), print('First match is "bar".'),
               print('Some more output.')),
              error_message=None),
        param('expect_in_order_of_anything_and_contains',
              expected_io=(expect.InOrder(expect.Anything().repeatedly(),
                                          expect.Contains('foo'),
                                          expect.Contains('bar'),
                                          expect.Anything().repeatedly())),
              ios=lambda:
              (print('Some output'), print('Second match is "foo".'),
               print('First match is "bar".'), print('Some more output.')),
              error_message=None),

        # ==== AnyOrder ====
        param('expect_any_order_equals',
              expected_io=expect.AnyOrder(
                  expect.Equals('First expected output'),
                  expect.Equals('Second expected output'),
                  expect.Equals('Third expected output')),
              ios=lambda:
              (print('Second expected output'), print('First expected output'),
               print('Third expected output')),
              error_message=None),
        param('expect_any_order_or_in_order_repetitions',
              expected_io=(expect.AnyOrder(
                  expect.InOrder(
                      expect.Equals('Repeated output').times(2, 4),
                      expect.Equals('Last in order')),
                  expect.Equals('At any time'))),
              ios=lambda: (print('Repeated output'), print('Repeated output'),
                           print('Repeated output'), print('At any time'),
                           print('Repeated output'), print('Last in order')),
              error_message=None),
        param('expect_any_order_of_contains_and_anything',
              expected_io=(expect.AnyOrder(expect.Contains('foo'),
                                           expect.Contains('bar'),
                                           expect.Anything().repeatedly())),
              ios=lambda:
              (print('Something'), print('First match is "bar".'),
               print('Something else'), print('Second match is "foo".'),
               print('Some more output.')),
              error_message=None),
        param('expect_any_order_of_anything_and_contains',
              expected_io=(expect.AnyOrder(expect.Anything().repeatedly(),
                                           expect.Contains('foo'),
                                           expect.Contains('bar'))),
              ios=lambda:
              (print('Something'), print('First match is "bar".'),
               print('Something else'), print('Second match is "foo".'),
               print('Some more output.')),
              error_message=None),

        # ==== Reply ====
        param('expect_reply',
              expected_io=expect.Reply('yes'),
              ios=lambda: AssertEquals(input(), 'yes'),
              error_message=None),
        param('expect_reply_with_prompt',
              expected_io=(expect.InOrder(expect.Equals('Will it work? '),
                                          expect.Reply('yes'))),
              ios=lambda: AssertEquals(input('Will it work? '), 'yes'),
              error_message=None),

        # ==== Syntactic sugars ====
        param('short_syntax_expect_in_order_equals',
              expected_io=[
                  'First expected output', 'Second expected output',
                  'Third expected output'
              ],
              ios=lambda:
              (print('First expected output'), print('Second expected output'),
               print('Third expected output')),
              error_message=None),
        param('short_syntax_expect_repeatedly_contains',
              expected_io=expect.Repeatedly('out'),
              ios=lambda: (print('Expected output'), print('Expected output'),
                           print('Expected output')),
              error_message=None),
        param('short_syntax_expect_repeatedly_contains_in_order',
              expected_io=expect.Repeatedly(['a', 'b']),
              ios=lambda: (print('a'), print('b'), print('a'), print('b')),
              error_message=None),
    ])
    def test_expectation(self, test_name, expected_io, ios, error_message):
        self._io.set_expected_io(expected_io)
        ios()
        if error_message is None:
            self._io.assert_expectations_fulfilled()
        else:
            with self.assertRaises(AssertionError) as error:
                self._io.assert_expectations_fulfilled()
            if error_message:
                self.assertEqual(error_message, str(error.exception))

    def test_documentation_example(self):
        self._io.set_expected_io(
            expect.InOrder(expect.Contains('initialization'),
                           expect.Anything().repeatedly(),
                           expect.Equals('Proceed?'), expect.Reply('yes'),
                           expect.Prefix('Success')))

        print('Program initialization...')
        print('Loading resources...')
        print('Initialization complete.')
        print('Proceed? ')
        if input() == 'yes':
            print('Success')
        else:
            print('Aborting')

        self._io.assert_expectations_fulfilled()

    def test_output_was(self):
        print('Some output')
        self._io.assert_output_was('Some output')
        print('Some more output')
        self._io.assert_output_was('Some more output')

    def test_set_expected_io_ignore_previous_outputs(self):
        print('Some ignored output')
        self._io.set_expected_io('Some expected output')
        print('Some expected output')
        self._io.assert_expectations_fulfilled()
from rsmtool.test_utils import (check_run_experiment,
                                collect_warning_messages_from_report,
                                do_run_experiment)

# allow test directory to be set via an environment variable
# which is needed for package testing
TEST_DIR = os.environ.get('TESTDIR', None)
if TEST_DIR:
    rsmtool_test_dir = TEST_DIR
else:
    from rsmtool.test_utils import rsmtool_test_dir


@parameterized([
    param('lr-with-h2-include-zeros',
          'lr_with_h2_include_zeros',
          consistency=True),
    param('lr-with-h2-and-length', 'lr_with_h2_and_length', consistency=True),
    param('lr-with-h2-named-sc1', 'lr_with_h2_named_sc1', consistency=True),
    param('lars', 'Lars', skll=True),
    param('lars-custom-objective', 'Lars_custom_objective', skll=True),
    param('logistic-regression', 'LogisticRegression', skll=True),
    param('logistic-regression-custom-objective',
          'LogisticRegression_custom_objective',
          skll=True),
    param('logistic-regression-custom-objective-and-params',
          'LogisticRegression_custom_objective_and_params',
          skll=True),
    param('logistic-regression-expected-scores',
          'LogisticRegression_expected_scores',
          skll=True),