cs
DL·ML

[DL] RLHF (Reinforcement Learning from Human Feedback)

Introduction

language model의 loss를 정의하는 것은 어려운 일이다. 단순한 token의 차이로는 좋은 loss를 얻기 어렵고, BLEU나 ROUGE로 측정하고 있지만 이 경우에도 단순 비교를 통해서 얻기 때문에 정확하지 않은 점이 많다.

 

RLHF(Reinforcement Learning from Human Feedback)은 이런 점을 개선하여 사람의 feedback으로부터 모델이 학습할 수 있도록 한다.  

 

RLHF는 task alignment를 위해 사용하는 방법으로, pretrained LM에 대해 사용한다. 다양한 instruction에 대해서 반응할 수 있는 형태여야 한다. 

 

그 다음, 사람의 preference를 반영한 reward model을 만든다. 

 

Reward Model

reward model은 text sequence를 받아 reward $\in \mathbb R$을 리턴하는 모델이다. 이 reward는 human preference를 의미한다. 

 

reward modeling에는 fine-tuned LM을 사용하거나, preference data로 학습한 LM 모델이 될 수 있다. 다른 형태여도 상관 없다. 

 

RM을 학습하기 위해서는 prompt가 나타난다. 이를 통해서 새로운 text를 만들어내고 human annotator는 이를 rank를 매겨 평가한다. human feedback을 scalar로 사용하지 않는 이유는, 사람마다 점수의 산정 방식이 달라 uncalibrated score를 만들어내기 때문이다. 

 

Fine-tuning with RL

만들어진 reward model을 이용하여 원래의 language model을 fine-tuning한다. 

 

RL을 사용하는 방법은, initial LM의 parameter의 copy를 policy-gradient RL algorithm(Proximal Policy Optimization, PPO)으로 학습시키는 것이다.

 

전체 parameter를 학습시키는 것은 너무 expensive하므로 일부는 freeze되고, 이 크기는 유동적이다. 

 

fine-tuning에서 policy는 prompt를 받아서 seqeunce of text를 만드는 LM이다. 이 policy의 action space는 LM의 vocabulary에 대응하는 모든 token이다. observation space는 가능한 input token sequence의 distribution이다. reward function은 preference model의 combination이다.

 

update는 reward를 maximize하도록 학습된다.

 

 


References

[1] https://huggingface.co/blog/rlhf

<script> MathJax = { tex: {inlineMath: [['$', '$'], ['\\(', '\\)']]} }; </script>
<script src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-chtml.js"></script>

Footnotes

'DL·ML' 카테고리의 다른 글

TPU란 무엇인가?  (0) 2023.11.10
[GNN] GNN Model  (0) 2023.11.07
[GNN] Visual genome dataset  (0) 2023.09.15
[GNN] GCN을 이용한 Image Captioning 구현  (0) 2023.08.24
[DL] Vision Transformer (ViT)  (0) 2023.08.17