pytorch千问模型源码分析

news/2024/10/3 19:51:47/

class Qwen2Config(PretrainedConfig):
    model_type = "qwen2"
    # 表明在推理过程中,对于某些操作,模型或库会忽略 past_key_values 的存在。这对于控制序列生成的行为是非常有用的,
    # 特别是在需要初始化生成过程或格式化输出结果时。然而,实际应用中,past_key_values 经常用于加速连续生成过程,特
    # 别是在长时间依赖的场景下
    keys_to_ignore_at_inference = ["past_key_values"]
    def __init__(
        self,
        # 用途:用于初始化嵌入层(embedding layer),以及作为最终全连接层(fully connected layer)的输出维度。
        vocab_size=151936,# 词汇表的大小,即模型可以识别的不同单词或标记的数量。
        hidden_size=4096,# 含义:隐藏层的维度,即每个Transformer编码器或解码器层的输出向量的大小。
        # 决定了模型内部状态的表示能力
        intermediate_size=22016,# 前馈神经网络(feed-forward network, FFN)中间层的维度。
        # FFN通常由两个线性层组成,第一个线性层的输出维度为 intermediate_size,用于提升模型的学习能力。
        num_hidden_layers=32,# Transformer模型中编码器或解码器堆叠的层数。增加模型的深度,以增强其捕捉复杂特征的能力。
        num_attention_heads=32,# 含义:每个Transformer层中多头注意力机制(multi-head attention mechanism)的头数。
        # 允许多个并行的注意力机制运行,从而捕捉不同的特征。
        num_key_value_heads=32,# 每层中用于计算键(Key)和值(Value)的注意力头的数量
        # 优化计算资源,有时候为了节省计算成本,可以设置 num_key_value_heads 小于 num_attention_heads。
        hidden_act="silu",# 隐藏层使用的激活函数。引入非线性,使模型能够学习复杂的映射关系
        max_position_embeddings=32768,# 模型支持的最大位置嵌入的长度。决定了模型能够处理的最大序列长度。
        initializer_range=0.02,# 模型权重初始化的标准差范围。控制模型参数初始化时的随机性。
        rms_norm_eps=1e-6,# RMSNorm 层中使用的数值稳定性项。防止除法运算中的除零错误。
        use_cache=True,# 是否使用缓存机制来存储过去计算的结果。在生成任务中,可以加速推理过程
        tie_word_embeddings=False,# 是否共享输入嵌入层(input embedding)和输出嵌入层(output embedding)的权重。
        # 减少模型参数数量,有时可以提高模型性能。
        rope_theta=10000.0,# 旋转位置嵌入(Rotary Positional Embedding)中的超参数。
        # 帮助模型理解不同位置的相对关系。
        use_sliding_window=False,# 含义:是否使用滑动窗口机制。
        # 用途:在处理长序列时,可以减少内存消耗。
        sliding_window=4096, # 含义:滑动窗口的大小。定义了滑动窗口覆盖的序列长度。
        max_window_layers=28,# 含义:最多可以有多少层使用滑动窗口机制。
        # 用途:限制滑动窗口机制使用的层数,平衡计算效率和模型性能。
        attention_dropout=0.0,# 含义:注意力机制中的Dropout概率。随机丢弃一些注意力权重来防止过拟合。
        **kwargs,
    ):
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.use_sliding_window = use_sliding_window
        self.sliding_window = sliding_window
        self.max_window_layers = max_window_layers

        # for backward compatibility
        if num_key_value_heads is None:
            num_key_value_heads = num_attention_heads

        self.num_key_value_heads = num_key_value_heads
        self.hidden_act = hidden_act
        self.initializer_range = initializer_range
        self.rms_norm_eps = rms_norm_eps
        self.use_cache = use_cache
        self.rope_theta = rope_theta
        self.attention_dropout = attention_dropout

        super().__init__(
            tie_word_embeddings=tie_word_embeddings,
            **kwargs,
        )
# 规范化技术,旨在替代传统的 Layer Normalization(LN)
# 核心思想是对输入张量的每个样本的每个特征进行规范化,使其均值为 0,方差为 1
class Qwen2RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6): # 隐藏层的大小
        super().__init__()
        # 一个可学习的权重参数,初始化为全 1 张量。
        self.weight = nn.Parameter(torch.ones(hidden_size))
        # 用于防止除零错误的小常数。
        self.variance_epsilon = eps
    def forward(self, hidden_states):
        # 记录输入张量的数据类型,以便最终转换回原始类型。
        input_dtype = hidden_states.dtype
        # 转换为 torch.float32 类型,以确保数值稳定性。
        hidden_states = hidden_states.to(torch.float32)
        # 计算每个样本的方差
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        # 计算每个样本的 RMS 值,并对每个样本进行规范化
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        # 应用可学习的权重,其中 γγ 是一个可学习的参数,用于缩放规范化后的张量。
        return self.weight * hidden_states.to(input_dtype)
# 用于生成旋转位置嵌入。这种嵌入方法在 Transformer 模型中用于捕捉序列中的位置信息,尤其适用于长序列任务。
# 通过旋转的方式将位置信息编码到嵌入向量中。具体步骤如下:
# 生成频率:通过指数函数生成一系列频率值。计算正弦和余弦:利用生成的频率计算正弦和余弦值
# ,旋转嵌入:将输入向量按一定规则旋转,以嵌入位置信息。
class Qwen2RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        self.dim = dim
        # 最大位置嵌入的长度,默认为 2048,base:基数,默认为 10000。。
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        # inv_freq:计算频率的逆值。
        # 位置列表先归一化(从绝对位置变成相对位置),之后取指数(1--接近10000),之后取倒数,位置从1--越来越小
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
        # register_buffer:将 inv_freq 注册为缓冲区,以便在模型保存和加载时保持不变。
        # register_buffer 方法用于注册一个非训练的缓冲区(buffer),这意味着它不会被梯度更新。当你使用 register_buffer 注册一个缓
        # 冲区时,它会被保存在模型的状态字典(state dict)中,并且在模型保存和加载时也会被序列化。
        # persistent=True:缓冲区会出现在模型的状态字典中,并且会被序列化和加载。
        # persistent=False:缓冲区不会出现在模型的状态字典中,但在实际保存和加载时,仍然会被序列化并加载。
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        # Build here to make `torch.jit.trace` work.生成正弦和余弦缓存
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )
    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        # t 是一个包含位置索引的张量,形状为 (seq_len,)。
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
        # torch.outer:计算外积,得到一个形状为 (seq_len, dim/2) 的张量
        freqs = torch.outer(t, self.inv_freq) # 计算频率。
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        # 拼接频率。emb 的形状为 (seq_len, dim)。
        # 在旋转位置嵌入(RoPE)中,我们通常将嵌入向量分为两个部分,并分别应用正弦和余弦变换。具体来说:
        # 对于每个位置 tt,计算频率 ff,得到一个形状为 (seq_len, dim/2) 的张量。
        # 将频率张量拼接两次,得到一个形状为 (seq_len, dim) 的张量。
        # 这样做的原因是,我们将嵌入向量分为两部分,每部分对应一个频率值。
        emb = torch.cat((freqs, freqs), dim=-1)
        # cos_cached 和 sin_cached:注册正弦和余弦缓存。
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
    def forward(self, x, seq_len=None): # x:输入张量。
        # x: [bs, num_attention_heads, seq_len, head_size]
        # 如果 seq_len 大于已缓存的最大长度,则重新生成缓存。
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
        return ( # 返回正弦和余弦缓存的切片。
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )
class Qwen2MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size # d
        self.intermediate_size = config.intermediate_size # hd
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) # d-->hd
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)# d-->hd
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) # hd-->d
        self.act_fn = ACT2FN[config.hidden_act]
    def forward(self, hidden_state): # (h,s,d)
        # 门控信号生成:gate_proj(hidden_state) 生成门控信号
        # 特征调整:gate_output 与 up_output 相乘,将门控信号应用于特征表示。
        # 门控机制的作用:通过门控信号动态调整哪些特征应该通过哪些特征应该被抑制。
        # 激活函数的选择:如果 config.hidden_act 是 "sigmoid",那么激活函数将是 sigmoid
        return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
class Qwen2Attention(nn.Module):
    def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
        super().__init__() # 调用父类的初始化方法
        self.config = config # 配置类实例
        self.layer_idx = layer_idx # 层索引
        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
                "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )
        
        self.hidden_size = config.hidden_size # d
        self.num_heads = config.num_attention_heads # q_h
        self.head_dim = self.hidden_size // self.num_heads # dk
        self.num_key_value_heads = config.num_key_value_heads # kv_h
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads # 比例
        self.max_position_embeddings = config.max_position_embeddings # p
        self.rope_theta = config.rope_theta # base
        self.is_causal = True # 是否用因果掩码
        self.attention_dropout = config.attention_dropout # dropout
        # 嵌入维度必须能被整除
        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        # 线性投影
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
        #需要注意的是这里的投影维度可能和q的投影维度不同
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
        # 最后一个线性转换层
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
        # 旋转位置嵌入层
        self.rotary_emb = Qwen2RotaryEmbedding(
            self.head_dim, # dk
            max_position_embeddings=self.max_position_embeddings,# max_position
            base=self.rope_theta, # base
        )
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,# 可选
        position_ids: Optional[torch.LongTensor] = None,# 可选
        past_key_value: Optional[Cache] = None, # 可选参数:缓存
        output_attentions: bool = False,# 是否输出注意力权重
        use_cache: bool = False, # 是否使用缓存
        cache_position: Optional[torch.LongTensor] = None, # 缓存位置
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size() # b,s,d
        # 投影
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
        # (b,q_len,q_h,dk)-->(b,q_h,q_len,dk),transpose:换轴(转置)
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        # (b,k_h,k_len,dk)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        kv_seq_len = key_states.shape[-2] # k_len
        # 缓存上个时间步的key,value表示
        if past_key_value is not None: # 如果设置了缓存
            if self.layer_idx is None: # 就必须有layer_idx,不然报错
                raise ValueError(
                    f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
                    "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                    "with a layer index."
                )
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
        # 旋转位置嵌入,传kv_len
        # 键/值序列长度:kv_seq_len 是键和值向量的长度,这是因为键和值向量代表的是相同的序列。
        # 查询序列长度:q_len 是查询向量的长度,这可能不同于键/值向量的长度。
        # 旋转位置嵌入:在计算旋转位置嵌入时,使用键/值序列长度是为了确保位置信息与键和值向量一致。
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        # 返回带位置信息的嵌入表示
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
        # 如果past_key_value is not None
        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
            # 更新当前的key,value表示
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
        # repeat k/v heads if n_kv_heads < n_heads
        # 如果键值头数量少于查询头数量,则重复键值头以匹配查询头数量。
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)
        # (b,q_h,q_len,dk)@(b,k_h,dk,k_len)-->(b,h,q_len,k_len)
        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
        if attn_weights.size()


http://www.ppmy.cn/news/1533607.html

相关文章

oracle解决关联查询报invalid number问题

出现问题的原因和背景 oracle进行关联查询的时候因为字段存在多个用逗号切割的id&#xff0c;导致查询的过程中报无效数字或非法数字 问题复现 新建表A CREATE TABLE "A" (id NUMBER NOT NULL,name VARCHAR2(255 BYTE) )INSERT INTO "A" VALUES (1, 上海…

职场人情世故

1.人走茶凉本是常态&#xff0c;朋友是流动的&#xff0c;人是会变的&#xff0c;走的人不必挽留。 2.别人不问你的事&#xff0c;千万不要随意去指点别人&#xff0c;因为你不知道某些事的复杂程度。 3.不要太在乎面子&#xff0c;不管别人如何对你指指点点&#xff0c;你都…

足球青训俱乐部管理:Spring Boot技术驱动

摘 要 随着社会经济的快速发展&#xff0c;人们对足球俱乐部的需求日益增加&#xff0c;加快了足球健身俱乐部的发展&#xff0c;足球俱乐部管理工作日益繁忙&#xff0c;传统的管理方式已经无法满足足球俱乐部管理需求&#xff0c;因此&#xff0c;为了提高足球俱乐部管理效率…

基于单片机人体反应速度测试仪系统

** 文章目录 前言概要设计思路 软件设计效果图 程序文章目录 前言 &#x1f497;博主介绍&#xff1a;✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师&#xff0c;一名热衷于单片机技术探索与分享的博主、专注于 精通51/STM32/MSP430/AVR等单片机设计 主要对象是咱们…

Android Webview和ScrollView冲突和WebView使用总结

1.因为Webview和ScrollView都用滑动事件&#xff0c;导致webview很难被滑动&#xff0c;即使被滑动了一点也非常不顺畅2.解决滑动冲突问题后发现&#xff0c;如果webview嵌套的html中含有轮播图等还是有问题。 使用自定义ScrollWebView解决这个问题 public class ScrollWebVi…

初识算法 · 双指针(2)

目录 前言&#xff1a; 盛最多水的容器 题目解析&#xff1a; 算法原理&#xff1a; 算法编写&#xff1a; 有效三角形的个数 题目解析&#xff1a; 算法原理&#xff1a; 算法编写&#xff1a; 前言&#xff1a; 本文介绍两个题目&#xff0c;盛最多水的容器和有效三…

【Spring】深入理解控制反转-IOC

目录 一、Spring_ioc_01项目 1. jdbc.properties 2. 高耦合 3. 中耦合 4. 低耦合 二、Spring_ioc_02项目 1. xxx.properties(键值对存储) 2. 解耦的方式创建对象 3. 调用getBean()方法并传入xxx.properties对应键获取其相应的值 三、Spring_ioc_03项目 1. spring两大…

taro RN 左右滑动切换页面

引入 react-native-pager-view 组件 import React, { Component } from react import Taro from tarojs/taro import { View, PagerView, Button } from tarojs/components import PagerView from react-native-pager-view;export default class MyComponent extends Taro.C…