机器之心报道
机器之心编辑部
在本文中,来自哈佛大学、Facebook AI 研究院的研究者提出了一种基于残差能量模型的文本生成方法,效果超过 state-of-the-art 的 transformer 语言模型。这是能量模型在大规模文本生成中的首次成功应用,论文已入选 ICLR 2020。
论文链接:https://openreview.net/pdf?id=B1l4SgHKDH
近年来,随着 GPT-2、GPT-3 模型的出现,我们似乎已经可以使用语言模型生成以假乱真的文本。然而事实上,模型生成的文本存在明显的质量问题。
比如我们可以训练一个分类器去区分真实文本和语言模型生成的文本,而且可以达到非常高的准确率 [1,2]。那么,一个自然的问题是,我们能否使用这个分类器去提高文本生成的质量,以期达到更加以假乱真的水平呢?这就是本文的研究问题。
同时,本文还解答了另一个问题:由于传统的文本生成解码器只能使用单向模型,如何使用预训练的双向模型 BERT 改进文本生成解码器?
为了便于讨论,作者定义一段有 T 个词的文本为 x=x_1 x_2…x_T。它有可能是真实文本,也可能是一个语言模型 P_LM (x)生成的文本。他们训练了一个分类器 E_θ (x)去区分 x 是真实的(real)还是生成的:
这里的 σ 是 sigmoid 函数,以确保概率在 0-1 范围内。以下示意图展示了训练的目标:
一个好的分类器 E_θ (x)可以确保当 x 比较接近真实文本时,E_θ (x)比较小;而当 x 比较接近语言模型生成文本时,E_θ (x)比较大。利用 E_θ (x),可以修正语言模型 P_LM (x),从而得到一个新的文本生成模型 P_θ (x):
上式就是本文提出的残差能量模型(residual energy-based model),这里的 Z 是一个全局归一化常数。之所以叫它残差模型,是因为
在修正
,比如当 E_θ (x)≡0 时,
这个残差模型非常直观,当 x 比较「不真实」时,E_θ (x)比较大,因此在残差模型中的概率
会低于未经修正前的
选择这样形式的模型是否有数学上的依据呢?事实上,作者的训练方法是噪声对抗训练(NCE)的一个特殊形式 [3,4]。理论保证详见论文中的定理 1,其结论是当 E_θ (x) 足够强大时(一般意味着足够多参数),目标函数的最优解是
,亦即即使语言模型 P_LM (x)和真实文本有偏差,足够强大的 E_θ (x)和足够好的优化算法都可以使残差模型无限逼近真实文本分布。
虽然本文提出的模型具有很好的理论保证,但引入分类器 / 修正器 E_θ (x)引入了额外的参数。为什么不直接增加语言模型的参数呢?这涉及到了语言模型 P_LM (x)和残差能量模型 P_θ (x)的本质区别:目前的语言模型 P_LM (x)一般是局部归一化(locally normalized)的,而 P_θ (x)是全局归一化的(globally normalized):
也就是说,P_LM (x)的模型在生成每个单词时,只能使用前面已经生成的单词的信息。因此我们只能使用单向的模型作为文本生成模型,而无法使用双向的模型。对比之下,E_θ (x_1 x_2…x_T )是直接取整个文本作为模型的输入,因此可以使用双向的模型,比如预训练的 BERT。由于不需要像 P_LM (x)那样每生成一个单词都归一化,因此全局归一化的 P_θ (x)更灵活。其实,P_LM (x)只是 P_θ (x)的一种特例。
虽然全局归一化的模型更灵活,但与 P_LM (x)不同,P_θ (x)不能从左至右逐词生成,因为 E_θ (x)需要以整个文本作为输入。对此,作者提出了基于 importance sampling 的生成方式:为了生成一个文本,作者
首先从 P_LM (x)中采样 N 个完整文本{x^1,x^2,…,x^N}
然后从这个样本集中进行采样:P(x=x^i)∝exp(-E_θ (x^i ))
上述过程非常类似机器翻译和句法分析中的再排序算法(reranking),然而本文作者提出的算法有两点重要的改进:第一,他们的算法具有理论保证,当样本数 N 足够大,上述过程中采集的样本服从 P_θ (x)的分布;第二,再排序在第二步骤进行的是排序,而他们进行的是采样(初步实验证明排序的效果弱于采样,类似 [5] 中的观察)。
实验
最后简要介绍一下实验结果。本文主要使用的数据集 CC-News 规模非常大,有 160 亿个词 [6]。另外,作者选择的基线(baseline)是 GPT 级别的 state-of-the-art 语言模型。对如此大规模数据下基线模型的提高是非常有意义的。
首先,作为生成模型,作者使用自然语言处理中的常用指标 perplexity(PPL)衡量真实文本在模型下的概率。PPL 可以简化理解为正确生成每个词,模型平均需要猜几次。因此,PPL 越低越好。这里残差能量模型的 PPL 使用采样估计的上界,详见论文。
在上图中,base LM 是语言模型 P_LM (x),其余的(Joint 开头)都是残差能量模型。使用单向的 transformer 作为 E_θ ()(Joint UniT),PPL 略有降低,而使用双向的 transformer(Joint BiT-base),PPL 比单向模型进一步下降(值得一提的是,传统的语言模型是没法使用双向 transformer 的)。最后两列展示了本文所提方法可以使用预训练的双向模型,这里作者使用了 BERT 的变种 Roberta-base(Joint BiT-base)和 Roberta-Large(Joint BiT-Large),效果得到了进一步的提升。
PPL 的降低证明了:从概率模型的角度,本文提出的模型是优于基线模型的。但该模型能否生成更以假乱真的文本呢?下面的表格中,作者做了人工评测的实验,验证了该模型的确可以得到更好的文本:
最后,作者给出了一个具体例子,直观理解残差模型如何修正改进语言模型 P_LM (x)。
前文指出过,此项研究的生成过程是先采样一些样本,然后使用〖-E〗_θ (x)作为分数从这些样本中进行再次采样。以上的 Joint Bit-base Worst 是〖-E〗_θ (x)最低的样本(也就是分类器认为最不像真实文本的)。这个样本中,词组「these grants」重复了两次。重复生成词组是目前语言模型的常见问题 [5],因此分类器会根据这个特点,很容易判断出这句话不是真实文本,由此在再采样过程中,这个分数很低的样本基本不可能被采样到。值得一提的是,本文提出的模型训练时并没有明确要求它不生成重复词组,但分类器自动发现重复词组是一个语言模型生成文本的明显特征,因此残差能量模型生成的重复词组明显减少(详见论文)。
总结来看,残差能量模型是比 state-of-the-art 的 transformer 语言模型效果更好的全局归一化模型。它的训练方式只是训练一个辨别真实文本还是语言模型生成的分类器,因此非常简单稳定,同时还拥有 NCE 带来的理论正确保证。
作者在实验中使用了语言模型作为测试任务,但实际上很容易推广到条件生成,比如机器翻译或者文本摘要。
另外,作者提出的能量模型和 GAN 的思路有很大不同:GAN 使用分类判别器的目的是改进生成器,最后并没有使用分类判别器;而残差能量模型最终使用分类器,而且训练过程中不去试图改变分类器,因此训练过程更加稳定。最后,全局归一化(globally normalized)的能量模型虽然在 Yann Lecun 等人看来是未来的重要方向(https://iclr.cc/virtual_2020/speaker_7.html),但目前还没有得到广泛重视。作者认为这里有很多未来工作的可能方向,比如和隐变量结合等。
引用:
[1]: Bakhtin, Anton, Yuntian Deng, Sam Gross, Myle Ott, Marc'Aurelio Ranzato, and Arthur Szlam."Energy-based Models for Text." arXiv (2020): arXiv-2004.
[2]: Zellers, Rowan, Ari Holtzman, Hannah Rashkin, Yonatan Bisk, Ali Farhadi, Franziska Roesner, and Yejin Choi. "Defending against neural fake news." In Advances in Neural Information Processing Systems, pp. 9051-9062. 2019.
[3]: Gutmann, Michael, and Aapo Hyvrinen. "Noise-contrastive estimation: A new estimation principle for unnormalized statistical models." In Proceedings of the Thirteenth International Conference on Artificial Intelligence and Statistics, pp. 297-304. 2010.
[4]: Ma, Zhuang, and Michael Collins. "Noise contrastive estimation and negative sampling for conditional models: Consistency and statistical efficiency." arXiv preprint arXiv:1809.01812 (2018).
[5]: Holtzman, Ari, Jan Buys, Li Du, Maxwell Forbes, and Yejin Choi. "The curious case of neural text degeneration." arXiv preprint arXiv:1904.09751 (2019).