[oneAPI] 手写数字识别-BiLSTM

news/2023/11/30 11:37:52

[oneAPI] 手写数字识别-BiLSTM

  • 手写数字识别
    • 参数与包
    • 加载数据
    • 模型
    • 训练过程
    • 结果
  • oneAPI

比赛:https://marketing.csdn.net/p/f3e44fbfe46c465f4d9d6c23e38e0517
Intel® DevCloud for oneAPI:https://devcloud.intel.com/oneapi/get_started/aiAnalyticsToolkitSamples/

手写数字识别

使用了pytorch以及Intel® Optimization for PyTorch,通过优化扩展了 PyTorch,使英特尔硬件的性能进一步提升,让手写数字识别问题更加的快速高效
在这里插入图片描述

使用MNIST数据集,该数据集包含了一系列以黑白图像表示的手写数字,每个图像的大小为28x28像素,数据集组成如下:

  • 训练集:包含60,000个图像和标签,用于训练模型。
  • 测试集:包含10,000个图像和标签,用于测试模型的性能。

每个图像都被标记为0到9之间的一个数字,表示图像中显示的手写数字。这个数据集常常被用来验证图像分类模型的性能,特别是在计算机视觉领域。

参数与包

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transformsimport intel_extension_for_pytorch as ipex# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')# Hyper-parameters
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.003

加载数据

# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='../../data/',train=True,transform=transforms.ToTensor(),download=True)test_dataset = torchvision.datasets.MNIST(root='../../data/',train=False,transform=transforms.ToTensor())# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False)

模型

# Bidirectional recurrent neural network (many-to-one)
class BiRNN(nn.Module):def __init__(self, input_size, hidden_size, num_layers, num_classes):super(BiRNN, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)self.fc = nn.Linear(hidden_size * 2, num_classes)  # 2 for bidirectiondef forward(self, x):# Set initial statesh0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device)  # 2 for bidirectionc0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device)# Forward propagate LSTMout, _ = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size*2)# Decode the hidden state of the last time stepout = self.fc(out[:, -1, :])return out

训练过程

model = BiRNN(input_size, hidden_size, num_layers, num_classes).to(device)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
model, optimizer = ipex.optimize(model, optimizer=optimizer)# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):images = images.reshape(-1, sequence_length, input_size).to(device)labels = labels.to(device)# Forward passoutputs = model(images)loss = criterion(outputs, labels)# Backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()if (i + 1) % 100 == 0:print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))# Test the model
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.reshape(-1, sequence_length, input_size).to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

结果

在这里插入图片描述

oneAPI

import intel_extension_for_pytorch as ipex# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')# 模型
model = ConvNet(num_classes).to(device)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
model, optimizer = ipex.optimize(model, optimizer=optimizer)

http://www.ppmy.cn/news/1037600.html

相关文章

单片机之从C语言基础到专家编程 - 4 C语言基础 - 4.8 运算符

1.算术运算符 运算符名称备注加法运算符双目运算,a b-减法运算符双目运算,a - b*乘法运算符双目运算,a * b/除法运算符双目运算,a / b%求余运算符双目运算, a % b自增运算符单目运算, a–自减运算符单目运算, a– 2.关系运算符…

C# 读取pcd点云文件数据

pcd文件有ascii 和二进制格式,ascii可以直接记事本打开,C#可以一行行读。但二进制格式的打开是乱码,如果尝试程序中读取,对比下看了数据也对不上。 这里可以使用pcl里的函数来读取pcd,无论二进制或ascii都可以正确读取…

每日一题之跳台阶

跳台阶 问题描述: 一只青蛙一次可以跳上 1 级台阶,也可以跳上 2 级。求该青蛙跳上一个 n 级的台阶总共有多少种跳法(先后次序不同算不同的结果)。 要求:时间复杂度: O ( n ) O(n) O(n),空间复杂…

【设计模式】模板方法模式(Template Method Pattern)

23种设计模式之模板方法模式(Template Method Pattern) 基本概念 模板方法模式是一种行为型设计模式,它定义了一个算法骨架,将某些算法步骤的实现延迟到子类中。 这样可以使得算法的框架不被修改,但是具体的实现可以…

神经网络基础-神经网络补充概念-35-为什么正则化可以减少过拟合

概念 正则化可以减少过拟合的原因在于它通过限制模型的复杂性来约束参数的取值范围,从而提高了模型的泛化能力。过拟合是指模型在训练集上表现很好,但在未见过的数据上表现不佳,这通常是因为模型过于复杂,过多地拟合了训练数据中…

Kafka 入门到起飞 - 什么是 HW 和 LEO?何时更新HW和LEO呢?

上文我们已经学到, 一个Topic(主题)会有多个Partition(分区)为了保证高可用,每个分区有多个Replication(副本)副本分为Leader 和 Follower 两个角色,Follower 从Leader同…

程序的机器级表示

程序的机器级表示 程序编码数据格式访问信息的方式 所有的高级语言,都会被计算机翻译为机器代码,然后再根据汇编代码生成可执行的机器代码。二进制的机器代码我们人类肯定是读不懂了,但是汇编代码还是可以简单了解一下的。CPU 的 PC、寄存器、…

Java基础十一(面向对象OOP)

创建一个学生类 编写一个名为 Student 的类, 包含以下属性和方法: 属性:姓名(name)、年龄(age)、学号(studentId)、成绩(score)方法&#xff1a…

Lucene教程_编程入门自学教程_菜鸟教程-免费教程分享

教程简介 Lucene是apache软件基金会 jakarta项目组的一个子项目,是一个开放源代码的全文检索引擎工具包,但它不是一个完整的全文检索引擎,而是一个全文检索引擎的架构,提供了完整的查询引擎和索引引擎,部分文本分析引…

2023-08-17 Untiy进阶 C#知识补充7——C#8主要功能与语法

文章目录 一、Using 声明二、静态本地函数三、Null 合并赋值四、解构函数 Deconstruct五、模式匹配增强功能 ​ 注意:在此仅提及 Unity 开发中会用到的一些功能和特性,对于不适合在 Unity 中使用的内容会忽略。 ​ C# 8 对应 Unity 版本: Un…

kubeasz在线安装K8S集群单master集群(kubeasz安装之二)

一、介绍 Kubeasz 是一个基于 Ansible 自动化工具,用于快速部署和管理 Kubernetes 集群的工具。它支持快速部署高可用的 Kubernetes 集群,支持容器化部署,可以方便地扩展集群规模,支持多租户,提供了强大的监控和日志分…

操作系统的介绍

简介: 服务器操作系统可以实现对计算机硬件与软件的直接控制和管理协调。任何计算机的运行离不开操作系统,服务器也一样。服务器操作系统主要分为四大流派:Windows Server、Netware、Unix、Linux。 分类: Windows Server 重要…

Linux学习之firewallD

systemctl status firewalld.service查看一下firewalld服务的状态,发现状态是inactive (dead)。 systemctl start firewalld.service启动firewalld,systemctl status firewalld.service查看一下firewalld服务的状态,发现状态是active (runni…

vue2中$attrs 和 $listeners的使用

功能: 用于实现多层组件通信(多于父子组件通信层级)用法: v-bind$attrs 和 v-on$listeners vm. a t t r s 包含了父作用域中不作为 p r o p 被识别 ( 且获取 ) 的 a t t r i b u t e 绑定 ( c l a s s 和 s t y l e 除外 ) 。当一…

【数据结构与算法】普里姆算法

普里姆算法 最小生成树 最小生成树,简称MST。 给定一个带权的无向连通图,如何选取一棵生成树,使树上所有边上权的总和为最小,这就叫最小生成树。N 个顶点,一定有 N - 1 条边半酣全部顶点N - 1 条边都在图中举例说明…

在K8s上处理nginx

基本说明 创建一个名为ssl的TLS类型的Secret对象,用于存储证书和密钥信息。 kubectl create secret tls ssl --certserver.crt --keyserver.key配置Nginx的events块,设置worker连接数为1024。 events {worker_connections 1024; }配置Nginx的http块&a…

如何用树莓派Pico针对IoT编程?

目录 一、Raspberry Pi Pico 系列和功能 二、Raspberry Pi Pico 的替代方案 三、对 Raspberry Pi Pico 进行编程 硬件 软件 第 1 步:连接计算机 第 2 步:在 Pico 上安装 MicroPython 第 3 步:为 Thonny 设置解释器 第 4 步&#xff…

Java内存区域(运行时数据区域)和内存模型(JMM)

Java 内存区域和内存模型是不一样的东西,内存区域是指 Jvm 运行时将数据分区域存储,强调对内存空间的划分。 而内存模型(Java Memory Model,简称 JMM )是定义了线程和主内存之间的抽象关系,即 JMM 定义了 …

Spark SQL优化:NOT IN子查询优化解决

背景 有如下的数据查询场景。 SELECT a,b,c,d,e,f FROM xxx.BBBB WHERE dt ${zdt.addDay(0).format(yyyy-MM-dd)} AND predict_type not IN ( SELECT distinct a FROM xxx.AAAAAWHERE dt ${zdt.addDay(0).format(yyyy-MM-dd)} ) 分析 通过查看SQL语句的执行计划基本…

python 条件编译如何写

在Python中,条件编译通常是通过预处理指令来实现的。与其他编程语言不同,Python没有像C或C那样的预处理器,但您可以使用一些技巧来模拟条件编译的效果。以下是一种在Python中模拟条件编译的常见方法: # 定义一个条件变量&#xf…
最新文章