2.2、Softmax Regression算法实践

2023-02-27,,

Softmax Regression算法实践

  有了上篇博客的理论知识,我们可以利用实现好的函数,来构建Softmax Regression分类器,在训练分类器的过程中,我们使用多分类数据作为训练数据:如图

1、利用训练数据对模型进行训练:

完整代码为:

 # -*- coding: UTF- -*-
# date://
# User:WangHong
import numpy as np
def gradientAscent(feature_data,label_data,k,maxCycle,alpha):
'''利用梯度下降法训练Softmax模型
:param feature_data: 特征
:param label_data: 标签
:param k: 类别个数
:param maxCycle: 最大迭代次数
:param alpha: 学习率
:return weights: 权重
'''
m,n = np.shape(feature_data)
weights = np.mat(np.ones((n,k)))#初始化权重
i =
while i<=maxCycle:
err = np.exp(feature_data*weights)
if i % == :
print("\t--------iter:",i,\
",cost:",cost(err,label_data))
rowsum = -err.sum(axis=)
rowsum = rowsum.repeat(k,axis = )
err = err/rowsum
for x in range(m):
err[x,label_data[x,]]+=
weights = weights+(alpha/m)*feature_data.T*err
i+=
return weights def cost(err,label_data):
'''
:param err: exp的值
:param label_data: 标签的值
:return: 损失函数的值
'''
m = np.shape(err)[]
sum_cost = 0.0
for i in range(m):
if err[i,label_data[i,]]/np.sum(err[i,:])>:
sum_cost -=np.log(err[i,label_data[i,]]/np.sum(err[i,:]))
else:
sum_cost -=
return sum_cost / m def load_data(inputfile):
'''导入训练数据
input: inputfile(string)训练样本的位置
output: feature_data(mat)特征
label_data(mat)标签
k(int)类别的个数
'''
f = open(inputfile) # 打开文件
feature_data = []
label_data = []
for line in f.readlines():
feature_tmp = []
feature_tmp.append() # 偏置项
lines = line.strip().split("\t")
for i in range(len(lines) - ):
feature_tmp.append(float(lines[i]))
label_data.append(int(lines[-])) feature_data.append(feature_tmp)
f.close() # 关闭文件
return np.mat(feature_data), np.mat(label_data).T, len(set(label_data)) def save_model(file_name, weights):
'''保存最终的模型
input: file_name(string):保存的文件名
weights(mat):softmax模型
'''
f_w = open(file_name, "w")
m, n = np.shape(weights)
for i in range(m):
w_tmp = []
for j in range(n):
w_tmp.append(str(weights[i, j]))
f_w.write("\t".join(w_tmp) + "\n")
f_w.close() if __name__=="__main__":
inputfile = "SoftInput.txt"
#导入数据
print("--------------1.load data-------------")
feature,label,k = load_data(inputfile)
#训练模型
print("--------------2.traing----------------")
weights = gradientAscent(feature,label,k,,0.2)
#保存模型
print("--------------3.save model------------")
save_model("weights",weights)

训练结果为

weights文件内容

2、用训练好的模型对数据进行预测:

预测的代码:

 # -*- coding: UTF-8 -*-
# date:2018/5/29
# User:WangHong
import numpy as np
import random as rd
def load_weights(weights_path):
'''导入训练好的Softmax模型
input: weights_path(string)权重的存储位置
output: weights(mat)将权重存到矩阵中
m(int)权重的行数
n(int)权重的列数
'''
f = open(weights_path)
w = []
for line in f.readlines():
w_tmp = []
lines = line.strip().split("\t")
for x in lines:
w_tmp.append(float(x))
w.append(w_tmp)
f.close()
weights = np.mat(w)
m, n = np.shape(weights)
return weights, m, n def load_data(num, m):
'''导入测试数据
input: num(int)生成的测试样本的个数
m(int)样本的维数
output: testDataSet(mat)生成测试样本
'''
testDataSet = np.mat(np.ones((num, m)))
for i in range(num):
testDataSet[i, 1] = rd.random() * 6 - 3 # 随机生成[-3,3]之间的随机数
testDataSet[i, 2] = rd.random() * 15 # 随机生成[0,15]之间是的随机数
return testDataSet def predict(test_data, weights):
'''利用训练好的Softmax模型对测试数据进行预测
input: test_data(mat)测试数据的特征
weights(mat)模型的权重
output: h.argmax(axis=1)所属的类别
'''
h = test_data * weights
return h.argmax(axis=1) # 获得所属的类别 def save_result(file_name, result):
'''保存最终的预测结果
input: file_name(string):保存最终结果的文件名
result(mat):最终的预测结果
'''
f_result = open(file_name, "w")
m = np.shape(result)[0]
for i in range(m):
f_result.write(str(result[i, 0]) + "\n")
f_result.close() if __name__ == "__main__":
# 1、导入Softmax模型
print("---------- 1.load model ----------------")
w, m, n = load_weights("weights")
# 2、导入测试数据
print("---------- 2.load data -----------------")
test_data = load_data(4000, m)
# 3、利用训练好的Softmax模型对测试数据进行预测
print("---------- 3.get Prediction ------------")
result = predict(test_data, w)
# 4、保存最终的预测结果
print("---------- 4.save prediction ------------")
save_result("result", result)

预测结果;

会生成一个result文件用于存储预测结果

在本次测试中随机生成4000个样本,最终分类的结果为:

2.2、Softmax Regression算法实践的相关教程结束。

《2.2、Softmax Regression算法实践.doc》

下载本文的Word格式文档,以方便收藏与打印。