Decision Transformer - 学习笔记

Decision Transformer 学习笔记

1. 文章摘要

Transformer模型是一种seq2seq的模型,它的独特之处在于,给定一个输入序列,由模型自己决定输出序列的内容与长度。最开始很自然地应用在NLP类似的序列问题的解决上,不过很多问题都可以建模为seq2seq问题,从而这个模型也可以解决许许多多不同的问题,如:语音识别,词性标注,图像物体识别等。该文章将Transformer尝试使用在强化学习领域,提出一种将强化学习建模为序列决策任务的框架,称为Decision Tranformer 。模型的思想很简单:基于期望回报(Return To Go)、历史状态以及历史动作输出下一刻的最优动作。与传统的拟合值函数或者计算策略梯度的强化学习方法不同,Decision Tranformer是一种“离线强化学习(Offline Reinforcement Learning)”模型,更具体的说就是在模型训练的过程中完全不与真实环境进行交互,通过监督学习的方式,从历史样本中学习专家经验。

模型通过利用casually masked Transformer来输出最优动作。通过自回归(autoregressive 意思就是把自己当前和之前的所有输出作为下一次的输入,迭代产生一个输出序列)的方式运行,历史序列+当前的序列不断运行,让本文的Transformer模型可以生成相应的动作去达成期望回报。

论文:https://arxiv.org/abs/2106.01345

代码:https://github.com/kzl/decision-transformer

2. 模型结构

2.1 整体模型结构

Decision Transformer 模型结构

DT的模型结构如图所示,状态、动作、回报被投入各自对应的embedding中,并进行位置编码。这里位置编码采用的是时间戳编码。R,S,A组成的token被送入GPT结构中,以自回归的方式,结合因果掩码(causal mask)预测下一个时刻的动作。

2.2 模型分类

由于dt中没有model(model的作用是用于预测未来的状态),故模型属于model-free模型。
由于dt是采用离线数据进行训练,训练过程中不与真实环境进行交互,故模型属于offline RL模型。

2.3 回报设计

与传统的强化学习不同,文章希望transformer从历史序列中学习到动作的信息,并用于预测未来的动作。然而,对奖励函数进行建模又是不现实的,文章因此使用了reward-to-go(RTG)作为轨迹在训练过程中的reward,而非reward的原始值。

在测试的时候,只需要给定一个期望的奖励,以及初始状态即可。在实际环境中运行,得到实际奖励之后,就将期望奖励减去这个实际奖励,再迭代送入input。

3. 模型伪代码

伪代码

模型伪代码如图所示,基本上和上图结构一致。首先R S A送入各自的embedding,然后进行stack操作(类似concat,可以理解把三张表按照合并列的方式拼接)。随后送入transformer模型中,得到隐状态(hidden_state),并根据隐状态进行动作选择,最终得到预测动作并执行。

在评价回合执行动作后,获取剩余奖励,减去RTG,和S A拼接成token送入模型继续预测下一个动作。

4. 实验部分

由于DT的思路就是学习(s,a,r)的轨迹,很自然的想到,这和模仿学习非常相似,区别就在于dt还多加了一个rtg。

  1. 模型和行为模仿之间的比较
    Decision Transformer跟最好的%BC表现相当,表明在训练了整个数据集之后,它可以在特定的子集上选择更优的行为。这里测评的都是各种游戏,目标是获得更高的奖励分数。对比的是使用多少百分比轨迹训练的BC。
    文章指出DT取得了与使用更多轨迹BC相似的结果

  2. 使用更长上下文的好处
    DT是以序列的方式进行输入,我们很容易想到,序列里只放一个元素也可以进行预测,所以文章还测试了序列中序列上下文长度对预测效果的影响。结论是采用更长的序列进行预测可以获得更好的奖励。

长上下文评测结果

其他实验目的是证明当中的回报有效,就不再仔细展开了。

5. 参考文献

https://zhuanlan.zhihu.com/p/501117104


Decision Transformer - 学习笔记
https://runsstudio.github.io/2025/05/28/Decision Transformer学习笔记/
发布于
2025年5月28日
许可协议