Decision Transformer: Reinforcement Learning via Sequence Modeling
A framework that abstracts RL as a sequence modeling problem.
Decision Transformer: Reinforcement Learning via Sequence Modeling
参考文档
论文地址:https://arxiv.org/abs/2106.01345
1. Abstract
- Decision Transformer:
- A framework that abstracts RL as a sequence modeling problem.
- casts the problem of RL as conditional sequence modeling.
- simply outputs the optimal actions by leveraging a causally masked Transformer.
- casts the problem of RL as conditional sequence modeling.
2. Introduction
- 替换传统的RL算法:modeling the joint distribution of the sequence of states, actions, and rewards
3. Method
Trajectory representation:
自回归学习以下序列: \(\tau = (\hat{R_1},s_1,a_1,\hat{R_2},s_2,a_2,...,\hat{R_T},s_T,a_T,)\) 但是并不直接使用reward,而是使用returns-to-go
\[\hat{R_t} = \sum_{t'=t}^Tr_{t'}\]Architecture:
总共K个时间步的序列自回归预测,每个时间步包含3类token,因此总共是3K个token:
\(<\text{return-to-go}, state, action>\)
- embedding layer:
- 其他模态(linear layer for each modality):将原始输入转换为embedding
- 视觉输入:visual inputs is fed into a convolutional encoder instead of a linear layer.
- normalization layer:
- 时间步的embedding相加(与常规的positional embedding有所不同)
- attention layer
Training:
···python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# R , s , a , t : returns - to - go , states , actions , or timesteps # transformer : transformer with causal masking ( GPT ) # embed_s , embed_a , embed_R : linear embedding layers # embed_t : learned episode positional embedding # pred_a : linear action prediction layer
# main model
def DecisionTransformer (R , s , a , t ):
# compute embeddings for tokens
pos_embedding = embed_t ( t ) # per - timestep ( note : not per - token )
s_embedding = embed_s ( s ) + pos_embedding # state emb
a_embedding = embed_a ( a ) + pos_embedding # action emb
R_embedding = embed_R ( R ) + pos_embedding # reward-to-go emb
# interleave tokens as ( R_1 , s_1 , a_1 , ... , R_K , s_K )
input_embeds = stack ( R_embedding , s_embedding , a_embedding )
# use transformer to get hidden states
hidden_states = transformer ( input_embeds = input_embeds )
# select hidden states for action prediction tokens
a_hidden = unstack ( hidden_states ). actions
# predict action
return pred_a ( a_hidden )
# traning loop
for (R,s,a,t) in dataloader: # dims: (batch_size, K, dim)
a_preds = DecisionTransformer (R , s , a , t )
loss = mean (( a_preds - a )**2) # L2 loss for continuous actions
optimizer.zero_grad()
loss.backward()
optimizer.step()
# evaluation loop
target_return = 1 # for instance , expert - level return
R , s , a , t , done = [ target_return ] , [ env . reset ()] , [] , [1] , False while not done : # autoregressive generation / sampling
# sample next action
action = DecisionTransformer (R , s , a , t )[ -1] # for cts actions
new_s , r , done , _ = env.step(action)
# append new tokens to sequence
R = R + [ R [ -1] - r ] # decrement returns - to - go with reward
s , a , t = s + [ new_s ] , a + [ action ] , t + [ len ( R )]
R , s , a , t = R [ - K :] , ... # only keep context length of K# evaluation loop
4. Evaluations on Offline RL Benchmarks
This post is licensed under CC BY 4.0 by the author.
