def testDuplicateTimestamps(self): collection = self._TestCollection( "aff4:/sequential_collection/testDuplicateTimestamps") t = rdfvalue.RDFDatetime.Now() with data_store.DB.GetMutationPool() as pool: for i in range(10): ts = collection.Add( rdfvalue.RDFInteger(i), timestamp=t, mutation_pool=pool) self.assertEqual(ts[0], t) i = 0 for (ts, _) in collection.Scan(): self.assertEqual(ts, t) i += 1 self.assertEqual(i, 10)
def testAddScan(self): collection = self._TestCollection( "aff4:/sequential_collection/testAddScan") with data_store.DB.GetMutationPool() as pool: for i in range(100): collection.Add(rdfvalue.RDFInteger(i), mutation_pool=pool) i = 0 last_ts = 0 for (ts, v) in collection.Scan(): last_ts = ts self.assertEqual(i, v) i += 1 with data_store.DB.GetMutationPool() as pool: for j in range(100): collection.Add(rdfvalue.RDFInteger(j + 100), mutation_pool=pool) for (ts, v) in collection.Scan(after_timestamp=last_ts): self.assertEqual(i, v) i += 1 self.assertEqual(i, 200)
def testListing(self): test_urn = "aff4:/sequential_collection/testIndexedListing" collection = self._TestCollection(test_urn) timestamps = [] with data_store.DB.GetMutationPool() as pool: for i in range(100): timestamps.append( collection.Add(rdfvalue.RDFInteger(i), mutation_pool=pool)) with test_lib.Instrument(sequential_collection.SequentialCollection, "Scan") as scan: self.assertLen(list(collection), 100) # Listing should be done using a single scan but there is another one # for calculating the length. self.assertEqual(scan.call_count, 2)
def testStaticAddGet(self): aff4_path = "aff4:/sequential_collection/testStaticAddGet" collection = self._TestCollection(aff4_path) self.assertEqual(collection.CalculateLength(), 0) with data_store.DB.GetMutationPool() as pool: for i in range(100): TestIndexedSequentialCollection.StaticAdd( rdfvalue.RDFURN(aff4_path), rdfvalue.RDFInteger(i), mutation_pool=pool) for i in range(100): self.assertEqual(collection[i], i) self.assertEqual(collection.CalculateLength(), 100) self.assertLen(collection, 100)
def testAnyValueWithoutTypeCallback(self): test_pb = AnyValueWithoutTypeFunctionTest() for value_to_assign in [ rdfvalue.RDFString("test"), rdfvalue.RDFInteger(1234), rdfvalue.RDFBytes(b"abc"), rdf_flows.GrrStatus(status="WORKER_STUCK", error_message="stuck") ]: test_pb.dynamic = value_to_assign serialized = test_pb.SerializeToString() self.assertEqual( AnyValueWithoutTypeFunctionTest.FromSerializedString( serialized), test_pb)
def testClaimFiltersByStartTime(self): queue_urn = "aff4:/queue_test/testClaimFiltersByStartTime" middle = None with self.pool: with aff4.FACTORY.Create(queue_urn, TestQueue, token=self.token) as queue: for i in range(100): if i == 50: middle = rdfvalue.RDFDatetime.Now() queue.Add(rdfvalue.RDFInteger(i), mutation_pool=self.pool) with aff4.FACTORY.OpenWithLock( queue_urn, lease_time=200, token=self.token) as queue: results = queue.ClaimRecords(start_time=middle) self.assertLen(results, 50) self.assertEqual(50, results[0].value)
def testClaimReturnsRecordsInOrder(self): queue_urn = "aff4:/queue_test/testClaimReturnsRecordsInOrder" with self.pool: with aff4.FACTORY.Create(queue_urn, TestQueue, token=self.token) as queue: for i in range(100): queue.Add(rdfvalue.RDFInteger(i), mutation_pool=self.pool) data_store.DB.Flush() with aff4.FACTORY.OpenWithLock( queue_urn, lease_time=200, token=self.token) as queue: results = queue.ClaimRecords() self.assertLen(results, 100) self.assertEqual(0, results[0].value) self.assertEqual(99, results[99].value)
def testLengthIsReportedCorrectlyForEveryType(self): with self.pool: for i in range(99): self.collection.Add( rdf_flows.GrrMessage(payload=rdfvalue.RDFInteger(i)), mutation_pool=self.pool) for i in range(101): self.collection.Add(rdf_flows.GrrMessage( payload=rdfvalue.RDFString(unicode(i))), mutation_pool=self.pool) self.assertEqual( 99, self.collection.LengthByType(rdfvalue.RDFInteger.__name__)) self.assertEqual( 101, self.collection.LengthByType(rdfvalue.RDFString.__name__))
def testReceiveMessages(self): fs_server = fleetspeak_frontend_server.GRRFSServer() client_id = "C.1234567890123456" flow_id = "12345678" data_store.REL_DB.WriteClientMetadata(client_id, fleetspeak_enabled=True) rdf_flow = rdf_flow_objects.Flow( client_id=client_id, flow_id=flow_id, create_time=rdfvalue.RDFDatetime.Now()) data_store.REL_DB.WriteFlowObject(rdf_flow) flow_request = rdf_flow_objects.FlowRequest( client_id=client_id, flow_id=flow_id, request_id=1) data_store.REL_DB.WriteFlowRequests([flow_request]) session_id = "%s/%s" % (client_id, flow_id) fs_client_id = fleetspeak_utils.GRRIDToFleetspeakID(client_id) fs_messages = [] for i in range(1, 10): grr_message = rdf_flows.GrrMessage( request_id=1, response_id=i + 1, session_id=session_id, payload=rdfvalue.RDFInteger(i)) fs_message = fs_common_pb2.Message( message_type="GrrMessage", source=fs_common_pb2.Address( client_id=fs_client_id, service_name=FS_SERVICE_NAME)) fs_message.data.Pack(grr_message.AsPrimitiveProto()) fs_messages.append(fs_message) with test_lib.FakeTime(rdfvalue.RDFDatetime.FromSecondsSinceEpoch(123)): for fs_message in fs_messages: fs_server.Process(fs_message, None) # Ensure the last-ping timestamp gets updated. client_data = data_store.REL_DB.MultiReadClientMetadata([client_id]) self.assertEqual(client_data[client_id].ping, rdfvalue.RDFDatetime.FromSecondsSinceEpoch(123)) flow_data = data_store.REL_DB.ReadAllFlowRequestsAndResponses( client_id, flow_id) self.assertLen(flow_data, 1) stored_flow_request, flow_responses = flow_data[0] self.assertEqual(stored_flow_request, flow_request) self.assertLen(flow_responses, 9)
def testExtractsTypesFromGrrMessage(self): with self.pool: self.collection.Add( rdf_flows.GrrMessage(payload=rdfvalue.RDFInteger(0)), mutation_pool=self.pool) self.collection.Add( rdf_flows.GrrMessage(payload=rdfvalue.RDFString("foo")), mutation_pool=self.pool) self.collection.Add( rdf_flows.GrrMessage(payload=rdfvalue.RDFURN("aff4:/foo/bar")), mutation_pool=self.pool) self.assertEqual( set([ rdfvalue.RDFInteger.__name__, rdfvalue.RDFString.__name__, rdfvalue.RDFURN.__name__ ]), set(self.collection.ListStoredTypes()))
def testValuesOfMultipleTypesCanBeIteratedPerType(self): with self.pool: for i in range(100): self.collection.Add( rdf_flows.GrrMessage(payload=rdfvalue.RDFInteger(i)), mutation_pool=self.pool) self.collection.Add(rdf_flows.GrrMessage( payload=rdfvalue.RDFString(unicode(i))), mutation_pool=self.pool) for index, (_, v) in enumerate( self.collection.ScanByType(rdfvalue.RDFInteger.__name__)): self.assertEqual(index, v.payload) for index, (_, v) in enumerate( self.collection.ScanByType(rdfvalue.RDFString.__name__)): self.assertEqual(str(index), v.payload)
def testMultiResolve(self): collection = self._TestCollection("aff4:/sequential_collection/testAddScan") records = [] with data_store.DB.GetMutationPool() as pool: for i in range(100): ts, suffix = collection.Add(rdfvalue.RDFInteger(i), mutation_pool=pool) records.append( data_store.Record( queue_id=collection.collection_id, timestamp=ts, suffix=suffix, subpath="Results", value=None)) even_results = sorted([r for r in collection.MultiResolve(records[::2])]) self.assertLen(even_results, 50) self.assertEqual(even_results[0], 0) self.assertEqual(even_results[49], 98)
def testDeletingCollectionDeletesAllSubcollections(self): if not isinstance(data_store.DB, fake_data_store.FakeDataStore): self.skipTest("Only supported on FakeDataStore.") with self.pool: self.collection.Add( rdf_flows.GrrMessage(payload=rdfvalue.RDFInteger(0)), mutation_pool=self.pool) self.collection.Add( rdf_flows.GrrMessage(payload=rdfvalue.RDFString("foo")), mutation_pool=self.pool) self.collection.Add( rdf_flows.GrrMessage(payload=rdfvalue.RDFURN("aff4:/foo/bar")), mutation_pool=self.pool) self.collection.Delete() for urn in data_store.DB.subjects: self.assertFalse(utils.SmartStr(self.collection.collection_id) in urn)
def testClaimFiltersRecordsIfFilterIsSpecified(self): queue_urn = "aff4:/queue_test/testClaimFiltersRecordsIfFilterIsSpecified" with self.pool: with aff4.FACTORY.Create(queue_urn, TestQueue, token=self.token) as queue: for i in range(100): queue.Add(rdfvalue.RDFInteger(i), mutation_pool=self.pool) # Filters all even records. def EvenFilter(i): return int(i) % 2 == 0 with aff4.FACTORY.OpenWithLock( queue_urn, lease_time=200, token=self.token) as queue: results = queue.ClaimRecords(record_filter=EvenFilter) # Should have all the odd records. self.assertLen(results, 50) self.assertEqual(1, results[0].value) self.assertEqual(99, results[49].value)
def testWriteLastPingForNewClients(self): if not data_store.RelationalDBEnabled(): self.skipTest("Rel-db-only test.") fs_server = fs_frontend_tool.GRRFSServer() client_id = "C.1234567890123456" flow_id = "12345678" session_id = "%s/%s" % (client_id, flow_id) fs_client_id = fleetspeak_utils.GRRIDToFleetspeakID(client_id) grr_message = rdf_flows.GrrMessage(request_id=1, response_id=1, session_id=session_id, payload=rdfvalue.RDFInteger(1)) fs_message = fs_common_pb2.Message(message_type="GrrMessage", source=fs_common_pb2.Address( client_id=fs_client_id, service_name=FS_SERVICE_NAME)) fs_message.data.Pack(grr_message.AsPrimitiveProto()) fake_time = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(123) with mock.patch.object( events.Events, "PublishEvent", wraps=events.Events.PublishEvent) as publish_event_fn: with mock.patch.object(data_store.REL_DB, "WriteClientMetadata", wraps=data_store.REL_DB.WriteClientMetadata ) as write_metadata_fn: with test_lib.FakeTime(fake_time): fs_server.Process(fs_message, None) self.assertEqual(write_metadata_fn.call_count, 1) client_data = data_store.REL_DB.MultiReadClientMetadata( [client_id]) self.assertEqual(client_data[client_id].ping, fake_time) # TODO(user): publish_event_fn.assert_any_call( # "ClientEnrollment", mock.ANY, token=mock.ANY) doesn't work here # for some reason. triggered_events = [] for call_args, _ in publish_event_fn.call_args_list: if call_args: triggered_events.append(call_args[0]) self.assertIn("ClientEnrollment", triggered_events)
def testClaimCleansSpuriousLocks(self): queue_urn = "aff4:/queue_test/testClaimCleansSpuriousLocks" with self.pool: with aff4.FACTORY.Create(queue_urn, TestQueue, token=self.token) as queue: for i in range(100): queue.Add(rdfvalue.RDFInteger(i), mutation_pool=self.pool) data_store.DB.Flush() with aff4.FACTORY.OpenWithLock(queue_urn, lease_time=200, token=self.token) as queue: results = queue.ClaimRecords() self.assertLen(results, 100) for record in results: subject, _, _ = data_store.DataStore.CollectionMakeURN( record.queue_id, record.timestamp, record.suffix, record.subpath) data_store.DB.DeleteAttributes( subject, [data_store.DataStore.COLLECTION_ATTRIBUTE], sync=True) data_store.DB.Flush() self.assertEqual( 100, sum(1 for _ in data_store.DB.ScanAttribute( str(queue.urn.Add("Records")), data_store.DataStore.QUEUE_LOCK_ATTRIBUTE))) with aff4.FACTORY.OpenWithLock(queue_urn, lease_time=200, token=self.token) as queue: queue.ClaimRecords() data_store.DB.Flush() self.assertEqual( 0, sum(1 for _ in data_store.DB.ScanAttribute( str(queue.urn.Add("Records")), data_store.DataStore.QUEUE_LOCK_ATTRIBUTE)))
def testReadAhead(self): """Read a chunk, and test that the next few are in cache.""" urn = aff4.ROOT_URN.Add("temp_sparse_image.dd") fd = aff4.FACTORY.Create(urn, aff4_type=aff4_standard.AFF4SparseImage, token=self.token, mode="rw") fd.Set(fd.Schema._CHUNKSIZE, rdfvalue.RDFInteger(1024)) fd.chunksize = 1024 fd.Flush() start_chunk = 1000 blob_hashes = [] blobs = [] num_chunks = 5 for chunk in range(start_chunk, start_chunk + num_chunks): # Make sure the blobs have unique content. blob_contents = str(chunk % 10) * fd.chunksize blobs.append(blob_contents) blob_hash = self.AddBlobToBlobStore(blob_contents) fd.AddBlob(blob_hash=blob_hash, length=len(blob_contents), chunk_number=chunk) blob_hashes.append(blob_hash) self.assertEqual(fd.size, fd.chunksize * num_chunks) # Read the first chunk. fd.Seek(start_chunk * fd.chunksize) fd.Read(fd.chunksize) self.assertEqual(len(fd.chunk_cache._hash), num_chunks) fd.Flush() # They shouldn't be in cache anymore, so the chunk_cache should be empty. self.assertFalse(fd.chunk_cache._hash) # Make sure the contents of the file are what we put into it. fd.Seek(start_chunk * fd.chunksize) self.assertEqual(fd.Read(fd.chunksize * num_chunks), "".join(blobs))
def testIndexedReads(self): spacing = 10 with utils.Stubber(sequential_collection.IndexedSequentialCollection, "INDEX_SPACING", spacing): urn = "aff4:/sequential_collection/testIndexedReads" collection = self._TestCollection(urn) data_size = 4 * spacing # TODO(amoser): Without using a mutation pool, this test is really # slow on MySQL data store. with data_store.DB.GetMutationPool() as pool: for i in range(data_size): collection.StaticAdd( rdfvalue.RDFURN(urn), rdfvalue.RDFInteger(i), mutation_pool=pool) with test_lib.FakeTime(rdfvalue.RDFDatetime.Now() + rdfvalue.Duration("10m")): for i in range(data_size - 1, data_size - 20, -1): self.assertEqual(collection[i], i) for i in [spacing - 1, spacing, spacing + 1]: self.assertEqual(collection[i], i) for i in range(data_size - spacing + 5, data_size - spacing - 5, -1): self.assertEqual(collection[i], i)
def testClaimReturnsPreviouslyClaimedRecordsAfterTimeout(self): queue_urn = ( "aff4:/queue_test/testClaimReturnsPreviouslyClaimedRecordsAfterTimeout") with self.pool: with aff4.FACTORY.Create(queue_urn, TestQueue, token=self.token) as queue: for i in range(100): queue.Add(rdfvalue.RDFInteger(i), mutation_pool=self.pool) data_store.DB.Flush() with aff4.FACTORY.OpenWithLock( queue_urn, lease_time=200, token=self.token) as queue: results_1 = queue.ClaimRecords() self.assertLen(results_1, 100) with test_lib.FakeTime(rdfvalue.RDFDatetime.Now() + rdfvalue.DurationSeconds("45m")): with aff4.FACTORY.OpenWithLock( queue_urn, lease_time=200, token=self.token) as queue: results_2 = queue.ClaimRecords() self.assertLen(results_2, 100)
def testAddChunk(self): """Makes sure we can add a chunk and modify it.""" urn = aff4.ROOT_URN.Add("temp_sparse_image.dd") fd = aff4.FACTORY.Create( urn, aff4_type=aff4_standard.AFF4SparseImage, token=self.token, mode="rw") fd.Set(fd.Schema._CHUNKSIZE, rdfvalue.RDFInteger(1024)) fd.chunksize = 1024 fd.Flush() chunk_number = 0 # 1024 characters. blob_contents = b"A" * fd.chunksize blob_hash = self.AddBlobToBlobStore(blob_contents) fd.AddBlob( blob_hash=blob_hash, length=len(blob_contents), chunk_number=chunk_number) fd.Flush() # Make sure the size is correct. self.assertEqual(fd.size, len(blob_contents)) self.assertChunkEqual(fd, chunk_number, blob_contents) # Change the contents of the blob. blob_contents = b"B" * fd.chunksize blob_hash = self.AddBlobToBlobStore(blob_contents) # This time we're updating the blob. fd.AddBlob(blob_hash, len(blob_contents), chunk_number=chunk_number) # The size shouldn't get any bigger, since we got rid of the old blob. self.assertEqual(fd.size, len(blob_contents)) self.assertChunkEqual(fd, chunk_number, blob_contents)
def testMessageHandlerRequestLeasing(self): requests = [ rdf_objects.MessageHandlerRequest(client_id="C.1000000000000000", handler_name="Testhandler", request_id=i * 100, request=rdfvalue.RDFInteger(i)) for i in range(10) ] lease_time = rdfvalue.Duration("5m") with test_lib.FakeTime( rdfvalue.RDFDatetime.FromSecondsSinceEpoch(10000)): self.db.WriteMessageHandlerRequests(requests) leased = queue.Queue() self.db.RegisterMessageHandler(leased.put, lease_time, limit=5) got = [] while len(got) < 10: try: l = leased.get(True, timeout=6) except queue.Empty: self.fail( "Timed out waiting for messages, expected 10, got %d" % len(got)) self.assertLessEqual(len(l), 5) for m in l: self.assertEqual(m.leased_by, utils.ProcessIdString()) self.assertGreater(m.leased_until, rdfvalue.RDFDatetime.Now()) self.assertLess(m.timestamp, rdfvalue.RDFDatetime.Now()) m.leased_by = None m.leased_until = None m.timestamp = None got += l self.db.DeleteMessageHandlerRequests(got) got.sort(key=lambda req: req.request_id) self.assertEqual(requests, got)
def testReadingAfterLastChunk(self): urn = aff4.ROOT_URN.Add("temp_sparse_image.dd") fd = aff4.FACTORY.Create(urn, aff4_type=aff4_standard.AFF4SparseImage, token=self.token, mode="rw") fd.Set(fd.Schema._CHUNKSIZE, rdfvalue.RDFInteger(1024)) fd.chunksize = 1024 fd.Flush() # We shouldn't be able to get any chunks yet. self.assertFalse(fd.Read(10000)) start_chunk = 1000 num_chunks = 5 for chunk in range(start_chunk, start_chunk + num_chunks): # Make sure the blobs have unique content. blob_contents = str(chunk % 10) * fd.chunksize blob_hash = self.AddBlobToBlobStore(blob_contents) fd.AddBlob(blob_hash=blob_hash, length=len(blob_contents), chunk_number=chunk) # Make sure we can read the chunks we just wrote without error. fd.Seek(start_chunk * fd.chunksize) fd.Read(num_chunks * fd.chunksize) # Seek past the end of our chunks. fd.Seek((start_chunk + num_chunks) * fd.chunksize) # We should get the empty string back. self.assertEqual(fd.Read(10000), b"") # Seek to before our chunks start. fd.Seek((start_chunk - 1) * fd.chunksize) # There should be no chunk there and we should raise. with self.assertRaises(aff4.ChunkNotFoundError): fd.Read(fd.chunksize)
def testClaimReturnsPreviouslyReleasedRecords(self): queue_urn = "aff4:/queue_test/testClaimReturnsPreviouslyReleasedRecords" with self.pool: with aff4.FACTORY.Create(queue_urn, TestQueue, token=self.token) as queue: for i in range(100): queue.Add(rdfvalue.RDFInteger(i), mutation_pool=self.pool) data_store.DB.Flush() with aff4.FACTORY.OpenWithLock( queue_urn, lease_time=200, token=self.token) as queue: results = queue.ClaimRecords() odd_ids = [record for record in results if int(record.value) % 2 == 1] queue.ReleaseRecord(odd_ids[0], token=self.token) queue.ReleaseRecords(odd_ids[1:], token=self.token) with aff4.FACTORY.OpenWithLock( queue_urn, lease_time=200, token=self.token) as queue: odd_results = queue.ClaimRecords() self.assertLen(odd_results, 50) self.assertEqual(1, odd_results[0].value) self.assertEqual(99, odd_results[49].value)
def testReceiveMessages(self): """Tests receiving messages.""" client_id = "C.1234567890123456" flow_id = "12345678" data_store.REL_DB.WriteClientMetadata(client_id, fleetspeak_enabled=False) _, req = self._FlowSetup(client_id, flow_id) session_id = "%s/%s" % (client_id, flow_id) messages = [ rdf_flows.GrrMessage( request_id=1, response_id=i, session_id=session_id, auth_state="AUTHENTICATED", payload=rdfvalue.RDFInteger(i)) for i in range(1, 10) ] ReceiveMessages(client_id, messages) received = data_store.REL_DB.ReadAllFlowRequestsAndResponses( client_id, flow_id) self.assertLen(received, 1) self.assertEqual(received[0][0], req) self.assertLen(received[0][1], 9)
def InitFromKeyValue(self, key, value): self.key = key # Convert primitive types to rdf values so they can be serialized. if isinstance(value, float) and not value.is_integer(): # TODO(user): Do not convert float values here and mark them invalid # later. ATM, we do not have means to properly represent floats. Change # this part once we have a RDFFloat implementation. pass elif rdfvalue.RDFInteger.IsNumeric(value): value = rdfvalue.RDFInteger(value) elif isinstance(value, basestring): value = rdfvalue.RDFString(value) elif isinstance(value, bool): value = rdfvalue.RDFBool(value) if isinstance(value, rdfvalue.RDFValue): self.type = value.__class__.__name__ self.value = value else: self.invalid = True return self
def testWellKnownFlows(self): """Test the well known flows.""" test_flow = self.FlowSetup(flow_test_lib.WellKnownSessionTest.__name__) # Make sure the session ID is well known self.assertEqual(test_flow.session_id, flow_test_lib.WellKnownSessionTest.well_known_session_id) # Messages to Well Known flows can be unauthenticated messages = [ rdf_flows.GrrMessage(payload=rdfvalue.RDFInteger(i)) for i in range(10) ] for message in messages: test_flow.ProcessMessage(message) # The messages might be processed in arbitrary order test_flow.messages.sort() # Make sure that messages were processed even without a status # message to complete the transaction (Well known flows do not # have transactions or states - all messages always get to the # ProcessMessage method): self.assertEqual(test_flow.messages, list(range(10)))
def testMultipliesAndIsMultipliedByByPrimitive(self): self.assertEqual(rdfvalue.RDFInteger(10) * 10, 100) self.assertEqual(10 * rdfvalue.RDFInteger(10), 100)
def testDividesAndIsDividableByPrimitiveInts(self): self.assertEqual(rdfvalue.RDFInteger(10) // 5, 2)
def testComparableToPrimiviteInts(self): self.assertEqual(rdfvalue.RDFInteger(10), 10) self.assertGreater(rdfvalue.RDFInteger(10), 5) self.assertGreater(15, rdfvalue.RDFInteger(10)) self.assertLess(rdfvalue.RDFInteger(10), 15) self.assertLess(5, rdfvalue.RDFInteger(10))
def GenerateSample(self, number=0): return rdfvalue.RDFInteger(number)