Ejemplo n.º 1
0
    def _WriteRequestAndCompleteResponses(self,
                                          client_id,
                                          flow_id,
                                          request_id,
                                          num_responses,
                                          task_id=None):
        request = rdf_flow_objects.FlowRequest(client_id=client_id,
                                               flow_id=flow_id,
                                               request_id=request_id)
        self.db.WriteFlowRequests([request])

        # Write <num_responses> responses and a status in random order.
        responses = self._ResponsesAndStatus(client_id, flow_id, request_id,
                                             num_responses)
        random.shuffle(responses)

        for response in responses:
            request = self._ReadRequest(client_id, flow_id, request_id)
            self.assertIsNotNone(request)
            # This is false up to the moment when we write the last response.
            self.assertFalse(request.needs_processing)

            if task_id:
                response.task_id = task_id

            self.db.WriteFlowResponses([response])

        # Now that we sent all responses, the request needs processing.
        request = self._ReadRequest(client_id, flow_id, request_id)
        self.assertTrue(request.needs_processing)
        self.assertEqual(request.nr_responses_expected, len(responses))

        # Flow processing request might have been generated.
        return len(self.db.ReadFlowProcessingRequests())
Ejemplo n.º 2
0
    def CallState(self,
                  next_state: str = "",
                  start_time: Optional[rdfvalue.RDFDatetime] = None) -> None:
        """This method is used to schedule a new state on a different worker.

    This is basically the same as CallFlow() except we are calling
    ourselves. The state will be invoked at a later time.

    Args:
       next_state: The state in this flow to be invoked.
       start_time: Start the flow at this time. This delays notification for
         flow processing into the future. Note that the flow may still be
         processed earlier if there are client responses waiting.

    Raises:
      ValueError: The next state specified does not exist.
    """
        if not getattr(self, next_state):
            raise ValueError("Next state %s is invalid." % next_state)

        flow_request = rdf_flow_objects.FlowRequest(
            client_id=self.rdf_flow.client_id,
            flow_id=self.rdf_flow.flow_id,
            request_id=self.GetNextOutboundId(),
            next_state=next_state,
            start_time=start_time,
            needs_processing=True)

        self.flow_requests.append(flow_request)
Ejemplo n.º 3
0
    def testResponseWriting(self):
        client_id, flow_id = self._SetupClientAndFlow()

        request = rdf_flow_objects.FlowRequest(client_id=client_id,
                                               flow_id=flow_id,
                                               request_id=1,
                                               needs_processing=False)
        self.db.WriteFlowRequests([request])

        responses = [
            rdf_flow_objects.FlowResponse(client_id=client_id,
                                          flow_id=flow_id,
                                          request_id=1,
                                          response_id=i) for i in range(3)
        ]

        self.db.WriteFlowResponses(responses)

        all_requests = self.db.ReadAllFlowRequestsAndResponses(
            client_id, flow_id)
        self.assertEqual(len(all_requests), 1)

        read_request, read_responses = all_requests[0]
        self.assertEqual(read_request, request)
        self.assertEqual(list(read_responses), [0, 1, 2])

        for response_id, response in iteritems(read_responses):
            self.assertEqual(response.response_id, response_id)
Ejemplo n.º 4
0
    def testResponsesForUnknownRequest(self):
        client_id, flow_id = self._SetupClientAndFlow()

        request = rdf_flow_objects.FlowRequest(client_id=client_id,
                                               flow_id=flow_id,
                                               request_id=1)
        self.db.WriteFlowRequests([request])

        # Write two responses at a time, one request exists, the other doesn't.
        with test_lib.SuppressLogs():
            self.db.WriteFlowResponses([
                rdf_flow_objects.FlowResponse(client_id=client_id,
                                              flow_id=flow_id,
                                              request_id=1,
                                              response_id=1),
                rdf_flow_objects.FlowResponse(client_id=client_id,
                                              flow_id=flow_id,
                                              request_id=2,
                                              response_id=1)
            ])

        # We should have one response in the db.
        read = self.db.ReadAllFlowRequestsAndResponses(client_id, flow_id)
        self.assertEqual(len(read), 1)
        request, responses = read[0]
        self.assertEqual(len(responses), 1)
Ejemplo n.º 5
0
    def testDeleteFlowRequests(self):
        client_id, flow_id = self._SetupClientAndFlow()

        requests = []
        responses = []
        for request_id in range(1, 4):
            requests.append(
                rdf_flow_objects.FlowRequest(client_id=client_id,
                                             flow_id=flow_id,
                                             request_id=request_id))
            responses.append(
                rdf_flow_objects.FlowResponse(client_id=client_id,
                                              flow_id=flow_id,
                                              request_id=request_id,
                                              response_id=1))

        self.db.WriteFlowRequests(requests)
        self.db.WriteFlowResponses(responses)

        request_list = self.db.ReadAllFlowRequestsAndResponses(
            client_id, flow_id)
        self.assertItemsEqual([req.request_id for req, _ in request_list],
                              [req.request_id for req in requests])

        random.shuffle(requests)

        while requests:
            request = requests.pop()
            self.db.DeleteFlowRequests([request])
            request_list = self.db.ReadAllFlowRequestsAndResponses(
                client_id, flow_id)
            self.assertItemsEqual([req.request_id for req, _ in request_list],
                                  [req.request_id for req in requests])
Ejemplo n.º 6
0
    def testRequestWriting(self):
        client_id_1 = u"C.1234567890123456"
        client_id_2 = u"C.1234567890123457"
        flow_id_1 = u"1234ABCD"
        flow_id_2 = u"ABCD1234"

        with self.assertRaises(db.UnknownFlowError):
            self.db.WriteFlowRequests([
                rdf_flow_objects.FlowRequest(client_id=client_id_1,
                                             flow_id=flow_id_1)
            ])
        for client_id in [client_id_1, client_id_2]:
            self.db.WriteClientMetadata(client_id, fleetspeak_enabled=False)

        requests = []
        for flow_id in [flow_id_1, flow_id_2]:
            for client_id in [client_id_1, client_id_2]:
                rdf_flow = rdf_flow_objects.Flow(client_id=client_id,
                                                 flow_id=flow_id)
                self.db.WriteFlowObject(rdf_flow)

                for i in range(1, 4):
                    requests.append(
                        rdf_flow_objects.FlowRequest(client_id=client_id,
                                                     flow_id=flow_id,
                                                     request_id=i))

        self.db.WriteFlowRequests(requests)

        with self.assertRaises(db.UnknownFlowError):
            self.db.ReadAllFlowRequestsAndResponses(client_id=client_id_1,
                                                    flow_id=u"11111111")

        for flow_id in [flow_id_1, flow_id_2]:
            for client_id in [client_id_1, client_id_2]:
                read = self.db.ReadAllFlowRequestsAndResponses(
                    client_id=client_id, flow_id=flow_id)

                self.assertEqual(len(read), 3)
                self.assertEqual([req.request_id for (req, _) in read],
                                 list(range(1, 4)))
                for _, responses in read:
                    self.assertEqual(responses, [])
Ejemplo n.º 7
0
    def testReceiveMessageList(self):
        fs_server = fleetspeak_frontend_server.GRRFSServer()
        client_id = "C.1234567890123456"
        flow_id = "12345678"
        data_store.REL_DB.WriteClientMetadata(client_id,
                                              fleetspeak_enabled=True)

        rdf_flow = rdf_flow_objects.Flow(
            client_id=client_id,
            flow_id=flow_id,
            create_time=rdfvalue.RDFDatetime.Now())
        data_store.REL_DB.WriteFlowObject(rdf_flow)

        flow_request = rdf_flow_objects.FlowRequest(client_id=client_id,
                                                    flow_id=flow_id,
                                                    request_id=1)

        data_store.REL_DB.WriteFlowRequests([flow_request])
        session_id = "%s/%s" % (client_id, flow_id)
        fs_client_id = fleetspeak_utils.GRRIDToFleetspeakID(client_id)
        grr_messages = []
        for i in range(1, 10):
            grr_message = rdf_flows.GrrMessage(request_id=1,
                                               response_id=i + 1,
                                               session_id=session_id,
                                               payload=rdfvalue.RDFInteger(i))
            grr_messages.append(grr_message)
        packed_messages = rdf_flows.PackedMessageList()
        communicator.Communicator.EncodeMessageList(
            rdf_flows.MessageList(job=grr_messages), packed_messages)
        fs_message = fs_common_pb2.Message(message_type="MessageList",
                                           source=fs_common_pb2.Address(
                                               client_id=fs_client_id,
                                               service_name=FS_SERVICE_NAME))
        fs_message.data.Pack(packed_messages.AsPrimitiveProto())
        fs_message.validation_info.tags["foo"] = "bar"

        with test_lib.FakeTime(
                rdfvalue.RDFDatetime.FromSecondsSinceEpoch(123)):
            fs_server.Process(fs_message, None)

        # Ensure the last-ping timestamp gets updated.
        client_data = data_store.REL_DB.MultiReadClientMetadata([client_id])
        self.assertEqual(client_data[client_id].ping,
                         rdfvalue.RDFDatetime.FromSecondsSinceEpoch(123))
        self.assertEqual(
            client_data[client_id].last_fleetspeak_validation_info.
            ToStringDict(), {"foo": "bar"})

        flow_data = data_store.REL_DB.ReadAllFlowRequestsAndResponses(
            client_id, flow_id)
        self.assertLen(flow_data, 1)
        stored_flow_request, flow_responses = flow_data[0]
        self.assertEqual(stored_flow_request, flow_request)
        self.assertLen(flow_responses, 9)
Ejemplo n.º 8
0
    def _WriteRequestForProcessing(self, client_id, flow_id, request_id):
        with mock.patch.object(self.db.delegate,
                               "WriteFlowProcessingRequests") as req_func:

            request = rdf_flow_objects.FlowRequest(flow_id=flow_id,
                                                   client_id=client_id,
                                                   request_id=request_id,
                                                   needs_processing=True)
            self.db.WriteFlowRequests([request])

            return req_func.call_count
Ejemplo n.º 9
0
  def CallFlow(self,
               flow_name=None,
               next_state=None,
               request_data=None,
               client_id=None,
               base_session_id=None,
               **kwargs):
    """Creates a new flow and send its responses to a state.

    This creates a new flow. The flow may send back many responses which will be
    queued by the framework until the flow terminates. The final status message
    will cause the entire transaction to be committed to the specified state.

    Args:
       flow_name: The name of the flow to invoke.
       next_state: The state in this flow, that responses to this message should
         go to.
       request_data: Any dict provided here will be available in the
         RequestState protobuf. The Responses object maintains a reference to
         this protobuf for use in the execution of the state method. (so you can
         access this data by responses.request). There is no format mandated on
         this data but it may be a serialized protobuf.
       client_id: If given, the flow is started for this client.
       base_session_id: A URN which will be used to build a URN.
       **kwargs: Arguments for the child flow.

    Returns:
       The flow_id of the child flow which was created.

    Raises:
      ValueError: The requested next state does not exist.
    """
    if not getattr(self, next_state):
      raise ValueError("Next state %s is invalid." % next_state)

    flow_request = rdf_flow_objects.FlowRequest(
        client_id=self.rdf_flow.client_id,
        flow_id=self.rdf_flow.flow_id,
        request_id=self.GetNextOutboundId(),
        next_state=next_state)

    if request_data is not None:
      flow_request.request_data = rdf_protodict.Dict().FromDict(request_data)

    self.flow_requests.append(flow_request)

    flow_cls = registry.FlowRegistry.FlowClassByName(flow_name)

    flow.StartFlow(
        client_id=self.rdf_flow.client_id,
        flow_cls=flow_cls,
        parent_flow_obj=self,
        **kwargs)
Ejemplo n.º 10
0
    def testReturnProcessedFlow(self):
        client_id, flow_id = self._SetupClientAndFlow(
            next_request_to_process=1)

        processing_time = rdfvalue.Duration("60s")

        processed_flow = self.db.ReadFlowForProcessing(client_id, flow_id,
                                                       processing_time)

        # Let's say we processed one request on this flow.
        processed_flow.next_request_to_process = 2

        # There are some requests ready for processing but not #2.
        self.db.WriteFlowRequests([
            rdf_flow_objects.FlowRequest(client_id=client_id,
                                         flow_id=flow_id,
                                         request_id=1,
                                         needs_processing=True),
            rdf_flow_objects.FlowRequest(client_id=client_id,
                                         flow_id=flow_id,
                                         request_id=4,
                                         needs_processing=True)
        ])

        self.assertTrue(self.db.ReturnProcessedFlow(processed_flow))

        processed_flow = self.db.ReadFlowForProcessing(client_id, flow_id,
                                                       processing_time)
        # And another one.
        processed_flow.next_request_to_process = 3

        # But in the meantime, request 3 is ready for processing.
        self.db.WriteFlowRequests([
            rdf_flow_objects.FlowRequest(client_id=client_id,
                                         flow_id=flow_id,
                                         request_id=3,
                                         needs_processing=True)
        ])

        self.assertFalse(self.db.ReturnProcessedFlow(processed_flow))
Ejemplo n.º 11
0
    def testReceiveMessages_Relational(self):
        if not data_store.RelationalDBEnabled():
            self.skipTest("Rel-db-only test.")

        fs_server = fs_frontend_tool.GRRFSServer()
        client_id = "C.1234567890123456"
        flow_id = "12345678"
        data_store.REL_DB.WriteClientMetadata(client_id,
                                              fleetspeak_enabled=True)

        rdf_flow = rdf_flow_objects.Flow(
            client_id=client_id,
            flow_id=flow_id,
            create_time=rdfvalue.RDFDatetime.Now())
        data_store.REL_DB.WriteFlowObject(rdf_flow)

        flow_request = rdf_flow_objects.FlowRequest(client_id=client_id,
                                                    flow_id=flow_id,
                                                    request_id=1)

        data_store.REL_DB.WriteFlowRequests([flow_request])
        session_id = "%s/%s" % (client_id, flow_id)
        fs_client_id = fleetspeak_utils.GRRIDToFleetspeakID(client_id)
        fs_messages = []
        for i in range(1, 10):
            grr_message = rdf_flows.GrrMessage(request_id=1,
                                               response_id=i + 1,
                                               session_id=session_id,
                                               payload=rdfvalue.RDFInteger(i))
            fs_message = fs_common_pb2.Message(
                message_type="GrrMessage",
                source=fs_common_pb2.Address(client_id=fs_client_id,
                                             service_name=FS_SERVICE_NAME))
            fs_message.data.Pack(grr_message.AsPrimitiveProto())
            fs_messages.append(fs_message)

        with test_lib.FakeTime(
                rdfvalue.RDFDatetime.FromSecondsSinceEpoch(123)):
            for fs_message in fs_messages:
                fs_server.Process(fs_message, None)

        # Ensure the last-ping timestamp gets updated.
        client_data = data_store.REL_DB.MultiReadClientMetadata([client_id])
        self.assertEqual(client_data[client_id].ping,
                         rdfvalue.RDFDatetime.FromSecondsSinceEpoch(123))

        flow_data = data_store.REL_DB.ReadAllFlowRequestsAndResponses(
            client_id, flow_id)
        self.assertLen(flow_data, 1)
        stored_flow_request, flow_responses = flow_data[0]
        self.assertEqual(stored_flow_request, flow_request)
        self.assertLen(flow_responses, 9)
Ejemplo n.º 12
0
  def _FlowSetup(self, client_id, flow_id):
    rdf_flow = rdf_flow_objects.Flow(
        flow_class_name=compatibility.GetName(
            administrative.OnlineNotification),
        client_id=client_id,
        flow_id=flow_id,
        create_time=rdfvalue.RDFDatetime.Now())
    data_store.REL_DB.WriteFlowObject(rdf_flow)

    req = rdf_flow_objects.FlowRequest(
        client_id=client_id, flow_id=flow_id, request_id=1)

    data_store.REL_DB.WriteFlowRequests([req])

    return rdf_flow, req
Ejemplo n.º 13
0
    def testPathSpecCasingIsCorrected(self):
        flow = memory.DumpProcessMemory(rdf_flow_objects.Flow())
        flow.SendReply = mock.Mock(spec=flow.SendReply)

        request = rdf_flow_objects.FlowRequest(
            request_data={
                "YaraProcessDumpResponse":
                rdf_memory.YaraProcessDumpResponse(dumped_processes=[
                    rdf_memory.YaraProcessDumpInformation(memory_regions=[
                        rdf_memory.ProcessMemoryRegion(
                            start=1,
                            size=1,
                            file=rdf_paths.PathSpec.Temp(
                                path="/C:/grr/x_1_0_1.tmp")),
                        rdf_memory.ProcessMemoryRegion(
                            start=1,
                            size=1,
                            file=rdf_paths.PathSpec.Temp(
                                path="/C:/GRR/x_1_1_2.tmp"))
                    ])
                ])
            })
        pathspecs = [
            rdf_paths.PathSpec.Temp(path="/C:/Grr/x_1_0_1.tmp"),
            rdf_paths.PathSpec.Temp(path="/C:/Grr/x_1_1_2.tmp")
        ]
        responses = flow_responses.Responses.FromResponses(
            request, [
                rdf_flow_objects.FlowResponse(payload=rdf_client_fs.StatEntry(
                    pathspec=pathspec)) for pathspec in pathspecs
            ])

        flow.ProcessMemoryRegions(responses)
        flow.SendReply.assert_any_call(
            rdf_memory.YaraProcessDumpResponse(dumped_processes=[
                rdf_memory.YaraProcessDumpInformation(memory_regions=[
                    rdf_memory.ProcessMemoryRegion(
                        start=1,
                        size=1,
                        file=rdf_paths.PathSpec.Temp(
                            path="/C:/Grr/x_1_0_1.tmp")),
                    rdf_memory.ProcessMemoryRegion(
                        start=1,
                        size=1,
                        file=rdf_paths.PathSpec.Temp(
                            path="/C:/Grr/x_1_1_2.tmp"))
                ])
            ]))
Ejemplo n.º 14
0
    def _WriteRequestAndResponses(self, client_id, flow_id):
        rdf_flow = rdf_flow_objects.Flow(client_id=client_id, flow_id=flow_id)
        self.db.WriteFlowObject(rdf_flow)

        for request_id in range(1, 4):
            request = rdf_flow_objects.FlowRequest(client_id=client_id,
                                                   flow_id=flow_id,
                                                   request_id=request_id)
            self.db.WriteFlowRequests([request])

            for response_id in range(1, 3):
                response = rdf_flow_objects.FlowResponse(
                    client_id=client_id,
                    flow_id=flow_id,
                    request_id=request_id,
                    response_id=response_id)
                self.db.WriteFlowResponses([response])
Ejemplo n.º 15
0
    def testDrainTaskSchedulerQueue(self):
        client_id = u"C.1234567890123456"
        flow_id = flow.RandomFlowId()
        data_store.REL_DB.WriteClientMetadata(client_id,
                                              fleetspeak_enabled=False)

        rdf_flow = rdf_flow_objects.Flow(
            client_id=client_id,
            flow_id=flow_id,
            create_time=rdfvalue.RDFDatetime.Now())
        data_store.REL_DB.WriteFlowObject(rdf_flow)

        action_requests = []
        for i in range(3):
            data_store.REL_DB.WriteFlowRequests([
                rdf_flow_objects.FlowRequest(client_id=client_id,
                                             flow_id=flow_id,
                                             request_id=i)
            ])

            action_requests.append(
                rdf_flows.ClientActionRequest(client_id=client_id,
                                              flow_id=flow_id,
                                              request_id=i,
                                              action_identifier="WmiQuery"))

        data_store.REL_DB.WriteClientActionRequests(action_requests)
        server = TestServer()

        res = server.DrainTaskSchedulerQueueForClient(
            rdfvalue.RDFURN(client_id))
        msgs = [
            rdf_flow_objects.GRRMessageFromClientActionRequest(r)
            for r in action_requests
        ]
        for r in res:
            r.task_id = 0
        for m in msgs:
            m.task_id = 0

        self.assertItemsEqual(res, msgs)
Ejemplo n.º 16
0
    def testReadFlowRequestsReadyForProcessing(self):
        client_id = u"C.1234567890000000"
        flow_id = u"12344321"

        requests_for_processing = self.db.ReadFlowRequestsReadyForProcessing(
            client_id, flow_id, next_needed_request=1)
        self.assertEqual(requests_for_processing, {})

        client_id, flow_id = self._SetupClientAndFlow(
            next_request_to_process=3)

        for request_id in [1, 3, 4, 5, 7]:
            request = rdf_flow_objects.FlowRequest(client_id=client_id,
                                                   flow_id=flow_id,
                                                   request_id=request_id,
                                                   needs_processing=True)
            self.db.WriteFlowRequests([request])

        # Request 4 has some responses.
        responses = [
            rdf_flow_objects.FlowResponse(client_id=client_id,
                                          flow_id=flow_id,
                                          request_id=4,
                                          response_id=i) for i in range(3)
        ]
        self.db.WriteFlowResponses(responses)

        requests_for_processing = self.db.ReadFlowRequestsReadyForProcessing(
            client_id, flow_id, next_needed_request=3)

        # We expect three requests here. Req #1 is old and should not be there, req
        # #7 can't be processed since we are missing #6 in between. That leaves
        # requests #3, #4 and #5.
        self.assertEqual(len(requests_for_processing), 3)
        self.assertEqual(list(requests_for_processing), [3, 4, 5])

        for request_id in requests_for_processing:
            request, _ = requests_for_processing[request_id]
            self.assertEqual(request_id, request.request_id)

        self.assertEqual(requests_for_processing[4][1], responses)
Ejemplo n.º 17
0
    def testStatusMessagesCanBeWrittenAndRead(self):
        client_id, flow_id = self._SetupClientAndFlow()

        request = rdf_flow_objects.FlowRequest(client_id=client_id,
                                               flow_id=flow_id,
                                               request_id=1,
                                               needs_processing=False)
        self.db.WriteFlowRequests([request])

        responses = [
            rdf_flow_objects.FlowResponse(client_id=client_id,
                                          flow_id=flow_id,
                                          request_id=1,
                                          response_id=i) for i in range(3)
        ]
        # Also store an Iterator, why not.
        responses.append(
            rdf_flow_objects.FlowIterator(client_id=client_id,
                                          flow_id=flow_id,
                                          request_id=1,
                                          response_id=3))
        responses.append(
            rdf_flow_objects.FlowStatus(client_id=client_id,
                                        flow_id=flow_id,
                                        request_id=1,
                                        response_id=4))
        self.db.WriteFlowResponses(responses)

        all_requests = self.db.ReadAllFlowRequestsAndResponses(
            client_id, flow_id)
        self.assertEqual(len(all_requests), 1)

        _, read_responses = all_requests[0]
        self.assertEqual(list(read_responses), [0, 1, 2, 3, 4])
        for i in range(3):
            self.assertIsInstance(read_responses[i],
                                  rdf_flow_objects.FlowResponse)
        self.assertIsInstance(read_responses[3], rdf_flow_objects.FlowIterator)
        self.assertIsInstance(read_responses[4], rdf_flow_objects.FlowStatus)
Ejemplo n.º 18
0
    def CallClient(self,
                   action_cls: Type[server_stubs.ClientActionStub],
                   request: Optional[rdfvalue.RDFValue] = None,
                   next_state: Optional[str] = None,
                   callback_state: Optional[str] = None,
                   request_data: Optional[Mapping[str, Any]] = None,
                   **kwargs: Any):
        """Calls the client asynchronously.

    This sends a message to the client to invoke an Action. The run action may
    send back many responses that will be queued by the framework until a status
    message is sent by the client. The status message will cause the entire
    transaction to be committed to the specified state.

    Args:
       action_cls: The function to call on the client.
       request: The request to send to the client. If not specified, we create a
         new RDFValue using the kwargs.
       next_state: The state in this flow, that responses to this message should
         go to.
       callback_state: (optional) The state to call whenever a new response is
         arriving.
       request_data: A dict which will be available in the RequestState
         protobuf. The Responses object maintains a reference to this protobuf
         for use in the execution of the state method. (so you can access this
         data by responses.request).
       **kwargs: These args will be used to construct the client action argument
         rdfvalue.

    Raises:
       ValueError: The request passed to the client does not have the correct
                   type.
    """
        try:
            action_identifier = action_registry.ID_BY_ACTION_STUB[action_cls]
        except KeyError:
            raise ValueError("Action class %s not known." % action_cls)

        if action_cls.in_rdfvalue is None:
            if request:
                raise ValueError("Client action %s does not expect args." %
                                 action_cls)
        else:
            if request is None:
                # Create a new rdf request.
                request = action_cls.in_rdfvalue(**kwargs)
            else:
                # Verify that the request type matches the client action requirements.
                if not isinstance(request, action_cls.in_rdfvalue):
                    raise ValueError("Client action expected %s but got %s" %
                                     (action_cls.in_rdfvalue, type(request)))

        outbound_id = self.GetNextOutboundId()

        # Create a flow request.
        flow_request = rdf_flow_objects.FlowRequest(
            client_id=self.rdf_flow.client_id,
            flow_id=self.rdf_flow.flow_id,
            request_id=outbound_id,
            next_state=next_state,
            callback_state=callback_state)

        if request_data is not None:
            flow_request.request_data = rdf_protodict.Dict().FromDict(
                request_data)

        cpu_limit_ms = None
        network_bytes_limit = None

        if self.rdf_flow.cpu_limit:
            cpu_usage = self.rdf_flow.cpu_time_used
            cpu_limit_ms = 1000 * max(
                self.rdf_flow.cpu_limit - cpu_usage.user_cpu_time -
                cpu_usage.system_cpu_time, 0)

            if cpu_limit_ms == 0:
                raise flow.FlowResourcesExceededError(
                    "CPU limit exceeded for {} {}.".format(
                        self.rdf_flow.flow_class_name, self.rdf_flow.flow_id))

        if self.rdf_flow.network_bytes_limit:
            network_bytes_limit = max(
                self.rdf_flow.network_bytes_limit -
                self.rdf_flow.network_bytes_sent, 0)
            if network_bytes_limit == 0:
                raise flow.FlowResourcesExceededError(
                    "Network limit exceeded for {} {}.".format(
                        self.rdf_flow.flow_class_name, self.rdf_flow.flow_id))

        runtime_limit_us = self.rdf_flow.runtime_limit_us

        if runtime_limit_us and self.rdf_flow.runtime_us:
            if self.rdf_flow.runtime_us < runtime_limit_us:
                runtime_limit_us -= self.rdf_flow.runtime_us
            else:
                raise flow.FlowResourcesExceededError(
                    "Runtime limit exceeded for {} {}.".format(
                        self.rdf_flow.flow_class_name, self.rdf_flow.flow_id))

        client_action_request = rdf_flows.ClientActionRequest(
            client_id=self.rdf_flow.client_id,
            flow_id=self.rdf_flow.flow_id,
            request_id=outbound_id,
            action_identifier=action_identifier,
            action_args=request,
            cpu_limit_ms=cpu_limit_ms,
            network_bytes_limit=network_bytes_limit,
            runtime_limit_us=runtime_limit_us)

        self.flow_requests.append(flow_request)
        self.client_action_requests.append(client_action_request)
Ejemplo n.º 19
0
  def CallClient(self,
                 action_cls,
                 request=None,
                 next_state=None,
                 request_data=None,
                 **kwargs):
    """Calls the client asynchronously.

    This sends a message to the client to invoke an Action. The run action may
    send back many responses that will be queued by the framework until a status
    message is sent by the client. The status message will cause the entire
    transaction to be committed to the specified state.

    Args:
       action_cls: The function to call on the client.
       request: The request to send to the client. If not specified, we create a
         new RDFValue using the kwargs.
       next_state: The state in this flow, that responses to this message should
         go to.
       request_data: A dict which will be available in the RequestState
         protobuf. The Responses object maintains a reference to this protobuf
         for use in the execution of the state method. (so you can access this
         data by responses.request).
       **kwargs: These args will be used to construct the client action argument
         rdfvalue.

    Raises:
       ValueError: The request passed to the client does not have the correct
                   type.
    """
    if action_cls.in_rdfvalue is None:
      if request:
        raise ValueError(
            "Client action %s does not expect args." % action_cls.__name__)
    else:
      if request is None:
        # Create a new rdf request.
        request = action_cls.in_rdfvalue(**kwargs)
      else:
        # Verify that the request type matches the client action requirements.
        if not isinstance(request, action_cls.in_rdfvalue):
          raise ValueError("Client action expected %s but got %s" %
                           (action_cls.in_rdfvalue, type(request)))

    outbound_id = self.GetNextOutboundId()

    # Create a flow request.
    flow_request = rdf_flow_objects.FlowRequest(
        client_id=self.rdf_flow.client_id,
        flow_id=self.rdf_flow.flow_id,
        request_id=outbound_id,
        next_state=next_state)

    if request_data is not None:
      flow_request.request_data = rdf_protodict.Dict().FromDict(request_data)

    msg = rdf_flows.GrrMessage(
        session_id=self.rdf_flow.long_flow_id,
        name=action_cls.__name__,
        request_id=outbound_id,
        queue=self.rdf_flow.client_id,
        payload=request,
        generate_task_id=True)

    if self.rdf_flow.cpu_limit:
      cpu_usage = self.rdf_flow.cpu_time_used
      msg.cpu_limit = max(
          self.rdf_flow.cpu_limit - cpu_usage.user_cpu_time -
          cpu_usage.system_cpu_time, 0)

      if msg.cpu_limit == 0:
        raise flow.FlowError("CPU limit exceeded.")

    if self.rdf_flow.network_bytes_limit:
      msg.network_bytes_limit = max(
          self.rdf_flow.network_bytes_limit - self.rdf_flow.network_bytes_sent,
          0)
      if msg.network_bytes_limit == 0:
        raise flow.FlowError("Network limit exceeded.")

    self.flow_requests.append(flow_request)
    self.client_messages.append(msg)