发布于 2026-01-06 2 阅读
0

基于卷积神经网络的手写数字识别 引言 图像分类 卷积神经网络 手写数字识别 总结 参考文献

基于卷积神经网络的手写数字识别

介绍

图像分类

卷积神经网络

手写数字识别

概括

参考

介绍

在这篇博文中,我想分享一个我开发的用于识别手写数字图像的小应用程序,以及我在开发过程中学到的经验教训。过去,我在机器学习领域主要从事文本数据方面的工作。图像数据的模式识别对我来说是全新的领域,但我认为这是一项非常有用的技能。

本文结构如下。首先,我们将介绍图像分类的概念,以及它与其他问题(例如文本分类)相比的独特之处。第二部分将介绍一种名为卷积神经网络(CNN)的机器学习模型,该模型常用于图像分类。第三部分将展示一个通过网页界面进行手写数字分类的示例应用。最后,我们将总结本文的主要发现和观点。

该应用程序使用 Scala、HTML、CSS 和 JavaScript 编写。不过,其中的概念也可以迁移到其他语言。我也尽量避免涉及过多的数学细节,而是专注于读者理解算法所需的基本信息。如果您有兴趣深入了解该主题,我建议您参考其他教程、研究论文或书籍。

图像分类

机器学习算法期望数据以计算机能够理解的某种数值格式表示。例如,在使用概率模型时,数据必须符合模型所用分布所期望的格式。

例如,考虑多项混合模型[1]。要使用这种模型,需要将数据转换为计数。在文本中,可以通过为每个文档中每个聚类的每个可能词引入一个计数变量来实现。该模型非常简单,适用于许多应用场景。然而,它有一个很大的缺点:它会丢弃大量信息,例如词项共现情况和在文档中的位置。

对于图像数据而言,这个问题更加突出。虽然你仍然可以通过统计字数来判断电子邮件是否为垃圾邮件,但如果只统计特定颜色像素的数量,识别包含猫的图像就困难得多。文本数据是一维的,即一系列词语,而图像至少是二维的,即像素矩阵,并且像素之间的空间关系包含了更多信息。

幸运的是,我们可以使用其他考虑空间信息的模型。卷积神经网络(CNN)就是一种非常常用的模型。虽然这方面的研究已经持续了一段时间[2],但近年来基于GPU的训练方式在模型性能方面取得了重大突破[3]。

我们如何用计算机表示原始图像?计算机图像的最小可寻址元素是像素。每个像素都有一个位置和一个颜色。我们可以用不同的方式表示颜色。彩色图像常用的方案是红蓝绿 (RGB) 方案。如果我们为每个像素分配 24 位,即红、蓝、绿三种颜色各占 8 位,那么我们可以分别编码 256 种不同的红色、蓝色和绿色色调。将它们组合起来,我们就可以表示大约 1600 万种不同的颜色。

为了在代码中访问图像信息,我们可以将像素存储在一个二维数组(即矩阵)中。虽然可以将所有三个颜色通道合并到该矩阵的单个坐标中,但只存储一个数值效率更高。这样,每个通道就对应一个矩阵,从而可以将灰度图像表示为矩阵,将彩色图像表示为三维张量。下图展示了对于一个 3×3 像素的图像,这个过程是如何进行的。需要注意的是,在实际图像中,颜色通常会混合在一起。

图像像素

现在让我们来看看 CNN 的工作原理,以及如何使用这种图像表示作为基于 CNN 的分类器的输入。

卷积神经网络

建筑学

神经网络是一种机器学习模型,它由相互连接的神经元层组成。每个神经元包含一个数值,称为激活值。连接被赋予权,权重描述了传递到所连接神经元的信号强度。

输入数据被送入第一层,每个输入神经元都会被一定程度地激活。基于权重和激活函数,神经网络决定激活下一层中的哪些神经元以及激活的强度。这种所谓的前馈过程会持续进行,直到输出神经元被激活为止。神经网络的架构对其能够处理的数据类型和性能有着巨大的影响。下图展示了一个包含三层的简单神经网络。

简单神经网络

卷积神经网络(CNN)是一种特殊的神经网络。它可以分为两部分:特征学习部分和分类部分。每一部分都由一个或多个层组成。特征学习通常通过结合两种类型的层来实现:卷积层池化层。然后,基于学习到的特征,通过全连接(也称为密集层)进行分类。此外,还有一个输入层,用于存储图像数据;以及一个输出层,用于存储我们试图预测的不同类别。

下图展示了一个包含一个卷积层、一个池化层和一个全连接层的卷积神经网络(CNN)。任务是预测图像中是否包含一只猫。位于输入层和输出层之间的层也称为隐藏层,因为当将模型视为黑盒时,它们的状态是不可见的。

CNN示例

考虑单通道颜色图像时,输入层可以是原始图像矩阵,也可以是预处理后的图像矩阵,例如裁剪、调整大小、颜色值缩放到 0 到 1 之间的图像矩阵等等。输出层表示由最后一个隐藏层分配的每个可能类别的权重。下一小节我们将更详细地了解不同类型的隐藏层。

卷积层

卷积层负责将滤波器与前一层进行卷积运算。如果您不熟悉二维图像滤波,可以参考机器学习专家网站上的“图像滤波”文章。滤波器可以看作是一个比输入图像小的图像(即一个比输入图像小的矩阵),它被应用于输入图像的一部分。如果这部分图像符合滤波器的预期,则输出值会很高。将滤波器与整个输入图像进行卷积运算,会得到另一幅图像,该图像会突出显示输入图像的某些特征。

我们来看一个例子。下图展示了 Sobel-Feldman 算子 [4](也称为 Sobel 边缘检测器滤波器)在蓝猫图像上的应用。准确地说,我们应用了两个滤波器,一个用于检测水平边缘,一个用于检测垂直边缘。然后,我们将两个滤波器的结果结合起来,得到一幅同时显示水平和垂直边缘的图像。滤波器核位于图的中心位置。

猫过滤器

定义卷积层时有多种配置选项。每个卷积层可以包含一个或多个滤波器。卷积层会针对每个滤波器输​​出输入的中间表示。滤波器越多,图像特征的多样性就越高。

除了滤波器内核的数量,我们还可以选择一个内核大小。内核大小决定了滤波器的局部性,即应用滤波器时需要考虑的周围像素数量。其次,我们需要选择一个步长值。步长决定了卷积时每次移动的像素数。步长为 1 时,滤波器会遍历每个像素;而步长为 2 时,滤​​波器会跳过每隔一个像素。

问题是,我们如何选择要使用的滤波器?答案是,我们不需要选择。神经网络的强大之处在于,它们能够基于训练数据自行学习特征。训练过程将在后面的章节中详细讨论。现在,让我们来看第二种特征学习层:池化层。

池化层

池化层用于对输入进行下采样。其目的是降低模型的计算复杂度并避免过拟合。信息损失通常不会造成太大问题,因为特征的确切位置远不如它们之间的关系重要。

池化是通过应用一个特殊的滤波函数来实现的,同时选择合适的核大小和步长值,以确保滤波过程不会重叠。一种常用的技术是最大池化。在最大池化中,我们选择子区域中的最大值作为子采样输出。下图展示了对一个 4×4 输入矩阵应用 2×2 最大池化的结果。

2x2 最大池化

下图展示了对卷积层输出进行两次子采样后的结果。请注意,子采样会减小图像尺寸,但我又将其放大以更直观地展示信息损失。

子样本猫

我们如何利用派生特征来预测类别呢?让我们通过深入了解密集层的工作原理来找到答案。

致密层

全连接层将前一层中的每个神经元连接到下一层。在卷积神经网络(CNN)中,它们构成网络的分类部分。全连接层中的神经元学习每个类别由哪些特征组成。

与卷积层相比,全连接层在参数拟合方面更为复杂。卷积层中一个 3×3 卷积核的滤波器有 9 个参数,与输入神经元的数量无关。一个包含 16 个神经元的全连接层,如果前一层有 28×28 个神经元,则其权重已经达到 28×28×16 = 12,544 个。

现在我们对 CNN 的不同组成部分更加熟悉了,你可能会想知道如何找到所有参数的正确值,即密集层中的滤波器内核和权重。

训练

与所有机器学习算法一样,训练也是基于已知类别标签的示例输入进行的。一个未经训练的卷积神经网络(CNN)会被初始化为随机参数。然后,我们可以将训练示例输入到网络中,并检查输出神经元的激活情况。基于预期的激活情况(即与正确类别相关的神经元完全激活,而其余神经元不激活),我们可以推导出一个成本函数,该函数可以衡量网络预测的误差程度。

然后我们可以开始调整参数以降低成本。这个过程从输出神经元开始,逐层调整参数,直到输入层。这个学习过程被称为反向传播。我们如何知道应该增加哪个参数、减少哪个参数,以及调整多少呢?

这里我不会赘述太多数学细节,但你可能还记得微积分里讲过,有些函数可以求导,导数可以告诉你函数的输出如何随输入变量的变化而变化。导数表示函数图像上切线的斜率。如果我们计算成本函数的导数,就能知道每个参数如何影响结果,使其更接近我们预期的类别标签。

由于我们的成本函数不仅包含一个输入变量,而是可能有成千上万个(想想即使是很小的全连接层也需要很多权重),我们可以利用所谓的梯度。梯度是多变量函数导数的推广。准确地说,我们需要使用负梯度,因为我们的目标是降低成本。负梯度会告诉我们如何调整网络参数才能更好地对训练样本进行分类。这种方法称为梯度下降

大多数情况下,计算所有训练样本的精确负梯度在计算上是不可行的。然而,我们可以采用一个小技巧:将输入数据打乱并分组为小批次。然后,我们仅针对这些小子集计算梯度,并相应地调整网络参数,然后继续处理下一个批次。这种所谓的随机梯度下降法可以给出足够好的精确解。

但请记住,通过沿梯度下降,我们只能在初始随机参数允许的范围内进行改进。如果不使用完全不同的权重,网络可能无法改进,从而陷入所谓的成本函数局部最小值。虽然存在一些避免陷入局部最小值的技术,但它们也各有缺点。

现在我们有了训练好的模型,可以输入没有标签的图像,然后查看输出结果来确定正确的类别。接下来,我们来看一个图像分类的“Hello World”示例,以及我基于此示例开发的小应用程序。

手写数字识别

数据

图像分类的“Hello World”示例看似简单,实则不然,它指的是对手写数字进行分类。修改后的国家标准与技术研究院数据库(简称MNIST数据库)中提供了一个丰富的训练和测试数据集,可在线免费获取。

每个数字都以 28×28 像素的灰度图像形式提供。下图展示了每个数字的一​​些示例图像。

MNIST

应用架构

为了构建一个可供用户使用和体验的程序,我的目标是开发一个 Web 应用程序,允许用户绘制数字并对其进行分类。我使用Deeplearning4j (DL4J) 来构建、训练、验证和应用模型。DL4J 是一个适用于 JVM 的开源深度学习库。下方是一个简单的架构图。

应用架构

该应用程序分为两部分:

  • 培训与验证
  • 预言

训练和验证过程离线进行。它从一个目录结构中读取数据,该结构已将数据分为训练数据和测试数据,并且每个数字都分别存储在各自的目录中。训练成功后,网络会被序列化并持久化到文件系统中model.zip。预测 API 会在启动时加载模型,并使用它来处理来自前端的请求。

在详细介绍各个组件之前,请注意,源代码已上传至 GitHub,并且该应用已上线,您可以通过Heroku进行试用。由于我目前仅使用了免费套餐,因此在应用运行一段时间后首次使用时,您可能需要稍等片刻,因为它会延迟启动服务器。

前端

前端是一个简单的 HTML5 Canvas,外加一些 JavaScript 代码用于将数据发送到后端。它的设计很大程度上受到了William Malone的教程《使用 HTML5 Canvas 和 JavaScript 创建绘图应用》的启发。如果您现在无法访问在线版本,可以查看下面的前端屏幕截图。

前端 4

它包含一个绘图画布、一个用于将画布内容发送到后端的按钮、一个用于清除画布的按钮以及一个用于显示分类结果的输出区域。它index.html并不复杂。以下是所使用的 HTML 元素:

<body>
    <div id="canvasDiv"></div>
    <div id="controls">
        <button id="predictButton" type="button">Predict</button>
        <button id="clearCanvasButton" type="button">Clear</button>
    </div>
    <div id="predictionResult">
    </div>
</body>
Enter fullscreen mode Exit fullscreen mode

然后我们添加一些 CSSapp.css样式,让它看起来更美观。JavaScript 代码app.js是基本的jQuery,没什么特别之处,非常原型化。它首先构建画布并定义绘制函数。预测是通过将画布内容发送到后端来实现的。一旦结果到达后端,我们就将其显示在输出中div

$('#predictButton').mousedown(function(e) {
  canvas.toBlob(function(d) {
  var fd = new FormData();
  fd.append('image', d)
    $.ajax({
      type: "POST",
      url: "predict",
      data: fd,
      contentType: false,
      processData: false
    }).done(function(o) {
      $('#predictionResult').text(o)
    });
  });
});
Enter fullscreen mode Exit fullscreen mode

后端

后端PredictAPI.scala是一个小型Akka HTTP Web 服务器。启动时,我们从磁盘加载模型。由于 DL4J 的默认模型实现不是线程安全的,因此我们必须将访问操作封装在一个同步代码块中。

val model = new SynchronizedClassifier(
  ModelSerializer.restoreMultiLayerNetwork("model.zip")
)
Enter fullscreen mode Exit fullscreen mode

静态文件有一条路由,即index.htmlapp.jsapp.css,还有一条路由用于接收数字图像以进行预测。

val route =
  path("") {
    getFromResource("static/index.html")
  } ~
  pathPrefix("static") {
    getFromResourceDirectory("static")
  } ~
  path("predict") {
    fileUpload("image") {
      case (fileInfo, fileStream) =>
        val in = fileStream.runWith(StreamConverters.asInputStream(3.seconds))
        val img = invert(MnistLoader.fromStream(in))
        complete(model.predict(img).toString)
    }
  }
Enter fullscreen mode Exit fullscreen mode

对于每一幅输入的图像,我们都需要进行一些基本的变换,例如调整大小和缩放,这些变换已在MnistLoad.fromStream方法中实现。此外,由于网络需要识别黑色背景上的白色数字,我们还需要对图像进行反转处理。

模型

该模型采用七层卷积神经网络(CNN),其设计很大程度上借鉴了DL4J的CNN代码示例。隐藏层由两对卷积池化层和一个全连接层组成。模型使用随机梯度下降法进行训练,每次训练处理64张图像。模型的测试准确率达到98%。

训练和验证过程已在[此处]实现TrainMain.scala。您也可以在那里找到具体的模型配置。目前我不想赘述太多细节,但如果您有任何关于模型架构的问题,欢迎留言。

使用 Heroku 进行部署

我选择使用 Heroku 部署应用程序,因为它能够快速地将应用程序公开部署,提供免费套餐,并且与开发工作流程完美集成。我使用的是Heroku CLI

对于使用 SBT 构建的 Scala 项目,Heroku 将执行 `sbatch` 命令sbt stage。这将生成应用程序的二进制文件及其所有库依赖项。`sbatch` 命令Procfile指定了如何启动应用程序。以下是部署到 Heroku 所需的命令。

  • heroku login(登录您的 Heroku 帐户)
  • heroku create(正在初始化heroku远程设备)
  • git push heroku master(推送更改,触发构建)
  • heroku open(在浏览器中打开应用程序网址)

问题

如果您尝试过该应用程序,可能会遇到一些奇怪的输出结果。事实上,即使模型准确率高达 98%,也存在多种问题可能导致您绘制的数字被错误分类。

其中一个原因是图像没有居中。虽然卷积层和池化子采样相结合有所帮助,但我怀疑将所有数字移动并调整大小到画布中心可以提高性能。为了获得最佳效果,请尝试将图像绘制在画布的下三分之二处。

此外,训练数据捕捉到了美国常见的一种特定手写风格。在世界其他地区,数字 1 由多条线组成,而美国人通常将其写成一条线。这可能导致写法不同的 1 被识别为 7。下图对此进行了说明。

美国1号

概括

本文介绍了如何使用卷积神经网络(CNN)对图像数据进行分类。通过结合近似优化技术、子采样和滤波应用,我们能够训练出一个深度网络,该网络可以很好地捕捉输入图像的特征。

只需少量 JavaScript、HTML 和 CSS 代码,即可开发一个用于绘制待分类图像的前端。后端可以使用类似 Akka HTTP 的 HTTP 服务器,并结合类似 DL4J 的深度学习框架来实现。

我们还发现,只有当真实数据与构建模型时使用的训练数据和测试数据完全一致时,实际应用中的分类性能才能与测试准确率相匹配。因此,在模型运行时持续监控其性能至关重要,需要定期调整或重新训练模型以保持较高的准确率。

参考

  • [1] Rigouste, L., Cappé, O. 和 Yvon, F., 2007. 文本聚类多项式混合模型的推断和评估。信息处理与管理,43(5),第 1260-1280 页。
  • [2] LeCun, Y., Bottou, L., Bengio, Y. 和 Haffner, P., 1998. 基于梯度的学习应用于文档识别。IEEE 会刊,86(11),第 2278-2324 页。
  • [3] Ciregan, D.、Meier, U. 和 Schmidhuber, J.,2012 年 6 月。用于图像分类的多列深度神经网络。载于 2012 年 IEEE 计算机视觉与模式识别会议 (CVPR) 论文集(第 3642-3649 页)。IEEE。
  • [4] Sobel, I., Feldman, G., 用于图像处理的 3x3 各向同性梯度算子,于 1968 年在斯坦福人工智能项目 (SAIL) 上发表。

如果你喜欢这篇文章,可以在 ko-fi 上支持我

文章来源:https://dev.to/frosnerd/handwriting-digit-recognition-using-convolutional-neural-networks-11g0