Package rdkit :: Package ML :: Package KNN :: Module CrossValidate
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.KNN.CrossValidate

  1  # 
  2  #  Copyright (C) 2000-2008  greg Landrum 
  3  # 
  4  """ handles doing cross validation with k-nearest neighbors model 
  5   
  6  and evaluation of individual models 
  7   
  8  """ 
  9  from __future__ import print_function 
 10  from rdkit.ML.KNN.KNNClassificationModel import KNNClassificationModel 
 11  from rdkit.ML.KNN.KNNRegressionModel import KNNRegressionModel 
 12  from rdkit.ML.KNN import DistFunctions 
 13  from rdkit.ML.Data import SplitData 
 14   
15 -def makeClassificationModel(numNeigh, attrs, distFunc) :
16 return KNNClassificationModel(numNeigh, attrs, distFunc)
17 -def makeRegressionModel(numNeigh, attrs, distFunc) :
18 return KNNRegressionModel(numNeigh, attrs, distFunc)
19
20 -def CrossValidate(knnMod,testExamples,appendExamples=0):
21 """ 22 Determines the classification error for the testExamples 23 24 **Arguments** 25 26 - tree: a decision tree (or anything supporting a _ClassifyExample()_ method) 27 28 - testExamples: a list of examples to be used for testing 29 30 - appendExamples: a toggle which is passed along to the tree as it does 31 the classification. The trees can use this to store the examples they 32 classify locally. 33 34 **Returns** 35 36 a 2-tuple consisting of: 37 """ 38 nTest = len(testExamples) 39 40 if isinstance(knnMod,KNNClassificationModel): 41 badExamples = [] 42 nBad = 0 43 for i in range(nTest): 44 testEx = testExamples[i] 45 trueRes = testEx[-1] 46 res = knnMod.ClassifyExample(testEx, appendExamples) 47 if (trueRes != res) : 48 badExamples.append(testEx) 49 nBad += 1 50 return float(nBad)/nTest, badExamples 51 elif isinstance(knnMod,KNNRegressionModel): 52 devSum=0.0 53 for i in range(nTest): 54 testEx = testExamples[i] 55 trueRes = testEx[-1] 56 res = knnMod.PredictExample(testEx, appendExamples) 57 devSum += abs(trueRes-res) 58 return devSum/nTest,None 59 raise ValueError("Unrecognized Model Type")
60
61 -def CrossValidationDriver(examples, attrs, nPossibleValues, numNeigh, 62 modelBuilder=makeClassificationModel, 63 distFunc=DistFunctions.EuclideanDist, 64 holdOutFrac=0.3, 65 silent=0, 66 calcTotalError=0, 67 **kwargs) :
68 """ Driver function for building a KNN model of a specified type 69 70 **Arguments** 71 72 - examples: the full set of examples 73 74 - numNeigh: number of neighbors for the KNN model (basically k in k-NN) 75 76 - knnModel: the type of KNN model (a classification vs regression model) 77 78 - holdOutFrac: the fraction of the data which should be reserved for the hold-out set 79 (used to calculate error) 80 81 - silent: a toggle used to control how much visual noise this makes as it goes 82 83 - calcTotalError: a toggle used to indicate whether the classification error 84 of the tree should be calculated using the entire data set (when true) or just 85 the training hold out set (when false) 86 """ 87 88 nTot = len(examples) 89 if not kwargs.get('replacementSelection',0): 90 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac, 91 silent=1,legacy=1, 92 replacement=0) 93 else: 94 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac, 95 silent=1,legacy=0, 96 replacement=1) 97 trainExamples = [examples[x] for x in trainIndices] 98 testExamples = [examples[x] for x in testIndices] 99 100 101 nTrain = len(trainExamples) 102 103 if not silent: 104 print("Training with %d examples"%(nTrain)) 105 106 knnMod = modelBuilder(numNeigh, attrs, distFunc) 107 108 knnMod.SetTrainingExamples(trainExamples) 109 knnMod.SetTestExamples(testExamples) 110 111 if not calcTotalError: 112 xValError,badExamples = CrossValidate(knnMod, testExamples,appendExamples=1) 113 else: 114 xValError,badExamples = CrossValidate(knnMod, examples,appendExamples=0) 115 116 if not silent : 117 print('Validation error was %%%4.2f'%(100*xValError)) 118 119 knnMod._trainIndices = trainIndices 120 return knnMod, xValError
121