1
2
3
4
5
6
7
8
9
10
11 """ contains the Cluster class for representing hierarchical cluster trees
12
13 """
14 from __future__ import print_function
15 import numpy
16
17 from rdkit.six.moves import reduce
18 from rdkit.six import cmp
19
20 CMPTOL=1e-6
21
23 """a class for storing clusters/data
24
25 **General Remarks**
26
27 - It is assumed that the bottom of any cluster hierarchy tree is composed of
28 the individual data points which were clustered.
29
30 - Clusters objects store the following pieces of data, most are
31 accessible via standard Setters/Getters:
32
33 - Children: *Not Settable*, the list of children. You can add children
34 with the _AddChild()_ and _AddChildren()_ methods.
35
36 **Note** this can be of arbitrary length,
37 but the current algorithms I have only produce trees with two children
38 per cluster
39
40 - Metric: the metric for this cluster (i.e. how far apart its children are)
41
42 - Index: the order in which this cluster was generated
43
44 - Points: *Not Settable*, the list of original points in this cluster
45 (calculated recursively from the children)
46
47 - PointsPositions: *Not Settable*, the list of positions of the original
48 points in this cluster (calculated recursively from the children)
49
50 - Position: the location of the cluster **Note** for a cluster this
51 probably means the location of the average of all the Points which are
52 its children.
53
54 - Data: a data field. This is used with the original points to store their
55 data value (i.e. the value we're using to classify)
56
57 - Name: the name of this cluster
58
59 """
60 - def __init__(self,metric=0.0,children=None,position=None,index=-1,name=None,data=None):
61 """Constructor
62
63 **Arguments**
64
65 see the class documentation for the meanings of these arguments
66
67 *my wrists are tired*
68
69 """
70 if children is None:
71 children = []
72 if position is None:
73 position = []
74 self.metric = metric
75 self.children = children
76 self._UpdateLength()
77 self.pos = position
78 self.index = index
79 self.name = name
80 self._points = None
81 self._pointsPositions = None
82 self.data = data
83
88
93
98
100 if self._pointsPositions is not None:
101 return self._pointsPositions
102 else:
103 self._GenPoints()
104 return self._pointsPositions
105
107 if self._points is not None:
108 return self._points
109 else:
110 self._GenPoints()
111 return self._points
112
114 """ finds and returns the subtree with a particular index
115 """
116 res = None
117 if index == self.index:
118 res = self
119 else:
120 for child in self.children:
121 res = child.FindSubtree(index)
122 if res:
123 break
124 return res
125
127 """ Generates the _Points_ and _PointsPositions_ lists
128
129 *intended for internal use*
130
131 """
132 if len(self) == 1:
133 self._points = [self]
134 self._pointsPositions = [self.GetPosition()]
135 return self._points
136 else:
137 res = []
138 children = self.GetChildren()
139
140 children.sort(key=lambda x:len(x), reverse=True)
141 for child in children:
142 res += child.GetPoints()
143 self._points=res
144 self._pointsPositions = [x.GetPosition() for x in res]
145
147 """Adds a child to our list
148
149 **Arguments**
150
151 - child: a Cluster
152
153 """
154 self.children.append(child)
155 self._GenPoints()
156 self._UpdateLength()
169 """Removes a child from our list
170
171 **Arguments**
172
173 - child: a Cluster
174
175 """
176 self.children.remove(child)
177 self._UpdateLength()
178
183
188
192 if self.name is None:
193 return 'Cluster(%d)'%(self.GetIndex())
194 else:
195 return self.name
196
197 - def Print(self,level=0,showData=0,offset='\t'):
198 if not showData or self.GetData() is None:
199 print('%s%s%s Metric: %f'%(' '*level,self.GetName(),offset,self.GetMetric()))
200 else:
201 print('%s%s%s Data: %f\t Metric: %f'%(' '*level,self.GetName(),offset,self.GetData(),self.GetMetric()))
202
203 for child in self.GetChildren():
204 child.Print(level=level+1,showData=showData,offset=offset)
205
206 - def Compare(self,other,ignoreExtras=1):
207 """ not as choosy as self==other
208
209 """
210 tv1,tv2 = str(type(self)),str(type(other))
211 tv = cmp(tv1,tv2)
212 if tv:
213 return tv
214 tv1,tv2 = len(self),len(other)
215 tv = cmp(tv1,tv2)
216 if tv:
217 return tv
218
219 if not ignoreExtras:
220 m1,m2=self.GetMetric(),other.GetMetric()
221 if abs(m1-m2)>CMPTOL:
222 return cmp(m1,m2)
223
224 if cmp(self.GetName(),other.GetName()):
225 return cmp(self.GetName(),other.GetName())
226
227 sP = self.GetPosition()
228 oP = other.GetPosition()
229 try:
230 r = cmp(len(sP),len(oP))
231 except Exception:
232 pass
233 else:
234 if r:
235 return r
236
237 try:
238 r = cmp(sP,oP)
239 except Exception:
240 r = sum(sP-oP)
241 if r:
242 return r
243
244 c1,c2=self.GetChildren(),other.GetChildren()
245 if cmp(len(c1),len(c2)):
246 return cmp(len(c1),len(c2))
247 for i in range(len(c1)):
248 t = c1[i].Compare(c2[i],ignoreExtras=ignoreExtras)
249 if t:
250 return t
251
252 return 0
253
255 """ updates our length
256
257 *intended for internal use*
258
259 """
260 self._len = reduce(lambda x,y: len(y)+x,self.children,1)
261
264
266 """ allows _len(cluster)_ to work
267
268 """
269 return self._len
270
272 """ allows _cluster1 == cluster2_ to work
273
274 """
275 if cmp(type(self),type(other)):
276 return cmp(type(self),type(other))
277
278 m1,m2=self.GetMetric(),other.GetMetric()
279 if abs(m1-m2)>CMPTOL:
280 return cmp(m1,m2)
281
282 if cmp(self.GetName(),other.GetName()):
283 return cmp(self.GetName(),other.GetName())
284
285 c1,c2=self.GetChildren(),other.GetChildren()
286 return cmp(c1,c2)
287
288
289 if __name__ == '__main__':
290 from rdkit.ML.Cluster import ClusterUtils
291 root = Cluster(index=1,metric=1000)
292 c1 = Cluster(index=10,metric=100)
293 c1.AddChild(Cluster(index=30,metric=10))
294 c1.AddChild(Cluster(index=31,metric=10))
295 c1.AddChild(Cluster(index=32,metric=10))
296
297 c2 = Cluster(index=11,metric=100)
298 c2.AddChild(Cluster(index=40,metric=10))
299 c2.AddChild(Cluster(index=41,metric=10))
300
301 root.AddChild(c1)
302 root.AddChild(c2)
303
304 nodes = ClusterUtils.GetNodeList(root)
305
306 indices = [x.GetIndex() for x in nodes]
307 print('XXX:',indices)
308