(三)朴素贝叶斯
定义 :
朴素贝叶斯是基于贝叶斯定理和特征条件独立假设的分类方法。
首先学习输入/输出的联合概率分布,然后基于此模型,对给定的输入$x$,利用贝叶斯定理求出后验概率最大的输出$y$。
模型:
首先学习先验概率分布:$P(Y=c_k),k=1,2,…,K$ , $c_k$代表某一类,也就是计算该类别的概率(在样本中我们已知)
然后学习条件概率分布:$P(X=x|Y=c_k)=P(X^{1}=x^{1},…,X^{n}=x^{n}|Y=c_k)$,给定一个类别$c_k$,计算该样本各个特征的概率,比如该样本第一个特征为
朴素贝叶斯法对条件概率分布作了条件独立性的假设:$$P(X^{(1)}=x^{(1)}|Y=c_k)P(X^{(2)}=x^{(2)}|Y=c_k)…P(X^{(j)}=x^{(j)}|Y=c_k)$$
上式变成:$$\prod_{j=1}^{n}P(X^{(j)}=x^{(j)}|Y=c_k)$$
在分类时,通过学习到的模型计算后验概率分布,由贝叶斯定理得到:
$$P(Y=c_k|X=x)=\frac{P(X=x|Y=c_k)P(Y=c_k)}{\sum_{k}P(X=x|Y=c_k)P(Y=c_k)}$$
将条件独立性假设得到的等式代入,并且注意到分母都是相同的,所以得到朴素贝叶斯分类器:
$$y=argmax_{c_k}P(Y=c_k)\prod_{j=1}P(X^{(j)}=x^{(j)}|Y=c_k)$$
算法:使用极大似然估计法估计相应的先验概率率:$$P(Y=c_k)=\frac{\sum_{i=1}^{N}I(y_i=c_k)}{N},k=1,2,…,K$$
以及条件概率:
$$P(X^{(j)}=a_{jl}|Y=c_k)=\frac{\sum_{i=1}^{N}I(x_{i}^{(j)}=a_{jl},y_i=c_k)}{\sum_{i=1}^{N}I(y_{i}=c_k)}$$
计算条件独立性假设下的实例各个取值的可能性,选取其中的最大值作为输出。
使用贝叶斯估计虽然保证了所有连乘项的概率都大于0,不会再出现某一项为0结果为0的情况。但若一个样本数据时高维的,比如说100维(100其实并不高),连乘项都是0-1之间的,那100个0-1之间的数相乘,最后的数一定是非常非常小了,可能无限接近于0。对于程序而言过于接近0的数可能会造成下溢出,也就是精度不够表达了。所以我们会给整个连乘项取对数,这样哪怕所有连乘最后结果无限接近0,那取完log以后数也会变得很大(虽然是负的很大),计算机就可以表示了。同样,多项连乘取对数,对数的连乘可以表示成对数的相加,在计算上也简便了。所以在实际运用中,不光需要使用贝叶斯估计(保证概率不为0),同时也要取对数(保证连乘结果不下溢出)。
代码:
参考代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
|
''' 数据集:Mnist 训练集数量:60000 测试集数量:10000 ------------------------------ 运行结果: 正确率:84.3% 运行时长:103s '''
import numpy as np import time
def loadData(fileName): ''' 加载文件 :param fileName:要加载的文件路径 :return: 数据集和标签集 ''' dataArr = []; labelArr = [] fr = open(fileName) for line in fr.readlines(): curLine = line.strip().split(',') dataArr.append([int(int(num) > 128) for num in curLine[1:]]) labelArr.append(int(curLine[0])) return dataArr, labelArr
def NaiveBayes(Py, Px_y, x): ''' 通过朴素贝叶斯进行概率估计 :param Py: 先验概率分布 :param Px_y: 条件概率分布 :param x: 要估计的样本x :return: 返回所有label的估计概率 ''' featrueNum = 784 classNum = 10 P = [0] * classNum for i in range(classNum): sum = 0 for j in range(featrueNum): sum += Px_y[i][j][x[j]] P[i] = sum + Py[i]
return P.index(max(P))
def test(Py, Px_y, testDataArr, testLabelArr): ''' 对测试集进行测试 :param Py: 先验概率分布 :param Px_y: 条件概率分布 :param testDataArr: 测试集数据 :param testLabelArr: 测试集标记 :return: 准确率 ''' errorCnt = 0 for i in range(len(testDataArr)): presict = NaiveBayes(Py, Px_y, testDataArr[i]) if presict != testLabelArr[i]: errorCnt += 1 return 1 - (errorCnt / len(testDataArr))
def getAllProbability(trainDataArr, trainLabelArr): ''' 通过训练集计算先验概率分布和条件概率分布 :param trainDataArr: 训练数据集 :param trainLabelArr: 训练标记集 :return: 先验概率分布和条件概率分布 ''' featureNum = 784 classNum = 10
Py = np.zeros((classNum, 1)) for i in range(classNum): Py[i] = ((np.sum(np.mat(trainLabelArr) == i)) + 1) / (len(trainLabelArr) + 10) Py = np.log(Py)
Px_y = np.zeros((classNum, featureNum, 2)) for i in range(len(trainLabelArr)): label = trainLabelArr[i] x = trainDataArr[i] for j in range(featureNum): Px_y[label][j][x[j]] += 1
for label in range(classNum): for j in range(featureNum): Px_y0 = Px_y[label][j][0] Px_y1 = Px_y[label][j][1] Px_y[label][j][0] = np.log((Px_y0 + 1) / (Px_y0 + Px_y1 + 2)) Px_y[label][j][1] = np.log((Px_y1 + 1) / (Px_y0 + Px_y1 + 2))
return Py, Px_y
if __name__ == "__main__": start = time.time() print('start read transSet') trainDataArr, trainLabelArr = loadData('../Mnist/mnist_train.csv')
print('start read testSet') testDataArr, testLabelArr = loadData('../Mnist/mnist_test.csv')
print('start to train') Py, Px_y = getAllProbability(trainDataArr, trainLabelArr)
print('start to test') accuracy = test(Py, Px_y, testDataArr, testLabelArr)
print('the accuracy is:', accuracy) print('time span:', time.time() -start)
|