Esempio n. 1
0
def _Insert(cursor, table, values):
  """Inserts one or multiple rows into the given table.

  Args:
    cursor: The MySQL cursor to perform the insertion.
    table: The table name, where rows should be inserted.
    values: A list of dicts, associating column names to values.
  """
  precondition.AssertIterableType(values, dict)

  if not values:  # Nothing can be INSERTed with empty `values` list.
    return

  column_names = list(sorted(values[0]))
  for value_dict in values:
    if set(column_names) != set(value_dict):
      raise ValueError("Given value dictionaries must have identical keys. "
                       "Expecting columns {!r}, but got value {!r}".format(
                           column_names, value_dict))

  query = "INSERT IGNORE INTO %s {cols} VALUES {vals}" % table
  query = query.format(
      cols=mysql_utils.Columns(column_names),
      vals=mysql_utils.Placeholders(num=len(column_names), values=len(values)))

  values_list = []
  for values_dict in values:
    values_list.extend(values_dict[column] for column in column_names)

  cursor.execute(query, values_list)
Esempio n. 2
0
    def WriteHuntOutputPluginsStates(self, hunt_id, states, cursor=None):
        """Writes hunt output plugin states for a given hunt."""

        columns = ", ".join(_HUNT_OUTPUT_PLUGINS_STATES_COLUMNS)
        placeholders = mysql_utils.Placeholders(
            2 + len(_HUNT_OUTPUT_PLUGINS_STATES_COLUMNS))
        hunt_id_int = db_utils.HuntIDToInt(hunt_id)

        for index, state in enumerate(states):
            query = ("INSERT INTO hunt_output_plugins_states "
                     "(hunt_id, plugin_id, {columns}) "
                     "VALUES {placeholders}".format(columns=columns,
                                                    placeholders=placeholders))
            args = [hunt_id_int, index, state.plugin_descriptor.plugin_name]

            if state.plugin_descriptor.plugin_args is None:
                args.append(None)
            else:
                args.append(
                    state.plugin_descriptor.plugin_args.SerializeToBytes())

            args.append(state.plugin_state.SerializeToBytes())

            try:
                cursor.execute(query, args)
            except MySQLdb.IntegrityError as e:
                raise db.UnknownHuntError(hunt_id=hunt_id, cause=e)
Esempio n. 3
0
 def CheckBlobsExist(self, blob_ids, cursor=None):
   """Checks if given blobs exist."""
   exists = {blob_id: False for blob_id in blob_ids}
   query = "SELECT blob_id FROM blobs WHERE blob_id IN {}".format(
       mysql_utils.Placeholders(len(blob_ids)))
   cursor.execute(query, [blob_id.AsBytes() for blob_id in blob_ids])
   for blob_id, in cursor.fetchall():
     exists[rdf_objects.BlobID.FromBytes(blob_id)] = True
   return exists
Esempio n. 4
0
 def ReadHashBlobReferences(self, hashes, cursor):
   """Reads blob references of a given set of hashes."""
   query = ("SELECT hash_id, blob_references FROM hash_blob_references WHERE "
            "hash_id IN {}").format(mysql_utils.Placeholders(len(hashes)))
   cursor.execute(query, [hash_id.AsBytes() for hash_id in hashes])
   results = {hash_id: None for hash_id in hashes}
   for hash_id, blob_references in cursor.fetchall():
     sha_hash_id = rdf_objects.SHA256HashID.FromBytes(hash_id)
     refs = rdf_objects.BlobReferences.FromSerializedString(blob_references)
     results[sha_hash_id] = list(refs.items)
   return results
Esempio n. 5
0
 def ReadBlobs(self, blob_ids, cursor=None):
   """Reads given blobs."""
   query = ("SELECT blob_id, blob_chunk FROM blobs WHERE blob_id IN {} "
            "ORDER BY blob_id, chunk_index ASC").format(
                mysql_utils.Placeholders(len(blob_ids)))
   cursor.execute(query, [blob_id.AsBytes() for blob_id in blob_ids])
   results = {blob_id: None for blob_id in blob_ids}
   for blob_id_bytes, blob in cursor.fetchall():
     blob_id = rdf_objects.BlobID.FromBytes(blob_id_bytes)
     if results[blob_id] is None:
       results[blob_id] = blob
     else:
       results[blob_id] += blob
   return results
Esempio n. 6
0
    def CheckBlobsExist(self, blob_ids, cursor=None):
        """Checks if given blobs exist."""
        if not blob_ids:
            return {}

        exists = {blob_id: False for blob_id in blob_ids}
        query = ("SELECT blob_id "
                 "FROM blobs "
                 "FORCE INDEX (PRIMARY) "
                 "WHERE blob_id IN {}".format(
                     mysql_utils.Placeholders(len(blob_ids))))
        cursor.execute(query, [blob_id.AsBytes() for blob_id in blob_ids])
        for blob_id, in cursor.fetchall():
            exists[rdf_objects.BlobID.FromSerializedBytes(blob_id)] = True
        return exists
Esempio n. 7
0
    def UpdateUserNotifications(self,
                                username,
                                timestamps,
                                state=None,
                                cursor=None):
        """Updates existing user notification objects."""

        query = ("UPDATE user_notification "
                 "SET notification_state = %s "
                 "WHERE username_hash = %s AND timestamp IN {}").format(
                     mysql_utils.Placeholders(len(timestamps)))

        args = [
            int(state),
            mysql_utils.Hash(username),
        ] + [mysql_utils.RDFDatetimeToMysqlString(t) for t in timestamps]
        cursor.execute(query, args)
Esempio n. 8
0
    def _MultiWritePathInfos(self, path_infos, cursor=None):
        """Writes a collection of path info records for specified clients."""
        path_info_count = 0
        path_info_values = []

        parent_path_info_count = 0
        parent_path_info_values = []

        has_stat_entries = False
        has_hash_entries = False

        for client_id, client_path_infos in iteritems(path_infos):
            for path_info in client_path_infos:
                path = mysql_utils.ComponentsToPath(path_info.components)

                path_info_values.append(db_utils.ClientIDToInt(client_id))
                path_info_values.append(int(path_info.path_type))
                path_info_values.append(path_info.GetPathID().AsBytes())
                path_info_values.append(path)
                path_info_values.append(bool(path_info.directory))
                path_info_values.append(len(path_info.components))

                if path_info.HasField("stat_entry"):
                    path_info_values.append(
                        path_info.stat_entry.SerializeToString())
                    has_stat_entries = True
                else:
                    path_info_values.append(None)
                if path_info.HasField("hash_entry"):
                    path_info_values.append(
                        path_info.hash_entry.SerializeToString())
                    path_info_values.append(
                        path_info.hash_entry.sha256.AsBytes())
                    has_hash_entries = True
                else:
                    path_info_values.append(None)
                    path_info_values.append(None)

                path_info_count += 1

                # TODO(hanuszczak): Implement a trie in order to avoid inserting
                # duplicated records.
                for parent_path_info in path_info.GetAncestors():
                    path = mysql_utils.ComponentsToPath(
                        parent_path_info.components)

                    parent_path_info_values.append(
                        db_utils.ClientIDToInt(client_id))
                    parent_path_info_values.append(
                        int(parent_path_info.path_type))
                    parent_path_info_values.append(
                        parent_path_info.GetPathID().AsBytes())
                    parent_path_info_values.append(path)
                    parent_path_info_values.append(
                        len(parent_path_info.components))

                    parent_path_info_count += 1

        with mysql_utils.TemporaryTable(
                cursor=cursor,
                name="client_path_infos",
                columns=[
                    ("client_id", "BIGINT UNSIGNED NOT NULL"),
                    ("path_type", "INT UNSIGNED NOT NULL"),
                    ("path_id", "BINARY(32) NOT NULL"),
                    ("path", "TEXT NOT NULL"),
                    ("directory", "BOOLEAN NOT NULL"),
                    ("depth", "INT NOT NULL"),
                    ("stat_entry", "MEDIUMBLOB NULL"),
                    ("hash_entry", "MEDIUMBLOB NULL"),
                    ("sha256", "BINARY(32) NULL"),
                    ("timestamp", "TIMESTAMP(6) NOT NULL DEFAULT now(6)"),
                ]):
            if path_info_count > 0:
                query = """
        INSERT INTO client_path_infos(client_id, path_type, path_id,
                                      path, directory, depth,
                                      stat_entry, hash_entry, sha256)
        VALUES {}
        """.format(mysql_utils.Placeholders(num=9, values=path_info_count))
                cursor.execute(query, path_info_values)

                cursor.execute("""
        INSERT INTO client_paths(client_id, path_type, path_id, path,
                                 directory, depth)
             SELECT client_id, path_type, path_id, path, directory, depth
               FROM client_path_infos
        ON DUPLICATE KEY UPDATE
          client_paths.directory = (client_paths.directory OR
                                    VALUES(client_paths.directory)),
          client_paths.timestamp = now(6)
        """)

            if parent_path_info_count > 0:
                placeholders = ["(%s, %s, %s, %s, TRUE, %s)"
                                ] * parent_path_info_count

                cursor.execute(
                    """
        INSERT INTO client_paths(client_id, path_type, path_id, path,
                                 directory, depth)
        VALUES {}
        ON DUPLICATE KEY UPDATE
          directory = TRUE,
          timestamp = now()
        """.format(", ".join(placeholders)), parent_path_info_values)

            if has_stat_entries:
                cursor.execute("""
        INSERT INTO client_path_stat_entries(client_id, path_type, path_id,
                                             stat_entry, timestamp)
             SELECT client_id, path_type, path_id, stat_entry, timestamp
               FROM client_path_infos
              WHERE stat_entry IS NOT NULL
        """)

                cursor.execute("""
        UPDATE client_paths, client_path_infos
           SET client_paths.last_stat_entry_timestamp = client_path_infos.timestamp
         WHERE client_paths.client_id = client_path_infos.client_id
           AND client_paths.path_type = client_path_infos.path_type
           AND client_paths.path_id = client_path_infos.path_id
           AND client_path_infos.stat_entry IS NOT NULL
        """)

            if has_hash_entries:
                cursor.execute("""
        INSERT INTO client_path_hash_entries(client_id, path_type, path_id,
                                             hash_entry, sha256, timestamp)
             SELECT client_id, path_type, path_id, hash_entry, sha256, timestamp
               FROM client_path_infos
              WHERE hash_entry IS NOT NULL
        """)

                cursor.execute("""
        UPDATE client_paths, client_path_infos
           SET client_paths.last_hash_entry_timestamp = client_path_infos.timestamp
         WHERE client_paths.client_id = client_path_infos.client_id
           AND client_paths.path_type = client_path_infos.path_type
           AND client_paths.path_id = client_path_infos.path_id
           AND client_path_infos.hash_entry IS NOT NULL
        """)
Esempio n. 9
0
 def testManyValues(self):
   self.assertEqual(
       mysql_utils.Placeholders(3, 2), "(%s, %s, %s), (%s, %s, %s)")
Esempio n. 10
0
 def testZeroValues(self):
   self.assertEqual(mysql_utils.Placeholders(3, 0), "")
Esempio n. 11
0
 def testMany(self):
   self.assertEqual(mysql_utils.Placeholders(4), "(%s, %s, %s, %s)")
Esempio n. 12
0
 def testOne(self):
   self.assertEqual(mysql_utils.Placeholders(1), "(%s)")
Esempio n. 13
0
 def testEmpty(self):
   self.assertEqual(mysql_utils.Placeholders(0), "()")
Esempio n. 14
0
    def _MultiWritePathInfos(self, path_infos, connection=None):
        """Writes a collection of path info records for specified clients."""
        path_info_count = 0
        path_info_values = []

        parent_path_info_count = 0
        parent_path_info_values = []

        has_stat_entries = False
        has_hash_entries = False

        for client_id, client_path_infos in iteritems(path_infos):
            for path_info in client_path_infos:
                path = mysql_utils.ComponentsToPath(path_info.components)

                path_info_values.append(db_utils.ClientIDToInt(client_id))
                path_info_values.append(int(path_info.path_type))
                path_info_values.append(path_info.GetPathID().AsBytes())
                path_info_values.append(path)
                path_info_values.append(bool(path_info.directory))
                path_info_values.append(len(path_info.components))

                if path_info.HasField("stat_entry"):
                    path_info_values.append(
                        path_info.stat_entry.SerializeToString())
                    has_stat_entries = True
                else:
                    path_info_values.append(None)
                if path_info.HasField("hash_entry"):
                    path_info_values.append(
                        path_info.hash_entry.SerializeToString())
                    path_info_values.append(
                        path_info.hash_entry.sha256.AsBytes())
                    has_hash_entries = True
                else:
                    path_info_values.append(None)
                    path_info_values.append(None)

                path_info_count += 1

                # TODO(hanuszczak): Implement a trie in order to avoid inserting
                # duplicated records.
                for parent_path_info in path_info.GetAncestors():
                    path = mysql_utils.ComponentsToPath(
                        parent_path_info.components)

                    parent_path_info_values.append(
                        db_utils.ClientIDToInt(client_id))
                    parent_path_info_values.append(
                        int(parent_path_info.path_type))
                    parent_path_info_values.append(
                        parent_path_info.GetPathID().AsBytes())
                    parent_path_info_values.append(path)
                    parent_path_info_values.append(
                        len(parent_path_info.components))

                    parent_path_info_count += 1

        try:
            with contextlib.closing(connection.cursor()) as cursor:
                cursor.execute("""
        CREATE TEMPORARY TABLE client_path_infos(
          client_id BIGINT UNSIGNED NOT NULL,
          path_type INT UNSIGNED NOT NULL,
          path_id BINARY(32) NOT NULL,
          path TEXT NOT NULL,
          directory BOOLEAN NOT NULL,
          depth INT NOT NULL,
          stat_entry MEDIUMBLOB NULL,
          hash_entry MEDIUMBLOB NULL,
          sha256 BINARY(32) NULL,
          timestamp TIMESTAMP(6) NOT NULL DEFAULT now(6)
        )""")

                if path_info_count > 0:
                    cursor.execute(
                        """
          INSERT INTO client_path_infos(client_id, path_type, path_id,
                                        path, directory, depth,
                                        stat_entry, hash_entry, sha256)
          VALUES {}
          """.format(mysql_utils.Placeholders(num=9, values=path_info_count)),
                        path_info_values)

                    cursor.execute("""
          INSERT INTO client_paths(client_id, path_type, path_id, path,
                                   directory, depth)
               SELECT client_id, path_type, path_id, path, directory, depth
                 FROM client_path_infos
          ON DUPLICATE KEY UPDATE
            client_paths.directory = client_paths.directory OR VALUES(client_paths.directory),
            client_paths.timestamp = now(6)
          """)

                if parent_path_info_count > 0:
                    placeholders = ["(%s, %s, %s, %s, TRUE, %s)"
                                    ] * parent_path_info_count

                    cursor.execute(
                        """
          INSERT INTO client_paths(client_id, path_type, path_id, path,
                                   directory, depth)
          VALUES {}
          ON DUPLICATE KEY UPDATE
            directory = TRUE,
            timestamp = now()
          """.format(", ".join(placeholders)), parent_path_info_values)

                if has_stat_entries:
                    cursor.execute("""
          INSERT INTO client_path_stat_entries(client_id, path_type, path_id,
                                               stat_entry, timestamp)
               SELECT client_id, path_type, path_id, stat_entry, timestamp
                 FROM client_path_infos
                WHERE stat_entry IS NOT NULL
          """)

                    cursor.execute("""
          UPDATE client_paths, client_path_infos
             SET client_paths.last_stat_entry_timestamp = client_path_infos.timestamp
           WHERE client_paths.client_id = client_path_infos.client_id
             AND client_paths.path_type = client_path_infos.path_type
             AND client_paths.path_id = client_path_infos.path_id
             AND client_path_infos.stat_entry IS NOT NULL
          """)

                if has_hash_entries:
                    cursor.execute("""
          INSERT INTO client_path_hash_entries(client_id, path_type, path_id,
                                               hash_entry, sha256, timestamp)
               SELECT client_id, path_type, path_id, hash_entry, sha256, timestamp
                 FROM client_path_infos
                WHERE hash_entry IS NOT NULL
          """)

                    cursor.execute("""
          UPDATE client_paths, client_path_infos
             SET client_paths.last_hash_entry_timestamp = client_path_infos.timestamp
           WHERE client_paths.client_id = client_path_infos.client_id
             AND client_paths.path_type = client_path_infos.path_type
             AND client_paths.path_id = client_path_infos.path_id
             AND client_path_infos.hash_entry IS NOT NULL
          """)
        finally:
            # Drop the temporary table in a separate cursor. This ensures that
            # even if the previous cursor.execute fails mid-way leaving the
            # temporary table created (as table creation can't be rolled back), the
            # table would still be correctly dropped.
            #
            # This is important since connections are reused in the MySQL connection
            # pool.
            with contextlib.closing(connection.cursor()) as cursor:
                cursor.execute(
                    "DROP TEMPORARY TABLE IF EXISTS client_path_infos")