Package rdkit :: Package ML :: Module AnalyzeComposite
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.AnalyzeComposite

  1  # $Id$ 
  2  # 
  3  #  Copyright (C) 2002-2008  greg Landrum and Rational Discovery LLC 
  4  # 
  5  #   @@ All Rights Reserved @@ 
  6  #  This file is part of the RDKit. 
  7  #  The contents are covered by the terms of the BSD license 
  8  #  which is included in the file license.txt, found at the root 
  9  #  of the RDKit source tree. 
 10  # 
 11  """ command line utility to report on the contributions of descriptors to 
 12  tree-based composite models 
 13   
 14  Usage:  AnalyzeComposite [optional args] <models> 
 15   
 16        <models>: file name(s) of pickled composite model(s) 
 17          (this is the name of the db table if using a database) 
 18   
 19      Optional Arguments: 
 20   
 21        -n number: the number of levels of each model to consider 
 22   
 23        -d dbname: the database from which to read the models 
 24   
 25        -N Note: the note string to search for to pull models from the database 
 26   
 27        -v: be verbose whilst screening 
 28  """ 
 29  from __future__ import print_function 
 30  import numpy 
 31  import sys 
 32  from rdkit.six.moves import cPickle 
 33  from rdkit.ML.DecTree import TreeUtils,Tree 
 34  from rdkit.ML.Data import Stats 
 35  from rdkit.Dbase.DbConnection import DbConnect 
 36  from rdkit.ML import ScreenComposite 
 37   
 38  __VERSION_STRING="2.2.0" 
 39   
40 -def ProcessIt(composites,nToConsider=3,verbose=0):
41 composite=composites[0] 42 nComposites =len(composites) 43 ns = composite.GetDescriptorNames() 44 #nDesc = len(ns)-2 45 if len(ns)>2: 46 globalRes = {} 47 48 nDone = 1 49 descNames = {} 50 for composite in composites: 51 if verbose > 0: 52 print('#------------------------------------') 53 print('Doing: ',nDone) 54 nModels = len(composite) 55 nDone += 1 56 res = {} 57 for i in range(len(composite)): 58 model = composite.GetModel(i) 59 if isinstance(model,Tree.TreeNode): 60 levels = TreeUtils.CollectLabelLevels(model,{},0,nToConsider) 61 TreeUtils.CollectDescriptorNames(model,descNames,0,nToConsider) 62 for descId in levels.keys(): 63 v = res.get(descId,numpy.zeros(nToConsider,numpy.float)) 64 v[levels[descId]] += 1./nModels 65 res[descId] = v 66 for k in res: 67 v = globalRes.get(k,numpy.zeros(nToConsider,numpy.float)) 68 v += res[k]/nComposites 69 globalRes[k] = v 70 if verbose > 0: 71 for k in res.keys(): 72 name = descNames[k] 73 strRes = ', '.join(['%4.2f'%x for x in res[k]]) 74 print('%s,%s,%5.4f'%(name,strRes,sum(res[k]))) 75 76 print() 77 78 79 if verbose >= 0: 80 print('# Average Descriptor Positions') 81 retVal = [] 82 for k in globalRes.keys(): 83 name = descNames[k] 84 if verbose >= 0: 85 strRes = ', '.join(['%4.2f'%x for x in globalRes[k]]) 86 print('%s,%s,%5.4f'%(name,strRes,sum(globalRes[k]))) 87 tmp = [name] 88 tmp.extend(globalRes[k]) 89 tmp.append(sum(globalRes[k])) 90 retVal.append(tmp) 91 if verbose >= 0: 92 print() 93 else: 94 retVal = [] 95 return retVal
96 97
98 -def ErrorStats(conn,where,enrich=1):
99 fields = 'overall_error,holdout_error,overall_result_matrix,holdout_result_matrix,overall_correct_conf,overall_incorrect_conf,holdout_correct_conf,holdout_incorrect_conf' 100 try: 101 data = conn.GetData(fields=fields,where=where) 102 except Exception: 103 import traceback 104 traceback.print_exc() 105 return None 106 nPts = len(data) 107 if not nPts: 108 sys.stderr.write('no runs found\n') 109 return None 110 overall = numpy.zeros(nPts,numpy.float) 111 overallEnrich = numpy.zeros(nPts,numpy.float) 112 oCorConf = 0.0 113 oInCorConf = 0.0 114 holdout = numpy.zeros(nPts,numpy.float) 115 holdoutEnrich = numpy.zeros(nPts,numpy.float) 116 hCorConf = 0.0 117 hInCorConf = 0.0 118 overallMatrix = None 119 holdoutMatrix = None 120 for i in range(nPts): 121 if data[i][0] is not None: 122 overall[i] = data[i][0] 123 oCorConf += data[i][4] 124 oInCorConf += data[i][5] 125 if data[i][1] is not None: 126 holdout[i] = data[i][1] 127 haveHoldout=1 128 else: 129 haveHoldout=0 130 tmpOverall = 1.*eval(data[i][2]) 131 if enrich >=0: 132 overallEnrich[i] = ScreenComposite.CalcEnrichment(tmpOverall,tgt=enrich) 133 if haveHoldout: 134 tmpHoldout = 1.*eval(data[i][3]) 135 if enrich >=0: 136 holdoutEnrich[i] = ScreenComposite.CalcEnrichment(tmpHoldout,tgt=enrich) 137 if overallMatrix is None: 138 if data[i][2] is not None: 139 overallMatrix = tmpOverall 140 if haveHoldout and data[i][3] is not None: 141 holdoutMatrix = tmpHoldout 142 else: 143 overallMatrix += tmpOverall 144 if haveHoldout: 145 holdoutMatrix += tmpHoldout 146 if haveHoldout: 147 hCorConf += data[i][6] 148 hInCorConf += data[i][7] 149 150 avgOverall = sum(overall)/nPts 151 oCorConf /= nPts 152 oInCorConf /= nPts 153 overallMatrix /= nPts 154 oSort = numpy.argsort(overall) 155 oMin = overall[oSort[0]] 156 overall -= avgOverall 157 devOverall = sqrt(sum(overall**2)/(nPts-1)) 158 res = {} 159 res['oAvg'] = 100*avgOverall 160 res['oDev'] = 100*devOverall 161 res['oCorrectConf'] = 100*oCorConf 162 res['oIncorrectConf'] = 100*oInCorConf 163 res['oResultMat']=overallMatrix 164 res['oBestIdx']=oSort[0] 165 res['oBestErr']=100*oMin 166 167 if enrich>=0: 168 mean,dev = Stats.MeanAndDev(overallEnrich) 169 res['oAvgEnrich'] = mean 170 res['oDevEnrich'] = dev 171 172 if haveHoldout: 173 avgHoldout = sum(holdout)/nPts 174 hCorConf /= nPts 175 hInCorConf /= nPts 176 holdoutMatrix /= nPts 177 hSort = numpy.argsort(holdout) 178 hMin = holdout[hSort[0]] 179 holdout -= avgHoldout 180 devHoldout = sqrt(sum(holdout**2)/(nPts-1)) 181 res['hAvg'] = 100*avgHoldout 182 res['hDev'] = 100*devHoldout 183 res['hCorrectConf'] = 100*hCorConf 184 res['hIncorrectConf'] = 100*hInCorConf 185 res['hResultMat']=holdoutMatrix 186 res['hBestIdx']=hSort[0] 187 res['hBestErr']=100*hMin 188 if enrich>=0: 189 mean,dev = Stats.MeanAndDev(holdoutEnrich) 190 res['hAvgEnrich'] = mean 191 res['hDevEnrich'] = dev 192 return res
193
194 -def ShowStats(statD,enrich=1):
195 statD = statD.copy() 196 statD['oBestIdx'] = statD['oBestIdx']+1 197 txt=""" 198 # Error Statistics: 199 \tOverall: %(oAvg)6.3f%% (%(oDev)6.3f) %(oCorrectConf)4.1f/%(oIncorrectConf)4.1f 200 \t\tBest: %(oBestIdx)d %(oBestErr)6.3f%%"""%(statD) 201 if 'hAvg' in statD: 202 statD['hBestIdx'] = statD['hBestIdx']+1 203 txt += """ 204 \tHoldout: %(hAvg)6.3f%% (%(hDev)6.3f) %(hCorrectConf)4.1f/%(hIncorrectConf)4.1f 205 \t\tBest: %(hBestIdx)d %(hBestErr)6.3f%% 206 """%(statD) 207 print(txt) 208 print() 209 print('# Results matrices:') 210 print('\tOverall:') 211 tmp = transpose(statD['oResultMat']) 212 colCounts = sum(tmp) 213 rowCounts = sum(tmp,1) 214 for i in range(len(tmp)): 215 if rowCounts[i]==0: rowCounts[i]=1 216 row = tmp[i] 217 print('\t\t', end='') 218 for j in range(len(row)): 219 print('% 6.2f'%row[j], end='') 220 print('\t| % 4.2f'%(100.*tmp[i,i]/rowCounts[i])) 221 print('\t\t', end='') 222 for i in range(len(tmp)): 223 print('------',end='') 224 print() 225 print('\t\t',end='') 226 for i in range(len(tmp)): 227 if colCounts[i]==0: colCounts[i]=1 228 print('% 6.2f'%(100.*tmp[i,i]/colCounts[i]), end='') 229 print() 230 if enrich>-1 and 'oAvgEnrich' in statD: 231 print('\t\tEnrich(%d): %.3f (%.3f)'%(enrich,statD['oAvgEnrich'],statD['oDevEnrich'])) 232 233 234 if 'hResultMat' in statD: 235 print('\tHoldout:') 236 tmp = transpose(statD['hResultMat']) 237 colCounts = sum(tmp) 238 rowCounts = sum(tmp,1) 239 for i in range(len(tmp)): 240 if rowCounts[i]==0: rowCounts[i]=1 241 row = tmp[i] 242 print('\t\t', end='') 243 for j in range(len(row)): 244 print('% 6.2f'%row[j], end='') 245 print('\t| % 4.2f'%(100.*tmp[i,i]/rowCounts[i])) 246 print('\t\t',end='') 247 for i in range(len(tmp)): 248 print('------',end='') 249 print() 250 print('\t\t',end='') 251 for i in range(len(tmp)): 252 if colCounts[i]==0: colCounts[i]=1 253 print('% 6.2f'%(100.*tmp[i,i]/colCounts[i]),end='') 254 print() 255 if enrich>-1 and 'hAvgEnrich' in statD: 256 print('\t\tEnrich(%d): %.3f (%.3f)'%(enrich,statD['hAvgEnrich'],statD['hDevEnrich'])) 257 258 259 return
260 261
262 -def Usage():
263 print(__doc__) 264 sys.exit(-1)
265 266 if __name__ == "__main__": 267 import getopt 268 try: 269 args,extras = getopt.getopt(sys.argv[1:],'n:d:N:vX',('skip', 270 'enrich=', 271 )) 272 except Exception: 273 Usage() 274 275 count = 3 276 db = None 277 note = '' 278 verbose = 0 279 skip = 0 280 enrich = 1 281 for arg,val in args: 282 if arg == '-n': 283 count = int(val)+1 284 elif arg == '-d': 285 db = val 286 elif arg == '-N': 287 note = val 288 elif arg == '-v': 289 verbose = 1 290 elif arg == '--skip': 291 skip = 1 292 elif arg == '--enrich': 293 enrich = int(val) 294 composites = [] 295 if db is None: 296 for arg in extras: 297 composite = cPickle.load(open(arg,'rb')) 298 composites.append(composite) 299 else: 300 tbl = extras[0] 301 conn = DbConnect(db,tbl) 302 if note: 303 where="where note='%s'"%(note) 304 else: 305 where = '' 306 if not skip: 307 pkls = conn.GetData(fields='model',where=where) 308 composites = [] 309 for pkl in pkls: 310 pkl = str(pkl[0]) 311 comp = cPickle.loads(pkl) 312 composites.append(comp) 313 314 if len(composites): 315 ProcessIt(composites,count,verbose=verbose) 316 elif not skip: 317 print('ERROR: no composite models found') 318 sys.exit(-1) 319 320 if db: 321 res = ErrorStats(conn,where,enrich=enrich) 322 if res: 323 ShowStats(res) 324