Package rdkit ::
Package ML ::
Module GrowComposite
|
|
1
2
3
4
5
6
7
8
9
10
11
12 """ command line utility for growing composite models
13
14 **Usage**
15
16 _GrowComposite [optional args] filename_
17
18 **Command Line Arguments**
19
20 - -n *count*: number of new models to build
21
22 - -C *pickle file name*: name of file containing composite upon which to build.
23
24 - --inNote *note*: note to be used in loading composite models from the database
25 for growing
26
27 - --balTable *table name*: table from which to take the original data set
28 (for balancing)
29
30 - --balWeight *weight*: (between 0 and 1) weighting factor for the new data
31 (for balancing). OR, *weight* can be a list of weights
32
33 - --balCnt *count*: number of individual models in the balanced composite
34 (for balancing)
35
36 - --balH: use only the holdout set from the original data set in the balancing
37 (for balancing)
38
39 - --balT: use only the training set from the original data set in the balancing
40 (for balancing)
41
42 - -S: shuffle the original data set
43 (for balancing)
44
45 - -r: randomize the activities of the original data set
46 (for balancing)
47
48 - -N *note*: note to be attached to the grown composite when it's saved in the
49 database
50
51 - --outNote *note*: equivalent to -N
52
53 - -o *filename*: name of an output file to hold the pickled composite after
54 it has been grown.
55 If multiple balance weights are used, the weights will be added to
56 the filenames.
57
58 - -L *limit*: provide an (integer) limit on individual model complexity
59
60 - -d *database name*: instead of reading the data from a QDAT file,
61 pull it from a database. In this case, the _filename_ argument
62 provides the name of the database table containing the data set.
63
64 - -p *tablename*: store persistence data in the database
65 in table *tablename*
66
67 - -l: locks the random number generator to give consistent sets
68 of training and hold-out data. This is primarily intended
69 for testing purposes.
70
71 - -g: be less greedy when training the models.
72
73 - -G *number*: force trees to be rooted at descriptor *number*.
74
75 - -D: show a detailed breakdown of the composite model performance
76 across the training and, when appropriate, hold-out sets.
77
78 - -t *threshold value*: use high-confidence predictions for the final
79 analysis of the hold-out data.
80
81 - -q *list string*: Add QuantTrees to the composite and use the list
82 specified in *list string* as the number of target quantization
83 bounds for each descriptor. Don't forget to include 0's at the
84 beginning and end of *list string* for the name and value fields.
85 For example, if there are 4 descriptors and you want 2 quant bounds
86 apiece, you would use _-q "[0,2,2,2,2,0]"_.
87 Two special cases:
88 1) If you would like to ignore a descriptor in the model building,
89 use '-1' for its number of quant bounds.
90 2) If you have integer valued data that should not be quantized
91 further, enter 0 for that descriptor.
92
93 - -V: print the version number and exit
94
95 """
96 from __future__ import print_function
97 from rdkit import RDConfig
98 import numpy
99 from rdkit.ML.Data import DataUtils,SplitData
100 from rdkit.ML import ScreenComposite,BuildComposite
101 from rdkit.ML.Composite import AdjustComposite
102 from rdkit.Dbase.DbConnection import DbConnect
103 from rdkit.ML import CompositeRun
104 from rdkit.six.moves import cPickle
105 import sys,time,types
106
107 _runDetails = CompositeRun.CompositeRun()
108
109 __VERSION_STRING="0.5.0"
110
111 _verbose = 1
113 """ emits messages to _sys.stdout_
114 override this in modules which import this one to redirect output
115
116 **Arguments**
117
118 - msg: the string to be displayed
119
120 """
121 if _verbose: sys.stdout.write('%s\n'%(msg))
122
123 -def GrowIt(details,composite,progressCallback=None,
124 saveIt=1,setDescNames=0,data=None):
125 """ does the actual work of building a composite model
126
127 **Arguments**
128
129 - details: a _CompositeRun.CompositeRun_ object containing details
130 (options, parameters, etc.) about the run
131
132 - composite: the composite model to grow
133
134 - progressCallback: (optional) a function which is called with a single
135 argument (the number of models built so far) after each model is built.
136
137 - saveIt: (optional) if this is nonzero, the resulting model will be pickled
138 and dumped to the filename specified in _details.outName_
139
140 - setDescNames: (optional) if nonzero, the composite's _SetInputOrder()_ method
141 will be called using the results of the data set's _GetVarNames()_ method;
142 it is assumed that the details object has a _descNames attribute which
143 is passed to the composites _SetDescriptorNames()_ method. Otherwise
144 (the default), _SetDescriptorNames()_ gets the results of _GetVarNames()_.
145
146 - data: (optional) the data set to be used. If this is not provided, the
147 data set described in details will be used.
148
149 **Returns**
150
151 the enlarged composite model
152
153
154 """
155 details.rundate = time.asctime()
156
157 if data is None:
158 fName = details.tableName.strip()
159 if details.outName == '':
160 details.outName = fName + '.pkl'
161 if details.dbName == '':
162 data = DataUtils.BuildQuantDataSet(fName)
163 elif details.qBounds != []:
164 details.tableName = fName
165 data = details.GetDataSet()
166 else:
167 data = DataUtils.DBToQuantData(details.dbName,fName,quantName=details.qTableName,
168 user=details.dbUser,password=details.dbPassword)
169
170 nExamples = data.GetNPts()
171 seed = composite._randomSeed
172 DataUtils.InitRandomNumbers(seed)
173 testExamples = []
174 if details.shuffleActivities == 1:
175 DataUtils.RandomizeActivities(data,shuffle=1,runDetails=details)
176 elif details.randomActivities == 1:
177 DataUtils.RandomizeActivities(data,shuffle=0,runDetails=details)
178
179 namedExamples = data.GetNamedData()
180 trainExamples = namedExamples
181 nExamples = len(trainExamples)
182 message('Training with %d examples'%(nExamples))
183 message('\t%d descriptors'%(len(trainExamples[0])-2))
184 nVars = data.GetNVars()
185 nPossibleVals = composite.nPossibleVals
186 attrs = range(1,nVars+1)
187
188 if details.useTrees:
189 from rdkit.ML.DecTree import CrossValidate,PruneTree
190 if details.qBounds != []:
191 from rdkit.ML.DecTree import BuildQuantTree
192 builder = BuildQuantTree.QuantTreeBoot
193 else:
194 from rdkit.ML.DecTree import ID3
195 builder = ID3.ID3Boot
196 driver = CrossValidate.CrossValidationDriver
197 pruner = PruneTree.PruneTree
198
199 if setDescNames:
200 composite.SetInputOrder(data.GetVarNames())
201 composite.Grow(trainExamples,attrs,[0]+nPossibleVals,
202 buildDriver=driver,
203 pruner=pruner,
204 nTries=details.nModels,pruneIt=details.pruneIt,
205 lessGreedy=details.lessGreedy,needsQuantization=0,
206 treeBuilder=builder,nQuantBounds=details.qBounds,
207 startAt=details.startAt,
208 maxDepth=details.limitDepth,
209 progressCallback=progressCallback,
210 silent=not _verbose)
211
212
213 else:
214 from rdkit.ML.Neural import CrossValidate
215 driver = CrossValidate.CrossValidationDriver
216 composite.Grow(trainExamples,attrs,[0]+nPossibleVals,nTries=details.nModels,
217 buildDriver=driver,needsQuantization=0)
218
219 composite.AverageErrors()
220 composite.SortModels()
221 modelList,counts,avgErrs = composite.GetAllData()
222 counts = numpy.array(counts)
223 avgErrs = numpy.array(avgErrs)
224 composite._varNames = data.GetVarNames()
225
226 for i in range(len(modelList)):
227 modelList[i].NameModel(composite._varNames)
228
229
230 weightedErrs = counts*avgErrs
231 averageErr = sum(weightedErrs)/sum(counts)
232 devs = (avgErrs - averageErr)
233 devs = devs * counts
234 devs = numpy.sqrt(devs*devs)
235 avgDev = sum(devs)/sum(counts)
236 if _verbose:
237 message('# Overall Average Error: %%% 5.2f, Average Deviation: %%% 6.2f'%(100.*averageErr,100.*avgDev))
238
239 if details.bayesModel:
240 composite.Train(trainExamples,verbose=0)
241
242 badExamples = []
243 if not details.detailedRes:
244 if _verbose:
245 message('Testing all examples')
246 wrong = BuildComposite.testall(composite,namedExamples,badExamples)
247 if _verbose:
248 message('%d examples (%% %5.2f) were misclassified'%(len(wrong),100.*float(len(wrong))/float(len(namedExamples))))
249 _runDetails.overall_error = float(len(wrong))/len(namedExamples)
250
251 if details.detailedRes:
252 if _verbose:
253 message('\nEntire data set:')
254 resTup = ScreenComposite.ShowVoteResults(range(data.GetNPts()),data,composite,
255 nPossibleVals[-1],details.threshold)
256 nGood,nBad,nSkip,avgGood,avgBad,avgSkip,voteTab = resTup
257 nPts = len(namedExamples)
258 nClass = nGood+nBad
259 _runDetails.overall_error = float(nBad) / nClass
260 _runDetails.overall_correct_conf = avgGood
261 _runDetails.overall_incorrect_conf = avgBad
262 _runDetails.overall_result_matrix = repr(voteTab)
263 nRej = nClass-nPts
264 if nRej > 0:
265 _runDetails.overall_fraction_dropped = float(nRej)/nPts
266
267 return composite
268
270 res = []
271 if details.persistTblName and details.inNote:
272 conn = DbConnect(details.dbName,details.persistTblName)
273 mdls = conn.GetData(fields='MODEL',where="where note='%s'"%(details.inNote))
274 for row in mdls:
275 rawD = row[0]
276 res.append(cPickle.loads(str(rawD)))
277 elif details.composFileName:
278 res.append(cPickle.load(open(details.composFileName,'rb')))
279 return res
280
282 """ balances the composite using the parameters provided in details
283
284 **Arguments**
285
286 - details a _CompositeRun.RunDetails_ object
287
288 - composite: the composite model to be balanced
289
290 - data1: (optional) if provided, this should be the
291 data set used to construct the original models
292
293 - data2: (optional) if provided, this should be the
294 data set used to construct the new individual models
295
296 """
297 if not details.balCnt or details.balCnt > len(composite):
298 return composite
299 message("Balancing Composite")
300
301
302
303
304
305 if data1 is None:
306 message("\tReading First Data Set")
307 fName = details.balTable.strip()
308 tmp = details.tableName
309 details.tableName = fName
310 dbName = details.dbName
311 details.dbName = details.balDb
312 data1 = details.GetDataSet()
313 details.tableName = tmp
314 details.dbName = dbName
315 if data1 is None:
316 return composite
317 details.splitFrac = composite._splitFrac
318 details.randomSeed = composite._randomSeed
319 DataUtils.InitRandomNumbers(details.randomSeed)
320 if details.shuffleActivities == 1:
321 DataUtils.RandomizeActivities(data1,shuffle=1,runDetails=details)
322 elif details.randomActivities == 1:
323 DataUtils.RandomizeActivities(data1,shuffle=0,runDetails=details)
324 namedExamples = data1.GetNamedData()
325 if details.balDoHoldout or details.balDoTrain:
326 trainIdx,testIdx = SplitData.SplitIndices(len(namedExamples),details.splitFrac,
327 silent=1)
328 trainExamples = [namedExamples[x] for x in trainIdx]
329 testExamples = [namedExamples[x] for x in testIdx]
330 if details.filterFrac != 0.0:
331 trainIdx,temp = DataUtils.FilterData(trainExamples,details.filterVal,
332 details.filterFrac,-1,
333 indicesOnly=1)
334 tmp = [trainExamples[x] for x in trainIdx]
335 testExamples += [trainExamples[x] for x in temp]
336 trainExamples = tmp
337 if details.balDoHoldout:
338 testExamples,trainExamples = trainExamples,testExamples
339 else:
340 trainExamples = namedExamples
341 dataSet1 = trainExamples
342 cols1 = [x.upper() for x in data1.GetVarNames()]
343 data1 = None
344
345
346
347
348 if data2 is None:
349 message("\tReading Second Data Set")
350 data2 = details.GetDataSet()
351 if data2 is None:
352 return composite
353 details.splitFrac = composite._splitFrac
354 details.randomSeed = composite._randomSeed
355 DataUtils.InitRandomNumbers(details.randomSeed)
356 if details.shuffleActivities == 1:
357 DataUtils.RandomizeActivities(data2,shuffle=1,runDetails=details)
358 elif details.randomActivities == 1:
359 DataUtils.RandomizeActivities(data2,shuffle=0,runDetails=details)
360 dataSet2 = data2.GetNamedData()
361 cols2 = [x.upper() for x in data2.GetVarNames()]
362 data2 = None
363
364
365 res = []
366 weights = details.balWeight
367 if type(weights) not in (types.TupleType,types.ListType):
368 weights = (weights,)
369 for weight in weights:
370 message("\tBalancing with Weight: %.4f"%(weight))
371 res.append(AdjustComposite.BalanceComposite(composite,dataSet1,dataSet2,
372 weight,
373 details.balCnt,
374 names1=cols1,names2=cols2))
375 return res
376
378 """ prints the version number
379
380 """
381 print('This is GrowComposite.py version %s'%(__VERSION_STRING))
382 if includeArgs:
383 import sys
384 print('command line was:')
385 print(' '.join(sys.argv))
386
388 """ provides a list of arguments for when this is used from the command line
389
390 """
391 import sys
392 print(__doc__)
393 sys.exit(-1)
394
396 """ initializes a details object with default values
397
398 **Arguments**
399
400 - details: (optional) a _CompositeRun.CompositeRun_ object.
401 If this is not provided, the global _runDetails will be used.
402
403 **Returns**
404
405 the initialized _CompositeRun_ object.
406
407
408 """
409 if runDetails is None: runDetails = _runDetails
410 return CompositeRun.SetDefaults(runDetails)
411
413 """ parses command line arguments and updates _runDetails_
414
415 **Arguments**
416
417 - runDetails: a _CompositeRun.CompositeRun_ object.
418
419 """
420 import getopt
421 args,extra = getopt.getopt(sys.argv[1:],'P:o:n:p:b:sf:F:v:hlgd:rSTt:Q:q:DVG:L:C:N:',
422 ['inNote=','outNote=','balTable=','balWeight=','balCnt=',
423 'balH','balT','balDb=',])
424 runDetails.inNote=''
425 runDetails.composFileName=''
426 runDetails.balTable=''
427 runDetails.balWeight=(0.5,)
428 runDetails.balCnt=0
429 runDetails.balDoHoldout=0
430 runDetails.balDoTrain=0
431 runDetails.balDb=''
432 for arg,val in args:
433 if arg == '-n':
434 runDetails.nModels = int(val)
435 elif arg == '-C':
436 runDetails.composFileName=val
437 elif arg=='--balTable':
438 runDetails.balTable=val
439 elif arg=='--balWeight':
440 runDetails.balWeight=eval(val)
441 if type(runDetails.balWeight) not in (types.TupleType,types.ListType):
442 runDetails.balWeight=(runDetails.balWeight,)
443 elif arg=='--balCnt':
444 runDetails.balCnt=int(val)
445 elif arg=='--balH':
446 runDetails.balDoHoldout=1
447 elif arg=='--balT':
448 runDetails.balDoTrain=1
449 elif arg=='--balDb':
450 runDetails.balDb=val
451 elif arg == '--inNote':
452 runDetails.inNote=val
453 elif arg == '-N' or arg=='--outNote':
454 runDetails.note=val
455 elif arg == '-o':
456 runDetails.outName = val
457 elif arg == '-p':
458 runDetails.persistTblName=val
459 elif arg == '-r':
460 runDetails.randomActivities = 1
461 elif arg == '-S':
462 runDetails.shuffleActivities = 1
463 elif arg == '-h':
464 Usage()
465 elif arg == '-l':
466 runDetails.lockRandom = 1
467 elif arg == '-g':
468 runDetails.lessGreedy=1
469 elif arg == '-G':
470 runDetails.startAt = int(val)
471 elif arg == '-d':
472 runDetails.dbName=val
473 elif arg == '-T':
474 runDetails.useTrees = 0
475 elif arg == '-t':
476 runDetails.threshold=float(val)
477 elif arg == '-D':
478 runDetails.detailedRes = 1
479 elif arg == '-L':
480 runDetails.limitDepth = int(val)
481 elif arg == '-q':
482 qBounds = eval(val)
483 assert type(qBounds) in (types.TupleType,types.ListType),'bad argument type for -q, specify a list as a string'
484 runDetails.qBoundCount=val
485 runDetails.qBounds = qBounds
486 elif arg == '-Q':
487 qBounds = eval(val)
488 assert type(qBounds) in [type([]),type(())],'bad argument type for -Q, specify a list as a string'
489 runDetails.activityBounds=qBounds
490 runDetails.activityBoundsVals=val
491 elif arg == '-V':
492 ShowVersion()
493 sys.exit(0)
494 else:
495 print('bad argument:',arg,file=sys.stderr)
496 Usage()
497 runDetails.tableName=extra[0]
498 if not runDetails.balDb:
499 runDetails.balDb=runDetails.dbName
500 if __name__ == '__main__':
501 if len(sys.argv) < 2:
502 Usage()
503
504 _runDetails.cmd = ' '.join(sys.argv)
505 SetDefaults(_runDetails)
506 ParseArgs(_runDetails)
507
508 ShowVersion(includeArgs=1)
509
510 initModels = GetComposites(_runDetails)
511 nModels = len(initModels)
512 if nModels>1:
513 for i in range(nModels):
514 sys.stderr.write('---------------------------------\n\tDoing %d of %d\n---------------------------------\n'%(i+1,nModels))
515 composite = GrowIt(_runDetails,initModels[i],setDescNames=1)
516 if _runDetails.balTable and _runDetails.balCnt:
517 composites = BalanceComposite(_runDetails,composite)
518 else:
519 composites=[composite]
520 for mdl in composites:
521 mdl.ClearModelExamples()
522 if _runDetails.outName:
523 nWeights = len(_runDetails.balWeight)
524 if nWeights==1:
525 outName = _runDetails.outName
526 composites[0].Pickle(outName)
527 else:
528 for i in range(nWeights):
529 weight = int(100*_runDetails.balWeight[i])
530 model = composites[i]
531 outName = '%s.%d.pkl'%(_runDetails.outName.split('.pkl')[0],weight)
532 model.Pickle(outName)
533 if _runDetails.persistTblName and _runDetails.dbName:
534 message('Updating results table %s:%s'%(_runDetails.dbName,_runDetails.persistTblName))
535 if(len(_runDetails.balWeight))>1:
536 message('WARNING: updating results table with models having different weights')
537
538 for i in range(len(composites)):
539 _runDetails.model = cPickle.dumps(composites[i])
540 _runDetails.Store(db=_runDetails.dbName,table=_runDetails.persistTblName)
541 elif nModels==1:
542 composite = GrowIt(_runDetails,initModels[0],setDescNames=1)
543 if _runDetails.balTable and _runDetails.balCnt:
544 composites = BalanceComposite(_runDetails,composite)
545 else:
546 composites=[composite]
547 for mdl in composites:
548 mdl.ClearModelExamples()
549 if _runDetails.outName:
550 nWeights = len(_runDetails.balWeight)
551 if nWeights==1:
552 outName = _runDetails.outName
553 composites[0].Pickle(outName)
554 else:
555 for i in range(nWeights):
556 weight = int(100*_runDetails.balWeight[i])
557 model = composites[i]
558 outName = '%s.%d.pkl'%(_runDetails.outName.split('.pkl')[0],weight)
559 model.Pickle(outName)
560 if _runDetails.persistTblName and _runDetails.dbName:
561 message('Updating results table %s:%s'%(_runDetails.dbName,_runDetails.persistTblName))
562 if(len(composites))>1:
563 message('WARNING: updating results table with models having different weights')
564 for i in range(len(composites)):
565 _runDetails.model = cPickle.dumps(composites[i])
566 _runDetails.Store(db=_runDetails.dbName,table=_runDetails.persistTblName)
567 else:
568 message("No models found")
569