[oneAPI] 手写数字识别-GAN

news/2024/9/8 4:01:39/

[oneAPI] 手写数字识别-GAN

  • 手写数字识别
    • 参数与包
    • 加载数据
    • 模型
    • 训练过程
    • 结果
  • oneAPI

比赛:https://marketing.csdn.net/p/f3e44fbfe46c465f4d9d6c23e38e0517
Intel® DevCloud for oneAPI:https://devcloud.intel.com/oneapi/get_started/aiAnalyticsToolkitSamples/

手写数字识别

使用了pytorch以及Intel® Optimization for PyTorch,通过优化扩展了 PyTorch,使英特尔硬件的性能进一步提升,让手写数字识别问题更加的快速高效
在这里插入图片描述

使用MNIST数据集,该数据集包含了一系列以黑白图像表示的手写数字,每个图像的大小为28x28像素,数据集组成如下:

  • 训练集:包含60,000个图像和标签,用于训练模型。
  • 测试集:包含10,000个图像和标签,用于测试模型的性能。

每个图像都被标记为0到9之间的一个数字,表示图像中显示的手写数字。这个数据集常常被用来验证图像分类模型的性能,特别是在计算机视觉领域。

参数与包

import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_imageimport intel_extension_for_pytorch as ipex# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')# Hyper-parameters
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'samples'

加载数据

# Create a directory if not exists
if not os.path.exists(sample_dir):os.makedirs(sample_dir)# Image processing
# transform = transforms.Compose([
#                 transforms.ToTensor(),
#                 transforms.Normalize(mean=(0.5, 0.5, 0.5),   # 3 for RGB channels
#                                      std=(0.5, 0.5, 0.5))])
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5],  # 1 for greyscale channelsstd=[0.5])])# MNIST dataset
mnist = torchvision.datasets.MNIST(root='./data/',train=True,transform=transform,download=True)# Data loader
data_loader = torch.utils.data.DataLoader(dataset=mnist,batch_size=batch_size,shuffle=True)

模型

# Discriminator
D = nn.Sequential(nn.Linear(image_size, hidden_size),nn.LeakyReLU(0.2),nn.Linear(hidden_size, hidden_size),nn.LeakyReLU(0.2),nn.Linear(hidden_size, 1),nn.Sigmoid())# Generator 
G = nn.Sequential(nn.Linear(latent_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, image_size),nn.Tanh())

训练过程

# Device setting
D = D.to(device)
G = G.to(device)# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
D, d_optimizer = ipex.optimize(D, optimizer=d_optimizer)
G, g_optimizer = ipex.optimize(G, optimizer=g_optimizer)def denorm(x):out = (x + 1) / 2return out.clamp(0, 1)def reset_grad():d_optimizer.zero_grad()g_optimizer.zero_grad()# Start training
total_step = len(data_loader)
for epoch in range(num_epochs):for i, (images, _) in enumerate(data_loader):images = images.reshape(batch_size, -1).to(device)# Create the labels which are later used as input for the BCE lossreal_labels = torch.ones(batch_size, 1).to(device)fake_labels = torch.zeros(batch_size, 1).to(device)# ================================================================== ##                      Train the discriminator                       ## ================================================================== ## Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))# Second term of the loss is always zero since real_labels == 1outputs = D(images)d_loss_real = criterion(outputs, real_labels)real_score = outputs# Compute BCELoss using fake images# First term of the loss is always zero since fake_labels == 0z = torch.randn(batch_size, latent_size).to(device)fake_images = G(z)outputs = D(fake_images)d_loss_fake = criterion(outputs, fake_labels)fake_score = outputs# Backprop and optimized_loss = d_loss_real + d_loss_fakereset_grad()d_loss.backward()d_optimizer.step()# ================================================================== ##                        Train the generator                         ## ================================================================== ## Compute loss with fake imagesz = torch.randn(batch_size, latent_size).to(device)fake_images = G(z)outputs = D(fake_images)# We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))# For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdfg_loss = criterion(outputs, real_labels)# Backprop and optimizereset_grad()g_loss.backward()g_optimizer.step()if (i + 1) % 200 == 0:print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'.format(epoch, num_epochs, i + 1, total_step, d_loss.item(), g_loss.item(),real_score.mean().item(), fake_score.mean().item()))# Save real imagesif (epoch + 1) == 1:images = images.reshape(images.size(0), 1, 28, 28)save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))# Save sampled imagesfake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch + 1)))# Save the model checkpoints 
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')

结果

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

oneAPI

import intel_extension_for_pytorch as ipex# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')# Device setting
D = D.to(device)
G = G.to(device)# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
D, d_optimizer = ipex.optimize(D, optimizer=d_optimizer)
G, g_optimizer = ipex.optimize(G, optimizer=g_optimizer)

http://www.ppmy.cn/news/1048780.html

相关文章

opencv 进阶10-人脸识别原理说明及示例-cv2.CascadeClassifier.detectMultiScale()

人脸识别是指程序对输入的人脸图像进行判断,并识别出其对应的人的过程。人脸识别程 序像我们人类一样,“看到”一张人脸后就能够分辨出这个人是家人、朋友还是明星。 当然,要实现人脸识别,首先要判断当前图像内是否出现了人脸&…

【React学习】React组件生命周期

1. 介绍 在 React 中,组件的生命周期是指组件从被创建到被销毁的整个过程。React框架提供了一系列生命周期方法,在不同的生命周期方法中,开发人员可以执行不同的操作,例如初始化状态、数据加载、渲染、更新等。一个组件的生命周期…

43.227.198.x怎么检查服务器里是否中毒情况?

要检查43.227.198.1服务器是否中毒,可以执行以下步骤: 运行杀毒软件:运行已安装的杀毒软件进行全盘扫描,查看是否有病毒或恶意软件。如果发现病毒或恶意软件,立即将其删除或隔离。 检查系统文件:检查服务器…

跑赢空间智能新时代,就在2023 SuperMap开发者大会

当前,3S技术加速融合,并不断融入IT技术,正在助力千行百业数字化转型,一线开发者们正在用指尖改造世界。尤其是以大模型为代表的新技术涌现,给GIS开发带来新方向的同时,也给开发者们带来了新挑战。开发者们亟…

Web自动化测试-Selenium语法入门到精通

前言 说到自动化测试,就不得不提大名鼎鼎的Selenium。Selenium 是如今最常用的自动化测试工具之一,支持快速开发自动化测试框架,且支持在多种浏览器上执行测试。 Selenium学习难度小,开发周期短。对测试人员来说,如果…

【BASH】回顾与知识点梳理(三十七)

【BASH】回顾与知识点梳理 三十七 三十七. 基础系统设定与备份策略37.1 系统基本设定网络设定 (手动设定与 DHCP 自动取得)手动设定 IP 网络参数(nmcli)自动取得 IP 参数(dhcp)修改主机名(hostnamectl) 37.2 日期与时间设定时区的显示与设定时间的调整用 ntpdate 手动网络校时 …

工业安全生产平台在面粉行业的应用分享

一、背景介绍 面粉行业是一个传统的工业行业,安全生产问题一直备受关注。然而,由于生产过程中存在的各种安全隐患和风险,如粉尘爆炸、机械伤害等,使得面粉行业的安全生产形势依然严峻。为了解决这一问题,工业安全生产…

认识米娜:一个真正健谈的聊天机器人

人工智能驱动的聊天机器人已被寻求简化客户服务、提高生产力和增加收入的企业广泛采用。在电子商务平台上,聊天机器人可以将客户引导至推荐的产品,跟踪订单,解释如何打印退货运输标签等。 然而,这样的聊天机器人在脱靶或无关紧要…

Unity 变量修饰符 之protected ,internal,const , readonly, static

文章目录 protectedinternalconstreadonlystatic protected 当在Unity中使用C#编程时,protected是一种访问修饰符,用于控制类成员(字段、方法、属性等)的可见性和访问权限。protected修饰的成员可以在当前类内部、派生类&#xf…

Mac更新homebrew时卡住的解决办法

Mac更新homebrew时卡住的解决办法 引起问题的原因brew命令安装软件跟这3个仓库地址有关1、brew2、homebrew-core3、homebrew-bottles4、若/bin/zsh,则输入5、若/bin/bash,则输入6、更新brew 引起问题的原因 知其然,还要知其所以然。brew的更…

day4 驱动开发

【ioctl函数的使用】 1.概述 linux有意将对设备的功能选择和设置以及硬件数据的读写分成不同的函数来实现。让read/write函数专注于数据的读写,而硬件功能的设备和选择通过ioctl函数来选择 2.ioctl函数分析 int ioctl(int fd,unsigned long request) 通过&…

有没有免费格式转换工具推荐?PDF转化为PPT的方法

在当今职场生活中,掌握文件格式转换技能变得异常重要。将PDF文档转换为PPT格式可以在演讲、报告等场合更好地展示和传达信息,为我们的专业形象增添亮点,接下来我们可以一起来看一下“有没有免费格式转换工具推荐?PDF转化为PPT的方法”相关的…

ping使用方法

文章目录 1、Ping的基础知识2、Ping命令详解3、怎样使用Ping这命令来测试网络连通?4、如何用Ping命令来判断一条链路好坏?5、对Ping后返回信息的分析1.Request timed out2.Destination host Unreachable 1、Ping的基础知识 ping命令相信大家已经再熟悉不…

Java算法_ BST 中第 k 个最小元素 (LeetCode_Hot100)

题目描述:给定一个二叉搜索树的根节点 ,和一个整数 ,请你设计一个算法查找其中第 个最小元素(从 1 开始计数)。 获得更多?算法思路:代码文档,算法解析的私得。 运行效果 完整代码 /*** 2 * Aut…

数据在内存中的储存·大小端(文字+画图详解)(c语言·超详细入门必看)

前言:Hello,大家好,我是心跳sy😘,本节我们介绍c语言的两种基本的内置数据类型:数值类型和字符类型在内存中的储存方法,并对大小端进行详细介绍(附两种大小端判断方法)&am…

sql性能优化的相关面试专题

1.比如,现在有个面试官说,现在线上有个SQL执行很慢,你怎么优化? 这种时候最好分几步回答,不要一上来就说,该怎么怎么写SQL,面试时要学会,跳出来,看全貌,装进去&#xf…

浅析DIX与DIF(T10 PI)

文章目录 概述DIF与DIX端到端数据保护 DIFDIF保护类型 SCSI设备支持DIFStandard INQUIRY DataExtended INQUIRY Data VPD pageSPT字段GRD_CHK、APP_CHK、REF_CHK字段 READ CAPACITY(16)响应信息 SCSI命令请求读命令请求写命令请求 DIF盘格式化相关参考 概述 DIF与DIX DIF&…

前端框架学习-React(一)

React 应用程序是由组件组成的。 react 程序是用的jsx语法,使用这种语法的代码需要由babel进行解析,解析成js代码。 jsx语法: 只能返回一个根元素 所有的标签都必须闭合(自闭和或使用一对标签的方式闭合) 使用驼峰式…

微信删除的聊天记录怎么恢复?满满干货,建议收藏!

微信的出现逐渐改变了我们的社交方式,它架起了我们与朋友、家人以及同事之间的沟通桥梁,成为我们生活中不可缺失的一部分。 但是总会有那么点意外会发生,比如自己和朋友吵架了,一怒之下将朋友删除,导致所有聊天记录都…

解锁数据潜力:信息抽取、数据增强与UIE的完美融合

解锁数据潜力:信息抽取、数据增强与UIE的完美融合 1.信息抽取(Information Extraction) 1.1 IE简介 信息抽取是 NLP 任务中非常常见的一种任务,其目的在于从一段自然文本中提取出我们想要的关键信息结构。 举例来讲&#xff0…