Package pycv :: Package cs :: Package ml :: Package cla :: Module nb
[hide private]
[frames] | no frames]

Source Code for Module pycv.cs.ml.cla.nb

 1  # PyCV - A Computer Vision Package for Python Incorporating Fast Training of Face Detection 
 2   
 3  # Copyright 2007 Nanyang Technological University, Singapore. 
 4  # Authors: Minh-Tri Pham, Viet-Dung D. Hoang, and Tat-Jen Cham. 
 5   
 6  # This file is part of PyCV. 
 7   
 8  # PyCV is free software: you can redistribute it and/or modify 
 9  # it under the terms of the GNU General Public  
10  # License as published by the Free Software Foundation, either version  
11  # 3 of the License, or (at your option) any later version. 
12   
13  # PyCV is distributed in the hope that it will be useful, 
14  # but WITHOUT ANY WARRANTY; without even the implied warranty of 
15  # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
16  # GNU General Public License for more details. 
17   
18  # You should have received a copy of the GNU General Public License 
19  # along with this program.  If not, see <http://www.gnu.org/licenses/>. 
20   
21  # --------------------------------------------------------------------- 
22  #!/usr/bin/env python 
23   
24   
25  __all__ = ['NBClassifier','train_NBClassifier'] 
26   
27  from copy import copy 
28  from math import fabs, log, exp 
29  from numpy import array, where, zeros, ones, dot, prod 
30  from numpy import log as NP_log 
31  from numpy import exp as NP_exp 
32  from numpy import sqrt as NP_sqrt 
33   
34  from cla import Classifier 
35  from dataset import WeightedCDataset 
36   
37  from pycv.cs.ml import PredictPdfInterface, OnlineLearningInterface 
38  from pycv import tprint 
39  from pycv.ext import nb_NBClassifier_predict_pdf 
40   
41  #------------------------------------------------------------------------------- 
42  # Naive Bayes Classifier 
43  #------------------------------------------------------------------------------- 
44   
45 -class NBClassifier(PredictPdfInterface, OnlineLearningInterface, \ 46 Classifier): # Naive Bayes classifier
47
48 - def __init__(self,nclasses,stats2e):
49 Classifier.__init__(self,nclasses) 50 self.stats2e = stats2e
51
52 - def predict_pdf(self, input_point, *args, **kwds):
53 ip = input_point 54 pdf = zeros(self.nclasses) 55 56 #gstat = self.stats2e.get_cond_stat 57 # for j in xrange(self.nclasses): 58 # pdf[j] = NP_log(gstat(j,0).squeeze()) 59 # pdf[j] -= 0.5 * (((ip - gstat(j,1)) / gstat(j,2))**2).ravel().sum() 60 # pdf[j] -= NP_log(gstat(j,2)).sum() 61 # pdf = NP_exp(pdf) 62 #return pdf/pdf.sum() 63 64 A = self.stats2e.A 65 d = len(ip) 66 J = self.nclasses 67 nb_NBClassifier_predict_pdf(ip, pdf, A, d, J) 68 return pdf
69
70 - def learn(self, input_point, j, w = None, *args, **kwds):
71 """Learn incrementally with a new input point, its class, and optionally its weight. 72 73 Input: 74 input_point: a new input point 75 j: its corresponding class 76 w: optionally its weight, or 1 if not specified 77 """ 78 self.stats2e.learn(input_point, j, w)
79 80
81 -def train_NBClassifier( cd ):
82 """Train a NBClassifier using a WeightedCDataset 83 84 Input: 85 classification_dataset: a WeightedCDataset 86 Output: 87 an NBClassifier 88 """ 89 return NBClassifier(cd.J,cd.compute_Stats2e())
90