示例#1
0
    def MultiDestroyFlowStates(self, session_ids):
        """Deletes all states in multiple flows and dequeues all client messages."""
        subjects = [session_id.Add("state") for session_id in session_ids]
        to_delete = []

        for subject, values in self.data_store.MultiResolvePrefix(
                subjects,
                self.FLOW_REQUEST_PREFIX,
                token=self.token,
                limit=self.request_limit):
            for _, serialized, _ in values:

                request = rdf_flows.RequestState(serialized)

                # Drop all responses to this request.
                response_subject = self.GetFlowResponseSubject(
                    request.session_id, request.id)
                to_delete.append(response_subject)

                if request.HasField("request"):
                    # Client request dequeueing is cached so we can call it directly.
                    self.DeQueueClientRequest(request.client_id,
                                              request.request.task_id)

            # Mark the request itself for deletion.
            to_delete.append(subject)

        # Drop them all at once.
        self.data_store.DeleteSubjects(to_delete, token=self.token)
示例#2
0
    def FetchCompletedRequests(self, session_id, timestamp=None):
        """Fetch all the requests with a status message queued for them."""
        subject = session_id.Add("state")
        requests = {}
        status = {}

        if timestamp is None:
            timestamp = (0, self.frozen_timestamp
                         or rdfvalue.RDFDatetime().Now())

        for predicate, serialized, _ in self.data_store.ResolveRegex(
                subject, [self.FLOW_REQUEST_REGEX, self.FLOW_STATUS_REGEX],
                token=self.token,
                limit=self.request_limit,
                timestamp=timestamp):

            parts = predicate.split(":", 3)
            request_id = parts[2]
            if parts[1] == "status":
                status[request_id] = serialized
            else:
                requests[request_id] = serialized

        for request_id, serialized in sorted(requests.items()):
            if request_id in status:
                yield (rdf_flows.RequestState(serialized),
                       rdf_flows.GrrMessage(status[request_id]))
示例#3
0
    def testDeleteRequest(self):
        """Check that we can efficiently destroy a single flow request."""
        session_id = rdfvalue.SessionID(flow_name="test3")

        request = rdf_flows.RequestState(id=1,
                                         client_id=test_lib.TEST_CLIENT_ID,
                                         next_state="TestState",
                                         session_id=session_id)

        with queue_manager.QueueManager(token=self.token) as manager:
            manager.QueueRequest(request)
            manager.QueueResponse(
                rdf_flows.GrrMessage(session_id=session_id,
                                     request_id=1,
                                     response_id=1))

        # Check the request and responses are there.
        all_requests = list(manager.FetchRequestsAndResponses(session_id))
        self.assertEqual(len(all_requests), 1)
        self.assertEqual(all_requests[0][0], request)

        with queue_manager.QueueManager(token=self.token) as manager:
            manager.DeleteRequest(request)

        all_requests = list(manager.FetchRequestsAndResponses(session_id))
        self.assertEqual(len(all_requests), 0)
示例#4
0
    def FetchRequestsAndResponses(self, session_id):
        """Well known flows do not have real requests.

    This manages retrieving all the responses without requiring corresponding
    requests.

    Args:
      session_id: The session_id to get the requests/responses for.

    Yields:
      A tuple of request (None) and responses.
    """
        subject = session_id.Add("state/request:00000000")

        # Get some requests
        for _, serialized, _ in sorted(
                self.data_store.ResolvePrefix(
                    subject,
                    self.FLOW_RESPONSE_PREFIX,
                    token=self.token,
                    limit=self.response_limit,
                    timestamp=(0, self.frozen_timestamp
                               or rdfvalue.RDFDatetime.Now()))):

            # The predicate format is flow:response:REQUEST_ID:RESPONSE_ID. For well
            # known flows both request_id and response_id are randomized.
            response = rdf_flows.GrrMessage.FromSerializedString(serialized)

            yield rdf_flows.RequestState(id=0), [response]
示例#5
0
    def testCountsActualNumberOfCompletedResponsesWhenApplyingTheLimit(self):
        session_id = rdfvalue.SessionID(flow_name="test")

        # Now queue more requests and responses:
        with queue_manager.QueueManager(token=self.token) as manager:
            # Start with request 1 - leave request 1 un-responded to.
            for request_id in range(5):
                request = rdf_flows.RequestState(
                    id=request_id,
                    client_id=test_lib.TEST_CLIENT_ID,
                    next_state="TestState",
                    session_id=session_id)

                manager.QueueRequest(request)

                # Don't queue any actual responses, just a status message with a
                # fake response_id.
                manager.QueueResponse(
                    rdf_flows.GrrMessage(
                        session_id=session_id,
                        request_id=request_id,
                        response_id=1000,
                        type=rdf_flows.GrrMessage.Type.STATUS))

        # Check that even though status message for every request indicates 1000
        # responses, only the actual response count is used to apply the limit
        # when FetchCompletedResponses is called.
        completed_response = list(
            manager.FetchCompletedResponses(session_id, limit=5))
        self.assertEqual(len(completed_response), 5)
        for i, (request, responses) in enumerate(completed_response):
            self.assertEqual(request.id, i)
            # Responses contain just the status message.
            self.assertEqual(len(responses), 1)
示例#6
0
  def SendOKStatus(self, response_id, session_id):
    """Send a message to the flow."""
    message = rdf_flows.GrrMessage(
        request_id=1,
        response_id=response_id,
        session_id=session_id,
        type=rdf_flows.GrrMessage.Type.STATUS,
        auth_state=rdf_flows.GrrMessage.AuthorizationState.AUTHENTICATED)

    status = rdf_flows.GrrStatus(status=rdf_flows.GrrStatus.ReturnedStatus.OK)
    message.payload = status

    self.SendMessage(message)

    # Now also set the state on the RequestState
    request_state, _ = data_store.DB.Resolve(
        message.session_id.Add("state"),
        queue_manager.QueueManager.FLOW_REQUEST_TEMPLATE % message.request_id,
        token=self.token)

    request_state = rdf_flows.RequestState(request_state)
    request_state.status = status

    data_store.DB.Set(
        message.session_id.Add("state"),
        queue_manager.QueueManager.FLOW_REQUEST_TEMPLATE % message.request_id,
        request_state, token=self.token)

    return message
示例#7
0
    def FetchRequestsAndResponses(self, session_id, timestamp=None):
        """Fetches all outstanding requests and responses for this flow.

    We first cache all requests and responses for this flow in memory to
    prevent round trips.

    Args:
      session_id: The session_id to get the requests/responses for.
      timestamp: Tupe (start, end) with a time range. Fetched requests and
                 responses will have timestamp in this range.

    Yields:
      an tuple (request protobufs, list of responses messages) in ascending
      order of request ids.

    Raises:
      MoreDataException: When there is more data available than read by the
                         limited query.
    """
        subject = session_id.Add("state")
        requests = {}

        if timestamp is None:
            timestamp = (0, self.frozen_timestamp
                         or rdfvalue.RDFDatetime().Now())

        # Get some requests.
        for predicate, serialized, _ in self.data_store.ResolveRegex(
                subject,
                self.FLOW_REQUEST_REGEX,
                token=self.token,
                limit=self.request_limit,
                timestamp=timestamp):

            request_id = predicate.split(":", 1)[1]
            requests[str(subject.Add(request_id))] = serialized

        # And the responses for them.
        response_data = dict(
            self.data_store.MultiResolveRegex(requests.keys(),
                                              self.FLOW_RESPONSE_REGEX,
                                              limit=self.response_limit,
                                              token=self.token,
                                              timestamp=timestamp))

        for urn, request_data in sorted(requests.items()):
            request = rdf_flows.RequestState(request_data)
            responses = []
            for _, serialized, _ in response_data.get(urn, []):
                responses.append(rdf_flows.GrrMessage(serialized))

            yield (request, sorted(responses, key=lambda msg: msg.response_id))

        if len(requests) >= self.request_limit:
            raise MoreDataException()
示例#8
0
    def testDestroyFlowStates(self):
        """Check that we can efficiently destroy the flow's request queues."""
        session_id = rdfvalue.SessionID(flow_name="test2")

        request = rdf_flows.RequestState(id=1,
                                         client_id=self.client_id,
                                         next_state="TestState",
                                         session_id=session_id)

        with queue_manager.QueueManager(token=self.token) as manager:
            manager.QueueRequest(session_id, request)
            manager.QueueResponse(
                session_id, rdf_flows.GrrMessage(request_id=1, response_id=1))

        # Check the request and responses are there.
        all_requests = list(manager.FetchRequestsAndResponses(session_id))
        self.assertEqual(len(all_requests), 1)
        self.assertEqual(all_requests[0][0], request)

        # Ensure the rows are in the data store:
        self.assertEqual(
            data_store.DB.ResolveRegex(session_id.Add("state"),
                                       ".*",
                                       token=self.token)[0][0],
            "flow:request:00000001")

        self.assertEqual(
            data_store.DB.ResolveRegex(
                session_id.Add("state/request:00000001"),
                ".*",
                token=self.token)[0][0], "flow:response:00000001:00000001")

        with queue_manager.QueueManager(token=self.token) as manager:
            manager.DestroyFlowStates(session_id)

        all_requests = list(manager.FetchRequestsAndResponses(session_id))
        self.assertEqual(len(all_requests), 0)

        # Ensure the rows are gone from the data store.
        self.assertEqual(
            data_store.DB.ResolveRegex(
                session_id.Add("state/request:00000001"),
                ".*",
                token=self.token), [])

        self.assertEqual(
            data_store.DB.ResolveRegex(session_id.Add("state"),
                                       ".*",
                                       token=self.token), [])
示例#9
0
    def testDestroyFlowStates(self):
        """Check that we can efficiently destroy the flow's request queues."""
        session_id = rdfvalue.SessionID(flow_name="test2")

        request = rdf_flows.RequestState(id=1,
                                         client_id=self.client_id,
                                         next_state="TestState",
                                         session_id=session_id)

        with queue_manager.QueueManager(token=self.token) as manager:
            manager.QueueRequest(request)
            manager.QueueResponse(
                rdf_flows.GrrMessage(request_id=1,
                                     response_id=1,
                                     session_id=session_id))

        # Check the request and responses are there.
        all_requests = list(manager.FetchRequestsAndResponses(session_id))
        self.assertEqual(len(all_requests), 1)
        self.assertEqual(all_requests[0][0], request)

        # Read the response directly.
        responses = data_store.DB.ReadResponsesForRequestId(session_id, 1)
        self.assertEqual(len(responses), 1)
        response = responses[0]
        self.assertEqual(response.request_id, 1)
        self.assertEqual(response.response_id, 1)
        self.assertEqual(response.session_id, session_id)

        with queue_manager.QueueManager(token=self.token) as manager:
            manager.DestroyFlowStates(session_id)

        all_requests = list(manager.FetchRequestsAndResponses(session_id))
        self.assertEqual(len(all_requests), 0)

        # Check that the response is gone.
        responses = data_store.DB.ReadResponsesForRequestId(session_id, 1)
        self.assertEqual(len(responses), 0)

        # Ensure the rows are gone from the data store. Some data stores
        # don't store the queues in that way but there is no harm in
        # checking.
        self.assertEqual(
            data_store.DB.ResolveRow(session_id.Add("state/request:00000001"),
                                     token=self.token), [])

        self.assertEqual(
            data_store.DB.ResolveRow(session_id.Add("state"),
                                     token=self.token), [])
示例#10
0
    def GetFlowRequests(self, flow_urns, token):
        """Returns all outstanding requests for the flows in flow_urns."""
        flow_requests = {}
        flow_request_urns = [flow_urn.Add("state") for flow_urn in flow_urns]

        for flow_urn, values in data_store.DB.MultiResolvePrefix(
                flow_request_urns, "flow:", token=token):
            for subject, serialized, _ in values:
                try:
                    if "status" in subject:
                        msg = rdf_flows.GrrMessage(serialized)
                    else:
                        msg = rdf_flows.RequestState(serialized)
                except Exception as e:  # pylint: disable=broad-except
                    logging.warn("Error while parsing: %s", e)
                    continue

                flow_requests.setdefault(flow_urn, []).append(msg)
        return flow_requests
示例#11
0
  def ReadResponsesForRequestId(self,
                                session_id,
                                request_id,
                                timestamp=None,
                                token=None):
    """Reads responses for one request.

    Args:
      session_id: The session id to use.
      request_id: The id of the request.
      timestamp: A timestamp as used in the data store.
      token: A data store token.

    Yields:
      fetched responses for the request
    """
    request = rdf_flows.RequestState(id=request_id, session_id=session_id)
    for _, responses in self.ReadResponses(
        [request], timestamp=timestamp, token=token):
      return responses
示例#12
0
    def StartClients(cls, hunt_id, client_ids, token=None):
        """This method is called by the foreman for each client it discovers.

    Note that this function is performance sensitive since it is called by the
    foreman for every client which needs to be scheduled.

    Args:
      hunt_id: The hunt to schedule.
      client_ids: List of clients that should be added to the hunt.
      token: An optional access token to use.
    """
        token = token or access_control.ACLToken(username="******",
                                                 reason="hunting")

        with queue_manager.QueueManager(token=token) as flow_manager:
            for client_id in client_ids:
                # Now we construct a special response which will be sent to the hunt
                # flow. Randomize the request_id so we do not overwrite other messages
                # in the queue.
                state = rdf_flows.RequestState(id=utils.PRNG.GetULong(),
                                               session_id=hunt_id,
                                               client_id=client_id,
                                               next_state="AddClient")

                # Queue the new request.
                flow_manager.QueueRequest(hunt_id, state)

                # Send a response.
                msg = rdf_flows.GrrMessage(
                    session_id=hunt_id,
                    request_id=state.id,
                    response_id=1,
                    auth_state=rdf_flows.GrrMessage.AuthorizationState.
                    AUTHENTICATED,
                    type=rdf_flows.GrrMessage.Type.STATUS,
                    payload=rdf_flows.GrrStatus())

                flow_manager.QueueResponse(hunt_id, msg)

                # And notify the worker about it.
                flow_manager.QueueNotification(session_id=hunt_id)
示例#13
0
    def testDrainUpdateSessionRequestStates(self):
        """Draining the flow requests and preparing messages."""
        # This flow sends 10 messages on Start()
        flow_obj = self.FlowSetup("SendingTestFlow")
        session_id = flow_obj.session_id

        # There should be 10 messages in the client's task queue
        manager = queue_manager.QueueManager(token=self.token)
        tasks = manager.Query(self.client_id, 100)
        self.assertEqual(len(tasks), 10)

        # Check that the response state objects have the correct ts_id set
        # in the client_queue:
        for task in tasks:
            request_id = task.request_id

            # Retrieve the request state for this request_id
            request_state, _ = data_store.DB.Resolve(
                session_id.Add("state"),
                manager.FLOW_REQUEST_TEMPLATE % request_id,
                token=self.token)

            request_state = rdf_flows.RequestState(request_state)

            # Check that task_id for the client message is correctly set in
            # request_state.
            self.assertEqual(request_state.request.task_id, task.task_id)

        # Now ask the server to drain the outbound messages into the
        # message list.
        response = rdf_flows.MessageList()

        response.job = self.server.DrainTaskSchedulerQueueForClient(
            self.client_id, 5)

        # Check that we received only as many messages as we asked for
        self.assertEqual(len(response.job), 5)

        for i in range(4):
            self.assertEqual(response.job[i].session_id, session_id)
            self.assertEqual(response.job[i].name, "Test")
示例#14
0
    def DestroyFlowStates(self, session_id):
        """Deletes all states in this flow and dequeue all client messages."""
        subject = session_id.Add("state")

        for _, serialized, _ in self.data_store.ResolveRegex(
                subject,
                self.FLOW_REQUEST_REGEX,
                token=self.token,
                limit=self.request_limit):

            request = rdf_flows.RequestState(serialized)

            # Efficiently drop all responses to this request.
            response_subject = self.GetFlowResponseSubject(
                session_id, request.id)
            self.data_store.DeleteSubject(response_subject, token=self.token)

            if request.HasField("request"):
                self.DeQueueClientRequest(request.client_id,
                                          request.request.task_id)

        # Now drop all the requests at once.
        self.data_store.DeleteSubject(subject, token=self.token)
示例#15
0
    def CallState(self,
                  messages=None,
                  next_state="",
                  request_data=None,
                  start_time=None):
        """This method is used to schedule a new state on a different worker.

    This is basically the same as CallFlow() except we are calling
    ourselves. The state will be invoked in a later time and receive all the
    messages we send.

    Args:
       messages: A list of rdfvalues to send. If the last one is not a
            GrrStatus, we append an OK Status.

       next_state: The state in this flow to be invoked with the responses.

       request_data: Any dict provided here will be available in the
             RequestState protobuf. The Responses object maintains a reference
             to this protobuf for use in the execution of the state method. (so
             you can access this data by responses.request).

       start_time: Start the flow at this time. This Delays notification for
         flow processing into the future. Note that the flow may still be
         processed earlier if there are client responses waiting.

    Raises:
       FlowRunnerError: if the next state is not valid.
    """
        if messages is None:
            messages = []

        # Check if the state is valid
        if not getattr(self.flow_obj, next_state):
            raise FlowRunnerError("Next state %s is invalid.")

        # Queue the response message to the parent flow
        request_state = rdf_flows.RequestState(
            id=self.GetNextOutboundId(),
            session_id=self.context.session_id,
            client_id=self.runner_args.client_id,
            next_state=next_state)
        if request_data:
            request_state.data = rdf_protodict.Dict().FromDict(request_data)

        self.QueueRequest(request_state, timestamp=start_time)

        # Add the status message if needed.
        if not messages or not isinstance(messages[-1], rdf_flows.GrrStatus):
            messages.append(rdf_flows.GrrStatus())

        # Send all the messages
        for i, payload in enumerate(messages):
            if isinstance(payload, rdfvalue.RDFValue):
                msg = rdf_flows.GrrMessage(
                    session_id=self.session_id,
                    request_id=request_state.id,
                    response_id=1 + i,
                    auth_state=rdf_flows.GrrMessage.AuthorizationState.
                    AUTHENTICATED,
                    payload=payload,
                    type=rdf_flows.GrrMessage.Type.MESSAGE)

                if isinstance(payload, rdf_flows.GrrStatus):
                    msg.type = rdf_flows.GrrMessage.Type.STATUS
            else:
                raise FlowRunnerError("Bad message %s of type %s." %
                                      (payload, type(payload)))

            self.QueueResponse(msg, start_time)

        # Notify the worker about it.
        self.QueueNotification(session_id=self.session_id,
                               timestamp=start_time)
示例#16
0
 def testEmbeddedDict(self):
     state = rdf_flows.RequestState(data=rdf_protodict.Dict({"a": 1}))
     serialized = state.SerializeToString()
     deserialized = rdf_flows.RequestState.FromSerializedString(serialized)
     self.assertEqual(deserialized.data, state.data)
示例#17
0
    def CallState(self,
                  messages=None,
                  next_state="",
                  client_id=None,
                  request_data=None,
                  start_time=None):
        """This method is used to asynchronously schedule a new hunt state.

    The state will be invoked in a later time and receive all the messages
    we send.

    Args:
      messages: A list of rdfvalues to send. If the last one is not a
              GrrStatus, we append an OK Status.

      next_state: The state in this hunt to be invoked with the responses.

      client_id: ClientURN to use in scheduled requests.

      request_data: Any dict provided here will be available in the
                    RequestState protobuf. The Responses object maintains a
                    reference to this protobuf for use in the execution of the
                    state method. (so you can access this data by
                    responses.request).

      start_time: Schedule the state at this time. This delays notification
                  and messages for processing into the future.
    Raises:
      ValueError: on arguments error.
    """

        if messages is None:
            messages = []

        if not next_state:
            raise ValueError("next_state can't be empty.")

        # Now we construct a special response which will be sent to the hunt
        # flow. Randomize the request_id so we do not overwrite other messages in
        # the queue.
        request_state = rdf_flows.RequestState(
            id=utils.PRNG.GetULong(),
            session_id=self.context.session_id,
            client_id=client_id,
            next_state=next_state)

        if request_data:
            request_state.data = rdf_protodict.Dict().FromDict(request_data)

        self.QueueRequest(request_state, timestamp=start_time)

        # Add the status message if needed.
        if not messages or not isinstance(messages[-1], rdf_flows.GrrStatus):
            messages.append(rdf_flows.GrrStatus())

        # Send all the messages
        for i, payload in enumerate(messages):
            if isinstance(payload, rdfvalue.RDFValue):
                msg = rdf_flows.GrrMessage(
                    session_id=self.session_id,
                    request_id=request_state.id,
                    response_id=1 + i,
                    auth_state=rdf_flows.GrrMessage.AuthorizationState.
                    AUTHENTICATED,
                    payload=payload,
                    type=rdf_flows.GrrMessage.Type.MESSAGE)

                if isinstance(payload, rdf_flows.GrrStatus):
                    msg.type = rdf_flows.GrrMessage.Type.STATUS
            else:
                raise flow_runner.FlowRunnerError(
                    "Bad message %s of type %s." % (payload, type(payload)))

            self.QueueResponse(msg, timestamp=start_time)

        # Add the status message if needed.
        if not messages or not isinstance(messages[-1], rdf_flows.GrrStatus):
            messages.append(rdf_flows.GrrStatus())

        # Notify the worker about it.
        self.QueueNotification(session_id=self.session_id,
                               timestamp=start_time)
示例#18
0
    def CallClient(self,
                   action_cls,
                   request=None,
                   next_state=None,
                   client_id=None,
                   request_data=None,
                   start_time=None,
                   **kwargs):
        """Calls the client asynchronously.

    This sends a message to the client to invoke an Action. The run
    action may send back many responses. These will be queued by the
    framework until a status message is sent by the client. The status
    message will cause the entire transaction to be committed to the
    specified state.

    Args:
       action_cls: The function to call on the client.

       request: The request to send to the client. If not specified (Or None) we
             create a new RDFValue using the kwargs.

       next_state: The state in this flow, that responses to this
             message should go to.

       client_id: rdf_client.ClientURN to send the request to.

       request_data: A dict which will be available in the RequestState
             protobuf. The Responses object maintains a reference to this
             protobuf for use in the execution of the state method. (so you can
             access this data by responses.request). Valid values are
             strings, unicode and protobufs.

       start_time: Call the client at this time. This Delays the client request
         for into the future.

       **kwargs: These args will be used to construct the client action semantic
         protobuf.

    Raises:
       FlowRunnerError: If next_state is not one of the allowed next states.
       RuntimeError: The request passed to the client does not have the correct
                     type.
    """
        if client_id is None:
            client_id = self.runner_args.client_id

        if client_id is None:
            raise FlowRunnerError(
                "CallClient() is used on a flow which was not "
                "started with a client.")

        if not isinstance(client_id, rdf_client.ClientURN):
            # Try turning it into a ClientURN
            client_id = rdf_client.ClientURN(client_id)

        if action_cls.in_rdfvalue is None:
            if request:
                raise RuntimeError("Client action %s does not expect args." %
                                   action_cls.__name__)
        else:
            if request is None:
                # Create a new rdf request.
                request = action_cls.in_rdfvalue(**kwargs)
            else:
                # Verify that the request type matches the client action requirements.
                if not isinstance(request, action_cls.in_rdfvalue):
                    raise RuntimeError("Client action expected %s but got %s" %
                                       (action_cls.in_rdfvalue, type(request)))

        outbound_id = self.GetNextOutboundId()

        # Create a new request state
        state = rdf_flows.RequestState(id=outbound_id,
                                       session_id=self.session_id,
                                       next_state=next_state,
                                       client_id=client_id)

        if request_data is not None:
            state.data = rdf_protodict.Dict(request_data)

        # Send the message with the request state
        msg = rdf_flows.GrrMessage(
            session_id=utils.SmartUnicode(self.session_id),
            name=action_cls.__name__,
            request_id=outbound_id,
            priority=self.runner_args.priority,
            require_fastpoll=self.runner_args.require_fastpoll,
            queue=client_id.Queue(),
            payload=request,
            generate_task_id=True)

        cpu_usage = self.context.client_resources.cpu_usage
        if self.runner_args.cpu_limit:
            msg.cpu_limit = max(
                self.runner_args.cpu_limit - cpu_usage.user_cpu_time -
                cpu_usage.system_cpu_time, 0)

            if msg.cpu_limit == 0:
                raise FlowRunnerError("CPU limit exceeded.")

        if self.runner_args.network_bytes_limit:
            msg.network_bytes_limit = max(
                self.runner_args.network_bytes_limit -
                self.context.network_bytes_sent, 0)
            if msg.network_bytes_limit == 0:
                raise FlowRunnerError("Network limit exceeded.")

        state.request = msg
        self.QueueRequest(state, timestamp=start_time)
示例#19
0
    def CallFlow(self,
                 flow_name=None,
                 next_state=None,
                 sync=True,
                 request_data=None,
                 client_id=None,
                 base_session_id=None,
                 **kwargs):
        """Creates a new flow and send its responses to a state.

    This creates a new flow. The flow may send back many responses which will be
    queued by the framework until the flow terminates. The final status message
    will cause the entire transaction to be committed to the specified state.

    Args:
       flow_name: The name of the flow to invoke.

       next_state: The state in this flow, that responses to this
       message should go to.

       sync: If True start the flow inline on the calling thread, else schedule
         a worker to actually start the child flow.

       request_data: Any dict provided here will be available in the
             RequestState protobuf. The Responses object maintains a reference
             to this protobuf for use in the execution of the state method. (so
             you can access this data by responses.request). There is no
             format mandated on this data but it may be a serialized protobuf.

       client_id: If given, the flow is started for this client.

       base_session_id: A URN which will be used to build a URN.

       **kwargs: Arguments for the child flow.

    Raises:
       FlowRunnerError: If next_state is not one of the allowed next states.

    Returns:
       The URN of the child flow which was created.
    """
        client_id = client_id or self.runner_args.client_id

        # This looks very much like CallClient() above - we prepare a request state,
        # and add it to our queue - any responses from the child flow will return to
        # the request state and the stated next_state. Note however, that there is
        # no client_id or actual request message here because we directly invoke the
        # child flow rather than queue anything for it.
        state = rdf_flows.RequestState(id=self.GetNextOutboundId(),
                                       session_id=utils.SmartUnicode(
                                           self.session_id),
                                       client_id=client_id,
                                       next_state=next_state,
                                       response_count=0)

        if request_data:
            state.data = rdf_protodict.Dict().FromDict(request_data)

        # If the urn is passed explicitly (e.g. from the hunt runner) use that,
        # otherwise use the urn from the flow_runner args. If both are None, create
        # a new collection and give the urn to the flow object.
        logs_urn = self._GetLogCollectionURN(
            kwargs.pop("logs_collection_urn", None)
            or self.runner_args.logs_collection_urn)

        # If we were called with write_intermediate_results, propagate down to
        # child flows.  This allows write_intermediate_results to be set to True
        # either at the top level parent, or somewhere in the middle of
        # the call chain.
        write_intermediate = (kwargs.pop("write_intermediate_results", False)
                              or self.runner_args.write_intermediate_results)

        try:
            event_id = self.runner_args.event_id
        except AttributeError:
            event_id = None

        # Create the new child flow but do not notify the user about it.
        child_urn = self.flow_obj.StartFlow(
            client_id=client_id,
            flow_name=flow_name,
            base_session_id=base_session_id or self.session_id,
            event_id=event_id,
            request_state=state,
            token=self.token,
            notify_to_user=False,
            parent_flow=self.flow_obj,
            sync=sync,
            queue=self.runner_args.queue,
            write_intermediate_results=write_intermediate,
            logs_collection_urn=logs_urn,
            **kwargs)

        self.QueueRequest(state)

        return child_urn
示例#20
0
    def testQueueing(self):
        """Tests that queueing and fetching of requests and responses work."""
        session_id = rdfvalue.SessionID(flow_name="test")

        request = rdf_flows.RequestState(id=1,
                                         client_id=self.client_id,
                                         next_state="TestState",
                                         session_id=session_id)

        with queue_manager.QueueManager(token=self.token) as manager:
            manager.QueueRequest(session_id, request)

        # We only have one unanswered request on the queue.
        all_requests = list(manager.FetchRequestsAndResponses(session_id))
        self.assertEqual(len(all_requests), 1)
        self.assertEqual(all_requests[0], (request, []))

        # FetchCompletedRequests should return nothing now.
        self.assertEqual(list(manager.FetchCompletedRequests(session_id)), [])

        # Now queue more requests and responses:
        with queue_manager.QueueManager(token=self.token) as manager:
            # Start with request 2 - leave request 1 un-responded to.
            for request_id in range(2, 5):
                request = rdf_flows.RequestState(id=request_id,
                                                 client_id=self.client_id,
                                                 next_state="TestState",
                                                 session_id=session_id)

                manager.QueueRequest(session_id, request)

                response_id = None
                for response_id in range(1, 10):
                    # Normal message.
                    manager.QueueResponse(
                        session_id,
                        rdf_flows.GrrMessage(request_id=request_id,
                                             response_id=response_id))

                # And a status message.
                manager.QueueResponse(
                    session_id,
                    rdf_flows.GrrMessage(
                        request_id=request_id,
                        response_id=response_id + 1,
                        type=rdf_flows.GrrMessage.Type.STATUS))

        completed_requests = list(manager.FetchCompletedRequests(session_id))
        self.assertEqual(len(completed_requests), 3)

        # First completed message is request_id = 2 with 10 responses.
        self.assertEqual(completed_requests[0][0].id, 2)

        # Last message is the status message.
        self.assertEqual(completed_requests[0][-1].type,
                         rdf_flows.GrrMessage.Type.STATUS)
        self.assertEqual(completed_requests[0][-1].response_id, 10)

        # Now fetch all the completed responses. Set the limit so we only fetch some
        # of the responses.
        completed_response = list(manager.FetchCompletedResponses(session_id))
        self.assertEqual(len(completed_response), 3)
        for i, (request, responses) in enumerate(completed_response, 2):
            self.assertEqual(request.id, i)
            self.assertEqual(len(responses), 10)

        # Now check if the limit is enforced. The limit refers to the total number
        # of responses to return. We ask for maximum 15 responses, so we should get
        # a single request with 10 responses (since 2 requests will exceed the
        # limit).
        more_data = False
        i = 0
        try:
            partial_response = manager.FetchCompletedResponses(session_id,
                                                               limit=15)
            for i, (request, responses) in enumerate(partial_response, 2):
                self.assertEqual(request.id, i)
                self.assertEqual(len(responses), 10)
        except queue_manager.MoreDataException:
            more_data = True

        # Returns the first request that is completed.
        self.assertEqual(i, 3)

        # Make sure the manager told us that more data is available.
        self.assertTrue(more_data)