コード例 #1
0
    def __init__(self, delimiter=","):
        precondition.AssertType(delimiter, Text)

        if compatibility.PY2:
            self._output = io.BytesIO()
        else:
            self._output = io.StringIO()

        self._csv = csv.writer(self._output,
                               delimiter=compatibility.NativeStr(delimiter),
                               lineterminator=compatibility.NativeStr("\n"))
コード例 #2
0
ファイル: utils_test.py プロジェクト: rezaduty/grr
 def testReturnValueForFollowingCallsIsCached(self):
   result = object()
   mock_fn = mock.Mock(side_effect=[result])
   mock_fn.__name__ = compatibility.NativeStr("MockFunction")
   fn = utils.RunOnce(mock_fn)
   self.assertIs(fn(), result)
   self.assertIs(fn(), result)
コード例 #3
0
ファイル: utils_test.py プロジェクト: rezaduty/grr
 def testDecoratedFunctionIsCalledAtLeastOnce(self):
   mock_fn = mock.Mock()
   mock_fn.__name__ = compatibility.NativeStr("MockFunction")
   fn = utils.RunOnce(mock_fn)
   mock_fn.assert_not_called()
   fn()
   mock_fn.assert_called_once()
コード例 #4
0
ファイル: conftest.py プロジェクト: x35029/grr
def pytest_cmdline_main(config):
    """A pytest hook that is called when the main function is executed."""
    del config  # Unused.

    # TODO: `sys.argv` on Python 2 uses `bytes` to represent passed
    # arguments.
    sys.argv = [compatibility.NativeStr("pytest")] + test_args
コード例 #5
0
ファイル: utils_test.py プロジェクト: rezaduty/grr
 def testExceptionsArePassedThrough(self):
   mock_fn = mock.Mock(side_effect=ValueError())
   mock_fn.__name__ = compatibility.NativeStr("MockFunction")
   fn = utils.RunOnce(mock_fn)
   with self.assertRaises(ValueError):
     fn()
   with self.assertRaises(ValueError):
     fn()
コード例 #6
0
ファイル: utils_test.py プロジェクト: rezaduty/grr
 def testDecoratedFunctionIsCalledAtMostOnce(self):
   mock_fn = mock.Mock(side_effect=[None, AssertionError()])
   mock_fn.__name__ = compatibility.NativeStr("MockFunction")
   fn = utils.RunOnce(mock_fn)
   fn()
   fn()
   fn()
   mock_fn.assert_called_once()
コード例 #7
0
ファイル: conftest.py プロジェクト: esmat777/grr
def pytest_cmdline_main(config):
    """A pytest hook that is called when the main function is executed."""

    if "PYTEST_XDIST_WORKER" in os.environ:
        # If ran concurrently using pytest-xdist (`-n` cli flag), mainargv is the
        # result of the execution of pytest_cmdline_main in the main process.
        sys.argv = config.workerinput["mainargv"]
    else:
        # TODO: `sys.argv` on Python 2 uses `bytes` to represent passed
        # arguments.
        sys.argv = [compatibility.NativeStr("pytest")] + test_args
コード例 #8
0
    def __iter__(self):
        if compatibility.PY2:
            filedesc = io.BytesIO(self._content.encode("utf-8"))
        else:
            filedesc = io.StringIO(self._content)

        reader = csv.reader(filedesc,
                            delimiter=compatibility.NativeStr(self._delimiter),
                            lineterminator=compatibility.NativeStr("\n"))

        for values in reader:
            row = []

            for value in values:
                if compatibility.PY2:
                    # TODO(hanuszczak): https://github.com/google/pytype/issues/127
                    row.append(value.decode("utf-8"))  # pytype: disable=attribute-error
                else:
                    row.append(value)

            yield row
コード例 #9
0
        def Corruptor(url="", data=None, **kwargs):
            """Futz with some of the fields."""
            comm_cls = rdf_flows.ClientCommunication
            if data is not None:
                self.client_communication = comm_cls.FromSerializedBytes(data)
            else:
                self.client_communication = comm_cls(None)

            if self.corruptor_field and "server.pem" not in url:
                orig_str_repr = self.client_communication.SerializeToBytes()
                field_data = getattr(self.client_communication,
                                     self.corruptor_field)
                if hasattr(field_data, "SerializeToBytes"):
                    # This converts encryption keys to a string so we can corrupt them.
                    field_data = field_data.SerializeToBytes()

                # TODO: We use `bytes` from the `future` package here to
                # have Python 3 iteration behaviour. This call should be a noop in
                # Python 3 and should be safe to remove on support for Python 2 is
                # dropped.
                field_data = bytes(field_data)

                # TODO: On Python 2.7.6 and lower `array.array` accepts
                # only byte strings as argument so the call below is necessary. Once
                # support for old Python versions is dropped, this call should be
                # removed.
                modified_data = array.array(compatibility.NativeStr("B"),
                                            field_data)
                offset = len(field_data) // 2
                char = field_data[offset]
                modified_data[offset] = char % 250 + 1
                setattr(self.client_communication, self.corruptor_field,
                        modified_data.tostring())

                # Make sure we actually changed the data.
                self.assertNotEqual(field_data, modified_data)

                mod_str_repr = self.client_communication.SerializeToBytes()
                self.assertLen(orig_str_repr, len(mod_str_repr))
                differences = [
                    True for x, y in zip(orig_str_repr, mod_str_repr) if x != y
                ]
                self.assertLen(differences, 1)

            data = self.client_communication.SerializeToBytes()
            return self.UrlMock(url=url, data=data, **kwargs)
コード例 #10
0
ファイル: filesystem.py プロジェクト: x35029/grr
    def _FetchLinuxFlags(self):
        """Fetches Linux extended file flags."""
        if platform.system() != "Linux":
            return 0

        # Since we open a file in the next step we do not want to open a symlink.
        # `lsattr` returns an error when trying to check flags of a symlink, so we
        # assume that symlinks cannot have them.
        if self.IsSymlink():
            return 0

        # Some files (e.g. sockets) cannot be opened. For these we do not really
        # care about extended flags (they should have none). `lsattr` does not seem
        # to support such cases anyway. It is also possible that a file has been
        # deleted (because this method is used lazily).
        try:
            fd = os.open(self._path, os.O_RDONLY)
        except (IOError, OSError):
            return 0

        try:
            # This import is Linux-specific.
            import fcntl  # pylint: disable=g-import-not-at-top
            # TODO: On Python 2.7.6 `array.array` accepts only byte
            # strings as an argument. On Python 2.7.12 and 2.7.13 unicodes are
            # supported as well. On Python 3, only unicode strings are supported. This
            # is why, as a temporary hack, we wrap the literal with `str` call that
            # will convert it to whatever is the default on given Python version. This
            # should be changed to raw literal once support for Python 2 is dropped.
            buf = array.array(compatibility.NativeStr("l"), [0])
            # TODO(user):pytype: incorrect type spec for fcntl.ioctl
            # pytype: disable=wrong-arg-types
            fcntl.ioctl(fd, self.FS_IOC_GETFLAGS, buf)
            # pytype: enable=wrong-arg-types
            return buf[0]
        except (IOError, OSError):
            # File system does not support extended attributes.
            return 0
        finally:
            os.close(fd)
コード例 #11
0
ファイル: utils_test.py プロジェクト: rezaduty/grr
 def testWrapsFunctionProperly(self):
   mock_fn = mock.Mock()
   mock_fn.__name__ = compatibility.NativeStr("MockFunction")
   fn = utils.RunOnce(mock_fn)
   self.assertEqual(fn.__name__, compatibility.NativeStr("MockFunction"))
コード例 #12
0
ファイル: utils_test.py プロジェクト: rezaduty/grr
 def testReturnValueIsPassedThrough(self):
   mock_fn = mock.Mock(return_value="bar")
   mock_fn.__name__ = compatibility.NativeStr("MockFunction")
   fn = utils.RunOnce(mock_fn)
   self.assertEqual("bar", fn())
コード例 #13
0
ファイル: utils_test.py プロジェクト: rezaduty/grr
 def testArgumentsArePassedThrough(self):
   mock_fn = mock.Mock()
   mock_fn.__name__ = compatibility.NativeStr("MockFunction")
   fn = utils.RunOnce(mock_fn)
   fn(1, 2, foo="bar")
   mock_fn.assert_called_once_with(1, 2, foo="bar")
コード例 #14
0
ファイル: flow_base.py プロジェクト: avmi/grr
    def RunStateMethod(
        self,
        method_name: str,
        request: Optional[rdf_flow_runner.RequestState] = None,
        responses: Optional[Sequence[rdf_flow_objects.FlowMessage]] = None
    ) -> None:
        """Completes the request by calling the state method.

    Args:
      method_name: The name of the state method to call.
      request: A RequestState protobuf.
      responses: A list of FlowMessages responding to the request.

    Raises:
      FlowError: Processing time for the flow has expired.
    """
        client_id = self.rdf_flow.client_id

        deadline = self.rdf_flow.processing_deadline
        if deadline and rdfvalue.RDFDatetime.Now() > deadline:
            raise FlowError("Processing time for flow %s on %s expired." %
                            (self.rdf_flow.flow_id, self.rdf_flow.client_id))

        self.rdf_flow.current_state = method_name
        if request and responses:
            logging.debug("Running %s for flow %s on %s, %d responses.",
                          method_name, self.rdf_flow.flow_id, client_id,
                          len(responses))
        else:
            logging.debug("Running %s for flow %s on %s", method_name,
                          self.rdf_flow.flow_id, client_id)

        try:
            try:
                method = getattr(self, method_name)
            except AttributeError:
                raise ValueError("Flow %s has no state method %s" %
                                 (self.__class__.__name__, method_name))

            # Prepare a responses object for the state method to use:
            responses = flow_responses.Responses.FromResponses(
                request=request, responses=responses)

            if responses.status is not None:
                self.SaveResourceUsage(responses.status)

            GRR_WORKER_STATES_RUN.Increment()

            if method_name == "Start":
                FLOW_STARTS.Increment(fields=[self.rdf_flow.flow_class_name])
                method()
            else:
                method(responses)

            if self.replies_to_process:
                if self.rdf_flow.parent_hunt_id and not self.rdf_flow.parent_flow_id:
                    self._ProcessRepliesWithHuntOutputPlugins(
                        self.replies_to_process)
                else:
                    self._ProcessRepliesWithFlowOutputPlugins(
                        self.replies_to_process)

                self.replies_to_process = []

        except flow.FlowResourcesExceededError as e:
            FLOW_ERRORS.Increment(fields=[self.rdf_flow.flow_class_name])
            logging.info("Flow %s on %s exceeded resource limits: %s.",
                         self.rdf_flow.flow_id, client_id, str(e))
            self.Error(error_message=str(e))
        # We don't know here what exceptions can be thrown in the flow but we have
        # to continue. Thus, we catch everything.
        except Exception as e:  # pylint: disable=broad-except
            # TODO(amoser): We don't know what's in this exception so we have to deal
            # with all eventualities. Replace this code with a simple str(e) once
            # Python 2 support has been dropped.
            msg = compatibility.NativeStr(e)
            FLOW_ERRORS.Increment(fields=[self.rdf_flow.flow_class_name])

            self.Error(error_message=msg, backtrace=traceback.format_exc())
コード例 #15
0
ファイル: flow.py プロジェクト: viszsec/grr
def StartFlow(client_id=None,
              cpu_limit=None,
              creator=None,
              flow_args=None,
              flow_cls=None,
              network_bytes_limit=None,
              original_flow=None,
              output_plugins=None,
              start_at=None,
              parent_flow_obj=None,
              parent_hunt_id=None,
              runtime_limit=None,
              **kwargs):
  """The main factory function for creating and executing a new flow.

  Args:
    client_id: ID of the client this flow should run on.
    cpu_limit: CPU limit in seconds for this flow.
    creator: Username that requested this flow.
    flow_args: An arg protocol buffer which is an instance of the required
      flow's args_type class attribute.
    flow_cls: Class of the flow that should be started.
    network_bytes_limit: Limit on the network traffic this flow can generated.
    original_flow: A FlowReference object in case this flow was copied from
      another flow.
    output_plugins: An OutputPluginDescriptor object indicating what output
      plugins should be used for this flow.
    start_at: If specified, flow will be started not immediately, but at a given
      time.
    parent_flow_obj: A parent flow object. None if this is a top level flow.
    parent_hunt_id: String identifying parent hunt. Can't be passed together
      with parent_flow_obj.
    runtime_limit: Runtime limit as Duration for all ClientActions.
    **kwargs: If args or runner_args are not specified, we construct these
      protobufs from these keywords.

  Returns:
    the flow id of the new flow.

  Raises:
    ValueError: Unknown or invalid parameters were provided.
  """

  if parent_flow_obj is not None and parent_hunt_id is not None:
    raise ValueError(
        "parent_flow_obj and parent_hunt_id are mutually exclusive.")

  # Is the required flow a known flow?
  try:
    registry.FlowRegistry.FlowClassByName(flow_cls.__name__)
  except ValueError:
    GRR_FLOW_INVALID_FLOW_COUNT.Increment()
    raise ValueError("Unable to locate flow %s" % flow_cls.__name__)

  if not client_id:
    raise ValueError("Client_id is needed to start a flow.")

  # Now parse the flow args into the new object from the keywords.
  if flow_args is None:
    flow_args = flow_cls.args_type()

  FilterArgsFromSemanticProtobuf(flow_args, kwargs)
  # At this point we should exhaust all the keyword args. If any are left
  # over, we do not know what to do with them so raise.
  if kwargs:
    raise type_info.UnknownArg("Unknown parameters to StartFlow: %s" % kwargs)

  # Check that the flow args are valid.
  flow_args.Validate()

  rdf_flow = rdf_flow_objects.Flow(
      client_id=client_id,
      flow_class_name=flow_cls.__name__,
      args=flow_args,
      create_time=rdfvalue.RDFDatetime.Now(),
      creator=creator,
      output_plugins=output_plugins,
      original_flow=original_flow,
      flow_state="RUNNING")

  if parent_hunt_id is not None and parent_flow_obj is None:
    rdf_flow.flow_id = parent_hunt_id
  else:
    rdf_flow.flow_id = RandomFlowId()

  # For better performance, only do conflicting IDs check for top-level flows.
  if not parent_flow_obj:
    try:
      data_store.REL_DB.ReadFlowObject(client_id, rdf_flow.flow_id)
      raise CanNotStartFlowWithExistingIdError(client_id, rdf_flow.flow_id)
    except db.UnknownFlowError:
      pass

  if parent_flow_obj:  # A flow is a nested flow.
    parent_rdf_flow = parent_flow_obj.rdf_flow
    rdf_flow.long_flow_id = "%s/%s" % (parent_rdf_flow.long_flow_id,
                                       rdf_flow.flow_id)
    rdf_flow.parent_flow_id = parent_rdf_flow.flow_id
    rdf_flow.parent_hunt_id = parent_rdf_flow.parent_hunt_id
    rdf_flow.parent_request_id = parent_flow_obj.GetCurrentOutboundId()
    if parent_rdf_flow.creator:
      rdf_flow.creator = parent_rdf_flow.creator
  elif parent_hunt_id:  # A flow is a root-level hunt-induced flow.
    rdf_flow.long_flow_id = "%s/%s" % (client_id, rdf_flow.flow_id)
    rdf_flow.parent_hunt_id = parent_hunt_id
  else:  # A flow is a root-level non-hunt flow.
    rdf_flow.long_flow_id = "%s/%s" % (client_id, rdf_flow.flow_id)

  if output_plugins:
    rdf_flow.output_plugins_states = GetOutputPluginStates(
        output_plugins,
        rdf_flow.long_flow_id,
        token=access_control.ACLToken(username=rdf_flow.creator))

  if network_bytes_limit is not None:
    rdf_flow.network_bytes_limit = network_bytes_limit
  if cpu_limit is not None:
    rdf_flow.cpu_limit = cpu_limit
  if runtime_limit is not None:
    rdf_flow.runtime_limit_us = runtime_limit

  logging.info(u"Scheduling %s(%s) on %s (%s)", rdf_flow.long_flow_id,
               rdf_flow.flow_class_name, client_id, start_at or "now")

  rdf_flow.current_state = "Start"

  flow_obj = flow_cls(rdf_flow)

  # Prevent a race condition, where a flow is scheduled twice, because one
  # worker inserts the row and another worker silently updates the existing row.
  allow_update = False

  if start_at is None:

    # Store an initial version of the flow straight away. This is needed so the
    # database doesn't raise consistency errors due to missing parent keys when
    # writing logs / errors / results which might happen in Start().
    try:
      data_store.REL_DB.WriteFlowObject(flow_obj.rdf_flow, allow_update=False)
    except db.FlowExistsError:
      raise CanNotStartFlowWithExistingIdError(client_id, rdf_flow.flow_id)

    allow_update = True

    try:
      # Just run the first state inline. NOTE: Running synchronously means
      # that this runs on the thread that starts the flow. The advantage is
      # that that Start method can raise any errors immediately.
      flow_obj.Start()

      # The flow does not need to actually remain running.
      if not flow_obj.outstanding_requests:
        flow_obj.RunStateMethod("End")
        # Additional check for the correct state in case the End method raised
        # and terminated the flow.
        if flow_obj.IsRunning():
          flow_obj.MarkDone()
    except Exception as e:  # pylint: disable=broad-except
      # We catch all exceptions that happen in Start() and mark the flow as
      # failed.
      msg = compatibility.NativeStr(e)
      if compatibility.PY2:
        msg = msg.decode("utf-8", "replace")

      flow_obj.Error(error_message=msg, backtrace=traceback.format_exc())

  else:
    flow_obj.CallState("Start", start_time=start_at)

  flow_obj.PersistState()

  try:
    data_store.REL_DB.WriteFlowObject(
        flow_obj.rdf_flow, allow_update=allow_update)
  except db.FlowExistsError:
    raise CanNotStartFlowWithExistingIdError(client_id, rdf_flow.flow_id)

  if parent_flow_obj is not None:
    # We can optimize here and not write requests/responses to the database
    # since we have to do this for the parent flow at some point anyways.
    parent_flow_obj.MergeQueuedMessages(flow_obj)
  else:
    flow_obj.FlushQueuedMessages()

  return rdf_flow.flow_id
コード例 #16
0
    def RunStateMethod(self, method_name, request=None, responses=None):
        """Completes the request by calling the state method.

    Args:
      method_name: The name of the state method to call.
      request: A RequestState protobuf.
      responses: A list of FlowMessages responding to the request.
    """
        if self.rdf_flow.pending_termination:
            self.Error(error_message=self.rdf_flow.pending_termination.reason)
            return

        client_id = self.rdf_flow.client_id

        deadline = self.rdf_flow.processing_deadline
        if deadline and rdfvalue.RDFDatetime.Now() > deadline:
            raise flow.FlowError(
                "Processing time for flow %s on %s expired." %
                (self.rdf_flow.flow_id, self.rdf_flow.client_id))

        self.rdf_flow.current_state = method_name
        if request and responses:
            logging.debug("Running %s for flow %s on %s, %d responses.",
                          method_name, self.rdf_flow.flow_id, client_id,
                          len(responses))
        else:
            logging.debug("Running %s for flow %s on %s", method_name,
                          self.rdf_flow.flow_id, client_id)

        try:
            try:
                method = getattr(self, method_name)
            except AttributeError:
                raise ValueError("Flow %s has no state method %s" %
                                 (self.__class__.__name__, method_name))

            # Prepare a responses object for the state method to use:
            responses = flow_responses.Responses.FromResponses(
                request=request, responses=responses)

            if responses.status is not None:
                self.SaveResourceUsage(responses.status)

            stats_collector_instance.Get().IncrementCounter(
                "grr_worker_states_run")

            if method_name == "Start":
                stats_collector_instance.Get().IncrementCounter(
                    "flow_starts", fields=[self.rdf_flow.flow_class_name])
                method()
            else:
                method(responses)

            if self.replies_to_process:
                if self.rdf_flow.parent_hunt_id and not self.rdf_flow.parent_flow_id:
                    self._ProcessRepliesWithHuntOutputPlugins(
                        self.replies_to_process)
                else:
                    self._ProcessRepliesWithFlowOutputPlugins(
                        self.replies_to_process)

                self.replies_to_process = []

        except flow.FlowResourcesExceededError as e:
            stats_collector_instance.Get().IncrementCounter(
                "flow_errors", fields=[self.rdf_flow.flow_class_name])
            logging.info("Flow %s on %s exceeded resource limits: %s.",
                         self.rdf_flow.flow_id, client_id, str(e))
            self.Error(error_message=str(e), backtrace=traceback.format_exc())
        # We don't know here what exceptions can be thrown in the flow but we have
        # to continue. Thus, we catch everything.
        except Exception as e:  # pylint: disable=broad-except
            # TODO(amoser): We don't know what's in this exception so we have to deal
            # with all eventualities. Replace this code with a simple str(e) once
            # Python 2 support has been dropped.
            msg = compatibility.NativeStr(e)
            if compatibility.PY2:
                msg = msg.decode("utf-8", "replace")

            stats_collector_instance.Get().IncrementCounter(
                "flow_errors", fields=[self.rdf_flow.flow_class_name])
            logging.exception("Flow %s on %s raised %s.",
                              self.rdf_flow.flow_id, client_id, msg)

            self.Error(error_message=msg, backtrace=traceback.format_exc())