(论文复现)DeepAnt模型复现及应用

news/2024/4/20 16:42:41/

DeepAnt论文如下,其主要是用于时间序列的无监督粗差探测

 

其提出的模型架构如下:

        该文提出了一个无监督的时间序列粗差探测模型,其主要有预测模块和探测模块组成,其中预测模块的网络结构如下。
       预测结构是将时间序列数据组织成数据集之后经过两次的卷积和最大池化,最后将卷积结果通过一个全连接层转换为一个输出数据(若是单步预测则输出单元个数为1)
       探测模块是将模型的时序预测结果与该时刻的观测数据相比来计算欧氏距离,以此来作为当前时间点距离的异常分数。以此来作为数据粗差探测的标准。

        (本博客主要是分享复现代码,论文中的细节原理可自行下载学习)


 复现代码(数据不便分享):

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader,TensorDataset
from sklearn.preprocessing import MinMaxScaler,StandardScalerdef MSE(arr1,arr2):arr1,arr2 = np.array(arr1).flatten(),np.array(arr2).flatten()assert arr1.shape[0] == arr2.shape[0]return np.sum(np.power(arr1-arr2,2)) / arr1.shape[0]def MAE(arr1,arr2):arr1,arr2 = np.array(arr1).flatten(),np.array(arr2).flatten()assert arr1.shape[0] == arr2.shape[0]return np.sum(np.abs(arr1-arr2)) / arr1.shape[0]class MyData(Dataset):def __init__(self,arr,history_window,predict_len) -> None:self.length = arr.flatten().shape[0]self.history_window = history_windowself.dataset_x,self.dataset_y = self.get_dataset(arr,history_window,predict_len)def get_dataset(self,arr,history_window,predict_len):arr = np.array(arr).flatten()N = history_windowM = predict_lendataset_x = np.zeros((arr.shape[0] - N,N))dataset_y = np.zeros((arr.shape[0] - N,M))for i in range(arr.shape[0] - N):dataset_x[i] = arr[i:i+N]dataset_y[i] = arr[i+N:i+N+M]dataset_x = torch.from_numpy(dataset_x).to(torch.float)dataset_y = torch.from_numpy(dataset_y).to(torch.float)return (dataset_x,dataset_y)def __getitem__(self, index):		# 定义方法 data[i] 的返回值return (self.dataset_x[index,:],self.dataset_y[index,:])def __len__(self):					# 获取数据集样本个数return self.length - self.history_windowclass DeepAnt(nn.Module):def __init__(self,lag,p_w):super().__init__()self.convblock1 = nn.Sequential(nn.Conv1d(in_channels=1, out_channels=32, kernel_size=3, padding='valid'),nn.ReLU(inplace=True),nn.MaxPool1d(kernel_size=2))self.convblock2 = nn.Sequential(nn.Conv1d(in_channels=32, out_channels=32, kernel_size=3, padding='valid'),nn.ReLU(inplace=True),nn.MaxPool1d(kernel_size=2))self.flatten = nn.Flatten()self.denseblock = nn.Sequential(nn.Linear(32, 40), # for lag = 10#nn.Linear(96, 40), # for lag = 20#nn.Linear(192, 40), # for lag = 30nn.ReLU(inplace=True),nn.Dropout(p=0.25),)self.out = nn.Linear(40, p_w)def forward(self, x):x = x.view(-1,1,lag)x = self.convblock1(x)x = self.convblock2(x)x = self.flatten(x)x = self.denseblock(x)x = self.out(x)return xdef Train(model,data_set,EPOCH,task_id):if torch.cuda.is_available():device = torch.device('cuda')print('cuda is used...')else:torch.device('cpu')print('cpu is used...')scale = StandardScaler()loss_fn = nn.MSELoss()model.to(device)loss_fn.to(device)train_x,train_y = data_set.dataset_x,data_set.dataset_ytrain_x = scale.fit_transform(train_x)train_x = torch.from_numpy(train_x).to(torch.float).to(device)train_y = train_y.to(device).to(torch.float)torch_dataset = TensorDataset(train_x,train_y)optimizer = torch.optim.Adam(model.parameters())BATCH_SIZE = 100model = model.train()train_loss = []print('======Start training...=======')print(f'Epoch is {EPOCH}\ntrain_x shape is {train_x.shape}\nBATCH_SIZE is {BATCH_SIZE}')for i in range(EPOCH):loader = DataLoader(dataset=torch_dataset,batch_size=BATCH_SIZE,shuffle=True)temp_1 = []for step,(batch_x,batch_y) in enumerate(loader):out = model(batch_x)optimizer.zero_grad()loss = loss_fn(out,batch_y)temp_1.append(loss.item())loss.backward()optimizer.step()torch.cuda.empty_cache()train_loss.append(np.mean(np.array(temp_1)))if i % 10 == 0:print(f"The {i}/{EPOCH} is end, loss is {np.round(np.mean(np.array(temp_1)),6)}.")print('========Training end...=======')model = model.eval()plt.plot(train_loss)pred = model(train_x).cpu().data.numpy()print(f'pred shape {pred.shape}')plt.figure()y = train_y.cpu().data.numpy().flatten()print(f'y shape {y.shape}')plt.plot(y,c='b',label='True')plt.plot(pred,'r',label='pred')plt.legend()plt.title('Train_result')plt.show()return predif __name__ == "__main__":data_f = pd.read_csv('HF05_processed.csv')data = np.array(pd.DataFrame(data_f)['OT'])lag = 10dataset = MyData(data,lag,1)model = DeepAnt(lag,1)res = Train(model,dataset,200,'1')data = data[lag:].flatten() plt.plot(data)plt.plot(res,c='r')err = data - res.flatten()anomaly_score = np.sqrt(np.power(err,2))plt.figure()plt.plot(anomaly_score)error_list = []threshold = 0.04for i in range(anomaly_score.shape[0]):if anomaly_score[i] > threshold:error_list.append(i)print(len(error_list))plt.figure()plt.plot(data)plt.plot(error_list,[data[i] for i in error_list],ls='',marker='x',c='r',markersize=4)plt.show()

运行结果:

 

才疏学浅,敬请指正!

欢迎交流:

邮箱:rton.xu@qq.com

QQ:2264787072


http://www.ppmy.cn/news/1012680.html

相关文章

C++/Linux项目——日志系统(简介)

一,日志系统的目的 1.⽣产环境的产品为了保证其稳定性及安全性是不允许开发⼈员附加调试器去排查问题, 可以借助⽇志系统来打印⼀些⽇志帮助开发⼈员解决问题 2.上线客⼾端的产品出现bug⽆法复现并解决, 可以借助⽇志系统打印⽇志并上传到服…

JavaScript--WebStorage

目录 WebStorage概述 WebStorage分类 注意: localStorage方法 介绍: 常见方法: 案例演示: sessionStorage方法 介绍: 常见方法: 案例演示: WebStorage概述 WebStorage是HTML5中…

生成2×2 或3*3 混淆矩阵(confusion matrix)的python代码

该代码可以生成22的混淆矩阵。每个矩阵对应的数值可以自行改变。 代码如下: import numpy as np import matplotlib.pyplot as plt# 随机生成值 import numpy as np import matplotlib.pyplot as plt# 创建一个2x2的二分类数据矩阵。这里可以手动改变值 data np…

Cpp学习——string(2)

目录 ​编辑 容器string中的一些函数 1.capacity() 2.reserve() 3.resize() 4.push_back()与append() 5.find系列函数 容器string中的一些函数 1.capacity() capacity是string当中表示容量大小的函数。但是string开空间时是如何开的呢?现在就来看一下。先写…

CSS调色网有哪些

本文章转载于湖南五车教育,仅用于学习和讨论,如有侵权请联系 1、https://webgradients.com/ Wbgradients 是一个在线调整渐变色的网站 ,可以根据你想要的调整效果,同时支持复制 CSS 代码,可以更好的与开发对接。 Wbg…

微信小程序多图片上传实用代码记录

微信小程序多图片上传实用代码记录 由于在小程序中,wx.uploadFile 只能一次上传一张图片,因此在一次需要上传多张图片的应用场景中例如商品图片上传、评论图片上传等场景下,不得不使用for等循环上传每一张图片,多次调用wx.upload…

替换开源LDAP,某科技企业用宁盾目录统一身份,为业务敏捷提供支撑

客户介绍 某高科技企业成立于2015年,是一家深耕于大物流领域的人工智能公司,迄今为止已为全球16个国家和地区,120余家客户打造智能化升级体验,场景覆盖海陆空铁、工厂等货运物流领域。 该公司使用开源LDAP面临的挑战 挑战1 开源…

vim学习笔记(致敬vim作者)

vim cheat sheet 30. vim 删除大法 vim 删除某个字符之后改行的其他的字符?删除某行之后的其他行?删除某个字符之后的其他字符?【1】删除单个字符? 跳到要删除的字符位置 按下d键然后按下shift 4键 【2】删除某行之后的其他行…

AlmaLinux 9 安装 Edge 和 Chrome

AlmaLinux 9 安装 Edge 和 Chrome 1. 安装 Edge2. 安装 Chrome 1. 安装 Edge 更新源, sudo dnf update -y # sudo dnf install dnf-utils -y添加 Edge 源, sudo dnf config-manager --add-repo https://packages.microsoft.com/yumrepos/edge再次更新…

MySQL alter命令修改表详解

目录 ALTER TABLE 语法 ALTER TABLE 实例 添加一列 添加多列 重命名列 修改列定义 修改列名和定义 添加主键 删除列 重命名表 修改表的存储引擎 结论 在使用表的过程中,如果您需要对表进行修改,您可以使用 ALTER TABLE 语句。通过 ALTER TAB…

GitHub中readme.md文件的编辑和使用

GitHub中readme.md文件的编辑和使用 | YuuiChungs BlogGitHub - guodongxiaren/README: README文件语法解读,即Github Flavored Markdown语法介绍

Node.js-npm包管理工具的介绍

一、概念 包,代表一组特定功能的源码集合。 包管理工具,管理包的应用软件,可以下载安装、更新、删除包等操作,在项目开发中大大提高开发效率。 npm全称:Node Package Manager 二、npm使用 如果安装了 node,…

ruoyi若依 组织架构设计--[ 角色管理 ]

ruoyi若依 组织架构设计--[ 角色管理 ] 角色新增后端代码 角色修改后端代码 角色查询角色删除角色分配数据权限后端代码 角色分配用户 角色新增 后端代码 有一点,我认为新增的时候,也需要修改redis中的权限。 角色修改 后端代码 因为修改了role_menu表了…

什么都能画的大模型来了~飞星大模型AI绘画已上线,站在SD巨人的肩膀对标MJ

飞星大模型 原文地址:(原文有所有的效果图片) 什么都能画的大模型来了~飞星大模型AI绘画已上线,站在SD巨人的肩膀对标MJ 专门为飞链云用户准备,拥有上千万版权画作,使用全球最顶尖的GPU服务器,经过数百个日夜训练而成…

算法练习工程1.1

最长公共前缀 题目说明: * 编写一个函数来查找字符串数组中的最长公共前缀。如果不存在公共前缀,返回空字符串 ""。 示例 1: * 输入:strs ["flower","flow","flight"] * …

算法-岛屿数量

给你一个由 1(陆地)和 0(水)组成的的二维网格,请你计算网格中岛屿的数量。 岛屿总是被水包围,并且每座岛屿只能由水平方向和/或竖直方向上相邻的陆地连接形成。 此外,你可以假设该网格的四条边…

冒泡排序 简单选择排序 插入排序 快速排序

bubblesort 两个for循环&#xff0c;从最右端开始一个一个逐渐有序 #include <stdio.h> #include <string.h> #include <stdlib.h>void bubble(int *arr, int len); int main(int argc, char *argv[]) {int arr[] {1, 2, 3, 4, 5, 6, 7};int len sizeof(…

人大金仓数据库Docker部署

docker 搭建 yum -y install yum-utilsyum-config-manager --add-repo http://mirrors.aliyun.com/docker-ce/linux/centos/docker-ce.reposystemctl start docker.servicesystemctl enable docker.servicesystemctl status docker.service 配置Docker cd /etc/docker/ vi da…

HBase Shell 操作

1、基本操作 1.1、进入HBase客户端命令行 前提是先启动hadoop集群和zookeeper集群。 bin/hbase shell 1.2、查看帮助命令 helphelp 查看指定命令的语法规则 查看 list_namespace 的用法&#xff08;‘记得加单引号’&#xff09; help list_namespace 2、namespace 我们…

webpack基础知识九:如何提高webpack的构建速度?

一、背景 随着我们的项目涉及到页面越来越多&#xff0c;功能和业务代码也会随着越多&#xff0c;相应的 webpack 的构建时间也会越来越久 构建时间与我们日常开发效率密切相关&#xff0c;当我们本地开发启动 devServer 或者 build 的时候&#xff0c;如果时间过长&#xff…