예제 #1
0
    def UpdateHuntOutputPluginState(self,
                                    hunt_id,
                                    state_index,
                                    update_fn,
                                    cursor=None):
        """Updates hunt output plugin state for a given output plugin."""

        hunt_id_int = db_utils.HuntIDToInt(hunt_id)

        query = "SELECT hunt_id FROM hunts WHERE hunt_id = %s"
        rows_returned = cursor.execute(query, [hunt_id_int])
        if rows_returned == 0:
            raise db.UnknownHuntError(hunt_id)

        columns = ", ".join(_HUNT_OUTPUT_PLUGINS_STATES_COLUMNS)
        query = ("SELECT {columns} FROM hunt_output_plugins_states "
                 "WHERE hunt_id = %s AND plugin_id = %s".format(
                     columns=columns))
        rows_returned = cursor.execute(query, [hunt_id_int, state_index])
        if rows_returned == 0:
            raise db.UnknownHuntOutputPluginStateError(hunt_id, state_index)

        state = self._HuntOutputPluginStateFromRow(cursor.fetchone())
        modified_plugin_state = update_fn(state.plugin_state)

        query = ("UPDATE hunt_output_plugins_states "
                 "SET plugin_state = %s "
                 "WHERE hunt_id = %s AND plugin_id = %s")
        args = [
            modified_plugin_state.SerializeToBytes(), hunt_id_int, state_index
        ]
        cursor.execute(query, args)
        return state
예제 #2
0
    def UpdateHuntObject(self,
                         hunt_id,
                         duration=None,
                         client_rate=None,
                         client_limit=None,
                         hunt_state=None,
                         hunt_state_comment=None,
                         start_time=None,
                         num_clients_at_start_time=None,
                         cursor=None):
        """Updates the hunt object by applying the update function."""
        vals = []
        args = {}

        if duration is not None:
            vals.append("duration_micros = %(duration_micros)s")
            args["duration_micros"] = duration.microseconds

        if client_rate is not None:
            vals.append("client_rate = %(client_rate)s")
            args["client_rate"] = client_rate

        if client_limit is not None:
            vals.append("client_limit = %(client_limit)s")
            args["client_limit"] = client_limit

        if hunt_state is not None:
            vals.append("hunt_state = %(hunt_state)s")
            args["hunt_state"] = int(hunt_state)

        if hunt_state_comment is not None:
            vals.append("hunt_state_comment = %(hunt_state_comment)s")
            args["hunt_state_comment"] = hunt_state_comment

        if start_time is not None:
            vals.append("""
      init_start_time = IFNULL(init_start_time, FROM_UNIXTIME(%(start_time)s))
      """)
            vals.append("""
      last_start_time = FROM_UNIXTIME(%(start_time)s)
      """)
            args["start_time"] = mysql_utils.RDFDatetimeToTimestamp(start_time)

        if num_clients_at_start_time is not None:
            vals.append(
                "num_clients_at_start_time = %(num_clients_at_start_time)s")
            args["num_clients_at_start_time"] = num_clients_at_start_time

        vals.append("last_update_timestamp = NOW(6)")

        query = """
    UPDATE hunts
       SET {updates}
     WHERE hunt_id = %(hunt_id)s
    """.format(updates=", ".join(vals))
        args["hunt_id"] = db_utils.HuntIDToInt(hunt_id)

        rows_modified = cursor.execute(query, args)
        if rows_modified == 0:
            raise db.UnknownHuntError(hunt_id)
예제 #3
0
    def WriteHuntOutputPluginsStates(self, hunt_id, states, cursor=None):
        """Writes hunt output plugin states for a given hunt."""

        columns = ", ".join(_HUNT_OUTPUT_PLUGINS_STATES_COLUMNS)
        placeholders = mysql_utils.Placeholders(
            2 + len(_HUNT_OUTPUT_PLUGINS_STATES_COLUMNS))
        hunt_id_int = db_utils.HuntIDToInt(hunt_id)

        for index, state in enumerate(states):
            query = ("INSERT INTO hunt_output_plugins_states "
                     "(hunt_id, plugin_id, {columns}) "
                     "VALUES {placeholders}".format(columns=columns,
                                                    placeholders=placeholders))
            args = [hunt_id_int, index, state.plugin_descriptor.plugin_name]

            if state.plugin_descriptor.plugin_args is None:
                args.append(None)
            else:
                args.append(
                    state.plugin_descriptor.plugin_args.SerializeToBytes())

            args.append(state.plugin_state.SerializeToBytes())

            try:
                cursor.execute(query, args)
            except MySQLdb.IntegrityError as e:
                raise db.UnknownHuntError(hunt_id=hunt_id, cause=e)
예제 #4
0
    def WriteHuntOutputPluginsStates(self, hunt_id, states):

        if hunt_id not in self.hunts:
            raise db.UnknownHuntError(hunt_id)

        self.hunt_output_plugins_states[hunt_id] = [
            s.SerializeToString() for s in states
        ]
예제 #5
0
    def ReadHuntOutputPluginsStates(self, hunt_id):
        if hunt_id not in self.hunts:
            raise db.UnknownHuntError(hunt_id)

        serialized_states = self.hunt_output_plugins_states.get(hunt_id, [])
        return [
            rdf_flow_runner.OutputPluginState.FromSerializedString(s)
            for s in serialized_states
        ]
예제 #6
0
    def ReadHuntObject(self, hunt_id, cursor=None):
        """Reads a hunt object from the database."""
        query = ("SELECT {columns} "
                 "FROM hunts WHERE hunt_id = %s".format(
                     columns=_HUNT_COLUMNS_SELECT))

        nr_results = cursor.execute(query, [db_utils.HuntIDToInt(hunt_id)])
        if nr_results == 0:
            raise db.UnknownHuntError(hunt_id)

        return self._HuntObjectFromRow(cursor.fetchone())
예제 #7
0
    def DeleteHuntObject(self, hunt_id, cursor=None):
        """Deletes a given hunt object."""
        query = "DELETE FROM hunts WHERE hunt_id = %s"
        hunt_id_int = db_utils.HuntIDToInt(hunt_id)

        rows_deleted = cursor.execute(query, [hunt_id_int])
        if rows_deleted == 0:
            raise db.UnknownHuntError(hunt_id)

        query = "DELETE FROM hunt_output_plugins_states WHERE hunt_id = %s"
        cursor.execute(query, [hunt_id_int])
예제 #8
0
    def UpdateHuntOutputPluginState(self, hunt_id, state_index, update_fn):
        """Updates hunt output plugin state for a given output plugin."""

        if hunt_id not in self.hunts:
            raise db.UnknownHuntError(hunt_id)

        try:
            state = rdf_flow_runner.OutputPluginState.FromSerializedString(
                self.hunt_output_plugins_states[hunt_id][state_index])
        except KeyError:
            raise db.UnknownHuntOutputPluginError(hunt_id, state_index)

        state.plugin_state = update_fn(state.plugin_state)

        self.hunt_output_plugins_states[hunt_id][
            state_index] = state.SerializeToString()

        return state.plugin_state
예제 #9
0
    def ReadHuntOutputPluginsStates(self, hunt_id, cursor=None):
        """Reads all hunt output plugins states of a given hunt."""

        columns = ", ".join(_HUNT_OUTPUT_PLUGINS_STATES_COLUMNS)

        query = ("SELECT {columns} FROM hunt_output_plugins_states "
                 "WHERE hunt_id = %s".format(columns=columns))
        rows_returned = cursor.execute(query, [db_utils.HuntIDToInt(hunt_id)])
        if rows_returned > 0:
            states = []
            for row in cursor.fetchall():
                states.append(self._HuntOutputPluginStateFromRow(row))
            return states

        query = "SELECT hunt_id FROM hunts WHERE hunt_id = %s"
        rows_returned = cursor.execute(query, [db_utils.HuntIDToInt(hunt_id)])
        if rows_returned == 0:
            raise db.UnknownHuntError(hunt_id)

        return []
예제 #10
0
 def ReadHuntObject(self, hunt_id):
     """Reads a hunt object from the database."""
     try:
         return self._DeepCopy(self.hunts[hunt_id])
     except KeyError:
         raise db.UnknownHuntError(hunt_id)
예제 #11
0
 def DeleteHuntObject(self, hunt_id):
     try:
         del self.hunts[hunt_id]
     except KeyError:
         raise db.UnknownHuntError(hunt_id)