Package rdkit :: Package ML :: Package DecTree :: Module PruneTree
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.DecTree.PruneTree

  1  # 
  2  #  Copyright (C) 2000-2008  greg Landrum and Rational Discovery LLC 
  3  # 
  4  """ Contains functionality for doing tree pruning 
  5   
  6  """ 
  7  from __future__ import print_function 
  8  import numpy 
  9  import copy 
 10  from rdkit.ML.DecTree import CrossValidate, DecTree 
 11  from rdkit.six.moves import range 
 12   
 13  _verbose = 0 
 14   
15 -def MaxCount(examples):
16 """ given a set of examples, returns the most common result code 17 18 **Arguments** 19 20 examples: a list of examples to be counted 21 22 **Returns** 23 24 the most common result code 25 26 """ 27 resList = [x[-1] for x in examples] 28 maxVal = max(resList) 29 counts = [None]*(maxVal+1) 30 for i in range(maxVal+1): 31 counts[i] = sum([x==i for x in resList]) 32 33 return numpy.argmax(counts)
34
35 -def _GetLocalError(node):
36 nWrong = 0 37 for example in node.GetExamples(): 38 pred = node.ClassifyExample(example,appendExamples=0) 39 if pred != example[-1]: 40 nWrong +=1 41 #if _verbose: print('------------------>MISS:',example,pred) 42 return nWrong
43
44 -def _Pruner(node,level=0):
45 """Recursively finds and removes the nodes whose removals improve classification 46 47 **Arguments** 48 49 - node: the tree to be pruned. The pruning data should already be contained 50 within node (i.e. node.GetExamples() should return the pruning data) 51 52 - level: (optional) the level of recursion, used only in _verbose printing 53 54 55 **Returns** 56 57 the pruned version of node 58 59 60 **Notes** 61 62 - This uses a greedy algorithm which basically does a DFS traversal of the tree, 63 removing nodes whenever possible. 64 65 - If removing a node does not affect the accuracy, it *will be* removed. We 66 favor smaller trees. 67 68 """ 69 if _verbose: print(' '*level,'<%d> '%level,'>>> Pruner') 70 children = node.GetChildren()[:] 71 72 bestTree = copy.deepcopy(node) 73 bestErr = 1e6 74 emptyChildren=[] 75 # 76 # Loop over the children of this node, removing them when doing so 77 # either improves the local error or leaves it unchanged (we're 78 # introducing a bias for simpler trees). 79 # 80 for i in range(len(children)): 81 child = children[i] 82 examples = child.GetExamples() 83 if _verbose: 84 print(' '*level,'<%d> '%level,' Child:',i,child.GetLabel()) 85 bestTree.Print() 86 print() 87 if len(examples): 88 if _verbose: print(' '*level,'<%d> '%level,' Examples',len(examples)) 89 if not child.GetTerminal(): 90 if _verbose: print(' '*level,'<%d> '%level,' Nonterminal') 91 92 workTree = copy.deepcopy(bestTree) 93 # 94 # First recurse on the child (try removing things below it) 95 # 96 newNode = _Pruner(child,level=level+1) 97 workTree.ReplaceChildIndex(i,newNode) 98 tempErr = _GetLocalError(workTree) 99 if tempErr<=bestErr: 100 bestErr = tempErr 101 bestTree = copy.deepcopy(workTree) 102 if _verbose: 103 print(' '*level,'<%d> '%level,'>->->->->->') 104 print(' '*level,'<%d> '%level,'replacing:',i,child.GetLabel()) 105 child.Print() 106 print(' '*level,'<%d> '%level,'with:') 107 newNode.Print() 108 print(' '*level,'<%d> '%level,'<-<-<-<-<-<') 109 else: 110 workTree.ReplaceChildIndex(i,child) 111 # 112 # Now try replacing the child entirely 113 # 114 bestGuess = MaxCount(child.GetExamples()) 115 newNode = DecTree.DecTreeNode(workTree,'L:%d'%(bestGuess), 116 label=bestGuess,isTerminal=1) 117 newNode.SetExamples(child.GetExamples()) 118 workTree.ReplaceChildIndex(i,newNode) 119 if _verbose: 120 print(' '*level,'<%d> '%level,'ATTEMPT:') 121 workTree.Print() 122 newErr = _GetLocalError(workTree) 123 if _verbose: print(' '*level,'<%d> '%level,'---> ',newErr,bestErr) 124 if newErr <= bestErr: 125 bestErr = newErr 126 bestTree = copy.deepcopy(workTree) 127 if _verbose: 128 print(' '*level,'<%d> '%level,'PRUNING:') 129 workTree.Print() 130 else: 131 if _verbose: print(' '*level,'<%d> '%level,'FAIL') 132 # whoops... put the child back in: 133 workTree.ReplaceChildIndex(i,child) 134 else: 135 if _verbose: print(' '*level,'<%d> '%level,' Terminal') 136 else: 137 if _verbose: print(' '*level,'<%d> '%level,' No Examples',len(examples)) 138 # 139 # FIX: we need to figure out what to do here (nodes that contain 140 # no examples in the testing set). I can concoct arguments for 141 # leaving them in and for removing them. At the moment they are 142 # left intact. 143 # 144 pass 145 146 if _verbose: print(' '*level,'<%d> '%level,'<<< out') 147 return bestTree
148
149 -def PruneTree(tree,trainExamples,testExamples,minimizeTestErrorOnly=1):
150 """ implements a reduced-error pruning of decision trees 151 152 This algorithm is described on page 69 of Mitchell's book. 153 154 Pruning can be done using just the set of testExamples (the validation set) 155 or both the testExamples and the trainExamples by setting minimizeTestErrorOnly 156 to 0. 157 158 **Arguments** 159 160 - tree: the initial tree to be pruned 161 162 - trainExamples: the examples used to train the tree 163 164 - testExamples: the examples held out for testing the tree 165 166 - minimizeTestErrorOnly: if this toggle is zero, all examples (i.e. 167 _trainExamples_ + _testExamples_ will be used to evaluate the error. 168 169 **Returns** 170 171 a 2-tuple containing: 172 173 1) the best tree 174 175 2) the best error (the one which corresponds to that tree) 176 177 """ 178 if minimizeTestErrorOnly: 179 testSet = testExamples 180 else: 181 testSet = trainExamples + testExamples 182 183 # remove any stored examples the tree may have 184 tree.ClearExamples() 185 186 # 187 # screen the test data through the tree so that we end up with the 188 # appropriate points stored at each node of the tree 189 # 190 totErr,badEx = CrossValidate.CrossValidate(tree,testSet,appendExamples=1) 191 192 193 # 194 # Prune 195 # 196 newTree = _Pruner(tree) 197 198 # 199 # And recalculate the errors 200 # 201 totErr,badEx = CrossValidate.CrossValidate(newTree,testSet) 202 newTree.SetBadExamples(badEx) 203 204 return newTree,totErr
205 206 207 # ------- 208 # testing code 209 # -------
210 -def _testRandom():
211 from rdkit.ML.DecTree import randomtest 212 #examples,attrs,nPossibleVals = randomtest.GenRandomExamples(nVars=20,randScale=0.25,nExamples = 200) 213 examples,attrs,nPossibleVals = randomtest.GenRandomExamples(nVars=10,randScale=0.5,nExamples = 200) 214 tree,frac = CrossValidate.CrossValidationDriver(examples,attrs,nPossibleVals) 215 tree.Print() 216 tree.Pickle('orig.pkl') 217 print('original error is:', frac) 218 219 print('----Pruning') 220 newTree,frac2 = PruneTree(tree,tree.GetTrainingExamples(),tree.GetTestExamples()) 221 newTree.Print() 222 print('pruned error is:',frac2) 223 newTree.Pickle('prune.pkl')
224 225
226 -def _testSpecific():
227 from rdkit.ML.DecTree import ID3 228 oPts= [ \ 229 [0,0,1,0], 230 [0,1,1,1], 231 [1,0,1,1], 232 [1,1,0,0], 233 [1,1,1,1], 234 ] 235 tPts = oPts+[[0,1,1,0],[0,1,1,0]] 236 237 tree = ID3.ID3Boot(oPts,attrs=range(3),nPossibleVals=[2]*4) 238 tree.Print() 239 err,badEx = CrossValidate.CrossValidate(tree,oPts) 240 print('original error:',err) 241 242 243 err,badEx = CrossValidate.CrossValidate(tree,tPts) 244 print('original holdout error:',err) 245 newTree,frac2 = PruneTree(tree,oPts,tPts) 246 newTree.Print() 247 err,badEx = CrossValidate.CrossValidate(newTree,tPts) 248 print('pruned holdout error is:',err) 249 print(badEx) 250 251 print(len(tree),len(newTree))
252
253 -def _testChain():
254 from rdkit.ML.DecTree import ID3 255 oPts= [ \ 256 [1,0,0,0,1], 257 [1,0,0,0,1], 258 [1,0,0,0,1], 259 [1,0,0,0,1], 260 [1,0,0,0,1], 261 [1,0,0,0,1], 262 [1,0,0,0,1], 263 [0,0,1,1,0], 264 [0,0,1,1,0], 265 [0,0,1,1,1], 266 [0,1,0,1,0], 267 [0,1,0,1,0], 268 [0,1,0,0,1], 269 ] 270 tPts = oPts 271 272 tree = ID3.ID3Boot(oPts,attrs=range(len(oPts[0])-1),nPossibleVals=[2]*len(oPts[0])) 273 tree.Print() 274 err,badEx = CrossValidate.CrossValidate(tree,oPts) 275 print('original error:',err) 276 277 278 err,badEx = CrossValidate.CrossValidate(tree,tPts) 279 print('original holdout error:',err) 280 newTree,frac2 = PruneTree(tree,oPts,tPts) 281 newTree.Print() 282 err,badEx = CrossValidate.CrossValidate(newTree,tPts) 283 print('pruned holdout error is:',err) 284 print(badEx)
285 286 287 if __name__ == '__main__': 288 _verbose=1 289 #_testRandom() 290 291 _testChain() 292