Package rdkit :: Package ML :: Package DecTree :: Module QuantTree
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.DecTree.QuantTree

 1  # $Id$ 
 2  # 
 3  #  Copyright (C) 2001, 2003  greg Landrum and Rational Discovery LLC 
 4  #   All Rights Reserved 
 5  # 
 6  """ Defines the class _QuantTreeNode_, used to represent decision trees with automatic 
 7   quantization bounds 
 8   
 9    _QuantTreeNode_ is derived from _DecTree.DecTreeNode_ 
10   
11  """ 
12  from rdkit.ML.DecTree import DecTree,Tree 
13  from rdkit.six import cmp 
14   
15 -class QuantTreeNode(DecTree.DecTreeNode):
16 """ 17 18 """
19 - def __init__(self,*args,**kwargs):
20 DecTree.DecTreeNode.__init__(self,*args,**kwargs) 21 self.qBounds = [] 22 self.nBounds = 0
23 - def ClassifyExample(self,example,appendExamples=0):
24 """ Recursively classify an example by running it through the tree 25 26 **Arguments** 27 28 - example: the example to be classified 29 30 - appendExamples: if this is nonzero then this node (and all children) 31 will store the example 32 33 **Returns** 34 35 the classification of _example_ 36 37 **NOTE:** 38 In the interest of speed, I don't use accessor functions 39 here. So if you subclass DecTreeNode for your own trees, you'll 40 have to either include ClassifyExample or avoid changing the names 41 of the instance variables this needs. 42 43 """ 44 if appendExamples: 45 self.examples.append(example) 46 if self.terminalNode: 47 return self.label 48 else: 49 val = example[self.label] 50 if not hasattr(self,'nBounds'): self.nBounds = len(self.qBounds) 51 if self.nBounds: 52 for i,bound in enumerate(self.qBounds): 53 if val < bound: 54 val = i 55 break 56 else: 57 val = i+1 58 else: 59 val = int(val) 60 return self.children[val].ClassifyExample(example,appendExamples=appendExamples)
61
62 - def SetQuantBounds(self,qBounds):
63 self.qBounds = qBounds[:] 64 self.nBounds = len(self.qBounds)
65 - def GetQuantBounds(self):
66 return self.qBounds
67
68 - def __cmp__(self,other):
69 return (self<other)*-1 or (other<self)*1
70
71 - def __lt__(self,other):
72 if str(type(self)) < str(type(other)): return True 73 if self.qBounds<other.qBounds: return True 74 if Tree.TreeNode.__lt__(self,other): return True 75 return False
76 - def __eq__(self,other):
77 return not self<other and not other<self
78
79 - def __str__(self):
80 """ returns a string representation of the tree 81 82 **Note** 83 84 this works recursively 85 86 """ 87 here = '%s%s %s\n'%(' '*self.level,self.name,str(self.qBounds)) 88 for child in self.children: 89 here = here + str(child) 90 return here
91