示例#1
0
def ProcessMol(mol,typeConversions,globalProps,nDone,nameProp='_Name',nameCol='compound_id',
               redraw=False,keepHs=False,
               skipProps=False,addComputedProps=False,
               skipSmiles=False,
               uniqNames=None,namesSeen=None):
  if not mol:
    raise ValueError('no molecule')
  if keepHs:
    Chem.SanitizeMol(mol)
  try:
    nm = mol.GetProp(nameProp)
  except KeyError:
    nm = None
  if not nm:
    nm = 'Mol_%d'%nDone
  if uniqNames and nm in namesSeen:
    logger.error('duplicate compound id (%s) encountered. second instance skipped.'%nm)
    return None
  namesSeen.add(nm)
  row = [nm]
  if not skipProps:
    if addComputedProps:
      nHD=Lipinski.NumHDonors(mol)
      mol.SetProp('DonorCount',str(nHD))
      nHA=Lipinski.NumHAcceptors(mol)
      mol.SetProp('AcceptorCount',str(nHA))
      nRot=Lipinski.NumRotatableBonds(mol)
      mol.SetProp('RotatableBondCount',str(nRot))
      MW=Descriptors.MolWt(mol)
      mol.SetProp('AMW',str(MW))
      logp=Crippen.MolLogP(mol)
      mol.SetProp('MolLogP',str(logp))

    pns = list(mol.GetPropNames())
    pD={}
    for pi,pn in enumerate(pns):
      if pn.lower()==nameCol.lower(): continue
      pv = mol.GetProp(pn).strip()
      if pv.find('>')<0 and pv.find('<')<0:
        colTyp = globalProps.get(pn,2)
        while colTyp>0:
          try:
            tpi = typeConversions[colTyp][1](pv)
          except:
            colTyp-=1
          else:
            break
        globalProps[pn]=colTyp
        pD[pn]=typeConversions[colTyp][1](pv)
      else:
        pD[pn]=pv
  else:
    pD={}
  if redraw:
    AllChem.Compute2DCoords(m)
  if not skipSmiles:
    row.append(Chem.MolToSmiles(mol,True))
  row.append(DbModule.binaryHolder(mol.ToBinary()))
  row.append(pD)
  return row
示例#2
0
文件: DbInfo.py 项目: abradle/rdkit
def GetDbNames(user='******', password='******', dirName='.', dBase='::template1', cn=None):
  """ returns a list of databases that are available

    **Arguments**

      - user: the username for DB access

      - password: the password to be used for DB access

    **Returns**

      - a list of db names (strings)

  """
  if DbModule.getDbSql:
    if not cn:
      try:
        cn = DbModule.connect(dBase, user, password)
      except Exception:
        print('Problems opening database: %s' % (dBase))
        return []
    c = cn.cursor()
    c.execute(DbModule.getDbSql)
    if RDConfig.usePgSQL:
      names = ['::' + str(x[0]) for x in c.fetchall()]
    else:
      names = ['::' + str(x[0]) for x in c.fetchall()]
    names.remove(dBase)
  elif DbModule.fileWildcard:
    import os.path, glob
    names = glob.glob(os.path.join(dirName, DbModule.fileWildcard))
  else:
    names = []
  return names
示例#3
0
文件: DbInfo.py 项目: abradle/rdkit
def GetTableNames(dBase, user='******', password='******', includeViews=0, cn=None):
  """ returns a list of tables available in a database

    **Arguments**

      - dBase: the name of the DB file to be used

      - user: the username for DB access

      - password: the password to be used for DB access

      - includeViews: if this is non-null, the views in the db will
        also be returned

    **Returns**

      - a list of table names (strings)

  """
  if not cn:
    try:
      cn = DbModule.connect(dBase, user, password)
    except Exception:
      print('Problems opening database: %s' % (dBase))
      return []
  c = cn.cursor()
  if not includeViews:
    comm = DbModule.getTablesSql
  else:
    comm = DbModule.getTablesAndViewsSql
  c.execute(comm)
  names = [str(x[0]).upper() for x in c.fetchall()]
  if RDConfig.usePgSQL and 'PG_LOGDIR_LS' in names:
    names.remove('PG_LOGDIR_LS')
  return names
示例#4
0
文件: DbInfo.py 项目: abradle/rdkit
def GetColumnNamesAndTypes(dBase, table, user='******', password='******', join='', what='*',
                           cn=None):
  """ gets a list of columns available in a DB table along with their types

    **Arguments**

      - dBase: the name of the DB file to be used

      - table: the name of the table to query

      - user: the username for DB access

      - password: the password to be used for DB access

      - join: an optional join clause (omit the verb 'join')

      - what: an optional clause indicating what to select

    **Returns**

      - a list of 2-tuples containing:

          1) column name

          2) column type

  """
  if not cn:
    cn = DbModule.connect(dBase, user, password)
  c = cn.cursor()
  cmd = 'select %s from %s' % (what, table)
  if join:
    cmd += ' join %s' % (join)
  c.execute(cmd)
  return GetColumnInfoFromCursor(c)
示例#5
0
文件: DbInfo.py 项目: mivicms/clusfps
def GetDbNames(user='******',password='******',dirName='.',dBase='::template1',cn=None):
  """ returns a list of databases that are available

    **Arguments**

      - user: the username for DB access

      - password: the password to be used for DB access

    **Returns**

      - a list of db names (strings)

  """
  if DbModule.getDbSql:
    if not cn:
      try:
        cn = DbModule.connect(dBase,user,password)
      except:
        print('Problems opening database: %s'%(dBase))
        return []
    c = cn.cursor()
    c.execute(DbModule.getDbSql)
    if RDConfig.usePgSQL:
      names = ['::'+str(x[0]) for x in c.fetchall()]
    else:
      names = ['::'+str(x[0]) for x in c.fetchall()]
    names.remove(dBase)
  elif DbModule.fileWildcard:
    import os.path,glob
    names = glob.glob(os.path.join(dirName,DbModule.fileWildcard))
  else:
    names = []
  return names
示例#6
0
文件: DbInfo.py 项目: abradle/rdkit
def GetColumnNames(dBase, table, user='******', password='******', join='', what='*', cn=None):
  """ gets a list of columns available in a DB table

    **Arguments**

      - dBase: the name of the DB file to be used

      - table: the name of the table to query

      - user: the username for DB access

      - password: the password to be used for DB access

      - join: an optional join clause  (omit the verb 'join')

      - what: an optional clause indicating what to select

    **Returns**

      -  a list of column names

  """
  if not cn:
    cn = DbModule.connect(dBase, user, password)
  c = cn.cursor()
  cmd = 'select %s from %s' % (what, table)
  if join:
    if join.strip().find('join') != 0:
      join = 'join %s' % (join)
    cmd += ' ' + join
  c.execute(cmd)
  c.fetchone()
  desc = c.description
  res = [str(x[0]) for x in desc]
  return res
示例#7
0
文件: DbUtils.py 项目: rwest/rdkit
def GetColumns(dBase,table,fieldString,user='******',password='******',
               join='',cn=None):
  """ gets a set of data from a table

    **Arguments**

     - dBase: database name

     - table: table name
     
     - fieldString: a string with the names of the fields to be extracted,
        this should be a comma delimited list

     - user and  password:

     - join: a join clause (omit the verb 'join')
       

    **Returns**

     - a list of the data

  """
  if not cn:
    cn = DbModule.connect(dBase,user,password)
  c = cn.cursor()
  cmd = 'select %s from %s'%(fieldString,table)
  if join:
    if join.strip().find('join') != 0:
      join = 'join %s'%(join)
    cmd +=' ' + join
  c.execute(cmd)
  return c.fetchall()
示例#8
0
def GetAtomicData(atomDict, descriptorsDesired, dBase=_atomDbName, table='atomic_data', where='',
                  user='******', password='******', includeElCounts=0):
    """ pulls atomic data from a database

      **Arguments**

        - atomDict: the dictionary to populate

        - descriptorsDesired: the descriptors to pull for each atom

        - dBase: the DB to use

        - table: the DB table to use

        - where: the SQL where clause

        - user: the user name to use with the DB

        - password: the password to use with the DB

        - includeElCounts: if nonzero, valence electron count fields are added to
           the _atomDict_

    """
    extraFields = ['NVAL', 'NVAL_NO_FULL_F', 'NVAL_NO_FULL_D', 'NVAL_NO_FULL']
    from rdkit.Dbase import DbModule
    cn = DbModule.connect(dBase, user, password)
    c = cn.cursor()
    descriptorsDesired = [s.upper() for s in descriptorsDesired]
    if 'NAME' not in descriptorsDesired:
        descriptorsDesired.append('NAME')
    if includeElCounts and 'CONFIG' not in descriptorsDesired:
        descriptorsDesired.append('CONFIG')
    for field in extraFields:
        if field in descriptorsDesired:
            descriptorsDesired.remove(field)
    toPull = ','.join(descriptorsDesired)
    command = 'select %s from atomic_data %s' % (toPull, where)
    try:
        c.execute(command)
    except Exception:
        print('Problems executing command:', command)
        return
    res = c.fetchall()
    for atom in res:
        tDict = {}
        for i in range(len(descriptorsDesired)):
            desc = descriptorsDesired[i]
            val = atom[i]
            tDict[desc] = val
        name = tDict['NAME']
        atomDict[name] = tDict
        if includeElCounts:
            config = atomDict[name]['CONFIG']
            atomDict[name]['NVAL'] = ConfigToNumElectrons(config)
            atomDict[name]['NVAL_NO_FULL_F'] = ConfigToNumElectrons(config, ignoreFullF=1)
            atomDict[name]['NVAL_NO_FULL_D'] = ConfigToNumElectrons(config, ignoreFullD=1)
            atomDict[name]['NVAL_NO_FULL'] = ConfigToNumElectrons(
                config, ignoreFullF=1, ignoreFullD=1)
示例#9
0
文件: wayne.py 项目: aytsai/ricebowl
  def test9(self):
    " substructure counts "
    curs = self.conn.GetCursor()

    res = curs.execute("SELECT rd_substructcount('O','OCCC(=O)O')").fetchone()
    self.failUnless(res[0]==3)
    res = curs.execute("SELECT rd_substructcount('N','OCCC(=O)O')").fetchone()
    self.failUnless(res[0]==0)
    res = curs.execute("SELECT rd_substructcount('[O,S]','SCCC(=O)O')").fetchone()
    self.failUnless(res[0]==3)
    res = curs.execute("SELECT rd_substructcount('[O,S]','OCCC(=O)O')").fetchone()
    self.failUnless(res[0]==3)
    
    self.failUnlessRaises(DataError,
                          lambda : curs.execute("SELECT rd_substructcount('QcC','c1ccccc1C');"))
    curs = self.conn.GetCursor()
    self.failUnlessRaises(DataError,
                          lambda : curs.execute("SELECT rd_substructcount('QcC','c1ccccc1C');"))
    curs = self.conn.GetCursor()
    self.failUnlessRaises(DataError,
                          lambda : curs.execute("SELECT rd_substructcount('cC','Qc1ccccc1C');"))
    curs = self.conn.GetCursor()
    self.failUnlessRaises(DataError,
                          lambda : curs.execute("SELECT rd_substructcount('nC','c1ccccn1C');"))
    curs = self.conn.GetCursor()


    pkl1 = DbModule.binaryHolder(Chem.MolFromSmiles('O').ToBinary())
    pkl2 = DbModule.binaryHolder(Chem.MolFromSmiles('OCCC(=O)O').ToBinary())
    cmd = "SELECT rd_substructcount(cast (%s as bytea),cast (%s as bytea))"
    res = curs.execute(cmd,(pkl1,pkl2)).fetchone()
    self.failUnless(res[0]==3)
    pkl1 = DbModule.binaryHolder(Chem.MolFromSmiles('N').ToBinary())
    res = curs.execute(cmd,(pkl1,pkl2)).fetchone()
    self.failUnless(res[0]==0)

    pkl1 = DbModule.binaryHolder(Chem.MolFromSmiles('O').ToBinary())
    curs = self.conn.GetCursor()
    self.failUnlessRaises(DataError,
                          lambda : curs.execute(cmd,('',pkl2)))
    curs = self.conn.GetCursor()
    self.failUnlessRaises(DataError,
                          lambda : curs.execute(cmd,(pkl1,'')))
    curs = self.conn.GetCursor()
    self.failUnlessRaises(DataError,
                          lambda : curs.execute(cmd,('','')))
示例#10
0
def GetAtomicData(atomDict, descriptorsDesired, dBase=_atomDbName, table='atomic_data', where='',
                  user='******', password='******', includeElCounts=0):
  """ pulls atomic data from a database

    **Arguments**

      - atomDict: the dictionary to populate

      - descriptorsDesired: the descriptors to pull for each atom

      - dBase: the DB to use

      - table: the DB table to use

      - where: the SQL where clause

      - user: the user name to use with the DB

      - password: the password to use with the DB

      - includeElCounts: if nonzero, valence electron count fields are added to
         the _atomDict_

  """
  extraFields = ['NVAL', 'NVAL_NO_FULL_F', 'NVAL_NO_FULL_D', 'NVAL_NO_FULL']
  from rdkit.Dbase import DbModule
  cn = DbModule.connect(dBase, user, password)
  c = cn.cursor()
  descriptorsDesired = [s.upper() for s in descriptorsDesired]
  if 'NAME' not in descriptorsDesired:
    descriptorsDesired.append('NAME')
  if includeElCounts and 'CONFIG' not in descriptorsDesired:
    descriptorsDesired.append('CONFIG')
  for field in extraFields:
    if field in descriptorsDesired:
      descriptorsDesired.remove(field)
  toPull = ','.join(descriptorsDesired)
  command = 'select %s from atomic_data %s' % (toPull, where)
  try:
    c.execute(command)
  except Exception:
    print('Problems executing command:', command)
    return
  res = c.fetchall()
  for atom in res:
    tDict = {}
    for i in xrange(len(descriptorsDesired)):
      desc = descriptorsDesired[i]
      val = atom[i]
      tDict[desc] = val
    name = tDict['NAME']
    atomDict[name] = tDict
    if includeElCounts:
      config = atomDict[name]['CONFIG']
      atomDict[name]['NVAL'] = ConfigToNumElectrons(config)
      atomDict[name]['NVAL_NO_FULL_F'] = ConfigToNumElectrons(config, ignoreFullF=1)
      atomDict[name]['NVAL_NO_FULL_D'] = ConfigToNumElectrons(config, ignoreFullD=1)
      atomDict[name]['NVAL_NO_FULL'] = ConfigToNumElectrons(config, ignoreFullF=1, ignoreFullD=1)
示例#11
0
文件: DbUtils.py 项目: rwest/rdkit
def DatabaseToText(dBase,table,fields='*',join='',where='',
                  user='******',password='******',delim=',',cn=None):
  """ Pulls the contents of a database and makes a deliminted text file from them

    **Arguments**
      - dBase: the name of the DB file to be used

      - table: the name of the table to query

      - fields: the fields to select with the SQL query

      - join: the join clause of the SQL query
        (e.g. 'join foo on foo.bar=base.bar')

      - where: the where clause of the SQL query
        (e.g. 'where foo = 2' or 'where bar > 17.6')

      - user: the username for DB access

      - password: the password to be used for DB access

    **Returns**

      - the CSV data (as text) 

  """
  if len(where) and where.strip().find('where')==-1:
    where = 'where %s'%(where)
  if len(join) and join.strip().find('join') == -1:
    join = 'join %s'%(join)
  sqlCommand = 'select %s from %s %s %s'%(fields,table,join,where)
  if not cn:
    cn = DbModule.connect(dBase,user,password)
  c = cn.cursor()
  c.execute(sqlCommand)
  headers = []
  colsToTake = []
  # the description field of the cursor carries around info about the columns
  #  of the table
  for i in range(len(c.description)):
    item = c.description[i]
    if item[1] not in DbInfo.sqlBinTypes:
      colsToTake.append(i)
      headers.append(item[0])

  lines = []
  lines.append(delim.join(headers))

  # grab the data
  results = c.fetchall()
  for res in results:
    d = _take(res,colsToTake)
    lines.append(delim.join(map(str,d)))

  return '\n'.join(lines)
示例#12
0
  def GetCursor(self):
    """ returns a cursor for direct manipulation of the DB
      only one cursor is available

    """
    if self.cursor is not None:
      return self.cursor
      
    self.cn = DbModule.connect(self.dbName,self.user,self.password)
    self.cursor = self.cn.cursor()
    return self.cursor
示例#13
0
    def GetCursor(self):
        """ returns a cursor for direct manipulation of the DB
      only one cursor is available

    """
        if self.cursor is not None:
            return self.cursor

        self.cn = DbModule.connect(self.dbName, self.user, self.password)
        self.cursor = self.cn.cursor()
        return self.cursor
示例#14
0
文件: DbUtils.py 项目: rwest/rdkit
def _AddDataToDb(dBase,table,user,password,colDefs,colTypes,data,
                 nullMarker=None,blockSize=100,cn=None):
  """ *For Internal Use*

    (drops and) creates a table and then inserts the values

  """
  if not cn:
    cn = DbModule.connect(dBase,user,password)
  c = cn.cursor()
  try:
    c.execute('drop table %s'%(table))
  except:
    print('cannot drop table %s'%(table))
  try:
    sqlStr = 'create table %s (%s)'%(table,colDefs)
    c.execute(sqlStr)
  except:
    print('create table failed: ', sqlStr)
    print('here is the exception:')
    import traceback
    traceback.print_exc()
    return
  cn.commit()
  c = None
  
  block = []
  entryTxt = [DbModule.placeHolder]*len(data[0])
  dStr = ','.join(entryTxt)
  sqlStr = 'insert into %s values (%s)'%(table,dStr)
  nDone = 0
  for row in data:
    entries = [None]*len(row)
    for col in xrange(len(row)):
      if row[col] is not None and \
         (nullMarker is None or row[col] != nullMarker):
        if colTypes[col][0] == types.FloatType:
          entries[col] = float(row[col])
        elif colTypes[col][0] == types.IntType:
          entries[col] = int(row[col])
        else:
          entries[col] = str(row[col])
      else:
        entries[col] = None
    block.append(tuple(entries))
    if len(block)>=blockSize:
      nDone += _insertBlock(cn,sqlStr,block)
      if not hasattr(cn,'autocommit') or not cn.autocommit:
        cn.commit()
      block = []
  if len(block):
    nDone += _insertBlock(cn,sqlStr,block)
  if not hasattr(cn,'autocommit') or not cn.autocommit:
    cn.commit()
示例#15
0
def GetColumnNames(dBase,
                   table,
                   user='******',
                   password='******',
                   join='',
                   what='*',
                   cn=None):
    """ gets a list of columns available in a DB table

      **Arguments**

        - dBase: the name of the DB file to be used

        - table: the name of the table to query

        - user: the username for DB access

        - password: the password to be used for DB access

        - join: an optional join clause  (omit the verb 'join')

        - what: an optional clause indicating what to select

      **Returns**

        -  a list of column names

    """
    if not cn:
        cn = DbModule.connect(dBase, user, password)
    c = cn.cursor()
    cmd = 'select %s from %s' % (what, table)
    if join:
        if join.strip().find('join') != 0:
            join = 'join %s' % (join)
        cmd += ' ' + join
    c.execute(cmd)
    c.fetchone()
    desc = c.description
    res = [str(x[0]) for x in desc]
    return res
示例#16
0
def GetColumnNamesAndTypes(dBase,
                           table,
                           user='******',
                           password='******',
                           join='',
                           what='*',
                           cn=None):
    """ gets a list of columns available in a DB table along with their types

      **Arguments**

        - dBase: the name of the DB file to be used

        - table: the name of the table to query

        - user: the username for DB access

        - password: the password to be used for DB access

        - join: an optional join clause (omit the verb 'join')

        - what: an optional clause indicating what to select

      **Returns**

        - a list of 2-tuples containing:

            1) column name

            2) column type

    """
    if not cn:
        cn = DbModule.connect(dBase, user, password)
    c = cn.cursor()
    cmd = 'select %s from %s' % (what, table)
    if join:
        cmd += ' join %s' % (join)
    c.execute(cmd)
    return GetColumnInfoFromCursor(c)
示例#17
0
def GetTableNames(dBase,
                  user='******',
                  password='******',
                  includeViews=0,
                  cn=None):
    """ returns a list of tables available in a database

      **Arguments**

        - dBase: the name of the DB file to be used

        - user: the username for DB access

        - password: the password to be used for DB access

        - includeViews: if this is non-null, the views in the db will
          also be returned

      **Returns**

        - a list of table names (strings)

    """
    if not cn:
        try:
            cn = DbModule.connect(dBase, user, password)
        except Exception:
            print('Problems opening database: %s' % (dBase))
            return []

    c = cn.cursor()
    if not includeViews:
        comm = DbModule.getTablesSql
    else:
        comm = DbModule.getTablesAndViewsSql
    c.execute(comm)
    names = [str(x[0]).upper() for x in c.fetchall()]
    if RDConfig.usePgSQL and 'PG_LOGDIR_LS' in names:
        names.remove('PG_LOGDIR_LS')
    return names
示例#18
0
def FingerprintsFromDetails(details, reportFreq=10):
    data = None
    if details.dbName and details.tableName:
        from rdkit.Dbase.DbConnection import DbConnect
        from rdkit.Dbase import DbInfo
        from rdkit.ML.Data import DataUtils
        try:
            conn = DbConnect(details.dbName, details.tableName)
        except Exception:
            import traceback
            error('Problems establishing connection to database: %s|%s\n' %
                  (details.dbName, details.tableName))
            traceback.print_exc()
        if not details.idName:
            details.idName = DbInfo.GetColumnNames(details.dbName,
                                                   details.tableName)[0]
        dataSet = DataUtils.DBToData(details.dbName,
                                     details.tableName,
                                     what='%s,%s' %
                                     (details.idName, details.smilesName))
        idCol = 0
        smiCol = 1
    elif details.inFileName and details.useSmiles:
        from rdkit.ML.Data import DataUtils
        conn = None
        if not details.idName:
            details.idName = 'ID'
        try:
            dataSet = DataUtils.TextFileToData(
                details.inFileName,
                onlyCols=[details.idName, details.smilesName])
        except IOError:
            import traceback
            error('Problems reading from file %s\n' % (details.inFileName))
            traceback.print_exc()

        idCol = 0
        smiCol = 1
    elif details.inFileName and details.useSD:
        conn = None
        dataset = None
        if not details.idName:
            details.idName = 'ID'
        dataSet = []
        try:
            s = Chem.SDMolSupplier(details.inFileName)
        except Exception:
            import traceback
            error('Problems reading from file %s\n' % (details.inFileName))
            traceback.print_exc()
        else:
            while 1:
                try:
                    m = s.next()
                except StopIteration:
                    break
                if m:
                    dataSet.append(m)
                    if reportFreq > 0 and not len(dataSet) % reportFreq:
                        message('Read %d molecules\n' % (len(dataSet)))
                        if details.maxMols > 0 and len(
                                dataSet) >= details.maxMols:
                            break

        for i, mol in enumerate(dataSet):
            if mol.HasProp(details.idName):
                nm = mol.GetProp(details.idName)
            else:
                nm = mol.GetProp('_Name')
            dataSet[i] = (nm, mol)
    else:
        dataSet = None

    fps = None
    if dataSet and not details.useSD:
        data = dataSet.GetNamedData()
        if not details.molPklName:
            fps = FingerprintsFromSmiles(data, idCol, smiCol,
                                         **details.__dict__)
        else:
            fps = FingerprintsFromPickles(data, idCol, smiCol,
                                          **details.__dict__)
    elif dataSet and details.useSD:
        fps = FingerprintsFromMols(dataSet, **details.__dict__)

    if fps:
        if details.outFileName:
            outF = open(details.outFileName, 'wb+')
            for i in range(len(fps)):
                pickle.dump(fps[i], outF)
            outF.close()
        dbName = details.outDbName or details.dbName
        if details.outTableName and dbName:
            from rdkit.Dbase.DbConnection import DbConnect
            from rdkit.Dbase import DbUtils, DbModule
            conn = DbConnect(dbName)
            #
            #  We don't have a db open already, so we'll need to figure out
            #    the types of our columns...
            #
            colTypes = DbUtils.TypeFinder(data, len(data), len(data[0]))
            typeStrs = DbUtils.GetTypeStrings(
                [details.idName, details.smilesName],
                colTypes,
                keyCol=details.idName)
            cols = '%s, %s %s' % (typeStrs[0], details.fpColName,
                                  DbModule.binaryTypeName)

            # FIX: we should really check to see if the table
            #  is already there and, if so, add the appropriate
            #  column.

            #
            # create the new table
            #
            if details.replaceTable or \
               details.outTableName.upper() not in [x.upper() for x in conn.GetTableNames()]:
                conn.AddTable(details.outTableName, cols)

            #
            # And add the data
            #
            for ID, fp in fps:
                tpl = ID, DbModule.binaryHolder(fp.ToBinary())
                conn.InsertData(details.outTableName, tpl)
            conn.Commit()
    return fps
示例#19
0
def RunOnData(details, data, progressCallback=None, saveIt=1, setDescNames=0):
    nExamples = data.GetNPts()
    if details.lockRandom:
        seed = details.randomSeed
    else:
        import random
        seed = (random.randint(0, 1e6), random.randint(0, 1e6))
    DataUtils.InitRandomNumbers(seed)
    testExamples = []
    if details.shuffleActivities == 1:
        DataUtils.RandomizeActivities(data, shuffle=1, runDetails=details)
    elif details.randomActivities == 1:
        DataUtils.RandomizeActivities(data, shuffle=0, runDetails=details)

    namedExamples = data.GetNamedData()
    if details.splitRun == 1:
        trainIdx, testIdx = SplitData.SplitIndices(len(namedExamples),
                                                   details.splitFrac,
                                                   silent=not _verbose)

        trainExamples = [namedExamples[x] for x in trainIdx]
        testExamples = [namedExamples[x] for x in testIdx]
    else:
        testExamples = []
        testIdx = []
        trainIdx = range(len(namedExamples))
        trainExamples = namedExamples

    if details.filterFrac != 0.0:
        # if we're doing quantization on the fly, we need to handle that here:
        if hasattr(details, 'activityBounds') and details.activityBounds:
            tExamples = []
            bounds = details.activityBounds
            for pt in trainExamples:
                pt = pt[:]
                act = pt[-1]
                placed = 0
                bound = 0
                while not placed and bound < len(bounds):
                    if act < bounds[bound]:
                        pt[-1] = bound
                        placed = 1
                    else:
                        bound += 1
                if not placed:
                    pt[-1] = bound
                tExamples.append(pt)
        else:
            bounds = None
            tExamples = trainExamples
        trainIdx, temp = DataUtils.FilterData(tExamples,
                                              details.filterVal,
                                              details.filterFrac,
                                              -1,
                                              indicesOnly=1)
        tmp = [trainExamples[x] for x in trainIdx]
        testExamples += [trainExamples[x] for x in temp]
        trainExamples = tmp

        counts = DataUtils.CountResults(trainExamples, bounds=bounds)
        ks = counts.keys()
        ks.sort()
        message('Result Counts in training set:')
        for k in ks:
            message(str((k, counts[k])))
        counts = DataUtils.CountResults(testExamples, bounds=bounds)
        ks = counts.keys()
        ks.sort()
        message('Result Counts in test set:')
        for k in ks:
            message(str((k, counts[k])))
    nExamples = len(trainExamples)
    message('Training with %d examples' % (nExamples))

    nVars = data.GetNVars()
    attrs = range(1, nVars + 1)
    nPossibleVals = data.GetNPossibleVals()
    for i in range(1, len(nPossibleVals)):
        if nPossibleVals[i - 1] == -1:
            attrs.remove(i)

    if details.pickleDataFileName != '':
        pickleDataFile = open(details.pickleDataFileName, 'wb+')
        cPickle.dump(trainExamples, pickleDataFile)
        cPickle.dump(testExamples, pickleDataFile)
        pickleDataFile.close()

    if details.bayesModel:
        composite = BayesComposite.BayesComposite()
    else:
        composite = Composite.Composite()

    composite._randomSeed = seed
    composite._splitFrac = details.splitFrac
    composite._shuffleActivities = details.shuffleActivities
    composite._randomizeActivities = details.randomActivities

    if hasattr(details, 'filterFrac'):
        composite._filterFrac = details.filterFrac
    if hasattr(details, 'filterVal'):
        composite._filterVal = details.filterVal

    composite.SetModelFilterData(details.modelFilterFrac,
                                 details.modelFilterVal)

    composite.SetActivityQuantBounds(details.activityBounds)
    nPossibleVals = data.GetNPossibleVals()
    if details.activityBounds:
        nPossibleVals[-1] = len(details.activityBounds) + 1

    if setDescNames:
        composite.SetInputOrder(data.GetVarNames())
        composite.SetDescriptorNames(details._descNames)
    else:
        composite.SetDescriptorNames(data.GetVarNames())
    composite.SetActivityQuantBounds(details.activityBounds)
    if details.nModels == 1:
        details.internalHoldoutFrac = 0.0
    if details.useTrees:
        from rdkit.ML.DecTree import CrossValidate, PruneTree
        if details.qBounds != []:
            from rdkit.ML.DecTree import BuildQuantTree
            builder = BuildQuantTree.QuantTreeBoot
        else:
            from rdkit.ML.DecTree import ID3
            builder = ID3.ID3Boot
        driver = CrossValidate.CrossValidationDriver
        pruner = PruneTree.PruneTree

        composite.SetQuantBounds(details.qBounds)
        nPossibleVals = data.GetNPossibleVals()
        if details.activityBounds:
            nPossibleVals[-1] = len(details.activityBounds) + 1
        composite.Grow(trainExamples,
                       attrs,
                       nPossibleVals=[0] + nPossibleVals,
                       buildDriver=driver,
                       pruner=pruner,
                       nTries=details.nModels,
                       pruneIt=details.pruneIt,
                       lessGreedy=details.lessGreedy,
                       needsQuantization=0,
                       treeBuilder=builder,
                       nQuantBounds=details.qBounds,
                       startAt=details.startAt,
                       maxDepth=details.limitDepth,
                       progressCallback=progressCallback,
                       holdOutFrac=details.internalHoldoutFrac,
                       replacementSelection=details.replacementSelection,
                       recycleVars=details.recycleVars,
                       randomDescriptors=details.randomDescriptors,
                       silent=not _verbose)

    elif details.useSigTrees:
        from rdkit.ML.DecTree import CrossValidate
        from rdkit.ML.DecTree import BuildSigTree
        builder = BuildSigTree.SigTreeBuilder
        driver = CrossValidate.CrossValidationDriver
        nPossibleVals = data.GetNPossibleVals()
        if details.activityBounds:
            nPossibleVals[-1] = len(details.activityBounds) + 1
        if hasattr(details, 'sigTreeBiasList'):
            biasList = details.sigTreeBiasList
        else:
            biasList = None
        if hasattr(details, 'useCMIM'):
            useCMIM = details.useCMIM
        else:
            useCMIM = 0
        if hasattr(details, 'allowCollections'):
            allowCollections = details.allowCollections
        else:
            allowCollections = False
        composite.Grow(trainExamples,
                       attrs,
                       nPossibleVals=[0] + nPossibleVals,
                       buildDriver=driver,
                       nTries=details.nModels,
                       needsQuantization=0,
                       treeBuilder=builder,
                       maxDepth=details.limitDepth,
                       progressCallback=progressCallback,
                       holdOutFrac=details.internalHoldoutFrac,
                       replacementSelection=details.replacementSelection,
                       recycleVars=details.recycleVars,
                       randomDescriptors=details.randomDescriptors,
                       biasList=biasList,
                       useCMIM=useCMIM,
                       allowCollection=allowCollections,
                       silent=not _verbose)

    elif details.useKNN:
        from rdkit.ML.KNN import CrossValidate
        from rdkit.ML.KNN import DistFunctions

        driver = CrossValidate.CrossValidationDriver
        dfunc = ''
        if (details.knnDistFunc == "Euclidean"):
            dfunc = DistFunctions.EuclideanDist
        elif (details.knnDistFunc == "Tanimoto"):
            dfunc = DistFunctions.TanimotoDist
        else:
            assert 0, "Bad KNN distance metric value"

        composite.Grow(trainExamples,
                       attrs,
                       nPossibleVals=[0] + nPossibleVals,
                       buildDriver=driver,
                       nTries=details.nModels,
                       needsQuantization=0,
                       numNeigh=details.knnNeighs,
                       holdOutFrac=details.internalHoldoutFrac,
                       distFunc=dfunc)

    elif details.useNaiveBayes or details.useSigBayes:
        from rdkit.ML.NaiveBayes import CrossValidate
        driver = CrossValidate.CrossValidationDriver
        if not (hasattr(details, 'useSigBayes') and details.useSigBayes):
            composite.Grow(trainExamples,
                           attrs,
                           nPossibleVals=[0] + nPossibleVals,
                           buildDriver=driver,
                           nTries=details.nModels,
                           needsQuantization=0,
                           nQuantBounds=details.qBounds,
                           holdOutFrac=details.internalHoldoutFrac,
                           replacementSelection=details.replacementSelection,
                           mEstimateVal=details.mEstimateVal,
                           silent=not _verbose)
        else:
            if hasattr(details, 'useCMIM'):
                useCMIM = details.useCMIM
            else:
                useCMIM = 0

            composite.Grow(trainExamples,
                           attrs,
                           nPossibleVals=[0] + nPossibleVals,
                           buildDriver=driver,
                           nTries=details.nModels,
                           needsQuantization=0,
                           nQuantBounds=details.qBounds,
                           mEstimateVal=details.mEstimateVal,
                           useSigs=True,
                           useCMIM=useCMIM,
                           holdOutFrac=details.internalHoldoutFrac,
                           replacementSelection=details.replacementSelection,
                           silent=not _verbose)


##   elif details.useSVM:
##     from rdkit.ML.SVM import CrossValidate
##     driver = CrossValidate.CrossValidationDriver
##     composite.Grow(trainExamples, attrs, nPossibleVals=[0]+nPossibleVals,
##                    buildDriver=driver, nTries=details.nModels,
##                    needsQuantization=0,
##                    cost=details.svmCost,gamma=details.svmGamma,
##                    weights=details.svmWeights,degree=details.svmDegree,
##                    type=details.svmType,kernelType=details.svmKernel,
##                    coef0=details.svmCoeff,eps=details.svmEps,nu=details.svmNu,
##                    cache_size=details.svmCache,shrinking=details.svmShrink,
##                    dataType=details.svmDataType,
##                    holdOutFrac=details.internalHoldoutFrac,
##                    replacementSelection=details.replacementSelection,
##                    silent=not _verbose)

    else:
        from rdkit.ML.Neural import CrossValidate
        driver = CrossValidate.CrossValidationDriver
        composite.Grow(trainExamples,
                       attrs, [0] + nPossibleVals,
                       nTries=details.nModels,
                       buildDriver=driver,
                       needsQuantization=0)

    composite.AverageErrors()
    composite.SortModels()
    modelList, counts, avgErrs = composite.GetAllData()
    counts = numpy.array(counts)
    avgErrs = numpy.array(avgErrs)
    composite._varNames = data.GetVarNames()

    for i in range(len(modelList)):
        modelList[i].NameModel(composite._varNames)

    # do final statistics
    weightedErrs = counts * avgErrs
    averageErr = sum(weightedErrs) / sum(counts)
    devs = (avgErrs - averageErr)
    devs = devs * counts
    devs = numpy.sqrt(devs * devs)
    avgDev = sum(devs) / sum(counts)
    message('# Overall Average Error: %%% 5.2f, Average Deviation: %%% 6.2f' %
            (100. * averageErr, 100. * avgDev))

    if details.bayesModel:
        composite.Train(trainExamples, verbose=0)

    # blow out the saved examples and then save the composite:
    composite.ClearModelExamples()
    if saveIt:
        composite.Pickle(details.outName)
    details.model = DbModule.binaryHolder(cPickle.dumps(composite))

    badExamples = []
    if not details.detailedRes and (not hasattr(details, 'noScreen')
                                    or not details.noScreen):
        if details.splitRun:
            message('Testing all hold-out examples')
            wrong = testall(composite, testExamples, badExamples)
            message('%d examples (%% %5.2f) were misclassified' %
                    (len(wrong),
                     100. * float(len(wrong)) / float(len(testExamples))))
            _runDetails.holdout_error = float(len(wrong)) / len(testExamples)
        else:
            message('Testing all examples')
            wrong = testall(composite, namedExamples, badExamples)
            message('%d examples (%% %5.2f) were misclassified' %
                    (len(wrong),
                     100. * float(len(wrong)) / float(len(namedExamples))))
            _runDetails.overall_error = float(len(wrong)) / len(namedExamples)

    if details.detailedRes:
        message('\nEntire data set:')
        resTup = ScreenComposite.ShowVoteResults(range(data.GetNPts()), data,
                                                 composite, nPossibleVals[-1],
                                                 details.threshold)
        nGood, nBad, nSkip, avgGood, avgBad, avgSkip, voteTab = resTup
        nPts = len(namedExamples)
        nClass = nGood + nBad
        _runDetails.overall_error = float(nBad) / nClass
        _runDetails.overall_correct_conf = avgGood
        _runDetails.overall_incorrect_conf = avgBad
        _runDetails.overall_result_matrix = repr(voteTab)
        nRej = nClass - nPts
        if nRej > 0:
            _runDetails.overall_fraction_dropped = float(nRej) / nPts

        if details.splitRun:
            message('\nHold-out data:')
            resTup = ScreenComposite.ShowVoteResults(range(len(testExamples)),
                                                     testExamples, composite,
                                                     nPossibleVals[-1],
                                                     details.threshold)
            nGood, nBad, nSkip, avgGood, avgBad, avgSkip, voteTab = resTup
            nPts = len(testExamples)
            nClass = nGood + nBad
            _runDetails.holdout_error = float(nBad) / nClass
            _runDetails.holdout_correct_conf = avgGood
            _runDetails.holdout_incorrect_conf = avgBad
            _runDetails.holdout_result_matrix = repr(voteTab)
            nRej = nClass - nPts
            if nRej > 0:
                _runDetails.holdout_fraction_dropped = float(nRej) / nPts

    if details.persistTblName and details.dbName:
        message('Updating results table %s:%s' %
                (details.dbName, details.persistTblName))
        details.Store(db=details.dbName, table=details.persistTblName)

    if details.badName != '':
        badFile = open(details.badName, 'w+')
        for i in range(len(badExamples)):
            ex = badExamples[i]
            vote = wrong[i]
            outStr = '%s\t%s\n' % (ex, vote)
            badFile.write(outStr)
        badFile.close()

    composite.ClearModelExamples()
    return composite
示例#20
0
文件: DbUtils.py 项目: rwest/rdkit
def GetData(dBase,table,fieldString='*',whereString='',user='******',password='******',
            removeDups=-1,join='',forceList=0,transform=None,randomAccess=1,extras=None,cn=None):
  """ a more flexible method to get a set of data from a table

    **Arguments**

     - fields: a string with the names of the fields to be extracted,
          this should be a comma delimited list

     - where: the SQL where clause to be used with the DB query

     - removeDups indicates the column which should be used to screen
        out duplicates.  Only the first appearance of a duplicate will
        be left in the dataset.

    **Returns**

      - a list of the data


    **Notes**

      - EFF: this isn't particularly efficient

  """
  if not cn:
    cn = DbModule.connect(dBase,user,password)
  c = cn.cursor()
  cmd = 'select %s from %s'%(fieldString,table)
  if join:
    if join.strip().find('join') != 0:
      join = 'join %s'%(join)
    cmd += ' ' + join
  if whereString:
    if whereString.strip().find('where')!=0:
      whereString = 'where %s'%(whereString)
    cmd += ' ' + whereString

  if forceList:
    try:
      if not extras:
        c.execute(cmd)
      else:
        c.execute(cmd,extras)
    except:
      sys.stderr.write('the command "%s" generated errors:\n'%(cmd))
      import traceback
      traceback.print_exc()
      return None
    if transform is not None:
      raise ValueError('forceList and transform arguments are not compatible')
    if not randomAccess:
      raise ValueError('when forceList is set, randomAccess must also be used')
    data = c.fetchall()
    if removeDups>0:
      seen = []
      for entry in data[:]:
        if entry[removeDups] in seen:
          data.remove(entry)
        else:
          seen.append(entry[removeDups])
  else:
    if randomAccess:
      klass = RandomAccessDbResultSet 
    else:
      klass = DbResultSet 

    data = klass(c,cn,cmd,removeDups=removeDups,transform=transform,extras=extras)

  return data
示例#21
0
文件: wayne.py 项目: aytsai/ricebowl
  def test11(self):
    " descriptors "
    from rdkit.Chem import Crippen,Descriptors
    curs = self.conn.GetCursor()

    smi = "c1ncccc1"
    m = Chem.MolFromSmiles(smi)
    pkl= DbModule.binaryHolder(m.ToBinary())
    ref = Crippen.MolLogP(m, includeHs=True)
    res = curs.execute("SELECT rd_mollogp(%s)",(smi,)).fetchone()
    v = res[0]
    self.failUnlessAlmostEqual(ref,v,4)
    res = curs.execute("SELECT rd_mollogp(cast (%s as bytea))",(pkl,))
    v = res.fetchone()[0]
    self.failUnlessAlmostEqual(ref,v,4)
    res = curs.execute("SELECT rd_mollogp(rd_molpickle(%s))",(smi,))
    v = res.fetchone()[0]
    self.failUnlessAlmostEqual(ref,v,4)

    ref = Descriptors.MolWt(m)
    res = curs.execute("SELECT rd_molamw(%s)",(smi,))
    v = res.fetchone()[0]
    self.failUnlessAlmostEqual(ref,v,4)
    res = curs.execute("SELECT rd_molamw(cast (%s as bytea))",(pkl,))
    v = res.fetchone()[0]
    self.failUnlessAlmostEqual(ref,v,4)
    res = curs.execute("SELECT rd_molamw(rd_molpickle(%s))",(smi,))
    v = res.fetchone()[0]
    self.failUnlessAlmostEqual(ref,v,4)



    smi = "CCOCC(C)(C)C"
    m = Chem.MolFromSmiles(smi)
    pkl= DbModule.binaryHolder(m.ToBinary())
    ref = Crippen.MolLogP(m,includeHs=1)
    res = curs.execute("SELECT rd_mollogp(%s)",(smi,))
    v = res.fetchone()[0]
    self.failUnlessAlmostEqual(ref,v,4)
    res = curs.execute("SELECT rd_mollogp(cast (%s as bytea))",(pkl,))
    v = res.fetchone()[0]
    self.failUnlessAlmostEqual(ref,v,4)
    res = curs.execute("SELECT rd_mollogp(rd_molpickle(%s))",(smi,))
    v = res.fetchone()[0]
    self.failUnlessAlmostEqual(ref,v,4)

    ref = Descriptors.MolWt(m)
    res = curs.execute("SELECT rd_molamw(%s)",(smi,))
    v = res.fetchone()[0]
    self.failUnlessAlmostEqual(ref,v,4)
    res = curs.execute("SELECT rd_molamw(cast (%s as bytea))",(pkl,))
    v = res.fetchone()[0]
    self.failUnlessAlmostEqual(ref,v,4)
    res = curs.execute("SELECT rd_molamw(rd_molpickle(%s))",(smi,))
    v = res.fetchone()[0]
    self.failUnlessAlmostEqual(ref,v,4)


    self.failUnlessRaises(DataError,
                          lambda : curs.execute('select rd_mollogp(%s)',('',)))
    curs = self.conn.GetCursor()
    self.failUnlessRaises(DataError,
                          lambda : curs.execute('select rd_mollogp(%s)',('QC',)))
    curs = self.conn.GetCursor()
    self.failUnlessRaises(DataError,
                          lambda : curs.execute('select rd_mollogp(cast (%s as bytea))',('',)))
    curs = self.conn.GetCursor()
    self.failUnlessRaises(DataError,
                          lambda : curs.execute('select rd_mollogp(cast (%s as bytea))',('randomtext',)))
示例#22
0
def RunOnData(details, data, progressCallback=None, saveIt=1, setDescNames=0):
  if details.lockRandom:
    seed = details.randomSeed
  else:
    import random
    seed = (random.randint(0, 1e6), random.randint(0, 1e6))
  DataUtils.InitRandomNumbers(seed)
  testExamples = []
  if details.shuffleActivities == 1:
    DataUtils.RandomizeActivities(data, shuffle=1, runDetails=details)
  elif details.randomActivities == 1:
    DataUtils.RandomizeActivities(data, shuffle=0, runDetails=details)

  namedExamples = data.GetNamedData()
  if details.splitRun == 1:
    trainIdx, testIdx = SplitData.SplitIndices(
      len(namedExamples), details.splitFrac, silent=not _verbose)

    trainExamples = [namedExamples[x] for x in trainIdx]
    testExamples = [namedExamples[x] for x in testIdx]
  else:
    testExamples = []
    testIdx = []
    trainIdx = list(range(len(namedExamples)))
    trainExamples = namedExamples

  if details.filterFrac != 0.0:
    # if we're doing quantization on the fly, we need to handle that here:
    if hasattr(details, 'activityBounds') and details.activityBounds:
      tExamples = []
      bounds = details.activityBounds
      for pt in trainExamples:
        pt = pt[:]
        act = pt[-1]
        placed = 0
        bound = 0
        while not placed and bound < len(bounds):
          if act < bounds[bound]:
            pt[-1] = bound
            placed = 1
          else:
            bound += 1
        if not placed:
          pt[-1] = bound
        tExamples.append(pt)
    else:
      bounds = None
      tExamples = trainExamples
    trainIdx, temp = DataUtils.FilterData(tExamples, details.filterVal, details.filterFrac, -1,
                                          indicesOnly=1)
    tmp = [trainExamples[x] for x in trainIdx]
    testExamples += [trainExamples[x] for x in temp]
    trainExamples = tmp

    counts = DataUtils.CountResults(trainExamples, bounds=bounds)
    ks = counts.keys()
    ks.sort()
    message('Result Counts in training set:')
    for k in ks:
      message(str((k, counts[k])))
    counts = DataUtils.CountResults(testExamples, bounds=bounds)
    ks = counts.keys()
    ks.sort()
    message('Result Counts in test set:')
    for k in ks:
      message(str((k, counts[k])))
  nExamples = len(trainExamples)
  message('Training with %d examples' % (nExamples))

  nVars = data.GetNVars()
  attrs = list(range(1, nVars + 1))
  nPossibleVals = data.GetNPossibleVals()
  for i in range(1, len(nPossibleVals)):
    if nPossibleVals[i - 1] == -1:
      attrs.remove(i)

  if details.pickleDataFileName != '':
    pickleDataFile = open(details.pickleDataFileName, 'wb+')
    pickle.dump(trainExamples, pickleDataFile)
    pickle.dump(testExamples, pickleDataFile)
    pickleDataFile.close()

  if details.bayesModel:
    composite = BayesComposite.BayesComposite()
  else:
    composite = Composite.Composite()

  composite._randomSeed = seed
  composite._splitFrac = details.splitFrac
  composite._shuffleActivities = details.shuffleActivities
  composite._randomizeActivities = details.randomActivities

  if hasattr(details, 'filterFrac'):
    composite._filterFrac = details.filterFrac
  if hasattr(details, 'filterVal'):
    composite._filterVal = details.filterVal

  composite.SetModelFilterData(details.modelFilterFrac, details.modelFilterVal)

  composite.SetActivityQuantBounds(details.activityBounds)
  nPossibleVals = data.GetNPossibleVals()
  if details.activityBounds:
    nPossibleVals[-1] = len(details.activityBounds) + 1

  if setDescNames:
    composite.SetInputOrder(data.GetVarNames())
    composite.SetDescriptorNames(details._descNames)
  else:
    composite.SetDescriptorNames(data.GetVarNames())
  composite.SetActivityQuantBounds(details.activityBounds)
  if details.nModels == 1:
    details.internalHoldoutFrac = 0.0
  if details.useTrees:
    from rdkit.ML.DecTree import CrossValidate, PruneTree
    if details.qBounds != []:
      from rdkit.ML.DecTree import BuildQuantTree
      builder = BuildQuantTree.QuantTreeBoot
    else:
      from rdkit.ML.DecTree import ID3
      builder = ID3.ID3Boot
    driver = CrossValidate.CrossValidationDriver
    pruner = PruneTree.PruneTree

    composite.SetQuantBounds(details.qBounds)
    nPossibleVals = data.GetNPossibleVals()
    if details.activityBounds:
      nPossibleVals[-1] = len(details.activityBounds) + 1
    composite.Grow(
      trainExamples, attrs, nPossibleVals=[0] + nPossibleVals, buildDriver=driver, pruner=pruner,
      nTries=details.nModels, pruneIt=details.pruneIt, lessGreedy=details.lessGreedy,
      needsQuantization=0, treeBuilder=builder, nQuantBounds=details.qBounds,
      startAt=details.startAt, maxDepth=details.limitDepth, progressCallback=progressCallback,
      holdOutFrac=details.internalHoldoutFrac, replacementSelection=details.replacementSelection,
      recycleVars=details.recycleVars, randomDescriptors=details.randomDescriptors,
      silent=not _verbose)

  elif details.useSigTrees:
    from rdkit.ML.DecTree import CrossValidate
    from rdkit.ML.DecTree import BuildSigTree
    builder = BuildSigTree.SigTreeBuilder
    driver = CrossValidate.CrossValidationDriver
    nPossibleVals = data.GetNPossibleVals()
    if details.activityBounds:
      nPossibleVals[-1] = len(details.activityBounds) + 1
    if hasattr(details, 'sigTreeBiasList'):
      biasList = details.sigTreeBiasList
    else:
      biasList = None
    if hasattr(details, 'useCMIM'):
      useCMIM = details.useCMIM
    else:
      useCMIM = 0
    if hasattr(details, 'allowCollections'):
      allowCollections = details.allowCollections
    else:
      allowCollections = False
    composite.Grow(
      trainExamples, attrs, nPossibleVals=[0] + nPossibleVals, buildDriver=driver,
      nTries=details.nModels, needsQuantization=0, treeBuilder=builder, maxDepth=details.limitDepth,
      progressCallback=progressCallback, holdOutFrac=details.internalHoldoutFrac,
      replacementSelection=details.replacementSelection, recycleVars=details.recycleVars,
      randomDescriptors=details.randomDescriptors, biasList=biasList, useCMIM=useCMIM,
      allowCollection=allowCollections, silent=not _verbose)

  elif details.useKNN:
    from rdkit.ML.KNN import CrossValidate
    from rdkit.ML.KNN import DistFunctions

    driver = CrossValidate.CrossValidationDriver
    dfunc = ''
    if (details.knnDistFunc == "Euclidean"):
      dfunc = DistFunctions.EuclideanDist
    elif (details.knnDistFunc == "Tanimoto"):
      dfunc = DistFunctions.TanimotoDist
    else:
      assert 0, "Bad KNN distance metric value"

    composite.Grow(trainExamples, attrs, nPossibleVals=[0] + nPossibleVals, buildDriver=driver,
                   nTries=details.nModels, needsQuantization=0, numNeigh=details.knnNeighs,
                   holdOutFrac=details.internalHoldoutFrac, distFunc=dfunc)

  elif details.useNaiveBayes or details.useSigBayes:
    from rdkit.ML.NaiveBayes import CrossValidate
    driver = CrossValidate.CrossValidationDriver
    if not (hasattr(details, 'useSigBayes') and details.useSigBayes):
      composite.Grow(trainExamples, attrs, nPossibleVals=[0] + nPossibleVals, buildDriver=driver,
                     nTries=details.nModels, needsQuantization=0, nQuantBounds=details.qBounds,
                     holdOutFrac=details.internalHoldoutFrac,
                     replacementSelection=details.replacementSelection,
                     mEstimateVal=details.mEstimateVal, silent=not _verbose)
    else:
      if hasattr(details, 'useCMIM'):
        useCMIM = details.useCMIM
      else:
        useCMIM = 0

      composite.Grow(trainExamples, attrs, nPossibleVals=[0] + nPossibleVals, buildDriver=driver,
                     nTries=details.nModels, needsQuantization=0, nQuantBounds=details.qBounds,
                     mEstimateVal=details.mEstimateVal, useSigs=True, useCMIM=useCMIM,
                     holdOutFrac=details.internalHoldoutFrac,
                     replacementSelection=details.replacementSelection, silent=not _verbose)

    # #   elif details.useSVM:
    # #     from rdkit.ML.SVM import CrossValidate
    # #     driver = CrossValidate.CrossValidationDriver
    # #     composite.Grow(trainExamples, attrs, nPossibleVals=[0]+nPossibleVals,
    # #                    buildDriver=driver, nTries=details.nModels,
    # #                    needsQuantization=0,
    # #                    cost=details.svmCost,gamma=details.svmGamma,
    # #                    weights=details.svmWeights,degree=details.svmDegree,
    # #                    type=details.svmType,kernelType=details.svmKernel,
    # #                    coef0=details.svmCoeff,eps=details.svmEps,nu=details.svmNu,
    # #                    cache_size=details.svmCache,shrinking=details.svmShrink,
    # #                    dataType=details.svmDataType,
    # #                    holdOutFrac=details.internalHoldoutFrac,
    # #                    replacementSelection=details.replacementSelection,
    # #                    silent=not _verbose)

  else:
    from rdkit.ML.Neural import CrossValidate
    driver = CrossValidate.CrossValidationDriver
    composite.Grow(trainExamples, attrs, [0] + nPossibleVals, nTries=details.nModels,
                   buildDriver=driver, needsQuantization=0)

  composite.AverageErrors()
  composite.SortModels()
  modelList, counts, avgErrs = composite.GetAllData()
  counts = numpy.array(counts)
  avgErrs = numpy.array(avgErrs)
  composite._varNames = data.GetVarNames()

  for i in range(len(modelList)):
    modelList[i].NameModel(composite._varNames)

  # do final statistics
  weightedErrs = counts * avgErrs
  averageErr = sum(weightedErrs) / sum(counts)
  devs = (avgErrs - averageErr)
  devs = devs * counts
  devs = numpy.sqrt(devs * devs)
  avgDev = sum(devs) / sum(counts)
  message('# Overall Average Error: %%% 5.2f, Average Deviation: %%% 6.2f' %
          (100. * averageErr, 100. * avgDev))

  if details.bayesModel:
    composite.Train(trainExamples, verbose=0)

  # blow out the saved examples and then save the composite:
  composite.ClearModelExamples()
  if saveIt:
    composite.Pickle(details.outName)
  details.model = DbModule.binaryHolder(pickle.dumps(composite))

  badExamples = []
  if not details.detailedRes and (not hasattr(details, 'noScreen') or not details.noScreen):
    if details.splitRun:
      message('Testing all hold-out examples')
      wrong = testall(composite, testExamples, badExamples)
      message('%d examples (%% %5.2f) were misclassified' % (len(wrong), 100. * float(len(wrong)) /
                                                             float(len(testExamples))))
      _runDetails.holdout_error = float(len(wrong)) / len(testExamples)
    else:
      message('Testing all examples')
      wrong = testall(composite, namedExamples, badExamples)
      message('%d examples (%% %5.2f) were misclassified' % (len(wrong), 100. * float(len(wrong)) /
                                                             float(len(namedExamples))))
      _runDetails.overall_error = float(len(wrong)) / len(namedExamples)

  if details.detailedRes:
    message('\nEntire data set:')
    resTup = ScreenComposite.ShowVoteResults(
      range(data.GetNPts()), data, composite, nPossibleVals[-1], details.threshold)
    nGood, nBad, nSkip, avgGood, avgBad, avgSkip, voteTab = resTup
    nPts = len(namedExamples)
    nClass = nGood + nBad
    _runDetails.overall_error = float(nBad) / nClass
    _runDetails.overall_correct_conf = avgGood
    _runDetails.overall_incorrect_conf = avgBad
    _runDetails.overall_result_matrix = repr(voteTab)
    nRej = nClass - nPts
    if nRej > 0:
      _runDetails.overall_fraction_dropped = float(nRej) / nPts

    if details.splitRun:
      message('\nHold-out data:')
      resTup = ScreenComposite.ShowVoteResults(
        range(len(testExamples)), testExamples, composite, nPossibleVals[-1], details.threshold)
      nGood, nBad, nSkip, avgGood, avgBad, avgSkip, voteTab = resTup
      nPts = len(testExamples)
      nClass = nGood + nBad
      _runDetails.holdout_error = float(nBad) / nClass
      _runDetails.holdout_correct_conf = avgGood
      _runDetails.holdout_incorrect_conf = avgBad
      _runDetails.holdout_result_matrix = repr(voteTab)
      nRej = nClass - nPts
      if nRej > 0:
        _runDetails.holdout_fraction_dropped = float(nRej) / nPts

  if details.persistTblName and details.dbName:
    message('Updating results table %s:%s' % (details.dbName, details.persistTblName))
    details.Store(db=details.dbName, table=details.persistTblName)

  if details.badName != '':
    badFile = open(details.badName, 'w+')
    for i in range(len(badExamples)):
      ex = badExamples[i]
      vote = wrong[i]
      outStr = '%s\t%s\n' % (ex, vote)
      badFile.write(outStr)
    badFile.close()

  composite.ClearModelExamples()
  return composite
示例#23
0
def CreateDb(options,dataFilename='',supplier=None):
  if not dataFilename and supplier is None:
    raise ValueError('Please provide either a data filename or a supplier')

  if options.errFilename:
    errFile=open(os.path.join(options.outDir,options.errFilename),'w+')
  else:
    errFile=None

  if options.noExtras:
    options.doPairs=False
    options.doDescriptors=False
    options.doFingerprints=False
    options.doPharm2D=False
    options.doGobbi2D=False
    options.doLayered=False
    options.doMorganFps=False

  if options.loadMols:
    if supplier is None:
      if not options.molFormat:
        ext = os.path.splitext(dataFilename)[-1].lower()
        if ext=='.sdf':
          options.molFormat='sdf'
        elif ext in ('.smi','.smiles','.txt','.csv'):
          options.molFormat='smiles'
          if not options.delimiter:
            # guess the delimiter
            import csv
            sniffer = csv.Sniffer()
            dlct=sniffer.sniff(open(dataFilename,'r').read(2000))
            options.delimiter=dlct.delimiter
            if not options.silent:
              logger.info('Guessing that delimiter is %s. Use --delimiter argument if this is wrong.'%repr(options.delimiter))

        if not options.silent:
          logger.info('Guessing that mol format is %s. Use --molFormat argument if this is wrong.'%repr(options.molFormat))  
      if options.molFormat=='smiles':
        if options.delimiter=='\\t': options.delimiter='\t'
        supplier=Chem.SmilesMolSupplier(dataFilename,
                                        titleLine=options.titleLine,
                                        delimiter=options.delimiter,
                                        smilesColumn=options.smilesColumn,
                                        nameColumn=options.nameColumn
                                        )
      else:
        supplier = Chem.SDMolSupplier(dataFilename)
    if not options.silent: logger.info('Reading molecules and constructing molecular database.')
    Loader.LoadDb(supplier,os.path.join(options.outDir,options.molDbName),
                  errorsTo=errFile,regName=options.regName,nameCol=options.molIdName,
                  skipProps=options.skipProps,defaultVal=options.missingPropertyVal,
                  addComputedProps=options.addProps,uniqNames=True,
                  skipSmiles=options.skipSmiles,maxRowsCached=int(options.maxRowsCached),
                  silent=options.silent,nameProp=options.nameProp,
                  lazySupplier=int(options.maxRowsCached)>0,
                  startAnew=not options.updateDb
                  )

  if options.doPairs:
    pairConn = DbConnect(os.path.join(options.outDir,options.pairDbName))
    pairCurs = pairConn.GetCursor()
    try:
      pairCurs.execute('drop table %s'%(options.pairTableName))
    except:
      pass
    pairCurs.execute('create table %s (guid integer not null primary key,%s varchar not null unique,atompairfp blob,torsionfp blob)'%(options.pairTableName,
                                                                                                         options.molIdName))

  if options.doFingerprints or options.doPharm2D or options.doGobbi2D or options.doLayered:
    fpConn = DbConnect(os.path.join(options.outDir,options.fpDbName))
    fpCurs=fpConn.GetCursor()
    try:
      fpCurs.execute('drop table %s'%(options.fpTableName))
    except:
      pass
    try:
      fpCurs.execute('drop table %s'%(options.pharm2DTableName))
    except:
      pass
    try:
      fpCurs.execute('drop table %s'%(options.gobbi2DTableName))
    except:
      pass
    try:
      fpCurs.execute('drop table %s'%(options.layeredTableName))
    except:
      pass

    if options.doFingerprints:
      fpCurs.execute('create table %s (guid integer not null primary key,%s varchar not null unique,rdkfp blob)'%(options.fpTableName,
                                                                                     options.molIdName))
    if options.doLayered:
      layeredQs = ','.join('?'*LayeredOptions.nWords)
      colDefs=','.join(['Col_%d integer'%(x+1) for x in range(LayeredOptions.nWords)])
      fpCurs.execute('create table %s (guid integer not null primary key,%s varchar not null unique,%s)'%(options.layeredTableName,
                                                                             options.molIdName,
                                                                             colDefs))
      
    if options.doPharm2D:
      fpCurs.execute('create table %s (guid integer not null primary key,%s varchar not null unique,pharm2dfp blob)'%(options.pharm2DTableName,
                                                                                     options.molIdName))
      sigFactory = BuildSigFactory(options)
    if options.doGobbi2D:
      fpCurs.execute('create table %s (guid integer not null primary key,%s varchar not null unique,gobbi2dfp blob)'%(options.gobbi2DTableName,
                                                                                     options.molIdName))
      from rdkit.Chem.Pharm2D import Generate,Gobbi_Pharm2D

  if options.doMorganFps :
    fpConn = DbConnect(os.path.join(options.outDir,options.fpDbName))
    fpCurs=fpConn.GetCursor()
    try:
      fpCurs.execute('drop table %s'%(options.morganFpTableName))
    except:
      pass
    fpCurs.execute('create table %s (guid integer not null primary key,%s varchar not null unique,morganfp blob)'%(options.morganFpTableName,
                                                                                        options.molIdName))

  if options.doDescriptors:
    descrConn=DbConnect(os.path.join(options.outDir,options.descrDbName))
    with open(options.descriptorCalcFilename,'r') as inTF:
      buf = inTF.read().replace('\r\n', '\n').encode('utf-8')
      inTF.close()
    calc = cPickle.load(io.BytesIO(buf))
    nms = [x for x in calc.GetDescriptorNames()]
    descrCurs = descrConn.GetCursor()
    descrs = ['guid integer not null primary key','%s varchar not null unique'%options.molIdName]
    descrs.extend(['%s float'%x for x in nms])
    try:
      descrCurs.execute('drop table %s'%(options.descrTableName))
    except:
      pass
    descrCurs.execute('create table %s (%s)'%(options.descrTableName,','.join(descrs)))
    descrQuery=','.join([DbModule.placeHolder]*len(descrs))
  pairRows = []
  fpRows = []
  layeredRows = []
  descrRows = []
  pharm2DRows=[]
  gobbi2DRows=[]
  morganRows = []

  if not options.silent: logger.info('Generating fingerprints and descriptors:')
  molConn = DbConnect(os.path.join(options.outDir,options.molDbName))
  molCurs = molConn.GetCursor()
  if not options.skipSmiles:
    molCurs.execute('select guid,%s,smiles,molpkl from %s'%(options.molIdName,options.regName))
  else:
    molCurs.execute('select guid,%s,molpkl from %s'%(options.molIdName,options.regName))
  i=0
  while 1:
    try:
      tpl = molCurs.fetchone()
      molGuid = tpl[0]
      molId = tpl[1]
      pkl = tpl[-1]
      i+=1
    except:
      break
    if isinstance(pkl,(bytes,str)):
      mol = Chem.Mol(pkl)
    else:
      mol = Chem.Mol(str(pkl))
    if not mol: continue
     
    if options.doPairs:
      pairs = FingerprintUtils.BuildAtomPairFP(mol)
      torsions = FingerprintUtils.BuildTorsionsFP(mol)
      pkl1 = DbModule.binaryHolder(pairs.ToBinary())
      pkl2 = DbModule.binaryHolder(torsions.ToBinary())
      row = (molGuid,molId,pkl1,pkl2)
      pairRows.append(row)
    if options.doFingerprints:
      fp2 = FingerprintUtils.BuildRDKitFP(mol)
      pkl = DbModule.binaryHolder(fp2.ToBinary())
      row = (molGuid,molId,pkl)
      fpRows.append(row)
    if options.doLayered:
      words = LayeredOptions.GetWords(mol)
      row = [molGuid,molId]+words
      layeredRows.append(row)
    if options.doDescriptors:
      descrs= calc.CalcDescriptors(mol)
      row = [molGuid,molId]
      row.extend(descrs)
      descrRows.append(row)
    if options.doPharm2D:
      FingerprintUtils.sigFactory=sigFactory
      fp= FingerprintUtils.BuildPharm2DFP(mol)
      pkl = DbModule.binaryHolder(fp.ToBinary())
      row = (molGuid,molId,pkl)
      pharm2DRows.append(row)
    if options.doGobbi2D:
      FingerprintUtils.sigFactory=Gobbi_Pharm2D.factory
      fp= FingerprintUtils.BuildPharm2DFP(mol)
      pkl = DbModule.binaryHolder(fp.ToBinary())
      row = (molGuid,molId,pkl)
      gobbi2DRows.append(row)
    if options.doMorganFps:
      morgan = FingerprintUtils.BuildMorganFP(mol)
      pkl = DbModule.binaryHolder(morgan.ToBinary())
      row = (molGuid,molId,pkl)
      morganRows.append(row)

    if not i%500:
      if len(pairRows):
        pairCurs.executemany('insert into %s values (?,?,?,?)'%options.pairTableName,
                             pairRows)
        pairRows = []
        pairConn.Commit()
      if len(fpRows):
        fpCurs.executemany('insert into %s values (?,?,?)'%options.fpTableName,
                           fpRows)
        fpRows = []
        fpConn.Commit()
      if len(layeredRows):
        fpCurs.executemany('insert into %s values (?,?,%s)'%(options.layeredTableName,layeredQs),
                           layeredRows)
        layeredRows = []
        fpConn.Commit()
      if len(descrRows):
        descrCurs.executemany('insert into %s values (%s)'%(options.descrTableName,descrQuery),
                              descrRows)
        descrRows = []
        descrConn.Commit()
      if len(pharm2DRows):
        fpCurs.executemany('insert into %s values (?,?,?)'%options.pharm2DTableName,
                           pharm2DRows)
        pharm2DRows = []
        fpConn.Commit()
      if len(gobbi2DRows):
        fpCurs.executemany('insert into %s values (?,?,?)'%options.gobbi2DTableName,
                           gobbi2DRows)
        gobbi2DRows = []
        fpConn.Commit()
      if len(morganRows):
        fpCurs.executemany('insert into %s values (?,?,?)'%options.morganFpTableName,
                             morganRows)
        morganRows = []
        fpConn.Commit()
        
    if not options.silent and not i%500: 
      logger.info('  Done: %d'%(i))

  if len(pairRows):
    pairCurs.executemany('insert into %s values (?,?,?,?)'%options.pairTableName,
                         pairRows)
    pairRows = []
    pairConn.Commit()
  if len(fpRows):
    fpCurs.executemany('insert into %s values (?,?,?)'%options.fpTableName,
                       fpRows)
    fpRows = []
    fpConn.Commit()
  if len(layeredRows):
    fpCurs.executemany('insert into %s values (?,?,%s)'%(options.layeredTableName,layeredQs),
                       layeredRows)
    layeredRows = []
    fpConn.Commit()
  if len(descrRows):
    descrCurs.executemany('insert into %s values (%s)'%(options.descrTableName,descrQuery),
                          descrRows)
    descrRows = []
    descrConn.Commit()
  if len(pharm2DRows):
    fpCurs.executemany('insert into %s values (?,?,?)'%options.pharm2DTableName,
                       pharm2DRows)
    pharm2DRows = []
    fpConn.Commit()
  if len(gobbi2DRows):
    fpCurs.executemany('insert into %s values (?,?,?)'%options.gobbi2DTableName,
                       gobbi2DRows)
    gobbi2DRows = []
    fpConn.Commit()
  if len(morganRows):
    fpCurs.executemany('insert into %s values (?,?,?)'%options.morganFpTableName,
                       morganRows)
    morganRows = []
    fpConn.Commit()
    
  if not options.silent:
    logger.info('Finished.')
示例#24
0
def CreateDb(options, dataFilename="", supplier=None):
    if not dataFilename and supplier is None:
        raise ValueError, "Please provide either a data filename or a supplier"

    if options.errFilename:
        errFile = file(os.path.join(options.outDir, options.errFilename), "w+")
    else:
        errFile = None

    if options.noExtras:
        options.doPairs = False
        options.doDescriptors = False
        options.doFingerprints = False
        options.doPharm2D = False
        options.doGobbi2D = False
        options.doLayered = False
        options.doMorganFps = False

    if options.loadMols:
        if supplier is None:
            if not options.molFormat:
                ext = os.path.splitext(dataFilename)[-1].lower()
                if ext == ".sdf":
                    options.molFormat = "sdf"
                elif ext in (".smi", ".smiles", ".txt", ".csv"):
                    options.molFormat = "smiles"
                    if not options.delimiter:
                        # guess the delimiter
                        import csv

                        sniffer = csv.Sniffer()
                        dlct = sniffer.sniff(file(dataFilename, "r").read(2000))
                        options.delimiter = dlct.delimiter
                        if not options.silent:
                            logger.info(
                                "Guessing that delimiter is %s. Use --delimiter argument if this is wrong."
                                % repr(options.delimiter)
                            )

                if not options.silent:
                    logger.info(
                        "Guessing that mol format is %s. Use --molFormat argument if this is wrong."
                        % repr(options.molFormat)
                    )
            if options.molFormat == "smiles":
                if options.delimiter == "\\t":
                    options.delimiter = "\t"
                supplier = Chem.SmilesMolSupplier(
                    dataFilename,
                    titleLine=options.titleLine,
                    delimiter=options.delimiter,
                    smilesColumn=options.smilesColumn,
                    nameColumn=options.nameColumn,
                )
            else:
                supplier = Chem.SDMolSupplier(dataFilename)
        if not options.silent:
            logger.info("Reading molecules and constructing molecular database.")
        Loader.LoadDb(
            supplier,
            os.path.join(options.outDir, options.molDbName),
            errorsTo=errFile,
            regName=options.regName,
            nameCol=options.molIdName,
            skipProps=options.skipProps,
            defaultVal=options.missingPropertyVal,
            addComputedProps=options.addProps,
            uniqNames=True,
            skipSmiles=options.skipSmiles,
            maxRowsCached=int(options.maxRowsCached),
            silent=options.silent,
            nameProp=options.nameProp,
            lazySupplier=int(options.maxRowsCached) > 0,
        )
    if options.doPairs:
        pairConn = DbConnect(os.path.join(options.outDir, options.pairDbName))
        pairCurs = pairConn.GetCursor()
        try:
            pairCurs.execute("drop table %s" % (options.pairTableName))
        except:
            pass
        pairCurs.execute(
            "create table %s (guid integer not null primary key,%s varchar not null unique,atompairfp blob,torsionfp blob)"
            % (options.pairTableName, options.molIdName)
        )

    if options.doFingerprints or options.doPharm2D or options.doGobbi2D or options.doLayered:
        fpConn = DbConnect(os.path.join(options.outDir, options.fpDbName))
        fpCurs = fpConn.GetCursor()
        try:
            fpCurs.execute("drop table %s" % (options.fpTableName))
        except:
            pass
        try:
            fpCurs.execute("drop table %s" % (options.pharm2DTableName))
        except:
            pass
        try:
            fpCurs.execute("drop table %s" % (options.gobbi2DTableName))
        except:
            pass
        try:
            fpCurs.execute("drop table %s" % (options.layeredTableName))
        except:
            pass

        if options.doFingerprints:
            fpCurs.execute(
                "create table %s (guid integer not null primary key,%s varchar not null unique,rdkfp blob)"
                % (options.fpTableName, options.molIdName)
            )
        if options.doLayered:
            layeredQs = ",".join("?" * LayeredOptions.nWords)
            colDefs = ",".join(["Col_%d integer" % (x + 1) for x in range(LayeredOptions.nWords)])
            fpCurs.execute(
                "create table %s (guid integer not null primary key,%s varchar not null unique,%s)"
                % (options.layeredTableName, options.molIdName, colDefs)
            )

        if options.doPharm2D:
            fpCurs.execute(
                "create table %s (guid integer not null primary key,%s varchar not null unique,pharm2dfp blob)"
                % (options.pharm2DTableName, options.molIdName)
            )
            sigFactory = BuildSigFactory(options)
        if options.doGobbi2D:
            fpCurs.execute(
                "create table %s (guid integer not null primary key,%s varchar not null unique,gobbi2dfp blob)"
                % (options.gobbi2DTableName, options.molIdName)
            )
            from rdkit.Chem.Pharm2D import Generate, Gobbi_Pharm2D

    if options.doMorganFps:
        fpConn = DbConnect(os.path.join(options.outDir, options.fpDbName))
        fpCurs = fpConn.GetCursor()
        try:
            fpCurs.execute("drop table %s" % (options.morganFpTableName))
        except:
            pass
        fpCurs.execute(
            "create table %s (guid integer not null primary key,%s varchar not null unique,morganfp blob)"
            % (options.morganFpTableName, options.molIdName)
        )

    if options.doDescriptors:
        descrConn = DbConnect(os.path.join(options.outDir, options.descrDbName))
        calc = cPickle.load(file(options.descriptorCalcFilename, "rb"))
        nms = [x for x in calc.GetDescriptorNames()]
        descrCurs = descrConn.GetCursor()
        descrs = ["guid integer not null primary key", "%s varchar not null unique" % options.molIdName]
        descrs.extend(["%s float" % x for x in nms])
        try:
            descrCurs.execute("drop table %s" % (options.descrTableName))
        except:
            pass
        descrCurs.execute("create table %s (%s)" % (options.descrTableName, ",".join(descrs)))
        descrQuery = ",".join([DbModule.placeHolder] * len(descrs))
    pairRows = []
    fpRows = []
    layeredRows = []
    descrRows = []
    pharm2DRows = []
    gobbi2DRows = []
    morganRows = []

    if not options.silent:
        logger.info("Generating fingerprints and descriptors:")
    molConn = DbConnect(os.path.join(options.outDir, options.molDbName))
    molCurs = molConn.GetCursor()
    if not options.skipSmiles:
        molCurs.execute("select guid,%s,smiles,molpkl from %s" % (options.molIdName, options.regName))
    else:
        molCurs.execute("select guid,%s,molpkl from %s" % (options.molIdName, options.regName))
    i = 0
    while 1:
        try:
            tpl = molCurs.fetchone()
            molGuid = tpl[0]
            molId = tpl[1]
            pkl = tpl[-1]
            i += 1
        except:
            break
        mol = Chem.Mol(str(pkl))
        if not mol:
            continue

        if options.doPairs:
            pairs = FingerprintUtils.BuildAtomPairFP(mol)
            torsions = FingerprintUtils.BuildTorsionsFP(mol)
            pkl1 = DbModule.binaryHolder(pairs.ToBinary())
            pkl2 = DbModule.binaryHolder(torsions.ToBinary())
            row = (molGuid, molId, pkl1, pkl2)
            pairRows.append(row)
        if options.doFingerprints:
            fp2 = FingerprintUtils.BuildRDKitFP(mol)
            pkl = DbModule.binaryHolder(fp2.ToBinary())
            row = (molGuid, molId, pkl)
            fpRows.append(row)
        if options.doLayered:
            words = LayeredOptions.GetWords(mol)
            row = [molGuid, molId] + words
            layeredRows.append(row)
        if options.doDescriptors:
            descrs = calc.CalcDescriptors(mol)
            row = [molGuid, molId]
            row.extend(descrs)
            descrRows.append(row)
        if options.doPharm2D:
            FingerprintUtils.sigFactory = sigFactory
            fp = FingerprintUtils.BuildPharm2DFP(mol)
            pkl = DbModule.binaryHolder(fp.ToBinary())
            row = (molGuid, molId, pkl)
            pharm2DRows.append(row)
        if options.doGobbi2D:
            FingerprintUtils.sigFactory = Gobbi_Pharm2D.factory
            fp = FingerprintUtils.BuildPharm2DFP(mol)
            pkl = DbModule.binaryHolder(fp.ToBinary())
            row = (molGuid, molId, pkl)
            gobbi2DRows.append(row)
        if options.doMorganFps:
            morgan = FingerprintUtils.BuildMorganFP(mol)
            pkl = DbModule.binaryHolder(morgan.ToBinary())
            row = (molGuid, molId, pkl)
            morganRows.append(row)

        if not i % 500:
            if len(pairRows):
                pairCurs.executemany("insert into %s values (?,?,?,?)" % options.pairTableName, pairRows)
                pairRows = []
                pairConn.Commit()
            if len(fpRows):
                fpCurs.executemany("insert into %s values (?,?,?)" % options.fpTableName, fpRows)
                fpRows = []
                fpConn.Commit()
            if len(layeredRows):
                fpCurs.executemany(
                    "insert into %s values (?,?,%s)" % (options.layeredTableName, layeredQs), layeredRows
                )
                layeredRows = []
                fpConn.Commit()
            if len(descrRows):
                descrCurs.executemany("insert into %s values (%s)" % (options.descrTableName, descrQuery), descrRows)
                descrRows = []
                descrConn.Commit()
            if len(pharm2DRows):
                fpCurs.executemany("insert into %s values (?,?,?)" % options.pharm2DTableName, pharm2DRows)
                pharm2DRows = []
                fpConn.Commit()
            if len(gobbi2DRows):
                fpCurs.executemany("insert into %s values (?,?,?)" % options.gobbi2DTableName, gobbi2DRows)
                gobbi2DRows = []
                fpConn.Commit()
            if len(morganRows):
                fpCurs.executemany("insert into %s values (?,?,?)" % options.morganFpTableName, morganRows)
                morganRows = []
                fpConn.Commit()

        if not options.silent and not i % 500:
            logger.info("  Done: %d" % (i))

    if len(pairRows):
        pairCurs.executemany("insert into %s values (?,?,?,?)" % options.pairTableName, pairRows)
        pairRows = []
        pairConn.Commit()
    if len(fpRows):
        fpCurs.executemany("insert into %s values (?,?,?)" % options.fpTableName, fpRows)
        fpRows = []
        fpConn.Commit()
    if len(layeredRows):
        fpCurs.executemany("insert into %s values (?,?,%s)" % (options.layeredTableName, layeredQs), layeredRows)
        layeredRows = []
        fpConn.Commit()
    if len(descrRows):
        descrCurs.executemany("insert into %s values (%s)" % (options.descrTableName, descrQuery), descrRows)
        descrRows = []
        descrConn.Commit()
    if len(pharm2DRows):
        fpCurs.executemany("insert into %s values (?,?,?)" % options.pharm2DTableName, pharm2DRows)
        pharm2DRows = []
        fpConn.Commit()
    if len(gobbi2DRows):
        fpCurs.executemany("insert into %s values (?,?,?)" % options.gobbi2DTableName, gobbi2DRows)
        gobbi2DRows = []
        fpConn.Commit()
    if len(morganRows):
        fpCurs.executemany("insert into %s values (?,?,?)" % options.morganFpTableName, morganRows)
        morganRows = []
        fpConn.Commit()

    if not options.silent:
        logger.info("Finished.")
示例#25
0
文件: genfps.py 项目: Kaziaa/rdkit-1
from rdkit import Chem
from rdkit import RDConfig
from rdkit.Dbase import DbModule
from rdkit.Dbase.DbConnection import DbConnect
import pickle

if RDConfig.usePgSQL:
    dbName = "::RDTests"
else:
    dbName = "data.sqlt"

molTblName = 'simple_mols1'
fpTblName = 'simple_mols1_fp'
conn = DbConnect(dbName, molTblName)
conn.AddTable(fpTblName,
              'id varchar(10),autofragmentfp %s' % DbModule.binaryTypeName)
d = conn.GetData()
for smi, ID in d:
    print(repr(ID), repr(smi))
    mol = Chem.MolFromSmiles(smi)
    fp = Chem.RDKFingerprint(mol)
    pkl = pickle.dumps(fp)
    conn.InsertData(fpTblName, (ID, DbModule.binaryHolder(pkl)))
conn.Commit()
示例#26
0
from __future__ import print_function

from rdkit import Chem
from rdkit import RDConfig
from rdkit.Dbase import DbModule
from rdkit.Dbase.DbConnection import DbConnect
from six.moves import cPickle  # @UnresolvedImport

if RDConfig.usePgSQL:
    dbName = "::RDTests"
else:
    dbName = "data.sqlt"

molTblName = 'simple_mols1'
fpTblName = 'simple_mols1_fp'
conn = DbConnect(dbName, molTblName)
conn.AddTable(fpTblName,
              'id varchar(10),autofragmentfp %s' % DbModule.binaryTypeName)
d = conn.GetData()
for smi, ID in d:
    print(repr(ID), repr(smi))
    mol = Chem.MolFromSmiles(smi)
    fp = Chem.RDKFingerprint(mol)
    pkl = cPickle.dumps(fp)
    conn.InsertData(fpTblName, (ID, DbModule.binaryHolder(pkl)))
conn.Commit()
示例#27
0
def FingerprintsFromDetails(details, reportFreq=10):
  data = None
  if details.dbName and details.tableName:
    from rdkit.Dbase.DbConnection import DbConnect
    from rdkit.Dbase import DbInfo
    from rdkit.ML.Data import DataUtils
    try:
      conn = DbConnect(details.dbName, details.tableName)
    except Exception:
      import traceback
      error('Problems establishing connection to database: %s|%s\n' % (details.dbName,
                                                                       details.tableName))
      traceback.print_exc()
    if not details.idName:
      details.idName = DbInfo.GetColumnNames(details.dbName, details.tableName)[0]
    dataSet = DataUtils.DBToData(details.dbName, details.tableName,
                                 what='%s,%s' % (details.idName, details.smilesName))
    idCol = 0
    smiCol = 1
  elif details.inFileName and details.useSmiles:
    from rdkit.ML.Data import DataUtils
    conn = None
    if not details.idName:
      details.idName = 'ID'
    try:
      dataSet = DataUtils.TextFileToData(details.inFileName,
                                         onlyCols=[details.idName, details.smilesName])
    except IOError:
      import traceback
      error('Problems reading from file %s\n' % (details.inFileName))
      traceback.print_exc()

    idCol = 0
    smiCol = 1
  elif details.inFileName and details.useSD:
    conn = None
    dataset = None
    if not details.idName:
      details.idName = 'ID'
    dataSet = []
    try:
      s = Chem.SDMolSupplier(details.inFileName)
    except Exception:
      import traceback
      error('Problems reading from file %s\n' % (details.inFileName))
      traceback.print_exc()
    else:
      while 1:
        try:
          m = s.next()
        except StopIteration:
          break
        if m:
          dataSet.append(m)
          if reportFreq > 0 and not len(dataSet) % reportFreq:
            message('Read %d molecules\n' % (len(dataSet)))
            if details.maxMols > 0 and len(dataSet) >= details.maxMols:
              break

    for i, mol in enumerate(dataSet):
      if mol.HasProp(details.idName):
        nm = mol.GetProp(details.idName)
      else:
        nm = mol.GetProp('_Name')
      dataSet[i] = (nm, mol)
  else:
    dataSet = None

  fps = None
  if dataSet and not details.useSD:
    data = dataSet.GetNamedData()
    if not details.molPklName:
      fps = apply(FingerprintsFromSmiles, (data, idCol, smiCol), details.__dict__)
    else:
      fps = apply(FingerprintsFromPickles, (data, idCol, smiCol), details.__dict__)
  elif dataSet and details.useSD:
    fps = apply(FingerprintsFromMols, (dataSet, ), details.__dict__)

  if fps:
    if details.outFileName:
      outF = open(details.outFileName, 'wb+')
      for i in range(len(fps)):
        cPickle.dump(fps[i], outF)
      outF.close()
    dbName = details.outDbName or details.dbName
    if details.outTableName and dbName:
      from rdkit.Dbase.DbConnection import DbConnect
      from rdkit.Dbase import DbUtils, DbModule
      conn = DbConnect(dbName)
      #
      #  We don't have a db open already, so we'll need to figure out
      #    the types of our columns...
      #
      colTypes = DbUtils.TypeFinder(data, len(data), len(data[0]))
      typeStrs = DbUtils.GetTypeStrings([details.idName, details.smilesName], colTypes,
                                        keyCol=details.idName)
      cols = '%s, %s %s' % (typeStrs[0], details.fpColName, DbModule.binaryTypeName)

      # FIX: we should really check to see if the table
      #  is already there and, if so, add the appropriate
      #  column.

      #
      # create the new table
      #
      if details.replaceTable or \
         details.outTableName.upper() not in [x.upper() for x in conn.GetTableNames()]:
        conn.AddTable(details.outTableName, cols)

      #
      # And add the data
      #
      for ID, fp in fps:
        tpl = ID, DbModule.binaryHolder(fp.ToBinary())
        conn.InsertData(details.outTableName, tpl)
      conn.Commit()
  return fps