第N6周:使用Word2vec实现文本分类

news/2024/5/19 6:02:52/
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,datasets
import os,PIL,pathlib,warnings
#忽略警告信息
warnings.filterwarnings("ignore")
# win10系统
device = torch.device("cuda"if torch.cuda.is_available()else"cpu")
deviceimport pandas as pd
# 加载自定义中文数据
train_data= pd.read_csv('./data/train2.csv',sep='\t',header=None)
train_data.head()# 构造数据集迭代器
def coustom_data_iter(texts,labels):for x,y in zip(texts,labels):yield x,y
x = train_data[0].values[:]
#多类标签的one-hot展开
y = train_data[1].values[:]from gensim.models.word2vec import Word2Vec
import numpy as np
#训练word2Vec浅层神经网络模型
w2v=Word2Vec(vector_size=100#是指特征向量的维度,默认为100。,min_count=3)#可以对字典做截断。词频少于min_count次数的单词会被丢弃掉,默认为5w2v.build_vocab(x)
w2v.train(x,total_examples=w2v.corpus_count,epochs=20)# 将文本转化为向量
def average_vec(text):vec =np.zeros(100).reshape((1,100))for word in text:try:vec +=w2v.wv[word].reshape((1,100))except KeyError:continuereturn vec
#将词向量保存为Ndarray
x_vec= np.concatenate([average_vec(z)for z in x])
#保存Word2Vec模型及词向量
w2v.save('data/w2v_model.pk1')train_iter= coustom_data_iter(x_vec,y)
len(x),len(x_vec)label_name =list(set(train_data[1].values[:]))
print(label_name)text_pipeline =lambda x:average_vec(x)
label_pipeline =lambda x:label_name.index(x)text_pipeline("你在干嘛")
label_pipeline("Travel-Query")from torch.utils.data import DataLoader
def collate_batch(batch):label_list,text_list=[],[]for(_text,_label)in batch:# 标签列表label_list.append(label_pipeline(_label))# 文本列表processed_text = torch.tensor(text_pipeline(_text),dtype=torch.float32)text_list.append(processed_text)label_list = torch.tensor(label_list,dtype=torch.int64)text_list = torch.cat(text_list)return text_list.to(device),label_list.to(device)
# 数据加载器,调用示例
dataloader = DataLoader(train_iter,batch_size=8,
shuffle =False,
collate_fn=collate_batch)from torch import nn
class TextclassificationModel(nn.Module):def __init__(self,num_class):super(TextclassificationModel,self).__init__()self.fc = nn.Linear(100,num_class)def forward(self,text):return self.fc(text)num_class =len(label_name)
vocab_size =100000
em_size=12
model= TextclassificationModel(num_class).to(device)import time
def train(dataloader):model.train()#切换为训练模式total_acc,train_loss,total_count =0,0,0log_interval=50start_time= time.time()for idx,(text,label)in enumerate(dataloader):predicted_label= model(text)# grad属性归零optimizer.zero_grad()loss=criterion(predicted_label,label)#计算网络输出和真实值之间的差距,labelloss.backward()#反向传播torch.nn.utils.clip_grad_norm(model.parameters(),0.1)#梯度裁剪optimizer.step()#每一步自动更新#记录acc与losstotal_acc+=(predicted_label.argmax(1)==label).sum().item()train_loss += loss.item()total_count += label.size(0)if idx % log_interval==0 and idx>0:elapsed =time.time()-start_timeprint('Iepoch {:1d}I{:4d}/{:4d} batches''|train_acc {:4.3f} train_loss {:4.5f}'.format(epoch,idx,len(dataloader),total_acc/total_count,train_loss/total_count))total_acc,train_loss,total_count =0,0,0start_time = time.time()
def evaluate(dataloader):model.eval()#切换为测试模式total_acc,train_loss,total_count =0,0,0with torch.no_grad():for idx,(text,label)in enumerate(dataloader):predicted_label= model(text)loss = criterion(predicted_label,label)# 计算loss值# 记录测试数据total_acc+=(predicted_label.argmax(1)== label).sum().item()train_loss += loss.item()total_count += label.size(0)return total_acc/total_count,train_loss/total_countfrom torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
# 超参数
EPOCHS=10#epoch
LR=5 #学习率
BATCH_SIZE=64 # batch size for training
criterion = torch.nn.CrossEntropyLoss()
optimizer= torch.optim.SGD(model.parameters(),lr=LR)
scheduler=torch.optim.lr_scheduler.StepLR(optimizer,1.0,gamma=0.1)
total_accu = None
# 构建数据集
train_iter= coustom_data_iter(train_data[0].values[:],train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)split_train_,split_valid_= random_split(train_dataset,[int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])
train_dataloader =DataLoader(split_train_,batch_size=BATCH_SIZE,
shuffle=True,collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_,batch_size=BATCH_SIZE,
shuffle=True,collate_fn=collate_batch)
for epoch in range(1,EPOCHS+1):epoch_start_time = time.time()train(train_dataloader)val_acc,val_loss = evaluate(valid_dataloader)# 获取当前的学习率lr =optimizer.state_dict()['param_groups'][0]['1r']if total_accu is not None and total_accu>val_acc:scheduler.step()else:total_accu = val_accprint('-'*69)print('|epoch {:1d}|time:{:4.2f}s |''valid_acc {:4.3f} valid_loss {:4.3f}I1r {:4.6f}'.format(epoch,time.time()-epoch_start_time,val_acc,val_loss,lr))print('-'*69)# test_acc,test_loss =evaluate(valid_dataloader)
# print('模型准确率为:{:5.4f}'.format(test_acc))
#
#
# def predict(text,text_pipeline):
#     with torch.no_grad():
#         text = torch.tensor(text_pipeline(text),dtype=torch.float32)
#         print(text.shape)
#         output = model(text)
#         return output.argmax(1).item()
# # ex_text_str="随便播放一首专辑阁楼里的佛里的歌"
# ex_text_str="还有双鸭山到淮阴的汽车票吗13号的"
# model=model.to("cpu")
# print("该文本的类别是:%s"%label_name[predict(ex_text_str,text_pipeline)])

以上是文本识别基本代码

输出:

[[-0.85472693  0.96605204  1.5058695  -0.06065784 -2.10079319 -0.120211511.41170089  2.00004494  0.90861696 -0.62710127 -0.62408304 -3.805954991.02797993 -0.45584389  0.54715634  1.70490362  2.33389823 -1.996075184.34822938 -0.76296186  2.73265275 -1.15046433  0.82106878 -0.32701646-0.50515595 -0.37742117 -2.02331601 -1.365334    1.48786476 -1.63949711.59438308  2.23569647 -0.00500725 -0.65070192  0.07377997  0.01777986-1.35580809  3.82080549 -2.19764423  1.06595343  0.99296588  0.58972518-0.33535255  2.15471306 -0.52244038  1.00874437  1.28869729 -0.72208139-2.81094289  2.2614549   0.20799019 -2.36187895 -0.94019454  0.49448857-0.68613767 -0.79071895  0.47535057 -0.78339124 -0.71336574 -0.279315671.0514895  -1.76352624  1.93158554 -0.85853558 -0.65540617  1.3612217-1.39405773  1.18187538  1.31730198 -0.02322496  0.14652854  0.222498812.01789951 -0.40144247 -0.39880068 -0.16220299 -2.85221207 -0.277228682.48236791 -0.51239379 -1.47679498 -0.28452797 -2.64497767  2.12093259-1.2326943  -1.89571355  2.3295732  -0.53244872 -0.67313893 -0.808146040.86987564 -1.31373079  1.33797717  1.02223087  0.5817025  -0.835356470.97088164  2.09045361 -2.57758138  0.07126901]]
6

输出结果并非为0


http://www.ppmy.cn/news/1405443.html

相关文章

20240328金融读报:国内金融安全网与银行适老化实例

1、国内金融安全网(原则:事前防范金融风险过度积累,事中、事后快速高效处置风险):1)强化金融机构的公司治理和风险管理(如重组与否)2)二加强金融监管(各种存贷…

视频剪辑软件哪个好?2024会声会影怎么样呢?

随着科技的不断发展,视频制作已经不再是专业人士的专属领域,越来越多的人开始使用各种视频制作软件来记录生活、创作内容。其中,会声会影是被广泛使用的一款视频制作软件,其旗舰版更是备受关注。 视频剪辑软件哪个好?…

SnapGene 5 for Mac 分子生物学软件

SnapGene 5 for Mac是一款专为Mac操作系统设计的分子生物学软件,以其强大的功能和用户友好的界面,为科研人员提供了高效、便捷的基因克隆和分子实验设计体验。 软件下载:SnapGene 5 for Mac v5.3.1中文激活版 这款软件支持DNA构建和克隆设计&…

C++的字节对齐

什么是字节对齐 参考什么是字节对齐,为什么要对齐? 现代计算机中,内存空间按照字节划分,理论上可以从任何起始地址访问任意类型的变量。但实际中在访问特定类型变量时经常在特定的内存地址访问,这就需要各种类型数据按照一定的规…

每日一题 --- 有效的括号[力扣][Go]

有效的括号 题目:20. 有效的括号 给定一个只包括 (,),{,},[,] 的字符串 s ,判断字符串是否有效。 有效字符串需满足: 左括号必须用相同类型的右括号闭合。左括号必须以正确的顺序…

【Python基础知识点】Python的浅拷贝和深拷贝

概述 本文主要通过两个简单的代码小例子理解深拷贝和浅拷贝 主体内容 copy 模块提供了浅拷贝和深拷贝的功能。它的主要函数有: copy(x): 返回对象 x 的浅拷贝。 deepcopy(x): 返回对象 x 的深拷贝。 浅拷贝使用 copy(x) 函数,它只复制了最外层的对象,但内层的对象仍然是引用…

EfficientVMamba实战:使用EfficientVMamba实现图像分类任务(一)

文章目录 摘要安装包安装timm 数据增强Cutout和MixupEMA项目结构编译安装Vim环境环境安装过程安装库文件 计算mean和std生成数据集 摘要 论文:https://arxiv.org/pdf/2401.09417v1.pdf 作者研究了轻量级模型设计的新方法,通过引入视觉状态空间模型&…

YOLOv5改进系列:升级版ResNet的新主干网络DenseNet

一、论文理论 论文地址:Densely Connected Convolutional Networks 1.理论思想 DenseNet最大化前后层信息交流,通过建立前面所有层与后面层的密集连接,实现了特征在通道维度上的复用,不但减缓了梯度消失的现象,也使其…

AI2.0时代如何快速落地AI智能应用开发,抓住时代机会

写在前面的话 当我们提到人工智能时也就是AI的时候呢,我们大多数人首先想到的可能就是像chatGPT这样的聊天机器人,这些聊天机器人通过理解,还有生成自然语言可以给我们提供一些信息,这个是AI最终的形态吗或者AI最终的形式吗&…

每日三个JAVA经典面试题(三十四)

1.Mybatis的一级、二级缓存 MyBatis提供了两种缓存机制来提高查询效率:一级缓存和二级缓存。 一级缓存(Session级别) 作用范围:一级缓存是基于SqlSession的。这意味着,如果你在同一个SqlSession中执行两次相同的查询…

数据可视化之折线图plot

import matplotlib.pyplot as plt plt.rcParams[font.family] [SimHei]# 查看matplotlibde文件地址# import matplotlib # print(matplotlib.matplotlib_fname()) # plt.rcParams[font.sans-serif] [SimHei] # 准备数据time [20200401,20200402,20200403,20200404,20200405…

nginx与tomcat的区别?

关于nginx和tomcat的概念 网上有很多关于nginx和tomcat是什么东西的定义,我总结了一下: tomcat是Web服务器、HTTP服务器、应用服务器、Servlet容器、web容器。 Nginx是Web服务器、HTTP服务器、正向/反向代理服务器,。 这里有两个概念是交叉的&#xff…

Springboot自动获取接口实现

ServiceLoader加载接口实现步骤 1.编写接口 public interface CommunicationAdapterFactory {void setKernel(LocalKernel kernel);boolean providesAdapterFor(Vehicle vehicle);BasicCommunicationAdapter getAdapterFor(Vehicle vehicle); }2.编写实现 // 实现类 1 publi…

ElasticSearch的常用数据类型

常见的数据类型 Text类型(文本数据类型) 用于索引全文值的字段,例如电子邮件的正文或产品的描述。这些字段是analyzed,也就是说,它们通过分析器传递,以便 在被索引之前将字符串转换为单个术语的列表。通过…

【算法】字典序超详细解析(让你有一种相见恨晚的感觉!)

目录 一、前言 二、什么是字典序 ? ✨字典序概念 ✨深度理解字典序 ✨字典序排序的重要性和应用场景 三、常考面试题 ✨ 下一个排列 ✨ 字典数排序 ✨ 字典序最小回文串 四、共勉 一、前言 经常刷算法题的朋友,肯定会经常看到题目中提到 字典序 这样…

on-my-zsh 命令自动补全插件 zsh-autosuggestions 安装和配置

首先 Oh My Zsh 是什么? Oh My Zsh 是一款社区驱动的命令行工具,正如它的主页上说的,Oh My Zsh 是一种生活方式。它基于 zsh 命令行,提供了主题配置,插件机制,已经内置的便捷操作。给我们一种全新的方式使用命令行。…

JWFD流程图转换为矩阵数据库的过程说明

在最开始设计流程图的时候,请务必先把开始节点和结束节点画到流程图上面,就是设计器面板的最开始两个按钮,先画开始点和结束点,再画中间的流程,然后保存,这样提交到矩阵数据库就不会出任何问题,…

视频监控/云存储/磁盘阵列/AI智能分析平台EasyCVR集成时调用接口报跨域错误是什么原因?

EasyCVR视频融合平台基于云边端架构,可支持海量视频汇聚管理,能提供视频监控直播、云端录像、云存储、录像检索与回看、智能告警、平台级联、智能分析等视频服务。平台兼容性强,支持多协议、多类型设备接入,包括:国标G…

蓝色wordpress外贸建站模板

蓝色wordpress外贸建站模板 https://www.mymoban.com/wordpress/7.html

Android 手机部署whisper 模型

Whisper 是什么? “Whisper” 是一个由OpenAI开发的开源深度学习模型,专门用于语音识别任务。这个模型能够将语音转换成文本,支持多种语言,并且在处理不同的口音、环境噪音以及跨语言的语音识别方面表现出色。Whisper模型的目标是提供一个高效、准确的工具,以支持自动字幕…