Esempio n. 1
0
class TestCaseDataValidation:
    @pytest.fixture(scope="function")
    def app(self):
        return Flama(schema=None, docs=None)

    @pytest.fixture(scope="function")
    def client(self, app):
        return TestClient(app)

    @pytest.mark.parametrize(
        "request_params,response_status,response_json",
        [
            # JSON
            param({"json": {
                "abc": 123
            }},
                  200, {"data": {
                      "abc": 123
                  }},
                  id="valid json body"),
            param({}, 200, {"data": None}, id="empty json body"),
            # Urlencoding
            param({"data": {
                "abc": 123
            }},
                  200, {"data": {
                      "abc": "123"
                  }},
                  id="valid urlencoded body"),
            param(
                {
                    "headers": {
                        "content-type": "application/x-www-form-urlencoded"
                    }
                },
                200,
                {"data": None},
                id="empty urlencoded body",
            ),
            # Multipart
            param(
                {
                    "files": {
                        "a": ("b", "123")
                    },
                    "data": {
                        "b": "42"
                    }
                },
                200,
                {
                    "data": {
                        "a": {
                            "filename": "b",
                            "content": "123"
                        },
                        "b": "42"
                    }
                },
                id="multipart",
            ),
            # Misc
            param({
                "data": b"...",
                "headers": {
                    "content-type": "unknown"
                }
            },
                  415,
                  None,
                  id="unknown body type"),
            param(
                {
                    "data": b"...",
                    "headers": {
                        "content-type": "application/json"
                    }
                },
                400,
                None,
                id="json parse failure"),
        ],
    )
    def test_request_data(self, request_params, response_status, response_json,
                          app, client):
        @app.route("/request_data/", methods=["POST"])
        async def get_request_data(data: http.RequestData):
            try:
                data = {
                    key: value if not hasattr(value, "filename") else {
                        "filename": value.filename,
                        "content": (await value.read()).decode("utf-8")
                    }
                    for key, value in data.items()
                }
            except Exception:
                pass

            return {"data": data}

        response = client.post("/request_data/", **request_params)
        assert response.status_code == response_status, str(response.content)
        if response_json is not None:
            assert response.json() == response_json

    @pytest.mark.parametrize(
        "encoding,send_method,data,expected_result",
        [
            # bytes
            param("bytes",
                  "bytes",
                  b"foo", {
                      "type": "websocket.send",
                      "bytes": b"foo"
                  },
                  id="bytes"),
            param("bytes",
                  "text",
                  b"foo", {
                      "type": "websocket.close",
                      "code": 1003
                  },
                  id="bytes wrong format"),
            # text
            param("text",
                  "text",
                  "foo", {
                      "type": "websocket.send",
                      "text": "foo"
                  },
                  id="text"),
            param("text",
                  "bytes",
                  "foo", {
                      "type": "websocket.close",
                      "code": 1003
                  },
                  id="text wrong format"),
            # json
            param(
                "json",
                "json",
                {"foo": "bar"},
                {
                    "type": "websocket.send",
                    "text": '{"foo": "bar"}'
                },
                id="json from json",
            ),
            param(
                "json",
                "text",
                '{"foo": "bar"}',
                {
                    "type": "websocket.send",
                    "text": '{"foo": "bar"}'
                },
                id="json from text",
            ),
            param(
                "json",
                "bytes",
                b'{"foo": "bar"}',
                {
                    "type": "websocket.send",
                    "text": '{"foo": "bar"}'
                },
                id="json from bytes",
            ),
            param("json",
                  "bytes",
                  b'{"foo":', {
                      "type": "websocket.close",
                      "code": 1003
                  },
                  id="json wrong format"),
        ],
    )
    def test_websocket_message_data(self, encoding, send_method, data,
                                    expected_result, app, client):
        encoding_ = encoding

        @app.websocket_route("/websocket/")
        class Endpoint(WebSocketEndpoint):
            encoding = encoding_

            async def on_receive(self, websocket: websockets.WebSocket,
                                 data: websockets.Data):
                await getattr(websocket, f"send_{encoding}")(data)

        with client.websocket_connect("/websocket/") as ws:
            getattr(ws, f"send_{send_method}")(data)
            message = ws.receive()

        assert message == expected_result
Esempio n. 2
0
class TestZoteroWrap:

    # __init__

    def test_init(self, zw0_shared, shared_directory, cache_filename):
        """Test the creation of a ZoteroWrap instance."""
        assert zw0_shared.cache_path == os.path.join(shared_directory,
                                                     cache_filename)
        assert zw0_shared.reference_types == []
        assert zw0_shared.reference_templates == {}
        assert zw0_shared._references == []

    # initialize

    def test_initialize_cache(self, mocker, zw0):
        """When there are data cached."""
        mocker.patch("nat.ZoteroWrap.load_cache")
        zw0.initialize()
        # TODO Use assert_called_once() with Python 3.6+.
        nat.ZoteroWrap.load_cache.assert_called_once_with()

    def test_initialize_distant(self, mocker, zw0):
        """When there are no data cached."""
        mocker.patch("nat.ZoteroWrap.load_distant")
        zw0.initialize()
        # TODO Use assert_called_once() with Python 3.6+.
        nat.ZoteroWrap.load_distant.assert_called_once_with()

    # cache

    def test_cache(self, zw0_shared, references, reference_types,
                   reference_templates):
        """Test ZoteroWrap.cache()."""
        zw0_shared._references = references
        zw0_shared.reference_types = reference_types
        zw0_shared.reference_templates = reference_templates
        zw0_shared.cache()
        assert os.path.getsize(zw0_shared.cache_path) > 0

    # load_cache

    # Should be executed after test_cache().
    def test_load_cache(self, zw0_shared):
        """Test ZoteroWrap.load_cache()."""
        zw0_shared.load_cache()
        assert zw0_shared._references == REFERENCES
        assert zw0_shared.reference_types == REFERENCE_TYPES
        assert zw0_shared.reference_templates == REFERENCE_TEMPLATES

    # load_distant

    def test_load_distant(self, monkeypatch, zw0, references, reference_types,
                          reference_templates):
        """Test ZoteroWrap.load_distant()."""
        scope = "nat.ZoteroWrap."
        monkeypatch.setattr(scope + "get_references", lambda _: references)
        monkeypatch.setattr(scope + "get_reference_types",
                            lambda _: reference_types)
        monkeypatch.setattr(scope + "get_reference_templates",
                            lambda _, x: reference_templates)
        zw0.load_distant()
        assert zw0._references == REFERENCES
        assert zw0.reference_types == REFERENCE_TYPES
        assert zw0.reference_templates == REFERENCE_TEMPLATES
        assert os.path.getsize(zw0.cache_path) > 0

    # create_local_reference

    @mark.parametrize("zww", [
        param(lazy_fixture("zw"), id="initialized"),
        param(lazy_fixture("zw0"), id="not_initialized")
    ])
    def test_create_local_reference(self, zww, reference):
        """When a ZoteroWrap instance has several references or none."""
        reference_count = zww.reference_count()
        zww.create_local_reference(reference)
        assert zww.reference_count() == reference_count + 1
        assert zww._references[-1] == reference
        assert os.path.getsize(zww.cache_path) > 0

    # create_distant_reference

    def test_create_distant_reference_successful(self, monkeypatch, zw0,
                                                 reference):
        """When Zotero.create_items() has been successful."""
        creation_status = {
            "failed": {},
            "success": {
                "0": reference["key"]
            },
            "successful": {
                "0": reference
            },
            "unchanged": {}
        }
        monkeypatch.setattr("nat.ZoteroWrap.validate_reference_data",
                            lambda _, x: None)
        monkeypatch.setattr("pyzotero.zotero.Zotero.create_items",
                            lambda _, x: creation_status)
        assert zw0.create_distant_reference(reference["data"]) == reference

    def test_create_distant_reference_unsuccessful(self, monkeypatch, zw0,
                                                   reference):
        """When Zotero.create_items() hasn't been successful."""
        # NB: Creation status expected when unsuccessful hasn't been observed.
        monkeypatch.setattr("nat.ZoteroWrap.validate_reference_data",
                            lambda _, x: None)
        monkeypatch.setattr("pyzotero.zotero.Zotero.create_items",
                            lambda _, x: {})
        with raises(nat.zotero_wrap.CreateZoteroItemError):
            zw0.create_distant_reference(reference["data"])

    def test_create_distant_reference_invalid(self, monkeypatch, mocker, zw0,
                                              reference):
        """When ZoteroWrap.validate_reference_data() hasn't been successful."""

        # NB: Exception raised at the validate_reference_data() level isn't captured.
        def raise_exception(_, x):
            raise pyzotero.zotero_errors.InvalidItemFields

        monkeypatch.setattr("pyzotero.zotero.Zotero.check_items",
                            raise_exception)
        mocker.patch("pyzotero.zotero.Zotero.create_items")
        with raises(nat.zotero_wrap.InvalidZoteroItemError):
            zw0.create_distant_reference(reference["data"])
        # TODO Use assert_not_called() with Python 3.5+.
        assert not pyzotero.zotero.Zotero.create_items.called

    # update_local_reference

    def test_update_local_reference(self, zw, reference):
        """When a ZoteroWrap instance has several references."""
        references = deepcopy(zw._references)
        reference_count = zw.reference_count()
        index = reference_count // 2
        zw.update_local_reference(index, reference)
        changed = [
            new for old, new in zip(references, zw._references) if old != new
        ]
        assert zw.reference_count() == reference_count
        assert zw._references[index] == reference
        assert len(changed) == 1
        assert os.path.getsize(zw.cache_path) > 0

    def test_update_local_reference_empty(self, zw0, reference):
        """When a ZoteroWrap instance has no references."""
        with raises(IndexError):
            zw0.update_local_reference(0, reference)
        assert zw0.reference_count() == 0
        assert len(os.listdir(os.path.dirname(zw0.cache_path))) == 0

    # update_distant_reference

    def test_update_distant_reference(self, monkeypatch, mocker, zw0,
                                      reference):
        """Test ZoteroWrap.update_distant_reference()."""
        monkeypatch.setattr("nat.ZoteroWrap.validate_reference_data",
                            lambda _, x: None)
        mocker.patch("pyzotero.zotero.Zotero.update_item")
        zw0.update_distant_reference(reference)
        # TODO Use assert_called_once() with Python 3.6+.
        pyzotero.zotero.Zotero.update_item.assert_called_once_with(reference)

    def test_update_distant_reference_invalid(self, monkeypatch, mocker, zw0,
                                              reference):
        """When ZoteroWrap.validate_reference_data() hasn't been successful."""

        # NB: Exception raised at the validate_reference_data() level isn't captured.
        def raise_exception(_, x):
            raise pyzotero.zotero_errors.InvalidItemFields

        monkeypatch.setattr("pyzotero.zotero.Zotero.check_items",
                            raise_exception)
        mocker.patch("pyzotero.zotero.Zotero.update_item")
        with raises(nat.zotero_wrap.InvalidZoteroItemError):
            zw0.update_distant_reference(reference)
        # TODO Use assert_not_called() with Python 3.5+.
        assert not pyzotero.zotero.Zotero.update_item.called

    # validate_reference_data

    def test_validate_reference_data(self, monkeypatch, zw0, reference):
        """When Zotero.check_items() isn't successful."""
        def raise_exception(_, x):
            raise pyzotero.zotero_errors.InvalidItemFields

        monkeypatch.setattr("pyzotero.zotero.Zotero.check_items",
                            raise_exception)
        with raises(nat.zotero_wrap.InvalidZoteroItemError):
            zw0.validate_reference_data(reference["data"])

    # get_references

    def test_get_references(self, mocker, zw0):
        """Test ZoteroWrap.get_references()."""
        mocker_top = mocker.patch("pyzotero.zotero.Zotero.top")
        mocker.patch("pyzotero.zotero.Zotero.everything")
        zw0.get_references()
        # TODO Use assert_called_once() with Python 3.6+.
        pyzotero.zotero.Zotero.top.assert_called_once_with()
        pyzotero.zotero.Zotero.everything.assert_called_once_with(mocker_top())

    # get_reference_types

    def test_get_reference_types(self, monkeypatch, zw0, item_types):
        """Test ZoteroWrap.get_reference_types()."""
        monkeypatch.setattr("pyzotero.zotero.Zotero.item_types",
                            lambda _: item_types)
        assert zw0.get_reference_types() == REFERENCE_TYPES

    # get_reference_templates

    def test_get_reference_templates(self, monkeypatch, zw0, reference_types):
        """Test ZoteroWrap.get_reference_templates()."""
        def patch(_, ref_type):
            from tests.zotero.data import GET_REFERENCE_TEMPLATES
            index = reference_types.index(ref_type)
            return GET_REFERENCE_TEMPLATES[index]

        monkeypatch.setattr("nat.ZoteroWrap.get_reference_template", patch)
        assert zw0.get_reference_templates(
            reference_types) == REFERENCE_TEMPLATES

    # get_reference_template

    @mark.parametrize("ref_type, patched, expected", TEMPLATES_TEST_DATA)
    def test_get_reference_template(self, monkeypatch, zw0, ref_type, patched,
                                    expected):
        """Test ZoteroWrap.get_reference_template() for each known cases."""
        monkeypatch.setattr("pyzotero.zotero.Zotero.item_template",
                            lambda _, x: patched)
        assert zw0.get_reference_template(ref_type) == expected

    # get_reference

    def test_get_reference(self, mocker, zw0, reference):
        """Test ZoteroWrap.get_reference()."""
        reference_key = reference["key"]
        mocker.patch("pyzotero.zotero.Zotero.item")
        zw0.get_reference(reference_key)
        pyzotero.zotero.Zotero.item.assert_called_once_with(reference_key)

    # reference_count

    @mark.parametrize("zww, expected", [
        param(lazy_fixture("zw"), len(REFERENCES), id="initialized"),
        param(lazy_fixture("zw0"), 0, id="not_initialized")
    ])
    def test_reference_count(self, zww, expected):
        """When a ZoteroWrap instance has several references or none."""
        assert zww.reference_count() == expected

    # reference_data

    def test_reference_data(self, zw0, reference):
        """Test ZoteroWrap.reference_data()."""
        zw0._references.append(reference)
        assert zw0.reference_data(0) == reference["data"]

    # reference_extra_field

    @mark.parametrize("value, expected", [
        param(PMID_STR, (PMID, ""), id="one"),
        param(PMID_STR + "\nPMCID: PMC1234567", (PMID, "PMC1234567"),
              id="several"),
        param("", ("", ""), id="empty")
    ])
    def test_reference_extra_field(self, zw0, value, expected):
        """When a reference has one, several, or no extra field(s)."""
        init(zw0, "extra", value)
        assert zw0.reference_extra_field("PMID", 0) == expected[0]
        assert zw0.reference_extra_field("PMCID", 0) == expected[1]

    # reference_type

    def test_reference_type(self, zw0):
        """Test ZoteroWrap.reference_type()."""
        value = "journalArticle"
        init(zw0, "itemType", value)
        assert zw0.reference_type(0) == value

    # reference_key

    def test_reference_key(self, zw0):
        """Test ZoteroWrap.reference_key()."""
        value = "01ABC3DE"
        init(zw0, "key", value, is_data_subfield=False)
        assert zw0.reference_key(0) == value

    # reference_doi

    def test_reference_doi(self, zw0):
        """When a reference has a DOI."""
        init(zw0, "DOI", DOI)
        assert zw0.reference_doi(0) == DOI

    def test_reference_doi_extra(self, zw0, reference_book):
        """When a reference has a DOI as an extra field (not a 'journalArticle')."""
        reference_book["data"]["extra"] = DOI_STR
        zw0._references.append(reference_book)
        assert zw0.reference_doi(0) == DOI

    def test_reference_doi_without(self, zw0):
        """When a reference has no DOI (dedicated field or in 'extra')."""
        reference = init(zw0, "DOI", "")
        reference["extra"] = ""
        assert zw0.reference_doi(0) == ""

    # reference_pmid

    @mark.parametrize("value, expected", [
        param(PMID_STR, PMID, id="with"),
        param("", "", id="without"),
    ])
    def test_reference_pmid(self, zw0, value, expected):
        """When a reference has a PMID as an extra field or none."""
        init(zw0, "extra", value)
        assert zw0.reference_pmid(0) == expected

    # reference_unpublished_id

    @mark.parametrize("value, expected", [
        param(UPID_STR, UPID, id="with"),
        param("", "", id="without"),
    ])
    def test_reference_unpublished_id(self, zw0, value, expected):
        """When a reference has an UNPUBLISHED ID as an extra field or none."""
        init(zw0, "extra", value)
        assert zw0.reference_unpublished_id(0) == expected

    # reference_id

    def test_reference_id_doi(self, zw0):
        """When a reference has a DOI."""
        init(zw0, "DOI", DOI)
        assert zw0.reference_id(0) == DOI

    def test_reference_id_doi_extra(self, zw0, reference_book):
        """When a reference has a DOI as an extra field (not a 'journalArticle')."""
        reference_book["data"]["extra"] = DOI_STR
        zw0._references.append(reference_book)
        assert zw0.reference_id(0) == DOI

    @mark.parametrize("value, expected", [
        param(PMID_STR, "PMID_" + PMID, id="PMID"),
        param(UPID_STR, "UNPUBLISHED_" + UPID, id="UNPUBLISHED_ID")
    ])
    def test_reference_id_extra(self, zw0, value, expected):
        """When a reference has a PMID or a UNPUBLISHED ID as an extra field."""
        init(zw0, "extra", value)
        assert zw0.reference_id(0) == expected

    def test_reference_id_without(self, zw0):
        """When a reference has no ID (DOI, PMID, UNPUBLISHED ID)."""
        reference = init(zw0, "DOI", "")
        reference["extra"] = ""
        assert zw0.reference_id(0) == ""

    # reference_title

    def test_reference_title(self, zw0):
        """When a reference has a title."""
        value = "Journal Article"
        init(zw0, "title", value)
        assert zw0.reference_title(0) == value

    # reference_creator_surnames

    @mark.parametrize("value, expected", [
        param(CREATORS[:1], ["AuthorLastA"], id="one_author"),
        param(CREATORS[:2], ["AuthorLastA", "AuthorLast-B"], id="two_authors"),
        param(CREATORS[3:-1], ["EditorLastA", "EditorLast-B"],
              id="several_not_authors"),
        param(CREATORS[:-1], ["AuthorLastA", "AuthorLast-B", "AuthorLastC"],
              id="several_mixed"),
        param(CREATORS[-1:], [], id="no_last_name"),
    ])
    def test_reference_creator_surnames(self, zw0, value, expected):
        """Test reference_creator_surnames().

        When a reference has one creator of type 'author'.
        When a reference has two creators of type 'author'.
        When a reference has several creators which aren't of type 'author'.
        When a reference has several creators of type 'author' and not.
        When a reference has one creator of type 'author' described with 'name'.
        """
        init(zw0, "creators", value)
        assert zw0.reference_creator_surnames(0) == expected

    # reference_creator_surnames_str

    @mark.parametrize("value, expected", [
        param(CREATORS[:1], "AuthorLastA", id="one_author"),
        param(CREATORS[:2], "AuthorLastA, AuthorLast-B", id="two_authors"),
        param(CREATORS[-1:], "", id="no_last_name"),
    ])
    def test_reference_creator_surnames_str(self, zw0, value, expected):
        """Test reference_creator_surnames_str().

        When a reference has one creator of type 'author'.
        When a reference has two creators of type 'author'.
        When a reference has one creator of type 'author' described with 'name'.
        """
        init(zw0, "creators", value)
        assert zw0.reference_creator_surnames_str(0) == expected

    # reference_date

    @mark.parametrize(
        "value, expected",
        [param(DATE, DATE, id="with"),
         param("", "", id="empty")])
    def test_reference_date(self, zw0, value, expected):
        """When a reference has a date or none."""
        init(zw0, "date", value)
        assert zw0.reference_date(0) == expected

    # reference_year

    @mark.parametrize("value, expected", [
        param(DATE, 2017, id="recognized"),
        param("9 avr. 2017", 2017, id="unrecognized"),
        param("", "", id="without")
    ])
    def test_reference_year(self, zw0, value, expected):
        """Test reference_year().

        When a reference has a year recognized by dateutil.parser.parse().
        When a reference has a year not recognized by dateutil.parser.parse().
        When a reference has no year (because no date).
        """
        init(zw0, "date", value)
        assert zw0.reference_year(0) == expected

    # reference_journal

    def test_reference_journal(self, zw0):
        """When a reference has a journal."""
        value = "The Journal of Journals"
        init(zw0, "publicationTitle", value)
        assert zw0.reference_journal(0) == value

    def test_reference_journal_without(self, zw0, reference_book):
        """When a reference has no journal (not a 'journalArticle')."""
        zw0._references.append(reference_book)
        assert zw0.reference_journal(0) == "(book)"

    # reference_index

    def test_reference_index_one_found(self, zw0):
        """When a reference in the reference list has the searched reference ID."""
        init(zw0, "extra", PMID_STR)
        init(zw0, "DOI", DOI)
        init(zw0, "extra", UPID_STR)
        assert zw0.reference_index(DOI) == 1

    def test_reference_index_several_found(self, zw0):
        """When there are duplicates in the reference list (same reference ID)."""
        init(zw0, "DOI", DOI)
        init(zw0, "extra", PMID_STR)
        init(zw0, "extra", PMID_STR)
        assert zw0.reference_index("PMID_" + PMID) == 1

    def test_reference_index_not_found(self, zw0):
        """When no reference in the reference list has the searched reference ID."""
        init(zw0, "extra", PMID_STR)
        init(zw0, "extra", UPID_STR)
        exception_str = "^ID: {}$".format(DOI)
        with raises(nat.zotero_wrap.ReferenceNotFoundError,
                    match=exception_str):
            zw0.reference_index(DOI)

    # reference_creators_citation

    @mark.parametrize("value, expected", [
        param(CREATORS[:1], "AuthorLastA (2017)", id="one_author"),
        param(CREATORS[:2],
              "AuthorLastA and AuthorLast-B (2017)",
              id="two_authors"),
        param(CREATORS[:3], "AuthorLastA et al. (2017)", id="three_authors"),
        param(CREATORS[-1:], "", id="no_last_name"),
    ])
    def test_reference_creators_citation(self, zw0, value, expected):
        """Test reference_creators_citation().

        When a reference with an ID and a date has one creator.
        When a reference with an ID and a date has two creators.
        When a reference with an ID and a date has three creators.
        When a reference with an ID and a date has one creator described with 'name'.
        """
        reference = init(zw0, "DOI", DOI)
        reference["data"]["date"] = DATE
        reference["data"]["creators"] = value
        assert zw0.reference_creators_citation(DOI) == expected

    @mark.parametrize("value, expected", [
        param(CREATORS[:3], "AuthorLastA et al. ()", id="three_authors"),
        param(CREATORS[-1:], "", id="no_last_name"),
    ])
    def test_reference_creators_citation_year_without(self, zw0, value,
                                                      expected):
        """When a reference has no year (because no date)."""
        reference = init(zw0, "DOI", DOI)
        reference["data"]["date"] = ""
        reference["data"]["creators"] = value
        assert zw0.reference_creators_citation(DOI) == expected
Esempio n. 3
0

@fixture
def request_session():
    return Mock(Session, name='request_session')


@mark.parametrize(
    'game, url, response, result',
    [
        param(
            Game.mw_mp,
            (
                'http://fake-host/api'
                '/platform/battle/gamer/test_user%231234/matches'
                '/mp/start/0/end/0/details'
            ),
            MATCH_1_IN,
            MATCH_1_OUT,
            id='Parse MP matches',
        ),
        param(
            Game.mw_wz,
            (
                'http://fake-host/api'
                '/platform/battle/gamer/test_user%231234/matches'
                '/wz/start/0/end/0/details'
            ),
            MATCH_2_IN,
            MATCH_2_OUT,
            id='Parse WZ matches',
Esempio n. 4
0
class TestAnyObject:
    @pytest.mark.parametrize(
        "type_,instance,matches",
        (
            # If you don't specify, we really don't care
            (None, object(), True),
            (None, ValueError(), True),
            (None, 1, True),
            # Actual classes to test for
            param(ValueError, ValueError(), True, id="Exact class ok"),
            param(BaseException, ValueError(), True, id="Subclass ok"),
            param(ValueError, object(), False, id="Superclass not ok"),
            param(ValueError, IndexError, False, id="Different class not ok"),
        ),
    )
    def test_it_matches_types_correctly(self, type_, instance, matches):
        if matches:
            assert instance == AnyObject(type_=type_)
        else:
            assert instance != AnyObject(type_=type_)

    @pytest.mark.parametrize(
        "attributes,matches",
        (
            ({
                "one": "one",
                "two": "two"
            }, True),
            ({
                "one": "one"
            }, True),
            ({
                "one": Any()
            }, True),
            ({}, True),
            ({
                "one": "one",
                "not_an_attr": ""
            }, False),
            ({
                "one": "BAD",
                "two": "two"
            }, False),
        ),
    )
    def test_it_matches_with_attributes_correctly(self, attributes, matches):
        other = ValueObject(one="one", two="two")

        if matches:
            assert other == AnyObject.with_attrs(attributes)
        else:
            assert other != AnyObject.with_attrs(attributes)

    @pytest.mark.parametrize("bad_input", (None, 1, []))
    def test_it_raise_ValueError_if_attributes_does_not_support_items(
            self, bad_input):
        with pytest.raises(ValueError):
            AnyObject.with_attrs(bad_input)

    def test_it_mocks_attributes(self):
        matcher = AnyObject.with_attrs({"a": "A"})

        assert matcher.a == "A"

        with pytest.raises(AttributeError):
            matcher.b  # pylint: disable=pointless-statement

    @pytest.mark.parametrize(
        "magic_method",
        [
            # All the magic methods we rely on post init
            "__eq__",
            "__getattr__",
            "__str__",
        ],
    )
    def test_setting_magic_methods_as_attributes_does_not_set_attributes(
            self, magic_method):
        # There's no sensible reason to do it, but we should still be able to
        # function normally if you do.

        weird_matcher = AnyObject.with_attrs({magic_method: "test"})

        result = getattr(weird_matcher, magic_method)

        assert callable(result)

    def test_type_and_attributes_at_once(self):
        matcher = AnyObject.of_type(ValueObject).with_attrs({"one": "one"})

        assert ValueObject("one", "two") == matcher
        assert NotValueObject("one", "two") != matcher
        assert ValueObject("bad", "two") != matcher

    @pytest.mark.parametrize(
        "type_,attributes,string",
        (
            (None, None, "<Any instance of 'object'>"),
            (ValueObject, None, "<Any instance of 'ValueObject'>"),
            (None, {
                "a": "b"
            }, "<Any instance of 'object' with attributes {'a': 'b'}>"),
            (
                ValueObject,
                {
                    "a": "b"
                },
                "<Any instance of 'ValueObject' with attributes {'a': 'b'}>",
            ),
        ),
    )
    def test_stringification(self, type_, attributes, string):
        matcher = AnyObject(type_=type_, attributes=attributes)

        assert str(matcher) == string
Esempio n. 5
0
class TestTokenCalls:
    @pytest.mark.parametrize(
        "json_data",
        (
            {
                "access_token": "test_access_token",
                "refresh_token": "test_refresh_token",
                "expires_in": 3600,
            },
            {
                "access_token": "test_access_token"
            },
        ),
    )
    def test_get_token(self, base_client, http_session, db_session,
                       pyramid_request, json_data):
        http_session.send.return_value = _make_response(json_data)

        base_client.get_token("authorization_code")

        http_session.send.assert_called_once_with(
            AnyRequest(
                method="POST",
                url=Any.url().with_path("login/oauth2/token").with_query({
                    "grant_type":
                    "authorization_code",
                    "client_id":
                    "developer_key",
                    "client_secret":
                    "developer_secret",
                    "redirect_uri":
                    base_client._redirect_uri,
                    "code":
                    "authorization_code",
                    "replace_tokens":
                    "True",
                }),
            ),
            timeout=Any(),
        )
        self._assert_token_is_saved_in_db(db_session, pyramid_request,
                                          json_data)

    # Add noise with an existing token to make sure we update it
    @pytest.mark.usefixtures("access_token")
    @pytest.mark.parametrize(
        "json_data",
        (
            {
                "access_token": "test_access_token",
                "refresh_token": "test_refresh_token",
                "expires_in": 3600,
            },
            {
                "access_token": "test_access_token"
            },
        ),
    )
    def test__get_refreshed_token(self, base_client, http_session, db_session,
                                  pyramid_request, json_data):
        http_session.send.return_value = _make_response(json_data)

        base_client._get_refreshed_token("new_refresh_token")

        http_session.send.assert_called_once_with(
            AnyRequest(
                method="POST",
                url=Any.url().with_path("login/oauth2/token").with_query({
                    "grant_type":
                    "refresh_token",
                    "client_id":
                    "developer_key",
                    "client_secret":
                    "developer_secret",
                    "refresh_token":
                    "new_refresh_token",
                }),
            ),
            timeout=Any(),
        )

        # We use our own refresh token if none was passed back from Canvas
        json_data.setdefault("refresh_token", "new_refresh_token")
        self._assert_token_is_saved_in_db(db_session, pyramid_request,
                                          json_data)

    @pytest.mark.parametrize(
        "json_data",
        (
            param({}, id="No access token"),
            param(
                {
                    "expires_in": -1,
                    "access_token": "irrelevant"
                },
                id="Negative expires in",
            ),
        ),
    )
    @pytest.mark.parametrize("method", ("get_token", "_get_refreshed_token"))
    def test_token_methods_raises_CanvasAPIServerError_for_bad_responses(
            self, http_session, base_client, method, json_data):
        http_session.send.return_value = _make_response(json_data)

        method = getattr(base_client, method)

        with pytest.raises(CanvasAPIServerError):
            method("token_value")

    def _assert_token_is_saved_in_db(self, db_session, pyramid_request,
                                     json_data):
        oauth2_token = db_session.query(OAuth2Token).one()

        assert oauth2_token.user_id == pyramid_request.lti_user.user_id
        assert oauth2_token.consumer_key == pyramid_request.lti_user.oauth_consumer_key

        assert oauth2_token.access_token == json_data["access_token"]
        assert oauth2_token.refresh_token == json_data.get("refresh_token")
        assert oauth2_token.expires_in == json_data.get("expires_in")

    @pytest.fixture(autouse=True)
    def ai_getter(self, ai_getter):
        ai_getter.developer_key.return_value = "developer_key"
        ai_getter.developer_secret.return_value = "developer_secret"

        return ai_getter