def test_invokes_digest_validator(self): start_date = START_DATE end_date = END_DATE key_name = end_date.strftime(DATE_FORMAT) + '.json.gz' digest = {'digestPublicKeyFingerprint': 'a', 'digestS3Bucket': '1', 'digestS3Object': key_name, 'previousDigestSignature': '...', 'digestStartTime': (end_date - timedelta(hours=1)).strftime( DATE_FORMAT), 'digestEndTime': end_date.strftime(DATE_FORMAT)} digest_provider = Mock() digest_provider.load_digest_keys_in_range.return_value = [ key_name] digest_provider.fetch_digest.return_value = (digest, key_name) key_provider = Mock() public_keys = {'a': {'Fingerprint': 'a', 'Value': 'a'}} key_provider.get_public_keys.return_value = public_keys digest_validator = Mock() traverser = DigestTraverser( digest_provider=digest_provider, starting_bucket='1', starting_prefix='baz', public_key_provider=key_provider, digest_validator=digest_validator) digest_iter = traverser.traverse(start_date, end_date) self.assertEqual(digest, next(digest_iter)) digest_validator.validate.assert_called_with( '1', key_name, public_keys['a']['Value'], digest, key_name)
def test_invokes_cb_and_continues_when_missing(self): start_date = START_DATE end_date = END_DATE key_provider, digest_provider, validator = create_scenario( ['gap', 'link', 'missing', 'link']) on_missing, missing_calls = collecting_callback() traverser = DigestTraverser( digest_provider=digest_provider, starting_bucket='1', starting_prefix='baz', public_key_provider=key_provider, digest_validator=validator, on_missing=on_missing) collected = list(traverser.traverse(start_date, end_date)) self.assertEquals(3, len(collected)) self.assertEqual(1, key_provider.get_public_keys.call_count) self.assertEquals(1, len(missing_calls)) # Ensure the keys were provided in the correct order. self.assertIn('bucket', missing_calls[0]) self.assertIn('next_end_date', missing_calls[0]) # Ensure the keys were provided in the correct order. self.assertEqual(digest_provider.digests[1], missing_calls[0]['next_key']) self.assertEqual(digest_provider.digests[2], missing_calls[0]['last_key']) # Ensure the provider was called correctly self.assertEqual(1, key_provider.get_public_keys.call_count) self.assertEqual( 1, len(digest_provider.calls['load_digest_keys_in_range'])) self.assertEqual(4, len(digest_provider.calls['fetch_digest']))
def test_ensures_public_key_is_found(self): start_date = START_DATE end_date = END_DATE key_name = end_date.strftime(DATE_FORMAT) + '.json.gz' region = 'us-west-2' digest_provider = Mock() digest_provider.trail_home_region = region digest_provider.load_digest_keys_in_range.return_value = [key_name] digest_provider.fetch_digest.return_value = ( {'digestEndTime': 'foo', 'digestStartTime': 'foo', 'awsAccountId': 'account', 'digestPublicKeyFingerprint': 'abc', 'digestS3Bucket': '1', 'digestS3Object': key_name, 'previousDigestSignature': 'xyz'}, 'abc' ) key_provider = Mock() key_provider.get_public_keys.return_value = [{'Fingerprint': 'a'}] on_invalid, calls = collecting_callback() traverser = DigestTraverser( digest_provider=digest_provider, starting_bucket='1', starting_prefix='baz', public_key_provider=key_provider, on_invalid=on_invalid) digest_iter = traverser.traverse(start_date, end_date) with self.assertRaises(StopIteration): next(digest_iter) self.assertEqual(1, len(calls)) self.assertEqual( ('Digest file\ts3://1/%s\tINVALID: public key not ' 'found in region %s for fingerprint abc' % (key_name, region)), calls[0]['message'])
def test_invokes_cb_and_continues_when_invalid(self): start_date = START_DATE end_date = END_DATE key_provider, digest_provider, validator = create_scenario( ['gap', 'link', 'invalid', 'link', 'invalid']) on_invalid, invalid_calls = collecting_callback() traverser = DigestTraverser( digest_provider=digest_provider, starting_bucket='1', starting_prefix='baz', public_key_provider=key_provider, digest_validator=validator, on_invalid=on_invalid) collected = list(traverser.traverse(start_date, end_date)) self.assertEquals(3, len(collected)) self.assertEqual(1, key_provider.get_public_keys.call_count) self.assertEquals(2, len(invalid_calls)) # Ensure it was invoked with all the kwargs we expected. self.assertIn('bucket', invalid_calls[0]) self.assertIn('next_end_date', invalid_calls[0]) # Ensure the keys were provided in the correct order. self.assertEqual(digest_provider.digests[4], invalid_calls[0]['last_key']) self.assertEqual(digest_provider.digests[3], invalid_calls[0]['next_key']) self.assertEqual(digest_provider.digests[2], invalid_calls[1]['last_key']) self.assertEqual(digest_provider.digests[1], invalid_calls[1]['next_key']) # Ensure the provider was called correctly self.assertEqual(1, key_provider.get_public_keys.call_count) self.assertEqual( 1, len(digest_provider.calls['load_digest_keys_in_range'])) self.assertEqual(5, len(digest_provider.calls['fetch_digest']))
def test_does_not_hard_fail_on_invalid_signature(self): start_date = START_DATE end_date = END_DATE end_timestamp = end_date.strftime(DATE_FORMAT) + ".json.gz" digest = { "digestPublicKeyFingerprint": "a", "digestS3Bucket": "1", "digestS3Object": end_timestamp, "previousDigestSignature": "...", "digestStartTime": (end_date - timedelta(hours=1)).strftime(DATE_FORMAT), "digestEndTime": end_timestamp, "_signature": "123", } digest_provider = Mock() digest_provider.load_digest_keys_in_range.return_value = [end_timestamp] digest_provider.fetch_digest.return_value = (digest, end_timestamp) key_provider = Mock() public_keys = {"a": {"Fingerprint": "a", "Value": "a"}} key_provider.get_public_keys.return_value = public_keys digest_validator = Sha256RSADigestValidator() on_invalid, calls = collecting_callback() traverser = DigestTraverser( digest_provider=digest_provider, starting_bucket="1", starting_prefix="baz", public_key_provider=key_provider, digest_validator=digest_validator, on_invalid=on_invalid, ) digest_iter = traverser.traverse(start_date, end_date) next(digest_iter, None) self.assertEquals("Digest file\ts3://1/%s\tINVALID: Incorrect padding" % end_timestamp, calls[0]["message"])
def test_does_not_hard_fail_on_invalid_signature(self): start_date = START_DATE end_date = END_DATE end_timestamp = end_date.strftime(DATE_FORMAT) + '.json.gz' digest = {'digestPublicKeyFingerprint': 'a', 'digestS3Bucket': '1', 'digestS3Object': end_timestamp, 'previousDigestSignature': '...', 'digestStartTime': (end_date - timedelta(hours=1)).strftime( DATE_FORMAT), 'digestEndTime': end_timestamp, '_signature': '123'} digest_provider = Mock() digest_provider.load_digest_keys_in_range.return_value = [ end_timestamp] digest_provider.fetch_digest.return_value = (digest, end_timestamp) key_provider = Mock() public_keys = {'a': {'Fingerprint': 'a', 'Value': 'a'}} key_provider.get_public_keys.return_value = public_keys digest_validator = Sha256RSADigestValidator() on_invalid, calls = collecting_callback() traverser = DigestTraverser( digest_provider=digest_provider, starting_bucket='1', starting_prefix='baz', public_key_provider=key_provider, digest_validator=digest_validator, on_invalid=on_invalid) digest_iter = traverser.traverse(start_date, end_date) next(digest_iter, None) self.assertEquals( 'Digest file\ts3://1/%s\tINVALID: Incorrect padding' % end_timestamp, calls[0]['message'])
def test_invokes_cb_and_continues_when_gap(self): start_date = START_DATE end_date = END_DATE key_provider, digest_provider, validator = create_scenario(["gap", "link", "gap", "gap"]) on_gap, gap_calls = collecting_callback() traverser = DigestTraverser( digest_provider=digest_provider, starting_bucket="1", starting_prefix="baz", public_key_provider=key_provider, digest_validator=validator, on_gap=on_gap, ) collected = list(traverser.traverse(start_date, end_date)) self.assertEquals(4, len(collected)) self.assertEqual(1, key_provider.get_public_keys.call_count) self.assertEquals(2, len(gap_calls)) # Ensure it was invoked with all the kwargs we expected. self.assertIn("bucket", gap_calls[0]) self.assertIn("next_key", gap_calls[0]) self.assertIn("next_end_date", gap_calls[0]) self.assertIn("last_key", gap_calls[0]) self.assertIn("last_start_date", gap_calls[0]) # Ensure the keys were provided in the correct order. self.assertEqual(digest_provider.digests[3], gap_calls[0]["last_key"]) self.assertEqual(digest_provider.digests[2], gap_calls[0]["next_key"]) self.assertEqual(digest_provider.digests[2], gap_calls[1]["last_key"]) self.assertEqual(digest_provider.digests[1], gap_calls[1]["next_key"]) # Ensure the provider was called correctly self.assertEqual(1, key_provider.get_public_keys.call_count) self.assertEqual(1, len(digest_provider.calls["load_digest_keys_in_range"])) self.assertEqual(4, len(digest_provider.calls["fetch_digest"]))
def test_ensures_digest_from_same_location_as_json_contents(self): start_date = START_DATE end_date = END_DATE callback, collected = collecting_callback() key_name = end_date.strftime(DATE_FORMAT) + ".json.gz" digest = { "digestPublicKeyFingerprint": "a", "digestS3Bucket": "not_same", "digestS3Object": key_name, "digestEndTime": end_date.strftime(DATE_FORMAT), } digest_provider = Mock() digest_provider.load_digest_keys_in_range.return_value = [key_name] digest_provider.fetch_digest.return_value = (digest, key_name) key_provider = Mock() digest_validator = Mock() traverser = DigestTraverser( digest_provider=digest_provider, starting_bucket="1", starting_prefix="baz", public_key_provider=key_provider, digest_validator=digest_validator, on_invalid=callback, ) digest_iter = traverser.traverse(start_date, end_date) self.assertIsNone(next(digest_iter, None)) self.assertEqual(1, len(collected)) self.assertEqual("Digest file\ts3://1/%s\tINVALID: invalid format" % key_name, collected[0]["message"])
def test_does_not_hard_fail_on_invalid_signature(self): start_date = START_DATE end_date = END_DATE digest = {'digestPublicKeyFingerprint': 'a', 'digestS3Bucket': '1', 'digestS3Object': 'abc', 'previousDigestSignature': '...', 'digestStartTime': (end_date - timedelta(hours=1)).strftime( DATE_FORMAT), 'digestEndTime': end_date.strftime(DATE_FORMAT), '_signature': '123'} digest_provider = Mock() digest_provider.load_digest_keys_in_range.return_value = ['abc'] digest_provider.fetch_digest.return_value = (digest, 'abc') key_provider = Mock() public_keys = {'a': {'Fingerprint': 'a', 'Value': 'a'}} key_provider.get_public_keys.return_value = public_keys digest_validator = Sha256RSADigestValidator() on_invalid, calls = collecting_callback() traverser = DigestTraverser( digest_provider=digest_provider, starting_bucket='1', starting_prefix='baz', public_key_provider=key_provider, digest_validator=digest_validator, on_invalid=on_invalid) digest_iter = traverser.traverse(start_date, end_date) next(digest_iter, None) self.assertEquals('Digest file\ts3://1/abc\tINVALID: Incorrect padding', calls[0]['message'])
def test_invokes_digest_validator(self): start_date = START_DATE end_date = END_DATE key_name = end_date.strftime(DATE_FORMAT) + ".json.gz" digest = { "digestPublicKeyFingerprint": "a", "digestS3Bucket": "1", "digestS3Object": key_name, "previousDigestSignature": "...", "digestStartTime": (end_date - timedelta(hours=1)).strftime(DATE_FORMAT), "digestEndTime": end_date.strftime(DATE_FORMAT), } digest_provider = Mock() digest_provider.load_digest_keys_in_range.return_value = [key_name] digest_provider.fetch_digest.return_value = (digest, key_name) key_provider = Mock() public_keys = {"a": {"Fingerprint": "a", "Value": "a"}} key_provider.get_public_keys.return_value = public_keys digest_validator = Mock() traverser = DigestTraverser( digest_provider=digest_provider, starting_bucket="1", starting_prefix="baz", public_key_provider=key_provider, digest_validator=digest_validator, ) digest_iter = traverser.traverse(start_date, end_date) self.assertEqual(digest, next(digest_iter)) digest_validator.validate.assert_called_with("1", key_name, public_keys["a"]["Value"], digest, key_name)
def test_ensures_digest_from_same_location_as_json_contents(self): start_date = START_DATE end_date = END_DATE callback, collected = collecting_callback() key_name = end_date.strftime(DATE_FORMAT) + '.json.gz' digest = { 'digestPublicKeyFingerprint': 'a', 'digestS3Bucket': 'not_same', 'digestS3Object': key_name, 'digestEndTime': end_date.strftime(DATE_FORMAT) } digest_provider = Mock() digest_provider.load_digest_keys_in_range.return_value = [key_name] digest_provider.fetch_digest.return_value = (digest, key_name) key_provider = Mock() digest_validator = Mock() traverser = DigestTraverser(digest_provider=digest_provider, starting_bucket='1', starting_prefix='baz', public_key_provider=key_provider, digest_validator=digest_validator, on_invalid=callback) digest_iter = traverser.traverse(start_date, end_date) self.assertIsNone(next(digest_iter, None)) self.assertEqual(1, len(collected)) self.assertEqual( 'Digest file\ts3://1/%s\tINVALID: invalid format' % key_name, collected[0]['message'])
def test_invokes_digest_validator(self): start_date = START_DATE end_date = END_DATE digest = {'digestPublicKeyFingerprint': 'a', 'digestS3Bucket': '1', 'digestS3Object': 'abc', 'previousDigestSignature': '...', 'digestStartTime': (end_date - timedelta(hours=1)).strftime( DATE_FORMAT), 'digestEndTime': end_date.strftime(DATE_FORMAT)} digest_provider = Mock() digest_provider.load_digest_keys_in_range.return_value = ['abc'] digest_provider.fetch_digest.return_value = (digest, 'abc') key_provider = Mock() public_keys = {'a': {'Fingerprint': 'a', 'Value': 'a'}} key_provider.get_public_keys.return_value = public_keys digest_validator = Mock() traverser = DigestTraverser( digest_provider=digest_provider, starting_bucket='1', starting_prefix='baz', public_key_provider=key_provider, digest_validator=digest_validator) digest_iter = traverser.traverse(start_date, end_date) self.assertEqual(digest, next(digest_iter)) digest_validator.validate.assert_called_with( '1', 'abc', public_keys['a']['Value'], digest, 'abc')
def test_ensures_public_key_is_found(self): start_date = START_DATE end_date = END_DATE digest_provider = Mock() digest_provider.load_digest_keys_in_range.return_value = ['abc'] digest_provider.fetch_digest.return_value = ( {'digestEndTime': 'foo', 'digestStartTime': 'foo', 'awsAccountId': 'account', 'digestPublicKeyFingerprint': 'abc', 'digestS3Bucket': '1', 'digestS3Object': 'abc', 'previousDigestSignature': 'xyz'}, 'abc' ) key_provider = Mock() key_provider.get_public_keys.return_value = [{'Fingerprint': 'a'}] on_invalid, calls = collecting_callback() traverser = DigestTraverser( digest_provider=digest_provider, starting_bucket='1', starting_prefix='baz', public_key_provider=key_provider, on_invalid=on_invalid) digest_iter = traverser.traverse(start_date, end_date) with self.assertRaises(StopIteration): next(digest_iter) self.assertEqual(1, len(calls)) self.assertEqual(('Digest file\ts3://1/abc\tINVALID: public key not ' 'found for fingerprint abc'), calls[0]['message'])
def test_ensures_public_keys_are_loaded(self): start_date = START_DATE end_date = END_DATE digest_provider = Mock() key_provider = Mock() key_provider.get_public_keys.return_value = [] traverser = DigestTraverser(digest_provider=digest_provider, starting_bucket='1', starting_prefix='baz', public_key_provider=key_provider) digest_iter = traverser.traverse(start_date, end_date) with self.assertRaises(RuntimeError): next(digest_iter) key_provider.get_public_keys.assert_called_with(start_date, end_date)
def test_ensures_public_keys_are_loaded(self): start_date = START_DATE end_date = END_DATE digest_provider = Mock() key_provider = Mock() key_provider.get_public_keys.return_value = [] traverser = DigestTraverser( digest_provider=digest_provider, starting_bucket='1', starting_prefix='baz', public_key_provider=key_provider) digest_iter = traverser.traverse(start_date, end_date) with self.assertRaises(RuntimeError): next(digest_iter) key_provider.get_public_keys.assert_called_with( start_date, end_date)
def test_loads_digests_in_range(self): start_date = START_DATE end_date = START_DATE + timedelta(hours=5) key_provider, digest_provider, validator = create_scenario( ['gap', 'link', 'link', 'link']) traverser = DigestTraverser( digest_provider=digest_provider, starting_bucket='1', starting_prefix='baz', public_key_provider=key_provider, digest_validator=validator) collected = list(traverser.traverse(start_date, end_date)) self.assertEqual(1, key_provider.get_public_keys.call_count) self.assertEqual( 1, len(digest_provider.calls['load_digest_keys_in_range'])) self.assertEqual(4, len(digest_provider.calls['fetch_digest'])) self.assertEquals(4, len(collected))
def test_loads_digests_in_range(self): start_date = START_DATE end_date = START_DATE + timedelta(hours=5) key_provider, digest_provider, validator = create_scenario(["gap", "link", "link", "link"]) traverser = DigestTraverser( digest_provider=digest_provider, starting_bucket="1", starting_prefix="baz", public_key_provider=key_provider, digest_validator=validator, ) collected = list(traverser.traverse(start_date, end_date)) self.assertEqual(1, key_provider.get_public_keys.call_count) self.assertEqual(1, len(digest_provider.calls["load_digest_keys_in_range"])) self.assertEqual(4, len(digest_provider.calls["fetch_digest"])) self.assertEquals(4, len(collected))
def test_reloads_objects_on_bucket_change(self): start_date = START_DATE end_date = END_DATE key_provider, digest_provider, validator = create_scenario(["gap", "link", "bucket_change", "link"]) traverser = DigestTraverser( digest_provider=digest_provider, starting_bucket="1", starting_prefix="baz", public_key_provider=key_provider, digest_validator=validator, ) collected = list(traverser.traverse(start_date, end_date)) self.assertEquals(4, len(collected)) self.assertEqual(1, key_provider.get_public_keys.call_count) # Ensure the provider was called correctly self.assertEqual(1, key_provider.get_public_keys.call_count) self.assertEqual(2, len(digest_provider.calls["load_digest_keys_in_range"])) self.assertEquals(["1", "1", "2", "2"], [c["digestS3Bucket"] for c in collected])
def test_reloads_objects_on_bucket_change(self): start_date = START_DATE end_date = END_DATE key_provider, digest_provider, validator = create_scenario( ['gap', 'link', 'bucket_change', 'link']) traverser = DigestTraverser( digest_provider=digest_provider, starting_bucket='1', starting_prefix='baz', public_key_provider=key_provider, digest_validator=validator) collected = list(traverser.traverse(start_date, end_date)) self.assertEquals(4, len(collected)) self.assertEqual(1, key_provider.get_public_keys.call_count) # Ensure the provider was called correctly self.assertEqual(1, key_provider.get_public_keys.call_count) self.assertEqual( 2, len(digest_provider.calls['load_digest_keys_in_range'])) self.assertEquals(['1', '1', '2', '2'], [c['digestS3Bucket'] for c in collected])
def test_ensures_digest_from_same_location_as_json_contents(self): start_date = START_DATE end_date = END_DATE callback, collected = collecting_callback() digest = {'digestPublicKeyFingerprint': 'a', 'digestS3Bucket': 'not_same', 'digestS3Object': 'abc', 'digestEndTime': end_date.strftime(DATE_FORMAT)} digest_provider = Mock() digest_provider.load_digest_keys_in_range.return_value = ['abc'] digest_provider.fetch_digest.return_value = (digest, 'abc') key_provider = Mock() digest_validator = Mock() traverser = DigestTraverser( digest_provider=digest_provider, starting_bucket='1', starting_prefix='baz', public_key_provider=key_provider, digest_validator=digest_validator, on_invalid=callback) digest_iter = traverser.traverse(start_date, end_date) self.assertIsNone(next(digest_iter, None)) self.assertEqual(1, len(collected)) self.assertEqual( 'Digest file\ts3://1/abc\tINVALID: invalid format', collected[0]['message'])
def test_ensures_public_key_is_found(self): start_date = START_DATE end_date = END_DATE key_name = end_date.strftime(DATE_FORMAT) + ".json.gz" digest_provider = Mock() digest_provider.load_digest_keys_in_range.return_value = [key_name] digest_provider.fetch_digest.return_value = ( { "digestEndTime": "foo", "digestStartTime": "foo", "awsAccountId": "account", "digestPublicKeyFingerprint": "abc", "digestS3Bucket": "1", "digestS3Object": key_name, "previousDigestSignature": "xyz", }, "abc", ) key_provider = Mock() key_provider.get_public_keys.return_value = [{"Fingerprint": "a"}] on_invalid, calls = collecting_callback() traverser = DigestTraverser( digest_provider=digest_provider, starting_bucket="1", starting_prefix="baz", public_key_provider=key_provider, on_invalid=on_invalid, ) digest_iter = traverser.traverse(start_date, end_date) with self.assertRaises(StopIteration): next(digest_iter) self.assertEqual(1, len(calls)) self.assertEqual( ("Digest file\ts3://1/%s\tINVALID: public key not " "found for fingerprint abc" % key_name), calls[0]["message"], )