返回博客

核心想法:如果大语言模型不再逐词生成呢?

你使用过的每一个大语言模型——GPT-4、Claude、Llama——都是逐词生成文本的。这种自回归方法简单但天生缓慢:生成 1000 个词,就需要运行 1000 次前向传播。

Meta 首席 AI 科学家 Yann LeCun 一直在倡导一种完全不同的架构:JEPA(联合嵌入预测架构)。JEPA 不预测下一个词,而是预测下一段语义——在嵌入空间中完成。模型在一次前向传播中生成整个后续序列。

我们基于这个理念构建了一个原型 JEPALM(JEPA 语言模型),并在莎士比亚数据集上进行了训练。以下是我们的发现。

架构设计

JEPALM 包含三个核心组件:

  1. TextEncoder(文本编码器):将源文本转换为嵌入向量(4层 Transformer)
  2. Predictor(预测器):一次性将源嵌入映射为目标嵌入(4层交叉注意力 Transformer)
  3. TextDecoder(文本解码器):将预测的嵌入还原为文本(4层 Transformer)
[源文本] → 编码器 → [源嵌入]
                      |
                预测器(一次性输出!)
                      |
                [目标嵌入]
                      |
                解码器 → [目标文本]

总参数量:1010 万——以现代大语言模型的标准来看非常小,但足以验证概念可行性。

训练配置

参数
数据集 Tiny Shakespeare(111 万字符,97 个唯一字符)
模型规模 1010 万参数
训练规模 10 轮 × 2500 批次 = 25,000 批次
GPU NVIDIA A100 80GB
训练时长 约 8.6 小时
学习率 1e-3,余弦衰减

损失函数包含三个组件:

实验结果:三个意外发现

发现一:模型几乎完美地学会了莎士比亚

交叉熵损失——衡量模型复现文本能力的指标——在前 500 个批次就收敛到接近零。这意味着给定正确的嵌入向量,解码器可以近乎完美地重建莎士比亚的文本。

CE 收敛

CE 损失从 0.029 下降到 ~0,之后在剩余 24,500 个批次中保持平坦。

发现二:嵌入预测有效——但进展缓慢

核心创新——从源嵌入预测目标嵌入——展现了稳定但有限的改进。经过 10 轮训练,嵌入 MSE 下降了 25.3%(从 242 降至 181)。

嵌入趋势

每条线代表一个 epoch。跨 epoch 的清晰下降趋势证实预测器确实在学习。

发现三:长度预测任务存在问题

长度预测损失高度波动,即使在稳定阶段也在 0 到 1500+ 之间剧烈震荡。它贡献了总损失的 42.5%——接近一半——但对生成质量的提升微乎其微。

长度波动

长度 MSE 在整个训练过程中不可预测地飙升。20 批次移动平均线显示轻微上升趋势——与我们期望的相反。

损失构成:模型到底在学什么?

在前几百个批次之后,损失构成稳定为:

这意味着模型在很早就不再优化文本质量,将剩余 99% 的训练时间花在了嵌入和长度预测上。交叉熵损失实质上变得无关紧要。

可行性评估

成功之处

不足之处

关键经验教训

  1. 多任务损失权重至关重要。当一个任务(CE)远快于其他任务收敛时,它会停止贡献梯度。动态权重或课程学习可能有所帮助。

  2. 长度预测可能不是合适的辅助任务。从嵌入向量预测精确的词元数量本质上具有噪声。基于分类的方法或直接移除该组件可能更好。

  3. 预测器需要更大容量。仅有 4 层和 256 维,预测器可能不足以完成从源嵌入到目标嵌入的复杂映射。

下一步计划

JEPALM 的下一版本将聚焦于:

总结

JEPALM 实验证明,嵌入空间序列预测是自回归生成的可行替代方案——验证了 LeCun 的核心假设。架构本身是有效的,但损失函数设计需要大幅改进。

这次实验也揭示了一个更广泛的 AI 研究教训:魔鬼藏在损失函数中。即使架构设计合理,糟糕的损失设计也会主导训练过程,掩盖有效信号。下一版本将解决这些问题,更清晰地回答 JEPA 式语言模型能否与自回归方法竞争这一问题。


本实验在双 A100 80GB 服务器上完成。完整的训练代码和日志可在 JEPALM 仓库中获取。欢迎贡献代码和改进建议。