Package rdkit :: Package ML :: Package NaiveBayes :: Module CrossValidate
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.NaiveBayes.CrossValidate

 1  # $Id$ 
 2  # 
 3  #  Copyright (C) 2004-2005 Rational Discovery LLC. 
 4  #   All Rights Reserved 
 5  # 
 6  """ handles doing cross validation with naive bayes models 
 7  and evaluation of individual models 
 8   
 9  """ 
10  from __future__ import print_function 
11  from rdkit.ML.NaiveBayes.ClassificationModel import NaiveBayesClassifier 
12  from rdkit.ML.Data import SplitData 
13  try: 
14    from rdkit.ML.FeatureSelect import CMIM 
15  except ImportError: 
16    CMIM=None 
17   
18 -def makeNBClassificationModel(trainExamples, attrs, nPossibleValues, nQuantBounds, 19 mEstimateVal=-1.0, 20 useSigs=False, 21 ensemble=None,useCMIM=0, 22 **kwargs) :
23 if CMIM is not None and useCMIM > 0 and useSigs and not ensemble: 24 ensemble = CMIM.SelectFeatures(trainExamples,useCMIM,bvCol=1) 25 if ensemble: 26 attrs = ensemble 27 model = NaiveBayesClassifier(attrs, nPossibleValues, nQuantBounds, 28 mEstimateVal=mEstimateVal,useSigs=useSigs) 29 30 31 model.SetTrainingExamples(trainExamples) 32 model.trainModel() 33 return model
34
35 -def CrossValidate(NBmodel, testExamples, appendExamples=0) :
36 37 nTest = len(testExamples) 38 assert nTest,'no test examples: %s'%str(testExamples) 39 badExamples = [] 40 nBad = 0 41 preds = NBmodel.ClassifyExamples(testExamples, appendExamples) 42 assert len(preds) == nTest 43 44 for i in range(nTest): 45 testEg = testExamples[i] 46 trueRes = testEg[-1] 47 res = preds[i] 48 49 if (trueRes != res) : 50 badExamples.append(testEg) 51 nBad += 1 52 return float(nBad)/nTest, badExamples
53
54 -def CrossValidationDriver(examples, attrs, nPossibleValues, nQuantBounds, 55 mEstimateVal=0.0, 56 holdOutFrac=0.3, modelBuilder=makeNBClassificationModel, 57 silent=0, calcTotalError=0, **kwargs) :
58 nTot = len(examples) 59 if not kwargs.get('replacementSelection',0): 60 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac, 61 silent=1,legacy=1, 62 replacement=0) 63 else : 64 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac, 65 silent=1,legacy=0, 66 replacement=1) 67 68 trainExamples = [examples[x] for x in trainIndices] 69 testExamples = [examples[x] for x in testIndices] 70 71 NBmodel = modelBuilder(trainExamples, attrs, nPossibleValues, nQuantBounds, 72 mEstimateVal,**kwargs) 73 74 if not calcTotalError: # 75 xValError, badExamples = CrossValidate(NBmodel, testExamples,appendExamples=1) 76 else: 77 xValError,badExamples = CrossValidate(NBmodel, examples,appendExamples=0) 78 79 if not silent: 80 print('Validation error was %%%4.2f'%(100*xValError)) 81 NBmodel._trainIndices = trainIndices 82 return NBmodel, xValError
83