1
2
3
4
5
6
7
8 """
9
10 """
11 from __future__ import print_function
12 import numpy
13 import random
14 from rdkit.ML.DecTree import QuantTree, ID3
15 from rdkit.ML.InfoTheory import entropy
16 from rdkit.ML.Data import Quantize
17 from rdkit.six.moves import range
18
19 -def FindBest(resCodes,examples,nBoundsPerVar,nPossibleRes,
20 nPossibleVals,attrs,exIndices=None,**kwargs):
21 bestGain =-1e6
22 best = -1
23 bestBounds = []
24
25 if exIndices is None:
26 exIndices=list(range(len(examples)))
27
28 if not len(exIndices):
29 return best,bestGain,bestBounds
30
31 nToTake = kwargs.get('randomDescriptors',0)
32 if nToTake > 0:
33 nAttrs = len(attrs)
34 if nToTake < nAttrs:
35 ids = list(range(nAttrs))
36 random.shuffle(ids,random=random.random)
37 tmp = [attrs[x] for x in ids[:nToTake]]
38 attrs = tmp
39
40 for var in attrs:
41 nBounds = nBoundsPerVar[var]
42 if nBounds > 0:
43
44 try:
45 vTable = [examples[x][var] for x in exIndices]
46 except IndexError:
47 print('index error retrieving variable: %d'%var)
48 raise
49 qBounds,gainHere = Quantize.FindVarMultQuantBounds(vTable,nBounds,
50 resCodes,nPossibleRes)
51
52 elif nBounds==0:
53 vTable = ID3.GenVarTable((examples[x] for x in exIndices),
54 nPossibleVals,[var])[0]
55 gainHere = entropy.InfoGain(vTable)
56 qBounds = []
57 else:
58 gainHere = -1e6
59 qBounds = []
60 if gainHere > bestGain:
61 bestGain = gainHere
62 bestBounds = qBounds
63 best = var
64 elif bestGain==gainHere:
65 if len(qBounds)<len(bestBounds):
66 best = var
67 bestBounds = qBounds
68 if best == -1:
69 print('best unaltered')
70 print('\tattrs:',attrs)
71 print('\tnBounds:',take(nBoundsPerVar,attrs))
72 print('\texamples:')
73 for example in (examples[x] for x in exIndices):
74 print('\t\t',example)
75
76
77 if 0:
78 print('BEST:',len(exIndices),best,bestGain,bestBounds)
79 if(len(exIndices)<10):
80 print(len(exIndices),len(resCodes),len(examples))
81 exs = [examples[x] for x in exIndices]
82 vals = [x[best] for x in exs]
83 sortIdx = numpy.argsort(vals)
84 sortVals = [exs[x] for x in sortIdx]
85 sortResults = [resCodes[x] for x in sortIdx]
86 for i in range(len(vals)):
87 print(' ',i,['%.4f'%x for x in sortVals[i][1:-1]],sortResults[i])
88 return best,bestGain,bestBounds
89
90
91 -def BuildQuantTree(examples,target,attrs,nPossibleVals,nBoundsPerVar,
92 depth=0,maxDepth=-1,exIndices=None,**kwargs):
93 """
94 **Arguments**
95
96 - examples: a list of lists (nInstances x nVariables+1) of variable
97 values + instance values
98
99 - target: an int
100
101 - attrs: a list of ints indicating which variables can be used in the tree
102
103 - nPossibleVals: a list containing the number of possible values of
104 every variable.
105
106 - nBoundsPerVar: the number of bounds to include for each variable
107
108 - depth: (optional) the current depth in the tree
109
110 - maxDepth: (optional) the maximum depth to which the tree
111 will be grown
112 **Returns**
113
114 a QuantTree.QuantTreeNode with the decision tree
115
116 **NOTE:** This code cannot bootstrap (start from nothing...)
117 use _QuantTreeBoot_ (below) for that.
118 """
119 tree=QuantTree.QuantTreeNode(None,'node')
120 tree.SetData(-666)
121 nPossibleRes = nPossibleVals[-1]
122
123 if exIndices is None:
124 exIndices=list(range(len(examples)))
125
126
127 resCodes = [int(x[-1]) for x in (examples[y] for y in exIndices)]
128 counts = [0]*nPossibleRes
129 for res in resCodes:
130 counts[res] += 1
131 nzCounts = numpy.nonzero(counts)[0]
132
133 if len(nzCounts) == 1:
134
135
136
137 res = nzCounts[0]
138 tree.SetLabel(res)
139 tree.SetName(str(res))
140 tree.SetTerminal(1)
141 elif len(attrs) == 0 or (maxDepth>=0 and depth>maxDepth):
142
143
144
145
146 v = numpy.argmax(counts)
147 tree.SetLabel(v)
148 tree.SetName('%d?'%v)
149 tree.SetTerminal(1)
150 else:
151
152 best,bestGain,bestBounds = FindBest(resCodes,examples,nBoundsPerVar,
153 nPossibleRes,nPossibleVals,attrs,
154 exIndices=exIndices,
155 **kwargs)
156
157 nextAttrs = attrs[:]
158 if not kwargs.get('recycleVars',0):
159 nextAttrs.remove(best)
160
161
162 tree.SetName('Var: %d'%(best))
163 tree.SetLabel(best)
164 tree.SetQuantBounds(bestBounds)
165 tree.SetTerminal(0)
166
167
168
169 indices = exIndices[:]
170 if len(bestBounds) > 0:
171 for bound in bestBounds:
172 nextExamples = []
173 for index in indices[:]:
174 ex = examples[index]
175 if ex[best] < bound:
176 nextExamples.append(index)
177 indices.remove(index)
178
179 if len(nextExamples) == 0:
180
181
182
183 v = numpy.argmax(counts)
184 tree.AddChild('%d'%v,label=v,data=0.0,isTerminal=1)
185 else:
186
187 tree.AddChildNode(BuildQuantTree(examples,best,
188 nextAttrs,nPossibleVals,
189 nBoundsPerVar,
190 depth=depth+1,maxDepth=maxDepth,
191 exIndices=nextExamples,
192 **kwargs))
193
194 nextExamples = []
195 for index in indices:
196 nextExamples.append(index)
197 if len(nextExamples) == 0:
198 v = numpy.argmax(counts)
199 tree.AddChild('%d'%v,label=v,data=0.0,isTerminal=1)
200 else:
201 tree.AddChildNode(BuildQuantTree(examples,best,
202 nextAttrs,nPossibleVals,
203 nBoundsPerVar,
204 depth=depth+1,maxDepth=maxDepth,
205 exIndices=nextExamples,
206 **kwargs))
207 else:
208 for val in range(nPossibleVals[best]):
209 nextExamples = []
210 for idx in exIndices:
211 if examples[idx][best] == val:
212 nextExamples.append(idx)
213 if len(nextExamples) == 0:
214 v = numpy.argmax(counts)
215 tree.AddChild('%d'%v,label=v,data=0.0,isTerminal=1)
216 else:
217 tree.AddChildNode(BuildQuantTree(examples,best,
218 nextAttrs,nPossibleVals,
219 nBoundsPerVar,
220 depth=depth+1,maxDepth=maxDepth,
221 exIndices=nextExamples,
222 **kwargs))
223 return tree
224
225 -def QuantTreeBoot(examples,attrs,nPossibleVals,nBoundsPerVar,initialVar=None,
226 maxDepth=-1,**kwargs):
227 """ Bootstrapping code for the QuantTree
228
229 If _initialVar_ is not set, the algorithm will automatically
230 choose the first variable in the tree (the standard greedy
231 approach). Otherwise, _initialVar_ will be used as the first
232 split.
233
234 """
235 attrs = list(attrs)
236 for i in range(len(nBoundsPerVar)):
237 if nBoundsPerVar[i]==-1 and i in attrs:
238 attrs.remove(i)
239
240 tree=QuantTree.QuantTreeNode(None,'node')
241 nPossibleRes = nPossibleVals[-1]
242 tree._nResultCodes = nPossibleRes
243
244 resCodes = [int(x[-1]) for x in examples]
245 counts = [0]*nPossibleRes
246 for res in resCodes:
247 counts[res] += 1
248 if initialVar is None:
249 best,gainHere,qBounds = FindBest(resCodes,examples,nBoundsPerVar,
250 nPossibleRes,nPossibleVals,attrs,
251 **kwargs)
252 else:
253 best = initialVar
254 if nBoundsPerVar[best] > 0:
255 vTable = map(lambda x,z=best:x[z],examples)
256 qBounds,gainHere = Quantize.FindVarMultQuantBounds(vTable,nBoundsPerVar[best],
257 resCodes,nPossibleRes)
258 elif nBoundsPerVar[best] == 0:
259 vTable = ID3.GenVarTable(examples,nPossibleVals,[best])[0]
260 gainHere = entropy.InfoGain(vTable)
261 qBounds = []
262 else:
263 gainHere = -1e6
264 qBounds = []
265
266 tree.SetName('Var: %d'%(best))
267 tree.SetData(gainHere)
268 tree.SetLabel(best)
269 tree.SetTerminal(0)
270 tree.SetQuantBounds(qBounds)
271 nextAttrs = list(attrs)
272 if not kwargs.get('recycleVars',0):
273 nextAttrs.remove(best)
274
275 indices = list(range(len(examples)))
276 if len(qBounds) > 0:
277 for bound in qBounds:
278 nextExamples = []
279 for index in list(indices):
280 ex = examples[index]
281 if ex[best] < bound:
282 nextExamples.append(ex)
283 indices.remove(index)
284
285 if len(nextExamples):
286 tree.AddChildNode(BuildQuantTree(nextExamples,best,
287 nextAttrs,nPossibleVals,
288 nBoundsPerVar,
289 depth=1,maxDepth=maxDepth,
290 **kwargs))
291 else:
292 v = numpy.argmax(counts)
293 tree.AddChild('%d??'%(v),label=v,data=0.0,isTerminal=1)
294
295 nextExamples = []
296 for index in indices:
297 nextExamples.append(examples[index])
298 if len(nextExamples) != 0:
299 tree.AddChildNode(BuildQuantTree(nextExamples,best,
300 nextAttrs,nPossibleVals,
301 nBoundsPerVar,
302 depth=1,maxDepth=maxDepth,
303 **kwargs))
304 else:
305 v = numpy.argmax(counts)
306 tree.AddChild('%d??'%(v),label=v,data=0.0,isTerminal=1)
307 else:
308 for val in range(nPossibleVals[best]):
309 nextExamples = []
310 for example in examples:
311 if example[best] == val:
312 nextExamples.append(example)
313 if len(nextExamples) != 0:
314 tree.AddChildNode(BuildQuantTree(nextExamples,best,
315 nextAttrs,nPossibleVals,
316 nBoundsPerVar,
317 depth=1,maxDepth=maxDepth,
318 **kwargs))
319 else:
320 v = numpy.argmax(counts)
321 tree.AddChild('%d??'%(v),label=v,data=0.0,isTerminal=1)
322 return tree
323
324
326 """ testing code for named trees
327
328 """
329 examples1 = [['p1',0,1,0,0],
330 ['p2',0,0,0,1],
331 ['p3',0,0,1,2],
332 ['p4',0,1,1,2],
333 ['p5',1,0,0,2],
334 ['p6',1,0,1,2],
335 ['p7',1,1,0,2],
336 ['p8',1,1,1,0]
337 ]
338 attrs = list(range(1,len(examples1[0])-1))
339 nPossibleVals = [0,2,2,2,3]
340 t1 = ID3.ID3Boot(examples1,attrs,nPossibleVals,maxDepth=1)
341 t1.Print()
342
343
345 """ testing code for named trees
346
347 """
348 examples1 = [['p1',0,1,0.1,0],
349 ['p2',0,0,0.1,1],
350 ['p3',0,0,1.1,2],
351 ['p4',0,1,1.1,2],
352 ['p5',1,0,0.1,2],
353 ['p6',1,0,1.1,2],
354 ['p7',1,1,0.1,2],
355 ['p8',1,1,1.1,0]
356 ]
357 attrs = list(range(1,len(examples1[0])-1))
358 nPossibleVals = [0,2,2,0,3]
359 boundsPerVar=[0,0,0,1,0]
360
361 print('base')
362 t1 = QuantTreeBoot(examples1,attrs,nPossibleVals,boundsPerVar)
363 t1.Pickle('test_data/QuantTree1.pkl')
364 t1.Print()
365
366 print('depth limit')
367 t1 = QuantTreeBoot(examples1,attrs,nPossibleVals,boundsPerVar,maxDepth=1)
368 t1.Pickle('test_data/QuantTree1.pkl')
369 t1.Print()
370
372 """ testing code for named trees
373
374 """
375 examples1 = [['p1',0.1,1,0.1,0],
376 ['p2',0.1,0,0.1,1],
377 ['p3',0.1,0,1.1,2],
378 ['p4',0.1,1,1.1,2],
379 ['p5',1.1,0,0.1,2],
380 ['p6',1.1,0,1.1,2],
381 ['p7',1.1,1,0.1,2],
382 ['p8',1.1,1,1.1,0]
383 ]
384 attrs = list(range(1,len(examples1[0])-1))
385 nPossibleVals = [0,0,2,0,3]
386 boundsPerVar=[0,1,0,1,0]
387
388 t1 = QuantTreeBoot(examples1,attrs,nPossibleVals,boundsPerVar)
389 t1.Print()
390 t1.Pickle('test_data/QuantTree2.pkl')
391
392 for example in examples1:
393 print(example,t1.ClassifyExample(example))
394
395 if __name__ == "__main__":
396 TestTree()
397 TestQuantTree()
398
399