Package rdkit :: Package Chem :: Package Subshape :: Module SubshapeAligner
[hide private]
[frames] | no frames]

Source Code for Module rdkit.Chem.Subshape.SubshapeAligner

  1  # $Id$ 
  2  # 
  3  # Copyright (C) 2007-2008 by Greg Landrum  
  4  #  All rights reserved 
  5  # 
  6  from __future__ import print_function 
  7   
  8  from rdkit import RDLogger 
  9  logger = RDLogger.logger() 
 10  from rdkit import Chem,Geometry 
 11  import numpy 
 12  from rdkit.Numerics import Alignment 
 13  from rdkit.Chem.Subshape import SubshapeObjects 
 14   
15 -class SubshapeAlignment(object):
16 transform=None 17 triangleSSD=None 18 targetTri=None 19 queryTri=None 20 alignedConfId=-1 21 dirMatch=0.0 22 shapeDist=0.0
23
24 -def _getAllTriangles(pts,orderedTraversal=False):
25 for i in range(len(pts)): 26 if orderedTraversal: 27 jStart=i+1 28 else: 29 jStart=0 30 for j in range(jStart,len(pts)): 31 if j==i: 32 continue 33 if orderedTraversal: 34 kStart=j+1 35 else: 36 kStart=0 37 for k in range(j+1,len(pts)): 38 if k==i or k==j: 39 continue 40 yield (i,j,k)
41
42 -class SubshapeDistanceMetric(object):
43 TANIMOTO=0 44 PROTRUDE=1
45 46 # returns the distance between two shapea according to the provided metric
47 -def GetShapeShapeDistance(s1,s2,distMetric):
48 if distMetric==SubshapeDistanceMetric.PROTRUDE: 49 #print s1.grid.GetOccupancyVect().GetTotalVal(),s2.grid.GetOccupancyVect().GetTotalVal() 50 if s1.grid.GetOccupancyVect().GetTotalVal()<s2.grid.GetOccupancyVect().GetTotalVal(): 51 d = Geometry.ProtrudeDistance(s1.grid,s2.grid) 52 #print d 53 else: 54 d = Geometry.ProtrudeDistance(s2.grid,s1.grid) 55 else: 56 d = Geometry.TanimotoDistance(s1.grid,s2.grid) 57 return d
58 59 # clusters a set of alignments and returns the cluster centroid
60 -def ClusterAlignments(mol,alignments,builder, 61 neighborTol=0.1, 62 distMetric=SubshapeDistanceMetric.PROTRUDE, 63 tempConfId=1001):
64 from rdkit.ML.Cluster import Butina 65 dists = [] 66 for i in range(len(alignments)): 67 TransformMol(mol,alignments[i].transform,newConfId=tempConfId) 68 shapeI=builder.GenerateSubshapeShape(mol,tempConfId,addSkeleton=False) 69 for j in range(i): 70 TransformMol(mol,alignments[j].transform,newConfId=tempConfId+1) 71 shapeJ=builder.GenerateSubshapeShape(mol,tempConfId+1,addSkeleton=False) 72 d = GetShapeShapeDistance(shapeI,shapeJ,distMetric) 73 dists.append(d) 74 mol.RemoveConformer(tempConfId+1) 75 mol.RemoveConformer(tempConfId) 76 clusts=Butina.ClusterData(dists,len(alignments),neighborTol,isDistData=True) 77 res = [alignments[x[0]] for x in clusts] 78 return res
79
80 -def TransformMol(mol,tform,confId=-1,newConfId=100):
81 """ Applies the transformation to a molecule and sets it up with 82 a single conformer 83 84 """ 85 newConf = Chem.Conformer() 86 newConf.SetId(0) 87 refConf = mol.GetConformer(confId) 88 for i in range(refConf.GetNumAtoms()): 89 pos = list(refConf.GetAtomPosition(i)) 90 pos.append(1.0) 91 newPos = numpy.dot(tform,numpy.array(pos)) 92 newConf.SetAtomPosition(i,list(newPos)[:3]) 93 newConf.SetId(newConfId) 94 mol.RemoveConformer(newConfId) 95 mol.AddConformer(newConf,assignId=False)
96
97 -class SubshapeAligner(object):
98 triangleRMSTol=1.0 99 distMetric=SubshapeDistanceMetric.PROTRUDE 100 shapeDistTol=0.2 101 numFeatThresh=3 102 dirThresh=2.6 103 edgeTol=6.0 104 #coarseGridToleranceMult=1.5 105 #medGridToleranceMult=1.25 106 coarseGridToleranceMult=1.0 107 medGridToleranceMult=1.0 108
109 - def GetTriangleMatches(self,target,query):
110 """ this is a generator function returning the possible triangle 111 matches between the two shapes 112 """ 113 ssdTol = (self.triangleRMSTol**2)*9 114 res = [] 115 tgtPts = target.skelPts 116 queryPts = query.skelPts 117 tgtLs = {} 118 for i in range(len(tgtPts)): 119 for j in range(i+1,len(tgtPts)): 120 l2 = (tgtPts[i].location-tgtPts[j].location).LengthSq() 121 tgtLs[(i,j)]=l2 122 queryLs = {} 123 for i in range(len(queryPts)): 124 for j in range(i+1,len(queryPts)): 125 l2 = (queryPts[i].location-queryPts[j].location).LengthSq() 126 queryLs[(i,j)]=l2 127 compatEdges={} 128 tol2 = self.edgeTol*self.edgeTol 129 for tk,tv in tgtLs.items(): 130 for qk,qv in queryLs.items(): 131 if abs(tv-qv)<tol2: 132 compatEdges[(tk,qk)]=1 133 seqNo=0 134 for tgtTri in _getAllTriangles(tgtPts,orderedTraversal=True): 135 tgtLocs=[tgtPts[x].location for x in tgtTri] 136 for queryTri in _getAllTriangles(queryPts,orderedTraversal=False): 137 if ((tgtTri[0],tgtTri[1]),(queryTri[0],queryTri[1])) in compatEdges and \ 138 ((tgtTri[0],tgtTri[2]),(queryTri[0],queryTri[2])) in compatEdges and \ 139 ((tgtTri[1],tgtTri[2]),(queryTri[1],queryTri[2])) in compatEdges: 140 queryLocs=[queryPts[x].location for x in queryTri] 141 ssd,tf = Alignment.GetAlignmentTransform(tgtLocs,queryLocs) 142 if ssd<=ssdTol: 143 alg = SubshapeAlignment() 144 alg.transform=tf 145 alg.triangleSSD=ssd 146 alg.targetTri=tgtTri 147 alg.queryTri=queryTri 148 alg._seqNo=seqNo 149 seqNo+=1 150 yield alg
151
152 - def _checkMatchFeatures(self,targetPts,queryPts,alignment):
153 nMatched=0 154 for i in range(3): 155 tgtFeats = targetPts[alignment.targetTri[i]].molFeatures 156 qFeats = queryPts[alignment.queryTri[i]].molFeatures 157 if not tgtFeats and not qFeats: 158 nMatched+=1 159 else: 160 for j,jFeat in enumerate(tgtFeats): 161 if jFeat in qFeats: 162 nMatched+=1 163 break 164 if nMatched>=self.numFeatThresh: 165 break 166 return nMatched>=self.numFeatThresh
167
168 - def PruneMatchesUsingFeatures(self,target,query,alignments,pruneStats=None):
169 i = 0 170 targetPts = target.skelPts 171 queryPts = query.skelPts 172 while i<len(alignments): 173 alg = alignments[i] 174 if not self._checkMatchFeatures(targetPts,queryPts,alg): 175 if pruneStats is not None: 176 pruneStats['features']=pruneStats.get('features',0)+1 177 del alignments[i] 178 else: 179 i+=1
180
181 - def _checkMatchDirections(self,targetPts,queryPts,alignment):
182 dot = 0.0 183 for i in range(3): 184 tgtPt = targetPts[alignment.targetTri[i]] 185 queryPt = queryPts[alignment.queryTri[i]] 186 qv = queryPt.shapeDirs[0] 187 tv = tgtPt.shapeDirs[0] 188 rotV =[0.0]*3 189 rotV[0] = alignment.transform[0,0]*qv[0]+alignment.transform[0,1]*qv[1]+alignment.transform[0,2]*qv[2] 190 rotV[1] = alignment.transform[1,0]*qv[0]+alignment.transform[1,1]*qv[1]+alignment.transform[1,2]*qv[2] 191 rotV[2] = alignment.transform[2,0]*qv[0]+alignment.transform[2,1]*qv[1]+alignment.transform[2,2]*qv[2] 192 dot += abs(rotV[0]*tv[0]+rotV[1]*tv[1]+rotV[2]*tv[2]) 193 if dot>=self.dirThresh: 194 # already above the threshold, no need to continue 195 break 196 alignment.dirMatch=dot 197 return dot>=self.dirThresh
198
199 - def PruneMatchesUsingDirection(self,target,query,alignments,pruneStats=None):
200 i = 0 201 tgtPts = target.skelPts 202 queryPts = query.skelPts 203 while i<len(alignments): 204 if not self._checkMatchDirections(tgtPts,queryPts,alignments[i]): 205 if pruneStats is not None: 206 pruneStats['direction']=pruneStats.get('direction',0)+1 207 del alignments[i] 208 else: 209 i+=1
210
211 - def _addCoarseAndMediumGrids(self,mol,tgt,confId,builder):
212 oSpace=builder.gridSpacing 213 if mol: 214 builder.gridSpacing = oSpace*1.5 215 tgt.medGrid = builder.GenerateSubshapeShape(mol,confId,addSkeleton=False) 216 builder.gridSpacing = oSpace*2 217 tgt.coarseGrid = builder.GenerateSubshapeShape(mol,confId,addSkeleton=False) 218 builder.gridSpacing = oSpace 219 else: 220 tgt.medGrid = builder.SampleSubshape(tgt,oSpace*1.5) 221 tgt.coarseGrid = builder.SampleSubshape(tgt,oSpace*2.0)
222
223 - def _checkMatchShape(self,targetMol,target,queryMol,query,alignment,builder, 224 targetConf,queryConf,pruneStats=None,tConfId=1001):
225 matchOk=True 226 TransformMol(queryMol,alignment.transform,confId=queryConf,newConfId=tConfId) 227 oSpace=builder.gridSpacing 228 builder.gridSpacing=oSpace*2 229 coarseGrid=builder.GenerateSubshapeShape(queryMol,tConfId,addSkeleton=False) 230 d = GetShapeShapeDistance(coarseGrid,target.coarseGrid,self.distMetric) 231 if d>self.shapeDistTol*self.coarseGridToleranceMult: 232 matchOk=False 233 if pruneStats is not None: 234 pruneStats['coarseGrid']=pruneStats.get('coarseGrid',0)+1 235 else: 236 builder.gridSpacing=oSpace*1.5 237 medGrid=builder.GenerateSubshapeShape(queryMol,tConfId,addSkeleton=False) 238 d = GetShapeShapeDistance(medGrid,target.medGrid,self.distMetric) 239 if d>self.shapeDistTol*self.medGridToleranceMult: 240 matchOk=False 241 if pruneStats is not None: 242 pruneStats['medGrid']=pruneStats.get('medGrid',0)+1 243 else: 244 builder.gridSpacing=oSpace 245 fineGrid=builder.GenerateSubshapeShape(queryMol,tConfId,addSkeleton=False) 246 d = GetShapeShapeDistance(fineGrid,target,self.distMetric) 247 #print ' ',d 248 if d>self.shapeDistTol: 249 matchOk=False 250 if pruneStats is not None: 251 pruneStats['fineGrid']=pruneStats.get('fineGrid',0)+1 252 alignment.shapeDist=d 253 queryMol.RemoveConformer(tConfId) 254 builder.gridSpacing=oSpace 255 return matchOk
256
257 - def PruneMatchesUsingShape(self,targetMol,target,queryMol,query,builder, 258 alignments,tgtConf=-1,queryConf=-1, 259 pruneStats=None):
260 if not hasattr(target,'medGrid'): 261 self._addCoarseAndMediumGrids(targetMol,target,tgtConf,builder) 262 263 logger.info("Shape-based Pruning") 264 i=0 265 nOrig = len(alignments) 266 nDone=0 267 while i < len(alignments): 268 removeIt=False 269 alg = alignments[i] 270 nDone+=1 271 if not nDone%100: 272 nLeft = len(alignments) 273 logger.info(' processed %d of %d. %d alignments remain'%((nDone, 274 nOrig, 275 nLeft))) 276 if not self._checkMatchShape(targetMol,target,queryMol,query,alg,builder, 277 targetConf=tgtConf,queryConf=queryConf, 278 pruneStats=pruneStats): 279 del alignments[i] 280 else: 281 i+=1
282
283 - def GetSubshapeAlignments(self,targetMol,target,queryMol,query,builder, 284 tgtConf=-1,queryConf=-1,pruneStats=None):
285 import time 286 if pruneStats is None: 287 pruneStats={} 288 logger.info("Generating triangle matches") 289 t1=time.time() 290 res = [x for x in self.GetTriangleMatches(target,query)] 291 t2=time.time() 292 logger.info("Got %d possible alignments in %.1f seconds"%(len(res),t2-t1)) 293 pruneStats['gtm_time']=t2-t1 294 if builder.featFactory: 295 logger.info("Doing feature pruning") 296 t1 = time.time() 297 self.PruneMatchesUsingFeatures(target,query,res,pruneStats=pruneStats) 298 t2 = time.time() 299 pruneStats['feats_time']=t2-t1 300 logger.info("%d possible alignments remain. (%.1f seconds required)"%(len(res),t2-t1)) 301 logger.info("Doing direction pruning") 302 t1 = time.time() 303 self.PruneMatchesUsingDirection(target,query,res,pruneStats=pruneStats) 304 t2 = time.time() 305 pruneStats['direction_time']=t2-t1 306 logger.info("%d possible alignments remain. (%.1f seconds required)"%(len(res),t2-t1)) 307 t1 = time.time() 308 self.PruneMatchesUsingShape(targetMol,target,queryMol,query,builder,res, 309 tgtConf=tgtConf,queryConf=queryConf, 310 pruneStats=pruneStats) 311 t2 = time.time() 312 pruneStats['shape_time']=t2-t1 313 return res
314
315 - def __call__(self,targetMol,target,queryMol,query,builder, 316 tgtConf=-1,queryConf=-1,pruneStats=None):
317 for alignment in self.GetTriangleMatches(target,query): 318 if builder.featFactory and \ 319 not self._checkMatchFeatures(target.skelPts,query.skelPts,alignment): 320 if pruneStats is not None: 321 pruneStats['features']=pruneStats.get('features',0)+1 322 continue 323 if not self._checkMatchDirections(target.skelPts,query.skelPts,alignment): 324 if pruneStats is not None: 325 pruneStats['direction']=pruneStats.get('direction',0)+1 326 continue 327 328 if not hasattr(target,'medGrid'): 329 self._addCoarseAndMediumGrids(targetMol,target,tgtConf,builder) 330 331 if not self._checkMatchShape(targetMol,target,queryMol,query,alignment,builder, 332 targetConf=tgtConf,queryConf=queryConf, 333 pruneStats=pruneStats): 334 continue 335 # if we made it this far, it's a good alignment 336 yield alignment
337 338 339 if __name__=='__main__': 340 from rdkit.six.moves import cPickle 341 tgtMol,tgtShape = cPickle.load(file('target.pkl','rb')) 342 queryMol,queryShape = cPickle.load(file('query.pkl','rb')) 343 builder = cPickle.load(file('builder.pkl','rb')) 344 aligner = SubshapeAligner() 345 algs = aligner.GetSubshapeAlignments(tgtMol,tgtShape,queryMol,queryShape,builder) 346 print(len(algs)) 347 348 from rdkit.Chem.PyMol import MolViewer 349 v = MolViewer() 350 v.ShowMol(tgtMol,name='Target',showOnly=True) 351 v.ShowMol(queryMol,name='Query',showOnly=False) 352 SubshapeObjects.DisplaySubshape(v,tgtShape,'target_shape',color=(.8,.2,.2)) 353 SubshapeObjects.DisplaySubshape(v,queryShape,'query_shape',color=(.2,.2,.8)) 354