Package rdkit :: Package ML :: Package DecTree :: Module TreeVis
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.DecTree.TreeVis

  1  # $Id$ 
  2  # 
  3  #  Copyright (C) 2002,2003  Greg Landrum and Rational Discovery LLC 
  4  #    All Rights Reserved 
  5  # 
  6  """ functionality for drawing trees on sping canvases 
  7   
  8  """     
  9  from rdkit.sping import pid as piddle 
 10  import math 
 11   
12 -class VisOpts(object):
13 circRad = 10 14 minCircRad = 4 15 maxCircRad = 16 16 circColor = piddle.Color(0.6,0.6,0.9) 17 terminalEmptyColor = piddle.Color(.8,.8,.2) 18 terminalOnColor = piddle.Color(0.8,0.8,0.8) 19 terminalOffColor = piddle.Color(0.2,0.2,0.2) 20 outlineColor = piddle.transparent 21 lineColor = piddle.Color(0,0,0) 22 lineWidth = 2 23 horizOffset = 10 24 vertOffset = 50 25 labelFont = piddle.Font(face='helvetica',size=10) 26 highlightColor = piddle.Color(1.,1.,.4) 27 highlightWidth = 2
28 29 visOpts = VisOpts() 30
31 -def CalcTreeNodeSizes(node):
32 """Recursively calculate the total number of nodes under us. 33 34 results are set in node.totNChildren for this node and 35 everything underneath it. 36 """ 37 children = node.GetChildren() 38 if len(children) > 0: 39 nHere = 0 40 nBelow=0 41 for child in children: 42 CalcTreeNodeSizes(child) 43 nHere = nHere + child.totNChildren 44 if child.nLevelsBelow > nBelow: 45 nBelow = child.nLevelsBelow 46 else: 47 nBelow = 0 48 nHere = 1 49 50 node.nExamples = len(node.GetExamples()) 51 node.totNChildren = nHere 52 node.nLevelsBelow = nBelow+1
53
54 -def _ExampleCounter(node,min,max):
55 if node.GetTerminal(): 56 cnt = node.nExamples 57 if cnt < min: min = cnt 58 if cnt > max: max = cnt 59 else: 60 for child in node.GetChildren(): 61 provMin,provMax = _ExampleCounter(child,min,max) 62 if provMin < min: min = provMin 63 if provMax > max: max = provMax 64 return min,max
65
66 -def _ApplyNodeScales(node,min,max):
67 if node.GetTerminal(): 68 if max!=min: 69 loc = float(node.nExamples - min)/(max-min) 70 else: 71 loc = .5 72 node._scaleLoc = loc 73 else: 74 for child in node.GetChildren(): 75 _ApplyNodeScales(child,min,max)
76
77 -def SetNodeScales(node):
78 min,max = 1e8,-1e8 79 min,max = _ExampleCounter(node,min,max) 80 node._scales=min,max 81 _ApplyNodeScales(node,min,max)
82 83
84 -def DrawTreeNode(node,loc,canvas,nRes=2,scaleLeaves=False,showPurity=False):
85 """Recursively displays the given tree node and all its children on the canvas 86 """ 87 try: 88 nChildren = node.totNChildren 89 except AttributeError: 90 nChildren = None 91 if nChildren is None: 92 CalcTreeNodeSizes(node) 93 94 if not scaleLeaves or not node.GetTerminal(): 95 rad = visOpts.circRad 96 else: 97 scaleLoc = getattr(node, "_scaleLoc", 0.5) 98 99 rad = visOpts.minCircRad + node._scaleLoc*(visOpts.maxCircRad-visOpts.minCircRad) 100 101 x1 = loc[0] - rad 102 y1 = loc[1] - rad 103 x2 = loc[0] + rad 104 y2 = loc[1] + rad 105 106 107 if showPurity and node.GetTerminal(): 108 examples = node.GetExamples() 109 nEx = len(examples) 110 if nEx: 111 tgtVal = int(node.GetLabel()) 112 purity = 0.0 113 for ex in examples: 114 if int(ex[-1])==tgtVal: 115 purity += 1./len(examples) 116 else: 117 purity = 1.0 118 119 deg = purity*math.pi 120 xFact = rad*math.sin(deg) 121 yFact = rad*math.cos(deg) 122 pureX = loc[0]+xFact 123 pureY = loc[1]+yFact 124 125 126 children = node.GetChildren() 127 # just move down one level 128 childY = loc[1] + visOpts.vertOffset 129 # this is the left-hand side of the leftmost span 130 childX = loc[0] - ((visOpts.horizOffset+visOpts.circRad)*node.totNChildren)/2 131 for i in range(len(children)): 132 # center on this child's space 133 child = children[i] 134 halfWidth = ((visOpts.horizOffset+visOpts.circRad)*child.totNChildren)/2 135 136 childX = childX + halfWidth 137 childLoc = [childX,childY] 138 canvas.drawLine(loc[0],loc[1],childLoc[0],childLoc[1], 139 visOpts.lineColor,visOpts.lineWidth) 140 DrawTreeNode(child,childLoc,canvas,nRes=nRes,scaleLeaves=scaleLeaves, 141 showPurity=showPurity) 142 143 # and move over to the leftmost point of the next child 144 childX = childX + halfWidth 145 146 if node.GetTerminal(): 147 lab = node.GetLabel() 148 cFac = float(lab)/float(nRes-1) 149 if hasattr(node,'GetExamples') and node.GetExamples(): 150 theColor = (1.-cFac)*visOpts.terminalOffColor + cFac*visOpts.terminalOnColor 151 outlColor = visOpts.outlineColor 152 else: 153 theColor = (1.-cFac)*visOpts.terminalOffColor + cFac*visOpts.terminalOnColor 154 outlColor = visOpts.terminalEmptyColor 155 canvas.drawEllipse(x1,y1,x2,y2, 156 outlColor,visOpts.lineWidth, 157 theColor) 158 if showPurity: 159 canvas.drawLine(loc[0],loc[1],pureX,pureY,piddle.Color(1,1,1),2) 160 else: 161 theColor = visOpts.circColor 162 canvas.drawEllipse(x1,y1,x2,y2, 163 visOpts.outlineColor,visOpts.lineWidth, 164 theColor) 165 166 # this does not need to be done every time 167 canvas.defaultFont=visOpts.labelFont 168 169 labelStr = str(node.GetLabel()) 170 strLoc = (loc[0] - canvas.stringWidth(labelStr)/2, 171 loc[1]+canvas.fontHeight()/4) 172 173 canvas.drawString(labelStr,strLoc[0],strLoc[1]) 174 node._bBox = (x1,y1,x2,y2)
175
176 -def CalcTreeWidth(tree):
177 try: 178 tree.totNChildren 179 except AttributeError: 180 CalcTreeNodeSizes(tree) 181 totWidth = tree.totNChildren * (visOpts.circRad+visOpts.horizOffset) 182 return totWidth
183
184 -def DrawTree(tree,canvas,nRes=2,scaleLeaves=False,allowShrink=True,showPurity=False):
185 dims = canvas.size 186 loc = (dims[0]/2,visOpts.vertOffset) 187 if scaleLeaves: 188 #try: 189 # l = tree._scales 190 #except AttributeError: 191 # l = None 192 #if l is None: 193 SetNodeScales(tree) 194 if allowShrink: 195 treeWid = CalcTreeWidth(tree) 196 while treeWid > dims[0]: 197 visOpts.circRad /= 2 198 visOpts.horizOffset /= 2 199 treeWid = CalcTreeWidth(tree) 200 DrawTreeNode(tree,loc,canvas,nRes,scaleLeaves=scaleLeaves, 201 showPurity=showPurity)
202
203 -def ResetTree(tree):
204 tree._scales = None 205 tree.totNChildren = None 206 for child in tree.GetChildren(): 207 ResetTree(child)
208
209 -def _simpleTest(canv):
210 from Tree import TreeNode as Node 211 root = Node(None,'r',label='r') 212 c1 = root.AddChild('l1_1',label='l1_1') 213 c2 = root.AddChild('l1_2',isTerminal=1,label=1) 214 c3 = c1.AddChild('l2_1',isTerminal=1,label=0) 215 c4 = c1.AddChild('l2_2',isTerminal=1,label=1) 216 217 DrawTreeNode(root,(150,visOpts.vertOffset),canv)
218 219 220 if __name__ == '__main__': 221 from rdkit.sping.PIL.pidPIL import PILCanvas 222 canv = PILCanvas(size=(300,300),name='test.png') 223 _simpleTest(canv) 224 canv.save() 225