(2021|NIPS,VQ-VAE,精度瓶颈松弛,三明治层归一化,CapLoss)CogView:通过转换器掌握文本到图像的生成

news/2024/4/14 21:15:42

CogView: Mastering Text-to-Image Generation via Transformers

公众号:EDPJ(添加 VX:CV_EDPJ 或直接进 Q 交流群:922230617 获取资料)

0. 摘要

通用领域中的文本到图像生成长期以来一直是一个悬而未决的问题,这需要强大的生成模型和跨模态理解。 我们提出了 CogView,一个具有 VQ-VAE 标记器的 40 亿参数 Transformer 来解决这个问题。 我们还演示了各种下游任务(例如,风格学习、超分辨率、文本图像排名和时装设计)的微调策略,以及稳定预训练的方法(例如,消除 NaN 损失)。 CogView 在模糊的 MS COCO 数据集上实现了最先进的 FID,优于之前基于 GAN 的模型和最近的类似工作 DALL-E。

代码和模型位于 https://github.com/THUDM/CogView

最新模型的演示网站:https://wudao.aminer.cn/CogView/index.html(无 post-selection)

1. 简介

“对于画家来说,有两件事:眼睛和心灵……眼睛,我们通过它观察自然;眼睛,我们通过它观察自然; 大脑,我们通过逻辑组织感觉以进行有意义的表达。” (Paul Cézanne [17])

由于对比自监督预训练彻底改变了计算机视觉 (CV) [24, 21, 8, 32],为图像带来高级语义的视觉语言预训练正在成为视觉理解的下一个前沿 [38, 30, 39]。 在各种前置任务(pretext tasks)中,文本到图像生成期望模型(1)从像素中解耦形状、颜色、手势和其他特征,(2)理解输入文本,(3)将对象和特征与相应的单词及其特征对齐 (4)学习复杂的分布来生成不同物体和特征的重叠和复合,这就像绘画一样,超出了基本的视觉功能(与眼睛和大脑中的 V1-V4 有关[22]),需要更高水平的认知能力(与大脑中的角回(angular gyrus)更相关 [3])。 

教授机器文本到图像生成的尝试可以追溯到深度生成模型的早期,当时 Mansimov 等人 [35] 向DRAW [20] 添加了文本信息。然后生成对抗网络(GAN) [19] 开始主导这项任务。

  • Reed 等人 [42] 将文本嵌入作为额外输入提供给生成器和鉴别器。
  • StackGAN [54] 将生成分解为草图细化(sketch-refinement)过程。
  • AttnGAN [51] 使用对单词的注意力来关注相应的子区域。
  • ObjectGAN [29] 在 text → boxes → layouts → image 过程之后生成图像。
  • DM-GAN [55] 和 DF-GAN [45] 引入了新的架构,例如,动态记忆或深度融合块,以获得更好的图像细化。
  • 尽管这些基于 GAN 的模型可以在简单且特定领域的数据集中(例如 Caltech-UCSD Birds 200 (CUB))进行合理的合成,但在复杂和通用领域场景(例如 MS COCO [31])的结果,远不能令人满意。

近年来,自回归生成模型兴起。

  • 生成预训练(Generative Pre-Training,GPT)模型 [37, 4] 利用 Transformers [48] 在大规模语料库中学习语言模型,极大地提升了自然语言生成和小样本语言理解的性能 [33]。
  • 自回归模型在 CV 中并不是新生事物。 PixelCNN、PixelRNN [47] 和 Image Transformer [36] 将图像上的概率密度函数分解到具有不同网络主干的子像素(像素中的颜色通道)上,显示出有希望的结果。
  • 然而,真实图像通常包含数百万个子像素,这表明大型模型的计算量难以承受。 即使是最大的像素级自回归模型 ImageGPT [7],也是在 ImageNet 上以最大分辨率仅为 96*96 进行预训练的。 

矢量量化变分自动编码器(Vector Quantized Variational AutoEncoders,VQ-VAE)[46] 的框架缓解了这个问题。

  • VQ-VAE 在阶段 1 中训练编码器将图像压缩到低维离散潜在空间,并训练解码器从隐藏变量中恢复图像。
  • 然后在阶段 2 中训练自回归模型(例如 PixelCNN) [47])学习拟合隐藏变量的先验。
  • 这种离散压缩比直接下采样损失的保真度更少,同时保持了像素的空间相关性。 因此,VQ-VAE 重振了 CV 中的自回归模型 [41]。
  • 遵循这个框架,Esser 等人 [15] 使用 Transformer 来拟合先验,并进一步从 L2 损失切换到 GAN 损失以进行解码器训练,极大地提高了特定领域无条件生成的性能。

CogView 的想法很自然:对文本和图像(来自 VQ-VAE)标记进行大规模生成联合预训练。

  • 我们收集了 3000 万个高质量(中文)文本图像对,并使用 40 亿个参数预训练 Transformer。
  • 然而,由于数据的异构性,大规模文本到图像的生成预训练可能非常不稳定。 我们系统地分析了原因,并通过提出的精度瓶颈松弛和三明治层归一化(Precision Bottleneck Relaxation and Sandwich Layernorm)解决了这个问题。 因此,CogView 极大地提高了文本到图像生成的质量。 

最近的一篇工作 DALL-E [39] 独立提出了同样的想法,并且发布时间早于 CogView。 与 DALL-E 相比,CogView 在以下四个方面取得了进步:

  • 根据模糊 MS COCO 上的 FID [25],CogView 的性能大幅优于 DALL-E 和之前基于 GAN 的方法,并且是第一个开源大型文本到图像 transformer。
  • 除了零样本生成之外,我们还进一步研究了微调预训练 CogView 的潜力。 CogView 可适用于各种下游任务,例如风格学习(特定领域的文本到图像)、超分辨率(图像到图像)、图像标题(图像到文本),以及文本到图像重排名。 
  • 经过微调的 CogView 可实现用于 post-selection 的自我重新排名,并摆脱 DALL-E 中的附加 CLIP 模型 [38]。 它还提供了一个新的指标 “Caption Loss”,以比 FID 和 Inception Score (IS) [43] 更细的粒度来衡量文本图像生成的质量和准确性。 
  • 我们提出了 精度瓶颈松弛和三明治层归一化(PB-relaxation 和 Sandwich-LN)来稳定大型 Transformer 在复杂数据集上的训练。 这些技术非常简单,可以消除前向过程(表征为 NaN 损失)中的溢出,并使 CogView 几乎能够用 FP16(O2)(这意味着所有计算(包括前向和后向)均采用 FP16,无需任何转换,但优化器状态和主权重为 FP32)进行训练。 它们还可以推广到其他 transformer 的训练。 

2. 方法

2.1 理论 

在本节中,我们将从 VAE [26] 推导出 CogView 的理论(本文中,粗体表示随机变量,常规字体表示具体值。 有关 VAE 的基础知识,请参阅此综合教程 [12]):CogView 优化图像和文本联合似然的证据下界(Evidence Lower BOund,ELBO)。 如果没有文本 t,下面的推导将变成对 VQ-VAE 的清晰的重新解释。 

假设数据集

由 N 个 i.i.d 图像变量 x 及其描述文本变量 t 的样本组成。 我们假设图像 x 可以通过涉及潜在变量 z 的随机过程生成: (1) t_i 首先从先验 p(t; θ) 生成。 (2) 然后从条件分布 p(z | t = t_i; θ) 生成 z_i。 (3) x_i 最终由 p(x | z = z_i; Ψ) 生成。 在接下来的部分中,我们将使用像 p(x_i) 这样的简写形式来指示 p(x = x_i)。 

令 q(z | x_i; Φ) 为变分分布(variational distribution),它是 VAE 编码器的输出。 对数似然和证据下界 (ELBO) 可写为: 

VQ-VAE 的框架与传统 VAE 的不同之处主要在于 KL 项。

  • 传统的 VAE 固定先验 p(z | t = t_i; θ),通常为 N(0,I),并学习编码器 。 然而,它会导致后验塌陷(posterior collapse) [23],这意味着 q(z | x_i; Φ) 有时会向先验塌陷。
  • 而 VQ-VAE 固定 Φ,并使用由 θ 参数化的另一个模型来拟合先验 p(z | t_i; θ)。 该技术消除了后验塌陷,因为现在仅在优化重建损失时更新编码器 Φ。
  • 作为交换,对于不同的 x_i,近似后验 q(z | x_i; Φ) 可能非常不同,因此我们需要一个非常强大的 p(z | t_i; θ) 模型来最小化 KL 项。

目前,最强大的生成模型 Transformer (GPT) 可处理离散码本上的标记序列。 为了使用它,令

其中 |V| 是码本的大小,h*w 是 z 的维数。 序列 z_i 可以从 q(z | x_i; Φ) 中采样,也可以直接

为了简单起见,我们选择后者,以便 q(z | x_i; Φ) 成为 z_i 上的单点分布。 方程(2)可以改写为:

然后学习过程分为两个阶段:

  • (1) 编码器 Φ 和解码器 Ψ 学习最小化重建损失。
  • (2) 单个 GPT 通过连接文本 t_i 和 z_i 作为输入序列来优化两个负对数似然 (negative log-likelihood,NLL) 损失。 

因此,第一阶段退化为纯粹的离散自动编码器,充当图像标记器,将图像转换为标记序列; 第二阶段的 GPT 承担了大部分建模任务。 图 3 展示了 CogView 的框架。

2.2 标记化

在本节中,我们将介绍 CogView 中标记器的详细信息,并比较图像标记器的不同训练策略(VQVAE 第 1 阶段)。 

文本的标记化已经得到了充分的研究,例如 BPE [16] 和 SentencePiece [28]。 在 CogView 中,我们在大型中文语料库上运行 SentencePiece 以提取 50,000 个文本标记。 

图像标记器是一个离散自动编码器,类似于 VQ-VAE [46] 或 d-VAE [39] 的第一阶段。更具体地说,编码器 Φ 将形状为 H*W*3 的图像 x 映射为形状为 h*w*d 的 Enc_Φ(x),然后将每个 d 维向量量化为可学习码本

中的临近嵌入。量化结果可以用 h*w 嵌入索引表示,然后我们得到潜在变量

解码器 Ψ 将量化的矢量映射回(模糊)图像以重建输入。 在我们的 4B 参数 CogView 中,|V| = 8192,d = 256,H = W = 256,h = w = 32。 

由于离散选择的存在,图像标记器的训练并不简单。 这里我们介绍四种训练图像标记器的方法: 

1) 最近邻映射,直通估计器 [2],由原始 VQ-VAE 提出。 这种方法的一个常见问题 [39] 是,当码本很大并且初始化不仔细时,由于维数灾难,只会使用少量嵌入。 我们在实验中没有观察到这种现象。 

2) Gumbel 采样,直通估计器。 如果我们按照原始 VAE 来根据向量之间的距离重新参数化潜在变量 z 的分类分布,即

无偏采样策略为

其中,温度 τ 逐渐降至 0。我们可以进一步使用可微分的 softmax 来近似 argmax 的 one-hot 分布。 DALL-E采用这种方法并结合许多其他技巧来稳定训练。 

3) 最近邻映射,移动平均值,其中码本中的每个嵌入在训练期间定期更新为最近映射到它的向量的平均值 [46]。 

4) 最近邻映射,固定码本,初始化后码本是固定的。

对比。 为了比较这些方法,我们在相同的数据集和随机种子上训练了四个具有相同架构的图像标记器,并展示了图 2 中的损失曲线。我们发现所有方法基本上是均匀匹配的,这意味着嵌入的学习 如果初始化正确的话,码本并不是很重要。 在预训练中,我们使用移动平均法的标记器。

数据的介绍和更多关于标记化的细节参见附录A。

2.3 自回归 Transformer

CogView 的主干是单向 Transformer (GPT)。

  • Transformer 有 48 层,隐藏层大小为 2560 个,有 40 个注意力头,总共 40 亿个参数。
  • 如图 3 所示,四个分隔符标记,[ROI1](图像的参考文本)、[BASE]、[BOI1](图像开始)、[EOI1](图像结束)被添加到每个序列中以指示文本和图像的边界。
  • 所有序列都被剪裁或填充至 1088 的长度。 

预训练的前置任务(pretext task)是从左到右的标记预测,也称为语言建模。

  • 图像和文本标记都受到同等对待。
  • DALL-E [39] 建议降低文本标记的损失权重; 相反,在小规模实验中,我们惊讶地发现文本建模是文本到图像预训练成功的关键。 如果文本标记的损失权重设置为零,模型将无法找到文本和图像之间的联系,并生成与输入文本完全无关的图像。
  • 我们假设文本建模抽象了隐藏层中的知识,这些知识可以在后续的图像建模中得到有效利用。 

训练细节

  • 我们在 512 个 V100 GPU (32GB) 上训练模型,批量大小为 6,144 个序列(每批量 670 万个标记),执行 144,000 个步骤。
  • 由 Adam 更新参数,最大 lr = 3*10^(-4); β1 = 0.9; β2 = 0.95;权重衰减= 4*10^(-2)。
  • 学习率在前 2% 的步骤中升温,并随着余弦退火 [34] 而衰减。
  • 当超参数在适当的范围内时,我们发现训练损失主要取决于训练的标记总数(每批的标记数*步骤数),这意味着,如果使用相同数量的标记训练,将批大小(和学习率)加倍会导致非常相似的损失。 因此,我们使用相对较大的批大小来提高并行性并减少通信时间的百分比。
  • 我们还设计了一个三区域稀疏注意力(three-region sparse attention)来加速训练并节省内存而不损害性能,附录 B 对此进行了介绍。 

2.4 训练的稳定性 

目前,预训练大型模型(>2B 参数)通常依赖于 16 位精度来节省 GPU 内存并加快计算速度。 许多框架,例如 DeepSpeed ZeRO [40],甚至只支持 FP16 参数。 然而,文本到图像的预训练在 16 位精度下非常不稳定。 训练 4B 普通 pre-LN Transformer 将在 1,000 次迭代内快速导致 NaN 损失。 稳定训练是 CogView 中最具挑战性的部分,它与 DALL-E 非常一致。 

我们将 DALL-E 的解决方案总结为容忍训练的数值问题。

  • 由于不同层中的值和梯度在规模上变化很大,因此他们提出了一种新的混合精度框架,每个 resblock 损失缩放,并以 32 位精度和 32 位梯度存储所有增益、偏差、嵌入和非嵌入。
  • 该解决方案很复杂,消耗额外的时间和内存,并且大多数当前的训练框架都不支持。 

相反,CogView 会正则化这些值。 我们发现存在两种不稳定性:溢出(以 NaN 损失为特征)和下溢(以发散损失为特征)。 提出以下技术来解决这些问题。 

精度瓶颈松弛(Precision Bottleneck Relaxation,PB-Relax)。 在分析训练动态后,我们发现溢出总是发生在两个瓶颈操作处,即最终的 LayerNorm 或注意力。 

  • 在深层中,输出的值可能会爆炸至 10^4 ~ 10^5 之大,从而导致 LayerNorm 中的变化溢出。 幸运的是,由于 LayerNorm(x) = LayerNorm(x / max(x)),我们首先可以通过除以最大值来缓解这个瓶颈(我们不能直接将 x 除以一个很大的常数,这会导致训练初期出现下溢。)。 
  • 注意力分数 (Q^T)K / √d 可能明显大于输入元素,并导致溢出。 将计算顺序更改为 (Q^T)(K / √d) 可缓解该问题。 为了消除溢出,我们注意到 softmax((Q^T)K / √d ) = softmax((Q^T)K / √d - 常量),这意味着我们可以将注意力的计算改为等式 4,其中 α 是一个大数,例如 α = 32(最大值必须至少是按注意力头的,因为不同注意力头的值差异很大)。这样,注意力分数的最大值(绝对值)也被除以 α 来防止溢出。 关于 CogView 中注意力的详细分析见附录 C。 

三明治层规范(Sandwich LayerNorm,Sandwich-LN)。 Transformers 中的 LayerNorms [1] 对于稳定训练至关重要。 Pre-LN [50] 被证明比原始的 Post-LN 收敛更快、更稳定,并成为最近作品中 Transformer 层的默认结构。 然而,这对于文本到图像的预训练来说还不够。LayerNorm

的输出基本上与 x 的隐层大小的平方根成正比,在 CogView 中为 √d = √2560 ≈ 50。 如果某些维度的输入值明显大于其他维度(对于 Transformer 来说就是这样),这些维度的输出值也会很大 (10 ~ 100)。 在残差分支中,这些大值被放大并加回主支,从而在下一层加剧这种现象,最终导致深层数值爆炸。 

这个数值爆炸背后的原因激励我们去限制层层加剧。

  • 我们提出了 Sandwich LayerNorm,它还在每个残差分支的末尾添加了一个 LayerNorm。
  • Sandwich-LN 保证了各层输入值的规模在合理范围内,训练 500M 模型的实验表明,其对收敛的影响可以忽略不计。
  • 图 4(a) 说明了 Transformer 中不同的 LayerNorm 结构。

玩具(Toy)实验

  • 图 4(b) 显示了 PB-relax 和 Sandwich-LN 在玩具实验设置下的有效性,因为训练许多大型模型进行验证是不现实的。
  • 我们发现深度 Transformer(64 层,1024 个隐藏层大小)、大学习率(0.1 或 0.01)、小批量大小(4)可以在合理的超参数下模拟训练中的价值爆炸。
  • PB-relax + Sandwich-LN 甚至可以稳定玩具实验。 

缩小嵌入梯度。尽管我们在使用 Sandwich-LN 后没有观察到任何下溢的迹象,但我们发现标记嵌入的梯度比其他参数的梯度大得多,因此只需将其规模缩小 α = 0.1 即可增加动态损失规模,以进一步防止下溢,可以通过 Pytorch 中的 emb = emb*alpha + emb.detach()*(1 - alpha) 来实现。 它似乎减慢了标记嵌入的更新,但实际上并没有损害我们实验中的性能,这也对应于最近的一项工作MoCo v3 [9]。

讨论

  • PB-relax 和 Sandwich-LN 成功稳定了 CogView 和 8.3B 参数 CogView-large 的训练。
  • 它们对于所有 Transformer 预训练也是通用的,并将在未来支持非常深的 Transformer 训练。作为证明,我们使用 PB-relax 成功消除了训练 10B 参数 GLM 时的溢出 [14]。
  • 然而,总的来说,语言预训练中的精度问题并不像文本到图像预训练中那么重要。 我们假设根源是数据的异质性,因为我们观察到文本和图像标记在某些隐藏状态下通过尺度来区分。 另一个可能的原因是 DALL-E 猜测很难找到下溢。 彻底的调查留给未来的工作。 

3. 微调 

CogView 在微调方面比 DALL-E 更进一步。 特别是,我们可以通过微调 CogView 的超分辨率和自重新排名来改进文本到图像的生成。 所有微调任务都可以在一台 DGX-2 上在一天内完成。

3.1 超分辨率 

由于图像标记器在训练之前将 256*256 像素图像压缩为 32*32 标记序列,因此由于有损压缩,生成的图像比真实图像更模糊。

  • 然而,由于注意力操作的复杂度为 O(n^2),增加序列长度将消耗更多的计算和内存。
  • 以前关于超分辨率或图像恢复的作品 [13] 通常处理已经是高分辨率的图像,将模糊的局部纹理映射到清晰的纹理。 它们不能应用于我们的情况,我们需要向生成的低分辨率图像添加有意义的细节。
  • 图 5 (b) 是我们的微调方法的示例,说明了我们期望的超分辨率行为。 

我们超分辨率微调解决方案的动机是相信 CogView 是在一般域中最复杂的分布上进行训练的,并且已经涵盖了不同分辨率的对象。

  • 支持这一信念的证据是,如果我们在文本末尾附加 “特写视图”,模型将生成对象一部分的细节。
  • 因此,针对超分辨率微调 CogView 应该不难 。 

具体来说,

  • 我们首先将 CogView 微调为条件超分辨率模型,从 16*16 个图像标记到 32*32 个标记。
  • 然后,我们通过图 5 (a) 中的中心连续滑动窗口策略将 32*32 个标记的图像逐块放大到 64*64 个标记(512*512 像素)。
  • 在保持中心区域的完整性方面,该顺序比光栅扫描顺序(raster-scan order)表现得更好。

为了准备数据,

  • 我们将大约 200 万张图像裁剪为 256*256 个区域,并将它们下采样到 128*128。
  • 标记化后,我们得到不同分辨率的 32*32 和 16*16 序列对。
  • 微调序列的模式为 “[ROI1] text tokens [BASE] [BOI1] 16*16 image tokens [EOI1] [ROI2] [BASE] [BOI2] 32*32 image tokens [EOI2]”,比最大位置嵌入索引 1087 长 。作为解决方案,我们从 [ROI2] 处的 0 开始重新计算位置索引。 
  • 人们可能会担心位置索引的复用可能会引起混乱,但实际上,模型可以很好地区分两幅图像,可能是基于它们是否能够关注前面的 [ROI2]。 

3.2 图像标题和自我重新排名

微调用于图像标题的 CogView 非常简单:交换输入序列中文本和图像标记的顺序。

  • 由于模型已经学习了文本和图像之间的对应关系,因此逆向生成并不难。
  • 我们没有评估性能,因为(1)没有权威的中文图像标题基准(2)图像标题不是这项工作的重点。
  • 微调这种模型的主要目的是为了自我重新排名。 

我们提出了标题损失(Caption Loss,CapLoss)来评估图像和文本之间的对应关系。 更具体地说,

其中 t 是文本标记序列,x 是图像。 CapLoss(x, t) 是文本标记的交叉熵损失,该方法可以看作是文本到图像生成的逆提示(inverse prompting) [56] 的适应。 最后,选择 CapLosses 最低的图像。

与额外训练另一个用于重新排名的对比自监督模型(例如 CLIP [38])相比,我们的方法消耗更少的计算资源,因为我们只需要微调。 图 9 中的结果显示,通过我们的方法选择的图像在 FID 中比通过 CLIP 选择的图像表现更好。 图 6 显示了重新排名的示例。 

3.3 风格学习 

尽管 CogView 经过预训练以尽可能覆盖多样化的图像,但无法很好地满足生成特定风格或主题图像的愿望。 我们对模型进行了四种风格的微调:国画、油画、素描、卡通。 这些风格的图片是从Google、百度、Bing 等搜索引擎页面自动提取的,关键词为 “An image of {style} style”,其中{style} 是风格名称。 我们分别针对不同风格对模型进行微调,每个模型有 1,000 张图像。 

在微调过程中,图像对应的文本也是 “An image of {style} style”。 生成时,文本为 “A {object} of {style} style”,其中 {object} 是要生成的对象。 通过这种方式,CogView 可以将从预训练中学到的对象形状知识转移到微调的样式中。 图 7 显示了样式的示例。 

3.4 工业时尚设计

当生成针对单个域时,纹理的复杂性大大降低。 在这些场景中,我们可以 (1) 训练 VQGAN [15] 而不是 VQVAE 作为潜在变量,以获得更真实的纹理,(2) 减少参数数量并增加序列长度以获得更高分辨率。 在这种情况下,我们的三区域稀疏注意力(three-region sparse attention)(附录 B)可以加速高分辨率图像的生成。 

我们使用 50*50 个 VQGAN 图像标记在大约 1000 万个时尚-标题对上训练 3B 参数模型,并将它们解码为 800*800 像素。 图 8 显示了用于时装设计的 CogView 示例,该示例已成功部署到阿里巴巴 Rhino 时装生产中。 

4. 实验结果

4.1 机器评估 

目前,通用领域文本到图像生成最权威的机器评估指标是 MS COCO 上的 FID,该指标未包含在我们的训练集中。

  • 为了与 DALL-E 进行比较,我们遵循相同的设置,在对地面实况图像和生成的图像应用不同半径的高斯滤波器后,对从数据集中采样的 30,000 个标题的子集评估 CogView。(我们对 DM-GAN 和 DALL-E 使用相同的评估代码,可在 https://github.com/MinfengZhu/DM-GAN 获取)
  • 使用机器翻译将标题转换为用于 CogView 的中文。
  • 为了公平地与 DALL-E 进行比较,我们不使用超分辨率。
  • 此外,DALL-E 为每个标题生成 512 张图像,并通过 CLIP 选择最好的一张,这需要生成约 150 亿个标记。
  • 为了节省计算资源,我们根据 CapLosses 从 60 张生成的图像中选择最好的一张。 CapLoss 的评估基于 5,000 张图像的子集。我们最终将生成图像的对比度增强了 1.5。
  • 表 1 显示了 CogView 和其他方法的指标。 

标题损失作为衡量标准

  • FID 和 IS 旨在测量相对简单的分布(通常是单个对象)的无条件生成的质量。然而,文本到图像的生成应该逐对评估。
  • 表 1 显示 DM-GAN 实现了最佳的不模糊 FID 和 IS,但在人类偏好方面排名最后(图 10(a))。
  • 标题损失是一个绝对分数(而不是像 CLIP 那样的相对分数),因此可以对样本进行平均。 对于这项任务来说,它应该是一个更好的指标,并且与第 4.2 节中我们人类评估的总体得分更加一致。 

与 CLIP 对比自我重排。 我们评估了由 CLIP 选择的 CogView 生成图像的 FID-0 和 IS,并在 MS COCO 上进行自我重新排名。 图 9 显示了不同候选数量的曲线。 自我重新排名获得更好的 FID,并随着候选数量的增加而稳步细化 FID。 CLIP 在增加 IS 方面表现更好,但如上所述,它不是此任务的合适指标。 

关于 CogView 和 DALL-E 之间性能差异的讨论

  • 由于 DALLE 比 CogView 使用更多的数据和参数进行预训练,为什么 CogView 即使没有超分辨率也能获得更好的 FID?
  • 具体原因很难得知,因为DALL-E不是开源的,但我们猜测原因包括:
  • (1) CogView使用PB-relax和Sandwich-LN进行更稳定的优化。
  • (2) DALL-E使用了大量的卡通和渲染数据,使得生成的图像纹理与 MS COCO 中的照片有很大不同。
  • (3) 自重排序在 FID 中比 CLIP 中更好地选择图像。
  • (4) CogView 的训练时间更长(CogView 中 96B 训练好的标记 vs. DALL-E 中 56B 训练好的标记)。 

4.2 人类评价 

在文本到图像的生成方面,人类评估比机器评估更有说服力。 我们的人类评估包括 AttnGAN、DM-GAN、DF-GAN、CogView 生成的图像和恢复的地面实况(即由我们的图像标记器模糊的地面实况,理论上是 CogView 的上边界)之间的 2,950 组比较。 模型之间的详细信息和基于示例的比较参见附录 E。 

图 10 中的结果表明,CogView 的性能大幅优于基于 GAN 的基线。 CogView 以 37.02% 的概率被选为最好的,与恢复的地面实况 (59.53%) 的性能相媲美。 图 10(b)(c) 还表明我们的超分辨率模型持续提高了图像质量,尤其是清晰度,甚至优于恢复的地面实况。 

5. 结论与讨论 

局限性。 CogView 的缺点是生成速度慢,这对于自回归模型很常见,因为每个图像都是逐标记生成的。 VQVAE带来的模糊也是一个重要的限制。 这些问题将在今后的工作中得到解决。 

道德问题

  • 与 Deepfake 类似,CogView 因其可控且强大的图像生成能力而容易受到恶意利用[49]。 一项调查讨论了缓解此问题的可能方法 [5]。
  • 此外,关于人类的生成模型通常存在公平性问题(https://thegradient.pub/pulse-lessons)。在附录 D 中,我们分析了 CogView 中的公平性情况,并介绍了一种简单的 “单词替换” 方法来解决这个问题。 

我们系统地研究了结合 VQVAE 和 Transformers 进行文本到图像生成的框架。 CogView 展示了可扩展跨模态生成预训练的有希望的结果,并且还揭示并解决了可能源自数据异构性的精度问题。 我们还介绍了针对不同下游任务微调 CogView 的方法。 我们希望 CogView 能够推进可控图像生成和跨模态知识理解的研究和应用,但需要防止它被用来创建错误信息的图像。 

参考

Ding M, Yang Z, Hong W, et al. Cogview: Mastering text-to-image generation via transformers[J]. Advances in Neural Information Processing Systems, 2021, 34: 19822-19835.

附录

A. 数据收集和标记器的详细信息 

我们从多个渠道收集了约 3000 万个文本图像对,并构建了 2.5TB 的新数据集(标记化后,大小变为约 250GB)。

  • 该数据集是 WudaoCorpora 项目(https://wudaoai.cn/data) [52] 的扩展。 大约 50% 的文本是英文,包括 Conceptual Captions [44]。 它们通过机器翻译翻译成中文。
  • 此外,我们没有去除数据集中的水印和白边,尽管它们会影响生成图像的质量,因为我们认为从研究的角度来看这不会影响我们论文的结论。 

数据来源基本分为以下几类:

  • (1) 专业图片网站(中英文)。 网站中的图像通常带有标题。 该渠道的数据占比最高。
  • (2) Conceptual Captions [44] 和 ImageNet [11]。
  • (3) 网上新闻图片及其周围文字。
  • (4) 来自阿里巴巴的一小部分商品-标题对。
  • (5) 图像搜索引擎。 为了覆盖尽可能多的常见实体,我们制作了一个由 1,200 个查询组成的查询列表。 每个查询都是从大规模知识图谱中提取的实体名称。 我们选择七大类:美食、地域、物种、人名、风景、物产和艺术品。 我们根据英文维基百科中的出现次数为每个类别提取前 k 个实体,其中 k 是为每个类别手动选择的。 我们收集了每个主要搜索引擎网站针对每个查询返回的前 100 张图像。 

我们已经在 2.2 节中介绍了标记器,这里有一些细节。

  • 文本标记器直接基于 https://github.com/google/sentencepiece 上的 SentencePiece 包。
  • 图像标记器中的编码器是一个 4 层卷积神经网络 (CNN),具有 512 个隐藏单元,每层都有 ReLU 激活。 前三层的感受野为 4,步长为 2,从而减半图像的宽度和高度,最后一层是 1*1 卷积,将通道数转换为 256,这是词典中嵌入的隐藏大小。
  • 除了将卷积替换为反卷积之外,解码器具有与编码器相同的架构。 字典中的嵌入通过 Xavier 均匀初始化(Xavier uniform initialization) [18] 进行初始化。 

B. 稀疏注意力 

如图 11 所示,我们设计了三区域稀疏注意力(three-region sparse attention),这是一种易于实现的文本到图像生成的稀疏注意力。 每个标记关注所有文本标记、所有枢轴标记(pivots tokens)以及其之前的相邻窗口中的块中的标记。 

枢轴标记(pivots tokens)是随机选择的图像标记,类似于大鸟(big bird)[53]。 每次我们进入新层时都会对它们重新采样。 我们认为他们可以提供有关图像的全局信息。 

分块窗口注意力提供了局部信息,这是最重要的区域。 一维窗口注意力的前向计算可以通过仔细填充和改变张量的步幅来有效地就地实现,因为要关注的位置在内存中已经是连续的。 然而,如果没有定制的 CUDA 内核,我们仍然需要额外的内存来进行向后计算。 我们通过将相邻标记分组到块中来缓解这个问题,其中所有标记都注意相同的标记(在因果掩蔽之前)。 更多详细信息包含在我们发布的代码中。

在我们对 4096 个标记序列的基准测试中,三区域稀疏注意力(768 个文本和枢轴标记、768 个分块窗口标记)比普通注意力快 2.5倍,并节省 40% 的 GPU 内存。 整个训练比普通注意力训练快 1.5倍,并节省 20% 的 GPU 内存。 在相同的超参数、数据和随机种子的情况下,它们的损失曲线几乎相同,这意味着稀疏注意力不会影响收敛。 

然而,我们在训练 40 亿参数 CogView 时没有使用三区域稀疏注意力,因为担心它可能与 3.1 节中的超分辨率微调不兼容。 但它成功地加速了 CogView-fashion 的训练,而且没有副作用。 

C. 注意力分析 

为了探索 CogView 的注意力机制,我们通过绘制热图并标记最受关注的标记来可视化推理过程中的注意力分布。 我们发现我们的模型的注意力头在捕获位置和语义信息方面表现出很强的能力,并且不同层之间的注意力分布有所不同。 关于注意力评分量表的分析在 C.4 节中。 

C.1 位置偏差 

注意力分布与图像的位置结构高度相关。

  • 有很多头非常关注固定位置偏移,尤其是 32 的倍数(这是一行包含的标记数量)(图 12 (a))。
  • 有些头专门关注图像中的前几行(图 12 (b))。
  • 一些头部的热图显示出棋盘格图案(图12(c)),表明边界处的标记与中心处的标记不同。
  • 更深层次也显示出一些广泛的结构偏差。 例如,一些注意力头大量关注图像上/下半部分或中心的标记(图 12 (d)(e))。 

C.2 语义分割 

CogView 中的注意力还表明它也执行隐式语义分割。 有些标题突出了文本中提到的主要对象。 我们用 “桌子上有一个苹果,旁边有一个花瓶,里面有紫色的花” 作为我们实验的输入。 在图 13 中,我们用红点标记了与最受关注的标记相对应的像素,并发现注意力头成功捕获了苹果和紫色花等项目。 

C.3 注意力随深度而变化 

不同层之间的注意力模式有所不同。

  • 前面的层主要关注位置信息,而后面的层更关注内容。
  • 有趣的是,我们观察到注意力在最后几层(第 42 层之后)变得稀疏,很多注意力头只关注少数标记,例如分隔符标记(图 12 (f))。 一种可能的解释是,最后一层倾向于聚焦在当前标记上以确定输出标记,并且对分隔符标记的关注可以用作注意力头的无操作(no-op),这不会显着改变模型的输出,类似于 BERT 中的分析 [10]。因此,最后一层的注意力头忽略了大多数标记,并使注意力层退化为前馈层。 

C.4 注意力数值尺度

作为 2.4 节的补充,我们可视化第 38 层的注意力数值尺度,该层在 CogView 中具有最大的注意力得分尺度 (Q^T)K / √ d 。 不同头部的尺度差异很大,但每个头部的方差很小(这就是为什么即使分数很大,注意力也不会退化)。 我们认为原因是模型希望不同的头部有不同的敏感度,以便它学习乘以不同的常数来得到 Q 和 K。作为副作用,这些值可能有很大的偏差。 注意力的 PB- relax 是为了消除计算过程中的偏差。 

S. 总结

S.1 主要贡献 

提出了 CogView(cognitive ability),一个具有 VQ-VAE 标记器的 40 亿参数 Transformer,来解决通用领域中的文本到图像生成问题。CogView 可适用于各种下游任务,例如风格学习(特定领域的文本到图像)、超分辨率(图像到图像)、图像标题(图像到文本),以及文本到图像重排名

提出精度瓶颈松弛和三明治层归一化(Precision Bottleneck Relaxation and Sandwich Layernorm)来解决可能由数据的异构性引起的大规模文本到图像的生成预训练的不稳定。

提出了标题损失(Caption Loss,CapLoss)来评估图像和文本之间的对应关系,以比 FID 和 Inception Score (IS) 更细的粒度来衡量文本图像生成的质量和准确性。CapLoss 是文本标记的交叉熵损失,该方法可以看作是文本到图像生成的逆提示(inverse prompting)的适应。相比于对比自监督模型(例如 CLIP),CapLoss 消耗更少的计算资源,并且有更好的 FID 性能。

提出三区域稀疏注意力(three-region sparse attention)来加速训练并节省内存而不损害性能。

S.2 架构和方法

CogView 的架构如图 3 所示。 

  • 文本标记:大型中文语料库上运行 SentencePiece 以提取 50,000 个文本标记。
  • 图像标记:使用 VQ-VAE 编码器把图像转换为标记序列。
  • 解码器用于从潜在变量恢复图像。
  • 四个分隔符标记,[ROI1](图像的参考文本)、[BASE]、[BOI1](图像开始)、[EOI1](图像结束)被添加到每个序列中以指示文本和图像的边界。

精度瓶颈松弛(Precision Bottleneck Relaxation,PB-Relax):通过适度的缩放,避免数值爆炸导致的 LayerNorm 中的溢出。

三明治层归一化(Sandwich LayerNorm,Sandwich-LN):Transformer 中不同的 LayerNorm 结构如图 4 所示,Sandwich-LN 保证了各层输入值的规模在合理范围内,从而避免数值的逐层放大而最终导致深层数值爆炸。


http://www.ppmy.cn/news/1141820.html

相关文章

数字IC前端学习笔记:数字乘法器的优化设计(Dadda Tree乘法器)

相关阅读 数字IC前端https://blog.csdn.net/weixin_45791458/category_12173698.html?spm1001.2014.3001.5482 华莱士树仍然是一种比较规则的结构(这使得可以方便地生成树的结构),这导致了它所使用的全加器和半加器个数不是最少的&#xff…

学习笔记(css穿透、vue-cookie、拦截器、vuex、导航守卫、token/Cookie、正则校验)

目录 一、记录 1、CSS穿透 2、输入框是否提示输入 3、插槽 #slot 4、v-deep深入改掉属性值 二、vue-cookie 1、官方文档 2、使用 三、拦截器 1、请求拦截器 2、响应拦截器 四、vuex对信息存取改 五、路由导航守卫 1、登录思路 2、设置白名单 六、Token与Cookie…

Java线程池:并发编程的利器

Java线程池:并发编程的利器 在多任务、高并发的时代,Java并发编程显得尤为重要。其中,Java线程池是一种高效的管理线程的工具,能够提高应用程序的性能和响应速度。本文将深入探讨Java线程池的工作原理、应用场景以及简单示例&…

fastadmin框架如何查询数据表指定时间段内的数据

1.查看今日的数据 $currentDate date( Y-m-d );//获取今日的时间 $data[ today_order ] Db( orders ) ->whereTime( success_time, >, $currentDate.00:00:00 ) ->whereTime( success_time, <, $currentDate.23:59:59 ) ->where(type_store,4) ->where(sh…

山体滑坡监测系统——高效、便捷的新选择

在当今社会&#xff0c;科技的进步为我们的生活和工作带来了诸多便利。而在山体滑坡监测领域&#xff0c;全球导航卫星系统&#xff08;GNSS&#xff09;的引入更是增加了数据监测的高效性和便捷性。 一、山体滑坡监测系统的基本原理 山体滑坡监测系统是由监控平台和GNSS位移…

PHP知识大全

PHP知识大全 1. 变量如何定义&#xff1f;如何检查变量是否定义&#xff1f;如何删除一个变量&#xff1f;怎样检测变量是否设置&#xff1f; $定义 isset()// 检测变量是否设置 defined&#xff08;&#xff09;// 检测常量是否设置unset()//销毁指定的变量 empty()// 检测…

8.2 JUC - 6.CyclicBarrier

目录 一、是什么&#xff1f;二、使用demo三、注意 一、是什么&#xff1f; CyclicBarrier &#xff1a; 循环栅栏&#xff0c;用来进行线程协作&#xff0c;等待线程满足某个计数。构造时设置计数个数&#xff0c;每个线程执行到某个需要“同步”的时刻调用 await() 方法进行…

数据结构基本概念-Java常用算法

数据结构基本概念-Java常用算法 1、数据结构基本概念2、数据逻辑结构3、算法时间复杂度 1、数据结构基本概念 数据&#xff08;Data&#xff09;&#xff1a;数据是信息的载体&#xff0c;其能够被计算机识别、存储和加工处理&#xff0c;是计算机程序加工的“原材料”。数据元…

nrm 安装教程(图文教程)

序&#xff1a; 1、都知道nvm解决的是node版本切换的问题&#xff0c;nrm 解决的是则是npm指向的问题。 2、雪狼的公众号&#xff1a;“程序员野区”&#xff0c;也许未来的莫一天&#xff0c;你会用到&#xff0c;不凡先进来瞅瞅公众号都分享了啥内容 正文&#xff1a; 1、安…

途虎养车上市、京东养车“震虎”,如何突围汽车后市场?

“汽车后市场第一股”终于来了&#xff01; 赶在十一黄金周之前&#xff0c;途虎养车股份有限公司(09690.HK&#xff0c;下称“途虎养车”)于9月26日挂牌港交所&#xff0c;开盘价为28港元/股&#xff0c;与发行价持平&#xff1b;IPO首日报收29.50港元/股&#xff0c;涨幅5.3…

Python实用技术二:数据分析和可视化(2)

目录 一&#xff0c;多维数组库numpy 1&#xff0c;操作函数&#xff1a;​ 2&#xff0c;numpy数组元素增删 1&#xff09;添加数组元素 2&#xff09;numpy删除数组元素 3&#xff09;在numpy数组中查找元素 4&#xff09;numpy数组的数学运算 3&#xff0c;numpy数…

什么是mvvm模式,优点是什么

MVVM&#xff08;Model-View-ViewModel&#xff09;模式是一种设计模式。它是一种开发模式&#xff0c;旨在分离用户界面的开发和业务逻辑的开发。MVVM模式将应用程序分为三个部分&#xff1a; Model&#xff1a;它代表应用程序的数据模型和业务逻辑。 View&#xff1a;它代表…

队列--二叉树层序遍历

/*1/ \2 3/\ /\4 5 6 7利用LinkedListQueue1. 头 [1] 尾12.头 [2 3] 尾1 23.头 [3 4 5] 尾1 24.头 [4 5 6 7] 尾1 2 35.头 [] 尾1 2 3 4 5 6 7*/ 代码&#xff1a; class Solution {public List<List<Integer>> levelOrder(TreeNode root) {List<List&l…

arm-三盏灯流水

.text .global _start _start: 1.设置GPIOE寄存器的时钟使能 RCC_MP_AHB4ENSETR[4]->1 0x50000a28 LDR R0,0x50000A28 LDR R1,[R0] ORR R1,R1,#(0x3<<4) 第四位第五位都设置为1 STR R1,[R0] 写回2.设置PE10管脚为输出模式 GPIOE_MODER[21:20]->01 0x5000…

win11 vscode配置c/c++,使用mingw编译器

文章目录 第一步&#xff1a;装好vscode第二步&#xff1a;下载 mingw创建一个文件夹作为C或者C的项目文件夹&#xff0c;用vscode打开 第一步&#xff1a;装好vscode 之前使用python时装过 第二步&#xff1a;下载 mingw 官网 3.从这个界面一直往下滑 找到&#xff1a; 下…

pandas 笔记:asfreq

1 方法介绍 asfreq 是一个在 Pandas 时间序列数据分析中常用的方法。这个方法主要用于改变时间序列的频率。asfreq 可以帮助我们将一个时间序列从一个频率转换为另一个频率 2 基本用法 DataFrame.asfreq(freq, methodNone, howNone, normalizeFalse, fill_valueNone)3 参数说…

证件照换底色详细教程

说到证件照的底色更改&#xff0c;我想对大部分朋友来说是蛮头疼的事情&#xff0c;由于我们不论是在生活还是学习中&#xff0c;有时候总会要上传一些证件照&#xff0c;而当你手上有证件照准备上传时&#xff0c;发现底色不对&#xff0c;是不是很抓狂&#xff0c;现在&#…

一篇理解网络分层原理

一、网络分层的必要性。 如图是一个数据的传输过程&#xff0c;在这个途中会有很多的原因导致数据丢失&#xff0c;网络分层就要可以很大程度的避免这个现象。 网络分层的必要性体现在以下几个方面&#xff1a; 抽象复杂度&#xff1a;网络分层将网络功能按照不同的层次进行分…

flink双流join结果数据重复问题排查

1.背景 Kafka的两个topic&#xff0c;topic1 为用户下单明细记录&#xff08;包含订单基本信息&#xff09;&#xff0c;topic2为下单渠道记录&#xff08;包含下单来源和渠道内容设备相关的信息&#xff09; &#xff0c;要求实时统计每分钟内所有订单下的渠道来源分布详情。具…

基于SSM+Vue的学习交流论坛的设计与实现

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;采用Vue技术开发 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#x…
最新文章