1
2
3
4 """ handles doing cross validation with decision trees
5
6 This is, perhaps, a little misleading. For the purposes of this module,
7 cross validation == evaluating the accuracy of a tree.
8
9
10 """
11 from __future__ import print_function
12 from rdkit.ML.DecTree import ID3
13 from rdkit.ML.Data import SplitData
14 import numpy
15 from rdkit.six.moves import xrange
16
17 -def ChooseOptimalRoot(examples,trainExamples,testExamples,attrs,
18 nPossibleVals,treeBuilder,nQuantBounds=[],
19 **kwargs):
20 """ loops through all possible tree roots and chooses the one which produces the best tree
21
22 **Arguments**
23
24 - examples: the full set of examples
25
26 - trainExamples: the training examples
27
28 - testExamples: the testing examples
29
30 - attrs: a list of attributes to consider in the tree building
31
32 - nPossibleVals: a list of the number of possible values each variable can adopt
33
34 - treeBuilder: the function to be used to actually build the tree
35
36 - nQuantBounds: an optional list. If present, it's assumed that the builder
37 algorithm takes this argument as well (for building QuantTrees)
38
39 **Returns**
40
41 The best tree found
42
43 **Notes**
44
45 1) Trees are built using _trainExamples_
46
47 2) Testing of each tree (to determine which is best) is done using _CrossValidate_ and
48 the entire set of data (i.e. all of _examples_)
49
50 3) _trainExamples_ is not used at all, which immediately raises the question of
51 why it's even being passed in
52
53 """
54 attrs = attrs[:]
55 if nQuantBounds:
56 for i in range(len(nQuantBounds)):
57 if nQuantBounds[i]==-1 and i in attrs:
58 attrs.remove(i)
59 nAttrs = len(attrs)
60 trees = [None]*nAttrs
61 errs = [0]*nAttrs
62 errs[0] = 1e6
63
64 for i in xrange(1,nAttrs):
65 argD = {'initialVar':attrs[i]}
66 argD.update(kwargs)
67 if nQuantBounds is None or nQuantBounds == []:
68 trees[i] = treeBuilder(trainExamples,attrs,nPossibleVals,**argd)
69 else:
70 trees[i] = treeBuilder(trainExamples,attrs,nPossibleVals,nQuantBounds,**argD)
71 if trees[i]:
72 errs[i],foo = CrossValidate(trees[i],examples,appendExamples=0)
73 else:
74 errs[i] = 1e6
75 best = numpy.argmin(errs)
76
77 return trees[best]
78
80 """ Determines the classification error for the testExamples
81
82 **Arguments**
83
84 - tree: a decision tree (or anything supporting a _ClassifyExample()_ method)
85
86 - testExamples: a list of examples to be used for testing
87
88 - appendExamples: a toggle which is passed along to the tree as it does
89 the classification. The trees can use this to store the examples they
90 classify locally.
91
92 **Returns**
93
94 a 2-tuple consisting of:
95
96 1) the percent error of the tree
97
98 2) a list of misclassified examples
99
100 """
101 nTest = len(testExamples)
102 nBad = 0
103 badExamples = []
104 for i in xrange(nTest):
105 testEx = testExamples[i]
106 trueRes = testEx[-1]
107 res = tree.ClassifyExample(testEx,appendExamples)
108 if (trueRes != res).any():
109 badExamples.append(testEx)
110 nBad += 1
111
112
113 return float(nBad)/nTest,badExamples
114
115 -def CrossValidationDriver(examples,attrs,nPossibleVals,holdOutFrac=.3,silent=0,
116 calcTotalError=0,treeBuilder=ID3.ID3Boot,lessGreedy=0,
117 startAt=None,
118 nQuantBounds=[],
119 maxDepth=-1,
120 **kwargs):
121 """ Driver function for building trees and doing cross validation
122
123 **Arguments**
124
125 - examples: the full set of examples
126
127 - attrs: a list of attributes to consider in the tree building
128
129 - nPossibleVals: a list of the number of possible values each variable can adopt
130
131 - holdOutFrac: the fraction of the data which should be reserved for the hold-out set
132 (used to calculate the error)
133
134 - silent: a toggle used to control how much visual noise this makes as it goes.
135
136 - calcTotalError: a toggle used to indicate whether the classification error
137 of the tree should be calculated using the entire data set (when true) or just
138 the training hold out set (when false)
139
140 - treeBuilder: the function to call to build the tree
141
142 - lessGreedy: toggles use of the less greedy tree growth algorithm (see
143 _ChooseOptimalRoot_).
144
145 - startAt: forces the tree to be rooted at this descriptor
146
147 - nQuantBounds: an optional list. If present, it's assumed that the builder
148 algorithm takes this argument as well (for building QuantTrees)
149
150 - maxDepth: an optional integer. If present, it's assumed that the builder
151 algorithm takes this argument as well
152
153 **Returns**
154
155 a 2-tuple containing:
156
157 1) the tree
158
159 2) the cross-validation error of the tree
160
161 """
162 nTot = len(examples)
163 if not kwargs.get('replacementSelection',0):
164 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac,
165 silent=1,legacy=1,
166 replacement=0)
167 else:
168 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac,
169 silent=1,legacy=0,
170 replacement=1)
171 trainExamples = [examples[x] for x in trainIndices]
172 testExamples = [examples[x] for x in testIndices]
173
174 nTrain = len(trainExamples)
175 if not silent:
176 print('Training with %d examples'%(nTrain))
177
178 if not lessGreedy:
179 if nQuantBounds is None or nQuantBounds == []:
180 tree = treeBuilder(trainExamples,attrs,nPossibleVals,
181 initialVar=startAt,maxDepth=maxDepth,**kwargs)
182 else:
183 tree = treeBuilder(trainExamples,attrs,nPossibleVals,nQuantBounds,
184 initialVar=startAt,maxDepth=maxDepth,**kwargs)
185 else:
186 tree = ChooseOptimalRoot(examples,trainExamples,testExamples,
187 attrs,nPossibleVals,treeBuilder,nQuantBounds,
188 maxDepth=maxDepth,**kwargs)
189
190 nTest = len(testExamples)
191 if not silent:
192 print('Testing with %d examples'%nTest)
193 if not calcTotalError:
194 xValError,badExamples = CrossValidate(tree,testExamples,appendExamples=1)
195 else:
196 xValError,badExamples = CrossValidate(tree,examples,appendExamples=0)
197 if not silent:
198 print('Validation error was %%%4.2f'%(100*xValError))
199 tree.SetBadExamples(badExamples)
200 tree.SetTrainingExamples(trainExamples)
201 tree.SetTestExamples(testExamples)
202 tree._trainIndices = trainIndices
203 return tree,xValError
204
205
207 """ testing code
208
209 """
210 from rdkit.ML.DecTree import randomtest
211 examples,attrs,nPossibleVals = randomtest.GenRandomExamples(nExamples = 200)
212 tree,frac = CrossValidationDriver(examples,attrs,
213 nPossibleVals)
214
215 tree.Pickle('save.pkl')
216
217 import copy
218 t2 = copy.deepcopy(tree)
219 print('t1 == t2',tree==t2)
220 l = [tree]
221 print('t2 in [tree]', t2 in l, l.index(t2))
222
223 if __name__ == '__main__':
224 TestRun()
225