Goal-Conditioned Predictive Coding for Offline Reinforcement Learning

Brown University
NeurIPS 2023

Abstract

Recent work has demonstrated the effectiveness of formulating decision making as supervised learning on offline-collected trajectories. Powerful sequence models, such as GPT or BERT, are often employed to encode the trajectories. However, the benefits of performing sequence modeling on trajectory data remain unclear. In this work, we investigate whether sequence modeling has the ability to condense trajectories into useful representations that enhance policy learning. We adopt a two-stage framework that first leverages sequence models to encode trajectory-level representations, and then learns a goal-conditioned policy employing the encoded representations as its input. This formulation allows us to consider many existing supervised offline RL methods as specific instances of our framework. Within this framework, we introduce Goal-Conditioned Predictive Coding (GCPC), a sequence modeling objective that yields powerful trajectory representations and leads to performant policies. Through extensive empirical evaluations on AntMaze, FrankaKitchen and Locomotion environments, we observe that sequence modeling can have a significant impact on challenging decision making tasks. Furthermore, we demonstrate that GCPC learns a goal-conditioned latent representation encoding the future trajectory, which enables competitive performance on all three benchmarks.

Goal-Conditioned Predictive Coding

We introduce a two-stage framework that decouples the trajectory representation learning and policy learning. Under this framework, we derive a specific design that performs goal-conditioned future prediction and generates latent representations encoding future behaviors toward the desired goal.

Two-Stage GCPC

What Brings Effective Trajectory Representations?

The two-stage framework offers flexibility on the choice of representation learning objectives, and allows us to study the impact of sequence modeling for trajectory representation learning and policy learning independently. We further explore how to properly utilize sequence modeling to generate helpful trajectory representations from the following aspects:


Masking Patterns: To study the impact of trajectory representation learning objectives on the resulting policy performance, we implement five different sequence modeling objectives by varying masking patterns in the first stage pretraining. We observe that predicitve coding objectives yields powerful trajectory representations about the future trajectory, which enhance the policy learning.

masking_patterns

Goal Conditioning: We investigate whether goal conditioning (i.e. the input goal) in TrajNet is necessary or beneficial for learning trajectory representations. Goal conditioning is crucial for predictive coding objectives to properly encode expected long-term future.


goal conditioning

Comparison with Explicit Future: In GCPC, the future trajectory information is stored in the goal-conditioned latent representation, which serves as a conditioning variable of the policy. We compare the latent future representation with the decoded explicit future sequence to study how the form of future information would impact policy performance. We observe that the latent representation is a powerful future information carrier that can effectively improve the policy performance.

future ablation

Benchmark Results

Acknowledgements

We appreciate all anonymous reviewers for their constructive feedback. We would like to thank Calvin Luo and Haotian Fu for their discussions and insights, and Tian Yun for the help on this project. This work is in part supported by Adobe, Honda Research Institute, Meta AI, Samsung Advanced Institute of Technology, and a Richard B. Salomon Faculty Research Award for C.S.

BibTeX

@inproceedings{zeng2023gcpc,
  title={Goal-Conditioned Predictive Coding for Offline Reinforcement Learning},
  author={Zeng, Zilai and Zhang, Ce and Wang, Shijie and Sun, Chen},
  booktitle={Advances in Neural Information Processing Systems},
  year={2023}
}