Package rdkit :: Package ML :: Module BuildComposite
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.BuildComposite

   1  # $Id$ 
   2  # 
   3  #  Copyright (C) 2000-2008  greg Landrum and Rational Discovery LLC 
   4  # 
   5  #   @@ All Rights Reserved @@ 
   6  #  This file is part of the RDKit. 
   7  #  The contents are covered by the terms of the BSD license 
   8  #  which is included in the file license.txt, found at the root 
   9  #  of the RDKit source tree. 
  10  # 
  11  """ command line utility for building composite models 
  12   
  13  #DOC 
  14   
  15  **Usage** 
  16   
  17    BuildComposite [optional args] filename 
  18   
  19  Unless indicated otherwise (via command line arguments), _filename_ is 
  20  a QDAT file. 
  21   
  22  **Command Line Arguments** 
  23   
  24    - -o *filename*: name of the output file for the pickled composite 
  25   
  26    - -n *num*: number of separate models to add to the composite 
  27   
  28    - -p *tablename*: store persistence data in the database 
  29       in table *tablename* 
  30   
  31    - -N *note*: attach some arbitrary text to the persistence data 
  32   
  33    - -b *filename*: name of the text file to hold examples from the 
  34       holdout set which are misclassified 
  35   
  36    - -s: split the data into training and hold-out sets before building 
  37       the composite 
  38   
  39    - -f *frac*: the fraction of data to use in the training set when the 
  40       data is split 
  41   
  42    - -r: randomize the activities (for testing purposes).  This ignores 
  43       the initial distribution of activity values and produces each 
  44       possible activity value with equal likliehood. 
  45   
  46    - -S: shuffle the activities (for testing purposes) This produces 
  47       a permutation of the input activity values. 
  48   
  49    - -l: locks the random number generator to give consistent sets 
  50       of training and hold-out data.  This is primarily intended 
  51       for testing purposes. 
  52   
  53    - -B: use a so-called Bayesian composite model. 
  54   
  55    - -d *database name*: instead of reading the data from a QDAT file, 
  56       pull it from a database.  In this case, the _filename_ argument 
  57       provides the name of the database table containing the data set. 
  58   
  59    - -D: show a detailed breakdown of the composite model performance 
  60       across the training and, when appropriate, hold-out sets. 
  61        
  62    - -P *pickle file name*: write out the pickled data set to the file 
  63   
  64    - -F *filter frac*: filters the data before training to change the 
  65       distribution of activity values in the training set.  *filter 
  66       frac* is the fraction of the training set that should have the 
  67       target value.  **See note below on data filtering.** 
  68   
  69    - -v *filter value*: filters the data before training to change the 
  70       distribution of activity values in the training set. *filter 
  71       value* is the target value to use in filtering.  **See note below 
  72       on data filtering.** 
  73        
  74    - --modelFiltFrac *model filter frac*: Similar to filter frac above, 
  75       in this case the data is filtered for each model in the composite 
  76       rather than a single overall filter for a composite. *model 
  77       filter frac* is the fraction of the training set for each model 
  78       that should have the target value (*model filter value*). 
  79   
  80    - --modelFiltVal *model filter value*: target value to use for 
  81       filtering data before training each model in the composite. 
  82        
  83    - -t *threshold value*: use high-confidence predictions for the 
  84       final analysis of the hold-out data. 
  85   
  86    - -Q *list string*: the values of quantization bounds for the 
  87       activity value.  See the _-q_ argument for the format of *list 
  88       string*. 
  89   
  90    - --nRuns *count*: build *count* composite models 
  91   
  92    - --prune: prune any models built 
  93   
  94    - -h: print a usage message and exit. 
  95   
  96    - -V: print the version number and exit 
  97   
  98    *-*-*-*-*-*-*-*- Tree-Related Options -*-*-*-*-*-*-*-* 
  99   
 100    - -g: be less greedy when training the models. 
 101   
 102    - -G *number*: force trees to be rooted at descriptor *number*. 
 103   
 104    - -L *limit*: provide an (integer) limit on individual model 
 105       complexity 
 106   
 107    - -q *list string*: Add QuantTrees to the composite and use the list 
 108       specified in *list string* as the number of target quantization 
 109       bounds for each descriptor.  Don't forget to include 0's at the 
 110       beginning and end of *list string* for the name and value fields. 
 111       For example, if there are 4 descriptors and you want 2 quant 
 112       bounds apiece, you would use _-q "[0,2,2,2,2,0]"_. 
 113       Two special cases: 
 114         1) If you would like to ignore a descriptor in the model 
 115            building, use '-1' for its number of quant bounds. 
 116         2) If you have integer valued data that should not be quantized 
 117            further, enter 0 for that descriptor. 
 118   
 119    - --recycle: allow descriptors to be used more than once in a tree         
 120   
 121    - --randomDescriptors=val: toggles growing random forests with val 
 122        randomly-selected descriptors available at each node. 
 123   
 124   
 125    *-*-*-*-*-*-*-*- KNN-Related Options -*-*-*-*-*-*-*-* 
 126   
 127    - --doKnn: use K-Nearest Neighbors models 
 128   
 129    - --knnK=*value*: the value of K to use in the KNN models 
 130   
 131    - --knnTanimoto: use the Tanimoto metric in KNN models 
 132     
 133    - --knnEuclid: use a Euclidean metric in KNN models 
 134     
 135    *-*-*-*-*-*-*- Naive Bayes Classifier Options -*-*-*-*-*-*-*-* 
 136    - --doNaiveBayes : use Naive Bayes classifiers 
 137     
 138    - --mEstimateVal : the value to be used in the m-estimate formula 
 139        If this is greater than 0.0, we use it to compute the conditional 
 140        probabilities by the m-estimate 
 141   
 142    *-*-*-*-*-*-*-*- SVM-Related Options -*-*-*-*-*-*-*-* 
 143   
 144    **** NOTE: THESE ARE DISABLED ****   
 145   
 146  ##   - --doSVM: use Support-vector machines 
 147   
 148  ##   - --svmKernel=*kernel*: choose the type of kernel to be used for 
 149  ##     the SVMs.  Options are: 
 150  ##     The default is: 
 151   
 152  ##   - --svmType=*type*: choose the type of support-vector machine 
 153  ##     to be used.  Options are: 
 154  ##     The default is: 
 155   
 156  ##   - --svmGamma=*gamma*: provide the gamma value for the SVMs.  If this 
 157  ##     is not provided, a grid search will be carried out to determine an 
 158  ##     optimal *gamma* value for each SVM. 
 159       
 160  ##   - --svmCost=*cost*: provide the cost value for the SVMs.  If this is 
 161  ##     not provided, a grid search will be carried out to determine an 
 162  ##     optimal *cost* value for each SVM. 
 163   
 164  ##   - --svmWeights=*weights*: provide the weight values for the 
 165  ##     activities.  If provided this should be a sequence of (label, 
 166  ##     weight) 2-tuples *nActs* long.  If not provided, a weight of 1 
 167  ##     will be used for each activity. 
 168   
 169  ##   - --svmEps=*epsilon*: provide the epsilon value used to determine 
 170  ##     when the SVM has converged.  Defaults to 0.001 
 171       
 172  ##   - --svmDegree=*degree*: provide the degree of the kernel (when 
 173  ##     sensible) Defaults to 3 
 174   
 175  ##   - --svmCoeff=*coeff*: provide the coefficient for the kernel (when 
 176  ##     sensible) Defaults to 0 
 177       
 178  ##   - --svmNu=*nu*: provide the nu value for the kernel (when sensible) 
 179  ##     Defaults to 0.5 
 180   
 181  ##   - --svmDataType=*float*: if the data is contains only 1 and 0 s, specify by 
 182  ##     using binary. Defaults to float 
 183       
 184  ##   - --svmCache=*cache*: provide the size of the memory cache (in MB) 
 185  ##     to be used while building the SVM.  Defaults to 40 
 186   
 187  **Notes** 
 188   
 189    - *Data filtering*: When there is a large disparity between the 
 190      numbers of points with various activity levels present in the 
 191      training set it is sometimes desirable to train on a more 
 192      homogeneous data set.  This can be accomplished using filtering. 
 193      The filtering process works by selecting a particular target 
 194      fraction and target value.  For example, in a case where 95% of 
 195      the original training set has activity 0 and ony 5% activity 1, we 
 196      could filter (by randomly removing points with activity 0) so that 
 197      30% of the data set used to build the composite has activity 1. 
 198        
 199   
 200  """ 
 201  from __future__ import print_function 
 202  import sys,time 
 203  import math 
 204  import numpy 
 205  from rdkit.six.moves import cPickle 
 206  from rdkit import RDConfig 
 207  from rdkit.utils import listutils 
 208  from rdkit.ML.Composite import Composite,BayesComposite 
 209  #from ML.SVM import SVMClassificationModel as SVM 
 210  from rdkit.ML.Data import DataUtils,SplitData 
 211  from rdkit.ML import ScreenComposite 
 212  from rdkit.Dbase import DbModule 
 213  from rdkit.Dbase.DbConnection import DbConnect 
 214  from rdkit.ML import CompositeRun 
 215  from rdkit import DataStructs 
 216   
 217  _runDetails = CompositeRun.CompositeRun() 
 218   
 219  __VERSION_STRING="3.2.3" 
 220   
 221  _verbose = 1 
222 -def message(msg):
223 """ emits messages to _sys.stdout_ 224 override this in modules which import this one to redirect output 225 226 **Arguments** 227 228 - msg: the string to be displayed 229 230 """ 231 if _verbose: sys.stdout.write('%s\n'%(msg))
232 233
234 -def testall(composite,examples,badExamples=[]):
235 """ screens a number of examples past a composite 236 237 **Arguments** 238 239 - composite: a composite model 240 241 - examples: a list of examples (with results) to be screened 242 243 - badExamples: a list to which misclassified examples are appended 244 245 **Returns** 246 247 a list of 2-tuples containing: 248 249 1) a vote 250 251 2) a confidence 252 253 these are the votes and confidence levels for **misclassified** examples 254 255 """ 256 wrong = [] 257 for example in examples: 258 if composite.GetActivityQuantBounds(): 259 answer = composite.QuantizeActivity(example)[-1] 260 else: 261 answer = example[-1] 262 res,conf = composite.ClassifyExample(example) 263 if res != answer: 264 wrong.append((res,conf)) 265 badExamples.append(example) 266 267 return wrong
268
269 -def GetCommandLine(details):
270 """ #DOC 271 272 """ 273 args = ['BuildComposite'] 274 args.append('-n %d'%(details.nModels)) 275 if details.filterFrac != 0.0: args.append('-F %.3f -v %d'%(details.filterFrac,details.filterVal)) 276 if details.modelFilterFrac != 0.0: args.append('--modelFiltFrac=%.3f --modelFiltVal=%d'%(details.modelFilterFrac, 277 details.modelFilterVal)) 278 if details.splitRun: args.append('-s -f %.3f'%(details.splitFrac)) 279 if details.shuffleActivities: args.append('-S') 280 if details.randomActivities: args.append('-r') 281 if details.threshold > 0.0: args.append('-t %.3f'%(details.threshold)) 282 if details.activityBounds: args.append('-Q "%s"'%(details.activityBoundsVals)) 283 if details.dbName: args.append('-d %s'%(details.dbName)) 284 if details.detailedRes: args.append('-D') 285 if hasattr(details,'noScreen') and details.noScreen: args.append('--noScreen') 286 if details.persistTblName and details.dbName: 287 args.append('-p %s'%(details.persistTblName)) 288 if details.note: 289 args.append('-N %s'%(details.note)) 290 if details.useTrees: 291 if details.limitDepth>0: args.append('-L %d'%(details.limitDepth)) 292 if details.lessGreedy: args.append('-g') 293 if details.qBounds: 294 shortBounds = listutils.CompactListRepr(details.qBounds) 295 if details.qBounds: args.append('-q "%s"'%(shortBounds)) 296 else: 297 if details.qBounds: args.append('-q "%s"'%(details.qBoundCount)) 298 299 if details.pruneIt: args.append('--prune') 300 if details.startAt: args.append('-G %d'%details.startAt) 301 if details.recycleVars: args.append('--recycle') 302 if details.randomDescriptors: args.append('--randomDescriptors=%d'%details.randomDescriptors) 303 if details.useSigTrees: 304 args.append('--doSigTree') 305 if details.limitDepth>0: args.append('-L %d'%(details.limitDepth)) 306 if details.randomDescriptors: 307 args.append('--randomDescriptors=%d'%details.randomDescriptors) 308 309 if details.useKNN: 310 args.append('--doKnn --knnK %d'%(details.knnNeighs)) 311 if details.knnDistFunc=='Tanimoto': 312 args.append('--knnTanimoto') 313 else: 314 args.append('--knnEuclid') 315 316 if details.useNaiveBayes: 317 args.append('--doNaiveBayes') 318 if details.mEstimateVal >= 0.0 : 319 args.append('--mEstimateVal=%.3f'%details.mEstimateVal) 320 321 ## if details.useSVM: 322 ## args.append('--doSVM') 323 ## if details.svmKernel: 324 ## for k in SVM.kernels.keys(): 325 ## if SVM.kernels[k]==details.svmKernel: 326 ## args.append('--svmKernel=%s'%k) 327 ## break 328 ## if details.svmType: 329 ## for k in SVM.machineTypes.keys(): 330 ## if SVM.machineTypes[k]==details.svmType: 331 ## args.append('--svmType=%s'%k) 332 ## break 333 ## if details.svmGamma: 334 ## args.append('--svmGamma=%f'%details.svmGamma) 335 ## if details.svmCost: 336 ## args.append('--svmCost=%f'%details.svmCost) 337 ## if details.svmWeights: 338 ## args.append("--svmWeights='%s'"%str(details.svmWeights)) 339 ## if details.svmDegree: 340 ## args.append('--svmDegree=%d'%details.svmDegree) 341 ## if details.svmCoeff: 342 ## args.append('--svmCoeff=%d'%details.svmCoeff) 343 ## if details.svmEps: 344 ## args.append('--svmEps=%f'%details.svmEps) 345 ## if details.svmNu: 346 ## args.append('--svmNu=%f'%details.svmNu) 347 ## if details.svmCache: 348 ## args.append('--svmCache=%d'%details.svmCache) 349 ## if detail.svmDataType: 350 ## args.append('--svmDataType=%s'%details.svmDataType) 351 ## if not details.svmShrink: 352 ## args.append('--svmShrink') 353 354 if details.replacementSelection: args.append('--replacementSelection') 355 356 357 # this should always be last: 358 if details.tableName: args.append(details.tableName) 359 360 return ' '.join(args)
361
362 -def RunOnData(details,data,progressCallback=None,saveIt=1,setDescNames=0):
363 nExamples = data.GetNPts() 364 if details.lockRandom: 365 seed = details.randomSeed 366 else: 367 import random 368 seed = (random.randint(0,1e6),random.randint(0,1e6)) 369 DataUtils.InitRandomNumbers(seed) 370 testExamples = [] 371 if details.shuffleActivities == 1: 372 DataUtils.RandomizeActivities(data,shuffle=1,runDetails=details) 373 elif details.randomActivities == 1: 374 DataUtils.RandomizeActivities(data,shuffle=0,runDetails=details) 375 376 namedExamples = data.GetNamedData() 377 if details.splitRun == 1: 378 trainIdx,testIdx = SplitData.SplitIndices(len(namedExamples),details.splitFrac, 379 silent=not _verbose) 380 381 trainExamples = [namedExamples[x] for x in trainIdx] 382 testExamples = [namedExamples[x] for x in testIdx] 383 else: 384 testExamples = [] 385 testIdx = [] 386 trainIdx = range(len(namedExamples)) 387 trainExamples = namedExamples 388 389 if details.filterFrac != 0.0: 390 # if we're doing quantization on the fly, we need to handle that here: 391 if hasattr(details,'activityBounds') and details.activityBounds: 392 tExamples = [] 393 bounds = details.activityBounds 394 for pt in trainExamples: 395 pt = pt[:] 396 act = pt[-1] 397 placed=0 398 bound=0 399 while not placed and bound < len(bounds): 400 if act < bounds[bound]: 401 pt[-1] = bound 402 placed = 1 403 else: 404 bound += 1 405 if not placed: 406 pt[-1] = bound 407 tExamples.append(pt) 408 else: 409 bounds = None 410 tExamples = trainExamples 411 trainIdx,temp = DataUtils.FilterData(tExamples,details.filterVal, 412 details.filterFrac,-1, 413 indicesOnly=1) 414 tmp = [trainExamples[x] for x in trainIdx] 415 testExamples += [trainExamples[x] for x in temp] 416 trainExamples = tmp 417 418 counts = DataUtils.CountResults(trainExamples,bounds=bounds) 419 ks = counts.keys() 420 ks.sort() 421 message('Result Counts in training set:') 422 for k in ks: 423 message(str((k, counts[k]))) 424 counts = DataUtils.CountResults(testExamples,bounds=bounds) 425 ks = counts.keys() 426 ks.sort() 427 message('Result Counts in test set:') 428 for k in ks: 429 message(str((k, counts[k]))) 430 nExamples = len(trainExamples) 431 message('Training with %d examples'%(nExamples)) 432 433 nVars = data.GetNVars() 434 attrs = range(1,nVars+1) 435 nPossibleVals = data.GetNPossibleVals() 436 for i in range(1,len(nPossibleVals)): 437 if nPossibleVals[i-1] == -1: 438 attrs.remove(i) 439 440 if details.pickleDataFileName != '': 441 pickleDataFile = open(details.pickleDataFileName,'wb+') 442 cPickle.dump(trainExamples,pickleDataFile) 443 cPickle.dump(testExamples,pickleDataFile) 444 pickleDataFile.close() 445 446 if details.bayesModel: 447 composite = BayesComposite.BayesComposite() 448 else: 449 composite = Composite.Composite() 450 451 composite._randomSeed = seed 452 composite._splitFrac = details.splitFrac 453 composite._shuffleActivities = details.shuffleActivities 454 composite._randomizeActivities = details.randomActivities 455 456 if hasattr(details,'filterFrac'): 457 composite._filterFrac = details.filterFrac 458 if hasattr(details,'filterVal'): 459 composite._filterVal = details.filterVal 460 461 composite.SetModelFilterData(details.modelFilterFrac, details.modelFilterVal) 462 463 composite.SetActivityQuantBounds(details.activityBounds) 464 nPossibleVals = data.GetNPossibleVals() 465 if details.activityBounds: 466 nPossibleVals[-1] = len(details.activityBounds)+1 467 468 if setDescNames: 469 composite.SetInputOrder(data.GetVarNames()) 470 composite.SetDescriptorNames(details._descNames) 471 else: 472 composite.SetDescriptorNames(data.GetVarNames()) 473 composite.SetActivityQuantBounds(details.activityBounds) 474 if details.nModels==1: 475 details.internalHoldoutFrac=0.0 476 if details.useTrees: 477 from rdkit.ML.DecTree import CrossValidate,PruneTree 478 if details.qBounds != []: 479 from rdkit.ML.DecTree import BuildQuantTree 480 builder = BuildQuantTree.QuantTreeBoot 481 else: 482 from rdkit.ML.DecTree import ID3 483 builder = ID3.ID3Boot 484 driver = CrossValidate.CrossValidationDriver 485 pruner = PruneTree.PruneTree 486 487 composite.SetQuantBounds(details.qBounds) 488 nPossibleVals = data.GetNPossibleVals() 489 if details.activityBounds: 490 nPossibleVals[-1] = len(details.activityBounds)+1 491 composite.Grow(trainExamples,attrs,nPossibleVals=[0]+nPossibleVals, 492 buildDriver=driver, 493 pruner=pruner, 494 nTries=details.nModels,pruneIt=details.pruneIt, 495 lessGreedy=details.lessGreedy,needsQuantization=0, 496 treeBuilder=builder,nQuantBounds=details.qBounds, 497 startAt=details.startAt, 498 maxDepth=details.limitDepth, 499 progressCallback=progressCallback, 500 holdOutFrac=details.internalHoldoutFrac, 501 replacementSelection=details.replacementSelection, 502 recycleVars=details.recycleVars, 503 randomDescriptors=details.randomDescriptors, 504 silent=not _verbose) 505 506 elif details.useSigTrees: 507 from rdkit.ML.DecTree import CrossValidate 508 from rdkit.ML.DecTree import BuildSigTree 509 builder = BuildSigTree.SigTreeBuilder 510 driver = CrossValidate.CrossValidationDriver 511 nPossibleVals = data.GetNPossibleVals() 512 if details.activityBounds: 513 nPossibleVals[-1] = len(details.activityBounds)+1 514 if hasattr(details,'sigTreeBiasList'): 515 biasList = details.sigTreeBiasList 516 else: 517 biasList=None 518 if hasattr(details,'useCMIM'): 519 useCMIM=details.useCMIM 520 else: 521 useCMIM=0 522 if hasattr(details,'allowCollections'): 523 allowCollections = details.allowCollections 524 else: 525 allowCollections=False 526 composite.Grow(trainExamples,attrs,nPossibleVals=[0]+nPossibleVals, 527 buildDriver=driver, 528 nTries=details.nModels, 529 needsQuantization=0, 530 treeBuilder=builder, 531 maxDepth=details.limitDepth, 532 progressCallback=progressCallback, 533 holdOutFrac=details.internalHoldoutFrac, 534 replacementSelection=details.replacementSelection, 535 recycleVars=details.recycleVars, 536 randomDescriptors=details.randomDescriptors, 537 biasList=biasList, 538 useCMIM=useCMIM, 539 allowCollection=allowCollections, 540 silent=not _verbose) 541 542 elif details.useKNN: 543 from rdkit.ML.KNN import CrossValidate 544 from rdkit.ML.KNN import DistFunctions 545 546 driver = CrossValidate.CrossValidationDriver 547 dfunc = '' 548 if (details.knnDistFunc == "Euclidean") : 549 dfunc = DistFunctions.EuclideanDist 550 elif (details.knnDistFunc == "Tanimoto"): 551 dfunc = DistFunctions.TanimotoDist 552 else: 553 assert 0,"Bad KNN distance metric value" 554 555 556 composite.Grow(trainExamples, attrs, nPossibleVals=[0]+nPossibleVals, 557 buildDriver=driver, nTries=details.nModels, 558 needsQuantization=0, 559 numNeigh=details.knnNeighs, 560 holdOutFrac=details.internalHoldoutFrac, 561 distFunc=dfunc) 562 563 elif details.useNaiveBayes or details.useSigBayes: 564 from rdkit.ML.NaiveBayes import CrossValidate 565 driver = CrossValidate.CrossValidationDriver 566 if not (hasattr(details,'useSigBayes') and details.useSigBayes): 567 composite.Grow(trainExamples, attrs, nPossibleVals=[0]+nPossibleVals, 568 buildDriver=driver, nTries=details.nModels, 569 needsQuantization=0, nQuantBounds=details.qBounds, 570 holdOutFrac=details.internalHoldoutFrac, 571 replacementSelection=details.replacementSelection, 572 mEstimateVal=details.mEstimateVal, 573 silent=not _verbose) 574 else: 575 if hasattr(details,'useCMIM'): 576 useCMIM=details.useCMIM 577 else: 578 useCMIM=0 579 580 composite.Grow(trainExamples, attrs, nPossibleVals=[0]+nPossibleVals, 581 buildDriver=driver, nTries=details.nModels, 582 needsQuantization=0, nQuantBounds=details.qBounds, 583 mEstimateVal=details.mEstimateVal, 584 useSigs=True,useCMIM=useCMIM, 585 holdOutFrac=details.internalHoldoutFrac, 586 replacementSelection=details.replacementSelection, 587 silent=not _verbose) 588 589 590 591 ## elif details.useSVM: 592 ## from rdkit.ML.SVM import CrossValidate 593 ## driver = CrossValidate.CrossValidationDriver 594 ## composite.Grow(trainExamples, attrs, nPossibleVals=[0]+nPossibleVals, 595 ## buildDriver=driver, nTries=details.nModels, 596 ## needsQuantization=0, 597 ## cost=details.svmCost,gamma=details.svmGamma, 598 ## weights=details.svmWeights,degree=details.svmDegree, 599 ## type=details.svmType,kernelType=details.svmKernel, 600 ## coef0=details.svmCoeff,eps=details.svmEps,nu=details.svmNu, 601 ## cache_size=details.svmCache,shrinking=details.svmShrink, 602 ## dataType=details.svmDataType, 603 ## holdOutFrac=details.internalHoldoutFrac, 604 ## replacementSelection=details.replacementSelection, 605 ## silent=not _verbose) 606 607 else: 608 from rdkit.ML.Neural import CrossValidate 609 driver = CrossValidate.CrossValidationDriver 610 composite.Grow(trainExamples,attrs,[0]+nPossibleVals,nTries=details.nModels, 611 buildDriver=driver,needsQuantization=0) 612 613 composite.AverageErrors() 614 composite.SortModels() 615 modelList,counts,avgErrs = composite.GetAllData() 616 counts = numpy.array(counts) 617 avgErrs = numpy.array(avgErrs) 618 composite._varNames = data.GetVarNames() 619 620 for i in range(len(modelList)): 621 modelList[i].NameModel(composite._varNames) 622 623 # do final statistics 624 weightedErrs = counts*avgErrs 625 averageErr = sum(weightedErrs)/sum(counts) 626 devs = (avgErrs - averageErr) 627 devs = devs * counts 628 devs = numpy.sqrt(devs*devs) 629 avgDev = sum(devs)/sum(counts) 630 message('# Overall Average Error: %%% 5.2f, Average Deviation: %%% 6.2f'%(100.*averageErr,100.*avgDev)) 631 632 if details.bayesModel: 633 composite.Train(trainExamples,verbose=0) 634 635 # blow out the saved examples and then save the composite: 636 composite.ClearModelExamples() 637 if saveIt: 638 composite.Pickle(details.outName) 639 details.model = DbModule.binaryHolder(cPickle.dumps(composite)) 640 641 badExamples = [] 642 if not details.detailedRes and (not hasattr(details,'noScreen') or not details.noScreen): 643 if details.splitRun: 644 message('Testing all hold-out examples') 645 wrong = testall(composite,testExamples,badExamples) 646 message('%d examples (%% %5.2f) were misclassified'%(len(wrong), 647 100.*float(len(wrong))/float(len(testExamples)))) 648 _runDetails.holdout_error = float(len(wrong))/len(testExamples) 649 else: 650 message('Testing all examples') 651 wrong = testall(composite,namedExamples,badExamples) 652 message('%d examples (%% %5.2f) were misclassified'%(len(wrong), 653 100.*float(len(wrong))/float(len(namedExamples)))) 654 _runDetails.overall_error = float(len(wrong))/len(namedExamples) 655 656 if details.detailedRes: 657 message('\nEntire data set:') 658 resTup = ScreenComposite.ShowVoteResults(range(data.GetNPts()),data,composite, 659 nPossibleVals[-1],details.threshold) 660 nGood,nBad,nSkip,avgGood,avgBad,avgSkip,voteTab = resTup 661 nPts = len(namedExamples) 662 nClass = nGood+nBad 663 _runDetails.overall_error = float(nBad) / nClass 664 _runDetails.overall_correct_conf = avgGood 665 _runDetails.overall_incorrect_conf = avgBad 666 _runDetails.overall_result_matrix = repr(voteTab) 667 nRej = nClass-nPts 668 if nRej > 0: 669 _runDetails.overall_fraction_dropped = float(nRej)/nPts 670 671 if details.splitRun: 672 message('\nHold-out data:') 673 resTup = ScreenComposite.ShowVoteResults(range(len(testExamples)),testExamples, 674 composite, 675 nPossibleVals[-1],details.threshold) 676 nGood,nBad,nSkip,avgGood,avgBad,avgSkip,voteTab = resTup 677 nPts = len(testExamples) 678 nClass = nGood+nBad 679 _runDetails.holdout_error = float(nBad) / nClass 680 _runDetails.holdout_correct_conf = avgGood 681 _runDetails.holdout_incorrect_conf = avgBad 682 _runDetails.holdout_result_matrix = repr(voteTab) 683 nRej = nClass-nPts 684 if nRej > 0: 685 _runDetails.holdout_fraction_dropped = float(nRej)/nPts 686 687 if details.persistTblName and details.dbName: 688 message('Updating results table %s:%s'%(details.dbName,details.persistTblName)) 689 details.Store(db=details.dbName,table=details.persistTblName) 690 691 if details.badName != '': 692 badFile = open(details.badName,'w+') 693 for i in range(len(badExamples)): 694 ex = badExamples[i] 695 vote = wrong[i] 696 outStr = '%s\t%s\n'%(ex,vote) 697 badFile.write(outStr) 698 badFile.close() 699 700 composite.ClearModelExamples() 701 return composite
702
703 -def RunIt(details,progressCallback=None,saveIt=1,setDescNames=0):
704 """ does the actual work of building a composite model 705 706 **Arguments** 707 708 - details: a _CompositeRun.CompositeRun_ object containing details 709 (options, parameters, etc.) about the run 710 711 - progressCallback: (optional) a function which is called with a single 712 argument (the number of models built so far) after each model is built. 713 714 - saveIt: (optional) if this is nonzero, the resulting model will be pickled 715 and dumped to the filename specified in _details.outName_ 716 717 - setDescNames: (optional) if nonzero, the composite's _SetInputOrder()_ method 718 will be called using the results of the data set's _GetVarNames()_ method; 719 it is assumed that the details object has a _descNames attribute which 720 is passed to the composites _SetDescriptorNames()_ method. Otherwise 721 (the default), _SetDescriptorNames()_ gets the results of _GetVarNames()_. 722 723 **Returns** 724 725 the composite model constructed 726 727 728 """ 729 details.rundate = time.asctime() 730 731 fName = details.tableName.strip() 732 if details.outName == '': 733 details.outName = fName + '.pkl' 734 if not details.dbName: 735 if details.qBounds != []: 736 data = DataUtils.TextFileToData(fName) 737 else: 738 data = DataUtils.BuildQuantDataSet(fName) 739 elif details.useSigTrees or details.useSigBayes: 740 details.tableName = fName 741 data = details.GetDataSet(pickleCol=0,pickleClass=DataStructs.ExplicitBitVect) 742 elif details.qBounds != [] or not details.useTrees: 743 details.tableName = fName 744 data = details.GetDataSet() 745 else: 746 data = DataUtils.DBToQuantData(details.dbName,fName,quantName=details.qTableName, 747 user=details.dbUser,password=details.dbPassword) 748 749 composite = RunOnData(details,data,progressCallback=progressCallback, 750 saveIt=saveIt,setDescNames=setDescNames) 751 return composite
752 753
754 -def ShowVersion(includeArgs=0):
755 """ prints the version number 756 757 """ 758 print('This is BuildComposite.py version %s' % (__VERSION_STRING)) 759 if includeArgs: 760 import sys 761 print('command line was:') 762 print(' '.join(sys.argv))
763
764 -def Usage():
765 """ provides a list of arguments for when this is used from the command line 766 767 """ 768 import sys 769 print(__doc__) 770 sys.exit(-1)
771
772 -def SetDefaults(runDetails=None):
773 """ initializes a details object with default values 774 775 **Arguments** 776 777 - details: (optional) a _CompositeRun.CompositeRun_ object. 778 If this is not provided, the global _runDetails will be used. 779 780 **Returns** 781 782 the initialized _CompositeRun_ object. 783 784 785 """ 786 if runDetails is None: runDetails = _runDetails 787 return CompositeRun.SetDefaults(runDetails)
788
789 -def ParseArgs(runDetails):
790 """ parses command line arguments and updates _runDetails_ 791 792 **Arguments** 793 794 - runDetails: a _CompositeRun.CompositeRun_ object. 795 796 """ 797 import getopt 798 args,extra = getopt.getopt(sys.argv[1:],'P:o:n:p:b:sf:F:v:hlgd:rSTt:BQ:q:DVG:N:L:', 799 ['nRuns=','prune','profile', 800 'seed=','noScreen', 801 802 'modelFiltFrac=', 'modelFiltVal=', 803 804 'recycle','randomDescriptors=', 805 806 'doKnn','knnK=','knnTanimoto','knnEuclid', 807 808 'doSigTree','allowCollections', 809 810 'doNaiveBayes', 'mEstimateVal=', 811 'doSigBayes', 812 813 ## 'doSVM','svmKernel=','svmType=','svmGamma=', 814 ## 'svmCost=','svmWeights=','svmDegree=', 815 ## 'svmCoeff=','svmEps=','svmNu=','svmCache=', 816 ## 'svmShrink','svmDataType=', 817 818 'replacementSelection', 819 820 ]) 821 runDetails.profileIt=0 822 for arg,val in args: 823 if arg == '-n': 824 runDetails.nModels = int(val) 825 elif arg == '-N': 826 runDetails.note=val 827 elif arg == '-o': 828 runDetails.outName = val 829 elif arg == '-Q': 830 qBounds = eval(val) 831 assert type(qBounds) in [type([]),type(())],'bad argument type for -Q, specify a list as a string' 832 runDetails.activityBounds=qBounds 833 runDetails.activityBoundsVals=val 834 elif arg == '-p': 835 runDetails.persistTblName=val 836 elif arg == '-P': 837 runDetails.pickleDataFileName= val 838 elif arg == '-r': 839 runDetails.randomActivities = 1 840 elif arg == '-S': 841 runDetails.shuffleActivities = 1 842 elif arg == '-b': 843 runDetails.badName = val 844 elif arg == '-B': 845 runDetails.bayesModels=1 846 elif arg == '-s': 847 runDetails.splitRun = 1 848 elif arg == '-f': 849 runDetails.splitFrac=float(val) 850 elif arg == '-F': 851 runDetails.filterFrac=float(val) 852 elif arg == '-v': 853 runDetails.filterVal=float(val) 854 elif arg == '-l': 855 runDetails.lockRandom = 1 856 elif arg == '-g': 857 runDetails.lessGreedy=1 858 elif arg == '-G': 859 runDetails.startAt = int(val) 860 elif arg == '-d': 861 runDetails.dbName=val 862 elif arg == '-T': 863 runDetails.useTrees = 0 864 elif arg == '-t': 865 runDetails.threshold=float(val) 866 elif arg == '-D': 867 runDetails.detailedRes = 1 868 elif arg == '-L': 869 runDetails.limitDepth = int(val) 870 elif arg == '-q': 871 qBounds = eval(val) 872 assert type(qBounds) in [type([]),type(())],'bad argument type for -q, specify a list as a string' 873 runDetails.qBoundCount=val 874 runDetails.qBounds = qBounds 875 elif arg == '-V': 876 ShowVersion() 877 sys.exit(0) 878 elif arg == '--nRuns': 879 runDetails.nRuns = int(val) 880 elif arg == '--modelFiltFrac': 881 runDetails.modelFilterFrac=float(val) 882 elif arg == '--modelFiltVal': 883 runDetails.modelFilterVal=float(val) 884 elif arg == '--prune': 885 runDetails.pruneIt=1 886 elif arg == '--profile': 887 runDetails.profileIt=1 888 889 elif arg == '--recycle': 890 runDetails.recycleVars=1 891 elif arg == '--randomDescriptors': 892 runDetails.randomDescriptors=int(val) 893 894 elif arg == '--doKnn': 895 runDetails.useKNN=1 896 runDetails.useTrees=0 897 ## runDetails.useSVM=0 898 runDetails.useNaiveBayes=0 899 elif arg == '--knnK': 900 runDetails.knnNeighs = int(val) 901 elif arg == '--knnTanimoto': 902 runDetails.knnDistFunc="Tanimoto" 903 elif arg == '--knnEuclid': 904 runDetails.knnDistFunc="Euclidean" 905 906 elif arg == '--doSigTree': 907 ## runDetails.useSVM=0 908 runDetails.useKNN=0 909 runDetails.useTrees=0 910 runDetails.useNaiveBayes=0 911 runDetails.useSigTrees=1 912 elif arg == '--allowCollections': 913 runDetails.allowCollections=True 914 915 elif arg == '--doNaiveBayes': 916 runDetails.useNaiveBayes=1 917 ## runDetails.useSVM=0 918 runDetails.useKNN=0 919 runDetails.useTrees=0 920 runDetails.useSigBayes=0 921 elif arg == '--doSigBayes': 922 runDetails.useSigBayes=1 923 runDetails.useNaiveBayes=0 924 ## runDetails.useSVM=0 925 runDetails.useKNN=0 926 runDetails.useTrees=0 927 elif arg == '--mEstimateVal': 928 runDetails.mEstimateVal=float(val) 929 930 ## elif arg == '--doSVM': 931 ## runDetails.useSVM=1 932 ## runDetails.useKNN=0 933 ## runDetails.useTrees=0 934 ## runDetails.useNaiveBayes=0 935 ## elif arg == '--svmKernel': 936 ## if val not in SVM.kernels.keys(): 937 ## message('kernel %s not in list of available kernels:\n%s\n'%(val,SVM.kernels.keys())) 938 ## sys.exit(-1) 939 ## else: 940 ## runDetails.svmKernel=SVM.kernels[val] 941 ## elif arg == '--svmType': 942 ## if val not in SVM.machineTypes.keys(): 943 ## message('type %s not in list of available machines:\n%s\n'%(val,SVM.machineTypes.keys())) 944 ## sys.exit(-1) 945 ## else: 946 ## runDetails.svmType=SVM.machineTypes[val] 947 ## elif arg == '--svmGamma': 948 ## runDetails.svmGamma = float(val) 949 ## elif arg == '--svmCost': 950 ## runDetails.svmCost = float(val) 951 ## elif arg == '--svmWeights': 952 ## # FIX: this is dangerous 953 ## runDetails.svmWeights = eval(val) 954 ## elif arg == '--svmDegree': 955 ## runDetails.svmDegree = int(val) 956 ## elif arg == '--svmCoeff': 957 ## runDetails.svmCoeff = float(val) 958 ## elif arg == '--svmEps': 959 ## runDetails.svmEps = float(val) 960 ## elif arg == '--svmNu': 961 ## runDetails.svmNu = float(val) 962 ## elif arg == '--svmCache': 963 ## runDetails.svmCache = int(val) 964 ## elif arg == '--svmShrink': 965 ## runDetails.svmShrink = 0 966 ## elif arg == '--svmDataType': 967 ## runDetails.svmDataType=val 968 969 elif arg== '--seed': 970 # FIX: dangerous 971 runDetails.randomSeed = eval(val) 972 973 elif arg== '--noScreen': 974 runDetails.noScreen=1 975 976 elif arg== '--replacementSelection': 977 runDetails.replacementSelection = 1 978 979 elif arg == '-h': 980 Usage() 981 982 else: 983 Usage() 984 runDetails.tableName=extra[0]
985 986 if __name__ == '__main__': 987 if len(sys.argv) < 2: 988 Usage() 989 990 _runDetails.cmd = ' '.join(sys.argv) 991 SetDefaults(_runDetails) 992 ParseArgs(_runDetails) 993 994 995 ShowVersion(includeArgs=1) 996 997 if _runDetails.nRuns > 1: 998 for i in range(_runDetails.nRuns): 999 sys.stderr.write('---------------------------------\n\tDoing %d of %d\n---------------------------------\n'%(i+1,_runDetails.nRuns)) 1000 RunIt(_runDetails) 1001 else: 1002 if _runDetails.profileIt: 1003 import hotshot,hotshot.stats 1004 prof=hotshot.Profile('prof.dat') 1005 prof.runcall(RunIt,_runDetails) 1006 stats = hotshot.stats.load('prof.dat') 1007 stats.strip_dirs() 1008 stats.sort_stats('time','calls') 1009 stats.print_stats(30) 1010 else: 1011 RunIt(_runDetails) 1012