1
2
3
4
5
6 from __future__ import print_function
7
8 from rdkit import RDLogger
9 logger = RDLogger.logger()
10 from rdkit import Chem,Geometry
11 import numpy
12 from rdkit.Numerics import Alignment
13 from rdkit.Chem.Subshape import SubshapeObjects
14
23
25 for i in range(len(pts)):
26 if orderedTraversal:
27 jStart=i+1
28 else:
29 jStart=0
30 for j in range(jStart,len(pts)):
31 if j==i:
32 continue
33 if orderedTraversal:
34 kStart=j+1
35 else:
36 kStart=0
37 for k in range(j+1,len(pts)):
38 if k==i or k==j:
39 continue
40 yield (i,j,k)
41
45
46
58
59
64 from rdkit.ML.Cluster import Butina
65 dists = []
66 for i in range(len(alignments)):
67 TransformMol(mol,alignments[i].transform,newConfId=tempConfId)
68 shapeI=builder.GenerateSubshapeShape(mol,tempConfId,addSkeleton=False)
69 for j in range(i):
70 TransformMol(mol,alignments[j].transform,newConfId=tempConfId+1)
71 shapeJ=builder.GenerateSubshapeShape(mol,tempConfId+1,addSkeleton=False)
72 d = GetShapeShapeDistance(shapeI,shapeJ,distMetric)
73 dists.append(d)
74 mol.RemoveConformer(tempConfId+1)
75 mol.RemoveConformer(tempConfId)
76 clusts=Butina.ClusterData(dists,len(alignments),neighborTol,isDistData=True)
77 res = [alignments[x[0]] for x in clusts]
78 return res
79
96
98 triangleRMSTol=1.0
99 distMetric=SubshapeDistanceMetric.PROTRUDE
100 shapeDistTol=0.2
101 numFeatThresh=3
102 dirThresh=2.6
103 edgeTol=6.0
104
105
106 coarseGridToleranceMult=1.0
107 medGridToleranceMult=1.0
108
110 """ this is a generator function returning the possible triangle
111 matches between the two shapes
112 """
113 ssdTol = (self.triangleRMSTol**2)*9
114 res = []
115 tgtPts = target.skelPts
116 queryPts = query.skelPts
117 tgtLs = {}
118 for i in range(len(tgtPts)):
119 for j in range(i+1,len(tgtPts)):
120 l2 = (tgtPts[i].location-tgtPts[j].location).LengthSq()
121 tgtLs[(i,j)]=l2
122 queryLs = {}
123 for i in range(len(queryPts)):
124 for j in range(i+1,len(queryPts)):
125 l2 = (queryPts[i].location-queryPts[j].location).LengthSq()
126 queryLs[(i,j)]=l2
127 compatEdges={}
128 tol2 = self.edgeTol*self.edgeTol
129 for tk,tv in tgtLs.items():
130 for qk,qv in queryLs.items():
131 if abs(tv-qv)<tol2:
132 compatEdges[(tk,qk)]=1
133 seqNo=0
134 for tgtTri in _getAllTriangles(tgtPts,orderedTraversal=True):
135 tgtLocs=[tgtPts[x].location for x in tgtTri]
136 for queryTri in _getAllTriangles(queryPts,orderedTraversal=False):
137 if ((tgtTri[0],tgtTri[1]),(queryTri[0],queryTri[1])) in compatEdges and \
138 ((tgtTri[0],tgtTri[2]),(queryTri[0],queryTri[2])) in compatEdges and \
139 ((tgtTri[1],tgtTri[2]),(queryTri[1],queryTri[2])) in compatEdges:
140 queryLocs=[queryPts[x].location for x in queryTri]
141 ssd,tf = Alignment.GetAlignmentTransform(tgtLocs,queryLocs)
142 if ssd<=ssdTol:
143 alg = SubshapeAlignment()
144 alg.transform=tf
145 alg.triangleSSD=ssd
146 alg.targetTri=tgtTri
147 alg.queryTri=queryTri
148 alg._seqNo=seqNo
149 seqNo+=1
150 yield alg
151
153 nMatched=0
154 for i in range(3):
155 tgtFeats = targetPts[alignment.targetTri[i]].molFeatures
156 qFeats = queryPts[alignment.queryTri[i]].molFeatures
157 if not tgtFeats and not qFeats:
158 nMatched+=1
159 else:
160 for j,jFeat in enumerate(tgtFeats):
161 if jFeat in qFeats:
162 nMatched+=1
163 break
164 if nMatched>=self.numFeatThresh:
165 break
166 return nMatched>=self.numFeatThresh
167
169 i = 0
170 targetPts = target.skelPts
171 queryPts = query.skelPts
172 while i<len(alignments):
173 alg = alignments[i]
174 if not self._checkMatchFeatures(targetPts,queryPts,alg):
175 if pruneStats is not None:
176 pruneStats['features']=pruneStats.get('features',0)+1
177 del alignments[i]
178 else:
179 i+=1
180
198
200 i = 0
201 tgtPts = target.skelPts
202 queryPts = query.skelPts
203 while i<len(alignments):
204 if not self._checkMatchDirections(tgtPts,queryPts,alignments[i]):
205 if pruneStats is not None:
206 pruneStats['direction']=pruneStats.get('direction',0)+1
207 del alignments[i]
208 else:
209 i+=1
210
222
223 - def _checkMatchShape(self,targetMol,target,queryMol,query,alignment,builder,
224 targetConf,queryConf,pruneStats=None,tConfId=1001):
225 matchOk=True
226 TransformMol(queryMol,alignment.transform,confId=queryConf,newConfId=tConfId)
227 oSpace=builder.gridSpacing
228 builder.gridSpacing=oSpace*2
229 coarseGrid=builder.GenerateSubshapeShape(queryMol,tConfId,addSkeleton=False)
230 d = GetShapeShapeDistance(coarseGrid,target.coarseGrid,self.distMetric)
231 if d>self.shapeDistTol*self.coarseGridToleranceMult:
232 matchOk=False
233 if pruneStats is not None:
234 pruneStats['coarseGrid']=pruneStats.get('coarseGrid',0)+1
235 else:
236 builder.gridSpacing=oSpace*1.5
237 medGrid=builder.GenerateSubshapeShape(queryMol,tConfId,addSkeleton=False)
238 d = GetShapeShapeDistance(medGrid,target.medGrid,self.distMetric)
239 if d>self.shapeDistTol*self.medGridToleranceMult:
240 matchOk=False
241 if pruneStats is not None:
242 pruneStats['medGrid']=pruneStats.get('medGrid',0)+1
243 else:
244 builder.gridSpacing=oSpace
245 fineGrid=builder.GenerateSubshapeShape(queryMol,tConfId,addSkeleton=False)
246 d = GetShapeShapeDistance(fineGrid,target,self.distMetric)
247
248 if d>self.shapeDistTol:
249 matchOk=False
250 if pruneStats is not None:
251 pruneStats['fineGrid']=pruneStats.get('fineGrid',0)+1
252 alignment.shapeDist=d
253 queryMol.RemoveConformer(tConfId)
254 builder.gridSpacing=oSpace
255 return matchOk
256
257 - def PruneMatchesUsingShape(self,targetMol,target,queryMol,query,builder,
258 alignments,tgtConf=-1,queryConf=-1,
259 pruneStats=None):
260 if not hasattr(target,'medGrid'):
261 self._addCoarseAndMediumGrids(targetMol,target,tgtConf,builder)
262
263 logger.info("Shape-based Pruning")
264 i=0
265 nOrig = len(alignments)
266 nDone=0
267 while i < len(alignments):
268 removeIt=False
269 alg = alignments[i]
270 nDone+=1
271 if not nDone%100:
272 nLeft = len(alignments)
273 logger.info(' processed %d of %d. %d alignments remain'%((nDone,
274 nOrig,
275 nLeft)))
276 if not self._checkMatchShape(targetMol,target,queryMol,query,alg,builder,
277 targetConf=tgtConf,queryConf=queryConf,
278 pruneStats=pruneStats):
279 del alignments[i]
280 else:
281 i+=1
282
283 - def GetSubshapeAlignments(self,targetMol,target,queryMol,query,builder,
284 tgtConf=-1,queryConf=-1,pruneStats=None):
285 import time
286 if pruneStats is None:
287 pruneStats={}
288 logger.info("Generating triangle matches")
289 t1=time.time()
290 res = [x for x in self.GetTriangleMatches(target,query)]
291 t2=time.time()
292 logger.info("Got %d possible alignments in %.1f seconds"%(len(res),t2-t1))
293 pruneStats['gtm_time']=t2-t1
294 if builder.featFactory:
295 logger.info("Doing feature pruning")
296 t1 = time.time()
297 self.PruneMatchesUsingFeatures(target,query,res,pruneStats=pruneStats)
298 t2 = time.time()
299 pruneStats['feats_time']=t2-t1
300 logger.info("%d possible alignments remain. (%.1f seconds required)"%(len(res),t2-t1))
301 logger.info("Doing direction pruning")
302 t1 = time.time()
303 self.PruneMatchesUsingDirection(target,query,res,pruneStats=pruneStats)
304 t2 = time.time()
305 pruneStats['direction_time']=t2-t1
306 logger.info("%d possible alignments remain. (%.1f seconds required)"%(len(res),t2-t1))
307 t1 = time.time()
308 self.PruneMatchesUsingShape(targetMol,target,queryMol,query,builder,res,
309 tgtConf=tgtConf,queryConf=queryConf,
310 pruneStats=pruneStats)
311 t2 = time.time()
312 pruneStats['shape_time']=t2-t1
313 return res
314
315 - def __call__(self,targetMol,target,queryMol,query,builder,
316 tgtConf=-1,queryConf=-1,pruneStats=None):
317 for alignment in self.GetTriangleMatches(target,query):
318 if builder.featFactory and \
319 not self._checkMatchFeatures(target.skelPts,query.skelPts,alignment):
320 if pruneStats is not None:
321 pruneStats['features']=pruneStats.get('features',0)+1
322 continue
323 if not self._checkMatchDirections(target.skelPts,query.skelPts,alignment):
324 if pruneStats is not None:
325 pruneStats['direction']=pruneStats.get('direction',0)+1
326 continue
327
328 if not hasattr(target,'medGrid'):
329 self._addCoarseAndMediumGrids(targetMol,target,tgtConf,builder)
330
331 if not self._checkMatchShape(targetMol,target,queryMol,query,alignment,builder,
332 targetConf=tgtConf,queryConf=queryConf,
333 pruneStats=pruneStats):
334 continue
335
336 yield alignment
337
338
339 if __name__=='__main__':
340 from rdkit.six.moves import cPickle
341 tgtMol,tgtShape = cPickle.load(file('target.pkl','rb'))
342 queryMol,queryShape = cPickle.load(file('query.pkl','rb'))
343 builder = cPickle.load(file('builder.pkl','rb'))
344 aligner = SubshapeAligner()
345 algs = aligner.GetSubshapeAlignments(tgtMol,tgtShape,queryMol,queryShape,builder)
346 print(len(algs))
347
348 from rdkit.Chem.PyMol import MolViewer
349 v = MolViewer()
350 v.ShowMol(tgtMol,name='Target',showOnly=True)
351 v.ShowMol(queryMol,name='Query',showOnly=False)
352 SubshapeObjects.DisplaySubshape(v,tgtShape,'target_shape',color=(.8,.2,.2))
353 SubshapeObjects.DisplaySubshape(v,queryShape,'query_shape',color=(.2,.2,.8))
354