Package rdkit :: Package ML :: Module EnrichPlot
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.EnrichPlot

  1  # $Id$ 
  2  # 
  3  #  Copyright (C) 2002-2006  greg Landrum and Rational Discovery LLC 
  4  # 
  5  #   @@ All Rights Reserved @@ 
  6  #  This file is part of the RDKit. 
  7  #  The contents are covered by the terms of the BSD license 
  8  #  which is included in the file license.txt, found at the root 
  9  #  of the RDKit source tree. 
 10  # 
 11   
 12  """Command line tool to construct an enrichment plot from saved composite models 
 13   
 14  Usage:  EnrichPlot [optional args] -d dbname -t tablename <models> 
 15   
 16  Required Arguments: 
 17    -d "dbName": the name of the database for screening 
 18   
 19    -t "tablename": provide the name of the table with the data to be screened 
 20   
 21    <models>: file name(s) of pickled composite model(s). 
 22       If the -p argument is also provided (see below), this argument is ignored. 
 23        
 24  Optional Arguments: 
 25    - -a "list": the list of result codes to be considered active.  This will be 
 26          eval'ed, so be sure that it evaluates as a list or sequence of 
 27          integers. For example, -a "[1,2]" will consider activity values 1 and 2 
 28          to be active 
 29   
 30    - --enrich "list": identical to the -a argument above.       
 31   
 32    - --thresh: sets a threshold for the plot.  If the confidence falls below 
 33            this value, picking will be terminated 
 34   
 35    - -H: screen only the hold out set (works only if a version of  
 36          BuildComposite more recent than 1.2.2 was used). 
 37   
 38    - -T: screen only the training set (works only if a version of  
 39          BuildComposite more recent than 1.2.2 was used). 
 40   
 41    - -S: shuffle activity values before screening 
 42   
 43    - -R: randomize activity values before screening 
 44   
 45    - -F *filter frac*: filters the data before training to change the 
 46       distribution of activity values in the training set.  *filter frac* 
 47       is the fraction of the training set that should have the target value. 
 48       **See note in BuildComposite help about data filtering** 
 49   
 50    - -v *filter value*: filters the data before training to change the 
 51       distribution of activity values in the training set. *filter value* 
 52       is the target value to use in filtering. 
 53       **See note in BuildComposite help about data filtering** 
 54   
 55    - -p "tableName": provides the name of a db table containing the 
 56        models to be screened.  If you use this argument, you should also 
 57        use the -N argument (below) to specify a note value. 
 58         
 59    - -N "note": provides a note to be used to pull models from a db table. 
 60   
 61    - --plotFile "filename": writes the data to an output text file (filename.dat) 
 62      and creates a gnuplot input file (filename.gnu) to plot it 
 63   
 64    - --showPlot: causes the gnuplot plot constructed using --plotFile to be 
 65      displayed in gnuplot. 
 66   
 67  """ 
 68  from __future__ import print_function 
 69  from rdkit import RDConfig 
 70  import numpy 
 71  import copy 
 72  from rdkit.six.moves import cPickle 
 73  #from rdkit.Dbase.DbConnection import DbConnect 
 74  from rdkit.ML.Data import DataUtils,SplitData,Stats 
 75  from rdkit.Dbase.DbConnection import DbConnect 
 76  from rdkit import DataStructs 
 77  from rdkit.ML import CompositeRun 
 78  import sys,os,types 
 79  from rdkit.six import cmp 
 80   
 81  __VERSION_STRING="2.4.0" 
82 -def message(msg,noRet=0,dest=sys.stderr):
83 """ emits messages to _sys.stderr_ 84 override this in modules which import this one to redirect output 85 86 **Arguments** 87 88 - msg: the string to be displayed 89 90 """ 91 if noRet: 92 dest.write('%s '%(msg)) 93 else: 94 dest.write('%s\n'%(msg))
95 -def error(msg,dest=sys.stderr):
96 """ emits messages to _sys.stderr_ 97 override this in modules which import this one to redirect output 98 99 **Arguments** 100 101 - msg: the string to be displayed 102 103 """ 104 sys.stderr.write('ERROR: %s\n'%(msg))
105
106 -def ScreenModel(mdl,descs,data,picking=[1],indices=[],errorEstimate=0):
107 """ collects the results of screening an individual composite model that match 108 a particular value 109 110 **Arguments** 111 112 - mdl: the composite model 113 114 - descs: a list of descriptor names corresponding to the data set 115 116 - data: the data set, a list of points to be screened. 117 118 - picking: (Optional) a list of values that are to be collected. 119 For examples, if you want an enrichment plot for picking the values 120 1 and 2, you'd having picking=[1,2]. 121 122 **Returns** 123 124 a list of 4-tuples containing: 125 126 - the id of the point 127 128 - the true result (from the data set) 129 130 - the predicted result 131 132 - the confidence value for the prediction 133 134 """ 135 mdl.SetInputOrder(descs) 136 137 for j in range(len(mdl)): 138 tmp = mdl.GetModel(j) 139 if hasattr(tmp,'_trainIndices') and type(tmp._trainIndices)!=types.DictType: 140 tis = {} 141 if hasattr(tmp,'_trainIndices'): 142 for v in tmp._trainIndices: tis[v]=1 143 tmp._trainIndices=tis 144 145 res = [] 146 if mdl.GetQuantBounds(): 147 needsQuant = 1 148 else: 149 needsQuant = 0 150 151 if not indices: indices = range(len(data)) 152 nTrueActives=0 153 for i in indices: 154 if errorEstimate: 155 use=[] 156 for j in range(len(mdl)): 157 tmp = mdl.GetModel(j) 158 if not tmp._trainIndices.get(i,0): 159 use.append(j) 160 else: 161 use=None 162 pt = data[i] 163 pred,conf = mdl.ClassifyExample(pt,onlyModels=use) 164 if needsQuant: 165 pt = mdl.QuantizeActivity(pt[:]) 166 trueRes = pt[-1] 167 if trueRes in picking: 168 nTrueActives+=1 169 if pred in picking: 170 res.append((pt[0],trueRes,pred,conf)) 171 return nTrueActives,res
172
173 -def AccumulateCounts(predictions,thresh=0,sortIt=1):
174 """ Accumulates the data for the enrichment plot for a single model 175 176 **Arguments** 177 178 - predictions: a list of 3-tuples (as returned by _ScreenModels_) 179 180 - thresh: a threshold for the confidence level. Anything below 181 this threshold will not be considered 182 183 - sortIt: toggles sorting on confidence levels 184 185 186 **Returns** 187 188 - a list of 3-tuples: 189 190 - the id of the active picked here 191 192 - num actives found so far 193 194 - number of picks made so far 195 196 """ 197 if sortIt: 198 predictions.sort(lambda x,y:cmp(y[3],x[3])) 199 res = [] 200 nCorrect = 0 201 nPts = 0 202 for i in range(len(predictions)): 203 id,real,pred,conf = predictions[i] 204 if conf > thresh: 205 if pred == real: 206 nCorrect += 1 207 nPts += 1 208 res.append((id,nCorrect,nPts)) 209 210 return res
211
212 -def MakePlot(details,final,counts,pickVects,nModels,nTrueActs=-1):
213 if not hasattr(details,'plotFile') or not details.plotFile: 214 return 215 216 dataFileName = '%s.dat'%(details.plotFile) 217 outF = open(dataFileName,'w+') 218 i = 0 219 while i < len(final) and counts[i] != 0: 220 if nModels>1: 221 mean,sd = Stats.MeanAndDev(pickVects[i]) 222 confInterval = Stats.GetConfidenceInterval(sd,len(pickVects[i]),level=90) 223 outF.write('%d %f %f %d %f\n'%(i+1,final[i][0]/counts[i], 224 final[i][1]/counts[i],counts[i],confInterval)) 225 else: 226 outF.write('%d %f %f %d\n'%(i+1,final[i][0]/counts[i], 227 final[i][1]/counts[i],counts[i])) 228 i+=1 229 outF.close() 230 plotFileName = '%s.gnu'%(details.plotFile) 231 gnuF = open(plotFileName,'w+') 232 gnuHdr="""# Generated by EnrichPlot.py version: %s 233 set size square 0.7 234 set xr [0:] 235 set data styl points 236 set ylab 'Num Correct Picks' 237 set xlab 'Num Picks' 238 set grid 239 set nokey 240 set term postscript enh color solid "Helvetica" 16 241 set term X 242 """%(__VERSION_STRING) 243 print(gnuHdr, file=gnuF) 244 if nTrueActs >0: 245 print('set yr [0:%d]'%nTrueActs, file=gnuF) 246 print('plot x with lines', file=gnuF) 247 if nModels>1: 248 everyGap = i/20 249 print('replot "%s" using 1:2 with lines,'%(dataFileName),end='', file=gnuF) 250 print('"%s" every %d using 1:2:5 with yerrorbars'%(dataFileName, 251 everyGap), file=gnuF) 252 else: 253 print('replot "%s" with points'%(dataFileName), file=gnuF) 254 gnuF.close() 255 256 if hasattr(details,'showPlot') and details.showPlot: 257 try: 258 import os 259 from Gnuplot import Gnuplot 260 p = Gnuplot() 261 #p('cd "%s"'%(os.getcwd())) 262 p('load "%s"'%(plotFileName)) 263 raw_input('press return to continue...\n') 264 except Exception: 265 import traceback 266 traceback.print_exc()
267 268 269 270
271 -def Usage():
272 """ displays a usage message and exits """ 273 sys.stderr.write(__doc__) 274 sys.exit(-1)
275 276 if __name__=='__main__': 277 import getopt 278 try: 279 args,extras = getopt.getopt(sys.argv[1:],'d:t:a:N:p:cSTHF:v:', 280 ('thresh=','plotFile=','showPlot', 281 'pickleCol=','OOB','noSort','pickBase=', 282 'doROC','rocThresh=','enrich=')) 283 except Exception: 284 import traceback 285 traceback.print_exc() 286 Usage() 287 288 289 details = CompositeRun.CompositeRun() 290 CompositeRun.SetDefaults(details) 291 292 details.activeTgt=[1] 293 details.doTraining = 0 294 details.doHoldout = 0 295 details.dbTableName = '' 296 details.plotFile = '' 297 details.showPlot = 0 298 details.pickleCol = -1 299 details.errorEstimate=0 300 details.sortIt=1 301 details.pickBase = '' 302 details.doROC=0 303 details.rocThresh=-1 304 for arg,val in args: 305 if arg == '-d': 306 details.dbName = val 307 if arg == '-t': 308 details.dbTableName = val 309 elif arg == '-a' or arg == '--enrich': 310 details.activeTgt = eval(val) 311 if(type(details.activeTgt) not in (types.TupleType,types.ListType)): 312 details.activeTgt = (details.activeTgt,) 313 314 elif arg == '--thresh': 315 details.threshold = float(val) 316 elif arg == '-N': 317 details.note = val 318 elif arg == '-p': 319 details.persistTblName = val 320 elif arg == '-S': 321 details.shuffleActivities = 1 322 elif arg == '-H': 323 details.doTraining = 0 324 details.doHoldout = 1 325 elif arg == '-T': 326 details.doTraining = 1 327 details.doHoldout = 0 328 elif arg == '-F': 329 details.filterFrac=float(val) 330 elif arg == '-v': 331 details.filterVal=float(val) 332 elif arg == '--plotFile': 333 details.plotFile = val 334 elif arg == '--showPlot': 335 details.showPlot=1 336 elif arg == '--pickleCol': 337 details.pickleCol=int(val)-1 338 elif arg == '--OOB': 339 details.errorEstimate=1 340 elif arg == '--noSort': 341 details.sortIt=0 342 elif arg == '--doROC': 343 details.doROC=1 344 elif arg == '--rocThresh': 345 details.rocThresh=int(val) 346 elif arg == '--pickBase': 347 details.pickBase=val 348 349 if not details.dbName or not details.dbTableName: 350 Usage() 351 print('*******Please provide both the -d and -t arguments') 352 353 message('Building Data set\n') 354 dataSet = DataUtils.DBToData(details.dbName,details.dbTableName, 355 user=RDConfig.defaultDBUser, 356 password=RDConfig.defaultDBPassword, 357 pickleCol=details.pickleCol, 358 pickleClass=DataStructs.ExplicitBitVect) 359 360 descs = dataSet.GetVarNames() 361 nPts = dataSet.GetNPts() 362 message('npts: %d\n'%(nPts)) 363 final = numpy.zeros((nPts,2),numpy.float) 364 counts = numpy.zeros(nPts,numpy.integer) 365 selPts = [None]*nPts 366 367 models = [] 368 if details.persistTblName: 369 conn = DbConnect(details.dbName,details.persistTblName) 370 message('-> Retrieving models from database') 371 curs = conn.GetCursor() 372 curs.execute("select model from %s where note='%s'"%(details.persistTblName,details.note)) 373 message('-> Reconstructing models') 374 try: 375 blob = curs.fetchone() 376 except Exception: 377 blob = None 378 while blob: 379 message(' Building model %d'%len(models)) 380 blob = blob[0] 381 try: 382 models.append(cPickle.loads(str(blob))) 383 except Exception: 384 import traceback 385 traceback.print_exc() 386 print('Model failed') 387 else: 388 message(' <-Done') 389 try: 390 blob = curs.fetchone() 391 except Exception: 392 blob = None 393 curs = None 394 else: 395 for modelName in extras: 396 try: 397 model = cPickle.load(open(modelName,'rb')) 398 except Exception: 399 import traceback 400 print('problems with model %s:'%modelName) 401 traceback.print_exc() 402 else: 403 models.append(model) 404 nModels = len(models) 405 pickVects = {} 406 halfwayPts = [1e8]*len(models) 407 for whichModel,model in enumerate(models): 408 tmpD = dataSet 409 try: 410 seed = model._randomSeed 411 except AttributeError: 412 pass 413 else: 414 DataUtils.InitRandomNumbers(seed) 415 if details.shuffleActivities: 416 DataUtils.RandomizeActivities(tmpD, 417 shuffle=1) 418 if hasattr(model,'_splitFrac') and (details.doHoldout or details.doTraining): 419 trainIdx,testIdx = SplitData.SplitIndices(tmpD.GetNPts(),model._splitFrac, 420 silent=1) 421 if details.filterFrac != 0.0: 422 trainFilt,temp = DataUtils.FilterData(tmpD,details.filterVal, 423 details.filterFrac,-1, 424 indicesToUse=trainIdx, 425 indicesOnly=1) 426 testIdx += temp 427 trainIdx = trainFilt 428 if details.doTraining: 429 testIdx,trainIdx = trainIdx,testIdx 430 else: 431 testIdx = range(tmpD.GetNPts()) 432 433 message('screening %d examples'%(len(testIdx))) 434 nTrueActives,screenRes = ScreenModel(model,descs,tmpD,picking=details.activeTgt, 435 indices=testIdx, 436 errorEstimate=details.errorEstimate) 437 message('accumulating') 438 runningCounts = AccumulateCounts(screenRes, 439 sortIt=details.sortIt, 440 thresh=details.threshold) 441 if details.pickBase: 442 pickFile = open('%s.%d.picks'%(details.pickBase,whichModel+1),'w+') 443 else: 444 pickFile = None 445 446 447 for i,entry in enumerate(runningCounts): 448 entry = runningCounts[i] 449 selPts[i] = entry[0] 450 final[i][0] += entry[1] 451 final[i][1] += entry[2] 452 v = pickVects.get(i,[]) 453 v.append(entry[1]) 454 pickVects[i] = v 455 counts[i] += 1 456 if pickFile: 457 pickFile.write('%s\n'%(entry[0])) 458 if entry[1] >= nTrueActives/2 and entry[2]<halfwayPts[whichModel]: 459 halfwayPts[whichModel]=entry[2] 460 message('Halfway point: %d\n'%halfwayPts[whichModel]) 461 462 if details.plotFile: 463 MakePlot(details,final,counts,pickVects,nModels,nTrueActs=nTrueActives) 464 else: 465 if nModels>1: 466 print('#Index\tAvg_num_correct\tConf90Pct\tAvg_num_picked\tNum_picks\tlast_selection') 467 else: 468 print('#Index\tAvg_num_correct\tAvg_num_picked\tNum_picks\tlast_selection') 469 470 i = 0 471 while i < nPts and counts[i] != 0: 472 if nModels>1: 473 mean,sd = Stats.MeanAndDev(pickVects[i]) 474 confInterval = Stats.GetConfidenceInterval(sd,len(pickVects[i]),level=90) 475 print('%d\t%f\t%f\t%f\t%d\t%s'%(i+1,final[i][0]/counts[i],confInterval, 476 final[i][1]/counts[i], 477 counts[i],str(selPts[i]))) 478 else: 479 print('%d\t%f\t%f\t%d\t%s'%(i+1,final[i][0]/counts[i], 480 final[i][1]/counts[i], 481 counts[i],str(selPts[i]))) 482 i += 1 483 484 mean,sd = Stats.MeanAndDev(halfwayPts) 485 print('Halfway point: %.2f(%.2f)'%(mean,sd)) 486