| Home | Trees | Indices | Help |
|---|
|
|
1 # PyCV - A Computer Vision Package for Python Incorporating Fast Training of Face Detection 2 3 # Copyright 2007 Nanyang Technological University, Singapore. 4 # Authors: Minh-Tri Pham, Viet-Dung D. Hoang, and Tat-Jen Cham. 5 6 # This file is part of PyCV. 7 8 # PyCV is free software: you can redistribute it and/or modify 9 # it under the terms of the GNU General Public 10 # License as published by the Free Software Foundation, either version 11 # 3 of the License, or (at your option) any later version. 12 13 # PyCV is distributed in the hope that it will be useful, 14 # but WITHOUT ANY WARRANTY; without even the implied warranty of 15 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 # GNU General Public License for more details. 17 18 # You should have received a copy of the GNU General Public License 19 # along with this program. If not, see <http://www.gnu.org/licenses/>. 20 21 # --------------------------------------------------------------------- 22 #!/usr/bin/env python 23 24 25 __all__ = ['NBClassifier','train_NBClassifier'] 26 27 from copy import copy 28 from math import fabs, log, exp 29 from numpy import array, where, zeros, ones, dot, prod 30 from numpy import log as NP_log 31 from numpy import exp as NP_exp 32 from numpy import sqrt as NP_sqrt 33 34 from cla import Classifier 35 from dataset import WeightedCDataset 36 37 from pycv.cs.ml import PredictPdfInterface, OnlineLearningInterface 38 from pycv import tprint 39 from pycv.ext import nb_NBClassifier_predict_pdf 40 41 #------------------------------------------------------------------------------- 42 # Naive Bayes Classifier 43 #------------------------------------------------------------------------------- 4445 -class NBClassifier(PredictPdfInterface, OnlineLearningInterface, \ 46 Classifier): # Naive Bayes classifier47 5153 ip = input_point 54 pdf = zeros(self.nclasses) 55 56 #gstat = self.stats2e.get_cond_stat 57 # for j in xrange(self.nclasses): 58 # pdf[j] = NP_log(gstat(j,0).squeeze()) 59 # pdf[j] -= 0.5 * (((ip - gstat(j,1)) / gstat(j,2))**2).ravel().sum() 60 # pdf[j] -= NP_log(gstat(j,2)).sum() 61 # pdf = NP_exp(pdf) 62 #return pdf/pdf.sum() 63 64 A = self.stats2e.A 65 d = len(ip) 66 J = self.nclasses 67 nb_NBClassifier_predict_pdf(ip, pdf, A, d, J) 68 return pdf6971 """Learn incrementally with a new input point, its class, and optionally its weight. 72 73 Input: 74 input_point: a new input point 75 j: its corresponding class 76 w: optionally its weight, or 1 if not specified 77 """ 78 self.stats2e.learn(input_point, j, w)79 8082 """Train a NBClassifier using a WeightedCDataset 83 84 Input: 85 classification_dataset: a WeightedCDataset 86 Output: 87 an NBClassifier 88 """ 89 return NBClassifier(cd.J,cd.compute_Stats2e())90
| Home | Trees | Indices | Help |
|---|
| Generated by Epydoc 3.0beta1 on Mon Feb 25 10:24:20 2008 | http://epydoc.sourceforge.net |