Package rdkit :: Package ML :: Package DecTree :: Module CrossValidate
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.DecTree.CrossValidate

  1  # 
  2  #  Copyright (C) 2000  greg Landrum 
  3  # 
  4  """ handles doing cross validation with decision trees 
  5   
  6  This is, perhaps, a little misleading.  For the purposes of this module, 
  7  cross validation == evaluating the accuracy of a tree. 
  8   
  9   
 10  """ 
 11  from __future__ import print_function 
 12  from rdkit.ML.DecTree import ID3 
 13  from rdkit.ML.Data import SplitData 
 14  import numpy 
 15  from rdkit.six.moves import xrange 
 16   
17 -def ChooseOptimalRoot(examples,trainExamples,testExamples,attrs, 18 nPossibleVals,treeBuilder,nQuantBounds=[], 19 **kwargs):
20 """ loops through all possible tree roots and chooses the one which produces the best tree 21 22 **Arguments** 23 24 - examples: the full set of examples 25 26 - trainExamples: the training examples 27 28 - testExamples: the testing examples 29 30 - attrs: a list of attributes to consider in the tree building 31 32 - nPossibleVals: a list of the number of possible values each variable can adopt 33 34 - treeBuilder: the function to be used to actually build the tree 35 36 - nQuantBounds: an optional list. If present, it's assumed that the builder 37 algorithm takes this argument as well (for building QuantTrees) 38 39 **Returns** 40 41 The best tree found 42 43 **Notes** 44 45 1) Trees are built using _trainExamples_ 46 47 2) Testing of each tree (to determine which is best) is done using _CrossValidate_ and 48 the entire set of data (i.e. all of _examples_) 49 50 3) _trainExamples_ is not used at all, which immediately raises the question of 51 why it's even being passed in 52 53 """ 54 attrs = attrs[:] 55 if nQuantBounds: 56 for i in range(len(nQuantBounds)): 57 if nQuantBounds[i]==-1 and i in attrs: 58 attrs.remove(i) 59 nAttrs = len(attrs) 60 trees = [None]*nAttrs 61 errs = [0]*nAttrs 62 errs[0] = 1e6 63 64 for i in xrange(1,nAttrs): 65 argD = {'initialVar':attrs[i]} 66 argD.update(kwargs) 67 if nQuantBounds is None or nQuantBounds == []: 68 trees[i] = treeBuilder(trainExamples,attrs,nPossibleVals,**argd) 69 else: 70 trees[i] = treeBuilder(trainExamples,attrs,nPossibleVals,nQuantBounds,**argD) 71 if trees[i]: 72 errs[i],foo = CrossValidate(trees[i],examples,appendExamples=0) 73 else: 74 errs[i] = 1e6 75 best = numpy.argmin(errs) 76 # FIX: this used to say 'trees[i]', could that possibly have been right? 77 return trees[best]
78
79 -def CrossValidate(tree,testExamples,appendExamples=0):
80 """ Determines the classification error for the testExamples 81 82 **Arguments** 83 84 - tree: a decision tree (or anything supporting a _ClassifyExample()_ method) 85 86 - testExamples: a list of examples to be used for testing 87 88 - appendExamples: a toggle which is passed along to the tree as it does 89 the classification. The trees can use this to store the examples they 90 classify locally. 91 92 **Returns** 93 94 a 2-tuple consisting of: 95 96 1) the percent error of the tree 97 98 2) a list of misclassified examples 99 100 """ 101 nTest = len(testExamples) 102 nBad = 0 103 badExamples = [] 104 for i in xrange(nTest): 105 testEx = testExamples[i] 106 trueRes = testEx[-1] 107 res = tree.ClassifyExample(testEx,appendExamples) 108 if (trueRes != res).any(): 109 badExamples.append(testEx) 110 nBad += 1 111 112 113 return float(nBad)/nTest,badExamples
114
115 -def CrossValidationDriver(examples,attrs,nPossibleVals,holdOutFrac=.3,silent=0, 116 calcTotalError=0,treeBuilder=ID3.ID3Boot,lessGreedy=0, 117 startAt=None, 118 nQuantBounds=[], 119 maxDepth=-1, 120 **kwargs):
121 """ Driver function for building trees and doing cross validation 122 123 **Arguments** 124 125 - examples: the full set of examples 126 127 - attrs: a list of attributes to consider in the tree building 128 129 - nPossibleVals: a list of the number of possible values each variable can adopt 130 131 - holdOutFrac: the fraction of the data which should be reserved for the hold-out set 132 (used to calculate the error) 133 134 - silent: a toggle used to control how much visual noise this makes as it goes. 135 136 - calcTotalError: a toggle used to indicate whether the classification error 137 of the tree should be calculated using the entire data set (when true) or just 138 the training hold out set (when false) 139 140 - treeBuilder: the function to call to build the tree 141 142 - lessGreedy: toggles use of the less greedy tree growth algorithm (see 143 _ChooseOptimalRoot_). 144 145 - startAt: forces the tree to be rooted at this descriptor 146 147 - nQuantBounds: an optional list. If present, it's assumed that the builder 148 algorithm takes this argument as well (for building QuantTrees) 149 150 - maxDepth: an optional integer. If present, it's assumed that the builder 151 algorithm takes this argument as well 152 153 **Returns** 154 155 a 2-tuple containing: 156 157 1) the tree 158 159 2) the cross-validation error of the tree 160 161 """ 162 nTot = len(examples) 163 if not kwargs.get('replacementSelection',0): 164 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac, 165 silent=1,legacy=1, 166 replacement=0) 167 else: 168 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac, 169 silent=1,legacy=0, 170 replacement=1) 171 trainExamples = [examples[x] for x in trainIndices] 172 testExamples = [examples[x] for x in testIndices] 173 174 nTrain = len(trainExamples) 175 if not silent: 176 print('Training with %d examples'%(nTrain)) 177 178 if not lessGreedy: 179 if nQuantBounds is None or nQuantBounds == []: 180 tree = treeBuilder(trainExamples,attrs,nPossibleVals, 181 initialVar=startAt,maxDepth=maxDepth,**kwargs) 182 else: 183 tree = treeBuilder(trainExamples,attrs,nPossibleVals,nQuantBounds, 184 initialVar=startAt,maxDepth=maxDepth,**kwargs) 185 else: 186 tree = ChooseOptimalRoot(examples,trainExamples,testExamples, 187 attrs,nPossibleVals,treeBuilder,nQuantBounds, 188 maxDepth=maxDepth,**kwargs) 189 190 nTest = len(testExamples) 191 if not silent: 192 print('Testing with %d examples'%nTest) 193 if not calcTotalError: 194 xValError,badExamples = CrossValidate(tree,testExamples,appendExamples=1) 195 else: 196 xValError,badExamples = CrossValidate(tree,examples,appendExamples=0) 197 if not silent: 198 print('Validation error was %%%4.2f'%(100*xValError)) 199 tree.SetBadExamples(badExamples) 200 tree.SetTrainingExamples(trainExamples) 201 tree.SetTestExamples(testExamples) 202 tree._trainIndices = trainIndices 203 return tree,xValError
204 205
206 -def TestRun():
207 """ testing code 208 209 """ 210 from rdkit.ML.DecTree import randomtest 211 examples,attrs,nPossibleVals = randomtest.GenRandomExamples(nExamples = 200) 212 tree,frac = CrossValidationDriver(examples,attrs, 213 nPossibleVals) 214 215 tree.Pickle('save.pkl') 216 217 import copy 218 t2 = copy.deepcopy(tree) 219 print('t1 == t2',tree==t2) 220 l = [tree] 221 print('t2 in [tree]', t2 in l, l.index(t2))
222 223 if __name__ == '__main__': 224 TestRun() 225