1
2
3
4
5
6 """ functionality for drawing trees on sping canvases
7
8 """
9 from rdkit.sping import pid as piddle
10 import math
11
13 circRad = 10
14 minCircRad = 4
15 maxCircRad = 16
16 circColor = piddle.Color(0.6,0.6,0.9)
17 terminalEmptyColor = piddle.Color(.8,.8,.2)
18 terminalOnColor = piddle.Color(0.8,0.8,0.8)
19 terminalOffColor = piddle.Color(0.2,0.2,0.2)
20 outlineColor = piddle.transparent
21 lineColor = piddle.Color(0,0,0)
22 lineWidth = 2
23 horizOffset = 10
24 vertOffset = 50
25 labelFont = piddle.Font(face='helvetica',size=10)
26 highlightColor = piddle.Color(1.,1.,.4)
27 highlightWidth = 2
28
29 visOpts = VisOpts()
30
32 """Recursively calculate the total number of nodes under us.
33
34 results are set in node.totNChildren for this node and
35 everything underneath it.
36 """
37 children = node.GetChildren()
38 if len(children) > 0:
39 nHere = 0
40 nBelow=0
41 for child in children:
42 CalcTreeNodeSizes(child)
43 nHere = nHere + child.totNChildren
44 if child.nLevelsBelow > nBelow:
45 nBelow = child.nLevelsBelow
46 else:
47 nBelow = 0
48 nHere = 1
49
50 node.nExamples = len(node.GetExamples())
51 node.totNChildren = nHere
52 node.nLevelsBelow = nBelow+1
53
55 if node.GetTerminal():
56 cnt = node.nExamples
57 if cnt < min: min = cnt
58 if cnt > max: max = cnt
59 else:
60 for child in node.GetChildren():
61 provMin,provMax = _ExampleCounter(child,min,max)
62 if provMin < min: min = provMin
63 if provMax > max: max = provMax
64 return min,max
65
67 if node.GetTerminal():
68 if max!=min:
69 loc = float(node.nExamples - min)/(max-min)
70 else:
71 loc = .5
72 node._scaleLoc = loc
73 else:
74 for child in node.GetChildren():
75 _ApplyNodeScales(child,min,max)
76
82
83
84 -def DrawTreeNode(node,loc,canvas,nRes=2,scaleLeaves=False,showPurity=False):
85 """Recursively displays the given tree node and all its children on the canvas
86 """
87 try:
88 nChildren = node.totNChildren
89 except AttributeError:
90 nChildren = None
91 if nChildren is None:
92 CalcTreeNodeSizes(node)
93
94 if not scaleLeaves or not node.GetTerminal():
95 rad = visOpts.circRad
96 else:
97 scaleLoc = getattr(node, "_scaleLoc", 0.5)
98
99 rad = visOpts.minCircRad + node._scaleLoc*(visOpts.maxCircRad-visOpts.minCircRad)
100
101 x1 = loc[0] - rad
102 y1 = loc[1] - rad
103 x2 = loc[0] + rad
104 y2 = loc[1] + rad
105
106
107 if showPurity and node.GetTerminal():
108 examples = node.GetExamples()
109 nEx = len(examples)
110 if nEx:
111 tgtVal = int(node.GetLabel())
112 purity = 0.0
113 for ex in examples:
114 if int(ex[-1])==tgtVal:
115 purity += 1./len(examples)
116 else:
117 purity = 1.0
118
119 deg = purity*math.pi
120 xFact = rad*math.sin(deg)
121 yFact = rad*math.cos(deg)
122 pureX = loc[0]+xFact
123 pureY = loc[1]+yFact
124
125
126 children = node.GetChildren()
127
128 childY = loc[1] + visOpts.vertOffset
129
130 childX = loc[0] - ((visOpts.horizOffset+visOpts.circRad)*node.totNChildren)/2
131 for i in range(len(children)):
132
133 child = children[i]
134 halfWidth = ((visOpts.horizOffset+visOpts.circRad)*child.totNChildren)/2
135
136 childX = childX + halfWidth
137 childLoc = [childX,childY]
138 canvas.drawLine(loc[0],loc[1],childLoc[0],childLoc[1],
139 visOpts.lineColor,visOpts.lineWidth)
140 DrawTreeNode(child,childLoc,canvas,nRes=nRes,scaleLeaves=scaleLeaves,
141 showPurity=showPurity)
142
143
144 childX = childX + halfWidth
145
146 if node.GetTerminal():
147 lab = node.GetLabel()
148 cFac = float(lab)/float(nRes-1)
149 if hasattr(node,'GetExamples') and node.GetExamples():
150 theColor = (1.-cFac)*visOpts.terminalOffColor + cFac*visOpts.terminalOnColor
151 outlColor = visOpts.outlineColor
152 else:
153 theColor = (1.-cFac)*visOpts.terminalOffColor + cFac*visOpts.terminalOnColor
154 outlColor = visOpts.terminalEmptyColor
155 canvas.drawEllipse(x1,y1,x2,y2,
156 outlColor,visOpts.lineWidth,
157 theColor)
158 if showPurity:
159 canvas.drawLine(loc[0],loc[1],pureX,pureY,piddle.Color(1,1,1),2)
160 else:
161 theColor = visOpts.circColor
162 canvas.drawEllipse(x1,y1,x2,y2,
163 visOpts.outlineColor,visOpts.lineWidth,
164 theColor)
165
166
167 canvas.defaultFont=visOpts.labelFont
168
169 labelStr = str(node.GetLabel())
170 strLoc = (loc[0] - canvas.stringWidth(labelStr)/2,
171 loc[1]+canvas.fontHeight()/4)
172
173 canvas.drawString(labelStr,strLoc[0],strLoc[1])
174 node._bBox = (x1,y1,x2,y2)
175
183
184 -def DrawTree(tree,canvas,nRes=2,scaleLeaves=False,allowShrink=True,showPurity=False):
202
204 tree._scales = None
205 tree.totNChildren = None
206 for child in tree.GetChildren():
207 ResetTree(child)
208
210 from Tree import TreeNode as Node
211 root = Node(None,'r',label='r')
212 c1 = root.AddChild('l1_1',label='l1_1')
213 c2 = root.AddChild('l1_2',isTerminal=1,label=1)
214 c3 = c1.AddChild('l2_1',isTerminal=1,label=0)
215 c4 = c1.AddChild('l2_2',isTerminal=1,label=1)
216
217 DrawTreeNode(root,(150,visOpts.vertOffset),canv)
218
219
220 if __name__ == '__main__':
221 from rdkit.sping.PIL.pidPIL import PILCanvas
222 canv = PILCanvas(size=(300,300),name='test.png')
223 _simpleTest(canv)
224 canv.save()
225