def __init__(self, delimiter=","): precondition.AssertType(delimiter, Text) if compatibility.PY2: self._output = io.BytesIO() else: self._output = io.StringIO() self._csv = csv.writer(self._output, delimiter=compatibility.NativeStr(delimiter), lineterminator=compatibility.NativeStr("\n"))
def testReturnValueForFollowingCallsIsCached(self): result = object() mock_fn = mock.Mock(side_effect=[result]) mock_fn.__name__ = compatibility.NativeStr("MockFunction") fn = utils.RunOnce(mock_fn) self.assertIs(fn(), result) self.assertIs(fn(), result)
def testDecoratedFunctionIsCalledAtLeastOnce(self): mock_fn = mock.Mock() mock_fn.__name__ = compatibility.NativeStr("MockFunction") fn = utils.RunOnce(mock_fn) mock_fn.assert_not_called() fn() mock_fn.assert_called_once()
def pytest_cmdline_main(config): """A pytest hook that is called when the main function is executed.""" del config # Unused. # TODO: `sys.argv` on Python 2 uses `bytes` to represent passed # arguments. sys.argv = [compatibility.NativeStr("pytest")] + test_args
def testExceptionsArePassedThrough(self): mock_fn = mock.Mock(side_effect=ValueError()) mock_fn.__name__ = compatibility.NativeStr("MockFunction") fn = utils.RunOnce(mock_fn) with self.assertRaises(ValueError): fn() with self.assertRaises(ValueError): fn()
def testDecoratedFunctionIsCalledAtMostOnce(self): mock_fn = mock.Mock(side_effect=[None, AssertionError()]) mock_fn.__name__ = compatibility.NativeStr("MockFunction") fn = utils.RunOnce(mock_fn) fn() fn() fn() mock_fn.assert_called_once()
def pytest_cmdline_main(config): """A pytest hook that is called when the main function is executed.""" if "PYTEST_XDIST_WORKER" in os.environ: # If ran concurrently using pytest-xdist (`-n` cli flag), mainargv is the # result of the execution of pytest_cmdline_main in the main process. sys.argv = config.workerinput["mainargv"] else: # TODO: `sys.argv` on Python 2 uses `bytes` to represent passed # arguments. sys.argv = [compatibility.NativeStr("pytest")] + test_args
def __iter__(self): if compatibility.PY2: filedesc = io.BytesIO(self._content.encode("utf-8")) else: filedesc = io.StringIO(self._content) reader = csv.reader(filedesc, delimiter=compatibility.NativeStr(self._delimiter), lineterminator=compatibility.NativeStr("\n")) for values in reader: row = [] for value in values: if compatibility.PY2: # TODO(hanuszczak): https://github.com/google/pytype/issues/127 row.append(value.decode("utf-8")) # pytype: disable=attribute-error else: row.append(value) yield row
def Corruptor(url="", data=None, **kwargs): """Futz with some of the fields.""" comm_cls = rdf_flows.ClientCommunication if data is not None: self.client_communication = comm_cls.FromSerializedBytes(data) else: self.client_communication = comm_cls(None) if self.corruptor_field and "server.pem" not in url: orig_str_repr = self.client_communication.SerializeToBytes() field_data = getattr(self.client_communication, self.corruptor_field) if hasattr(field_data, "SerializeToBytes"): # This converts encryption keys to a string so we can corrupt them. field_data = field_data.SerializeToBytes() # TODO: We use `bytes` from the `future` package here to # have Python 3 iteration behaviour. This call should be a noop in # Python 3 and should be safe to remove on support for Python 2 is # dropped. field_data = bytes(field_data) # TODO: On Python 2.7.6 and lower `array.array` accepts # only byte strings as argument so the call below is necessary. Once # support for old Python versions is dropped, this call should be # removed. modified_data = array.array(compatibility.NativeStr("B"), field_data) offset = len(field_data) // 2 char = field_data[offset] modified_data[offset] = char % 250 + 1 setattr(self.client_communication, self.corruptor_field, modified_data.tostring()) # Make sure we actually changed the data. self.assertNotEqual(field_data, modified_data) mod_str_repr = self.client_communication.SerializeToBytes() self.assertLen(orig_str_repr, len(mod_str_repr)) differences = [ True for x, y in zip(orig_str_repr, mod_str_repr) if x != y ] self.assertLen(differences, 1) data = self.client_communication.SerializeToBytes() return self.UrlMock(url=url, data=data, **kwargs)
def _FetchLinuxFlags(self): """Fetches Linux extended file flags.""" if platform.system() != "Linux": return 0 # Since we open a file in the next step we do not want to open a symlink. # `lsattr` returns an error when trying to check flags of a symlink, so we # assume that symlinks cannot have them. if self.IsSymlink(): return 0 # Some files (e.g. sockets) cannot be opened. For these we do not really # care about extended flags (they should have none). `lsattr` does not seem # to support such cases anyway. It is also possible that a file has been # deleted (because this method is used lazily). try: fd = os.open(self._path, os.O_RDONLY) except (IOError, OSError): return 0 try: # This import is Linux-specific. import fcntl # pylint: disable=g-import-not-at-top # TODO: On Python 2.7.6 `array.array` accepts only byte # strings as an argument. On Python 2.7.12 and 2.7.13 unicodes are # supported as well. On Python 3, only unicode strings are supported. This # is why, as a temporary hack, we wrap the literal with `str` call that # will convert it to whatever is the default on given Python version. This # should be changed to raw literal once support for Python 2 is dropped. buf = array.array(compatibility.NativeStr("l"), [0]) # TODO(user):pytype: incorrect type spec for fcntl.ioctl # pytype: disable=wrong-arg-types fcntl.ioctl(fd, self.FS_IOC_GETFLAGS, buf) # pytype: enable=wrong-arg-types return buf[0] except (IOError, OSError): # File system does not support extended attributes. return 0 finally: os.close(fd)
def testWrapsFunctionProperly(self): mock_fn = mock.Mock() mock_fn.__name__ = compatibility.NativeStr("MockFunction") fn = utils.RunOnce(mock_fn) self.assertEqual(fn.__name__, compatibility.NativeStr("MockFunction"))
def testReturnValueIsPassedThrough(self): mock_fn = mock.Mock(return_value="bar") mock_fn.__name__ = compatibility.NativeStr("MockFunction") fn = utils.RunOnce(mock_fn) self.assertEqual("bar", fn())
def testArgumentsArePassedThrough(self): mock_fn = mock.Mock() mock_fn.__name__ = compatibility.NativeStr("MockFunction") fn = utils.RunOnce(mock_fn) fn(1, 2, foo="bar") mock_fn.assert_called_once_with(1, 2, foo="bar")
def RunStateMethod( self, method_name: str, request: Optional[rdf_flow_runner.RequestState] = None, responses: Optional[Sequence[rdf_flow_objects.FlowMessage]] = None ) -> None: """Completes the request by calling the state method. Args: method_name: The name of the state method to call. request: A RequestState protobuf. responses: A list of FlowMessages responding to the request. Raises: FlowError: Processing time for the flow has expired. """ client_id = self.rdf_flow.client_id deadline = self.rdf_flow.processing_deadline if deadline and rdfvalue.RDFDatetime.Now() > deadline: raise FlowError("Processing time for flow %s on %s expired." % (self.rdf_flow.flow_id, self.rdf_flow.client_id)) self.rdf_flow.current_state = method_name if request and responses: logging.debug("Running %s for flow %s on %s, %d responses.", method_name, self.rdf_flow.flow_id, client_id, len(responses)) else: logging.debug("Running %s for flow %s on %s", method_name, self.rdf_flow.flow_id, client_id) try: try: method = getattr(self, method_name) except AttributeError: raise ValueError("Flow %s has no state method %s" % (self.__class__.__name__, method_name)) # Prepare a responses object for the state method to use: responses = flow_responses.Responses.FromResponses( request=request, responses=responses) if responses.status is not None: self.SaveResourceUsage(responses.status) GRR_WORKER_STATES_RUN.Increment() if method_name == "Start": FLOW_STARTS.Increment(fields=[self.rdf_flow.flow_class_name]) method() else: method(responses) if self.replies_to_process: if self.rdf_flow.parent_hunt_id and not self.rdf_flow.parent_flow_id: self._ProcessRepliesWithHuntOutputPlugins( self.replies_to_process) else: self._ProcessRepliesWithFlowOutputPlugins( self.replies_to_process) self.replies_to_process = [] except flow.FlowResourcesExceededError as e: FLOW_ERRORS.Increment(fields=[self.rdf_flow.flow_class_name]) logging.info("Flow %s on %s exceeded resource limits: %s.", self.rdf_flow.flow_id, client_id, str(e)) self.Error(error_message=str(e)) # We don't know here what exceptions can be thrown in the flow but we have # to continue. Thus, we catch everything. except Exception as e: # pylint: disable=broad-except # TODO(amoser): We don't know what's in this exception so we have to deal # with all eventualities. Replace this code with a simple str(e) once # Python 2 support has been dropped. msg = compatibility.NativeStr(e) FLOW_ERRORS.Increment(fields=[self.rdf_flow.flow_class_name]) self.Error(error_message=msg, backtrace=traceback.format_exc())
def StartFlow(client_id=None, cpu_limit=None, creator=None, flow_args=None, flow_cls=None, network_bytes_limit=None, original_flow=None, output_plugins=None, start_at=None, parent_flow_obj=None, parent_hunt_id=None, runtime_limit=None, **kwargs): """The main factory function for creating and executing a new flow. Args: client_id: ID of the client this flow should run on. cpu_limit: CPU limit in seconds for this flow. creator: Username that requested this flow. flow_args: An arg protocol buffer which is an instance of the required flow's args_type class attribute. flow_cls: Class of the flow that should be started. network_bytes_limit: Limit on the network traffic this flow can generated. original_flow: A FlowReference object in case this flow was copied from another flow. output_plugins: An OutputPluginDescriptor object indicating what output plugins should be used for this flow. start_at: If specified, flow will be started not immediately, but at a given time. parent_flow_obj: A parent flow object. None if this is a top level flow. parent_hunt_id: String identifying parent hunt. Can't be passed together with parent_flow_obj. runtime_limit: Runtime limit as Duration for all ClientActions. **kwargs: If args or runner_args are not specified, we construct these protobufs from these keywords. Returns: the flow id of the new flow. Raises: ValueError: Unknown or invalid parameters were provided. """ if parent_flow_obj is not None and parent_hunt_id is not None: raise ValueError( "parent_flow_obj and parent_hunt_id are mutually exclusive.") # Is the required flow a known flow? try: registry.FlowRegistry.FlowClassByName(flow_cls.__name__) except ValueError: GRR_FLOW_INVALID_FLOW_COUNT.Increment() raise ValueError("Unable to locate flow %s" % flow_cls.__name__) if not client_id: raise ValueError("Client_id is needed to start a flow.") # Now parse the flow args into the new object from the keywords. if flow_args is None: flow_args = flow_cls.args_type() FilterArgsFromSemanticProtobuf(flow_args, kwargs) # At this point we should exhaust all the keyword args. If any are left # over, we do not know what to do with them so raise. if kwargs: raise type_info.UnknownArg("Unknown parameters to StartFlow: %s" % kwargs) # Check that the flow args are valid. flow_args.Validate() rdf_flow = rdf_flow_objects.Flow( client_id=client_id, flow_class_name=flow_cls.__name__, args=flow_args, create_time=rdfvalue.RDFDatetime.Now(), creator=creator, output_plugins=output_plugins, original_flow=original_flow, flow_state="RUNNING") if parent_hunt_id is not None and parent_flow_obj is None: rdf_flow.flow_id = parent_hunt_id else: rdf_flow.flow_id = RandomFlowId() # For better performance, only do conflicting IDs check for top-level flows. if not parent_flow_obj: try: data_store.REL_DB.ReadFlowObject(client_id, rdf_flow.flow_id) raise CanNotStartFlowWithExistingIdError(client_id, rdf_flow.flow_id) except db.UnknownFlowError: pass if parent_flow_obj: # A flow is a nested flow. parent_rdf_flow = parent_flow_obj.rdf_flow rdf_flow.long_flow_id = "%s/%s" % (parent_rdf_flow.long_flow_id, rdf_flow.flow_id) rdf_flow.parent_flow_id = parent_rdf_flow.flow_id rdf_flow.parent_hunt_id = parent_rdf_flow.parent_hunt_id rdf_flow.parent_request_id = parent_flow_obj.GetCurrentOutboundId() if parent_rdf_flow.creator: rdf_flow.creator = parent_rdf_flow.creator elif parent_hunt_id: # A flow is a root-level hunt-induced flow. rdf_flow.long_flow_id = "%s/%s" % (client_id, rdf_flow.flow_id) rdf_flow.parent_hunt_id = parent_hunt_id else: # A flow is a root-level non-hunt flow. rdf_flow.long_flow_id = "%s/%s" % (client_id, rdf_flow.flow_id) if output_plugins: rdf_flow.output_plugins_states = GetOutputPluginStates( output_plugins, rdf_flow.long_flow_id, token=access_control.ACLToken(username=rdf_flow.creator)) if network_bytes_limit is not None: rdf_flow.network_bytes_limit = network_bytes_limit if cpu_limit is not None: rdf_flow.cpu_limit = cpu_limit if runtime_limit is not None: rdf_flow.runtime_limit_us = runtime_limit logging.info(u"Scheduling %s(%s) on %s (%s)", rdf_flow.long_flow_id, rdf_flow.flow_class_name, client_id, start_at or "now") rdf_flow.current_state = "Start" flow_obj = flow_cls(rdf_flow) # Prevent a race condition, where a flow is scheduled twice, because one # worker inserts the row and another worker silently updates the existing row. allow_update = False if start_at is None: # Store an initial version of the flow straight away. This is needed so the # database doesn't raise consistency errors due to missing parent keys when # writing logs / errors / results which might happen in Start(). try: data_store.REL_DB.WriteFlowObject(flow_obj.rdf_flow, allow_update=False) except db.FlowExistsError: raise CanNotStartFlowWithExistingIdError(client_id, rdf_flow.flow_id) allow_update = True try: # Just run the first state inline. NOTE: Running synchronously means # that this runs on the thread that starts the flow. The advantage is # that that Start method can raise any errors immediately. flow_obj.Start() # The flow does not need to actually remain running. if not flow_obj.outstanding_requests: flow_obj.RunStateMethod("End") # Additional check for the correct state in case the End method raised # and terminated the flow. if flow_obj.IsRunning(): flow_obj.MarkDone() except Exception as e: # pylint: disable=broad-except # We catch all exceptions that happen in Start() and mark the flow as # failed. msg = compatibility.NativeStr(e) if compatibility.PY2: msg = msg.decode("utf-8", "replace") flow_obj.Error(error_message=msg, backtrace=traceback.format_exc()) else: flow_obj.CallState("Start", start_time=start_at) flow_obj.PersistState() try: data_store.REL_DB.WriteFlowObject( flow_obj.rdf_flow, allow_update=allow_update) except db.FlowExistsError: raise CanNotStartFlowWithExistingIdError(client_id, rdf_flow.flow_id) if parent_flow_obj is not None: # We can optimize here and not write requests/responses to the database # since we have to do this for the parent flow at some point anyways. parent_flow_obj.MergeQueuedMessages(flow_obj) else: flow_obj.FlushQueuedMessages() return rdf_flow.flow_id
def RunStateMethod(self, method_name, request=None, responses=None): """Completes the request by calling the state method. Args: method_name: The name of the state method to call. request: A RequestState protobuf. responses: A list of FlowMessages responding to the request. """ if self.rdf_flow.pending_termination: self.Error(error_message=self.rdf_flow.pending_termination.reason) return client_id = self.rdf_flow.client_id deadline = self.rdf_flow.processing_deadline if deadline and rdfvalue.RDFDatetime.Now() > deadline: raise flow.FlowError( "Processing time for flow %s on %s expired." % (self.rdf_flow.flow_id, self.rdf_flow.client_id)) self.rdf_flow.current_state = method_name if request and responses: logging.debug("Running %s for flow %s on %s, %d responses.", method_name, self.rdf_flow.flow_id, client_id, len(responses)) else: logging.debug("Running %s for flow %s on %s", method_name, self.rdf_flow.flow_id, client_id) try: try: method = getattr(self, method_name) except AttributeError: raise ValueError("Flow %s has no state method %s" % (self.__class__.__name__, method_name)) # Prepare a responses object for the state method to use: responses = flow_responses.Responses.FromResponses( request=request, responses=responses) if responses.status is not None: self.SaveResourceUsage(responses.status) stats_collector_instance.Get().IncrementCounter( "grr_worker_states_run") if method_name == "Start": stats_collector_instance.Get().IncrementCounter( "flow_starts", fields=[self.rdf_flow.flow_class_name]) method() else: method(responses) if self.replies_to_process: if self.rdf_flow.parent_hunt_id and not self.rdf_flow.parent_flow_id: self._ProcessRepliesWithHuntOutputPlugins( self.replies_to_process) else: self._ProcessRepliesWithFlowOutputPlugins( self.replies_to_process) self.replies_to_process = [] except flow.FlowResourcesExceededError as e: stats_collector_instance.Get().IncrementCounter( "flow_errors", fields=[self.rdf_flow.flow_class_name]) logging.info("Flow %s on %s exceeded resource limits: %s.", self.rdf_flow.flow_id, client_id, str(e)) self.Error(error_message=str(e), backtrace=traceback.format_exc()) # We don't know here what exceptions can be thrown in the flow but we have # to continue. Thus, we catch everything. except Exception as e: # pylint: disable=broad-except # TODO(amoser): We don't know what's in this exception so we have to deal # with all eventualities. Replace this code with a simple str(e) once # Python 2 support has been dropped. msg = compatibility.NativeStr(e) if compatibility.PY2: msg = msg.decode("utf-8", "replace") stats_collector_instance.Get().IncrementCounter( "flow_errors", fields=[self.rdf_flow.flow_class_name]) logging.exception("Flow %s on %s raised %s.", self.rdf_flow.flow_id, client_id, msg) self.Error(error_message=msg, backtrace=traceback.format_exc())