【机器学习】036_权重衰退

news/2023/12/1 1:22:31

一、范数

· 定义:向量的范数表示一个向量有多大(分量的大小)

L1范数:

        · 即向量元素绝对值之和,用符号 ‖ v ‖ 1 表示。

        · 公式:\left \| x \right \|_1 = \sum_{n}^{i=1}|x_i|

L2范数:

        · 即向量的模,向量各元素绝对值的平方之和再开根号,用符号 ‖ v ‖ 2 表示。

        · 公式:\left \| x \right \|_2=\sqrt{\sum_{n}^{i=1}x_i^2}

Lp范数:

        · 即向量范数的一般形式,各元素绝对值的p次幂之和再开p次根号,用符号 ‖ v ‖ p 表示。

        · 公式:\left \| x \right \|_p = (\sqrt[p]{\sum_{n}^{i=1}|x|^p})

二、权重衰减(L2正则化)

模型(函数)复杂度的度量:

· 一般通过线性函数 f(x) = w^Tx 中的权重向量的某个范数(如 \left \| w \right \|^2)来度量其复杂度

要想避免模型的过拟合,就要控制模型容量,使模型的权重向量尽可能小

· 通过限制参数值的选择范围来控制模型容量

衰减方法:

借助损失函数,将权重范数作为惩罚项添加到最小化损失中;使得损失函数的作用变为“最小化预测损失和惩罚项之和”。

损失函数公式如下:

J(w,b)=L(w,b)+\frac{\lambda }{2}\left \| w \right \|^2

· 其中,L(w,b) 是模型原本的损失函数,\frac{\lambda }{2}\left \| w \right \|^2 是新添加的惩罚项。

· 正则化常数 \lambda 用来描绘这种权衡,其为一个非负超参数。

· \lambda 的值越大,表示对 w 的约束较大;反之 \lambda 的值越小,表示对 w 的约束较小。

※为何选用平方范数而不是标准范数:

        · 便于计算。平方范数可以去掉平方根使得导数更容易计算,利于反向传播过程。

        · 使用L2范数是因为它会对权重向量的大分量施加巨大的惩罚,使各权重均匀分布。

        · L1范数惩罚会导致权重集中在某一小部分特征上,其它权重被清除为0(特征选择)。

使用该损失函数,就可以使梯度下降的优化算法在训练的每一步都衰减权重,避免过拟合发生。

如上图所示,现在模型的损失函数同时受两项影响,一是误差项,二是惩罚项。

        现在在等高线图上,梯度下降最终收敛的位置不再是某一个项所造成的最低点,因为在这时,可能误差项达到最小了,但是惩罚项很大,使得惩罚项拉着损失函数再向另一个方向移动。

        只有当达到了两个项共同作用下的一个平衡点时,损失函数才具有最小值,这个时候的模型往往复杂度也降低了,虽然有可能造成训练损失增大,但是测试损失会减小。

三、代码实现权重衰减

从零实现代码如下:

import matplotlib
import torch
from torch import nn
from d2l import torch as d2l# 训练数据集、测试数据集、输入值、训练批次
n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
# 初始化w和b的真实值
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
# 拿到训练数据
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)# 初始化模型参数w和b
def init_params():w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)b = torch.zeros(1, requires_grad=True)return [w, b]
# 定义L2范数惩罚项
def l2_penalty(w):return torch.sum(w.pow(2)) / 2
# 实现训练代码,读入参数为兰姆达(正则化参数)
def train(lambd):w, b = init_params()net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_lossnum_epochs, lr = 100, 0.003animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:# 增加了L2范数惩罚项,# 广播机制使l2_penalty(w)成为一个长度为batch_size的向量l = loss(net(X), y) + lambd * l2_penalty(w)l.sum().backward()d2l.sgd([w, b], lr, batch_size)if (epoch + 1) % 5 == 0:animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数是:', torch.norm(w).item())
# 使用权重进行训练
train(lambd=3)

简洁实现代码如下:

import torch
from torch import nn
from d2l import torch as d2l# 训练数据集、测试数据集、输入值、训练批次
n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
# 初始化w和b的真实值
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
# 拿到训练数据
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)def train_concise(wd):net = nn.Sequential(nn.Linear(num_inputs, 1))for param in net.parameters():param.data.normal_()loss = nn.MSELoss(reduction='none')num_epochs, lr = 100, 0.003# 偏置参数没有衰减trainer = torch.optim.SGD([{"params":net[0].weight,'weight_decay': wd},{"params":net[0].bias}], lr=lr)animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:trainer.zero_grad()l = loss(net(X), y)l.mean().backward()trainer.step()if (epoch + 1) % 5 == 0:animator.add(epoch + 1,(d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数:', net[0].weight.norm().item())train_concise(3)

http://www.ppmy.cn/news/1231401.html

相关文章

虾皮选品免费工具:如何用知虾进行虾皮市场分析选品

在如今的电商时代,了解市场需求和选择热销产品是成功经营的关键。虾皮作为东南亚地区最大的电商平台之一,提供了一系列的选品工具,帮助卖家在市场竞争中脱颖而出。本文将介绍如何使用虾皮的免费工具——知虾进行虾皮市场分析选品,…

CentOS8部署Skywalking(非容器方式)

一、官网下载安装包 二、安装 #tar -zxf apache-skywalking-apm-9.6.0.tar.gz #mv apache-skywalking-apm-9.6.0 skywalking #cd /opt/skywalking 修改配置文件 #vi /opt/skywalking/config/application.yml #vi vi /opt/skywalking/webapp/application.yml 三、运行 ./bin…

【开源】基于Vue和SpringBoot的服装店库存管理系统

项目编号: S 052 ,文末获取源码。 \color{red}{项目编号:S052,文末获取源码。} 项目编号:S052,文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 数据中心模块2.2 角色管理模块2.3 服…

【Linux】 find命令使用

find find命令是一种通过条件匹配在指定目录下查找对应文件或者目录的工具。匹配的条件可以是文件名称、类型、大小、权限属性、时间戳等。find命令还可以配合相关命令对匹配到的文件作出后续处理。 语法 find [路径...] [表达式] [path...]为需要查找文件所指定的路径。如果…

MQ和redis的内部原理一些总结

首先,先知道内部原理;其次,就是查官方文档实战了。 但是如果不熟悉内部原理,那么仅仅只是安装官方文档,并不能排除跟踪问题和故障、预防风险等策略; 以下总结图解:(mysql 8.0新增的…

安防视频监控管理平台EasyCVR定制首页开发与实现

视频监控平台EasyCVR能在复杂的网络环境中,将分散的各类视频资源进行统一汇聚、整合、集中管理,在视频监控播放上,TSINGSEE青犀视频安防监控汇聚平台可支持1、4、9、16个画面窗口播放,可同时播放多路视频流,也能支持视…

【ARM AMBA AXI 入门 14 -- AXI 窄位传输 | 非对齐传输| 大小端传输】

请阅读【ARM AMBA AXI 总线 文章专栏导读】 文章目录 窄位传输 (Narrow Transfer)非对齐传输 (Unaligned Transfer)大小端传输 (Endianness Transfer)ARM AMBA AXI (Advanced eXtensible Interface) 是一个高性能、高带宽的总线接口,常用于连接高速微处理器核心与其它部件。在…

Python+Qt虹膜检测识别

程序示例精选 PythonQt虹膜检测识别 如需安装运行环境或远程调试,见文章底部个人QQ名片,由专业技术人员远程协助! 前言 这篇博客针对《PythonQt虹膜检测识别》编写代码,代码整洁,规则,易读。 学习与应用推…

网络连接Android设备

参考:https://blog.csdn.net/qq_37858386/article/details/123755700 二、网络adb调试开启步骤 1、把Android平板或者手机WiFi连接到跟PC机子同一个网段的网络,在设置-系统-关于-状态 下面查看设备IP,然后查看PC是否可以ping通手机的设备的IP。 2、先…

centos7 利用nc命令探测某个tcp端口是否在监听

脚本 # 安装nc yum install -y ncnc -vz 192.168.3.128 60001 if [ $? -eq 0 ]; thenecho "tcp succeed" elseecho "tcp failed" fi nc -vz 192.168.3.128 60001 探测192.168.3.128服务器上60001 tcp端口, -vz说明是探测TCP的 端口开启的情况 执行…

spark内置数据类型

在用scala编写spark的时候,假如我现在需要将我spark读的数据源的字段,做一个类型转换,因 为需求中要拼接出sql的create table语句,需要每个字段的sql中的类型,那么就需要去和sparksql 中的内置数据类型去比对。 写s…

Bandzip下载(好用的解压缩工具)

1.下载链接:Bandizip - Download Bandizip 6.x 2.点击 下载Bandzip 进行下载,下载到本地,直接安装即可

Maven-自定义插件

Maven自定义插件 一、背景二、命名规范三、插件开发四、执行插件1.执行插件2.简化命令行2.1 命令格式为 mvn groupId:artifactId:goal2.2 命令格式为 mvn ${prefix}:goal 五、构建周期执行插件总结参考链接: 一、背景 Maven是由一系列用于执行构建任务的插件和一个…

Spring Boot和Spring MVC的区别

1 Spring MVC 是Spring的一个模块,是一个web框架。分为Model,View,Controller(模型层、视图层、控制层)。 2 Spring Boot Spring Boot 自动配置,降低了项目搭建的复杂度。Spring框架需要大量的配置&…

Kubernetes Gateway API 攻略:解锁集群流量服务新维度!

Kubernetes Gateway API 刚刚 GA,旨在改进将集群服务暴露给外部的过程。这其中包括一套更标准、更强大的 API资源,用于管理已暴露的服务。在这篇文章中,我将介绍 Gateway API 资源,并以 Istio 为例来展示这些资源是如何关联的。通…

【开源】基于Vue和SpringBoot的高校宿舍调配管理系统

项目编号: S 051 ,文末获取源码。 \color{red}{项目编号:S051,文末获取源码。} 项目编号:S051,文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能需求2.1 学生端2.2 宿管2.3 老师端 三、系统…

应用软件安全编程--24不要使用硬编码密匙

当程序中使用硬编码加密密匙时,所有项目开发人员都可以查看该密匙,甚至如果攻击者能够获取 程序 class文件,可通过反编译得到密匙,硬编码加密密匙会大大降低系统安全性。 对于避免使用硬编码密匙的情况,示例1给出了不…

Vue3的watch使用介绍及场景

目录 一、watch的使用 1. 监听一个变量 2. 监听一个对象的属性 3. 监听一个函数的返回值 二、watch的使用场景 1. 监听表单的变化 2. 监听路由参数的变化 3. 监听Vuex中的数据变化 三、watch的效果图 四、watch的示例 以上就是Vue3的watch的介绍,watch是…

(付费请教,防遗忘记录)为什么EDEM bondV1 V2添加后还是没有胶结?????

最近用到了EDEM的胶结功能,在添加了bond V1和V2后,我发现无论颗粒半径设置为多少,接触半径设置为多少,胶结键的参数设置为多少,bond generated数量均为0!!!在后处理中bond中的力也通通为0!!我跟我的同门百思不得其解!!!!最终付费请教了一个大牛,解决了这个问题!…

纽扣电池/含纽扣电池产品上架亚马逊各国法规标准要求16 CFR 第 1700.15/20 ANSI C18.3M(瑞西法案认证)

亚马逊纽扣电池认证标准有哪些? 一、美国站(亚马逊纽扣电池/含纽扣电池商品)安全测试标准要求: 16 CFR 第 1700.15 、16 CFR 第 1700.20 ANSI C18.3M、警示标签声明要求(第 117-171 号公众法) 二、澳大…
最新文章