Package rdkit :: Package ML :: Package DecTree :: Module Forest
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.DecTree.Forest

  1  # 
  2  #  Copyright (C) 2000-2008  greg Landrum 
  3  # 
  4  """ code for dealing with forests (collections) of decision trees 
  5   
  6  **NOTE** This code should be obsolete now that ML.Composite.Composite is up and running. 
  7   
  8  """ 
  9  from __future__ import print_function 
 10  from rdkit.six.moves import cPickle 
 11  import numpy 
 12  from rdkit.ML.DecTree import CrossValidate,PruneTree 
 13   
14 -class Forest(object):
15 """a forest of unique decision trees. 16 17 adding an existing tree just results in its count field being incremented 18 and the errors being averaged. 19 20 typical usage: 21 22 1) grow the forest with AddTree until happy with it 23 24 2) call AverageErrors to calculate the average error values 25 26 3) call SortTrees to put things in order by either error or count 27 28 """
29 - def MakeHistogram(self):
30 """ creates a histogram of error/count pairs 31 32 """ 33 nExamples = len(self.treeList) 34 histo = [] 35 i = 1 36 lastErr = self.errList[0] 37 countHere = self.countList[0] 38 eps = 0.001 39 while i < nExamples: 40 if self.errList[i]-lastErr > eps: 41 histo.append((lastErr,countHere)) 42 lastErr = self.errList[i] 43 countHere = self.countList[i] 44 else: 45 countHere = countHere + self.countList[i] 46 i = i + 1 47 48 return histo
49
50 - def CollectVotes(self,example):
51 """ collects votes across every member of the forest for the given example 52 53 **Returns** 54 55 a list of the results 56 57 """ 58 nTrees = len(self.treeList) 59 votes = [0]*nTrees 60 for i in range(nTrees): 61 votes[i] = self.treeList[i].ClassifyExample(example) 62 return votes
63
64 - def ClassifyExample(self,example):
65 """ classifies the given example using the entire forest 66 67 **returns** a result and a measure of confidence in it. 68 69 **FIX:** statistics sucks... I'm not seeing an obvious way to get 70 the confidence intervals. For that matter, I'm not seeing 71 an unobvious way. 72 73 For now, this is just treated as a voting problem with the confidence 74 measure being the percent of trees which voted for the winning result. 75 """ 76 self.treeVotes = self.CollectVotes(example) 77 votes = [0]*len(self._nPossible) 78 for i in range(len(self.treeList)): 79 res = self.treeVotes[i] 80 votes[res] = votes[res] + self.countList[i] 81 82 totVotes = sum(votes) 83 res = argmax(votes) 84 #print 'v:',res,votes,totVotes 85 return res,float(votes[res])/float(totVotes)
86
87 - def GetVoteDetails(self):
88 """ Returns the details of the last vote the forest conducted 89 90 this will be an empty list if no voting has yet been done 91 92 """ 93 return self.treeVotes
94
95 - def Grow(self,examples,attrs,nPossibleVals,nTries=10,pruneIt=0, 96 lessGreedy=0):
97 """ Grows the forest by adding trees 98 99 **Arguments** 100 101 - examples: the examples to be used for training 102 103 - attrs: a list of the attributes to be used in training 104 105 - nPossibleVals: a list with the number of possible values each variable 106 (as well as the result) can take on 107 108 - nTries: the number of new trees to add 109 110 - pruneIt: a toggle for whether or not the tree should be pruned 111 112 - lessGreedy: toggles the use of a less greedy construction algorithm where 113 each possible tree root is used. The best tree from each step is actually 114 added to the forest. 115 116 """ 117 self._nPossible = nPossibleVals 118 for i in range(nTries): 119 tree,frac = CrossValidate.CrossValidationDriver(examples,attrs,nPossibleVals, 120 silent=1,calcTotalError=1, 121 lessGreedy=lessGreedy) 122 if pruneIt: 123 tree,frac2 = PruneTree.PruneTree(tree,tree.GetTrainingExamples(), 124 tree.GetTestExamples(), 125 minimizeTestErrorOnly=0) 126 print('prune: ', frac,frac2) 127 frac = frac2 128 self.AddTree(tree,frac) 129 if i % (nTries/10) == 0: 130 print('Cycle: % 4d'%(i))
131
132 - def Pickle(self,fileName='foo.pkl'):
133 """ Writes this forest off to a file so that it can be easily loaded later 134 135 **Arguments** 136 137 fileName is the name of the file to be written 138 139 """ 140 pFile = open(fileName,'wb+') 141 cPickle.dump(self,pFile,1) 142 pFile.close()
143
144 - def AddTree(self,tree,error):
145 """ Adds a tree to the forest 146 147 If an identical tree is already present, its count is incremented 148 149 **Arguments** 150 151 - tree: the new tree 152 153 - error: its error value 154 155 **NOTE:** the errList is run as an accumulator, 156 you probably want to call AverageErrors after finishing the forest 157 158 """ 159 if tree in self.treeList: 160 idx = self.treeList.index(tree) 161 self.errList[idx] = self.errList[idx]+error 162 self.countList[idx] = self.countList[idx] + 1 163 else: 164 self.treeList.append(tree) 165 self.errList.append(error) 166 self.countList.append(1)
167
168 - def AverageErrors(self):
169 """ convert summed error to average error 170 171 This does the conversion in place 172 """ 173 self.errList = [x/y for x,y in zip(self.errList,self.countList)]
174
175 - def SortTrees(self,sortOnError=1):
176 """ sorts the list of trees 177 178 **Arguments** 179 180 sortOnError: toggles sorting on the trees' errors rather than their counts 181 182 """ 183 if sortOnError: 184 order = numpy.argsort(self.errList) 185 else: 186 order = numpy.argsort(self.countList) 187 188 # these elaborate contortions are required because, at the time this 189 # code was written, Numeric arrays didn't unpickle so well... 190 self.treeList = [self.treeList[x] for x in order] 191 self.countList = [self.countList[x] for x in order] 192 self.errList = [self.errList[x] for x in order]
193
194 - def GetTree(self,i):
195 return self.treeList[i]
196 - def SetTree(self,i,val):
197 self.treeList[i] = val
198
199 - def GetCount(self,i):
200 return self.countList[i]
201 - def SetCount(self,i,val):
202 self.countList[i] = val
203
204 - def GetError(self,i):
205 return self.errList[i]
206 - def SetError(self,i,val):
207 self.errList[i] = val
208
209 - def GetDataTuple(self,i):
210 """ returns all relevant data about a particular tree in the forest 211 212 **Arguments** 213 214 i: an integer indicating which tree should be returned 215 216 **Returns** 217 218 a 3-tuple consisting of: 219 220 1) the tree 221 222 2) its count 223 224 3) its error 225 """ 226 return (self.treeList[i],self.countList[i],self.errList[i])
227
228 - def SetDataTuple(self,i,tup):
229 """ sets all relevant data for a particular tree in the forest 230 231 **Arguments** 232 233 - i: an integer indicating which tree should be returned 234 235 - tup: a 3-tuple consisting of: 236 237 1) the tree 238 239 2) its count 240 241 3) its error 242 """ 243 self.treeList[i],self.countList[i],self.errList[i] = tup
244
245 - def GetAllData(self):
246 """ Returns everything we know 247 248 **Returns** 249 250 a 3-tuple consisting of: 251 252 1) our list of trees 253 254 2) our list of tree counts 255 256 3) our list of tree errors 257 258 """ 259 return (self.treeList,self.countList,self.errList)
260
261 - def __len__(self):
262 """ allows len(forest) to work 263 264 """ 265 return len(self.treeList)
266
267 - def __getitem__(self,which):
268 """ allows forest[i] to work. return the data tuple 269 270 """ 271 return self.GetDataTuple(which)
272
273 - def __str__(self):
274 """ allows the forest to show itself as a string 275 276 """ 277 outStr= 'Forest\n' 278 for i in range(len(self.treeList)): 279 outStr = outStr + \ 280 ' Tree % 4d: % 5d occurances %%% 5.2f average error\n'%(i,self.countList[i], 281 100.*self.errList[i]) 282 return outStr
283
284 - def __init__(self):
285 self.treeList=[] 286 self.errList=[] 287 self.countList=[] 288 self.treeVotes=[]
289 290 if __name__ == '__main__': 291 from rdkit.ML.DecTree import DecTree 292 f = Forest() 293 n = DecTree.DecTreeNode(None,'foo') 294 f.AddTree(n,0.5) 295 f.AddTree(n,0.5) 296 f.AverageErrors() 297 f.SortTrees() 298 print(f) 299