def runTest(self):
        params = ["GET", "/", "", {}, "", "", "", {}]

        for i in range(len(params)):
            if not isinstance(params[i], string_types):
                continue
            args = params[:i] + [None] + params[i+1:]
            try:
                sigv4.AWSSigV4Verifier(*args)
                self.fail("Expected TypeError")
            except TypeError:
                pass

        try:
            sigv4.AWSSigV4Verifier("GET", "/", "", {"Host": 7}, "", "",
                                   "", {})
            self.fail("Expected TypeError")
        except TypeError:
            pass

        try:
            sigv4.AWSSigV4Verifier("GET", "/", "", {0: "Foo"}, "", "",
                                   "", {})
            self.fail("Expected TypeError")
        except TypeError:
            pass
Exemple #2
0
    def test_headers(self):
        with self.assertRaises(TypeError):
            sigv4.AWSSigV4Verifier(headers=None)

        with self.assertRaises(TypeError):
            sigv4.AWSSigV4Verifier(headers={"Host": 0})

        with self.assertRaises(TypeError):
            sigv4.AWSSigV4Verifier(headers={0: "Foo"})
Exemple #3
0
    def test_timestamp_mismatch(self):
        with self.assertRaises(TypeError):
            sigv4.AWSSigV4Verifier(timestamp_mismatch="Hello")

        with self.assertRaises(ValueError):
            sigv4.AWSSigV4Verifier(timestamp_mismatch=-1)
Exemple #4
0
 def test_service(self):
     with self.assertRaises(TypeError):
         sigv4.AWSSigV4Verifier(service=None)
Exemple #5
0
 def test_region(self):
     with self.assertRaises(TypeError):
         sigv4.AWSSigV4Verifier(region=None)
Exemple #6
0
    def test_body(self):
        with self.assertRaises(TypeError):
            sigv4.AWSSigV4Verifier(body=None)

        with self.assertRaises(TypeError):
            sigv4.AWSSigV4Verifier(body=u"Hello")
Exemple #7
0
 def test_query_string(self):
     with self.assertRaises(TypeError):
         sigv4.AWSSigV4Verifier(query_string=None)
Exemple #8
0
 def test_uri_path(self):
     with self.assertRaises(TypeError):
         sigv4.AWSSigV4Verifier(uri_path=None)
Exemple #9
0
 def test_request_method(self):
     with self.assertRaises(TypeError):
         sigv4.AWSSigV4Verifier(request_method=None)
Exemple #10
0
    def verify(self,
               method,
               url,
               body,
               timestamp,
               headers,
               signed_headers,
               timestamp_mismatch=60,
               bad=False,
               scope=None,
               quote_chars=False,
               fix_qp=True):
        date = timestamp[:8]
        credential_scope = "/".join([date, region, service, "aws4_request"])

        if scope is None:
            scope = access_key + "/" + credential_scope
        if "?" in url:
            uri, query_string = url.split("?", 1)
        else:
            uri = url
            query_string = ""

        if not fix_qp:
            scope = scope.replace("/", "%2F")

        normalized_uri = sub("//+", "/", uri)

        query_params = [
            "X-Amz-Algorithm=AWS4-HMAC-SHA256", "X-Amz-Credential=" + scope,
            "X-Amz-Date=" + timestamp,
            "X-Amz-SignedHeaders=" + ";".join(signed_headers)
        ]

        if query_string:
            query_params.extend(query_string.split("&"))

        def fixup_qp(qp):
            result = cStringIO()
            key, value = qp.split("=", 1)
            for c in value:
                if c in allowed_qp:
                    result.write(c)
                else:
                    result.write("%%%02X" % ord(c))

            return key + "=" + result.getvalue()

        if fix_qp:
            canonical_query_string = "&".join(
                sorted(map(fixup_qp, [qp for qp in query_params if qp])))
        else:
            canonical_query_string = "&".join(sorted(query_params))

        canonical_headers = "".join([
            (header + ":" + ",".join(headers[header]) + "\n")
            for header in sorted(signed_headers)
        ])

        canonical_req = (method + "\n" + normalized_uri + "\n" +
                         canonical_query_string + "\n" + canonical_headers +
                         "\n" + ";".join(signed_headers) + "\n" +
                         sha256(body).hexdigest())

        string_to_sign = ("AWS4-HMAC-SHA256\n" + timestamp + "\n" +
                          credential_scope + "\n" +
                          sha256(canonical_req.encode("utf-8")).hexdigest())

        def sign(secret, value):
            return hmac.new(secret, value.encode("utf-8"), sha256).digest()

        k_date = sign(b"AWS4" + secret_key.encode("utf-8"), date)
        k_region = sign(k_date, region)
        k_service = sign(k_region, service)
        k_signing = sign(k_service, "aws4_request")
        signature = hmac.new(k_signing, string_to_sign.encode("utf-8"),
                             sha256).hexdigest()

        query_params.append("X-Amz-Signature=" + signature)

        if quote_chars:
            bad_qp = []

            for qp in query_params:
                result = cStringIO()

                for c in qp:
                    if c.isalpha():
                        result.write("%%%02X" % ord(c))
                    else:
                        result.write(c)

                bad_qp.append(result.getvalue())
            query_params = bad_qp

        v = sigv4.AWSSigV4Verifier(request_method=method,
                                   uri_path=uri,
                                   query_string="&".join(query_params),
                                   headers=headers,
                                   body=body,
                                   region=region,
                                   service=service,
                                   key_mapping=key_mapping,
                                   timestamp_mismatch=timestamp_mismatch)

        if not bad:
            self.assertEqual(v.canonical_request, canonical_req)
            self.assertEqual(v.string_to_sign, string_to_sign)
        v.verify()
        return
Exemple #11
0
    def run_sigv4_case(self, filebase, tweak=""):
        filebase = self.basedir + filebase

        with open(filebase + ".sreq", "rb") as fd:
            method_line = fd.readline().strip()
            if isinstance(method_line, binary_type):
                method_line = method_line.decode("utf-8")
            headers = {}

            last_header = None

            while True:
                line = fd.readline()
                if line in (
                        b"\n",
                        b"",
                ):
                    break

                line = line.decode("utf-8")
                if line.startswith(" ") or line.startswith("\t"):
                    assert last_header is not None
                    header = last_header
                    value = line.strip()
                else:
                    try:
                        header, value = line.split(":", 1)
                    except ValueError as e:
                        raise ValueError("Invalid header line: %s" % line)
                    key = header.lower()
                    value = value.strip()
                    last_header = header

                if key == "authorization":
                    if tweak == remove_auth:
                        continue
                    elif tweak == wrong_authtype:
                        value = "XX" + value
                    elif tweak == clobber_sig_equals:
                        value = value.replace("Signature=", "Signature")
                    elif tweak == delete_credential:
                        value = value.replace("Credential=", "Foo=")
                    elif tweak == delete_signature:
                        value = value.replace("Signature=", "Foo=")
                    elif tweak == dup_signature:
                        value += ", Signature=foo"
                elif key in (
                        "date",
                        "x-amz-date",
                ):
                    if tweak == delete_date:
                        continue

                if key in headers:
                    headers[key].append(value)
                else:
                    headers[key] = [value]

            body = fd.read()

            first_space = method_line.find(" ")
            last_space = method_line.rfind(" ")

            method = method_line[:first_space]
            uri_path = method_line[first_space + 1:last_space]

            qpos = uri_path.find("?")
            if qpos == -1:
                query_string = ""
            else:
                query_string = uri_path[qpos + 1:]
                uri_path = uri_path[:qpos]

        with open(filebase + ".creq", "r") as fd:
            canonical_request = fd.read().replace("\r", "")

        with open(filebase + ".sts", "r") as fd:
            string_to_sign = fd.read().replace("\r", "")

        v = sigv4.AWSSigV4Verifier(request_method=method,
                                   uri_path=uri_path,
                                   query_string=query_string,
                                   headers=headers,
                                   body=body,
                                   region=region,
                                   service=service,
                                   key_mapping=key_mapping,
                                   timestamp_mismatch=None)

        if tweak:
            try:
                v.verify()
                self.fail("Expected verify() to throw an InvalidSignature "
                          "error for tweak %s" % tweak)
            except sigv4.InvalidSignatureError:
                pass
        else:
            self.assertEqual(
                v.canonical_request, canonical_request,
                "Canonical request mismatch in %s\nExpected: %r\nReceived: %r"
                % (filebase, canonical_request, v.canonical_request))
            self.assertEqual(
                v.string_to_sign, string_to_sign,
                "String to sign mismatch in %s\nExpected: %r\nReceived: %r" %
                (filebase, string_to_sign, v.string_to_sign))
            v.verify()
    def runTest(self):
        with open(self.filebase + ".sreq", "rb") as fd:
            method_line = fd.readline().strip()
            if isinstance(method_line, binary_type):
                method_line = method_line.decode("utf-8")
            headers = {}

            while True:
                line = fd.readline()
                if line in (b"\r\n", b""):
                    break

                self.assertTrue(line.endswith(b"\r\n"))
                line = line.decode("utf-8")
                header, value = line[:-2].split(":", 1)
                key = header.lower()
                value = value.strip()

                if key == "authorization":
                    if self.tweaks == remove_auth:
                        continue
                    elif self.tweaks == wrong_authtype:
                        value = "XX" + value
                    elif self.tweaks == clobber_sig_equals:
                        value = value.replace("Signature=", "Signature")
                    elif self.tweaks == delete_credential:
                        value = value.replace("Credential=", "Foo=")
                    elif self.tweaks == delete_signature:
                        value = value.replace("Signature=", "Foo=")
                    elif self.tweaks == dup_signature:
                        value += ", Signature=foo"
                elif key == "date":
                    if self.tweaks == delete_date:
                        continue
                
                if key in headers:
                    headers[key].append(value)
                else:
                    headers[key] = [value]

            headers = dict([(key, ",".join(sorted(values)))
                            for key, values in iteritems(headers)])
            body = fd.read()

            first_space = method_line.find(" ")
            second_space = method_line.find(" ", first_space + 1)
            
            method = method_line[:first_space]
            uri_path = method_line[first_space + 1:second_space]

            qpos = uri_path.find("?")
            if qpos == -1:
                query_string = ""
            else:
                query_string = uri_path[qpos+1:]
                uri_path = uri_path[:qpos]

        with open(self.filebase + ".creq", "r") as fd:
            canonical_request = fd.read().replace("\r", "")

        with open(self.filebase + ".sts", "r") as fd:
            string_to_sign = fd.read().replace("\r", "")

        v = sigv4.AWSSigV4Verifier(
            method, uri_path, query_string, headers, body, region, service,
            key_mapping, None)

        if self.tweaks:
            try:
                v.verify()
                self.fail("Expected verify() to throw an InvalidSignature "
                          "error")
            except sigv4.InvalidSignatureError:
                pass
        else:
            self.assertEqual(
                v.canonical_request, canonical_request,
                "Canonical request mismatch in %s\nExpected: %r\nReceived: %r" %
                (self.filebase, canonical_request, v.canonical_request))
            self.assertEqual(
                v.string_to_sign, string_to_sign,
                "String to sign mismatch in %s\nExpected: %r\nReceived: %r" %
                (self.filebase, string_to_sign, v.string_to_sign))
            v.verify()

        return