1200字范文,内容丰富有趣,写作的好帮手!
1200字范文 > tensorflow2 搭建LeNet5训练MINST手写数字数据集并用c++ opencv4.5.5 DNN加载模型预测结果

tensorflow2 搭建LeNet5训练MINST手写数字数据集并用c++ opencv4.5.5 DNN加载模型预测结果

时间:2022-08-13 14:04:21

相关推荐

tensorflow2 搭建LeNet5训练MINST手写数字数据集并用c++ opencv4.5.5 DNN加载模型预测结果

目录

一、LeNet5网络介绍二、环境搭建三、网络搭建以及训练3.1、加载数据集3.2、网络搭建3.3、模型训练3.4、模型固化四、c++ opencv加载模型

一、LeNet5网络介绍

LeNet5 这个网络包含了深度学习的基本模块:卷积层,池化层,全链接层。是其他深度学习模型的基础。LeNet-5共有7层,不包含输入,每层都包含可训练参数;每个层有多个Feature Map,每个FeatureMap通过一种卷积滤波器提取输入的一种特征,然后每个FeatureMap有多个神经元。

二、环境搭建

本人环境配置如下:

pycharm

vs

Anaconda3

tensorflow=2.3

opencv=4.5.5

前几个安装相对轻松,直接上官网安装即可,tensorflow使用pip命令安装,c++ opencv相对较为麻烦,可以参考本人以前的安装方法:c++ opencv 学习笔记(一) Visual Studio + OpenCV4.5.5 配置详解

三、网络搭建以及训练

3.1、加载数据集

tensorflow内置了MINST数据集,从tensorflow中导入即可

import tensorflow as tfmnist = tf.keras.datasets.mnisttrain, test = mnist.load_data()

将数据按照batch提供给网络模型

import numpy as npclass MNISTData:def __init__(self, data, need_shuffle, batch_size=128):""":param datas: 数据集,格式为 data,label:param shuffle: 是否随机打乱数据 True or False:param batch_size: 一批数据大小"""self._data = data[0]self._labels = data[1]self.num_examples = self._data.shape[0]self._need_shuffle = need_shuffleself._indicator = 0self._batch_size = batch_sizeif self._need_shuffle:self._shuffle_data()def __iter__(self):return selfdef _shuffle_data(self):p = np.random.permutation(self.num_examples)self._data = self._data[p]self._labels = self._labels[p]def next_batch(self):end_indicator = self._indicator + self._batch_sizeif end_indicator > self.num_examples:if self._need_shuffle:self._shuffle_data()self._indicator = 0end_indicator = self._batch_sizeelse:self._indicator = 0end_indicator = self._batch_sizeif end_indicator > self.num_examples:raise StopIterationbatch_data = self._data[self._indicator: end_indicator] / 255.0 # 归一化batch_labels = self._labels[self._indicator: end_indicator]self._indicator = end_indicatorreturn batch_data, batch_labelsdef __next__(self):return self.next_batch()train_dataset = dataset.MNISTData(train, True)test_dateset = dataset.MNISTData(test, False)

查看数据集

def display(train_images, train_labels):plt.figure(figsize=(10,10))for i in range(25):plt.subplot(5,5,i+1)plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(train_images[i], cmap=plt.cm.binary)plt.xlabel(train_labels[i])plt.show()for data in train_dataset:display(*data)

3.2、网络搭建

使用tensorflow中的keras搭建网络结构,激活函数使用Mish

import tensorflow as tffrom tensorflow.keras.models import Modelfrom tensorflow.keras.layers import *from tensorflow.keras.utils import get_custom_objectsclass Mish(Activation):def __init__(self, activate, **kwargs):super(Mish, self).__init__(activate, **kwargs)self.__name__ = "Mish"def mish(inputs):return inputs * tf.math.tanh(tf.math.softplus(inputs))def LeNet5(input_shape=[32, 32, 3]):get_custom_objects().update({'Mish': Mish(mish)})#输入层inputs = Input(shape=input_shape)#第一个卷积-池化层conv1 = Conv2D(6, 5, activation="relu", padding='same')(inputs)pool1 = MaxPooling2D((2, 2))(conv1)#第二个卷积-池化层conv2 = Conv2D(16, 5, activation="relu", padding='same')(pool1)pool2 = MaxPooling2D((2, 2))(conv2)#第三个卷积层conv2 = Conv2D(120, 5, activation="relu", padding='same')(pool2)fc = Flatten()(conv2)#全连接层fc1 = Dense(120, activation="relu")(fc)#输出层fc2 = Dense(10, activation="softmax")(fc1)model = Model(inputs, fc2)return modelmodel = LeNet5(input_shape=[28, 28, 1])

定义损失函数以及优化器

pile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01),loss=tf.keras.losses.SparseCategoricalCrossentropy(),metrics=['accuracy'])

保存模型

model_filepath = 'model/'checkpoint_filepath = model_filepath + 'tmp/'cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_filepath,save_best_only=True,save_weights_only=True,monitor='accuracy',mode='max')

3.3、模型训练

开始训练(俗称炼丹)

# 是否使用GPUuse_gpu = Truetf.debugging.set_log_device_placement(True)if use_gpu:gpus = tf.config.experimental.list_physical_devices(device_type='GPU')if gpus:for gpu in gpus:tf.config.experimental.set_memory_growth(device=gpu, enable=True)tf.print(gpu)else:os.environ["CUDA_VISIBLE_DEVICE"] = "-1"else:os.environ["CUDA_VISIBLE_DEVICE"] = "-1"# TensorBoard可视化工具log_path = 'logging/'logging = tf.keras.callbacks.TensorBoard(log_dir=log_path)model_filepath = 'model/'checkpoint_filepath = model_filepath + 'tmp/'history = model.fit(train_dataset,epochs=10,steps_per_epoch=train_dataset.num_examples // BATCH_SIZE + 1,validation_data=test_dateset,validation_steps=test_dateset.num_examples // BATCH_SIZE + 1,callbacks=[cp_callback, logging ])model.load_weights(checkpoint_filepath)model.save(model_filepath + 'model')

可视化训练过程

TensorBoard是一个可视化工具,它可以用来展示网络图、张量的指标变化、张量的分布情况等。进入logging文件夹的上一层文件夹,在DOS窗口运行命令:

tensorboard --logdir=./logging

在浏览器输入网址:http://localhost:6006,或者输入上图提示的网址,即可查看生成图。

3.4、模型固化

import tensorflow as tffrom tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2def export_frozen_graph(model, name, input_size) :f = tf.function(lambda x: model(x))f = f.get_concrete_function(x=tf.TensorSpec(shape=[None, input_size[0], input_size[1], input_size[2]], dtype=tf.float32))f2 = convert_variables_to_constants_v2(f)graph_def = f2.graph.as_graph_def()# Export frozen graphwith tf.io.gfile.GFile(name, 'wb') as f:f.write(graph_def.SerializeToString())export_frozen_graph(model, model_filepath + 'frozen_graph.pb', (input_size, input_size, 1))

四、c++ opencv加载模型

#include <opencv2/opencv.hpp>#include <iostream>#include <vector>using namespace std;//多分类问题用这个函数判断类别,二分类的话不用也行std::vector<int> Argmax(cv::Mat x){std::vector<int> res;for (int i = 0; i < x.rows; i++){int maxIdx = 0;float maxNum = 0.0;for (int j = 0; j < x.cols; j++){float tmp = x.at<float>(i, j);if (tmp > maxNum){maxIdx = j; //更新最优值序号maxNum = tmp; //更新最优值}}res.push_back(maxIdx); //最优预测值的序号}return res;}int main(){//cv加载模型cv::dnn::Net net = cv::dnn::readNetFromTensorflow("frozen_graph.pb");//加载图片cv::Mat src = cv::imread("8.jpg", cv::IMREAD_COLOR);cv::Mat img = src;cv::cvtColor(img, img, cv::COLOR_BGR2GRAY);//调整图片大小cv::resize(img, img, cv::Size(28, 28));//归一化 0-1之间img.convertTo(img, CV_32FC1, 1.f / 255.f, -1.f);//格式转化cv::dnn::blobFromImage(img, img, 1.0, cv::Size(), cv::Scalar(), false, false, CV_32F);//将数据喂给网络net.setInput(img);//前向传播,得到传播结果cv::Mat pred = net.forward();//输出结果vector<int> res = Argmax(pred);//输出标签stringstream ss;string str;ss << "label:" << res[0];ss >> str;//放大图片便于观察cv::resize(src, src, cv::Size(280, 280));cv::putText(src, str, cv::Size(0, 40), cv::FONT_HERSHEY_COMPLEX, 1, cv::Scalar(0, 255, 0), 1);cv::imshow("", src);cv::waitKey();}

结果如下:

有需要的可以下载完整项目链接进行测试:

GitHub:/small-guang/LeNet5

CSDN:/download/qq_45723275/77992089

其他项目链接:tensorflow2.3 搭建 vgg16训练cifar10数据集

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。