应用SVM对MNIST数据集进行分类

2022-07-31,,,

MNIST是机器学习领域十分经典的一个手写数字数据集,共60000张训练图像,10000张测试图像,图像大小为28*28.

MNIST百度网盘下载地址:https://pan.baidu.com/s/1k1Ji6amaUhDG6jfdcl_kNg  提取码:nykv

将下载下来的压缩包解压后放到源代码所在的文件夹下即可。

如运行缺少相关python库,可往https://www.lfd.uci.edu/~gohlke/pythonlibs/下载

SVM分类MNIST的源代码如下:

from sklearn import svm
import numpy as np
from time import time
from sklearn.metrics import accuracy_score
from struct import unpack
from sklearn.model_selection import GridSearchCV

def readimage(path):
    with open(path, 'rb') as f:
        magic, num, rows, cols = unpack('>4I', f.read(16))
        img = np.fromfile(f, dtype=np.uint8).reshape(num, 784)
    return img

def readlabel(path):
    with open(path, 'rb') as f:
        magic, num = unpack('>2I', f.read(8))
        lab = np.fromfile(f, dtype=np.uint8)
    return lab

def main():
    train_data  = readimage("train-images.idx3-ubyte")
    train_label = readlabel("train-labels.idx1-ubyte")
    test_data   = readimage("t10k-images.idx3-ubyte")
    test_label  = readlabel("t10k-labels.idx1-ubyte")
    svc=svm.SVC()
    parameters = {'kernel':['rbf'], 'C':[1]}
    print("Train...")
    clf=GridSearchCV(svc,parameters,n_jobs=-1)
    start = time()
    clf.fit(train_data, train_label)
    end = time()
    t = end - start
    print('Train:%dmin%.3fsec' % (t//60, t - 60 * (t//60)))
    prediction = clf.predict(test_data)
    print("accuracy: ", accuracy_score(prediction, test_label))
    accurate=[0]*10
    sumall=[0]*10
    i=0
    while i<len(test_label):
        sumall[test_label[i]]+=1
        if prediction[i]==test_label[i]:
            accurate[test_label[i]]+=1
        i+=1
    print("分类正确的:",accurate)
    print("总的测试标签:",sumall)

if __name__ == '__main__':
    main()

程序通过readimage和readlabel函数读入数据后创建svm分类器,并用parameter添加相应的参数,这里使用GridSearchCV将参数作为输入优化网络,这里输入的parameter对应分类器唯一,可进行添加以达到优化参数的目的,代码中使用GridSearchCV的主要目的是引入n_jobs让cpu进行多线程处理,n_jobs=-1时程序的并行数将和cpu的核数一致,从而极大的加速程序的运行。在i5-8300H的四核CPU中训练时间为26min。

源代码训练时的正确率如下:

欢迎评论区交流。

友情链接:svm.SVC参数详解:https://blog.csdn.net/weixin_41990278/article/details/93137009

                  GridSearchCV参数详解:https://blog.csdn.net/foneone/article/details/89985045

本文地址:https://blog.csdn.net/qq_43160985/article/details/107675241

《应用SVM对MNIST数据集进行分类.doc》

下载本文的Word格式文档,以方便收藏与打印。