Ep5 线性模型with Pytorch

news/2023/12/2 11:11:46

1、流程

确定数据集、

设计模型(算出预测值)、

构建损失函数(最终为一个标量值,只有标量才能用backward)和优化器、

训练周期(forward算loss,backward算grad,update更新wi)

2、numpy的广播机制

y=wx+b中w的扩充

3、如何使用Mudule类来自定义一个模型

创建模型有两个要素:构建子模块和拼接子模块。在`__init__()` 方法里创建子模块,在`forward()`方法里拼接子模块。

4、torch.nn

包含用来搭建各个层的模块(Modules)、一系列有用的loss函数、常用的激活函数等。

我们在定义自已的网络的时候,需要继承nn.Module类,并重新实现构造函数__init__和forward这两个方法

(1)一般把网络中具有可学习参数的层(如全连接层、卷积层等)放在构造函数__init__()中,当然我也可以吧不具有参数的层也放在里面;

(2)一般把不具有可学习参数的层(如ReLU、dropout、BatchNormanation层)可放在构造函数中,也可不放在构造函数中,如果不放在构造函数__init__里面,则在forward方法里面可以使用nn.functional来代替.
    
(3)forward方法是必须要重写的,它是实现模型的功能,实现各个层之间的连接关系的核心。
总结:

torch.nn是专门为神经网络设计的模块化接口。nn构建于autograd之上,可以用来定义和运行神经网络。

Pytorch学习笔记07----nn.Module类与前向传播函数forward的理解 - 雨后观山色 - 博客园 (cnblogs.com)

5、一些方法

4.1 torch.nn.Linear

4.2 torch.nn.MSELoss

4.3 torch.opti.SGD

6、三个魔法函数

python中所有对象都有一个从创建,被使用,再到消亡的过程,不同的阶段由不同的方法负责执行。

6.1 __new__()

__new__ 方法的返回值就是类的实例对象,这个实例对象会传递给 __init__ 方法中定义的 self 参数,以便实例对象可以被正确地初始化。

6.2 __int__()构造函数

当创建实例时,__init__() 方法被自动调用为创建的实例增加实例属性。

Python入门 类class 基础篇 (zhihu.com)

class Circle(object):  # 创建Circle类def __init__(self, r): # 初始化一个属性r(不要忘记self参数,他是类下面所有方法必须的参数)self.r = r  # 表示给我们将要创建的实例赋予属性r赋值

面试喜欢问的问题:创建类时,类方法中的self是什么?
self 代表类的实例,是通过类创建的实例 

6.3  __call__()

在python中,int/str/func/class都是对象。

在Python中,只要在创建类的时候定义了__call__()方法,这个类型就是可调用的。

在Python 中,凡是可以将 () 直接应用到自身并执行,都称为可调用对象。可调用对象包括自定义的函数、Python 内置函数以及本节所讲的类实例对象。

 

1. xInstance = X(1, 2, 3) 这句代码实例化了xInstance类型的对象X,在实例化时调用了__init__(self, a, b, range)函数

2. xInstance(1,2)代码使用类+参数的语法 直接调用了 __call__(self, a, b)函数,而不需要使用X.__call__(self, a, b)语法,这就是__call__(self, a, b)函数的作用,能够让python中的类能够像方法一样被调用

 #深入探究# PyTorch中的 forward() 方法详解 - 知乎 (zhihu.com)

7、如何定义类?

class people:#定义基本属性name = ''age = 0#定义私有属性,私有属性在类外部无法直接进行访问__weight = 0#定义构造方法def __init__(self,n,a,w):self.name = nself.age = aself.__weight = wdef speak(self):print("%s 说: 我 %d 岁。" %(self.name,self.age))# 实例化类
p = people('runoob',10,30)
p.speak()

(62条消息) python如何定义类?_python定义类_DongHappyyy的博客-CSDN博客

【python类包含方法】

公有方法:在类中和类外都能调用的方法

私有方法:不能被类外部调用,在方法前面加上“__“双下划线就是私有方法

类方法:被classmethod()函数处理过的函数,能被类所调用,也能被对象所调用(是继承的关系)

静态方法:相当于“全局函数”,可以被类直接调用,可以被所有实例化对象共享,通过staticmethod()定义, 静态方法没有“self”参数。

内置函数:__xxx__ 系统定义名字,前后均有一个“双下划线” 代表python里特殊方法专用的标识,如 __init__() 代表类的构造函数。

三种方法的主要区别在于参数,实例方法被绑定到一个实例,只能通过实例进行调用;但是对于静态方法和类方法,可以通过类名和实例两种方式进行调用

self参数 说明:
用于区分函数和类的方法(必须有一个self),self参数表示执行对象本身

【python类包含属性】

  • 私有属性

    函数、方法或者属性的名称以两个下划线开始,则为私有类型;

  • 公有属性

    如果函数、方法或者属性的名称没有以两个下划线开始,则为公有属性;

  • 实例属性

    以self作为前缀的属性;

  • 局部变量

    类的方法中定义的变量没有使用self作为前缀声明,则该变量为局部变量;

详细说明:

xxx 公有变量

_xxx "单下划线 " 开始的成员变量叫做保护变量,意思是只有类对象(即类实例)和子类对象自己能访问到这些变量,需通过类提供的接口进行访问;不能用'from module import *'导入

__xxx 类中的私有变量(Python的函数也是对象,所以成员方法称为成员变量也行得通。)," 双下划线 " 开始的是私有成员,意思是只有类对象自己能访问,连子类对象也不能访问到这个数据但可以通过私有属性,可以通过instance._classname_attribute方式访问

tip:python中有点伪私有属性的意思

Python 类>>>类属性(私有属性、公有属性、实例属性、局部变量)类方法(实例方法、静态方法) - 勿忘-前行 - 博客园 (cnblogs.com)

8、前向和后向

8.1 forward()

模型训练时,不需要调用forward这个函数,只需要在实例化一个对象中传入对应的参数就可以自动调用 forward 函数。

 

8.2 backward()

loss.backward()故名思义,就是将损失loss 向输入侧进行反向传播,同时对于需要进行梯度计算的所有变量 x(requires_grad=True),计算梯度 dloss/dx ,并将其累积到梯度x.grad中备用,即:

 在pytorch中,只有浮点类型的数才有梯度,因此在定义张量时一定要将类型指定为float型。

(62条消息) PyTorch:梯度计算之反向传播函数backward()_backward函数_精致的螺旋线的博客-CSDN博客

requires_grad: 如果需要为张量计算梯度,则为True,否则为False。我们使用pytorch创建tensor时,可以指定requires_grad为True(默认为False),

grad_fn: grad_fn用来记录变量是怎么来的,方便计算梯度,y = x*3,grad_fn记录了y由x计算的过程。

>>y = x + 2
tensor([[3., 3.],[3., 3.]], grad_fn=<AddBackward>)>>print(y.grad_fn)  # <AddBackward object at 0x1100477b8>

grad:当执行完了backward()之后,通过x.grad查看x的梯度值。
(62条消息) requires_grad,grad_fn,grad的含义及使用_dlage的博客-CSDN博客

如果输出是一个向量,计算梯度需要传入参数。例如

x = torch.tensor([0.0, 2.0, 8.0], requires_grad=True)
y = torch.tensor([5.0, 1.0, 7.0], requires_grad=True)
z = x * y
print(z)

如果想求z对x或y的梯度,则需要将一个外部梯度传递给z.backward()函数。这个额外被传入的张量就是grad_tensor。

(62条消息) PyTorch:梯度计算之反向传播函数backward()_backward函数_精致的螺旋线的博客-CSDN博客9. 类与子类

这是两个平等的类

B是A的子类

 10、super().__inti__()

class A:def hi(self):print("A hi")class B(A):def hello(self):print("B hello")b = B()
b.hi()       # B里没有写hi(),这里调用的是继承自A的hi()------------------------------------------------------------------
class A:def hi(self):print("A hi")class B(A):def hi(self):print("B hi")b = B()
b.hi()    # 这里调用的就是B自己的hi()
------------------------------------------------------------------
class A:def hi(self):print("A hi")class B(A):def hi(self):super().hi()         # 通过super调用父类A的hi()print("B hi")b = B()
b.hi()    # 这里调用的就是A里面的hi()

 (62条消息) python中super().__init__()_BeanInJ的博客-CSDN博客

torch.nn.Module 这个类的内部有多达 48 个函数,这个类是 PyTorch 中所有 neural network module 的基类,自己创建的网络模型都是这个类的子类

作业

optimizer = torch.optim.SGD(model.parameters(), lr=0.01) 

 

 

optimizer = torch.optim.Adam(model.parameters(), lr=0.1, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

 

 

optimizer = torch.optim.ASGD(model.parameters(), lr=0.01, lambd=0.0001, alpha=0.75, t0=1000000.0, weight_decay=0)

 

 代码:

import torch
import matplotlib.pyplot as plt# prepare dataset
# x,y是矩阵,3行1列 也就是说总共有3个数据,每个数据只有1个特征
x_data = torch.tensor([[1.0], [2.0], [3.0]])#创建Tensor
y_data = torch.tensor([[2.0], [4.0], [6.0]])# design model using class
"""
our model class should be inherit from nn.Module, which is base class for all neural network modules.
member methods __init__() and forward() have to be implemented
class nn.linear contain two member Tensors: weight and bias
class nn.Linear has implemented the magic method __call__(),which enable the instance of the class can
be called just like a function.Normally the forward() will be called 
"""class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__()##父类# (1,1)是指输入x和输出y的特征维度,这里数据集中的x和y的特征都是1维的# 该线性层需要学习的参数是w和b  获取w/b的方式分别是~linear.weight/linear.biasself.linear = torch.nn.Linear(1, 1, bias=False)def forward(self, x):y_pred = self.linear(x)return y_predmodel = LinearModel()
Loss_list = []
# construct loss and optimizer
# criterion = torch.nn.MSELoss(size_average = False)
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  #params:要训练的参数
#optimizer = torch.optim.Adam(model.parameters(), lr=0.1, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
# optimizer = torch.optim.ASGD(model.parameters(), lr=0.01, lambd=0.0001, alpha=0.75, t0=1000000.0, weight_decay=0)
# training cycle forward, backward, update
for epoch in range(100):y_pred = model(x_data)  # forward:predictloss = criterion(y_pred, y_data)  # forward: lossprint(epoch, loss.item())optimizer.zero_grad()  # the grad computer by .backward() will be accumulated. so before backward, remember set the grad to zeroloss.backward()  # backward: autograd,自动计算梯度optimizer.step()  # update 参数,即更新w和b的值Loss_list.append(loss)Loss_list_final = torch.tensor(Loss_list)
print('w = ', model.linear.weight.item())
#print('b = ', model.linear.bias.item())x_test = torch.tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)plt.figure()
plt.plot(Loss_list_final,'b',label = 'loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.title('with SGD')
plt.show()

 

 


http://www.ppmy.cn/news/131453.html

相关文章

HBase Shell 常用命令练习

HBase Shell 常用命令练习 前言一、HBase Shell是什么&#xff1f;二、HBase Shell使用步骤1.启动HBase2.启用HBase Shell3.键入HBase Shell命令操作HBase 三、常用HBase Shell实例1.常用的HBase Shell命令2.一个运用上述命令的综合实例&#xff1a; 总结 前言 提示&#xff1…

侃侃算法EP5·二叉树及其遍历

1. 前言 这个板块旨在记录一些日常中或是面试中会问到的算法和数据结构相关的内容&#xff0c;更多是给自己总结和需要的人分享。在内容部分可能由于我的阅历和实战经历不足&#xff0c;会有忽视或是写错的点&#xff0c;还望轻喷。 2. 内容 关于什么是树、子树、根节点、叶…

ES6、ES7、ES8、ES9、ES10新特性及其兼容性

强烈推荐阅读一篇文章&#xff0c;也是自己为了做保存把地址贴到自己博客&#xff0c;大家一起学习&#xff1a; ECMAScript 6 入门教程——阮一峰 盘点ES7、ES8、ES9、ES10新特性

es_01

字段&#xff1a;等于一个属性 文档&#xff1a;行数据等于多个字段组成 映射&#xff1a;mapping表结构 索引&#xff1a;index 数据库 存文档 类型&#xff1a;忽略 正排索引&#xff1a; 需要按照key来搜索每个key下的value&#xff0c;要收到全部的数据&#xff0c;就要进…

es(八)

单字符串串多字段查询:Dis Max Query 想在百度搜索一个单字符 should是如何算分过程 查询 should 语句句中的两个查询加和两个查询的评分乘以匹配语句句的总数除以所有语句句的总数

JS高级+ES678

js高级 数据类型 基本(值)类型 Number: 任意数值String: 任意文本Boolean: true/falseundefined: undefinednull: null 对象(引用)类型 Object: 任意对象 主要用来包含无序复杂的数据Array: 特别的对象类型(下标/内部数据有序)Function: 特别的对象类型(可执行)&#xff0c;F…

阿里云ECS部署ES

背景 最近越来越多的公司把业务搬迁到云上&#xff0c;公司也有这个计划&#xff0c;自己抽时间在阿里云和Azure上做了一些小的尝试&#xff0c;现在把阿里云上部署ES和kibana记录下来。为以后做一个参考&#xff0c;也希望对其他人有帮助。 这里以阿里云为例&#xff0c;由于测…

ES-08-ElasticSearch数据分片(shard)

说明 ElasticSearch数据分片&#xff08;shard&#xff09;创建多分片索引、更改多分片索引副本分片数量、路由计算和分片控制官方文档&#xff1a;https://www.elastic.co/cn/ 核心概念 》什么是数据分片&#xff08;shard&#xff09;&#xff1f; 一个分片是一个底层的工…

ES-09-ElasticSearch分词器

说明 ElasticSearch分词器默认分词器&#xff08;标准分词器&#xff09;、ik分词器、ik分词器扩展字典自定义词语关键词&#xff1a;keyword、text、ik_max_word、ik_smart、词条、词典、倒排表官方文档&#xff1a;https://www.elastic.co/cn/ik分词器文档&#xff1a;https…

Elasticsearch8.0

Elastic 中国社区官方博客_CSDN博客-Elastic,Elasticsearch,Kibana领域博主Elastic 中国社区官方博客擅长Elastic,Elasticsearch,Kibana,等方面的知识https://elasticstack.blog.csdn.net/ ✅ 启动 elasticsearch # cd /usr/local/elastic/elasticsearch/ # ./bin/elasticsearc…

ES6/ES7/ES8/ES9/ES10

ES10 ES10 功能完全指南 好犯困啊 我来打打字 string.prototype.matchAll() ‘Hello’.match(‘l’) eg:判断字符串中存在几个某元素 yannanna’.match(/n/g).length 扁平化多维数组&#xff08;想不出啥时候会用到&#xff09; let array [1, 2, 3, 4, 5]; array.map(x &g…

ES

文章目录 1. 什么是ElasticSearch&#xff1f;为什么要使用Elasticsearch?——克服模糊查询的缺点、查询速度快2. ES中的倒排索引是什么&#xff1f;——词→文章3. ES是如何实现master选举的&#xff1f;——各节点分别排序投票4. 如何解决ES集群的脑裂问题——增大最少候选节…

es 客户端

ES客户端&#xff1a;Elasticsearch Clients 语言无关性 Java REST ClientJava APIPython APIGo API.Net APIPHP APIJavaScripts APIRuby APIPerl APIElandRustCommunity Contributed Clients Java API 生命周期&#xff08;生卒年&#xff1a;ES 0.9 - ES 7.x&#xff09;…

ES7,ES8,ES10新特性

ES7 ES7在ES6的基础上增加了三项内容 求幂运算符 ** console.log(3 ** 2 ) // 9 Array.prototype.includes()方法 includes()的作用是查找一个值在不在数组中&#xff0c;接受两个参数&#xff1a;搜索值和搜索的开始索引。如果没有传递参数默认的索引是0 // 下面的这两种方…

ES7+ES8

撰文为何 身为一个前端开发者&#xff0c;ECMAScript(以下简称ES)早已广泛应用在我们的工作当中。了解ECMA机构流程的人应该知道&#xff0c;标准委员会会在每年的6月份正式发布一次规范的修订&#xff0c;而这次的发布也将作为当年的正式版本。以后的改动&#xff0c;都会基于…

elasticsearch系列七:ES Java客户端-Elasticsearch Java client

一、ES Client 简介 1. ES是一个服务&#xff0c;采用C/S结构 2. 回顾 ES的架构 3. ES支持的客户端连接方式 3.1 REST API &#xff0c;端口 9200 这种连接方式对应于架构图中的RESTful style API这一层&#xff0c;这种客户端的连接方式是RESTful风格的&#xff0c;使用http…

【阅读理解】ES7/ES8/ES9/ES10新特性

今天阅读了一篇咨询&#xff0c;有关于ES7-ES10 &#xff08;ES2016-2019&#xff09;&#xff0c;ES6后新出的特性比较频繁。 首先附上思维导图 下面都是我阅读咨询后理解而编写的&#xff1a; ES7&#xff1a; 1.Array.prototype.includes() 这个方法可以判断一个元素…

ES7.8 安装

环境 CentOS7.4 elasticsearch-7.8.0 jdk8 下载Linux版本的elasticsearch安装包 https://www.elastic.co/cn/downloads/past-releases 安装集群在每个节点上的安装步骤基本上都是一样的&#xff0c;我以一个节点为例 下载完成之后通过ftp上传到linux服务器指定目录下&am…

ES7、ES8、ES9、ES10、ES11新特性

一、ES7新特性 1. Array.prototype.includes includes 方法用来检测数组中是否包含某个元素&#xff0c;返回布尔值 2. 指数操作符 指数运算符 ** &#xff0c;用来实现幂运算&#xff0c;功能与 Math.pow 结果相同 二、ES8新特性 1. async 和 await async 和 await 两种…

ES7、ES8、ES9、ES10新特性

ES7新特性 1.Array.prototype.includes()方法 在ES6中我们有String.prototype.includes()可以查询给定字符串是否包含一个字符,而在 ES7 中,我们在数组中也可以用 Array.prototype.includes 方法来判断一个数组是否包含一个指定的值,根据情况,如果包含则返回true,否则返…
最新文章