def WriteFlowObject(self, flow_obj, cursor=None): """Writes a flow object to the database.""" query = ("INSERT INTO flows " "(client_id, flow_id, long_flow_id, parent_flow_id, flow, " "next_request_to_process, timestamp, last_update) VALUES " "(%s, %s, %s, %s, %s, %s, %s, %s) " "ON DUPLICATE KEY UPDATE " "flow=VALUES(flow), " "next_request_to_process=VALUES(next_request_to_process)," "last_update=VALUES(last_update)") if flow_obj.parent_flow_id: pfi = mysql_utils.FlowIDToInt(flow_obj.parent_flow_id) else: pfi = None timestamp_str = mysql_utils.RDFDatetimeToMysqlString(flow_obj.create_time) now_str = mysql_utils.RDFDatetimeToMysqlString(rdfvalue.RDFDatetime.Now()) args = [ mysql_utils.ClientIDToInt(flow_obj.client_id), mysql_utils.FlowIDToInt(flow_obj.flow_id), flow_obj.long_flow_id, pfi, flow_obj.SerializeToString(), flow_obj.next_request_to_process, timestamp_str, now_str ] try: cursor.execute(query, args) except MySQLdb.IntegrityError as e: raise db.UnknownClientError(flow_obj.client_id, cause=e)
def WriteFlowRequests(self, requests, cursor=None): """Writes a list of flow requests to the database.""" args = [] templates = [] flow_keys = [] needs_processing = {} now_str = mysql_utils.RDFDatetimeToMysqlString( rdfvalue.RDFDatetime.Now()) for r in requests: if r.needs_processing: needs_processing.setdefault((r.client_id, r.flow_id), []).append(r.request_id) flow_keys.append((r.client_id, r.flow_id)) templates.append("(%s, %s, %s, %s, %s, %s)") args.extend([ mysql_utils.ClientIDToInt(r.client_id), mysql_utils.FlowIDToInt(r.flow_id), r.request_id, r.needs_processing, r.SerializeToString(), now_str ]) if needs_processing: flow_processing_requests = [] nr_conditions = [] nr_args = [] for client_id, flow_id in needs_processing: nr_conditions.append("(client_id=%s AND flow_id=%s)") nr_args.append(mysql_utils.ClientIDToInt(client_id)) nr_args.append(mysql_utils.FlowIDToInt(flow_id)) nr_query = ("SELECT client_id, flow_id, next_request_to_process " "FROM flows WHERE ") nr_query += " OR ".join(nr_conditions) cursor.execute(nr_query, nr_args) db_result = cursor.fetchall() for client_id_int, flow_id_int, next_request_to_process in db_result: client_id = mysql_utils.IntToClientID(client_id_int) flow_id = mysql_utils.IntToFlowID(flow_id_int) if next_request_to_process in needs_processing[(client_id, flow_id)]: flow_processing_requests.append( rdf_flows.FlowProcessingRequest(client_id=client_id, flow_id=flow_id)) if flow_processing_requests: self._WriteFlowProcessingRequests(flow_processing_requests, cursor) query = ("INSERT INTO flow_requests " "(client_id, flow_id, request_id, needs_processing, request, " "timestamp) VALUES ") query += ", ".join(templates) try: cursor.execute(query, args) except MySQLdb.IntegrityError as e: raise db.AtLeastOneUnknownFlowError(flow_keys, cause=e)
def _ReadCurrentFlowInfo(self, responses, currently_available_requests, next_request_by_flow, responses_expected_by_request, current_responses_by_request, cursor): """Reads stored data for flows we want to modify.""" flow_conditions = [] flow_args = [] req_conditions = [] req_args = [] for r in responses: flow_conditions.append("(client_id=%s AND flow_id=%s)") flow_args.append(mysql_utils.ClientIDToInt(r.client_id)) flow_args.append(mysql_utils.FlowIDToInt(r.flow_id)) req_conditions.append("(client_id=%s AND flow_id=%s AND request_id=%s)") req_args.append(mysql_utils.ClientIDToInt(r.client_id)) req_args.append(mysql_utils.FlowIDToInt(r.flow_id)) req_args.append(r.request_id) flow_query = ("SELECT client_id, flow_id, next_request_to_process " "FROM flows WHERE ") flow_query += " OR ".join(flow_conditions) req_query = ("SELECT client_id, flow_id, request_id, responses_expected " "FROM flow_requests WHERE ") req_query += " OR ".join(req_conditions) res_query = ("SELECT client_id, flow_id, request_id, response_id " "FROM flow_responses WHERE ") res_query += " OR ".join(req_conditions) cursor.execute(flow_query, flow_args) for row in cursor.fetchall(): client_id_int, flow_id_int, next_request_to_process = row client_id = mysql_utils.IntToClientID(client_id_int) flow_id = mysql_utils.IntToFlowID(flow_id_int) next_request_by_flow[(client_id, flow_id)] = next_request_to_process cursor.execute(req_query, req_args) for row in cursor.fetchall(): client_id_int, flow_id_int, request_id, responses_expected = row client_id = mysql_utils.IntToClientID(client_id_int) flow_id = mysql_utils.IntToFlowID(flow_id_int) request_key = (client_id, flow_id, request_id) currently_available_requests.add(request_key) if responses_expected: responses_expected_by_request[request_key] = responses_expected cursor.execute(res_query, req_args) for row in cursor.fetchall(): client_id_int, flow_id_int, request_id, response_id = row client_id = mysql_utils.IntToClientID(client_id_int) flow_id = mysql_utils.IntToFlowID(flow_id_int) request_key = (client_id, flow_id, request_id) current_responses_by_request.setdefault(request_key, set()).add(response_id)
def _WriteFlowProcessingRequests(self, requests, cursor): """Returns a (query, args) tuple that inserts the given requests.""" timestamp = rdfvalue.RDFDatetime.Now() timestamp_str = mysql_utils.RDFDatetimeToMysqlString(timestamp) templates = [] args = [] for req in requests: templates.append("(%s, %s, %s, %s, %s)") req = req.Copy() req.timestamp = timestamp args.append(mysql_utils.ClientIDToInt(req.client_id)) args.append(mysql_utils.FlowIDToInt(req.flow_id)) args.append(timestamp_str) args.append(req.SerializeToString()) if req.delivery_time: args.append( mysql_utils.RDFDatetimeToMysqlString(req.delivery_time)) else: args.append(None) query = ( "INSERT INTO flow_processing_requests " "(client_id, flow_id, timestamp, request, delivery_time) VALUES ") query += ", ".join(templates) cursor.execute(query, args)
def ReadChildFlowObjects(self, client_id, flow_id, cursor=None): """Reads flows that were started by a given flow from the database.""" query = ("SELECT " + self.FLOW_DB_FIELDS + "FROM flows WHERE client_id=%s AND parent_flow_id=%s") cursor.execute(query, [ mysql_utils.ClientIDToInt(client_id), mysql_utils.FlowIDToInt(flow_id) ]) return [self._FlowObjectFromRow(row) for row in cursor.fetchall()]
def ReadFlowForProcessing(self, client_id, flow_id, processing_time, cursor=None): """Marks a flow as being processed on this worker and returns it.""" query = ("SELECT " + self.FLOW_DB_FIELDS + "FROM flows WHERE client_id=%s AND flow_id=%s") cursor.execute(query, [ mysql_utils.ClientIDToInt(client_id), mysql_utils.FlowIDToInt(flow_id) ]) response = cursor.fetchall() if not response: raise db.UnknownFlowError(client_id, flow_id) row, = response rdf_flow = self._FlowObjectFromRow(row) now = rdfvalue.RDFDatetime.Now() if rdf_flow.processing_on and rdf_flow.processing_deadline > now: raise ValueError( "Flow %s on client %s is already being processed." % (client_id, flow_id)) update_query = ( "UPDATE flows SET processing_on=%s, processing_since=%s, " "processing_deadline=%s WHERE client_id=%s and flow_id=%s") processing_deadline = now + processing_time process_id_string = utils.ProcessIdString() args = [ process_id_string, mysql_utils.RDFDatetimeToMysqlString(now), mysql_utils.RDFDatetimeToMysqlString(processing_deadline), mysql_utils.ClientIDToInt(client_id), mysql_utils.FlowIDToInt(flow_id) ] cursor.execute(update_query, args) # This needs to happen after we are sure that the write has succeeded. rdf_flow.processing_on = process_id_string rdf_flow.processing_since = now rdf_flow.processing_deadline = processing_deadline return rdf_flow
def DeleteAllFlowRequestsAndResponses(self, client_id, flow_id, cursor=None): """Deletes all requests and responses for a given flow from the database.""" args = [ mysql_utils.ClientIDToInt(client_id), mysql_utils.FlowIDToInt(flow_id) ] res_query = "DELETE FROM flow_responses WHERE client_id=%s AND flow_id=%s" cursor.execute(res_query, args) req_query = "DELETE FROM flow_requests WHERE client_id=%s AND flow_id=%s" cursor.execute(req_query, args)
def _UpdateExpected(self, requests, value_dict, cursor): """Updates requests that have their ResponsesExpected set.""" for client_id, flow_id, request_id in requests: query = ("UPDATE flow_requests SET responses_expected=%s " "WHERE client_id=%s AND flow_id=%s AND request_id=%s") args = [ value_dict[(client_id, flow_id, request_id)], mysql_utils.ClientIDToInt(client_id), mysql_utils.FlowIDToInt(flow_id), request_id ] cursor.execute(query, args)
def _UpdateNeedsProcessing(self, requests, cursor): """Updates requests that have their NeedsProcessing flag set.""" query = "UPDATE flow_requests SET needs_processing=TRUE WHERE" conditions = [] args = [] for client_id, flow_id, request_id in requests: conditions.append("(client_id=%s AND flow_id=%s AND request_id=%s)") args.append(mysql_utils.ClientIDToInt(client_id)) args.append(mysql_utils.FlowIDToInt(flow_id)) args.append(request_id) query += " OR ".join(conditions) cursor.execute(query, args)
def ReadFlowRequestsReadyForProcessing(self, client_id, flow_id, next_needed_request, cursor=None): """Reads all requests for a flow that can be processed by the worker.""" query = ( "SELECT request, needs_processing, timestamp FROM flow_requests " "WHERE client_id=%s AND flow_id=%s") args = [ mysql_utils.ClientIDToInt(client_id), mysql_utils.FlowIDToInt(flow_id) ] cursor.execute(query, args) requests = {} for req, needs_processing, ts in cursor.fetchall(): if not needs_processing: continue request = rdf_flow_objects.FlowRequest.FromSerializedString(req) request.needs_processing = needs_processing request.timestamp = mysql_utils.MysqlToRDFDatetime(ts) requests[request.request_id] = request query = ( "SELECT response, status, iterator, timestamp FROM flow_responses " "WHERE client_id=%s AND flow_id=%s") cursor.execute(query, args) responses = {} for res, status, iterator, ts in cursor.fetchall(): if status: response = rdf_flow_objects.FlowStatus.FromSerializedString( status) elif iterator: response = rdf_flow_objects.FlowIterator.FromSerializedString( iterator) else: response = rdf_flow_objects.FlowResponse.FromSerializedString( res) response.timestamp = mysql_utils.MysqlToRDFDatetime(ts) responses.setdefault(response.request_id, []).append(response) res = {} while next_needed_request in requests: req = requests[next_needed_request] sorted_responses = sorted(responses.get(next_needed_request, []), key=lambda r: r.response_id) res[req.request_id] = (req, sorted_responses) next_needed_request += 1 return res
def _UpdateCombined(self, requests, value_dict, cursor): """Updates requests that have both fields changes.""" for client_id, flow_id, request_id in requests: query = ("UPDATE flow_requests SET responses_expected=%s, " "needs_processing=TRUE " "WHERE client_id=%s AND flow_id=%s AND request_id=%s") args = [ value_dict[(client_id, flow_id, request_id)], mysql_utils.ClientIDToInt(client_id), mysql_utils.FlowIDToInt(flow_id), request_id ] cursor.execute(query, args)
def ReadFlowObject(self, client_id, flow_id, cursor=None): """Reads a flow object from the database.""" query = ("SELECT " + self.FLOW_DB_FIELDS + "FROM flows WHERE client_id=%s AND flow_id=%s") cursor.execute(query, [ mysql_utils.ClientIDToInt(client_id), mysql_utils.FlowIDToInt(flow_id) ]) result = cursor.fetchall() if not result: raise db.UnknownFlowError(client_id, flow_id) row, = result return self._FlowObjectFromRow(row)
def UpdateFlow(self, client_id, flow_id, flow_obj=db.Database.unchanged, flow_state=db.Database.unchanged, client_crash_info=db.Database.unchanged, pending_termination=db.Database.unchanged, processing_on=db.Database.unchanged, processing_since=db.Database.unchanged, processing_deadline=db.Database.unchanged, cursor=None): """Updates flow objects in the database.""" updates = [] args = [] if flow_obj != db.Database.unchanged: updates.append("flow=%s") args.append(flow_obj.SerializeToString()) updates.append("flow_state=%s") args.append(int(flow_obj.flow_state)) if flow_state != db.Database.unchanged: updates.append("flow_state=%s") args.append(int(flow_state)) if client_crash_info != db.Database.unchanged: updates.append("client_crash_info=%s") args.append(client_crash_info.SerializeToString()) if pending_termination != db.Database.unchanged: updates.append("pending_termination=%s") args.append(pending_termination.SerializeToString()) if processing_on != db.Database.unchanged: updates.append("processing_on=%s") args.append(processing_on) if processing_since != db.Database.unchanged: updates.append("processing_since=%s") args.append(mysql_utils.RDFDatetimeToMysqlString(processing_since)) if processing_deadline != db.Database.unchanged: updates.append("processing_deadline=%s") args.append(mysql_utils.RDFDatetimeToMysqlString(processing_deadline)) if not updates: return query = "UPDATE flows SET " query += ", ".join(updates) query += " WHERE client_id=%s AND flow_id=%s" args.append(mysql_utils.ClientIDToInt(client_id)) args.append(mysql_utils.FlowIDToInt(flow_id)) updated = cursor.execute(query, args) if updated == 0: raise db.UnknownFlowError(client_id, flow_id)
def ReturnProcessedFlow(self, flow_obj, cursor=None): """Returns a flow that the worker was processing to the database.""" query = ("SELECT needs_processing FROM flow_requests " "WHERE client_id=%s AND flow_id=%s AND request_id=%s") cursor.execute(query, [ mysql_utils.ClientIDToInt(flow_obj.client_id), mysql_utils.FlowIDToInt(flow_obj.flow_id), flow_obj.next_request_to_process ]) for row in cursor.fetchall(): needs_processing = row[0] if needs_processing: return False update_query = ("UPDATE flows SET flow=%s, processing_on=%s, " "processing_since=%s, processing_deadline=%s, " "next_request_to_process=%s, last_update=%s " "WHERE client_id=%s AND flow_id=%s") clone = flow_obj.Copy() clone.processing_on = None clone.processing_since = None clone.processing_deadline = None now = rdfvalue.RDFDatetime.Now() now_str = mysql_utils.RDFDatetimeToMysqlString(now) args = [ clone.SerializeToString(), None, None, None, flow_obj.next_request_to_process, now_str, mysql_utils.ClientIDToInt(flow_obj.client_id), mysql_utils.FlowIDToInt(flow_obj.flow_id) ] cursor.execute(update_query, args) # This needs to happen after we are sure that the write has succeeded. flow_obj.processing_on = None flow_obj.processing_since = None flow_obj.processing_deadline = None return True
def AckFlowProcessingRequests(self, requests, cursor=None): """Deletes a list of flow processing requests from the database.""" if not requests: return query = "DELETE FROM flow_processing_requests WHERE " conditions = [] args = [] for r in requests: conditions.append("(client_id=%s AND flow_id=%s AND timestamp=%s)") args.append(mysql_utils.ClientIDToInt(r.client_id)) args.append(mysql_utils.FlowIDToInt(r.flow_id)) args.append(mysql_utils.RDFDatetimeToMysqlString(r.timestamp)) query += " OR ".join(conditions) cursor.execute(query, args)
def DeleteFlowRequests(self, requests, cursor=None): """Deletes a list of flow requests from the database.""" if not requests: return conditions = [] args = [] for r in requests: conditions.append("(client_id=%s AND flow_id=%s AND request_id=%s)") args.append(mysql_utils.ClientIDToInt(r.client_id)) args.append(mysql_utils.FlowIDToInt(r.flow_id)) args.append(r.request_id) req_query = "DELETE FROM flow_requests WHERE " + " OR ".join(conditions) res_query = "DELETE FROM flow_responses WHERE " + " OR ".join(conditions) cursor.execute(res_query, args) cursor.execute(req_query, args)
def ReadAllFlowRequestsAndResponses(self, client_id, flow_id, cursor=None): """Reads all requests and responses for a given flow from the database.""" query = ( "SELECT request, needs_processing, responses_expected, timestamp " "FROM flow_requests WHERE client_id=%s AND flow_id=%s") args = [ mysql_utils.ClientIDToInt(client_id), mysql_utils.FlowIDToInt(flow_id) ] cursor.execute(query, args) requests = [] for req, needs_processing, resp_expected, ts in cursor.fetchall(): request = rdf_flow_objects.FlowRequest.FromSerializedString(req) request.needs_processing = needs_processing request.nr_responses_expected = resp_expected request.timestamp = mysql_utils.MysqlToRDFDatetime(ts) requests.append(request) query = ("SELECT response, status, iterator, timestamp " "FROM flow_responses WHERE client_id=%s AND flow_id=%s") cursor.execute(query, args) responses = {} for res, status, iterator, ts in cursor.fetchall(): if status: response = rdf_flow_objects.FlowStatus.FromSerializedString( status) elif iterator: response = rdf_flow_objects.FlowIterator.FromSerializedString( iterator) else: response = rdf_flow_objects.FlowResponse.FromSerializedString( res) response.timestamp = mysql_utils.MysqlToRDFDatetime(ts) responses.setdefault(response.request_id, {})[response.response_id] = response ret = [] for req in sorted(requests, key=lambda r: r.request_id): ret.append((req, responses.get(req.request_id, {}))) return ret
def UpdateFlows(self, client_id_flow_id_pairs, pending_termination=db.Database.unchanged, cursor=None): """Updates flow objects in the database.""" if pending_termination == db.Database.unchanged: return serialized_termination = pending_termination.SerializeToString() query = "UPDATE flows SET pending_termination=%s WHERE " args = [serialized_termination] for index, (client_id, flow_id) in enumerate(client_id_flow_id_pairs): query += ("" if index == 0 else " OR ") + " client_id=%s AND flow_id=%s" args.extend([ mysql_utils.ClientIDToInt(client_id), mysql_utils.FlowIDToInt(flow_id) ]) cursor.execute(query, args)
def _WriteResponses(self, responses, timestamp_str, cursor): """Builds the writes to store the given responses in the db.""" query = ("INSERT IGNORE INTO flow_responses " "(client_id, flow_id, request_id, response_id, " "response, status, iterator, timestamp) VALUES ") templates = [] args = [] for r in responses: templates.append("(%s, %s, %s, %s, %s, %s, %s, %s)") client_id_int = mysql_utils.ClientIDToInt(r.client_id) flow_id_int = mysql_utils.FlowIDToInt(r.flow_id) args.append(client_id_int) args.append(flow_id_int) args.append(r.request_id) args.append(r.response_id) if isinstance(r, rdf_flow_objects.FlowResponse): args.append(r.SerializeToString()) args.append("") args.append("") elif isinstance(r, rdf_flow_objects.FlowStatus): args.append("") args.append(r.SerializeToString()) args.append("") elif isinstance(r, rdf_flow_objects.FlowIterator): args.append("") args.append("") args.append(r.SerializeToString()) else: # This can't really happen due to db api type checking. raise ValueError("Got unexpected response type: %s %s" % (type(r), r)) args.append(timestamp_str) query += ",".join(templates) cursor.execute(query, args)