深度学习(1)-简单神经网络示例

news/2025/3/15 23:33:17/

我们来看一个神经网络的具体实例:使用Python的Keras库来学习手写数字分类。在这个例子中,我们要解决的问题是,将手写数字的灰度图像(28像素×28像素)划分到10个类别中(从0到9)​。我们将使用MNIST数据集,图2-1给出了MNIST数据集的一些样本。
在这里插入图片描述
在机器学习中,分类问题中的某个类别叫作类(class)​,数据点叫作样本(sample)​,与某个样本对应的类叫作标签(label)​。你不需要现在就尝试在计算机上运行这个例子。如果你想这么做,那么首先需要建立深度学习工作区(见第3章)​。MNIST数据集已预先加载在Keras库中,其中包含4个NumPy数组,如代码清单2-1所示。

from tensorflow.keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images和train_labels组成了训练集,模型将从这些数据中进行学习。然后,我们在测试集(包括test_images和test_labels)上对模型进行测试。图像被编码为NumPy数组,而标签是一个数字数组,取值范围是0~9。图像和标签一一对应。我们来看一下训练数据:

>>> train_images.shape
(60000, 28, 28)
>>> len(train_labels)
60000
>>> train_labels
array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)

再来看一下测试数据:

>>> test_images.shape
(10000, 28, 28)
>>> len(test_labels)
10000
>>> test_labels
array([7, 2, 1, ..., 4, 5, 6], dtype=uint8)

工作流程如下:首先,将训练数据(train_images和train_labels)输入神经网络;然后,神经网络学习将图像和标签关联在一起;最后,神经网络对test_images进行预测,我们来验证这些预测与test_labels中的标签是否匹配。下面我们来构建神经网络,如代码清单2-2所示。

代码清单2-2 神经网络架构

from tensorflow import keras
from tensorflow.keras import layers
model = keras.Sequential([layers.Dense(512, activation="relu"),layers.Dense(10, activation="softmax")
])

**神经网络的核心组件是层(layer)**​。你可以将层看成数据过滤器:进去一些数据,出来的数据变得更加有用。具体来说,层从输入数据中提取表示——我们期望这种表示有助于解决手头的问题。大多数深度学习工作涉及将简单的层链接起来,从而实现渐进式的数据蒸馏(data distillation)​。深度学习模型就像是处理数据的筛子,包含一系列越来越精细的数据过滤器(也就是层)​。**本例中的模型包含2个Dense层,它们都是密集连接(也叫全连接)的神经层。第2层(也是最后一层)是一个10路softmax分类层,它将返回一个由10个概率值(总和为1)组成的数组。**每个概率值表示当前数字图像属于10个数字类别中某一个的概率。在训练模型之前,我们还需要指定编译(compilation)步骤的3个参数。优化器(optimizer)​:模型基于训练数据来自我更新的机制,其目的是提高模型性能。损失函数(loss function)​:模型如何衡量在训练数据上的性能,从而引导自己朝着正确的方向前进。在训练和测试过程中需要监控的指标(metric)​:本例只关心精度(accuracy)​,即正确分类的图像所占比例。

后面两章会详细介绍损失函数和优化器的确切用途。代码清单2-3展示了编译步骤。

代码清单2-3 编译步骤

model.compile(optimizer="rmsprop",loss="sparse_categorical_crossentropy",metrics=["accuracy"])

在开始训练之前,我们先对数据进行预处理,将其变换为模型要求的形状,并缩放到所有值都在[0, 1]区间。前面提到过,训练图像保存在一个uint8类型的数组中,其形状为(60000, 28, 28),取值区间为[0, 255]。我们将把它变换为一个float32数组,其形状为(60000, 28 * 28),取值范围是[0, 1]。下面准备图像数据,如代码清单2-4所示。

代码清单2-4 准备图像数据

train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype("float32") / 255
test_images = test_images.reshape((10000, 28 * 28))
test_images = test_images.astype("float32") / 255

现在我们准备开始训练模型。在Keras中,这一步是通过调用模型的fit方法来完成的——我们在训练数据上拟合(fit)模型,如代码清单2-5所示。
代码清单2-5 拟合模型

>>> model.fit(train_images, train_labels, epochs=5, batch_size=128)
Epoch 1/5
60000/60000 [===========================] - 5s - loss: 0.2524 - acc: 0.9273
Epoch 2/5
51328/60000 [=====================>.....] - ETA: 1s - loss: 0.1035 - acc: 0.9692

训练过程中显示了两个数字:一个是模型在训练数据上的损失值(loss)​,另一个是模型在训练数据上的精度(acc)​。我们很快就在训练数据上达到了0.989(98.9%)的精度。现在我们得到了一个训练好的模型,可以利用它来预测新数字图像的类别概率(见代码清单2-6)​。这些新数字图像不属于训练数据,比如可以是测试集中的数据。

代码清单2-6 利用模型进行预测

>>> test_digits = test_images[0:10]
>>> predictions = model.predict(test_digits)
>>> predictions[0]
array([1.0726176e-10, 1.6918376e-10, 6.1314843e-08, 8.4106023e-06,2.9967067e-11, 3.0331331e-09, 8.3651971e-14, 9.9999106e-01,2.6657624e-08, 3.8127661e-07], dtype=float32)

这个数组中每个索引为i的数字对应数字图像test_digits[0]属于类别i的概率。第一个测试数字在索引为7时的概率最大(0.99999106,几乎等于1)​,所以根据我们的模型,这个数字一定是7。

>>> predictions[0].argmax()
7
>>> predictions[0][7]
0.99999106

我们可以检查测试标签是否与之一致:

>>> test_labels[0]
7

平均而言,我们的模型对这种前所未见的数字图像进行分类的效果如何?我们来计算在整个测试集上的平均精度,如代码清单2-7所示。

代码清单2-7 在新数据上评估模型

>>> test_loss, test_acc = model.evaluate(test_images, test_labels)
>>> print(f"test_acc: {test_acc}")
test_acc: 0.9785

测试精度约为97.8%,比训练精度(98.9%)低不少。训练精度和测试精度之间的这种差距是过拟合(overfit)造成的。**过拟合是指机器学习模型在新数据上的性能往往比在训练数据上要差,**它是第4章的核心主题。第一个例子到这里就结束了。你刚刚看到了如何用不到15行Python代码构建和训练一个神经网络,对手写数字进行分类。在本章和第3章中,我们会详细了解这个例子中的每一个步骤及其原理。接下来,你将学到张量(输入模型的数据存储对象)​、张量运算(层的组成要素)与梯度下降(可以让模型从训练示例中进行学习)​。

需要记住的名词:
1.类
2.样本
3.标签
4.训练集
5.测试集
6.层(layer)
7.dense
8.softmax
9.损失函数
10.指标
11.过拟合


http://www.ppmy.cn/news/1573090.html

相关文章

IP 路由基础 | 路由条目生成 / 路由表内信息获取

注:本文为 “IP 路由” 相关文章合辑。 未整理去重。 IP 路由基础 秦同学学学已于 2022-04-09 18:44:20 修改 一. IP 路由产生背景 我们都知道 IP 地址可以标识网络中的一个节点,并且每个 IP 地址都有自己的网段,各个网段并不相同&#xf…

NPM如何更换淘宝镜像——Node.js国内镜像配置教程

在国内使用 npm 安装 Node.js 包时,由于网络环境的原因,下载速度可能非常慢。为了解决这个问题,很多开发者会选择使用淘宝镜像(现在由 npmmirror.com 维护)。本文将带你一步一步完成更换 npm 源为淘宝镜像的配置&#…

【架构设计】详解高可用架构

架构设计的愿景就是高可用、高性能、高扩展、高效率。为了实现架构设计四高愿景,需要实现自动化系统目标: 标准化。 流程自助化。 可视化:可观测系统各项指标、包括全链路跟踪。 自动化:ci/cd 自动化部署。 精细化&#xff1a…

Whisper+T5-translate实现python实时语音翻译

1.首先下载模型,加载模型 import torch import numpy as np import webrtcvad import pyaudio import queue import threading from datetime import datetime from faster_whisper import WhisperModel from transformers import AutoTokenizer, AutoModelForSeq2…

github用户名密码登陆失效了

问题: git push突然推代码需要登陆,但是用户名和密码正确输入后,却提示403 git push# Username for https://github.com: **** #Password for https://gyp-programmergithub.com: #remote: Permission to gyp-programmer/my-app.git denie…

【Elasticsearch】`nested`字段

Elasticsearch 的nested字段是一种强大的数据类型,用于处理嵌套对象数组,允许将每个对象独立索引和查询。以下是关于nested字段的详细说明: 1.nested字段的定义 nested字段是object数据类型的特殊版本,允许将对象数组索引为独立…

Linux下学【MySQL】中如何实现:多表查询(配sql+实操图+案例巩固 通俗易懂版~)

每日激励:“不设限和自我肯定的心态:I can do all things。 — Stephen Curry” 绪论​: 本章是MySQL篇中,非常实用性的篇章,相信在实际工作中对于表的查询,很多时候会涉及多表的查询,在多表查询…

七星棋牌全开源修复版源码解析:6端兼容,200种玩法全面支持

本篇文章将详细讲解 七星棋牌修复版源码 的 技术架构、功能实现、二次开发思路、搭建教程 等内容,助您快速掌握该棋牌系统的开发技巧。 1. 七星棋牌源码概述 七星棋牌修复版源码是一款高度自由的 开源棋牌项目,该版本修复了原版中的多个 系统漏洞&#…